├── misc
├── __init__.py
├── resnet_utils.py
├── resnet.py
├── rewards.py
└── utils.py
├── vis
├── imgs
│ └── dummy
└── index.html
├── ADVANCED.md
├── .gitmodules
├── scripts
├── copy_model.sh
├── make_bu_data.py
├── prepro_feats.py
├── prepro_ngrams.py
└── prepro_labels.py
├── models
├── __init__.py
├── ShowTellModel.py
├── FCModel.py
├── CaptionModel.py
├── OldModel.py
├── AttEnsemble.py
├── TransformerModel.py
└── AttModel.py
├── dataloaderraw.py
├── eval_utils.py
├── eval.py
├── eval_ensemble.py
├── opts.py
├── train.py
├── README.md
└── dataloader.py
/misc/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/vis/imgs/dummy:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/ADVANCED.md:
--------------------------------------------------------------------------------
1 | # Advanced
2 |
3 | ## Ensemble
4 |
5 | ## Batch normalization
6 |
7 | ## Box feature
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "cider"]
2 | path = cider
3 | url = https://github.com/ruotianluo/cider.git
4 |
--------------------------------------------------------------------------------
/scripts/copy_model.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | cp -r log_$1 log_$2
4 | cd log_$2
5 | mv infos_$1-best.pkl infos_$2-best.pkl
6 | mv infos_$1.pkl infos_$2.pkl
7 | cd ../
--------------------------------------------------------------------------------
/misc/resnet_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class myResnet(nn.Module):
6 | def __init__(self, resnet):
7 | super(myResnet, self).__init__()
8 | self.resnet = resnet
9 |
10 | def forward(self, img, att_size=14):
11 | x = img.unsqueeze(0)
12 |
13 | x = self.resnet.conv1(x)
14 | x = self.resnet.bn1(x)
15 | x = self.resnet.relu(x)
16 | x = self.resnet.maxpool(x)
17 |
18 | x = self.resnet.layer1(x)
19 | x = self.resnet.layer2(x)
20 | x = self.resnet.layer3(x)
21 | x = self.resnet.layer4(x)
22 |
23 | fc = x.mean(3).mean(2).squeeze()
24 | att = F.adaptive_avg_pool2d(x,[att_size,att_size]).squeeze().permute(1, 2, 0)
25 |
26 | return fc, att
27 |
28 |
--------------------------------------------------------------------------------
/scripts/make_bu_data.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import os
6 | import base64
7 | import numpy as np
8 | import csv
9 | import sys
10 | import zlib
11 | import time
12 | import mmap
13 | import argparse
14 |
15 | parser = argparse.ArgumentParser()
16 |
17 | # output_dir
18 | parser.add_argument('--downloaded_feats', default='data/bu_data', help='downloaded feature directory')
19 | parser.add_argument('--output_dir', default='data/cocobu', help='output feature files')
20 |
21 | args = parser.parse_args()
22 |
23 | csv.field_size_limit(sys.maxsize)
24 |
25 |
26 | FIELDNAMES = ['image_id', 'image_w','image_h','num_boxes', 'boxes', 'features']
27 | infiles = ['trainval/karpathy_test_resnet101_faster_rcnn_genome.tsv',
28 | 'trainval/karpathy_val_resnet101_faster_rcnn_genome.tsv',\
29 | 'trainval/karpathy_train_resnet101_faster_rcnn_genome.tsv.0', \
30 | 'trainval/karpathy_train_resnet101_faster_rcnn_genome.tsv.1']
31 |
32 | os.makedirs(args.output_dir+'_att')
33 | os.makedirs(args.output_dir+'_fc')
34 | os.makedirs(args.output_dir+'_box')
35 |
36 | for infile in infiles:
37 | print('Reading ' + infile)
38 | with open(os.path.join(args.downloaded_feats, infile), "r+b") as tsv_in_file:
39 | reader = csv.DictReader(tsv_in_file, delimiter='\t', fieldnames = FIELDNAMES)
40 | for item in reader:
41 | item['image_id'] = int(item['image_id'])
42 | item['num_boxes'] = int(item['num_boxes'])
43 | for field in ['boxes', 'features']:
44 | item[field] = np.frombuffer(base64.decodestring(item[field]),
45 | dtype=np.float32).reshape((item['num_boxes'],-1))
46 | np.savez_compressed(os.path.join(args.output_dir+'_att', str(item['image_id'])), feat=item['features'])
47 | np.save(os.path.join(args.output_dir+'_fc', str(item['image_id'])), item['features'].mean(0))
48 | np.save(os.path.join(args.output_dir+'_box', str(item['image_id'])), item['boxes'])
49 |
50 |
51 |
52 |
53 |
--------------------------------------------------------------------------------
/vis/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | neuraltalk2 results visualization
7 |
8 |
42 |
43 |
44 |
45 |
72 |
73 |
74 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import os
6 | import copy
7 |
8 | import numpy as np
9 | import misc.utils as utils
10 | import torch
11 |
12 | from .ShowTellModel import ShowTellModel
13 | from .FCModel import FCModel
14 | from .OldModel import ShowAttendTellModel, AllImgModel
15 | from .AttModel import *
16 | from .TransformerModel import TransformerModel
17 |
18 | def setup(opt):
19 |
20 | if opt.caption_model == 'fc':
21 | model = FCModel(opt)
22 | if opt.caption_model == 'show_tell':
23 | model = ShowTellModel(opt)
24 | # Att2in model in self-critical
25 | elif opt.caption_model == 'att2in':
26 | model = Att2inModel(opt)
27 | # Att2in model with two-layer MLP img embedding and word embedding
28 | elif opt.caption_model == 'att2in2':
29 | model = Att2in2Model(opt)
30 | elif opt.caption_model == 'att2all2':
31 | model = Att2all2Model(opt)
32 | # Adaptive Attention model from Knowing when to look
33 | elif opt.caption_model == 'adaatt':
34 | model = AdaAttModel(opt)
35 | # Adaptive Attention with maxout lstm
36 | elif opt.caption_model == 'adaattmo':
37 | model = AdaAttMOModel(opt)
38 | # Top-down attention model
39 | elif opt.caption_model == 'topdown':
40 | model = TopDownModel(opt)
41 | # StackAtt
42 | elif opt.caption_model == 'stackatt':
43 | model = StackAttModel(opt)
44 | # DenseAtt
45 | elif opt.caption_model == 'denseatt':
46 | model = DenseAttModel(opt)
47 | # Transformer
48 | elif opt.caption_model == 'transformer':
49 | model = TransformerModel(opt)
50 | else:
51 | raise Exception("Caption model not supported: {}".format(opt.caption_model))
52 |
53 | # check compatibility if training is continued from previously saved model
54 | if vars(opt).get('start_from', None) is not None:
55 | # check if all necessary files exist
56 | assert os.path.isdir(opt.start_from)," %s must be a a path" % opt.start_from
57 | assert os.path.isfile(os.path.join(opt.start_from,"infos_"+opt.id+".pkl")),"infos.pkl file does not exist in path %s"%opt.start_from
58 | model.load_state_dict(torch.load(os.path.join(opt.start_from, 'model.pth')))
59 |
60 | return model
61 |
--------------------------------------------------------------------------------
/misc/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision.models.resnet
4 | from torchvision.models.resnet import BasicBlock, Bottleneck
5 |
6 | class ResNet(torchvision.models.resnet.ResNet):
7 | def __init__(self, block, layers, num_classes=1000):
8 | super(ResNet, self).__init__(block, layers, num_classes)
9 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) # change
10 | for i in range(2, 5):
11 | getattr(self, 'layer%d'%i)[0].conv1.stride = (2,2)
12 | getattr(self, 'layer%d'%i)[0].conv2.stride = (1,1)
13 |
14 | def resnet18(pretrained=False):
15 | """Constructs a ResNet-18 model.
16 |
17 | Args:
18 | pretrained (bool): If True, returns a model pre-trained on ImageNet
19 | """
20 | model = ResNet(BasicBlock, [2, 2, 2, 2])
21 | if pretrained:
22 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
23 | return model
24 |
25 |
26 | def resnet34(pretrained=False):
27 | """Constructs a ResNet-34 model.
28 |
29 | Args:
30 | pretrained (bool): If True, returns a model pre-trained on ImageNet
31 | """
32 | model = ResNet(BasicBlock, [3, 4, 6, 3])
33 | if pretrained:
34 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
35 | return model
36 |
37 |
38 | def resnet50(pretrained=False):
39 | """Constructs a ResNet-50 model.
40 |
41 | Args:
42 | pretrained (bool): If True, returns a model pre-trained on ImageNet
43 | """
44 | model = ResNet(Bottleneck, [3, 4, 6, 3])
45 | if pretrained:
46 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
47 | return model
48 |
49 |
50 | def resnet101(pretrained=False):
51 | """Constructs a ResNet-101 model.
52 |
53 | Args:
54 | pretrained (bool): If True, returns a model pre-trained on ImageNet
55 | """
56 | model = ResNet(Bottleneck, [3, 4, 23, 3])
57 | if pretrained:
58 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
59 | return model
60 |
61 |
62 | def resnet152(pretrained=False):
63 | """Constructs a ResNet-152 model.
64 |
65 | Args:
66 | pretrained (bool): If True, returns a model pre-trained on ImageNet
67 | """
68 | model = ResNet(Bottleneck, [3, 8, 36, 3])
69 | if pretrained:
70 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
71 | return model
--------------------------------------------------------------------------------
/misc/rewards.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import numpy as np
6 | import time
7 | import misc.utils as utils
8 | from collections import OrderedDict
9 | import torch
10 |
11 | import sys
12 | sys.path.append("cider")
13 | from pyciderevalcap.ciderD.ciderD import CiderD
14 | sys.path.append("coco-caption")
15 | from pycocoevalcap.bleu.bleu import Bleu
16 |
17 | CiderD_scorer = None
18 | Bleu_scorer = None
19 | #CiderD_scorer = CiderD(df='corpus')
20 |
21 | def init_scorer(cached_tokens):
22 | global CiderD_scorer
23 | CiderD_scorer = CiderD_scorer or CiderD(df=cached_tokens)
24 | global Bleu_scorer
25 | Bleu_scorer = Bleu_scorer or Bleu(4)
26 |
27 | def array_to_str(arr):
28 | out = ''
29 | for i in range(len(arr)):
30 | out += str(arr[i]) + ' '
31 | if arr[i] == 0:
32 | break
33 | return out.strip()
34 |
35 | def get_self_critical_reward(model, fc_feats, att_feats, att_masks, data, gen_result, opt):
36 | batch_size = gen_result.size(0)# batch_size = sample_size * seq_per_img
37 | seq_per_img = batch_size // len(data['gts'])
38 |
39 | # get greedy decoding baseline
40 | model.eval()
41 | with torch.no_grad():
42 | greedy_res, _ = model(fc_feats, att_feats, att_masks=att_masks, mode='sample')
43 | model.train()
44 |
45 | res = OrderedDict()
46 |
47 | gen_result = gen_result.data.cpu().numpy()
48 | greedy_res = greedy_res.data.cpu().numpy()
49 | for i in range(batch_size):
50 | res[i] = [array_to_str(gen_result[i])]
51 | for i in range(batch_size):
52 | res[batch_size + i] = [array_to_str(greedy_res[i])]
53 |
54 | gts = OrderedDict()
55 | for i in range(len(data['gts'])):
56 | gts[i] = [array_to_str(data['gts'][i][j]) for j in range(len(data['gts'][i]))]
57 |
58 | res_ = [{'image_id':i, 'caption': res[i]} for i in range(2 * batch_size)]
59 | res__ = {i: res[i] for i in range(2 * batch_size)}
60 | gts = {i: gts[i % batch_size // seq_per_img] for i in range(2 * batch_size)}
61 | if opt.cider_reward_weight > 0:
62 | _, cider_scores = CiderD_scorer.compute_score(gts, res_)
63 | print('Cider scores:', _)
64 | else:
65 | cider_scores = 0
66 | if opt.bleu_reward_weight > 0:
67 | _, bleu_scores = Bleu_scorer.compute_score(gts, res__)
68 | bleu_scores = np.array(bleu_scores[3])
69 | print('Bleu scores:', _[3])
70 | else:
71 | bleu_scores = 0
72 | scores = opt.cider_reward_weight * cider_scores + opt.bleu_reward_weight * bleu_scores
73 |
74 | scores = scores[:batch_size] - scores[batch_size:]
75 |
76 | rewards = np.repeat(scores[:, np.newaxis], gen_result.shape[1], 1)
77 |
78 | return rewards
--------------------------------------------------------------------------------
/scripts/prepro_feats.py:
--------------------------------------------------------------------------------
1 | """
2 | Preprocess a raw json dataset into hdf5/json files for use in data_loader.lua
3 |
4 | Input: json file that has the form
5 | [{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...]
6 | example element in this list would look like
7 | {'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895}
8 |
9 | This script reads this json, does some basic preprocessing on the captions
10 | (e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays
11 |
12 | Output: a json file and an hdf5 file
13 | The hdf5 file contains several fields:
14 | /images is (N,3,256,256) uint8 array of raw image data in RGB format
15 | /labels is (M,max_length) uint32 array of encoded labels, zero padded
16 | /label_start_ix and /label_end_ix are (N,) uint32 arrays of pointers to the
17 | first and last indices (in range 1..M) of labels for each image
18 | /label_length stores the length of the sequence for each of the M sequences
19 |
20 | The json file has a dict that contains:
21 | - an 'ix_to_word' field storing the vocab in form {ix:'word'}, where ix is 1-indexed
22 | - an 'images' field that is a list holding auxiliary information for each image,
23 | such as in particular the 'split' it was assigned to.
24 | """
25 |
26 | from __future__ import absolute_import
27 | from __future__ import division
28 | from __future__ import print_function
29 |
30 | import os
31 | import json
32 | import argparse
33 | from random import shuffle, seed
34 | import string
35 | # non-standard dependencies:
36 | import h5py
37 | from six.moves import cPickle
38 | import numpy as np
39 | import torch
40 | import torchvision.models as models
41 | import skimage.io
42 |
43 | from torchvision import transforms as trn
44 | preprocess = trn.Compose([
45 | #trn.ToTensor(),
46 | trn.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
47 | ])
48 |
49 | from misc.resnet_utils import myResnet
50 | import misc.resnet as resnet
51 |
52 | def main(params):
53 | net = getattr(resnet, params['model'])()
54 | net.load_state_dict(torch.load(os.path.join(params['model_root'],params['model']+'.pth')))
55 | my_resnet = myResnet(net)
56 | my_resnet.cuda()
57 | my_resnet.eval()
58 |
59 | imgs = json.load(open(params['input_json'], 'r'))
60 | imgs = imgs['images']
61 | N = len(imgs)
62 |
63 | seed(123) # make reproducible
64 |
65 | dir_fc = params['output_dir']+'_fc'
66 | dir_att = params['output_dir']+'_att'
67 | if not os.path.isdir(dir_fc):
68 | os.mkdir(dir_fc)
69 | if not os.path.isdir(dir_att):
70 | os.mkdir(dir_att)
71 |
72 | for i,img in enumerate(imgs):
73 | # load the image
74 | I = skimage.io.imread(os.path.join(params['images_root'], img['filepath'], img['filename']))
75 | # handle grayscale input images
76 | if len(I.shape) == 2:
77 | I = I[:,:,np.newaxis]
78 | I = np.concatenate((I,I,I), axis=2)
79 |
80 | I = I.astype('float32')/255.0
81 | I = torch.from_numpy(I.transpose([2,0,1])).cuda()
82 | I = preprocess(I)
83 | with torch.no_grad():
84 | tmp_fc, tmp_att = my_resnet(I, params['att_size'])
85 | # write to pkl
86 | np.save(os.path.join(dir_fc, str(img['cocoid'])), tmp_fc.data.cpu().float().numpy())
87 | np.savez_compressed(os.path.join(dir_att, str(img['cocoid'])), feat=tmp_att.data.cpu().float().numpy())
88 |
89 | if i % 1000 == 0:
90 | print('processing %d/%d (%.2f%% done)' % (i, N, i*100.0/N))
91 | print('wrote ', params['output_dir'])
92 |
93 | if __name__ == "__main__":
94 |
95 | parser = argparse.ArgumentParser()
96 |
97 | # input json
98 | parser.add_argument('--input_json', required=True, help='input json file to process into hdf5')
99 | parser.add_argument('--output_dir', default='data', help='output h5 file')
100 |
101 | # options
102 | parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json')
103 | parser.add_argument('--att_size', default=14, type=int, help='14x14 or 7x7')
104 | parser.add_argument('--model', default='resnet101', type=str, help='resnet101, resnet152')
105 | parser.add_argument('--model_root', default='./data/imagenet_weights', type=str, help='model root')
106 |
107 | args = parser.parse_args()
108 | params = vars(args) # convert to ordinary dict
109 | print('parsed input parameters:')
110 | print(json.dumps(params, indent = 2))
111 | main(params)
112 |
--------------------------------------------------------------------------------
/dataloaderraw.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import json
6 | import h5py
7 | import os
8 | import numpy as np
9 | import random
10 | import torch
11 | import skimage
12 | import skimage.io
13 | import scipy.misc
14 |
15 | from torchvision import transforms as trn
16 | preprocess = trn.Compose([
17 | #trn.ToTensor(),
18 | trn.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
19 | ])
20 |
21 | from misc.resnet_utils import myResnet
22 | import misc.resnet
23 |
24 | class DataLoaderRaw():
25 |
26 | def __init__(self, opt):
27 | self.opt = opt
28 | self.coco_json = opt.get('coco_json', '')
29 | self.folder_path = opt.get('folder_path', '')
30 |
31 | self.batch_size = opt.get('batch_size', 1)
32 | self.seq_per_img = 1
33 |
34 | # Load resnet
35 | self.cnn_model = opt.get('cnn_model', 'resnet101')
36 | self.my_resnet = getattr(misc.resnet, self.cnn_model)()
37 | self.my_resnet.load_state_dict(torch.load('./data/imagenet_weights/'+self.cnn_model+'.pth'))
38 | self.my_resnet = myResnet(self.my_resnet)
39 | self.my_resnet.cuda()
40 | self.my_resnet.eval()
41 |
42 |
43 |
44 | # load the json file which contains additional information about the dataset
45 | print('DataLoaderRaw loading images from folder: ', self.folder_path)
46 |
47 | self.files = []
48 | self.ids = []
49 |
50 | print(len(self.coco_json))
51 | if len(self.coco_json) > 0:
52 | print('reading from ' + opt.coco_json)
53 | # read in filenames from the coco-style json file
54 | self.coco_annotation = json.load(open(self.coco_json))
55 | for k,v in enumerate(self.coco_annotation['images']):
56 | fullpath = os.path.join(self.folder_path, v['file_name'])
57 | self.files.append(fullpath)
58 | self.ids.append(v['id'])
59 | else:
60 | # read in all the filenames from the folder
61 | print('listing all images in directory ' + self.folder_path)
62 | def isImage(f):
63 | supportedExt = ['.jpg','.JPG','.jpeg','.JPEG','.png','.PNG','.ppm','.PPM']
64 | for ext in supportedExt:
65 | start_idx = f.rfind(ext)
66 | if start_idx >= 0 and start_idx + len(ext) == len(f):
67 | return True
68 | return False
69 |
70 | n = 1
71 | for root, dirs, files in os.walk(self.folder_path, topdown=False):
72 | for file in files:
73 | fullpath = os.path.join(self.folder_path, file)
74 | if isImage(fullpath):
75 | self.files.append(fullpath)
76 | self.ids.append(str(n)) # just order them sequentially
77 | n = n + 1
78 |
79 | self.N = len(self.files)
80 | print('DataLoaderRaw found ', self.N, ' images')
81 |
82 | self.iterator = 0
83 |
84 | def get_batch(self, split, batch_size=None):
85 | batch_size = batch_size or self.batch_size
86 |
87 | # pick an index of the datapoint to load next
88 | fc_batch = np.ndarray((batch_size, 2048), dtype = 'float32')
89 | att_batch = np.ndarray((batch_size, 14, 14, 2048), dtype = 'float32')
90 | max_index = self.N
91 | wrapped = False
92 | infos = []
93 |
94 | for i in range(batch_size):
95 | ri = self.iterator
96 | ri_next = ri + 1
97 | if ri_next >= max_index:
98 | ri_next = 0
99 | wrapped = True
100 | # wrap back around
101 | self.iterator = ri_next
102 |
103 | img = skimage.io.imread(self.files[ri])
104 |
105 | if len(img.shape) == 2:
106 | img = img[:,:,np.newaxis]
107 | img = np.concatenate((img, img, img), axis=2)
108 |
109 | img = img.astype('float32')/255.0
110 | img = torch.from_numpy(img.transpose([2,0,1])).cuda()
111 | img = preprocess(img)
112 | with torch.no_grad():
113 | tmp_fc, tmp_att = self.my_resnet(img)
114 |
115 | fc_batch[i] = tmp_fc.data.cpu().float().numpy()
116 | att_batch[i] = tmp_att.data.cpu().float().numpy()
117 |
118 | info_struct = {}
119 | info_struct['id'] = self.ids[ri]
120 | info_struct['file_path'] = self.files[ri]
121 | infos.append(info_struct)
122 |
123 | data = {}
124 | data['fc_feats'] = fc_batch
125 | data['att_feats'] = att_batch
126 | data['bounds'] = {'it_pos_now': self.iterator, 'it_max': self.N, 'wrapped': wrapped}
127 | data['infos'] = infos
128 |
129 | return data
130 |
131 | def reset_iterator(self, split):
132 | self.iterator = 0
133 |
134 | def get_vocab_size(self):
135 | return len(self.ix_to_word)
136 |
137 | def get_vocab(self):
138 | return self.ix_to_word
139 |
--------------------------------------------------------------------------------
/scripts/prepro_ngrams.py:
--------------------------------------------------------------------------------
1 | """
2 | Preprocess a raw json dataset into hdf5/json files for use in data_loader.lua
3 |
4 | Input: json file that has the form
5 | [{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...]
6 | example element in this list would look like
7 | {'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895}
8 |
9 | This script reads this json, does some basic preprocessing on the captions
10 | (e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays
11 |
12 | Output: a json file and an hdf5 file
13 | The hdf5 file contains several fields:
14 | /images is (N,3,256,256) uint8 array of raw image data in RGB format
15 | /labels is (M,max_length) uint32 array of encoded labels, zero padded
16 | /label_start_ix and /label_end_ix are (N,) uint32 arrays of pointers to the
17 | first and last indices (in range 1..M) of labels for each image
18 | /label_length stores the length of the sequence for each of the M sequences
19 |
20 | The json file has a dict that contains:
21 | - an 'ix_to_word' field storing the vocab in form {ix:'word'}, where ix is 1-indexed
22 | - an 'images' field that is a list holding auxiliary information for each image,
23 | such as in particular the 'split' it was assigned to.
24 | """
25 |
26 | import os
27 | import json
28 | import argparse
29 | from six.moves import cPickle
30 | from collections import defaultdict
31 |
32 | def precook(s, n=4, out=False):
33 | """
34 | Takes a string as input and returns an object that can be given to
35 | either cook_refs or cook_test. This is optional: cook_refs and cook_test
36 | can take string arguments as well.
37 | :param s: string : sentence to be converted into ngrams
38 | :param n: int : number of ngrams for which representation is calculated
39 | :return: term frequency vector for occuring ngrams
40 | """
41 | words = s.split()
42 | counts = defaultdict(int)
43 | for k in xrange(1,n+1):
44 | for i in xrange(len(words)-k+1):
45 | ngram = tuple(words[i:i+k])
46 | counts[ngram] += 1
47 | return counts
48 |
49 | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average"
50 | '''Takes a list of reference sentences for a single segment
51 | and returns an object that encapsulates everything that BLEU
52 | needs to know about them.
53 | :param refs: list of string : reference sentences for some image
54 | :param n: int : number of ngrams for which (ngram) representation is calculated
55 | :return: result (list of dict)
56 | '''
57 | return [precook(ref, n) for ref in refs]
58 |
59 | def create_crefs(refs):
60 | crefs = []
61 | for ref in refs:
62 | # ref is a list of 5 captions
63 | crefs.append(cook_refs(ref))
64 | return crefs
65 |
66 | def compute_doc_freq(crefs):
67 | '''
68 | Compute term frequency for reference data.
69 | This will be used to compute idf (inverse document frequency later)
70 | The term frequency is stored in the object
71 | :return: None
72 | '''
73 | document_frequency = defaultdict(float)
74 | for refs in crefs:
75 | # refs, k ref captions of one image
76 | for ngram in set([ngram for ref in refs for (ngram,count) in ref.iteritems()]):
77 | document_frequency[ngram] += 1
78 | # maxcounts[ngram] = max(maxcounts.get(ngram,0), count)
79 | return document_frequency
80 |
81 | def build_dict(imgs, wtoi, params):
82 | wtoi[''] = 0
83 |
84 | count_imgs = 0
85 |
86 | refs_words = []
87 | refs_idxs = []
88 | for img in imgs:
89 | if (params['split'] == img['split']) or \
90 | (params['split'] == 'train' and img['split'] == 'restval') or \
91 | (params['split'] == 'all'):
92 | #(params['split'] == 'val' and img['split'] == 'restval') or \
93 | ref_words = []
94 | ref_idxs = []
95 | for sent in img['sentences']:
96 | tmp_tokens = sent['tokens'] + ['']
97 | tmp_tokens = [_ if _ in wtoi else 'UNK' for _ in tmp_tokens]
98 | ref_words.append(' '.join(tmp_tokens))
99 | ref_idxs.append(' '.join([str(wtoi[_]) for _ in tmp_tokens]))
100 | refs_words.append(ref_words)
101 | refs_idxs.append(ref_idxs)
102 | count_imgs += 1
103 | print('total imgs:', count_imgs)
104 |
105 | ngram_words = compute_doc_freq(create_crefs(refs_words))
106 | ngram_idxs = compute_doc_freq(create_crefs(refs_idxs))
107 | return ngram_words, ngram_idxs, count_imgs
108 |
109 | def main(params):
110 |
111 | imgs = json.load(open(params['input_json'], 'r'))
112 | itow = json.load(open(params['dict_json'], 'r'))['ix_to_word']
113 | wtoi = {w:i for i,w in itow.items()}
114 |
115 | imgs = imgs['images']
116 |
117 | ngram_words, ngram_idxs, ref_len = build_dict(imgs, wtoi, params)
118 |
119 | cPickle.dump({'document_frequency': ngram_words, 'ref_len': ref_len}, open(params['output_pkl']+'-words.p','w'), protocol=cPickle.HIGHEST_PROTOCOL)
120 | cPickle.dump({'document_frequency': ngram_idxs, 'ref_len': ref_len}, open(params['output_pkl']+'-idxs.p','w'), protocol=cPickle.HIGHEST_PROTOCOL)
121 |
122 | if __name__ == "__main__":
123 |
124 | parser = argparse.ArgumentParser()
125 |
126 | # input json
127 | parser.add_argument('--input_json', default='/home-nfs/rluo/rluo/nips/code/prepro/dataset_coco.json', help='input json file to process into hdf5')
128 | parser.add_argument('--dict_json', default='data/cocotalk.json', help='output json file')
129 | parser.add_argument('--output_pkl', default='data/coco-all', help='output pickle file')
130 | parser.add_argument('--split', default='all', help='test, val, train, all')
131 | args = parser.parse_args()
132 | params = vars(args) # convert to ordinary dict
133 |
134 | main(params)
135 |
--------------------------------------------------------------------------------
/eval_utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import torch
6 | import torch.nn as nn
7 |
8 | import numpy as np
9 | import json
10 | from json import encoder
11 | import random
12 | import string
13 | import time
14 | import os
15 | import sys
16 | import misc.utils as utils
17 |
18 | def language_eval(dataset, preds, model_id, split):
19 | import sys
20 | sys.path.append("coco-caption")
21 | annFile = 'coco-caption/annotations/captions_val2014.json'
22 | from pycocotools.coco import COCO
23 | from pycocoevalcap.eval import COCOEvalCap
24 |
25 | # encoder.FLOAT_REPR = lambda o: format(o, '.3f')
26 |
27 | if not os.path.isdir('eval_results'):
28 | os.mkdir('eval_results')
29 | cache_path = os.path.join('eval_results/', model_id + '_' + split + '.json')
30 |
31 | coco = COCO(annFile)
32 | valids = coco.getImgIds()
33 |
34 | # filter results to only those in MSCOCO validation set (will be about a third)
35 | preds_filt = [p for p in preds if p['image_id'] in valids]
36 | print('using %d/%d predictions' % (len(preds_filt), len(preds)))
37 | json.dump(preds_filt, open(cache_path, 'w')) # serialize to temporary json file. Sigh, COCO API...
38 |
39 | cocoRes = coco.loadRes(cache_path)
40 | cocoEval = COCOEvalCap(coco, cocoRes)
41 | cocoEval.params['image_id'] = cocoRes.getImgIds()
42 | cocoEval.evaluate()
43 |
44 | # create output dictionary
45 | out = {}
46 | for metric, score in cocoEval.eval.items():
47 | out[metric] = score
48 |
49 | imgToEval = cocoEval.imgToEval
50 | for p in preds_filt:
51 | image_id, caption = p['image_id'], p['caption']
52 | imgToEval[image_id]['caption'] = caption
53 | with open(cache_path, 'w') as outfile:
54 | json.dump({'overall': out, 'imgToEval': imgToEval}, outfile)
55 |
56 | return out
57 |
58 | def eval_split(model, crit, loader, eval_kwargs={}):
59 | verbose = eval_kwargs.get('verbose', True)
60 | verbose_beam = eval_kwargs.get('verbose_beam', 1)
61 | verbose_loss = eval_kwargs.get('verbose_loss', 1)
62 | num_images = eval_kwargs.get('num_images', eval_kwargs.get('val_images_use', -1))
63 | split = eval_kwargs.get('split', 'val')
64 | lang_eval = eval_kwargs.get('language_eval', 0)
65 | dataset = eval_kwargs.get('dataset', 'coco')
66 | beam_size = eval_kwargs.get('beam_size', 1)
67 |
68 | # Make sure in the evaluation mode
69 | model.eval()
70 |
71 | loader.reset_iterator(split)
72 |
73 | n = 0
74 | loss = 0
75 | loss_sum = 0
76 | loss_evals = 1e-8
77 | predictions = []
78 | while True:
79 | data = loader.get_batch(split)
80 | n = n + loader.batch_size
81 |
82 | if data.get('labels', None) is not None and verbose_loss:
83 | # forward the model to get loss
84 | tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks']]
85 | tmp = [torch.from_numpy(_).cuda() if _ is not None else _ for _ in tmp]
86 | fc_feats, att_feats, labels, masks, att_masks = tmp
87 |
88 | with torch.no_grad():
89 | loss = crit(model(fc_feats, att_feats, labels, att_masks), labels[:,1:], masks[:,1:]).item()
90 | loss_sum = loss_sum + loss
91 | loss_evals = loss_evals + 1
92 |
93 | # forward the model to also get generated samples for each image
94 | # Only leave one feature for each image, in case duplicate sample
95 | tmp = [data['fc_feats'][np.arange(loader.batch_size) * loader.seq_per_img],
96 | data['att_feats'][np.arange(loader.batch_size) * loader.seq_per_img],
97 | data['att_masks'][np.arange(loader.batch_size) * loader.seq_per_img] if data['att_masks'] is not None else None]
98 | tmp = [torch.from_numpy(_).cuda() if _ is not None else _ for _ in tmp]
99 | fc_feats, att_feats, att_masks = tmp
100 | # forward the model to also get generated samples for each image
101 | with torch.no_grad():
102 | seq = model(fc_feats, att_feats, att_masks, opt=eval_kwargs, mode='sample')[0].data
103 |
104 | # Print beam search
105 | if beam_size > 1 and verbose_beam:
106 | for i in range(loader.batch_size):
107 | print('\n'.join([utils.decode_sequence(loader.get_vocab(), _['seq'].unsqueeze(0))[0] for _ in model.done_beams[i]]))
108 | print('--' * 10)
109 | sents = utils.decode_sequence(loader.get_vocab(), seq)
110 |
111 | for k, sent in enumerate(sents):
112 | entry = {'image_id': data['infos'][k]['id'], 'caption': sent}
113 | if eval_kwargs.get('dump_path', 0) == 1:
114 | entry['file_name'] = data['infos'][k]['file_path']
115 | predictions.append(entry)
116 | if eval_kwargs.get('dump_images', 0) == 1:
117 | # dump the raw image to vis/ folder
118 | cmd = 'cp "' + os.path.join(eval_kwargs['image_root'], data['infos'][k]['file_path']) + '" vis/imgs/img' + str(len(predictions)) + '.jpg' # bit gross
119 | print(cmd)
120 | os.system(cmd)
121 |
122 | if verbose:
123 | print('image %s: %s' %(entry['image_id'], entry['caption']))
124 |
125 | # if we wrapped around the split or used up val imgs budget then bail
126 | ix0 = data['bounds']['it_pos_now']
127 | ix1 = data['bounds']['it_max']
128 | if num_images != -1:
129 | ix1 = min(ix1, num_images)
130 | for i in range(n - ix1):
131 | predictions.pop()
132 |
133 | if verbose:
134 | print('evaluating validation preformance... %d/%d (%f)' %(ix0 - 1, ix1, loss))
135 |
136 | if data['bounds']['wrapped']:
137 | break
138 | if num_images >= 0 and n >= num_images:
139 | break
140 |
141 | lang_stats = None
142 | if lang_eval == 1:
143 | lang_stats = language_eval(dataset, predictions, eval_kwargs['id'], split)
144 |
145 | # Switch back to training mode
146 | model.train()
147 | return loss_sum/loss_evals, predictions, lang_stats
148 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import json
6 | import numpy as np
7 |
8 | import time
9 | import os
10 | from six.moves import cPickle
11 |
12 | import opts
13 | import models
14 | from dataloader import *
15 | from dataloaderraw import *
16 | import eval_utils
17 | import argparse
18 | import misc.utils as utils
19 | import torch
20 |
21 | # Input arguments and options
22 | parser = argparse.ArgumentParser()
23 | # Input paths
24 | parser.add_argument('--model', type=str, default='',
25 | help='path to model to evaluate')
26 | parser.add_argument('--cnn_model', type=str, default='resnet101',
27 | help='resnet101, resnet152')
28 | parser.add_argument('--infos_path', type=str, default='',
29 | help='path to infos to evaluate')
30 | # Basic options
31 | parser.add_argument('--batch_size', type=int, default=0,
32 | help='if > 0 then overrule, otherwise load from checkpoint.')
33 | parser.add_argument('--num_images', type=int, default=-1,
34 | help='how many images to use when periodically evaluating the loss? (-1 = all)')
35 | parser.add_argument('--language_eval', type=int, default=0,
36 | help='Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.')
37 | parser.add_argument('--dump_images', type=int, default=1,
38 | help='Dump images into vis/imgs folder for vis? (1=yes,0=no)')
39 | parser.add_argument('--dump_json', type=int, default=1,
40 | help='Dump json with predictions into vis folder? (1=yes,0=no)')
41 | parser.add_argument('--dump_path', type=int, default=0,
42 | help='Write image paths along with predictions into vis json? (1=yes,0=no)')
43 |
44 | # Sampling options
45 | parser.add_argument('--sample_max', type=int, default=1,
46 | help='1 = sample argmax words. 0 = sample from distributions.')
47 | parser.add_argument('--max_ppl', type=int, default=0,
48 | help='beam search by max perplexity or max probability.')
49 | parser.add_argument('--beam_size', type=int, default=2,
50 | help='used when sample_max = 1, indicates number of beams in beam search. Usually 2 or 3 works well. More is not better. Set this to 1 for faster runtime but a bit worse performance.')
51 | parser.add_argument('--group_size', type=int, default=1,
52 | help='used for diverse beam search. if group_size is 1, then it\'s normal beam search')
53 | parser.add_argument('--diversity_lambda', type=float, default=0.5,
54 | help='used for diverse beam search. Usually from 0.2 to 0.8. Higher value of lambda produces a more diverse list')
55 | parser.add_argument('--temperature', type=float, default=1.0,
56 | help='temperature when sampling from distributions (i.e. when sample_max = 0). Lower = "safer" predictions.')
57 | parser.add_argument('--decoding_constraint', type=int, default=0,
58 | help='If 1, not allowing same word in a row')
59 | # For evaluation on a folder of images:
60 | parser.add_argument('--image_folder', type=str, default='',
61 | help='If this is nonempty then will predict on the images in this folder path')
62 | parser.add_argument('--image_root', type=str, default='',
63 | help='In case the image paths have to be preprended with a root path to an image folder')
64 | # For evaluation on MSCOCO images from some split:
65 | parser.add_argument('--input_fc_dir', type=str, default='',
66 | help='path to the h5file containing the preprocessed dataset')
67 | parser.add_argument('--input_att_dir', type=str, default='',
68 | help='path to the h5file containing the preprocessed dataset')
69 | parser.add_argument('--input_box_dir', type=str, default='',
70 | help='path to the h5file containing the preprocessed dataset')
71 | parser.add_argument('--input_label_h5', type=str, default='',
72 | help='path to the h5file containing the preprocessed dataset')
73 | parser.add_argument('--input_json', type=str, default='',
74 | help='path to the json file containing additional info and vocab. empty = fetch from model checkpoint.')
75 | parser.add_argument('--split', type=str, default='test',
76 | help='if running on MSCOCO images, which split to use: val|test|train')
77 | parser.add_argument('--coco_json', type=str, default='',
78 | help='if nonempty then use this file in DataLoaderRaw (see docs there). Used only in MSCOCO test evaluation, where we have a specific json file of only test set images.')
79 | # misc
80 | parser.add_argument('--id', type=str, default='',
81 | help='an id identifying this run/job. used only if language_eval = 1 for appending to intermediate files')
82 | parser.add_argument('--verbose_beam', type=int, default=1,
83 | help='if we need to print out all beam search beams.')
84 | parser.add_argument('--verbose_loss', type=int, default=0,
85 | help='if we need to calculate loss.')
86 |
87 | opt = parser.parse_args()
88 |
89 | # Load infos
90 | with open(opt.infos_path) as f:
91 | infos = cPickle.load(f)
92 |
93 | # override and collect parameters
94 | if len(opt.input_fc_dir) == 0:
95 | opt.input_fc_dir = infos['opt'].input_fc_dir
96 | opt.input_att_dir = infos['opt'].input_att_dir
97 | opt.input_box_dir = infos['opt'].input_box_dir
98 | opt.input_label_h5 = infos['opt'].input_label_h5
99 | if len(opt.input_json) == 0:
100 | opt.input_json = infos['opt'].input_json
101 | if opt.batch_size == 0:
102 | opt.batch_size = infos['opt'].batch_size
103 | if len(opt.id) == 0:
104 | opt.id = infos['opt'].id
105 | ignore = ["id", "batch_size", "beam_size", "start_from", "language_eval"]
106 | for k in vars(infos['opt']).keys():
107 | if k not in ignore:
108 | if k in vars(opt):
109 | assert vars(opt)[k] == vars(infos['opt'])[k], k + ' option not consistent'
110 | else:
111 | vars(opt).update({k: vars(infos['opt'])[k]}) # copy over options from model
112 |
113 | vocab = infos['vocab'] # ix -> word mapping
114 |
115 | # Setup the model
116 | model = models.setup(opt)
117 | model.load_state_dict(torch.load(opt.model))
118 | model.cuda()
119 | model.eval()
120 | crit = utils.LanguageModelCriterion()
121 |
122 | # Create the Data Loader instance
123 | if len(opt.image_folder) == 0:
124 | loader = DataLoader(opt)
125 | else:
126 | loader = DataLoaderRaw({'folder_path': opt.image_folder,
127 | 'coco_json': opt.coco_json,
128 | 'batch_size': opt.batch_size,
129 | 'cnn_model': opt.cnn_model})
130 | # When eval using provided pretrained model, the vocab may be different from what you have in your cocotalk.json
131 | # So make sure to use the vocab in infos file.
132 | loader.ix_to_word = infos['vocab']
133 |
134 |
135 | # Set sample options
136 | loss, split_predictions, lang_stats = eval_utils.eval_split(model, crit, loader,
137 | vars(opt))
138 |
139 | print('loss: ', loss)
140 | if lang_stats:
141 | print(lang_stats)
142 |
143 | if opt.dump_json == 1:
144 | # dump the json
145 | json.dump(split_predictions, open('vis/vis.json', 'w'))
146 |
--------------------------------------------------------------------------------
/eval_ensemble.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import json
6 | import numpy as np
7 |
8 | import time
9 | import os
10 | from six.moves import cPickle
11 |
12 | import opts
13 | import models
14 | from dataloader import *
15 | from dataloaderraw import *
16 | import eval_utils
17 | import argparse
18 | import misc.utils as utils
19 | import torch
20 |
21 | # Input arguments and options
22 | parser = argparse.ArgumentParser()
23 | # Input paths
24 | parser.add_argument('--ids', nargs='+', required=True, help='id of the models to ensemble')
25 | # parser.add_argument('--models', nargs='+', required=True
26 | # help='path to model to evaluate')
27 | # parser.add_argument('--infos_paths', nargs='+', required=True, help='path to infos to evaluate')
28 | # Basic options
29 | parser.add_argument('--batch_size', type=int, default=0,
30 | help='if > 0 then overrule, otherwise load from checkpoint.')
31 | parser.add_argument('--num_images', type=int, default=-1,
32 | help='how many images to use when periodically evaluating the loss? (-1 = all)')
33 | parser.add_argument('--language_eval', type=int, default=0,
34 | help='Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.')
35 | parser.add_argument('--dump_images', type=int, default=1,
36 | help='Dump images into vis/imgs folder for vis? (1=yes,0=no)')
37 | parser.add_argument('--dump_json', type=int, default=1,
38 | help='Dump json with predictions into vis folder? (1=yes,0=no)')
39 | parser.add_argument('--dump_path', type=int, default=0,
40 | help='Write image paths along with predictions into vis json? (1=yes,0=no)')
41 |
42 | # Sampling options
43 | parser.add_argument('--sample_max', type=int, default=1,
44 | help='1 = sample argmax words. 0 = sample from distributions.')
45 | parser.add_argument('--max_ppl', type=int, default=0,
46 | help='beam search by max perplexity or max probability.')
47 | parser.add_argument('--beam_size', type=int, default=2,
48 | help='used when sample_max = 1, indicates number of beams in beam search. Usually 2 or 3 works well. More is not better. Set this to 1 for faster runtime but a bit worse performance.')
49 | parser.add_argument('--group_size', type=int, default=1,
50 | help='used for diverse beam search. if group_size is 1, then it\'s normal beam search')
51 | parser.add_argument('--diversity_lambda', type=float, default=0.5,
52 | help='used for diverse beam search. Usually from 0.2 to 0.8. Higher value of lambda produces a more diverse list')
53 | parser.add_argument('--temperature', type=float, default=1.0,
54 | help='temperature when sampling from distributions (i.e. when sample_max = 0). Lower = "safer" predictions.')
55 | parser.add_argument('--decoding_constraint', type=int, default=0,
56 | help='If 1, not allowing same word in a row')
57 | # For evaluation on a folder of images:
58 | parser.add_argument('--image_folder', type=str, default='',
59 | help='If this is nonempty then will predict on the images in this folder path')
60 | parser.add_argument('--image_root', type=str, default='',
61 | help='In case the image paths have to be preprended with a root path to an image folder')
62 | # For evaluation on MSCOCO images from some split:
63 | parser.add_argument('--input_fc_dir', type=str, default='',
64 | help='path to the h5file containing the preprocessed dataset')
65 | parser.add_argument('--input_att_dir', type=str, default='',
66 | help='path to the h5file containing the preprocessed dataset')
67 | parser.add_argument('--input_box_dir', type=str, default='',
68 | help='path to the h5file containing the preprocessed dataset')
69 | parser.add_argument('--input_label_h5', type=str, default='',
70 | help='path to the h5file containing the preprocessed dataset')
71 | parser.add_argument('--input_json', type=str, default='',
72 | help='path to the json file containing additional info and vocab. empty = fetch from model checkpoint.')
73 | parser.add_argument('--split', type=str, default='test',
74 | help='if running on MSCOCO images, which split to use: val|test|train')
75 | parser.add_argument('--coco_json', type=str, default='',
76 | help='if nonempty then use this file in DataLoaderRaw (see docs there). Used only in MSCOCO test evaluation, where we have a specific json file of only test set images.')
77 | parser.add_argument('--seq_length', type=int, default=40,
78 | help='maximum sequence length during sampling')
79 | # misc
80 | parser.add_argument('--id', type=str, default='',
81 | help='an id identifying this run/job. used only if language_eval = 1 for appending to intermediate files')
82 | parser.add_argument('--verbose_beam', type=int, default=1,
83 | help='if we need to print out all beam search beams.')
84 | parser.add_argument('--verbose_loss', type=int, default=0,
85 | help='If calculate loss using ground truth during evaluation')
86 |
87 | opt = parser.parse_args()
88 |
89 | model_infos = [cPickle.load(open('log_%s/infos_%s-best.pkl' %(id, id))) for id in opt.ids]
90 | model_paths = ['log_%s/model-best.pth' %(id) for id in opt.ids]
91 |
92 | # Load one infos
93 | infos = model_infos[0]
94 |
95 | # override and collect parameters
96 | if len(opt.input_fc_dir) == 0:
97 | opt.input_fc_dir = infos['opt'].input_fc_dir
98 | opt.input_att_dir = infos['opt'].input_att_dir
99 | opt.input_box_dir = infos['opt'].input_box_dir
100 | opt.input_label_h5 = infos['opt'].input_label_h5
101 | if len(opt.input_json) == 0:
102 | opt.input_json = infos['opt'].input_json
103 | if opt.batch_size == 0:
104 | opt.batch_size = infos['opt'].batch_size
105 | if len(opt.id) == 0:
106 | opt.id = infos['opt'].id
107 |
108 | vars(opt).update({k: vars(infos['opt'])[k] for k in vars(infos['opt']).keys() if k not in vars(opt)}) # copy over options from model
109 |
110 |
111 | opt.use_box = max([getattr(infos['opt'], 'use_box', 0) for infos in model_infos])
112 | assert max([getattr(infos['opt'], 'norm_att_feat', 0) for infos in model_infos]) == max([getattr(infos['opt'], 'norm_att_feat', 0) for infos in model_infos]), 'Not support different norm_att_feat'
113 | assert max([getattr(infos['opt'], 'norm_box_feat', 0) for infos in model_infos]) == max([getattr(infos['opt'], 'norm_box_feat', 0) for infos in model_infos]), 'Not support different norm_box_feat'
114 |
115 | vocab = infos['vocab'] # ix -> word mapping
116 |
117 | # Setup the model
118 | from models.AttEnsemble import AttEnsemble
119 |
120 | _models = []
121 | for i in range(len(model_infos)):
122 | model_infos[i]['opt'].start_from = None
123 | tmp = models.setup(model_infos[i]['opt'])
124 | tmp.load_state_dict(torch.load(model_paths[i]))
125 | tmp.cuda()
126 | tmp.eval()
127 | _models.append(tmp)
128 |
129 | model = AttEnsemble(_models)
130 | model.seq_length = opt.seq_length
131 | model.eval()
132 | crit = utils.LanguageModelCriterion()
133 |
134 | # Create the Data Loader instance
135 | if len(opt.image_folder) == 0:
136 | loader = DataLoader(opt)
137 | else:
138 | loader = DataLoaderRaw({'folder_path': opt.image_folder,
139 | 'coco_json': opt.coco_json,
140 | 'batch_size': opt.batch_size,
141 | 'cnn_model': opt.cnn_model})
142 | # When eval using provided pretrained model, the vocab may be different from what you have in your cocotalk.json
143 | # So make sure to use the vocab in infos file.
144 | loader.ix_to_word = infos['vocab']
145 |
146 |
147 | # Set sample options
148 | loss, split_predictions, lang_stats = eval_utils.eval_split(model, crit, loader,
149 | vars(opt))
150 |
151 | print('loss: ', loss)
152 | if lang_stats:
153 | print(lang_stats)
154 |
155 | if opt.dump_json == 1:
156 | # dump the json
157 | json.dump(split_predictions, open('vis/vis.json', 'w'))
158 |
--------------------------------------------------------------------------------
/models/ShowTellModel.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from torch.autograd import *
9 | import misc.utils as utils
10 |
11 | from .CaptionModel import CaptionModel
12 |
13 | class ShowTellModel(CaptionModel):
14 | def __init__(self, opt):
15 | super(ShowTellModel, self).__init__()
16 | self.vocab_size = opt.vocab_size
17 | self.input_encoding_size = opt.input_encoding_size
18 | self.rnn_type = opt.rnn_type
19 | self.rnn_size = opt.rnn_size
20 | self.num_layers = opt.num_layers
21 | self.drop_prob_lm = opt.drop_prob_lm
22 | self.seq_length = opt.seq_length
23 | self.fc_feat_size = opt.fc_feat_size
24 |
25 | self.ss_prob = 0.0 # Schedule sampling probability
26 |
27 | self.img_embed = nn.Linear(self.fc_feat_size, self.input_encoding_size)
28 | self.core = getattr(nn, self.rnn_type.upper())(self.input_encoding_size, self.rnn_size, self.num_layers, bias=False, dropout=self.drop_prob_lm)
29 | self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size)
30 | self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1)
31 | self.dropout = nn.Dropout(self.drop_prob_lm)
32 |
33 | self.init_weights()
34 |
35 | def init_weights(self):
36 | initrange = 0.1
37 | self.embed.weight.data.uniform_(-initrange, initrange)
38 | self.logit.bias.data.fill_(0)
39 | self.logit.weight.data.uniform_(-initrange, initrange)
40 |
41 | def init_hidden(self, bsz):
42 | weight = next(self.parameters()).data
43 | if self.rnn_type == 'lstm':
44 | return (weight.new_zeros(self.num_layers, bsz, self.rnn_size),
45 | weight.new_zeros(self.num_layers, bsz, self.rnn_size))
46 | else:
47 | return weight.new_zeros(self.num_layers, bsz, self.rnn_size)
48 |
49 | def _forward(self, fc_feats, att_feats, seq, att_masks=None):
50 | batch_size = fc_feats.size(0)
51 | state = self.init_hidden(batch_size)
52 | outputs = []
53 |
54 | for i in range(seq.size(1)):
55 | if i == 0:
56 | xt = self.img_embed(fc_feats)
57 | else:
58 | if self.training and i >= 2 and self.ss_prob > 0.0: # otherwiste no need to sample
59 | sample_prob = fc_feats.data.new(batch_size).uniform_(0, 1)
60 | sample_mask = sample_prob < self.ss_prob
61 | if sample_mask.sum() == 0:
62 | it = seq[:, i-1].clone()
63 | else:
64 | sample_ind = sample_mask.nonzero().view(-1)
65 | it = seq[:, i-1].data.clone()
66 | #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1)
67 | #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1))
68 | prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1)
69 | it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind))
70 | else:
71 | it = seq[:, i-1].clone()
72 | # break if all the sequences end
73 | if i >= 2 and seq[:, i-1].data.sum() == 0:
74 | break
75 | xt = self.embed(it)
76 |
77 | output, state = self.core(xt.unsqueeze(0), state)
78 | output = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1)
79 | outputs.append(output)
80 |
81 | return torch.cat([_.unsqueeze(1) for _ in outputs[1:]], 1).contiguous()
82 |
83 | def get_logprobs_state(self, it, state):
84 | # 'it' contains a word index
85 | xt = self.embed(it)
86 |
87 | output, state = self.core(xt.unsqueeze(0), state)
88 | logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1)
89 |
90 | return logprobs, state
91 |
92 | def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
93 | beam_size = opt.get('beam_size', 10)
94 | batch_size = fc_feats.size(0)
95 |
96 | assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
97 | seq = torch.LongTensor(self.seq_length, batch_size).zero_()
98 | seqLogprobs = torch.FloatTensor(self.seq_length, batch_size)
99 | # lets process every image independently for now, for simplicity
100 |
101 | self.done_beams = [[] for _ in range(batch_size)]
102 | for k in range(batch_size):
103 | state = self.init_hidden(beam_size)
104 | for t in range(2):
105 | if t == 0:
106 | xt = self.img_embed(fc_feats[k:k+1]).expand(beam_size, self.input_encoding_size)
107 | elif t == 1: # input
108 | it = fc_feats.data.new(beam_size).long().zero_()
109 | xt = self.embed(it)
110 |
111 | output, state = self.core(xt.unsqueeze(0), state)
112 | logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1)
113 |
114 | self.done_beams[k] = self.beam_search(state, logprobs, opt=opt)
115 | seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
116 | seqLogprobs[:, k] = self.done_beams[k][0]['logps']
117 | # return the samples and their log likelihoods
118 | return seq.transpose(0, 1), seqLogprobs.transpose(0, 1)
119 |
120 | def _sample(self, fc_feats, att_feats, att_masks=None, opt={}):
121 | sample_max = opt.get('sample_max', 1)
122 | beam_size = opt.get('beam_size', 1)
123 | temperature = opt.get('temperature', 1.0)
124 | if beam_size > 1:
125 | return self.sample_beam(fc_feats, att_feats, opt)
126 |
127 | batch_size = fc_feats.size(0)
128 | state = self.init_hidden(batch_size)
129 | seq = fc_feats.new_zeros(batch_size, self.seq_length, dtype=torch.long)
130 | seqLogprobs = fc_feats.new_zeros(batch_size, self.seq_length)
131 | for t in range(self.seq_length + 2):
132 | if t == 0:
133 | xt = self.img_embed(fc_feats)
134 | else:
135 | if t == 1: # input
136 | it = fc_feats.data.new(batch_size).long().zero_()
137 | xt = self.embed(it)
138 |
139 | output, state = self.core(xt.unsqueeze(0), state)
140 | logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1)
141 |
142 | # sample the next word
143 | if t == self.seq_length + 1: # skip if we achieve maximum length
144 | break
145 | if sample_max:
146 | sampleLogprobs, it = torch.max(logprobs.data, 1)
147 | it = it.view(-1).long()
148 | else:
149 | if temperature == 1.0:
150 | prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1)
151 | else:
152 | # scale logprobs by temperature
153 | prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu()
154 | it = torch.multinomial(prob_prev, 1).cuda()
155 | sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions
156 | it = it.view(-1).long() # and flatten indices for downstream processing
157 |
158 | if t >= 1:
159 | # stop when all finished
160 | if t == 1:
161 | unfinished = it > 0
162 | else:
163 | unfinished = unfinished * (it > 0)
164 | it = it * unfinished.type_as(it)
165 | seq[:,t-1] = it #seq[t] the input of t+2 time step
166 | seqLogprobs[:,t-1] = sampleLogprobs.view(-1)
167 | if unfinished.sum() == 0:
168 | break
169 |
170 | return seq, seqLogprobs
--------------------------------------------------------------------------------
/misc/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import collections
6 | import torch
7 | import torch.nn as nn
8 | import numpy as np
9 | import torch.optim as optim
10 |
11 | def if_use_att(caption_model):
12 | # Decide if load attention feature according to caption model
13 | if caption_model in ['show_tell', 'all_img', 'fc']:
14 | return False
15 | return True
16 |
17 | # Input: seq, N*D numpy array, with element 0 .. vocab_size. 0 is END token.
18 | def decode_sequence(ix_to_word, seq):
19 | N, D = seq.size()
20 | out = []
21 | for i in range(N):
22 | txt = ''
23 | for j in range(D):
24 | ix = seq[i,j]
25 | if ix > 0 :
26 | if j >= 1:
27 | txt = txt + ' '
28 | txt = txt + ix_to_word[str(ix.item())]
29 | else:
30 | break
31 | out.append(txt)
32 | return out
33 |
34 | def to_contiguous(tensor):
35 | if tensor.is_contiguous():
36 | return tensor
37 | else:
38 | return tensor.contiguous()
39 |
40 | class RewardCriterion(nn.Module):
41 | def __init__(self):
42 | super(RewardCriterion, self).__init__()
43 |
44 | def forward(self, input, seq, reward):
45 | input = to_contiguous(input).view(-1)
46 | reward = to_contiguous(reward).view(-1)
47 | mask = (seq>0).float()
48 | mask = to_contiguous(torch.cat([mask.new(mask.size(0), 1).fill_(1), mask[:, :-1]], 1)).view(-1)
49 | output = - input * reward * mask
50 | output = torch.sum(output) / torch.sum(mask)
51 |
52 | return output
53 |
54 | class LanguageModelCriterion(nn.Module):
55 | def __init__(self):
56 | super(LanguageModelCriterion, self).__init__()
57 |
58 | def forward(self, input, target, mask):
59 | # truncate to the same size
60 | target = target[:, :input.size(1)]
61 | mask = mask[:, :input.size(1)]
62 |
63 | output = -input.gather(2, target.unsqueeze(2)).squeeze(2) * mask
64 | output = torch.sum(output) / torch.sum(mask)
65 |
66 | return output
67 |
68 | class LabelSmoothing(nn.Module):
69 | "Implement label smoothing."
70 | def __init__(self, size=0, padding_idx=0, smoothing=0.0):
71 | super(LabelSmoothing, self).__init__()
72 | self.criterion = nn.KLDivLoss(size_average=False, reduce=False)
73 | # self.padding_idx = padding_idx
74 | self.confidence = 1.0 - smoothing
75 | self.smoothing = smoothing
76 | # self.size = size
77 | self.true_dist = None
78 |
79 | def forward(self, input, target, mask):
80 | # truncate to the same size
81 | target = target[:, :input.size(1)]
82 | mask = mask[:, :input.size(1)]
83 |
84 | input = to_contiguous(input).view(-1, input.size(-1))
85 | target = to_contiguous(target).view(-1)
86 | mask = to_contiguous(mask).view(-1)
87 |
88 | # assert x.size(1) == self.size
89 | self.size = input.size(1)
90 | # true_dist = x.data.clone()
91 | true_dist = input.data.clone()
92 | # true_dist.fill_(self.smoothing / (self.size - 2))
93 | true_dist.fill_(self.smoothing / (self.size - 1))
94 | true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
95 | # true_dist[:, self.padding_idx] = 0
96 | # mask = torch.nonzero(target.data == self.padding_idx)
97 | # self.true_dist = true_dist
98 | return (self.criterion(input, true_dist).sum(1) * mask).sum() / mask.sum()
99 |
100 | def set_lr(optimizer, lr):
101 | for group in optimizer.param_groups:
102 | group['lr'] = lr
103 |
104 | def get_lr(optimizer):
105 | for group in optimizer.param_groups:
106 | return group['lr']
107 |
108 | def clip_gradient(optimizer, grad_clip):
109 | for group in optimizer.param_groups:
110 | for param in group['params']:
111 | param.grad.data.clamp_(-grad_clip, grad_clip)
112 |
113 | def build_optimizer(params, opt):
114 | if opt.optim == 'rmsprop':
115 | return optim.RMSprop(params, opt.learning_rate, opt.optim_alpha, opt.optim_epsilon, weight_decay=opt.weight_decay)
116 | elif opt.optim == 'adagrad':
117 | return optim.Adagrad(params, opt.learning_rate, weight_decay=opt.weight_decay)
118 | elif opt.optim == 'sgd':
119 | return optim.SGD(params, opt.learning_rate, weight_decay=opt.weight_decay)
120 | elif opt.optim == 'sgdm':
121 | return optim.SGD(params, opt.learning_rate, opt.optim_alpha, weight_decay=opt.weight_decay)
122 | elif opt.optim == 'sgdmom':
123 | return optim.SGD(params, opt.learning_rate, opt.optim_alpha, weight_decay=opt.weight_decay, nesterov=True)
124 | elif opt.optim == 'adam':
125 | return optim.Adam(params, opt.learning_rate, (opt.optim_alpha, opt.optim_beta), opt.optim_epsilon, weight_decay=opt.weight_decay)
126 | else:
127 | raise Exception("bad option opt.optim: {}".format(opt.optim))
128 |
129 |
130 | class NoamOpt(object):
131 | "Optim wrapper that implements rate."
132 | def __init__(self, model_size, factor, warmup, optimizer):
133 | self.optimizer = optimizer
134 | self._step = 0
135 | self.warmup = warmup
136 | self.factor = factor
137 | self.model_size = model_size
138 | self._rate = 0
139 |
140 | def step(self):
141 | "Update parameters and rate"
142 | self._step += 1
143 | rate = self.rate()
144 | for p in self.optimizer.param_groups:
145 | p['lr'] = rate
146 | self._rate = rate
147 | self.optimizer.step()
148 |
149 | def rate(self, step = None):
150 | "Implement `lrate` above"
151 | if step is None:
152 | step = self._step
153 | return self.factor * \
154 | (self.model_size ** (-0.5) *
155 | min(step ** (-0.5), step * self.warmup ** (-1.5)))
156 |
157 | def __getattr__(self, name):
158 | return getattr(self.optimizer, name)
159 |
160 | class ReduceLROnPlateau(object):
161 | "Optim wrapper that implements rate."
162 | def __init__(self, optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08):
163 | self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode, factor, patience, verbose, threshold, threshold_mode, cooldown, min_lr, eps)
164 | self.optimizer = optimizer
165 | self.current_lr = get_lr(optimizer)
166 |
167 | def step(self):
168 | "Update parameters and rate"
169 | self.optimizer.step()
170 |
171 | def scheduler_step(self, val):
172 | self.scheduler.step(val)
173 | self.current_lr = get_lr(self.optimizer)
174 |
175 | def state_dict(self):
176 | return {'current_lr':self.current_lr,
177 | 'scheduler_state_dict': {key: value for key, value in self.scheduler.__dict__.items() if key not in {'optimizer', 'is_better'}},
178 | 'optimizer_state_dict': self.optimizer.state_dict()}
179 |
180 | def load_state_dict(self, state_dict):
181 | if 'current_lr' not in state_dict:
182 | # it's normal optimizer
183 | self.optimizer.load_state_dict(state_dict)
184 | set_lr(self.optimizer, self.current_lr) # use the lr fromt the option
185 | else:
186 | # it's a schduler
187 | self.current_lr = state_dict['current_lr']
188 | self.scheduler.__dict__.update(state_dict['scheduler_state_dict'])
189 | self.scheduler._init_is_better(mode=self.scheduler.mode, threshold=self.scheduler.threshold, threshold_mode=self.scheduler.threshold_mode)
190 | self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
191 | # current_lr is actually useless in this case
192 |
193 | def rate(self, step = None):
194 | "Implement `lrate` above"
195 | if step is None:
196 | step = self._step
197 | return self.factor * \
198 | (self.model_size ** (-0.5) *
199 | min(step ** (-0.5), step * self.warmup ** (-1.5)))
200 |
201 | def __getattr__(self, name):
202 | return getattr(self.optimizer, name)
203 |
204 | def get_std_opt(model, factor=1, warmup=2000):
205 | # return NoamOpt(model.tgt_embed[0].d_model, 2, 4000,
206 | # torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
207 | return NoamOpt(model.model.tgt_embed[0].d_model, factor, warmup,
208 | torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
209 |
--------------------------------------------------------------------------------
/scripts/prepro_labels.py:
--------------------------------------------------------------------------------
1 | """
2 | Preprocess a raw json dataset into hdf5/json files for use in data_loader.lua
3 |
4 | Input: json file that has the form
5 | [{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...]
6 | example element in this list would look like
7 | {'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895}
8 |
9 | This script reads this json, does some basic preprocessing on the captions
10 | (e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays
11 |
12 | Output: a json file and an hdf5 file
13 | The hdf5 file contains several fields:
14 | /images is (N,3,256,256) uint8 array of raw image data in RGB format
15 | /labels is (M,max_length) uint32 array of encoded labels, zero padded
16 | /label_start_ix and /label_end_ix are (N,) uint32 arrays of pointers to the
17 | first and last indices (in range 1..M) of labels for each image
18 | /label_length stores the length of the sequence for each of the M sequences
19 |
20 | The json file has a dict that contains:
21 | - an 'ix_to_word' field storing the vocab in form {ix:'word'}, where ix is 1-indexed
22 | - an 'images' field that is a list holding auxiliary information for each image,
23 | such as in particular the 'split' it was assigned to.
24 | """
25 |
26 | from __future__ import absolute_import
27 | from __future__ import division
28 | from __future__ import print_function
29 |
30 | import os
31 | import json
32 | import argparse
33 | from random import shuffle, seed
34 | import string
35 | # non-standard dependencies:
36 | import h5py
37 | import numpy as np
38 | import torch
39 | import torchvision.models as models
40 | import skimage.io
41 | from PIL import Image
42 |
43 | def build_vocab(imgs, params):
44 | count_thr = params['word_count_threshold']
45 |
46 | # count up the number of words
47 | counts = {}
48 | for img in imgs:
49 | for sent in img['sentences']:
50 | for w in sent['tokens']:
51 | counts[w] = counts.get(w, 0) + 1
52 | cw = sorted([(count,w) for w,count in counts.items()], reverse=True)
53 | print('top words and their counts:')
54 | print('\n'.join(map(str,cw[:20])))
55 |
56 | # print some stats
57 | total_words = sum(counts.values())
58 | print('total words:', total_words)
59 | bad_words = [w for w,n in counts.items() if n <= count_thr]
60 | vocab = [w for w,n in counts.items() if n > count_thr]
61 | bad_count = sum(counts[w] for w in bad_words)
62 | print('number of bad words: %d/%d = %.2f%%' % (len(bad_words), len(counts), len(bad_words)*100.0/len(counts)))
63 | print('number of words in vocab would be %d' % (len(vocab), ))
64 | print('number of UNKs: %d/%d = %.2f%%' % (bad_count, total_words, bad_count*100.0/total_words))
65 |
66 | # lets look at the distribution of lengths as well
67 | sent_lengths = {}
68 | for img in imgs:
69 | for sent in img['sentences']:
70 | txt = sent['tokens']
71 | nw = len(txt)
72 | sent_lengths[nw] = sent_lengths.get(nw, 0) + 1
73 | max_len = max(sent_lengths.keys())
74 | print('max length sentence in raw data: ', max_len)
75 | print('sentence length distribution (count, number of words):')
76 | sum_len = sum(sent_lengths.values())
77 | for i in range(max_len+1):
78 | print('%2d: %10d %f%%' % (i, sent_lengths.get(i,0), sent_lengths.get(i,0)*100.0/sum_len))
79 |
80 | # lets now produce the final annotations
81 | if bad_count > 0:
82 | # additional special UNK token we will use below to map infrequent words to
83 | print('inserting the special UNK token')
84 | vocab.append('UNK')
85 |
86 | for img in imgs:
87 | img['final_captions'] = []
88 | for sent in img['sentences']:
89 | txt = sent['tokens']
90 | caption = [w if counts.get(w,0) > count_thr else 'UNK' for w in txt]
91 | img['final_captions'].append(caption)
92 |
93 | return vocab
94 |
95 | def encode_captions(imgs, params, wtoi):
96 | """
97 | encode all captions into one large array, which will be 1-indexed.
98 | also produces label_start_ix and label_end_ix which store 1-indexed
99 | and inclusive (Lua-style) pointers to the first and last caption for
100 | each image in the dataset.
101 | """
102 |
103 | max_length = params['max_length']
104 | N = len(imgs)
105 | M = sum(len(img['final_captions']) for img in imgs) # total number of captions
106 |
107 | label_arrays = []
108 | label_start_ix = np.zeros(N, dtype='uint32') # note: these will be one-indexed
109 | label_end_ix = np.zeros(N, dtype='uint32')
110 | label_length = np.zeros(M, dtype='uint32')
111 | caption_counter = 0
112 | counter = 1
113 | for i,img in enumerate(imgs):
114 | n = len(img['final_captions'])
115 | assert n > 0, 'error: some image has no captions'
116 |
117 | Li = np.zeros((n, max_length), dtype='uint32')
118 | for j,s in enumerate(img['final_captions']):
119 | label_length[caption_counter] = min(max_length, len(s)) # record the length of this sequence
120 | caption_counter += 1
121 | for k,w in enumerate(s):
122 | if k < max_length:
123 | Li[j,k] = wtoi[w]
124 |
125 | # note: word indices are 1-indexed, and captions are padded with zeros
126 | label_arrays.append(Li)
127 | label_start_ix[i] = counter
128 | label_end_ix[i] = counter + n - 1
129 |
130 | counter += n
131 |
132 | L = np.concatenate(label_arrays, axis=0) # put all the labels together
133 | assert L.shape[0] == M, 'lengths don\'t match? that\'s weird'
134 | assert np.all(label_length > 0), 'error: some caption had no words?'
135 |
136 | print('encoded captions to array of size ', L.shape)
137 | return L, label_start_ix, label_end_ix, label_length
138 |
139 | def main(params):
140 |
141 | imgs = json.load(open(params['input_json'], 'r'))
142 | imgs = imgs['images']
143 |
144 | seed(123) # make reproducible
145 |
146 | # create the vocab
147 | vocab = build_vocab(imgs, params)
148 | itow = {i+1:w for i,w in enumerate(vocab)} # a 1-indexed vocab translation table
149 | wtoi = {w:i+1 for i,w in enumerate(vocab)} # inverse table
150 |
151 | # encode captions in large arrays, ready to ship to hdf5 file
152 | L, label_start_ix, label_end_ix, label_length = encode_captions(imgs, params, wtoi)
153 |
154 | # create output h5 file
155 | N = len(imgs)
156 | f_lb = h5py.File(params['output_h5']+'_label.h5', "w")
157 | f_lb.create_dataset("labels", dtype='uint32', data=L)
158 | f_lb.create_dataset("label_start_ix", dtype='uint32', data=label_start_ix)
159 | f_lb.create_dataset("label_end_ix", dtype='uint32', data=label_end_ix)
160 | f_lb.create_dataset("label_length", dtype='uint32', data=label_length)
161 | f_lb.close()
162 |
163 | # create output json file
164 | out = {}
165 | out['ix_to_word'] = itow # encode the (1-indexed) vocab
166 | out['images'] = []
167 | for i,img in enumerate(imgs):
168 |
169 | jimg = {}
170 | jimg['split'] = img['split']
171 | if 'filename' in img: jimg['file_path'] = os.path.join(img['filepath'], img['filename']) # copy it over, might need
172 | if 'cocoid' in img: jimg['id'] = img['cocoid'] # copy over & mantain an id, if present (e.g. coco ids, useful)
173 |
174 | if params['images_root'] != '':
175 | with Image.open(os.path.join(params['images_root'], img['filepath'], img['filename'])) as _img:
176 | jimg['width'], jimg['height'] = _img.size
177 |
178 | out['images'].append(jimg)
179 |
180 | json.dump(out, open(params['output_json'], 'w'))
181 | print('wrote ', params['output_json'])
182 |
183 | if __name__ == "__main__":
184 |
185 | parser = argparse.ArgumentParser()
186 |
187 | # input json
188 | parser.add_argument('--input_json', required=True, help='input json file to process into hdf5')
189 | parser.add_argument('--output_json', default='data.json', help='output json file')
190 | parser.add_argument('--output_h5', default='data', help='output h5 file')
191 | parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json')
192 |
193 | # options
194 | parser.add_argument('--max_length', default=16, type=int, help='max length of a caption, in number of words. captions longer than this get clipped.')
195 | parser.add_argument('--word_count_threshold', default=5, type=int, help='only words that occur more than this number of times will be put in vocab')
196 |
197 | args = parser.parse_args()
198 | params = vars(args) # convert to ordinary dict
199 | print('parsed input parameters:')
200 | print(json.dumps(params, indent = 2))
201 | main(params)
202 |
--------------------------------------------------------------------------------
/models/FCModel.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from torch.autograd import *
9 | import misc.utils as utils
10 |
11 | from .CaptionModel import CaptionModel
12 |
13 | class LSTMCore(nn.Module):
14 | def __init__(self, opt):
15 | super(LSTMCore, self).__init__()
16 | self.input_encoding_size = opt.input_encoding_size
17 | self.rnn_size = opt.rnn_size
18 | self.drop_prob_lm = opt.drop_prob_lm
19 |
20 | # Build a LSTM
21 | self.i2h = nn.Linear(self.input_encoding_size, 5 * self.rnn_size)
22 | self.h2h = nn.Linear(self.rnn_size, 5 * self.rnn_size)
23 | self.dropout = nn.Dropout(self.drop_prob_lm)
24 |
25 | def forward(self, xt, state):
26 |
27 | all_input_sums = self.i2h(xt) + self.h2h(state[0][-1])
28 | sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size)
29 | sigmoid_chunk = F.sigmoid(sigmoid_chunk)
30 | in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size)
31 | forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size)
32 | out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size)
33 |
34 | in_transform = torch.max(\
35 | all_input_sums.narrow(1, 3 * self.rnn_size, self.rnn_size),
36 | all_input_sums.narrow(1, 4 * self.rnn_size, self.rnn_size))
37 | next_c = forget_gate * state[1][-1] + in_gate * in_transform
38 | next_h = out_gate * F.tanh(next_c)
39 |
40 | output = self.dropout(next_h)
41 | state = (next_h.unsqueeze(0), next_c.unsqueeze(0))
42 | return output, state
43 |
44 | class FCModel(CaptionModel):
45 | def __init__(self, opt):
46 | super(FCModel, self).__init__()
47 | self.vocab_size = opt.vocab_size
48 | self.input_encoding_size = opt.input_encoding_size
49 | self.rnn_type = opt.rnn_type
50 | self.rnn_size = opt.rnn_size
51 | self.num_layers = opt.num_layers
52 | self.drop_prob_lm = opt.drop_prob_lm
53 | self.seq_length = opt.seq_length
54 | self.fc_feat_size = opt.fc_feat_size
55 |
56 | self.ss_prob = 0.0 # Schedule sampling probability
57 |
58 | self.img_embed = nn.Linear(self.fc_feat_size, self.input_encoding_size)
59 | self.core = LSTMCore(opt)
60 | self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size)
61 | self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1)
62 |
63 | self.init_weights()
64 |
65 | def init_weights(self):
66 | initrange = 0.1
67 | self.embed.weight.data.uniform_(-initrange, initrange)
68 | self.logit.bias.data.fill_(0)
69 | self.logit.weight.data.uniform_(-initrange, initrange)
70 |
71 | def init_hidden(self, bsz):
72 | weight = next(self.parameters())
73 | if self.rnn_type == 'lstm':
74 | return (weight.new_zeros(self.num_layers, bsz, self.rnn_size),
75 | weight.new_zeros(self.num_layers, bsz, self.rnn_size))
76 | else:
77 | return weight.new_zeros(self.num_layers, bsz, self.rnn_size)
78 |
79 | def _forward(self, fc_feats, att_feats, seq, att_masks=None):
80 | batch_size = fc_feats.size(0)
81 | state = self.init_hidden(batch_size)
82 | outputs = []
83 |
84 | for i in range(seq.size(1)):
85 | if i == 0:
86 | xt = self.img_embed(fc_feats)
87 | else:
88 | if self.training and i >= 2 and self.ss_prob > 0.0: # otherwiste no need to sample
89 | sample_prob = fc_feats.data.new(batch_size).uniform_(0, 1)
90 | sample_mask = sample_prob < self.ss_prob
91 | if sample_mask.sum() == 0:
92 | it = seq[:, i-1].clone()
93 | else:
94 | sample_ind = sample_mask.nonzero().view(-1)
95 | it = seq[:, i-1].data.clone()
96 | #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1)
97 | #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1))
98 | prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1)
99 | it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind))
100 | else:
101 | it = seq[:, i-1].clone()
102 | # break if all the sequences end
103 | if i >= 2 and seq[:, i-1].sum() == 0:
104 | break
105 | xt = self.embed(it)
106 |
107 | output, state = self.core(xt, state)
108 | output = F.log_softmax(self.logit(output), dim=1)
109 | outputs.append(output)
110 |
111 | return torch.cat([_.unsqueeze(1) for _ in outputs[1:]], 1).contiguous()
112 |
113 | def get_logprobs_state(self, it, state):
114 | # 'it' is contains a word index
115 | xt = self.embed(it)
116 |
117 | output, state = self.core(xt, state)
118 | logprobs = F.log_softmax(self.logit(output), dim=1)
119 |
120 | return logprobs, state
121 |
122 | def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
123 | beam_size = opt.get('beam_size', 10)
124 | batch_size = fc_feats.size(0)
125 |
126 | assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
127 | seq = torch.LongTensor(self.seq_length, batch_size).zero_()
128 | seqLogprobs = torch.FloatTensor(self.seq_length, batch_size)
129 | # lets process every image independently for now, for simplicity
130 |
131 | self.done_beams = [[] for _ in range(batch_size)]
132 | for k in range(batch_size):
133 | state = self.init_hidden(beam_size)
134 | for t in range(2):
135 | if t == 0:
136 | xt = self.img_embed(fc_feats[k:k+1]).expand(beam_size, self.input_encoding_size)
137 | elif t == 1: # input
138 | it = fc_feats.data.new(beam_size).long().zero_()
139 | xt = self.embed(it)
140 |
141 | output, state = self.core(xt, state)
142 | logprobs = F.log_softmax(self.logit(output), dim=1)
143 |
144 | self.done_beams[k] = self.beam_search(state, logprobs, opt=opt)
145 | seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
146 | seqLogprobs[:, k] = self.done_beams[k][0]['logps']
147 | # return the samples and their log likelihoods
148 | return seq.transpose(0, 1), seqLogprobs.transpose(0, 1)
149 |
150 | def _sample(self, fc_feats, att_feats, att_masks=None, opt={}):
151 | sample_max = opt.get('sample_max', 1)
152 | beam_size = opt.get('beam_size', 1)
153 | temperature = opt.get('temperature', 1.0)
154 | if beam_size > 1:
155 | return self.sample_beam(fc_feats, att_feats, opt)
156 |
157 | batch_size = fc_feats.size(0)
158 | state = self.init_hidden(batch_size)
159 | seq = fc_feats.new_zeros(batch_size, self.seq_length, dtype=torch.long)
160 | seqLogprobs = fc_feats.new_zeros(batch_size, self.seq_length)
161 | for t in range(self.seq_length + 2):
162 | if t == 0:
163 | xt = self.img_embed(fc_feats)
164 | else:
165 | if t == 1: # input
166 | it = fc_feats.data.new(batch_size).long().zero_()
167 | xt = self.embed(it)
168 |
169 | output, state = self.core(xt, state)
170 | logprobs = F.log_softmax(self.logit(output), dim=1)
171 |
172 | # sample the next_word
173 | if t == self.seq_length + 1: # skip if we achieve maximum length
174 | break
175 | if sample_max:
176 | sampleLogprobs, it = torch.max(logprobs.data, 1)
177 | it = it.view(-1).long()
178 | else:
179 | if temperature == 1.0:
180 | prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1)
181 | else:
182 | # scale logprobs by temperature
183 | prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu()
184 | it = torch.multinomial(prob_prev, 1).cuda()
185 | sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions
186 | it = it.view(-1).long() # and flatten indices for downstream processing
187 |
188 | if t >= 1:
189 | # stop when all finished
190 | if t == 1:
191 | unfinished = it > 0
192 | else:
193 | unfinished = unfinished * (it > 0)
194 | it = it * unfinished.type_as(it)
195 | seq[:,t-1] = it #seq[t] the input of t+2 time step
196 | seqLogprobs[:,t-1] = sampleLogprobs.view(-1)
197 | if unfinished.sum() == 0:
198 | break
199 |
200 | return seq, seqLogprobs
201 |
--------------------------------------------------------------------------------
/models/CaptionModel.py:
--------------------------------------------------------------------------------
1 | # This file contains ShowAttendTell and AllImg model
2 |
3 | # ShowAttendTell is from Show, Attend and Tell: Neural Image Caption Generation with Visual Attention
4 | # https://arxiv.org/abs/1502.03044
5 |
6 | # AllImg is a model where
7 | # img feature is concatenated with word embedding at every time step as the input of lstm
8 | from __future__ import absolute_import
9 | from __future__ import division
10 | from __future__ import print_function
11 |
12 | import torch
13 | import torch.nn as nn
14 | import torch.nn.functional as F
15 | from torch.autograd import *
16 | import misc.utils as utils
17 |
18 |
19 | class CaptionModel(nn.Module):
20 | def __init__(self):
21 | super(CaptionModel, self).__init__()
22 |
23 | # implements beam search
24 | # calls beam_step and returns the final set of beams
25 | # augments log-probabilities with diversity terms when number of groups > 1
26 |
27 | def forward(self, *args, **kwargs):
28 | mode = kwargs.get('mode', 'forward')
29 | if 'mode' in kwargs:
30 | del kwargs['mode']
31 | return getattr(self, '_'+mode)(*args, **kwargs)
32 |
33 | def beam_search(self, init_state, init_logprobs, *args, **kwargs):
34 |
35 | # function computes the similarity score to be augmented
36 | def add_diversity(beam_seq_table, logprobsf, t, divm, diversity_lambda, bdash):
37 | local_time = t - divm
38 | unaug_logprobsf = logprobsf.clone()
39 | for prev_choice in range(divm):
40 | prev_decisions = beam_seq_table[prev_choice][local_time]
41 | for sub_beam in range(bdash):
42 | for prev_labels in range(bdash):
43 | logprobsf[sub_beam][prev_decisions[prev_labels]] = logprobsf[sub_beam][prev_decisions[prev_labels]] - diversity_lambda
44 | return unaug_logprobsf
45 |
46 | # does one step of classical beam search
47 |
48 | def beam_step(logprobsf, unaug_logprobsf, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state):
49 | #INPUTS:
50 | #logprobsf: probabilities augmented after diversity
51 | #beam_size: obvious
52 | #t : time instant
53 | #beam_seq : tensor contanining the beams
54 | #beam_seq_logprobs: tensor contanining the beam logprobs
55 | #beam_logprobs_sum: tensor contanining joint logprobs
56 | #OUPUTS:
57 | #beam_seq : tensor containing the word indices of the decoded captions
58 | #beam_seq_logprobs : log-probability of each decision made, same size as beam_seq
59 | #beam_logprobs_sum : joint log-probability of each beam
60 |
61 | ys,ix = torch.sort(logprobsf,1,True)
62 | candidates = []
63 | cols = min(beam_size, ys.size(1))
64 | rows = beam_size
65 | if t == 0:
66 | rows = 1
67 | for c in range(cols): # for each column (word, essentially)
68 | for q in range(rows): # for each beam expansion
69 | #compute logprob of expanding beam q with word in (sorted) position c
70 | local_logprob = ys[q,c].item()
71 | candidate_logprob = beam_logprobs_sum[q] + local_logprob
72 | local_unaug_logprob = unaug_logprobsf[q,ix[q,c]]
73 | candidates.append({'c':ix[q,c], 'q':q, 'p':candidate_logprob, 'r':local_unaug_logprob})
74 | candidates = sorted(candidates, key=lambda x: -x['p'])
75 |
76 | new_state = [_.clone() for _ in state]
77 | #beam_seq_prev, beam_seq_logprobs_prev
78 | if t >= 1:
79 | #we''ll need these as reference when we fork beams around
80 | beam_seq_prev = beam_seq[:t].clone()
81 | beam_seq_logprobs_prev = beam_seq_logprobs[:t].clone()
82 | for vix in range(beam_size):
83 | v = candidates[vix]
84 | #fork beam index q into index vix
85 | if t >= 1:
86 | beam_seq[:t, vix] = beam_seq_prev[:, v['q']]
87 | beam_seq_logprobs[:t, vix] = beam_seq_logprobs_prev[:, v['q']]
88 | #rearrange recurrent states
89 | for state_ix in range(len(new_state)):
90 | # copy over state in previous beam q to new beam at vix
91 | new_state[state_ix][:, vix] = state[state_ix][:, v['q']] # dimension one is time step
92 | #append new end terminal at the end of this beam
93 | beam_seq[t, vix] = v['c'] # c'th word is the continuation
94 | beam_seq_logprobs[t, vix] = v['r'] # the raw logprob here
95 | beam_logprobs_sum[vix] = v['p'] # the new (sum) logprob along this beam
96 | state = new_state
97 | return beam_seq,beam_seq_logprobs,beam_logprobs_sum,state,candidates
98 |
99 | # Start diverse_beam_search
100 | opt = kwargs['opt']
101 | beam_size = opt.get('beam_size', 10)
102 | group_size = opt.get('group_size', 1)
103 | diversity_lambda = opt.get('diversity_lambda', 0.5)
104 | decoding_constraint = opt.get('decoding_constraint', 0)
105 | max_ppl = opt.get('max_ppl', 0)
106 | bdash = beam_size // group_size # beam per group
107 |
108 | # INITIALIZATIONS
109 | beam_seq_table = [torch.LongTensor(self.seq_length, bdash).zero_() for _ in range(group_size)]
110 | beam_seq_logprobs_table = [torch.FloatTensor(self.seq_length, bdash).zero_() for _ in range(group_size)]
111 | beam_logprobs_sum_table = [torch.zeros(bdash) for _ in range(group_size)]
112 |
113 | # logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1)
114 | done_beams_table = [[] for _ in range(group_size)]
115 | state_table = [list(torch.unbind(_)) for _ in torch.stack(init_state).chunk(group_size, 2)]
116 | logprobs_table = list(init_logprobs.chunk(group_size, 0))
117 | # END INIT
118 |
119 | # Chunk elements in the args
120 | args = list(args)
121 | args = [_.chunk(group_size) if _ is not None else [None]*group_size for _ in args]
122 | args = [[args[i][j] for i in range(len(args))] for j in range(group_size)]
123 |
124 | for t in range(self.seq_length + group_size - 1):
125 | for divm in range(group_size):
126 | if t >= divm and t <= self.seq_length + divm - 1:
127 | # add diversity
128 | logprobsf = logprobs_table[divm].data.float()
129 | # suppress previous word
130 | if decoding_constraint and t-divm > 0:
131 | logprobsf.scatter_(1, beam_seq_table[divm][t-divm-1].unsqueeze(1).cuda(), float('-inf'))
132 | # suppress UNK tokens in the decoding
133 | logprobsf[:,logprobsf.size(1)-1] = logprobsf[:, logprobsf.size(1)-1] - 1000
134 | # diversity is added here
135 | # the function directly modifies the logprobsf values and hence, we need to return
136 | # the unaugmented ones for sorting the candidates in the end. # for historical
137 | # reasons :-)
138 | unaug_logprobsf = add_diversity(beam_seq_table,logprobsf,t,divm,diversity_lambda,bdash)
139 |
140 | # infer new beams
141 | beam_seq_table[divm],\
142 | beam_seq_logprobs_table[divm],\
143 | beam_logprobs_sum_table[divm],\
144 | state_table[divm],\
145 | candidates_divm = beam_step(logprobsf,
146 | unaug_logprobsf,
147 | bdash,
148 | t-divm,
149 | beam_seq_table[divm],
150 | beam_seq_logprobs_table[divm],
151 | beam_logprobs_sum_table[divm],
152 | state_table[divm])
153 |
154 | # if time's up... or if end token is reached then copy beams
155 | for vix in range(bdash):
156 | if beam_seq_table[divm][t-divm,vix] == 0 or t == self.seq_length + divm - 1:
157 | final_beam = {
158 | 'seq': beam_seq_table[divm][:, vix].clone(),
159 | 'logps': beam_seq_logprobs_table[divm][:, vix].clone(),
160 | 'unaug_p': beam_seq_logprobs_table[divm][:, vix].sum().item(),
161 | 'p': beam_logprobs_sum_table[divm][vix].item()
162 | }
163 | if max_ppl:
164 | final_beam['p'] = final_beam['p'] / (t-divm+1)
165 | done_beams_table[divm].append(final_beam)
166 | # don't continue beams from finished sequences
167 | beam_logprobs_sum_table[divm][vix] = -1000
168 |
169 | # move the current group one step forward in time
170 |
171 | it = beam_seq_table[divm][t-divm]
172 | logprobs_table[divm], state_table[divm] = self.get_logprobs_state(it.cuda(), *(args[divm] + [state_table[divm]]))
173 |
174 | # all beams are sorted by their log-probabilities
175 | done_beams_table = [sorted(done_beams_table[i], key=lambda x: -x['p'])[:bdash] for i in range(group_size)]
176 | done_beams = reduce(lambda a,b:a+b, done_beams_table)
177 | return done_beams
--------------------------------------------------------------------------------
/opts.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def parse_opt():
4 | parser = argparse.ArgumentParser()
5 | # Data input settings
6 | parser.add_argument('--input_json', type=str, default='data/coco.json',
7 | help='path to the json file containing additional info and vocab')
8 | parser.add_argument('--input_fc_dir', type=str, default='data/cocotalk_fc',
9 | help='path to the directory containing the preprocessed fc feats')
10 | parser.add_argument('--input_att_dir', type=str, default='data/cocotalk_att',
11 | help='path to the directory containing the preprocessed att feats')
12 | parser.add_argument('--input_box_dir', type=str, default='data/cocotalk_box',
13 | help='path to the directory containing the boxes of att feats')
14 | parser.add_argument('--input_label_h5', type=str, default='data/coco_label.h5',
15 | help='path to the h5file containing the preprocessed dataset')
16 | parser.add_argument('--start_from', type=str, default=None,
17 | help="""continue training from saved model at this path. Path must contain files saved by previous training process:
18 | 'infos.pkl' : configuration;
19 | 'checkpoint' : paths to model file(s) (created by tf).
20 | Note: this file contains absolute paths, be careful when moving files around;
21 | 'model.ckpt-*' : file(s) with model definition (created by tf)
22 | """)
23 | parser.add_argument('--cached_tokens', type=str, default='coco-train-idxs',
24 | help='Cached token file for calculating cider score during self critical training.')
25 |
26 | # Model settings
27 | parser.add_argument('--caption_model', type=str, default="show_tell",
28 | help='show_tell, show_attend_tell, all_img, fc, att2in, att2in2, att2all2, adaatt, adaattmo, topdown, stackatt, denseatt, transformer')
29 | parser.add_argument('--rnn_size', type=int, default=512,
30 | help='size of the rnn in number of hidden nodes in each layer')
31 | parser.add_argument('--num_layers', type=int, default=1,
32 | help='number of layers in the RNN')
33 | parser.add_argument('--rnn_type', type=str, default='lstm',
34 | help='rnn, gru, or lstm')
35 | parser.add_argument('--input_encoding_size', type=int, default=512,
36 | help='the encoding size of each token in the vocabulary, and the image.')
37 | parser.add_argument('--att_hid_size', type=int, default=512,
38 | help='the hidden size of the attention MLP; only useful in show_attend_tell; 0 if not using hidden layer')
39 | parser.add_argument('--fc_feat_size', type=int, default=2048,
40 | help='2048 for resnet, 4096 for vgg')
41 | parser.add_argument('--att_feat_size', type=int, default=2048,
42 | help='2048 for resnet, 512 for vgg')
43 | parser.add_argument('--logit_layers', type=int, default=1,
44 | help='number of layers in the RNN')
45 |
46 |
47 | parser.add_argument('--use_bn', type=int, default=0,
48 | help='If 1, then do batch_normalization first in att_embed, if 2 then do bn both in the beginning and the end of att_embed')
49 |
50 | # feature manipulation
51 | parser.add_argument('--norm_att_feat', type=int, default=0,
52 | help='If normalize attention features')
53 | parser.add_argument('--use_box', type=int, default=0,
54 | help='If use box features')
55 | parser.add_argument('--norm_box_feat', type=int, default=0,
56 | help='If use box, do we normalize box feature')
57 |
58 | # Optimization: General
59 | parser.add_argument('--max_epochs', type=int, default=-1,
60 | help='number of epochs')
61 | parser.add_argument('--batch_size', type=int, default=16,
62 | help='minibatch size')
63 | parser.add_argument('--grad_clip', type=float, default=0.1, #5.,
64 | help='clip gradients at this value')
65 | parser.add_argument('--drop_prob_lm', type=float, default=0.5,
66 | help='strength of dropout in the Language Model RNN')
67 | parser.add_argument('--self_critical_after', type=int, default=-1,
68 | help='After what epoch do we start finetuning the CNN? (-1 = disable; never finetune, 0 = finetune from start)')
69 | parser.add_argument('--seq_per_img', type=int, default=5,
70 | help='number of captions to sample for each image during training. Done for efficiency since CNN forward pass is expensive. E.g. coco has 5 sents/image')
71 | parser.add_argument('--beam_size', type=int, default=1,
72 | help='used when sample_max = 1, indicates number of beams in beam search. Usually 2 or 3 works well. More is not better. Set this to 1 for faster runtime but a bit worse performance.')
73 |
74 | #Optimization: for the Language Model
75 | parser.add_argument('--optim', type=str, default='adam',
76 | help='what update to use? rmsprop|sgd|sgdmom|adagrad|adam')
77 | parser.add_argument('--learning_rate', type=float, default=4e-4,
78 | help='learning rate')
79 | parser.add_argument('--learning_rate_decay_start', type=int, default=-1,
80 | help='at what iteration to start decaying learning rate? (-1 = dont) (in epoch)')
81 | parser.add_argument('--learning_rate_decay_every', type=int, default=3,
82 | help='every how many iterations thereafter to drop LR?(in epoch)')
83 | parser.add_argument('--learning_rate_decay_rate', type=float, default=0.8,
84 | help='every how many iterations thereafter to drop LR?(in epoch)')
85 | parser.add_argument('--optim_alpha', type=float, default=0.9,
86 | help='alpha for adam')
87 | parser.add_argument('--optim_beta', type=float, default=0.999,
88 | help='beta used for adam')
89 | parser.add_argument('--optim_epsilon', type=float, default=1e-8,
90 | help='epsilon that goes into denominator for smoothing')
91 | parser.add_argument('--weight_decay', type=float, default=0,
92 | help='weight_decay')
93 |
94 | parser.add_argument('--scheduled_sampling_start', type=int, default=-1,
95 | help='at what iteration to start decay gt probability')
96 | parser.add_argument('--scheduled_sampling_increase_every', type=int, default=5,
97 | help='every how many iterations thereafter to gt probability')
98 | parser.add_argument('--scheduled_sampling_increase_prob', type=float, default=0.05,
99 | help='How much to update the prob')
100 | parser.add_argument('--scheduled_sampling_max_prob', type=float, default=0.25,
101 | help='Maximum scheduled sampling prob.')
102 |
103 |
104 | # Evaluation/Checkpointing
105 | parser.add_argument('--val_images_use', type=int, default=3200,
106 | help='how many images to use when periodically evaluating the validation loss? (-1 = all)')
107 | parser.add_argument('--save_checkpoint_every', type=int, default=2500,
108 | help='how often to save a model checkpoint (in iterations)?')
109 | parser.add_argument('--checkpoint_path', type=str, default='save',
110 | help='directory to store checkpointed models')
111 | parser.add_argument('--language_eval', type=int, default=0,
112 | help='Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.')
113 | parser.add_argument('--losses_log_every', type=int, default=25,
114 | help='How often do we snapshot losses, for inclusion in the progress dump? (0 = disable)')
115 | parser.add_argument('--load_best_score', type=int, default=1,
116 | help='Do we load previous best score when resuming training.')
117 |
118 | # misc
119 | parser.add_argument('--id', type=str, default='',
120 | help='an id identifying this run/job. used in cross-val and appended when writing progress files')
121 | parser.add_argument('--train_only', type=int, default=0,
122 | help='if true then use 80k, else use 110k')
123 |
124 |
125 | # Reward
126 | parser.add_argument('--cider_reward_weight', type=float, default=1,
127 | help='The reward weight from cider')
128 | parser.add_argument('--bleu_reward_weight', type=float, default=0,
129 | help='The reward weight from bleu4')
130 |
131 | # Transformer
132 | parser.add_argument('--label_smoothing', type=float, default=0,
133 | help='')
134 | parser.add_argument('--noamopt', action='store_true',
135 | help='')
136 | parser.add_argument('--noamopt_warmup', type=int, default=2000,
137 | help='')
138 | parser.add_argument('--noamopt_factor', type=float, default=1,
139 | help='')
140 |
141 | parser.add_argument('--reduce_on_plateau', action='store_true',
142 | help='')
143 |
144 | args = parser.parse_args()
145 |
146 | # Check if args are valid
147 | assert args.rnn_size > 0, "rnn_size should be greater than 0"
148 | assert args.num_layers > 0, "num_layers should be greater than 0"
149 | assert args.input_encoding_size > 0, "input_encoding_size should be greater than 0"
150 | assert args.batch_size > 0, "batch_size should be greater than 0"
151 | assert args.drop_prob_lm >= 0 and args.drop_prob_lm < 1, "drop_prob_lm should be between 0 and 1"
152 | assert args.seq_per_img > 0, "seq_per_img should be greater than 0"
153 | assert args.beam_size > 0, "beam_size should be greater than 0"
154 | assert args.save_checkpoint_every > 0, "save_checkpoint_every should be greater than 0"
155 | assert args.losses_log_every > 0, "losses_log_every should be greater than 0"
156 | assert args.language_eval == 0 or args.language_eval == 1, "language_eval should be 0 or 1"
157 | assert args.load_best_score == 0 or args.load_best_score == 1, "language_eval should be 0 or 1"
158 | assert args.train_only == 0 or args.train_only == 1, "language_eval should be 0 or 1"
159 |
160 | return args
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.optim as optim
8 |
9 | import numpy as np
10 |
11 | import time
12 | import os
13 | from six.moves import cPickle
14 |
15 | import opts
16 | import models
17 | from dataloader import *
18 | import eval_utils
19 | import misc.utils as utils
20 | from misc.rewards import init_scorer, get_self_critical_reward
21 |
22 | try:
23 | import tensorboardX as tb
24 | except ImportError:
25 | print("tensorboardX is not installed")
26 | tb = None
27 |
28 | def add_summary_value(writer, key, value, iteration):
29 | if writer:
30 | writer.add_scalar(key, value, iteration)
31 |
32 | def train(opt):
33 | # Deal with feature things before anything
34 | opt.use_att = utils.if_use_att(opt.caption_model)
35 | if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5
36 |
37 | loader = DataLoader(opt)
38 | opt.vocab_size = loader.vocab_size
39 | opt.seq_length = loader.seq_length
40 |
41 | tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)
42 |
43 | infos = {}
44 | histories = {}
45 | if opt.start_from is not None:
46 | # open old infos and check if models are compatible
47 | with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl')) as f:
48 | infos = cPickle.load(f)
49 | saved_model_opt = infos['opt']
50 | need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"]
51 | for checkme in need_be_same:
52 | assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme
53 |
54 | if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')):
55 | with open(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')) as f:
56 | histories = cPickle.load(f)
57 |
58 | iteration = infos.get('iter', 0)
59 | epoch = infos.get('epoch', 0)
60 |
61 | val_result_history = histories.get('val_result_history', {})
62 | loss_history = histories.get('loss_history', {})
63 | lr_history = histories.get('lr_history', {})
64 | ss_prob_history = histories.get('ss_prob_history', {})
65 |
66 | loader.iterators = infos.get('iterators', loader.iterators)
67 | loader.split_ix = infos.get('split_ix', loader.split_ix)
68 | if opt.load_best_score == 1:
69 | best_val_score = infos.get('best_val_score', None)
70 |
71 | model = models.setup(opt).cuda()
72 | dp_model = torch.nn.DataParallel(model)
73 |
74 | epoch_done = True
75 | # Assure in training mode
76 | dp_model.train()
77 |
78 | if opt.label_smoothing > 0:
79 | crit = utils.LabelSmoothing(smoothing=opt.label_smoothing)
80 | else:
81 | crit = utils.LanguageModelCriterion()
82 | rl_crit = utils.RewardCriterion()
83 |
84 | if opt.noamopt:
85 | assert opt.caption_model == 'transformer', 'noamopt can only work with transformer'
86 | optimizer = utils.get_std_opt(model, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup)
87 | optimizer._step = iteration
88 | elif opt.reduce_on_plateau:
89 | optimizer = utils.build_optimizer(model.parameters(), opt)
90 | optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3)
91 | else:
92 | optimizer = utils.build_optimizer(model.parameters(), opt)
93 | # Load the optimizer
94 | if vars(opt).get('start_from', None) is not None and os.path.isfile(os.path.join(opt.start_from,"optimizer.pth")):
95 | optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer.pth')))
96 |
97 | while True:
98 | if epoch_done:
99 | if not opt.noamopt and not opt.reduce_on_plateau:
100 | # Assign the learning rate
101 | if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
102 | frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
103 | decay_factor = opt.learning_rate_decay_rate ** frac
104 | opt.current_lr = opt.learning_rate * decay_factor
105 | else:
106 | opt.current_lr = opt.learning_rate
107 | utils.set_lr(optimizer, opt.current_lr) # set the decayed rate
108 | # Assign the scheduled sampling prob
109 | if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
110 | frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
111 | opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob)
112 | model.ss_prob = opt.ss_prob
113 |
114 | # If start self critical training
115 | if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
116 | sc_flag = True
117 | init_scorer(opt.cached_tokens)
118 | else:
119 | sc_flag = False
120 |
121 | epoch_done = False
122 |
123 | start = time.time()
124 | # Load data from train split (0)
125 | data = loader.get_batch('train')
126 | print('Read data:', time.time() - start)
127 |
128 | torch.cuda.synchronize()
129 | start = time.time()
130 |
131 | tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks']]
132 | tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
133 | fc_feats, att_feats, labels, masks, att_masks = tmp
134 |
135 | optimizer.zero_grad()
136 | if not sc_flag:
137 | loss = crit(dp_model(fc_feats, att_feats, labels, att_masks), labels[:,1:], masks[:,1:])
138 | else:
139 | gen_result, sample_logprobs = dp_model(fc_feats, att_feats, att_masks, opt={'sample_max':0}, mode='sample')
140 | reward = get_self_critical_reward(dp_model, fc_feats, att_feats, att_masks, data, gen_result, opt)
141 | loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda())
142 |
143 | loss.backward()
144 | utils.clip_gradient(optimizer, opt.grad_clip)
145 | optimizer.step()
146 | train_loss = loss.item()
147 | torch.cuda.synchronize()
148 | end = time.time()
149 | if not sc_flag:
150 | print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
151 | .format(iteration, epoch, train_loss, end - start))
152 | else:
153 | print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
154 | .format(iteration, epoch, np.mean(reward[:,0]), end - start))
155 |
156 | # Update the iteration and epoch
157 | iteration += 1
158 | if data['bounds']['wrapped']:
159 | epoch += 1
160 | epoch_done = True
161 |
162 | # Write the training loss summary
163 | if (iteration % opt.losses_log_every == 0):
164 | add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration)
165 | if opt.noamopt:
166 | opt.current_lr = optimizer.rate()
167 | elif opt.reduce_on_plateau:
168 | opt.current_lr = optimizer.current_lr
169 | add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration)
170 | add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)
171 | if sc_flag:
172 | add_summary_value(tb_summary_writer, 'avg_reward', np.mean(reward[:,0]), iteration)
173 |
174 | loss_history[iteration] = train_loss if not sc_flag else np.mean(reward[:,0])
175 | lr_history[iteration] = opt.current_lr
176 | ss_prob_history[iteration] = model.ss_prob
177 |
178 | # make evaluation on validation set, and save model
179 | if (iteration % opt.save_checkpoint_every == 0):
180 | # eval model
181 | eval_kwargs = {'split': 'val',
182 | 'dataset': opt.input_json}
183 | eval_kwargs.update(vars(opt))
184 | val_loss, predictions, lang_stats = eval_utils.eval_split(dp_model, crit, loader, eval_kwargs)
185 |
186 | if opt.reduce_on_plateau:
187 | if 'CIDEr' in lang_stats:
188 | optimizer.scheduler_step(-lang_stats['CIDEr'])
189 | else:
190 | optimizer.scheduler_step(val_loss)
191 |
192 | # Write validation result into summary
193 | add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration)
194 | for k,v in lang_stats.items():
195 | add_summary_value(tb_summary_writer, k, v, iteration)
196 | val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}
197 |
198 | # Save model if is improving on validation result
199 | if opt.language_eval == 1:
200 | current_score = lang_stats['CIDEr']
201 | else:
202 | current_score = - val_loss
203 |
204 | best_flag = False
205 | if True: # if true
206 | if best_val_score is None or current_score > best_val_score:
207 | best_val_score = current_score
208 | best_flag = True
209 | checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth')
210 | torch.save(model.state_dict(), checkpoint_path)
211 | print("model saved to {}".format(checkpoint_path))
212 | optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth')
213 | torch.save(optimizer.state_dict(), optimizer_path)
214 |
215 | # Dump miscalleous informations
216 | infos['iter'] = iteration
217 | infos['epoch'] = epoch
218 | infos['iterators'] = loader.iterators
219 | infos['split_ix'] = loader.split_ix
220 | infos['best_val_score'] = best_val_score
221 | infos['opt'] = opt
222 | infos['vocab'] = loader.get_vocab()
223 |
224 | histories['val_result_history'] = val_result_history
225 | histories['loss_history'] = loss_history
226 | histories['lr_history'] = lr_history
227 | histories['ss_prob_history'] = ss_prob_history
228 | with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'.pkl'), 'wb') as f:
229 | cPickle.dump(infos, f)
230 | with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+'.pkl'), 'wb') as f:
231 | cPickle.dump(histories, f)
232 |
233 | if best_flag:
234 | checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best.pth')
235 | torch.save(model.state_dict(), checkpoint_path)
236 | print("model saved to {}".format(checkpoint_path))
237 | with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'-best.pkl'), 'wb') as f:
238 | cPickle.dump(infos, f)
239 |
240 | # Stop if reaching max epochs
241 | if epoch >= opt.max_epochs and opt.max_epochs != -1:
242 | break
243 |
244 | opt = opts.parse_opt()
245 | train(opt)
246 |
--------------------------------------------------------------------------------
/models/OldModel.py:
--------------------------------------------------------------------------------
1 | # This file contains ShowAttendTell and AllImg model
2 |
3 | # ShowAttendTell is from Show, Attend and Tell: Neural Image Caption Generation with Visual Attention
4 | # https://arxiv.org/abs/1502.03044
5 |
6 | # AllImg is a model where
7 | # img feature is concatenated with word embedding at every time step as the input of lstm
8 | from __future__ import absolute_import
9 | from __future__ import division
10 | from __future__ import print_function
11 |
12 | import torch
13 | import torch.nn as nn
14 | import torch.nn.functional as F
15 | from torch.autograd import *
16 | import misc.utils as utils
17 |
18 | from .CaptionModel import CaptionModel
19 |
20 | class OldModel(CaptionModel):
21 | def __init__(self, opt):
22 | super(OldModel, self).__init__()
23 | self.vocab_size = opt.vocab_size
24 | self.input_encoding_size = opt.input_encoding_size
25 | self.rnn_type = opt.rnn_type
26 | self.rnn_size = opt.rnn_size
27 | self.num_layers = opt.num_layers
28 | self.drop_prob_lm = opt.drop_prob_lm
29 | self.seq_length = opt.seq_length
30 | self.fc_feat_size = opt.fc_feat_size
31 | self.att_feat_size = opt.att_feat_size
32 |
33 | self.ss_prob = 0.0 # Schedule sampling probability
34 |
35 | self.linear = nn.Linear(self.fc_feat_size, self.num_layers * self.rnn_size) # feature to rnn_size
36 | self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size)
37 | self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1)
38 | self.dropout = nn.Dropout(self.drop_prob_lm)
39 |
40 | self.init_weights()
41 |
42 | def init_weights(self):
43 | initrange = 0.1
44 | self.embed.weight.data.uniform_(-initrange, initrange)
45 | self.logit.bias.data.fill_(0)
46 | self.logit.weight.data.uniform_(-initrange, initrange)
47 |
48 | def init_hidden(self, fc_feats):
49 | image_map = self.linear(fc_feats).view(-1, self.num_layers, self.rnn_size).transpose(0, 1)
50 | if self.rnn_type == 'lstm':
51 | return (image_map, image_map)
52 | else:
53 | return image_map
54 |
55 | def forward(self, fc_feats, att_feats, seq):
56 | batch_size = fc_feats.size(0)
57 | state = self.init_hidden(fc_feats)
58 |
59 | outputs = []
60 |
61 | for i in range(seq.size(1) - 1):
62 | if self.training and i >= 1 and self.ss_prob > 0.0: # otherwiste no need to sample
63 | sample_prob = fc_feats.data.new(batch_size).uniform_(0, 1)
64 | sample_mask = sample_prob < self.ss_prob
65 | if sample_mask.sum() == 0:
66 | it = seq[:, i].clone()
67 | else:
68 | sample_ind = sample_mask.nonzero().view(-1)
69 | it = seq[:, i].data.clone()
70 | #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1)
71 | #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1))
72 | prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1)
73 | it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind))
74 | else:
75 | it = seq[:, i].clone()
76 | # break if all the sequences end
77 | if i >= 1 and seq[:, i].sum() == 0:
78 | break
79 |
80 | xt = self.embed(it)
81 |
82 | output, state = self.core(xt, fc_feats, att_feats, state)
83 | output = F.log_softmax(self.logit(self.dropout(output)), dim=1)
84 | outputs.append(output)
85 |
86 | return torch.cat([_.unsqueeze(1) for _ in outputs], 1)
87 |
88 | def get_logprobs_state(self, it, tmp_fc_feats, tmp_att_feats, state):
89 | # 'it' contains a word index
90 | xt = self.embed(it)
91 |
92 | output, state = self.core(xt, tmp_fc_feats, tmp_att_feats, state)
93 | logprobs = F.log_softmax(self.logit(self.dropout(output)), dim=1)
94 |
95 | return logprobs, state
96 |
97 | def sample_beam(self, fc_feats, att_feats, opt={}):
98 | beam_size = opt.get('beam_size', 10)
99 | batch_size = fc_feats.size(0)
100 |
101 | assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
102 | seq = torch.LongTensor(self.seq_length, batch_size).zero_()
103 | seqLogprobs = torch.FloatTensor(self.seq_length, batch_size)
104 | # lets process every image independently for now, for simplicity
105 |
106 | self.done_beams = [[] for _ in range(batch_size)]
107 | for k in range(batch_size):
108 | tmp_fc_feats = fc_feats[k:k+1].expand(beam_size, self.fc_feat_size)
109 | tmp_att_feats = att_feats[k:k+1].expand(*((beam_size,)+att_feats.size()[1:])).contiguous()
110 |
111 | state = self.init_hidden(tmp_fc_feats)
112 |
113 | beam_seq = torch.LongTensor(self.seq_length, beam_size).zero_()
114 | beam_seq_logprobs = torch.FloatTensor(self.seq_length, beam_size).zero_()
115 | beam_logprobs_sum = torch.zeros(beam_size) # running sum of logprobs for each beam
116 | done_beams = []
117 | for t in range(1):
118 | if t == 0: # input
119 | it = fc_feats.data.new(beam_size).long().zero_()
120 | xt = self.embed(it)
121 |
122 | output, state = self.core(xt, tmp_fc_feats, tmp_att_feats, state)
123 | logprobs = F.log_softmax(self.logit(self.dropout(output)), dim=1)
124 |
125 | self.done_beams[k] = self.beam_search(state, logprobs, tmp_fc_feats, tmp_att_feats, opt=opt)
126 | seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
127 | seqLogprobs[:, k] = self.done_beams[k][0]['logps']
128 | # return the samples and their log likelihoods
129 | return seq.transpose(0, 1), seqLogprobs.transpose(0, 1)
130 |
131 | def sample(self, fc_feats, att_feats, opt={}):
132 | sample_max = opt.get('sample_max', 1)
133 | beam_size = opt.get('beam_size', 1)
134 | temperature = opt.get('temperature', 1.0)
135 | if beam_size > 1:
136 | return self.sample_beam(fc_feats, att_feats, opt)
137 |
138 | batch_size = fc_feats.size(0)
139 | state = self.init_hidden(fc_feats)
140 |
141 | seq = []
142 | seqLogprobs = []
143 | for t in range(self.seq_length + 1):
144 | if t == 0: # input
145 | it = fc_feats.data.new(batch_size).long().zero_()
146 | elif sample_max:
147 | sampleLogprobs, it = torch.max(logprobs.data, 1)
148 | it = it.view(-1).long()
149 | else:
150 | if temperature == 1.0:
151 | prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1)
152 | else:
153 | # scale logprobs by temperature
154 | prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu()
155 | it = torch.multinomial(prob_prev, 1).cuda()
156 | sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions
157 | it = it.view(-1).long() # and flatten indices for downstream processing
158 |
159 | xt = self.embed(it)
160 |
161 | if t >= 1:
162 | # stop when all finished
163 | if t == 1:
164 | unfinished = it > 0
165 | else:
166 | unfinished = unfinished * (it > 0)
167 | if unfinished.sum() == 0:
168 | break
169 | it = it * unfinished.type_as(it)
170 | seq.append(it) #seq[t] the input of t+2 time step
171 | seqLogprobs.append(sampleLogprobs.view(-1))
172 |
173 | output, state = self.core(xt, fc_feats, att_feats, state)
174 | logprobs = F.log_softmax(self.logit(self.dropout(output)), dim=1)
175 |
176 | return torch.cat([_.unsqueeze(1) for _ in seq], 1), torch.cat([_.unsqueeze(1) for _ in seqLogprobs], 1)
177 |
178 |
179 | class ShowAttendTellCore(nn.Module):
180 | def __init__(self, opt):
181 | super(ShowAttendTellCore, self).__init__()
182 | self.input_encoding_size = opt.input_encoding_size
183 | self.rnn_type = opt.rnn_type
184 | self.rnn_size = opt.rnn_size
185 | self.num_layers = opt.num_layers
186 | self.drop_prob_lm = opt.drop_prob_lm
187 | self.fc_feat_size = opt.fc_feat_size
188 | self.att_feat_size = opt.att_feat_size
189 | self.att_hid_size = opt.att_hid_size
190 |
191 | self.rnn = getattr(nn, self.rnn_type.upper())(self.input_encoding_size + self.att_feat_size,
192 | self.rnn_size, self.num_layers, bias=False, dropout=self.drop_prob_lm)
193 |
194 | if self.att_hid_size > 0:
195 | self.ctx2att = nn.Linear(self.att_feat_size, self.att_hid_size)
196 | self.h2att = nn.Linear(self.rnn_size, self.att_hid_size)
197 | self.alpha_net = nn.Linear(self.att_hid_size, 1)
198 | else:
199 | self.ctx2att = nn.Linear(self.att_feat_size, 1)
200 | self.h2att = nn.Linear(self.rnn_size, 1)
201 |
202 | def forward(self, xt, fc_feats, att_feats, state):
203 | att_size = att_feats.numel() // att_feats.size(0) // self.att_feat_size
204 | att = att_feats.view(-1, self.att_feat_size)
205 | if self.att_hid_size > 0:
206 | att = self.ctx2att(att) # (batch * att_size) * att_hid_size
207 | att = att.view(-1, att_size, self.att_hid_size) # batch * att_size * att_hid_size
208 | att_h = self.h2att(state[0][-1]) # batch * att_hid_size
209 | att_h = att_h.unsqueeze(1).expand_as(att) # batch * att_size * att_hid_size
210 | dot = att + att_h # batch * att_size * att_hid_size
211 | dot = F.tanh(dot) # batch * att_size * att_hid_size
212 | dot = dot.view(-1, self.att_hid_size) # (batch * att_size) * att_hid_size
213 | dot = self.alpha_net(dot) # (batch * att_size) * 1
214 | dot = dot.view(-1, att_size) # batch * att_size
215 | else:
216 | att = self.ctx2att(att)(att) # (batch * att_size) * 1
217 | att = att.view(-1, att_size) # batch * att_size
218 | att_h = self.h2att(state[0][-1]) # batch * 1
219 | att_h = att_h.expand_as(att) # batch * att_size
220 | dot = att_h + att # batch * att_size
221 |
222 | weight = F.softmax(dot, dim=1)
223 | att_feats_ = att_feats.view(-1, att_size, self.att_feat_size) # batch * att_size * att_feat_size
224 | att_res = torch.bmm(weight.unsqueeze(1), att_feats_).squeeze(1) # batch * att_feat_size
225 |
226 | output, state = self.rnn(torch.cat([xt, att_res], 1).unsqueeze(0), state)
227 | return output.squeeze(0), state
228 |
229 | class AllImgCore(nn.Module):
230 | def __init__(self, opt):
231 | super(AllImgCore, self).__init__()
232 | self.input_encoding_size = opt.input_encoding_size
233 | self.rnn_type = opt.rnn_type
234 | self.rnn_size = opt.rnn_size
235 | self.num_layers = opt.num_layers
236 | self.drop_prob_lm = opt.drop_prob_lm
237 | self.fc_feat_size = opt.fc_feat_size
238 |
239 | self.rnn = getattr(nn, self.rnn_type.upper())(self.input_encoding_size + self.fc_feat_size,
240 | self.rnn_size, self.num_layers, bias=False, dropout=self.drop_prob_lm)
241 |
242 | def forward(self, xt, fc_feats, att_feats, state):
243 | output, state = self.rnn(torch.cat([xt, fc_feats], 1).unsqueeze(0), state)
244 | return output.squeeze(0), state
245 |
246 | class ShowAttendTellModel(OldModel):
247 | def __init__(self, opt):
248 | super(ShowAttendTellModel, self).__init__(opt)
249 | self.core = ShowAttendTellCore(opt)
250 |
251 | class AllImgModel(OldModel):
252 | def __init__(self, opt):
253 | super(AllImgModel, self).__init__(opt)
254 | self.core = AllImgCore(opt)
255 |
256 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Transformer for captioning
2 |
3 | # Note: This repository is deprecated, and the code has been merged to [self-critical.pytorch](https://github.com/ruotianluo/self-critical.pytorch). The same training script should work for self-critical too.
4 |
5 | This is an experiment to use transformer model to do captioning. Most of the code is copy from [Harvard detailed tutorial for transformer(http://nlp.seas.harvard.edu/2018/04/03/attention.html).
6 |
7 | Also, notice, this repository is a fork of my [self-critical.pytorch](https://github.com/ruotianluo/self-critical.pytorch) repository. Most of the code are shared.
8 |
9 | The addition to self-critical.pytorch is following:
10 | - transformer model
11 | - Add warmup adam for training transformer (important)
12 | - Add reduce_on_paltaeu (not really useful)
13 |
14 | A training script that could achieve 1.25 on validation set without beam search.
15 |
16 | ```bash
17 | id="transformer"
18 | ckpt_path="log_"$id
19 | if [ ! -d $ckpt_path ]; then
20 | mkdir $ckpt_path
21 | fi
22 | if [ ! -f $ckpt_path"/infos_"$id".pkl" ]; then
23 | start_from=""
24 | else
25 | start_from="--start_from "$ckpt_path
26 | fi
27 |
28 | python train.py --id $id --caption_model transformer --noamopt --noamopt_warmup 20000 --label_smoothing 0.0 --input_json data/cocotalk.json --input_label_h5 data/cocotalk_label.h5 --input_fc_dir data/cocobu_fc --input_att_dir data/cocobu_att --seq_per_img 5 --batch_size 10 --beam_size 1 --learning_rate 5e-4 --num_layers 6 --input_encoding_size 512 --rnn_size 2048 --learning_rate_decay_start 0 --scheduled_sampling_start 0 --checkpoint_path $ckpt_path $start_from --save_checkpoint_every 3000 --language_eval 1 --val_images_use 5000 --max_epochs 15
29 |
30 | python train.py --id $id --caption_model transformer --reduce_on_plateau --input_json data/cocotalk.json --input_label_h5 data/cocotalk_label.h5 --input_fc_dir data/cocobu_fc --input_att_dir data/cocobu_att --input_box_dir data/cocobu_box --seq_per_img 5 --batch_size 10 --beam_size 1 --learning_rate 1e-5 --num_layers 6 --input_encoding_size 512 --rnn_size 2048 --checkpoint_path $ckpt_path $start_from --save_checkpoint_every 3000 --language_eval 1 --val_images_use 5000 --self_critical_after 10
31 | ```
32 |
33 | **Notice**: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
34 | ```
35 | N=num_layers
36 | d_model=input_encoding_size
37 | d_ff=rnn_size
38 | h is always 8
39 | ```
40 |
41 |
42 | # Self-critical Sequence Training for Image Captioning (+ misc.)
43 |
44 | This repository includes the unofficial implementation [Self-critical Sequence Training for Image Captioning](https://arxiv.org/abs/1612.00563) and [Bottom-Up and Top-Down Attention for Image Captioning and Visual Question Answering](https://arxiv.org/abs/1707.07998).
45 |
46 | The author of SCST helped me a lot when I tried to replicate the result. Great thanks. The att2in2 model can achieve more than 1.20 Cider score on Karpathy's test split (with self-critical training, bottom-up feature, large rnn hidden size, without ensemble)
47 |
48 | This is based on my [ImageCaptioning.pytorch](https://github.com/ruotianluo/ImageCaptioning.pytorch) repository. The modifications is:
49 | - Self critical training.
50 | - Bottom up feature support from [ref](https://arxiv.org/abs/1707.07998). (Evaluation on arbitrary images is not supported.)
51 | - Ensemble
52 | - Multi-GPU training
53 |
54 | ## Requirements
55 | Python 2.7 (because there is no [coco-caption](https://github.com/tylin/coco-caption) version for python 3)
56 | PyTorch 0.4 (along with torchvision)
57 | cider (already been added as a submodule)
58 |
59 | (**Skip if you are using bottom-up feature**): If you want to use resnet to extract image features, you need to download pretrained resnet model for both training and evaluation. The models can be downloaded from [here](https://drive.google.com/open?id=0B7fNdx_jAqhtbVYzOURMdDNHSGM), and should be placed in `data/imagenet_weights`.
60 |
61 | ## Pretrained models (using resnet101 feature)
62 | Pretrained models are provided [here](https://drive.google.com/open?id=0B7fNdx_jAqhtdE1JRXpmeGJudTg). And the performances of each model will be maintained in this [issue](https://github.com/ruotianluo/neuraltalk2.pytorch/issues/10).
63 |
64 | If you want to do evaluation only, you can then follow [this section](#generate-image-captions) after downloading the pretrained models (and also the pretrained resnet101).
65 |
66 | ## Train your own network on COCO
67 |
68 | ### Download COCO captions and preprocess them
69 |
70 | Download preprocessed coco captions from [link](http://cs.stanford.edu/people/karpathy/deepimagesent/caption_datasets.zip) from Karpathy's homepage. Extract `dataset_coco.json` from the zip file and copy it in to `data/`. This file provides preprocessed captions and also standard train-val-test splits.
71 |
72 | Then do:
73 |
74 | ```bash
75 | $ python scripts/prepro_labels.py --input_json data/dataset_coco.json --output_json data/cocotalk.json --output_h5 data/cocotalk
76 | ```
77 |
78 | `prepro_labels.py` will map all words that occur <= 5 times to a special `UNK` token, and create a vocabulary for all the remaining words. The image information and vocabulary are dumped into `data/cocotalk.json` and discretized caption data are dumped into `data/cocotalk_label.h5`.
79 |
80 | ### Download COCO dataset and pre-extract the image features (Skip if you are using bottom-up feature)
81 |
82 | Download the coco images from [link](http://mscoco.org/dataset/#download). We need 2014 training images and 2014 val. images. You should put the `train2014/` and `val2014/` in the same directory, denoted as `$IMAGE_ROOT`.
83 |
84 | Then:
85 |
86 | ```
87 | $ python scripts/prepro_feats.py --input_json data/dataset_coco.json --output_dir data/cocotalk --images_root $IMAGE_ROOT
88 | ```
89 |
90 |
91 | `prepro_feats.py` extract the resnet101 features (both fc feature and last conv feature) of each image. The features are saved in `data/cocotalk_fc` and `data/cocotalk_att`, and resulting files are about 200GB.
92 |
93 | (Check the prepro scripts for more options, like other resnet models or other attention sizes.)
94 |
95 | **Warning**: the prepro script will fail with the default MSCOCO data because one of their images is corrupted. See [this issue](https://github.com/karpathy/neuraltalk2/issues/4) for the fix, it involves manually replacing one image in the dataset.
96 |
97 | ### Download Bottom-up features (Skip if you are using resnet features)
98 |
99 | Download pre-extracted feature from [link](https://github.com/peteanderson80/bottom-up-attention). You can either download adaptive one or fixed one.
100 |
101 | For example:
102 | ```
103 | mkdir data/bu_data; cd data/bu_data
104 | wget https://storage.googleapis.com/bottom-up-attention/trainval.zip
105 | unzip trainval.zip
106 |
107 | ```
108 |
109 | Then:
110 |
111 | ```bash
112 | python script/make_bu_data.py --output_dir data/cocobu
113 | ```
114 |
115 | This will create `data/cocobu_fc`, `data/cocobu_att` and `data/cocobu_box`. If you want to use bottom-up feature, you can just follow the following steps and replace all cocotalk with cocobu.
116 |
117 | ### Start training
118 |
119 | ```bash
120 | $ python train.py --id fc --caption_model fc --input_json data/cocotalk.json --input_fc_dir data/cocotalk_fc --input_att_dir data/cocotalk_att --input_label_h5 data/cocotalk_label.h5 --batch_size 10 --learning_rate 5e-4 --learning_rate_decay_start 0 --scheduled_sampling_start 0 --checkpoint_path log_fc --save_checkpoint_every 6000 --val_images_use 5000 --max_epochs 30
121 | ```
122 |
123 | The train script will dump checkpoints into the folder specified by `--checkpoint_path` (default = `save/`). We only save the best-performing checkpoint on validation and the latest checkpoint to save disk space.
124 |
125 | To resume training, you can specify `--start_from` option to be the path saving `infos.pkl` and `model.pth` (usually you could just set `--start_from` and `--checkpoint_path` to be the same).
126 |
127 | If you have tensorflow, the loss histories are automatically dumped into `--checkpoint_path`, and can be visualized using tensorboard.
128 |
129 | The current command use scheduled sampling, you can also set scheduled_sampling_start to -1 to turn off scheduled sampling.
130 |
131 | If you'd like to evaluate BLEU/METEOR/CIDEr scores during training in addition to validation cross entropy loss, use `--language_eval 1` option, but don't forget to download the [coco-caption code](https://github.com/tylin/coco-caption) into `coco-caption` directory.
132 |
133 | For more options, see `opts.py`.
134 |
135 | **A few notes on training.** To give you an idea, with the default settings one epoch of MS COCO images is about 11000 iterations. After 1 epoch of training results in validation loss ~2.5 and CIDEr score of ~0.68. By iteration 60,000 CIDEr climbs up to about ~0.84 (validation loss at about 2.4 (under scheduled sampling)).
136 |
137 | ### Train using self critical
138 |
139 | First you should preprocess the dataset and get the cache for calculating cider score:
140 | ```
141 | $ python scripts/prepro_ngrams.py --input_json .../dataset_coco.json --dict_json data/cocotalk.json --output_pkl data/coco-train --split train
142 | ```
143 |
144 | Then, copy the model from the pretrained model using cross entropy. (It's not mandatory to copy the model, just for back-up)
145 | ```
146 | $ bash scripts/copy_model.sh fc fc_rl
147 | ```
148 |
149 | Then
150 | ```bash
151 | $ python train.py --id fc_rl --caption_model fc --input_json data/cocotalk.json --input_fc_dir data/cocotalk_fc --input_att_dir data/cocotalk_att --input_label_h5 data/cocotalk_label.h5 --batch_size 10 --learning_rate 5e-5 --start_from log_fc_rl --checkpoint_path log_fc_rl --save_checkpoint_every 6000 --language_eval 1 --val_images_use 5000 --self_critical_after 30
152 | ```
153 |
154 | You will see a huge boost on Cider score, : ).
155 |
156 | **A few notes on training.** Starting self-critical training after 30 epochs, the CIDEr score goes up to 1.05 after 600k iterations (including the 30 epochs pertraining).
157 |
158 | ### Caption images after training
159 |
160 | ## Generate image captions
161 |
162 | ### Evaluate on raw images
163 | Now place all your images of interest into a folder, e.g. `blah`, and run
164 | the eval script:
165 |
166 | ```bash
167 | $ python eval.py --model model.pth --infos_path infos.pkl --image_folder blah --num_images 10
168 | ```
169 |
170 | This tells the `eval` script to run up to 10 images from the given folder. If you have a big GPU you can speed up the evaluation by increasing `batch_size`. Use `--num_images -1` to process all images. The eval script will create an `vis.json` file inside the `vis` folder, which can then be visualized with the provided HTML interface:
171 |
172 | ```bash
173 | $ cd vis
174 | $ python -m SimpleHTTPServer
175 | ```
176 |
177 | Now visit `localhost:8000` in your browser and you should see your predicted captions.
178 |
179 | ### Evaluate on Karpathy's test split
180 |
181 | ```bash
182 | $ python eval.py --dump_images 0 --num_images 5000 --model model.pth --infos_path infos.pkl --language_eval 1
183 | ```
184 |
185 | The defualt split to evaluate is test. The default inference method is greedy decoding (`--sample_max 1`), to sample from the posterior, set `--sample_max 0`.
186 |
187 | **Beam Search**. Beam search can increase the performance of the search for greedy decoding sequence by ~5%. However, this is a little more expensive. To turn on the beam search, use `--beam_size N`, N should be greater than 1.
188 |
189 | ## Miscellanea
190 | **Using cpu**. The code is currently defaultly using gpu; there is even no option for switching. If someone highly needs a cpu model, please open an issue; I can potentially create a cpu checkpoint and modify the eval.py to run the model on cpu. However, there's no point using cpu to train the model.
191 |
192 | **Train on other dataset**. It should be trivial to port if you can create a file like `dataset_coco.json` for your own dataset.
193 |
194 | **Live demo**. Not supported now. Welcome pull request.
195 |
196 | ## For more advanced features:
197 |
198 | Checkout `ADVANCED.md`.
199 |
200 | ## Reference
201 |
202 | If you find this repo useful, please consider citing (no obligation at all):
203 |
204 | ```
205 | @article{luo2018discriminability,
206 | title={Discriminability objective for training descriptive captions},
207 | author={Luo, Ruotian and Price, Brian and Cohen, Scott and Shakhnarovich, Gregory},
208 | journal={arXiv preprint arXiv:1803.04376},
209 | year={2018}
210 | }
211 | ```
212 |
213 | Of course, please cite the original paper of models you are using (You can find references in the model files).
214 |
215 | ## Acknowledgements
216 |
217 | Thanks the original [neuraltalk2](https://github.com/karpathy/neuraltalk2) and awesome PyTorch team.
218 |
--------------------------------------------------------------------------------
/dataloader.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import json
6 | import h5py
7 | import os
8 | import numpy as np
9 | import random
10 |
11 | import torch
12 | import torch.utils.data as data
13 |
14 | import multiprocessing
15 |
16 | class DataLoader(data.Dataset):
17 |
18 | def reset_iterator(self, split):
19 | del self._prefetch_process[split]
20 | self._prefetch_process[split] = BlobFetcher(split, self, split=='train')
21 | self.iterators[split] = 0
22 |
23 | def get_vocab_size(self):
24 | return self.vocab_size
25 |
26 | def get_vocab(self):
27 | return self.ix_to_word
28 |
29 | def get_seq_length(self):
30 | return self.seq_length
31 |
32 | def __init__(self, opt):
33 | self.opt = opt
34 | self.batch_size = self.opt.batch_size
35 | self.seq_per_img = opt.seq_per_img
36 |
37 | # feature related options
38 | self.use_att = getattr(opt, 'use_att', True)
39 | self.use_box = getattr(opt, 'use_box', 0)
40 | self.norm_att_feat = getattr(opt, 'norm_att_feat', 0)
41 | self.norm_box_feat = getattr(opt, 'norm_box_feat', 0)
42 |
43 | # load the json file which contains additional information about the dataset
44 | print('DataLoader loading json file: ', opt.input_json)
45 | self.info = json.load(open(self.opt.input_json))
46 | self.ix_to_word = self.info['ix_to_word']
47 | self.vocab_size = len(self.ix_to_word)
48 | print('vocab size is ', self.vocab_size)
49 |
50 | # open the hdf5 file
51 | print('DataLoader loading h5 file: ', opt.input_fc_dir, opt.input_att_dir, opt.input_box_dir, opt.input_label_h5)
52 | self.h5_label_file = h5py.File(self.opt.input_label_h5, 'r', driver='core')
53 |
54 | self.input_fc_dir = self.opt.input_fc_dir
55 | self.input_att_dir = self.opt.input_att_dir
56 | self.input_box_dir = self.opt.input_box_dir
57 |
58 | # load in the sequence data
59 | seq_size = self.h5_label_file['labels'].shape
60 | self.seq_length = seq_size[1]
61 | print('max sequence length in data is', self.seq_length)
62 | # load the pointers in full to RAM (should be small enough)
63 | self.label_start_ix = self.h5_label_file['label_start_ix'][:]
64 | self.label_end_ix = self.h5_label_file['label_end_ix'][:]
65 |
66 | self.num_images = self.label_start_ix.shape[0]
67 | print('read %d image features' %(self.num_images))
68 |
69 | # separate out indexes for each of the provided splits
70 | self.split_ix = {'train': [], 'val': [], 'test': []}
71 | for ix in range(len(self.info['images'])):
72 | img = self.info['images'][ix]
73 | if img['split'] == 'train':
74 | self.split_ix['train'].append(ix)
75 | elif img['split'] == 'val':
76 | self.split_ix['val'].append(ix)
77 | elif img['split'] == 'test':
78 | self.split_ix['test'].append(ix)
79 | elif opt.train_only == 0: # restval
80 | self.split_ix['train'].append(ix)
81 |
82 | print('assigned %d images to split train' %len(self.split_ix['train']))
83 | print('assigned %d images to split val' %len(self.split_ix['val']))
84 | print('assigned %d images to split test' %len(self.split_ix['test']))
85 |
86 | self.iterators = {'train': 0, 'val': 0, 'test': 0}
87 |
88 | self._prefetch_process = {} # The three prefetch process
89 | for split in self.iterators.keys():
90 | self._prefetch_process[split] = BlobFetcher(split, self, split=='train')
91 | # Terminate the child process when the parent exists
92 | def cleanup():
93 | print('Terminating BlobFetcher')
94 | for split in self.iterators.keys():
95 | del self._prefetch_process[split]
96 | import atexit
97 | atexit.register(cleanup)
98 |
99 | def get_captions(self, ix, seq_per_img):
100 | # fetch the sequence labels
101 | ix1 = self.label_start_ix[ix] - 1 #label_start_ix starts from 1
102 | ix2 = self.label_end_ix[ix] - 1
103 | ncap = ix2 - ix1 + 1 # number of captions available for this image
104 | assert ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t'
105 |
106 | if ncap < seq_per_img:
107 | # we need to subsample (with replacement)
108 | seq = np.zeros([seq_per_img, self.seq_length], dtype = 'int')
109 | for q in range(seq_per_img):
110 | ixl = random.randint(ix1,ix2)
111 | seq[q, :] = self.h5_label_file['labels'][ixl, :self.seq_length]
112 | else:
113 | ixl = random.randint(ix1, ix2 - seq_per_img + 1)
114 | seq = self.h5_label_file['labels'][ixl: ixl + seq_per_img, :self.seq_length]
115 |
116 | return seq
117 |
118 | def get_batch(self, split, batch_size=None, seq_per_img=None):
119 | batch_size = batch_size or self.batch_size
120 | seq_per_img = seq_per_img or self.seq_per_img
121 |
122 | fc_batch = [] # np.ndarray((batch_size * seq_per_img, self.opt.fc_feat_size), dtype = 'float32')
123 | att_batch = [] # np.ndarray((batch_size * seq_per_img, 14, 14, self.opt.att_feat_size), dtype = 'float32')
124 | label_batch = np.zeros([batch_size * seq_per_img, self.seq_length + 2], dtype = 'int')
125 | mask_batch = np.zeros([batch_size * seq_per_img, self.seq_length + 2], dtype = 'float32')
126 |
127 | wrapped = False
128 |
129 | infos = []
130 | gts = []
131 |
132 | for i in range(batch_size):
133 | # fetch image
134 | tmp_fc, tmp_att,\
135 | ix, tmp_wrapped = self._prefetch_process[split].get()
136 | fc_batch.append(tmp_fc)
137 | att_batch.append(tmp_att)
138 |
139 | label_batch[i * seq_per_img : (i + 1) * seq_per_img, 1 : self.seq_length + 1] = self.get_captions(ix, seq_per_img)
140 |
141 | if tmp_wrapped:
142 | wrapped = True
143 |
144 | # Used for reward evaluation
145 | gts.append(self.h5_label_file['labels'][self.label_start_ix[ix] - 1: self.label_end_ix[ix]])
146 |
147 | # record associated info as well
148 | info_dict = {}
149 | info_dict['ix'] = ix
150 | info_dict['id'] = self.info['images'][ix]['id']
151 | info_dict['file_path'] = self.info['images'][ix]['file_path']
152 | infos.append(info_dict)
153 |
154 | # #sort by att_feat length
155 | # fc_batch, att_batch, label_batch, gts, infos = \
156 | # zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: len(x[1]), reverse=True))
157 | fc_batch, att_batch, label_batch, gts, infos = \
158 | zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: 0, reverse=True))
159 | data = {}
160 | data['fc_feats'] = np.stack(reduce(lambda x,y:x+y, [[_]*seq_per_img for _ in fc_batch]))
161 | # merge att_feats
162 | max_att_len = max([_.shape[0] for _ in att_batch])
163 | data['att_feats'] = np.zeros([len(att_batch)*seq_per_img, max_att_len, att_batch[0].shape[1]], dtype = 'float32')
164 | for i in range(len(att_batch)):
165 | data['att_feats'][i*seq_per_img:(i+1)*seq_per_img, :att_batch[i].shape[0]] = att_batch[i]
166 | data['att_masks'] = np.zeros(data['att_feats'].shape[:2], dtype='float32')
167 | for i in range(len(att_batch)):
168 | data['att_masks'][i*seq_per_img:(i+1)*seq_per_img, :att_batch[i].shape[0]] = 1
169 | # set att_masks to None if attention features have same length
170 | if data['att_masks'].sum() == data['att_masks'].size:
171 | data['att_masks'] = None
172 |
173 | data['labels'] = np.vstack(label_batch)
174 | # generate mask
175 | nonzeros = np.array(list(map(lambda x: (x != 0).sum()+2, data['labels'])))
176 | for ix, row in enumerate(mask_batch):
177 | row[:nonzeros[ix]] = 1
178 | data['masks'] = mask_batch
179 |
180 | data['gts'] = gts # all ground truth captions of each images
181 | data['bounds'] = {'it_pos_now': self.iterators[split], 'it_max': len(self.split_ix[split]), 'wrapped': wrapped}
182 | data['infos'] = infos
183 |
184 | return data
185 |
186 | # It's not coherent to make DataLoader a subclass of Dataset, but essentially, we only need to implement the following to functions,
187 | # so that the torch.utils.data.DataLoader can load the data according the index.
188 | # However, it's minimum change to switch to pytorch data loading.
189 | def __getitem__(self, index):
190 | """This function returns a tuple that is further passed to collate_fn
191 | """
192 | ix = index #self.split_ix[index]
193 | if self.use_att:
194 | att_feat = np.load(os.path.join(self.input_att_dir, str(self.info['images'][ix]['id']) + '.npz'))['feat']
195 | # Reshape to K x C
196 | att_feat = att_feat.reshape(-1, att_feat.shape[-1])
197 | if self.norm_att_feat:
198 | att_feat = att_feat / np.linalg.norm(att_feat, 2, 1, keepdims=True)
199 | if self.use_box:
200 | box_feat = np.load(os.path.join(self.input_box_dir, str(self.info['images'][ix]['id']) + '.npy'))
201 | # devided by image width and height
202 | x1,y1,x2,y2 = np.hsplit(box_feat, 4)
203 | h,w = self.info['images'][ix]['height'], self.info['images'][ix]['width']
204 | box_feat = np.hstack((x1/w, y1/h, x2/w, y2/h, (x2-x1)*(y2-y1)/(w*h))) # question? x2-x1+1??
205 | if self.norm_box_feat:
206 | box_feat = box_feat / np.linalg.norm(box_feat, 2, 1, keepdims=True)
207 | att_feat = np.hstack([att_feat, box_feat])
208 | # sort the features by the size of boxes
209 | att_feat = np.stack(sorted(att_feat, key=lambda x:x[-1], reverse=True))
210 | else:
211 | att_feat = np.zeros((1,1,1))
212 | return (np.load(os.path.join(self.input_fc_dir, str(self.info['images'][ix]['id']) + '.npy')),
213 | att_feat,
214 | ix)
215 |
216 | def __len__(self):
217 | return len(self.info['images'])
218 |
219 | class SubsetSampler(torch.utils.data.sampler.Sampler):
220 | r"""Samples elements randomly from a given list of indices, without replacement.
221 | Arguments:
222 | indices (list): a list of indices
223 | """
224 |
225 | def __init__(self, indices):
226 | self.indices = indices
227 |
228 | def __iter__(self):
229 | return (self.indices[i] for i in range(len(self.indices)))
230 |
231 | def __len__(self):
232 | return len(self.indices)
233 |
234 | class BlobFetcher():
235 | """Experimental class for prefetching blobs in a separate process."""
236 | def __init__(self, split, dataloader, if_shuffle=False):
237 | """
238 | db is a list of tuples containing: imcrop_name, caption, bbox_feat of gt box, imname
239 | """
240 | self.split = split
241 | self.dataloader = dataloader
242 | self.if_shuffle = if_shuffle
243 |
244 | # Add more in the queue
245 | def reset(self):
246 | """
247 | Two cases for this function to be triggered:
248 | 1. not hasattr(self, 'split_loader'): Resume from previous training. Create the dataset given the saved split_ix and iterator
249 | 2. wrapped: a new epoch, the split_ix and iterator have been updated in the get_minibatch_inds already.
250 | """
251 | # batch_size is 1, the merge is done in DataLoader class
252 | self.split_loader = iter(data.DataLoader(dataset=self.dataloader,
253 | batch_size=1,
254 | sampler=SubsetSampler(self.dataloader.split_ix[self.split][self.dataloader.iterators[self.split]:]),
255 | shuffle=False,
256 | pin_memory=True,
257 | num_workers=4, # 4 is usually enough
258 | collate_fn=lambda x: x[0]))
259 |
260 | def _get_next_minibatch_inds(self):
261 | max_index = len(self.dataloader.split_ix[self.split])
262 | wrapped = False
263 |
264 | ri = self.dataloader.iterators[self.split]
265 | ix = self.dataloader.split_ix[self.split][ri]
266 |
267 | ri_next = ri + 1
268 | if ri_next >= max_index:
269 | ri_next = 0
270 | if self.if_shuffle:
271 | random.shuffle(self.dataloader.split_ix[self.split])
272 | wrapped = True
273 | self.dataloader.iterators[self.split] = ri_next
274 |
275 | return ix, wrapped
276 |
277 | def get(self):
278 | if not hasattr(self, 'split_loader'):
279 | self.reset()
280 |
281 | ix, wrapped = self._get_next_minibatch_inds()
282 | tmp = self.split_loader.next()
283 | if wrapped:
284 | self.reset()
285 |
286 | assert tmp[2] == ix, "ix not equal"
287 |
288 | return tmp + [wrapped]
--------------------------------------------------------------------------------
/models/AttEnsemble.py:
--------------------------------------------------------------------------------
1 | # This file contains Att2in2, AdaAtt, AdaAttMO, TopDown model
2 |
3 | # AdaAtt is from Knowing When to Look: Adaptive Attention via A Visual Sentinel for Image Captioning
4 | # https://arxiv.org/abs/1612.01887
5 | # AdaAttMO is a modified version with maxout lstm
6 |
7 | # Att2in is from Self-critical Sequence Training for Image Captioning
8 | # https://arxiv.org/abs/1612.00563
9 | # In this file we only have Att2in2, which is a slightly different version of att2in,
10 | # in which the img feature embedding and word embedding is the same as what in adaatt.
11 |
12 | # TopDown is from Bottom-Up and Top-Down Attention for Image Captioning and VQA
13 | # https://arxiv.org/abs/1707.07998
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import torch
20 | import torch.nn as nn
21 | import torch.nn.functional as F
22 | from torch.autograd import *
23 | import misc.utils as utils
24 |
25 | from .CaptionModel import CaptionModel
26 | from .AttModel import pack_wrapper, AttModel
27 |
28 | class AttEnsemble(AttModel):
29 | def __init__(self, models):
30 | CaptionModel.__init__(self)
31 | # super(AttEnsemble, self).__init__()
32 |
33 | self.models = nn.ModuleList(models)
34 | self.vocab_size = models[0].vocab_size
35 | self.seq_length = models[0].seq_length
36 | self.ss_prob = 0
37 |
38 | def init_hidden(self, batch_size):
39 | return [m.init_hidden(batch_size) for m in self.models]
40 |
41 | def embed(self, it):
42 | return [m.embed(it) for m in self.models]
43 |
44 | def core(self, *args):
45 | return zip(*[m.core(*_) for m, _ in zip(self.models, zip(*args))])
46 |
47 | def get_logprobs_state(self, it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, state):
48 | # 'it' contains a word index
49 | xt = self.embed(it)
50 |
51 | output, state = self.core(xt, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, state, tmp_att_masks)
52 | logprobs = torch.stack([F.softmax(m.logit(output[i]), dim=1) for i,m in enumerate(self.models)], 2).mean(2).log()
53 |
54 | return logprobs, state
55 |
56 | def _prepare_feature(self, fc_feats, att_feats, att_masks):
57 | att_feats, att_masks = self.clip_att(att_feats, att_masks)
58 |
59 | # embed fc and att feats
60 | fc_feats = [m.fc_embed(fc_feats) for m in self.models]
61 | att_feats = [pack_wrapper(m.att_embed, att_feats[...,:m.att_feat_size], att_masks) for m in self.models]
62 |
63 | # Project the attention feats first to reduce memory and computation comsumptions.
64 | p_att_feats = [m.ctx2att(att_feats[i]) for i,m in enumerate(self.models)]
65 |
66 | return fc_feats, att_feats, p_att_feats, [att_masks] * len(self.models)
67 |
68 | def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
69 | beam_size = opt.get('beam_size', 10)
70 | batch_size = fc_feats.size(0)
71 |
72 | fc_feats, att_feats, p_att_feats, att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
73 |
74 | assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
75 | seq = torch.LongTensor(self.seq_length, batch_size).zero_()
76 | seqLogprobs = torch.FloatTensor(self.seq_length, batch_size)
77 | # lets process every image independently for now, for simplicity
78 |
79 | self.done_beams = [[] for _ in range(batch_size)]
80 | for k in range(batch_size):
81 | state = self.init_hidden(beam_size)
82 | tmp_fc_feats = [fc_feats[i][k:k+1].expand(beam_size, fc_feats[i].size(1)) for i,m in enumerate(self.models)]
83 | tmp_att_feats = [att_feats[i][k:k+1].expand(*((beam_size,)+att_feats[i].size()[1:])).contiguous() for i,m in enumerate(self.models)]
84 | tmp_p_att_feats = [p_att_feats[i][k:k+1].expand(*((beam_size,)+p_att_feats[i].size()[1:])).contiguous() for i,m in enumerate(self.models)]
85 | tmp_att_masks = [att_masks[k:k+1].expand(*((beam_size,)+att_masks.size()[1:])).contiguous() for i,m in enumerate(self.models)] if att_masks[0] is not None else att_masks
86 |
87 | it = fc_feats[0].data.new(beam_size).long().zero_()
88 | logprobs, state = self.get_logprobs_state(it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, state)
89 |
90 | self.done_beams[k] = self.beam_search(state, logprobs, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, opt=opt)
91 | seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
92 | seqLogprobs[:, k] = self.done_beams[k][0]['logps']
93 | # return the samples and their log likelihoods
94 | return seq.transpose(0, 1), seqLogprobs.transpose(0, 1)
95 |
96 | def beam_search(self, init_state, init_logprobs, *args, **kwargs):
97 |
98 | # function computes the similarity score to be augmented
99 | def add_diversity(beam_seq_table, logprobsf, t, divm, diversity_lambda, bdash):
100 | local_time = t - divm
101 | unaug_logprobsf = logprobsf.clone()
102 | for prev_choice in range(divm):
103 | prev_decisions = beam_seq_table[prev_choice][local_time]
104 | for sub_beam in range(bdash):
105 | for prev_labels in range(bdash):
106 | logprobsf[sub_beam][prev_decisions[prev_labels]] = logprobsf[sub_beam][prev_decisions[prev_labels]] - diversity_lambda
107 | return unaug_logprobsf
108 |
109 | # does one step of classical beam search
110 |
111 | def beam_step(logprobsf, unaug_logprobsf, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state):
112 | #INPUTS:
113 | #logprobsf: probabilities augmented after diversity
114 | #beam_size: obvious
115 | #t : time instant
116 | #beam_seq : tensor contanining the beams
117 | #beam_seq_logprobs: tensor contanining the beam logprobs
118 | #beam_logprobs_sum: tensor contanining joint logprobs
119 | #OUPUTS:
120 | #beam_seq : tensor containing the word indices of the decoded captions
121 | #beam_seq_logprobs : log-probability of each decision made, same size as beam_seq
122 | #beam_logprobs_sum : joint log-probability of each beam
123 |
124 | ys,ix = torch.sort(logprobsf,1,True)
125 | candidates = []
126 | cols = min(beam_size, ys.size(1))
127 | rows = beam_size
128 | if t == 0:
129 | rows = 1
130 | for c in range(cols): # for each column (word, essentially)
131 | for q in range(rows): # for each beam expansion
132 | #compute logprob of expanding beam q with word in (sorted) position c
133 | local_logprob = ys[q,c].item()
134 | candidate_logprob = beam_logprobs_sum[q] + local_logprob
135 | local_unaug_logprob = unaug_logprobsf[q,ix[q,c]]
136 | candidates.append({'c':ix[q,c], 'q':q, 'p':candidate_logprob, 'r':local_unaug_logprob})
137 | candidates = sorted(candidates, key=lambda x: -x['p'])
138 |
139 | new_state = [[_.clone() for _ in state_] for state_ in state]
140 | #beam_seq_prev, beam_seq_logprobs_prev
141 | if t >= 1:
142 | #we''ll need these as reference when we fork beams around
143 | beam_seq_prev = beam_seq[:t].clone()
144 | beam_seq_logprobs_prev = beam_seq_logprobs[:t].clone()
145 | for vix in range(beam_size):
146 | v = candidates[vix]
147 | #fork beam index q into index vix
148 | if t >= 1:
149 | beam_seq[:t, vix] = beam_seq_prev[:, v['q']]
150 | beam_seq_logprobs[:t, vix] = beam_seq_logprobs_prev[:, v['q']]
151 | #rearrange recurrent states
152 | for ii in range(len(new_state)):
153 | for state_ix in range(len(new_state[ii])):
154 | # copy over state in previous beam q to new beam at vix
155 | new_state[ii][state_ix][:, vix] = state[ii][state_ix][:, v['q']] # dimension one is time step
156 | #append new end terminal at the end of this beam
157 | beam_seq[t, vix] = v['c'] # c'th word is the continuation
158 | beam_seq_logprobs[t, vix] = v['r'] # the raw logprob here
159 | beam_logprobs_sum[vix] = v['p'] # the new (sum) logprob along this beam
160 | state = new_state
161 | return beam_seq,beam_seq_logprobs,beam_logprobs_sum,state,candidates
162 |
163 | # Start diverse_beam_search
164 | opt = kwargs['opt']
165 | beam_size = opt.get('beam_size', 10)
166 | group_size = opt.get('group_size', 1)
167 | diversity_lambda = opt.get('diversity_lambda', 0.5)
168 | decoding_constraint = opt.get('decoding_constraint', 0)
169 | max_ppl = opt.get('max_ppl', 0)
170 | bdash = beam_size // group_size # beam per group
171 |
172 | # INITIALIZATIONS
173 | beam_seq_table = [torch.LongTensor(self.seq_length, bdash).zero_() for _ in range(group_size)]
174 | beam_seq_logprobs_table = [torch.FloatTensor(self.seq_length, bdash).zero_() for _ in range(group_size)]
175 | beam_logprobs_sum_table = [torch.zeros(bdash) for _ in range(group_size)]
176 |
177 | # logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1)
178 | done_beams_table = [[] for _ in range(group_size)]
179 | state_table = zip(*[[list(torch.unbind(_)) for _ in torch.stack(init_state_).chunk(group_size, 2)] for init_state_ in init_state])
180 | logprobs_table = list(init_logprobs.chunk(group_size, 0))
181 | # END INIT
182 |
183 | # Chunk elements in the args
184 | args = [[_.chunk(group_size) if _ is not None else [None]*group_size for _ in args_] for args_ in args] # arg_name, model_name, group_name
185 | args = [[[args[j][i][k] for i in range(len(self.models))] for j in range(len(args))] for k in range(group_size)] # group_name, arg_name, model_name
186 |
187 | for t in range(self.seq_length + group_size - 1):
188 | for divm in range(group_size):
189 | if t >= divm and t <= self.seq_length + divm - 1:
190 | # add diversity
191 | logprobsf = logprobs_table[divm].data.float()
192 | # suppress previous word
193 | if decoding_constraint and t-divm > 0:
194 | logprobsf.scatter_(1, beam_seq_table[divm][t-divm-1].unsqueeze(1).cuda(), float('-inf'))
195 | # suppress UNK tokens in the decoding
196 | logprobsf[:,logprobsf.size(1)-1] = logprobsf[:, logprobsf.size(1)-1] - 1000
197 | # diversity is added here
198 | # the function directly modifies the logprobsf values and hence, we need to return
199 | # the unaugmented ones for sorting the candidates in the end. # for historical
200 | # reasons :-)
201 | unaug_logprobsf = add_diversity(beam_seq_table,logprobsf,t,divm,diversity_lambda,bdash)
202 |
203 | # infer new beams
204 | beam_seq_table[divm],\
205 | beam_seq_logprobs_table[divm],\
206 | beam_logprobs_sum_table[divm],\
207 | state_table[divm],\
208 | candidates_divm = beam_step(logprobsf,
209 | unaug_logprobsf,
210 | bdash,
211 | t-divm,
212 | beam_seq_table[divm],
213 | beam_seq_logprobs_table[divm],
214 | beam_logprobs_sum_table[divm],
215 | state_table[divm])
216 |
217 | # if time's up... or if end token is reached then copy beams
218 | for vix in range(bdash):
219 | if beam_seq_table[divm][t-divm,vix] == 0 or t == self.seq_length + divm - 1:
220 | final_beam = {
221 | 'seq': beam_seq_table[divm][:, vix].clone(),
222 | 'logps': beam_seq_logprobs_table[divm][:, vix].clone(),
223 | 'unaug_p': beam_seq_logprobs_table[divm][:, vix].sum().item(),
224 | 'p': beam_logprobs_sum_table[divm][vix].item()
225 | }
226 | if max_ppl:
227 | final_beam['p'] = final_beam['p'] / (t-divm+1)
228 | done_beams_table[divm].append(final_beam)
229 | # don't continue beams from finished sequences
230 | beam_logprobs_sum_table[divm][vix] = -1000
231 |
232 | # move the current group one step forward in time
233 |
234 | it = beam_seq_table[divm][t-divm]
235 | logprobs_table[divm], state_table[divm] = self.get_logprobs_state(it.cuda(), *(args[divm] + [state_table[divm]]))
236 |
237 | # all beams are sorted by their log-probabilities
238 | done_beams_table = [sorted(done_beams_table[i], key=lambda x: -x['p'])[:bdash] for i in range(group_size)]
239 | done_beams = reduce(lambda a,b:a+b, done_beams_table)
240 | return done_beams
241 |
--------------------------------------------------------------------------------
/models/TransformerModel.py:
--------------------------------------------------------------------------------
1 | # This file contains Transformer network
2 | # Most of the code is copied from http://nlp.seas.harvard.edu/2018/04/03/attention.html
3 |
4 | # The cfg name correspondance:
5 | # N=num_layers
6 | # d_model=input_encoding_size
7 | # d_ff=rnn_size
8 | # h is always 8
9 |
10 | from __future__ import absolute_import
11 | from __future__ import division
12 | from __future__ import print_function
13 |
14 | import torch
15 | import torch.nn as nn
16 | import torch.nn.functional as F
17 | import misc.utils as utils
18 |
19 | import copy
20 | import math
21 | import numpy as np
22 |
23 | from .CaptionModel import CaptionModel
24 | from .AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper, AttModel
25 |
26 | class EncoderDecoder(nn.Module):
27 | """
28 | A standard Encoder-Decoder architecture. Base for this and many
29 | other models.
30 | """
31 | def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
32 | super(EncoderDecoder, self).__init__()
33 | self.encoder = encoder
34 | self.decoder = decoder
35 | self.src_embed = src_embed
36 | self.tgt_embed = tgt_embed
37 | self.generator = generator
38 |
39 | def forward(self, src, tgt, src_mask, tgt_mask):
40 | "Take in and process masked src and target sequences."
41 | return self.decode(self.encode(src, src_mask), src_mask,
42 | tgt, tgt_mask)
43 |
44 | def encode(self, src, src_mask):
45 | return self.encoder(self.src_embed(src), src_mask)
46 |
47 | def decode(self, memory, src_mask, tgt, tgt_mask):
48 | return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)
49 |
50 | class Generator(nn.Module):
51 | "Define standard linear + softmax generation step."
52 | def __init__(self, d_model, vocab):
53 | super(Generator, self).__init__()
54 | self.proj = nn.Linear(d_model, vocab)
55 |
56 | def forward(self, x):
57 | return F.log_softmax(self.proj(x), dim=-1)
58 |
59 | def clones(module, N):
60 | "Produce N identical layers."
61 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
62 |
63 | class Encoder(nn.Module):
64 | "Core encoder is a stack of N layers"
65 | def __init__(self, layer, N):
66 | super(Encoder, self).__init__()
67 | self.layers = clones(layer, N)
68 | self.norm = LayerNorm(layer.size)
69 |
70 | def forward(self, x, mask):
71 | "Pass the input (and mask) through each layer in turn."
72 | for layer in self.layers:
73 | x = layer(x, mask)
74 | return self.norm(x)
75 |
76 | class LayerNorm(nn.Module):
77 | "Construct a layernorm module (See citation for details)."
78 | def __init__(self, features, eps=1e-6):
79 | super(LayerNorm, self).__init__()
80 | self.a_2 = nn.Parameter(torch.ones(features))
81 | self.b_2 = nn.Parameter(torch.zeros(features))
82 | self.eps = eps
83 |
84 | def forward(self, x):
85 | mean = x.mean(-1, keepdim=True)
86 | std = x.std(-1, keepdim=True)
87 | return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
88 |
89 | class SublayerConnection(nn.Module):
90 | """
91 | A residual connection followed by a layer norm.
92 | Note for code simplicity the norm is first as opposed to last.
93 | """
94 | def __init__(self, size, dropout):
95 | super(SublayerConnection, self).__init__()
96 | self.norm = LayerNorm(size)
97 | self.dropout = nn.Dropout(dropout)
98 |
99 | def forward(self, x, sublayer):
100 | "Apply residual connection to any sublayer with the same size."
101 | return x + self.dropout(sublayer(self.norm(x)))
102 |
103 | class EncoderLayer(nn.Module):
104 | "Encoder is made up of self-attn and feed forward (defined below)"
105 | def __init__(self, size, self_attn, feed_forward, dropout):
106 | super(EncoderLayer, self).__init__()
107 | self.self_attn = self_attn
108 | self.feed_forward = feed_forward
109 | self.sublayer = clones(SublayerConnection(size, dropout), 2)
110 | self.size = size
111 |
112 | def forward(self, x, mask):
113 | "Follow Figure 1 (left) for connections."
114 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
115 | return self.sublayer[1](x, self.feed_forward)
116 |
117 | class Decoder(nn.Module):
118 | "Generic N layer decoder with masking."
119 | def __init__(self, layer, N):
120 | super(Decoder, self).__init__()
121 | self.layers = clones(layer, N)
122 | self.norm = LayerNorm(layer.size)
123 |
124 | def forward(self, x, memory, src_mask, tgt_mask):
125 | for layer in self.layers:
126 | x = layer(x, memory, src_mask, tgt_mask)
127 | return self.norm(x)
128 |
129 | class DecoderLayer(nn.Module):
130 | "Decoder is made of self-attn, src-attn, and feed forward (defined below)"
131 | def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
132 | super(DecoderLayer, self).__init__()
133 | self.size = size
134 | self.self_attn = self_attn
135 | self.src_attn = src_attn
136 | self.feed_forward = feed_forward
137 | self.sublayer = clones(SublayerConnection(size, dropout), 3)
138 |
139 | def forward(self, x, memory, src_mask, tgt_mask):
140 | "Follow Figure 1 (right) for connections."
141 | m = memory
142 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
143 | x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
144 | return self.sublayer[2](x, self.feed_forward)
145 |
146 | def subsequent_mask(size):
147 | "Mask out subsequent positions."
148 | attn_shape = (1, size, size)
149 | subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
150 | return torch.from_numpy(subsequent_mask) == 0
151 |
152 | def attention(query, key, value, mask=None, dropout=None):
153 | "Compute 'Scaled Dot Product Attention'"
154 | d_k = query.size(-1)
155 | scores = torch.matmul(query, key.transpose(-2, -1)) \
156 | / math.sqrt(d_k)
157 | if mask is not None:
158 | scores = scores.masked_fill(mask == 0, -1e9)
159 | p_attn = F.softmax(scores, dim = -1)
160 | if dropout is not None:
161 | p_attn = dropout(p_attn)
162 | return torch.matmul(p_attn, value), p_attn
163 |
164 | class MultiHeadedAttention(nn.Module):
165 | def __init__(self, h, d_model, dropout=0.1):
166 | "Take in model size and number of heads."
167 | super(MultiHeadedAttention, self).__init__()
168 | assert d_model % h == 0
169 | # We assume d_v always equals d_k
170 | self.d_k = d_model // h
171 | self.h = h
172 | self.linears = clones(nn.Linear(d_model, d_model), 4)
173 | self.attn = None
174 | self.dropout = nn.Dropout(p=dropout)
175 |
176 | def forward(self, query, key, value, mask=None):
177 | "Implements Figure 2"
178 | if mask is not None:
179 | # Same mask applied to all h heads.
180 | mask = mask.unsqueeze(1)
181 | nbatches = query.size(0)
182 |
183 | # 1) Do all the linear projections in batch from d_model => h x d_k
184 | query, key, value = \
185 | [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
186 | for l, x in zip(self.linears, (query, key, value))]
187 |
188 | # 2) Apply attention on all the projected vectors in batch.
189 | x, self.attn = attention(query, key, value, mask=mask,
190 | dropout=self.dropout)
191 |
192 | # 3) "Concat" using a view and apply a final linear.
193 | x = x.transpose(1, 2).contiguous() \
194 | .view(nbatches, -1, self.h * self.d_k)
195 | return self.linears[-1](x)
196 |
197 | class PositionwiseFeedForward(nn.Module):
198 | "Implements FFN equation."
199 | def __init__(self, d_model, d_ff, dropout=0.1):
200 | super(PositionwiseFeedForward, self).__init__()
201 | self.w_1 = nn.Linear(d_model, d_ff)
202 | self.w_2 = nn.Linear(d_ff, d_model)
203 | self.dropout = nn.Dropout(dropout)
204 |
205 | def forward(self, x):
206 | return self.w_2(self.dropout(F.relu(self.w_1(x))))
207 |
208 | class Embeddings(nn.Module):
209 | def __init__(self, d_model, vocab):
210 | super(Embeddings, self).__init__()
211 | self.lut = nn.Embedding(vocab, d_model)
212 | self.d_model = d_model
213 |
214 | def forward(self, x):
215 | return self.lut(x) * math.sqrt(self.d_model)
216 |
217 | class PositionalEncoding(nn.Module):
218 | "Implement the PE function."
219 | def __init__(self, d_model, dropout, max_len=5000):
220 | super(PositionalEncoding, self).__init__()
221 | self.dropout = nn.Dropout(p=dropout)
222 |
223 | # Compute the positional encodings once in log space.
224 | pe = torch.zeros(max_len, d_model)
225 | position = torch.arange(0, max_len).unsqueeze(1).float()
226 | div_term = torch.exp(torch.arange(0, d_model, 2).float() *
227 | -(math.log(10000.0) / d_model))
228 | pe[:, 0::2] = torch.sin(position * div_term)
229 | pe[:, 1::2] = torch.cos(position * div_term)
230 | pe = pe.unsqueeze(0)
231 | self.register_buffer('pe', pe)
232 |
233 | def forward(self, x):
234 | x = x + self.pe[:, :x.size(1)]
235 | return self.dropout(x)
236 |
237 | class TransformerModel(AttModel):
238 |
239 | def make_model(self, src_vocab, tgt_vocab, N=6,
240 | d_model=512, d_ff=2048, h=8, dropout=0.1):
241 | "Helper: Construct a model from hyperparameters."
242 | c = copy.deepcopy
243 | attn = MultiHeadedAttention(h, d_model)
244 | ff = PositionwiseFeedForward(d_model, d_ff, dropout)
245 | position = PositionalEncoding(d_model, dropout)
246 | model = EncoderDecoder(
247 | Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
248 | Decoder(DecoderLayer(d_model, c(attn), c(attn),
249 | c(ff), dropout), N),
250 | lambda x:x, # nn.Sequential(Embeddings(d_model, src_vocab), c(position)),
251 | nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)),
252 | Generator(d_model, tgt_vocab))
253 |
254 | # This was important from their code.
255 | # Initialize parameters with Glorot / fan_avg.
256 | for p in model.parameters():
257 | if p.dim() > 1:
258 | nn.init.xavier_uniform_(p)
259 | return model
260 |
261 | def __init__(self, opt):
262 | super(TransformerModel, self).__init__(opt)
263 | self.opt = opt
264 | # self.config = yaml.load(open(opt.config_file))
265 | # d_model = self.input_encoding_size # 512
266 |
267 | delattr(self, 'att_embed')
268 | self.att_embed = nn.Sequential(*(
269 | ((nn.BatchNorm1d(self.att_feat_size),) if self.use_bn else ())+
270 | (nn.Linear(self.att_feat_size, self.input_encoding_size),
271 | nn.ReLU(),
272 | nn.Dropout(self.drop_prob_lm))+
273 | ((nn.BatchNorm1d(self.input_encoding_size),) if self.use_bn==2 else ())))
274 |
275 | delattr(self, 'embed')
276 | self.embed = lambda x : x
277 | delattr(self, 'fc_embed')
278 | self.fc_embed = lambda x : x
279 | delattr(self, 'logit')
280 | del self.ctx2att
281 |
282 | tgt_vocab = self.vocab_size + 1
283 | self.model = self.make_model(0, tgt_vocab,
284 | N=opt.num_layers,
285 | d_model=opt.input_encoding_size,
286 | d_ff=opt.rnn_size)
287 |
288 | def logit(self, x): # unsafe way
289 | return self.model.generator.proj(x)
290 |
291 | def init_hidden(self, bsz):
292 | return None
293 |
294 | def _prepare_feature(self, fc_feats, att_feats, att_masks):
295 |
296 | att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks)
297 | memory = self.model.encode(att_feats, att_masks)
298 |
299 | return fc_feats[...,:1], att_feats[...,:1], memory, att_masks
300 |
301 | def _prepare_feature_forward(self, att_feats, att_masks=None, seq=None):
302 | att_feats, att_masks = self.clip_att(att_feats, att_masks)
303 |
304 | att_feats = pack_wrapper(self.att_embed, att_feats, att_masks)
305 |
306 | if att_masks is None:
307 | att_masks = att_feats.new_ones(att_feats.shape[:2], dtype=torch.long)
308 | att_masks = att_masks.unsqueeze(-2)
309 |
310 | if seq is not None:
311 | # crop the last one
312 | seq = seq[:,:-1]
313 | seq_mask = (seq.data > 0)
314 | seq_mask[:,0] += 1
315 |
316 | seq_mask = seq_mask.unsqueeze(-2)
317 | seq_mask = seq_mask & subsequent_mask(seq.size(-1)).to(seq_mask)
318 | else:
319 | seq_mask = None
320 |
321 | return att_feats, seq, att_masks, seq_mask
322 |
323 | def _forward(self, fc_feats, att_feats, seq, att_masks=None):
324 | att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks, seq)
325 |
326 | out = self.model(att_feats, seq, att_masks, seq_mask)
327 |
328 | outputs = self.model.generator(out)
329 | return outputs
330 | # return torch.cat([_.unsqueeze(1) for _ in outputs], 1)
331 |
332 | def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask):
333 | """
334 | state = [ys.unsqueeze(0)]
335 | """
336 | if state is None:
337 | ys = it.unsqueeze(1)
338 | else:
339 | ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1)
340 | out = self.model.decode(memory, mask,
341 | ys,
342 | subsequent_mask(ys.size(1))
343 | .to(memory.device))
344 | return out[:, -1], [ys.unsqueeze(0)]
--------------------------------------------------------------------------------
/models/AttModel.py:
--------------------------------------------------------------------------------
1 | # This file contains Att2in2, AdaAtt, AdaAttMO, TopDown model
2 |
3 | # AdaAtt is from Knowing When to Look: Adaptive Attention via A Visual Sentinel for Image Captioning
4 | # https://arxiv.org/abs/1612.01887
5 | # AdaAttMO is a modified version with maxout lstm
6 |
7 | # Att2in is from Self-critical Sequence Training for Image Captioning
8 | # https://arxiv.org/abs/1612.00563
9 | # In this file we only have Att2in2, which is a slightly different version of att2in,
10 | # in which the img feature embedding and word embedding is the same as what in adaatt.
11 |
12 | # TopDown is from Bottom-Up and Top-Down Attention for Image Captioning and VQA
13 | # https://arxiv.org/abs/1707.07998
14 | # However, it may not be identical to the author's architecture.
15 |
16 | from __future__ import absolute_import
17 | from __future__ import division
18 | from __future__ import print_function
19 |
20 | import torch
21 | import torch.nn as nn
22 | import torch.nn.functional as F
23 | import misc.utils as utils
24 | from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence
25 |
26 | from .CaptionModel import CaptionModel
27 |
28 | def sort_pack_padded_sequence(input, lengths):
29 | sorted_lengths, indices = torch.sort(lengths, descending=True)
30 | tmp = pack_padded_sequence(input[indices], sorted_lengths, batch_first=True)
31 | inv_ix = indices.clone()
32 | inv_ix[indices] = torch.arange(0,len(indices)).type_as(inv_ix)
33 | return tmp, inv_ix
34 |
35 | def pad_unsort_packed_sequence(input, inv_ix):
36 | tmp, _ = pad_packed_sequence(input, batch_first=True)
37 | tmp = tmp[inv_ix]
38 | return tmp
39 |
40 | def pack_wrapper(module, att_feats, att_masks):
41 | if att_masks is not None:
42 | packed, inv_ix = sort_pack_padded_sequence(att_feats, att_masks.data.long().sum(1))
43 | return pad_unsort_packed_sequence(PackedSequence(module(packed[0]), packed[1]), inv_ix)
44 | else:
45 | return module(att_feats)
46 |
47 | class AttModel(CaptionModel):
48 | def __init__(self, opt):
49 | super(AttModel, self).__init__()
50 | self.vocab_size = opt.vocab_size
51 | self.input_encoding_size = opt.input_encoding_size
52 | #self.rnn_type = opt.rnn_type
53 | self.rnn_size = opt.rnn_size
54 | self.num_layers = opt.num_layers
55 | self.drop_prob_lm = opt.drop_prob_lm
56 | self.seq_length = opt.seq_length
57 | self.fc_feat_size = opt.fc_feat_size
58 | self.att_feat_size = opt.att_feat_size
59 | self.att_hid_size = opt.att_hid_size
60 |
61 | self.use_bn = getattr(opt, 'use_bn', 0)
62 |
63 | self.ss_prob = 0.0 # Schedule sampling probability
64 |
65 | self.embed = nn.Sequential(nn.Embedding(self.vocab_size + 1, self.input_encoding_size),
66 | nn.ReLU(),
67 | nn.Dropout(self.drop_prob_lm))
68 | self.fc_embed = nn.Sequential(nn.Linear(self.fc_feat_size, self.rnn_size),
69 | nn.ReLU(),
70 | nn.Dropout(self.drop_prob_lm))
71 | self.att_embed = nn.Sequential(*(
72 | ((nn.BatchNorm1d(self.att_feat_size),) if self.use_bn else ())+
73 | (nn.Linear(self.att_feat_size, self.rnn_size),
74 | nn.ReLU(),
75 | nn.Dropout(self.drop_prob_lm))+
76 | ((nn.BatchNorm1d(self.rnn_size),) if self.use_bn==2 else ())))
77 |
78 | self.logit_layers = getattr(opt, 'logit_layers', 1)
79 | if self.logit_layers == 1:
80 | self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1)
81 | else:
82 | self.logit = [[nn.Linear(self.rnn_size, self.rnn_size), nn.ReLU(), nn.Dropout(0.5)] for _ in range(opt.logit_layers - 1)]
83 | self.logit = nn.Sequential(*(reduce(lambda x,y:x+y, self.logit) + [nn.Linear(self.rnn_size, self.vocab_size + 1)]))
84 | self.ctx2att = nn.Linear(self.rnn_size, self.att_hid_size)
85 |
86 | def init_hidden(self, bsz):
87 | weight = next(self.parameters())
88 | return (weight.new_zeros(self.num_layers, bsz, self.rnn_size),
89 | weight.new_zeros(self.num_layers, bsz, self.rnn_size))
90 |
91 | def clip_att(self, att_feats, att_masks):
92 | # Clip the length of att_masks and att_feats to the maximum length
93 | if att_masks is not None:
94 | max_len = att_masks.data.long().sum(1).max()
95 | att_feats = att_feats[:, :max_len].contiguous()
96 | att_masks = att_masks[:, :max_len].contiguous()
97 | return att_feats, att_masks
98 |
99 | def _prepare_feature(self, fc_feats, att_feats, att_masks):
100 | att_feats, att_masks = self.clip_att(att_feats, att_masks)
101 |
102 | # embed fc and att feats
103 | fc_feats = self.fc_embed(fc_feats)
104 | att_feats = pack_wrapper(self.att_embed, att_feats, att_masks)
105 |
106 | # Project the attention feats first to reduce memory and computation comsumptions.
107 | p_att_feats = self.ctx2att(att_feats)
108 |
109 | return fc_feats, att_feats, p_att_feats, att_masks
110 |
111 | def _forward(self, fc_feats, att_feats, seq, att_masks=None):
112 | batch_size = fc_feats.size(0)
113 | state = self.init_hidden(batch_size)
114 |
115 | outputs = fc_feats.new_zeros(batch_size, seq.size(1) - 1, self.vocab_size+1)
116 |
117 | # Prepare the features
118 | p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
119 | # pp_att_feats is used for attention, we cache it in advance to reduce computation cost
120 |
121 | for i in range(seq.size(1) - 1):
122 | if self.training and i >= 1 and self.ss_prob > 0.0: # otherwiste no need to sample
123 | sample_prob = fc_feats.new(batch_size).uniform_(0, 1)
124 | sample_mask = sample_prob < self.ss_prob
125 | if sample_mask.sum() == 0:
126 | it = seq[:, i].clone()
127 | else:
128 | sample_ind = sample_mask.nonzero().view(-1)
129 | it = seq[:, i].data.clone()
130 | #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1)
131 | #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1))
132 | # prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1)
133 | prob_prev = torch.exp(outputs[:, i-1].detach()) # fetch prev distribution: shape Nx(M+1)
134 | it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind))
135 | else:
136 | it = seq[:, i].clone()
137 | # break if all the sequences end
138 | if i >= 1 and seq[:, i].sum() == 0:
139 | break
140 |
141 | output, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state)
142 | outputs[:, i] = output
143 |
144 | return outputs
145 |
146 | def get_logprobs_state(self, it, fc_feats, att_feats, p_att_feats, att_masks, state):
147 | # 'it' contains a word index
148 | xt = self.embed(it)
149 |
150 | output, state = self.core(xt, fc_feats, att_feats, p_att_feats, state, att_masks)
151 | logprobs = F.log_softmax(self.logit(output), dim=1)
152 |
153 | return logprobs, state
154 |
155 | def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
156 | beam_size = opt.get('beam_size', 10)
157 | batch_size = fc_feats.size(0)
158 |
159 | p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
160 |
161 | assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
162 | seq = torch.LongTensor(self.seq_length, batch_size).zero_()
163 | seqLogprobs = torch.FloatTensor(self.seq_length, batch_size)
164 | # lets process every image independently for now, for simplicity
165 |
166 | self.done_beams = [[] for _ in range(batch_size)]
167 | for k in range(batch_size):
168 | state = self.init_hidden(beam_size)
169 | tmp_fc_feats = p_fc_feats[k:k+1].expand(beam_size, p_fc_feats.size(1))
170 | tmp_att_feats = p_att_feats[k:k+1].expand(*((beam_size,)+p_att_feats.size()[1:])).contiguous()
171 | tmp_p_att_feats = pp_att_feats[k:k+1].expand(*((beam_size,)+pp_att_feats.size()[1:])).contiguous()
172 | tmp_att_masks = p_att_masks[k:k+1].expand(*((beam_size,)+p_att_masks.size()[1:])).contiguous() if att_masks is not None else None
173 |
174 | for t in range(1):
175 | if t == 0: # input
176 | it = fc_feats.new_zeros([beam_size], dtype=torch.long)
177 |
178 | logprobs, state = self.get_logprobs_state(it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, state)
179 |
180 | self.done_beams[k] = self.beam_search(state, logprobs, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, opt=opt)
181 | seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
182 | seqLogprobs[:, k] = self.done_beams[k][0]['logps']
183 | # return the samples and their log likelihoods
184 | return seq.transpose(0, 1), seqLogprobs.transpose(0, 1)
185 |
186 | def _sample(self, fc_feats, att_feats, att_masks=None, opt={}):
187 |
188 | sample_max = opt.get('sample_max', 1)
189 | beam_size = opt.get('beam_size', 1)
190 | temperature = opt.get('temperature', 1.0)
191 | decoding_constraint = opt.get('decoding_constraint', 0)
192 | if beam_size > 1:
193 | return self._sample_beam(fc_feats, att_feats, att_masks, opt)
194 |
195 | batch_size = fc_feats.size(0)
196 | state = self.init_hidden(batch_size)
197 |
198 | p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
199 |
200 | seq = fc_feats.new_zeros((batch_size, self.seq_length), dtype=torch.long)
201 | seqLogprobs = fc_feats.new_zeros(batch_size, self.seq_length)
202 | for t in range(self.seq_length + 1):
203 | if t == 0: # input
204 | it = fc_feats.new_zeros(batch_size, dtype=torch.long)
205 |
206 | logprobs, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state)
207 |
208 | if decoding_constraint and t > 0:
209 | tmp = logprobs.new_zeros(logprobs.size())
210 | tmp.scatter_(1, seq[:,t-1].data.unsqueeze(1), float('-inf'))
211 | logprobs = logprobs + tmp
212 |
213 | # sample the next word
214 | if t == self.seq_length: # skip if we achieve maximum length
215 | break
216 | if sample_max:
217 | sampleLogprobs, it = torch.max(logprobs.data, 1)
218 | it = it.view(-1).long()
219 | else:
220 | if temperature == 1.0:
221 | prob_prev = torch.exp(logprobs.data) # fetch prev distribution: shape Nx(M+1)
222 | else:
223 | # scale logprobs by temperature
224 | prob_prev = torch.exp(torch.div(logprobs.data, temperature))
225 | it = torch.multinomial(prob_prev, 1)
226 | sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions
227 | it = it.view(-1).long() # and flatten indices for downstream processing
228 |
229 | # stop when all finished
230 | if t == 0:
231 | unfinished = it > 0
232 | else:
233 | unfinished = unfinished * (it > 0)
234 | it = it * unfinished.type_as(it)
235 | seq[:,t] = it
236 | seqLogprobs[:,t] = sampleLogprobs.view(-1)
237 | # quit loop if all sequences have finished
238 | if unfinished.sum() == 0:
239 | break
240 |
241 | return seq, seqLogprobs
242 |
243 | class AdaAtt_lstm(nn.Module):
244 | def __init__(self, opt, use_maxout=True):
245 | super(AdaAtt_lstm, self).__init__()
246 | self.input_encoding_size = opt.input_encoding_size
247 | #self.rnn_type = opt.rnn_type
248 | self.rnn_size = opt.rnn_size
249 | self.num_layers = opt.num_layers
250 | self.drop_prob_lm = opt.drop_prob_lm
251 | self.fc_feat_size = opt.fc_feat_size
252 | self.att_feat_size = opt.att_feat_size
253 | self.att_hid_size = opt.att_hid_size
254 |
255 | self.use_maxout = use_maxout
256 |
257 | # Build a LSTM
258 | self.w2h = nn.Linear(self.input_encoding_size, (4+(use_maxout==True)) * self.rnn_size)
259 | self.v2h = nn.Linear(self.rnn_size, (4+(use_maxout==True)) * self.rnn_size)
260 |
261 | self.i2h = nn.ModuleList([nn.Linear(self.rnn_size, (4+(use_maxout==True)) * self.rnn_size) for _ in range(self.num_layers - 1)])
262 | self.h2h = nn.ModuleList([nn.Linear(self.rnn_size, (4+(use_maxout==True)) * self.rnn_size) for _ in range(self.num_layers)])
263 |
264 | # Layers for getting the fake region
265 | if self.num_layers == 1:
266 | self.r_w2h = nn.Linear(self.input_encoding_size, self.rnn_size)
267 | self.r_v2h = nn.Linear(self.rnn_size, self.rnn_size)
268 | else:
269 | self.r_i2h = nn.Linear(self.rnn_size, self.rnn_size)
270 | self.r_h2h = nn.Linear(self.rnn_size, self.rnn_size)
271 |
272 |
273 | def forward(self, xt, img_fc, state):
274 |
275 | hs = []
276 | cs = []
277 | for L in range(self.num_layers):
278 | # c,h from previous timesteps
279 | prev_h = state[0][L]
280 | prev_c = state[1][L]
281 | # the input to this layer
282 | if L == 0:
283 | x = xt
284 | i2h = self.w2h(x) + self.v2h(img_fc)
285 | else:
286 | x = hs[-1]
287 | x = F.dropout(x, self.drop_prob_lm, self.training)
288 | i2h = self.i2h[L-1](x)
289 |
290 | all_input_sums = i2h+self.h2h[L](prev_h)
291 |
292 | sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size)
293 | sigmoid_chunk = F.sigmoid(sigmoid_chunk)
294 | # decode the gates
295 | in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size)
296 | forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size)
297 | out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size)
298 | # decode the write inputs
299 | if not self.use_maxout:
300 | in_transform = F.tanh(all_input_sums.narrow(1, 3 * self.rnn_size, self.rnn_size))
301 | else:
302 | in_transform = all_input_sums.narrow(1, 3 * self.rnn_size, 2 * self.rnn_size)
303 | in_transform = torch.max(\
304 | in_transform.narrow(1, 0, self.rnn_size),
305 | in_transform.narrow(1, self.rnn_size, self.rnn_size))
306 | # perform the LSTM update
307 | next_c = forget_gate * prev_c + in_gate * in_transform
308 | # gated cells form the output
309 | tanh_nex_c = F.tanh(next_c)
310 | next_h = out_gate * tanh_nex_c
311 | if L == self.num_layers-1:
312 | if L == 0:
313 | i2h = self.r_w2h(x) + self.r_v2h(img_fc)
314 | else:
315 | i2h = self.r_i2h(x)
316 | n5 = i2h+self.r_h2h(prev_h)
317 | fake_region = F.sigmoid(n5) * tanh_nex_c
318 |
319 | cs.append(next_c)
320 | hs.append(next_h)
321 |
322 | # set up the decoder
323 | top_h = hs[-1]
324 | top_h = F.dropout(top_h, self.drop_prob_lm, self.training)
325 | fake_region = F.dropout(fake_region, self.drop_prob_lm, self.training)
326 |
327 | state = (torch.cat([_.unsqueeze(0) for _ in hs], 0),
328 | torch.cat([_.unsqueeze(0) for _ in cs], 0))
329 | return top_h, fake_region, state
330 |
331 | class AdaAtt_attention(nn.Module):
332 | def __init__(self, opt):
333 | super(AdaAtt_attention, self).__init__()
334 | self.input_encoding_size = opt.input_encoding_size
335 | #self.rnn_type = opt.rnn_type
336 | self.rnn_size = opt.rnn_size
337 | self.drop_prob_lm = opt.drop_prob_lm
338 | self.att_hid_size = opt.att_hid_size
339 |
340 | # fake region embed
341 | self.fr_linear = nn.Sequential(
342 | nn.Linear(self.rnn_size, self.input_encoding_size),
343 | nn.ReLU(),
344 | nn.Dropout(self.drop_prob_lm))
345 | self.fr_embed = nn.Linear(self.input_encoding_size, self.att_hid_size)
346 |
347 | # h out embed
348 | self.ho_linear = nn.Sequential(
349 | nn.Linear(self.rnn_size, self.input_encoding_size),
350 | nn.Tanh(),
351 | nn.Dropout(self.drop_prob_lm))
352 | self.ho_embed = nn.Linear(self.input_encoding_size, self.att_hid_size)
353 |
354 | self.alpha_net = nn.Linear(self.att_hid_size, 1)
355 | self.att2h = nn.Linear(self.rnn_size, self.rnn_size)
356 |
357 | def forward(self, h_out, fake_region, conv_feat, conv_feat_embed, att_masks=None):
358 |
359 | # View into three dimensions
360 | att_size = conv_feat.numel() // conv_feat.size(0) // self.rnn_size
361 | conv_feat = conv_feat.view(-1, att_size, self.rnn_size)
362 | conv_feat_embed = conv_feat_embed.view(-1, att_size, self.att_hid_size)
363 |
364 | # view neighbor from bach_size * neighbor_num x rnn_size to bach_size x rnn_size * neighbor_num
365 | fake_region = self.fr_linear(fake_region)
366 | fake_region_embed = self.fr_embed(fake_region)
367 |
368 | h_out_linear = self.ho_linear(h_out)
369 | h_out_embed = self.ho_embed(h_out_linear)
370 |
371 | txt_replicate = h_out_embed.unsqueeze(1).expand(h_out_embed.size(0), att_size + 1, h_out_embed.size(1))
372 |
373 | img_all = torch.cat([fake_region.view(-1,1,self.input_encoding_size), conv_feat], 1)
374 | img_all_embed = torch.cat([fake_region_embed.view(-1,1,self.input_encoding_size), conv_feat_embed], 1)
375 |
376 | hA = F.tanh(img_all_embed + txt_replicate)
377 | hA = F.dropout(hA,self.drop_prob_lm, self.training)
378 |
379 | hAflat = self.alpha_net(hA.view(-1, self.att_hid_size))
380 | PI = F.softmax(hAflat.view(-1, att_size + 1), dim=1)
381 |
382 | if att_masks is not None:
383 | att_masks = att_masks.view(-1, att_size)
384 | PI = PI * torch.cat([att_masks[:,:1], att_masks], 1) # assume one one at the first time step.
385 | PI = PI / PI.sum(1, keepdim=True)
386 |
387 | visAtt = torch.bmm(PI.unsqueeze(1), img_all)
388 | visAttdim = visAtt.squeeze(1)
389 |
390 | atten_out = visAttdim + h_out_linear
391 |
392 | h = F.tanh(self.att2h(atten_out))
393 | h = F.dropout(h, self.drop_prob_lm, self.training)
394 | return h
395 |
396 | class AdaAttCore(nn.Module):
397 | def __init__(self, opt, use_maxout=False):
398 | super(AdaAttCore, self).__init__()
399 | self.lstm = AdaAtt_lstm(opt, use_maxout)
400 | self.attention = AdaAtt_attention(opt)
401 |
402 | def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None):
403 | h_out, p_out, state = self.lstm(xt, fc_feats, state)
404 | atten_out = self.attention(h_out, p_out, att_feats, p_att_feats, att_masks)
405 | return atten_out, state
406 |
407 | class TopDownCore(nn.Module):
408 | def __init__(self, opt, use_maxout=False):
409 | super(TopDownCore, self).__init__()
410 | self.drop_prob_lm = opt.drop_prob_lm
411 |
412 | self.att_lstm = nn.LSTMCell(opt.input_encoding_size + opt.rnn_size * 2, opt.rnn_size) # we, fc, h^2_t-1
413 | self.lang_lstm = nn.LSTMCell(opt.rnn_size * 2, opt.rnn_size) # h^1_t, \hat v
414 | self.attention = Attention(opt)
415 |
416 | def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None):
417 | prev_h = state[0][-1]
418 | att_lstm_input = torch.cat([prev_h, fc_feats, xt], 1)
419 |
420 | h_att, c_att = self.att_lstm(att_lstm_input, (state[0][0], state[1][0]))
421 |
422 | att = self.attention(h_att, att_feats, p_att_feats, att_masks)
423 |
424 | lang_lstm_input = torch.cat([att, h_att], 1)
425 | # lang_lstm_input = torch.cat([att, F.dropout(h_att, self.drop_prob_lm, self.training)], 1) ?????
426 |
427 | h_lang, c_lang = self.lang_lstm(lang_lstm_input, (state[0][1], state[1][1]))
428 |
429 | output = F.dropout(h_lang, self.drop_prob_lm, self.training)
430 | state = (torch.stack([h_att, h_lang]), torch.stack([c_att, c_lang]))
431 |
432 | return output, state
433 |
434 |
435 | ############################################################################
436 | # Notice:
437 | # StackAtt and DenseAtt are models that I randomly designed.
438 | # They are not related to any paper.
439 | ############################################################################
440 |
441 | from .FCModel import LSTMCore
442 | class StackAttCore(nn.Module):
443 | def __init__(self, opt, use_maxout=False):
444 | super(StackAttCore, self).__init__()
445 | self.drop_prob_lm = opt.drop_prob_lm
446 |
447 | # self.att0 = Attention(opt)
448 | self.att1 = Attention(opt)
449 | self.att2 = Attention(opt)
450 |
451 | opt_input_encoding_size = opt.input_encoding_size
452 | opt.input_encoding_size = opt.input_encoding_size + opt.rnn_size
453 | self.lstm0 = LSTMCore(opt) # att_feat + word_embedding
454 | opt.input_encoding_size = opt.rnn_size * 2
455 | self.lstm1 = LSTMCore(opt)
456 | self.lstm2 = LSTMCore(opt)
457 | opt.input_encoding_size = opt_input_encoding_size
458 |
459 | # self.emb1 = nn.Linear(opt.rnn_size, opt.rnn_size)
460 | self.emb2 = nn.Linear(opt.rnn_size, opt.rnn_size)
461 |
462 | def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None):
463 | # att_res_0 = self.att0(state[0][-1], att_feats, p_att_feats, att_masks)
464 | h_0, state_0 = self.lstm0(torch.cat([xt,fc_feats],1), [state[0][0:1], state[1][0:1]])
465 | att_res_1 = self.att1(h_0, att_feats, p_att_feats, att_masks)
466 | h_1, state_1 = self.lstm1(torch.cat([h_0,att_res_1],1), [state[0][1:2], state[1][1:2]])
467 | att_res_2 = self.att2(h_1 + self.emb2(att_res_1), att_feats, p_att_feats, att_masks)
468 | h_2, state_2 = self.lstm2(torch.cat([h_1,att_res_2],1), [state[0][2:3], state[1][2:3]])
469 |
470 | return h_2, [torch.cat(_, 0) for _ in zip(state_0, state_1, state_2)]
471 |
472 | class DenseAttCore(nn.Module):
473 | def __init__(self, opt, use_maxout=False):
474 | super(DenseAttCore, self).__init__()
475 | self.drop_prob_lm = opt.drop_prob_lm
476 |
477 | # self.att0 = Attention(opt)
478 | self.att1 = Attention(opt)
479 | self.att2 = Attention(opt)
480 |
481 | opt_input_encoding_size = opt.input_encoding_size
482 | opt.input_encoding_size = opt.input_encoding_size + opt.rnn_size
483 | self.lstm0 = LSTMCore(opt) # att_feat + word_embedding
484 | opt.input_encoding_size = opt.rnn_size * 2
485 | self.lstm1 = LSTMCore(opt)
486 | self.lstm2 = LSTMCore(opt)
487 | opt.input_encoding_size = opt_input_encoding_size
488 |
489 | # self.emb1 = nn.Linear(opt.rnn_size, opt.rnn_size)
490 | self.emb2 = nn.Linear(opt.rnn_size, opt.rnn_size)
491 |
492 | # fuse h_0 and h_1
493 | self.fusion1 = nn.Sequential(nn.Linear(opt.rnn_size*2, opt.rnn_size),
494 | nn.ReLU(),
495 | nn.Dropout(opt.drop_prob_lm))
496 | # fuse h_0, h_1 and h_2
497 | self.fusion2 = nn.Sequential(nn.Linear(opt.rnn_size*3, opt.rnn_size),
498 | nn.ReLU(),
499 | nn.Dropout(opt.drop_prob_lm))
500 |
501 | def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None):
502 | # att_res_0 = self.att0(state[0][-1], att_feats, p_att_feats, att_masks)
503 | h_0, state_0 = self.lstm0(torch.cat([xt,fc_feats],1), [state[0][0:1], state[1][0:1]])
504 | att_res_1 = self.att1(h_0, att_feats, p_att_feats, att_masks)
505 | h_1, state_1 = self.lstm1(torch.cat([h_0,att_res_1],1), [state[0][1:2], state[1][1:2]])
506 | att_res_2 = self.att2(h_1 + self.emb2(att_res_1), att_feats, p_att_feats, att_masks)
507 | h_2, state_2 = self.lstm2(torch.cat([self.fusion1(torch.cat([h_0, h_1], 1)),att_res_2],1), [state[0][2:3], state[1][2:3]])
508 |
509 | return self.fusion2(torch.cat([h_0, h_1, h_2], 1)), [torch.cat(_, 0) for _ in zip(state_0, state_1, state_2)]
510 |
511 | class Attention(nn.Module):
512 | def __init__(self, opt):
513 | super(Attention, self).__init__()
514 | self.rnn_size = opt.rnn_size
515 | self.att_hid_size = opt.att_hid_size
516 |
517 | self.h2att = nn.Linear(self.rnn_size, self.att_hid_size)
518 | self.alpha_net = nn.Linear(self.att_hid_size, 1)
519 |
520 | def forward(self, h, att_feats, p_att_feats, att_masks=None):
521 | # The p_att_feats here is already projected
522 | att_size = att_feats.numel() // att_feats.size(0) // att_feats.size(-1)
523 | att = p_att_feats.view(-1, att_size, self.att_hid_size)
524 |
525 | att_h = self.h2att(h) # batch * att_hid_size
526 | att_h = att_h.unsqueeze(1).expand_as(att) # batch * att_size * att_hid_size
527 | dot = att + att_h # batch * att_size * att_hid_size
528 | dot = F.tanh(dot) # batch * att_size * att_hid_size
529 | dot = dot.view(-1, self.att_hid_size) # (batch * att_size) * att_hid_size
530 | dot = self.alpha_net(dot) # (batch * att_size) * 1
531 | dot = dot.view(-1, att_size) # batch * att_size
532 |
533 | weight = F.softmax(dot, dim=1) # batch * att_size
534 | if att_masks is not None:
535 | weight = weight * att_masks.view(-1, att_size).float()
536 | weight = weight / weight.sum(1, keepdim=True) # normalize to 1
537 | att_feats_ = att_feats.view(-1, att_size, att_feats.size(-1)) # batch * att_size * att_feat_size
538 | att_res = torch.bmm(weight.unsqueeze(1), att_feats_).squeeze(1) # batch * att_feat_size
539 |
540 | return att_res
541 |
542 | class Att2in2Core(nn.Module):
543 | def __init__(self, opt):
544 | super(Att2in2Core, self).__init__()
545 | self.input_encoding_size = opt.input_encoding_size
546 | #self.rnn_type = opt.rnn_type
547 | self.rnn_size = opt.rnn_size
548 | #self.num_layers = opt.num_layers
549 | self.drop_prob_lm = opt.drop_prob_lm
550 | self.fc_feat_size = opt.fc_feat_size
551 | self.att_feat_size = opt.att_feat_size
552 | self.att_hid_size = opt.att_hid_size
553 |
554 | # Build a LSTM
555 | self.a2c = nn.Linear(self.rnn_size, 2 * self.rnn_size)
556 | self.i2h = nn.Linear(self.input_encoding_size, 5 * self.rnn_size)
557 | self.h2h = nn.Linear(self.rnn_size, 5 * self.rnn_size)
558 | self.dropout = nn.Dropout(self.drop_prob_lm)
559 |
560 | self.attention = Attention(opt)
561 |
562 | def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None):
563 | att_res = self.attention(state[0][-1], att_feats, p_att_feats, att_masks)
564 |
565 | all_input_sums = self.i2h(xt) + self.h2h(state[0][-1])
566 | sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size)
567 | sigmoid_chunk = F.sigmoid(sigmoid_chunk)
568 | in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size)
569 | forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size)
570 | out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size)
571 |
572 | in_transform = all_input_sums.narrow(1, 3 * self.rnn_size, 2 * self.rnn_size) + \
573 | self.a2c(att_res)
574 | in_transform = torch.max(\
575 | in_transform.narrow(1, 0, self.rnn_size),
576 | in_transform.narrow(1, self.rnn_size, self.rnn_size))
577 | next_c = forget_gate * state[1][-1] + in_gate * in_transform
578 | next_h = out_gate * F.tanh(next_c)
579 |
580 | output = self.dropout(next_h)
581 | state = (next_h.unsqueeze(0), next_c.unsqueeze(0))
582 | return output, state
583 |
584 | class Att2inCore(Att2in2Core):
585 | def __init__(self, opt):
586 | super(Att2inCore, self).__init__(opt)
587 | del self.a2c
588 | self.a2c = nn.Linear(self.att_feat_size, 2 * self.rnn_size)
589 |
590 | """
591 | Note this is my attempt to replicate att2all model in self-critical paper.
592 | However, this is not a correct replication actually. Will fix it.
593 | """
594 | class Att2all2Core(nn.Module):
595 | def __init__(self, opt):
596 | super(Att2all2Core, self).__init__()
597 | self.input_encoding_size = opt.input_encoding_size
598 | #self.rnn_type = opt.rnn_type
599 | self.rnn_size = opt.rnn_size
600 | #self.num_layers = opt.num_layers
601 | self.drop_prob_lm = opt.drop_prob_lm
602 | self.fc_feat_size = opt.fc_feat_size
603 | self.att_feat_size = opt.att_feat_size
604 | self.att_hid_size = opt.att_hid_size
605 |
606 | # Build a LSTM
607 | self.a2h = nn.Linear(self.rnn_size, 5 * self.rnn_size)
608 | self.i2h = nn.Linear(self.input_encoding_size, 5 * self.rnn_size)
609 | self.h2h = nn.Linear(self.rnn_size, 5 * self.rnn_size)
610 | self.dropout = nn.Dropout(self.drop_prob_lm)
611 |
612 | self.attention = Attention(opt)
613 |
614 | def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None):
615 | att_res = self.attention(state[0][-1], att_feats, p_att_feats, att_masks)
616 |
617 | all_input_sums = self.i2h(xt) + self.h2h(state[0][-1]) + self.a2h(att_res)
618 | sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size)
619 | sigmoid_chunk = F.sigmoid(sigmoid_chunk)
620 | in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size)
621 | forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size)
622 | out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size)
623 |
624 | in_transform = all_input_sums.narrow(1, 3 * self.rnn_size, 2 * self.rnn_size)
625 | in_transform = torch.max(\
626 | in_transform.narrow(1, 0, self.rnn_size),
627 | in_transform.narrow(1, self.rnn_size, self.rnn_size))
628 | next_c = forget_gate * state[1][-1] + in_gate * in_transform
629 | next_h = out_gate * F.tanh(next_c)
630 |
631 | output = self.dropout(next_h)
632 | state = (next_h.unsqueeze(0), next_c.unsqueeze(0))
633 | return output, state
634 |
635 | class AdaAttModel(AttModel):
636 | def __init__(self, opt):
637 | super(AdaAttModel, self).__init__(opt)
638 | self.core = AdaAttCore(opt)
639 |
640 | # AdaAtt with maxout lstm
641 | class AdaAttMOModel(AttModel):
642 | def __init__(self, opt):
643 | super(AdaAttMOModel, self).__init__(opt)
644 | self.core = AdaAttCore(opt, True)
645 |
646 | class Att2in2Model(AttModel):
647 | def __init__(self, opt):
648 | super(Att2in2Model, self).__init__(opt)
649 | self.core = Att2in2Core(opt)
650 | delattr(self, 'fc_embed')
651 | self.fc_embed = lambda x : x
652 |
653 | class Att2all2Model(AttModel):
654 | def __init__(self, opt):
655 | super(Att2all2Model, self).__init__(opt)
656 | self.core = Att2all2Core(opt)
657 | delattr(self, 'fc_embed')
658 | self.fc_embed = lambda x : x
659 |
660 | class TopDownModel(AttModel):
661 | def __init__(self, opt):
662 | super(TopDownModel, self).__init__(opt)
663 | self.num_layers = 2
664 | self.core = TopDownCore(opt)
665 |
666 | class StackAttModel(AttModel):
667 | def __init__(self, opt):
668 | super(StackAttModel, self).__init__(opt)
669 | self.num_layers = 3
670 | self.core = StackAttCore(opt)
671 |
672 | class DenseAttModel(AttModel):
673 | def __init__(self, opt):
674 | super(DenseAttModel, self).__init__(opt)
675 | self.num_layers = 3
676 | self.core = DenseAttCore(opt)
677 |
678 | class Att2inModel(AttModel):
679 | def __init__(self, opt):
680 | super(Att2inModel, self).__init__(opt)
681 | del self.embed, self.fc_embed, self.att_embed
682 | self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size)
683 | self.fc_embed = self.att_embed = lambda x: x
684 | del self.ctx2att
685 | self.ctx2att = nn.Linear(self.att_feat_size, self.att_hid_size)
686 | self.core = Att2inCore(opt)
687 | self.init_weights()
688 |
689 | def init_weights(self):
690 | initrange = 0.1
691 | self.embed.weight.data.uniform_(-initrange, initrange)
692 | self.logit.bias.data.fill_(0)
693 | self.logit.weight.data.uniform_(-initrange, initrange)
694 |
--------------------------------------------------------------------------------