├── data └── README.md ├── KarpathySplit.py ├── README.md ├── data_loader.py ├── test.py ├── train.py └── model.py /data/README.md: -------------------------------------------------------------------------------- 1 | This folder is for storing [MS-COCO dataset](http://mscoco.org/) and annotations 2 | -------------------------------------------------------------------------------- /KarpathySplit.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | # # Karpathy Split for MS-COCO Dataset 4 | import json 5 | from random import shuffle, seed 6 | 7 | seed( 123 ) 8 | 9 | num_val = 5000 10 | num_test = 5000 11 | 12 | val = json.load( open('./data/annotations/captions_val2014.json', 'r') ) 13 | train = json.load( open('./data/annotations/captions_train2014.json', 'r') ) 14 | 15 | # Merge together 16 | imgs = val['images'] + train['images'] 17 | annots = val['annotations'] + train['annotations'] 18 | 19 | shuffle( imgs ) 20 | 21 | # Split into val, test, train 22 | dataset = {} 23 | dataset[ 'val' ] = imgs[ :num_val ] 24 | dataset[ 'test' ] = imgs[ num_val: num_val + num_test ] 25 | dataset[ 'train' ] = imgs[ num_val + num_test: ] 26 | 27 | # Group by image ids 28 | itoa = {} 29 | for a in annots: 30 | imgid = a['image_id'] 31 | if not imgid in itoa: itoa[imgid] = [] 32 | itoa[imgid].append(a) 33 | 34 | 35 | json_data = {} 36 | info = train['info'] 37 | licenses = train['licenses'] 38 | 39 | split = [ 'test', 'val', 'train' ] 40 | 41 | for subset in split: 42 | 43 | json_data[ subset ] = { 'type':'caption', 'info':info, 'licenses': licenses, 44 | 'images':[], 'annotations':[] } 45 | 46 | for img in dataset[ subset ]: 47 | 48 | img_id = img['id'] 49 | anns = itoa[ img_id ] 50 | 51 | json_data[ subset ]['images'].append( img ) 52 | json_data[ subset ]['annotations'].extend( anns ) 53 | 54 | json.dump( json_data[ subset ], open( './data/annotations/karpathy_split_' + subset + '.json', 'w' ) ) 55 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # simNet 2 | Implementation of "simNet: Stepwise Image-Topic Merging Network for Generating Detailed and Comprehensive Image Captions" by Fenglin Liu, Xuancheng Ren, Yuanxin Liu, Houfeng Wang, and Xu Sun. The paper can be found at [[arxiv]](https://arxiv.org/abs/1808.08732). 3 | 4 | ## Usage 5 | 6 | ### Requirements 7 | This code is written in Python2.7 and requires PyTorch 0.3 8 | 9 | You need to download pre-trained Resnet152 model from [torchvision](https://github.com/pytorch/vision) for both training and evaluation. 10 | 11 | You may take a look at https://github.com/s-gupta/visual-concepts to find how to get the topic words of an image. 12 | 13 | ### Training a simNet model 14 | Now we can train our simNet model with 15 | 16 | ``` 17 | CUDA_VISIBLE_DEVICES=1,2,3 screen python train.py 18 | ``` 19 | 20 | ### Testing a trained model 21 | We can test our simNet model with 22 | 23 | ``` 24 | CUDA_VISIBLE_DEVICES=1,2,3 screen python test.py 25 | ``` 26 | 27 | 28 | ## Reference 29 | If you use this code as part of any published research, please acknowledge the following paper 30 | ``` 31 | @inproceedings{Liu2018simNet, 32 | author = {Fenglin Liu and Xuancheng Ren and Yuanxin Liu and Houfeng Wang and Xu Sun}, 33 | title = {sim{N}et: Stepwise Image-Topic Merging Network for Generating Detailed and Comprehensive Image Captions}, 34 | booktitle = {EMNLP 2018}, 35 | year = {2018} 36 | } 37 | ``` 38 | 39 | ## Acknowledgements 40 | 41 | Thanks to [Torch](http://torch.ch/) team for providing Torch 0.3, [CodaLab](https://competitions.codalab.org/) team for providing online evaluation, [COCO](http://cocodataset.org/) team and [Flickr30k](http://web.engr.illinois.edu/~bplumme2/Flickr30kEntities/) for providing dataset, [Tsung-Yi Lin](https://github.com/tylin/coco-caption) for providing evaluation codes for MS COCO caption generation, [Yufeng Ma](https://github.com/yufengm)'s open source repositories and Torchvision [ResNet](https://github.com/pytorch/vision) implementation. 42 | 43 | ### Note 44 | If you have any questions about the code or our paper, please send an email to lfl@bupt.edu.cn 45 | 46 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import torchvision.transforms as transforms 4 | import torch.utils.data as data 5 | import os 6 | import pickle 7 | import string 8 | import numpy as np 9 | import nltk 10 | from PIL import Image 11 | from build_vocab import Vocabulary 12 | from coco.pycocotools.coco import COCO 13 | 14 | class CocoDataset(data.Dataset): 15 | """COCO Custom Dataset compatible with torch.utils.data.DataLoader.""" 16 | def __init__(self, root, json, topic, vocab, transform=None): 17 | """Set the path for images, captions and vocabulary wrapper. 18 | 19 | Args: 20 | root: image directory. 21 | json: coco annotation file path. 22 | vocab: vocabulary wrapper. 23 | transform: image transformer. 24 | """ 25 | 26 | self.root = root 27 | self.coco = COCO( json ) 28 | self.ids = list( self.coco.anns.keys() ) 29 | self.vocab = vocab 30 | self.transform = transform 31 | self.topic_train = topic 32 | 33 | def __getitem__(self, index): 34 | """Returns one data pair ( image, caption, image_id, T ).""" 35 | coco = self.coco 36 | vocab = self.vocab 37 | ann_id = self.ids[index] 38 | caption = coco.anns[ann_id]['caption'] 39 | img_id = coco.anns[ann_id]['image_id'] 40 | filename = coco.loadImgs(img_id)[0]['file_name'] 41 | 42 | if 'val2014' in filename.lower(): 43 | path = 'val2014/' + filename 44 | elif 'train2014' in filename.lower(): 45 | path = 'train2014/' + filename 46 | else: 47 | path = 'test2014/' + filename 48 | 49 | image = Image.open( os.path.join( self.root, path ) ).convert('RGB') 50 | if self.transform is not None: 51 | image = self.transform( image ) 52 | 53 | # Convert caption (string) to word ids. 54 | tokens = str( caption ).lower().translate( None, string.punctuation ).strip().split() 55 | caption = [] 56 | caption.append(vocab('')) 57 | caption.extend([vocab(token) for token in tokens]) 58 | caption.append(vocab('')) 59 | target = torch.Tensor(caption) 60 | 61 | # Load image topic 62 | 63 | T = [] 64 | for topic in self.topic_train: 65 | if topic['image_id'] == img_id: 66 | image_topic = topic['image_concepts'] 67 | T.extend([vocab(token) for token in image_topic]) 68 | break 69 | T = torch.Tensor(T) 70 | 71 | return image, target, img_id, filename, T 72 | 73 | def __len__(self): 74 | return len( self.ids ) 75 | 76 | def collate_fn(data): 77 | """Creates mini-batch tensors from the list of tuples (image, caption). 78 | 79 | We should build custom collate_fn rather than using default collate_fn, 80 | because merging caption (including padding) is not supported in default. 81 | 82 | Args: 83 | data: list of tuple (image, caption). 84 | - image: torch tensor of shape (3, 256, 256). 85 | - caption: torch tensor of shape (?); variable length. 86 | 87 | Returns: 88 | images: torch tensor of shape (batch_size, 3, 256, 256). 89 | targets: torch tensor of shape (batch_size, padded_length). 90 | lengths: list; valid length for each padded caption. 91 | img_ids: image ids in COCO dataset, for evaluation purpose 92 | filenames: image filenames in COCO dataset, for evaluation purpose 93 | """ 94 | 95 | # Sort a data list by caption length (descending order). 96 | data.sort( key=lambda x: len( x[1] ), reverse=True ) 97 | images, captions, img_ids, filenames, Topic = zip( *data ) # unzip 98 | 99 | # Merge images (from tuple of 3D tensor to 4D tensor). 100 | images = torch.stack(images, 0) 101 | img_ids = list( img_ids ) 102 | filenames = list( filenames ) 103 | 104 | # Merge captions (from tuple of 1D tensor to 2D tensor). 105 | lengths = [len(cap) for cap in captions] 106 | targets = torch.zeros(len(captions), max(lengths)).long() 107 | for i, cap in enumerate(captions): 108 | end = lengths[i] 109 | targets[i, :end] = cap[:end] 110 | 111 | # Merge image_topic (from tuple of 1D tensor to 2D tensor). 112 | lengths_topic = len(Topic[0]) 113 | T = torch.zeros(len(Topic), lengths_topic).long() 114 | for j, capj in enumerate(Topic): 115 | end_topic = lengths_topic 116 | T[j, :end_topic] = capj[:end_topic] 117 | 118 | return images, targets, lengths, img_ids, filenames, T 119 | 120 | 121 | def get_loader(root, json, topic, vocab, transform, batch_size, shuffle, num_workers): 122 | """Returns torch.utils.data.DataLoader for custom coco dataset.""" 123 | # COCO caption dataset 124 | coco = CocoDataset(root=root, 125 | json=json, 126 | topic=topic, 127 | vocab=vocab, 128 | transform=transform) 129 | 130 | # Data loader for COCO dataset 131 | # This will return (images, captions, lengths) for every iteration. 132 | # images: tensor of shape (batch_size, 3, 224, 224). 133 | # captions: tensor of shape (batch_size, padded_length). 134 | # lengths: list indicating valid length for each caption. length is (batch_size). 135 | data_loader = torch.utils.data.DataLoader(dataset=coco, 136 | batch_size=batch_size, 137 | shuffle=shuffle, 138 | num_workers=num_workers, 139 | collate_fn=collate_fn) 140 | return data_loader 141 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | import os 7 | import glob 8 | import pickle 9 | from build_vocab import Vocabulary 10 | from model import Encoder2Decoder 11 | from torch.autograd import Variable 12 | from torchvision import transforms, datasets 13 | from coco.pycocotools.coco import COCO 14 | from coco.pycocoevalcap.eval import COCOEvalCap 15 | import matplotlib.pyplot as plt 16 | 17 | # Variable wrapper 18 | def to_var(x, volatile=False): 19 | ''' 20 | Wrapper torch tensor into Variable 21 | ''' 22 | if torch.cuda.is_available(): 23 | x = x.cuda() 24 | return Variable( x, volatile=volatile ) 25 | 26 | # MS COCO evaluation data loader 27 | class CocoEvalLoader( datasets.ImageFolder ): 28 | def __init__( self, root, ann_path, topic_path, transform=None, target_transform=None, 29 | loader=datasets.folder.default_loader ): 30 | ''' 31 | Customized COCO loader to get Image ids and Image Filenames 32 | root: path for images 33 | ann_path: path for the annotation file (e.g., caption_val2014.json) 34 | ''' 35 | self.root = root 36 | self.transform = transform 37 | self.target_transform = target_transform 38 | self.loader = loader 39 | self.imgs = json.load( open( ann_path, 'r' ) )['images'] 40 | self.image_topic = json.load(open( topic_path , 'r')) 41 | 42 | def __getitem__(self, index): 43 | 44 | filename = self.imgs[ index ]['file_name'] 45 | img_id = self.imgs[ index ]['id'] 46 | 47 | # Filename for the image 48 | if 'val2014' in filename.lower(): 49 | path = os.path.join( self.root, 'val2014' , filename ) 50 | elif 'train2014' in filename.lower(): 51 | path = os.path.join( self.root, 'train2014' , filename ) 52 | else: 53 | path = os.path.join( self.root, 'test2014', filename ) 54 | 55 | # Load the vocabulary 56 | with open( 'vocab.pkl', 'rb' ) as f: 57 | vocab = pickle.load( f ) 58 | 59 | img = self.loader( path ) 60 | if self.transform is not None: 61 | img = self.transform( img ) 62 | 63 | # Load the image topic 64 | T_val = [] 65 | for topic in self.image_topic: 66 | if topic['image_id'] == img_id: 67 | image_topic = topic['image_concepts'] 68 | T_val.extend([vocab(token) for token in image_topic]) 69 | break 70 | 71 | T_val = torch.LongTensor(T_val) 72 | 73 | return img, img_id, filename, T_val 74 | 75 | # MSCOCO Evaluation function 76 | def main( args ): 77 | 78 | ''' 79 | model: trained model to be evaluated 80 | args: parameters 81 | ''' 82 | # Load vocabulary wrapper. 83 | with open( args.vocab_path, 'rb') as f: 84 | vocab = pickle.load( f ) 85 | # Load trained model 86 | model = Encoder2Decoder( args.embed_size, len(vocab), args.hidden_size ) 87 | model.load_state_dict(torch.load(args.trained)) 88 | 89 | # Change to GPU mode if available 90 | if torch.cuda.is_available(): 91 | model.cuda() 92 | 93 | model.eval() 94 | 95 | transform = transforms.Compose([ 96 | transforms.Resize( (args.crop_size, args.crop_size) ), 97 | transforms.ToTensor(), 98 | transforms.Normalize((0.485, 0.456, 0.406), 99 | (0.229, 0.224, 0.225))]) 100 | 101 | # Wrapper the COCO VAL dataset 102 | eval_data_loader = torch.utils.data.DataLoader( 103 | CocoEvalLoader( args.image_dir, args.caption_test_path, args.topic_path, transform ), 104 | batch_size = args.eval_size, 105 | shuffle = False, num_workers = args.num_workers, 106 | drop_last = False ) 107 | epoch = int( args.trained.split('/')[-1].split('-')[1].split('.')[0] ) 108 | 109 | # Generated captions to be compared with GT 110 | results = [] 111 | print '---------------------Start evaluation on MS-COCO dataset-----------------------' 112 | for i, (images, image_ids, _, T_val ) in enumerate( eval_data_loader ): 113 | 114 | images = to_var( images ) 115 | T_val = to_var( T_val ) 116 | generated_captions = model.sampler( epoch, images, T_val ) 117 | 118 | if torch.cuda.is_available(): 119 | captions = generated_captions.cpu().data.numpy() 120 | else: 121 | captions = generated_captions.data.numpy() 122 | 123 | # Build caption based on Vocabulary and the '' token 124 | for image_idx in range( captions.shape[0] ): 125 | 126 | sampled_ids = captions[ image_idx ] 127 | sampled_caption = [] 128 | 129 | for word_id in sampled_ids: 130 | 131 | word = vocab.idx2word[ word_id ] 132 | if word == '': 133 | break 134 | else: 135 | sampled_caption.append( word ) 136 | 137 | sentence = ' '.join( sampled_caption ) 138 | 139 | temp = { 'image_id': int( image_ids[ image_idx ] ), 'caption': sentence} 140 | results.append( temp ) 141 | 142 | # Disp evaluation process 143 | if (i+1) % 10 == 0: 144 | print '[%d/%d]'%( (i+1),len( eval_data_loader ) ) 145 | 146 | print '------------------------Caption Generated-------------------------------------' 147 | 148 | # Evaluate the results based on the COCO API 149 | resFile = args.save_path 150 | json.dump( results, open( resFile , 'w' ) ) 151 | 152 | annFile = args.caption_test_path 153 | coco = COCO( annFile ) 154 | cocoRes = coco.loadRes( resFile ) 155 | 156 | cocoEval = COCOEvalCap( coco, cocoRes ) 157 | cocoEval.params['image_id'] = cocoRes.getImgIds() 158 | cocoEval.evaluate() 159 | 160 | print '-----------Evaluation performance on MS-COCO dataset----------' 161 | for metric, score in cocoEval.eval.items(): 162 | print '%s: %.4f'%( metric, score ) 163 | 164 | if __name__ == '__main__': 165 | parser = argparse.ArgumentParser() 166 | parser.add_argument('-f', default='self', help='To make it runnable in jupyter') 167 | parser.add_argument('--crop_size', type=int, default=224, 168 | help='size for randomly cropping images') 169 | parser.add_argument('--vocab_path', type=str, default='vocab.pkl', 170 | help='path for vocabulary wrapper') 171 | parser.add_argument('--image_dir', type=str, default='./data/coco2014', 172 | help='directory for resized training images') 173 | parser.add_argument('--caption_test_path', type=str, 174 | default='./data/annotations/karpathy_split_test.json', 175 | help='path for test annotation json file') 176 | parser.add_argument('--topic_path', type=str, 177 | default='./data/topics/image_topic.json', 178 | help='path for test topic json file') 179 | 180 | # ---------------------------Hyper Parameter Setup------------------------------------ 181 | parser.add_argument('--save_path', type=str, default='model_generated_caption.json') 182 | parser.add_argument('--embed_size', type=int, default=256, 183 | help='dimension of word embedding vectors, also dimension of v_g') 184 | parser.add_argument('--hidden_size', type=int, default=512, 185 | help='dimension of lstm hidden states') 186 | parser.add_argument('--trained', type=str, default='./models/simNet-30.pkl', 187 | help='start from checkpoint or scratch') 188 | parser.add_argument('--eval_size', type=int, default=200) 189 | parser.add_argument('--num_workers', type=int, default=4) 190 | 191 | args = parser.parse_args() 192 | 193 | print '------------------------Model and Testing Details--------------------------' 194 | print(args) 195 | 196 | # Start training 197 | main(args) 198 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import math 2 | import json 3 | import argparse 4 | import torch 5 | import torch.nn as nn 6 | import numpy as np 7 | import os 8 | import pickle 9 | from data_loader import get_loader 10 | from model import Encoder2Decoder 11 | from build_vocab import Vocabulary 12 | from torch.autograd import Variable 13 | from torchvision import transforms 14 | from torch.nn.utils.rnn import pack_padded_sequence 15 | 16 | def to_var( x, volatile=False ): 17 | ''' 18 | Wrapper torch tensor into Variable 19 | ''' 20 | if torch.cuda.is_available(): 21 | x = x.cuda() 22 | return Variable( x, volatile=volatile ) 23 | 24 | def main( args ): 25 | # To reproduce training results 26 | torch.manual_seed( args.seed ) 27 | if torch.cuda.is_available(): 28 | torch.cuda.manual_seed( args.seed ) 29 | 30 | # Create model directory 31 | if not os.path.exists( args.model_path ): 32 | os.makedirs( args.model_path ) 33 | 34 | # Image Preprocessing 35 | # For normalization, see https://github.com/pytorch/vision#models 36 | transform = transforms.Compose([ 37 | transforms.RandomCrop( args.crop_size ), 38 | transforms.RandomHorizontalFlip(), 39 | transforms.ToTensor(), 40 | transforms.Normalize(( 0.485, 0.456, 0.406 ), 41 | ( 0.229, 0.224, 0.225 ))]) 42 | 43 | # Load vocabulary wrapper. 44 | with open( args.vocab_path, 'rb') as f: 45 | vocab = pickle.load( f ) 46 | 47 | # Load pretrained model or build from scratch 48 | simNet = Encoder2Decoder( args.embed_size, len( vocab ), args.hidden_size ) 49 | 50 | if args.pretrained: 51 | simNet.load_state_dict( torch.load( args.pretrained ) ) 52 | # Get starting epoch #, note that model is named as '...your path to model/algoname-epoch#.pkl' 53 | # A little messy here. 54 | start_epoch = int( args.pretrained.split('/')[-1].split('-')[1].split('.')[0] ) + 1 55 | 56 | elif args.pretrained_cnn: 57 | pretrained_dict = torch.load( args.pretrained_cnn ) 58 | model_dict=simNet.state_dict() 59 | 60 | # 1. filter out unnecessary keys 61 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 62 | # 2. overwrite entries in the existing state dict 63 | model_dict.update( pretrained_dict ) 64 | simNet.load_state_dict( model_dict ) 65 | 66 | start_epoch = 1 67 | 68 | else: 69 | start_epoch = 1 70 | 71 | # Parameter optimization 72 | params = list( simNet.encoder.affine_VI.parameters() ) + list( simNet.decoder.parameters() ) 73 | 74 | # Will decay later 75 | learning_rate = args.learning_rate 76 | 77 | # Language Modeling Loss 78 | LMcriterion = nn.CrossEntropyLoss() 79 | 80 | # Change to GPU mode if available 81 | if torch.cuda.is_available(): 82 | simNet.cuda() 83 | LMcriterion.cuda() 84 | # Load image_topic 85 | topic = json.load( open( args.topic_path , 'r' ) ) 86 | # Build training data loader 87 | data_loader = get_loader(args.image_dir, args.caption_path, topic, vocab, 88 | transform, args.batch_size, 89 | shuffle=True, num_workers=args.num_workers) 90 | # Train the Models 91 | total_step = len( data_loader ) 92 | 93 | # Start Training 94 | for epoch in range( start_epoch, args.num_epochs + 1 ): 95 | if epoch == args.visual_attention_epoch: 96 | print 'Starting Training Visual Attention' 97 | 98 | # Start Learning Rate Decay 99 | if epoch > args.lr_decay: 100 | 101 | frac = float( epoch - args.lr_decay ) / args.learning_rate_decay_every 102 | decay_factor = math.pow( 0.5, frac ) 103 | 104 | # Decay the learning rate 105 | learning_rate = args.learning_rate * decay_factor 106 | 107 | print 'Learning Rate for Epoch %d: %.6f'%( epoch, learning_rate ) 108 | 109 | optimizer = torch.optim.Adam( params, lr=learning_rate, betas=( args.alpha, args.beta ) ) 110 | 111 | # Language Modeling Training 112 | print '------------------Training for Epoch %d----------------'%( epoch ) 113 | for i, ( images, captions, lengths, _, _, T ) in enumerate( data_loader ): 114 | 115 | # Set mini-batch dataset 116 | images = to_var( images ) 117 | captions = to_var( captions ) 118 | T = to_var( T ) 119 | lengths = [ cap_len - 1 for cap_len in lengths ] 120 | targets = pack_padded_sequence( captions[:,1:], lengths, batch_first=True )[0] 121 | 122 | # Forward, Backward and Optimize 123 | simNet.train() 124 | simNet.zero_grad() 125 | 126 | packed_scores = simNet( epoch, images, captions, lengths, T ) 127 | 128 | # Compute loss and backprop 129 | loss = LMcriterion( packed_scores[0], targets ) 130 | loss.backward() 131 | 132 | # Gradient clipping for gradient exploding problem in LSTM 133 | for p in simNet.decoder.LSTM.parameters(): 134 | p.data.clamp_( -args.clip, args.clip ) 135 | 136 | optimizer.step() 137 | 138 | # Print log info 139 | if i % args.log_step == 0: 140 | print 'Epoch [%d/%d], Step [%d/%d], CrossEntropy Loss: %.4f'%( epoch, args.num_epochs, i, total_step, loss.data[0] ) 141 | 142 | # Save the simNet model after each epoch 143 | torch.save( simNet.state_dict(), 144 | os.path.join( args.model_path, 145 | 'simNet-%d.pkl'%( epoch ) ) ) 146 | 147 | if __name__ == '__main__': 148 | parser = argparse.ArgumentParser() 149 | parser.add_argument( '-f', default='self', help='To make it runnable in jupyter' ) 150 | parser.add_argument( '--model_path', type=str, default='./models', 151 | help='path for saving trained models') 152 | parser.add_argument('--crop_size', type=int, default=224 , 153 | help='size for randomly cropping images') 154 | parser.add_argument('--vocab_path', type=str, default='vocab.pkl', 155 | help='path for vocabulary wrapper') 156 | parser.add_argument('--image_dir', type=str, default='./data/coco2014' , 157 | help='directory for training images') 158 | parser.add_argument('--caption_path', type=str, 159 | default='./data/annotations/karpathy_split_train.json', 160 | help='path for train annotation json file') 161 | parser.add_argument('--topic_path', type=str, 162 | default='./data/topics/image_topic.json', 163 | help='path for image topic json file') 164 | parser.add_argument('--log_step', type=int, default=10, 165 | help='step size for printing log info') 166 | parser.add_argument('--seed', type=int, default=123, 167 | help='random seed for model reproduction') 168 | 169 | # ---------------------------Hyper Parameter Setup------------------------------------ 170 | # Optimizer Adam parameter 171 | parser.add_argument( '--alpha', type=float, default=0.8, 172 | help='alpha in Adam' ) 173 | parser.add_argument( '--beta', type=float, default=0.999, 174 | help='beta in Adam' ) 175 | parser.add_argument( '--learning_rate', type=float, default=4e-4, 176 | help='learning rate for the whole model' ) 177 | 178 | # LSTM hyper parameters 179 | parser.add_argument( '--embed_size', type=int, default=256, 180 | help='dimension of word embedding vectors' ) 181 | parser.add_argument( '--hidden_size', type=int, default=512, 182 | help='dimension of lstm hidden states' ) 183 | 184 | # Training details 185 | parser.add_argument( '--pretrained', type=str, default='', help='start from checkpoint or scratch' ) 186 | parser.add_argument( '--pretrained_cnn', type=str, default='models/pretrained_cnn.pkl', help='load pertraind_cnn parameters' ) 187 | parser.add_argument( '--num_epochs', type=int, default=30 ) 188 | parser.add_argument( '--batch_size', type=int, default=80 ) 189 | parser.add_argument( '--num_workers', type=int, default=4 ) 190 | parser.add_argument( '--clip', type=float, default=0.1 ) 191 | parser.add_argument( '--visual_attention_epoch', type=int, default=20, help='epoch at which to start training visual_attention' ) 192 | parser.add_argument( '--lr_decay', type=int, default=20, help='epoch at which to start lr decay' ) 193 | parser.add_argument( '--learning_rate_decay_every', type=int, default=50, 194 | help='decay learning rate at every this number') 195 | 196 | 197 | args = parser.parse_args() 198 | 199 | print '------------------------Model and Training Details--------------------------' 200 | print(args) 201 | 202 | # Start training 203 | main( args ) 204 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models as models 4 | from torch.nn.utils.rnn import pack_padded_sequence 5 | from torch.autograd import Variable 6 | import torch.nn.functional as F 7 | from torch.nn import init 8 | import numpy as np 9 | 10 | #=========================================simNet========================================= 11 | class AttentiveCNN( nn.Module ): 12 | def __init__( self, hidden_size ): 13 | super( AttentiveCNN, self ).__init__() 14 | 15 | # ResNet-152 backend 16 | resnet = models.resnet152() 17 | modules = list( resnet.children() )[ :-2 ] # delete the last fc layer and avg pool. 18 | resnet_conv = nn.Sequential( *modules ) # last conv feature 19 | 20 | self.resnet_conv = resnet_conv 21 | self.affine_VI = nn.Linear( 2048, hidden_size ) # reduce the dimension 22 | 23 | # Dropout before affine transformation 24 | self.dropout = nn.Dropout( 0.5 ) 25 | 26 | self.init_weights() 27 | 28 | def init_weights( self ): 29 | """Initialize the weights.""" 30 | init.kaiming_uniform( self.affine_VI.weight, mode='fan_in' ) 31 | self.affine_VI.bias.data.fill_( 0 ) 32 | 33 | def forward( self, images ): 34 | ''' 35 | Input: images 36 | Output: V=[v_1, ..., v_n], v_g 37 | ''' 38 | # Last conv layer feature map 39 | A = self.resnet_conv( images ) 40 | 41 | # V = [ v_1, v_2, ..., v_49 ] 42 | V = A.view( A.size( 0 ), A.size( 1 ), -1 ).transpose( 1,2 ) 43 | V = F.relu( self.affine_VI( self.dropout( V ) ) ) 44 | 45 | return V 46 | 47 | # Encoder Block 48 | class EncoderBlock( nn.Module ): 49 | def __init__( self, embed_size, hidden_size, vocab_size ): 50 | super( EncoderBlock, self ).__init__() 51 | 52 | self.affine_ZV = nn.Linear(hidden_size, 49) # W_zv_output_attention 53 | self.affine_Zh = nn.Linear(hidden_size, 49) # W_zh_output_attention 54 | self.affine_alphaz = nn.Linear(49, 1) # w_alphaz_output_attention 55 | 56 | self.affine_QT = nn.Linear(embed_size, 5) # W_Qt 57 | self.affine_Qh = nn.Linear(hidden_size, 5) # W_Qh 58 | self.affine_betaq = nn.Linear(5, 1) # u_betaq 59 | 60 | self.affine_sq = nn.Linear(embed_size, embed_size) # W_sq 61 | self.affine_sh = nn.Linear(hidden_size, embed_size) # W_sh 62 | 63 | self.affine_Ss = nn.Linear(embed_size, 5) # W_Ss 64 | self.affine_Sr = nn.Linear(embed_size, 5) # W_Sr 65 | 66 | self.affine_sz = nn.Linear(hidden_size, embed_size) 67 | 68 | # Final Caption generator 69 | self.mlp = nn.Linear( embed_size, vocab_size ) 70 | 71 | # Dropout layer inside Affine Transformation 72 | self.dropout = nn.Dropout( 0.5 ) 73 | 74 | self.hidden_size = hidden_size 75 | self.init_weights() 76 | 77 | def init_weights( self ): 78 | ''' 79 | """Initialize the weights.""" 80 | ''' 81 | 82 | init.xavier_uniform(self.affine_ZV.weight) 83 | self.affine_ZV.bias.data.fill_(0) 84 | init.xavier_uniform(self.affine_Zh.weight) 85 | self.affine_Zh.bias.data.fill_(0) 86 | init.xavier_uniform(self.affine_alphaz.weight) 87 | self.affine_alphaz.bias.data.fill_(0) 88 | 89 | init.xavier_uniform(self.affine_QT.weight) 90 | self.affine_QT.bias.data.fill_(0) 91 | init.xavier_uniform(self.affine_Qh.weight) 92 | self.affine_Qh.bias.data.fill_(0) 93 | init.xavier_uniform(self.affine_betaq.weight) 94 | self.affine_betaq.bias.data.fill_(0) 95 | 96 | init.xavier_uniform(self.affine_sq.weight) 97 | self.affine_sq.bias.data.fill_(0) 98 | init.xavier_uniform(self.affine_sh.weight) 99 | self.affine_sh.bias.data.fill_(0) 100 | 101 | init.xavier_uniform(self.affine_Ss.weight) 102 | self.affine_Ss.bias.data.fill_(0) 103 | init.xavier_uniform(self.affine_Sr.weight) 104 | self.affine_Sr.bias.data.fill_(0) 105 | 106 | init.xavier_uniform(self.affine_sz.weight) 107 | self.affine_sz.bias.data.fill_(0) 108 | 109 | init.kaiming_normal( self.mlp.weight, mode='fan_in' ) 110 | self.mlp.bias.data.fill_( 0 ) 111 | 112 | def forward( self, epoch, h_t, V, T ): 113 | 114 | ''' 115 | Input: V=[v_1, v_2, ... v_k], h_t from LSTM and T from Topic Extractor 116 | Output: A probability indicating how likely the corresponding word in vocabulary D is the current output word 117 | ''' 118 | 119 | # -------------------------Output Attention :z_t_output-------------------------------------------------------------------- 120 | # W_ZV * V + W_Zh * h_t * 1^T 121 | content_V = self.affine_ZV(self.dropout(V)).unsqueeze(1) + self.affine_Zh(self.dropout(h_t)).unsqueeze(2) 122 | 123 | # visual_t = W_alphaz * tanh( content_V ) 124 | visual_t = self.affine_alphaz(self.dropout(F.tanh(content_V))).squeeze(3) 125 | alpha_t = F.softmax(visual_t.view(-1, visual_t.size(2))).view(visual_t.size(0), visual_t.size(1), -1) 126 | 127 | z_t = torch.bmm(alpha_t, V).squeeze(2) 128 | r_t = F.tanh(self.affine_sz(self.dropout(z_t))) 129 | 130 | # -------------------------Topic Attention :q_t-------------------------------------------------------------------- 131 | content_T = self.affine_QT(self.dropout(T)).unsqueeze(1) + self.affine_Qh(self.dropout(h_t)).unsqueeze(2) 132 | 133 | # topic_t = W_betaq * tanh( content_T ) 134 | topic_t = self.affine_betaq(self.dropout(F.tanh(content_T))).squeeze(3) 135 | beta_t = F.softmax(topic_t.view(-1, topic_t.size(2))).view(topic_t.size(0), topic_t.size(1), -1) 136 | 137 | q_t = torch.bmm(beta_t, T).squeeze(2) 138 | s_t = F.tanh(self.affine_sq(self.dropout(q_t)) + self.affine_sh(self.dropout(h_t))) 139 | 140 | # ------------------------------------------Merging Gate---------------------------------------------------- 141 | for ip in range(r_t.size(1)): 142 | 143 | # compute socre_s_t 144 | s_t_ip = s_t[:, ip, :].contiguous().view(s_t.size(0), 1, s_t.size(2)) 145 | s_t_extended = torch.cat([s_t_ip] * 5, 1) 146 | 147 | content_s_t = self.affine_Ss( s_t_extended ).unsqueeze(1) + self.affine_Qh( h_t ).unsqueeze(2) 148 | score_s_t = self.affine_betaq( F.tanh( content_s_t ) ).squeeze(3) 149 | 150 | if ip == 0: 151 | score_s = score_s_t[0][0][0].view(1, 1, 1) 152 | else: 153 | score_s = torch.cat([score_s, score_s_t[0][ip][0].view(1, 1, 1)], 1) 154 | 155 | # compute socre_r_t 156 | r_t_ip = r_t[:, ip, :].contiguous().view(r_t.size(0), 1, r_t.size(2)) 157 | r_t_extended = torch.cat([r_t_ip] * 5, 1) 158 | 159 | content_r_t = self.affine_Sr( r_t_extended ).unsqueeze(1) + self.affine_Qh( h_t ).unsqueeze(2) 160 | score_r_t = self.affine_betaq( F.tanh( content_r_t ) ).squeeze(3) 161 | 162 | if ip == 0: 163 | score_r = score_r_t[0][0][0].view(1, 1, 1) 164 | else: 165 | score_r = torch.cat([score_r, score_r_t[0][ip][0].view(1, 1, 1)], 1) 166 | 167 | # First train the model without visual attention for 15 epoch 168 | if epoch <= 20: 169 | gama_t = 1.0 170 | else: 171 | gama_t = F.sigmoid(score_s - score_r) 172 | 173 | # Final score along vocabulary 174 | #scores = self.mlp( self.dropout( c_t ) ) 175 | c_t = gama_t * s_t + (1-gama_t) * r_t 176 | scores = self.mlp( self.dropout( c_t ) ) 177 | 178 | return scores 179 | 180 | # Caption Decoder 181 | class Decoder(nn.Module): 182 | def __init__(self, embed_size, vocab_size, hidden_size): 183 | super(Decoder, self).__init__() 184 | 185 | # word embedding 186 | self.embed = nn.Embedding(vocab_size, embed_size) 187 | # LSTM decoder: input = [ w_t; v_input ] => 2 x word_embed_size; 188 | self.LSTM = nn.LSTM(embed_size * 2, hidden_size, 1, batch_first=True) 189 | 190 | # Save hidden_size for hidden and cell variable 191 | self.hidden_size = hidden_size 192 | 193 | # Encoder Block: Final scores for caption sampling 194 | self.encoder = EncoderBlock(embed_size, hidden_size, vocab_size) 195 | 196 | # reduce the feature map dimension 197 | self.affine_b = nn.Linear( hidden_size, embed_size ) 198 | 199 | # input_attention weights 200 | self.affine_ZV_input = nn.Linear(embed_size, 49 ) # W_ZV_input_attention 201 | self.affine_Zh_input = nn.Linear(hidden_size, 49 ) # W_Zh_input_attention 202 | self.affine_alphaz_input = nn.Linear(49, 1 ) # w_alphaz_input_attention 203 | 204 | self.dropout = nn.Dropout(0.5) 205 | self.init_weights() 206 | 207 | def init_weights(self): 208 | """Initialize the weights.""" 209 | init.kaiming_uniform( self.affine_b.weight, mode='fan_in' ) 210 | self.affine_b.bias.data.fill_( 0 ) 211 | 212 | init.xavier_uniform(self.affine_ZV_input.weight) 213 | self.affine_ZV_input.bias.data.fill_( 0 ) 214 | init.xavier_uniform(self.affine_Zh_input.weight) 215 | self.affine_Zh_input.bias.data.fill_( 0 ) 216 | init.xavier_uniform(self.affine_alphaz_input.weight) 217 | self.affine_alphaz_input.bias.data.fill_( 0 ) 218 | 219 | def forward(self, epoch, V, captions, T, states=None): 220 | 221 | # Reduce the feature map dimension 222 | V_input = F.relu( self.affine_b( self.dropout( V ) ) ) 223 | v_g = torch.mean( V_input,dim=1 ) 224 | 225 | # Word Embedding 226 | embeddings = self.embed( captions ) 227 | 228 | # Topic Embedding 229 | T = self.embed( T ) 230 | 231 | # x_t = embeddings 232 | x = embeddings 233 | 234 | # Hiddens: Batch x seq_len x hidden_size 235 | if torch.cuda.is_available(): 236 | hiddens = Variable(torch.zeros(x.size(0), x.size(1), self.hidden_size).cuda()) 237 | else: 238 | hiddens = Variable(torch.zeros(x.size(0), x.size(1), self.hidden_size)) 239 | 240 | # Recurrent Block 241 | for time_step in range(x.size(1)): 242 | 243 | # Feed in x_t one at a time 244 | x_t = x[:, time_step, :] 245 | x_t = x_t.unsqueeze(1) 246 | 247 | #-----input attention----- 248 | if time_step == 0: 249 | x_t = torch.cat((x_t, v_g.unsqueeze(1).expand_as(x_t)), dim=2) 250 | else: 251 | # W_ZV * V + W_Zh * h_t * 1^T 252 | content_v_input = self.affine_ZV_input(self.dropout(V_input)).unsqueeze(1) + self.affine_Zh_input(self.dropout(h_t)).unsqueeze(2) 253 | 254 | # visual_t = W_alphaz * tanh( content_v_input ) 255 | visual_t_input = self.affine_alphaz_input(self.dropout(F.tanh(content_v_input))).squeeze(3) 256 | alpha_t_input = F.softmax(visual_t_input.view(-1, visual_t_input.size(2))).view(visual_t_input.size(0),visual_t_input.size(1), -1) 257 | z_t_input = torch.bmm(alpha_t_input, V_input).squeeze(2) 258 | 259 | #x_t =[embeddings;z_t_input] 260 | x_t = torch.cat((x_t, z_t_input), dim=2) 261 | 262 | h_t, states = self.LSTM(x_t, states) 263 | 264 | # Save hidden 265 | hiddens[:, time_step, :] = h_t 266 | 267 | # Data parallelism for Encoder block 268 | if torch.cuda.device_count() > 1: 269 | device_ids = range( torch.cuda.device_count() ) 270 | encoder_block_parallel = nn.DataParallel( self.encoder, device_ids=device_ids ) 271 | scores = encoder_block_parallel( epoch, hiddens, V, T ) 272 | else: 273 | scores = self.encoder( epoch, hiddens, V, T ) 274 | 275 | # Return states for Caption Sampling purpose 276 | return scores, states 277 | 278 | # Whole Architecture with Image Encoder and Caption decoder 279 | class Encoder2Decoder( nn.Module ): 280 | def __init__( self, embed_size, vocab_size, hidden_size ): 281 | super( Encoder2Decoder, self ).__init__() 282 | 283 | # Image CNN encoder and simNet Decoder 284 | self.encoder = AttentiveCNN( hidden_size ) 285 | self.decoder = Decoder( embed_size, vocab_size, hidden_size ) 286 | 287 | def forward( self, epoch, images, captions, lengths, T ): 288 | 289 | if torch.cuda.device_count() > 1: 290 | device_ids = range( torch.cuda.device_count() ) 291 | encoder_parallel = torch.nn.DataParallel( self.encoder, device_ids=device_ids ) 292 | V = encoder_parallel( images ) 293 | else: 294 | V = self.encoder( images ) 295 | 296 | # Language Modeling on word prediction 297 | scores, _ = self.decoder( epoch, V, captions, T ) 298 | 299 | # Pack it to make criterion calculation more efficient 300 | packed_scores = pack_padded_sequence( scores, lengths, batch_first=True ) 301 | 302 | return packed_scores 303 | 304 | # Caption generator 305 | def sampler( self, epoch, images, T, max_len=20 ): 306 | """ 307 | Samples captions for given image features. 308 | """ 309 | 310 | # Data parallelism if multiple GPUs 311 | if torch.cuda.device_count() > 1: 312 | device_ids = range( torch.cuda.device_count() ) 313 | encoder_parallel = torch.nn.DataParallel( self.encoder, device_ids=device_ids ) 314 | V = encoder_parallel( images ) 315 | else: 316 | V = self.encoder( images ) 317 | 318 | # Build the starting token Variable (index 1): B x 1 319 | if torch.cuda.is_available(): 320 | captions = Variable( torch.LongTensor( images.size( 0 ), 1 ).fill_( 1 ).cuda() ) 321 | else: 322 | captions = Variable( torch.LongTensor( images.size( 0 ), 1 ).fill_( 1 ) ) 323 | 324 | # Get generated caption idx list 325 | sampled_ids = [] 326 | 327 | # Initial hidden states 328 | states = None 329 | 330 | for i in range( max_len ): 331 | scores, states = self.decoder( epoch, V, captions, T, states ) 332 | captions = scores.max( 2 )[ 1 ] 333 | 334 | # Save sampled word 335 | sampled_ids.append( captions ) 336 | 337 | # caption: B x max_len 338 | sampled_ids = torch.cat( sampled_ids, dim=1 ) 339 | 340 | return sampled_ids 341 | --------------------------------------------------------------------------------