├── misc
├── __init__.py
├── resnet_utils.py
├── loss_wrapper.py
├── resnet.py
├── rewards.py
└── utils.py
├── vis
├── imgs
│ └── dummy
└── index.html
├── scripts
├── copy_model.sh
├── make_bu_data.py
├── prepro_reference_json.py
├── prepro_feats.py
├── prepro_ngrams.py
├── dump_to_lmdb.py
├── build_bpe_subword_nmt.py
└── prepro_labels.py
├── test-last.sh
├── .gitmodules
├── test-best.sh
├── ADVANCED.md
├── LICENSE
├── train.sh
├── train-wo-refining.sh
├── models
├── __init__.py
├── AttEnsemble.py
├── ShowTellModel.py
├── FCModel.py
├── AoAModel.py
├── OldModel.py
├── CaptionModel.py
└── TransformerModel.py
├── eval.py
├── README.md
├── eval_ensemble.py
├── data
└── README.md
├── dataloaderraw.py
├── eval_utils.py
├── train.py
├── dataloader.py
└── opts.py
/misc/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/vis/imgs/dummy:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/scripts/copy_model.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | cp -r log_$1 log_$2
4 | cd log_$2
5 | mv infos_$1-best.pkl infos_$2-best.pkl
6 | mv infos_$1.pkl infos_$2.pkl
7 | cd ../
--------------------------------------------------------------------------------
/test-last.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=$3 python eval.py --dump_images 0 --dump_json 1 --num_images -1 --model $1/model.pth --infos_path $1/infos_$2.pkl --language_eval 1 --beam_size $4 --batch_size 100 --split test
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "cider"]
2 | path = cider
3 | url = https://github.com/ruotianluo/cider.git
4 | [submodule "coco-caption"]
5 | path = coco-caption
6 | url = https://github.com/ruotianluo/coco-caption.git
7 |
--------------------------------------------------------------------------------
/test-best.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=$3 python eval.py --dump_images 0 --dump_json 1 --num_images -1 --model $1/model-best.pth --infos_path $1/infos_$2-best.pkl --language_eval 1 --image_root /mnt/hl/dataset/coco/train2014/ --beam_size $4 --batch_size 100 --split test
--------------------------------------------------------------------------------
/ADVANCED.md:
--------------------------------------------------------------------------------
1 | # Advanced
2 |
3 | ## Ensemble
4 |
5 | Current ensemble only supports models which are subclass of AttModel. Here is example of the script to run ensemble models. The `eval_ensemble.py` assumes the model saving under `log_$id`.
6 |
7 | ```
8 | python eval_ensemble.py --dump_json 0 --ids model1,model2,model3 --weights 0.3,0.3,0.3 --batch_size 1 --dump_images 0 --num_images 5000 --split test --language_eval 1 --beam_size 5 --temperature 1.0 --sample_method greedy --max_length 30
9 | ```
10 |
11 | ## Batch normalization
12 |
13 | ## Box feature
--------------------------------------------------------------------------------
/misc/resnet_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class myResnet(nn.Module):
6 | def __init__(self, resnet):
7 | super(myResnet, self).__init__()
8 | self.resnet = resnet
9 |
10 | def forward(self, img, att_size=14):
11 | x = img.unsqueeze(0)
12 |
13 | x = self.resnet.conv1(x)
14 | x = self.resnet.bn1(x)
15 | x = self.resnet.relu(x)
16 | x = self.resnet.maxpool(x)
17 |
18 | x = self.resnet.layer1(x)
19 | x = self.resnet.layer2(x)
20 | x = self.resnet.layer3(x)
21 | x = self.resnet.layer4(x)
22 |
23 | fc = x.mean(3).mean(2).squeeze()
24 | att = F.adaptive_avg_pool2d(x,[att_size,att_size]).squeeze().permute(1, 2, 0)
25 |
26 | return fc, att
27 |
28 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Ruotian(RT) Luo
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/misc/loss_wrapper.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import misc.utils as utils
3 | from misc.rewards import init_scorer, get_self_critical_reward
4 |
5 | class LossWrapper(torch.nn.Module):
6 | def __init__(self, model, opt):
7 | super(LossWrapper, self).__init__()
8 | self.opt = opt
9 | self.model = model
10 | if opt.label_smoothing > 0:
11 | self.crit = utils.LabelSmoothing(smoothing=opt.label_smoothing)
12 | else:
13 | self.crit = utils.LanguageModelCriterion()
14 | self.rl_crit = utils.RewardCriterion()
15 |
16 | def forward(self, fc_feats, att_feats, labels, masks, att_masks, gts, gt_indices,
17 | sc_flag):
18 | out = {}
19 | if not sc_flag:
20 | loss = self.crit(self.model(fc_feats, att_feats, labels, att_masks), labels[:,1:], masks[:,1:])
21 | else:
22 | self.model.eval()
23 | with torch.no_grad():
24 | greedy_res, _ = self.model(fc_feats, att_feats, att_masks, mode='sample')
25 | self.model.train()
26 | gen_result, sample_logprobs = self.model(fc_feats, att_feats, att_masks, opt={'sample_method':'sample'}, mode='sample')
27 | gts = [gts[_] for _ in gt_indices.tolist()]
28 | reward = get_self_critical_reward(greedy_res, gts, gen_result, self.opt)
29 | reward = torch.from_numpy(reward).float().to(gen_result.device)
30 | loss = self.rl_crit(sample_logprobs, gen_result.data, reward)
31 | out['reward'] = reward[:,0].mean()
32 | out['loss'] = loss
33 | return out
34 |
--------------------------------------------------------------------------------
/scripts/make_bu_data.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import os
6 | import base64
7 | import numpy as np
8 | import csv
9 | import sys
10 | import zlib
11 | import time
12 | import mmap
13 | import argparse
14 |
15 | parser = argparse.ArgumentParser()
16 |
17 | # output_dir
18 | parser.add_argument('--downloaded_feats', default='data/bu_data', help='downloaded feature directory')
19 | parser.add_argument('--output_dir', default='data/cocobu', help='output feature files')
20 |
21 | args = parser.parse_args()
22 |
23 | csv.field_size_limit(sys.maxsize)
24 |
25 |
26 | FIELDNAMES = ['image_id', 'image_w','image_h','num_boxes', 'boxes', 'features']
27 | infiles = ['trainval/karpathy_test_resnet101_faster_rcnn_genome.tsv',
28 | 'trainval/karpathy_val_resnet101_faster_rcnn_genome.tsv',\
29 | 'trainval/karpathy_train_resnet101_faster_rcnn_genome.tsv.0', \
30 | 'trainval/karpathy_train_resnet101_faster_rcnn_genome.tsv.1']
31 |
32 | os.makedirs(args.output_dir+'_att')
33 | os.makedirs(args.output_dir+'_fc')
34 | os.makedirs(args.output_dir+'_box')
35 |
36 | for infile in infiles:
37 | print('Reading ' + infile)
38 | with open(os.path.join(args.downloaded_feats, infile), "r+b") as tsv_in_file:
39 | reader = csv.DictReader(tsv_in_file, delimiter='\t', fieldnames = FIELDNAMES)
40 | for item in reader:
41 | item['image_id'] = int(item['image_id'])
42 | item['num_boxes'] = int(item['num_boxes'])
43 | for field in ['boxes', 'features']:
44 | item[field] = np.frombuffer(base64.decodestring(item[field]),
45 | dtype=np.float32).reshape((item['num_boxes'],-1))
46 | np.savez_compressed(os.path.join(args.output_dir+'_att', str(item['image_id'])), feat=item['features'])
47 | np.save(os.path.join(args.output_dir+'_fc', str(item['image_id'])), item['features'].mean(0))
48 | np.save(os.path.join(args.output_dir+'_box', str(item['image_id'])), item['boxes'])
49 |
50 |
51 |
52 |
53 |
--------------------------------------------------------------------------------
/vis/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | neuraltalk2 results visualization
7 |
8 |
42 |
43 |
44 |
45 |
72 |
73 |
74 |
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | id="aoanet"
2 | if [ ! -f log/log_$id/infos_$id.pkl ]; then
3 | start_from=""
4 | else
5 | start_from="--start_from log/log_$id"
6 | fi
7 | python train.py --id $id \
8 | --caption_model aoa \
9 | --refine 1 \
10 | --refine_aoa 1 \
11 | --use_ff 0 \
12 | --decoder_type AoA \
13 | --use_multi_head 2 \
14 | --num_heads 8 \
15 | --multi_head_scale 1 \
16 | --mean_feats 1 \
17 | --ctx_drop 1 \
18 | --dropout_aoa 0.3 \
19 | --label_smoothing 0.2 \
20 | --input_json data/cocotalk.json \
21 | --input_label_h5 data/cocotalk_label.h5 \
22 | --input_fc_dir data/cocobu_fc \
23 | --input_att_dir data/cocobu_att \
24 | --input_box_dir data/cocobu_box \
25 | --seq_per_img 5 \
26 | --batch_size 10 \
27 | --beam_size 1 \
28 | --learning_rate 2e-4 \
29 | --num_layers 2 \
30 | --input_encoding_size 1024 \
31 | --rnn_size 1024 \
32 | --learning_rate_decay_start 0 \
33 | --scheduled_sampling_start 0 \
34 | --checkpoint_path log/log_$id \
35 | $start_from \
36 | --save_checkpoint_every 6000 \
37 | --language_eval 1 \
38 | --val_images_use -1 \
39 | --max_epochs 25 \
40 | --scheduled_sampling_increase_every 5 \
41 | --scheduled_sampling_max_prob 0.5 \
42 | --learning_rate_decay_every 3
43 |
44 | python train.py --id $id \
45 | --caption_model aoa \
46 | --refine 1 \
47 | --refine_aoa 1 \
48 | --use_ff 0 \
49 | --decoder_type AoA \
50 | --use_multi_head 2 \
51 | --num_heads 8 \
52 | --multi_head_scale 1 \
53 | --mean_feats 1 \
54 | --ctx_drop 1 \
55 | --dropout_aoa 0.3 \
56 | --input_json data/cocotalk.json \
57 | --input_label_h5 data/cocotalk_label.h5 \
58 | --input_fc_dir data/cocobu_fc \
59 | --input_att_dir data/cocobu_att \
60 | --input_box_dir data/cocobu_box \
61 | --seq_per_img 5 \
62 | --batch_size 10 \
63 | --beam_size 1 \
64 | --num_layers 2 \
65 | --input_encoding_size 1024 \
66 | --rnn_size 1024 \
67 | --language_eval 1 \
68 | --val_images_use -1 \
69 | --save_checkpoint_every 3000 \
70 | --start_from log/log_$id \
71 | --checkpoint_path log/log_$id"_rl" \
72 | --learning_rate 2e-5 \
73 | --max_epochs 40 \
74 | --self_critical_after 0 \
75 | --learning_rate_decay_start -1 \
76 | --scheduled_sampling_start -1 \
77 | --reduce_on_plateau
--------------------------------------------------------------------------------
/train-wo-refining.sh:
--------------------------------------------------------------------------------
1 | id="aoanet-wo-refinng"
2 | if [ ! -f log/log_$id/infos_$id.pkl ]; then
3 | start_from=""
4 | else
5 | start_from="--start_from log/log_$id"
6 | fi
7 | python train.py --id $id \
8 | --caption_model aoa \
9 | --refine 0 \
10 | --refine_aoa 0 \
11 | --use_ff 0 \
12 | --decoder_type AoA \
13 | --use_multi_head 2 \
14 | --num_heads 8 \
15 | --multi_head_scale 1 \
16 | --mean_feats 1 \
17 | --ctx_drop 0 \
18 | --dropout_aoa 0.3 \
19 | --label_smoothing 0 \
20 | --input_json data/cocotalk.json \
21 | --input_label_h5 data/cocotalk_label.h5 \
22 | --input_fc_dir data/cocobu_fc \
23 | --input_att_dir data/cocobu_att \
24 | --input_box_dir data/cocobu_box \
25 | --seq_per_img 5 \
26 | --batch_size 10 \
27 | --beam_size 1 \
28 | --learning_rate 2e-4 \
29 | --num_layers 2 \
30 | --input_encoding_size 1024 \
31 | --rnn_size 1024 \
32 | --learning_rate_decay_start 0 \
33 | --scheduled_sampling_start 0 \
34 | --checkpoint_path log/log_$id \
35 | $start_from \
36 | --save_checkpoint_every 6000 \
37 | --language_eval 1 \
38 | --val_images_use -1 \
39 | --max_epochs 35 \
40 | --scheduled_sampling_increase_every 5 \
41 | --scheduled_sampling_max_prob 0.5 \
42 | --learning_rate_decay_every 3
43 |
44 | python train.py --id $id \
45 | --caption_model aoa \
46 | --refine 0 \
47 | --refine_aoa 0 \
48 | --use_ff 0 \
49 | --decoder_type AoA \
50 | --use_multi_head 2 \
51 | --num_heads 8 \
52 | --multi_head_scale 1 \
53 | --mean_feats 1 \
54 | --ctx_drop 0 \
55 | --dropout_aoa 0.3 \
56 | --input_json data/cocotalk.json \
57 | --input_label_h5 data/cocotalk_label.h5 \
58 | --input_fc_dir data/cocobu_fc \
59 | --input_att_dir data/cocobu_att \
60 | --input_box_dir data/cocobu_box \
61 | --seq_per_img 5 \
62 | --batch_size 10 \
63 | --beam_size 1 \
64 | --num_layers 2 \
65 | --input_encoding_size 1024 \
66 | --rnn_size 1024 \
67 | --language_eval 1 \
68 | --val_images_use -1 \
69 | --save_checkpoint_every 3000 \
70 | --start_from log/log_$id \
71 | --checkpoint_path log/log_$id"_rl" \
72 | --learning_rate 2e-5 \
73 | --max_epochs 60 \
74 | --self_critical_after 0 \
75 | --learning_rate_decay_start -1 \
76 | --scheduled_sampling_start -1 \
77 | --reduce_on_plateau
78 |
--------------------------------------------------------------------------------
/misc/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision.models.resnet
4 | from torchvision.models.resnet import BasicBlock, Bottleneck
5 |
6 | class ResNet(torchvision.models.resnet.ResNet):
7 | def __init__(self, block, layers, num_classes=1000):
8 | super(ResNet, self).__init__(block, layers, num_classes)
9 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) # change
10 | for i in range(2, 5):
11 | getattr(self, 'layer%d'%i)[0].conv1.stride = (2,2)
12 | getattr(self, 'layer%d'%i)[0].conv2.stride = (1,1)
13 |
14 | def resnet18(pretrained=False):
15 | """Constructs a ResNet-18 model.
16 |
17 | Args:
18 | pretrained (bool): If True, returns a model pre-trained on ImageNet
19 | """
20 | model = ResNet(BasicBlock, [2, 2, 2, 2])
21 | if pretrained:
22 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
23 | return model
24 |
25 |
26 | def resnet34(pretrained=False):
27 | """Constructs a ResNet-34 model.
28 |
29 | Args:
30 | pretrained (bool): If True, returns a model pre-trained on ImageNet
31 | """
32 | model = ResNet(BasicBlock, [3, 4, 6, 3])
33 | if pretrained:
34 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
35 | return model
36 |
37 |
38 | def resnet50(pretrained=False):
39 | """Constructs a ResNet-50 model.
40 |
41 | Args:
42 | pretrained (bool): If True, returns a model pre-trained on ImageNet
43 | """
44 | model = ResNet(Bottleneck, [3, 4, 6, 3])
45 | if pretrained:
46 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
47 | return model
48 |
49 |
50 | def resnet101(pretrained=False):
51 | """Constructs a ResNet-101 model.
52 |
53 | Args:
54 | pretrained (bool): If True, returns a model pre-trained on ImageNet
55 | """
56 | model = ResNet(Bottleneck, [3, 4, 23, 3])
57 | if pretrained:
58 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
59 | return model
60 |
61 |
62 | def resnet152(pretrained=False):
63 | """Constructs a ResNet-152 model.
64 |
65 | Args:
66 | pretrained (bool): If True, returns a model pre-trained on ImageNet
67 | """
68 | model = ResNet(Bottleneck, [3, 8, 36, 3])
69 | if pretrained:
70 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
71 | return model
--------------------------------------------------------------------------------
/misc/rewards.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import numpy as np
6 | import time
7 | import misc.utils as utils
8 | from collections import OrderedDict
9 | import torch
10 |
11 | import sys
12 | sys.path.append("cider")
13 | from pyciderevalcap.ciderD.ciderD import CiderD
14 | sys.path.append("coco-caption")
15 | from pycocoevalcap.bleu.bleu import Bleu
16 |
17 | CiderD_scorer = None
18 | Bleu_scorer = None
19 | #CiderD_scorer = CiderD(df='corpus')
20 |
21 | def init_scorer(cached_tokens):
22 | global CiderD_scorer
23 | CiderD_scorer = CiderD_scorer or CiderD(df=cached_tokens)
24 | global Bleu_scorer
25 | Bleu_scorer = Bleu_scorer or Bleu(4)
26 |
27 | def array_to_str(arr):
28 | out = ''
29 | for i in range(len(arr)):
30 | out += str(arr[i]) + ' '
31 | if arr[i] == 0:
32 | break
33 | return out.strip()
34 |
35 | def get_self_critical_reward(greedy_res, data_gts, gen_result, opt):
36 | batch_size = gen_result.size(0)# batch_size = sample_size * seq_per_img
37 | seq_per_img = batch_size // len(data_gts)
38 |
39 | res = OrderedDict()
40 |
41 | gen_result = gen_result.data.cpu().numpy()
42 | greedy_res = greedy_res.data.cpu().numpy()
43 | for i in range(batch_size):
44 | res[i] = [array_to_str(gen_result[i])]
45 | for i in range(batch_size):
46 | res[batch_size + i] = [array_to_str(greedy_res[i])]
47 |
48 | gts = OrderedDict()
49 | for i in range(len(data_gts)):
50 | gts[i] = [array_to_str(data_gts[i][j]) for j in range(len(data_gts[i]))]
51 |
52 | res_ = [{'image_id':i, 'caption': res[i]} for i in range(2 * batch_size)]
53 | res__ = {i: res[i] for i in range(2 * batch_size)}
54 | gts = {i: gts[i % batch_size // seq_per_img] for i in range(2 * batch_size)}
55 | if opt.cider_reward_weight > 0:
56 | _, cider_scores = CiderD_scorer.compute_score(gts, res_)
57 | print('Cider scores:', _)
58 | else:
59 | cider_scores = 0
60 | if opt.bleu_reward_weight > 0:
61 | _, bleu_scores = Bleu_scorer.compute_score(gts, res__)
62 | bleu_scores = np.array(bleu_scores[3])
63 | print('Bleu scores:', _[3])
64 | else:
65 | bleu_scores = 0
66 | scores = opt.cider_reward_weight * cider_scores + opt.bleu_reward_weight * bleu_scores
67 |
68 | scores = scores[:batch_size] - scores[batch_size:]
69 |
70 | rewards = np.repeat(scores[:, np.newaxis], gen_result.shape[1], 1)
71 |
72 | return rewards
73 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import os
6 | import copy
7 |
8 | import numpy as np
9 | import misc.utils as utils
10 | import torch
11 |
12 | from .ShowTellModel import ShowTellModel
13 | from .FCModel import FCModel
14 | from .OldModel import ShowAttendTellModel, AllImgModel
15 | from .AttModel import *
16 | from .TransformerModel import TransformerModel
17 | from .AoAModel import AoAModel
18 |
19 | def setup(opt):
20 | if opt.caption_model == 'fc':
21 | model = FCModel(opt)
22 | elif opt.caption_model == 'language_model':
23 | model = LMModel(opt)
24 | elif opt.caption_model == 'newfc':
25 | model = NewFCModel(opt)
26 | elif opt.caption_model == 'show_tell':
27 | model = ShowTellModel(opt)
28 | # Att2in model in self-critical
29 | elif opt.caption_model == 'att2in':
30 | model = Att2inModel(opt)
31 | # Att2in model with two-layer MLP img embedding and word embedding
32 | elif opt.caption_model == 'att2in2':
33 | model = Att2in2Model(opt)
34 | elif opt.caption_model == 'att2all2':
35 | model = Att2all2Model(opt)
36 | # Adaptive Attention model from Knowing when to look
37 | elif opt.caption_model == 'adaatt':
38 | model = AdaAttModel(opt)
39 | # Adaptive Attention with maxout lstm
40 | elif opt.caption_model == 'adaattmo':
41 | model = AdaAttMOModel(opt)
42 | # Top-down attention model
43 | elif opt.caption_model == 'topdown':
44 | model = TopDownModel(opt)
45 | # StackAtt
46 | elif opt.caption_model == 'stackatt':
47 | model = StackAttModel(opt)
48 | # DenseAtt
49 | elif opt.caption_model == 'denseatt':
50 | model = DenseAttModel(opt)
51 | # Transformer
52 | elif opt.caption_model == 'transformer':
53 | model = TransformerModel(opt)
54 | # AoANet
55 | elif opt.caption_model == 'aoa':
56 | model = AoAModel(opt)
57 | else:
58 | raise Exception("Caption model not supported: {}".format(opt.caption_model))
59 |
60 | # check compatibility if training is continued from previously saved model
61 | if vars(opt).get('start_from', None) is not None:
62 | # check if all necessary files exist
63 | assert os.path.isdir(opt.start_from)," %s must be a a path" % opt.start_from
64 | assert os.path.isfile(os.path.join(opt.start_from,"infos_"+opt.id+".pkl")),"infos.pkl file does not exist in path %s"%opt.start_from
65 | model.load_state_dict(torch.load(os.path.join(opt.start_from, 'model.pth')))
66 |
67 | return model
68 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import json
6 | import numpy as np
7 |
8 | import time
9 | import os
10 | from six.moves import cPickle
11 |
12 | import opts
13 | import models
14 | from dataloader import *
15 | from dataloaderraw import *
16 | import eval_utils
17 | import argparse
18 | import misc.utils as utils
19 | import torch
20 |
21 | # Input arguments and options
22 | parser = argparse.ArgumentParser()
23 | # Input paths
24 | parser.add_argument('--model', type=str, default='',
25 | help='path to model to evaluate')
26 | parser.add_argument('--cnn_model', type=str, default='resnet101',
27 | help='resnet101, resnet152')
28 | parser.add_argument('--infos_path', type=str, default='',
29 | help='path to infos to evaluate')
30 | opts.add_eval_options(parser)
31 |
32 | opt = parser.parse_args()
33 |
34 | # Load infos
35 | with open(opt.infos_path, 'rb') as f:
36 | infos = utils.pickle_load(f)
37 |
38 | # override and collect parameters
39 | replace = ['input_fc_dir', 'input_att_dir', 'input_box_dir', 'input_label_h5', 'input_json', 'batch_size', 'id']
40 | ignore = ['start_from']
41 |
42 | for k in vars(infos['opt']).keys():
43 | if k in replace:
44 | setattr(opt, k, getattr(opt, k) or getattr(infos['opt'], k, ''))
45 | elif k not in ignore:
46 | if not k in vars(opt):
47 | vars(opt).update({k: vars(infos['opt'])[k]}) # copy over options from model
48 |
49 | vocab = infos['vocab'] # ix -> word mapping
50 |
51 | # Setup the model
52 | opt.vocab = vocab
53 | model = models.setup(opt)
54 | del opt.vocab
55 | model.load_state_dict(torch.load(opt.model))
56 | model.cuda()
57 | model.eval()
58 | crit = utils.LanguageModelCriterion()
59 |
60 | # Create the Data Loader instance
61 | if len(opt.image_folder) == 0:
62 | loader = DataLoader(opt)
63 | else:
64 | loader = DataLoaderRaw({'folder_path': opt.image_folder,
65 | 'coco_json': opt.coco_json,
66 | 'batch_size': opt.batch_size,
67 | 'cnn_model': opt.cnn_model})
68 | # When eval using provided pretrained model, the vocab may be different from what you have in your cocotalk.json
69 | # So make sure to use the vocab in infos file.
70 | loader.ix_to_word = infos['vocab']
71 |
72 |
73 | # Set sample options
74 | opt.datset = opt.input_json
75 | loss, split_predictions, lang_stats = eval_utils.eval_split(model, crit, loader,
76 | vars(opt))
77 |
78 | print('loss: ', loss)
79 | if lang_stats:
80 | print(lang_stats)
81 |
82 | if opt.dump_json == 1:
83 | # dump the json
84 | json.dump(split_predictions, open('vis/vis.json', 'w'))
85 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Attention on Attention for Image Captioning
2 |
3 | This repository includes the implementation for [Attention on Attention for Image Captioning](https://arxiv.org/abs/1908.06954).
4 |
5 | ## Requirements
6 |
7 | - Python 3.6
8 | - Java 1.8.0
9 | - PyTorch 1.0
10 | - cider (already been added as a submodule)
11 | - coco-caption (already been added as a submodule)
12 | - tensorboardX
13 |
14 |
15 | ## Training AoANet
16 |
17 | ### Prepare data
18 |
19 | See details in `data/README.md`.
20 |
21 | (**notes:** Set `word_count_threshold` in `scripts/prepro_labels.py` to 4 to generate a vocabulary of size 10,369.)
22 |
23 | You should also preprocess the dataset and get the cache for calculating cider score for [SCST](https://arxiv.org/abs/1612.00563):
24 |
25 | ```bash
26 | $ python scripts/prepro_ngrams.py --input_json data/dataset_coco.json --dict_json data/cocotalk.json --output_pkl data/coco-train --split train
27 | ```
28 | ### Start training
29 |
30 | ```bash
31 | $ CUDA_VISIBLE_DEVICES=0 sh train.sh
32 | ```
33 |
34 | See `opts.py` for the options. (You can download the pretrained models from [here](https://drive.google.com/drive/folders/1ab0iPNyxdVm79ml-oozsIlH7H6t6dIVl?usp=sharing).)
35 |
36 |
37 | ### Evaluation
38 |
39 | ```bash
40 | $ CUDA_VISIBLE_DEVICES=0 python eval.py --model log/log_aoanet_rl/model.pth --infos_path log/log_aoanet_rl/infos_aoanet.pkl --dump_images 0 --dump_json 1 --num_images -1 --language_eval 1 --beam_size 2 --batch_size 100 --split test
41 | ```
42 |
43 | ### Performance
44 | You will get the scores close to below after training under xe loss for 25 epochs:
45 | ```python
46 | {'Bleu_1': 0.7729384559899702, 'Bleu_2': 0.6163398035383025, 'Bleu_3': 0.4790123137715982, 'Bleu_4': 0.36944349063530374, 'METEOR': 0.2848188431924821, 'ROUGE_L': 0.5729849683867054, 'CIDEr': 1.1842173801790759, 'SPICE': 0.21650786258302354}
47 | ```
48 | (**notes:** You can enlarge `--max_epochs` in `train.sh` to train the model for more epochs and improve the scores.)
49 |
50 | after training under SCST loss for another 15 epochs, you will get:
51 | ```python
52 | {'Bleu_1': 0.8054903453672397, 'Bleu_2': 0.6523038976984842, 'Bleu_3': 0.5096621263772566, 'Bleu_4': 0.39140307771618477, 'METEOR': 0.29011216375635934, 'ROUGE_L': 0.5890369750273199, 'CIDEr': 1.2892294296245852, 'SPICE': 0.22680092759866174}
53 | ```
54 |
55 |
56 | ## Reference
57 |
58 | If you find this repo helpful, please consider citing:
59 |
60 | ```
61 | @inproceedings{huang2019attention,
62 | title={Attention on Attention for Image Captioning},
63 | author={Huang, Lun and Wang, Wenmin and Chen, Jie and Wei, Xiao-Yong},
64 | booktitle={International Conference on Computer Vision},
65 | year={2019}
66 | }
67 | ```
68 |
69 | ## Acknowledgements
70 |
71 | This repository is based on [self-critical.pytorch](https://github.com/ruotianluo/self-critical.pytorch), and you may refer to it for more details about the code.
72 |
--------------------------------------------------------------------------------
/eval_ensemble.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import json
6 | import numpy as np
7 |
8 | import time
9 | import os
10 | from six.moves import cPickle
11 |
12 | import opts
13 | import models
14 | from dataloader import *
15 | from dataloaderraw import *
16 | import eval_utils
17 | import argparse
18 | import misc.utils as utils
19 | import torch
20 |
21 | # Input arguments and options
22 | parser = argparse.ArgumentParser()
23 | # Input paths
24 | parser.add_argument('--ids', nargs='+', required=True, help='id of the models to ensemble')
25 | parser.add_argument('--weights', nargs='+', required=False, default=None, help='id of the models to ensemble')
26 | # parser.add_argument('--models', nargs='+', required=True
27 | # help='path to model to evaluate')
28 | # parser.add_argument('--infos_paths', nargs='+', required=True, help='path to infos to evaluate')
29 | opts.add_eval_options(parser)
30 |
31 | opt = parser.parse_args()
32 |
33 | model_infos = []
34 | model_paths = []
35 | for id in opt.ids:
36 | if '-' in id:
37 | id, app = id.split('-')
38 | app = '-'+app
39 | else:
40 | app = ''
41 | model_infos.append(utils.pickle_load(open('log_%s/infos_%s%s.pkl' %(id, id, app), 'rb')))
42 | model_paths.append('log_%s/model%s.pth' %(id,app))
43 |
44 | # Load one infos
45 | infos = model_infos[0]
46 |
47 | # override and collect parameters
48 | replace = ['input_fc_dir', 'input_att_dir', 'input_box_dir', 'input_label_h5', 'input_json', 'batch_size', 'id']
49 | for k in replace:
50 | setattr(opt, k, getattr(opt, k) or getattr(infos['opt'], k, ''))
51 |
52 | vars(opt).update({k: vars(infos['opt'])[k] for k in vars(infos['opt']).keys() if k not in vars(opt)}) # copy over options from model
53 |
54 |
55 | opt.use_box = max([getattr(infos['opt'], 'use_box', 0) for infos in model_infos])
56 | assert max([getattr(infos['opt'], 'norm_att_feat', 0) for infos in model_infos]) == max([getattr(infos['opt'], 'norm_att_feat', 0) for infos in model_infos]), 'Not support different norm_att_feat'
57 | assert max([getattr(infos['opt'], 'norm_box_feat', 0) for infos in model_infos]) == max([getattr(infos['opt'], 'norm_box_feat', 0) for infos in model_infos]), 'Not support different norm_box_feat'
58 |
59 | vocab = infos['vocab'] # ix -> word mapping
60 |
61 | # Setup the model
62 | from models.AttEnsemble import AttEnsemble
63 |
64 | _models = []
65 | for i in range(len(model_infos)):
66 | model_infos[i]['opt'].start_from = None
67 | model_infos[i]['opt'].vocab = vocab
68 | tmp = models.setup(model_infos[i]['opt'])
69 | tmp.load_state_dict(torch.load(model_paths[i]))
70 | _models.append(tmp)
71 |
72 | if opt.weights is not None:
73 | opt.weights = [float(_) for _ in opt.weights]
74 | model = AttEnsemble(_models, weights=opt.weights)
75 | model.seq_length = opt.max_length
76 | model.cuda()
77 | model.eval()
78 | crit = utils.LanguageModelCriterion()
79 |
80 | # Create the Data Loader instance
81 | if len(opt.image_folder) == 0:
82 | loader = DataLoader(opt)
83 | else:
84 | loader = DataLoaderRaw({'folder_path': opt.image_folder,
85 | 'coco_json': opt.coco_json,
86 | 'batch_size': opt.batch_size,
87 | 'cnn_model': opt.cnn_model})
88 | # When eval using provided pretrained model, the vocab may be different from what you have in your cocotalk.json
89 | # So make sure to use the vocab in infos file.
90 | loader.ix_to_word = infos['vocab']
91 |
92 | opt.id = '+'.join([_+str(__) for _,__ in zip(opt.ids, opt.weights)])
93 | # Set sample options
94 | loss, split_predictions, lang_stats = eval_utils.eval_split(model, crit, loader,
95 | vars(opt))
96 |
97 | print('loss: ', loss)
98 | if lang_stats:
99 | print(lang_stats)
100 |
101 | if opt.dump_json == 1:
102 | # dump the json
103 | json.dump(split_predictions, open('vis/vis.json', 'w'))
104 |
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | # Prepare data
2 |
3 | ## COCO
4 |
5 | ### Download COCO captions and preprocess them
6 |
7 | Download preprocessed coco captions from [link](http://cs.stanford.edu/people/karpathy/deepimagesent/caption_datasets.zip) from Karpathy's homepage. Extract `dataset_coco.json` from the zip file and copy it in to `data/`. This file provides preprocessed captions and also standard train-val-test splits.
8 |
9 | Then do:
10 |
11 | ```bash
12 | $ python scripts/prepro_labels.py --input_json data/dataset_coco.json --output_json data/cocotalk.json --output_h5 data/cocotalk
13 | ```
14 |
15 | `prepro_labels.py` will map all words that occur <= 5 times to a special `UNK` token, and create a vocabulary for all the remaining words. The image information and vocabulary are dumped into `data/cocotalk.json` and discretized caption data are dumped into `data/cocotalk_label.h5`.
16 |
17 | ### Download COCO dataset and pre-extract the image features (Skip if you are using bottom-up feature)
18 |
19 | Download the coco images from [link](http://mscoco.org/dataset/#download). We need 2014 training images and 2014 val. images. You should put the `train2014/` and `val2014/` in the same directory, denoted as `$IMAGE_ROOT`.
20 |
21 | Then:
22 |
23 | ```
24 | $ python scripts/prepro_feats.py --input_json data/dataset_coco.json --output_dir data/cocotalk --images_root $IMAGE_ROOT
25 | ```
26 |
27 |
28 | `prepro_feats.py` extract the resnet101 features (both fc feature and last conv feature) of each image. The features are saved in `data/cocotalk_fc` and `data/cocotalk_att`, and resulting files are about 200GB.
29 |
30 | (Check the prepro scripts for more options, like other resnet models or other attention sizes.)
31 |
32 | **Warning**: the prepro script will fail with the default MSCOCO data because one of their images is corrupted. See [this issue](https://github.com/karpathy/neuraltalk2/issues/4) for the fix, it involves manually replacing one image in the dataset.
33 |
34 | ### Download Bottom-up features (Skip if you are using resnet features)
35 |
36 | Download pre-extracted feature from [link](https://github.com/peteanderson80/bottom-up-attention). You can either download adaptive one or fixed one.
37 |
38 | For example:
39 | ```
40 | mkdir data/bu_data; cd data/bu_data
41 | wget https://storage.googleapis.com/bottom-up-attention/trainval.zip
42 | unzip trainval.zip
43 |
44 | ```
45 |
46 | Then:
47 |
48 | ```bash
49 | python script/make_bu_data.py --output_dir data/cocobu
50 | ```
51 |
52 | This will create `data/cocobu_fc`, `data/cocobu_att` and `data/cocobu_box`. If you want to use bottom-up feature, you can just follow the following steps and replace all cocotalk with cocobu.
53 |
54 | ## Flickr30k.
55 |
56 | It's similar.
57 |
58 | ```
59 | python scripts/prepro_labels.py --input_json data/dataset_flickr30k.json --output_json data/f30ktalk.json --output_h5 data/f30ktalk
60 |
61 | python scripts/prepro_ngrams.py --input_json data/dataset_flickr30k.json --dict_json data/f30ktalk.json --output_pkl data/f30k-train --split train
62 | ```
63 |
64 | This is to generate the coco-like annotation file for evaluation using coco-caption.
65 |
66 | ```
67 | python scripts/prepro_reference_json.py --input_json data/dataset_flickr30k.json --output_json data/f30k_captions4eval.json
68 | ```
69 |
70 | ### Feature extraction
71 |
72 | For resnet feature, you can do the same thing as COCO.
73 |
74 | For bottom-up feature, you can download from [link](https://github.com/kuanghuei/SCAN)
75 |
76 | `wget https://scanproject.blob.core.windows.net/scan-data/data.zip`
77 |
78 | and then convert to a pth file using the following script:
79 |
80 | ```
81 | import numpy as np
82 | import os
83 | import torch
84 | from tqdm import tqdm
85 |
86 | out = {}
87 | def transform(id_file, feat_file):
88 | ids = open(id_file, 'r').readlines()
89 | ids = [_.strip('\n') for _ in ids]
90 | feats = np.load(feat_file)
91 | assert feats.shape[0] == len(ids)
92 | for _id, _feat in tqdm(zip(ids, feats)):
93 | out[str(_id)] = _feat
94 |
95 | transform('dev_ids.txt', 'dev_ims.npy')
96 | transform('train_ids.txt', 'train_ims.npy')
97 | transform('test_ids.txt', 'test_ims.npy')
98 |
99 | torch.save(out, 'f30kbu_att.pth')
100 | ```
--------------------------------------------------------------------------------
/scripts/prepro_reference_json.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 | """
3 | Preprocess a raw json dataset into hdf5/json files for use in data_loader.lua
4 |
5 | Input: json file that has the form
6 | [{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...]
7 | example element in this list would look like
8 | {'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895}
9 |
10 | This script reads this json, does some basic preprocessing on the captions
11 | (e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays
12 |
13 | Output: a json file and an hdf5 file
14 | The hdf5 file contains several fields:
15 | /images is (N,3,256,256) uint8 array of raw image data in RGB format
16 | /labels is (M,max_length) uint32 array of encoded labels, zero padded
17 | /label_start_ix and /label_end_ix are (N,) uint32 arrays of pointers to the
18 | first and last indices (in range 1..M) of labels for each image
19 | /label_length stores the length of the sequence for each of the M sequences
20 |
21 | The json file has a dict that contains:
22 | - an 'ix_to_word' field storing the vocab in form {ix:'word'}, where ix is 1-indexed
23 | - an 'images' field that is a list holding auxiliary information for each image,
24 | such as in particular the 'split' it was assigned to.
25 | """
26 |
27 | from __future__ import absolute_import
28 | from __future__ import division
29 | from __future__ import print_function
30 |
31 | import os
32 | import json
33 | import argparse
34 | import sys
35 | import hashlib
36 | from random import shuffle, seed
37 |
38 |
39 | def main(params):
40 |
41 | imgs = json.load(open(params['input_json'][0], 'r'))['images']
42 | # tmp = []
43 | # for k in imgs.keys():
44 | # for img in imgs[k]:
45 | # img['filename'] = img['image_id'] # k+'/'+img['image_id']
46 | # img['image_id'] = int(
47 | # int(hashlib.sha256(img['image_id']).hexdigest(), 16) % sys.maxint)
48 | # tmp.append(img)
49 | # imgs = tmp
50 |
51 | # create output json file
52 | out = {u'info': {u'description': u'This is stable 1.0 version of the 2014 MS COCO dataset.', u'url': u'http://mscoco.org', u'version': u'1.0', u'year': 2014, u'contributor': u'Microsoft COCO group', u'date_created': u'2015-01-27 09:11:52.357475'}, u'licenses': [{u'url': u'http://creativecommons.org/licenses/by-nc-sa/2.0/', u'id': 1, u'name': u'Attribution-NonCommercial-ShareAlike License'}, {u'url': u'http://creativecommons.org/licenses/by-nc/2.0/', u'id': 2, u'name': u'Attribution-NonCommercial License'}, {u'url': u'http://creativecommons.org/licenses/by-nc-nd/2.0/', u'id': 3, u'name': u'Attribution-NonCommercial-NoDerivs License'}, {u'url': u'http://creativecommons.org/licenses/by/2.0/', u'id': 4, u'name': u'Attribution License'}, {u'url': u'http://creativecommons.org/licenses/by-sa/2.0/', u'id': 5, u'name': u'Attribution-ShareAlike License'}, {u'url': u'http://creativecommons.org/licenses/by-nd/2.0/', u'id': 6, u'name': u'Attribution-NoDerivs License'}, {u'url': u'http://flickr.com/commons/usage/', u'id': 7, u'name': u'No known copyright restrictions'}, {u'url': u'http://www.usa.gov/copyright.shtml', u'id': 8, u'name': u'United States Government Work'}], u'type': u'captions'}
53 | out.update({'images': [], 'annotations': []})
54 |
55 | cnt = 0
56 | empty_cnt = 0
57 | for i, img in enumerate(imgs):
58 | if img['split'] == 'train':
59 | continue
60 | out['images'].append(
61 | {u'id': img.get('cocoid', img['imgid'])})
62 | for j, s in enumerate(img['sentences']):
63 | if len(s) == 0:
64 | continue
65 | s = ' '.join(s['tokens'])
66 | out['annotations'].append(
67 | {'image_id': out['images'][-1]['id'], 'caption': s, 'id': cnt})
68 | cnt += 1
69 |
70 | json.dump(out, open(params['output_json'], 'w'))
71 | print('wrote ', params['output_json'])
72 |
73 |
74 | if __name__ == "__main__":
75 |
76 | parser = argparse.ArgumentParser()
77 |
78 | # input json
79 | parser.add_argument('--input_json', nargs='+', required=True,
80 | help='input json file to process into hdf5')
81 | parser.add_argument('--output_json', default='data.json',
82 | help='output json file')
83 |
84 | args = parser.parse_args()
85 | params = vars(args) # convert to ordinary dict
86 | print('parsed input parameters:')
87 | print(json.dumps(params, indent=2))
88 | main(params)
89 |
90 |
--------------------------------------------------------------------------------
/scripts/prepro_feats.py:
--------------------------------------------------------------------------------
1 | """
2 | Preprocess a raw json dataset into hdf5/json files for use in data_loader.lua
3 |
4 | Input: json file that has the form
5 | [{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...]
6 | example element in this list would look like
7 | {'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895}
8 |
9 | This script reads this json, does some basic preprocessing on the captions
10 | (e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays
11 |
12 | Output: a json file and an hdf5 file
13 | The hdf5 file contains several fields:
14 | /images is (N,3,256,256) uint8 array of raw image data in RGB format
15 | /labels is (M,max_length) uint32 array of encoded labels, zero padded
16 | /label_start_ix and /label_end_ix are (N,) uint32 arrays of pointers to the
17 | first and last indices (in range 1..M) of labels for each image
18 | /label_length stores the length of the sequence for each of the M sequences
19 |
20 | The json file has a dict that contains:
21 | - an 'ix_to_word' field storing the vocab in form {ix:'word'}, where ix is 1-indexed
22 | - an 'images' field that is a list holding auxiliary information for each image,
23 | such as in particular the 'split' it was assigned to.
24 | """
25 |
26 | from __future__ import absolute_import
27 | from __future__ import division
28 | from __future__ import print_function
29 |
30 | import os
31 | import json
32 | import argparse
33 | from random import shuffle, seed
34 | import string
35 | # non-standard dependencies:
36 | import h5py
37 | from six.moves import cPickle
38 | import numpy as np
39 | import torch
40 | import torchvision.models as models
41 | import skimage.io
42 |
43 | from torchvision import transforms as trn
44 | preprocess = trn.Compose([
45 | #trn.ToTensor(),
46 | trn.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
47 | ])
48 |
49 | from misc.resnet_utils import myResnet
50 | import misc.resnet as resnet
51 |
52 | def main(params):
53 | net = getattr(resnet, params['model'])()
54 | net.load_state_dict(torch.load(os.path.join(params['model_root'],params['model']+'.pth')))
55 | my_resnet = myResnet(net)
56 | my_resnet.cuda()
57 | my_resnet.eval()
58 |
59 | imgs = json.load(open(params['input_json'], 'r'))
60 | imgs = imgs['images']
61 | N = len(imgs)
62 |
63 | seed(123) # make reproducible
64 |
65 | dir_fc = params['output_dir']+'_fc'
66 | dir_att = params['output_dir']+'_att'
67 | if not os.path.isdir(dir_fc):
68 | os.mkdir(dir_fc)
69 | if not os.path.isdir(dir_att):
70 | os.mkdir(dir_att)
71 |
72 | for i,img in enumerate(imgs):
73 | # load the image
74 | I = skimage.io.imread(os.path.join(params['images_root'], img['filepath'], img['filename']))
75 | # handle grayscale input images
76 | if len(I.shape) == 2:
77 | I = I[:,:,np.newaxis]
78 | I = np.concatenate((I,I,I), axis=2)
79 |
80 | I = I.astype('float32')/255.0
81 | I = torch.from_numpy(I.transpose([2,0,1])).cuda()
82 | I = preprocess(I)
83 | with torch.no_grad():
84 | tmp_fc, tmp_att = my_resnet(I, params['att_size'])
85 | # write to pkl
86 | np.save(os.path.join(dir_fc, str(img['cocoid'])), tmp_fc.data.cpu().float().numpy())
87 | np.savez_compressed(os.path.join(dir_att, str(img['cocoid'])), feat=tmp_att.data.cpu().float().numpy())
88 |
89 | if i % 1000 == 0:
90 | print('processing %d/%d (%.2f%% done)' % (i, N, i*100.0/N))
91 | print('wrote ', params['output_dir'])
92 |
93 | if __name__ == "__main__":
94 |
95 | parser = argparse.ArgumentParser()
96 |
97 | # input json
98 | parser.add_argument('--input_json', required=True, help='input json file to process into hdf5')
99 | parser.add_argument('--output_dir', default='data', help='output h5 file')
100 |
101 | # options
102 | parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json')
103 | parser.add_argument('--att_size', default=14, type=int, help='14x14 or 7x7')
104 | parser.add_argument('--model', default='resnet101', type=str, help='resnet101, resnet152')
105 | parser.add_argument('--model_root', default='./data/imagenet_weights', type=str, help='model root')
106 |
107 | args = parser.parse_args()
108 | params = vars(args) # convert to ordinary dict
109 | print('parsed input parameters:')
110 | print(json.dumps(params, indent = 2))
111 | main(params)
112 |
--------------------------------------------------------------------------------
/dataloaderraw.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import json
6 | import h5py
7 | import os
8 | import numpy as np
9 | import random
10 | import torch
11 | import skimage
12 | import skimage.io
13 | import scipy.misc
14 |
15 | from torchvision import transforms as trn
16 | preprocess = trn.Compose([
17 | #trn.ToTensor(),
18 | trn.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
19 | ])
20 |
21 | from misc.resnet_utils import myResnet
22 | import misc.resnet
23 |
24 | class DataLoaderRaw():
25 |
26 | def __init__(self, opt):
27 | self.opt = opt
28 | self.coco_json = opt.get('coco_json', '')
29 | self.folder_path = opt.get('folder_path', '')
30 |
31 | self.batch_size = opt.get('batch_size', 1)
32 | self.seq_per_img = 1
33 |
34 | # Load resnet
35 | self.cnn_model = opt.get('cnn_model', 'resnet101')
36 | self.my_resnet = getattr(misc.resnet, self.cnn_model)()
37 | self.my_resnet.load_state_dict(torch.load('./data/imagenet_weights/'+self.cnn_model+'.pth'))
38 | self.my_resnet = myResnet(self.my_resnet)
39 | self.my_resnet.cuda()
40 | self.my_resnet.eval()
41 |
42 |
43 |
44 | # load the json file which contains additional information about the dataset
45 | print('DataLoaderRaw loading images from folder: ', self.folder_path)
46 |
47 | self.files = []
48 | self.ids = []
49 |
50 | print(len(self.coco_json))
51 | if len(self.coco_json) > 0:
52 | print('reading from ' + opt.coco_json)
53 | # read in filenames from the coco-style json file
54 | self.coco_annotation = json.load(open(self.coco_json))
55 | for k,v in enumerate(self.coco_annotation['images']):
56 | fullpath = os.path.join(self.folder_path, v['file_name'])
57 | self.files.append(fullpath)
58 | self.ids.append(v['id'])
59 | else:
60 | # read in all the filenames from the folder
61 | print('listing all images in directory ' + self.folder_path)
62 | def isImage(f):
63 | supportedExt = ['.jpg','.JPG','.jpeg','.JPEG','.png','.PNG','.ppm','.PPM']
64 | for ext in supportedExt:
65 | start_idx = f.rfind(ext)
66 | if start_idx >= 0 and start_idx + len(ext) == len(f):
67 | return True
68 | return False
69 |
70 | n = 1
71 | for root, dirs, files in os.walk(self.folder_path, topdown=False):
72 | for file in files:
73 | fullpath = os.path.join(self.folder_path, file)
74 | if isImage(fullpath):
75 | self.files.append(fullpath)
76 | self.ids.append(str(n)) # just order them sequentially
77 | n = n + 1
78 |
79 | self.N = len(self.files)
80 | print('DataLoaderRaw found ', self.N, ' images')
81 |
82 | self.iterator = 0
83 |
84 | def get_batch(self, split, batch_size=None):
85 | batch_size = batch_size or self.batch_size
86 |
87 | # pick an index of the datapoint to load next
88 | fc_batch = np.ndarray((batch_size, 2048), dtype = 'float32')
89 | att_batch = np.ndarray((batch_size, 14, 14, 2048), dtype = 'float32')
90 | max_index = self.N
91 | wrapped = False
92 | infos = []
93 |
94 | for i in range(batch_size):
95 | ri = self.iterator
96 | ri_next = ri + 1
97 | if ri_next >= max_index:
98 | ri_next = 0
99 | wrapped = True
100 | # wrap back around
101 | self.iterator = ri_next
102 |
103 | img = skimage.io.imread(self.files[ri])
104 |
105 | if len(img.shape) == 2:
106 | img = img[:,:,np.newaxis]
107 | img = np.concatenate((img, img, img), axis=2)
108 |
109 | img = img[:,:,:3].astype('float32')/255.0
110 | img = torch.from_numpy(img.transpose([2,0,1])).cuda()
111 | img = preprocess(img)
112 | with torch.no_grad():
113 | tmp_fc, tmp_att = self.my_resnet(img)
114 |
115 | fc_batch[i] = tmp_fc.data.cpu().float().numpy()
116 | att_batch[i] = tmp_att.data.cpu().float().numpy()
117 |
118 | info_struct = {}
119 | info_struct['id'] = self.ids[ri]
120 | info_struct['file_path'] = self.files[ri]
121 | infos.append(info_struct)
122 |
123 | data = {}
124 | data['fc_feats'] = fc_batch
125 | data['att_feats'] = att_batch.reshape(batch_size, -1, 2048)
126 | data['att_masks'] = None
127 | data['bounds'] = {'it_pos_now': self.iterator, 'it_max': self.N, 'wrapped': wrapped}
128 | data['infos'] = infos
129 |
130 | return data
131 |
132 | def reset_iterator(self, split):
133 | self.iterator = 0
134 |
135 | def get_vocab_size(self):
136 | return len(self.ix_to_word)
137 |
138 | def get_vocab(self):
139 | return self.ix_to_word
140 |
--------------------------------------------------------------------------------
/models/AttEnsemble.py:
--------------------------------------------------------------------------------
1 | # This file contains Att2in2, AdaAtt, AdaAttMO, TopDown model
2 |
3 | # AdaAtt is from Knowing When to Look: Adaptive Attention via A Visual Sentinel for Image Captioning
4 | # https://arxiv.org/abs/1612.01887
5 | # AdaAttMO is a modified version with maxout lstm
6 |
7 | # Att2in is from Self-critical Sequence Training for Image Captioning
8 | # https://arxiv.org/abs/1612.00563
9 | # In this file we only have Att2in2, which is a slightly different version of att2in,
10 | # in which the img feature embedding and word embedding is the same as what in adaatt.
11 |
12 | # TopDown is from Bottom-Up and Top-Down Attention for Image Captioning and VQA
13 | # https://arxiv.org/abs/1707.07998
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import numpy as np
20 | import torch
21 | import torch.nn as nn
22 | import torch.nn.functional as F
23 | from torch.autograd import *
24 | import misc.utils as utils
25 |
26 | from .CaptionModel import CaptionModel
27 | from .AttModel import pack_wrapper, AttModel
28 |
29 | class AttEnsemble(AttModel):
30 | def __init__(self, models, weights=None):
31 | CaptionModel.__init__(self)
32 | # super(AttEnsemble, self).__init__()
33 |
34 | self.models = nn.ModuleList(models)
35 | self.vocab_size = models[0].vocab_size
36 | self.seq_length = models[0].seq_length
37 | self.bad_endings_ix = models[0].bad_endings_ix
38 | self.ss_prob = 0
39 | weights = weights or [1.0] * len(self.models)
40 | self.register_buffer('weights', torch.tensor(weights))
41 |
42 | def init_hidden(self, batch_size):
43 | state = [m.init_hidden(batch_size) for m in self.models]
44 | return self.pack_state(state)
45 |
46 | def pack_state(self, state):
47 | self.state_lengths = [len(_) for _ in state]
48 | return sum([list(_) for _ in state], [])
49 |
50 | def unpack_state(self, state):
51 | out = []
52 | for l in self.state_lengths:
53 | out.append(state[:l])
54 | state = state[l:]
55 | return out
56 |
57 | def embed(self, it):
58 | return [m.embed(it) for m in self.models]
59 |
60 | def core(self, *args):
61 | return zip(*[m.core(*_) for m, _ in zip(self.models, zip(*args))])
62 |
63 | def get_logprobs_state(self, it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, state):
64 | # 'it' contains a word index
65 | xt = self.embed(it)
66 |
67 | state = self.unpack_state(state)
68 | output, state = self.core(xt, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, state, tmp_att_masks)
69 | logprobs = torch.stack([F.softmax(m.logit(output[i]), dim=1) for i,m in enumerate(self.models)], 2).mul(self.weights).div(self.weights.sum()).sum(-1).log()
70 |
71 | return logprobs, self.pack_state(state)
72 |
73 | def _prepare_feature(self, *args):
74 | return tuple(zip(*[m._prepare_feature(*args) for m in self.models]))
75 |
76 | # def _prepare_feature(self, fc_feats, att_feats, att_masks):
77 |
78 | # att_feats, att_masks = self.clip_att(att_feats, att_masks)
79 |
80 | # # embed fc and att feats
81 | # fc_feats = [m.fc_embed(fc_feats) for m in self.models]
82 | # att_feats = [pack_wrapper(m.att_embed, att_feats[...,:m.att_feat_size], att_masks) for m in self.models]
83 |
84 | # # Project the attention feats first to reduce memory and computation comsumptions.
85 | # p_att_feats = [m.ctx2att(att_feats[i]) for i,m in enumerate(self.models)]
86 |
87 | # return fc_feats, att_feats, p_att_feats, [att_masks] * len(self.models)
88 |
89 | def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
90 | beam_size = opt.get('beam_size', 10)
91 | batch_size = fc_feats.size(0)
92 |
93 | fc_feats, att_feats, p_att_feats, att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
94 |
95 | assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
96 | seq = torch.LongTensor(self.seq_length, batch_size).zero_()
97 | seqLogprobs = torch.FloatTensor(self.seq_length, batch_size)
98 | # lets process every image independently for now, for simplicity
99 |
100 | self.done_beams = [[] for _ in range(batch_size)]
101 | for k in range(batch_size):
102 | state = self.init_hidden(beam_size)
103 | tmp_fc_feats = [fc_feats[i][k:k+1].expand(beam_size, fc_feats[i].size(1)) for i,m in enumerate(self.models)]
104 | tmp_att_feats = [att_feats[i][k:k+1].expand(*((beam_size,)+att_feats[i].size()[1:])).contiguous() for i,m in enumerate(self.models)]
105 | tmp_p_att_feats = [p_att_feats[i][k:k+1].expand(*((beam_size,)+p_att_feats[i].size()[1:])).contiguous() for i,m in enumerate(self.models)]
106 | tmp_att_masks = [att_masks[i][k:k+1].expand(*((beam_size,)+att_masks[i].size()[1:])).contiguous() if att_masks[i] is not None else att_masks[i] for i,m in enumerate(self.models)]
107 |
108 | it = fc_feats[0].data.new(beam_size).long().zero_()
109 | logprobs, state = self.get_logprobs_state(it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, state)
110 |
111 | self.done_beams[k] = self.beam_search(state, logprobs, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, opt=opt)
112 | seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
113 | seqLogprobs[:, k] = self.done_beams[k][0]['logps']
114 | # return the samples and their log likelihoods
115 | return seq.transpose(0, 1), seqLogprobs.transpose(0, 1)
116 | # return the samples and their log likelihoods
--------------------------------------------------------------------------------
/scripts/prepro_ngrams.py:
--------------------------------------------------------------------------------
1 | """
2 | Preprocess a raw json dataset into hdf5/json files for use in data_loader.lua
3 |
4 | Input: json file that has the form
5 | [{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...]
6 | example element in this list would look like
7 | {'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895}
8 |
9 | This script reads this json, does some basic preprocessing on the captions
10 | (e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays
11 |
12 | Output: a json file and an hdf5 file
13 | The hdf5 file contains several fields:
14 | /images is (N,3,256,256) uint8 array of raw image data in RGB format
15 | /labels is (M,max_length) uint32 array of encoded labels, zero padded
16 | /label_start_ix and /label_end_ix are (N,) uint32 arrays of pointers to the
17 | first and last indices (in range 1..M) of labels for each image
18 | /label_length stores the length of the sequence for each of the M sequences
19 |
20 | The json file has a dict that contains:
21 | - an 'ix_to_word' field storing the vocab in form {ix:'word'}, where ix is 1-indexed
22 | - an 'images' field that is a list holding auxiliary information for each image,
23 | such as in particular the 'split' it was assigned to.
24 | """
25 |
26 | import os
27 | import json
28 | import argparse
29 | from six.moves import cPickle
30 | import misc.utils as utils
31 | from collections import defaultdict
32 |
33 | def precook(s, n=4, out=False):
34 | """
35 | Takes a string as input and returns an object that can be given to
36 | either cook_refs or cook_test. This is optional: cook_refs and cook_test
37 | can take string arguments as well.
38 | :param s: string : sentence to be converted into ngrams
39 | :param n: int : number of ngrams for which representation is calculated
40 | :return: term frequency vector for occuring ngrams
41 | """
42 | words = s.split()
43 | counts = defaultdict(int)
44 | for k in xrange(1,n+1):
45 | for i in xrange(len(words)-k+1):
46 | ngram = tuple(words[i:i+k])
47 | counts[ngram] += 1
48 | return counts
49 |
50 | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average"
51 | '''Takes a list of reference sentences for a single segment
52 | and returns an object that encapsulates everything that BLEU
53 | needs to know about them.
54 | :param refs: list of string : reference sentences for some image
55 | :param n: int : number of ngrams for which (ngram) representation is calculated
56 | :return: result (list of dict)
57 | '''
58 | return [precook(ref, n) for ref in refs]
59 |
60 | def create_crefs(refs):
61 | crefs = []
62 | for ref in refs:
63 | # ref is a list of 5 captions
64 | crefs.append(cook_refs(ref))
65 | return crefs
66 |
67 | def compute_doc_freq(crefs):
68 | '''
69 | Compute term frequency for reference data.
70 | This will be used to compute idf (inverse document frequency later)
71 | The term frequency is stored in the object
72 | :return: None
73 | '''
74 | document_frequency = defaultdict(float)
75 | for refs in crefs:
76 | # refs, k ref captions of one image
77 | for ngram in set([ngram for ref in refs for (ngram,count) in ref.iteritems()]):
78 | document_frequency[ngram] += 1
79 | # maxcounts[ngram] = max(maxcounts.get(ngram,0), count)
80 | return document_frequency
81 |
82 | def build_dict(imgs, wtoi, params):
83 | wtoi[''] = 0
84 |
85 | count_imgs = 0
86 |
87 | refs_words = []
88 | refs_idxs = []
89 | for img in imgs:
90 | if (params['split'] == img['split']) or \
91 | (params['split'] == 'train' and img['split'] == 'restval') or \
92 | (params['split'] == 'all'):
93 | #(params['split'] == 'val' and img['split'] == 'restval') or \
94 | ref_words = []
95 | ref_idxs = []
96 | for sent in img['sentences']:
97 | if hasattr(params, 'bpe'):
98 | sent['tokens'] = params.bpe.segment(' '.join(sent['tokens'])).strip().split(' ')
99 | tmp_tokens = sent['tokens'] + ['']
100 | tmp_tokens = [_ if _ in wtoi else 'UNK' for _ in tmp_tokens]
101 | ref_words.append(' '.join(tmp_tokens))
102 | ref_idxs.append(' '.join([str(wtoi[_]) for _ in tmp_tokens]))
103 | refs_words.append(ref_words)
104 | refs_idxs.append(ref_idxs)
105 | count_imgs += 1
106 | print('total imgs:', count_imgs)
107 |
108 | ngram_words = compute_doc_freq(create_crefs(refs_words))
109 | ngram_idxs = compute_doc_freq(create_crefs(refs_idxs))
110 | return ngram_words, ngram_idxs, count_imgs
111 |
112 | def main(params):
113 |
114 | imgs = json.load(open(params['input_json'], 'r'))
115 | dict_json = json.load(open(params['dict_json'], 'r'))
116 | itow = dict_json['ix_to_word']
117 | wtoi = {w:i for i,w in itow.items()}
118 |
119 | # Load bpe
120 | if 'bpe' in dict_json:
121 | import tempfile
122 | import codecs
123 | codes_f = tempfile.NamedTemporaryFile(delete=False)
124 | codes_f.close()
125 | with open(codes_f.name, 'w') as f:
126 | f.write(dict_json['bpe'])
127 | with codecs.open(codes_f.name, encoding='UTF-8') as codes:
128 | bpe = apply_bpe.BPE(codes)
129 | params.bpe = bpe
130 |
131 | imgs = imgs['images']
132 |
133 | ngram_words, ngram_idxs, ref_len = build_dict(imgs, wtoi, params)
134 |
135 | utils.pickle_dump({'document_frequency': ngram_words, 'ref_len': ref_len}, open(params['output_pkl']+'-words.p','w'))
136 | utils.pickle_dump({'document_frequency': ngram_idxs, 'ref_len': ref_len}, open(params['output_pkl']+'-idxs.p','w'))
137 |
138 | if __name__ == "__main__":
139 |
140 | parser = argparse.ArgumentParser()
141 |
142 | # input json
143 | parser.add_argument('--input_json', default='/home-nfs/rluo/rluo/nips/code/prepro/dataset_coco.json', help='input json file to process into hdf5')
144 | parser.add_argument('--dict_json', default='data/cocotalk.json', help='output json file')
145 | parser.add_argument('--output_pkl', default='data/coco-all', help='output pickle file')
146 | parser.add_argument('--split', default='all', help='test, val, train, all')
147 | args = parser.parse_args()
148 | params = vars(args) # convert to ordinary dict
149 |
150 | main(params)
151 |
--------------------------------------------------------------------------------
/eval_utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import torch
6 | import torch.nn as nn
7 |
8 | import numpy as np
9 | import json
10 | from json import encoder
11 | import random
12 | import string
13 | import time
14 | import os
15 | import sys
16 | import misc.utils as utils
17 |
18 | bad_endings = ['a','an','the','in','for','at','of','with','before','after','on','upon','near','to','is','are','am']
19 | bad_endings += ['the']
20 |
21 | def count_bad(sen):
22 | sen = sen.split(' ')
23 | if sen[-1] in bad_endings:
24 | return 1
25 | else:
26 | return 0
27 |
28 | def language_eval(dataset, preds, model_id, split):
29 | import sys
30 | sys.path.append("coco-caption")
31 | if 'coco' in dataset:
32 | annFile = 'coco-caption/annotations/captions_val2014.json'
33 | elif 'flickr30k' in dataset or 'f30k' in dataset:
34 | annFile = 'coco-caption/f30k_captions4eval.json'
35 | from pycocotools.coco import COCO
36 | from pycocoevalcap.eval import COCOEvalCap
37 |
38 | # encoder.FLOAT_REPR = lambda o: format(o, '.3f')
39 |
40 | if not os.path.isdir('eval_results'):
41 | os.mkdir('eval_results')
42 | cache_path = os.path.join('eval_results/', '.cache_'+ model_id + '_' + split + '.json')
43 |
44 | coco = COCO(annFile)
45 | valids = coco.getImgIds()
46 |
47 | # filter results to only those in MSCOCO validation set (will be about a third)
48 | preds_filt = [p for p in preds if p['image_id'] in valids]
49 | print('using %d/%d predictions' % (len(preds_filt), len(preds)))
50 | json.dump(preds_filt, open(cache_path, 'w')) # serialize to temporary json file. Sigh, COCO API...
51 |
52 | cocoRes = coco.loadRes(cache_path)
53 | cocoEval = COCOEvalCap(coco, cocoRes)
54 | cocoEval.params['image_id'] = cocoRes.getImgIds()
55 | cocoEval.evaluate()
56 |
57 | # create output dictionary
58 | out = {}
59 | for metric, score in cocoEval.eval.items():
60 | out[metric] = score
61 |
62 | imgToEval = cocoEval.imgToEval
63 | for p in preds_filt:
64 | image_id, caption = p['image_id'], p['caption']
65 | imgToEval[image_id]['caption'] = caption
66 |
67 | out['bad_count_rate'] = sum([count_bad(_['caption']) for _ in preds_filt]) / float(len(preds_filt))
68 | outfile_path = os.path.join('eval_results/', model_id + '_' + split + '.json')
69 | with open(outfile_path, 'w') as outfile:
70 | json.dump({'overall': out, 'imgToEval': imgToEval}, outfile)
71 |
72 | return out
73 |
74 | def eval_split(model, crit, loader, eval_kwargs={}):
75 | verbose = eval_kwargs.get('verbose', True)
76 | verbose_beam = eval_kwargs.get('verbose_beam', 1)
77 | verbose_loss = eval_kwargs.get('verbose_loss', 1)
78 | num_images = eval_kwargs.get('num_images', eval_kwargs.get('val_images_use', -1))
79 | split = eval_kwargs.get('split', 'val')
80 | lang_eval = eval_kwargs.get('language_eval', 0)
81 | dataset = eval_kwargs.get('dataset', 'coco')
82 | beam_size = eval_kwargs.get('beam_size', 1)
83 | remove_bad_endings = eval_kwargs.get('remove_bad_endings', 0)
84 | os.environ["REMOVE_BAD_ENDINGS"] = str(remove_bad_endings) # Use this nasty way to make other code clean since it's a global configuration
85 |
86 | # Make sure in the evaluation mode
87 | model.eval()
88 |
89 | loader.reset_iterator(split)
90 |
91 | n = 0
92 | loss = 0
93 | loss_sum = 0
94 | loss_evals = 1e-8
95 | predictions = []
96 | while True:
97 | data = loader.get_batch(split)
98 | n = n + loader.batch_size
99 |
100 | if data.get('labels', None) is not None and verbose_loss:
101 | # forward the model to get loss
102 | tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks']]
103 | tmp = [_.cuda() if _ is not None else _ for _ in tmp]
104 | fc_feats, att_feats, labels, masks, att_masks = tmp
105 |
106 | with torch.no_grad():
107 | loss = crit(model(fc_feats, att_feats, labels, att_masks), labels[:,1:], masks[:,1:]).item()
108 | loss_sum = loss_sum + loss
109 | loss_evals = loss_evals + 1
110 |
111 | # forward the model to also get generated samples for each image
112 | # Only leave one feature for each image, in case duplicate sample
113 | tmp = [data['fc_feats'][np.arange(loader.batch_size) * loader.seq_per_img],
114 | data['att_feats'][np.arange(loader.batch_size) * loader.seq_per_img],
115 | data['att_masks'][np.arange(loader.batch_size) * loader.seq_per_img] if data['att_masks'] is not None else None]
116 | tmp = [_.cuda() if _ is not None else _ for _ in tmp]
117 | fc_feats, att_feats, att_masks = tmp
118 | # forward the model to also get generated samples for each image
119 | with torch.no_grad():
120 | seq = model(fc_feats, att_feats, att_masks, opt=eval_kwargs, mode='sample')[0].data
121 |
122 | # Print beam search
123 | if beam_size > 1 and verbose_beam:
124 | for i in range(loader.batch_size):
125 | print('\n'.join([utils.decode_sequence(loader.get_vocab(), _['seq'].unsqueeze(0))[0] for _ in model.done_beams[i]]))
126 | print('--' * 10)
127 | sents = utils.decode_sequence(loader.get_vocab(), seq)
128 |
129 | for k, sent in enumerate(sents):
130 | entry = {'image_id': data['infos'][k]['id'], 'caption': sent}
131 | if eval_kwargs.get('dump_path', 0) == 1:
132 | entry['file_name'] = data['infos'][k]['file_path']
133 | predictions.append(entry)
134 | if eval_kwargs.get('dump_images', 0) == 1:
135 | # dump the raw image to vis/ folder
136 | cmd = 'cp "' + os.path.join(eval_kwargs['image_root'], data['infos'][k]['file_path']) + '" vis/imgs/img' + str(len(predictions)) + '.jpg' # bit gross
137 | print(cmd)
138 | os.system(cmd)
139 |
140 | if verbose:
141 | print('image %s: %s' %(entry['image_id'], entry['caption']))
142 |
143 | # if we wrapped around the split or used up val imgs budget then bail
144 | ix0 = data['bounds']['it_pos_now']
145 | ix1 = data['bounds']['it_max']
146 | if num_images != -1:
147 | ix1 = min(ix1, num_images)
148 | for i in range(n - ix1):
149 | predictions.pop()
150 |
151 | if verbose:
152 | print('evaluating validation preformance... %d/%d (%f)' %(ix0 - 1, ix1, loss))
153 |
154 | if data['bounds']['wrapped']:
155 | break
156 | if num_images >= 0 and n >= num_images:
157 | break
158 |
159 | lang_stats = None
160 | if lang_eval == 1:
161 | lang_stats = language_eval(dataset, predictions, eval_kwargs['id'], split)
162 |
163 | # Switch back to training mode
164 | model.train()
165 | return loss_sum/loss_evals, predictions, lang_stats
166 |
--------------------------------------------------------------------------------
/scripts/dump_to_lmdb.py:
--------------------------------------------------------------------------------
1 | # copy from https://github.com/Lyken17/Efficient-PyTorch/tools
2 |
3 | import os
4 | import os.path as osp
5 | import os, sys
6 | import os.path as osp
7 | from PIL import Image
8 | import six
9 | import string
10 |
11 | import lmdb
12 | import pickle
13 | import tqdm
14 | import numpy as np
15 | import argparse
16 | import json
17 |
18 | import torch
19 | import torch.utils.data as data
20 | from torch.utils.data import DataLoader
21 |
22 | import csv
23 | csv.field_size_limit(sys.maxsize)
24 | FIELDNAMES = ['image_id', 'status']
25 |
26 | class FolderLMDB(data.Dataset):
27 | def __init__(self, db_path, fn_list=None):
28 | self.db_path = db_path
29 | self.env = lmdb.open(db_path, subdir=osp.isdir(db_path),
30 | readonly=True, lock=False,
31 | readahead=False, meminit=False)
32 | if fn_list is not None:
33 | self.length = len(fn_list)
34 | self.keys = fn_list
35 | else:
36 | raise Error
37 |
38 | def __getitem__(self, index):
39 | env = self.env
40 | with env.begin(write=False) as txn:
41 | byteflow = txn.get(self.keys[index])
42 |
43 | # load image
44 | imgbuf = byteflow
45 | buf = six.BytesIO()
46 | buf.write(imgbuf)
47 | buf.seek(0)
48 | try:
49 | if args.extension == '.npz':
50 | feat = np.load(buf)['feat']
51 | else:
52 | feat = np.load(buf)
53 | except Exception as e:
54 | print self.keys[index], e
55 | return None
56 |
57 | return feat
58 |
59 | def __len__(self):
60 | return self.length
61 |
62 | def __repr__(self):
63 | return self.__class__.__name__ + ' (' + self.db_path + ')'
64 |
65 |
66 | def make_dataset(dir, extension):
67 | images = []
68 | dir = os.path.expanduser(dir)
69 | for root, _, fnames in sorted(os.walk(dir)):
70 | for fname in sorted(fnames):
71 | if has_file_allowed_extension(fname, [extension]):
72 | path = os.path.join(root, fname)
73 | images.append(path)
74 |
75 | return images
76 |
77 |
78 | def raw_reader(path):
79 | with open(path, 'rb') as f:
80 | bin_data = f.read()
81 | return bin_data
82 |
83 |
84 | def raw_npz_reader(path):
85 | with open(path, 'rb') as f:
86 | bin_data = f.read()
87 | try:
88 | npz_data = np.load(six.BytesIO(bin_data))['feat']
89 | except Exception as e:
90 | print path
91 | npz_data = None
92 | print e
93 | return bin_data, npz_data
94 |
95 |
96 | def raw_npy_reader(path):
97 | with open(path, 'rb') as f:
98 | bin_data = f.read()
99 | try:
100 | npy_data = np.load(six.BytesIO(bin_data))
101 | except Exception as e:
102 | print path
103 | npy_data = None
104 | print e
105 | return bin_data, npy_data
106 |
107 |
108 | class Folder(data.Dataset):
109 |
110 | def __init__(self, root, loader, extension, fn_list=None):
111 | super(Folder, self).__init__()
112 | self.root = root
113 | if fn_list:
114 | samples = [os.path.join(root, str(_)+extension) for _ in fn_list]
115 | else:
116 | samples = make_dataset(self.root, extention)
117 |
118 | self.loader = loader
119 | self.extension = extension
120 | self.samples = samples
121 |
122 | def __getitem__(self, index):
123 | """
124 | Args:
125 | index (int): Index
126 | Returns:
127 | tuple: (sample, target) where target is class_index of the target class.
128 | """
129 | path = self.samples[index]
130 | sample = self.loader(path)
131 |
132 | return (path.split('/')[-1].split('.')[0],) + sample
133 |
134 | def __len__(self):
135 | return len(self.samples)
136 |
137 |
138 | def folder2lmdb(dpath, fn_list, write_frequency=5000):
139 | directory = osp.expanduser(osp.join(dpath))
140 | print("Loading dataset from %s" % directory)
141 | if args.extension == '.npz':
142 | dataset = Folder(directory, loader=raw_npz_reader, extension='.npz',
143 | fn_list=fn_list)
144 | else:
145 | dataset = Folder(directory, loader=raw_npy_reader, extension='.npy',
146 | fn_list=fn_list)
147 | data_loader = DataLoader(dataset, num_workers=16, collate_fn=lambda x: x)
148 |
149 | # lmdb_path = osp.join(dpath, "%s.lmdb" % (directory.split('/')[-1]))
150 | lmdb_path = osp.join("%s.lmdb" % (directory))
151 | isdir = os.path.isdir(lmdb_path)
152 |
153 | print("Generate LMDB to %s" % lmdb_path)
154 | db = lmdb.open(lmdb_path, subdir=isdir,
155 | map_size=1099511627776 * 2, readonly=False,
156 | meminit=False, map_async=True)
157 |
158 | txn = db.begin(write=True)
159 |
160 | tsvfile = open(args.output_file, 'ab')
161 | writer = csv.DictWriter(tsvfile, delimiter='\t', fieldnames=FIELDNAMES)
162 | names = []
163 | for idx, data in enumerate(tqdm.tqdm(data_loader)):
164 | # print(type(data), data)
165 | name, byte, npz = data[0]
166 | if npz is not None:
167 | txn.put(name, byte)
168 | names.append({'image_id': name, 'status': str(npz is not None)})
169 | if idx % write_frequency == 0:
170 | print("[%d/%d]" % (idx, len(data_loader)))
171 | print('writing')
172 | txn.commit()
173 | txn = db.begin(write=True)
174 | # write in tsv
175 | for name in names:
176 | writer.writerow(name)
177 | names = []
178 | tsvfile.flush()
179 | print('writing finished')
180 |
181 | # finish iterating through dataset
182 | txn.commit()
183 | for name in names:
184 | writer.writerow(name)
185 | tsvfile.flush()
186 | tsvfile.close()
187 |
188 | print("Flushing database ...")
189 | db.sync()
190 | db.close()
191 |
192 | def parse_args():
193 | """
194 | Parse input arguments
195 | """
196 | parser = argparse.ArgumentParser(description='Generate bbox output from a Fast R-CNN network')
197 | # parser.add_argument('--json)
198 | parser.add_argument('--input_json', default='./data/dataset_coco.json', type=str)
199 | parser.add_argument('--output_file', default='.dump_cache.tsv', type=str)
200 | parser.add_argument('--folder', default='./data/cocobu_att', type=str)
201 | parser.add_argument('--extension', default='.npz', type=str)
202 |
203 | args = parser.parse_args()
204 | return args
205 |
206 | if __name__ == "__main__":
207 | global args
208 | args = parse_args()
209 |
210 | args.output_file += args.folder.split('/')[-1]
211 | if args.folder.find('/') > 0:
212 | args.output_file = args.folder[:args.folder.rfind('/')+1]+args.output_file
213 | print(args.output_file)
214 |
215 | img_list = json.load(open(args.input_json, 'r'))['images']
216 | fn_list = [str(_['cocoid']) for _ in img_list]
217 | found_ids = set()
218 | try:
219 | with open(args.output_file, 'r') as tsvfile:
220 | reader = csv.DictReader(tsvfile, delimiter='\t', fieldnames=FIELDNAMES)
221 | for item in reader:
222 | if item['status'] == 'True':
223 | found_ids.add(item['image_id'])
224 | except:
225 | pass
226 | fn_list = [_ for _ in fn_list if _ not in found_ids]
227 | folder2lmdb(args.folder, fn_list)
228 |
229 | # Test existing.
230 | found_ids = set()
231 | with open(args.output_file, 'r') as tsvfile:
232 | reader = csv.DictReader(tsvfile, delimiter='\t', fieldnames=FIELDNAMES)
233 | for item in reader:
234 | if item['status'] == 'True':
235 | found_ids.add(item['image_id'])
236 |
237 | folder_dataset = FolderLMDB(args.folder+'.lmdb', list(found_ids))
238 | data_loader = DataLoader(folder_dataset, num_workers=16, collate_fn=lambda x: x)
239 | for data in tqdm.tqdm(data_loader):
240 | assert data[0] is not None
--------------------------------------------------------------------------------
/models/ShowTellModel.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from torch.autograd import *
9 | import misc.utils as utils
10 |
11 | from .CaptionModel import CaptionModel
12 |
13 | class ShowTellModel(CaptionModel):
14 | def __init__(self, opt):
15 | super(ShowTellModel, self).__init__()
16 | self.vocab_size = opt.vocab_size
17 | self.input_encoding_size = opt.input_encoding_size
18 | self.rnn_type = opt.rnn_type
19 | self.rnn_size = opt.rnn_size
20 | self.num_layers = opt.num_layers
21 | self.drop_prob_lm = opt.drop_prob_lm
22 | self.seq_length = opt.seq_length
23 | self.fc_feat_size = opt.fc_feat_size
24 |
25 | self.ss_prob = 0.0 # Schedule sampling probability
26 |
27 | self.img_embed = nn.Linear(self.fc_feat_size, self.input_encoding_size)
28 | self.core = getattr(nn, self.rnn_type.upper())(self.input_encoding_size, self.rnn_size, self.num_layers, bias=False, dropout=self.drop_prob_lm)
29 | self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size)
30 | self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1)
31 | self.dropout = nn.Dropout(self.drop_prob_lm)
32 |
33 | self.init_weights()
34 |
35 | def init_weights(self):
36 | initrange = 0.1
37 | self.embed.weight.data.uniform_(-initrange, initrange)
38 | self.logit.bias.data.fill_(0)
39 | self.logit.weight.data.uniform_(-initrange, initrange)
40 |
41 | def init_hidden(self, bsz):
42 | weight = next(self.parameters()).data
43 | if self.rnn_type == 'lstm':
44 | return (weight.new_zeros(self.num_layers, bsz, self.rnn_size),
45 | weight.new_zeros(self.num_layers, bsz, self.rnn_size))
46 | else:
47 | return weight.new_zeros(self.num_layers, bsz, self.rnn_size)
48 |
49 | def _forward(self, fc_feats, att_feats, seq, att_masks=None):
50 | batch_size = fc_feats.size(0)
51 | state = self.init_hidden(batch_size)
52 | outputs = []
53 |
54 | for i in range(seq.size(1)):
55 | if i == 0:
56 | xt = self.img_embed(fc_feats)
57 | else:
58 | if self.training and i >= 2 and self.ss_prob > 0.0: # otherwiste no need to sample
59 | sample_prob = fc_feats.data.new(batch_size).uniform_(0, 1)
60 | sample_mask = sample_prob < self.ss_prob
61 | if sample_mask.sum() == 0:
62 | it = seq[:, i-1].clone()
63 | else:
64 | sample_ind = sample_mask.nonzero().view(-1)
65 | it = seq[:, i-1].data.clone()
66 | #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1)
67 | #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1))
68 | prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1)
69 | it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind))
70 | else:
71 | it = seq[:, i-1].clone()
72 | # break if all the sequences end
73 | if i >= 2 and seq[:, i-1].data.sum() == 0:
74 | break
75 | xt = self.embed(it)
76 |
77 | output, state = self.core(xt.unsqueeze(0), state)
78 | output = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1)
79 | outputs.append(output)
80 |
81 | return torch.cat([_.unsqueeze(1) for _ in outputs[1:]], 1).contiguous()
82 |
83 | def get_logprobs_state(self, it, state):
84 | # 'it' contains a word index
85 | xt = self.embed(it)
86 |
87 | output, state = self.core(xt.unsqueeze(0), state)
88 | logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1)
89 |
90 | return logprobs, state
91 |
92 | def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
93 | beam_size = opt.get('beam_size', 10)
94 | batch_size = fc_feats.size(0)
95 |
96 | assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
97 | seq = torch.LongTensor(self.seq_length, batch_size).zero_()
98 | seqLogprobs = torch.FloatTensor(self.seq_length, batch_size)
99 | # lets process every image independently for now, for simplicity
100 |
101 | self.done_beams = [[] for _ in range(batch_size)]
102 | for k in range(batch_size):
103 | state = self.init_hidden(beam_size)
104 | for t in range(2):
105 | if t == 0:
106 | xt = self.img_embed(fc_feats[k:k+1]).expand(beam_size, self.input_encoding_size)
107 | elif t == 1: # input
108 | it = fc_feats.data.new(beam_size).long().zero_()
109 | xt = self.embed(it)
110 |
111 | output, state = self.core(xt.unsqueeze(0), state)
112 | logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1)
113 |
114 | self.done_beams[k] = self.beam_search(state, logprobs, opt=opt)
115 | seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
116 | seqLogprobs[:, k] = self.done_beams[k][0]['logps']
117 | # return the samples and their log likelihoods
118 | return seq.transpose(0, 1), seqLogprobs.transpose(0, 1)
119 |
120 | def _sample(self, fc_feats, att_feats, att_masks=None, opt={}):
121 | sample_method = opt.get('sample_method', 'greedy')
122 | beam_size = opt.get('beam_size', 1)
123 | temperature = opt.get('temperature', 1.0)
124 | if beam_size > 1:
125 | return self.sample_beam(fc_feats, att_feats, opt)
126 |
127 | batch_size = fc_feats.size(0)
128 | state = self.init_hidden(batch_size)
129 | seq = fc_feats.new_zeros(batch_size, self.seq_length, dtype=torch.long)
130 | seqLogprobs = fc_feats.new_zeros(batch_size, self.seq_length)
131 | for t in range(self.seq_length + 2):
132 | if t == 0:
133 | xt = self.img_embed(fc_feats)
134 | else:
135 | if t == 1: # input
136 | it = fc_feats.data.new(batch_size).long().zero_()
137 | xt = self.embed(it)
138 |
139 | output, state = self.core(xt.unsqueeze(0), state)
140 | logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1)
141 |
142 | # sample the next word
143 | if t == self.seq_length + 1: # skip if we achieve maximum length
144 | break
145 | if sample_method == 'greedy':
146 | sampleLogprobs, it = torch.max(logprobs.data, 1)
147 | it = it.view(-1).long()
148 | else:
149 | if temperature == 1.0:
150 | prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1)
151 | else:
152 | # scale logprobs by temperature
153 | prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu()
154 | it = torch.multinomial(prob_prev, 1).cuda()
155 | sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions
156 | it = it.view(-1).long() # and flatten indices for downstream processing
157 |
158 | if t >= 1:
159 | # stop when all finished
160 | if t == 1:
161 | unfinished = it > 0
162 | else:
163 | unfinished = unfinished * (it > 0)
164 | it = it * unfinished.type_as(it)
165 | seq[:,t-1] = it #seq[t] the input of t+2 time step
166 | seqLogprobs[:,t-1] = sampleLogprobs.view(-1)
167 | if unfinished.sum() == 0:
168 | break
169 |
170 | return seq, seqLogprobs
--------------------------------------------------------------------------------
/scripts/build_bpe_subword_nmt.py:
--------------------------------------------------------------------------------
1 | """
2 | Preprocess a raw json dataset into hdf5/json files for use in data_loader.lua
3 |
4 | Input: json file that has the form
5 | [{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...]
6 | example element in this list would look like
7 | {'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895}
8 |
9 | This script reads this json, does some basic preprocessing on the captions
10 | (e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays
11 |
12 | Output: a json file and an hdf5 file
13 | The hdf5 file contains several fields:
14 | /images is (N,3,256,256) uint8 array of raw image data in RGB format
15 | /labels is (M,max_length) uint32 array of encoded labels, zero padded
16 | /label_start_ix and /label_end_ix are (N,) uint32 arrays of pointers to the
17 | first and last indices (in range 1..M) of labels for each image
18 | /label_length stores the length of the sequence for each of the M sequences
19 |
20 | The json file has a dict that contains:
21 | - an 'ix_to_word' field storing the vocab in form {ix:'word'}, where ix is 1-indexed
22 | - an 'images' field that is a list holding auxiliary information for each image,
23 | such as in particular the 'split' it was assigned to.
24 | """
25 |
26 | from __future__ import absolute_import
27 | from __future__ import division
28 | from __future__ import print_function
29 |
30 | import os
31 | import json
32 | import argparse
33 | from random import shuffle, seed
34 | import string
35 | # non-standard dependencies:
36 | import h5py
37 | import numpy as np
38 | import torch
39 | import torchvision.models as models
40 | import skimage.io
41 | from PIL import Image
42 |
43 | import codecs
44 | import tempfile
45 | from subword_nmt import learn_bpe, apply_bpe
46 |
47 | # python scripts/build_bpe_subword_nmt.py --input_json data/dataset_coco.json --output_json data/cocotalkbpe.json --output_h5 data/cocotalkbpe
48 |
49 | def build_vocab(imgs, params):
50 | # count up the number of words
51 | captions = []
52 | for img in imgs:
53 | for sent in img['sentences']:
54 | captions.append(' '.join(sent['tokens']))
55 | captions='\n'.join(captions)
56 | all_captions = tempfile.NamedTemporaryFile(delete=False)
57 | all_captions.close()
58 | with open(all_captions.name, 'w') as txt_file:
59 | txt_file.write(captions)
60 |
61 | #
62 | codecs_output = tempfile.NamedTemporaryFile(delete=False)
63 | codecs_output.close()
64 | with codecs.open(codecs_output.name, 'w', encoding='UTF-8') as output:
65 | learn_bpe.learn_bpe(codecs.open(all_captions.name, encoding='UTF-8'), output, params['symbol_count'])
66 |
67 | with codecs.open(codecs_output.name, encoding='UTF-8') as codes:
68 | bpe = apply_bpe.BPE(codes)
69 |
70 | tmp = tempfile.NamedTemporaryFile(delete=False)
71 | tmp.close()
72 |
73 | tmpout = codecs.open(tmp.name, 'w', encoding='UTF-8')
74 |
75 | for _, img in enumerate(imgs):
76 | img['final_captions'] = []
77 | for sent in img['sentences']:
78 | txt = ' '.join(sent['tokens'])
79 | txt = bpe.segment(txt).strip()
80 | img['final_captions'].append(txt.split(' '))
81 | tmpout.write(txt)
82 | tmpout.write('\n')
83 | if _ < 20:
84 | print(txt)
85 |
86 | tmpout.close()
87 | tmpin = codecs.open(tmp.name, encoding='UTF-8')
88 |
89 | vocab = learn_bpe.get_vocabulary(tmpin)
90 | vocab = sorted(vocab.keys(), key=lambda x: vocab[x], reverse=True)
91 |
92 | # Always insert UNK
93 | print('inserting the special UNK token')
94 | vocab.append('UNK')
95 |
96 | print('Vocab size:', len(vocab))
97 |
98 | os.remove(all_captions.name)
99 | with open(codecs_output.name, 'r') as codes:
100 | bpe = codes.read()
101 | os.remove(codecs_output.name)
102 | os.remove(tmp.name)
103 |
104 | return vocab, bpe
105 |
106 | def encode_captions(imgs, params, wtoi):
107 | """
108 | encode all captions into one large array, which will be 1-indexed.
109 | also produces label_start_ix and label_end_ix which store 1-indexed
110 | and inclusive (Lua-style) pointers to the first and last caption for
111 | each image in the dataset.
112 | """
113 |
114 | max_length = params['max_length']
115 | N = len(imgs)
116 | M = sum(len(img['final_captions']) for img in imgs) # total number of captions
117 |
118 | label_arrays = []
119 | label_start_ix = np.zeros(N, dtype='uint32') # note: these will be one-indexed
120 | label_end_ix = np.zeros(N, dtype='uint32')
121 | label_length = np.zeros(M, dtype='uint32')
122 | caption_counter = 0
123 | counter = 1
124 | for i,img in enumerate(imgs):
125 | n = len(img['final_captions'])
126 | assert n > 0, 'error: some image has no captions'
127 |
128 | Li = np.zeros((n, max_length), dtype='uint32')
129 | for j,s in enumerate(img['final_captions']):
130 | label_length[caption_counter] = min(max_length, len(s)) # record the length of this sequence
131 | caption_counter += 1
132 | for k,w in enumerate(s):
133 | if k < max_length:
134 | Li[j,k] = wtoi[w]
135 |
136 | # note: word indices are 1-indexed, and captions are padded with zeros
137 | label_arrays.append(Li)
138 | label_start_ix[i] = counter
139 | label_end_ix[i] = counter + n - 1
140 |
141 | counter += n
142 |
143 | L = np.concatenate(label_arrays, axis=0) # put all the labels together
144 | assert L.shape[0] == M, 'lengths don\'t match? that\'s weird'
145 | assert np.all(label_length > 0), 'error: some caption had no words?'
146 |
147 | print('encoded captions to array of size ', L.shape)
148 | return L, label_start_ix, label_end_ix, label_length
149 |
150 | def main(params):
151 |
152 | imgs = json.load(open(params['input_json'], 'r'))
153 | imgs = imgs['images']
154 |
155 | seed(123) # make reproducible
156 |
157 | # create the vocab
158 | vocab, bpe = build_vocab(imgs, params)
159 | itow = {i+1:w for i,w in enumerate(vocab)} # a 1-indexed vocab translation table
160 | wtoi = {w:i+1 for i,w in enumerate(vocab)} # inverse table
161 |
162 | # encode captions in large arrays, ready to ship to hdf5 file
163 | L, label_start_ix, label_end_ix, label_length = encode_captions(imgs, params, wtoi)
164 |
165 | # create output h5 file
166 | N = len(imgs)
167 | f_lb = h5py.File(params['output_h5']+'_label.h5', "w")
168 | f_lb.create_dataset("labels", dtype='uint32', data=L)
169 | f_lb.create_dataset("label_start_ix", dtype='uint32', data=label_start_ix)
170 | f_lb.create_dataset("label_end_ix", dtype='uint32', data=label_end_ix)
171 | f_lb.create_dataset("label_length", dtype='uint32', data=label_length)
172 | f_lb.close()
173 |
174 | # create output json file
175 | out = {}
176 | out['ix_to_word'] = itow # encode the (1-indexed) vocab
177 | out['images'] = []
178 | out['bpe'] = bpe
179 | for i,img in enumerate(imgs):
180 |
181 | jimg = {}
182 | jimg['split'] = img['split']
183 | if 'filename' in img: jimg['file_path'] = os.path.join(img['filepath'], img['filename']) # copy it over, might need
184 | if 'cocoid' in img: jimg['id'] = img['cocoid'] # copy over & mantain an id, if present (e.g. coco ids, useful)
185 |
186 | if params['images_root'] != '':
187 | with Image.open(os.path.join(params['images_root'], img['filepath'], img['filename'])) as _img:
188 | jimg['width'], jimg['height'] = _img.size
189 |
190 | out['images'].append(jimg)
191 |
192 | json.dump(out, open(params['output_json'], 'w'))
193 | print('wrote ', params['output_json'])
194 |
195 | if __name__ == "__main__":
196 |
197 | parser = argparse.ArgumentParser()
198 |
199 | # input json
200 | parser.add_argument('--input_json', required=True, help='input json file to process into hdf5')
201 | parser.add_argument('--output_json', default='data.json', help='output json file')
202 | parser.add_argument('--output_h5', default='data', help='output h5 file')
203 | parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json')
204 |
205 | # options
206 | parser.add_argument('--max_length', default=16, type=int, help='max length of a caption, in number of words. captions longer than this get clipped.')
207 | parser.add_argument('--symbol_count', default=10000, type=int, help='only words that occur more than this number of times will be put in vocab')
208 |
209 | args = parser.parse_args()
210 | params = vars(args) # convert to ordinary dict
211 | print('parsed input parameters:')
212 | print(json.dumps(params, indent = 2))
213 | main(params)
214 |
215 |
216 |
--------------------------------------------------------------------------------
/scripts/prepro_labels.py:
--------------------------------------------------------------------------------
1 | """
2 | Preprocess a raw json dataset into hdf5/json files for use in data_loader.lua
3 |
4 | Input: json file that has the form
5 | [{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...]
6 | example element in this list would look like
7 | {'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895}
8 |
9 | This script reads this json, does some basic preprocessing on the captions
10 | (e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays
11 |
12 | Output: a json file and an hdf5 file
13 | The hdf5 file contains several fields:
14 | /images is (N,3,256,256) uint8 array of raw image data in RGB format
15 | /labels is (M,max_length) uint32 array of encoded labels, zero padded
16 | /label_start_ix and /label_end_ix are (N,) uint32 arrays of pointers to the
17 | first and last indices (in range 1..M) of labels for each image
18 | /label_length stores the length of the sequence for each of the M sequences
19 |
20 | The json file has a dict that contains:
21 | - an 'ix_to_word' field storing the vocab in form {ix:'word'}, where ix is 1-indexed
22 | - an 'images' field that is a list holding auxiliary information for each image,
23 | such as in particular the 'split' it was assigned to.
24 | """
25 |
26 | from __future__ import absolute_import
27 | from __future__ import division
28 | from __future__ import print_function
29 |
30 | import os
31 | import json
32 | import argparse
33 | from random import shuffle, seed
34 | import string
35 | # non-standard dependencies:
36 | import h5py
37 | import numpy as np
38 | import torch
39 | import torchvision.models as models
40 | import skimage.io
41 | from PIL import Image
42 |
43 | def build_vocab(imgs, params):
44 | count_thr = params['word_count_threshold']
45 |
46 | # count up the number of words
47 | counts = {}
48 | for img in imgs:
49 | for sent in img['sentences']:
50 | for w in sent['tokens']:
51 | counts[w] = counts.get(w, 0) + 1
52 | cw = sorted([(count,w) for w,count in counts.items()], reverse=True)
53 | print('top words and their counts:')
54 | print('\n'.join(map(str,cw[:20])))
55 |
56 | # print some stats
57 | total_words = sum(counts.values())
58 | print('total words:', total_words)
59 | bad_words = [w for w,n in counts.items() if n <= count_thr]
60 | vocab = [w for w,n in counts.items() if n > count_thr]
61 | bad_count = sum(counts[w] for w in bad_words)
62 | print('number of bad words: %d/%d = %.2f%%' % (len(bad_words), len(counts), len(bad_words)*100.0/len(counts)))
63 | print('number of words in vocab would be %d' % (len(vocab), ))
64 | print('number of UNKs: %d/%d = %.2f%%' % (bad_count, total_words, bad_count*100.0/total_words))
65 |
66 | # lets look at the distribution of lengths as well
67 | sent_lengths = {}
68 | for img in imgs:
69 | for sent in img['sentences']:
70 | txt = sent['tokens']
71 | nw = len(txt)
72 | sent_lengths[nw] = sent_lengths.get(nw, 0) + 1
73 | max_len = max(sent_lengths.keys())
74 | print('max length sentence in raw data: ', max_len)
75 | print('sentence length distribution (count, number of words):')
76 | sum_len = sum(sent_lengths.values())
77 | for i in range(max_len+1):
78 | print('%2d: %10d %f%%' % (i, sent_lengths.get(i,0), sent_lengths.get(i,0)*100.0/sum_len))
79 |
80 | # lets now produce the final annotations
81 | if bad_count > 0:
82 | # additional special UNK token we will use below to map infrequent words to
83 | print('inserting the special UNK token')
84 | vocab.append('UNK')
85 |
86 | for img in imgs:
87 | img['final_captions'] = []
88 | for sent in img['sentences']:
89 | txt = sent['tokens']
90 | caption = [w if counts.get(w,0) > count_thr else 'UNK' for w in txt]
91 | img['final_captions'].append(caption)
92 |
93 | return vocab
94 |
95 | def encode_captions(imgs, params, wtoi):
96 | """
97 | encode all captions into one large array, which will be 1-indexed.
98 | also produces label_start_ix and label_end_ix which store 1-indexed
99 | and inclusive (Lua-style) pointers to the first and last caption for
100 | each image in the dataset.
101 | """
102 |
103 | max_length = params['max_length']
104 | N = len(imgs)
105 | M = sum(len(img['final_captions']) for img in imgs) # total number of captions
106 |
107 | label_arrays = []
108 | label_start_ix = np.zeros(N, dtype='uint32') # note: these will be one-indexed
109 | label_end_ix = np.zeros(N, dtype='uint32')
110 | label_length = np.zeros(M, dtype='uint32')
111 | caption_counter = 0
112 | counter = 1
113 | for i,img in enumerate(imgs):
114 | n = len(img['final_captions'])
115 | assert n > 0, 'error: some image has no captions'
116 |
117 | Li = np.zeros((n, max_length), dtype='uint32')
118 | for j,s in enumerate(img['final_captions']):
119 | label_length[caption_counter] = min(max_length, len(s)) # record the length of this sequence
120 | caption_counter += 1
121 | for k,w in enumerate(s):
122 | if k < max_length:
123 | Li[j,k] = wtoi[w]
124 |
125 | # note: word indices are 1-indexed, and captions are padded with zeros
126 | label_arrays.append(Li)
127 | label_start_ix[i] = counter
128 | label_end_ix[i] = counter + n - 1
129 |
130 | counter += n
131 |
132 | L = np.concatenate(label_arrays, axis=0) # put all the labels together
133 | assert L.shape[0] == M, 'lengths don\'t match? that\'s weird'
134 | assert np.all(label_length > 0), 'error: some caption had no words?'
135 |
136 | print('encoded captions to array of size ', L.shape)
137 | return L, label_start_ix, label_end_ix, label_length
138 |
139 | def main(params):
140 |
141 | imgs = json.load(open(params['input_json'], 'r'))
142 | imgs = imgs['images']
143 |
144 | seed(123) # make reproducible
145 |
146 | # create the vocab
147 | vocab = build_vocab(imgs, params)
148 | itow = {i+1:w for i,w in enumerate(vocab)} # a 1-indexed vocab translation table
149 | wtoi = {w:i+1 for i,w in enumerate(vocab)} # inverse table
150 |
151 | # encode captions in large arrays, ready to ship to hdf5 file
152 | L, label_start_ix, label_end_ix, label_length = encode_captions(imgs, params, wtoi)
153 |
154 | # create output h5 file
155 | N = len(imgs)
156 | f_lb = h5py.File(params['output_h5']+'_label.h5', "w")
157 | f_lb.create_dataset("labels", dtype='uint32', data=L)
158 | f_lb.create_dataset("label_start_ix", dtype='uint32', data=label_start_ix)
159 | f_lb.create_dataset("label_end_ix", dtype='uint32', data=label_end_ix)
160 | f_lb.create_dataset("label_length", dtype='uint32', data=label_length)
161 | f_lb.close()
162 |
163 | # create output json file
164 | out = {}
165 | out['ix_to_word'] = itow # encode the (1-indexed) vocab
166 | out['images'] = []
167 | for i,img in enumerate(imgs):
168 |
169 | jimg = {}
170 | jimg['split'] = img['split']
171 | if 'filename' in img: jimg['file_path'] = os.path.join(img.get('filepath', ''), img['filename']) # copy it over, might need
172 | if 'cocoid' in img:
173 | jimg['id'] = img['cocoid'] # copy over & mantain an id, if present (e.g. coco ids, useful)
174 | elif 'imgid' in img:
175 | jimg['id'] = img['imgid']
176 |
177 | if params['images_root'] != '':
178 | with Image.open(os.path.join(params['images_root'], img['filepath'], img['filename'])) as _img:
179 | jimg['width'], jimg['height'] = _img.size
180 |
181 | out['images'].append(jimg)
182 |
183 | json.dump(out, open(params['output_json'], 'w'))
184 | print('wrote ', params['output_json'])
185 |
186 | if __name__ == "__main__":
187 |
188 | parser = argparse.ArgumentParser()
189 |
190 | # input json
191 | parser.add_argument('--input_json', required=True, help='input json file to process into hdf5')
192 | parser.add_argument('--output_json', default='data.json', help='output json file')
193 | parser.add_argument('--output_h5', default='data', help='output h5 file')
194 | parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json')
195 |
196 | # options
197 | parser.add_argument('--max_length', default=16, type=int, help='max length of a caption, in number of words. captions longer than this get clipped.')
198 | parser.add_argument('--word_count_threshold', default=5, type=int, help='only words that occur more than this number of times will be put in vocab')
199 |
200 | args = parser.parse_args()
201 | params = vars(args) # convert to ordinary dict
202 | print('parsed input parameters:')
203 | print(json.dumps(params, indent = 2))
204 | main(params)
205 |
--------------------------------------------------------------------------------
/models/FCModel.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from torch.autograd import *
9 | import misc.utils as utils
10 |
11 | from .CaptionModel import CaptionModel
12 |
13 | class LSTMCore(nn.Module):
14 | def __init__(self, opt):
15 | super(LSTMCore, self).__init__()
16 | self.input_encoding_size = opt.input_encoding_size
17 | self.rnn_size = opt.rnn_size
18 | self.drop_prob_lm = opt.drop_prob_lm
19 |
20 | # Build a LSTM
21 | self.i2h = nn.Linear(self.input_encoding_size, 5 * self.rnn_size)
22 | self.h2h = nn.Linear(self.rnn_size, 5 * self.rnn_size)
23 | self.dropout = nn.Dropout(self.drop_prob_lm)
24 |
25 | def forward(self, xt, state):
26 |
27 | all_input_sums = self.i2h(xt) + self.h2h(state[0][-1])
28 | sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size)
29 | sigmoid_chunk = torch.sigmoid(sigmoid_chunk)
30 | in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size)
31 | forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size)
32 | out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size)
33 |
34 | in_transform = torch.max(\
35 | all_input_sums.narrow(1, 3 * self.rnn_size, self.rnn_size),
36 | all_input_sums.narrow(1, 4 * self.rnn_size, self.rnn_size))
37 | next_c = forget_gate * state[1][-1] + in_gate * in_transform
38 | next_h = out_gate * torch.tanh(next_c)
39 |
40 | output = self.dropout(next_h)
41 | state = (next_h.unsqueeze(0), next_c.unsqueeze(0))
42 | return output, state
43 |
44 | class FCModel(CaptionModel):
45 | def __init__(self, opt):
46 | super(FCModel, self).__init__()
47 | self.vocab_size = opt.vocab_size
48 | self.input_encoding_size = opt.input_encoding_size
49 | self.rnn_type = opt.rnn_type
50 | self.rnn_size = opt.rnn_size
51 | self.num_layers = opt.num_layers
52 | self.drop_prob_lm = opt.drop_prob_lm
53 | self.seq_length = opt.seq_length
54 | self.fc_feat_size = opt.fc_feat_size
55 |
56 | self.ss_prob = 0.0 # Schedule sampling probability
57 |
58 | self.img_embed = nn.Linear(self.fc_feat_size, self.input_encoding_size)
59 | self.core = LSTMCore(opt)
60 | self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size)
61 | self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1)
62 |
63 | self.init_weights()
64 |
65 | def init_weights(self):
66 | initrange = 0.1
67 | self.embed.weight.data.uniform_(-initrange, initrange)
68 | self.logit.bias.data.fill_(0)
69 | self.logit.weight.data.uniform_(-initrange, initrange)
70 |
71 | def init_hidden(self, bsz):
72 | weight = next(self.parameters())
73 | if self.rnn_type == 'lstm':
74 | return (weight.new_zeros(self.num_layers, bsz, self.rnn_size),
75 | weight.new_zeros(self.num_layers, bsz, self.rnn_size))
76 | else:
77 | return weight.new_zeros(self.num_layers, bsz, self.rnn_size)
78 |
79 | def _forward(self, fc_feats, att_feats, seq, att_masks=None):
80 | batch_size = fc_feats.size(0)
81 | state = self.init_hidden(batch_size)
82 | outputs = []
83 |
84 | for i in range(seq.size(1)):
85 | if i == 0:
86 | xt = self.img_embed(fc_feats)
87 | else:
88 | if self.training and i >= 2 and self.ss_prob > 0.0: # otherwiste no need to sample
89 | sample_prob = fc_feats.data.new(batch_size).uniform_(0, 1)
90 | sample_mask = sample_prob < self.ss_prob
91 | if sample_mask.sum() == 0:
92 | it = seq[:, i-1].clone()
93 | else:
94 | sample_ind = sample_mask.nonzero().view(-1)
95 | it = seq[:, i-1].data.clone()
96 | #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1)
97 | #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1))
98 | prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1)
99 | it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind))
100 | else:
101 | it = seq[:, i-1].clone()
102 | # break if all the sequences end
103 | if i >= 2 and seq[:, i-1].sum() == 0:
104 | break
105 | xt = self.embed(it)
106 |
107 | output, state = self.core(xt, state)
108 | output = F.log_softmax(self.logit(output), dim=1)
109 | outputs.append(output)
110 |
111 | return torch.cat([_.unsqueeze(1) for _ in outputs[1:]], 1).contiguous()
112 |
113 | def get_logprobs_state(self, it, state):
114 | # 'it' is contains a word index
115 | xt = self.embed(it)
116 |
117 | output, state = self.core(xt, state)
118 | logprobs = F.log_softmax(self.logit(output), dim=1)
119 |
120 | return logprobs, state
121 |
122 | def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
123 | beam_size = opt.get('beam_size', 10)
124 | batch_size = fc_feats.size(0)
125 |
126 | assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
127 | seq = torch.LongTensor(self.seq_length, batch_size).zero_()
128 | seqLogprobs = torch.FloatTensor(self.seq_length, batch_size)
129 | # lets process every image independently for now, for simplicity
130 |
131 | self.done_beams = [[] for _ in range(batch_size)]
132 | for k in range(batch_size):
133 | state = self.init_hidden(beam_size)
134 | for t in range(2):
135 | if t == 0:
136 | xt = self.img_embed(fc_feats[k:k+1]).expand(beam_size, self.input_encoding_size)
137 | elif t == 1: # input
138 | it = fc_feats.data.new(beam_size).long().zero_()
139 | xt = self.embed(it)
140 |
141 | output, state = self.core(xt, state)
142 | logprobs = F.log_softmax(self.logit(output), dim=1)
143 |
144 | self.done_beams[k] = self.beam_search(state, logprobs, opt=opt)
145 | seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
146 | seqLogprobs[:, k] = self.done_beams[k][0]['logps']
147 | # return the samples and their log likelihoods
148 | return seq.transpose(0, 1), seqLogprobs.transpose(0, 1)
149 |
150 | def _sample(self, fc_feats, att_feats, att_masks=None, opt={}):
151 | sample_method = opt.get('sample_method', 'greedy')
152 | beam_size = opt.get('beam_size', 1)
153 | temperature = opt.get('temperature', 1.0)
154 | if beam_size > 1:
155 | return self._sample_beam(fc_feats, att_feats, opt)
156 |
157 | batch_size = fc_feats.size(0)
158 | state = self.init_hidden(batch_size)
159 | seq = fc_feats.new_zeros(batch_size, self.seq_length, dtype=torch.long)
160 | seqLogprobs = fc_feats.new_zeros(batch_size, self.seq_length)
161 | for t in range(self.seq_length + 2):
162 | if t == 0:
163 | xt = self.img_embed(fc_feats)
164 | else:
165 | if t == 1: # input
166 | it = fc_feats.data.new(batch_size).long().zero_()
167 | xt = self.embed(it)
168 |
169 | output, state = self.core(xt, state)
170 | logprobs = F.log_softmax(self.logit(output), dim=1)
171 |
172 | # sample the next_word
173 | if t == self.seq_length + 1: # skip if we achieve maximum length
174 | break
175 | if sample_method == 'greedy':
176 | sampleLogprobs, it = torch.max(logprobs.data, 1)
177 | it = it.view(-1).long()
178 | else:
179 | if temperature == 1.0:
180 | prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1)
181 | else:
182 | # scale logprobs by temperature
183 | prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu()
184 | it = torch.multinomial(prob_prev, 1).cuda()
185 | sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions
186 | it = it.view(-1).long() # and flatten indices for downstream processing
187 |
188 | if t >= 1:
189 | # stop when all finished
190 | if t == 1:
191 | unfinished = it > 0
192 | else:
193 | unfinished = unfinished * (it > 0)
194 | it = it * unfinished.type_as(it)
195 | seq[:,t-1] = it #seq[t] the input of t+2 time step
196 | seqLogprobs[:,t-1] = sampleLogprobs.view(-1)
197 | if unfinished.sum() == 0:
198 | break
199 |
200 | return seq, seqLogprobs
201 |
--------------------------------------------------------------------------------
/models/AoAModel.py:
--------------------------------------------------------------------------------
1 | # Implementation for paper 'Attention on Attention for Image Captioning'
2 | # https://arxiv.org/abs/1908.06954
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | import misc.utils as utils
12 |
13 | from .AttModel import pack_wrapper, AttModel, Attention
14 | from .TransformerModel import LayerNorm, attention, clones, SublayerConnection, PositionwiseFeedForward
15 |
16 | class MultiHeadedDotAttention(nn.Module):
17 | def __init__(self, h, d_model, dropout=0.1, scale=1, project_k_v=1, use_output_layer=1, do_aoa=0, norm_q=0, dropout_aoa=0.3):
18 | super(MultiHeadedDotAttention, self).__init__()
19 | assert d_model * scale % h == 0
20 | # We assume d_v always equals d_k
21 | self.d_k = d_model * scale // h
22 | self.h = h
23 |
24 | # Do we need to do linear projections on K and V?
25 | self.project_k_v = project_k_v
26 |
27 | # normalize the query?
28 | if norm_q:
29 | self.norm = LayerNorm(d_model)
30 | else:
31 | self.norm = lambda x:x
32 | self.linears = clones(nn.Linear(d_model, d_model * scale), 1 + 2 * project_k_v)
33 |
34 | # output linear layer after the multi-head attention?
35 | self.output_layer = nn.Linear(d_model * scale, d_model)
36 |
37 | # apply aoa after attention?
38 | self.use_aoa = do_aoa
39 | if self.use_aoa:
40 | self.aoa_layer = nn.Sequential(nn.Linear((1 + scale) * d_model, 2 * d_model), nn.GLU())
41 | # dropout to the input of AoA layer
42 | if dropout_aoa > 0:
43 | self.dropout_aoa = nn.Dropout(p=dropout_aoa)
44 | else:
45 | self.dropout_aoa = lambda x:x
46 |
47 | if self.use_aoa or not use_output_layer:
48 | # AoA doesn't need the output linear layer
49 | del self.output_layer
50 | self.output_layer = lambda x:x
51 |
52 | self.attn = None
53 | self.dropout = nn.Dropout(p=dropout)
54 |
55 | def forward(self, query, value, key, mask=None):
56 | if mask is not None:
57 | if len(mask.size()) == 2:
58 | mask = mask.unsqueeze(-2)
59 | # Same mask applied to all h heads.
60 | mask = mask.unsqueeze(1)
61 |
62 | single_query = 0
63 | if len(query.size()) == 2:
64 | single_query = 1
65 | query = query.unsqueeze(1)
66 |
67 | nbatches = query.size(0)
68 |
69 | query = self.norm(query)
70 |
71 | # Do all the linear projections in batch from d_model => h x d_k
72 | if self.project_k_v == 0:
73 | query_ = self.linears[0](query).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
74 | key_ = key.view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
75 | value_ = value.view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
76 | else:
77 | query_, key_, value_ = \
78 | [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
79 | for l, x in zip(self.linears, (query, key, value))]
80 |
81 | # Apply attention on all the projected vectors in batch.
82 | x, self.attn = attention(query_, key_, value_, mask=mask,
83 | dropout=self.dropout)
84 |
85 | # "Concat" using a view
86 | x = x.transpose(1, 2).contiguous() \
87 | .view(nbatches, -1, self.h * self.d_k)
88 |
89 | if self.use_aoa:
90 | # Apply AoA
91 | x = self.aoa_layer(self.dropout_aoa(torch.cat([x, query], -1)))
92 | x = self.output_layer(x)
93 |
94 | if single_query:
95 | query = query.squeeze(1)
96 | x = x.squeeze(1)
97 | return x
98 |
99 | class AoA_Refiner_Layer(nn.Module):
100 | def __init__(self, size, self_attn, feed_forward, dropout):
101 | super(AoA_Refiner_Layer, self).__init__()
102 | self.self_attn = self_attn
103 | self.feed_forward = feed_forward
104 | self.use_ff = 0
105 | if self.feed_forward is not None:
106 | self.use_ff = 1
107 | self.sublayer = clones(SublayerConnection(size, dropout), 1+self.use_ff)
108 | self.size = size
109 |
110 | def forward(self, x, mask):
111 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
112 | return self.sublayer[-1](x, self.feed_forward) if self.use_ff else x
113 |
114 | class AoA_Refiner_Core(nn.Module):
115 | def __init__(self, opt):
116 | super(AoA_Refiner_Core, self).__init__()
117 | attn = MultiHeadedDotAttention(opt.num_heads, opt.rnn_size, project_k_v=1, scale=opt.multi_head_scale, do_aoa=opt.refine_aoa, norm_q=0, dropout_aoa=getattr(opt, 'dropout_aoa', 0.3))
118 | layer = AoA_Refiner_Layer(opt.rnn_size, attn, PositionwiseFeedForward(opt.rnn_size, 2048, 0.1) if opt.use_ff else None, 0.1)
119 | self.layers = clones(layer, 6)
120 | self.norm = LayerNorm(layer.size)
121 |
122 | def forward(self, x, mask):
123 | for layer in self.layers:
124 | x = layer(x, mask)
125 | return self.norm(x)
126 |
127 | class AoA_Decoder_Core(nn.Module):
128 | def __init__(self, opt):
129 | super(AoA_Decoder_Core, self).__init__()
130 | self.drop_prob_lm = opt.drop_prob_lm
131 | self.d_model = opt.rnn_size
132 | self.use_multi_head = opt.use_multi_head
133 | self.multi_head_scale = opt.multi_head_scale
134 | self.use_ctx_drop = getattr(opt, 'ctx_drop', 0)
135 | self.out_res = getattr(opt, 'out_res', 0)
136 | self.decoder_type = getattr(opt, 'decoder_type', 'AoA')
137 | self.att_lstm = nn.LSTMCell(opt.input_encoding_size + opt.rnn_size, opt.rnn_size) # we, fc, h^2_t-1
138 | self.out_drop = nn.Dropout(self.drop_prob_lm)
139 |
140 | if self.decoder_type == 'AoA':
141 | # AoA layer
142 | self.att2ctx = nn.Sequential(nn.Linear(self.d_model * opt.multi_head_scale + opt.rnn_size, 2 * opt.rnn_size), nn.GLU())
143 | elif self.decoder_type == 'LSTM':
144 | # LSTM layer
145 | self.att2ctx = nn.LSTMCell(self.d_model * opt.multi_head_scale + opt.rnn_size, opt.rnn_size)
146 | else:
147 | # Base linear layer
148 | self.att2ctx = nn.Sequential(nn.Linear(self.d_model * opt.multi_head_scale + opt.rnn_size, opt.rnn_size), nn.ReLU())
149 |
150 | # if opt.use_multi_head == 1: # TODO, not implemented for now
151 | # self.attention = MultiHeadedAddAttention(opt.num_heads, opt.d_model, scale=opt.multi_head_scale)
152 | if opt.use_multi_head == 2:
153 | self.attention = MultiHeadedDotAttention(opt.num_heads, opt.rnn_size, project_k_v=0, scale=opt.multi_head_scale, use_output_layer=0, do_aoa=0, norm_q=1)
154 | else:
155 | self.attention = Attention(opt)
156 |
157 | if self.use_ctx_drop:
158 | self.ctx_drop = nn.Dropout(self.drop_prob_lm)
159 | else:
160 | self.ctx_drop = lambda x :x
161 |
162 | def forward(self, xt, mean_feats, att_feats, p_att_feats, state, att_masks=None):
163 | # state[0][1] is the context vector at the last step
164 | h_att, c_att = self.att_lstm(torch.cat([xt, mean_feats + self.ctx_drop(state[0][1])], 1), (state[0][0], state[1][0]))
165 |
166 | if self.use_multi_head == 2:
167 | att = self.attention(h_att, p_att_feats.narrow(2, 0, self.multi_head_scale * self.d_model), p_att_feats.narrow(2, self.multi_head_scale * self.d_model, self.multi_head_scale * self.d_model), att_masks)
168 | else:
169 | att = self.attention(h_att, att_feats, p_att_feats, att_masks)
170 |
171 | ctx_input = torch.cat([att, h_att], 1)
172 | if self.decoder_type == 'LSTM':
173 | output, c_logic = self.att2ctx(ctx_input, (state[0][1], state[1][1]))
174 | state = (torch.stack((h_att, output)), torch.stack((c_att, c_logic)))
175 | else:
176 | output = self.att2ctx(ctx_input)
177 | # save the context vector to state[0][1]
178 | state = (torch.stack((h_att, output)), torch.stack((c_att, state[1][1])))
179 |
180 | if self.out_res:
181 | # add residual connection
182 | output = output + h_att
183 |
184 | output = self.out_drop(output)
185 | return output, state
186 |
187 | class AoAModel(AttModel):
188 | def __init__(self, opt):
189 | super(AoAModel, self).__init__(opt)
190 | self.num_layers = 2
191 | # mean pooling
192 | self.use_mean_feats = getattr(opt, 'mean_feats', 1)
193 | if opt.use_multi_head == 2:
194 | del self.ctx2att
195 | self.ctx2att = nn.Linear(opt.rnn_size, 2 * opt.multi_head_scale * opt.rnn_size)
196 |
197 | if self.use_mean_feats:
198 | del self.fc_embed
199 | if opt.refine:
200 | self.refiner = AoA_Refiner_Core(opt)
201 | else:
202 | self.refiner = lambda x,y : x
203 | self.core = AoA_Decoder_Core(opt)
204 |
205 |
206 | def _prepare_feature(self, fc_feats, att_feats, att_masks):
207 | att_feats, att_masks = self.clip_att(att_feats, att_masks)
208 |
209 | # embed att feats
210 | att_feats = pack_wrapper(self.att_embed, att_feats, att_masks)
211 | att_feats = self.refiner(att_feats, att_masks)
212 |
213 | if self.use_mean_feats:
214 | # meaning pooling
215 | if att_masks is None:
216 | mean_feats = torch.mean(att_feats, dim=1)
217 | else:
218 | mean_feats = (torch.sum(att_feats * att_masks.unsqueeze(-1), 1) / torch.sum(att_masks.unsqueeze(-1), 1))
219 | else:
220 | mean_feats = self.fc_embed(fc_feats)
221 |
222 | # Project the attention feats first to reduce memory and computation.
223 | p_att_feats = self.ctx2att(att_feats)
224 |
225 | return mean_feats, att_feats, p_att_feats, att_masks
--------------------------------------------------------------------------------
/misc/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import collections
6 | import torch
7 | import torch.nn as nn
8 | import numpy as np
9 | import torch.optim as optim
10 | import os
11 |
12 | import six
13 | from six.moves import cPickle
14 |
15 | bad_endings = ['with','in','on','of','a','at','to','for','an','this','his','her','that']
16 | bad_endings += ['the']
17 |
18 | def pickle_load(f):
19 | """ Load a pickle.
20 | Parameters
21 | ----------
22 | f: file-like object
23 | """
24 | if six.PY3:
25 | return cPickle.load(f, encoding='latin-1')
26 | else:
27 | return cPickle.load(f)
28 |
29 |
30 | def pickle_dump(obj, f):
31 | """ Dump a pickle.
32 | Parameters
33 | ----------
34 | obj: pickled object
35 | f: file-like object
36 | """
37 | if six.PY3:
38 | return cPickle.dump(obj, f, protocol=2)
39 | else:
40 | return cPickle.dump(obj, f)
41 |
42 |
43 | def if_use_feat(caption_model):
44 | # Decide if load attention feature according to caption model
45 | if caption_model in ['show_tell', 'all_img', 'fc', 'newfc']:
46 | use_att, use_fc = False, True
47 | elif caption_model == 'language_model':
48 | use_att, use_fc = False, False
49 | elif caption_model in ['topdown', 'aoa']:
50 | use_fc, use_att = True, True
51 | else:
52 | use_att, use_fc = True, False
53 | return use_fc, use_att
54 |
55 | # Input: seq, N*D numpy array, with element 0 .. vocab_size. 0 is END token.
56 | def decode_sequence(ix_to_word, seq):
57 | N, D = seq.size()
58 | out = []
59 | for i in range(N):
60 | txt = ''
61 | for j in range(D):
62 | ix = seq[i,j]
63 | if ix > 0 :
64 | if j >= 1:
65 | txt = txt + ' '
66 | txt = txt + ix_to_word[str(ix.item())]
67 | else:
68 | break
69 | if int(os.getenv('REMOVE_BAD_ENDINGS', '0')):
70 | flag = 0
71 | words = txt.split(' ')
72 | for j in range(len(words)):
73 | if words[-j-1] not in bad_endings:
74 | flag = -j
75 | break
76 | txt = ' '.join(words[0:len(words)+flag])
77 | out.append(txt.replace('@@ ', ''))
78 | return out
79 |
80 | def to_contiguous(tensor):
81 | if tensor.is_contiguous():
82 | return tensor
83 | else:
84 | return tensor.contiguous()
85 |
86 | class RewardCriterion(nn.Module):
87 | def __init__(self):
88 | super(RewardCriterion, self).__init__()
89 |
90 | def forward(self, input, seq, reward):
91 | input = to_contiguous(input).view(-1)
92 | reward = to_contiguous(reward).view(-1)
93 | mask = (seq>0).float()
94 | mask = to_contiguous(torch.cat([mask.new(mask.size(0), 1).fill_(1), mask[:, :-1]], 1)).view(-1)
95 | output = - input * reward * mask
96 | output = torch.sum(output) / torch.sum(mask)
97 |
98 | return output
99 |
100 | class LanguageModelCriterion(nn.Module):
101 | def __init__(self):
102 | super(LanguageModelCriterion, self).__init__()
103 |
104 | def forward(self, input, target, mask):
105 | # truncate to the same size
106 | target = target[:, :input.size(1)]
107 | mask = mask[:, :input.size(1)]
108 |
109 | output = -input.gather(2, target.unsqueeze(2)).squeeze(2) * mask
110 | output = torch.sum(output) / torch.sum(mask)
111 |
112 | return output
113 |
114 | class LabelSmoothing(nn.Module):
115 | "Implement label smoothing."
116 | def __init__(self, size=0, padding_idx=0, smoothing=0.0):
117 | super(LabelSmoothing, self).__init__()
118 | self.criterion = nn.KLDivLoss(size_average=False, reduce=False)
119 | # self.padding_idx = padding_idx
120 | self.confidence = 1.0 - smoothing
121 | self.smoothing = smoothing
122 | # self.size = size
123 | self.true_dist = None
124 |
125 | def forward(self, input, target, mask):
126 | # truncate to the same size
127 | target = target[:, :input.size(1)]
128 | mask = mask[:, :input.size(1)]
129 |
130 | input = to_contiguous(input).view(-1, input.size(-1))
131 | target = to_contiguous(target).view(-1)
132 | mask = to_contiguous(mask).view(-1)
133 |
134 | # assert x.size(1) == self.size
135 | self.size = input.size(1)
136 | # true_dist = x.data.clone()
137 | true_dist = input.data.clone()
138 | # true_dist.fill_(self.smoothing / (self.size - 2))
139 | true_dist.fill_(self.smoothing / (self.size - 1))
140 | true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
141 | # true_dist[:, self.padding_idx] = 0
142 | # mask = torch.nonzero(target.data == self.padding_idx)
143 | # self.true_dist = true_dist
144 | return (self.criterion(input, true_dist).sum(1) * mask).sum() / mask.sum()
145 |
146 | def set_lr(optimizer, lr):
147 | for group in optimizer.param_groups:
148 | group['lr'] = lr
149 |
150 | def get_lr(optimizer):
151 | for group in optimizer.param_groups:
152 | return group['lr']
153 |
154 | def clip_gradient(optimizer, grad_clip):
155 | for group in optimizer.param_groups:
156 | for param in group['params']:
157 | param.grad.data.clamp_(-grad_clip, grad_clip)
158 |
159 | def build_optimizer(params, opt):
160 | if opt.optim == 'rmsprop':
161 | return optim.RMSprop(params, opt.learning_rate, opt.optim_alpha, opt.optim_epsilon, weight_decay=opt.weight_decay)
162 | elif opt.optim == 'adagrad':
163 | return optim.Adagrad(params, opt.learning_rate, weight_decay=opt.weight_decay)
164 | elif opt.optim == 'sgd':
165 | return optim.SGD(params, opt.learning_rate, weight_decay=opt.weight_decay)
166 | elif opt.optim == 'sgdm':
167 | return optim.SGD(params, opt.learning_rate, opt.optim_alpha, weight_decay=opt.weight_decay)
168 | elif opt.optim == 'sgdmom':
169 | return optim.SGD(params, opt.learning_rate, opt.optim_alpha, weight_decay=opt.weight_decay, nesterov=True)
170 | elif opt.optim == 'adam':
171 | return optim.Adam(params, opt.learning_rate, (opt.optim_alpha, opt.optim_beta), opt.optim_epsilon, weight_decay=opt.weight_decay)
172 | else:
173 | raise Exception("bad option opt.optim: {}".format(opt.optim))
174 |
175 |
176 | def penalty_builder(penalty_config):
177 | if penalty_config == '':
178 | return lambda x,y: y
179 | pen_type, alpha = penalty_config.split('_')
180 | alpha = float(alpha)
181 | if pen_type == 'wu':
182 | return lambda x,y: length_wu(x,y,alpha)
183 | if pen_type == 'avg':
184 | return lambda x,y: length_average(x,y,alpha)
185 |
186 | def length_wu(length, logprobs, alpha=0.):
187 | """
188 | NMT length re-ranking score from
189 | "Google's Neural Machine Translation System" :cite:`wu2016google`.
190 | """
191 |
192 | modifier = (((5 + length) ** alpha) /
193 | ((5 + 1) ** alpha))
194 | return (logprobs / modifier)
195 |
196 | def length_average(length, logprobs, alpha=0.):
197 | """
198 | Returns the average probability of tokens in a sequence.
199 | """
200 | return logprobs / length
201 |
202 |
203 | class NoamOpt(object):
204 | "Optim wrapper that implements rate."
205 | def __init__(self, model_size, factor, warmup, optimizer):
206 | self.optimizer = optimizer
207 | self._step = 0
208 | self.warmup = warmup
209 | self.factor = factor
210 | self.model_size = model_size
211 | self._rate = 0
212 |
213 | def step(self):
214 | "Update parameters and rate"
215 | self._step += 1
216 | rate = self.rate()
217 | for p in self.optimizer.param_groups:
218 | p['lr'] = rate
219 | self._rate = rate
220 | self.optimizer.step()
221 |
222 | def rate(self, step = None):
223 | "Implement `lrate` above"
224 | if step is None:
225 | step = self._step
226 | return self.factor * \
227 | (self.model_size ** (-0.5) *
228 | min(step ** (-0.5), step * self.warmup ** (-1.5)))
229 |
230 | def __getattr__(self, name):
231 | return getattr(self.optimizer, name)
232 |
233 | class ReduceLROnPlateau(object):
234 | "Optim wrapper that implements rate."
235 | def __init__(self, optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08):
236 | self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode, factor, patience, verbose, threshold, threshold_mode, cooldown, min_lr, eps)
237 | self.optimizer = optimizer
238 | self.current_lr = get_lr(optimizer)
239 |
240 | def step(self):
241 | "Update parameters and rate"
242 | self.optimizer.step()
243 |
244 | def scheduler_step(self, val):
245 | self.scheduler.step(val)
246 | self.current_lr = get_lr(self.optimizer)
247 |
248 | def state_dict(self):
249 | return {'current_lr':self.current_lr,
250 | 'scheduler_state_dict': self.scheduler.state_dict(),
251 | 'optimizer_state_dict': self.optimizer.state_dict()}
252 |
253 | def load_state_dict(self, state_dict):
254 | if 'current_lr' not in state_dict:
255 | # it's normal optimizer
256 | self.optimizer.load_state_dict(state_dict)
257 | set_lr(self.optimizer, self.current_lr) # use the lr fromt the option
258 | else:
259 | # it's a schduler
260 | self.current_lr = state_dict['current_lr']
261 | self.scheduler.load_state_dict(state_dict['scheduler_state_dict'])
262 | self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
263 | # current_lr is actually useless in this case
264 |
265 | def rate(self, step = None):
266 | "Implement `lrate` above"
267 | if step is None:
268 | step = self._step
269 | return self.factor * \
270 | (self.model_size ** (-0.5) *
271 | min(step ** (-0.5), step * self.warmup ** (-1.5)))
272 |
273 | def __getattr__(self, name):
274 | return getattr(self.optimizer, name)
275 |
276 | def get_std_opt(model, factor=1, warmup=2000):
277 | # return NoamOpt(model.tgt_embed[0].d_model, 2, 4000,
278 | # torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
279 | return NoamOpt(model.model.tgt_embed[0].d_model, factor, warmup,
280 | torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
281 |
282 |
--------------------------------------------------------------------------------
/models/OldModel.py:
--------------------------------------------------------------------------------
1 | # This file contains ShowAttendTell and AllImg model
2 |
3 | # ShowAttendTell is from Show, Attend and Tell: Neural Image Caption Generation with Visual Attention
4 | # https://arxiv.org/abs/1502.03044
5 |
6 | # AllImg is a model where
7 | # img feature is concatenated with word embedding at every time step as the input of lstm
8 | from __future__ import absolute_import
9 | from __future__ import division
10 | from __future__ import print_function
11 |
12 | import torch
13 | import torch.nn as nn
14 | import torch.nn.functional as F
15 | from torch.autograd import *
16 | import misc.utils as utils
17 |
18 | from .CaptionModel import CaptionModel
19 |
20 | class OldModel(CaptionModel):
21 | def __init__(self, opt):
22 | super(OldModel, self).__init__()
23 | self.vocab_size = opt.vocab_size
24 | self.input_encoding_size = opt.input_encoding_size
25 | self.rnn_type = opt.rnn_type
26 | self.rnn_size = opt.rnn_size
27 | self.num_layers = opt.num_layers
28 | self.drop_prob_lm = opt.drop_prob_lm
29 | self.seq_length = opt.seq_length
30 | self.fc_feat_size = opt.fc_feat_size
31 | self.att_feat_size = opt.att_feat_size
32 |
33 | self.ss_prob = 0.0 # Schedule sampling probability
34 |
35 | self.linear = nn.Linear(self.fc_feat_size, self.num_layers * self.rnn_size) # feature to rnn_size
36 | self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size)
37 | self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1)
38 | self.dropout = nn.Dropout(self.drop_prob_lm)
39 |
40 | self.init_weights()
41 |
42 | def init_weights(self):
43 | initrange = 0.1
44 | self.embed.weight.data.uniform_(-initrange, initrange)
45 | self.logit.bias.data.fill_(0)
46 | self.logit.weight.data.uniform_(-initrange, initrange)
47 |
48 | def init_hidden(self, fc_feats):
49 | image_map = self.linear(fc_feats).view(-1, self.num_layers, self.rnn_size).transpose(0, 1)
50 | if self.rnn_type == 'lstm':
51 | return (image_map, image_map)
52 | else:
53 | return image_map
54 |
55 | def forward(self, fc_feats, att_feats, seq):
56 | batch_size = fc_feats.size(0)
57 | state = self.init_hidden(fc_feats)
58 |
59 | outputs = []
60 |
61 | for i in range(seq.size(1) - 1):
62 | if self.training and i >= 1 and self.ss_prob > 0.0: # otherwiste no need to sample
63 | sample_prob = fc_feats.data.new(batch_size).uniform_(0, 1)
64 | sample_mask = sample_prob < self.ss_prob
65 | if sample_mask.sum() == 0:
66 | it = seq[:, i].clone()
67 | else:
68 | sample_ind = sample_mask.nonzero().view(-1)
69 | it = seq[:, i].data.clone()
70 | #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1)
71 | #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1))
72 | prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1)
73 | it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind))
74 | else:
75 | it = seq[:, i].clone()
76 | # break if all the sequences end
77 | if i >= 1 and seq[:, i].sum() == 0:
78 | break
79 |
80 | xt = self.embed(it)
81 |
82 | output, state = self.core(xt, fc_feats, att_feats, state)
83 | output = F.log_softmax(self.logit(self.dropout(output)), dim=1)
84 | outputs.append(output)
85 |
86 | return torch.cat([_.unsqueeze(1) for _ in outputs], 1)
87 |
88 | def get_logprobs_state(self, it, tmp_fc_feats, tmp_att_feats, state):
89 | # 'it' contains a word index
90 | xt = self.embed(it)
91 |
92 | output, state = self.core(xt, tmp_fc_feats, tmp_att_feats, state)
93 | logprobs = F.log_softmax(self.logit(self.dropout(output)), dim=1)
94 |
95 | return logprobs, state
96 |
97 | def sample_beam(self, fc_feats, att_feats, opt={}):
98 | beam_size = opt.get('beam_size', 10)
99 | batch_size = fc_feats.size(0)
100 |
101 | assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
102 | seq = torch.LongTensor(self.seq_length, batch_size).zero_()
103 | seqLogprobs = torch.FloatTensor(self.seq_length, batch_size)
104 | # lets process every image independently for now, for simplicity
105 |
106 | self.done_beams = [[] for _ in range(batch_size)]
107 | for k in range(batch_size):
108 | tmp_fc_feats = fc_feats[k:k+1].expand(beam_size, self.fc_feat_size)
109 | tmp_att_feats = att_feats[k:k+1].expand(*((beam_size,)+att_feats.size()[1:])).contiguous()
110 |
111 | state = self.init_hidden(tmp_fc_feats)
112 |
113 | beam_seq = torch.LongTensor(self.seq_length, beam_size).zero_()
114 | beam_seq_logprobs = torch.FloatTensor(self.seq_length, beam_size).zero_()
115 | beam_logprobs_sum = torch.zeros(beam_size) # running sum of logprobs for each beam
116 | done_beams = []
117 | for t in range(1):
118 | if t == 0: # input
119 | it = fc_feats.data.new(beam_size).long().zero_()
120 | xt = self.embed(it)
121 |
122 | output, state = self.core(xt, tmp_fc_feats, tmp_att_feats, state)
123 | logprobs = F.log_softmax(self.logit(self.dropout(output)), dim=1)
124 |
125 | self.done_beams[k] = self.beam_search(state, logprobs, tmp_fc_feats, tmp_att_feats, opt=opt)
126 | seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
127 | seqLogprobs[:, k] = self.done_beams[k][0]['logps']
128 | # return the samples and their log likelihoods
129 | return seq.transpose(0, 1), seqLogprobs.transpose(0, 1)
130 |
131 | def sample(self, fc_feats, att_feats, opt={}):
132 | sample_method = opt.get('sample_method', 'greedy')
133 | beam_size = opt.get('beam_size', 1)
134 | temperature = opt.get('temperature', 1.0)
135 | if beam_size > 1:
136 | return self.sample_beam(fc_feats, att_feats, opt)
137 |
138 | batch_size = fc_feats.size(0)
139 | state = self.init_hidden(fc_feats)
140 |
141 | seq = []
142 | seqLogprobs = []
143 | for t in range(self.seq_length + 1):
144 | if t == 0: # input
145 | it = fc_feats.data.new(batch_size).long().zero_()
146 | elif sample_method == 'greedy':
147 | sampleLogprobs, it = torch.max(logprobs.data, 1)
148 | it = it.view(-1).long()
149 | else:
150 | if temperature == 1.0:
151 | prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1)
152 | else:
153 | # scale logprobs by temperature
154 | prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu()
155 | it = torch.multinomial(prob_prev, 1).cuda()
156 | sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions
157 | it = it.view(-1).long() # and flatten indices for downstream processing
158 |
159 | xt = self.embed(it)
160 |
161 | if t >= 1:
162 | # stop when all finished
163 | if t == 1:
164 | unfinished = it > 0
165 | else:
166 | unfinished = unfinished * (it > 0)
167 | if unfinished.sum() == 0:
168 | break
169 | it = it * unfinished.type_as(it)
170 | seq.append(it) #seq[t] the input of t+2 time step
171 | seqLogprobs.append(sampleLogprobs.view(-1))
172 |
173 | output, state = self.core(xt, fc_feats, att_feats, state)
174 | logprobs = F.log_softmax(self.logit(self.dropout(output)), dim=1)
175 |
176 | return torch.cat([_.unsqueeze(1) for _ in seq], 1), torch.cat([_.unsqueeze(1) for _ in seqLogprobs], 1)
177 |
178 |
179 | class ShowAttendTellCore(nn.Module):
180 | def __init__(self, opt):
181 | super(ShowAttendTellCore, self).__init__()
182 | self.input_encoding_size = opt.input_encoding_size
183 | self.rnn_type = opt.rnn_type
184 | self.rnn_size = opt.rnn_size
185 | self.num_layers = opt.num_layers
186 | self.drop_prob_lm = opt.drop_prob_lm
187 | self.fc_feat_size = opt.fc_feat_size
188 | self.att_feat_size = opt.att_feat_size
189 | self.att_hid_size = opt.att_hid_size
190 |
191 | self.rnn = getattr(nn, self.rnn_type.upper())(self.input_encoding_size + self.att_feat_size,
192 | self.rnn_size, self.num_layers, bias=False, dropout=self.drop_prob_lm)
193 |
194 | if self.att_hid_size > 0:
195 | self.ctx2att = nn.Linear(self.att_feat_size, self.att_hid_size)
196 | self.h2att = nn.Linear(self.rnn_size, self.att_hid_size)
197 | self.alpha_net = nn.Linear(self.att_hid_size, 1)
198 | else:
199 | self.ctx2att = nn.Linear(self.att_feat_size, 1)
200 | self.h2att = nn.Linear(self.rnn_size, 1)
201 |
202 | def forward(self, xt, fc_feats, att_feats, state):
203 | att_size = att_feats.numel() // att_feats.size(0) // self.att_feat_size
204 | att = att_feats.view(-1, self.att_feat_size)
205 | if self.att_hid_size > 0:
206 | att = self.ctx2att(att) # (batch * att_size) * att_hid_size
207 | att = att.view(-1, att_size, self.att_hid_size) # batch * att_size * att_hid_size
208 | att_h = self.h2att(state[0][-1]) # batch * att_hid_size
209 | att_h = att_h.unsqueeze(1).expand_as(att) # batch * att_size * att_hid_size
210 | dot = att + att_h # batch * att_size * att_hid_size
211 | dot = F.tanh(dot) # batch * att_size * att_hid_size
212 | dot = dot.view(-1, self.att_hid_size) # (batch * att_size) * att_hid_size
213 | dot = self.alpha_net(dot) # (batch * att_size) * 1
214 | dot = dot.view(-1, att_size) # batch * att_size
215 | else:
216 | att = self.ctx2att(att)(att) # (batch * att_size) * 1
217 | att = att.view(-1, att_size) # batch * att_size
218 | att_h = self.h2att(state[0][-1]) # batch * 1
219 | att_h = att_h.expand_as(att) # batch * att_size
220 | dot = att_h + att # batch * att_size
221 |
222 | weight = F.softmax(dot, dim=1)
223 | att_feats_ = att_feats.view(-1, att_size, self.att_feat_size) # batch * att_size * att_feat_size
224 | att_res = torch.bmm(weight.unsqueeze(1), att_feats_).squeeze(1) # batch * att_feat_size
225 |
226 | output, state = self.rnn(torch.cat([xt, att_res], 1).unsqueeze(0), state)
227 | return output.squeeze(0), state
228 |
229 | class AllImgCore(nn.Module):
230 | def __init__(self, opt):
231 | super(AllImgCore, self).__init__()
232 | self.input_encoding_size = opt.input_encoding_size
233 | self.rnn_type = opt.rnn_type
234 | self.rnn_size = opt.rnn_size
235 | self.num_layers = opt.num_layers
236 | self.drop_prob_lm = opt.drop_prob_lm
237 | self.fc_feat_size = opt.fc_feat_size
238 |
239 | self.rnn = getattr(nn, self.rnn_type.upper())(self.input_encoding_size + self.fc_feat_size,
240 | self.rnn_size, self.num_layers, bias=False, dropout=self.drop_prob_lm)
241 |
242 | def forward(self, xt, fc_feats, att_feats, state):
243 | output, state = self.rnn(torch.cat([xt, fc_feats], 1).unsqueeze(0), state)
244 | return output.squeeze(0), state
245 |
246 | class ShowAttendTellModel(OldModel):
247 | def __init__(self, opt):
248 | super(ShowAttendTellModel, self).__init__(opt)
249 | self.core = ShowAttendTellCore(opt)
250 |
251 | class AllImgModel(OldModel):
252 | def __init__(self, opt):
253 | super(AllImgModel, self).__init__(opt)
254 | self.core = AllImgCore(opt)
255 |
256 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.optim as optim
8 |
9 | import numpy as np
10 |
11 | import time
12 | import os
13 | from six.moves import cPickle
14 | import traceback
15 |
16 | import opts
17 | import models
18 | from dataloader import *
19 | import skimage.io
20 | import eval_utils
21 | import misc.utils as utils
22 | from misc.rewards import init_scorer, get_self_critical_reward
23 | from misc.loss_wrapper import LossWrapper
24 |
25 | try:
26 | import tensorboardX as tb
27 | except ImportError:
28 | print("tensorboardX is not installed")
29 | tb = None
30 |
31 | def add_summary_value(writer, key, value, iteration):
32 | if writer:
33 | writer.add_scalar(key, value, iteration)
34 |
35 | def train(opt):
36 | # Deal with feature things before anything
37 | opt.use_fc, opt.use_att = utils.if_use_feat(opt.caption_model)
38 | if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5
39 |
40 | acc_steps = getattr(opt, 'acc_steps', 1)
41 |
42 | loader = DataLoader(opt)
43 | opt.vocab_size = loader.vocab_size
44 | opt.seq_length = loader.seq_length
45 |
46 | tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)
47 |
48 | infos = {}
49 | histories = {}
50 | if opt.start_from is not None:
51 | # open old infos and check if models are compatible
52 | with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl'), 'rb') as f:
53 | infos = utils.pickle_load(f)
54 | saved_model_opt = infos['opt']
55 | need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"]
56 | for checkme in need_be_same:
57 | assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme
58 |
59 | if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')):
60 | with open(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl'), 'rb') as f:
61 | histories = utils.pickle_load(f)
62 | else:
63 | infos['iter'] = 0
64 | infos['epoch'] = 0
65 | infos['iterators'] = loader.iterators
66 | infos['split_ix'] = loader.split_ix
67 | infos['vocab'] = loader.get_vocab()
68 | infos['opt'] = opt
69 |
70 | iteration = infos.get('iter', 0)
71 | epoch = infos.get('epoch', 0)
72 |
73 | val_result_history = histories.get('val_result_history', {})
74 | loss_history = histories.get('loss_history', {})
75 | lr_history = histories.get('lr_history', {})
76 | ss_prob_history = histories.get('ss_prob_history', {})
77 |
78 | loader.iterators = infos.get('iterators', loader.iterators)
79 | loader.split_ix = infos.get('split_ix', loader.split_ix)
80 | if opt.load_best_score == 1:
81 | best_val_score = infos.get('best_val_score', None)
82 |
83 | opt.vocab = loader.get_vocab()
84 | model = models.setup(opt).cuda()
85 | del opt.vocab
86 | dp_model = torch.nn.DataParallel(model)
87 | lw_model = LossWrapper(model, opt)
88 | dp_lw_model = torch.nn.DataParallel(lw_model)
89 |
90 | epoch_done = True
91 | # Assure in training mode
92 | dp_lw_model.train()
93 |
94 | if opt.noamopt:
95 | assert opt.caption_model in ['transformer','aoa'], 'noamopt can only work with transformer'
96 | optimizer = utils.get_std_opt(model, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup)
97 | optimizer._step = iteration
98 | elif opt.reduce_on_plateau:
99 | optimizer = utils.build_optimizer(model.parameters(), opt)
100 | optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3)
101 | else:
102 | optimizer = utils.build_optimizer(model.parameters(), opt)
103 | # Load the optimizer
104 | if vars(opt).get('start_from', None) is not None and os.path.isfile(os.path.join(opt.start_from,"optimizer.pth")):
105 | optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer.pth')))
106 |
107 |
108 | def save_checkpoint(model, infos, optimizer, histories=None, append=''):
109 | if len(append) > 0:
110 | append = '-' + append
111 | # if checkpoint_path doesn't exist
112 | if not os.path.isdir(opt.checkpoint_path):
113 | os.makedirs(opt.checkpoint_path)
114 | checkpoint_path = os.path.join(opt.checkpoint_path, 'model%s.pth' %(append))
115 | torch.save(model.state_dict(), checkpoint_path)
116 | print("model saved to {}".format(checkpoint_path))
117 | optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer%s.pth' %(append))
118 | torch.save(optimizer.state_dict(), optimizer_path)
119 | with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'%s.pkl' %(append)), 'wb') as f:
120 | utils.pickle_dump(infos, f)
121 | if histories:
122 | with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+'%s.pkl' %(append)), 'wb') as f:
123 | utils.pickle_dump(histories, f)
124 |
125 | try:
126 | while True:
127 | if epoch_done:
128 | if not opt.noamopt and not opt.reduce_on_plateau:
129 | # Assign the learning rate
130 | if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
131 | frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
132 | decay_factor = opt.learning_rate_decay_rate ** frac
133 | opt.current_lr = opt.learning_rate * decay_factor
134 | else:
135 | opt.current_lr = opt.learning_rate
136 | utils.set_lr(optimizer, opt.current_lr) # set the decayed rate
137 | # Assign the scheduled sampling prob
138 | if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
139 | frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
140 | opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob)
141 | model.ss_prob = opt.ss_prob
142 |
143 | # If start self critical training
144 | if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
145 | sc_flag = True
146 | init_scorer(opt.cached_tokens)
147 | else:
148 | sc_flag = False
149 |
150 | epoch_done = False
151 |
152 | start = time.time()
153 | if (opt.use_warmup == 1) and (iteration < opt.noamopt_warmup):
154 | opt.current_lr = opt.learning_rate * (iteration+1) / opt.noamopt_warmup
155 | utils.set_lr(optimizer, opt.current_lr)
156 | # Load data from train split (0)
157 | data = loader.get_batch('train')
158 | print('Read data:', time.time() - start)
159 |
160 | if (iteration % acc_steps == 0):
161 | optimizer.zero_grad()
162 |
163 | torch.cuda.synchronize()
164 | start = time.time()
165 | tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks']]
166 | tmp = [_ if _ is None else _.cuda() for _ in tmp]
167 | fc_feats, att_feats, labels, masks, att_masks = tmp
168 |
169 | model_out = dp_lw_model(fc_feats, att_feats, labels, masks, att_masks, data['gts'], torch.arange(0, len(data['gts'])), sc_flag)
170 |
171 | loss = model_out['loss'].mean()
172 | loss_sp = loss / acc_steps
173 |
174 | loss_sp.backward()
175 | if ((iteration+1) % acc_steps == 0):
176 | utils.clip_gradient(optimizer, opt.grad_clip)
177 | optimizer.step()
178 | torch.cuda.synchronize()
179 | train_loss = loss.item()
180 | end = time.time()
181 | if not sc_flag:
182 | print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
183 | .format(iteration, epoch, train_loss, end - start))
184 | else:
185 | print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
186 | .format(iteration, epoch, model_out['reward'].mean(), end - start))
187 |
188 | # Update the iteration and epoch
189 | iteration += 1
190 | if data['bounds']['wrapped']:
191 | epoch += 1
192 | epoch_done = True
193 |
194 | # Write the training loss summary
195 | if (iteration % opt.losses_log_every == 0):
196 | add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration)
197 | if opt.noamopt:
198 | opt.current_lr = optimizer.rate()
199 | elif opt.reduce_on_plateau:
200 | opt.current_lr = optimizer.current_lr
201 | add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration)
202 | add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)
203 | if sc_flag:
204 | add_summary_value(tb_summary_writer, 'avg_reward', model_out['reward'].mean(), iteration)
205 |
206 | loss_history[iteration] = train_loss if not sc_flag else model_out['reward'].mean()
207 | lr_history[iteration] = opt.current_lr
208 | ss_prob_history[iteration] = model.ss_prob
209 |
210 | # update infos
211 | infos['iter'] = iteration
212 | infos['epoch'] = epoch
213 | infos['iterators'] = loader.iterators
214 | infos['split_ix'] = loader.split_ix
215 |
216 | # make evaluation on validation set, and save model
217 | if (iteration % opt.save_checkpoint_every == 0):
218 | # eval model
219 | eval_kwargs = {'split': 'val',
220 | 'dataset': opt.input_json}
221 | eval_kwargs.update(vars(opt))
222 | val_loss, predictions, lang_stats = eval_utils.eval_split(
223 | dp_model, lw_model.crit, loader, eval_kwargs)
224 |
225 | if opt.reduce_on_plateau:
226 | if 'CIDEr' in lang_stats:
227 | optimizer.scheduler_step(-lang_stats['CIDEr'])
228 | else:
229 | optimizer.scheduler_step(val_loss)
230 | # Write validation result into summary
231 | add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration)
232 | if lang_stats is not None:
233 | for k,v in lang_stats.items():
234 | add_summary_value(tb_summary_writer, k, v, iteration)
235 | val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}
236 |
237 | # Save model if is improving on validation result
238 | if opt.language_eval == 1:
239 | current_score = lang_stats['CIDEr']
240 | else:
241 | current_score = - val_loss
242 |
243 | best_flag = False
244 |
245 | if best_val_score is None or current_score > best_val_score:
246 | best_val_score = current_score
247 | best_flag = True
248 |
249 | # Dump miscalleous informations
250 | infos['best_val_score'] = best_val_score
251 | histories['val_result_history'] = val_result_history
252 | histories['loss_history'] = loss_history
253 | histories['lr_history'] = lr_history
254 | histories['ss_prob_history'] = ss_prob_history
255 |
256 | save_checkpoint(model, infos, optimizer, histories)
257 | if opt.save_history_ckpt:
258 | save_checkpoint(model, infos, optimizer, append=str(iteration))
259 |
260 | if best_flag:
261 | save_checkpoint(model, infos, optimizer, append='best')
262 |
263 | # Stop if reaching max epochs
264 | if epoch >= opt.max_epochs and opt.max_epochs != -1:
265 | break
266 | except (RuntimeError, KeyboardInterrupt):
267 | print('Save ckpt on exception ...')
268 | save_checkpoint(model, infos, optimizer)
269 | print('Save ckpt done.')
270 | stack_trace = traceback.format_exc()
271 | print(stack_trace)
272 |
273 |
274 | opt = opts.parse_opt()
275 | train(opt)
276 |
--------------------------------------------------------------------------------
/models/CaptionModel.py:
--------------------------------------------------------------------------------
1 | # This file contains ShowAttendTell and AllImg model
2 |
3 | # ShowAttendTell is from Show, Attend and Tell: Neural Image Caption Generation with Visual Attention
4 | # https://arxiv.org/abs/1502.03044
5 |
6 | # AllImg is a model where
7 | # img feature is concatenated with word embedding at every time step as the input of lstm
8 | from __future__ import absolute_import
9 | from __future__ import division
10 | from __future__ import print_function
11 |
12 | import numpy as np
13 | import torch
14 | import torch.nn as nn
15 | import torch.nn.functional as F
16 | from torch.autograd import *
17 | import misc.utils as utils
18 |
19 | from functools import reduce
20 |
21 |
22 | class CaptionModel(nn.Module):
23 | def __init__(self):
24 | super(CaptionModel, self).__init__()
25 |
26 | # implements beam search
27 | # calls beam_step and returns the final set of beams
28 | # augments log-probabilities with diversity terms when number of groups > 1
29 |
30 | def forward(self, *args, **kwargs):
31 | mode = kwargs.get('mode', 'forward')
32 | if 'mode' in kwargs:
33 | del kwargs['mode']
34 | return getattr(self, '_'+mode)(*args, **kwargs)
35 |
36 | def beam_search(self, init_state, init_logprobs, *args, **kwargs):
37 |
38 | # function computes the similarity score to be augmented
39 | def add_diversity(beam_seq_table, logprobsf, t, divm, diversity_lambda, bdash):
40 | local_time = t - divm
41 | unaug_logprobsf = logprobsf.clone()
42 | for prev_choice in range(divm):
43 | prev_decisions = beam_seq_table[prev_choice][local_time]
44 | for sub_beam in range(bdash):
45 | for prev_labels in range(bdash):
46 | logprobsf[sub_beam][prev_decisions[prev_labels]] = logprobsf[sub_beam][prev_decisions[prev_labels]] - diversity_lambda
47 | return unaug_logprobsf
48 |
49 | # does one step of classical beam search
50 |
51 | def beam_step(logprobsf, unaug_logprobsf, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state):
52 | #INPUTS:
53 | #logprobsf: probabilities augmented after diversity
54 | #beam_size: obvious
55 | #t : time instant
56 | #beam_seq : tensor contanining the beams
57 | #beam_seq_logprobs: tensor contanining the beam logprobs
58 | #beam_logprobs_sum: tensor contanining joint logprobs
59 | #OUPUTS:
60 | #beam_seq : tensor containing the word indices of the decoded captions
61 | #beam_seq_logprobs : log-probability of each decision made, same size as beam_seq
62 | #beam_logprobs_sum : joint log-probability of each beam
63 |
64 | ys,ix = torch.sort(logprobsf,1,True)
65 | candidates = []
66 | cols = min(beam_size, ys.size(1))
67 | rows = beam_size
68 | if t == 0:
69 | rows = 1
70 | for c in range(cols): # for each column (word, essentially)
71 | for q in range(rows): # for each beam expansion
72 | #compute logprob of expanding beam q with word in (sorted) position c
73 | local_logprob = ys[q,c].item()
74 | candidate_logprob = beam_logprobs_sum[q] + local_logprob
75 | local_unaug_logprob = unaug_logprobsf[q,ix[q,c]]
76 | candidates.append({'c':ix[q,c], 'q':q, 'p':candidate_logprob, 'r':local_unaug_logprob})
77 | candidates = sorted(candidates, key=lambda x: -x['p'])
78 |
79 | new_state = [_.clone() for _ in state]
80 | #beam_seq_prev, beam_seq_logprobs_prev
81 | if t >= 1:
82 | #we''ll need these as reference when we fork beams around
83 | beam_seq_prev = beam_seq[:t].clone()
84 | beam_seq_logprobs_prev = beam_seq_logprobs[:t].clone()
85 | for vix in range(beam_size):
86 | v = candidates[vix]
87 | #fork beam index q into index vix
88 | if t >= 1:
89 | beam_seq[:t, vix] = beam_seq_prev[:, v['q']]
90 | beam_seq_logprobs[:t, vix] = beam_seq_logprobs_prev[:, v['q']]
91 | #rearrange recurrent states
92 | for state_ix in range(len(new_state)):
93 | # copy over state in previous beam q to new beam at vix
94 | new_state[state_ix][:, vix] = state[state_ix][:, v['q']] # dimension one is time step
95 | #append new end terminal at the end of this beam
96 | beam_seq[t, vix] = v['c'] # c'th word is the continuation
97 | beam_seq_logprobs[t, vix] = v['r'] # the raw logprob here
98 | beam_logprobs_sum[vix] = v['p'] # the new (sum) logprob along this beam
99 | state = new_state
100 | return beam_seq,beam_seq_logprobs,beam_logprobs_sum,state,candidates
101 |
102 | # Start diverse_beam_search
103 | opt = kwargs['opt']
104 | temperature = opt.get('temperature', 1) # This should not affect beam search, but will affect dbs
105 | beam_size = opt.get('beam_size', 10)
106 | group_size = opt.get('group_size', 1)
107 | diversity_lambda = opt.get('diversity_lambda', 0.5)
108 | decoding_constraint = opt.get('decoding_constraint', 0)
109 | remove_bad_endings = opt.get('remove_bad_endings', 0)
110 | length_penalty = utils.penalty_builder(opt.get('length_penalty', ''))
111 | bdash = beam_size // group_size # beam per group
112 |
113 | # INITIALIZATIONS
114 | beam_seq_table = [torch.LongTensor(self.seq_length, bdash).zero_() for _ in range(group_size)]
115 | beam_seq_logprobs_table = [torch.FloatTensor(self.seq_length, bdash).zero_() for _ in range(group_size)]
116 | beam_logprobs_sum_table = [torch.zeros(bdash) for _ in range(group_size)]
117 |
118 | # logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1)
119 | done_beams_table = [[] for _ in range(group_size)]
120 | # state_table = [list(torch.unbind(_)) for _ in torch.stack(init_state).chunk(group_size, 2)]
121 | state_table = list(zip(*[_.chunk(group_size, 1) for _ in init_state]))
122 | logprobs_table = list(init_logprobs.chunk(group_size, 0))
123 | # END INIT
124 |
125 | # Chunk elements in the args
126 | args = list(args)
127 | if self.__class__.__name__ == 'AttEnsemble':
128 | args = [[_.chunk(group_size) if _ is not None else [None]*group_size for _ in args_] for args_ in args] # arg_name, model_name, group_name
129 | args = [[[args[j][i][k] for i in range(len(self.models))] for j in range(len(args))] for k in range(group_size)] # group_name, arg_name, model_name
130 | else:
131 | args = [_.chunk(group_size) if _ is not None else [None]*group_size for _ in args]
132 | args = [[args[i][j] for i in range(len(args))] for j in range(group_size)]
133 |
134 | for t in range(self.seq_length + group_size - 1):
135 | for divm in range(group_size):
136 | if t >= divm and t <= self.seq_length + divm - 1:
137 | # add diversity
138 | logprobsf = logprobs_table[divm].data.float()
139 | # suppress previous word
140 | if decoding_constraint and t-divm > 0:
141 | logprobsf.scatter_(1, beam_seq_table[divm][t-divm-1].unsqueeze(1).cuda(), float('-inf'))
142 | if remove_bad_endings and t-divm > 0:
143 | logprobsf[torch.from_numpy(np.isin(beam_seq_table[divm][t-divm-1].cpu().numpy(), self.bad_endings_ix).astype('uint8')), 0] = float('-inf')
144 | # suppress UNK tokens in the decoding
145 | logprobsf[:,logprobsf.size(1)-1] = logprobsf[:, logprobsf.size(1)-1] - 1000
146 | # diversity is added here
147 | # the function directly modifies the logprobsf values and hence, we need to return
148 | # the unaugmented ones for sorting the candidates in the end. # for historical
149 | # reasons :-)
150 | unaug_logprobsf = add_diversity(beam_seq_table,logprobsf,t,divm,diversity_lambda,bdash)
151 |
152 | # infer new beams
153 | beam_seq_table[divm],\
154 | beam_seq_logprobs_table[divm],\
155 | beam_logprobs_sum_table[divm],\
156 | state_table[divm],\
157 | candidates_divm = beam_step(logprobsf,
158 | unaug_logprobsf,
159 | bdash,
160 | t-divm,
161 | beam_seq_table[divm],
162 | beam_seq_logprobs_table[divm],
163 | beam_logprobs_sum_table[divm],
164 | state_table[divm])
165 |
166 | # if time's up... or if end token is reached then copy beams
167 | for vix in range(bdash):
168 | if beam_seq_table[divm][t-divm,vix] == 0 or t == self.seq_length + divm - 1:
169 | final_beam = {
170 | 'seq': beam_seq_table[divm][:, vix].clone(),
171 | 'logps': beam_seq_logprobs_table[divm][:, vix].clone(),
172 | 'unaug_p': beam_seq_logprobs_table[divm][:, vix].sum().item(),
173 | 'p': beam_logprobs_sum_table[divm][vix].item()
174 | }
175 | final_beam['p'] = length_penalty(t-divm+1, final_beam['p'])
176 | done_beams_table[divm].append(final_beam)
177 | # don't continue beams from finished sequences
178 | beam_logprobs_sum_table[divm][vix] = -1000
179 |
180 | # move the current group one step forward in time
181 |
182 | it = beam_seq_table[divm][t-divm]
183 | logprobs_table[divm], state_table[divm] = self.get_logprobs_state(it.cuda(), *(args[divm] + [state_table[divm]]))
184 | logprobs_table[divm] = F.log_softmax(logprobs_table[divm] / temperature, dim=-1)
185 |
186 | # all beams are sorted by their log-probabilities
187 | done_beams_table = [sorted(done_beams_table[i], key=lambda x: -x['p'])[:bdash] for i in range(group_size)]
188 | done_beams = reduce(lambda a,b:a+b, done_beams_table)
189 | return done_beams
190 |
191 |
192 | def sample_next_word(self, logprobs, sample_method, temperature):
193 | if sample_method == 'greedy':
194 | sampleLogprobs, it = torch.max(logprobs.data, 1)
195 | it = it.view(-1).long()
196 | elif sample_method == 'gumbel': # gumbel softmax
197 | # ref: https://gist.github.com/yzh119/fd2146d2aeb329d067568a493b20172f
198 | def sample_gumbel(shape, eps=1e-20):
199 | U = torch.rand(shape).cuda()
200 | return -torch.log(-torch.log(U + eps) + eps)
201 | def gumbel_softmax_sample(logits, temperature):
202 | y = logits + sample_gumbel(logits.size())
203 | return F.log_softmax(y / temperature, dim=-1)
204 | _logprobs = gumbel_softmax_sample(logprobs, temperature)
205 | _, it = torch.max(_logprobs.data, 1)
206 | sampleLogprobs = logprobs.gather(1, it.unsqueeze(1)) # gather the logprobs at sampled positions
207 | else:
208 | logprobs = logprobs / temperature
209 | if sample_method.startswith('top'): # topk sampling
210 | top_num = float(sample_method[3:])
211 | if 0 < top_num < 1:
212 | # nucleus sampling from # The Curious Case of Neural Text Degeneration
213 | probs = F.softmax(logprobs, dim=1)
214 | sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=1)
215 | _cumsum = sorted_probs.cumsum(1)
216 | mask = _cumsum < top_num
217 | mask = torch.cat([torch.ones_like(mask[:,:1]), mask[:,:-1]], 1)
218 | sorted_probs = sorted_probs * mask.float()
219 | sorted_probs = sorted_probs / sorted_probs.sum(1, keepdim=True)
220 | logprobs.scatter_(1, sorted_indices, sorted_probs.log())
221 | else:
222 | the_k = int(top_num)
223 | tmp = torch.empty_like(logprobs).fill_(float('-inf'))
224 | topk, indices = torch.topk(logprobs, the_k, dim=1)
225 | tmp = tmp.scatter(1, indices, topk)
226 | logprobs = tmp
227 | it = torch.distributions.Categorical(logits=logprobs.detach()).sample()
228 | sampleLogprobs = logprobs.gather(1, it.unsqueeze(1)) # gather the logprobs at sampled positions
229 | return it, sampleLogprobs
--------------------------------------------------------------------------------
/models/TransformerModel.py:
--------------------------------------------------------------------------------
1 | # This file contains Transformer network
2 | # Most of the code is copied from http://nlp.seas.harvard.edu/2018/04/03/attention.html
3 |
4 | # The cfg name correspondance:
5 | # N=num_layers
6 | # d_model=input_encoding_size
7 | # d_ff=rnn_size
8 | # h is always 8
9 |
10 | from __future__ import absolute_import
11 | from __future__ import division
12 | from __future__ import print_function
13 |
14 | import torch
15 | import torch.nn as nn
16 | import torch.nn.functional as F
17 | import misc.utils as utils
18 |
19 | import copy
20 | import math
21 | import numpy as np
22 |
23 | from .CaptionModel import CaptionModel
24 | from .AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper, AttModel
25 |
26 | class EncoderDecoder(nn.Module):
27 | """
28 | A standard Encoder-Decoder architecture. Base for this and many
29 | other models.
30 | """
31 | def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
32 | super(EncoderDecoder, self).__init__()
33 | self.encoder = encoder
34 | self.decoder = decoder
35 | self.src_embed = src_embed
36 | self.tgt_embed = tgt_embed
37 | self.generator = generator
38 |
39 | def forward(self, src, tgt, src_mask, tgt_mask):
40 | "Take in and process masked src and target sequences."
41 | return self.decode(self.encode(src, src_mask), src_mask,
42 | tgt, tgt_mask)
43 |
44 | def encode(self, src, src_mask):
45 | return self.encoder(self.src_embed(src), src_mask)
46 |
47 | def decode(self, memory, src_mask, tgt, tgt_mask):
48 | return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)
49 |
50 | class Generator(nn.Module):
51 | "Define standard linear + softmax generation step."
52 | def __init__(self, d_model, vocab):
53 | super(Generator, self).__init__()
54 | self.proj = nn.Linear(d_model, vocab)
55 |
56 | def forward(self, x):
57 | return F.log_softmax(self.proj(x), dim=-1)
58 |
59 | def clones(module, N):
60 | "Produce N identical layers."
61 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
62 |
63 | class Encoder(nn.Module):
64 | "Core encoder is a stack of N layers"
65 | def __init__(self, layer, N):
66 | super(Encoder, self).__init__()
67 | self.layers = clones(layer, N)
68 | self.norm = LayerNorm(layer.size)
69 |
70 | def forward(self, x, mask):
71 | "Pass the input (and mask) through each layer in turn."
72 | for layer in self.layers:
73 | x = layer(x, mask)
74 | return self.norm(x)
75 |
76 | class LayerNorm(nn.Module):
77 | "Construct a layernorm module (See citation for details)."
78 | def __init__(self, features, eps=1e-6):
79 | super(LayerNorm, self).__init__()
80 | self.a_2 = nn.Parameter(torch.ones(features))
81 | self.b_2 = nn.Parameter(torch.zeros(features))
82 | self.eps = eps
83 |
84 | def forward(self, x):
85 | mean = x.mean(-1, keepdim=True)
86 | std = x.std(-1, keepdim=True)
87 | return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
88 |
89 | class SublayerConnection(nn.Module):
90 | """
91 | A residual connection followed by a layer norm.
92 | Note for code simplicity the norm is first as opposed to last.
93 | """
94 | def __init__(self, size, dropout):
95 | super(SublayerConnection, self).__init__()
96 | self.norm = LayerNorm(size)
97 | self.dropout = nn.Dropout(dropout)
98 |
99 | def forward(self, x, sublayer):
100 | "Apply residual connection to any sublayer with the same size."
101 | return x + self.dropout(sublayer(self.norm(x)))
102 |
103 | class EncoderLayer(nn.Module):
104 | "Encoder is made up of self-attn and feed forward (defined below)"
105 | def __init__(self, size, self_attn, feed_forward, dropout):
106 | super(EncoderLayer, self).__init__()
107 | self.self_attn = self_attn
108 | self.feed_forward = feed_forward
109 | self.sublayer = clones(SublayerConnection(size, dropout), 2)
110 | self.size = size
111 |
112 | def forward(self, x, mask):
113 | "Follow Figure 1 (left) for connections."
114 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
115 | return self.sublayer[1](x, self.feed_forward)
116 |
117 | class Decoder(nn.Module):
118 | "Generic N layer decoder with masking."
119 | def __init__(self, layer, N):
120 | super(Decoder, self).__init__()
121 | self.layers = clones(layer, N)
122 | self.norm = LayerNorm(layer.size)
123 |
124 | def forward(self, x, memory, src_mask, tgt_mask):
125 | for layer in self.layers:
126 | x = layer(x, memory, src_mask, tgt_mask)
127 | return self.norm(x)
128 |
129 | class DecoderLayer(nn.Module):
130 | "Decoder is made of self-attn, src-attn, and feed forward (defined below)"
131 | def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
132 | super(DecoderLayer, self).__init__()
133 | self.size = size
134 | self.self_attn = self_attn
135 | self.src_attn = src_attn
136 | self.feed_forward = feed_forward
137 | self.sublayer = clones(SublayerConnection(size, dropout), 3)
138 |
139 | def forward(self, x, memory, src_mask, tgt_mask):
140 | "Follow Figure 1 (right) for connections."
141 | m = memory
142 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
143 | x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
144 | return self.sublayer[2](x, self.feed_forward)
145 |
146 | def subsequent_mask(size):
147 | "Mask out subsequent positions."
148 | attn_shape = (1, size, size)
149 | subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
150 | return torch.from_numpy(subsequent_mask) == 0
151 |
152 | def attention(query, key, value, mask=None, dropout=None):
153 | "Compute 'Scaled Dot Product Attention'"
154 | d_k = query.size(-1)
155 | scores = torch.matmul(query, key.transpose(-2, -1)) \
156 | / math.sqrt(d_k)
157 | if mask is not None:
158 | scores = scores.masked_fill(mask == 0, -1e9)
159 | p_attn = F.softmax(scores, dim = -1)
160 | if dropout is not None:
161 | p_attn = dropout(p_attn)
162 | return torch.matmul(p_attn, value), p_attn
163 |
164 | class MultiHeadedAttention(nn.Module):
165 | def __init__(self, h, d_model, dropout=0.1):
166 | "Take in model size and number of heads."
167 | super(MultiHeadedAttention, self).__init__()
168 | assert d_model % h == 0
169 | # We assume d_v always equals d_k
170 | self.d_k = d_model // h
171 | self.h = h
172 | self.linears = clones(nn.Linear(d_model, d_model), 4)
173 | self.attn = None
174 | self.dropout = nn.Dropout(p=dropout)
175 |
176 | def forward(self, query, key, value, mask=None):
177 | "Implements Figure 2"
178 | if mask is not None:
179 | # Same mask applied to all h heads.
180 | mask = mask.unsqueeze(1)
181 | nbatches = query.size(0)
182 |
183 | # 1) Do all the linear projections in batch from d_model => h x d_k
184 | query, key, value = \
185 | [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
186 | for l, x in zip(self.linears, (query, key, value))]
187 |
188 | # 2) Apply attention on all the projected vectors in batch.
189 | x, self.attn = attention(query, key, value, mask=mask,
190 | dropout=self.dropout)
191 |
192 | # 3) "Concat" using a view and apply a final linear.
193 | x = x.transpose(1, 2).contiguous() \
194 | .view(nbatches, -1, self.h * self.d_k)
195 | return self.linears[-1](x)
196 |
197 | class PositionwiseFeedForward(nn.Module):
198 | "Implements FFN equation."
199 | def __init__(self, d_model, d_ff, dropout=0.1):
200 | super(PositionwiseFeedForward, self).__init__()
201 | self.w_1 = nn.Linear(d_model, d_ff)
202 | self.w_2 = nn.Linear(d_ff, d_model)
203 | self.dropout = nn.Dropout(dropout)
204 |
205 | def forward(self, x):
206 | return self.w_2(self.dropout(F.relu(self.w_1(x))))
207 |
208 | class Embeddings(nn.Module):
209 | def __init__(self, d_model, vocab):
210 | super(Embeddings, self).__init__()
211 | self.lut = nn.Embedding(vocab, d_model)
212 | self.d_model = d_model
213 |
214 | def forward(self, x):
215 | return self.lut(x) * math.sqrt(self.d_model)
216 |
217 | class PositionalEncoding(nn.Module):
218 | "Implement the PE function."
219 | def __init__(self, d_model, dropout, max_len=5000):
220 | super(PositionalEncoding, self).__init__()
221 | self.dropout = nn.Dropout(p=dropout)
222 |
223 | # Compute the positional encodings once in log space.
224 | pe = torch.zeros(max_len, d_model)
225 | position = torch.arange(0, max_len).unsqueeze(1).float()
226 | div_term = torch.exp(torch.arange(0, d_model, 2).float() *
227 | -(math.log(10000.0) / d_model))
228 | pe[:, 0::2] = torch.sin(position * div_term)
229 | pe[:, 1::2] = torch.cos(position * div_term)
230 | pe = pe.unsqueeze(0)
231 | self.register_buffer('pe', pe)
232 |
233 | def forward(self, x):
234 | x = x + self.pe[:, :x.size(1)]
235 | return self.dropout(x)
236 |
237 | class TransformerModel(AttModel):
238 |
239 | def make_model(self, src_vocab, tgt_vocab, N=6,
240 | d_model=512, d_ff=2048, h=8, dropout=0.1):
241 | "Helper: Construct a model from hyperparameters."
242 | c = copy.deepcopy
243 | attn = MultiHeadedAttention(h, d_model)
244 | ff = PositionwiseFeedForward(d_model, d_ff, dropout)
245 | position = PositionalEncoding(d_model, dropout)
246 | model = EncoderDecoder(
247 | Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
248 | Decoder(DecoderLayer(d_model, c(attn), c(attn),
249 | c(ff), dropout), N),
250 | lambda x:x, # nn.Sequential(Embeddings(d_model, src_vocab), c(position)),
251 | nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)),
252 | Generator(d_model, tgt_vocab))
253 |
254 | # This was important from their code.
255 | # Initialize parameters with Glorot / fan_avg.
256 | for p in model.parameters():
257 | if p.dim() > 1:
258 | nn.init.xavier_uniform_(p)
259 | return model
260 |
261 | def __init__(self, opt):
262 | super(TransformerModel, self).__init__(opt)
263 | self.opt = opt
264 | # self.config = yaml.load(open(opt.config_file))
265 | # d_model = self.input_encoding_size # 512
266 |
267 | delattr(self, 'att_embed')
268 | self.att_embed = nn.Sequential(*(
269 | ((nn.BatchNorm1d(self.att_feat_size),) if self.use_bn else ())+
270 | (nn.Linear(self.att_feat_size, self.input_encoding_size),
271 | nn.ReLU(),
272 | nn.Dropout(self.drop_prob_lm))+
273 | ((nn.BatchNorm1d(self.input_encoding_size),) if self.use_bn==2 else ())))
274 |
275 | delattr(self, 'embed')
276 | self.embed = lambda x : x
277 | delattr(self, 'fc_embed')
278 | self.fc_embed = lambda x : x
279 | delattr(self, 'logit')
280 | del self.ctx2att
281 |
282 | tgt_vocab = self.vocab_size + 1
283 | self.model = self.make_model(0, tgt_vocab,
284 | N=opt.num_layers,
285 | d_model=opt.input_encoding_size,
286 | d_ff=opt.rnn_size)
287 |
288 | def logit(self, x): # unsafe way
289 | return self.model.generator.proj(x)
290 |
291 | def init_hidden(self, bsz):
292 | return []
293 |
294 | def _prepare_feature(self, fc_feats, att_feats, att_masks):
295 |
296 | att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks)
297 | memory = self.model.encode(att_feats, att_masks)
298 |
299 | return fc_feats[...,:1], att_feats[...,:1], memory, att_masks
300 |
301 | def _prepare_feature_forward(self, att_feats, att_masks=None, seq=None):
302 | att_feats, att_masks = self.clip_att(att_feats, att_masks)
303 |
304 | att_feats = pack_wrapper(self.att_embed, att_feats, att_masks)
305 |
306 | if att_masks is None:
307 | att_masks = att_feats.new_ones(att_feats.shape[:2], dtype=torch.long)
308 | att_masks = att_masks.unsqueeze(-2)
309 |
310 | if seq is not None:
311 | # crop the last one
312 | seq = seq[:,:-1]
313 | seq_mask = (seq.data > 0)
314 | seq_mask[:,0] += 1
315 |
316 | seq_mask = seq_mask.unsqueeze(-2)
317 | seq_mask = seq_mask & subsequent_mask(seq.size(-1)).to(seq_mask)
318 | else:
319 | seq_mask = None
320 |
321 | return att_feats, seq, att_masks, seq_mask
322 |
323 | def _forward(self, fc_feats, att_feats, seq, att_masks=None):
324 | att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks, seq)
325 |
326 | out = self.model(att_feats, seq, att_masks, seq_mask)
327 |
328 | outputs = self.model.generator(out)
329 | return outputs
330 | # return torch.cat([_.unsqueeze(1) for _ in outputs], 1)
331 |
332 | def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask):
333 | """
334 | state = [ys.unsqueeze(0)]
335 | """
336 | if len(state) == 0:
337 | ys = it.unsqueeze(1)
338 | else:
339 | ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1)
340 | out = self.model.decode(memory, mask,
341 | ys,
342 | subsequent_mask(ys.size(1))
343 | .to(memory.device))
344 | return out[:, -1], [ys.unsqueeze(0)]
--------------------------------------------------------------------------------
/dataloader.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import json
6 | import h5py
7 | import lmdb
8 | import os
9 | import numpy as np
10 | import random
11 |
12 | import torch
13 | import torch.utils.data as data
14 |
15 | import multiprocessing
16 | import six
17 |
18 | class HybridLoader:
19 | """
20 | If db_path is a director, then use normal file loading
21 | If lmdb, then load from lmdb
22 | The loading method depend on extention.
23 | """
24 | def __init__(self, db_path, ext):
25 | self.db_path = db_path
26 | self.ext = ext
27 | if self.ext == '.npy':
28 | self.loader = lambda x: np.load(x)
29 | else:
30 | self.loader = lambda x: np.load(x)['feat']
31 | if db_path.endswith('.lmdb'):
32 | self.db_type = 'lmdb'
33 | self.env = lmdb.open(db_path, subdir=os.path.isdir(db_path),
34 | readonly=True, lock=False,
35 | readahead=False, meminit=False)
36 | elif db_path.endswith('.pth'): # Assume a key,value dictionary
37 | self.db_type = 'pth'
38 | self.feat_file = torch.load(db_path)
39 | self.loader = lambda x: x
40 | print('HybridLoader: ext is ignored')
41 | else:
42 | self.db_type = 'dir'
43 |
44 | def get(self, key):
45 |
46 | if self.db_type == 'lmdb':
47 | env = self.env
48 | with env.begin(write=False) as txn:
49 | byteflow = txn.get(key)
50 | f_input = six.BytesIO(byteflow)
51 | elif self.db_type == 'pth':
52 | f_input = self.feat_file[key]
53 | else:
54 | f_input = os.path.join(self.db_path, key + self.ext)
55 |
56 | # load image
57 | feat = self.loader(f_input)
58 |
59 | return feat
60 |
61 |
62 | class DataLoader(data.Dataset):
63 |
64 | def reset_iterator(self, split):
65 | del self._prefetch_process[split]
66 | self._prefetch_process[split] = BlobFetcher(split, self, split=='train')
67 | self.iterators[split] = 0
68 |
69 | def get_vocab_size(self):
70 | return self.vocab_size
71 |
72 | def get_vocab(self):
73 | return self.ix_to_word
74 |
75 | def get_seq_length(self):
76 | return self.seq_length
77 |
78 | def __init__(self, opt):
79 | self.opt = opt
80 | self.batch_size = self.opt.batch_size
81 | self.seq_per_img = opt.seq_per_img
82 |
83 | # feature related options
84 | self.use_fc = getattr(opt, 'use_fc', True)
85 | self.use_att = getattr(opt, 'use_att', True)
86 | self.use_box = getattr(opt, 'use_box', 0)
87 | self.norm_att_feat = getattr(opt, 'norm_att_feat', 0)
88 | self.norm_box_feat = getattr(opt, 'norm_box_feat', 0)
89 |
90 | # load the json file which contains additional information about the dataset
91 | print('DataLoader loading json file: ', opt.input_json)
92 | self.info = json.load(open(self.opt.input_json))
93 | if 'ix_to_word' in self.info:
94 | self.ix_to_word = self.info['ix_to_word']
95 | self.vocab_size = len(self.ix_to_word)
96 | print('vocab size is ', self.vocab_size)
97 |
98 | # open the hdf5 file
99 | print('DataLoader loading h5 file: ', opt.input_fc_dir, opt.input_att_dir, opt.input_box_dir, opt.input_label_h5)
100 | if self.opt.input_label_h5 != 'none':
101 | self.h5_label_file = h5py.File(self.opt.input_label_h5, 'r', driver='core')
102 | # load in the sequence data
103 | seq_size = self.h5_label_file['labels'].shape
104 | self.label = self.h5_label_file['labels'][:]
105 | self.seq_length = seq_size[1]
106 | print('max sequence length in data is', self.seq_length)
107 | # load the pointers in full to RAM (should be small enough)
108 | self.label_start_ix = self.h5_label_file['label_start_ix'][:]
109 | self.label_end_ix = self.h5_label_file['label_end_ix'][:]
110 | else:
111 | self.seq_length = 1
112 |
113 | self.fc_loader = HybridLoader(self.opt.input_fc_dir, '.npy')
114 | self.att_loader = HybridLoader(self.opt.input_att_dir, '.npz')
115 | self.box_loader = HybridLoader(self.opt.input_box_dir, '.npy')
116 |
117 | self.num_images = len(self.info['images']) # self.label_start_ix.shape[0]
118 | print('read %d image features' %(self.num_images))
119 |
120 | # separate out indexes for each of the provided splits
121 | self.split_ix = {'train': [], 'val': [], 'test': []}
122 | for ix in range(len(self.info['images'])):
123 | img = self.info['images'][ix]
124 | if not 'split' in img:
125 | self.split_ix['train'].append(ix)
126 | self.split_ix['val'].append(ix)
127 | self.split_ix['test'].append(ix)
128 | elif img['split'] == 'train':
129 | self.split_ix['train'].append(ix)
130 | elif img['split'] == 'val':
131 | self.split_ix['val'].append(ix)
132 | elif img['split'] == 'test':
133 | self.split_ix['test'].append(ix)
134 | elif opt.train_only == 0: # restval
135 | self.split_ix['train'].append(ix)
136 |
137 | print('assigned %d images to split train' %len(self.split_ix['train']))
138 | print('assigned %d images to split val' %len(self.split_ix['val']))
139 | print('assigned %d images to split test' %len(self.split_ix['test']))
140 |
141 | self.iterators = {'train': 0, 'val': 0, 'test': 0}
142 |
143 | self._prefetch_process = {} # The three prefetch process
144 | for split in self.iterators.keys():
145 | self._prefetch_process[split] = BlobFetcher(split, self, split=='train')
146 | # Terminate the child process when the parent exists
147 | def cleanup():
148 | print('Terminating BlobFetcher')
149 | for split in self.iterators.keys():
150 | del self._prefetch_process[split]
151 | import atexit
152 | atexit.register(cleanup)
153 |
154 | def get_captions(self, ix, seq_per_img):
155 | # fetch the sequence labels
156 | ix1 = self.label_start_ix[ix] - 1 #label_start_ix starts from 1
157 | ix2 = self.label_end_ix[ix] - 1
158 | ncap = ix2 - ix1 + 1 # number of captions available for this image
159 | assert ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t'
160 |
161 | if ncap < seq_per_img:
162 | # we need to subsample (with replacement)
163 | seq = np.zeros([seq_per_img, self.seq_length], dtype = 'int')
164 | for q in range(seq_per_img):
165 | ixl = random.randint(ix1,ix2)
166 | seq[q, :] = self.label[ixl, :self.seq_length]
167 | else:
168 | ixl = random.randint(ix1, ix2 - seq_per_img + 1)
169 | seq = self.label[ixl: ixl + seq_per_img, :self.seq_length]
170 |
171 | return seq
172 |
173 | def get_batch(self, split, batch_size=None):
174 | batch_size = batch_size or self.batch_size
175 | seq_per_img = self.seq_per_img
176 |
177 | fc_batch = [] # np.ndarray((batch_size * seq_per_img, self.opt.fc_feat_size), dtype = 'float32')
178 | att_batch = [] # np.ndarray((batch_size * seq_per_img, 14, 14, self.opt.att_feat_size), dtype = 'float32')
179 | label_batch = [] #np.zeros([batch_size * seq_per_img, self.seq_length + 2], dtype = 'int')
180 |
181 | wrapped = False
182 |
183 | infos = []
184 | gts = []
185 |
186 | for i in range(batch_size):
187 | # fetch image
188 | tmp_fc, tmp_att, tmp_seq, \
189 | ix, tmp_wrapped = self._prefetch_process[split].get()
190 | if tmp_wrapped:
191 | wrapped = True
192 |
193 | fc_batch.append(tmp_fc)
194 | att_batch.append(tmp_att)
195 |
196 | tmp_label = np.zeros([seq_per_img, self.seq_length + 2], dtype = 'int')
197 | if hasattr(self, 'h5_label_file'):
198 | tmp_label[:, 1 : self.seq_length + 1] = tmp_seq
199 | label_batch.append(tmp_label)
200 |
201 | # Used for reward evaluation
202 | if hasattr(self, 'h5_label_file'):
203 | gts.append(self.label[self.label_start_ix[ix] - 1: self.label_end_ix[ix]])
204 | else:
205 | gts.append([])
206 |
207 | # record associated info as well
208 | info_dict = {}
209 | info_dict['ix'] = ix
210 | info_dict['id'] = self.info['images'][ix]['id']
211 | info_dict['file_path'] = self.info['images'][ix].get('file_path', '')
212 | infos.append(info_dict)
213 |
214 | # #sort by att_feat length
215 | # fc_batch, att_batch, label_batch, gts, infos = \
216 | # zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: len(x[1]), reverse=True))
217 | fc_batch, att_batch, label_batch, gts, infos = \
218 | zip(*sorted(zip(fc_batch, att_batch, label_batch, gts, infos), key=lambda x: 0, reverse=True))
219 | data = {}
220 | data['fc_feats'] = np.stack(sum([[_]*seq_per_img for _ in fc_batch], []))
221 | # merge att_feats
222 | max_att_len = max([_.shape[0] for _ in att_batch])
223 | data['att_feats'] = np.zeros([len(att_batch)*seq_per_img, max_att_len, att_batch[0].shape[1]], dtype = 'float32')
224 | for i in range(len(att_batch)):
225 | data['att_feats'][i*seq_per_img:(i+1)*seq_per_img, :att_batch[i].shape[0]] = att_batch[i]
226 | data['att_masks'] = np.zeros(data['att_feats'].shape[:2], dtype='float32')
227 | for i in range(len(att_batch)):
228 | data['att_masks'][i*seq_per_img:(i+1)*seq_per_img, :att_batch[i].shape[0]] = 1
229 | # set att_masks to None if attention features have same length
230 | if data['att_masks'].sum() == data['att_masks'].size:
231 | data['att_masks'] = None
232 |
233 | data['labels'] = np.vstack(label_batch)
234 | # generate mask
235 | nonzeros = np.array(list(map(lambda x: (x != 0).sum()+2, data['labels'])))
236 | mask_batch = np.zeros([data['labels'].shape[0], self.seq_length + 2], dtype = 'float32')
237 | for ix, row in enumerate(mask_batch):
238 | row[:nonzeros[ix]] = 1
239 | data['masks'] = mask_batch
240 |
241 | data['gts'] = gts # all ground truth captions of each images
242 | data['bounds'] = {'it_pos_now': self.iterators[split], 'it_max': len(self.split_ix[split]), 'wrapped': wrapped}
243 | data['infos'] = infos
244 |
245 | data = {k:torch.from_numpy(v) if type(v) is np.ndarray else v for k,v in data.items()} # Turn all ndarray to torch tensor
246 |
247 | return data
248 |
249 | # It's not coherent to make DataLoader a subclass of Dataset, but essentially, we only need to implement the following to functions,
250 | # so that the torch.utils.data.DataLoader can load the data according the index.
251 | # However, it's minimum change to switch to pytorch data loading.
252 | def __getitem__(self, index):
253 | """This function returns a tuple that is further passed to collate_fn
254 | """
255 | ix = index #self.split_ix[index]
256 | if self.use_att:
257 | att_feat = self.att_loader.get(str(self.info['images'][ix]['id']))
258 | # Reshape to K x C
259 | att_feat = att_feat.reshape(-1, att_feat.shape[-1])
260 | if self.norm_att_feat:
261 | att_feat = att_feat / np.linalg.norm(att_feat, 2, 1, keepdims=True)
262 | if self.use_box:
263 | box_feat = self.box_loader.get(str(self.info['images'][ix]['id']))
264 | # devided by image width and height
265 | x1,y1,x2,y2 = np.hsplit(box_feat, 4)
266 | h,w = self.info['images'][ix]['height'], self.info['images'][ix]['width']
267 | box_feat = np.hstack((x1/w, y1/h, x2/w, y2/h, (x2-x1)*(y2-y1)/(w*h))) # question? x2-x1+1??
268 | if self.norm_box_feat:
269 | box_feat = box_feat / np.linalg.norm(box_feat, 2, 1, keepdims=True)
270 | att_feat = np.hstack([att_feat, box_feat])
271 | # sort the features by the size of boxes
272 | att_feat = np.stack(sorted(att_feat, key=lambda x:x[-1], reverse=True))
273 | else:
274 | att_feat = np.zeros((1,1,1), dtype='float32')
275 | if self.use_fc:
276 | fc_feat = self.fc_loader.get(str(self.info['images'][ix]['id']))
277 | else:
278 | fc_feat = np.zeros((1), dtype='float32')
279 | if hasattr(self, 'h5_label_file'):
280 | seq = self.get_captions(ix, self.seq_per_img)
281 | else:
282 | seq = None
283 | return (fc_feat,
284 | att_feat, seq,
285 | ix)
286 |
287 | def __len__(self):
288 | return len(self.info['images'])
289 |
290 | class SubsetSampler(torch.utils.data.sampler.Sampler):
291 | r"""Samples elements randomly from a given list of indices, without replacement.
292 | Arguments:
293 | indices (list): a list of indices
294 | """
295 |
296 | def __init__(self, indices):
297 | self.indices = indices
298 |
299 | def __iter__(self):
300 | return (self.indices[i] for i in range(len(self.indices)))
301 |
302 | def __len__(self):
303 | return len(self.indices)
304 |
305 | class BlobFetcher():
306 | """Experimental class for prefetching blobs in a separate process."""
307 | def __init__(self, split, dataloader, if_shuffle=False):
308 | """
309 | db is a list of tuples containing: imcrop_name, caption, bbox_feat of gt box, imname
310 | """
311 | self.split = split
312 | self.dataloader = dataloader
313 | self.if_shuffle = if_shuffle
314 |
315 | # Add more in the queue
316 | def reset(self):
317 | """
318 | Two cases for this function to be triggered:
319 | 1. not hasattr(self, 'split_loader'): Resume from previous training. Create the dataset given the saved split_ix and iterator
320 | 2. wrapped: a new epoch, the split_ix and iterator have been updated in the get_minibatch_inds already.
321 | """
322 | # batch_size is 1, the merge is done in DataLoader class
323 | self.split_loader = iter(data.DataLoader(dataset=self.dataloader,
324 | batch_size=1,
325 | sampler=SubsetSampler(self.dataloader.split_ix[self.split][self.dataloader.iterators[self.split]:]),
326 | shuffle=False,
327 | pin_memory=True,
328 | num_workers=4, # 4 is usually enough
329 | collate_fn=lambda x: x[0]))
330 |
331 | def _get_next_minibatch_inds(self):
332 | max_index = len(self.dataloader.split_ix[self.split])
333 | wrapped = False
334 |
335 | ri = self.dataloader.iterators[self.split]
336 | ix = self.dataloader.split_ix[self.split][ri]
337 |
338 | ri_next = ri + 1
339 | if ri_next >= max_index:
340 | ri_next = 0
341 | if self.if_shuffle:
342 | random.shuffle(self.dataloader.split_ix[self.split])
343 | wrapped = True
344 | self.dataloader.iterators[self.split] = ri_next
345 |
346 | return ix, wrapped
347 |
348 | def get(self):
349 | if not hasattr(self, 'split_loader'):
350 | self.reset()
351 |
352 | ix, wrapped = self._get_next_minibatch_inds()
353 | tmp = self.split_loader.next()
354 | if wrapped:
355 | self.reset()
356 |
357 | assert tmp[-1] == ix, "ix not equal"
358 |
359 | return tmp + [wrapped]
--------------------------------------------------------------------------------
/opts.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def parse_opt():
4 | parser = argparse.ArgumentParser()
5 | # Data input settings
6 | parser.add_argument('--input_json', type=str, default='data/coco.json',
7 | help='path to the json file containing additional info and vocab')
8 | parser.add_argument('--input_fc_dir', type=str, default='data/cocotalk_fc',
9 | help='path to the directory containing the preprocessed fc feats')
10 | parser.add_argument('--input_att_dir', type=str, default='data/cocotalk_att',
11 | help='path to the directory containing the preprocessed att feats')
12 | parser.add_argument('--input_box_dir', type=str, default='data/cocotalk_box',
13 | help='path to the directory containing the boxes of att feats')
14 | parser.add_argument('--input_label_h5', type=str, default='data/coco_label.h5',
15 | help='path to the h5file containing the preprocessed dataset')
16 | parser.add_argument('--start_from', type=str, default=None,
17 | help="""continue training from saved model at this path. Path must contain files saved by previous training process:
18 | 'infos.pkl' : configuration;
19 | 'checkpoint' : paths to model file(s) (created by tf).
20 | Note: this file contains absolute paths, be careful when moving files around;
21 | 'model.ckpt-*' : file(s) with model definition (created by tf)
22 | """)
23 | parser.add_argument('--cached_tokens', type=str, default='coco-train-idxs',
24 | help='Cached token file for calculating cider score during self critical training.')
25 |
26 | # Model settings
27 | parser.add_argument('--caption_model', type=str, default="show_tell",
28 | help='show_tell, show_attend_tell, all_img, fc, att2in, att2in2, att2all2, adaatt, adaattmo, topdown, stackatt, denseatt, transformer')
29 | parser.add_argument('--rnn_size', type=int, default=512,
30 | help='size of the rnn in number of hidden nodes in each layer')
31 | parser.add_argument('--num_layers', type=int, default=1,
32 | help='number of layers in the RNN')
33 | parser.add_argument('--rnn_type', type=str, default='lstm',
34 | help='rnn, gru, or lstm')
35 | parser.add_argument('--input_encoding_size', type=int, default=512,
36 | help='the encoding size of each token in the vocabulary, and the image.')
37 | parser.add_argument('--att_hid_size', type=int, default=512,
38 | help='the hidden size of the attention MLP; only useful in show_attend_tell; 0 if not using hidden layer')
39 | parser.add_argument('--fc_feat_size', type=int, default=2048,
40 | help='2048 for resnet, 4096 for vgg')
41 | parser.add_argument('--att_feat_size', type=int, default=2048,
42 | help='2048 for resnet, 512 for vgg')
43 | parser.add_argument('--logit_layers', type=int, default=1,
44 | help='number of layers in the RNN')
45 |
46 |
47 | parser.add_argument('--use_bn', type=int, default=0,
48 | help='If 1, then do batch_normalization first in att_embed, if 2 then do bn both in the beginning and the end of att_embed')
49 |
50 | # AoA settings
51 | parser.add_argument('--mean_feats', type=int, default=1,
52 | help='use mean pooling of feats?')
53 | parser.add_argument('--refine', type=int, default=1,
54 | help='refining feature vectors?')
55 | parser.add_argument('--refine_aoa', type=int, default=1,
56 | help='use aoa in the refining module?')
57 | parser.add_argument('--use_ff', type=int, default=1,
58 | help='keep feed-forward layer in the refining module?')
59 | parser.add_argument('--dropout_aoa', type=float, default=0.3,
60 | help='dropout_aoa in the refining module?')
61 |
62 | parser.add_argument('--ctx_drop', type=int, default=0,
63 | help='apply dropout to the context vector before fed into LSTM?')
64 | parser.add_argument('--decoder_type', type=str, default='AoA',
65 | help='AoA, LSTM, base')
66 | parser.add_argument('--use_multi_head', type=int, default=2,
67 | help='use multi head attention? 0 for addictive single head; 1 for addictive multi head; 2 for productive multi head.')
68 | parser.add_argument('--num_heads', type=int, default=8,
69 | help='number of attention heads?')
70 | parser.add_argument('--multi_head_scale', type=int, default=1,
71 | help='scale q,k,v?')
72 |
73 | parser.add_argument('--use_warmup', type=int, default=0,
74 | help='warm up the learing rate?')
75 | parser.add_argument('--acc_steps', type=int, default=1,
76 | help='accumulation steps')
77 |
78 |
79 | # feature manipulation
80 | parser.add_argument('--norm_att_feat', type=int, default=0,
81 | help='If normalize attention features')
82 | parser.add_argument('--use_box', type=int, default=0,
83 | help='If use box features')
84 | parser.add_argument('--norm_box_feat', type=int, default=0,
85 | help='If use box, do we normalize box feature')
86 |
87 | # Optimization: General
88 | parser.add_argument('--max_epochs', type=int, default=-1,
89 | help='number of epochs')
90 | parser.add_argument('--batch_size', type=int, default=16,
91 | help='minibatch size')
92 | parser.add_argument('--grad_clip', type=float, default=0.1, #5.,
93 | help='clip gradients at this value')
94 | parser.add_argument('--drop_prob_lm', type=float, default=0.5,
95 | help='strength of dropout in the Language Model RNN')
96 | parser.add_argument('--self_critical_after', type=int, default=-1,
97 | help='After what epoch do we start finetuning the CNN? (-1 = disable; never finetune, 0 = finetune from start)')
98 | parser.add_argument('--seq_per_img', type=int, default=5,
99 | help='number of captions to sample for each image during training. Done for efficiency since CNN forward pass is expensive. E.g. coco has 5 sents/image')
100 |
101 | # Sample related
102 | parser.add_argument('--beam_size', type=int, default=1,
103 | help='used when sample_method = greedy, indicates number of beams in beam search. Usually 2 or 3 works well. More is not better. Set this to 1 for faster runtime but a bit worse performance.')
104 | parser.add_argument('--max_length', type=int, default=20,
105 | help='Maximum length during sampling')
106 | parser.add_argument('--length_penalty', type=str, default='',
107 | help='wu_X or avg_X, X is the alpha')
108 | parser.add_argument('--block_trigrams', type=int, default=0,
109 | help='block repeated trigram.')
110 | parser.add_argument('--remove_bad_endings', type=int, default=0,
111 | help='Remove bad endings')
112 |
113 | #Optimization: for the Language Model
114 | parser.add_argument('--optim', type=str, default='adam',
115 | help='what update to use? rmsprop|sgd|sgdmom|adagrad|adam')
116 | parser.add_argument('--learning_rate', type=float, default=4e-4,
117 | help='learning rate')
118 | parser.add_argument('--learning_rate_decay_start', type=int, default=-1,
119 | help='at what iteration to start decaying learning rate? (-1 = dont) (in epoch)')
120 | parser.add_argument('--learning_rate_decay_every', type=int, default=3,
121 | help='every how many iterations thereafter to drop LR?(in epoch)')
122 | parser.add_argument('--learning_rate_decay_rate', type=float, default=0.8,
123 | help='every how many iterations thereafter to drop LR?(in epoch)')
124 | parser.add_argument('--optim_alpha', type=float, default=0.9,
125 | help='alpha for adam')
126 | parser.add_argument('--optim_beta', type=float, default=0.999,
127 | help='beta used for adam')
128 | parser.add_argument('--optim_epsilon', type=float, default=1e-8,
129 | help='epsilon that goes into denominator for smoothing')
130 | parser.add_argument('--weight_decay', type=float, default=0,
131 | help='weight_decay')
132 | # Transformer
133 | parser.add_argument('--label_smoothing', type=float, default=0,
134 | help='')
135 | parser.add_argument('--noamopt', action='store_true',
136 | help='')
137 | parser.add_argument('--noamopt_warmup', type=int, default=2000,
138 | help='')
139 | parser.add_argument('--noamopt_factor', type=float, default=1,
140 | help='')
141 | parser.add_argument('--reduce_on_plateau', action='store_true',
142 | help='')
143 |
144 | parser.add_argument('--scheduled_sampling_start', type=int, default=-1,
145 | help='at what iteration to start decay gt probability')
146 | parser.add_argument('--scheduled_sampling_increase_every', type=int, default=5,
147 | help='every how many iterations thereafter to gt probability')
148 | parser.add_argument('--scheduled_sampling_increase_prob', type=float, default=0.05,
149 | help='How much to update the prob')
150 | parser.add_argument('--scheduled_sampling_max_prob', type=float, default=0.25,
151 | help='Maximum scheduled sampling prob.')
152 |
153 |
154 | # Evaluation/Checkpointing
155 | parser.add_argument('--val_images_use', type=int, default=3200,
156 | help='how many images to use when periodically evaluating the validation loss? (-1 = all)')
157 | parser.add_argument('--save_checkpoint_every', type=int, default=2500,
158 | help='how often to save a model checkpoint (in iterations)?')
159 | parser.add_argument('--save_history_ckpt', type=int, default=0,
160 | help='If save checkpoints at every save point')
161 | parser.add_argument('--checkpoint_path', type=str, default='save',
162 | help='directory to store checkpointed models')
163 | parser.add_argument('--language_eval', type=int, default=0,
164 | help='Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.')
165 | parser.add_argument('--losses_log_every', type=int, default=25,
166 | help='How often do we snapshot losses, for inclusion in the progress dump? (0 = disable)')
167 | parser.add_argument('--load_best_score', type=int, default=1,
168 | help='Do we load previous best score when resuming training.')
169 |
170 | # misc
171 | parser.add_argument('--id', type=str, default='',
172 | help='an id identifying this run/job. used in cross-val and appended when writing progress files')
173 | parser.add_argument('--train_only', type=int, default=0,
174 | help='if true then use 80k, else use 110k')
175 |
176 |
177 | # Reward
178 | parser.add_argument('--cider_reward_weight', type=float, default=1,
179 | help='The reward weight from cider')
180 | parser.add_argument('--bleu_reward_weight', type=float, default=0,
181 | help='The reward weight from bleu4')
182 |
183 | args = parser.parse_args()
184 |
185 | # Check if args are valid
186 | assert args.rnn_size > 0, "rnn_size should be greater than 0"
187 | assert args.num_layers > 0, "num_layers should be greater than 0"
188 | assert args.input_encoding_size > 0, "input_encoding_size should be greater than 0"
189 | assert args.batch_size > 0, "batch_size should be greater than 0"
190 | assert args.drop_prob_lm >= 0 and args.drop_prob_lm < 1, "drop_prob_lm should be between 0 and 1"
191 | assert args.seq_per_img > 0, "seq_per_img should be greater than 0"
192 | assert args.beam_size > 0, "beam_size should be greater than 0"
193 | assert args.save_checkpoint_every > 0, "save_checkpoint_every should be greater than 0"
194 | assert args.losses_log_every > 0, "losses_log_every should be greater than 0"
195 | assert args.language_eval == 0 or args.language_eval == 1, "language_eval should be 0 or 1"
196 | assert args.load_best_score == 0 or args.load_best_score == 1, "language_eval should be 0 or 1"
197 | assert args.train_only == 0 or args.train_only == 1, "language_eval should be 0 or 1"
198 |
199 | return args
200 |
201 | def add_eval_options(parser):
202 | # Basic options
203 | parser.add_argument('--batch_size', type=int, default=0,
204 | help='if > 0 then overrule, otherwise load from checkpoint.')
205 | parser.add_argument('--num_images', type=int, default=-1,
206 | help='how many images to use when periodically evaluating the loss? (-1 = all)')
207 | parser.add_argument('--language_eval', type=int, default=0,
208 | help='Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.')
209 | parser.add_argument('--dump_images', type=int, default=1,
210 | help='Dump images into vis/imgs folder for vis? (1=yes,0=no)')
211 | parser.add_argument('--dump_json', type=int, default=1,
212 | help='Dump json with predictions into vis folder? (1=yes,0=no)')
213 | parser.add_argument('--dump_path', type=int, default=0,
214 | help='Write image paths along with predictions into vis json? (1=yes,0=no)')
215 |
216 | # Sampling options
217 | parser.add_argument('--sample_method', type=str, default='greedy',
218 | help='greedy; sample; gumbel; top, top<0-1>')
219 | parser.add_argument('--beam_size', type=int, default=2,
220 | help='indicates number of beams in beam search. Usually 2 or 3 works well. More is not better. Set this to 1 for faster runtime but a bit worse performance.')
221 | parser.add_argument('--max_length', type=int, default=20,
222 | help='Maximum length during sampling')
223 | parser.add_argument('--length_penalty', type=str, default='',
224 | help='wu_X or avg_X, X is the alpha')
225 | parser.add_argument('--group_size', type=int, default=1,
226 | help='used for diverse beam search. if group_size is 1, then it\'s normal beam search')
227 | parser.add_argument('--diversity_lambda', type=float, default=0.5,
228 | help='used for diverse beam search. Usually from 0.2 to 0.8. Higher value of lambda produces a more diverse list')
229 | parser.add_argument('--temperature', type=float, default=1.0,
230 | help='temperature when sampling from distributions (i.e. when sample_method = sample). Lower = "safer" predictions.')
231 | parser.add_argument('--decoding_constraint', type=int, default=0,
232 | help='If 1, not allowing same word in a row')
233 | parser.add_argument('--block_trigrams', type=int, default=0,
234 | help='block repeated trigram.')
235 | parser.add_argument('--remove_bad_endings', type=int, default=0,
236 | help='Remove bad endings')
237 | # For evaluation on a folder of images:
238 | parser.add_argument('--image_folder', type=str, default='',
239 | help='If this is nonempty then will predict on the images in this folder path')
240 | parser.add_argument('--image_root', type=str, default='',
241 | help='In case the image paths have to be preprended with a root path to an image folder')
242 | # For evaluation on MSCOCO images from some split:
243 | parser.add_argument('--input_fc_dir', type=str, default='',
244 | help='path to the h5file containing the preprocessed dataset')
245 | parser.add_argument('--input_att_dir', type=str, default='',
246 | help='path to the h5file containing the preprocessed dataset')
247 | parser.add_argument('--input_box_dir', type=str, default='',
248 | help='path to the h5file containing the preprocessed dataset')
249 | parser.add_argument('--input_label_h5', type=str, default='',
250 | help='path to the h5file containing the preprocessed dataset')
251 | parser.add_argument('--input_json', type=str, default='',
252 | help='path to the json file containing additional info and vocab. empty = fetch from model checkpoint.')
253 | parser.add_argument('--split', type=str, default='test',
254 | help='if running on MSCOCO images, which split to use: val|test|train')
255 | parser.add_argument('--coco_json', type=str, default='',
256 | help='if nonempty then use this file in DataLoaderRaw (see docs there). Used only in MSCOCO test evaluation, where we have a specific json file of only test set images.')
257 | # misc
258 | parser.add_argument('--id', type=str, default='',
259 | help='an id identifying this run/job. used only if language_eval = 1 for appending to intermediate files')
260 | parser.add_argument('--verbose_beam', type=int, default=1,
261 | help='if we need to print out all beam search beams.')
262 | parser.add_argument('--verbose_loss', type=int, default=0,
263 | help='If calculate loss using ground truth during evaluation')
--------------------------------------------------------------------------------