├── README.md ├── __init__.py ├── config.py ├── dataloader.py ├── dataloaderraw.py ├── eval.py ├── eval_utils.py ├── misc ├── .DS_Store ├── __init__.py ├── resnet.py ├── resnet_utils.py ├── rewards.py ├── symlink.py └── utils.py ├── models ├── .DS_Store ├── Att2inModel.py ├── AttModel.py ├── AttModel_CCG.py ├── AttModel_V1.py ├── AttModel_V2.py ├── TDModel.py └── __init__.py ├── opts.py ├── scripts ├── .DS_Store ├── __init__.py ├── prepro_feats.py ├── prepro_feats_coco.py ├── prepro_labels.py ├── prepro_labels_detection.py └── prepro_ngrams.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # Multitask_Image_Captioning 2 | 3 | This is a preliminary rough sketch of the Multitask Image Captioning implementation. The main functions are already in this repository. The next update will including more detailed instructions. 4 | 5 | For any questions, you can report issue here. 6 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andyweizhao/Multitask_Image_Captioning/c672fe480618fccce1239600a394cf62f0b32719/__init__.py -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import opts 3 | def get_cn_fixed_opts(): 4 | opt = opts.parse_opt() 5 | opt.caption_model ='topdown' 6 | opt.batch_size=10 7 | #Pretrain 8 | opt.id ='topdown' 9 | opt.learning_rate= 5e-4 10 | opt.learning_rate_decay_start= 0 11 | opt.scheduled_sampling_start=0 12 | opt.save_checkpoint_every=1300#11500 13 | opt.val_images_use=5000 14 | opt.max_epochs=40 15 | opt.start_from=None 16 | opt.input_json='data/meta/aitalk_meta.json' 17 | opt.input_label_h5='data/dataset/aitalk_label.h5' 18 | opt.input_fc_dir='/media/andyweizhao/Elements/CVPR/cocotalk_fc' 19 | opt.input_att_dir='data/cocotalk_att' 20 | return opt 21 | 22 | def get_en_fixed_opts(): 23 | opt = opts.parse_opt() 24 | opt.caption_model ='topdown' 25 | opt.batch_size=10 26 | #Pretrain 27 | opt.id ='topdown' 28 | opt.learning_rate= 5e-4 29 | opt.learning_rate_decay_start= 0 30 | opt.scheduled_sampling_start=0 31 | opt.save_checkpoint_every=1300#11500 32 | opt.val_images_use=5000 33 | opt.max_epochs=40 34 | opt.start_from=None 35 | # opt.input_json='data/meta/aitalk_meta.json' 36 | # opt.input_label_h5='data/dataset/aitalk_label.h5' 37 | opt.input_json='/home/andyweizhao/wabywang/010/data/dataset/coco_processed.json' 38 | opt.input_label_h5='data/dataset/coco_label.h5' 39 | opt.input_fc_dir='/media/andyweizhao/Elements/CVPR/cocotalk_fc' 40 | opt.input_att_dir='data/cocotalk_att' 41 | return opt 42 | def get_cn_opts(): 43 | opt = opts.parse_opt() 44 | opt.caption_model ='cross_topdown' 45 | opt.batch_size=10 46 | #Pretrain 47 | opt.id ='topdown' 48 | opt.learning_rate= 5e-4 49 | opt.learning_rate_decay_start= 0 50 | opt.scheduled_sampling_start=0 51 | opt.save_checkpoint_every=1300#11500 52 | opt.val_images_use=5000 53 | opt.max_epochs=40 54 | opt.start_from=None 55 | opt.input_json='data/meta/aitalk_meta.json' 56 | opt.input_label_h5='data/dataset/aitalk_label.h5' 57 | # opt.input_json='data/dataset/tmp/aitalk_cross.json' 58 | # opt.input_label_h5='data/dataset/tmp/aitalk_cross_label.h5' 59 | opt.input_fc_dir='/media/andyweizhao/Elements/CVPR/cocotalk_fc' 60 | opt.input_att_dir='data/cocotalk_att' 61 | return opt 62 | 63 | def get_en_opts(): 64 | opt = opts.parse_opt() 65 | opt.caption_model ='cross_topdown' 66 | opt.batch_size=10 67 | #Pretrain 68 | opt.id ='topdown' 69 | opt.learning_rate= 5e-4 70 | opt.learning_rate_decay_start= 0 71 | opt.scheduled_sampling_start=0 72 | opt.save_checkpoint_every=1300#11500 73 | opt.val_images_use=5000 74 | opt.max_epochs=40 75 | opt.start_from=None 76 | opt.input_json='/home/andyweizhao/wabywang/010/data/dataset/coco_processed.json' 77 | opt.input_label_h5='data/dataset/coco_label.h5' 78 | opt.input_fc_dir='/media/andyweizhao/Elements/CVPR/cocotalk_fc' 79 | opt.input_att_dir='data/cocotalk_att' 80 | return opt -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import json 6 | import h5py 7 | import numpy as np 8 | import random 9 | import torch 10 | import cPickle 11 | import skimage.io 12 | from torchvision import transforms as trn 13 | preprocess = trn.Compose([ 14 | #trn.ToTensor(), 15 | trn.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 16 | ]) 17 | 18 | class DataLoader(): 19 | def __init__(self, opt): 20 | self.opt = opt 21 | self.batch_size = self.opt.batch_size 22 | self.seq_per_img = self.opt.seq_per_img 23 | 24 | print('DataLoader loading json file: ', opt.input_json) 25 | self.info = json.load(open(self.opt.input_json)) 26 | self.ix_to_word = self.info['ix_to_word'] 27 | self.ix_to_word_ccg = cPickle.load( open("data/ix_to_ccg.pkl","rb") ) 28 | self.detection_dataset = cPickle.load(open("data/detection_all.json", 'rb')) 29 | 30 | self.vocab_size = len(self.ix_to_word) 31 | print('vocab word size is ', self.vocab_size) 32 | self.vocab_ccg_size = len(self.ix_to_word_ccg) 33 | print('vocab ccg size is ', self.vocab_ccg_size) 34 | 35 | print('DataLoader loading h5 file: ', opt.input_label_h5, opt.input_image_h5) 36 | self.h5_label_file = h5py.File(self.opt.input_label_h5) 37 | self.h5_image_file = h5py.File(self.opt.input_image_h5) 38 | self.h5_image_path = np.load('data/image_path.npy') 39 | 40 | self.input_fc_dir = self.opt.input_fc_dir 41 | self.input_att_dir = self.opt.input_att_dir 42 | 43 | # extract image size from dataset 44 | images_size = self.h5_image_file['images'].shape 45 | assert len(images_size) == 4, 'images should be a 4D tensor' 46 | assert images_size[2] == images_size[3], 'width and height must match' 47 | self.num_images = images_size[0] 48 | self.num_channels = images_size[1] 49 | self.max_image_size = images_size[2] 50 | print('read %d images of size %dx%dx%d' %(self.num_images, 51 | self.num_channels, self.max_image_size, self.max_image_size)) 52 | 53 | # load in the sequence data 54 | seq_size = self.h5_label_file['labels'].shape 55 | self.seq_length = seq_size[1] 56 | print('max sequence length in data is', self.seq_length) 57 | 58 | # load the pointers in full to RAM (should be small enough) 59 | self.label_start_ix = self.h5_label_file['label_start_ix'][:] 60 | self.label_end_ix = self.h5_label_file['label_end_ix'][:] 61 | 62 | self.split_ix = {'train': [], 'val': [], 'test': []} 63 | for ix in range(len(self.info['images'])): 64 | img = self.info['images'][ix] 65 | if img['split'] == 'train': 66 | self.split_ix['train'].append(ix) 67 | elif img['split'] == 'val': 68 | self.split_ix['val'].append(ix) 69 | elif img['split'] == 'test': 70 | self.split_ix['test'].append(ix) 71 | elif opt.train_only == 0: # restval 72 | # I used some of val for train, and that's "restval". So train/restval is train, val is val, test is test 73 | self.split_ix['train'].append(ix) 74 | 75 | print('assigned %d images to split train' %len(self.split_ix['train'])) 76 | print('assigned %d images to split val' %len(self.split_ix['val'])) 77 | print('assigned %d images to split test' %len(self.split_ix['test'])) 78 | 79 | self.iterators = {'train': 0, 'val': 0, 'test': 0} 80 | 81 | def get_vocab_ccg(self): 82 | result=dict() 83 | for k, v in self.ix_to_word_ccg.items(): 84 | result[str(k)]=v 85 | return result 86 | 87 | def get_vocab(self): 88 | return self.ix_to_word 89 | 90 | def get_seq_length(self): 91 | return self.seq_length 92 | 93 | def get_batch(self, split, batch_size=None): 94 | split_ix = self.split_ix[split] 95 | batch_size = batch_size or self.batch_size 96 | seq_per_img = self.seq_per_img or self.seq_per_img 97 | 98 | # img_batch = np.ndarray([batch_size, 3, 512,512], dtype = 'float32') 99 | img_batch = [] 100 | label_batch = np.zeros([batch_size * self.seq_per_img, self.seq_length + 2], dtype = 'int') 101 | mask_batch = np.zeros([batch_size * self.seq_per_img, self.seq_length + 2], dtype = 'float32') 102 | ccg_batch = np.zeros([batch_size * seq_per_img, self.seq_length + 2], dtype = 'int') 103 | 104 | max_index = len(split_ix) 105 | wrapped = False 106 | 107 | infos = [] 108 | gts = [] 109 | detection_infos = [] 110 | for i in range(batch_size): 111 | ri = self.iterators[split] 112 | ri_next = ri + 1 113 | if ri_next >= max_index: 114 | ri_next = 0 115 | wrapped = True 116 | 117 | self.iterators[split] = ri_next 118 | ix = split_ix[ri] 119 | 120 | # img = self.h5_image_file['images'][ix, :, :, :] 121 | img_path = self.h5_image_path[ix] 122 | img_path = img_path.replace('/nlp/dataset/MSCOCO','/data1/zsfx/wabywang/caption/dataset/MSCOCO') 123 | img = skimage.io.imread(img_path) 124 | if len(img.shape) == 2: 125 | img = img[:,:,np.newaxis] 126 | img = np.concatenate((img,img,img), axis=2) 127 | img = img.transpose(2,0,1) 128 | 129 | # img_batch[i] = preprocess(torch.from_numpy(img.astype('float32')/255.0)).numpy() 130 | img_batch.append(preprocess(torch.from_numpy(img.astype('float32')/255.0)).numpy()) 131 | # fetch the sequence labels 132 | ix1 = self.label_start_ix[ix] - 1 #label_start_ix starts from 1 133 | ix2 = self.label_end_ix[ix] - 1 134 | ncap = ix2 - ix1 + 1 # number of captions available for this image 135 | assert ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t' 136 | 137 | if ncap < self.seq_per_img: 138 | seq = np.zeros([self.seq_per_img, self.seq_length], dtype = 'int') 139 | ccg_seq = np.zeros([self.seq_per_img, self.seq_length], dtype = 'int') 140 | for q in range(self.seq_per_img): 141 | ixl = random.randint(ix1,ix2) 142 | seq[q, :] = self.h5_label_file['labels'][ixl, :self.seq_length] 143 | if self.opt.ccg: 144 | ccg_seq[q, :] = self.h5_label_file['ccg'][ixl, :self.seq_length]# zero with padding and starts with 1 145 | else: 146 | ixl = random.randint(ix1, ix2 - self.seq_per_img + 1)# pick the last 5 captions 147 | seq = self.h5_label_file['labels'][ixl: ixl + self.seq_per_img, :self.seq_length] 148 | if self.opt.ccg: 149 | ccg_seq = self.h5_label_file['ccg'][ixl: ixl + self.seq_per_img, :self.seq_length] 150 | # leave bos and eos to 0 151 | if self.opt.ccg: 152 | ccg_batch[i * self.seq_per_img : (i + 1) * self.seq_per_img, 1 : self.seq_length + 1] = ccg_seq 153 | label_batch[i * self.seq_per_img : (i + 1) * self.seq_per_img, 1 : self.seq_length + 1] = seq 154 | 155 | # Used for reward evaluation 156 | gts.append(self.h5_label_file['labels'][self.label_start_ix[ix] - 1: self.label_end_ix[ix]]) 157 | 158 | # record associated info as well 159 | info_dict = {} 160 | info_dict['id'] = self.info['images'][ix]['id'] 161 | info_dict['file_path'] = self.info['images'][ix]['file_path'] 162 | infos.append(info_dict) 163 | 164 | detection_dict = {} 165 | if (self.detection_dataset.has_key(info_dict['id'])): 166 | img_id = info_dict['id'] 167 | detection_dict['label'] = self.detection_dataset[img_id]['label'].astype(int) 168 | detection_dict['l_import'] = self.detection_dataset[img_id]['l_import'] 169 | detection_dict['super_words'] = self.detection_dataset[img_id]['super_words'] 170 | detection_dict['sw_import'] = self.detection_dataset[img_id]['sw_import'] 171 | detection_dict['w_import'] = self.detection_dataset[img_id]['w_import'] 172 | detection_dict['words'] = self.detection_dataset[img_id]['words'] 173 | else: 174 | detection_dict['label'] = list(np.zeros(81)) 175 | detection_infos.append(detection_dict) 176 | # generate mask 177 | nonzeros = np.array(list(map(lambda x: (x != 0).sum() + 2, label_batch))) 178 | for ix, row in enumerate(mask_batch): 179 | row[:nonzeros[ix]] = 1 180 | 181 | data = {} 182 | data['images'] = img_batch 183 | data['labels'] = label_batch 184 | if self.opt.ccg: 185 | data['ccg'] = ccg_batch 186 | data['gts'] = gts 187 | data['masks'] = mask_batch 188 | data['bounds'] = {'it_pos_now': self.iterators[split], 'it_max': len(split_ix), 'wrapped': wrapped} 189 | data['infos'] = infos 190 | data['detection_infos'] = detection_infos 191 | 192 | return data 193 | 194 | def reset_iterator(self, split): 195 | self.iterators[split] = 0 196 | 197 | def main(): 198 | import opts 199 | import misc.utils as utils 200 | opt = opts.parse_opt() 201 | opt.caption_model ='topdown' 202 | opt.batch_size=10 203 | opt.id ='topdown' 204 | opt.learning_rate= 5e-4 205 | opt.learning_rate_decay_start= 0 206 | opt.scheduled_sampling_start=0 207 | opt.save_checkpoint_every=25#11500 208 | opt.val_images_use=5000 209 | opt.max_epochs=40 210 | opt.start_from=None 211 | opt.input_json='data/meta_coco_en.json' 212 | opt.input_label_h5='data/label_coco_en.h5' 213 | opt.input_image_h5 = 'data/coco_image_512.h5' 214 | opt.use_att = utils.if_use_att(opt.caption_model) 215 | opt.ccg = False 216 | loader = DataLoader(opt) 217 | opt.vocab_size = loader.vocab_size 218 | opt.seq_length = loader.seq_length 219 | data = loader.get_batch('train') 220 | 221 | data = loader.get_batch('val') 222 | -------------------------------------------------------------------------------- /dataloaderraw.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import json 6 | import h5py 7 | import os 8 | import numpy as np 9 | import random 10 | import torch 11 | from scipy.misc import imread, imresize 12 | from torch.autograd import Variable 13 | import skimage 14 | import skimage.io 15 | import scipy.misc 16 | 17 | from torchvision import transforms as trn 18 | preprocess = trn.Compose([ 19 | #trn.ToTensor(), 20 | trn.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 21 | ]) 22 | 23 | from misc.resnet_utils import myResnet 24 | import misc.resnet 25 | 26 | class DataLoaderRaw(): 27 | 28 | def __init__(self, opt): 29 | self.opt = opt 30 | self.coco_json = opt.get('coco_json', '') 31 | self.folder_path = opt.get('folder_path', '') 32 | 33 | self.batch_size = opt.get('batch_size', 1) 34 | self.seq_per_img = 1 35 | 36 | # Load resnet 37 | # self.cnn_model = opt.get('cnn_model', 'resnet101') 38 | # self.my_resnet = getattr(misc.resnet, self.cnn_model)() 39 | # self.my_resnet.load_state_dict(torch.load('./data/imagenet_weights/'+self.cnn_model+'.pth')) 40 | # self.my_resnet = myResnet(self.my_resnet) 41 | # self.my_resnet.cuda() 42 | # self.my_resnet.eval() 43 | 44 | 45 | 46 | # load the json file which contains additional information about the dataset 47 | print('DataLoaderRaw loading images from folder: ', self.folder_path) 48 | 49 | self.files = [] 50 | self.ids = [] 51 | 52 | print(len(self.coco_json)) 53 | if len(self.coco_json) > 0: 54 | print('reading from ' + self.coco_json) 55 | # read in filenames from the coco-style json file 56 | self.coco_annotation = json.load(open(self.coco_json)) 57 | for k,v in enumerate(self.coco_annotation['images']): 58 | fullpath = os.path.join(self.folder_path, v['file_name']) 59 | self.files.append(fullpath) 60 | self.ids.append(v['id']) 61 | else: 62 | # read in all the filenames from the folder 63 | print('listing all images in directory ' + self.folder_path) 64 | def isImage(f): 65 | supportedExt = ['.jpg','.JPG','.jpeg','.JPEG','.png','.PNG','.ppm','.PPM'] 66 | for ext in supportedExt: 67 | start_idx = f.rfind(ext) 68 | if start_idx >= 0 and start_idx + len(ext) == len(f): 69 | return True 70 | return False 71 | 72 | n = 1 73 | for root, dirs, files in os.walk(self.folder_path, topdown=False): 74 | for file in files: 75 | fullpath = os.path.join(self.folder_path, file) 76 | if isImage(fullpath): 77 | filename, _ = os.path.splitext(fullpath) 78 | self.files.append(fullpath) 79 | self.ids.append(filename) 80 | # self.ids.append(str(n)) # just order them sequentially 81 | n = n + 1 82 | 83 | self.N = len(self.files) 84 | print('DataLoaderRaw found ', self.N, ' images') 85 | 86 | self.iterator = 0 87 | 88 | def get_batch(self, split, batch_size=None): 89 | batch_size = batch_size or self.batch_size 90 | 91 | # pick an index of the datapoint to load next 92 | # fc_batch = np.ndarray((batch_size, 2048), dtype = 'float32') 93 | # att_batch = np.ndarray((batch_size, 14, 14, 2048), dtype = 'float32') 94 | img_batch = np.ndarray([batch_size, 3, 256,256], dtype = 'float32') 95 | max_index = self.N 96 | wrapped = False 97 | infos = [] 98 | 99 | for i in range(batch_size): 100 | ri = self.iterator 101 | ri_next = ri + 1 102 | if ri_next >= max_index: 103 | ri_next = 0 104 | wrapped = True 105 | # wrap back around 106 | self.iterator = ri_next 107 | 108 | # img = skimage.io.imread(self.files[ri]) 109 | img = imread(self.files[ri]) 110 | img = imresize(img, (256,256)) 111 | 112 | if len(img.shape) == 2: 113 | img = img[:,:,np.newaxis] 114 | img = np.concatenate((img, img, img), axis=2) 115 | 116 | # img = img.astype('float32')/255.0 117 | # img = torch.from_numpy(img.transpose([2,0,1])).cuda() 118 | # img = Variable(preprocess(img), volatile=True) 119 | # tmp_fc, tmp_att = self.my_resnet(img) 120 | # 121 | # fc_batch[i] = tmp_fc.data.cpu().float().numpy() 122 | # att_batch[i] = tmp_att.data.cpu().float().numpy() 123 | img_batch[i] = preprocess(torch.from_numpy(img.transpose(2,0,1).astype('float32')/255.0)).numpy() 124 | 125 | info_struct = {} 126 | info_struct['id'] = self.ids[ri] 127 | info_struct['file_path'] = self.files[ri] 128 | infos.append(info_struct) 129 | 130 | data = {} 131 | # data['fc_feats'] = fc_batch 132 | # data['att_feats'] = att_batch 133 | data['images'] = img_batch 134 | data['bounds'] = {'it_pos_now': self.iterator, 'it_max': self.N, 'wrapped': wrapped} 135 | data['infos'] = infos 136 | 137 | return data 138 | 139 | def reset_iterator(self, split): 140 | self.iterator = 0 141 | 142 | def get_vocab_size(self): 143 | return len(self.ix_to_word) 144 | 145 | def get_vocab(self): 146 | return self.ix_to_word 147 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import json 6 | import numpy as np 7 | 8 | import time 9 | import os 10 | from six.moves import cPickle 11 | 12 | import opts 13 | import models 14 | from dataloader import * 15 | from dataloaderraw import * 16 | import eval_utils 17 | import argparse 18 | import misc.utils as utils 19 | import torch 20 | 21 | #os.environ["CUDA_VISIBLE_DEVICES"] = "0" 22 | # Input arguments and options 23 | parser = argparse.ArgumentParser() 24 | # Input paths 25 | #parser.add_argument('--model', type=str, default='', 26 | parser.add_argument('--model_path', type=str, default='', 27 | help='path to model to evaluate') 28 | #parser.add_argument('--cnn_model', type=str, default='resnet101', 29 | # help='resnet101, resnet152') 30 | parser.add_argument('--cnn_model_path', type=str, default='', 31 | help='path to cnn model to evaluate') 32 | parser.add_argument('--infos_path', type=str, default='', 33 | help='path to infos to evaluate') 34 | # Basic options 35 | parser.add_argument('--batch_size', type=int, default=0, 36 | help='if > 0 then overrule, otherwise load from checkpoint.') 37 | parser.add_argument('--num_images', type=int, default=-1, 38 | help='how many images to use when periodically evaluating the loss? (-1 = all)') 39 | parser.add_argument('--language_eval', type=int, default=0, 40 | help='Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.') 41 | parser.add_argument('--dump_images', type=int, default=1, 42 | help='Dump images into vis/imgs folder for vis? (1=yes,0=no)') 43 | parser.add_argument('--dump_json', type=int, default=1, 44 | help='Dump json with predictions into vis folder? (1=yes,0=no)') 45 | parser.add_argument('--dump_path', type=int, default=0, 46 | help='Write image paths along with predictions into vis json? (1=yes,0=no)') 47 | 48 | # Sampling options 49 | parser.add_argument('--sample_max', type=int, default=1, 50 | help='1 = sample argmax words. 0 = sample from distributions.') 51 | parser.add_argument('--beam_size', type=int, default=2, 52 | help='used when sample_max = 1, indicates number of beams in beam search. Usually 2 or 3 works well. More is not better. Set this to 1 for faster runtime but a bit worse performance.') 53 | parser.add_argument('--temperature', type=float, default=1.0, 54 | help='temperature when sampling from distributions (i.e. when sample_max = 0). Lower = "safer" predictions.') 55 | # For evaluation on a folder of images: 56 | parser.add_argument('--image_folder', type=str, default='', 57 | help='If this is nonempty then will predict on the images in this folder path') 58 | parser.add_argument('--image_root', type=str, default='', 59 | help='In case the image paths have to be preprended with a root path to an image folder') 60 | # For evaluation on MSCOCO images from some split: 61 | parser.add_argument('--input_fc_dir', type=str, default='', 62 | help='path to the h5file containing the preprocessed dataset') 63 | parser.add_argument('--input_att_dir', type=str, default='', 64 | help='path to the h5file containing the preprocessed dataset') 65 | parser.add_argument('--input_label_h5', type=str, default='', 66 | help='path to the h5file containing the preprocessed label') 67 | parser.add_argument('--input_image_h5', type=str, default='', 68 | help='path to the h5file containing the preprocessed image') 69 | 70 | parser.add_argument('--input_json', type=str, default='', 71 | help='path to the json file containing additional info and vocab. empty = fetch from model checkpoint.') 72 | parser.add_argument('--split', type=str, default='test', 73 | help='if running on MSCOCO images, which split to use: val|test|train') 74 | parser.add_argument('--coco_json', type=str, default='', 75 | help='if nonempty then use this file in DataLoaderRaw (see docs there). Used only in MSCOCO test evaluation, where we have a specific json file of only test set images.') 76 | # misc 77 | parser.add_argument('--id', type=str, default='', 78 | help='an id identifying this run/job. used only if language_eval = 1 for appending to intermediate files') 79 | 80 | opt = parser.parse_args() 81 | opt.dump_images = 0 82 | opt.num_images = -1 83 | opt.cnn_model_path = 'save/multitask_pretrain_rl/best/model-cnn-best.pth' 84 | opt.model_path = 'save/multitask_pretrain_rl/best/model-best.pth' 85 | opt.infos_path = 'save/multitask_pretrain_rl/infos_topdown-best.pkl' 86 | opt.start_from = 'save/multitask_pretrain_rl' 87 | opt.split = 'test' 88 | opt.language_eval = 1 89 | #opt.batch_size=10 90 | opt.beam_size = 20 91 | opt.temperature = 1 92 | opt.sample_max = 1 93 | opt.verbose = True 94 | 95 | #opt.image_folder='/nlp/dataset/MSCOCO/test2014' 96 | #opt.num_images=-1 97 | #opt.coco_json='image_info_test2014.json' 98 | # Load infos 99 | with open(opt.infos_path) as f: 100 | infos = cPickle.load(f) 101 | 102 | # override and collect parameters 103 | if len(opt.input_fc_dir) == 0: 104 | opt.input_fc_dir = infos['opt'].input_fc_dir 105 | opt.input_att_dir = infos['opt'].input_att_dir 106 | if len(opt.input_label_h5) == 0: 107 | opt.input_label_h5 = infos['opt'].input_label_h5 108 | if len(opt.input_image_h5) == 0: 109 | opt.input_image_h5 = infos['opt'].input_image_h5 110 | 111 | if len(opt.input_json) == 0: 112 | opt.input_json = infos['opt'].input_json 113 | if opt.batch_size == 0: 114 | opt.batch_size = infos['opt'].batch_size 115 | if len(opt.id) == 0: 116 | opt.id = infos['opt'].id 117 | ignore = ["id", "batch_size", "beam_size", "start_from", "language_eval"] 118 | for k in vars(infos['opt']).keys(): 119 | if k not in ignore: 120 | if k in vars(opt): 121 | assert vars(opt)[k] == vars(infos['opt'])[k], k + ' option not consistent' 122 | else: 123 | vars(opt).update({k: vars(infos['opt'])[k]}) # copy over options from model 124 | 125 | vocab = infos['vocab'] # ix -> word mapping 126 | 127 | # Setup the model 128 | cnn_model = utils.build_cnn(opt) 129 | cnn_model.load_state_dict(torch.load(opt.cnn_model_path)) 130 | cnn_model.cuda() 131 | cnn_model.eval() 132 | 133 | model = models.setup(opt) 134 | model.load_state_dict(torch.load(opt.model_path)) 135 | model.cuda() 136 | model.eval() 137 | crit = utils.LanguageModelCriterion() 138 | 139 | # Create the Data Loader instance 140 | if len(opt.image_folder) == 0: 141 | loader = DataLoader(opt) 142 | else: 143 | loader = DataLoaderRaw({'folder_path': opt.image_folder, 144 | 'coco_json': opt.coco_json, 145 | 'batch_size': opt.batch_size}) 146 | # 'cnn_model': opt.cnn_model}) 147 | # When eval using provided pretrained model, the vocab may be different from what you have in your cocotalk.json 148 | # So make sure to use the vocab in infos file. 149 | loader.ix_to_word = infos['vocab'] 150 | 151 | 152 | # Set sample options 153 | #loss, split_predictions, lang_stats = eval_utils.eval_split(model, crit, loader, 154 | # vars(opt)) 155 | loss, split_predictions, lang_stats = eval_utils.eval_split(cnn_model, model, crit, loader, vars(opt), True) 156 | 157 | print('loss: ', loss) 158 | if lang_stats: 159 | print(lang_stats) 160 | 161 | #if opt.dump_json == 1: 162 | # # dump the json 163 | # json.dump(split_predictions, open('vis/vis.json', 'w')) 164 | -------------------------------------------------------------------------------- /eval_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.autograd import Variable 8 | 9 | import numpy as np 10 | import json 11 | from json import encoder 12 | import random 13 | import os 14 | import sys 15 | import misc.utils as utils 16 | import torch.nn.functional as F 17 | 18 | #os.environ["CUDA_VISIBLE_DEVICES"] = "0" 19 | 20 | def language_eval(dataset, preds, model_id, split): 21 | sys.path.append("coco-caption") 22 | annFile = 'coco-caption/annotations/captions_val2014.json' 23 | from pycocotools.coco import COCO 24 | from pycocoevalcap.eval import COCOEvalCap 25 | 26 | encoder.FLOAT_REPR = lambda o: format(o, '.3f') 27 | 28 | if not os.path.isdir('eval_results'): 29 | os.mkdir('eval_results') 30 | cache_path = os.path.join('eval_results/', model_id + '_' + split + '.json') 31 | 32 | 33 | coco = COCO(annFile) 34 | valids = coco.getImgIds() 35 | 36 | # filter results to only those in MSCOCO validation set (will be about a third) 37 | preds_filt = [p for p in preds if p['image_id'] in valids] 38 | print('using %d/%d predictions' % (len(preds_filt), len(preds))) 39 | json.dump(preds_filt, open(cache_path, 'w')) # serialize to temporary json file. Sigh, COCO API... 40 | 41 | cocoRes = coco.loadRes(cache_path) 42 | cocoEval = COCOEvalCap(coco, cocoRes) 43 | cocoEval.params['image_id'] = cocoRes.getImgIds() 44 | cocoEval.evaluate() 45 | 46 | # create output dictionary 47 | out = {} 48 | for metric, score in cocoEval.eval.items(): 49 | out[metric] = score 50 | 51 | imgToEval = cocoEval.imgToEval 52 | for p in preds_filt: 53 | image_id, caption = p['image_id'], p['caption'] 54 | imgToEval[image_id]['caption'] = caption 55 | with open(cache_path, 'w') as outfile: 56 | json.dump({'overall': out, 'imgToEval': imgToEval}, outfile) 57 | 58 | return out 59 | 60 | def eval_split(cnn_model, model, crit, loader, eval_kwargs={}, new_features=False): 61 | verbose = eval_kwargs.get('verbose', False) 62 | num_images = eval_kwargs.get('num_images', eval_kwargs.get('val_images_use', -1)) 63 | split = eval_kwargs.get('split', 'val') 64 | lang_eval = eval_kwargs.get('language_eval', 0) 65 | dataset = eval_kwargs.get('dataset', 'coco') 66 | 67 | cnn_model.eval() 68 | model.eval() 69 | loader.reset_iterator(split) 70 | 71 | n = 0 72 | loss = 0 73 | loss_sum = 0 74 | loss_evals = 1e-8 75 | predictions = [] 76 | while True: 77 | data = loader.get_batch(split) 78 | n = n + loader.batch_size 79 | 80 | #evaluate loss if we have the labels 81 | loss = 0 82 | torch.cuda.synchronize() 83 | if new_features: 84 | # tmp = [data['images'], data.get('labels', np.zeros(1)), data.get('masks', np.zeros(1))] 85 | # tmp = [Variable(torch.from_numpy(_), volatile=True).cuda() for _ in tmp] 86 | # images, labels, masks = tmp 87 | # att_feats, _ = _att_feats, _ = cnn_model(images) 88 | # fc_feats = _fc_feats = att_feats.mean(3).mean(2).squeeze() 89 | # att_feats = _att_feats = F.adaptive_avg_pool2d(att_feats,[14,14]).squeeze().permute(0, 2, 3, 1) 90 | # att_feats = att_feats.unsqueeze(1).expand(*((att_feats.size(0), loader.seq_per_img,) + att_feats.size()[1:])).contiguous().view(*((att_feats.size(0) * loader.seq_per_img,) + att_feats.size()[1:])) 91 | # fc_feats = fc_feats.unsqueeze(1).expand(*((fc_feats.size(0), loader.seq_per_img,) + fc_feats.size()[1:])).contiguous().view(*((fc_feats.size(0) * loader.seq_per_img,) + fc_feats.size()[1:])) 92 | tmp = [data.get('labels', np.zeros(1)), data.get('masks', np.zeros(1))] 93 | tmp = [Variable(torch.from_numpy(_), volatile=True).cuda() for _ in tmp] 94 | labels, masks = tmp 95 | images = data['images'] 96 | _fc_feats = [] 97 | _att_feats = [] 98 | for i in range(loader.batch_size): 99 | x = Variable(torch.from_numpy(images[i]), volatile=True).cuda() 100 | x = x.unsqueeze(0) 101 | att_feats, _ = cnn_model(x) 102 | fc_feats = att_feats.mean(3).mean(2).squeeze() 103 | att_feats = F.adaptive_avg_pool2d(att_feats,[14,14]).squeeze().permute(1, 2, 0)#(0, 2, 3, 1) 104 | _fc_feats.append(fc_feats) 105 | _att_feats.append(att_feats) 106 | _fc_feats = torch.stack(_fc_feats) 107 | _att_feats = torch.stack(_att_feats) 108 | att_feats = _att_feats.unsqueeze(1).expand(*((_att_feats.size(0), loader.seq_per_img,) + \ 109 | _att_feats.size()[1:])).contiguous().view(*((_att_feats.size(0) * loader.seq_per_img,) + \ 110 | _att_feats.size()[1:])) 111 | fc_feats = _fc_feats.unsqueeze(1).expand(*((_fc_feats.size(0), loader.seq_per_img,) + \ 112 | _fc_feats.size()[1:])).contiguous().view(*((_fc_feats.size(0) * loader.seq_per_img,) + \ 113 | _fc_feats.size()[1:])) 114 | else: 115 | tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks']] 116 | tmp = [Variable(torch.from_numpy(_), volatile=True).cuda() for _ in tmp] 117 | fc_feats, att_feats, labels, masks = tmp 118 | 119 | # forward the model to get loss 120 | if data.get('labels', None) is not None: 121 | if eval_kwargs.get("ccg",False)==False: 122 | loss = crit(model(fc_feats, att_feats, labels), labels[:,1:], masks[:,1:]).data[0] 123 | else: 124 | tmp = [data['ccg']] 125 | tmp = [Variable(torch.from_numpy(_), volatile=True).cuda() for _ in tmp] 126 | ccg = tmp 127 | # tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'],data['ccg']] 128 | # tmp = [Variable(torch.from_numpy(_), volatile=True).cuda() for _ in tmp] 129 | # fc_feats, att_feats, labels, masks,ccg = tmp 130 | word_labels, ccg_labels= model(fc_feats, att_feats, labels, ccg) 131 | loss = crit(word_labels, labels[:,1:], masks[:,1:]).data[0] 132 | 133 | loss_sum = loss_sum + loss 134 | loss_evals = loss_evals + 1 135 | 136 | # forward the model to also get generated samples for each image 137 | # Only leave one feature for each image, in case duplicate sample 138 | if new_features: 139 | fc_feats, att_feats = _fc_feats, _att_feats 140 | else: 141 | tmp = [data['fc_feats'][np.arange(loader.batch_size) * loader.seq_per_img], 142 | data['att_feats'][np.arange(loader.batch_size) * loader.seq_per_img]] 143 | tmp = [Variable(torch.from_numpy(_), volatile=True).cuda() for _ in tmp] 144 | fc_feats, att_feats = tmp 145 | 146 | # forward the model to also get generated samples for each image 147 | if eval_kwargs.get("ccg",False): 148 | seq, _,seq_ccg,___ = model.sample(fc_feats, att_feats, eval_kwargs)#model.module.sample(fc_feats, att_feats, eval_kwargs) 149 | else: 150 | seq, _ = model.sample(fc_feats.contiguous(), att_feats.contiguous(), eval_kwargs)#model.module.sample(fc_feats, att_feats, eval_kwargs) 151 | torch.cuda.synchronize() 152 | 153 | sents = utils.decode_sequence(loader.get_vocab(), seq) 154 | if eval_kwargs.get("ccg",False): 155 | sents_ccg = utils.decode_sequence(loader.get_vocab_ccg(),seq_ccg) 156 | for k, sent in enumerate(sents): 157 | if eval_kwargs.get("ccg",False): 158 | entry = {'image_id': data['infos'][k]['id'], 'caption': sent,"caption_ccg":sents_ccg[k]} 159 | else: 160 | entry = {'image_id': data['infos'][k]['id'], 'caption': sent} 161 | predictions.append(entry) 162 | 163 | if verbose and random.random()<0.0001 : 164 | print('image %s: %s' %(entry['image_id'], entry['caption'])) 165 | if eval_kwargs.get("ccg",False): 166 | print('image %s: %s' %(entry['image_id'], entry['caption_ccg'])) 167 | 168 | # if we wrapped around the split or used up val imgs budget then bail 169 | ix0 = data['bounds']['it_pos_now'] 170 | ix1 = data['bounds']['it_max'] 171 | if num_images != -1: 172 | ix1 = min(ix1, num_images) 173 | for i in range(n - ix1): 174 | predictions.pop() 175 | 176 | if verbose and ix0 % 2500 == 0: 177 | print('evaluating validation preformance... %d/%d (%f)' %(ix0 - 1, ix1, loss)) 178 | 179 | if data['bounds']['wrapped']: 180 | break 181 | if num_images >= 0 and n >= num_images: 182 | break 183 | 184 | lang_stats = None 185 | if lang_eval == 1: 186 | lang_stats = language_eval(dataset, predictions, eval_kwargs['id'], split) 187 | 188 | # Switch back to training mode 189 | model.train() 190 | return loss_sum/loss_evals, predictions, lang_stats 191 | 192 | def main(): 193 | import opts 194 | import misc.utils as utils 195 | opt = opts.parse_opt() 196 | opt.caption_model ='topdown' 197 | opt.batch_size=10#512#32*4*4 198 | opt.id ='topdown' 199 | opt.learning_rate= 5e-4 200 | opt.learning_rate_decay_start= 0 201 | opt.scheduled_sampling_start=0 202 | opt.save_checkpoint_every=5000#450#5000#11500 203 | opt.val_images_use=5000 204 | opt.max_epochs=50#30 205 | opt.start_from='save/rt'#"save" #None 206 | opt.language_eval = 1 207 | opt.input_json='data/meta_coco_en.json' 208 | opt.input_label_h5='data/label_coco_en.h5' 209 | # opt.input_json='data/coco_ccg.json' #'data/meta_coco_en.json' 210 | # opt.input_label_h5='data/coco_ccg_label.h5' #'data/label_coco_en.h5' 211 | # opt.input_fc_dir='/nlp/andyweizhao/self-critical.pytorch-master/data/cocotalk_fc' 212 | # opt.input_att_dir='/nlp/andyweizhao/self-critical.pytorch-master/data/cocotalk_att' 213 | opt.finetune_cnn_after = 0 214 | opt.ccg = False 215 | opt.input_image_h5 = 'data/coco_image_512.h5' 216 | 217 | opt.use_att = utils.if_use_att(opt.caption_model) 218 | 219 | from dataloader import DataLoader # just-in-time generated features 220 | loader = DataLoader(opt) 221 | 222 | # from dataloader_fixcnn import DataLoader # load pre-processed features 223 | # loader = DataLoader(opt) 224 | 225 | opt.vocab_size = loader.vocab_size 226 | opt.vocab_ccg_size = loader.vocab_ccg_size 227 | opt.seq_length = loader.seq_length 228 | 229 | import models 230 | model = models.setup(opt) 231 | cnn_model = utils.build_cnn(opt) 232 | cnn_model.cuda() 233 | model.cuda() 234 | 235 | data = loader.get_batch('train') 236 | images = data['images'] 237 | 238 | # _fc_feats_2048 = [] 239 | # _fc_feats_81 = [] 240 | # _att_feats = [] 241 | # for i in range(loader.batch_size): 242 | # x = Variable(torch.from_numpy(images[i]), volatile=True).cuda() 243 | # x = x.unsqueeze(0) 244 | # att_feats, fc_feats_81 = cnn_model(x) 245 | # fc_feats_2048 = att_feats.mean(3).mean(2).squeeze() 246 | # att_feats = F.adaptive_avg_pool2d(att_feats,[14,14]).squeeze().permute(1, 2, 0)#(0, 2, 3, 1) 247 | # _fc_feats_2048.append(fc_feats_2048) 248 | # _fc_feats_81.append(fc_feats_81) 249 | # _att_feats.append(att_feats) 250 | # _fc_feats_2048 = torch.stack(_fc_feats_2048) 251 | # _fc_feats_81 = torch.stack(_fc_feats_81) 252 | # _att_feats = torch.stack(_att_feats) 253 | # att_feats = _att_feats.unsqueeze(1).expand(*((_att_feats.size(0), loader.seq_per_img,) + \ 254 | # _att_feats.size()[1:])).contiguous().view(*((_att_feats.size(0) * loader.seq_per_img,) + \ 255 | # _att_feats.size()[1:])) 256 | # fc_feats_2048 = _fc_feats_2048.unsqueeze(1).expand(*((_fc_feats_2048.size(0), loader.seq_per_img,) + \ 257 | # _fc_feats_2048.size()[1:])).contiguous().view(*((_fc_feats_2048.size(0) * loader.seq_per_img,) + \ 258 | # _fc_feats_2048.size()[1:])) 259 | # fc_feats_81 = _fc_feats_81 260 | # 261 | # att_feats = Variable(att_feats, requires_grad=False).cuda() 262 | # Variable(fc_feats_81) 263 | 264 | crit = utils.LanguageModelCriterion() 265 | eval_kwargs = {'split': 'val','dataset': opt.input_json,'verbose':True} 266 | eval_kwargs.update(vars(opt)) 267 | val_loss, predictions, lang_stats = eval_split(cnn_model, model, crit, loader, eval_kwargs, True) 268 | 269 | # from models.AttModel import TopDownModel 270 | # model = TopDownModel(opt) 271 | # 272 | # import torch.optim as optim 273 | # optimizer = optim.Adam(model.parameters(), lr=opt.learning_rate) 274 | # cnn_optimizer = optim.Adam([\ 275 | # {'params': module.parameters()} for module in cnn_model._modules.values()[5:]\ 276 | # ], lr=opt.cnn_learning_rate, weight_decay=opt.cnn_weight_decay) 277 | # 278 | # cnn_optimizer.state_dict().keys() 279 | # import misc.resnet as resnet 280 | # net = getattr(resnet, opt.cnn_model)() 281 | ## net.load_state_dict(torch.load('save/'+opt.cnn_weight)) 282 | # net.load_state_dict(torch.load('save/rt/model-cnn.pth')) 283 | ## cnn_model = net 284 | ## net.state_dict().keys() 285 | # net = nn.Sequential(\ 286 | # net.conv1, 287 | # net.bn1, 288 | # net.relu, 289 | # net.maxpool, 290 | # net.layer1, 291 | # net.layer2, 292 | # net.layer3, 293 | # net.layer4) 294 | # 295 | # net.load_state_dict(torch.load('save/'+opt.cnn_weight)) 296 | 297 | #main() 298 | -------------------------------------------------------------------------------- /misc/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andyweizhao/Multitask_Image_Captioning/c672fe480618fccce1239600a394cf62f0b32719/misc/.DS_Store -------------------------------------------------------------------------------- /misc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andyweizhao/Multitask_Image_Captioning/c672fe480618fccce1239600a394cf62f0b32719/misc/__init__.py -------------------------------------------------------------------------------- /misc/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | 6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 7 | 'resnet152'] 8 | 9 | 10 | model_urls = { 11 | 'resnet18': 'https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', 12 | 'resnet34': 'https://s3.amazonaws.com/pytorch/models/resnet34-333f7ec4.pth', 13 | 'resnet50': 'https://s3.amazonaws.com/pytorch/models/resnet50-19c8e357.pth', 14 | 'resnet101': 'https://s3.amazonaws.com/pytorch/models/resnet101-5d3b4d8f.pth', 15 | 'resnet152': 'https://s3.amazonaws.com/pytorch/models/resnet152-b121ed2d.pth', 16 | } 17 | 18 | 19 | def conv3x3(in_planes, out_planes, stride=1): 20 | "3x3 convolution with padding" 21 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 22 | padding=1, bias=False) 23 | 24 | 25 | class BasicBlock(nn.Module): 26 | expansion = 1 27 | 28 | def __init__(self, inplanes, planes, stride=1, downsample=None): 29 | super(BasicBlock, self).__init__() 30 | self.conv1 = conv3x3(inplanes, planes, stride) 31 | self.bn1 = nn.BatchNorm2d(planes) 32 | self.relu = nn.ReLU(inplace=True) 33 | self.conv2 = conv3x3(planes, planes) 34 | self.bn2 = nn.BatchNorm2d(planes) 35 | self.downsample = downsample 36 | self.stride = stride 37 | 38 | def forward(self, x): 39 | residual = x 40 | 41 | out = self.conv1(x) 42 | out = self.bn1(out) 43 | out = self.relu(out) 44 | 45 | out = self.conv2(out) 46 | out = self.bn2(out) 47 | 48 | if self.downsample is not None: 49 | residual = self.downsample(x) 50 | 51 | out += residual 52 | out = self.relu(out) 53 | 54 | return out 55 | 56 | 57 | class Bottleneck(nn.Module): 58 | expansion = 4 59 | 60 | def __init__(self, inplanes, planes, stride=1, downsample=None): 61 | super(Bottleneck, self).__init__() 62 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False) # change 63 | self.bn1 = nn.BatchNorm2d(planes) 64 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, # change 65 | padding=1, bias=False) 66 | self.bn2 = nn.BatchNorm2d(planes) 67 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 68 | self.bn3 = nn.BatchNorm2d(planes * 4) 69 | self.relu = nn.ReLU(inplace=True) 70 | self.downsample = downsample 71 | self.stride = stride 72 | 73 | def forward(self, x): 74 | residual = x 75 | 76 | out = self.conv1(x) 77 | out = self.bn1(out) 78 | out = self.relu(out) 79 | 80 | out = self.conv2(out) 81 | out = self.bn2(out) 82 | out = self.relu(out) 83 | 84 | out = self.conv3(out) 85 | out = self.bn3(out) 86 | 87 | if self.downsample is not None: 88 | residual = self.downsample(x) 89 | 90 | out += residual 91 | out = self.relu(out) 92 | 93 | return out 94 | 95 | 96 | class ResNet(nn.Module): 97 | def __init__(self, block, layers, num_classes=81):#1000 98 | self.inplanes = 64 99 | super(ResNet, self).__init__() 100 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 101 | bias=False) 102 | self.bn1 = nn.BatchNorm2d(64) 103 | self.relu = nn.ReLU(inplace=True) 104 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) # change 105 | self.layer1 = self._make_layer(block, 64, layers[0]) 106 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 107 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 108 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 109 | self.avgpool = nn.AvgPool2d(7) 110 | self.fc = nn.Linear(512 * block.expansion, num_classes) 111 | 112 | for m in self.modules(): 113 | if isinstance(m, nn.Conv2d): 114 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 115 | m.weight.data.normal_(0, math.sqrt(2. / n)) 116 | elif isinstance(m, nn.BatchNorm2d): 117 | m.weight.data.fill_(1) 118 | m.bias.data.zero_() 119 | 120 | def _make_layer(self, block, planes, blocks, stride=1): 121 | downsample = None 122 | if stride != 1 or self.inplanes != planes * block.expansion: 123 | downsample = nn.Sequential( 124 | nn.Conv2d(self.inplanes, planes * block.expansion, 125 | kernel_size=1, stride=stride, bias=False), 126 | nn.BatchNorm2d(planes * block.expansion), 127 | ) 128 | 129 | layers = [] 130 | layers.append(block(self.inplanes, planes, stride, downsample)) 131 | self.inplanes = planes * block.expansion 132 | for i in range(1, blocks): 133 | layers.append(block(self.inplanes, planes)) 134 | 135 | return nn.Sequential(*layers) 136 | 137 | def forward(self, x): 138 | x = self.conv1(x) 139 | x = self.bn1(x) 140 | x = self.relu(x) 141 | x = self.maxpool(x) 142 | 143 | x = self.layer1(x) 144 | x = self.layer2(x) 145 | x = self.layer3(x) 146 | x = self.layer4(x) 147 | 148 | # x = self.avgpool(x) 149 | # x = x.view(x.size(0), -1) 150 | # x = self.fc(x) 151 | fc = x.mean(3).mean(2).squeeze() 152 | fc = self.fc(fc) 153 | return x,fc 154 | 155 | 156 | def resnet18(pretrained=False): 157 | """Constructs a ResNet-18 model. 158 | 159 | Args: 160 | pretrained (bool): If True, returns a model pre-trained on ImageNet 161 | """ 162 | model = ResNet(BasicBlock, [2, 2, 2, 2]) 163 | if pretrained: 164 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 165 | return model 166 | 167 | 168 | def resnet34(pretrained=False): 169 | """Constructs a ResNet-34 model. 170 | 171 | Args: 172 | pretrained (bool): If True, returns a model pre-trained on ImageNet 173 | """ 174 | model = ResNet(BasicBlock, [3, 4, 6, 3]) 175 | if pretrained: 176 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 177 | return model 178 | 179 | 180 | def resnet50(pretrained=False): 181 | """Constructs a ResNet-50 model. 182 | 183 | Args: 184 | pretrained (bool): If True, returns a model pre-trained on ImageNet 185 | """ 186 | model = ResNet(Bottleneck, [3, 4, 6, 3]) 187 | if pretrained: 188 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 189 | return model 190 | 191 | 192 | def resnet101(pretrained=False): 193 | """Constructs a ResNet-101 model. 194 | 195 | Args: 196 | pretrained (bool): If True, returns a model pre-trained on ImageNet 197 | """ 198 | model = ResNet(Bottleneck, [3, 4, 23, 3]) 199 | if pretrained: 200 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 201 | return model 202 | 203 | 204 | def resnet152(pretrained=False): 205 | """Constructs a ResNet-152 model. 206 | 207 | Args: 208 | pretrained (bool): If True, returns a model pre-trained on ImageNet 209 | """ 210 | model = ResNet(Bottleneck, [3, 8, 36, 3]) 211 | if pretrained: 212 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 213 | return model 214 | -------------------------------------------------------------------------------- /misc/resnet_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import torch.nn.functional as F 5 | 6 | class myResnet(nn.Module): 7 | def __init__(self, resnet): 8 | super(myResnet, self).__init__() 9 | self.resnet = resnet 10 | 11 | def forward(self, img, att_size=14): 12 | x = img.unsqueeze(0) 13 | 14 | x = self.resnet.conv1(x) 15 | x = self.resnet.bn1(x) 16 | x = self.resnet.relu(x) 17 | x = self.resnet.maxpool(x) 18 | 19 | x = self.resnet.layer1(x) 20 | x = self.resnet.layer2(x) 21 | x = self.resnet.layer3(x) 22 | x = self.resnet.layer4(x) 23 | 24 | fc = x.mean(3).mean(2).squeeze() 25 | att = F.adaptive_avg_pool2d(x,[att_size,att_size]).squeeze().permute(1, 2, 0) 26 | 27 | return fc, att 28 | 29 | -------------------------------------------------------------------------------- /misc/rewards.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | import time 7 | import misc.utils as utils 8 | from collections import OrderedDict 9 | import torch 10 | from torch.autograd import Variable 11 | 12 | import sys 13 | sys.path.append("cider") 14 | from pyciderevalcap.ciderD.ciderD import CiderD 15 | 16 | 17 | CiderD_scorer = CiderD(df='coco_en-train-idxs') 18 | 19 | 20 | 21 | 22 | #CiderD_scorer = CiderD(df='corpus') 23 | 24 | def array_to_str(arr): 25 | out = '' 26 | for i in range(len(arr)): 27 | out += str(arr[i]) + ' ' 28 | if arr[i] == 0: 29 | break 30 | return out.strip() 31 | 32 | def get_self_critical_reward(model, fc_feats, att_feats, data, gen_result): 33 | batch_size = gen_result.size(0)# batch_size = sample_size * seq_per_img 34 | seq_per_img = batch_size // len(data['gts']) 35 | 36 | # get greedy decoding baseline 37 | greedy_res, _ = model.sample(Variable(fc_feats.data, volatile=True), Variable(att_feats.data, volatile=True)) 38 | 39 | res = OrderedDict() 40 | 41 | gen_result = gen_result.cpu().numpy() 42 | greedy_res = greedy_res.cpu().numpy() 43 | for i in range(batch_size): 44 | res[i] = [array_to_str(gen_result[i])] 45 | for i in range(batch_size): 46 | res[batch_size + i] = [array_to_str(greedy_res[i])] 47 | 48 | gts = OrderedDict() 49 | for i in range(len(data['gts'])): 50 | gts[i] = [array_to_str(data['gts'][i][j]) for j in range(len(data['gts'][i]))] 51 | 52 | #_, scores = Bleu(4).compute_score(gts, res) 53 | #scores = np.array(scores[3]) 54 | res = [{'image_id':i, 'caption': res[i]} for i in range(2 * batch_size)] 55 | gts = {i: gts[i % batch_size // seq_per_img] for i in range(2 * batch_size)} 56 | _, scores = CiderD_scorer.compute_score(gts, res) 57 | # print('Cider scores:', _) 58 | 59 | scores = scores[:batch_size] - scores[batch_size:] 60 | 61 | rewards = np.repeat(scores[:, np.newaxis], gen_result.shape[1], 1) 62 | 63 | return rewards 64 | -------------------------------------------------------------------------------- /misc/symlink.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Wed Oct 4 15:56:13 2017 5 | 6 | @author: nlp 7 | """ 8 | import os 9 | i = 0 10 | spath = '/data1/zsfx/wabywang/caption/010/data/aitalk_fc/' 11 | tpath = '/data1/zsfx/wabywang/caption/010/data/combinedtalk_fc/' 12 | for item in os.listdir(spath): 13 | i = i+1 14 | os.symlink(spath + item, 15 | tpath + item) 16 | print(i) 17 | 18 | 19 | #os.path.join('data/cocotalk_fc', str(self.info['images'][ix]['id']) + '.npy'), 20 | -------------------------------------------------------------------------------- /misc/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import collections 6 | import torch 7 | import torch.nn as nn 8 | from torch.autograd import Variable 9 | import numpy as np 10 | 11 | import misc.resnet as resnet 12 | import os 13 | import random 14 | 15 | def set_bn_fix(m): 16 | classname = m.__class__.__name__ 17 | if classname.find('BatchNorm') != -1: 18 | for p in m.parameters(): 19 | p.requires_grad=False 20 | 21 | def set_bn_eval(m): 22 | classname = m.__class__.__name__ 23 | if classname.find('BatchNorm') != -1: 24 | m.eval() 25 | 26 | def build_cnn(opt): 27 | net = getattr(resnet, opt.cnn_model)() 28 | if vars(opt).get('start_from', None) is None and vars(opt).get('cnn_weight', '') != '': 29 | print(vars(opt).get('start_from')+'/load pretrained reset101') 30 | # net.load_state_dict(torch.load(opt.cnn_weight)) 31 | pretrained_m = torch.load(opt.cnn_weight) 32 | 33 | if vars(opt).get('start_from', None) is not None: 34 | print(vars(opt).get('start_from')+'/model-cnn.pth') 35 | # net.load_state_dict(torch.load(os.path.join(opt.start_from, 'model-cnn.pth'))) 36 | pretrained_m = torch.load(os.path.join(opt.start_from, 'model-cnn.pth')) 37 | 38 | net_dict = net.state_dict() 39 | pretrained_m = { k:v for k,v in pretrained_m.iteritems() 40 | if k in net_dict and v.size() == net_dict[k].size() } 41 | net_dict.update(pretrained_m) 42 | net.load_state_dict(net_dict) 43 | 44 | # compact_net = nn.Sequential(\ 45 | # net.conv1, 46 | # net.bn1, 47 | # net.relu, 48 | # net.maxpool, 49 | # net.layer1, 50 | # net.layer2, 51 | # net.layer3, 52 | # net.layer4) 53 | # 54 | # return net,compact_net 55 | return net 56 | 57 | def prepro_images(imgs, data_augment=False): 58 | # crop the image 59 | h,w = imgs.shape[2], imgs.shape[3] 60 | cnn_input_size = 224 61 | 62 | # cropping data augmentation, if needed 63 | if h > cnn_input_size or w > cnn_input_size: 64 | if data_augment: 65 | xoff, yoff = random.randint(0, w-cnn_input_size), random.randint(0, h-cnn_input_size) 66 | else: 67 | # sample the center 68 | xoff, yoff = (w-cnn_input_size)//2, (h-cnn_input_size)//2 69 | # crop. 70 | imgs = imgs[:,:, yoff:yoff+cnn_input_size, xoff:xoff+cnn_input_size] 71 | 72 | return imgs 73 | 74 | def if_use_att(caption_model): 75 | # Decide if load attention feature according to caption model 76 | if caption_model in ['show_tell', 'all_img', 'fc']: 77 | return False 78 | return True 79 | 80 | # Input: seq, N*D numpy array, with element 0 .. vocab_size. 0 is END token. 81 | def decode_sequence(ix_to_word, seq): 82 | 83 | N, D = seq.size() 84 | out = [] 85 | 86 | 87 | for i in range(N): 88 | txt = '' 89 | for j in range(D): 90 | ix = seq[i,j] 91 | if ix > 0 : 92 | if j >= 1: 93 | txt = txt + ' ' 94 | 95 | txt = txt + ix_to_word.get(str(ix),'unknown_token') 96 | else: 97 | break 98 | out.append(txt) 99 | 100 | return out 101 | 102 | def to_contiguous(tensor): 103 | if tensor.is_contiguous(): 104 | return tensor 105 | else: 106 | return tensor.contiguous() 107 | 108 | class RewardCriterion(nn.Module): 109 | def __init__(self): 110 | super(RewardCriterion, self).__init__() 111 | 112 | def forward(self, input, seq, reward): 113 | input = to_contiguous(input).view(-1) 114 | reward = to_contiguous(reward).view(-1) 115 | mask = (seq>0).float() 116 | mask = to_contiguous(torch.cat([mask.new(mask.size(0), 1).fill_(1), mask[:, :-1]], 1)).view(-1) 117 | output = - input * reward * Variable(mask) 118 | output = torch.sum(output) / torch.sum(mask) 119 | 120 | return output 121 | class LanguageModelCriterion(nn.Module): 122 | def __init__(self): 123 | super(LanguageModelCriterion, self).__init__() 124 | 125 | def forward(self, input, target, mask): 126 | # truncate to the same size 127 | target = target[:, :input.size(1)] 128 | mask = mask[:, :input.size(1)] 129 | input = to_contiguous(input).view(-1, input.size(2)) 130 | target = to_contiguous(target).view(-1, 1) 131 | mask = to_contiguous(mask).view(-1, 1) 132 | output = - input.gather(1, target) * mask 133 | output = torch.sum(output) / torch.sum(mask) 134 | 135 | return output 136 | class LanguageModel_CCG_Criterion(nn.Module): 137 | def __init__(self): 138 | super(LanguageModel_CCG_Criterion, self).__init__() 139 | 140 | def forward(self, word_labels,ccg_labels, word_target,ccg_target, mask): 141 | # truncate to the same size 142 | word_target = word_target[:, :word_labels.size(1)] 143 | mask = mask[:, :word_labels.size(1)] 144 | word_labels = to_contiguous(word_labels).view(-1, word_labels.size(2)) 145 | word_target = to_contiguous(word_target).view(-1, 1) 146 | mask = to_contiguous(mask).view(-1, 1) 147 | output = - word_labels.gather(1, word_target) * mask 148 | output_word = torch.sum(output) / torch.sum(mask) 149 | 150 | 151 | ccg_target = ccg_target[:, :ccg_labels.size(1)] 152 | mask = mask[:, :ccg_labels.size(1)] 153 | ccg_labels = to_contiguous(ccg_labels).view(-1, ccg_labels.size(2)) 154 | ccg_target = to_contiguous(ccg_target).view(-1, 1) 155 | mask = to_contiguous(mask).view(-1, 1) 156 | output = - ccg_labels.gather(1, ccg_target) * mask 157 | output_ccg = torch.sum(output) / torch.sum(mask) 158 | 159 | return output_word,output_ccg 160 | def set_lr(optimizer, lr): 161 | for group in optimizer.param_groups: 162 | group['lr'] = lr 163 | 164 | def clip_gradient(optimizer, grad_clip): 165 | for group in optimizer.param_groups: 166 | for param in group['params']: 167 | if param.grad is not None and param.requires_grad: 168 | param.grad.data.clamp_(-grad_clip, grad_clip) 169 | -------------------------------------------------------------------------------- /models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andyweizhao/Multitask_Image_Captioning/c672fe480618fccce1239600a394cf62f0b32719/models/.DS_Store -------------------------------------------------------------------------------- /models/Att2inModel.py: -------------------------------------------------------------------------------- 1 | # This file contains att2in model 2 | # Att2in is from Self-critical Sequence Training for Image Captioning 3 | # https://arxiv.org/abs/1612.00563 4 | # In this file we only have Att2in2, which is a slightly different version of att2in, 5 | # in which the img feature embedding and word embedding is the same as what in adaatt. 6 | 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | from torch.autograd import * 16 | import misc.utils as utils 17 | 18 | class Att2inCore(nn.Module): 19 | def __init__(self, opt): 20 | super(Att2inCore, self).__init__() 21 | self.input_encoding_size = opt.input_encoding_size 22 | #self.rnn_type = opt.rnn_type 23 | self.rnn_size = opt.rnn_size 24 | #self.num_layers = opt.num_layers 25 | self.drop_prob_lm = opt.drop_prob_lm 26 | self.fc_feat_size = opt.fc_feat_size 27 | self.att_feat_size = opt.att_feat_size 28 | self.att_hid_size = opt.att_hid_size 29 | 30 | # Build a LSTM 31 | self.a2c = nn.Linear(self.att_feat_size, 2 * self.rnn_size) 32 | self.i2h = nn.Linear(self.input_encoding_size, 5 * self.rnn_size) 33 | self.h2h = nn.Linear(self.rnn_size, 5 * self.rnn_size) 34 | self.dropout = nn.Dropout(self.drop_prob_lm) 35 | 36 | self.h2att = nn.Linear(self.rnn_size, self.att_hid_size) 37 | self.alpha_net = nn.Linear(self.att_hid_size, 1) 38 | 39 | def forward(self, xt, fc_feats, att_feats, p_att_feats, state): 40 | # The p_att_feats here is already projected 41 | att_size = att_feats.numel() // att_feats.size(0) // self.att_feat_size 42 | att = p_att_feats.view(-1, att_size, self.att_hid_size) 43 | 44 | att_h = self.h2att(state[0][-1]) # batch * att_hid_size 45 | att_h = att_h.unsqueeze(1).expand_as(att) # batch * att_size * att_hid_size 46 | dot = att + att_h # batch * att_size * att_hid_size 47 | dot = F.tanh(dot) # batch * att_size * att_hid_size 48 | dot = dot.view(-1, self.att_hid_size) # (batch * att_size) * att_hid_size 49 | dot = self.alpha_net(dot) # (batch * att_size) * 1 50 | dot = dot.view(-1, att_size) # batch * att_size 51 | 52 | weight = F.softmax(dot) # batch * att_size 53 | att_feats_ = att_feats.view(-1, att_size, self.att_feat_size) # batch * att_size * att_feat_size 54 | att_res = torch.bmm(weight.unsqueeze(1), att_feats_).squeeze(1) # batch * att_feat_size 55 | 56 | all_input_sums = self.i2h(xt) + self.h2h(state[0][-1]) 57 | sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size) 58 | sigmoid_chunk = F.sigmoid(sigmoid_chunk) 59 | in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size) 60 | forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size) 61 | out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size) 62 | 63 | in_transform = all_input_sums.narrow(1, 3 * self.rnn_size, 2 * self.rnn_size) + \ 64 | self.a2c(att_res) 65 | in_transform = torch.max(\ 66 | in_transform.narrow(1, 0, self.rnn_size), 67 | in_transform.narrow(1, self.rnn_size, self.rnn_size)) 68 | next_c = forget_gate * state[1][-1] + in_gate * in_transform 69 | next_h = out_gate * F.tanh(next_c) 70 | 71 | output = self.dropout(next_h) 72 | state = (next_h.unsqueeze(0), next_c.unsqueeze(0)) 73 | return output, state 74 | 75 | class Att2inModel(nn.Module): 76 | def __init__(self, opt): 77 | super(Att2inModel, self).__init__() 78 | self.vocab_size = opt.vocab_size 79 | self.input_encoding_size = opt.input_encoding_size 80 | #self.rnn_type = opt.rnn_type 81 | self.rnn_size = opt.rnn_size 82 | self.num_layers = 1 83 | self.drop_prob_lm = opt.drop_prob_lm 84 | self.seq_length = opt.seq_length 85 | self.fc_feat_size = opt.fc_feat_size 86 | self.att_feat_size = opt.att_feat_size 87 | self.att_hid_size = opt.att_hid_size 88 | 89 | self.ss_prob = 0.0 # Schedule sampling probability 90 | 91 | self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size) 92 | self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1) 93 | self.ctx2att = nn.Linear(self.att_feat_size, self.att_hid_size) 94 | self.core = Att2inCore(opt) 95 | 96 | self.init_weights() 97 | 98 | def init_weights(self): 99 | initrange = 0.1 100 | self.embed.weight.data.uniform_(-initrange, initrange) 101 | self.logit.bias.data.fill_(0) 102 | self.logit.weight.data.uniform_(-initrange, initrange) 103 | 104 | def init_hidden(self, bsz): 105 | weight = next(self.parameters()).data 106 | return (Variable(weight.new(self.num_layers, bsz, self.rnn_size).zero_()), 107 | Variable(weight.new(self.num_layers, bsz, self.rnn_size).zero_())) 108 | 109 | def forward(self, fc_feats, att_feats, seq): 110 | batch_size = fc_feats.size(0) 111 | state = self.init_hidden(batch_size) 112 | 113 | outputs = [] 114 | 115 | # Project the attention feats first to reduce memory and computation comsumptions. 116 | p_att_feats = self.ctx2att(att_feats.view(-1, self.att_feat_size)) 117 | p_att_feats = p_att_feats.view(*(att_feats.size()[:-1] + (self.att_hid_size,))) 118 | 119 | for i in range(seq.size(1) - 1): 120 | if self.training and i >= 1 and self.ss_prob > 0.0: # otherwiste no need to sample 121 | sample_prob = fc_feats.data.new(batch_size).uniform_(0, 1) 122 | sample_mask = sample_prob < self.ss_prob 123 | if sample_mask.sum() == 0: 124 | it = seq[:, i].clone() 125 | else: 126 | sample_ind = sample_mask.nonzero().view(-1) 127 | it = seq[:, i].data.clone() 128 | #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1) 129 | #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1)) 130 | prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1) 131 | it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind)) 132 | it = Variable(it, requires_grad=False) 133 | else: 134 | it = seq[:, i].clone() 135 | # break if all the sequences end 136 | if i >= 1 and seq[:, i].data.sum() == 0: 137 | break 138 | 139 | xt = self.embed(it) 140 | 141 | output, state = self.core(xt, fc_feats, att_feats, p_att_feats, state) 142 | output = F.log_softmax(self.logit(output)) 143 | outputs.append(output) 144 | 145 | return torch.cat([_.unsqueeze(1) for _ in outputs], 1) 146 | 147 | def sample_beam(self, fc_feats, att_feats, opt={}): 148 | beam_size = opt.get('beam_size', 10) 149 | batch_size = fc_feats.size(0) 150 | 151 | # Project the attention feats first to reduce memory and computation comsumptions. 152 | p_att_feats = self.ctx2att(att_feats.view(-1, self.att_feat_size)) 153 | p_att_feats = p_att_feats.view(*(att_feats.size()[:-1] + (self.att_hid_size,))) 154 | 155 | assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' 156 | seq = torch.LongTensor(self.seq_length, batch_size).zero_() 157 | seqLogprobs = torch.FloatTensor(self.seq_length, batch_size) 158 | # lets process every image independently for now, for simplicity 159 | 160 | self.done_beams = [[] for _ in range(batch_size)] 161 | for k in range(batch_size): 162 | state = self.init_hidden(beam_size) 163 | tmp_fc_feats = fc_feats[k:k+1].expand(beam_size, self.fc_feat_size) 164 | tmp_att_feats = att_feats[k:k+1].expand(*((beam_size,)+att_feats.size()[1:])).contiguous() 165 | tmp_p_att_feats = p_att_feats[k:k+1].expand(*((beam_size,)+p_att_feats.size()[1:])).contiguous() 166 | 167 | beam_seq = torch.LongTensor(self.seq_length, beam_size).zero_() 168 | beam_seq_logprobs = torch.FloatTensor(self.seq_length, beam_size).zero_() 169 | beam_logprobs_sum = torch.zeros(beam_size) # running sum of logprobs for each beam 170 | done_beams = [] 171 | for t in range(self.seq_length + 1): 172 | if t == 0: # input 173 | it = fc_feats.data.new(beam_size).long().zero_() 174 | xt = self.embed(Variable(it, requires_grad=False)) 175 | else: 176 | """pem a beam merge. that is, 177 | for every previous beam we now many new possibilities to branch out 178 | we need to resort our beams to maintain the loop invariant of keeping 179 | the top beam_size most likely sequences.""" 180 | logprobsf = logprobs.float() # lets go to CPU for more efficiency in indexing operations 181 | ys,ix = torch.sort(logprobsf,1,True) # sorted array of logprobs along each previous beam (last true = descending) 182 | candidates = [] 183 | cols = min(beam_size, ys.size(1)) 184 | rows = beam_size 185 | if t == 1: # at first time step only the first beam is active 186 | rows = 1 187 | for c in range(cols): 188 | for q in range(rows): 189 | # compute logprob of expanding beam q with word in (sorted) position c 190 | local_logprob = ys[q,c] 191 | candidate_logprob = beam_logprobs_sum[q] + local_logprob 192 | candidates.append({'c':ix.data[q,c], 'q':q, 'p':candidate_logprob.data[0], 'r':local_logprob.data[0]}) 193 | candidates = sorted(candidates, key=lambda x: -x['p']) 194 | 195 | # construct new beams 196 | new_state = [_.clone() for _ in state] 197 | if t > 1: 198 | # well need these as reference when we fork beams around 199 | beam_seq_prev = beam_seq[:t-1].clone() 200 | beam_seq_logprobs_prev = beam_seq_logprobs[:t-1].clone() 201 | for vix in range(beam_size): 202 | v = candidates[vix] 203 | # fork beam index q into index vix 204 | if t > 1: 205 | beam_seq[:t-1, vix] = beam_seq_prev[:, v['q']] 206 | beam_seq_logprobs[:t-1, vix] = beam_seq_logprobs_prev[:, v['q']] 207 | 208 | # rearrange recurrent states 209 | for state_ix in range(len(new_state)): 210 | # copy over state in previous beam q to new beam at vix 211 | new_state[state_ix][0, vix] = state[state_ix][0, v['q']] # dimension one is time step 212 | 213 | # append new end terminal at the end of this beam 214 | beam_seq[t-1, vix] = v['c'] # c'th word is the continuation 215 | beam_seq_logprobs[t-1, vix] = v['r'] # the raw logprob here 216 | beam_logprobs_sum[vix] = v['p'] # the new (sum) logprob along this beam 217 | 218 | if v['c'] == 0 or t == self.seq_length: 219 | # END token special case here, or we reached the end. 220 | # add the beam to a set of done beams 221 | self.done_beams[k].append({'seq': beam_seq[:, vix].clone(), 222 | 'logps': beam_seq_logprobs[:, vix].clone(), 223 | 'p': beam_logprobs_sum[vix] 224 | }) 225 | 226 | # encode as vectors 227 | it = beam_seq[t-1] 228 | xt = self.embed(Variable(it.cuda())) 229 | 230 | if t >= 1: 231 | state = new_state 232 | 233 | output, state = self.core(xt, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, state) 234 | logprobs = F.log_softmax(self.logit(output)) 235 | 236 | self.done_beams[k] = sorted(self.done_beams[k], key=lambda x: -x['p']) 237 | seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score 238 | seqLogprobs[:, k] = self.done_beams[k][0]['logps'] 239 | # return the samples and their log likelihoods 240 | return seq.transpose(0, 1), seqLogprobs.transpose(0, 1) 241 | 242 | def sample(self, fc_feats, att_feats, opt={}): 243 | sample_max = opt.get('sample_max', 1) 244 | beam_size = opt.get('beam_size', 1) 245 | temperature = opt.get('temperature', 1.0) 246 | if beam_size > 1: 247 | return self.sample_beam(fc_feats, att_feats, opt) 248 | 249 | batch_size = fc_feats.size(0) 250 | state = self.init_hidden(batch_size) 251 | 252 | # Project the attention feats first to reduce memory and computation comsumptions. 253 | p_att_feats = self.ctx2att(att_feats.view(-1, self.att_feat_size)) 254 | p_att_feats = p_att_feats.view(*(att_feats.size()[:-1] + (self.att_hid_size,))) 255 | 256 | seq = [] 257 | seqLogprobs = [] 258 | for t in range(self.seq_length + 1): 259 | if t == 0: # input 260 | it = fc_feats.data.new(batch_size).long().zero_() 261 | elif sample_max: 262 | sampleLogprobs, it = torch.max(logprobs.data, 1) 263 | it = it.view(-1).long() 264 | else: 265 | if temperature == 1.0: 266 | prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1) 267 | else: 268 | # scale logprobs by temperature 269 | prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu() 270 | it = torch.multinomial(prob_prev, 1).cuda() 271 | sampleLogprobs = logprobs.gather(1, Variable(it, requires_grad=False)) # gather the logprobs at sampled positions 272 | it = it.view(-1).long() # and flatten indices for downstream processing 273 | 274 | xt = self.embed(Variable(it, requires_grad=False)) 275 | 276 | if t >= 1: 277 | # stop when all finished 278 | if t == 1: 279 | unfinished = it > 0 280 | else: 281 | unfinished = unfinished * (it > 0) 282 | if unfinished.sum() == 0: 283 | break 284 | it = it * unfinished.type_as(it) 285 | seq.append(it) #seq[t] the input of t+2 time step 286 | 287 | seqLogprobs.append(sampleLogprobs.view(-1)) 288 | 289 | output, state = self.core(xt, fc_feats, att_feats, p_att_feats, state) 290 | logprobs = F.log_softmax(self.logit(output)) 291 | 292 | return torch.cat([_.unsqueeze(1) for _ in seq], 1), torch.cat([_.unsqueeze(1) for _ in seqLogprobs], 1) -------------------------------------------------------------------------------- /models/AttModel.py: -------------------------------------------------------------------------------- 1 | # This file contains Att2in2, AdaAtt, AdaAttMO, TopDown model 2 | 3 | # AdaAtt is from Knowing When to Look: Adaptive Attention via A Visual Sentinel for Image Captioning 4 | # https://arxiv.org/abs/1612.01887 5 | # AdaAttMO is a modified version with maxout lstm 6 | 7 | # Att2in is from Self-critical Sequence Training for Image Captioning 8 | # https://arxiv.org/abs/1612.00563 9 | # In this file we only have Att2in2, which is a slightly different version of att2in, 10 | # in which the img feature embedding and word embedding is the same as what in adaatt. 11 | 12 | # TopDown is from Bottom-Up and Top-Down Attention for Image Captioning and VQA 13 | # https://arxiv.org/abs/1707.07998 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import torch 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | from torch.autograd import * 23 | import misc.utils as utils 24 | 25 | from .CaptionModel import CaptionModel 26 | 27 | class AttModel(CaptionModel): 28 | def __init__(self, opt): 29 | super(AttModel, self).__init__() 30 | self.vocab_size = opt.vocab_size 31 | self.input_encoding_size = opt.input_encoding_size 32 | #self.rnn_type = opt.rnn_type 33 | self.rnn_size = opt.rnn_size 34 | self.num_layers = opt.num_layers 35 | self.drop_prob_lm = opt.drop_prob_lm 36 | self.seq_length = opt.seq_length 37 | self.fc_feat_size = opt.fc_feat_size 38 | self.att_feat_size = opt.att_feat_size 39 | self.att_hid_size = opt.att_hid_size 40 | 41 | self.ss_prob = 0.0 # Schedule sampling probability 42 | 43 | self.embed = nn.Sequential(nn.Embedding(self.vocab_size + 1, self.input_encoding_size), 44 | nn.ReLU(), 45 | nn.Dropout(self.drop_prob_lm)) 46 | self.fc_embed = nn.Sequential(nn.Linear(self.fc_feat_size, self.rnn_size), 47 | nn.ReLU(), 48 | nn.Dropout(self.drop_prob_lm)) 49 | self.att_embed = nn.Sequential(nn.Linear(self.att_feat_size, self.rnn_size), 50 | nn.ReLU(), 51 | nn.Dropout(self.drop_prob_lm)) 52 | self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1) 53 | self.ctx2att = nn.Linear(self.rnn_size, self.att_hid_size) 54 | 55 | def init_hidden(self, bsz): 56 | weight = next(self.parameters()).data 57 | return (Variable(weight.new(self.num_layers, bsz, self.rnn_size).zero_()), 58 | Variable(weight.new(self.num_layers, bsz, self.rnn_size).zero_())) 59 | 60 | def forward(self, fc_feats, att_feats, seq): 61 | batch_size = fc_feats.size(0) 62 | state = self.init_hidden(batch_size) 63 | 64 | outputs = [] 65 | 66 | # embed fc and att feats 67 | fc_feats = self.fc_embed(fc_feats) 68 | _att_feats = self.att_embed(att_feats.view(-1, self.att_feat_size)) 69 | att_feats = _att_feats.view(*(att_feats.size()[:-1] + (self.rnn_size,))) 70 | 71 | # Project the attention feats first to reduce memory and computation comsumptions. 72 | p_att_feats = self.ctx2att(att_feats.view(-1, self.rnn_size)) 73 | p_att_feats = p_att_feats.view(*(att_feats.size()[:-1] + (self.att_hid_size,))) 74 | 75 | for i in range(seq.size(1) - 1): 76 | if self.training and i >= 1 and self.ss_prob > 0.0: # otherwiste no need to sample 77 | sample_prob = fc_feats.data.new(batch_size).uniform_(0, 1) 78 | sample_mask = sample_prob < self.ss_prob 79 | if sample_mask.sum() == 0: 80 | it = seq[:, i].clone() 81 | else: 82 | sample_ind = sample_mask.nonzero().view(-1) 83 | it = seq[:, i].data.clone() 84 | #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1) 85 | #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1)) 86 | prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1) 87 | it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind)) 88 | it = Variable(it, requires_grad=False) 89 | else: 90 | it = seq[:, i].clone() 91 | # break if all the sequences end 92 | if i >= 1 and seq[:, i].data.sum() == 0: 93 | break 94 | 95 | xt = self.embed(it) 96 | 97 | output, state = self.core(xt, fc_feats, att_feats, p_att_feats, state) 98 | output = F.log_softmax(self.logit(output)) 99 | outputs.append(output) 100 | 101 | return torch.cat([_.unsqueeze(1) for _ in outputs], 1) 102 | 103 | def get_logprobs_state(self, it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, state): 104 | # 'it' is Variable contraining a word index 105 | xt = self.embed(it) 106 | 107 | output, state = self.core(xt, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, state) 108 | logprobs = F.log_softmax(self.logit(output)) 109 | 110 | return logprobs, state 111 | 112 | def sample_beam(self, fc_feats, att_feats, opt={}): 113 | beam_size = opt.get('beam_size', 10) 114 | batch_size = fc_feats.size(0) 115 | 116 | # embed fc and att feats 117 | fc_feats = self.fc_embed(fc_feats) 118 | _att_feats = self.att_embed(att_feats.view(-1, self.att_feat_size)) 119 | att_feats = _att_feats.view(*(att_feats.size()[:-1] + (self.rnn_size,))) 120 | 121 | # Project the attention feats first to reduce memory and computation comsumptions. 122 | p_att_feats = self.ctx2att(att_feats.view(-1, self.rnn_size)) 123 | p_att_feats = p_att_feats.view(*(att_feats.size()[:-1] + (self.att_hid_size,))) 124 | 125 | assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' 126 | seq = torch.LongTensor(self.seq_length, batch_size).zero_() 127 | seqLogprobs = torch.FloatTensor(self.seq_length, batch_size) 128 | # lets process every image independently for now, for simplicity 129 | 130 | self.done_beams = [[] for _ in range(batch_size)] 131 | for k in range(batch_size): 132 | state = self.init_hidden(beam_size) 133 | tmp_fc_feats = fc_feats[k:k+1].expand(beam_size, fc_feats.size(1)) 134 | tmp_att_feats = att_feats[k:k+1].expand(*((beam_size,)+att_feats.size()[1:])).contiguous() 135 | tmp_p_att_feats = p_att_feats[k:k+1].expand(*((beam_size,)+p_att_feats.size()[1:])).contiguous() 136 | 137 | for t in range(1): 138 | if t == 0: # input 139 | it = fc_feats.data.new(beam_size).long().zero_() 140 | xt = self.embed(Variable(it, requires_grad=False)) 141 | 142 | output, state = self.core(xt, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, state) 143 | logprobs = F.log_softmax(self.logit(output)) 144 | 145 | self.done_beams[k] = self.beam_search(state, logprobs, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, opt=opt) 146 | seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score 147 | seqLogprobs[:, k] = self.done_beams[k][0]['logps'] 148 | # return the samples and their log likelihoods 149 | return seq.transpose(0, 1), seqLogprobs.transpose(0, 1) 150 | 151 | def sample(self, fc_feats, att_feats, opt={}): 152 | sample_max = opt.get('sample_max', 1) 153 | beam_size = opt.get('beam_size', 1) 154 | temperature = opt.get('temperature', 1.0) 155 | if beam_size > 1: 156 | return self.sample_beam(fc_feats, att_feats, opt) 157 | 158 | batch_size = fc_feats.size(0) 159 | state = self.init_hidden(batch_size) 160 | 161 | # embed fc and att feats 162 | fc_feats = self.fc_embed(fc_feats) 163 | _att_feats = self.att_embed(att_feats.view(-1, self.att_feat_size)) 164 | att_feats = _att_feats.view(*(att_feats.size()[:-1] + (self.rnn_size,))) 165 | 166 | # Project the attention feats first to reduce memory and computation comsumptions. 167 | p_att_feats = self.ctx2att(att_feats.view(-1, self.rnn_size)) 168 | p_att_feats = p_att_feats.view(*(att_feats.size()[:-1] + (self.att_hid_size,))) 169 | 170 | seq = [] 171 | seqLogprobs = [] 172 | for t in range(self.seq_length + 1): 173 | if t == 0: # input 174 | it = fc_feats.data.new(batch_size).long().zero_() 175 | elif sample_max: 176 | sampleLogprobs, it = torch.max(logprobs.data, 1) 177 | it = it.view(-1).long() 178 | else: 179 | if temperature == 1.0: 180 | prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1) 181 | else: 182 | # scale logprobs by temperature 183 | prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu() 184 | it = torch.multinomial(prob_prev, 1).cuda() 185 | sampleLogprobs = logprobs.gather(1, Variable(it, requires_grad=False)) # gather the logprobs at sampled positions 186 | it = it.view(-1).long() # and flatten indices for downstream processing 187 | 188 | xt = self.embed(Variable(it, requires_grad=False)) 189 | 190 | if t >= 1: 191 | # stop when all finished 192 | if t == 1: 193 | unfinished = it > 0 194 | else: 195 | unfinished = unfinished * (it > 0) 196 | if unfinished.sum() == 0: 197 | break 198 | it = it * unfinished.type_as(it) 199 | seq.append(it) #seq[t] the input of t+2 time step 200 | 201 | seqLogprobs.append(sampleLogprobs.view(-1)) 202 | 203 | output, state = self.core(xt, fc_feats, att_feats, p_att_feats, state) 204 | logprobs = F.log_softmax(self.logit(output)) 205 | 206 | return torch.cat([_.unsqueeze(1) for _ in seq], 1), torch.cat([_.unsqueeze(1) for _ in seqLogprobs], 1) 207 | 208 | class AdaAtt_lstm(nn.Module): 209 | def __init__(self, opt, use_maxout=True): 210 | super(AdaAtt_lstm, self).__init__() 211 | self.input_encoding_size = opt.input_encoding_size 212 | #self.rnn_type = opt.rnn_type 213 | self.rnn_size = opt.rnn_size 214 | self.num_layers = opt.num_layers 215 | self.drop_prob_lm = opt.drop_prob_lm 216 | self.fc_feat_size = opt.fc_feat_size 217 | self.att_feat_size = opt.att_feat_size 218 | self.att_hid_size = opt.att_hid_size 219 | 220 | self.use_maxout = use_maxout 221 | 222 | # Build a LSTM 223 | self.w2h = nn.Linear(self.input_encoding_size, (4+(use_maxout==True)) * self.rnn_size) 224 | self.v2h = nn.Linear(self.rnn_size, (4+(use_maxout==True)) * self.rnn_size) 225 | 226 | self.i2h = nn.ModuleList([nn.Linear(self.rnn_size, (4+(use_maxout==True)) * self.rnn_size) for _ in range(self.num_layers - 1)]) 227 | self.h2h = nn.ModuleList([nn.Linear(self.rnn_size, (4+(use_maxout==True)) * self.rnn_size) for _ in range(self.num_layers)]) 228 | 229 | # Layers for getting the fake region 230 | if self.num_layers == 1: 231 | self.r_w2h = nn.Linear(self.input_encoding_size, self.rnn_size) 232 | self.r_v2h = nn.Linear(self.rnn_size, self.rnn_size) 233 | else: 234 | self.r_i2h = nn.Linear(self.rnn_size, self.rnn_size) 235 | self.r_h2h = nn.Linear(self.rnn_size, self.rnn_size) 236 | 237 | 238 | def forward(self, xt, img_fc, state): 239 | 240 | hs = [] 241 | cs = [] 242 | for L in range(self.num_layers): 243 | # c,h from previous timesteps 244 | prev_h = state[0][L] 245 | prev_c = state[1][L] 246 | # the input to this layer 247 | if L == 0: 248 | x = xt 249 | i2h = self.w2h(x) + self.v2h(img_fc) 250 | else: 251 | x = hs[-1] 252 | x = F.dropout(x, self.drop_prob_lm, self.training) 253 | i2h = self.i2h[L-1](x) 254 | 255 | all_input_sums = i2h+self.h2h[L](prev_h) 256 | 257 | sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size) 258 | sigmoid_chunk = F.sigmoid(sigmoid_chunk) 259 | # decode the gates 260 | in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size) 261 | forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size) 262 | out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size) 263 | # decode the write inputs 264 | if not self.use_maxout: 265 | in_transform = F.tanh(all_input_sums.narrow(1, 3 * self.rnn_size, self.rnn_size)) 266 | else: 267 | in_transform = all_input_sums.narrow(1, 3 * self.rnn_size, 2 * self.rnn_size) 268 | in_transform = torch.max(\ 269 | in_transform.narrow(1, 0, self.rnn_size), 270 | in_transform.narrow(1, self.rnn_size, self.rnn_size)) 271 | # perform the LSTM update 272 | next_c = forget_gate * prev_c + in_gate * in_transform 273 | # gated cells form the output 274 | tanh_nex_c = F.tanh(next_c) 275 | next_h = out_gate * tanh_nex_c 276 | if L == self.num_layers-1: 277 | if L == 0: 278 | i2h = self.r_w2h(x) + self.r_v2h(img_fc) 279 | else: 280 | i2h = self.r_i2h(x) 281 | n5 = i2h+self.r_h2h(prev_h) 282 | fake_region = F.sigmoid(n5) * tanh_nex_c 283 | 284 | cs.append(next_c) 285 | hs.append(next_h) 286 | 287 | # set up the decoder 288 | top_h = hs[-1] 289 | top_h = F.dropout(top_h, self.drop_prob_lm, self.training) 290 | fake_region = F.dropout(fake_region, self.drop_prob_lm, self.training) 291 | 292 | state = (torch.cat([_.unsqueeze(0) for _ in hs], 0), 293 | torch.cat([_.unsqueeze(0) for _ in cs], 0)) 294 | return top_h, fake_region, state 295 | 296 | class AdaAtt_attention(nn.Module): 297 | def __init__(self, opt): 298 | super(AdaAtt_attention, self).__init__() 299 | self.input_encoding_size = opt.input_encoding_size 300 | #self.rnn_type = opt.rnn_type 301 | self.rnn_size = opt.rnn_size 302 | self.drop_prob_lm = opt.drop_prob_lm 303 | self.att_hid_size = opt.att_hid_size 304 | 305 | # fake region embed 306 | self.fr_linear = nn.Sequential( 307 | nn.Linear(self.rnn_size, self.input_encoding_size), 308 | nn.ReLU(), 309 | nn.Dropout(self.drop_prob_lm)) 310 | self.fr_embed = nn.Linear(self.input_encoding_size, self.att_hid_size) 311 | 312 | # h out embed 313 | self.ho_linear = nn.Sequential( 314 | nn.Linear(self.rnn_size, self.input_encoding_size), 315 | nn.Tanh(), 316 | nn.Dropout(self.drop_prob_lm)) 317 | self.ho_embed = nn.Linear(self.input_encoding_size, self.att_hid_size) 318 | 319 | self.alpha_net = nn.Linear(self.att_hid_size, 1) 320 | self.att2h = nn.Linear(self.rnn_size, self.rnn_size) 321 | 322 | def forward(self, h_out, fake_region, conv_feat, conv_feat_embed): 323 | 324 | # View into three dimensions 325 | att_size = conv_feat.numel() // conv_feat.size(0) // self.rnn_size 326 | conv_feat = conv_feat.view(-1, att_size, self.rnn_size) 327 | conv_feat_embed = conv_feat_embed.view(-1, att_size, self.att_hid_size) 328 | 329 | # view neighbor from bach_size * neighbor_num x rnn_size to bach_size x rnn_size * neighbor_num 330 | fake_region = self.fr_linear(fake_region) 331 | fake_region_embed = self.fr_embed(fake_region) 332 | 333 | h_out_linear = self.ho_linear(h_out) 334 | h_out_embed = self.ho_embed(h_out_linear) 335 | 336 | txt_replicate = h_out_embed.unsqueeze(1).expand(h_out_embed.size(0), att_size + 1, h_out_embed.size(1)) 337 | 338 | img_all = torch.cat([fake_region.view(-1,1,self.input_encoding_size), conv_feat], 1) 339 | img_all_embed = torch.cat([fake_region_embed.view(-1,1,self.input_encoding_size), conv_feat_embed], 1) 340 | 341 | hA = F.tanh(img_all_embed + txt_replicate) 342 | hA = F.dropout(hA,self.drop_prob_lm, self.training) 343 | 344 | hAflat = self.alpha_net(hA.view(-1, self.att_hid_size)) 345 | PI = F.softmax(hAflat.view(-1, att_size + 1)) 346 | 347 | visAtt = torch.bmm(PI.unsqueeze(1), img_all) 348 | visAttdim = visAtt.squeeze(1) 349 | 350 | atten_out = visAttdim + h_out_linear 351 | 352 | h = F.tanh(self.att2h(atten_out)) 353 | h = F.dropout(h, self.drop_prob_lm, self.training) 354 | return h 355 | 356 | class AdaAttCore(nn.Module): 357 | def __init__(self, opt, use_maxout=False): 358 | super(AdaAttCore, self).__init__() 359 | self.lstm = AdaAtt_lstm(opt, use_maxout) 360 | self.attention = AdaAtt_attention(opt) 361 | 362 | def forward(self, xt, fc_feats, att_feats, p_att_feats, state): 363 | h_out, p_out, state = self.lstm(xt, fc_feats, state) 364 | atten_out = self.attention(h_out, p_out, att_feats, p_att_feats) 365 | return atten_out, state 366 | 367 | class TopDownCore(nn.Module): 368 | def __init__(self, opt, use_maxout=False): 369 | super(TopDownCore, self).__init__() 370 | self.drop_prob_lm = opt.drop_prob_lm 371 | 372 | self.att_lstm = nn.LSTMCell(opt.input_encoding_size + opt.rnn_size * 2, opt.rnn_size) # we, fc, h^2_t-1 373 | self.lang_lstm = nn.LSTMCell(opt.rnn_size * 2, opt.rnn_size) # h^1_t, \hat v 374 | self.attention = Attention(opt) 375 | 376 | def forward(self, xt, fc_feats, att_feats, p_att_feats, state): 377 | prev_h = state[0][-1] 378 | att_lstm_input = torch.cat([prev_h, fc_feats, xt], 1) 379 | 380 | h_att, c_att = self.att_lstm(att_lstm_input, (state[0][0], state[1][0])) 381 | 382 | att = self.attention(h_att, att_feats, p_att_feats) 383 | 384 | lang_lstm_input = torch.cat([att, h_att], 1) 385 | # lang_lstm_input = torch.cat([att, F.dropout(h_att, self.drop_prob_lm, self.training)], 1) ????? 386 | 387 | h_lang, c_lang = self.lang_lstm(lang_lstm_input, (state[0][1], state[1][1])) 388 | 389 | output = F.dropout(h_lang, self.drop_prob_lm, self.training) 390 | state = (torch.stack([h_att, h_lang]), torch.stack([c_att, c_lang])) 391 | 392 | return output, state 393 | 394 | class Attention(nn.Module): 395 | def __init__(self, opt): 396 | super(Attention, self).__init__() 397 | self.rnn_size = opt.rnn_size 398 | self.att_hid_size = opt.att_hid_size 399 | 400 | self.h2att = nn.Linear(self.rnn_size, self.att_hid_size) 401 | self.alpha_net = nn.Linear(self.att_hid_size, 1) 402 | 403 | def forward(self, h, att_feats, p_att_feats): 404 | # The p_att_feats here is already projected 405 | att_size = att_feats.numel() // att_feats.size(0) // self.rnn_size 406 | att = p_att_feats.view(-1, att_size, self.att_hid_size) 407 | 408 | att_h = self.h2att(h) # batch * att_hid_size 409 | att_h = att_h.unsqueeze(1).expand_as(att) # batch * att_size * att_hid_size 410 | dot = att + att_h # batch * att_size * att_hid_size 411 | dot = F.tanh(dot) # batch * att_size * att_hid_size 412 | dot = dot.view(-1, self.att_hid_size) # (batch * att_size) * att_hid_size 413 | dot = self.alpha_net(dot) # (batch * att_size) * 1 414 | dot = dot.view(-1, att_size) # batch * att_size 415 | 416 | weight = F.softmax(dot) # batch * att_size 417 | att_feats_ = att_feats.view(-1, att_size, self.rnn_size) # batch * att_size * att_feat_size 418 | att_res = torch.bmm(weight.unsqueeze(1), att_feats_).squeeze(1) # batch * att_feat_size 419 | 420 | return att_res 421 | 422 | 423 | class Att2in2Core(nn.Module): 424 | def __init__(self, opt): 425 | super(Att2in2Core, self).__init__() 426 | self.input_encoding_size = opt.input_encoding_size 427 | #self.rnn_type = opt.rnn_type 428 | self.rnn_size = opt.rnn_size 429 | #self.num_layers = opt.num_layers 430 | self.drop_prob_lm = opt.drop_prob_lm 431 | self.fc_feat_size = opt.fc_feat_size 432 | self.att_feat_size = opt.att_feat_size 433 | self.att_hid_size = opt.att_hid_size 434 | 435 | # Build a LSTM 436 | self.a2c = nn.Linear(self.rnn_size, 2 * self.rnn_size) 437 | self.i2h = nn.Linear(self.input_encoding_size, 5 * self.rnn_size) 438 | self.h2h = nn.Linear(self.rnn_size, 5 * self.rnn_size) 439 | self.dropout = nn.Dropout(self.drop_prob_lm) 440 | 441 | self.attention = Attention(opt) 442 | 443 | def forward(self, xt, fc_feats, att_feats, p_att_feats, state): 444 | att_res = self.attention(state[0][-1], att_feats, p_att_feats) 445 | 446 | all_input_sums = self.i2h(xt) + self.h2h(state[0][-1]) 447 | sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size) 448 | sigmoid_chunk = F.sigmoid(sigmoid_chunk) 449 | in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size) 450 | forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size) 451 | out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size) 452 | 453 | in_transform = all_input_sums.narrow(1, 3 * self.rnn_size, 2 * self.rnn_size) + \ 454 | self.a2c(att_res) 455 | in_transform = torch.max(\ 456 | in_transform.narrow(1, 0, self.rnn_size), 457 | in_transform.narrow(1, self.rnn_size, self.rnn_size)) 458 | next_c = forget_gate * state[1][-1] + in_gate * in_transform 459 | next_h = out_gate * F.tanh(next_c) 460 | 461 | output = self.dropout(next_h) 462 | state = (next_h.unsqueeze(0), next_c.unsqueeze(0)) 463 | return output, state 464 | 465 | class AdaAttModel(AttModel): 466 | def __init__(self, opt): 467 | super(AdaAttModel, self).__init__(opt) 468 | self.core = AdaAttCore(opt) 469 | 470 | # AdaAtt with maxout lstm 471 | class AdaAttMOModel(AttModel): 472 | def __init__(self, opt): 473 | super(AdaAttMOModel, self).__init__(opt) 474 | self.core = AdaAttCore(opt, True) 475 | 476 | class Att2in2Model(AttModel): 477 | def __init__(self, opt): 478 | super(Att2in2Model, self).__init__(opt) 479 | self.core = Att2in2Core(opt) 480 | delattr(self, 'fc_embed') 481 | self.fc_embed = lambda x : x 482 | 483 | class TopDownModel(AttModel): 484 | def __init__(self, opt): 485 | super(TopDownModel, self).__init__(opt) 486 | self.num_layers = 2 487 | self.core = TopDownCore(opt) 488 | -------------------------------------------------------------------------------- /models/AttModel_CCG.py: -------------------------------------------------------------------------------- 1 | # This file contains Att2in2, AdaAtt, AdaAttMO, TopDown model 2 | 3 | # AdaAtt is from Knowing When to Look: Adaptive Attention via A Visual Sentinel for Image Captioning 4 | # https://arxiv.org/abs/1612.01887 5 | # AdaAttMO is a modified version with maxout lstm 6 | 7 | # Att2in is from Self-critical Sequence Training for Image Captioning 8 | # https://arxiv.org/abs/1612.00563 9 | # In this file we only have Att2in2, which is a slightly different version of att2in, 10 | # in which the img feature embedding and word embedding is the same as what in adaatt. 11 | 12 | # TopDown is from Bottom-Up and Top-Down Attention for Image Captioning and VQA 13 | # https://arxiv.org/abs/1707.07998 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import torch 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | from torch.autograd import * 23 | import misc.utils as utils 24 | 25 | class CCGAttModel(nn.Module): 26 | def __init__(self, opt): 27 | super(CCGAttModel, self).__init__() 28 | self.vocab_size = opt.vocab_size 29 | self.input_encoding_size = opt.input_encoding_size 30 | #self.rnn_type = opt.rnn_type 31 | self.rnn_size = opt.rnn_size 32 | self.num_layers = opt.num_layers 33 | self.drop_prob_lm = opt.drop_prob_lm 34 | self.seq_length = opt.seq_length 35 | self.fc_feat_size = opt.fc_feat_size 36 | self.att_feat_size = opt.att_feat_size 37 | self.att_hid_size = opt.att_hid_size 38 | self.ccg_embedding_dim = opt.ccg_embedding_dim 39 | self.ccg_vocab_size = opt.ccg_vocab_size 40 | self.ss_prob = 0.0 # Schedule sampling probability 41 | 42 | self.embed = nn.Sequential(nn.Embedding(self.vocab_size + 1, self.input_encoding_size), 43 | nn.ReLU(), 44 | nn.Dropout(self.drop_prob_lm)) 45 | self.ccg_embed = nn.Sequential(nn.Embedding(self.ccg_vocab_size+1, self.ccg_embedding_dim), 46 | nn.ReLU(), 47 | nn.Dropout(self.drop_prob_lm)) 48 | self.fc_embed = nn.Sequential(nn.Linear(self.fc_feat_size, self.rnn_size), 49 | nn.ReLU(), 50 | nn.Dropout(self.drop_prob_lm)) 51 | self.att_embed = nn.Sequential(nn.Linear(self.att_feat_size, self.rnn_size), 52 | nn.ReLU(), 53 | nn.Dropout(self.drop_prob_lm)) 54 | self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1) 55 | self.logit_ccg = nn.Linear(self.rnn_size, self.ccg_vocab_size + 1) 56 | self.ctx2att = nn.Linear(self.rnn_size, self.att_hid_size) 57 | 58 | def init_hidden(self, bsz): 59 | weight = next(self.parameters()).data 60 | return (Variable(weight.new(self.num_layers, bsz, self.rnn_size).zero_()), 61 | Variable(weight.new(self.num_layers, bsz, self.rnn_size).zero_())) 62 | 63 | def forward(self, fc_feats, att_feats, seq, ccg_seq): 64 | batch_size = fc_feats.size(0) 65 | state = self.init_hidden(batch_size) 66 | 67 | outputs_word = [] 68 | outputs_ccg=[] 69 | 70 | # embed fc and att feats 71 | fc_feats = self.fc_embed(fc_feats) 72 | _att_feats = self.att_embed(att_feats.view(-1, self.att_feat_size)) 73 | att_feats = _att_feats.view(*(att_feats.size()[:-1] + (self.rnn_size,))) 74 | 75 | # Project the attention feats first to reduce memory and computation comsumptions. 76 | p_att_feats = self.ctx2att(att_feats.view(-1, self.rnn_size)) 77 | p_att_feats = p_att_feats.view(*(att_feats.size()[:-1] + (self.att_hid_size,))) 78 | 79 | seq_size = min (seq.size(1) - 1,ccg_seq.size(1) - 1) 80 | for i in range(seq_size): 81 | if self.training and i >= 1 and self.ss_prob > 0.0: # otherwiste no need to sample 82 | sample_prob = fc_feats.data.new(batch_size).uniform_(0, 1) 83 | sample_mask = sample_prob < self.ss_prob 84 | if sample_mask.sum() == 0: # disable schedule sampling 85 | it = seq[:, i].clone() 86 | it_ccg = ccg_seq[:, i].clone() 87 | else: # enable schedule sampling 88 | sample_ind = sample_mask.nonzero().view(-1) 89 | #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1) 90 | #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1)) 91 | prob_prev = torch.exp(outputs_word[-1].data) # fetch prev distribution: shape Nx(M+1) 92 | it = seq[:, i].data.clone() 93 | it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind)) 94 | it = Variable(it, requires_grad=False) 95 | 96 | # disable shedule sampling for CCG. 97 | # prob_prev = torch.exp(outputs_ccg[-1].data) # fetch prev distribution: shape Nx(M+1) 98 | it_ccg = ccg_seq[:, i].data.clone() 99 | # it_ccg.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind)) 100 | it_ccg = Variable(it_ccg, requires_grad=False) 101 | else: 102 | it = seq[:, i].clone() 103 | it_ccg = ccg_seq[:, i].clone() 104 | # it = Variable(it, requires_grad=False) 105 | # it_cross = Variable(it_cross, requires_grad=False) 106 | 107 | # break if all the sequences end 108 | if i >= 1 and seq[:, i].data.sum() == 0: 109 | break 110 | 111 | xt = self.embed(it) 112 | xt_ccg = self.ccg_embed(it_ccg) 113 | xt_concat=torch.cat([xt_ccg, xt], 1) 114 | output, state = self.core(xt_concat, fc_feats, att_feats, p_att_feats, state) 115 | 116 | output_word = F.log_softmax(self.logit(output)) 117 | output_ccg = F.log_softmax(self.logit_ccg(output)) 118 | 119 | outputs_word.append(output_word) 120 | outputs_ccg.append(output_ccg) 121 | 122 | return torch.cat([_.unsqueeze(1) for _ in outputs_word], 1) , torch.cat([_.unsqueeze(1) for _ in outputs_ccg], 1) 123 | 124 | def sample_beam(self, fc_feats, att_feats, opt={}): 125 | beam_size = opt.get('beam_size', 10) 126 | batch_size = fc_feats.size(0) 127 | 128 | # embed fc and att feats 129 | fc_feats = self.fc_embed(fc_feats) 130 | _att_feats = self.att_embed(att_feats.view(-1, self.att_feat_size)) 131 | att_feats = _att_feats.view(*(att_feats.size()[:-1] + (self.rnn_size,))) 132 | 133 | # Project the attention feats first to reduce memory and computation comsumptions. 134 | p_att_feats = self.ctx2att(att_feats.view(-1, self.rnn_size)) 135 | p_att_feats = p_att_feats.view(*(att_feats.size()[:-1] + (self.att_hid_size,))) 136 | 137 | assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' 138 | seq = torch.LongTensor(self.seq_length, batch_size).zero_() 139 | seqLogprobs = torch.FloatTensor(self.seq_length, batch_size) 140 | # lets process every image independently for now, for simplicity 141 | 142 | self.done_beams = [[] for _ in range(batch_size)] 143 | for k in range(batch_size): 144 | state = self.init_hidden(beam_size) 145 | tmp_fc_feats = fc_feats[k:k+1].expand(beam_size, fc_feats.size(1)) 146 | tmp_att_feats = att_feats[k:k+1].expand(*((beam_size,)+att_feats.size()[1:])).contiguous() 147 | tmp_p_att_feats = p_att_feats[k:k+1].expand(*((beam_size,)+p_att_feats.size()[1:])).contiguous() 148 | 149 | beam_seq = torch.LongTensor(self.seq_length, beam_size).zero_() 150 | beam_seq_logprobs = torch.FloatTensor(self.seq_length, beam_size).zero_() 151 | beam_logprobs_sum = torch.zeros(beam_size) # running sum of logprobs for each beam 152 | done_beams = [] 153 | for t in range(self.seq_length + 1): 154 | if t == 0: # input 155 | it = fc_feats.data.new(beam_size).long().zero_() 156 | xt = self.embed(Variable(it, requires_grad=False)) 157 | else: 158 | """pem a beam merge. that is, 159 | for every previous beam we now many new possibilities to branch out 160 | we need to resort our beams to maintain the loop invariant of keeping 161 | the top beam_size most likely sequences.""" 162 | logprobsf = logprobs.float() # lets go to CPU for more efficiency in indexing operations 163 | ys,ix = torch.sort(logprobsf,1,True) # sorted array of logprobs along each previous beam (last true = descending) 164 | candidates = [] 165 | cols = min(beam_size, ys.size(1)) 166 | rows = beam_size 167 | if t == 1: # at first time step only the first beam is active 168 | rows = 1 169 | for c in range(cols): 170 | for q in range(rows): 171 | # compute logprob of expanding beam q with word in (sorted) position c 172 | local_logprob = ys[q,c] 173 | candidate_logprob = beam_logprobs_sum[q] + local_logprob 174 | candidates.append({'c':ix.data[q,c], 'q':q, 'p':candidate_logprob.data[0], 'r':local_logprob.data[0]}) 175 | candidates = sorted(candidates, key=lambda x: -x['p']) 176 | 177 | # construct new beams 178 | new_state = [_.clone() for _ in state] 179 | if t > 1: 180 | # well need these as reference when we fork beams around 181 | beam_seq_prev = beam_seq[:t-1].clone() 182 | beam_seq_logprobs_prev = beam_seq_logprobs[:t-1].clone() 183 | for vix in range(beam_size): 184 | v = candidates[vix] 185 | # fork beam index q into index vix 186 | if t > 1: 187 | beam_seq[:t-1, vix] = beam_seq_prev[:, v['q']] 188 | beam_seq_logprobs[:t-1, vix] = beam_seq_logprobs_prev[:, v['q']] 189 | 190 | # rearrange recurrent states 191 | for state_ix in range(len(new_state)): 192 | # copy over state in previous beam q to new beam at vix 193 | new_state[state_ix][0, vix] = state[state_ix][0, v['q']] # dimension one is time step 194 | 195 | # append new end terminal at the end of this beam 196 | beam_seq[t-1, vix] = v['c'] # c'th word is the continuation 197 | beam_seq_logprobs[t-1, vix] = v['r'] # the raw logprob here 198 | beam_logprobs_sum[vix] = v['p'] # the new (sum) logprob along this beam 199 | 200 | if v['c'] == 0 or t == self.seq_length: 201 | # END token special case here, or we reached the end. 202 | # add the beam to a set of done beams 203 | self.done_beams[k].append({'seq': beam_seq[:, vix].clone(), 204 | 'logps': beam_seq_logprobs[:, vix].clone(), 205 | 'p': beam_logprobs_sum[vix] 206 | }) 207 | 208 | # encode as vectors 209 | it = beam_seq[t-1] 210 | xt = self.embed(Variable(it.cuda())) 211 | 212 | if t >= 1: 213 | state = new_state 214 | 215 | combined_xt=torch.cat([it, xt], 1) 216 | output, state = self.core(combined_xt, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, state) 217 | logprobs = F.log_softmax(self.logit(output)) 218 | 219 | self.done_beams[k] = sorted(self.done_beams[k], key=lambda x: -x['p']) 220 | seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score 221 | seqLogprobs[:, k] = self.done_beams[k][0]['logps'] 222 | # return the samples and their log likelihoods 223 | return seq.transpose(0, 1), seqLogprobs.transpose(0, 1) 224 | 225 | def sample(self, fc_feats, att_feats, opt={}): 226 | sample_max = opt.get('sample_max', 1) 227 | beam_size = opt.get('beam_size', 1) 228 | temperature = opt.get('temperature', 1.0) 229 | if beam_size > 1: 230 | return self.sample_beam(fc_feats, att_feats, opt) 231 | 232 | batch_size = fc_feats.size(0) 233 | state = self.init_hidden(batch_size) 234 | 235 | # embed fc and att feats 236 | fc_feats = self.fc_embed(fc_feats) 237 | _att_feats = self.att_embed(att_feats.view(-1, self.att_feat_size)) 238 | att_feats = _att_feats.view(*(att_feats.size()[:-1] + (self.rnn_size,))) 239 | 240 | # Project the attention feats first to reduce memory and computation comsumptions. 241 | p_att_feats = self.ctx2att(att_feats.view(-1, self.rnn_size)) 242 | p_att_feats = p_att_feats.view(*(att_feats.size()[:-1] + (self.att_hid_size,))) 243 | 244 | seq = [] 245 | seq_ccg=[] 246 | seqLogprobs = [] 247 | seqLogprobs_ccg = [] 248 | 249 | for t in range(self.seq_length + 1): 250 | if t == 0: # input 251 | it = fc_feats.data.new(batch_size).long().zero_() 252 | it_ccg= fc_feats.data.new(batch_size).long().zero_() 253 | elif sample_max: 254 | sampleLogprobs, it = torch.max(logprobs.data, 1) 255 | it = it.view(-1).long() 256 | 257 | sampleLogprobs_ccg, it_ccg = torch.max(logprobs_ccg.data, 1) 258 | it_ccg = it_ccg.view(-1).long() 259 | else: 260 | if temperature == 1.0: 261 | prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1) 262 | prob_prev_ccg = torch.exp(logprobs_ccg.data).cpu() 263 | else: 264 | # scale logprobs by temperature 265 | prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu() 266 | prob_prev_ccg = torch.exp(torch.div(logprobs_ccg.data, temperature)).cpu() 267 | 268 | it = torch.multinomial(prob_prev, 1).cuda() 269 | it_ccg = torch.multinomial(prob_prev_ccg, 1).cuda() 270 | 271 | sampleLogprobs = logprobs.gather(1, Variable(it, requires_grad=False)) # gather the logprobs at sampled positions 272 | sampleLogprobs_ccg = logprobs_ccg.gather(1, Variable(it_ccg, requires_grad=False)) 273 | it = it.view(-1).long() # and flatten indices for downstream processing 274 | it_ccg = it_ccg.view(-1).long() 275 | 276 | xt = self.embed(Variable(it, requires_grad=False)) 277 | xt_ccg = self.ccg_embed(Variable(it_ccg, requires_grad=False)) 278 | xt_concat = torch.cat([xt_ccg, xt],1) 279 | 280 | if t >= 1: 281 | # stop when all finished 282 | if t == 1: 283 | unfinished = it > 0 284 | else: 285 | unfinished = unfinished * (it > 0) 286 | if unfinished.sum() == 0: 287 | break 288 | it = it * unfinished.type_as(it) 289 | seq.append(it) #seq[t] the input of t+2 time step 290 | 291 | it_ccg = it_ccg * unfinished.type_as(it_ccg) 292 | seq_ccg.append(it_ccg) 293 | 294 | seqLogprobs.append(sampleLogprobs.view(-1)) 295 | seqLogprobs_ccg.append(sampleLogprobs_ccg.view(-1)) 296 | 297 | output, state = self.core(xt_concat, fc_feats, att_feats, p_att_feats, state) 298 | 299 | logprobs = F.log_softmax(self.logit(output)) 300 | logprobs_ccg = F.log_softmax(self.logit_ccg(output)) 301 | 302 | return torch.cat([_.unsqueeze(1) for _ in seq], 1), torch.cat([_.unsqueeze(1) for _ in seqLogprobs], 1), \ 303 | torch.cat([_.unsqueeze(1) for _ in seq_ccg], 1), torch.cat([_.unsqueeze(1) for _ in seqLogprobs_ccg], 1) 304 | 305 | class TopDownCore(nn.Module): 306 | def __init__(self, opt, use_maxout=False,ccg_embedding_dim=None): 307 | super(TopDownCore, self).__init__() 308 | self.drop_prob_lm = opt.drop_prob_lm 309 | if ccg_embedding_dim is not None: 310 | self.att_lstm = nn.LSTMCell(opt.input_encoding_size + ccg_embedding_dim + opt.rnn_size * 2, 311 | opt.rnn_size) # we, ccg, fc, h^2_t-1 312 | else: 313 | self.att_lstm = nn.LSTMCell(opt.input_encoding_size + opt.rnn_size * 2, opt.rnn_size) # we, fc, h^2_t-1 314 | self.lang_lstm = nn.LSTMCell(opt.rnn_size * 2, opt.rnn_size) # h^1_t, \hat v 315 | self.attention = Attention(opt) 316 | 317 | def forward(self, xt, fc_feats, att_feats, p_att_feats, state): 318 | prev_h = state[0][-1] 319 | att_lstm_input = torch.cat([prev_h, fc_feats, xt], 1) # h_{t-1}^2, v_{bar}, w_{e}*pi_{t} 320 | 321 | h_att, c_att = self.att_lstm(att_lstm_input, (state[0][0], state[1][0]))# h_{t}^1 322 | 323 | att = self.attention(h_att, att_feats, p_att_feats)# h_{t}^1, conv_feats, p_conv_feats 324 | 325 | lang_lstm_input = torch.cat([att, h_att], 1)# \hat v, h_{t}^1 326 | 327 | # lang_lstm_input = torch.cat([att, F.dropout(h_att, self.drop_prob_lm, self.training)], 1) ????? 328 | 329 | h_lang, c_lang = self.lang_lstm(lang_lstm_input, (state[0][1], state[1][1])) 330 | 331 | output = F.dropout(h_lang, self.drop_prob_lm, self.training) 332 | state = (torch.stack([h_att, h_lang]), torch.stack([c_att, c_lang])) 333 | 334 | return output, state 335 | 336 | class Attention(nn.Module): 337 | def __init__(self, opt): 338 | super(Attention, self).__init__() 339 | self.rnn_size = opt.rnn_size 340 | self.att_hid_size = opt.att_hid_size 341 | 342 | self.h2att = nn.Linear(self.rnn_size, self.att_hid_size) 343 | self.alpha_net = nn.Linear(self.att_hid_size, 1) 344 | 345 | def forward(self, h, att_feats, p_att_feats): 346 | # The p_att_feats here is already projected 347 | att_size = att_feats.numel() // att_feats.size(0) // self.rnn_size 348 | att = p_att_feats.view(-1, att_size, self.att_hid_size) 349 | 350 | att_h = self.h2att(h) # batch * att_hid_size 351 | att_h = att_h.unsqueeze(1).expand_as(att) # batch * att_size * att_hid_size 352 | dot = att + att_h # batch * att_size * att_hid_size 353 | dot = F.tanh(dot) # batch * att_size * att_hid_size 354 | dot = dot.view(-1, self.att_hid_size) # (batch * att_size) * att_hid_size 355 | dot = self.alpha_net(dot) # (batch * att_size) * 1 356 | dot = dot.view(-1, att_size) # batch * att_size 357 | 358 | weight = F.softmax(dot) # batch * att_size 359 | att_feats_ = att_feats.view(-1, att_size, self.rnn_size) # batch * att_size * att_feat_size 360 | att_res = torch.bmm(weight.unsqueeze(1), att_feats_).squeeze(1) # batch * att_feat_size 361 | 362 | return att_res 363 | 364 | class CCG_TopDownModel(CCGAttModel): 365 | def __init__(self, opt): 366 | super(CCG_TopDownModel, self).__init__(opt) 367 | self.num_layers = 2 368 | self.core = TopDownCore(opt,ccg_embedding_dim=opt.ccg_embedding_dim) 369 | -------------------------------------------------------------------------------- /models/AttModel_V1.py: -------------------------------------------------------------------------------- 1 | # This file contains Att2in2, AdaAtt, AdaAttMO, TopDown model 2 | 3 | # AdaAtt is from Knowing When to Look: Adaptive Attention via A Visual Sentinel for Image Captioning 4 | # https://arxiv.org/abs/1612.01887 5 | # AdaAttMO is a modified version with maxout lstm 6 | 7 | # Att2in is from Self-critical Sequence Training for Image Captioning 8 | # https://arxiv.org/abs/1612.00563 9 | # In this file we only have Att2in2, which is a slightly different version of att2in, 10 | # in which the img feature embedding and word embedding is the same as what in adaatt. 11 | 12 | # TopDown is from Bottom-Up and Top-Down Attention for Image Captioning and VQA 13 | # https://arxiv.org/abs/1707.07998 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import torch 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | from torch.autograd import * 23 | import misc.utils as utils 24 | 25 | class CCGAttModel(nn.Module): 26 | def __init__(self, opt): 27 | super(CCGAttModel, self).__init__() 28 | self.vocab_size = opt.vocab_size 29 | self.input_encoding_size = opt.input_encoding_size 30 | #self.rnn_type = opt.rnn_type 31 | self.rnn_size = opt.rnn_size 32 | self.num_layers = opt.num_layers 33 | self.drop_prob_lm = opt.drop_prob_lm 34 | self.seq_length = opt.seq_length 35 | self.fc_feat_size = opt.fc_feat_size 36 | self.att_feat_size = opt.att_feat_size 37 | self.att_hid_size = opt.att_hid_size 38 | self.ccg_embedding_dim = opt.ccg_embedding_dim 39 | self.ccg_vocab_size = opt.ccg_vocab_size 40 | self.ss_prob = 0.0 # Schedule sampling probability 41 | 42 | self.embed = nn.Sequential(nn.Embedding(self.vocab_size + 1, self.input_encoding_size), 43 | nn.ReLU(), 44 | nn.Dropout(self.drop_prob_lm)) 45 | self.ccg_embed = nn.Sequential(nn.Embedding(self.ccg_vocab_size+1, self.ccg_embedding_dim), 46 | nn.ReLU(), 47 | nn.Dropout(self.drop_prob_lm)) 48 | self.fc_embed = nn.Sequential(nn.Linear(self.fc_feat_size, self.rnn_size), 49 | nn.ReLU(), 50 | nn.Dropout(self.drop_prob_lm)) 51 | self.att_embed = nn.Sequential(nn.Linear(self.att_feat_size, self.rnn_size), 52 | nn.ReLU(), 53 | nn.Dropout(self.drop_prob_lm)) 54 | self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1) 55 | self.logit_ccg = nn.Linear(self.rnn_size, self.ccg_vocab_size + 1) 56 | self.ctx2att = nn.Linear(self.rnn_size, self.att_hid_size) 57 | 58 | def init_hidden(self, bsz): 59 | weight = next(self.parameters()).data 60 | return (Variable(weight.new(self.num_layers, bsz, self.rnn_size).zero_()), 61 | Variable(weight.new(self.num_layers, bsz, self.rnn_size).zero_())) 62 | 63 | def forward(self, fc_feats, att_feats, seq, ccg_seq): 64 | batch_size = fc_feats.size(0) 65 | state = self.init_hidden(batch_size) 66 | 67 | outputs_word = [] 68 | outputs_ccg=[] 69 | 70 | # embed fc and att feats 71 | fc_feats = self.fc_embed(fc_feats) 72 | _att_feats = self.att_embed(att_feats.view(-1, self.att_feat_size)) 73 | att_feats = _att_feats.view(*(att_feats.size()[:-1] + (self.rnn_size,))) 74 | 75 | # Project the attention feats first to reduce memory and computation comsumptions. 76 | p_att_feats = self.ctx2att(att_feats.view(-1, self.rnn_size)) 77 | p_att_feats = p_att_feats.view(*(att_feats.size()[:-1] + (self.att_hid_size,))) 78 | 79 | seq_size = min (seq.size(1) - 1,ccg_seq.size(1) - 1) 80 | for i in range(seq_size): 81 | if self.training and i >= 1 and self.ss_prob > 0.0: # otherwiste no need to sample 82 | sample_prob = fc_feats.data.new(batch_size).uniform_(0, 1) 83 | sample_mask = sample_prob < self.ss_prob 84 | if sample_mask.sum() == 0: # disable schedule sampling 85 | it = seq[:, i].clone() 86 | it_ccg = ccg_seq[:, i].clone() 87 | else: # enable schedule sampling 88 | sample_ind = sample_mask.nonzero().view(-1) 89 | #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1) 90 | #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1)) 91 | prob_prev = torch.exp(outputs_word[-1].data) # fetch prev distribution: shape Nx(M+1) 92 | it = seq[:, i].data.clone() 93 | it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind)) 94 | it = Variable(it, requires_grad=False) 95 | 96 | # disable shedule sampling for CCG. 97 | # prob_prev = torch.exp(outputs_ccg[-1].data) # fetch prev distribution: shape Nx(M+1) 98 | it_ccg = ccg_seq[:, i].data.clone() 99 | # it_ccg.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind)) 100 | it_ccg = Variable(it_ccg, requires_grad=False) 101 | else: 102 | it = seq[:, i].clone() 103 | it_ccg = ccg_seq[:, i].clone() 104 | # it = Variable(it, requires_grad=False) 105 | # it_cross = Variable(it_cross, requires_grad=False) 106 | 107 | # break if all the sequences end 108 | if i >= 1 and seq[:, i].data.sum() == 0: 109 | break 110 | 111 | xt = self.embed(it) 112 | xt_ccg = self.ccg_embed(it_ccg) 113 | output, state = self.core(xt,xt_ccg, fc_feats, att_feats, p_att_feats, state) 114 | 115 | output_word = F.log_softmax(self.logit(output)) 116 | output_ccg = F.log_softmax(self.logit_ccg(output)) 117 | 118 | outputs_word.append(output_word) 119 | outputs_ccg.append(output_ccg) 120 | 121 | return torch.cat([_.unsqueeze(1) for _ in outputs_word], 1) , torch.cat([_.unsqueeze(1) for _ in outputs_ccg], 1) 122 | 123 | def sample(self, fc_feats, att_feats, opt={}): 124 | sample_max = opt.get('sample_max', 1) 125 | beam_size = opt.get('beam_size', 1) 126 | temperature = opt.get('temperature', 1.0) 127 | if beam_size > 1: 128 | return self.sample_beam(fc_feats, att_feats, opt) 129 | 130 | batch_size = fc_feats.size(0) 131 | state = self.init_hidden(batch_size) 132 | 133 | # embed fc and att feats 134 | fc_feats = self.fc_embed(fc_feats) 135 | _att_feats = self.att_embed(att_feats.view(-1, self.att_feat_size)) 136 | att_feats = _att_feats.view(*(att_feats.size()[:-1] + (self.rnn_size,))) 137 | 138 | # Project the attention feats first to reduce memory and computation comsumptions. 139 | p_att_feats = self.ctx2att(att_feats.view(-1, self.rnn_size)) 140 | p_att_feats = p_att_feats.view(*(att_feats.size()[:-1] + (self.att_hid_size,))) 141 | 142 | seq = [] 143 | seq_ccg=[] 144 | seqLogprobs = [] 145 | seqLogprobs_ccg = [] 146 | 147 | for t in range(self.seq_length + 1): 148 | if t == 0: # input 149 | it = fc_feats.data.new(batch_size).long().zero_() 150 | it_ccg= fc_feats.data.new(batch_size).long().zero_() 151 | elif sample_max: 152 | sampleLogprobs, it = torch.max(logprobs.data, 1) 153 | it = it.view(-1).long() 154 | 155 | sampleLogprobs_ccg, it_ccg = torch.max(logprobs_ccg.data, 1) 156 | it_ccg = it_ccg.view(-1).long() 157 | else: 158 | if temperature == 1.0: 159 | prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1) 160 | prob_prev_ccg = torch.exp(logprobs_ccg.data).cpu() 161 | else: 162 | # scale logprobs by temperature 163 | prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu() 164 | prob_prev_ccg = torch.exp(torch.div(logprobs_ccg.data, temperature)).cpu() 165 | 166 | it = torch.multinomial(prob_prev, 1).cuda() 167 | sampleLogprobs = logprobs.gather(1, Variable(it, requires_grad=False)) # gather the logprobs at sampled positions 168 | it = it.view(-1).long() # and flatten indices for downstream processing 169 | 170 | sampleLogprobs_ccg, it_ccg = torch.max(logprobs_ccg.data, 1) 171 | it_ccg = it_ccg.view(-1).long() 172 | # it_ccg = torch.multinomial(prob_prev_ccg, 1).cuda() 173 | # sampleLogprobs_ccg = logprobs_ccg.gather(1, Variable(it_ccg, requires_grad=False)) 174 | # it_ccg = it_ccg.view(-1).long() 175 | 176 | xt = self.embed(Variable(it, requires_grad=False)) 177 | xt_ccg = self.ccg_embed(Variable(it_ccg, requires_grad=False)) 178 | 179 | if t >= 1: 180 | # stop when all finished 181 | if t == 1: 182 | unfinished = it > 0 183 | else: 184 | unfinished = unfinished * (it > 0) 185 | if unfinished.sum() == 0: 186 | break 187 | it = it * unfinished.type_as(it) 188 | seq.append(it) #seq[t] the input of t+2 time step 189 | 190 | it_ccg = it_ccg * unfinished.type_as(it_ccg) 191 | seq_ccg.append(it_ccg) 192 | 193 | seqLogprobs.append(sampleLogprobs.view(-1)) 194 | seqLogprobs_ccg.append(sampleLogprobs_ccg.view(-1)) 195 | 196 | output, state = self.core(xt,xt_ccg, fc_feats, att_feats, p_att_feats, state) 197 | 198 | logprobs = F.log_softmax(self.logit(output)) 199 | logprobs_ccg = F.log_softmax(self.logit_ccg(output)) 200 | 201 | return torch.cat([_.unsqueeze(1) for _ in seq], 1), torch.cat([_.unsqueeze(1) for _ in seqLogprobs], 1), \ 202 | torch.cat([_.unsqueeze(1) for _ in seq_ccg], 1), torch.cat([_.unsqueeze(1) for _ in seqLogprobs_ccg], 1) 203 | 204 | class TopDownCore(nn.Module): 205 | def __init__(self, opt, use_maxout=False,ccg_embedding_dim=None): 206 | super(TopDownCore, self).__init__() 207 | self.drop_prob_lm = opt.drop_prob_lm 208 | self.att_lstm = nn.LSTMCell(opt.input_encoding_size + opt.rnn_size * 2, 209 | opt.rnn_size) # we, ccg, fc, h^2_t-1 210 | self.lang_lstm = nn.LSTMCell(opt.rnn_size * 2, opt.rnn_size) # h^1_t, \hat v 211 | self.attention = Attention(opt) 212 | #self.ccg_lstm = nn.LSTMCell(opt.rnn_size * 2, opt.rnn_size) 213 | 214 | def forward(self, xt, xt_ccg, fc_feats, att_feats, p_att_feats, state): 215 | # combined_xt=torch.cat([xt, xt_ccg], 1) 216 | prev_h = state[0][-1] # h_lang 217 | att_lstm_input = torch.cat([prev_h, fc_feats, xt], 1) # h_{t-1}^2, v_{bar}, w_{e}*pi_{t} 218 | 219 | h_att, c_att = self.att_lstm(att_lstm_input, (state[0][0], state[1][0]))# h_{t}^1 220 | 221 | att = self.attention(h_att, att_feats, p_att_feats)# h_{t}^1, conv_feats, p_conv_feats 222 | 223 | lang_lstm_input = torch.cat([att, h_att], 1)# \hat v, h_{t}^1 224 | 225 | # lang_lstm_input = torch.cat([att, F.dropout(h_att, self.drop_prob_lm, self.training)], 1) ????? 226 | 227 | h_lang, c_lang = self.lang_lstm(lang_lstm_input, (state[0][1], state[1][1])) 228 | 229 | output = F.dropout(h_lang, self.drop_prob_lm, self.training) 230 | state = (torch.stack([h_att, h_lang]), 231 | torch.stack([c_att, c_lang])) 232 | 233 | return output, state 234 | 235 | class Attention(nn.Module): 236 | def __init__(self, opt): 237 | super(Attention, self).__init__() 238 | self.rnn_size = opt.rnn_size 239 | self.att_hid_size = opt.att_hid_size 240 | 241 | self.h2att = nn.Linear(self.rnn_size, self.att_hid_size) 242 | self.alpha_net = nn.Linear(self.att_hid_size, 1) 243 | 244 | def forward(self, h, att_feats, p_att_feats): 245 | # The p_att_feats here is already projected 246 | att_size = att_feats.numel() // att_feats.size(0) // self.rnn_size 247 | att = p_att_feats.view(-1, att_size, self.att_hid_size) 248 | 249 | att_h = self.h2att(h) # batch * att_hid_size 250 | att_h = att_h.unsqueeze(1).expand_as(att) # batch * att_size * att_hid_size 251 | dot = att + att_h # batch * att_size * att_hid_size 252 | dot = F.tanh(dot) # batch * att_size * att_hid_size 253 | dot = dot.view(-1, self.att_hid_size) # (batch * att_size) * att_hid_size 254 | dot = self.alpha_net(dot) # (batch * att_size) * 1 255 | dot = dot.view(-1, att_size) # batch * att_size 256 | 257 | weight = F.softmax(dot) # batch * att_size 258 | att_feats_ = att_feats.view(-1, att_size, self.rnn_size) # batch * att_size * att_feat_size 259 | att_res = torch.bmm(weight.unsqueeze(1), att_feats_).squeeze(1) # batch * att_feat_size 260 | 261 | return att_res 262 | 263 | class CCG_TopDownModel(CCGAttModel): 264 | def __init__(self, opt): 265 | print("use new model") 266 | super(CCG_TopDownModel, self).__init__(opt) 267 | self.num_layers = 2 268 | self.core = TopDownCore(opt,ccg_embedding_dim=opt.ccg_embedding_dim) 269 | -------------------------------------------------------------------------------- /models/AttModel_V2.py: -------------------------------------------------------------------------------- 1 | # This file contains Att2in2, AdaAtt, AdaAttMO, TopDown model 2 | 3 | # AdaAtt is from Knowing When to Look: Adaptive Attention via A Visual Sentinel for Image Captioning 4 | # https://arxiv.org/abs/1612.01887 5 | # AdaAttMO is a modified version with maxout lstm 6 | 7 | # Att2in is from Self-critical Sequence Training for Image Captioning 8 | # https://arxiv.org/abs/1612.00563 9 | # In this file we only have Att2in2, which is a slightly different version of att2in, 10 | # in which the img feature embedding and word embedding is the same as what in adaatt. 11 | 12 | # TopDown is from Bottom-Up and Top-Down Attention for Image Captioning and VQA 13 | # https://arxiv.org/abs/1707.07998 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import torch 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | from torch.autograd import * 23 | import misc.utils as utils 24 | 25 | class CCGAttModel(nn.Module): 26 | def __init__(self, opt): 27 | super(CCGAttModel, self).__init__() 28 | self.vocab_size = opt.vocab_size 29 | self.input_encoding_size = opt.input_encoding_size 30 | #self.rnn_type = opt.rnn_type 31 | self.rnn_size = opt.rnn_size 32 | self.num_layers = opt.num_layers 33 | self.drop_prob_lm = opt.drop_prob_lm 34 | self.seq_length = opt.seq_length 35 | self.fc_feat_size = opt.fc_feat_size 36 | self.att_feat_size = opt.att_feat_size 37 | self.att_hid_size = opt.att_hid_size 38 | self.ccg_embedding_dim = opt.ccg_embedding_dim 39 | self.ccg_vocab_size = opt.ccg_vocab_size 40 | self.ss_prob = 0.0 # Schedule sampling probability 41 | 42 | self.embed = nn.Sequential(nn.Embedding(self.vocab_size + 1, self.input_encoding_size), 43 | nn.ReLU(), 44 | nn.Dropout(self.drop_prob_lm)) 45 | self.ccg_embed = nn.Sequential(nn.Embedding(self.ccg_vocab_size + 1, self.ccg_embedding_dim), 46 | nn.ReLU(), 47 | nn.Dropout(self.drop_prob_lm)) 48 | self.fc_embed = nn.Sequential(nn.Linear(self.fc_feat_size, self.rnn_size), 49 | nn.ReLU(), 50 | nn.Dropout(self.drop_prob_lm)) 51 | self.att_embed = nn.Sequential(nn.Linear(self.att_feat_size, self.rnn_size), 52 | nn.ReLU(), 53 | nn.Dropout(self.drop_prob_lm)) 54 | self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1) 55 | self.logit_ccg = nn.Linear(self.rnn_size, self.ccg_vocab_size + 1) 56 | self.ctx2att = nn.Linear(self.rnn_size, self.att_hid_size) 57 | 58 | def init_hidden(self, bsz): 59 | weight = next(self.parameters()).data 60 | return (Variable(weight.new(self.num_layers, bsz, self.rnn_size).zero_()), 61 | Variable(weight.new(self.num_layers, bsz, self.rnn_size).zero_())) 62 | 63 | def forward(self, fc_feats, att_feats, seq, ccg_seq): 64 | batch_size = fc_feats.size(0) 65 | state = self.init_hidden(batch_size) 66 | 67 | outputs_word = [] 68 | outputs_ccg=[] 69 | 70 | # embed fc and att feats 71 | fc_feats = self.fc_embed(fc_feats) 72 | _att_feats = self.att_embed(att_feats.view(-1, self.att_feat_size)) 73 | att_feats = _att_feats.view(*(att_feats.size()[:-1] + (self.rnn_size,))) 74 | 75 | # Project the attention feats first to reduce memory and computation comsumptions. 76 | p_att_feats = self.ctx2att(att_feats.view(-1, self.rnn_size)) 77 | p_att_feats = p_att_feats.view(*(att_feats.size()[:-1] + (self.att_hid_size,))) 78 | 79 | seq_size = min (seq.size(1) - 1,ccg_seq.size(1) - 1) 80 | for i in range(seq_size): 81 | if self.training and i >= 1 and self.ss_prob > 0.0: # otherwiste no need to sample 82 | sample_prob = fc_feats.data.new(batch_size).uniform_(0, 1) 83 | sample_mask = sample_prob < self.ss_prob 84 | if sample_mask.sum() == 0: # disable schedule sampling 85 | it = seq[:, i].clone() 86 | it_ccg = ccg_seq[:, i].clone() 87 | else: # enable schedule sampling 88 | sample_ind = sample_mask.nonzero().view(-1) 89 | #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1) 90 | #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1)) 91 | prob_prev = torch.exp(outputs_word[-1].data) # fetch prev distribution: shape Nx(M+1) 92 | it = seq[:, i].data.clone() 93 | it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind)) 94 | it = Variable(it, requires_grad=False) 95 | 96 | # disable shedule sampling for CCG. 97 | # prob_prev = torch.exp(outputs_ccg[-1].data) # fetch prev distribution: shape Nx(M+1) 98 | it_ccg = ccg_seq[:, i].data.clone() 99 | # it_ccg.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind)) 100 | it_ccg = Variable(it_ccg, requires_grad=False) 101 | else: 102 | it = seq[:, i].clone() 103 | it_ccg = ccg_seq[:, i].clone() 104 | # it = Variable(it, requires_grad=False) 105 | # it_cross = Variable(it_cross, requires_grad=False) 106 | 107 | # break if all the sequences end 108 | if i >= 1 and seq[:, i].data.sum() == 0: 109 | break 110 | 111 | xt = self.embed(it) 112 | xt_ccg = self.ccg_embed(it_ccg) 113 | output, state = self.core(xt,xt_ccg, fc_feats, att_feats, p_att_feats, state) 114 | 115 | output_word = F.log_softmax(self.logit(output)) 116 | output_ccg = F.log_softmax(self.logit_ccg(output)) 117 | 118 | outputs_word.append(output_word) 119 | outputs_ccg.append(output_ccg) 120 | 121 | return torch.cat([_.unsqueeze(1) for _ in outputs_word], 1) , torch.cat([_.unsqueeze(1) for _ in outputs_ccg], 1) 122 | 123 | def sample(self, fc_feats, att_feats, opt={}): 124 | sample_max = opt.get('sample_max', 1) 125 | beam_size = opt.get('beam_size', 1) 126 | temperature = opt.get('temperature', 1.0) 127 | if beam_size > 1: 128 | return self.sample_beam(fc_feats, att_feats, opt) 129 | 130 | batch_size = fc_feats.size(0) 131 | state = self.init_hidden(batch_size) 132 | 133 | # embed fc and att feats 134 | fc_feats = self.fc_embed(fc_feats) 135 | _att_feats = self.att_embed(att_feats.view(-1, self.att_feat_size)) 136 | att_feats = _att_feats.view(*(att_feats.size()[:-1] + (self.rnn_size,))) 137 | 138 | # Project the attention feats first to reduce memory and computation comsumptions. 139 | p_att_feats = self.ctx2att(att_feats.view(-1, self.rnn_size)) 140 | p_att_feats = p_att_feats.view(*(att_feats.size()[:-1] + (self.att_hid_size,))) 141 | 142 | seq = [] 143 | seq_ccg=[] 144 | seqLogprobs = [] 145 | seqLogprobs_ccg = [] 146 | 147 | for t in range(self.seq_length + 1): 148 | if t == 0: # input 149 | it = fc_feats.data.new(batch_size).long().zero_() 150 | it_ccg= fc_feats.data.new(batch_size).long().zero_() 151 | elif sample_max: 152 | sampleLogprobs, it = torch.max(logprobs.data, 1) 153 | it = it.view(-1).long() 154 | 155 | sampleLogprobs_ccg, it_ccg = torch.max(logprobs_ccg.data, 1) 156 | it_ccg = it_ccg.view(-1).long() 157 | else: 158 | if temperature == 1.0: 159 | prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1) 160 | prob_prev_ccg = torch.exp(logprobs_ccg.data).cpu() 161 | else: 162 | # scale logprobs by temperature 163 | prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu() 164 | prob_prev_ccg = torch.exp(torch.div(logprobs_ccg.data, temperature)).cpu() 165 | 166 | it = torch.multinomial(prob_prev, 1).cuda() 167 | sampleLogprobs = logprobs.gather(1, Variable(it, requires_grad=False)) # gather the logprobs at sampled positions 168 | it = it.view(-1).long() # and flatten indices for downstream processing 169 | 170 | sampleLogprobs_ccg, it_ccg = torch.max(logprobs_ccg.data, 1) 171 | it_ccg = it_ccg.view(-1).long() 172 | # it_ccg = torch.multinomial(prob_prev_ccg, 1).cuda() 173 | # sampleLogprobs_ccg = logprobs_ccg.gather(1, Variable(it_ccg, requires_grad=False)) 174 | # it_ccg = it_ccg.view(-1).long() 175 | 176 | xt = self.embed(Variable(it, requires_grad=False)) 177 | xt_ccg = self.ccg_embed(Variable(it_ccg, requires_grad=False)) 178 | 179 | if t >= 1: 180 | # stop when all finished 181 | if t == 1: 182 | unfinished = it > 0 183 | else: 184 | unfinished = unfinished * (it > 0) 185 | if unfinished.sum() == 0: 186 | break 187 | it = it * unfinished.type_as(it) 188 | seq.append(it) #seq[t] the input of t+2 time step 189 | 190 | it_ccg = it_ccg * unfinished.type_as(it_ccg) 191 | seq_ccg.append(it_ccg) 192 | 193 | seqLogprobs.append(sampleLogprobs.view(-1)) 194 | seqLogprobs_ccg.append(sampleLogprobs_ccg.view(-1)) 195 | 196 | output, state = self.core(xt,xt_ccg, fc_feats, att_feats, p_att_feats, state) 197 | 198 | logprobs = F.log_softmax(self.logit(output)) 199 | logprobs_ccg = F.log_softmax(self.logit_ccg(output)) 200 | 201 | return torch.cat([_.unsqueeze(1) for _ in seq], 1), torch.cat([_.unsqueeze(1) for _ in seqLogprobs], 1), \ 202 | torch.cat([_.unsqueeze(1) for _ in seq_ccg], 1), torch.cat([_.unsqueeze(1) for _ in seqLogprobs_ccg], 1) 203 | 204 | class TopDownCore(nn.Module): 205 | def __init__(self, opt, use_maxout=False,ccg_embedding_dim=None): 206 | super(TopDownCore, self).__init__() 207 | self.drop_prob_lm = opt.drop_prob_lm 208 | self.att_lstm = nn.LSTMCell(opt.input_encoding_size + ccg_embedding_dim + opt.rnn_size * 2, 209 | opt.rnn_size) # we, ccg, fc, h^2_t-1 210 | self.lang_lstm = nn.LSTMCell(opt.rnn_size * 2, opt.rnn_size) # h^1_t, \hat v 211 | self.attention = Attention(opt) 212 | #self.ccg_lstm = nn.LSTMCell(opt.rnn_size * 2, opt.rnn_size) 213 | 214 | def forward(self, xt, xt_ccg, fc_feats, att_feats, p_att_feats, state): 215 | prev_h = state[0][-1] # h_lang 216 | att_lstm_input = torch.cat([prev_h, fc_feats, xt, xt_ccg], 1) # h_{t-1}^2, v_{bar}, w_{e}*pi_{t} 217 | 218 | h_att, c_att = self.att_lstm(att_lstm_input, (state[0][0], state[1][0]))# h_{t}^1 219 | 220 | att = self.attention(h_att, att_feats, p_att_feats)# h_{t}^1, conv_feats, p_conv_feats 221 | 222 | lang_lstm_input = torch.cat([att, h_att], 1)# \hat v, h_{t}^1 223 | 224 | # lang_lstm_input = torch.cat([att, F.dropout(h_att, self.drop_prob_lm, self.training)], 1) ????? 225 | 226 | h_lang, c_lang = self.lang_lstm(lang_lstm_input, (state[0][1], state[1][1])) 227 | 228 | output = F.dropout(h_lang, self.drop_prob_lm, self.training) 229 | state = (torch.stack([h_att, h_lang]), 230 | torch.stack([c_att, c_lang])) 231 | 232 | return output, state 233 | 234 | class Attention(nn.Module): 235 | def __init__(self, opt): 236 | super(Attention, self).__init__() 237 | self.rnn_size = opt.rnn_size 238 | self.att_hid_size = opt.att_hid_size 239 | 240 | self.h2att = nn.Linear(self.rnn_size, self.att_hid_size) 241 | self.alpha_net = nn.Linear(self.att_hid_size, 1) 242 | 243 | def forward(self, h, att_feats, p_att_feats): 244 | # The p_att_feats here is already projected 245 | att_size = att_feats.numel() // att_feats.size(0) // self.rnn_size 246 | att = p_att_feats.view(-1, att_size, self.att_hid_size) 247 | 248 | att_h = self.h2att(h) # batch * att_hid_size 249 | att_h = att_h.unsqueeze(1).expand_as(att) # batch * att_size * att_hid_size 250 | dot = att + att_h # batch * att_size * att_hid_size 251 | dot = F.tanh(dot) # batch * att_size * att_hid_size 252 | dot = dot.view(-1, self.att_hid_size) # (batch * att_size) * att_hid_size 253 | dot = self.alpha_net(dot) # (batch * att_size) * 1 254 | dot = dot.view(-1, att_size) # batch * att_size 255 | 256 | weight = F.softmax(dot) # batch * att_size 257 | att_feats_ = att_feats.view(-1, att_size, self.rnn_size) # batch * att_size * att_feat_size 258 | att_res = torch.bmm(weight.unsqueeze(1), att_feats_).squeeze(1) # batch * att_feat_size 259 | 260 | return att_res 261 | 262 | class CCG_TopDownModel(CCGAttModel): 263 | def __init__(self, opt): 264 | print("use new model") 265 | super(CCG_TopDownModel, self).__init__(opt) 266 | self.num_layers = 2 267 | self.core = TopDownCore(opt,ccg_embedding_dim=opt.ccg_embedding_dim) 268 | -------------------------------------------------------------------------------- /models/TDModel.py: -------------------------------------------------------------------------------- 1 | # This file contains Att2in2, AdaAtt, AdaAttMO, TopDown model 2 | 3 | # AdaAtt is from Knowing When to Look: Adaptive Attention via A Visual Sentinel for Image Captioning 4 | # https://arxiv.org/abs/1612.01887 5 | # AdaAttMO is a modified version with maxout lstm 6 | 7 | # Att2in is from Self-critical Sequence Training for Image Captioning 8 | # https://arxiv.org/abs/1612.00563 9 | # In this file we only have Att2in2, which is a slightly different version of att2in, 10 | # in which the img feature embedding and word embedding is the same as what in adaatt. 11 | 12 | # TopDown is from Bottom-Up and Top-Down Attention for Image Captioning and VQA 13 | # https://arxiv.org/abs/1707.07998 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import torch 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | from torch.autograd import * 23 | import misc.utils as utils 24 | import numpy as np 25 | class AttModel(nn.Module): 26 | def __init__(self, opt): 27 | super(AttModel, self).__init__() 28 | self.vocab_size = opt.vocab_size 29 | self.input_encoding_size = opt.input_encoding_size 30 | #self.rnn_type = opt.rnn_type 31 | self.rnn_size = opt.rnn_size 32 | self.num_layers = opt.num_layers 33 | self.drop_prob_lm = opt.drop_prob_lm 34 | self.seq_length = opt.seq_length 35 | self.fc_feat_size = opt.fc_feat_size 36 | self.att_feat_size = opt.att_feat_size 37 | self.att_hid_size = opt.att_hid_size 38 | 39 | self.ss_prob = 0.0 # Schedule sampling probability 40 | 41 | self.embed = nn.Sequential(nn.Embedding(self.vocab_size + 1, self.input_encoding_size), 42 | nn.ReLU(), 43 | nn.Dropout(self.drop_prob_lm)) 44 | self.fc_embed = nn.Sequential(nn.Linear(self.fc_feat_size, self.rnn_size), 45 | nn.ReLU(), 46 | nn.Dropout(self.drop_prob_lm)) 47 | self.att_embed = nn.Sequential(nn.Linear(self.att_feat_size, self.rnn_size), 48 | nn.ReLU(), 49 | nn.Dropout(self.drop_prob_lm)) 50 | self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1) 51 | self.ctx2att = nn.Linear(self.rnn_size, self.att_hid_size) 52 | 53 | def init_hidden(self, bsz): 54 | weight = next(self.parameters()).data 55 | return (Variable(weight.new(self.num_layers, bsz, self.rnn_size).zero_()), 56 | Variable(weight.new(self.num_layers, bsz, self.rnn_size).zero_())) 57 | 58 | def forward(self, fc_feats, att_feats, seq): 59 | batch_size = fc_feats.size(0) 60 | state = self.init_hidden(batch_size) 61 | outputs = [] 62 | 63 | # embed fc and att feats 64 | fc_feats = self.fc_embed(fc_feats) 65 | _att_feats = self.att_embed(att_feats.view(-1, self.att_feat_size)) 66 | att_feats = _att_feats.view(*(att_feats.size()[:-1] + (self.rnn_size,))) 67 | 68 | # Project the attention feats first to reduce memory and computation comsumptions. 69 | p_att_feats = self.ctx2att(att_feats.view(-1, self.rnn_size)) 70 | p_att_feats = p_att_feats.view(*(att_feats.size()[:-1] + (self.att_hid_size,))) 71 | 72 | for i in range(seq.size(1) - 1): 73 | if self.training and i >= 1 and self.ss_prob > 0.0: # otherwiste no need to sample 74 | sample_prob = fc_feats.data.new(batch_size).uniform_(0, 1) 75 | sample_mask = sample_prob < self.ss_prob 76 | if sample_mask.sum() == 0: 77 | it = seq[:, i].clone() 78 | else: 79 | sample_ind = sample_mask.nonzero().view(-1) 80 | it = seq[:, i].data.clone() 81 | #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1) 82 | #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1)) 83 | prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1) 84 | it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind)) 85 | it = Variable(it, requires_grad=False) 86 | else: 87 | it = seq[:, i].clone() 88 | # break if all the sequences end 89 | if i >= 1 and seq[:, i].data.sum() == 0: 90 | # print(str(i)+':padding:'+str(len(outputs))) 91 | for _ in range(seq.size(1) - i -1): 92 | outputs.append(padding.clone()) 93 | break 94 | xt = self.embed(it) 95 | 96 | output, state = self.core(xt, fc_feats, att_feats, p_att_feats, state) 97 | output = F.log_softmax(self.logit(output)) 98 | outputs.append(output) 99 | 100 | if i==1: 101 | padding = Variable(output.data.new(batch_size, self.vocab_size + 1).zero_()) 102 | 103 | # pad_v = torch.rand(batch_size, self.vocab_size + 1) 104 | # pad_v = pad_v.type_as(it) 105 | 106 | return torch.cat([_.unsqueeze(1) for _ in outputs], 1) 107 | 108 | def sample_beam(self, fc_feats, att_feats, opt={}): 109 | beam_size = opt.get('beam_size', 10) 110 | batch_size = fc_feats.size(0) 111 | 112 | # embed fc and att feats 113 | fc_feats = self.fc_embed(fc_feats) 114 | _att_feats = self.att_embed(att_feats.view(-1, self.att_feat_size)) 115 | att_feats = _att_feats.view(*(att_feats.size()[:-1] + (self.rnn_size,))) 116 | 117 | # Project the attention feats first to reduce memory and computation comsumptions. 118 | p_att_feats = self.ctx2att(att_feats.view(-1, self.rnn_size)) 119 | p_att_feats = p_att_feats.view(*(att_feats.size()[:-1] + (self.att_hid_size,))) 120 | 121 | assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' 122 | seq = torch.LongTensor(self.seq_length, batch_size).zero_() 123 | seqLogprobs = torch.FloatTensor(self.seq_length, batch_size) 124 | # lets process every image independently for now, for simplicity 125 | 126 | self.done_beams = [[] for _ in range(batch_size)] 127 | for k in range(batch_size): 128 | state = self.init_hidden(beam_size) 129 | tmp_fc_feats = fc_feats[k:k+1].expand(beam_size, fc_feats.size(1)) 130 | tmp_att_feats = att_feats[k:k+1].expand(*((beam_size,)+att_feats.size()[1:])).contiguous() 131 | tmp_p_att_feats = p_att_feats[k:k+1].expand(*((beam_size,)+p_att_feats.size()[1:])).contiguous() 132 | 133 | beam_seq = torch.LongTensor(self.seq_length, beam_size).zero_() 134 | beam_seq_logprobs = torch.FloatTensor(self.seq_length, beam_size).zero_() 135 | beam_logprobs_sum = torch.zeros(beam_size) # running sum of logprobs for each beam 136 | done_beams = [] 137 | for t in range(self.seq_length + 1): 138 | if t == 0: # input 139 | it = fc_feats.data.new(beam_size).long().zero_() 140 | xt = self.embed(Variable(it, requires_grad=False)) 141 | else: 142 | """pem a beam merge. that is, 143 | for every previous beam we now many new possibilities to branch out 144 | we need to resort our beams to maintain the loop invariant of keeping 145 | the top beam_size most likely sequences.""" 146 | logprobsf = logprobs.float() # lets go to CPU for more efficiency in indexing operations 147 | ys,ix = torch.sort(logprobsf,1,True) # sorted array of logprobs along each previous beam (last true = descending) 148 | candidates = [] 149 | cols = min(beam_size, ys.size(1)) 150 | rows = beam_size 151 | if t == 1: # at first time step only the first beam is active 152 | rows = 1 153 | for c in range(cols): 154 | for q in range(rows): 155 | # compute logprob of expanding beam q with word in (sorted) position c 156 | local_logprob = ys[q,c] 157 | candidate_logprob = beam_logprobs_sum[q] + local_logprob 158 | candidates.append({'c':ix.data[q,c], 'q':q, 'p':candidate_logprob.data[0], 'r':local_logprob.data[0]}) 159 | candidates = sorted(candidates, key=lambda x: -x['p']) 160 | 161 | # construct new beams 162 | new_state = [_.clone() for _ in state] 163 | if t > 1: 164 | # well need these as reference when we fork beams around 165 | beam_seq_prev = beam_seq[:t-1].clone() 166 | beam_seq_logprobs_prev = beam_seq_logprobs[:t-1].clone() 167 | for vix in range(beam_size): 168 | v = candidates[vix] 169 | # fork beam index q into index vix 170 | if t > 1: 171 | beam_seq[:t-1, vix] = beam_seq_prev[:, v['q']] 172 | beam_seq_logprobs[:t-1, vix] = beam_seq_logprobs_prev[:, v['q']] 173 | 174 | # rearrange recurrent states 175 | for state_ix in range(len(new_state)): 176 | # copy over state in previous beam q to new beam at vix 177 | new_state[state_ix][0, vix] = state[state_ix][0, v['q']] # dimension one is time step 178 | 179 | # append new end terminal at the end of this beam 180 | beam_seq[t-1, vix] = v['c'] # c'th word is the continuation 181 | beam_seq_logprobs[t-1, vix] = v['r'] # the raw logprob here 182 | beam_logprobs_sum[vix] = v['p'] # the new (sum) logprob along this beam 183 | 184 | if v['c'] == 0 or t == self.seq_length: 185 | # END token special case here, or we reached the end. 186 | # add the beam to a set of done beams 187 | self.done_beams[k].append({'seq': beam_seq[:, vix].clone(), 188 | 'logps': beam_seq_logprobs[:, vix].clone(), 189 | 'p': beam_logprobs_sum[vix] 190 | }) 191 | 192 | # encode as vectors 193 | it = beam_seq[t-1] 194 | xt = self.embed(Variable(it.cuda())) 195 | 196 | if t >= 1: 197 | state = new_state 198 | 199 | output, state = self.core(xt, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, state) 200 | logprobs = F.log_softmax(self.logit(output)) 201 | 202 | self.done_beams[k] = sorted(self.done_beams[k], key=lambda x: -x['p']) 203 | seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score 204 | seqLogprobs[:, k] = self.done_beams[k][0]['logps'] 205 | # return the samples and their log likelihoods 206 | return seq.transpose(0, 1), seqLogprobs.transpose(0, 1) 207 | 208 | def sample(self, fc_feats, att_feats, opt={}): 209 | sample_max = opt.get('sample_max', 1) 210 | beam_size = opt.get('beam_size', 1) 211 | temperature = opt.get('temperature', 1.0) 212 | if beam_size > 1: 213 | return self.sample_beam(fc_feats, att_feats, opt) 214 | 215 | batch_size = fc_feats.size(0) 216 | state = self.init_hidden(batch_size) 217 | 218 | # embed fc and att feats 219 | fc_feats = self.fc_embed(fc_feats) 220 | _att_feats = self.att_embed(att_feats.view(-1, self.att_feat_size)) 221 | att_feats = _att_feats.view(*(att_feats.size()[:-1] + (self.rnn_size,))) 222 | 223 | # Project the attention feats first to reduce memory and computation comsumptions. 224 | p_att_feats = self.ctx2att(att_feats.view(-1, self.rnn_size)) 225 | p_att_feats = p_att_feats.view(*(att_feats.size()[:-1] + (self.att_hid_size,))) 226 | 227 | seq = [] 228 | seqLogprobs = [] 229 | for t in range(self.seq_length + 1): 230 | if t == 0: # input 231 | it = fc_feats.data.new(batch_size).long().zero_() 232 | elif sample_max: 233 | sampleLogprobs, it = torch.max(logprobs.data, 1) 234 | it = it.view(-1).long() 235 | else: 236 | if temperature == 1.0: 237 | prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1) 238 | else: 239 | # scale logprobs by temperature 240 | prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu() 241 | it = torch.multinomial(prob_prev, 1).cuda() 242 | sampleLogprobs = logprobs.gather(1, Variable(it, requires_grad=False)) # gather the logprobs at sampled positions 243 | it = it.view(-1).long() # and flatten indices for downstream processing 244 | 245 | xt = self.embed(Variable(it, requires_grad=False)) 246 | 247 | if t >= 1: 248 | # stop when all finished 249 | if t == 1: 250 | unfinished = it > 0 251 | else: 252 | unfinished = unfinished * (it > 0) 253 | if unfinished.sum() == 0: 254 | break 255 | it = it * unfinished.type_as(it) 256 | seq.append(it) #seq[t] the input of t+2 time step 257 | 258 | seqLogprobs.append(sampleLogprobs.view(-1)) 259 | 260 | output, state = self.core(xt, fc_feats, att_feats, p_att_feats, state) 261 | logprobs = F.log_softmax(self.logit(output)) 262 | 263 | return torch.cat([_.unsqueeze(1) for _ in seq], 1), torch.cat([_.unsqueeze(1) for _ in seqLogprobs], 1) 264 | 265 | 266 | class TopDownCore(nn.Module): 267 | def __init__(self, opt, use_maxout=False): 268 | super(TopDownCore, self).__init__() 269 | self.drop_prob_lm = opt.drop_prob_lm 270 | 271 | self.att_lstm = nn.LSTMCell(opt.input_encoding_size + opt.rnn_size * 2, opt.rnn_size) # we, fc, h^2_t-1 272 | self.lang_lstm = nn.LSTMCell(opt.rnn_size * 2, opt.rnn_size) # h^1_t, \hat v 273 | self.attention = Attention(opt) 274 | 275 | def forward(self, xt, fc_feats, att_feats, p_att_feats, state): 276 | prev_h = state[0][-1] 277 | att_lstm_input = torch.cat([prev_h, fc_feats, xt], 1) 278 | 279 | h_att, c_att = self.att_lstm(att_lstm_input, (state[0][0], state[1][0])) 280 | 281 | att = self.attention(h_att, att_feats, p_att_feats) 282 | 283 | lang_lstm_input = torch.cat([att, h_att], 1) 284 | # lang_lstm_input = torch.cat([att, F.dropout(h_att, self.drop_prob_lm, self.training)], 1) ????? 285 | 286 | h_lang, c_lang = self.lang_lstm(lang_lstm_input, (state[0][1], state[1][1])) 287 | 288 | output = F.dropout(h_lang, self.drop_prob_lm, self.training) 289 | state = (torch.stack([h_att, h_lang]), torch.stack([c_att, c_lang])) 290 | 291 | return output, state 292 | 293 | class Attention(nn.Module): 294 | def __init__(self, opt): 295 | super(Attention, self).__init__() 296 | self.rnn_size = opt.rnn_size 297 | self.att_hid_size = opt.att_hid_size 298 | 299 | self.h2att = nn.Linear(self.rnn_size, self.att_hid_size) 300 | self.alpha_net = nn.Linear(self.att_hid_size, 1) 301 | 302 | def forward(self, h, att_feats, p_att_feats): 303 | # The p_att_feats here is already projected 304 | att_size = att_feats.numel() // att_feats.size(0) // self.rnn_size 305 | att = p_att_feats.view(-1, att_size, self.att_hid_size) 306 | 307 | att_h = self.h2att(h) # batch * att_hid_size 308 | att_h = att_h.unsqueeze(1).expand_as(att) # batch * att_size * att_hid_size 309 | dot = att + att_h # batch * att_size * att_hid_size 310 | dot = F.tanh(dot) # batch * att_size * att_hid_size 311 | dot = dot.view(-1, self.att_hid_size) # (batch * att_size) * att_hid_size 312 | dot = self.alpha_net(dot) # (batch * att_size) * 1 313 | dot = dot.view(-1, att_size) # batch * att_size 314 | 315 | weight = F.softmax(dot) # batch * att_size 316 | att_feats_ = att_feats.view(-1, att_size, self.rnn_size) # batch * att_size * att_feat_size 317 | att_res = torch.bmm(weight.unsqueeze(1), att_feats_).squeeze(1) # batch * att_feat_size 318 | 319 | return att_res 320 | 321 | class TopDownModel(AttModel): 322 | def __init__(self, opt): 323 | super(TopDownModel, self).__init__(opt) 324 | self.num_layers = 2 325 | self.core = TopDownCore(opt) 326 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import copy 7 | 8 | import numpy as np 9 | import misc.utils as utils 10 | import torch 11 | 12 | from .ShowTellModel import ShowTellModel 13 | from .FCModel import FCModel 14 | #from .OldModel import ShowAttendTellModel, AllImgModel 15 | from .Att2inModel import Att2inModel 16 | from .AttModel import * 17 | 18 | def setup(opt): 19 | 20 | if opt.caption_model == 'fc': 21 | model = FCModel(opt) 22 | # Att2in model in self-critical 23 | elif opt.caption_model == 'att2in': 24 | model = Att2inModel(opt) 25 | # Att2in model with two-layer MLP img embedding and word embedding 26 | elif opt.caption_model == 'att2in2': 27 | model = Att2in2Model(opt) 28 | # Adaptive Attention model from Knowing when to look 29 | elif opt.caption_model == 'adaatt': 30 | model = AdaAttModel(opt) 31 | # Adaptive Attention with maxout lstm 32 | elif opt.caption_model == 'adaattmo': 33 | model = AdaAttMOModel(opt) 34 | # Top-down attention model 35 | elif opt.caption_model == 'topdown': 36 | model = TopDownModel(opt) 37 | else: 38 | raise Exception("Caption model not supported: {}".format(opt.caption_model)) 39 | 40 | # check compatibility if training is continued from previously saved model 41 | if vars(opt).get('start_from', None) is not None: 42 | # check if all necessary files exist 43 | assert os.path.isdir(opt.start_from)," %s must be a a path" % opt.start_from 44 | assert os.path.isfile(os.path.join(opt.start_from,"infos_"+opt.id+".pkl")),"infos.pkl file does not exist in path %s"%opt.start_from 45 | model.load_state_dict(torch.load(os.path.join(opt.start_from, 'model.pth'))) 46 | 47 | return model -------------------------------------------------------------------------------- /opts.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def parse_opt(): 4 | parser = argparse.ArgumentParser() 5 | # Data input settings 6 | parser.add_argument('--input_json', type=str, default='data/cocotalk.json', 7 | help='path to the json file containing additional info and vocab') 8 | parser.add_argument('--input_fc_dir', type=str, default='data/cocotalk_fc', 9 | help='path to the directory containing the preprocessed fc feats') 10 | parser.add_argument('--input_att_dir', type=str, default='data/cocotalk_att', 11 | help='path to the directory containing the preprocessed att feats') 12 | parser.add_argument('--input_label_h5', type=str, default='data/cocotalk_label.h5', 13 | # help='path to the h5file containing the preprocessed dataset') 14 | help='path to the h5file containing the preprocessed label') 15 | parser.add_argument('--input_image_h5', type=str, default='data/coco_image.h5', 16 | help='path to the h5file containing the preprocessed image') 17 | parser.add_argument('--cnn_model', type=str, default='resnet101', 18 | help='resnet') 19 | parser.add_argument('--cnn_weight', type=str, default='resnet101.pth', 20 | help='path to CNN tf model. Note this MUST be a resnet right now.') 21 | parser.add_argument('--start_from', type=str, default='save/', 22 | help="""continue training from saved model at this path. Path must contain files saved by previous training process: 23 | 'infos.pkl' : configuration; 24 | 'checkpoint' : paths to model file(s) (created by tf). 25 | Note: this file contains absolute paths, be careful when moving files around; 26 | 'model.ckpt-*' : file(s) with model definition (created by tf) 27 | """) 28 | # Model settings 29 | parser.add_argument('--caption_model', type=str, default="topdown", 30 | help='show_tell, show_attend_tell, all_img, fc, att2in, att2in2, adaatt, adaattmo, topdown') 31 | parser.add_argument('--rnn_size', type=int, default=512, 32 | help='size of the rnn in number of hidden nodes in each layer') 33 | parser.add_argument('--num_layers', type=int, default=1, 34 | help='number of layers in the RNN') 35 | parser.add_argument('--rnn_type', type=str, default='lstm', 36 | help='rnn, gru, or lstm') 37 | parser.add_argument('--input_encoding_size', type=int, default=512, 38 | help='the encoding size of each token in the vocabulary, and the image.') 39 | parser.add_argument('--att_hid_size', type=int, default=512, 40 | help='the hidden size of the attention MLP; only useful in show_attend_tell; 0 if not using hidden layer') 41 | parser.add_argument('--fc_feat_size', type=int, default=2048, 42 | help='2048 for resnet, 4096 for vgg') 43 | parser.add_argument('--att_feat_size', type=int, default=2048, 44 | help='2048 for resnet, 512 for vgg') 45 | 46 | # Optimization: General 47 | parser.add_argument('--max_epochs', type=int, default=10, 48 | help='number of epochs') 49 | parser.add_argument('--batch_size', type=int, default=16, 50 | help='minibatch size') 51 | parser.add_argument('--grad_clip', type=float, default=0.1, #5., 52 | help='clip gradients at this value') 53 | parser.add_argument('--drop_prob_lm', type=float, default=0.5, 54 | help='strength of dropout in the Language Model RNN') 55 | parser.add_argument('--self_critical_after', type=int, default=-1, 56 | help='After what epoch do we start finetuning the CNN? (-1 = disable; never finetune, 0 = finetune from start)') 57 | parser.add_argument('--finetune_cnn_after', type=int, default=-1, 58 | help='After what epoch do we start finetuning the CNN? (-1 = disable; never finetune, 0 = finetune from start)') 59 | parser.add_argument('--seq_per_img', type=int, default=5, 60 | help='number of captions to sample for each image during training. Done for efficiency since CNN forward pass is expensive. E.g. coco has 5 sents/image') 61 | parser.add_argument('--beam_size', type=int, default=1, 62 | help='used when sample_max = 1, indicates number of beams in beam search. Usually 2 or 3 works well. More is not better. Set this to 1 for faster runtime but a bit worse performance.') 63 | 64 | #Optimization: for the Language Model 65 | parser.add_argument('--optim', type=str, default='adam', 66 | help='what update to use? rmsprop|sgd|sgdmom|adagrad|adam') 67 | parser.add_argument('--learning_rate', type=float, default=5e-5,#4e-4, 68 | help='learning rate') 69 | parser.add_argument('--learning_rate_decay_start', type=int, default=-1,#-1, 70 | help='at what iteration to start decaying learning rate? (-1 = dont) (in epoch)') 71 | parser.add_argument('--learning_rate_decay_every', type=int, default=3, 72 | help='every how many iterations thereafter to drop LR?(in epoch)') 73 | parser.add_argument('--learning_rate_decay_rate', type=float, default=0.8, 74 | help='every how many iterations thereafter to drop LR?(in epoch)') 75 | # parser.add_argument('--optim_alpha', type=float, default=0.9, 76 | # help='alpha for adam') 77 | parser.add_argument('--optim_alpha', type=float, default=0.8, 78 | help='alpha for adam') 79 | parser.add_argument('--optim_beta', type=float, default=0.999, 80 | help='beta used for adam') 81 | parser.add_argument('--optim_epsilon', type=float, default=1e-8, 82 | help='epsilon that goes into denominator for smoothing') 83 | # parser.add_argument('--weight_decay', type=float, default=0, 84 | # help='weight_decay') 85 | 86 | #Optimization: for the CNN 87 | parser.add_argument('--cnn_optim', type=str, default='adam', 88 | help='optimization to use for CNN') 89 | parser.add_argument('--cnn_optim_alpha', type=float, default=0.8, 90 | help='alpha for momentum of CNN') 91 | parser.add_argument('--cnn_optim_beta', type=float, default=0.999, 92 | help='beta for momentum of CNN') 93 | parser.add_argument('--cnn_learning_rate', type=float, default=1e-5, 94 | help='learning rate for the CNN') 95 | parser.add_argument('--cnn_weight_decay', type=float, default=0, 96 | help='L2 weight decay just for the CNN') 97 | 98 | parser.add_argument('--scheduled_sampling_start', type=int, default=-1, #-1 99 | help='at what iteration to start decay gt probability') 100 | parser.add_argument('--scheduled_sampling_increase_every', type=int, default=5, 101 | help='every how many iterations thereafter to gt probability') 102 | parser.add_argument('--scheduled_sampling_increase_prob', type=float, default=0.05, 103 | help='How much to update the prob') 104 | parser.add_argument('--scheduled_sampling_max_prob', type=float, default=0.25, 105 | help='Maximum scheduled sampling prob.') 106 | 107 | 108 | # Evaluation/Checkpointing 109 | parser.add_argument('--val_images_use', type=int, default=5000,#3200 110 | help='how many images to use when periodically evaluating the validation loss? (-1 = all)') 111 | parser.add_argument('--save_checkpoint_every', type=int, default=2500, 112 | help='how often to save a model checkpoint (in iterations)?') 113 | parser.add_argument('--checkpoint_path', type=str, default='save', 114 | help='directory to store checkpointed models') 115 | parser.add_argument('--language_eval', type=int, default=1, 116 | help='Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.') 117 | parser.add_argument('--losses_log_every', type=int, default=25, 118 | help='How often do we snapshot losses, for inclusion in the progress dump? (0 = disable)') 119 | parser.add_argument('--load_best_score', type=int, default=1, 120 | help='Do we load previous best score when resuming training.') 121 | 122 | # misc 123 | parser.add_argument('--id', type=str, default='topdown_rl', 124 | help='an id identifying this run/job. used in cross-val and appended when writing progress files') 125 | parser.add_argument('--train_only', type=int, default=0, 126 | help='if true then use 80k, else use 110k') 127 | 128 | args = parser.parse_args() 129 | 130 | # Check if args are valid 131 | assert args.rnn_size > 0, "rnn_size should be greater than 0" 132 | assert args.num_layers > 0, "num_layers should be greater than 0" 133 | assert args.input_encoding_size > 0, "input_encoding_size should be greater than 0" 134 | assert args.batch_size > 0, "batch_size should be greater than 0" 135 | assert args.drop_prob_lm >= 0 and args.drop_prob_lm < 1, "drop_prob_lm should be between 0 and 1" 136 | assert args.seq_per_img > 0, "seq_per_img should be greater than 0" 137 | assert args.beam_size > 0, "beam_size should be greater than 0" 138 | assert args.save_checkpoint_every > 0, "save_checkpoint_every should be greater than 0" 139 | assert args.losses_log_every > 0, "losses_log_every should be greater than 0" 140 | assert args.language_eval == 0 or args.language_eval == 1, "language_eval should be 0 or 1" 141 | assert args.load_best_score == 0 or args.load_best_score == 1, "language_eval should be 0 or 1" 142 | assert args.train_only == 0 or args.train_only == 1, "language_eval should be 0 or 1" 143 | 144 | return args -------------------------------------------------------------------------------- /scripts/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andyweizhao/Multitask_Image_Captioning/c672fe480618fccce1239600a394cf62f0b32719/scripts/.DS_Store -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andyweizhao/Multitask_Image_Captioning/c672fe480618fccce1239600a394cf62f0b32719/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/prepro_feats.py: -------------------------------------------------------------------------------- 1 | """ 2 | Preprocess a raw json dataset into hdf5/json files for use in data_loader.lua 3 | 4 | Input: json file that has the form 5 | [{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...] 6 | example element in this list would look like 7 | {'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895} 8 | 9 | This script reads this json, does some basic preprocessing on the captions 10 | (e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays 11 | 12 | Output: a json file and an hdf5 file 13 | The hdf5 file contains several fields: 14 | /images is (N,3,256,256) uint8 array of raw image data in RGB format 15 | /labels is (M,max_length) uint32 array of encoded labels, zero padded 16 | /label_start_ix and /label_end_ix are (N,) uint32 arrays of pointers to the 17 | first and last indices (in range 1..M) of labels for each image 18 | /label_length stores the length of the sequence for each of the M sequences 19 | 20 | The json file has a dict that contains: 21 | - an 'ix_to_word' field storing the vocab in form {ix:'word'}, where ix is 1-indexed 22 | - an 'images' field that is a list holding auxiliary information for each image, 23 | such as in particular the 'split' it was assigned to. 24 | """ 25 | 26 | from __future__ import absolute_import 27 | from __future__ import division 28 | from __future__ import print_function 29 | 30 | import os 31 | import json 32 | import argparse 33 | from random import shuffle, seed 34 | import string 35 | # non-standard dependencies: 36 | import h5py 37 | from six.moves import cPickle 38 | import numpy as np 39 | import torch 40 | #import torchvision.models as models 41 | from torch.autograd import Variable 42 | import skimage.io 43 | #from torchvision import transforms as trn 44 | #preprocess = trn.Compose([ 45 | # #trn.ToTensor(), 46 | # trn.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 47 | #]) 48 | from misc.resnet_utils import myResnet 49 | import misc.resnet as resnet 50 | from scipy.misc import imread, imresize 51 | 52 | #os.environ["CUDA_VISIBLE_DEVICES"] = "0" 53 | 54 | def main(params): 55 | # net = getattr(resnet, params['model'])() 56 | # net.load_state_dict(torch.load(os.path.join(params['model_root'],params['model']+'.pth'))) 57 | # my_resnet = myResnet(net) 58 | # my_resnet.cuda() 59 | # my_resnet.eval() 60 | 61 | imgs = json.load(open(params['input_json'], 'r')) 62 | imgs = imgs['images'] 63 | 64 | seed(123) # make reproducible 65 | 66 | # dir_fc = params['output_dir']+'_fc' 67 | # dir_att = params['output_dir']+'_att' 68 | # if not os.path.isdir(dir_fc): 69 | # os.mkdir(dir_fc) 70 | # if not os.path.isdir(dir_att): 71 | # os.mkdir(dir_att) 72 | 73 | # create output h5 file 74 | N = len(imgs) 75 | resize = 512 #256,448 76 | 77 | f = h5py.File(params['output_h5'], "w") 78 | 79 | dset = f.create_dataset("images", (N,3,resize,resize), dtype='uint8') # space for resized images 80 | # img = imgs[1] 81 | dset = [] 82 | # f = h5py.File('image_path2.h5', "w") 83 | # dset = f.create_dataset("images_path", (N,), dtype='str') # space for resized images 84 | 85 | for i,img in enumerate(imgs): 86 | # load the image 87 | # I = imread(os.path.join(params['images_root'], img['filepath'], img['filename'])) 88 | # try: 89 | # Ir = imresize(I, (resize,resize)) 90 | # except: 91 | # print('failed resizing image %s - see http://git.io/vBIE0' % (img['file_path'],)) 92 | # raise 93 | # if len(Ir.shape) == 2: 94 | # Ir = Ir[:,:,np.newaxis] 95 | # Ir = np.concatenate((Ir,Ir,Ir), axis=2) 96 | # # and swap order of axes from (256,256,3) to (3,256,256) 97 | # Ir = Ir.transpose(2,0,1) 98 | # # write to h5 99 | # dset[i] = Ir 100 | 101 | I_path = os.path.join(params['images_root'], img['filepath'], img['filename']) 102 | dset.append(I_path) 103 | 104 | if i % 1000 == 0: 105 | print('processing %d/%d (%.2f%% done)' % (i, N, i*100.0/N)) 106 | # f.close() 107 | # print('wrote ', params['output_h5']+'_image.h5') 108 | 109 | #np.save('image_path.npy', np.array(dset)) 110 | 111 | if __name__ == "__main__": 112 | 113 | parser = argparse.ArgumentParser() 114 | 115 | # input json 116 | parser.add_argument('--input_json', default='/data1/zsfx/wabywang/caption/010/data/dataset_coco.json', help='input json file to process into hdf5') 117 | # parser.add_argument('--output_dir', default='data/cocotalk', help='output h5 file') 118 | parser.add_argument('--output_h5', default='coco_image_512.h5', help='output h5 file') 119 | 120 | # options 121 | parser.add_argument('--max_length', default=16, type=int, help='max length of a caption, in number of words. captions longer than this get clipped.') 122 | parser.add_argument('--images_root', default='/data1/zsfx/wabywang/caption/dataset/MSCOCO', help='root location in which images are stored, to be prepended to file_path in input json') 123 | # parser.add_argument('--att_size', default=14, type=int, help='14x14 or 7x7') 124 | # parser.add_argument('--model', default='resnet101', type=str, help='resnet101, resnet152') 125 | # parser.add_argument('--model_root', default='./data/imagenet_weights', type=str, help='model root') 126 | 127 | args = parser.parse_args() 128 | params = vars(args) # convert to ordinary dict 129 | print('parsed input parameters:') 130 | print(json.dumps(params, indent = 2)) 131 | main(params) 132 | 133 | #train = json.load(open('/nlp/dataset/Caption_CN/annotations/captions_train2017.json', 'r')) 134 | #val = json.load(open('/nlp/dataset/Caption_CN/annotations/captions_val2017.json', 'r')) 135 | # 136 | #imgs1 = [{'cocoid':os.path.splitext(o.get('image_id'))[0],'filename':o.get('image_id'),'filepath':'train2017'} for o in train] 137 | #imgs2 = [{'cocoid':os.path.splitext(o.get('image_id'))[0],'filename':o.get('image_id'),'filepath':'val2017'} for o in val] 138 | #imgs = imgs1 + imgs2 139 | # 140 | #files = [ os.path.splitext(f)[0] for f in os.listdir('/nlp/andyweizhao/self-critical.pytorch_CN/data/cocotalk_fc/')] 141 | #files = set(files) 142 | #imgs = [img for img in imgs if img.get('cocoid') not in files] 143 | #for o in b: 144 | # img_id = o.get('image_id') 145 | # if img_id not in image_list: 146 | # entry = {'image_id': img_id, 'caption': o.get('caption')} 147 | # image_list.append(img_id) 148 | # predictions.append(entry) 149 | #json.dump(predictions, open('vis/val.json', 'w')) 150 | -------------------------------------------------------------------------------- /scripts/prepro_feats_coco.py: -------------------------------------------------------------------------------- 1 | """ 2 | Preprocess a raw json dataset into hdf5/json files for use in data_loader.lua 3 | 4 | Input: json file that has the form 5 | [{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...] 6 | example element in this list would look like 7 | {'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895} 8 | 9 | This script reads this json, does some basic preprocessing on the captions 10 | (e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays 11 | 12 | Output: a json file and an hdf5 file 13 | The hdf5 file contains several fields: 14 | /images is (N,3,256,256) uint8 array of raw image data in RGB format 15 | /labels is (M,max_length) uint32 array of encoded labels, zero padded 16 | /label_start_ix and /label_end_ix are (N,) uint32 arrays of pointers to the 17 | first and last indices (in range 1..M) of labels for each image 18 | /label_length stores the length of the sequence for each of the M sequences 19 | 20 | The json file has a dict that contains: 21 | - an 'ix_to_word' field storing the vocab in form {ix:'word'}, where ix is 1-indexed 22 | - an 'images' field that is a list holding auxiliary information for each image, 23 | such as in particular the 'split' it was assigned to. 24 | """ 25 | 26 | from __future__ import absolute_import 27 | from __future__ import division 28 | from __future__ import print_function 29 | 30 | import os 31 | import json 32 | import argparse 33 | from random import shuffle, seed 34 | import string 35 | # non-standard dependencies: 36 | import h5py 37 | from six.moves import cPickle 38 | import numpy as np 39 | import torch 40 | #import torchvision.models as models 41 | from torch.autograd import Variable 42 | import skimage.io 43 | 44 | from torchvision import transforms as trn 45 | preprocess = trn.Compose([ 46 | #trn.ToTensor(), 47 | trn.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 48 | ]) 49 | 50 | from misc.resnet_utils import myResnet 51 | import misc.resnet as resnet 52 | 53 | 54 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 55 | 56 | 57 | 58 | 59 | def main(params): 60 | net = getattr(resnet, params['model'])() 61 | net.load_state_dict(torch.load(os.path.join(params['model_root'],params['model']+'.pth'))) 62 | my_resnet = myResnet(net) 63 | my_resnet.cuda() 64 | my_resnet.eval() 65 | 66 | imgs = json.load(open(params['input_json'], 'r')) 67 | imgs = imgs['images'] 68 | N = len(imgs) 69 | 70 | seed(123) # make reproducible 71 | 72 | dir_fc = params['output_dir']+'_fc' 73 | dir_att = params['output_dir']+'_att' 74 | if not os.path.isdir(dir_fc): 75 | os.mkdir(dir_fc) 76 | if not os.path.isdir(dir_att): 77 | os.mkdir(dir_att) 78 | 79 | for i,img in enumerate(imgs): 80 | # load the image 81 | I = skimage.io.imread(os.path.join(params['images_root'], img['filepath'], img['filename'])) 82 | # handle grayscale input images 83 | if len(I.shape) == 2: 84 | I = I[:,:,np.newaxis] 85 | I = np.concatenate((I,I,I), axis=2) 86 | 87 | I = I.astype('float32')/255.0 88 | I = torch.from_numpy(I.transpose([2,0,1])).cuda() 89 | I = Variable(preprocess(I), volatile=True) 90 | tmp_fc, tmp_att = my_resnet(I, params['att_size']) 91 | # write to pkl 92 | np.save(os.path.join(dir_fc, str(img['cocoid'])), tmp_fc.data.cpu().float().numpy()) 93 | np.savez_compressed(os.path.join(dir_att, str(img['cocoid'])), feat=tmp_att.data.cpu().float().numpy()) 94 | 95 | if i % 1000 == 0: 96 | print('processing %d/%d (%.2f%% done)' % (i, N, i*100.0/N)) 97 | print('wrote ', params['output_dir']) 98 | 99 | if __name__ == "__main__": 100 | 101 | parser = argparse.ArgumentParser() 102 | 103 | # input json 104 | parser.add_argument('--input_json', default='data/dataset_coco_ccg.json', help='input json file to process into hdf5') 105 | parser.add_argument('--output_dir', default='data/cocotalk', help='output h5 file') 106 | 107 | # options 108 | parser.add_argument('--images_root', default='../dataset/MSCOCO', help='root location in which images are stored, to be prepended to file_path in input json') 109 | #parser.add_argument('--input_json', default='data/dataset_coco.json', help='input json file to process into hdf5') 110 | #parser.add_argument('--output_dir', default='data/cocotalk', help='output h5 file') 111 | 112 | # options 113 | #parser.add_argument('--images_root', default='/nlp/dataset/MSCOCO', help='root location in which images are stored, to be prepended to file_path in input json') 114 | 115 | parser.add_argument('--att_size', default=14, type=int, help='14x14 or 7x7') 116 | parser.add_argument('--model', default='resnet101', type=str, help='resnet101, resnet152') 117 | parser.add_argument('--model_root', default='./data/imagenet_weights', type=str, help='model root') 118 | 119 | args = parser.parse_args() 120 | params = vars(args) # convert to ordinary dict 121 | print('parsed input parameters:') 122 | print(json.dumps(params, indent = 2)) 123 | main(params) 124 | -------------------------------------------------------------------------------- /scripts/prepro_labels.py: -------------------------------------------------------------------------------- 1 | """ 2 | Preprocess a raw json dataset into hdf5/json files for use in data_loader.lua 3 | 4 | Input: json file that has the form 5 | [{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...] 6 | example element in this list would look like 7 | {'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895} 8 | 9 | This script reads this json, does some basic preprocessing on the captions 10 | (e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays 11 | 12 | Output: a json file and an hdf5 file 13 | The hdf5 file contains several fields: 14 | /images is (N,3,256,256) uint8 array of raw image data in RGB format 15 | /labels is (M,max_length) uint32 array of encoded labels, zero padded 16 | /label_start_ix and /label_end_ix are (N,) uint32 arrays of pointers to the 17 | first and last indices (in range 1..M) of labels for each image 18 | /label_length stores the length of the sequence for each of the M sequences 19 | 20 | The json file has a dict that contains: 21 | - an 'ix_to_word' field storing the vocab in form {ix:'word'}, where ix is 1-indexed 22 | - an 'images' field that is a list holding auxiliary information for each image, 23 | such as in particular the 'split' it was assigned to. 24 | """ 25 | 26 | from __future__ import absolute_import 27 | from __future__ import division 28 | from __future__ import print_function 29 | 30 | import os 31 | import json 32 | import argparse 33 | from random import shuffle, seed 34 | import string 35 | # non-standard dependencies: 36 | import h5py 37 | import numpy as np 38 | import torch 39 | #import torchvision.models as models 40 | from torch.autograd import Variable 41 | import skimage.io 42 | 43 | #python scripts/prepro_labels.py --input_json data/dataset_coco.json --output_json data/cocotalk.json --output_h5 data/cocotalk 44 | 45 | def build_vocab(imgs, params): 46 | count_thr = params['word_count_threshold'] 47 | 48 | # count up the number of words 49 | counts = {} 50 | for img in imgs: 51 | for sent in img['sentences']: 52 | for w in sent['tokens']: 53 | counts[w] = counts.get(w, 0) + 1 54 | cw = sorted([(count,w) for w,count in counts.items()], reverse=True) 55 | print('top words and their counts:') 56 | print('\n'.join(map(str,cw[:20]))) 57 | 58 | # print some stats 59 | total_words = sum(counts.values()) 60 | print('total words:', total_words) 61 | bad_words = [w for w,n in counts.items() if n <= count_thr] 62 | vocab = [w for w,n in counts.items() if n > count_thr] 63 | bad_count = sum(counts[w] for w in bad_words) 64 | print('number of bad words: %d/%d = %.2f%%' % (len(bad_words), len(counts), len(bad_words)*100.0/len(counts))) 65 | print('number of words in vocab would be %d' % (len(vocab), )) 66 | print('number of UNKs: %d/%d = %.2f%%' % (bad_count, total_words, bad_count*100.0/total_words)) 67 | 68 | # lets look at the distribution of lengths as well 69 | sent_lengths = {} 70 | for img in imgs: 71 | for sent in img['sentences']: 72 | txt = sent['tokens'] 73 | nw = len(txt) 74 | sent_lengths[nw] = sent_lengths.get(nw, 0) + 1 75 | max_len = max(sent_lengths.keys()) 76 | print('max length sentence in raw data: ', max_len) 77 | print('sentence length distribution (count, number of words):') 78 | sum_len = sum(sent_lengths.values()) 79 | for i in range(max_len+1): 80 | print('%2d: %10d %f%%' % (i, sent_lengths.get(i,0), sent_lengths.get(i,0)*100.0/sum_len)) 81 | 82 | # lets now produce the final annotations 83 | if bad_count > 0: 84 | # additional special UNK token we will use below to map infrequent words to 85 | print('inserting the special UNK token') 86 | vocab.append('UNK') 87 | 88 | for img in imgs: 89 | img['final_captions'] = [] 90 | for sent in img['sentences']: 91 | txt = sent['tokens'] 92 | caption = [w if counts.get(w,0) > count_thr else 'UNK' for w in txt] 93 | img['final_captions'].append(caption) 94 | 95 | return vocab 96 | 97 | def build_en_vocab(imgs, params): 98 | count_thr = params['word_count_threshold'] 99 | 100 | # count up the number of words 101 | counts = {} 102 | for img in imgs: 103 | for sent in img['sentences']: 104 | for w in sent['coco_en'].lower().split(): 105 | counts[w] = counts.get(w, 0) + 1 106 | cw = sorted([(count,w) for w,count in counts.items()], reverse=True) 107 | print('top en words and their counts:') 108 | print('\n'.join(map(str,cw[:20]))) 109 | 110 | # print some stats 111 | total_words = sum(counts.values()) 112 | print('total words:', total_words) 113 | bad_words = [w for w,n in counts.items() if n <= count_thr] 114 | vocab = [w for w,n in counts.items() if n > count_thr] 115 | bad_count = sum(counts[w] for w in bad_words) 116 | print('number of bad words: %d/%d = %.2f%%' % (len(bad_words), len(counts), len(bad_words)*100.0/len(counts))) 117 | print('number of words in vocab would be %d' % (len(vocab), )) 118 | print('number of UNKs: %d/%d = %.2f%%' % (bad_count, total_words, bad_count*100.0/total_words)) 119 | 120 | # lets look at the distribution of lengths as well 121 | sent_lengths = {} 122 | for img in imgs: 123 | for sent in img['sentences']: 124 | txt = sent['coco_en'].lower().split() 125 | nw = len(txt) 126 | sent_lengths[nw] = sent_lengths.get(nw, 0) + 1 127 | max_len = max(sent_lengths.keys()) 128 | print('max length sentence in raw data: ', max_len) 129 | print('sentence length distribution (count, number of words):') 130 | sum_len = sum(sent_lengths.values()) 131 | for i in range(max_len+1): 132 | print('%2d: %10d %f%%' % (i, sent_lengths.get(i,0), sent_lengths.get(i,0)*100.0/sum_len)) 133 | 134 | # lets now produce the final annotations 135 | if bad_count > 0: 136 | # additional special UNK token we will use below to map infrequent words to 137 | print('inserting the special UNK token') 138 | vocab.append('UNK') 139 | 140 | for img in imgs: 141 | img['en_captions'] = [] 142 | for sent in img['sentences']: 143 | txt =sent['coco_en'].lower().split() 144 | caption = [w if counts.get(w,0) > count_thr else 'UNK' for w in txt] 145 | img['en_captions'].append(caption) 146 | 147 | return vocab 148 | 149 | def encode_captions(imgs, params, wtoi): 150 | """ 151 | encode all captions into one large array, which will be 1-indexed. 152 | also produces label_start_ix and label_end_ix which store 1-indexed 153 | and inclusive (Lua-style) pointers to the first and last caption for 154 | each image in the dataset. 155 | """ 156 | 157 | max_length = params['max_length'] 158 | N = len(imgs) 159 | M = sum(len(img['final_captions']) for img in imgs) # total number of captions 160 | 161 | label_arrays = [] 162 | label_start_ix = np.zeros(N, dtype='uint32') # note: these will be one-indexed 163 | label_end_ix = np.zeros(N, dtype='uint32') 164 | label_length = np.zeros(M, dtype='uint32') 165 | caption_counter = 0 166 | counter = 1 167 | for i,img in enumerate(imgs): 168 | n = len(img['final_captions']) 169 | assert n > 0, 'error: some image has no captions' 170 | 171 | Li = np.zeros((n, max_length), dtype='uint32') 172 | for j,s in enumerate(img['final_captions']): 173 | label_length[caption_counter] = min(max_length, len(s)) # record the length of this sequence 174 | caption_counter += 1 175 | for k,w in enumerate(s): 176 | if k < max_length: 177 | Li[j,k] = wtoi[w] 178 | 179 | # note: word indices are 1-indexed, and captions are padded with zeros 180 | label_arrays.append(Li) 181 | label_start_ix[i] = counter 182 | label_end_ix[i] = counter + n - 1 183 | 184 | counter += n 185 | print(np.sum(label_length==0)) 186 | L = np.concatenate(label_arrays, axis=0) # put all the labels together 187 | assert L.shape[0] == M, 'lengths don\'t match? that\'s weird' 188 | # assert np.all(label_length > 0), 'error: some caption had no words?' # there are 15 sentence without any words 189 | 190 | print('encoded captions to array of size ', L.shape) 191 | return L, label_start_ix, label_end_ix, label_length 192 | 193 | def main(params): 194 | 195 | imgs = json.load(open(params['input_json'], 'r')) 196 | imgs = imgs['images'] 197 | 198 | seed(123) # make reproducible 199 | 200 | # create the vocab 201 | vocab = build_vocab(imgs, params) 202 | en_vocab = build_en_vocab(imgs, params) 203 | import cPickle as pickle 204 | pickle.dump(vocab,open("vocab.pkl","w")) 205 | pickle.dump(en_vocab,open("en_vocab.pkl","w")) 206 | 207 | itow = {i+1:w for i,w in enumerate(vocab)} # a 1-indexed vocab translation table 208 | wtoi = {w:i+1 for i,w in enumerate(vocab)} # inverse table 209 | 210 | # encode captions in large arrays, ready to ship to hdf5 file 211 | L, label_start_ix, label_end_ix, label_length = encode_captions(imgs, params, wtoi) 212 | 213 | # create output h5 file 214 | N = len(imgs) 215 | f_lb = h5py.File(params['output_h5']+'_label.h5', "w") 216 | f_lb.create_dataset("labels", dtype='uint32', data=L) 217 | f_lb.create_dataset("label_start_ix", dtype='uint32', data=label_start_ix) 218 | f_lb.create_dataset("label_end_ix", dtype='uint32', data=label_end_ix) 219 | f_lb.create_dataset("label_length", dtype='uint32', data=label_length) 220 | f_lb.close() 221 | 222 | # create output json file 223 | out = {} 224 | out['ix_to_word'] = itow # encode the (1-indexed) vocab 225 | out['images'] = [] 226 | for i,img in enumerate(imgs): 227 | 228 | jimg = {} 229 | jimg['split'] = img['split'] 230 | if 'filename' in img: jimg['file_path'] = os.path.join(img['filepath'], img['filename']) # copy it over, might need 231 | if 'cocoid' in img: jimg['id'] = img['cocoid'] # copy over & mantain an id, if present (e.g. coco ids, useful) 232 | 233 | out['images'].append(jimg) 234 | 235 | json.dump(out, open(params['output_json'], 'w')) 236 | print('wrote ', params['output_json']) 237 | 238 | if __name__ == "__main__": 239 | 240 | # dataset="/nlp/andyweizhao/self-critical.pytorch_CN/data1/" 241 | dataset="/home/andyweizhao/wabywang/010/data/dataset/" 242 | # dataset="coco" 243 | parser = argparse.ArgumentParser() 244 | 245 | # input json 246 | # parser.add_argument('--input_json', default=dataset+'dataset_coco.json', help='input json file to process into hdf5') 247 | parser.add_argument('--input_json', default=dataset+'dataset_coco_zh.json', help='input json file to process into hdf5') 248 | # parser.add_argument('--input_json', default=dataset+'dataset_coco_zh.json', help='input json file to process into hdf5') 249 | parser.add_argument('--output_json', default=dataset+'tmp/aitalk.json', help='output json file') 250 | parser.add_argument('--output_h5', default=dataset+'tmp/aitalk', help='output h5 file') 251 | 252 | # options 253 | parser.add_argument('--max_length', default=21, type=int, help='max length of a caption, in number of words. captions longer than this get clipped.') 254 | parser.add_argument('--word_count_threshold', default=5, type=int, help='only words that occur more than this number of times will be put in vocab') 255 | 256 | args = parser.parse_args() 257 | params = vars(args) # convert to ordinary dict 258 | print('parsed input parameters:') 259 | print(json.dumps(params, indent = 2)) 260 | main(params) 261 | 262 | #predictions = [] 263 | #image_list = [] 264 | #a = json.load(open('/nlp/dataset/MSCOCO/annotations/captions_val2014.json', 'r')) 265 | #b = a.get('annotations') 266 | # 267 | #for o in b: 268 | # img_id = o.get('image_id') 269 | # if img_id not in image_list: 270 | # entry = {'image_id': img_id, 'caption': o.get('caption')} 271 | # image_list.append(img_id) 272 | # predictions.append(entry) 273 | #json.dump(predictions, open('vis/val.json', 'w')) -------------------------------------------------------------------------------- /scripts/prepro_labels_detection.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import json 7 | import argparse 8 | from random import shuffle, seed 9 | import string 10 | # non-standard dependencies: 11 | import h5py 12 | import numpy as np 13 | import torch 14 | import torchvision.models as models 15 | from torch.autograd import Variable 16 | import skimage.io 17 | import pprint as pp 18 | from shutil import copy2 19 | 20 | images_root='/home/dl/dataset/MSCOCO' 21 | karpathy_annFile='data/dataset_coco.json' 22 | train_annFile='/nlp/dataset/MSCOCO/annotations/instances_train2014.json' 23 | val_annFile='/nlp/dataset/MSCOCO/annotations/instances_val2014.json' 24 | 25 | karpathy_dataset = json.load(open(karpathy_annFile,'r')) 26 | #train_imgs = json.load(open(train_annFile,'r')) 27 | #train_imgs['type']='instances' 28 | #json.dump(train_imgs,open('data/instances_train2014_new.json',"w")) 29 | # 30 | #val_imgs = json.load(open(val_annFile,'r')) 31 | #val_imgs['type']='instances' 32 | #json.dump(val_imgs,open('data/instances_val2014_new.json',"w")) 33 | 34 | 35 | 36 | #Copy images to corresponding folders 37 | ''' 38 | for x in xrange(len(karpathy_imgs['images'])): 39 | img = karpathy_imgs['images'][x] 40 | file_path = os.path.join(images_root,img['filepath'],img['filename']) 41 | dest_path = os.path.join(images_root,'karpathy',img['split']) 42 | copy2(file_path,dest_path) 43 | ''' 44 | 45 | import torchvision.datasets as datasets 46 | import sys 47 | sys.path.append("coco-caption") 48 | from pycocotools.coco import COCO 49 | 50 | trainset = datasets.CocoDetection(root='/nlp/dataset/MSCOCO/train2014', 51 | annFile='/nlp/dataset/MSCOCO/annotations/instances_train2014_new.json') 52 | 53 | valset = datasets.CocoDetection(root='/nlp/dataset/MSCOCO/val2014', 54 | annFile='/nlp/dataset/MSCOCO/annotations/instances_val2014_new.json') 55 | 56 | 57 | info = json.load(open('data/meta_coco_en.json')) 58 | ix_to_word = info['ix_to_word'] 59 | word_to_ix = {} 60 | for ix,w in ix_to_word.items(): 61 | word_to_ix[w]=ix 62 | 63 | category_table = train_imgs['categories'] 64 | category_dict = {} 65 | for i in xrange(len(category_table)): 66 | cat_id = category_table[i]['id'] 67 | cat_name = category_table[i]['name'] 68 | super_cat_name = category_table[i]['supercategory'] 69 | label_id = i+1 70 | category_dict[cat_id] = {'label_id':i+1,'cat_name':cat_name,'super_cat_name':super_cat_name} 71 | 72 | pre_train = {} 73 | pre_val = {} 74 | for x in xrange(len(trainset)): 75 | img,target = trainset[x] 76 | if len(target) < 1 : continue 77 | image_id = target[0]['image_id'] 78 | label_list = [category_dict.get(target[i]['category_id'])['label_id'] 79 | for i in xrange(len(target))] 80 | words,w_c = np.unique([word_to_ix.get(category_dict.get(target[i]['category_id'])['cat_name'],-1) 81 | for i in xrange(len(target))], return_counts=True) 82 | w_c_dict = dict(zip(words.astype(int), w_c)) 83 | w_c_dict.pop(-1, None) 84 | words = words.astype(int) 85 | words = words[words > 0] 86 | 87 | super_words,sw_c = np.unique([word_to_ix.get(category_dict.get(target[i]['category_id'])['super_cat_name'],-1) 88 | for i in xrange(len(target))], return_counts=True) 89 | sw_c_dict = dict(zip(super_words.astype(int), sw_c)) 90 | sw_c_dict.pop(-1, None) 91 | super_words = super_words.astype(int) 92 | super_words = super_words[super_words > 0] 93 | 94 | l, c = np.unique(label_list, return_counts=True) 95 | l_c_dict = dict(zip(l, c)) 96 | l_c_dict[0] = 1 97 | 98 | label = np.zeros(80 + 1) # 0 for background 99 | label[np.array(label_list,dtype=np.int16)] = 1 100 | label[0] = 1 # 0 for background 101 | l_import = np.array( [l_c_dict.get(i,0)/(len(label_list) + 1) for i in range(len(label))]) 102 | pre_train[image_id] = {'label':label,'words':words, 103 | 'super_words':super_words, 104 | 'l_import':l_import, 105 | 'w_import':w_c_dict, 106 | 'sw_import':sw_c_dict} 107 | 108 | for x in xrange(len(valset)): 109 | img,target = valset[x] 110 | if len(target) < 1 : continue 111 | image_id = target[0]['image_id'] 112 | label_list = [category_dict.get(target[i]['category_id'])['label_id'] 113 | for i in xrange(len(target))] 114 | words,w_c = np.unique([word_to_ix.get(category_dict.get(target[i]['category_id'])['cat_name'],-1) 115 | for i in xrange(len(target))], return_counts=True) 116 | w_c_dict = dict(zip(words.astype(int), w_c)) 117 | w_c_dict.pop(-1, None) 118 | words = words.astype(int) 119 | words = words[words > 0] 120 | 121 | super_words,sw_c = np.unique([word_to_ix.get(category_dict.get(target[i]['category_id'])['super_cat_name'],-1) 122 | for i in xrange(len(target))], return_counts=True) 123 | sw_c_dict = dict(zip(super_words.astype(int), sw_c)) 124 | sw_c_dict.pop(-1, None) 125 | super_words = super_words.astype(int) 126 | super_words = super_words[super_words > 0] 127 | 128 | l, c = np.unique(label_list, return_counts=True) 129 | l_c_dict = dict(zip(l, c)) 130 | l_c_dict[0] = 1 131 | 132 | label = np.zeros(80 + 1) # 0 for background 133 | label[np.array(label_list,dtype=np.int16)] = 1 134 | label[0] = 1 # 0 for background 135 | l_import = np.array( [l_c_dict.get(i,0)/(len(label_list) + 1) for i in range(len(label))]) 136 | pre_val[image_id] = {'label':label,'words':words, 137 | 'super_words':super_words, 138 | 'l_import':l_import, 139 | 'w_import':w_c_dict, 140 | 'sw_import':sw_c_dict} 141 | 142 | detection_dataset = dict(pre_train.items() + pre_val.items()) 143 | 144 | from six.moves import cPickle 145 | cPickle.dump(pre_train, open("data/detection_train.json", "w")) 146 | cPickle.dump(pre_val, open("data/detection_val.json", "w")) 147 | cPickle.dump(detection_dataset, open("data/detection_all.json", "w")) 148 | 149 | n_label_cnt = 0 150 | for x in xrange(len(karpathy_dataset['images'])): 151 | img_id = karpathy_dataset['images'][x]['cocoid'] 152 | if (detection_dataset.has_key(img_id)): 153 | label = list(detection_dataset[img_id]['label'].astype(int)) 154 | l_import = list(detection_dataset[img_id]['l_import']) 155 | super_words = list(detection_dataset[img_id]['super_words']) 156 | sw_import = detection_dataset[img_id]['sw_import'] 157 | w_import = detection_dataset[img_id]['w_import'] 158 | words = list(detection_dataset[img_id]['words']) 159 | karpathy_dataset['images'][x]['label'] = label 160 | karpathy_dataset['images'][x]['l_import'] = l_import 161 | karpathy_dataset['images'][x]['super_words'] = super_words 162 | karpathy_dataset['images'][x]['sw_import'] = sw_import 163 | karpathy_dataset['images'][x]['w_import'] = w_import 164 | karpathy_dataset['images'][x]['words'] = words 165 | else: 166 | label = list(np.zeros(81)) 167 | n_label_cnt += 1 168 | karpathy_dataset['images'][x]['label'] = label 169 | 170 | 171 | coco_label_dataset = 'data/dataset_coco_label.json' 172 | with open(coco_label_dataset,'w') as f: 173 | json.dump(karpathy_dataset['images'],f) 174 | 175 | train_img_list = [] 176 | val_img_list = [] 177 | test_img_list = [] 178 | imgs = karpathy_dataset['images'] 179 | 180 | for x in xrange(len(imgs)): 181 | if imgs[x]['split'] == 'train': 182 | filename = imgs[x]['filename'] 183 | path = os.path.join(images_root,'karpathy','train',filename) 184 | label = imgs[x]['label'] 185 | item = (path, label) 186 | train_img_list.append(item) 187 | if imgs[x]['split'] == 'restval': 188 | filename = imgs[x]['filename'] 189 | path = os.path.join(images_root,'karpathy','train',filename) 190 | label = imgs[x]['label'] 191 | item = (path, label) 192 | train_img_list.append(item) 193 | if imgs[x]['split'] == 'val': 194 | filename = imgs[x]['filename'] 195 | path = os.path.join(images_root,'karpathy','val',filename) 196 | label = imgs[x]['label'] 197 | item = (path, label) 198 | val_img_list.append(item) 199 | if imgs[x]['split'] == 'test': 200 | filename = imgs[x]['filename'] 201 | path = os.path.join(images_root,'karpathy','test',filename) 202 | label = imgs[x]['label'] 203 | item = (path, label) 204 | test_img_list.append(item) 205 | 206 | train_json = os.path.join(images_root,'karpathy','train/data.json') 207 | val_json = os.path.join(images_root,'karpathy','val/data.json') 208 | test_json = os.path.join(images_root,'karpathy','test/data.json') 209 | 210 | -------------------------------------------------------------------------------- /scripts/prepro_ngrams.py: -------------------------------------------------------------------------------- 1 | """ 2 | Preprocess a raw json dataset into hdf5/json files for use in data_loader.lua 3 | 4 | Input: json file that has the form 5 | [{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...] 6 | example element in this list would look like 7 | {'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895} 8 | 9 | This script reads this json, does some basic preprocessing on the captions 10 | (e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays 11 | 12 | Output: a json file and an hdf5 file 13 | The hdf5 file contains several fields: 14 | /images is (N,3,256,256) uint8 array of raw image data in RGB format 15 | /labels is (M,max_length) uint32 array of encoded labels, zero padded 16 | /label_start_ix and /label_end_ix are (N,) uint32 arrays of pointers to the 17 | first and last indices (in range 1..M) of labels for each image 18 | /label_length stores the length of the sequence for each of the M sequences 19 | 20 | The json file has a dict that contains: 21 | - an 'ix_to_word' field storing the vocab in form {ix:'word'}, where ix is 1-indexed 22 | - an 'images' field that is a list holding auxiliary information for each image, 23 | such as in particular the 'split' it was assigned to. 24 | """ 25 | 26 | import os 27 | import json 28 | import argparse 29 | from six.moves import cPickle 30 | from collections import defaultdict 31 | 32 | def precook(s, n=4, out=False): 33 | """ 34 | Takes a string as input and returns an object that can be given to 35 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 36 | can take string arguments as well. 37 | :param s: string : sentence to be converted into ngrams 38 | :param n: int : number of ngrams for which representation is calculated 39 | :return: term frequency vector for occuring ngrams 40 | """ 41 | words = s.split() 42 | counts = defaultdict(int) 43 | for k in xrange(1,n+1): 44 | for i in xrange(len(words)-k+1): 45 | ngram = tuple(words[i:i+k]) 46 | counts[ngram] += 1 47 | return counts 48 | 49 | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average" 50 | '''Takes a list of reference sentences for a single segment 51 | and returns an object that encapsulates everything that BLEU 52 | needs to know about them. 53 | :param refs: list of string : reference sentences for some image 54 | :param n: int : number of ngrams for which (ngram) representation is calculated 55 | :return: result (list of dict) 56 | ''' 57 | return [precook(ref, n) for ref in refs] 58 | 59 | def create_crefs(refs): 60 | crefs = [] 61 | for ref in refs: 62 | # ref is a list of 5 captions 63 | crefs.append(cook_refs(ref)) 64 | return crefs 65 | 66 | def compute_doc_freq(crefs): 67 | ''' 68 | Compute term frequency for reference data. 69 | This will be used to compute idf (inverse document frequency later) 70 | The term frequency is stored in the object 71 | :return: None 72 | ''' 73 | document_frequency = defaultdict(float) 74 | for refs in crefs: 75 | # refs, k ref captions of one image 76 | for ngram in set([ngram for ref in refs for (ngram,count) in ref.iteritems()]): 77 | document_frequency[ngram] += 1 78 | # maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 79 | return document_frequency 80 | 81 | def build_dict(imgs, wtoi, params): 82 | wtoi[''] = 0 83 | 84 | count_imgs = 0 85 | 86 | refs_words = [] 87 | refs_idxs = [] 88 | for img in imgs: 89 | if (params['split'] == img['split']) or \ 90 | (params['split'] == 'train' and img['split'] == 'restval') or \ 91 | (params['split'] == 'all'): 92 | #(params['split'] == 'val' and img['split'] == 'restval') or \ 93 | ref_words = [] 94 | ref_idxs = [] 95 | for sent in img['sentences']: 96 | tmp_tokens = sent['tokens'] + [''] 97 | tmp_tokens = [_ if _ in wtoi else 'UNK' for _ in tmp_tokens] 98 | ref_words.append(' '.join(tmp_tokens)) 99 | ref_idxs.append(' '.join([str(wtoi[_]) for _ in tmp_tokens])) 100 | refs_words.append(ref_words) 101 | refs_idxs.append(ref_idxs) 102 | count_imgs += 1 103 | print('total imgs:', count_imgs) 104 | ngram_words = compute_doc_freq(create_crefs(refs_words)) 105 | ngram_idxs = compute_doc_freq(create_crefs(refs_idxs)) 106 | return ngram_words, ngram_idxs, count_imgs 107 | 108 | def main(params): 109 | 110 | imgs = json.load(open(params['input_json'], 'r')) 111 | itow = json.load(open(params['dict_json'], 'r'))['ix_to_word'] 112 | wtoi = {w:i for i,w in itow.items()} 113 | 114 | imgs = imgs['images'] 115 | 116 | ngram_words, ngram_idxs, ref_len = build_dict(imgs, wtoi, params) 117 | 118 | cPickle.dump({'document_frequency': ngram_words, 'ref_len': ref_len}, open(params['output_pkl']+'-words.p','w'), protocol=cPickle.HIGHEST_PROTOCOL) 119 | cPickle.dump({'document_frequency': ngram_idxs, 'ref_len': ref_len}, open(params['output_pkl']+'-idxs.p','w'), protocol=cPickle.HIGHEST_PROTOCOL) 120 | 121 | if __name__ == "__main__": 122 | path = '/nlp/andyweizhao/self-critical.pytorch_CN/' 123 | parser = argparse.ArgumentParser() 124 | 125 | parser.add_argument('--input_json', default=path + 'data/dataset/dataset_combined.json', help='input json file to process into hdf5') 126 | parser.add_argument('--dict_json', default=path + 'data/meta/combined_meta.json', help='output json file') 127 | parser.add_argument('--output_pkl', default=path + 'data/combined-train', help='output pickle file') 128 | parser.add_argument('--split', default='train', help='test, val, train, all') 129 | args = parser.parse_args() 130 | params = vars(args) # convert to ordinary dict 131 | 132 | main(params) 133 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch 6 | from torch.autograd import Variable 7 | import torch.optim as optim 8 | import numpy as np 9 | 10 | import time 11 | import os 12 | from six.moves import cPickle 13 | 14 | import opts 15 | import models 16 | import torch.nn as nn 17 | import eval_utils 18 | import misc.utils as utils 19 | import torch.nn.functional as F 20 | 21 | from misc.rewards import get_self_critical_reward 22 | 23 | 24 | def train(opt): 25 | opt.use_att = utils.if_use_att(opt.caption_model) 26 | 27 | from dataloader import DataLoader 28 | loader = DataLoader(opt) 29 | 30 | opt.vocab_size = loader.vocab_size 31 | opt.vocab_ccg_size = loader.vocab_ccg_size 32 | opt.seq_length = loader.seq_length 33 | 34 | infos = {} 35 | histories = {} 36 | if opt.start_from is not None: 37 | # open old infos and check if models are compatible 38 | with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl')) as f: 39 | infos = cPickle.load(f) 40 | saved_model_opt = infos['opt'] 41 | need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"] 42 | for checkme in need_be_same: 43 | assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme 44 | if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')): 45 | with open(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')) as f: 46 | histories = cPickle.load(f) 47 | 48 | iteration = infos.get('iter', 0) 49 | epoch = infos.get('epoch', 0) 50 | 51 | val_result_history = histories.get('val_result_history', {}) 52 | loss_history = histories.get('loss_history', {}) 53 | lr_history = histories.get('lr_history', {}) 54 | ss_prob_history = histories.get('ss_prob_history', {}) 55 | 56 | loader.iterators = infos.get('iterators', loader.iterators) 57 | loader.split_ix = infos.get('split_ix', loader.split_ix) 58 | 59 | if opt.load_best_score == 1: 60 | best_val_score = infos.get('best_val_score', None) 61 | 62 | cnn_model = utils.build_cnn(opt) 63 | cnn_model.cuda() 64 | 65 | model = models.setup(opt) 66 | model.cuda() 67 | # model = DataParallel(model) 68 | 69 | if vars(opt).get('start_from', None) is not None: 70 | # check if all necessary files exist 71 | assert os.path.isdir(opt.start_from)," %s must be a a path" % opt.start_from 72 | assert os.path.isfile(os.path.join(opt.start_from,"infos_"+opt.id+".pkl")),"infos.pkl file does not exist in path %s"%opt.start_from 73 | model.load_state_dict(torch.load(os.path.join(opt.start_from, 'model.pth'))) 74 | 75 | update_lr_flag = True 76 | model.train() 77 | 78 | crit = utils.LanguageModelCriterion() 79 | rl_crit = utils.RewardCriterion() 80 | multilabel_crit = nn.MultiLabelSoftMarginLoss().cuda() 81 | # optimizer = optim.Adam(model.parameters(), lr=opt.learning_rate, weight_decay=opt.weight_decay) 82 | optimizer = optim.Adam(model.parameters(), lr=opt.learning_rate) 83 | if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after: 84 | print('finetune mode') 85 | cnn_optimizer = optim.Adam([\ 86 | {'params': module.parameters()} for module in cnn_model._modules.values()[5:]\ 87 | ], lr=opt.cnn_learning_rate, weight_decay=opt.cnn_weight_decay) 88 | 89 | if vars(opt).get('start_from', None) is not None and os.path.isfile(os.path.join(opt.start_from,"optimizer.pth")): 90 | if os.path.isfile(os.path.join(opt.start_from, 'optimizer.pth')): 91 | optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer.pth'))) 92 | if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after: 93 | if os.path.isfile(os.path.join(opt.start_from, 'optimizer-cnn.pth')): 94 | cnn_optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer-cnn.pth'))) 95 | 96 | eval_kwargs = {'split': 'val','dataset': opt.input_json,'verbose':True} 97 | eval_kwargs.update(vars(opt)) 98 | val_loss, predictions, lang_stats = eval_utils.eval_split(cnn_model, model, crit, 99 | loader, eval_kwargs, True) 100 | epoch_start = time.time() 101 | while True: 102 | if update_lr_flag: 103 | if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: 104 | frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every 105 | decay_factor = opt.learning_rate_decay_rate ** frac 106 | opt.current_lr = opt.learning_rate * decay_factor 107 | utils.set_lr(optimizer, opt.current_lr) # set the decayed rate 108 | else: 109 | opt.current_lr = opt.learning_rate 110 | if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: 111 | frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every 112 | opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob) 113 | model.ss_prob = opt.ss_prob 114 | #model.module.ss_prob = opt.ss_prob 115 | if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: 116 | sc_flag = True 117 | else: 118 | sc_flag = False 119 | 120 | # Update the training stage of cnn 121 | for p in cnn_model.parameters(): 122 | p.requires_grad = True 123 | # Fix the first few layers: 124 | for module in cnn_model._modules.values()[:5]: 125 | for p in module.parameters(): 126 | p.requires_grad = False 127 | cnn_model.train() 128 | update_lr_flag = False 129 | 130 | cnn_model.apply(utils.set_bn_fix) 131 | cnn_model.apply(utils.set_bn_eval) 132 | 133 | start = time.time() 134 | torch.cuda.synchronize() 135 | data = loader.get_batch('train') 136 | if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after: 137 | 138 | multilabels = [data['detection_infos'][i]['label'] for i in range(len(data['detection_infos']))] 139 | 140 | tmp = [data['labels'], data['masks'],np.array(multilabels,dtype=np.int16)] 141 | tmp = [Variable(torch.from_numpy(_), requires_grad=False).cuda() for _ in tmp] 142 | labels, masks, multilabels = tmp 143 | images = data['images'] # it cannot be turned into tensor since different sizes. 144 | _fc_feats_2048 = [] 145 | _fc_feats_81 = [] 146 | _att_feats = [] 147 | for i in range(loader.batch_size): 148 | x = Variable(torch.from_numpy(images[i]), requires_grad=False).cuda() 149 | x = x.unsqueeze(0) 150 | att_feats, fc_feats_81 = cnn_model(x) 151 | fc_feats_2048 = att_feats.mean(3).mean(2).squeeze() 152 | att_feats = F.adaptive_avg_pool2d(att_feats,[14,14]).squeeze().permute(1, 2, 0)#(0, 2, 3, 1) 153 | _fc_feats_2048.append(fc_feats_2048) 154 | _fc_feats_81.append(fc_feats_81) 155 | _att_feats.append(att_feats) 156 | _fc_feats_2048 = torch.stack(_fc_feats_2048) 157 | _fc_feats_81 = torch.stack(_fc_feats_81) 158 | _att_feats = torch.stack(_att_feats) 159 | att_feats = _att_feats.unsqueeze(1).expand(*((_att_feats.size(0), loader.seq_per_img,) + \ 160 | _att_feats.size()[1:])).contiguous().view(*((_att_feats.size(0) * loader.seq_per_img,) + \ 161 | _att_feats.size()[1:])) 162 | fc_feats_2048 = _fc_feats_2048.unsqueeze(1).expand(*((_fc_feats_2048.size(0), loader.seq_per_img,) + \ 163 | _fc_feats_2048.size()[1:])).contiguous().view(*((_fc_feats_2048.size(0) * loader.seq_per_img,) + \ 164 | _fc_feats_2048.size()[1:])) 165 | fc_feats_81 = _fc_feats_81 166 | # 167 | cnn_optimizer.zero_grad() 168 | else: 169 | 170 | tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks']] 171 | tmp = [Variable(torch.from_numpy(_), requires_grad=False).cuda() for _ in tmp] 172 | fc_feats, att_feats, labels, masks = tmp 173 | 174 | optimizer.zero_grad() 175 | 176 | if not sc_flag: 177 | loss1 = crit(model(fc_feats_2048, att_feats, labels), labels[:,1:], masks[:,1:]) 178 | loss2 = multilabel_crit(fc_feats_81.double(), multilabels.double()) 179 | loss = 0.8*loss1 + 0.2*loss2.float() 180 | else: 181 | gen_result, sample_logprobs = model.sample(fc_feats_2048, att_feats, {'sample_max':0}) 182 | reward = get_self_critical_reward(model, fc_feats_2048, att_feats, data, gen_result) 183 | loss1 = rl_crit(sample_logprobs, gen_result, Variable(torch.from_numpy(reward).float().cuda(), requires_grad=False)) 184 | loss2 = multilabel_crit(fc_feats_81.double(), multilabels.double()) 185 | loss3 = crit(model(fc_feats_2048, att_feats, labels), labels[:,1:], masks[:,1:]) 186 | loss = 0.995*loss1 + 0.005*(loss2.float() + loss3) 187 | loss.backward() 188 | 189 | utils.clip_gradient(optimizer, opt.grad_clip) 190 | optimizer.step() 191 | 192 | train_loss = loss.data[0] 193 | mle_loss = loss1.data[0] 194 | multilabel_loss = loss2.data[0] 195 | torch.cuda.synchronize() 196 | end = time.time() 197 | if not sc_flag and iteration % 2500==0: 198 | print("iter {} (epoch {}), mle_loss = {:.3f}, multilabel_loss = {:.3f}, train_loss = {:.3f}, time/batch = {:.3f}" \ 199 | .format(iteration, epoch, mle_loss, multilabel_loss, train_loss, end - start)) 200 | 201 | if sc_flag and iteration % 2500==0: 202 | print("iter {} (epoch {}), avg_reward = {:.3f}, mle_loss = {:.3f}, multilabel_loss = {:.3f}, train_loss = {:.3f}, time/batch = {:.3f}" \ 203 | .format(iteration, epoch, np.mean(reward[:,0]), mle_loss, multilabel_loss, train_loss, end - start)) 204 | iteration += 1 205 | if (iteration % opt.losses_log_every == 0): 206 | loss_history[iteration] = train_loss if not sc_flag else np.mean(reward[:,0]) 207 | lr_history[iteration] = opt.current_lr 208 | ss_prob_history[iteration] = model.ss_prob 209 | 210 | if (iteration % opt.save_checkpoint_every == 0): 211 | eval_kwargs = {'split': 'val','dataset': opt.input_json,'verbose':True} 212 | eval_kwargs.update(vars(opt)) 213 | 214 | if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after: 215 | val_loss, predictions, lang_stats = eval_utils.eval_split(cnn_model, model, crit, 216 | loader, eval_kwargs, True) 217 | else: 218 | val_loss, predictions, lang_stats = eval_utils.eval_split(cnn_model, model, crit, 219 | loader, eval_kwargs, False) 220 | 221 | val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions} 222 | 223 | if opt.language_eval == 1: 224 | current_score = lang_stats['CIDEr'] 225 | else: 226 | current_score = - val_loss 227 | 228 | best_flag = False 229 | if True: 230 | if best_val_score is None or current_score > best_val_score: 231 | best_val_score = current_score 232 | best_flag = True 233 | checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth') 234 | torch.save(model.state_dict(), checkpoint_path) 235 | print("model saved to {}".format(checkpoint_path)) 236 | 237 | cnn_checkpoint_path = os.path.join(opt.checkpoint_path, 'model-cnn.pth') 238 | torch.save(cnn_model.state_dict(), cnn_checkpoint_path) 239 | print("cnn model saved to {}".format(cnn_checkpoint_path)) 240 | 241 | optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth') 242 | torch.save(optimizer.state_dict(), optimizer_path) 243 | 244 | if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after: 245 | cnn_optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer-cnn.pth') 246 | torch.save(cnn_optimizer.state_dict(), cnn_optimizer_path) 247 | 248 | infos['iter'] = iteration 249 | infos['epoch'] = epoch 250 | infos['iterators'] = loader.iterators 251 | infos['split_ix'] = loader.split_ix 252 | infos['best_val_score'] = best_val_score 253 | infos['opt'] = opt 254 | infos['vocab'] = loader.get_vocab() 255 | 256 | histories['val_result_history'] = val_result_history 257 | histories['loss_history'] = loss_history 258 | histories['lr_history'] = lr_history 259 | histories['ss_prob_history'] = ss_prob_history 260 | with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'.pkl'), 'wb') as f: 261 | cPickle.dump(infos, f) 262 | with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+'.pkl'), 'wb') as f: 263 | cPickle.dump(histories, f) 264 | 265 | if best_flag: 266 | checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best.pth') 267 | torch.save(model.state_dict(), checkpoint_path) 268 | print("model saved to {}".format(checkpoint_path)) 269 | 270 | cnn_checkpoint_path = os.path.join(opt.checkpoint_path, 'model-cnn-best.pth') 271 | torch.save(cnn_model.state_dict(), cnn_checkpoint_path) 272 | print("cnn model saved to {}".format(cnn_checkpoint_path)) 273 | 274 | with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'-best.pkl'), 'wb') as f: 275 | cPickle.dump(infos, f) 276 | 277 | if data['bounds']['wrapped']: 278 | epoch += 1 279 | update_lr_flag = True 280 | print("epoch: "+str(epoch)+ " during: " + str(time.time()-epoch_start)) 281 | epoch_start = time.time() 282 | 283 | if epoch >= opt.max_epochs and opt.max_epochs != -1: 284 | break 285 | 286 | def main(): 287 | opt = opts.parse_opt() 288 | opt.caption_model ='topdown' 289 | opt.batch_size=10 290 | opt.id ='topdown' 291 | opt.learning_rate= 5e-5 292 | opt.learning_rate_decay_start= -1 293 | opt.scheduled_sampling_start=-1 294 | opt.save_checkpoint_every=5000# 295 | opt.val_images_use=5000 296 | opt.max_epochs=60 297 | opt.start_from='save/multitask_pretrain'#"save" #None 298 | opt.language_eval = 1 299 | opt.input_json='data/meta_coco_en.json' 300 | opt.input_label_h5='data/label_coco_en.h5' 301 | opt.self_critical_after = 25 302 | opt.finetune_cnn_after = 0 303 | opt.ccg = False 304 | opt.input_image_h5 = 'data/coco_image_512.h5' 305 | opt.checkpoint_path = 'save/multitask_pretrain_rl' 306 | train(opt) 307 | main() 308 | --------------------------------------------------------------------------------