├── LICENCE.md ├── README.md ├── coco-caption ├── .gitignore └── myeval.py ├── data ├── coco │ ├── coco_preprocess.ipynb │ └── coco_preprocess_test.ipynb └── flickr30k │ └── flickr30k_preprocess.ipynb ├── demo.ipynb ├── demo ├── demo_img.jpg ├── fig1.png ├── fig2.png └── fig3.png ├── eval_visulization.lua ├── image_model └── download_model.py ├── misc ├── DataLoaderResNet.lua ├── LSTM.lua ├── LanguageModel.lua ├── LookupTableMaskZero.lua ├── attention.lua ├── call_python_caption_eval.sh ├── img_embedding.lua ├── net_utils.lua ├── optim_updates.lua ├── transforms.lua └── utils.lua ├── prepro ├── prepro_coco.py ├── prepro_coco_test.py └── prepro_flickr.py ├── train.lua └── visu ├── DataLoaderResNetEval.lua ├── LanguageModel_visu.lua ├── attention_visu.lua └── visAtten.lua /LICENCE.md: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Patent (Pending) 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AdaptiveAttention 2 | Implementation of "[Knowing When to Look: Adaptive Attention via A Visual Sentinel for Image Captioning](https://arxiv.org/pdf/1612.01887.pdf)" 3 | 4 | ![teaser results](https://raw.github.com/jiasenlu/AdaptiveAttention/master/demo/fig1.png) 5 | 6 | ### Requirements 7 | 8 | To train the model require GPU with 12GB Memory, if you do not have GPU, you can directly use the pretrained model for inference. 9 | 10 | This code is written in Lua and requires [Torch](http://torch.ch/). The preprocssinng code is in Python, and you need to install [NLTK](http://www.nltk.org/) if you want to use NLTK to tokenize the caption. 11 | 12 | You also need to install the following package in order to sucessfully run the code. 13 | 14 | - [cudnn.torch](https://github.com/soumith/cudnn.torch) 15 | - [torch-hdf5](https://github.com/deepmind/torch-hdf5) 16 | - [lua-cjson](http://www.kyne.com.au/~mark/software/lua-cjson.php) 17 | - [iTorch](https://github.com/facebook/iTorch) 18 | 19 | ##### Pretrained Model 20 | The pre-trained model for COCO can be download [here](https://filebox.ece.vt.edu/~jiasenlu/codeRelease/AdaptiveAttention/model/COCO/). 21 | The pre-trained model for Flickr30K can be download [here](https://filebox.ece.vt.edu/~jiasenlu/codeRelease/AdaptiveAttention/model/Flickr30k/). 22 | 23 | ##### Vocabulary File 24 | Download the corresponding Vocabulary file for [COCO](https://filebox.ece.vt.edu/~jiasenlu/codeRelease/AdaptiveAttention/data/COCO/) and [Flickr30k](https://filebox.ece.vt.edu/~jiasenlu/codeRelease/AdaptiveAttention/data/Flickr30k/) 25 | 26 | ##### Download Dataset 27 | The first thing you need to do is to download the data and do some preprocessing. Head over to the `data/` folder and run the correspodning ipython script. It will download, preprocess and generate coco_raw.json. 28 | 29 | Download [COCO](http://mscoco.org/) and Flickr30k image dataset, extract the image and put under somewhere. 30 | 31 | 32 | ### training a new model on MS COCO 33 | First train the Language model without finetune the image. 34 | ``` 35 | th train.lua -batch_size 20 36 | ``` 37 | When finetune the CNN, load the saved model and train for another 15~20 epoch. 38 | ``` 39 | th train.lua -batch_size 16 -startEpoch 21 -start_from 'model_id1_20.t7' 40 | ``` 41 | 42 | 43 | ### More Result about spatial attention and visual sentinel 44 | 45 | ![teaser results](https://raw.github.com/jiasenlu/AdaptiveAttention/master/demo/fig2.png) 46 | 47 | ![teaser results](https://raw.github.com/jiasenlu/AdaptiveAttention/master/demo/fig3.png) 48 | 49 | For more visualization result, you can visit [here](https://filebox.ece.vt.edu/~jiasenlu/demo/caption_atten/demo.html) 50 | (it will load more than 1000 image and their result...) 51 | 52 | ### Reference 53 | If you use this code as part of any published research, please acknowledge the following paper 54 | ``` 55 | @misc{Lu2017Adaptive, 56 | author = {Lu, Jiasen and Xiong, Caiming and Parikh, Devi and Socher, Richard}, 57 | title = {Knowing When to Look: Adaptive Attention via A Visual Sentinel for Image Captioning}, 58 | journal = {CVPR}, 59 | year = {2017} 60 | } 61 | ``` 62 | 63 | ### Acknowledgement 64 | 65 | This code is developed based on [NeuralTalk2](https://github.com/karpathy/neuraltalk2). 66 | 67 | Thanks [Torch](http://torch.ch/) team and Facebook [ResNet](https://github.com/facebook/fb.resnet.torch) implementation. 68 | 69 | ### License 70 | 71 | BSD 3-Clause License 72 | 73 | 74 | 75 | -------------------------------------------------------------------------------- /coco-caption/.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *~ 3 | -------------------------------------------------------------------------------- /coco-caption/myeval.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script should be run from root directory of this codebase: 3 | https://github.com/tylin/coco-caption 4 | """ 5 | 6 | from pycocotools.coco import COCO 7 | from pycocoevalcap.eval import COCOEvalCap 8 | import json 9 | from json import encoder 10 | encoder.FLOAT_REPR = lambda o: format(o, '.3f') 11 | import sys 12 | import pdb 13 | input_json = sys.argv[1] 14 | annFile = sys.argv[2] 15 | 16 | coco = COCO(annFile) 17 | valids = coco.getImgIds() 18 | 19 | checkpoint = json.load(open(input_json, 'r')) 20 | preds = checkpoint['val_predictions'] 21 | 22 | # filter results to only those in MSCOCO validation set (will be about a third) 23 | preds_filt = [p for p in preds if p['image_id'] in valids] 24 | print 'using %d/%d predictions' % (len(preds_filt), len(preds)) 25 | json.dump(preds_filt, open('tmp.json', 'w')) # serialize to temporary json file. Sigh, COCO API... 26 | 27 | resFile = 'tmp.json' 28 | cocoRes = coco.loadRes(resFile) 29 | cocoEval = COCOEvalCap(coco, cocoRes) 30 | cocoEval.params['image_id'] = cocoRes.getImgIds() 31 | cocoEval.evaluate() 32 | 33 | # create output dictionary 34 | out = {} 35 | for metric, score in cocoEval.eval.items(): 36 | out[metric] = score 37 | # serialize to file, to be read from Lua 38 | json.dump(out, open(input_json + '_out.json', 'w')) 39 | 40 | -------------------------------------------------------------------------------- /data/coco/coco_preprocess.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# COCO data preprocessing\n", 8 | "\n", 9 | "This code will download the caption anotations for coco and preprocess them into an hdf5 file and a json file. \n", 10 | "\n", 11 | "These will then be read by the COCO data loader in Lua and trained on." 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 2, 17 | "metadata": { 18 | "collapsed": false 19 | }, 20 | "outputs": [ 21 | { 22 | "data": { 23 | "text/plain": [ 24 | "0" 25 | ] 26 | }, 27 | "execution_count": 2, 28 | "metadata": {}, 29 | "output_type": "execute_result" 30 | } 31 | ], 32 | "source": [ 33 | "# lets download the annotations from http://mscoco.org/dataset/#download\n", 34 | "import os\n", 35 | "os.system('wget http://msvocds.blob.core.windows.net/annotations-1-0-3/captions_train-val2014.zip') # ~19MB" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 3, 41 | "metadata": { 42 | "collapsed": false 43 | }, 44 | "outputs": [ 45 | { 46 | "data": { 47 | "text/plain": [ 48 | "256" 49 | ] 50 | }, 51 | "execution_count": 3, 52 | "metadata": {}, 53 | "output_type": "execute_result" 54 | } 55 | ], 56 | "source": [ 57 | "os.system('unzip captions_train-val2014.zip')" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 4, 63 | "metadata": { 64 | "collapsed": false 65 | }, 66 | "outputs": [], 67 | "source": [ 68 | "import json\n", 69 | "val = json.load(open('annotations/captions_val2014.json', 'r'))\n", 70 | "train = json.load(open('annotations/captions_train2014.json', 'r'))" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 5, 76 | "metadata": { 77 | "collapsed": false 78 | }, 79 | "outputs": [ 80 | { 81 | "name": "stdout", 82 | "output_type": "stream", 83 | "text": [ 84 | "[u'info', u'images', u'licenses', u'annotations']\n", 85 | "{u'description': u'This is stable 1.0 version of the 2014 MS COCO dataset.', u'url': u'http://mscoco.org', u'version': u'1.0', u'year': 2014, u'contributor': u'Microsoft COCO group', u'date_created': u'2015-01-27 09:11:52.357475'}\n", 86 | "40504\n", 87 | "202654\n", 88 | "{u'license': 3, u'file_name': u'COCO_val2014_000000391895.jpg', u'coco_url': u'http://mscoco.org/images/391895', u'height': 360, u'width': 640, u'date_captured': u'2013-11-14 11:18:45', u'flickr_url': u'http://farm9.staticflickr.com/8186/8119368305_4e622c8349_z.jpg', u'id': 391895}\n", 89 | "{u'image_id': 203564, u'id': 37, u'caption': u'A bicycle replica with a clock as the front wheel.'}\n" 90 | ] 91 | } 92 | ], 93 | "source": [ 94 | "print val.keys()\n", 95 | "print val['info']\n", 96 | "print len(val['images'])\n", 97 | "print len(val['annotations'])\n", 98 | "print val['images'][0]\n", 99 | "print val['annotations'][0]" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 6, 105 | "metadata": { 106 | "collapsed": false 107 | }, 108 | "outputs": [], 109 | "source": [ 110 | "import json\n", 111 | "import os\n", 112 | "\n", 113 | "# combine all images and annotations together\n", 114 | "imgs = val['images'] + train['images']\n", 115 | "annots = val['annotations'] + train['annotations']\n", 116 | "\n", 117 | "# for efficiency lets group annotations by image\n", 118 | "itoa = {}\n", 119 | "for a in annots:\n", 120 | " imgid = a['image_id']\n", 121 | " if not imgid in itoa: itoa[imgid] = []\n", 122 | " itoa[imgid].append(a)\n", 123 | "\n", 124 | "# create the json blob\n", 125 | "out = []\n", 126 | "for i,img in enumerate(imgs):\n", 127 | " imgid = img['id']\n", 128 | " \n", 129 | " # coco specific here, they store train/val images separately\n", 130 | " loc = 'train2014' if 'train' in img['file_name'] else 'val2014'\n", 131 | " \n", 132 | " jimg = {}\n", 133 | " jimg['file_path'] = os.path.join(loc, img['file_name'])\n", 134 | " jimg['id'] = imgid\n", 135 | " \n", 136 | " sents = []\n", 137 | " annotsi = itoa[imgid]\n", 138 | " for a in annotsi:\n", 139 | " sents.append(a['caption'])\n", 140 | " jimg['captions'] = sents\n", 141 | " out.append(jimg)\n", 142 | " \n", 143 | "json.dump(out, open('coco_raw.json', 'w'))" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": 11, 149 | "metadata": { 150 | "collapsed": false 151 | }, 152 | "outputs": [ 153 | { 154 | "data": { 155 | "text/plain": [ 156 | "u'train2014/COCO_train2014_000000475546.jpg'" 157 | ] 158 | }, 159 | "execution_count": 11, 160 | "metadata": {}, 161 | "output_type": "execute_result" 162 | } 163 | ], 164 | "source": [ 165 | "jimg['file_path']" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": 7, 171 | "metadata": { 172 | "collapsed": false 173 | }, 174 | "outputs": [ 175 | { 176 | "name": "stdout", 177 | "output_type": "stream", 178 | "text": [ 179 | "{'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895}\n" 180 | ] 181 | } 182 | ], 183 | "source": [ 184 | "# lets see what they look like\n", 185 | "print out[0]" 186 | ] 187 | } 188 | ], 189 | "metadata": { 190 | "kernelspec": { 191 | "display_name": "Python 2", 192 | "language": "python", 193 | "name": "python2" 194 | }, 195 | "language_info": { 196 | "codemirror_mode": { 197 | "name": "ipython", 198 | "version": 2 199 | }, 200 | "file_extension": ".py", 201 | "mimetype": "text/x-python", 202 | "name": "python", 203 | "nbconvert_exporter": "python", 204 | "pygments_lexer": "ipython2", 205 | "version": "2.7.11" 206 | } 207 | }, 208 | "nbformat": 4, 209 | "nbformat_minor": 0 210 | } 211 | -------------------------------------------------------------------------------- /data/coco/coco_preprocess_test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# COCO data preprocessing\n", 8 | "\n", 9 | "This code will download the caption anotations for coco and preprocess them into an hdf5 file and a json file. \n", 10 | "\n", 11 | "These will then be read by the COCO data loader in Lua and trained on." 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 2, 17 | "metadata": { 18 | "collapsed": false 19 | }, 20 | "outputs": [], 21 | "source": [ 22 | "import json\n", 23 | "test = json.load(open('annotations/image_info_test2014.json', 'r'))" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 8, 29 | "metadata": { 30 | "collapsed": false 31 | }, 32 | "outputs": [ 33 | { 34 | "name": "stdout", 35 | "output_type": "stream", 36 | "text": [ 37 | "[u'info', u'images', u'licenses', u'categories']\n", 38 | "{u'description': u'This is stable 1.0 version of the 2014 MS COCO dataset.', u'url': u'http://mscoco.org', u'version': u'1.0', u'year': 2014, u'contributor': u'Microsoft COCO group', u'date_created': u'2015-11-11 02:11:36.777541'}\n", 39 | "40775\n", 40 | "{u'license': 2, u'file_name': u'COCO_test2014_000000523573.jpg', u'coco_url': u'http://mscoco.org/images/523573', u'height': 500, u'width': 423, u'date_captured': u'2013-11-14 12:21:59', u'id': 523573}\n" 41 | ] 42 | } 43 | ], 44 | "source": [ 45 | "print test.keys()\n", 46 | "print test['info']\n", 47 | "print len(test['images'])\n", 48 | "print test['images'][0]" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 9, 54 | "metadata": { 55 | "collapsed": false 56 | }, 57 | "outputs": [], 58 | "source": [ 59 | "import json\n", 60 | "import os\n", 61 | "\n", 62 | "# combine all images and annotations together\n", 63 | "imgs = test['images']\n", 64 | "\n", 65 | "# create the json blob\n", 66 | "out = []\n", 67 | "for i,img in enumerate(imgs):\n", 68 | " imgid = img['id']\n", 69 | " \n", 70 | " # coco specific here, they store train/val images separately\n", 71 | " loc = 'test2014'\n", 72 | " \n", 73 | " jimg = {}\n", 74 | " jimg['file_path'] = os.path.join(loc, img['file_name'])\n", 75 | " jimg['id'] = imgid\n", 76 | " \n", 77 | " out.append(jimg)\n", 78 | " \n", 79 | "json.dump(out, open('coco_test_raw.json', 'w'))" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 6, 85 | "metadata": { 86 | "collapsed": false 87 | }, 88 | "outputs": [ 89 | { 90 | "data": { 91 | "text/plain": [ 92 | "u'val2014/COCO_test2014_000000155724.jpg'" 93 | ] 94 | }, 95 | "execution_count": 6, 96 | "metadata": {}, 97 | "output_type": "execute_result" 98 | } 99 | ], 100 | "source": [ 101 | "jimg['file_path']" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": 7, 107 | "metadata": { 108 | "collapsed": false 109 | }, 110 | "outputs": [ 111 | { 112 | "name": "stdout", 113 | "output_type": "stream", 114 | "text": [ 115 | "{'file_path': u'val2014/COCO_test2014_000000523573.jpg', 'id': 523573}\n" 116 | ] 117 | } 118 | ], 119 | "source": [ 120 | "# lets see what they look like\n", 121 | "print out[0]" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": null, 127 | "metadata": { 128 | "collapsed": true 129 | }, 130 | "outputs": [], 131 | "source": [] 132 | } 133 | ], 134 | "metadata": { 135 | "kernelspec": { 136 | "display_name": "Python [Root]", 137 | "language": "python", 138 | "name": "Python [Root]" 139 | }, 140 | "language_info": { 141 | "codemirror_mode": { 142 | "name": "ipython", 143 | "version": 2 144 | }, 145 | "file_extension": ".py", 146 | "mimetype": "text/x-python", 147 | "name": "python", 148 | "nbconvert_exporter": "python", 149 | "pygments_lexer": "ipython2", 150 | "version": "2.7.12" 151 | } 152 | }, 153 | "nbformat": 4, 154 | "nbformat_minor": 0 155 | } 156 | -------------------------------------------------------------------------------- /demo/demo_img.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiasenlu/AdaptiveAttention/2618080c02d9cd708f36c65b199d41c4ab225087/demo/demo_img.jpg -------------------------------------------------------------------------------- /demo/fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiasenlu/AdaptiveAttention/2618080c02d9cd708f36c65b199d41c4ab225087/demo/fig1.png -------------------------------------------------------------------------------- /demo/fig2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiasenlu/AdaptiveAttention/2618080c02d9cd708f36c65b199d41c4ab225087/demo/fig2.png -------------------------------------------------------------------------------- /demo/fig3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiasenlu/AdaptiveAttention/2618080c02d9cd708f36c65b199d41c4ab225087/demo/fig3.png -------------------------------------------------------------------------------- /eval_visulization.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'nn' 3 | require 'nngraph' 4 | -- local imports 5 | require 'visu.DataLoaderResNetEval' 6 | local utils = require 'misc.utils' 7 | require 'visu.LanguageModel_visu' 8 | local net_utils = require 'misc.net_utils' 9 | require 'misc.optim_updates' 10 | require 'gnuplot' 11 | require 'xlua' 12 | ------------------------------------------------------------------------------- 13 | -- Input arguments and options 14 | ------------------------------------------------------------------------------- 15 | cmd = torch.CmdLine() 16 | cmd:text() 17 | cmd:text('Train an Image Captioning model') 18 | cmd:text() 19 | cmd:text('Options') 20 | 21 | 22 | -- Model settings 23 | --[[ 24 | cmd:option('-dataset','flickr30k','') 25 | cmd:option('-input_h5','/data/flickr30k/cocotalk.h5','path to the h5file containing the preprocessed dataset') 26 | cmd:option('-input_json','/data/flickr30k/cocotalk.json','path to the json file containing additional info and vocab') 27 | cmd:option('-cnn_model','../image_model/resnet-152.t7','path to CNN model file containing the weights, Caffe format. Note this MUST be a VGGNet-16 right now.') 28 | ]]-- 29 | 30 | cmd:option('-input_h5','/data/coco/cocotalk.h5','path to the h5file containing the preprocessed dataset') 31 | cmd:option('-input_json','/data/coco/cocotalk.json','path to the json file containing additional info and vocab') 32 | cmd:option('-cnn_model','../image_model/resnet-152.t7','path to CNN model file containing the weights, Caffe format. Note this MUST be a VGGNet-16 right now.') 33 | cmd:option('-checkpoint_path', 'save/coco_val_1', 'folder to save checkpoints into (empty = this folder)') 34 | 35 | --[[ 36 | cmd:option('-input_h5','/data/coco/cocotalk_test.h5','path to the h5file containing the preprocessed dataset') 37 | cmd:option('-input_json','/data/coco/cocotalk_test.json','path to the json file containing additional info and vocab') 38 | cmd:option('-input_vocab_json','/data/coco/cocotalk.json','path to the json file containing additional info and vocab') 39 | cmd:option('-cnn_model','../image_model/resnet-152.t7','path to CNN model file containing the weights, Caffe format. Note this MUST be a VGGNet-16 right now.') 40 | ]]-- 41 | cmd:option('-start_from', 'model_id1_36.t7', 'path to a model checkpoint to initialize model weights from. Empty = don\'t') 42 | cmd:option('-beam_size', 3, 'Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.') 43 | --cmd:option('-checkpoint_path', 'save/flickr30k_512x1_1', 'folder to save checkpoints into (empty = this folder)') 44 | 45 | cmd:option('-drop_prob_lm', 0.5, 'strength of dropout in the Language Model RNN') 46 | cmd:option('-rnn_size',512,'size of the rnn in number of hidden nodes in each layer') 47 | cmd:option('-num_layers',1,'the encoding size of each token in the vocabulary, and the image.') 48 | cmd:option('-input_encoding_size',512,'the encoding size of each token in the vocabulary, and the image.') 49 | cmd:option('-batch_size',10,'what is the batch size in number of images per batch? (there will be x seq_per_img sentences)') 50 | 51 | cmd:option('-fc_size',2048,'the encoding size of each token in the vocabulary, and the image.') 52 | cmd:option('-conv_size',2048,'the encoding size of each token in the vocabulary, and the image.') 53 | cmd:option('-seq_per_img',5,'number of captions to sample for each image during training. Done for efficiency since CNN forward pass is expensive. E.g. coco has 5 sents/image') 54 | 55 | cmd:option('-val_images_use', -1, 'how many images to use when periodically evaluating the validation loss? (-1 = all)') 56 | cmd:option('-save_checkpoint_every', 3, 'how often to save a model checkpoint?') 57 | cmd:option('-language_eval', 0, 'Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.') 58 | 59 | -- misc 60 | cmd:option('-backend', 'cudnn', 'nn|cudnn') 61 | cmd:option('-id', '1', 'an id identifying this run/job. used in cross-val and appended when writing progress files') 62 | cmd:option('-seed', 123, 'random number generator seed to use') 63 | cmd:option('-gpuid', 0, 'which gpu to use. -1 = use CPU') 64 | 65 | cmd:text() 66 | 67 | ------------------------------------------------------------------------------- 68 | -- Basic Torch initializations 69 | ------------------------------------------------------------------------------- 70 | local opt = cmd:parse(arg) 71 | --torch.manualSeed(opt.seed) 72 | torch.setdefaulttensortype('torch.FloatTensor') -- for CPU 73 | 74 | if opt.gpuid >= 0 then 75 | require 'cutorch' 76 | require 'cunn' 77 | if opt.backend == 'cudnn' then require 'cudnn' end 78 | --cutorch.manualSeed(opt.seed) 79 | cutorch.setDevice(opt.gpuid + 1) -- note +1 because lua is 1-indexed 80 | end 81 | 82 | ------------------------------------------------------------------------------- 83 | -- Create the Data Loader instance 84 | ------------------------------------------------------------------------------- 85 | local loader = DataLoader{h5_file = opt.input_h5, json_file = opt.input_json, neighbor_h5 = opt.nn_neighbor, 86 | batch_size = opt.batch_size, seq_per_img = opt.seq_per_img, thread_num = opt.thread_num} 87 | --local loader = DataLoader{h5_file = opt.input_h5, json_file = opt.input_json, vocab_json_file = opt.input_vocab_json,neighbor_h5 = opt.nn_neighbor, 88 | -- batch_size = opt.batch_size, seq_per_img = opt.seq_per_img, thread_num = opt.thread_num} 89 | 90 | ------------------------------------------------------------------------------- 91 | -- Initialize the networks 92 | ------------------------------------------------------------------------------- 93 | -- create protos from scratch 94 | -- intialize language model 95 | local lmOpt = {} 96 | lmOpt.vocab_size = loader:getVocabSize() 97 | lmOpt.input_encoding_size = opt.input_encoding_size 98 | lmOpt.rnn_size = opt.rnn_size 99 | lmOpt.num_layers = opt.num_layers 100 | lmOpt.dropout = opt.drop_prob_lm 101 | lmOpt.seq_length = loader:getSeqLength() 102 | lmOpt.batch_size = opt.batch_size * opt.seq_per_img 103 | lmOpt.fc_size = opt.fc_size 104 | lmOpt.conv_size = opt.conv_size 105 | 106 | local loaded_checkpoint 107 | if opt.start_from ~= '' then -- just copy to gpu1 params 108 | local loaded_checkpoint_path = path.join(opt.checkpoint_path, opt.start_from) 109 | print(loaded_checkpoint_path) 110 | loaded_checkpoint = torch.load(loaded_checkpoint_path) 111 | end 112 | 113 | -- iterate over different gpu 114 | local protos = {} 115 | 116 | protos.lm = nn.LanguageModel(lmOpt):cuda() 117 | -- initialize the ConvNet 118 | if opt.start_from ~= '' then -- just copy to gpu1 params 119 | protos.cnn_conv_fix = loaded_checkpoint.protos.cnn_conv_fix:cuda() 120 | protos.cnn_conv = loaded_checkpoint.protos.cnn_conv:cuda() 121 | protos.cnn_fc = loaded_checkpoint.protos.cnn_fc:cuda() 122 | else 123 | local cnn_raw = torch.load(opt.cnn_model) 124 | 125 | protos.cnn_conv_fix = net_utils.build_residual_cnn_conv_fix(cnn_raw, 126 | {backend = cnn_backend, start_layer_num = opt.finetune_start_layer}):cuda() 127 | 128 | protos.cnn_conv = net_utils.build_residual_cnn_conv(cnn_raw, 129 | {backend = cnn_backend, start_layer_num = opt.finetune_start_layer}):cuda() 130 | 131 | protos.cnn_fc = net_utils.build_residual_cnn_fc(cnn_raw, 132 | {backend = cnn_backend}):cuda() 133 | end 134 | protos.expanderConv = nn.FeatExpanderConv(opt.seq_per_img):cuda() 135 | protos.expanderFC = nn.FeatExpander(opt.seq_per_img):cuda() 136 | protos.transform_cnn_conv = net_utils.transform_cnn_conv(opt.conv_size):cuda() 137 | -- criterion for the language model 138 | protos.crit = nn.LanguageModelCriterion():cuda() 139 | 140 | params, grad_params = protos.lm:getParameters() 141 | cnn1_params, cnn1_grad_params = protos.cnn_conv:getParameters() 142 | 143 | print('total number of parameters in LM: ', params:nElement()) 144 | print('total number of parameters in CNN_conv: ', cnn1_params:nElement()) 145 | 146 | assert(params:nElement() == grad_params:nElement()) 147 | assert(cnn1_params:nElement() == cnn1_grad_params:nElement()) 148 | 149 | if opt.start_from ~= '' then -- just copy to gpu1 params 150 | params:copy(loaded_checkpoint.lmparam) 151 | end 152 | 153 | protos.lm:createClones() 154 | 155 | -- create clones and ensure parameter sharing. we have to do this 156 | -- all the way here at the end because calls such as :cuda() and 157 | -- :getParameters() reshuffle memory around. 158 | 159 | collectgarbage() -- "yeah, sure why not" 160 | ------------------------------------------------------------------------------- 161 | -- Evaluation fun(ction) 162 | ------------------------------------------------------------------------------- 163 | local function evaluate_split(split, evalopt) 164 | local val_images_use = utils.getopt(evalopt, 'val_images_use', -1) 165 | 166 | print('=> evaluating ...') 167 | 168 | -- setting to the evaluation mode, use only the first gpu 169 | protos.cnn_conv:evaluate() 170 | protos.cnn_fc:evaluate() 171 | protos.lm:evaluate() 172 | protos.cnn_conv_fix:evaluate() 173 | 174 | local n = 0 175 | local loss_sum = 0 176 | local loss_evals = 0 177 | local predictions = {} 178 | local vocab = loader:getVocab() 179 | local imgId_cell = {} 180 | 181 | local nbatch = math.ceil(val_images_use / opt.batch_size) 182 | if val_images_use == -1 then 183 | nbatch = loader:getnBatch(split) 184 | end 185 | 186 | loader:init_rand(split) 187 | loader:reset_iterator(split) 188 | 189 | local atten_out_all = torch.FloatTensor(loader:getSeqLength()+1, 5*nbatch*opt.batch_size, 50):zero() 190 | --for n, data in loader:run({split = split, size_image_use = val_images_use}) do 191 | for n = 1, nbatch do 192 | local data = loader:run({split = split, size_image_use = val_images_use}) 193 | xlua.progress(n,nbatch) 194 | 195 | -- convert the data to cuda 196 | data.images = data.images:cuda() 197 | data.labels = data.labels:cuda() 198 | 199 | -- forward the model to get loss 200 | local feats_conv_fix = protos.cnn_conv_fix:forward(data.images) 201 | 202 | local feats_conv = protos.cnn_conv:forward(feats_conv_fix) 203 | local feat_conv_t = protos.transform_cnn_conv:forward(feats_conv) 204 | local feats_fc = protos.cnn_fc:forward(feats_conv) 205 | 206 | local expanded_feats_conv = protos.expanderConv:forward(feat_conv_t) 207 | local expanded_feats_fc = protos.expanderFC:forward(feats_fc) 208 | 209 | local logprobs, atten = protos.lm:forward({expanded_feats_conv, expanded_feats_fc, data.labels}) 210 | --local loss = protos.crit:forward({logprobs, data.labels}) 211 | --loss_sum = loss_sum + loss 212 | --loss_evals = loss_evals + 1 213 | 214 | -- forward the model to also get generated samples for each image 215 | local sampleOpt = {beam_size = opt.beam_size} 216 | --local seq, atten = protos.lm:sample({feat_conv_t, feats_fc, vocab}, sampleOpt) 217 | local sents, count = net_utils.decode_sequence(vocab, data.labels) 218 | 219 | local s = (n-1)*opt.batch_size*5+1 220 | atten_out_all:narrow(2,s,opt.batch_size*5):copy(atten) 221 | 222 | for k=1,#sents do 223 | local idx = math.floor((k-1)/5)+1 224 | local img_id = data.img_id[idx] 225 | local entry 226 | --if imgId_cell[img_id] == nil then -- make sure there are one caption for each image. 227 | --imgId_cell[img_id] = 1 228 | local prob_tmp = {} 229 | for m = 1, count[k] do 230 | table.insert(prob_tmp, 1-atten[m][k][1]) 231 | end 232 | entry = {image_id = img_id, caption = sents[k], prob = prob_tmp} 233 | table.insert(predictions, entry) 234 | --end 235 | end 236 | end 237 | local lang_stats 238 | if opt.language_eval == 1 then 239 | lang_stats = net_utils.language_eval(predictions, {id = opt.id, dataset = opt.dataset}) 240 | end 241 | 242 | return predictions, lang_stats, atten_out_all 243 | end 244 | 245 | local split_predictions, lang_stats, atten_out_all = evaluate_split('test', {val_images_use = opt.val_images_use, verbose = opt.verbose}) 246 | 247 | if lang_stats then 248 | print(lang_stats) 249 | end 250 | 251 | utils.write_json('visu_gt_test.json', split_predictions) 252 | torch.save('atten_gt_test_1.t7', atten_out_all) 253 | 254 | 255 | -------------------------------------------------------------------------------- /image_model/download_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Download the VGG and deep residual model to extract image features. 3 | 4 | Version: 1.0 5 | Contributor: Jiasen Lu 6 | """ 7 | 8 | import os 9 | import argparse 10 | import json 11 | def download_VGG(): 12 | print('Downloading VGG model from http://www.robots.ox.ac.uk/~vgg/software/very_deep/caffe/VGG_ILSVRC_19_layers.caffemodel') 13 | os.system('wget http://www.robots.ox.ac.uk/~vgg/software/very_deep/caffe/VGG_ILSVRC_19_layers.caffemodel') 14 | os.system('wget https://gist.githubusercontent.com/ksimonyan/3785162f95cd2d5fee77/raw/bb2b4fe0a9bb0669211cf3d0bc949dfdda173e9e/VGG_ILSVRC_19_layers_deploy.prototxt') 15 | 16 | def download_deep_residual(): 17 | print('Downloading deep residual model from https://d2j0dndfm35trm.cloudfront.net/resnet-152.t7') 18 | os.system('wget https://d2j0dndfm35trm.cloudfront.net/resnet-152.t7') 19 | os.system('wget https://raw.githubusercontent.com/facebook/fb.resnet.torch/master/datasets/transforms.lua') 20 | 21 | def main(params): 22 | if params['download'] == 'VGG': 23 | download_VGG() 24 | else: 25 | download_deep_residual() 26 | 27 | if __name__ == "__main__": 28 | 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('--download', default='Residual', help='VGG or Residual') 31 | # input json 32 | args = parser.parse_args() 33 | params = vars(args) 34 | print 'parsed input parameters:' 35 | print json.dumps(params, indent = 2) 36 | main(params) 37 | -------------------------------------------------------------------------------- /misc/DataLoaderResNet.lua: -------------------------------------------------------------------------------- 1 | require 'hdf5' 2 | local utils = require 'misc.utils' 3 | local net_utils = require 'misc.net_utils' 4 | local t = require 'misc.transforms' 5 | 6 | local DataLoader = torch.class('DataLoader') 7 | 8 | function DataLoader:__init(opt) 9 | 10 | -- load the json file which contains additional information about the dataset 11 | print('DataLoader loading json file: ', opt.json_file) 12 | self.info = utils.read_json(opt.json_file) 13 | self.ix_to_word = self.info.ix_to_word 14 | self.vocab_size = utils.count_keys(self.ix_to_word) 15 | 16 | self.batch_size = utils.getopt(opt, 'batch_size', 5) -- how many images get returned at one time (to go through CNN) 17 | self.seq_per_img = utils.getopt(opt, 'seq_per_img', 5) -- number of sequences to return per image 18 | 19 | print('vocab size is ' .. self.vocab_size) 20 | 21 | -- open the hdf5 file 22 | print('DataLoader loading h5 file: ', opt.h5_file) 23 | self.h5_file = hdf5.open(opt.h5_file, 'r') 24 | 25 | -- extract image size from dataset 26 | local images_size = self.h5_file:read('/images'):dataspaceSize() 27 | assert(#images_size == 4, '/images should be a 4D tensor') 28 | assert(images_size[3] == images_size[4], 'width and height must match') 29 | self.num_images = images_size[1] 30 | self.num_channels = images_size[2] 31 | self.max_image_size = images_size[3] 32 | 33 | self.imgs = self.h5_file:read('/images'):all() 34 | 35 | print(string.format('read %d images of size %dx%dx%d', self.num_images, 36 | self.num_channels, self.max_image_size, self.max_image_size)) 37 | 38 | -- load in the sequence data 39 | local seq_size = self.h5_file:read('/labels'):dataspaceSize() 40 | self.seq_length = seq_size[2] 41 | print('max sequence length in data is ' .. self.seq_length) 42 | -- load the pointers in full to RAM (should be small enough) 43 | self.label_start_ix = self.h5_file:read('/label_start_ix'):all() 44 | self.label_end_ix = self.h5_file:read('/label_end_ix'):all() 45 | self.labels = self.h5_file:read('/labels'):all() 46 | self.label_lens = self.h5_file:read('/label_length'):all() 47 | -- separate out indexes for each of the provided splits 48 | self.split_ix = {} 49 | self.iterator = {} 50 | self.image_ids = torch.LongTensor(self.num_images):zero() 51 | for i,img in pairs(self.info.images) do 52 | local split = img.split 53 | if not self.split_ix[split] then 54 | -- initialize new split 55 | self.split_ix[split] = {} 56 | self.iterator[split] = 1 57 | end 58 | table.insert(self.split_ix[split], i) 59 | self.image_ids[i] = img.id 60 | end 61 | 62 | self.__size = {} 63 | for k,v in pairs(self.split_ix) do 64 | print(string.format('assigned %d images to split %s', #v, k)) 65 | end 66 | 67 | self.meanstd = { 68 | mean = { 0.485, 0.456, 0.406 }, 69 | std = { 0.229, 0.224, 0.225 }, 70 | } 71 | 72 | self.transform = t.Compose{ 73 | t.ColorNormalize(self.meanstd) 74 | } 75 | end 76 | 77 | function DataLoader:init_rand(split) 78 | local size = #self.split_ix[split] 79 | if split == 'train' then 80 | self.perm = torch.randperm(size) 81 | else 82 | self.perm = torch.range(1,size) -- for test and validation, do not permutate 83 | end 84 | end 85 | 86 | function DataLoader:reset_iterator(split) 87 | self.iterator[split] = 1 88 | end 89 | 90 | function DataLoader:getVocabSize() 91 | return self.vocab_size 92 | end 93 | 94 | function DataLoader:getVocab() 95 | return self.ix_to_word 96 | end 97 | 98 | function DataLoader:getSeqLength() 99 | return self.seq_length 100 | end 101 | 102 | function DataLoader:getnBatch(split) 103 | return math.ceil(#self.split_ix[split] / self.batch_size) 104 | end 105 | 106 | function DataLoader:run(opt) 107 | local split = utils.getopt(opt, 'split') -- lets require that user passes this in, for safety 108 | local size_image_use = utils.getopt(opt, 'size_image_use', -1) 109 | local size, batch_size = #self.split_ix[split], self.batch_size 110 | local seq_per_img, seq_length = self.seq_per_img, self.seq_length 111 | local num_channels, max_image_size = self.num_channels, self.max_image_size 112 | 113 | if size_image_use ~= -1 and size_image_use <= size then size = size_image_use end 114 | local split_ix = self.split_ix[split] 115 | local idx = self.iterator[split] 116 | 117 | if idx <= size then 118 | 119 | local indices = self.perm:narrow(1, idx, math.min(batch_size, size - idx + 1)) 120 | 121 | local img_batch_raw = torch.ByteTensor(batch_size, 3, 256, 256) 122 | local label_batch = torch.LongTensor(batch_size * seq_per_img, seq_length):zero() 123 | local img_id_batch = torch.LongTensor(batch_size*seq_per_img):zero() 124 | for i, ixm in ipairs(indices:totable()) do 125 | 126 | local ix = split_ix[ixm] 127 | img_batch_raw[i] = self.imgs[ix] 128 | 129 | -- fetch the sequence labels 130 | local ix1 = self.label_start_ix[ix] 131 | local ix2 = self.label_end_ix[ix] 132 | 133 | local ncap = ix2 - ix1 + 1 -- number of captions available for this image 134 | assert(ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t') 135 | local seq 136 | if ncap < seq_per_img then 137 | -- we need to subsample (with replacement) 138 | seq = torch.LongTensor(seq_per_img, seq_length) 139 | for q=1, seq_per_img do 140 | local ixl = torch.random(ix1,ix2) 141 | seq[{{q,q}}] = self.labels[{{ixl, ixl}, {1,seq_length}}] 142 | end 143 | else 144 | -- there is enough data to read a contiguous chunk, but subsample the chunk position 145 | local ixl = torch.random(ix1, ix2 - seq_per_img + 1) -- generates integer in the range 146 | seq = self.labels[{{ixl, ixl+seq_per_img-1}, {1,seq_length}}] 147 | end 148 | 149 | local il = (i-1)*seq_per_img+1 150 | label_batch[{{il,il+seq_per_img-1} }] = seq 151 | img_id_batch[i] = self.image_ids[ix] 152 | end 153 | 154 | local data_augment = false 155 | if split == 'train' then 156 | data_augment = true 157 | end 158 | 159 | local h,w = img_batch_raw:size(3), img_batch_raw:size(4) 160 | local cnn_input_size = 224 161 | -- cropping data augmentation, if needed 162 | if h > cnn_input_size or w > cnn_input_size then 163 | local xoff, yoff 164 | if data_augment then 165 | xoff, yoff = torch.random(w-cnn_input_size), torch.random(h-cnn_input_size) 166 | else 167 | -- sample the center 168 | xoff, yoff = math.ceil((w-cnn_input_size)/2), math.ceil((h-cnn_input_size)/2) 169 | end 170 | -- crop. 171 | img_batch_raw = img_batch_raw[{ {}, {}, {yoff,yoff+cnn_input_size-1}, {xoff,xoff+cnn_input_size-1}}] 172 | end 173 | 174 | img_batch_raw = self.transform(img_batch_raw:float():div(255)) 175 | --img_batch_raw = img_batch_raw:float():div(255) 176 | 177 | local batch_data = {} 178 | batch_data.labels = label_batch:transpose(1,2):contiguous() 179 | batch_data.images = img_batch_raw 180 | batch_data.img_id = img_id_batch 181 | 182 | self.iterator[split] = self.iterator[split] + batch_size 183 | return batch_data 184 | end 185 | end 186 | 187 | -------------------------------------------------------------------------------- /misc/LSTM.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'nngraph' 3 | 4 | local LSTM = {} 5 | function LSTM.lstm(input_size, rnn_size, n, dropout) 6 | dropout = dropout or 0 7 | 8 | -- there will be 2*n+1 inputs 9 | local inputs = {} 10 | table.insert(inputs, nn.Identity()()) -- indices giving the sequence of symbols 11 | table.insert(inputs, nn.Identity()()) -- indices giving the image feature 12 | for L = 1,n do 13 | table.insert(inputs, nn.Identity()()) -- prev_c[L] 14 | table.insert(inputs, nn.Identity()()) -- prev_h[L] 15 | end 16 | 17 | local img_fc = inputs[2] 18 | 19 | local x, input_size_L, i2h, fake_region, atten_region 20 | local outputs = {} 21 | for L = 1,n do 22 | -- c,h from previos timesteps 23 | local prev_h = inputs[L*2+2] 24 | local prev_c = inputs[L*2+1] 25 | -- the input to this layer 26 | if L == 1 then 27 | x = inputs[1] 28 | input_size_L = input_size 29 | local w2h = nn.Linear(input_size_L, 4 * rnn_size)(x):annotate{name='t2h_'..L} 30 | local v2h = nn.Linear(input_size_L, 4 * rnn_size)(img_fc):annotate{name='v2h_'..L} 31 | i2h = nn.CAddTable()({w2h, v2h}) 32 | else 33 | x = outputs[(L-1)*2] 34 | if dropout > 0 then x = nn.Dropout(dropout)(x):annotate{name='drop_' .. L} end -- apply dropout, if any 35 | input_size_L = rnn_size 36 | i2h = nn.Linear(input_size_L, 4 * rnn_size)(x):annotate{name='i2h_'..L} 37 | end 38 | 39 | local h2h = nn.Linear(rnn_size, 4 * rnn_size)(prev_h):annotate{name='h2h_'..L} 40 | local all_input_sums = nn.CAddTable()({i2h, h2h}) 41 | 42 | local reshaped = nn.Reshape(4, rnn_size)(all_input_sums) 43 | local n1, n2, n3, n4 = nn.SplitTable(2)(reshaped):split(4) 44 | -- decode the gates 45 | local in_gate = nn.Sigmoid()(n1) 46 | local forget_gate = nn.Sigmoid()(n2) 47 | local out_gate = nn.Sigmoid()(n3) 48 | -- decode the write inputs 49 | local in_transform = nn.Tanh()(n4) 50 | -- perform the LSTM update 51 | local next_c = nn.CAddTable()({ 52 | nn.CMulTable()({forget_gate, prev_c}), 53 | nn.CMulTable()({in_gate, in_transform}) 54 | }) 55 | 56 | local tanh_nex_c = nn.Tanh()(next_c) 57 | -- gated cells form the output 58 | local next_h = nn.CMulTable()({out_gate,tanh_nex_c}) 59 | if L == n then 60 | if L==1 then 61 | local w2h = nn.Linear(input_size_L, 1 * rnn_size)(x) 62 | local v2h = nn.Linear(input_size_L, 1 * rnn_size)(img_fc) 63 | i2h = nn.CAddTable()({w2h, v2h}) 64 | else 65 | i2h = nn.Linear(input_size_L, rnn_size)(x) 66 | end 67 | local h2h = nn.Linear(rnn_size, rnn_size)(prev_h) 68 | local n5 = nn.CAddTable()({i2h, h2h}) 69 | 70 | fake_region = nn.CMulTable()({nn.Sigmoid()(n5), tanh_nex_c}) 71 | end 72 | 73 | table.insert(outputs, next_c) 74 | table.insert(outputs, next_h) 75 | end 76 | -- set up the decoder 77 | local top_h = nn.Identity()(outputs[#outputs]) 78 | if dropout > 0 then top_h = nn.Dropout(dropout)(top_h) end 79 | if dropout > 0 then fake_region = nn.Dropout(dropout)(fake_region) end 80 | 81 | table.insert(outputs, top_h) 82 | table.insert(outputs, fake_region) 83 | return nn.gModule(inputs, outputs) 84 | end 85 | 86 | return LSTM 87 | 88 | -------------------------------------------------------------------------------- /misc/LanguageModel.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'misc.LookupTableMaskZero' 3 | local utils = require 'misc.utils' 4 | local net_utils = require 'misc.net_utils' 5 | local LSTM = require 'misc.LSTM' 6 | 7 | local attention = require 'misc.attention' 8 | local img_embedding = require 'misc.img_embedding' 9 | 10 | ------------------------------------------------------------------------------- 11 | -- Language Model core 12 | ------------------------------------------------------------------------------- 13 | 14 | local layer, parent = torch.class('nn.LanguageModel', 'nn.Module') 15 | function layer:__init(opt) 16 | parent.__init(self) 17 | 18 | -- options for core network 19 | self.vocab_size = utils.getopt(opt, 'vocab_size') -- required 20 | self.input_encoding_size = utils.getopt(opt, 'input_encoding_size') 21 | self.n_rnn_layer = utils.getopt(opt, 'n_rnn_layer', 1) 22 | 23 | self.rnn_size = utils.getopt(opt, 'rnn_size') 24 | self.num_layers = utils.getopt(opt, 'num_layers', 1) 25 | local dropout = utils.getopt(opt, 'dropout', 0) 26 | 27 | self.fc_size = utils.getopt(opt, 'fc_size', 4096) 28 | self.conv_size = utils.getopt(opt, 'conv_size', 512) 29 | 30 | -- options for Language Model 31 | self.seq_length = utils.getopt(opt, 'seq_length') 32 | 33 | print('rnn_size: ' .. self.rnn_size .. ' num_layers: ' .. self.num_layers) 34 | print('input_encoding_size: ' .. self.input_encoding_size) 35 | print('dropout rate: ' .. dropout) 36 | 37 | -- create the core lstm network. note +1 for both the START and END tokens 38 | self.core = LSTM.lstm(self.input_encoding_size, self.rnn_size, self.num_layers, dropout) 39 | 40 | self.lookup_table = nn.Sequential() 41 | :add(nn.LookupTableMaskZero(self.vocab_size+1, self.input_encoding_size)) 42 | :add(nn.ReLU()) 43 | :add(nn.Dropout(dropout)) 44 | 45 | self.img_embedding = img_embedding.img_embedding(self.input_encoding_size, self.fc_size, self.conv_size, 49, dropout) 46 | 47 | self.attention = attention.attention(self.input_encoding_size, self.rnn_size, self.vocab_size+1, dropout) 48 | 49 | self:_createInitState(1) -- will be lazily resized later during forward passes 50 | end 51 | 52 | function layer:_createInitState(batch_size) 53 | assert(batch_size ~= nil, 'batch size must be provided') 54 | -- construct the initial state for the LSTM 55 | if not self.init_state then self.init_state = {} end -- lazy init 56 | for h=1,self.num_layers*2 do 57 | -- note, the init state Must be zeros because we are using init_state to init grads in backward call too 58 | if self.init_state[h] then 59 | if self.init_state[h]:size(1) ~= batch_size then 60 | self.init_state[h]:resize(batch_size, self.rnn_size):zero() -- expand the memory 61 | end 62 | else 63 | self.init_state[h] = torch.zeros(batch_size, self.rnn_size) 64 | end 65 | end 66 | self.num_state = #self.init_state 67 | end 68 | 69 | 70 | function layer:createClones() 71 | -- construct the net clones 72 | print('constructing clones inside the LanguageModel') 73 | self.clones = {self.core} 74 | self.lookup_tables = {self.lookup_table} 75 | self.attentions = {self.attention} 76 | for t=2,self.seq_length+1 do 77 | self.clones[t] = self.core:clone('weight', 'bias', 'gradWeight', 'gradBias') 78 | self.lookup_tables[t] = self.lookup_table:clone('weight', 'gradWeight') 79 | self.attentions[t] = self.attention:clone('weight', 'bias', 'gradWeight', 'gradBias') 80 | end 81 | end 82 | 83 | 84 | function layer:getModulesList() 85 | return {self.core, self.lookup_table, self.img_embedding, self.attention} 86 | end 87 | 88 | function layer:parameters() 89 | -- we only have two internal modules, return their params 90 | local p1,g1 = self.core:parameters() 91 | local p2,g2 = self.lookup_table:parameters() 92 | local p3,g3 = self.img_embedding:parameters() 93 | local p4,g4 = self.attention:parameters() 94 | 95 | 96 | local params = {} 97 | for k,v in pairs(p1) do table.insert(params, v) end 98 | for k,v in pairs(p2) do table.insert(params, v) end 99 | for k,v in pairs(p3) do table.insert(params, v) end 100 | for k,v in pairs(p4) do table.insert(params, v) end 101 | 102 | local grad_params = {} 103 | for k,v in pairs(g1) do table.insert(grad_params, v) end 104 | for k,v in pairs(g2) do table.insert(grad_params, v) end 105 | for k,v in pairs(g3) do table.insert(grad_params, v) end 106 | for k,v in pairs(g4) do table.insert(grad_params, v) end 107 | 108 | return params, grad_params 109 | end 110 | 111 | function layer:training() 112 | for k,v in pairs(self.clones) do v:training() end 113 | for k,v in pairs(self.lookup_tables) do v:training() end 114 | for k,v in pairs(self.attentions) do v:training() end 115 | self.img_embedding:training() 116 | 117 | end 118 | 119 | function layer:evaluate() 120 | for k,v in pairs(self.clones) do v:evaluate() end 121 | for k,v in pairs(self.lookup_tables) do v:evaluate() end 122 | for k,v in pairs(self.attentions) do v:evaluate() end 123 | self.img_embedding:evaluate() 124 | end 125 | 126 | function layer:sample(inputs, opt) 127 | local conv = inputs[1] 128 | local fc = inputs[2] 129 | local ix_to_word = inputs[3] 130 | 131 | 132 | local sample_max = utils.getopt(opt, 'sample_max', 1) 133 | local beam_size = utils.getopt(opt, 'beam_size', 1) 134 | 135 | local temperature = utils.getopt(opt, 'temperature', 1.0) 136 | 137 | local batch_size = fc:size(1) 138 | 139 | if sample_max == 1 and beam_size > 1 then return self:sample_beam(inputs, opt) end -- indirection for beam search 140 | 141 | self:_createInitState(batch_size) 142 | local state = self.init_state 143 | 144 | local img_input = {conv, fc} 145 | local conv_feat, conv_feat_embed, fc_embed = unpack(self.img_embedding:forward(img_input)) 146 | 147 | -- we will write output predictions into tensor seq 148 | local seq = torch.LongTensor(self.seq_length, batch_size):zero() 149 | local seqLogprobs = torch.FloatTensor(self.seq_length, batch_size) 150 | 151 | local logprobs -- logprobs predicted in last time step 152 | local x_xt 153 | 154 | for t=1,self.seq_length+1 do 155 | local xt, it, sampleLogprobs 156 | if t == 1 then 157 | it = torch.LongTensor(batch_size):fill(self.vocab_size+1) 158 | xt = self.lookup_table:forward(it) 159 | else 160 | -- take predictions from previous time step and feed them in 161 | if sample_max == 1 then 162 | -- use argmax "sampling" 163 | sampleLogprobs, it = torch.max(logprobs, 2) 164 | it = it:view(-1):long() 165 | else 166 | -- sample from the distribution of previous predictions 167 | local prob_prev 168 | if temperature == 1.0 then 169 | prob_prev = torch.exp(logprobs) -- fetch prev distribution: shape Nx(M+1) 170 | else 171 | -- scale logprobs by temperature 172 | prob_prev = torch.exp(torch.div(logprobs, temperature)) 173 | end 174 | it = torch.multinomial(prob_prev, 1) 175 | sampleLogprobs = logprobs:gather(2, it) -- gather the logprobs at sampled positions 176 | it = it:view(-1):long() -- and flatten indices for downstream processing 177 | end 178 | xt = self.lookup_table:forward(it) 179 | end 180 | 181 | if t >= 2 then 182 | seq[t-1] = it -- record the samples 183 | seqLogprobs[t-1] = sampleLogprobs:view(-1):float() -- and also their log likelihoods 184 | end 185 | 186 | local inputs = {xt,fc_embed, unpack(state)} 187 | local out = self.core:forward(inputs) 188 | state = {} 189 | for i=1,self.num_state do table.insert(state, out[i]) end 190 | 191 | local h_out = out[self.num_state+1] 192 | local p_out = out[self.num_state+2] 193 | 194 | local atten_input = {h_out, p_out, conv_feat, conv_feat_embed} 195 | logprobs = self.attention:forward(atten_input) 196 | 197 | end 198 | -- return the samples and their log likelihoods 199 | return seq, seqLogprobs 200 | end 201 | 202 | 203 | function layer:sample_beam(inputs, opt) 204 | local beam_size = utils.getopt(opt, 'beam_size', 10) 205 | 206 | local conv = inputs[1] 207 | local fc = inputs[2] 208 | local ix_to_word = inputs[3] 209 | 210 | local batch_size = fc:size(1) 211 | local function compare(a,b) return a.p > b.p end -- used downstream 212 | 213 | assert(beam_size <= self.vocab_size+1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed') 214 | 215 | local img_input = {conv, fc} 216 | local conv_feat, conv_feat_embed, fc_embed = unpack(self.img_embedding:forward(img_input)) 217 | 218 | local seq = torch.LongTensor(self.seq_length, batch_size):zero() 219 | local seqLogprobs = torch.FloatTensor(self.seq_length, batch_size) 220 | local seqLogprobs_sum = torch.FloatTensor(batch_size) 221 | 222 | -- lets process every image independently for now, for simplicity 223 | for k=1,batch_size do 224 | 225 | -- create initial states for all beams 226 | self:_createInitState(beam_size) 227 | local state = self.init_state 228 | 229 | -- we will write output predictions into tensor seq 230 | local beam_seq = torch.LongTensor(self.seq_length, beam_size):zero() 231 | local beam_seq_logprobs = torch.FloatTensor(self.seq_length, beam_size):zero() 232 | local beam_logprobs_sum = torch.zeros(beam_size) -- running sum of logprobs for each beam 233 | local logprobs -- logprobs predicted in last time step, shape (beam_size, vocab_size+1) 234 | local done_beams = {} 235 | local imgk = fc_embed[{ {k,k} }]:expand(beam_size, self.input_encoding_size) -- k'th image feature expanded out 236 | local conv_feat_k = conv_feat[{ {k,k} }]:expand(beam_size, conv_feat:size(2), self.input_encoding_size) -- k'th image feature expanded out 237 | local conv_feat_embed_k = conv_feat_embed[{ {k,k} }]:expand(beam_size, conv_feat_embed:size(2), self.input_encoding_size) -- k'th image feature expanded out 238 | 239 | for t=1,self.seq_length+1 do 240 | 241 | local xt, it, sampleLogprobs 242 | local new_state 243 | if t == 1 then 244 | -- feed in the start tokens 245 | it = torch.LongTensor(beam_size):fill(self.vocab_size+1) 246 | xt = self.lookup_table:forward(it) 247 | else 248 | --[[ 249 | perform a beam merge. that is, 250 | for every previous beam we now many new possibilities to branch out 251 | we need to resort our beams to maintain the loop invariant of keeping 252 | the top beam_size most likely sequences. 253 | ]]-- 254 | local logprobsf = logprobs:float() -- lets go to CPU for more efficiency in indexing operations 255 | ys,ix = torch.sort(logprobsf,2,true) -- sorted array of logprobs along each previous beam (last true = descending) 256 | local candidates = {} 257 | local cols = math.min(beam_size,ys:size(2)) 258 | local rows = beam_size 259 | if t == 2 then rows = 1 end -- at first time step only the first beam is active 260 | for c=1,cols do -- for each column (word, essentially) 261 | for q=1,rows do -- for each beam expansion 262 | -- compute logprob of expanding beam q with word in (sorted) position c 263 | local local_logprob = ys[{ q,c }] 264 | local candidate_logprob = beam_logprobs_sum[q] + local_logprob 265 | table.insert(candidates, {c=ix[{ q,c }], q=q, p=candidate_logprob, r=local_logprob }) 266 | end 267 | end 268 | table.sort(candidates, compare) -- find the best c,q pairs 269 | 270 | -- construct new beams 271 | new_state = net_utils.clone_list(state) 272 | local beam_seq_prev, beam_seq_logprobs_prev 273 | if t > 2 then 274 | -- well need these as reference when we fork beams around 275 | beam_seq_prev = beam_seq[{ {1,t-2}, {} }]:clone() 276 | beam_seq_logprobs_prev = beam_seq_logprobs[{ {1,t-2}, {} }]:clone() 277 | end 278 | 279 | for vix=1,beam_size do 280 | local v = candidates[vix] 281 | -- fork beam index q into index vix 282 | if t > 2 then 283 | beam_seq[{ {1,t-2}, vix }] = beam_seq_prev[{ {}, v.q }] 284 | beam_seq_logprobs[{ {1,t-2}, vix }] = beam_seq_logprobs_prev[{ {}, v.q }] 285 | end 286 | -- rearrange recurrent states 287 | for state_ix = 1,#new_state do 288 | -- copy over state in previous beam q to new beam at vix 289 | new_state[state_ix][vix] = state[state_ix][v.q] 290 | end 291 | -- append new end terminal at the end of this beam 292 | beam_seq[{ t-1, vix }] = v.c -- c'th word is the continuation 293 | beam_seq_logprobs[{ t-1, vix }] = v.r -- the raw logprob here 294 | beam_logprobs_sum[vix] = v.p -- the new (sum) logprob along this beam 295 | 296 | if v.c == self.vocab_size+1 or t == self.seq_length+1 then 297 | -- END token special case here, or we reached the end. 298 | -- add the beam to a set of done beams 299 | table.insert(done_beams, {seq = beam_seq[{ {}, vix }]:clone(), 300 | logps = beam_seq_logprobs[{ {}, vix }]:clone(), 301 | p = beam_logprobs_sum[vix] 302 | }) 303 | end 304 | end 305 | 306 | -- encode as vectors 307 | it = beam_seq[t-1] 308 | xt = self.lookup_table:forward(it) 309 | end 310 | 311 | if new_state then state = new_state end -- swap rnn state, if we reassinged beams 312 | 313 | local inputs = {xt,imgk,unpack(state)} 314 | local out = self.core:forward(inputs) 315 | state = {} 316 | for i=1,self.num_state do table.insert(state, out[i]) end 317 | local h_out = out[self.num_state+1] 318 | local p_out = out[self.num_state+2] 319 | local atten_input = {h_out, p_out, conv_feat_k, conv_feat_embed_k} 320 | logprobs = self.attention:forward(atten_input) 321 | 322 | end 323 | 324 | table.sort(done_beams, compare) 325 | seq[{ {}, k }] = done_beams[1].seq -- the first beam has highest cumulative score 326 | seqLogprobs[{ {}, k }] = done_beams[1].logps 327 | seqLogprobs_sum[k]=done_beams[1].p 328 | end 329 | 330 | -- return the samples and their log likelihoods 331 | return seq, seqLogprobs_sum 332 | end 333 | 334 | function layer:updateOutput(input) 335 | local conv = input[1] 336 | local fc = input[2] 337 | local seq = input[3] 338 | 339 | assert(seq:size(1) == self.seq_length) 340 | local batch_size = seq:size(2) 341 | 342 | self:_createInitState(batch_size) 343 | 344 | -- first get the nearest neighbor representation. 345 | self.output:resize(self.seq_length+1, batch_size, self.vocab_size+1):zero() 346 | 347 | self.img_input = {conv, fc} 348 | self.conv_feat, self.conv_feat_embed, self.fc_embed = unpack(self.img_embedding:forward(self.img_input)) 349 | 350 | self.state = {[0] = self.init_state} 351 | self.inputs = {} 352 | self.atten_inputs = {} 353 | --self.x_inputs = {} 354 | self.lookup_tables_inputs = {} 355 | self.tmax = 0 -- we will keep track of max sequence length encountered in the data for efficiency 356 | 357 | for t = 1,self.seq_length+1 do 358 | local can_skip = false 359 | local xt 360 | if t == 1 then 361 | -- feed in the start tokens 362 | local it = torch.LongTensor(batch_size):fill(self.vocab_size+1) 363 | self.lookup_tables_inputs[t] = it 364 | xt = self.lookup_table:forward(it) -- NxK sized input (token embedding vectors) 365 | else 366 | -- feed in the rest of the sequence... 367 | local it = seq[t-1]:clone() 368 | if torch.sum(it) == 0 then 369 | can_skip = true 370 | end 371 | 372 | if not can_skip then 373 | self.lookup_tables_inputs[t] = it 374 | xt = self.lookup_tables[t]:forward(it) 375 | end 376 | end 377 | 378 | if not can_skip then 379 | -- construct the inputs 380 | self.inputs[t] = {xt, self.fc_embed, unpack(self.state[t-1])} 381 | -- forward the network 382 | local out = self.clones[t]:forward(self.inputs[t]) 383 | -- insert the hidden state 384 | self.state[t] = {} -- the rest is state 385 | for i=1,self.num_state do table.insert(self.state[t], out[i]) end 386 | local h_out = out[self.num_state+1] 387 | local p_out = out[self.num_state+2] 388 | 389 | --forward the attention 390 | self.atten_inputs[t] = {h_out, p_out, self.conv_feat, self.conv_feat_embed} 391 | local atten_out = self.attentions[t]:forward(self.atten_inputs[t]) 392 | 393 | self.output:narrow(1,t,1):copy(atten_out) 394 | self.tmax = t 395 | end 396 | end 397 | 398 | return self.output 399 | end 400 | 401 | function layer:updateGradInput(input, gradOutput) 402 | local dconv, dconv_embed, dfc-- grad on input images 403 | 404 | local batch_size = self.output:size(2) 405 | -- go backwards and lets compute gradients 406 | local dstate = self.init_state -- this works when init_state is all zeros 407 | 408 | for t=self.tmax,1,-1 do 409 | 410 | local d_atten = self.attentions[t]:backward(self.atten_inputs[t], gradOutput[t]) 411 | if not dconv then dconv = d_atten[3] else dconv:add(d_atten[3]) end 412 | if not dconv_embed then dconv_embed = d_atten[4] else dconv_embed:add(d_atten[4]) end 413 | 414 | local dout = {} 415 | for k=1, self.num_state do table.insert(dout, dstate[k]) end 416 | table.insert(dout, d_atten[1]) 417 | table.insert(dout, d_atten[2]) 418 | 419 | local dinputs = self.clones[t]:backward(self.inputs[t], dout) 420 | 421 | local dxt = dinputs[1] -- first element is the input vector 422 | if not dfc then dfc = dinputs[2] else dfc:add(dinputs[2]) end 423 | 424 | dstate = {} -- copy over rest to state grad 425 | for k=3,self.num_state+2 do table.insert(dstate, dinputs[k]) end 426 | 427 | -- continue backprop of xt 428 | local it = self.lookup_tables_inputs[t] 429 | self.lookup_tables[t]:backward(it, dxt) -- backprop into lookup table 430 | end 431 | 432 | -- backprob to the visual features. 433 | local dimgs_cnn, dfc_cnn = unpack(self.img_embedding:backward(self.img_input, {dconv, dconv_embed, dfc})) 434 | 435 | self.gradInput = {dimgs_cnn, dfc_cnn} 436 | return self.gradInput 437 | end 438 | 439 | ------------------------------------------------------------------------------- 440 | -- Language Model-aware Criterion 441 | ------------------------------------------------------------------------------- 442 | 443 | local crit, parent = torch.class('nn.LanguageModelCriterion', 'nn.Criterion') 444 | function crit:__init() 445 | parent.__init(self) 446 | end 447 | 448 | function crit:updateOutput(inputs) 449 | local input = inputs[1] 450 | local seq = inputs[2] 451 | --local seq_len = inputs[3] 452 | 453 | local L,N,Mp1 = input:size(1), input:size(2), input:size(3) 454 | local D = seq:size(1) 455 | assert(D == L-1, 'input Tensor should be 1 larger in time') 456 | 457 | self.gradInput:resizeAs(input):zero() 458 | local loss = 0 459 | local n = 0 460 | for b=1,N do -- iterate over batches 461 | local first_time = true 462 | for t=1,L do -- iterate over sequence time (ignore t=1, dummy forward for the image) 463 | -- fetch the index of the next token in the sequence 464 | local target_index 465 | if t > D then -- we are out of bounds of the index sequence: pad with null tokens 466 | target_index = 0 467 | else 468 | target_index = seq[{t,b}] 469 | end 470 | -- the first time we see null token as next index, actually want the model to predict the END token 471 | if target_index == 0 and first_time then 472 | target_index = Mp1 473 | first_time = false 474 | end 475 | 476 | -- if there is a non-null next token, enforce loss! 477 | if target_index ~= 0 then 478 | -- accumulate loss 479 | loss = loss - input[{ t,b,target_index }] -- log(p) 480 | self.gradInput[{ t,b,target_index }] = -1 481 | n = n + 1 482 | end 483 | end 484 | end 485 | self.output = loss / n -- normalize by number of predictions that were made 486 | self.gradInput:div(n) 487 | 488 | return self.output 489 | end 490 | 491 | function crit:updateGradInput(inputs) 492 | return self.gradInput 493 | end 494 | -------------------------------------------------------------------------------- /misc/LookupTableMaskZero.lua: -------------------------------------------------------------------------------- 1 | local LookupTableMaskZero, parent = torch.class('nn.LookupTableMaskZero', 'nn.LookupTable') 2 | 3 | function LookupTableMaskZero:__init(nIndex, nOutput) 4 | parent.__init(self, nIndex + 1, nOutput) 5 | end 6 | 7 | function LookupTableMaskZero:updateOutput(input) 8 | self.weight[1]:zero() 9 | if self.__input and (torch.type(self.__input) ~= torch.type(input)) then 10 | self.__input = nil -- fixes old casting bug 11 | end 12 | self.__input = self.__input or input.new() 13 | self.__input:resizeAs(input):add(input, 1) 14 | return parent.updateOutput(self, self.__input) 15 | end 16 | 17 | function LookupTableMaskZero:accGradParameters(input, gradOutput, scale) 18 | parent.accGradParameters(self, self.__input, gradOutput, scale) 19 | end 20 | 21 | function LookupTableMaskZero:type(type, cache) 22 | self.__input = nil 23 | return parent.type(self, type, cache) 24 | end -------------------------------------------------------------------------------- /misc/attention.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'nngraph' 3 | 4 | local attention = {} 5 | function attention.attention(input_size, rnn_size, output_size, dropout) 6 | local inputs = {} 7 | local outputs = {} 8 | table.insert(inputs, nn.Identity()()) -- top_h 9 | table.insert(inputs, nn.Identity()()) -- fake_region 10 | table.insert(inputs, nn.Identity()()) -- conv_feat 11 | table.insert(inputs, nn.Identity()()) -- conv_feat_embed 12 | 13 | local h_out = inputs[1] 14 | local fake_region = inputs[2] 15 | local conv_feat = inputs[3] 16 | local conv_feat_embed = inputs[4] 17 | 18 | local fake_region = nn.ReLU()(nn.Linear(rnn_size, input_size)(fake_region)) 19 | -- view neighbor from bach_size * neighbor_num x rnn_size to bach_size x rnn_size * neighbor_num 20 | if dropout > 0 then fake_region = nn.Dropout(dropout)(fake_region) end 21 | 22 | local fake_region_embed = nn.Linear(input_size, input_size)(fake_region) 23 | 24 | local h_out_linear = nn.Tanh()(nn.Linear(rnn_size, input_size)(h_out)) 25 | if dropout > 0 then h_out_linear = nn.Dropout(dropout)(h_out_linear) end 26 | 27 | local h_out_embed = nn.Linear(input_size, input_size)(h_out_linear) 28 | 29 | local txt_replicate = nn.Replicate(50,2)(h_out_embed) 30 | 31 | local img_all = nn.JoinTable(2)({nn.View(-1,1,input_size)(fake_region), conv_feat}) 32 | local img_all_embed = nn.JoinTable(2)({nn.View(-1,1,input_size)(fake_region_embed), conv_feat_embed}) 33 | 34 | local hA = nn.Tanh()(nn.CAddTable()({img_all_embed, txt_replicate})) 35 | if dropout > 0 then hA = nn.Dropout(dropout)(hA) end 36 | local hAflat = nn.Linear(input_size,1)(nn.View(input_size):setNumInputDims(2)(hA)) 37 | local PI = nn.SoftMax()(nn.View(50):setNumInputDims(2)(hAflat)) 38 | 39 | local probs3dim = nn.View(1,-1):setNumInputDims(1)(PI) 40 | local visAtt = nn.MM(false, false)({probs3dim, img_all}) 41 | local visAttdim = nn.View(input_size):setNumInputDims(2)(visAtt) 42 | local atten_out = nn.CAddTable()({visAttdim, h_out_linear}) 43 | 44 | local h = nn.Tanh()(nn.Linear(input_size, input_size)(atten_out)) 45 | if dropout > 0 then h = nn.Dropout(dropout)(h) end 46 | local proj = nn.Linear(input_size, output_size)(h) 47 | 48 | local logsoft = nn.LogSoftMax()(proj) 49 | --local logsoft = nn.SoftMax()(proj) 50 | table.insert(outputs, logsoft) 51 | 52 | return nn.gModule(inputs, outputs) 53 | end 54 | return attention -------------------------------------------------------------------------------- /misc/call_python_caption_eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd coco-caption 4 | python myeval.py $1 $2 5 | cd ../ 6 | -------------------------------------------------------------------------------- /misc/img_embedding.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'nngraph' 3 | 4 | local img_embedding = {} 5 | 6 | function img_embedding.img_embedding(hidden_size, fc_size, conv_size, conv_num, dropout) 7 | local inputs = {} 8 | local outputs = {} 9 | 10 | table.insert(inputs, nn.Identity()()) -- image feature 11 | table.insert(inputs, nn.Identity()()) 12 | 13 | local conv_feat = inputs[1] 14 | local fc_feat = inputs[2] 15 | 16 | -- embed the fc7 feature -- dropout here? 17 | local fc_feat_out = nn.ReLU()(nn.Linear(fc_size, hidden_size)(fc_feat)) 18 | if dropout > 0 then fc_feat_out = nn.Dropout(dropout)(fc_feat_out) end 19 | 20 | -- embed the conv feature 21 | local conv_feat_embed = nn.Linear(conv_size, hidden_size)(nn.View(conv_size):setNumInputDims(2)(conv_feat)) 22 | 23 | local conv_feat_out = nn.ReLU()(conv_feat_embed) 24 | if dropout > 0 then conv_feat_out = nn.Dropout(dropout)(conv_feat_out) end 25 | 26 | local conv_feat_back = nn.View(-1, conv_num, hidden_size)(conv_feat_out) 27 | 28 | 29 | local img_feat_dim = nn.View(hidden_size):setNumInputDims(2)(conv_feat_back) 30 | local embed_feat = nn.Linear(hidden_size, hidden_size)(img_feat_dim) 31 | local embed_feat_out = nn.View(-1, conv_num, hidden_size)(embed_feat) 32 | 33 | 34 | table.insert(outputs, conv_feat_back) 35 | table.insert(outputs, embed_feat_out) 36 | table.insert(outputs, fc_feat_out) 37 | 38 | return nn.gModule(inputs, outputs) 39 | 40 | end 41 | 42 | return img_embedding 43 | 44 | -------------------------------------------------------------------------------- /misc/net_utils.lua: -------------------------------------------------------------------------------- 1 | local utils = require 'misc.utils' 2 | local net_utils = {} 3 | 4 | 5 | 6 | -- take a raw CNN from Caffe and perform surgery. Note: VGG-16 SPECIFIC! 7 | function net_utils.build_cnn_conv_fix(cnn, opt) 8 | local layer_num = utils.getopt(opt, 'layer_num', 10) 9 | local backend = utils.getopt(opt, 'backend', 'cudnn') 10 | local encoding_size = utils.getopt(opt, 'encoding_size', 512) 11 | 12 | if backend == 'cudnn' then 13 | require 'cudnn' 14 | backend = cudnn 15 | elseif backend == 'nn' then 16 | require 'nn' 17 | backend = nn 18 | else 19 | error(string.format('Unrecognized backend "%s"', backend)) 20 | end 21 | 22 | -- copy over the first layer_num layers of the CNN 23 | local cnn_part = nn.Sequential() 24 | for i = 1, layer_num do 25 | local layer = cnn:get(i) 26 | if i == 1 then 27 | -- convert kernels in first conv layer into RGB format instead of BGR, 28 | -- which is the order in which it was trained in Caffe 29 | local w = layer.weight:clone() 30 | -- swap weights to R and B channels 31 | print('converting first layer conv filters from BGR to RGB...') 32 | layer.weight[{ {}, 1, {}, {} }]:copy(w[{ {}, 3, {}, {} }]) 33 | layer.weight[{ {}, 3, {}, {} }]:copy(w[{ {}, 1, {}, {} }]) 34 | end 35 | 36 | cnn_part:add(layer) 37 | end 38 | 39 | return cnn_part 40 | end 41 | 42 | 43 | 44 | -- take a raw CNN from Caffe and perform surgery. Note: VGG-16 SPECIFIC! 45 | function net_utils.build_cnn_conv(cnn, opt) 46 | local layer_num_start = utils.getopt(opt, 'layer_num_start', 11) 47 | local layer_num = utils.getopt(opt, 'layer_num', 37) 48 | local backend = utils.getopt(opt, 'backend', 'cudnn') 49 | local encoding_size = utils.getopt(opt, 'encoding_size', 512) 50 | 51 | if backend == 'cudnn' then 52 | require 'cudnn' 53 | backend = cudnn 54 | elseif backend == 'nn' then 55 | require 'nn' 56 | backend = nn 57 | else 58 | error(string.format('Unrecognized backend "%s"', backend)) 59 | end 60 | 61 | -- copy over the first layer_num layers of the CNN 62 | local cnn_part = nn.Sequential() 63 | for i = layer_num_start, layer_num do 64 | local layer = cnn:get(i) 65 | cnn_part:add(layer) 66 | end 67 | 68 | return cnn_part 69 | end 70 | 71 | function net_utils.build_residual_cnn_conv_fix(cnn, opt) 72 | local layer_num = utils.getopt(opt, 'start_layer_num', 6) 73 | local backend = utils.getopt(opt, 'backend', 'cudnn') 74 | local encoding_size = utils.getopt(opt, 'encoding_size', 512) 75 | 76 | if backend == 'cudnn' then 77 | require 'cudnn' 78 | backend = cudnn 79 | elseif backend == 'nn' then 80 | require 'nn' 81 | backend = nn 82 | else 83 | error(string.format('Unrecognized backend "%s"', backend)) 84 | end 85 | 86 | -- copy over the first layer_num layers of the CNN 87 | local cnn_part = nn.Sequential() 88 | for i = 1, layer_num-1 do 89 | local layer = cnn:get(i) 90 | cnn_part:add(layer) 91 | end 92 | --cnn_part:add(nn.View(512, -1):setNumInputDims(3)) 93 | --cnn_part:add(nn.Transpose({2,3})) 94 | return cnn_part 95 | end 96 | 97 | 98 | -- take a raw CNN from Caffe and perform surgery. Note: VGG-16 SPECIFIC! 99 | function net_utils.build_residual_cnn_conv(cnn, opt) 100 | local start_layer_num = utils.getopt(opt, 'start_layer_num', 6) 101 | local layer_num = utils.getopt(opt, 'layer_num', 8) 102 | local backend = utils.getopt(opt, 'backend', 'cudnn') 103 | 104 | if backend == 'cudnn' then 105 | require 'cudnn' 106 | backend = cudnn 107 | elseif backend == 'nn' then 108 | require 'nn' 109 | backend = nn 110 | else 111 | error(string.format('Unrecognized backend "%s"', backend)) 112 | end 113 | 114 | -- copy over the first layer_num layers of the CNN 115 | local cnn_part = nn.Sequential() 116 | for i = start_layer_num, layer_num do 117 | local layer = cnn:get(i) 118 | cnn_part:add(layer) 119 | end 120 | 121 | --cnn_part:add(nn.View(512, -1):setNumInputDims(3)) 122 | --cnn_part:add(nn.Transpose({2,3})) 123 | return cnn_part 124 | end 125 | 126 | 127 | function net_utils.transform_cnn_conv(nDim) 128 | 129 | local cnn_part = nn.Sequential() 130 | 131 | cnn_part:add(nn.View(nDim, -1):setNumInputDims(3)) 132 | cnn_part:add(nn.Transpose({2,3})) 133 | return cnn_part 134 | end 135 | 136 | 137 | -- take a raw CNN from Caffe and perform surgery. Note: VGG-16 SPECIFIC! 138 | function net_utils.build_cnn_fc(cnn, opt) 139 | local layer_num_start = utils.getopt(opt, 'layer_num', 38) 140 | local layer_num_end = utils.getopt(opt, 'layer_num', 43) 141 | local backend = utils.getopt(opt, 'backend', 'cudnn') 142 | 143 | if backend == 'cudnn' then 144 | require 'cudnn' 145 | backend = cudnn 146 | elseif backend == 'nn' then 147 | require 'nn' 148 | backend = nn 149 | else 150 | error(string.format('Unrecognized backend "%s"', backend)) 151 | end 152 | 153 | -- copy over the first layer_num layers of the CNN 154 | local cnn_part = nn.Sequential() 155 | for i = layer_num_start, layer_num_end do 156 | local layer = cnn:get(i) 157 | 158 | cnn_part:add(layer) 159 | end 160 | 161 | return cnn_part 162 | end 163 | 164 | function net_utils.build_residual_cnn_fc(cnn, opt) 165 | local layer_num_start = utils.getopt(opt, 'layer_num', 9) 166 | local layer_num_end = utils.getopt(opt, 'layer_num', 10) 167 | local backend = utils.getopt(opt, 'backend', 'cudnn') 168 | 169 | if backend == 'cudnn' then 170 | require 'cudnn' 171 | backend = cudnn 172 | elseif backend == 'nn' then 173 | require 'nn' 174 | backend = nn 175 | else 176 | error(string.format('Unrecognized backend "%s"', backend)) 177 | end 178 | 179 | -- copy over the first layer_num layers of the CNN 180 | local cnn_part = nn.Sequential() 181 | for i = layer_num_start, layer_num_end do 182 | local layer = cnn:get(i) 183 | 184 | cnn_part:add(layer) 185 | end 186 | 187 | return cnn_part 188 | end 189 | 190 | -- layer that expands features out so we can forward multiple sentences per image 191 | local layer, parent = torch.class('nn.FeatExpander', 'nn.Module') 192 | function layer:__init(n) 193 | parent.__init(self) 194 | self.n = n 195 | end 196 | function layer:updateOutput(input) 197 | if self.n == 1 then self.output = input; return self.output end -- act as a noop for efficiency 198 | -- simply expands out the features. Performs a copy information 199 | assert(input:nDimension() == 2) 200 | local d = input:size(2) 201 | self.output:resize(input:size(1)*self.n, d) 202 | for k=1,input:size(1) do 203 | local j = (k-1)*self.n+1 204 | self.output[{ {j,j+self.n-1} }] = input[{ {k,k}, {} }]:expand(self.n, d) -- copy over 205 | end 206 | return self.output 207 | end 208 | function layer:updateGradInput(input, gradOutput) 209 | if self.n == 1 then self.gradInput = gradOutput; return self.gradInput end -- act as noop for efficiency 210 | -- add up the gradients for each block of expanded features 211 | self.gradInput:resizeAs(input) 212 | local d = input:size(2) 213 | for k=1,input:size(1) do 214 | local j = (k-1)*self.n+1 215 | self.gradInput[k] = torch.sum(gradOutput[{ {j,j+self.n-1} }], 1) 216 | end 217 | return self.gradInput 218 | end 219 | 220 | local layer, parent = torch.class('nn.FeatExpanderConv', 'nn.Module') 221 | function layer:__init(n) 222 | parent.__init(self) 223 | self.n = n 224 | end 225 | function layer:updateOutput(input) 226 | if self.n == 1 then self.output = input; return self.output end -- act as a noop for efficiency 227 | -- simply expands out the features. Performs a copy information 228 | local s = input:size(2) 229 | local d = input:size(3) 230 | self.output:resize(input:size(1)*self.n, s, d) 231 | for k=1,input:size(1) do 232 | local j = (k-1)*self.n+1 233 | self.output[{ {j,j+self.n-1} }] = input[{ {k,k}, {} }]:expand(self.n, s, d) -- copy over 234 | end 235 | return self.output 236 | end 237 | function layer:updateGradInput(input, gradOutput) 238 | if self.n == 1 then self.gradInput = gradOutput; return self.gradInput end -- act as noop for efficiency 239 | -- add up the gradients for each block of expanded features 240 | self.gradInput:resizeAs(input) 241 | local d = input:size(2) 242 | for k=1,input:size(1) do 243 | local j = (k-1)*self.n+1 244 | self.gradInput[k] = torch.sum(gradOutput[{ {j,j+self.n-1} }], 1) 245 | end 246 | return self.gradInput 247 | end 248 | 249 | 250 | function net_utils.list_nngraph_modules(g) 251 | local omg = {} 252 | for i,node in ipairs(g.forwardnodes) do 253 | local m = node.data.module 254 | if m then 255 | table.insert(omg, m) 256 | end 257 | end 258 | return omg 259 | end 260 | function net_utils.listModules(net) 261 | -- torch, our relationship is a complicated love/hate thing. And right here it's the latter 262 | local t = torch.type(net) 263 | local moduleList 264 | if t == 'nn.gModule' then 265 | moduleList = net_utils.list_nngraph_modules(net) 266 | else 267 | moduleList = net:listModules() 268 | end 269 | return moduleList 270 | end 271 | function net_utils.sanitize_gradients(net) 272 | local moduleList = net_utils.listModules(net) 273 | for k,m in ipairs(moduleList) do 274 | if m.weight and m.gradWeight then 275 | --print('sanitizing gradWeight in of size ' .. m.gradWeight:nElement()) 276 | --print(m.weight:size()) 277 | m.gradWeight = nil 278 | end 279 | if m.bias and m.gradBias then 280 | --print('sanitizing gradWeight in of size ' .. m.gradBias:nElement()) 281 | --print(m.bias:size()) 282 | m.gradBias = nil 283 | end 284 | end 285 | end 286 | 287 | function net_utils.unsanitize_gradients(net) 288 | local moduleList = net_utils.listModules(net) 289 | for k,m in ipairs(moduleList) do 290 | if m.weight and (not m.gradWeight) then 291 | m.gradWeight = m.weight:clone():zero() 292 | --print('unsanitized gradWeight in of size ' .. m.gradWeight:nElement()) 293 | --print(m.weight:size()) 294 | end 295 | if m.bias and (not m.gradBias) then 296 | m.gradBias = m.bias:clone():zero() 297 | --print('unsanitized gradWeight in of size ' .. m.gradBias:nElement()) 298 | --print(m.bias:size()) 299 | end 300 | end 301 | end 302 | 303 | --[[ 304 | take a LongTensor of size DxN with elements 1..vocab_size+1 305 | (where last dimension is END token), and decode it into table of raw text sentences. 306 | each column is a sequence. ix_to_word gives the mapping to strings, as a table 307 | --]] 308 | function net_utils.decode_sequence(ix_to_word, seq) 309 | local D,N = seq:size(1), seq:size(2) 310 | local out = {} 311 | local count = {} 312 | for i=1,N do 313 | local tmp = 0 314 | local txt = '' 315 | for j=1,D do 316 | local ix = seq[{j,i}] 317 | local word = ix_to_word[tostring(ix)] 318 | if not word then break end -- END token, likely. Or null token 319 | if j >= 2 then txt = txt .. ' ' end 320 | tmp = tmp + 1 321 | txt = txt .. word 322 | end 323 | --txt = txt .. '.' 324 | table.insert(count, tmp) 325 | 326 | table.insert(out, txt) 327 | end 328 | return out, count 329 | end 330 | 331 | function net_utils.clone_list(lst) 332 | -- takes list of tensors, clone all 333 | local new = {} 334 | for k,v in pairs(lst) do 335 | new[k] = v:clone() 336 | end 337 | return new 338 | end 339 | 340 | function net_utils.clone_list_all(lst) 341 | -- takes list of tensors, clone all 342 | local new = {} 343 | for k,v in pairs(lst) do 344 | local new_sub = {} 345 | for m,n in pairs(v) do 346 | new_sub[m] = n:clone() 347 | end 348 | new[k] = new_sub 349 | end 350 | return new 351 | end 352 | 353 | -- hiding this piece of code on the bottom of the file, in hopes that 354 | -- noone will ever find it. Lets just pretend it doesn't exist 355 | function net_utils.language_eval(predictions, opt) 356 | -- this is gross, but we have to call coco python code. 357 | -- Not my favorite kind of thing, but here we go 358 | local id = utils.getopt(opt, 'id', 1) 359 | local dataset = utils.getopt(opt, 'dataset','coco') 360 | 361 | local out_struct = {val_predictions = predictions} 362 | utils.write_json('coco-caption/val' .. id .. '.json', out_struct) -- serialize to json (ew, so gross) 363 | print('./misc/call_python_caption_eval.sh val' .. id .. '.json annotations/' ..dataset..'.json') 364 | os.execute('./misc/call_python_caption_eval.sh val' .. id .. '.json annotations/' ..dataset..'.json') -- i'm dying over here 365 | local result_struct = utils.read_json('coco-caption/val' .. id .. '.json_out.json') -- god forgive me 366 | return result_struct 367 | end 368 | 369 | function net_utils.init_noise(graph, batch_size) 370 | if batch_size == nil then 371 | error('please provide valid batch_size value') 372 | end 373 | for i, node in pairs(graph:listModules()) do 374 | local layer = graph:get(i) 375 | local t = torch.type(layer) 376 | if t == 'nn.DropoutFix' then 377 | layer:init_noise(batch_size) 378 | end 379 | end 380 | end 381 | 382 | function net_utils.deepCopy(tbl) 383 | -- creates a copy of a network with new modules and the same tensors 384 | local copy = {} 385 | for k, v in pairs(tbl) do 386 | if type(v) == 'table' then 387 | copy[k] = net_utils.deepCopy(v) 388 | else 389 | copy[k] = v 390 | end 391 | end 392 | if torch.typename(tbl) then 393 | torch.setmetatable(copy, torch.typename(tbl)) 394 | end 395 | return copy 396 | end 397 | 398 | function net_utils.setBNGradient0(graph) 399 | -- setting the gradient of BN to be zero 400 | local BNlayers = graph:findModules('nn.SpatialBatchNormalization') 401 | for i, node in pairs(BNlayers) do 402 | node.gradWeight:zero() 403 | node.gradBias:zero() 404 | end 405 | end 406 | 407 | 408 | return net_utils -------------------------------------------------------------------------------- /misc/optim_updates.lua: -------------------------------------------------------------------------------- 1 | 2 | -- optim, simple as it should be, written from scratch. That's how I roll 3 | 4 | function sgd(x, dx, lr) 5 | x:add(-lr, dx) 6 | end 7 | 8 | function sgdm(x, dx, lr, alpha, state) 9 | -- sgd with momentum, standard update 10 | if not state.v then 11 | state.v = x.new(#x):zero() 12 | end 13 | state.v:mul(alpha) 14 | state.v:add(lr, dx) 15 | x:add(-1, state.v) 16 | end 17 | 18 | function sgdmom(x, dx, lr, alpha, state) 19 | -- sgd momentum, uses nesterov update (reference: http://cs231n.github.io/neural-networks-3/#sgd) 20 | if not state.m then 21 | state.m = x.new(#x):zero() 22 | state.tmp = x.new(#x) 23 | end 24 | state.tmp:copy(state.m) 25 | state.m:mul(alpha):add(-lr, dx) 26 | x:add(-alpha, state.tmp) 27 | x:add(1+alpha, state.m) 28 | end 29 | 30 | function adagrad(x, dx, lr, epsilon, state) 31 | if not state.m then 32 | state.m = x.new(#x):zero() 33 | state.tmp = x.new(#x) 34 | end 35 | -- calculate new mean squared values 36 | state.m:addcmul(1.0, dx, dx) 37 | -- perform update 38 | state.tmp:sqrt(state.m):add(epsilon) 39 | x:addcdiv(-lr, dx, state.tmp) 40 | end 41 | 42 | -- rmsprop implementation, simple as it should be 43 | function rmsprop(x, dx, lr, alpha, epsilon, state) 44 | if not state.m then 45 | state.m = x.new(#x):zero() 46 | state.tmp = x.new(#x) 47 | end 48 | -- calculate new (leaky) mean squared values 49 | state.m:mul(alpha) 50 | state.m:addcmul(1.0-alpha, dx, dx) 51 | -- perform update 52 | state.tmp:sqrt(state.m):add(epsilon) 53 | x:addcdiv(-lr, dx, state.tmp) 54 | end 55 | 56 | function adam(x, dx, lr, beta1, beta2, epsilon, state) 57 | local beta1 = beta1 or 0.9 58 | local beta2 = beta2 or 0.999 59 | local epsilon = epsilon or 1e-8 60 | 61 | if not state.m then 62 | -- Initialization 63 | state.t = 0 64 | -- Exponential moving average of gradient values 65 | state.m = x.new(#dx):zero() 66 | -- Exponential moving average of squared gradient values 67 | state.v = x.new(#dx):zero() 68 | -- A tmp tensor to hold the sqrt(v) + epsilon 69 | state.tmp = x.new(#dx):zero() 70 | end 71 | 72 | -- Decay the first and second moment running average coefficient 73 | state.m:mul(beta1):add(1-beta1, dx) 74 | state.v:mul(beta2):addcmul(1-beta2, dx, dx) 75 | state.tmp:copy(state.v):sqrt():add(epsilon) 76 | 77 | state.t = state.t + 1 78 | local biasCorrection1 = 1 - beta1^state.t 79 | local biasCorrection2 = 1 - beta2^state.t 80 | local stepSize = lr * math.sqrt(biasCorrection2)/biasCorrection1 81 | 82 | -- perform update 83 | x:addcdiv(-stepSize, state.m, state.tmp) 84 | end 85 | -------------------------------------------------------------------------------- /misc/transforms.lua: -------------------------------------------------------------------------------- 1 | require 'image' 2 | 3 | local M = {} 4 | 5 | function M.Compose(transforms) 6 | return function(input) 7 | for _, transform in ipairs(transforms) do 8 | input = transform(input) 9 | end 10 | return input 11 | end 12 | end 13 | 14 | function M.ColorNormalize(meanstd) 15 | return function(img) 16 | img = img:clone() 17 | for i=1,3 do 18 | img[{{},{i},{},{}}]:add(-meanstd.mean[i]) 19 | img[{{},{i},{},{}}]:div(meanstd.std[i]) 20 | end 21 | return img 22 | end 23 | end 24 | 25 | return M 26 | -------------------------------------------------------------------------------- /misc/utils.lua: -------------------------------------------------------------------------------- 1 | local cjson = require 'cjson' 2 | local utils = {} 3 | 4 | -- Assume required if default_value is nil 5 | function utils.getopt(opt, key, default_value) 6 | if default_value == nil and (opt == nil or opt[key] == nil) then 7 | error('error: required key ' .. key .. ' was not provided in an opt.') 8 | end 9 | if opt == nil then return default_value end 10 | local v = opt[key] 11 | if v == nil then v = default_value end 12 | return v 13 | end 14 | 15 | function utils.read_json(path) 16 | local file = io.open(path, 'r') 17 | local text = file:read() 18 | file:close() 19 | local info = cjson.decode(text) 20 | return info 21 | end 22 | 23 | function utils.write_json(path, j) 24 | -- API reference http://www.kyne.com.au/~mark/software/lua-cjson-manual.html#encode 25 | cjson.encode_sparse_array(true, 2, 10) 26 | local text = cjson.encode(j) 27 | local file = io.open(path, 'w') 28 | file:write(text) 29 | file:close() 30 | end 31 | 32 | -- dicts is a list of tables of k:v pairs, create a single 33 | -- k:v table that has the mean of the v's for each k 34 | -- assumes that all dicts have same keys always 35 | function utils.dict_average(dicts) 36 | local dict = {} 37 | local n = 0 38 | for i,d in pairs(dicts) do 39 | for k,v in pairs(d) do 40 | if dict[k] == nil then dict[k] = 0 end 41 | dict[k] = dict[k] + v 42 | end 43 | n=n+1 44 | end 45 | for k,v in pairs(dict) do 46 | dict[k] = dict[k] / n -- produce the average 47 | end 48 | return dict 49 | end 50 | 51 | -- seriously this is kind of ridiculous 52 | function utils.count_keys(t) 53 | local n = 0 54 | for k,v in pairs(t) do 55 | n = n + 1 56 | end 57 | return n 58 | end 59 | 60 | -- return average of all values in a table... 61 | function utils.average_values(t) 62 | local n = 0 63 | local vsum = 0 64 | for k,v in pairs(t) do 65 | vsum = vsum + v 66 | n = n + 1 67 | end 68 | return vsum / n 69 | end 70 | 71 | return utils 72 | -------------------------------------------------------------------------------- /prepro/prepro_coco.py: -------------------------------------------------------------------------------- 1 | """ 2 | Preprocess a raw json dataset into hdf5/json files for use in data_loader.lua 3 | 4 | Input: json file that has the form 5 | [{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...] 6 | example element in this list would look like 7 | {'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895} 8 | 9 | This script reads this json, does some basic preprocessing on the captions 10 | (e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays 11 | 12 | Output: a json file and an hdf5 file 13 | The hdf5 file contains several fields: 14 | /images is (N,3,256,256) uint8 array of raw image data in RGB format 15 | /labels is (M,max_length) uint32 array of encoded labels, zero padded 16 | /label_start_ix and /label_end_ix are (N,) uint32 arrays of pointers to the 17 | first and last indices (in range 1..M) of labels for each image 18 | /label_length stores the length of the sequence for each of the M sequences 19 | 20 | The json file has a dict that contains: 21 | - an 'ix_to_word' field storing the vocab in form {ix:'word'}, where ix is 1-indexed 22 | - an 'images' field that is a list holding auxiliary information for each image, 23 | such as in particular the 'split' it was assigned to. 24 | """ 25 | 26 | import os 27 | import json 28 | import argparse 29 | from random import shuffle, seed 30 | import string 31 | # non-standard dependencies: 32 | import h5py 33 | import numpy as np 34 | from scipy.misc import imread, imresize 35 | 36 | def prepro_captions(imgs): 37 | 38 | # preprocess all the captions 39 | print 'example processed tokens:' 40 | for i,img in enumerate(imgs): 41 | img['processed_tokens'] = [] 42 | for j,s in enumerate(img['captions']): 43 | txt = str(s).lower().translate(None, string.punctuation).strip().split() 44 | img['processed_tokens'].append(txt) 45 | if i < 10 and j == 0: print txt 46 | 47 | def build_vocab(imgs, params): 48 | count_thr = params['word_count_threshold'] 49 | 50 | # count up the number of words 51 | counts = {} 52 | for img in imgs: 53 | for txt in img['processed_tokens']: 54 | for w in txt: 55 | counts[w] = counts.get(w, 0) + 1 56 | cw = sorted([(count,w) for w,count in counts.iteritems()], reverse=True) 57 | print 'top words and their counts:' 58 | print '\n'.join(map(str,cw[:20])) 59 | 60 | # print some stats 61 | total_words = sum(counts.itervalues()) 62 | print 'total words:', total_words 63 | bad_words = [w for w,n in counts.iteritems() if n <= count_thr] 64 | vocab = [w for w,n in counts.iteritems() if n > count_thr] 65 | bad_count = sum(counts[w] for w in bad_words) 66 | print 'number of bad words: %d/%d = %.2f%%' % (len(bad_words), len(counts), len(bad_words)*100.0/len(counts)) 67 | print 'number of words in vocab would be %d' % (len(vocab), ) 68 | print 'number of UNKs: %d/%d = %.2f%%' % (bad_count, total_words, bad_count*100.0/total_words) 69 | 70 | # lets look at the distribution of lengths as well 71 | sent_lengths = {} 72 | for img in imgs: 73 | for txt in img['processed_tokens']: 74 | nw = len(txt) 75 | sent_lengths[nw] = sent_lengths.get(nw, 0) + 1 76 | max_len = max(sent_lengths.keys()) 77 | print 'max length sentence in raw data: ', max_len 78 | print 'sentence length distribution (count, number of words):' 79 | sum_len = sum(sent_lengths.values()) 80 | for i in xrange(max_len+1): 81 | print '%2d: %10d %f%%' % (i, sent_lengths.get(i,0), sent_lengths.get(i,0)*100.0/sum_len) 82 | 83 | # lets now produce the final annotations 84 | if bad_count > 0: 85 | # additional special UNK token we will use below to map infrequent words to 86 | print 'inserting the special UNK token' 87 | vocab.append('UNK') 88 | 89 | for img in imgs: 90 | img['final_captions'] = [] 91 | for txt in img['processed_tokens']: 92 | caption = [w if counts.get(w,0) > count_thr else 'UNK' for w in txt] 93 | img['final_captions'].append(caption) 94 | 95 | return vocab 96 | 97 | def assign_splits(imgs, params): 98 | num_val = params['num_val'] 99 | num_test = params['num_test'] 100 | 101 | for i,img in enumerate(imgs): 102 | if i < num_val: 103 | img['split'] = 'val' 104 | elif i < num_val + num_test: 105 | img['split'] = 'test' 106 | else: 107 | img['split'] = 'train' 108 | 109 | print 'assigned %d to val, %d to test.' % (num_val, num_test) 110 | 111 | def encode_captions(imgs, params, wtoi): 112 | """ 113 | encode all captions into one large array, which will be 1-indexed. 114 | also produces label_start_ix and label_end_ix which store 1-indexed 115 | and inclusive (Lua-style) pointers to the first and last caption for 116 | each image in the dataset. 117 | """ 118 | 119 | max_length = params['max_length'] 120 | N = len(imgs) 121 | M = sum(len(img['final_captions']) for img in imgs) # total number of captions 122 | 123 | label_arrays = [] 124 | label_start_ix = np.zeros(N, dtype='uint32') # note: these will be one-indexed 125 | label_end_ix = np.zeros(N, dtype='uint32') 126 | label_length = np.zeros(M, dtype='uint32') 127 | caption_counter = 0 128 | counter = 1 129 | for i,img in enumerate(imgs): 130 | n = len(img['final_captions']) 131 | assert n > 0, 'error: some image has no captions' 132 | 133 | Li = np.zeros((n, max_length), dtype='uint32') 134 | for j,s in enumerate(img['final_captions']): 135 | label_length[caption_counter] = min(max_length, len(s)) # record the length of this sequence 136 | caption_counter += 1 137 | for k,w in enumerate(s): 138 | if k < max_length: 139 | Li[j,k] = wtoi[w] 140 | 141 | # note: word indices are 1-indexed, and captions are padded with zeros 142 | label_arrays.append(Li) 143 | label_start_ix[i] = counter 144 | label_end_ix[i] = counter + n - 1 145 | 146 | counter += n 147 | 148 | L = np.concatenate(label_arrays, axis=0) # put all the labels together 149 | assert L.shape[0] == M, 'lengths don\'t match? that\'s weird' 150 | assert np.all(label_length > 0), 'error: some caption had no words?' 151 | 152 | print 'encoded captions to array of size ', `L.shape` 153 | return L, label_start_ix, label_end_ix, label_length 154 | 155 | def main(params): 156 | 157 | imgs = json.load(open(params['input_json'], 'r')) 158 | seed(123) # make reproducible 159 | shuffle(imgs) # shuffle the order 160 | 161 | # tokenization and preprocessing 162 | prepro_captions(imgs) 163 | 164 | # create the vocab 165 | vocab = build_vocab(imgs, params) 166 | itow = {i+1:w for i,w in enumerate(vocab)} # a 1-indexed vocab translation table 167 | wtoi = {w:i+1 for i,w in enumerate(vocab)} # inverse table 168 | 169 | # assign the splits 170 | assign_splits(imgs, params) 171 | 172 | # encode captions in large arrays, ready to ship to hdf5 file 173 | L, label_start_ix, label_end_ix, label_length = encode_captions(imgs, params, wtoi) 174 | 175 | # create output h5 file 176 | N = len(imgs) 177 | f = h5py.File(params['output_h5'], "w") 178 | f.create_dataset("labels", dtype='uint32', data=L) 179 | f.create_dataset("label_start_ix", dtype='uint32', data=label_start_ix) 180 | f.create_dataset("label_end_ix", dtype='uint32', data=label_end_ix) 181 | f.create_dataset("label_length", dtype='uint32', data=label_length) 182 | dset = f.create_dataset("images", (N,3,256,256), dtype='uint8') # space for resized images 183 | for i,img in enumerate(imgs): 184 | # load the image 185 | I = imread(os.path.join(params['images_root'], img['file_path'])) 186 | try: 187 | Ir = imresize(I, (256,256)) 188 | except: 189 | print 'failed resizing image %s - see http://git.io/vBIE0' % (img['file_path'],) 190 | raise 191 | # handle grayscale input images 192 | if len(Ir.shape) == 2: 193 | Ir = Ir[:,:,np.newaxis] 194 | Ir = np.concatenate((Ir,Ir,Ir), axis=2) 195 | # and swap order of axes from (256,256,3) to (3,256,256) 196 | Ir = Ir.transpose(2,0,1) 197 | # write to h5 198 | dset[i] = Ir 199 | if i % 1000 == 0: 200 | print 'processing %d/%d (%.2f%% done)' % (i, N, i*100.0/N) 201 | f.close() 202 | print 'wrote ', params['output_h5'] 203 | 204 | # create output json file 205 | out = {} 206 | out['ix_to_word'] = itow # encode the (1-indexed) vocab 207 | out['images'] = [] 208 | for i,img in enumerate(imgs): 209 | 210 | jimg = {} 211 | jimg['split'] = img['split'] 212 | if 'file_path' in img: jimg['file_path'] = img['file_path'] # copy it over, might need 213 | if 'id' in img: jimg['id'] = img['id'] # copy over & mantain an id, if present (e.g. coco ids, useful) 214 | 215 | out['images'].append(jimg) 216 | 217 | json.dump(out, open(params['output_json'], 'w')) 218 | print 'wrote ', params['output_json'] 219 | 220 | if __name__ == "__main__": 221 | 222 | parser = argparse.ArgumentParser() 223 | 224 | # input json 225 | parser.add_argument('--input_json', default='/data/coco/coco_raw.json', help='input json file to process into hdf5') 226 | parser.add_argument('--num_val', default=2000, type=int, help='number of images to assign to validation data (for CV etc)') 227 | parser.add_argument('--output_json', default='/data/coco/cocotalk_challenge.json', help='output json file') 228 | parser.add_argument('--output_h5', default='/data/coco/cocotalk_challenge.h5', help='output h5 file') 229 | 230 | # options 231 | parser.add_argument('--max_length', default=18, type=int, help='max length of a caption, in number of words. captions longer than this get clipped.') 232 | parser.add_argument('--images_root', default='/home/jiasen/dataset/coco/', help='root location in which images are stored, to be prepended to file_path in input json') 233 | parser.add_argument('--word_count_threshold', default=5, type=int, help='only words that occur more than this number of times will be put in vocab') 234 | parser.add_argument('--num_test', default=0, type=int, help='number of test images (to withold until very very end)') 235 | 236 | args = parser.parse_args() 237 | params = vars(args) # convert to ordinary dict 238 | print 'parsed input parameters:' 239 | print json.dumps(params, indent = 2) 240 | main(params) 241 | -------------------------------------------------------------------------------- /prepro/prepro_coco_test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Preprocess a raw json dataset into hdf5/json files for use in data_loader.lua 3 | 4 | Input: json file that has the form 5 | [{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...] 6 | example element in this list would look like 7 | {'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895} 8 | 9 | This script reads this json, does some basic preprocessing on the captions 10 | (e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays 11 | 12 | Output: a json file and an hdf5 file 13 | The hdf5 file contains several fields: 14 | /images is (N,3,256,256) uint8 array of raw image data in RGB format 15 | /labels is (M,max_length) uint32 array of encoded labels, zero padded 16 | /label_start_ix and /label_end_ix are (N,) uint32 arrays of pointers to the 17 | first and last indices (in range 1..M) of labels for each image 18 | /label_length stores the length of the sequence for each of the M sequences 19 | 20 | The json file has a dict that contains: 21 | - an 'ix_to_word' field storing the vocab in form {ix:'word'}, where ix is 1-indexed 22 | - an 'images' field that is a list holding auxiliary information for each image, 23 | such as in particular the 'split' it was assigned to. 24 | """ 25 | 26 | import os 27 | import json 28 | import argparse 29 | from random import shuffle, seed 30 | import string 31 | # non-standard dependencies: 32 | import h5py 33 | import numpy as np 34 | from scipy.misc import imread, imresize 35 | 36 | def prepro_captions(imgs): 37 | 38 | # preprocess all the captions 39 | print 'example processed tokens:' 40 | for i,img in enumerate(imgs): 41 | img['processed_tokens'] = [] 42 | for j,s in enumerate(img['captions']): 43 | txt = str(s).lower().translate(None, string.punctuation).strip().split() 44 | img['processed_tokens'].append(txt) 45 | if i < 10 and j == 0: print txt 46 | 47 | def build_vocab(imgs, params): 48 | count_thr = params['word_count_threshold'] 49 | 50 | # count up the number of words 51 | counts = {} 52 | for img in imgs: 53 | for txt in img['processed_tokens']: 54 | for w in txt: 55 | counts[w] = counts.get(w, 0) + 1 56 | cw = sorted([(count,w) for w,count in counts.iteritems()], reverse=True) 57 | print 'top words and their counts:' 58 | print '\n'.join(map(str,cw[:20])) 59 | 60 | # print some stats 61 | total_words = sum(counts.itervalues()) 62 | print 'total words:', total_words 63 | bad_words = [w for w,n in counts.iteritems() if n <= count_thr] 64 | vocab = [w for w,n in counts.iteritems() if n > count_thr] 65 | bad_count = sum(counts[w] for w in bad_words) 66 | print 'number of bad words: %d/%d = %.2f%%' % (len(bad_words), len(counts), len(bad_words)*100.0/len(counts)) 67 | print 'number of words in vocab would be %d' % (len(vocab), ) 68 | print 'number of UNKs: %d/%d = %.2f%%' % (bad_count, total_words, bad_count*100.0/total_words) 69 | 70 | # lets look at the distribution of lengths as well 71 | sent_lengths = {} 72 | for img in imgs: 73 | for txt in img['processed_tokens']: 74 | nw = len(txt) 75 | sent_lengths[nw] = sent_lengths.get(nw, 0) + 1 76 | max_len = max(sent_lengths.keys()) 77 | print 'max length sentence in raw data: ', max_len 78 | print 'sentence length distribution (count, number of words):' 79 | sum_len = sum(sent_lengths.values()) 80 | for i in xrange(max_len+1): 81 | print '%2d: %10d %f%%' % (i, sent_lengths.get(i,0), sent_lengths.get(i,0)*100.0/sum_len) 82 | 83 | # lets now produce the final annotations 84 | if bad_count > 0: 85 | # additional special UNK token we will use below to map infrequent words to 86 | print 'inserting the special UNK token' 87 | vocab.append('UNK') 88 | 89 | for img in imgs: 90 | img['final_captions'] = [] 91 | for txt in img['processed_tokens']: 92 | caption = [w if counts.get(w,0) > count_thr else 'UNK' for w in txt] 93 | img['final_captions'].append(caption) 94 | 95 | return vocab 96 | 97 | def assign_splits(imgs, params): 98 | 99 | for i,img in enumerate(imgs): 100 | img['split'] = 'test' 101 | 102 | def encode_captions(imgs, params, wtoi): 103 | """ 104 | encode all captions into one large array, which will be 1-indexed. 105 | also produces label_start_ix and label_end_ix which store 1-indexed 106 | and inclusive (Lua-style) pointers to the first and last caption for 107 | each image in the dataset. 108 | """ 109 | 110 | max_length = params['max_length'] 111 | N = len(imgs) 112 | M = sum(len(img['final_captions']) for img in imgs) # total number of captions 113 | 114 | label_arrays = [] 115 | label_start_ix = np.zeros(N, dtype='uint32') # note: these will be one-indexed 116 | label_end_ix = np.zeros(N, dtype='uint32') 117 | label_length = np.zeros(M, dtype='uint32') 118 | caption_counter = 0 119 | counter = 1 120 | for i,img in enumerate(imgs): 121 | n = len(img['final_captions']) 122 | assert n > 0, 'error: some image has no captions' 123 | 124 | Li = np.zeros((n, max_length), dtype='uint32') 125 | for j,s in enumerate(img['final_captions']): 126 | label_length[caption_counter] = min(max_length, len(s)) # record the length of this sequence 127 | caption_counter += 1 128 | for k,w in enumerate(s): 129 | if k < max_length: 130 | Li[j,k] = wtoi[w] 131 | 132 | # note: word indices are 1-indexed, and captions are padded with zeros 133 | label_arrays.append(Li) 134 | label_start_ix[i] = counter 135 | label_end_ix[i] = counter + n - 1 136 | 137 | counter += n 138 | 139 | L = np.concatenate(label_arrays, axis=0) # put all the labels together 140 | assert L.shape[0] == M, 'lengths don\'t match? that\'s weird' 141 | assert np.all(label_length > 0), 'error: some caption had no words?' 142 | 143 | print 'encoded captions to array of size ', `L.shape` 144 | return L, label_start_ix, label_end_ix, label_length 145 | 146 | def main(params): 147 | 148 | imgs = json.load(open(params['input_json'], 'r')) 149 | assign_splits(imgs, params) 150 | # create output h5 file 151 | N = len(imgs) 152 | f = h5py.File(params['output_h5'], "w") 153 | dset = f.create_dataset("images", (N,3,256,256), dtype='uint8') # space for resized images 154 | for i,img in enumerate(imgs): 155 | # load the image 156 | I = imread(os.path.join(params['images_root'], img['file_path'])) 157 | try: 158 | Ir = imresize(I, (256,256)) 159 | except: 160 | print 'failed resizing image %s - see http://git.io/vBIE0' % (img['file_path'],) 161 | raise 162 | # handle grayscale input images 163 | if len(Ir.shape) == 2: 164 | Ir = Ir[:,:,np.newaxis] 165 | Ir = np.concatenate((Ir,Ir,Ir), axis=2) 166 | # and swap order of axes from (256,256,3) to (3,256,256) 167 | Ir = Ir.transpose(2,0,1) 168 | # write to h5 169 | dset[i] = Ir 170 | if i % 1000 == 0: 171 | print 'processing %d/%d (%.2f%% done)' % (i, N, i*100.0/N) 172 | f.close() 173 | print 'wrote ', params['output_h5'] 174 | 175 | # create output json file 176 | out = {} 177 | out['images'] = [] 178 | for i,img in enumerate(imgs): 179 | jimg = {} 180 | jimg['split'] = img['split'] 181 | if 'file_path' in img: jimg['file_path'] = img['file_path'] # copy it over, might need 182 | if 'id' in img: jimg['id'] = img['id'] # copy over & mantain an id, if present (e.g. coco ids, useful) 183 | 184 | out['images'].append(jimg) 185 | 186 | json.dump(out, open(params['output_json'], 'w')) 187 | print 'wrote ', params['output_json'] 188 | 189 | if __name__ == "__main__": 190 | 191 | parser = argparse.ArgumentParser() 192 | 193 | # input json 194 | parser.add_argument('--input_json', default='/data/coco/coco_val_raw.json', help='input json file to process into hdf5') 195 | parser.add_argument('--output_json', default='/data/coco/cocotalk_val.json', help='output json file') 196 | parser.add_argument('--output_h5', default='/data/coco/cocotalk_val.h5', help='output h5 file') 197 | 198 | # options 199 | parser.add_argument('--images_root', default='/home/jiasen/dataset/coco/', help='root location in which images are stored, to be prepended to file_path in input json') 200 | 201 | args = parser.parse_args() 202 | params = vars(args) # convert to ordinary dict 203 | print 'parsed input parameters:' 204 | print json.dumps(params, indent = 2) 205 | main(params) 206 | -------------------------------------------------------------------------------- /prepro/prepro_flickr.py: -------------------------------------------------------------------------------- 1 | """ 2 | Preprocess a raw json dataset into hdf5/json files for use in data_loader.lua 3 | 4 | Input: json file that has the form 5 | [{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...] 6 | example element in this list would look like 7 | {'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895} 8 | 9 | This script reads this json, does some basic preprocessing on the captions 10 | (e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays 11 | 12 | Output: a json file and an hdf5 file 13 | The hdf5 file contains several fields: 14 | /images is (N,3,256,256) uint8 array of raw image data in RGB format 15 | /labels is (M,max_length) uint32 array of encoded labels, zero padded 16 | /label_start_ix and /label_end_ix are (N,) uint32 arrays of pointers to the 17 | first and last indices (in range 1..M) of labels for each image 18 | /label_length stores the length of the sequence for each of the M sequences 19 | 20 | The json file has a dict that contains: 21 | - an 'ix_to_word' field storing the vocab in form {ix:'word'}, where ix is 1-indexed 22 | - an 'images' field that is a list holding auxiliary information for each image, 23 | such as in particular the 'split' it was assigned to. 24 | """ 25 | 26 | import os 27 | import json 28 | import argparse 29 | from random import shuffle, seed 30 | import string 31 | # non-standard dependencies: 32 | import h5py 33 | import numpy as np 34 | from scipy.misc import imread, imresize 35 | import pdb 36 | 37 | def build_vocab(imgs, params): 38 | count_thr = params['word_count_threshold'] 39 | 40 | # count up the number of words 41 | counts = {} 42 | for img in imgs: 43 | for txt in img['sentences']: 44 | for w in txt['tokens']: 45 | counts[w] = counts.get(w, 0) + 1 46 | cw = sorted([(count,w) for w,count in counts.iteritems()], reverse=True) 47 | print 'top words and their counts:' 48 | print '\n'.join(map(str,cw[:20])) 49 | 50 | # print some stats 51 | total_words = sum(counts.itervalues()) 52 | print 'total words:', total_words 53 | bad_words = [w for w,n in counts.iteritems() if n <= count_thr] 54 | vocab = [w for w,n in counts.iteritems() if n > count_thr] 55 | bad_count = sum(counts[w] for w in bad_words) 56 | print 'number of bad words: %d/%d = %.2f%%' % (len(bad_words), len(counts), len(bad_words)*100.0/len(counts)) 57 | print 'number of words in vocab would be %d' % (len(vocab), ) 58 | print 'number of UNKs: %d/%d = %.2f%%' % (bad_count, total_words, bad_count*100.0/total_words) 59 | 60 | # lets look at the distribution of lengths as well 61 | sent_lengths = {} 62 | for img in imgs: 63 | for txt in img['sentences']: 64 | nw = len(txt['tokens']) 65 | sent_lengths[nw] = sent_lengths.get(nw, 0) + 1 66 | max_len = max(sent_lengths.keys()) 67 | print 'max length sentence in raw data: ', max_len 68 | print 'sentence length distribution (count, number of words):' 69 | sum_len = sum(sent_lengths.values()) 70 | for i in xrange(max_len+1): 71 | print '%2d: %10d %f%%' % (i, sent_lengths.get(i,0), sent_lengths.get(i,0)*100.0/sum_len) 72 | 73 | # lets now produce the final annotations 74 | if bad_count > 0: 75 | # additional special UNK token we will use below to map infrequent words to 76 | print 'inserting the special UNK token' 77 | vocab.append('UNK') 78 | 79 | for img in imgs: 80 | img['final_captions'] = [] 81 | for txt in img['sentences']: 82 | caption = [w if counts.get(w,0) > count_thr else 'UNK' for w in txt['tokens']] 83 | img['final_captions'].append(caption) 84 | 85 | return vocab 86 | 87 | def assign_splits(imgs, params): 88 | num_val = params['num_val'] 89 | num_test = params['num_test'] 90 | 91 | for i,img in enumerate(imgs): 92 | if i < num_val: 93 | img['split'] = 'val' 94 | elif i < num_val + num_test: 95 | img['split'] = 'test' 96 | else: 97 | img['split'] = 'train' 98 | 99 | print 'assigned %d to val, %d to test.' % (num_val, num_test) 100 | 101 | def encode_captions(imgs, params, wtoi): 102 | """ 103 | encode all captions into one large array, which will be 1-indexed. 104 | also produces label_start_ix and label_end_ix which store 1-indexed 105 | and inclusive (Lua-style) pointers to the first and last caption for 106 | each image in the dataset. 107 | """ 108 | 109 | max_length = params['max_length'] 110 | N = len(imgs) 111 | M = sum(len(img['final_captions']) for img in imgs) # total number of captions 112 | 113 | label_arrays = [] 114 | label_start_ix = np.zeros(N, dtype='uint32') # note: these will be one-indexed 115 | label_end_ix = np.zeros(N, dtype='uint32') 116 | label_length = np.zeros(M, dtype='uint32') 117 | caption_counter = 0 118 | counter = 1 119 | for i,img in enumerate(imgs): 120 | n = len(img['final_captions']) 121 | assert n > 0, 'error: some image has no captions' 122 | 123 | Li = np.zeros((n, max_length), dtype='uint32') 124 | for j,s in enumerate(img['final_captions']): 125 | label_length[caption_counter] = min(max_length, len(s)) # record the length of this sequence 126 | caption_counter += 1 127 | for k,w in enumerate(s): 128 | if k < max_length: 129 | Li[j,k] = wtoi[w] 130 | 131 | # note: word indices are 1-indexed, and captions are padded with zeros 132 | label_arrays.append(Li) 133 | label_start_ix[i] = counter 134 | label_end_ix[i] = counter + n - 1 135 | 136 | counter += n 137 | 138 | L = np.concatenate(label_arrays, axis=0) # put all the labels together 139 | assert L.shape[0] == M, 'lengths don\'t match? that\'s weird' 140 | assert np.all(label_length > 0), 'error: some caption had no words?' 141 | 142 | print 'encoded captions to array of size ', `L.shape` 143 | return L, label_start_ix, label_end_ix, label_length 144 | 145 | def main(params): 146 | 147 | imgs = json.load(open(params['input_json'], 'r')) 148 | imgs = imgs['images'] 149 | seed(123) # make reproducible 150 | 151 | shuffle(imgs) # shuffle the order 152 | 153 | # create the vocab 154 | vocab = build_vocab(imgs, params) 155 | itow = {i+1:w for i,w in enumerate(vocab)} # a 1-indexed vocab translation table 156 | wtoi = {w:i+1 for i,w in enumerate(vocab)} # inverse table 157 | 158 | # assign the splits 159 | # assign_splits(imgs, params) 160 | 161 | # encode captions in large arrays, ready to ship to hdf5 file 162 | L, label_start_ix, label_end_ix, label_length = encode_captions(imgs, params, wtoi) 163 | 164 | # create output h5 file 165 | N = len(imgs) 166 | f = h5py.File(params['output_h5'], "w") 167 | f.create_dataset("labels", dtype='uint32', data=L) 168 | f.create_dataset("label_start_ix", dtype='uint32', data=label_start_ix) 169 | f.create_dataset("label_end_ix", dtype='uint32', data=label_end_ix) 170 | f.create_dataset("label_length", dtype='uint32', data=label_length) 171 | dset = f.create_dataset("images", (N,3,256,256), dtype='uint8') # space for resized images 172 | for i,img in enumerate(imgs): 173 | # load the image 174 | I = imread(os.path.join(params['images_root'], img['filename'])) 175 | 176 | try: 177 | Ir = imresize(I, (256,256)) 178 | except: 179 | print 'failed resizing image %s - see http://git.io/vBIE0' % (img['file_path'],) 180 | raise 181 | # handle grayscale input images 182 | if len(Ir.shape) == 2: 183 | Ir = Ir[:,:,np.newaxis] 184 | Ir = np.concatenate((Ir,Ir,Ir), axis=2) 185 | # and swap order of axes from (256,256,3) to (3,256,256) 186 | Ir = Ir.transpose(2,0,1) 187 | # write to h5 188 | dset[i] = Ir 189 | 190 | if i % 1000 == 0: 191 | print 'processing %d/%d (%.2f%% done)' % (i, N, i*100.0/N) 192 | f.close() 193 | print 'wrote ', params['output_h5'] 194 | 195 | # create output json file 196 | out = {} 197 | out['ix_to_word'] = itow # encode the (1-indexed) vocab 198 | out['images'] = [] 199 | for i,img in enumerate(imgs): 200 | 201 | jimg = {} 202 | jimg['split'] = img['split'] 203 | if 'filename' in img: jimg['file_path'] = img['filename'] # copy it over, might need 204 | if 'imgid' in img: jimg['id'] = img['imgid'] # copy over & mantain an id, if present (e.g. coco ids, useful) 205 | 206 | out['images'].append(jimg) 207 | 208 | json.dump(out, open(params['output_json'], 'w')) 209 | print 'wrote ', params['output_json'] 210 | 211 | if __name__ == "__main__": 212 | 213 | parser = argparse.ArgumentParser() 214 | 215 | # input json 216 | parser.add_argument('--input_json', default='/data/flickr30k/dataset.json', help='input json file to process into hdf5') 217 | parser.add_argument('--output_json', default='/data/flickr30k/cocotalk.json', help='output json file') 218 | parser.add_argument('--output_h5', default='/data/flickr30k/cocotalk.h5', help='output h5 file') 219 | 220 | # options 221 | parser.add_argument('--max_length', default=16, type=int, help='max length of a caption, in number of words. captions longer than this get clipped.') 222 | parser.add_argument('--images_root', default='/data/flickr30k/images', help='root location in which images are stored, to be prepended to file_path in input json') 223 | parser.add_argument('--word_count_threshold', default=5, type=int, help='only words that occur more than this number of times will be put in vocab') 224 | 225 | args = parser.parse_args() 226 | params = vars(args) # convert to ordinary dict 227 | print 'parsed input parameters:' 228 | print json.dumps(params, indent = 2) 229 | main(params) 230 | -------------------------------------------------------------------------------- /train.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'nn' 3 | require 'nngraph' 4 | require 'misc.DataLoaderResNet' 5 | 6 | local utils = require 'misc.utils' 7 | require 'misc.LanguageModel' 8 | local net_utils = require 'misc.net_utils' 9 | require 'misc.optim_updates' 10 | require 'gnuplot' 11 | require 'xlua' 12 | ------------------------------------------------------------------------------- 13 | -- Input arguments and options 14 | ------------------------------------------------------------------------------- 15 | cmd = torch.CmdLine() 16 | cmd:text() 17 | cmd:text('Train an Image Captioning model') 18 | cmd:text() 19 | cmd:text('Options') 20 | 21 | 22 | -- Data input settings 23 | 24 | cmd:option('-input_h5','/data/coco/cocotalk.h5','path to the h5file containing the preprocessed dataset') 25 | cmd:option('-input_json','/data/coco/cocotalk.json','path to the json file containing additional info and vocab') 26 | cmd:option('-cnn_model','../image_model/resnet-152.t7','path to CNN model file containing the weights, Caffe format. Note this MUST be a VGGNet-16 right now.') 27 | 28 | cmd:option('-start_from', '', 'path to a model checkpoint to initialize model weights from. Empty = don\'t') 29 | cmd:option('-checkpoint_path', 'save/', 'folder to save checkpoints into (empty = this folder)') 30 | cmd:option('-startEpoch', 1, 'Max number of training epoch') 31 | 32 | -- Model settings 33 | cmd:option('-rnn_size',512,'size of the rnn in number of hidden nodes in each layer') 34 | cmd:option('-num_layers',1,'the encoding size of each token in the vocabulary, and the image.') 35 | cmd:option('-input_encoding_size',512,'the encoding size of each token in the vocabulary, and the image.') 36 | cmd:option('-batch_size',20,'what is the batch size in number of images per batch? (there will be x seq_per_img sentences)') 37 | 38 | -- training setting 39 | cmd:option('-nEpochs', 20, 'Max number of training epoch') 40 | cmd:option('-finetune_cnn_after', 20, 'After what epoch do we start finetuning the CNN? (-1 = disable; never finetune, 0 = finetune from start)') 41 | 42 | --actuall batch size = gpu_num * batch_size 43 | 44 | cmd:option('-fc_size',2048,'the encoding size of each token in the vocabulary, and the image.') 45 | cmd:option('-conv_size',2048,'the encoding size of each token in the vocabulary, and the image.') 46 | cmd:option('-seq_per_img',5,'number of captions to sample for each image during training. Done for efficiency since CNN forward pass is expensive. E.g. coco has 5 sents/image') 47 | 48 | -- Optimization: General 49 | cmd:option('-grad_clip',0.1,'clip gradients at this value (note should be lower than usual 5 because we normalize grads by both batch and seq_length)') 50 | cmd:option('-drop_prob_lm', 0.5, 'strength of dropout in the Language Model RNN') 51 | 52 | -- Optimization: for the Language Model 53 | cmd:option('-optim','adam','what update to use? rmsprop|sgd|sgdmom|adagrad|adam') 54 | cmd:option('-learning_rate',4e-4,'learning rate') 55 | cmd:option('-learning_rate_decay_start', 20, 'at what iteration to start decaying learning rate? (-1 = dont)') 56 | cmd:option('-learning_rate_decay_every', 50, 'how many epoch the learning rate x 0.5') 57 | cmd:option('-optim_alpha',0.8,'alpha for adagrad/rmsprop/momentum/adam') 58 | cmd:option('-optim_beta',0.999,'beta used for adam') 59 | cmd:option('-optim_epsilon',1e-8,'epsilon that goes into denominator for smoothing') 60 | 61 | -- Optimization: for the CNN 62 | cmd:option('-cnn_optim','adam','optimization to use for CNN') 63 | cmd:option('-cnn_optim_alpha',0.8,'alpha for momentum of CNN') 64 | cmd:option('-cnn_optim_beta',0.999,'alpha for momentum of CNN') 65 | cmd:option('-cnn_learning_rate',1e-5,'learning rate for the CNN') 66 | cmd:option('-cnn_weight_decay', 0, 'L2 weight decay just for the CNN') 67 | cmd:option('-finetune_start_layer', 6, 'finetune start layer. [1-10]') 68 | 69 | -- Evaluation/Checkpointing 70 | cmd:option('-val_images_use', -1, 'how many images to use when periodically evaluating the validation loss? (-1 = all)') 71 | cmd:option('-save_checkpoint_every', 4, 'how often to save a model checkpoint?') 72 | cmd:option('-language_eval', 1, 'Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.') 73 | 74 | -- misc 75 | cmd:option('-backend', 'cudnn', 'nn|cudnn') 76 | cmd:option('-id', '1', 'an id identifying this run/job. used in cross-val and appended when writing progress files') 77 | cmd:option('-seed', 123, 'random number generator seed to use') 78 | cmd:option('-gpuid', 0, 'which gpu to use. -1 = use CPU') 79 | 80 | cmd:text() 81 | 82 | ------------------------------------------------------------------------------- 83 | -- Basic Torch initializations 84 | ------------------------------------------------------------------------------- 85 | local opt = cmd:parse(arg) 86 | --torch.manualSeed(opt.seed) 87 | torch.setdefaulttensortype('torch.FloatTensor') -- for CPU 88 | 89 | if opt.gpuid >= 0 then 90 | require 'cutorch' 91 | require 'cunn' 92 | if opt.backend == 'cudnn' then require 'cudnn' end 93 | --cutorch.manualSeed(opt.seed) 94 | cutorch.setDevice(opt.gpuid + 1) -- note +1 because lua is 1-indexed 95 | end 96 | 97 | ------------------------------------------------------------------------------- 98 | -- Create the Data Loader instance 99 | ------------------------------------------------------------------------------- 100 | local loader = DataLoader{h5_file = opt.input_h5, json_file = opt.input_json, neighbor_h5 = opt.nn_neighbor, 101 | batch_size = opt.batch_size, seq_per_img = opt.seq_per_img, thread_num = opt.thread_num} 102 | 103 | ------------------------------------------------------------------------------- 104 | -- Initialize the networks 105 | ------------------------------------------------------------------------------- 106 | -- create protos from scratch 107 | -- intialize language model 108 | local lmOpt = {} 109 | lmOpt.vocab_size = loader:getVocabSize() 110 | lmOpt.input_encoding_size = opt.input_encoding_size 111 | lmOpt.rnn_size = opt.rnn_size 112 | lmOpt.num_layers = opt.num_layers 113 | lmOpt.dropout = opt.drop_prob_lm 114 | lmOpt.seq_length = loader:getSeqLength() 115 | lmOpt.batch_size = opt.batch_size * opt.seq_per_img 116 | lmOpt.fc_size = opt.fc_size 117 | lmOpt.conv_size = opt.conv_size 118 | 119 | local loaded_checkpoint 120 | if opt.start_from ~= '' then -- just copy to gpu1 params 121 | local loaded_checkpoint_path = path.join(opt.checkpoint_path, opt.start_from) 122 | loaded_checkpoint = torch.load(loaded_checkpoint_path) 123 | end 124 | 125 | -- iterate over different gpu 126 | local protos = {} 127 | 128 | protos.lm = nn.LanguageModel(lmOpt):cuda() 129 | -- initialize the ConvNet 130 | if opt.start_from ~= '' then -- just copy to gpu1 params 131 | protos.cnn_conv_fix = loaded_checkpoint.protos.cnn_conv_fix:cuda() 132 | protos.cnn_conv = loaded_checkpoint.protos.cnn_conv:cuda() 133 | protos.cnn_fc = loaded_checkpoint.protos.cnn_fc:cuda() 134 | else 135 | local cnn_raw = torch.load(opt.cnn_model) 136 | 137 | protos.cnn_conv_fix = net_utils.build_residual_cnn_conv_fix(cnn_raw, 138 | {backend = cnn_backend, start_layer_num = opt.finetune_start_layer}):cuda() 139 | 140 | protos.cnn_conv = net_utils.build_residual_cnn_conv(cnn_raw, 141 | {backend = cnn_backend, start_layer_num = opt.finetune_start_layer}):cuda() 142 | 143 | protos.cnn_fc = net_utils.build_residual_cnn_fc(cnn_raw, 144 | {backend = cnn_backend}):cuda() 145 | end 146 | protos.expanderConv = nn.FeatExpanderConv(opt.seq_per_img):cuda() 147 | protos.expanderFC = nn.FeatExpander(opt.seq_per_img):cuda() 148 | protos.transform_cnn_conv = net_utils.transform_cnn_conv(opt.conv_size):cuda() 149 | -- criterion for the language model 150 | protos.crit = nn.LanguageModelCriterion():cuda() 151 | 152 | params, grad_params = protos.lm:getParameters() 153 | cnn1_params, cnn1_grad_params = protos.cnn_conv:getParameters() 154 | 155 | print('total number of parameters in LM: ', params:nElement()) 156 | print('total number of parameters in CNN_conv: ', cnn1_params:nElement()) 157 | 158 | assert(params:nElement() == grad_params:nElement()) 159 | assert(cnn1_params:nElement() == cnn1_grad_params:nElement()) 160 | 161 | if opt.start_from ~= '' then -- just copy to gpu1 params 162 | params:copy(loaded_checkpoint.lmparam) 163 | end 164 | 165 | protos.lm:createClones() 166 | collectgarbage() 167 | 168 | ------------------------------------------------------------------------------- 169 | -- Validation evaluation 170 | ------------------------------------------------------------------------------- 171 | local function evaluate_split(split, evalopt) 172 | local val_images_use = utils.getopt(evalopt, 'val_images_use', true) 173 | 174 | print('=> evaluating ...') 175 | -- setting to the evaluation mode, use only the first gpu 176 | protos.cnn_conv:evaluate() 177 | protos.cnn_fc:evaluate() 178 | protos.lm:evaluate() 179 | protos.cnn_conv_fix:evaluate() 180 | 181 | local n = 0 182 | local loss_sum = 0 183 | local loss_evals = 0 184 | local predictions = {} 185 | local vocab = loader:getVocab() 186 | local imgId_cell = {} 187 | 188 | local nbatch = math.ceil(val_images_use / opt.batch_size) 189 | if val_images_use == -1 then 190 | nbatch = loader:getnBatch(split) 191 | end 192 | loader:init_rand(split) 193 | loader:reset_iterator(split) 194 | 195 | for n = 1, nbatch do 196 | local data = loader:run({split = split, size_image_use = val_images_use}) 197 | -- convert the data to cuda 198 | data.images = data.images:cuda() 199 | data.labels = data.labels:cuda() 200 | 201 | -- forward the model to get loss 202 | local feats_conv_fix = protos.cnn_conv_fix:forward(data.images) 203 | local feats_conv = protos.cnn_conv:forward(feats_conv_fix) 204 | local feat_conv_t = protos.transform_cnn_conv:forward(feats_conv) 205 | local feats_fc = protos.cnn_fc:forward(feats_conv) 206 | 207 | local expanded_feats_conv = protos.expanderConv:forward(feat_conv_t) 208 | local expanded_feats_fc = protos.expanderFC:forward(feats_fc) 209 | local logprobs = protos.lm:forward({expanded_feats_conv, expanded_feats_fc, data.labels}) 210 | 211 | local loss = protos.crit:forward({logprobs, data.labels}) 212 | loss_sum = loss_sum + loss 213 | loss_evals = loss_evals + 1 214 | -- forward the model to also get generated samples for each image 215 | local seq = protos.lm:sample({feat_conv_t, feats_fc, vocab, data.labels:cuda()}) 216 | local sents = net_utils.decode_sequence(vocab, seq) 217 | 218 | for k=1,#sents do 219 | local img_id = data.img_id[k] 220 | local entry 221 | if imgId_cell[img_id] == nil then -- make sure there are one caption for each image. 222 | imgId_cell[img_id] = 1 223 | entry = {image_id = img_id, caption = sents[k]} 224 | table.insert(predictions, entry) 225 | end 226 | if n == 1 then -- print the first batch 227 | print(string.format('image %s: %s', entry.image_id, entry.caption)) 228 | end 229 | end 230 | end 231 | local lang_stats 232 | if opt.language_eval == 1 then 233 | local sampleOpt = {beam_size = 3} 234 | lang_stats = net_utils.language_eval(predictions, {id = opt.id, dataset = opt.dataset},sampleOpt) 235 | end 236 | 237 | return loss_sum/loss_evals, predictions, lang_stats 238 | end 239 | 240 | ------------------------------------------------------------------------------- 241 | -- train function 242 | ------------------------------------------------------------------------------- 243 | local function Train(epoch) 244 | 245 | local size_image_use = -1 246 | print('=> Training epoch # ' .. epoch) 247 | print('lm_learning_rate: ' .. learning_rate 248 | .. ' cnn_learning_rate: ' .. cnn_learning_rate) 249 | 250 | protos.cnn_conv:training() 251 | protos.cnn_fc:training() 252 | protos.lm:training() 253 | protos.cnn_conv_fix:training() 254 | 255 | local nbatch = math.ceil(size_image_use / opt.batch_size) 256 | if size_image_use == -1 then 257 | nbatch = loader:getnBatch('train') 258 | end 259 | 260 | local ave_loss = 0 261 | loader:init_rand('train') 262 | loader:reset_iterator('train') 263 | for n = 1, nbatch do 264 | xlua.progress(n,nbatch) 265 | grad_params:zero() 266 | 267 | -- setting the gradient of the CNN network 268 | if epoch >= opt.finetune_cnn_after and opt.finetune_cnn_after ~= -1 then 269 | cnn1_grad_params:zero() 270 | end 271 | 272 | local data = loader:run({split = 'train', size_image_use = size_image_use}) 273 | -- convert the data to cuda 274 | data.images = data.images:cuda() 275 | data.labels = data.labels:cuda() 276 | 277 | local feats_conv_fix = protos.cnn_conv_fix:forward(data.images) 278 | local feats_conv = protos.cnn_conv:forward(feats_conv_fix) 279 | local feat_conv_t = protos.transform_cnn_conv:forward(feats_conv) 280 | 281 | -- we have to expand out image features, once for each sentence 282 | local feats_fc = protos.cnn_fc:forward(feats_conv) 283 | local expanded_feats_conv = protos.expanderConv:forward(feat_conv_t) 284 | local expanded_feats_fc = protos.expanderFC:forward(feats_fc) 285 | 286 | -- forward the language model 287 | local log_prob = protos.lm:forward({expanded_feats_conv, expanded_feats_fc, data.labels}) 288 | -- forward the language model criterion 289 | local loss = protos.crit:forward({log_prob, data.labels}) 290 | ----------------------------------------------------------------------------- 291 | -- Backward pass 292 | ----------------------------------------------------------------------------- 293 | -- backprop criterion 294 | local d_logprobs = protos.crit:backward({}) 295 | -- backprop language model 296 | local dexpanded_conv, dexpanded_fc = unpack(protos.lm:backward({}, d_logprobs)) 297 | 298 | grad_params:clamp(-opt.grad_clip, opt.grad_clip) 299 | 300 | if epoch >= opt.finetune_cnn_after and opt.finetune_cnn_after ~= -1 then 301 | net_utils.setBNGradient0(protos.transform_cnn_conv) 302 | net_utils.setBNGradient0(protos.cnn_fc) 303 | net_utils.setBNGradient0(protos.cnn_conv) 304 | -- backprop the CNN, but only if we are finetuning 305 | dconv_t = protos.expanderConv:backward(feat_conv_t, dexpanded_conv) 306 | dfc = protos.expanderFC:backward(feats_fc, dexpanded_fc) 307 | dconv = protos.transform_cnn_conv:backward(feats_conv, dconv_t) 308 | dx = protos.cnn_fc:backward(feats_conv, dfc) 309 | dconv:add(dx) 310 | local dummy = protos.cnn_conv:backward(feats_conv_fix, dconv) 311 | 312 | -- apply L2 regularization 313 | if opt.cnn_weight_decay > 0 then 314 | cnn1_grad_params:add(opt.cnn_weight_decay, cnn1_params) 315 | end 316 | 317 | cnn1_grad_params:clamp(-opt.grad_clip, opt.grad_clip) 318 | end 319 | 320 | ----------------------------------------------------------------------------- 321 | ave_loss = ave_loss + loss 322 | -- update the parameters 323 | if opt.optim == 'rmsprop' then 324 | rmsprop(params, grad_params, learning_rate, opt.optim_alpha, opt.optim_epsilon, optim_state) 325 | elseif opt.optim == 'adam' then 326 | adam(params, grad_params, learning_rate, opt.optim_alpha, opt.optim_beta, opt.optim_epsilon, optim_state) 327 | else 328 | error('bad option opt.optim') 329 | end 330 | 331 | if epoch >= opt.finetune_cnn_after and opt.finetune_cnn_after ~= -1 then 332 | if opt.cnn_optim == 'sgd' then 333 | sgd(cnn1_params, cnn1_grad_params, cnn1_learning_rate) 334 | elseif opt.cnn_optim == 'sgdm' then 335 | sgdm(cnn1_params, cnn1_grad_params, cnn_learning_rate, opt.cnn_optim_alpha, cnn1_optim_state) 336 | elseif opt.cnn_optim == 'adam' then 337 | adam(cnn1_params, cnn1_grad_params, cnn_learning_rate, opt.cnn_optim_alpha, opt.cnn_optim_beta, opt.optim_epsilon, cnn1_optim_state) 338 | else 339 | error('bad option for opt.cnn_optim') 340 | end 341 | end 342 | end 343 | ave_loss = ave_loss / nbatch 344 | 345 | return ave_loss 346 | end 347 | 348 | 349 | paths.mkdir(opt.checkpoint_path) 350 | 351 | ------------------------------------------------------------------------------- 352 | -- Main loop 353 | ------------------------------------------------------------------------------- 354 | optim_state = {} 355 | cnn1_optim_state = {} 356 | learning_rate = opt.learning_rate 357 | cnn_learning_rate = opt.cnn_learning_rate 358 | 359 | local startEpoch = opt.startEpoch 360 | local loss0 361 | local loss_history = {} 362 | local val_lang_stats_history = {} 363 | local val_loss_history = {} 364 | local checkpoint_path = path.join(opt.checkpoint_path, 'model_id' .. opt.id) 365 | local timer = torch.Timer() 366 | 367 | for epoch = startEpoch, opt.nEpochs do 368 | 369 | -- doing the learning rate decay 370 | if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0 then 371 | local frac = (epoch - opt.learning_rate_decay_start) / opt.learning_rate_decay_every 372 | local decay_factor = math.pow(0.5, frac) 373 | learning_rate = learning_rate * decay_factor -- set the decayed rate 374 | end 375 | 376 | 377 | local train_loss = Train(epoch) 378 | print('training loss for # ' .. epoch .. ' : ' .. train_loss) 379 | 380 | -- save the model. 381 | if epoch % opt.save_checkpoint_every == 0 then 382 | local val_loss, val_predictions, lang_stats = evaluate_split('val', {val_images_use = opt.val_images_use}) 383 | print('validation loss for # ' .. epoch .. ' : ' .. val_loss) 384 | 385 | loss_history[epoch] = train_loss 386 | val_loss_history[epoch] = val_loss 387 | 388 | if lang_stats then 389 | val_lang_stats_history[epoch] = lang_stats 390 | end 391 | 392 | local checkpoint = {} 393 | checkpoint.loss_history = loss_history 394 | checkpoint.val_loss_history = val_loss_history 395 | checkpoint.val_predictions = val_predictions -- save these too for CIDEr/METEOR/etc eval 396 | checkpoint.val_lang_stats_history = val_lang_stats_history 397 | 398 | utils.write_json(checkpoint_path .. '.json', checkpoint) 399 | print('wrote json checkpoint to ' .. checkpoint_path .. '.json') 400 | 401 | local save_protos = {} 402 | save_protos.cnn_conv = net_utils.deepCopy(protos.cnn_conv):float():clearState() 403 | save_protos.cnn_conv_fix = net_utils.deepCopy(protos.cnn_conv_fix):float():clearState() 404 | save_protos.cnn_fc = net_utils.deepCopy(protos.cnn_fc):float():clearState() 405 | 406 | checkpoint.protos = save_protos 407 | checkpoint.lmparam = params:float() 408 | checkpoint.vocab = loader:getVocab() 409 | torch.save(checkpoint_path .. '_' .. epoch .. '.t7', checkpoint) 410 | print('wrote checkpoint to ' .. checkpoint_path .. '_' .. epoch .. '.t7') 411 | end 412 | 413 | end 414 | -------------------------------------------------------------------------------- /visu/DataLoaderResNetEval.lua: -------------------------------------------------------------------------------- 1 | require 'hdf5' 2 | local utils = require 'misc.utils' 3 | local net_utils = require 'misc.net_utils' 4 | local t = require 'misc.transforms' 5 | 6 | local DataLoader = torch.class('DataLoader') 7 | 8 | function DataLoader:__init(opt) 9 | 10 | -- load the json file which contains additional information about the dataset 11 | print('DataLoader loading json file: ', opt.json_file) 12 | self.info = utils.read_json(opt.json_file) 13 | self.ix_to_word = self.info.ix_to_word 14 | self.vocab_size = utils.count_keys(self.ix_to_word) 15 | 16 | self.batch_size = utils.getopt(opt, 'batch_size', 5) -- how many images get returned at one time (to go through CNN) 17 | self.seq_per_img = utils.getopt(opt, 'seq_per_img', 5) -- number of sequences to return per image 18 | 19 | print('vocab size is ' .. self.vocab_size) 20 | 21 | -- open the hdf5 file 22 | print('DataLoader loading h5 file: ', opt.h5_file) 23 | self.h5_file = hdf5.open(opt.h5_file, 'r') 24 | 25 | -- extract image size from dataset 26 | local images_size = self.h5_file:read('/images'):dataspaceSize() 27 | assert(#images_size == 4, '/images should be a 4D tensor') 28 | assert(images_size[3] == images_size[4], 'width and height must match') 29 | self.num_images = images_size[1] 30 | self.num_channels = images_size[2] 31 | self.max_image_size = images_size[3] 32 | 33 | self.imgs = self.h5_file:read('/images'):all() 34 | 35 | print(string.format('read %d images of size %dx%dx%d', self.num_images, 36 | self.num_channels, self.max_image_size, self.max_image_size)) 37 | 38 | -- load in the sequence data 39 | local seq_size = self.h5_file:read('/labels'):dataspaceSize() 40 | self.seq_length = seq_size[2] 41 | print('max sequence length in data is ' .. self.seq_length) 42 | -- load the pointers in full to RAM (should be small enough) 43 | self.label_start_ix = self.h5_file:read('/label_start_ix'):all() 44 | self.label_end_ix = self.h5_file:read('/label_end_ix'):all() 45 | self.labels = self.h5_file:read('/labels'):all() 46 | self.label_lens = self.h5_file:read('/label_length'):all() 47 | -- separate out indexes for each of the provided splits 48 | self.split_ix = {} 49 | self.iterator = {} 50 | self.image_ids = torch.LongTensor(self.num_images):zero() 51 | for i,img in pairs(self.info.images) do 52 | local split = img.split 53 | if not self.split_ix[split] then 54 | -- initialize new split 55 | self.split_ix[split] = {} 56 | self.iterator[split] = 1 57 | end 58 | table.insert(self.split_ix[split], i) 59 | self.image_ids[i] = img.id 60 | end 61 | 62 | self.__size = {} 63 | for k,v in pairs(self.split_ix) do 64 | print(string.format('assigned %d images to split %s', #v, k)) 65 | end 66 | 67 | self.meanstd = { 68 | mean = { 0.485, 0.456, 0.406 }, 69 | std = { 0.229, 0.224, 0.225 }, 70 | } 71 | 72 | self.transform = t.Compose{ 73 | t.ColorNormalize(self.meanstd) 74 | } 75 | end 76 | 77 | function DataLoader:init_rand(split) 78 | local size = #self.split_ix[split] 79 | if split == 'train' then 80 | self.perm = torch.randperm(size) 81 | else 82 | self.perm = torch.range(1,size) -- for test and validation, do not permutate 83 | end 84 | end 85 | 86 | function DataLoader:reset_iterator(split) 87 | self.iterator[split] = 1 88 | end 89 | 90 | function DataLoader:getVocabSize() 91 | return self.vocab_size 92 | end 93 | 94 | function DataLoader:getVocab() 95 | return self.ix_to_word 96 | end 97 | 98 | function DataLoader:getSeqLength() 99 | return self.seq_length 100 | end 101 | 102 | function DataLoader:getnBatch(split) 103 | return math.ceil(#self.split_ix[split] / self.batch_size) 104 | end 105 | 106 | function DataLoader:run(opt) 107 | local split = utils.getopt(opt, 'split') -- lets require that user passes this in, for safety 108 | local size_image_use = utils.getopt(opt, 'size_image_use', -1) 109 | local size, batch_size = #self.split_ix[split], self.batch_size 110 | local seq_per_img, seq_length = self.seq_per_img, self.seq_length 111 | local num_channels, max_image_size = self.num_channels, self.max_image_size 112 | 113 | if size_image_use ~= -1 and size_image_use <= size then size = size_image_use end 114 | local split_ix = self.split_ix[split] 115 | local idx = self.iterator[split] 116 | 117 | if idx <= size then 118 | 119 | local indices = self.perm:narrow(1, idx, math.min(batch_size, size - idx + 1)) 120 | 121 | local img_batch_raw = torch.ByteTensor(batch_size, 3, 256, 256) 122 | local label_batch = torch.LongTensor(batch_size * seq_per_img, seq_length):zero() 123 | local img_id_batch = torch.LongTensor(batch_size*seq_per_img):zero() 124 | for i, ixm in ipairs(indices:totable()) do 125 | 126 | local ix = split_ix[ixm] 127 | img_batch_raw[i] = self.imgs[ix] 128 | 129 | -- fetch the sequence labels 130 | local ix1 = self.label_start_ix[ix] 131 | local ix2 = self.label_end_ix[ix] 132 | 133 | local ncap = ix2 - ix1 + 1 -- number of captions available for this image 134 | assert(ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t') 135 | local seq 136 | if ncap < seq_per_img then 137 | -- we need to subsample (with replacement) 138 | seq = torch.LongTensor(seq_per_img, seq_length) 139 | for q=1, seq_per_img do 140 | local ixl = torch.random(ix1,ix2) 141 | seq[{{q,q}}] = self.labels[{{ixl, ixl}, {1,seq_length}}] 142 | end 143 | else 144 | -- there is enough data to read a contiguous chunk, but subsample the chunk position 145 | local ixl = torch.random(ix1, ix2 - seq_per_img + 1) -- generates integer in the range 146 | seq = self.labels[{{ixl, ixl+seq_per_img-1}, {1,seq_length}}] 147 | end 148 | 149 | local il = (i-1)*seq_per_img+1 150 | label_batch[{{il,il+seq_per_img-1} }] = seq 151 | img_id_batch[i] = self.image_ids[ix] 152 | end 153 | 154 | local data_augment = false 155 | if split == 'train' then 156 | data_augment = true 157 | end 158 | 159 | local h,w = img_batch_raw:size(3), img_batch_raw:size(4) 160 | local cnn_input_size = 224 161 | -- cropping data augmentation, if needed 162 | if h > cnn_input_size or w > cnn_input_size then 163 | local xoff, yoff 164 | if data_augment then 165 | xoff, yoff = torch.random(w-cnn_input_size), torch.random(h-cnn_input_size) 166 | else 167 | -- sample the center 168 | xoff, yoff = math.ceil((w-cnn_input_size)/2), math.ceil((h-cnn_input_size)/2) 169 | end 170 | -- crop. 171 | img_batch_raw = img_batch_raw[{ {}, {}, {yoff,yoff+cnn_input_size-1}, {xoff,xoff+cnn_input_size-1}}] 172 | end 173 | 174 | img_batch_raw = img_batch_raw:float():div(255) 175 | 176 | local batch_data = {} 177 | batch_data.labels = label_batch:transpose(1,2):contiguous() 178 | batch_data.images = img_batch_raw 179 | batch_data.img_id = img_id_batch 180 | 181 | self.iterator[split] = self.iterator[split] + batch_size 182 | return batch_data 183 | end 184 | end 185 | 186 | -------------------------------------------------------------------------------- /visu/LanguageModel_visu.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'misc.LookupTableMaskZero' 3 | local utils = require 'misc.utils' 4 | local net_utils = require 'misc.net_utils' 5 | local LSTM = require 'misc.LSTM' 6 | 7 | local attention = require 'visu.attention_visu' 8 | local img_embedding = require 'misc.img_embedding' 9 | 10 | ------------------------------------------------------------------------------- 11 | -- Language Model core 12 | ------------------------------------------------------------------------------- 13 | 14 | local layer, parent = torch.class('nn.LanguageModel', 'nn.Module') 15 | function layer:__init(opt) 16 | parent.__init(self) 17 | 18 | -- options for core network 19 | self.vocab_size = utils.getopt(opt, 'vocab_size') -- required 20 | self.input_encoding_size = utils.getopt(opt, 'input_encoding_size') 21 | self.n_rnn_layer = utils.getopt(opt, 'n_rnn_layer', 1) 22 | 23 | self.rnn_size = utils.getopt(opt, 'rnn_size') 24 | self.num_layers = utils.getopt(opt, 'num_layers', 1) 25 | local dropout = utils.getopt(opt, 'dropout', 0) 26 | 27 | self.fc_size = utils.getopt(opt, 'fc_size', 4096) 28 | self.conv_size = utils.getopt(opt, 'conv_size', 512) 29 | 30 | -- options for Language Model 31 | self.seq_length = utils.getopt(opt, 'seq_length') 32 | 33 | print('rnn_size: ' .. self.rnn_size .. ' num_layers: ' .. self.num_layers) 34 | print('input_encoding_size: ' .. self.input_encoding_size) 35 | print('dropout rate: ' .. dropout) 36 | 37 | -- create the core lstm network. note +1 for both the START and END tokens 38 | self.core = LSTM.lstm(self.input_encoding_size, self.rnn_size, self.num_layers, dropout) 39 | 40 | self.lookup_table = nn.Sequential() 41 | :add(nn.LookupTableMaskZero(self.vocab_size+1, self.input_encoding_size)) 42 | :add(nn.ReLU()) 43 | :add(nn.Dropout(dropout)) 44 | 45 | self.img_embedding = img_embedding.img_embedding(self.input_encoding_size, self.fc_size, self.conv_size, 49, dropout) 46 | 47 | self.attention = attention.attention(self.input_encoding_size, self.rnn_size, self.vocab_size+1, dropout) 48 | 49 | self:_createInitState(1) -- will be lazily resized later during forward passes 50 | end 51 | 52 | function layer:_createInitState(batch_size) 53 | assert(batch_size ~= nil, 'batch size must be provided') 54 | -- construct the initial state for the LSTM 55 | if not self.init_state then self.init_state = {} end -- lazy init 56 | for h=1,self.num_layers*2 do 57 | -- note, the init state Must be zeros because we are using init_state to init grads in backward call too 58 | if self.init_state[h] then 59 | if self.init_state[h]:size(1) ~= batch_size then 60 | self.init_state[h]:resize(batch_size, self.rnn_size):zero() -- expand the memory 61 | end 62 | else 63 | self.init_state[h] = torch.zeros(batch_size, self.rnn_size) 64 | end 65 | end 66 | self.num_state = #self.init_state 67 | end 68 | 69 | 70 | function layer:createClones() 71 | -- construct the net clones 72 | print('constructing clones inside the LanguageModel') 73 | self.clones = {self.core} 74 | self.lookup_tables = {self.lookup_table} 75 | self.attentions = {self.attention} 76 | for t=2,self.seq_length+1 do 77 | self.clones[t] = self.core:clone('weight', 'bias', 'gradWeight', 'gradBias') 78 | self.lookup_tables[t] = self.lookup_table:clone('weight', 'gradWeight') 79 | self.attentions[t] = self.attention:clone('weight', 'bias', 'gradWeight', 'gradBias') 80 | end 81 | end 82 | 83 | 84 | function layer:getModulesList() 85 | return {self.core, self.lookup_table, self.img_embedding, self.attention} 86 | end 87 | 88 | function layer:parameters() 89 | -- we only have two internal modules, return their params 90 | local p1,g1 = self.core:parameters() 91 | local p2,g2 = self.lookup_table:parameters() 92 | local p3,g3 = self.img_embedding:parameters() 93 | local p4,g4 = self.attention:parameters() 94 | 95 | 96 | local params = {} 97 | for k,v in pairs(p1) do table.insert(params, v) end 98 | for k,v in pairs(p2) do table.insert(params, v) end 99 | for k,v in pairs(p3) do table.insert(params, v) end 100 | for k,v in pairs(p4) do table.insert(params, v) end 101 | 102 | local grad_params = {} 103 | for k,v in pairs(g1) do table.insert(grad_params, v) end 104 | for k,v in pairs(g2) do table.insert(grad_params, v) end 105 | for k,v in pairs(g3) do table.insert(grad_params, v) end 106 | for k,v in pairs(g4) do table.insert(grad_params, v) end 107 | 108 | return params, grad_params 109 | end 110 | 111 | function layer:training() 112 | for k,v in pairs(self.clones) do v:training() end 113 | for k,v in pairs(self.lookup_tables) do v:training() end 114 | for k,v in pairs(self.attentions) do v:training() end 115 | self.img_embedding:training() 116 | 117 | end 118 | 119 | function layer:evaluate() 120 | for k,v in pairs(self.clones) do v:evaluate() end 121 | for k,v in pairs(self.lookup_tables) do v:evaluate() end 122 | for k,v in pairs(self.attentions) do v:evaluate() end 123 | self.img_embedding:evaluate() 124 | end 125 | 126 | function layer:sample(inputs, opt) 127 | local conv = inputs[1] 128 | local fc = inputs[2] 129 | local ix_to_word = inputs[3] 130 | 131 | 132 | local sample_max = utils.getopt(opt, 'sample_max', 1) 133 | local beam_size = utils.getopt(opt, 'beam_size', 1) 134 | 135 | local temperature = utils.getopt(opt, 'temperature', 1.0) 136 | 137 | local batch_size = fc:size(1) 138 | 139 | if sample_max == 1 and beam_size > 1 then return self:sample_beam(inputs, opt) end -- indirection for beam search 140 | 141 | self:_createInitState(batch_size) 142 | local state = self.init_state 143 | 144 | local img_input = {conv, fc} 145 | local conv_feat, conv_feat_embed, fc_embed = unpack(self.img_embedding:forward(img_input)) 146 | 147 | -- we will write output predictions into tensor seq 148 | local seq = torch.LongTensor(self.seq_length, batch_size):zero() 149 | local seqLogprobs = torch.FloatTensor(self.seq_length, batch_size) 150 | 151 | local logprobs -- logprobs predicted in last time step 152 | local x_xt 153 | 154 | for t=1,self.seq_length+1 do 155 | local xt, it, sampleLogprobs 156 | if t == 1 then 157 | it = torch.LongTensor(batch_size):fill(self.vocab_size+1) 158 | xt = self.lookup_table:forward(it) 159 | else 160 | -- take predictions from previous time step and feed them in 161 | if sample_max == 1 then 162 | -- use argmax "sampling" 163 | sampleLogprobs, it = torch.max(logprobs, 2) 164 | it = it:view(-1):long() 165 | else 166 | -- sample from the distribution of previous predictions 167 | local prob_prev 168 | if temperature == 1.0 then 169 | prob_prev = torch.exp(logprobs) -- fetch prev distribution: shape Nx(M+1) 170 | else 171 | -- scale logprobs by temperature 172 | prob_prev = torch.exp(torch.div(logprobs, temperature)) 173 | end 174 | it = torch.multinomial(prob_prev, 1) 175 | sampleLogprobs = logprobs:gather(2, it) -- gather the logprobs at sampled positions 176 | it = it:view(-1):long() -- and flatten indices for downstream processing 177 | end 178 | xt = self.lookup_table:forward(it) 179 | end 180 | 181 | if t >= 2 then 182 | seq[t-1] = it -- record the samples 183 | seqLogprobs[t-1] = sampleLogprobs:view(-1):float() -- and also their log likelihoods 184 | end 185 | 186 | local inputs = {xt,fc_embed, unpack(state)} 187 | local out = self.core:forward(inputs) 188 | state = {} 189 | for i=1,self.num_state do table.insert(state, out[i]) end 190 | 191 | local h_out = out[self.num_state+1] 192 | local p_out = out[self.num_state+2] 193 | 194 | local atten_input = {h_out, p_out, conv_feat, conv_feat_embed} 195 | logprobs, attenprobs = unpack(self.attention:forward(atten_input)) 196 | end 197 | -- return the samples and their log likelihoods 198 | return seq, seqLogprobs 199 | end 200 | 201 | 202 | function layer:sample_beam(inputs, opt) 203 | local beam_size = utils.getopt(opt, 'beam_size', 10) 204 | 205 | local conv = inputs[1] 206 | local fc = inputs[2] 207 | local ix_to_word = inputs[3] 208 | 209 | local batch_size = fc:size(1) 210 | local function compare(a,b) return a.p > b.p end -- used downstream 211 | 212 | assert(beam_size <= self.vocab_size+1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed') 213 | 214 | local img_input = {conv, fc} 215 | local conv_feat, conv_feat_embed, fc_embed = unpack(self.img_embedding:forward(img_input)) 216 | 217 | local seq = torch.LongTensor(self.seq_length, batch_size):zero() 218 | local seqLogprobs = torch.FloatTensor(self.seq_length, batch_size) 219 | local seqLogprobs_sum = torch.FloatTensor(batch_size) 220 | local atten = torch.FloatTensor(self.seq_length, batch_size, 50) 221 | 222 | -- lets process every image independently for now, for simplicity 223 | for k=1,batch_size do 224 | 225 | -- create initial states for all beams 226 | self:_createInitState(beam_size) 227 | local state = self.init_state 228 | 229 | -- we will write output predictions into tensor seq 230 | local beam_seq = torch.LongTensor(self.seq_length, beam_size):zero() 231 | local beam_seq_logprobs = torch.FloatTensor(self.seq_length, beam_size):zero() 232 | local beam_logprobs_sum = torch.zeros(beam_size) -- running sum of logprobs for each beam 233 | local logprobs, attenprobs -- logprobs predicted in last time step, shape (beam_size, vocab_size+1) 234 | local done_beams = {} 235 | local imgk = fc_embed[{ {k,k} }]:expand(beam_size, self.input_encoding_size) -- k'th image feature expanded out 236 | local conv_feat_k = conv_feat[{ {k,k} }]:expand(beam_size, conv_feat:size(2), self.input_encoding_size) -- k'th image feature expanded out 237 | local conv_feat_embed_k = conv_feat_embed[{ {k,k} }]:expand(beam_size, conv_feat_embed:size(2), self.input_encoding_size) -- k'th image feature expanded out 238 | local attention = torch.FloatTensor(self.seq_length, beam_size, 50):zero() 239 | 240 | for t=1,self.seq_length+1 do 241 | 242 | local xt, it, sampleLogprobs 243 | local new_state 244 | local new_atten = torch.Tensor(beam_size,50):zero() 245 | if t == 1 then 246 | -- feed in the start tokens 247 | it = torch.LongTensor(beam_size):fill(self.vocab_size+1) 248 | xt = self.lookup_table:forward(it) 249 | else 250 | --[[ 251 | perform a beam merge. that is, 252 | for every previous beam we now many new possibilities to branch out 253 | we need to resort our beams to maintain the loop invariant of keeping 254 | the top beam_size most likely sequences. 255 | ]]-- 256 | local logprobsf = logprobs:float() -- lets go to CPU for more efficiency in indexing operations 257 | ys,ix = torch.sort(logprobsf,2,true) -- sorted array of logprobs along each previous beam (last true = descending) 258 | 259 | local candidates = {} 260 | local cols = math.min(beam_size,ys:size(2)) 261 | local rows = beam_size 262 | if t == 2 then rows = 1 end -- at first time step only the first beam is active 263 | for c=1,cols do -- for each column (word, essentially) 264 | for q=1,rows do -- for each beam expansion 265 | -- compute logprob of expanding beam q with word in (sorted) position c 266 | local local_logprob = ys[{ q,c }] 267 | local candidate_logprob = beam_logprobs_sum[q] + local_logprob 268 | table.insert(candidates, {c=ix[{ q,c }], q=q, p=candidate_logprob, r=local_logprob }) 269 | end 270 | end 271 | 272 | table.sort(candidates, compare) -- find the best c,q pairs 273 | 274 | -- construct new beams 275 | new_state = net_utils.clone_list(state) 276 | local beam_seq_prev, beam_seq_logprobs_prev 277 | if t > 2 then 278 | -- well need these as reference when we fork beams around 279 | beam_seq_prev = beam_seq[{ {1,t-2}, {} }]:clone() 280 | beam_seq_logprobs_prev = beam_seq_logprobs[{ {1,t-2}, {} }]:clone() 281 | end 282 | 283 | for vix=1,beam_size do 284 | local v = candidates[vix] 285 | -- fork beam index q into index vix 286 | if t > 2 then 287 | beam_seq[{ {1,t-2}, vix }] = beam_seq_prev[{ {}, v.q }] 288 | beam_seq_logprobs[{ {1,t-2}, vix }] = beam_seq_logprobs_prev[{ {}, v.q }] 289 | end 290 | -- rearrange recurrent states 291 | for state_ix = 1,#new_state do 292 | -- copy over state in previous beam q to new beam at vix 293 | new_state[state_ix][vix] = state[state_ix][v.q] 294 | end 295 | 296 | new_atten[{vix,{}}] = attenprobs[{v.q,{}}] 297 | 298 | -- append new end terminal at the end of this beam 299 | beam_seq[{ t-1, vix }] = v.c -- c'th word is the continuation 300 | beam_seq_logprobs[{ t-1, vix }] = v.r -- the raw logprob here 301 | beam_logprobs_sum[vix] = v.p -- the new (sum) logprob along this beam 302 | 303 | if v.c == self.vocab_size+1 or t == self.seq_length+1 then 304 | -- END token special case here, or we reached the end. 305 | -- add the beam to a set of done beams 306 | table.insert(done_beams, {seq = beam_seq[{ {}, vix }]:clone(), 307 | logps = beam_seq_logprobs[{ {}, vix }]:clone(), 308 | p = beam_logprobs_sum[vix], 309 | idx = vix 310 | }) 311 | end 312 | end 313 | -- encode as vectors 314 | it = beam_seq[t-1] 315 | xt = self.lookup_table:forward(it) 316 | end 317 | 318 | if t > 1 then 319 | attention:narrow(1,t-1,1):copy(new_atten) 320 | end 321 | if new_state then state = new_state end -- swap rnn state, if we reassinged beams 322 | 323 | local inputs = {xt,imgk,unpack(state)} 324 | local out = self.core:forward(inputs) 325 | state = {} 326 | for i=1,self.num_state do table.insert(state, out[i]) end 327 | local h_out = out[self.num_state+1] 328 | local p_out = out[self.num_state+2] 329 | local atten_input = {h_out, p_out, conv_feat_k, conv_feat_embed_k} 330 | logprobs, attenprobs = unpack(self.attention:forward(atten_input)) 331 | attenprobs = attenprobs:view(beam_size, 50):float() 332 | end 333 | 334 | table.sort(done_beams, compare) 335 | 336 | seq[{ {}, k }] = done_beams[1].seq -- the first beam has highest cumulative score 337 | seqLogprobs[{ {}, k }] = done_beams[1].logps 338 | seqLogprobs_sum[k]=done_beams[1].p 339 | 340 | atten[{{}, k, {}}] = attention[{{},{done_beams[1].idx},{}}] 341 | end 342 | 343 | -- return the samples and their log likelihoods 344 | return seq, atten 345 | end 346 | 347 | 348 | 349 | function layer:updateOutput(input) 350 | local conv = input[1] 351 | local fc = input[2] 352 | local seq = input[3] 353 | 354 | assert(seq:size(1) == self.seq_length) 355 | local batch_size = seq:size(2) 356 | 357 | self:_createInitState(batch_size) 358 | local atten = torch.FloatTensor(self.seq_length+1, batch_size, 50) 359 | 360 | -- first get the nearest neighbor representation. 361 | self.output:resize(self.seq_length+1, batch_size, self.vocab_size+1):zero() 362 | 363 | self.img_input = {conv, fc} 364 | self.conv_feat, self.conv_feat_embed, self.fc_embed = unpack(self.img_embedding:forward(self.img_input)) 365 | 366 | self.state = {[0] = self.init_state} 367 | self.inputs = {} 368 | self.atten_inputs = {} 369 | --self.x_inputs = {} 370 | self.lookup_tables_inputs = {} 371 | self.tmax = 0 -- we will keep track of max sequence length encountered in the data for efficiency 372 | 373 | for t = 1,self.seq_length+1 do 374 | local can_skip = false 375 | local xt 376 | if t == 1 then 377 | -- feed in the start tokens 378 | local it = torch.LongTensor(batch_size):fill(self.vocab_size+1) 379 | self.lookup_tables_inputs[t] = it 380 | xt = self.lookup_table:forward(it) -- NxK sized input (token embedding vectors) 381 | else 382 | -- feed in the rest of the sequence... 383 | local it = seq[t-1]:clone() 384 | if torch.sum(it) == 0 then 385 | -- computational shortcut for efficiency. All sequences have already terminated and only 386 | -- contain null tokens from here on. We can skip the rest of the forward pass and save time 387 | can_skip = true 388 | end 389 | 390 | if not can_skip then 391 | self.lookup_tables_inputs[t] = it 392 | xt = self.lookup_tables[t]:forward(it) 393 | end 394 | end 395 | 396 | if not can_skip then 397 | -- construct the inputs 398 | self.inputs[t] = {xt, self.fc_embed, unpack(self.state[t-1])} 399 | 400 | -- forward the network 401 | local out = self.clones[t]:forward(self.inputs[t]) 402 | 403 | -- insert the hidden state 404 | self.state[t] = {} -- the rest is state 405 | for i=1,self.num_state do table.insert(self.state[t], out[i]) end 406 | local h_out = out[self.num_state+1] 407 | local p_out = out[self.num_state+2] 408 | 409 | --forward the attention 410 | self.atten_inputs[t] = {h_out, p_out, self.conv_feat, self.conv_feat_embed} 411 | local atten_out, attenprobs = unpack(self.attention:forward(self.atten_inputs[t])) 412 | 413 | atten:narrow(1,t,1):copy(attenprobs:view(batch_size, 50)) 414 | self.output:narrow(1,t,1):copy(atten_out) 415 | self.tmax = t 416 | end 417 | end 418 | 419 | return self.output, atten 420 | end 421 | 422 | --[[ 423 | gradOutput is an (D+2)xNx(M+1) Tensor. 424 | --]] 425 | function layer:updateGradInput(input, gradOutput) 426 | local dconv, dconv_embed, dfc-- grad on input images 427 | 428 | local batch_size = self.output:size(2) 429 | -- go backwards and lets compute gradients 430 | local dstate = self.init_state -- this works when init_state is all zeros 431 | 432 | for t=self.tmax,1,-1 do 433 | 434 | local d_atten = self.attentions[t]:backward(self.atten_inputs[t], gradOutput[t]) 435 | if not dconv then dconv = d_atten[3] else dconv:add(d_atten[3]) end 436 | if not dconv_embed then dconv_embed = d_atten[4] else dconv_embed:add(d_atten[4]) end 437 | 438 | local dout = {} 439 | for k=1, self.num_state do table.insert(dout, dstate[k]) end 440 | table.insert(dout, d_atten[1]) 441 | table.insert(dout, d_atten[2]) 442 | 443 | local dinputs = self.clones[t]:backward(self.inputs[t], dout) 444 | 445 | local dxt = dinputs[1] -- first element is the input vector 446 | if not dfc then dfc = dinputs[2] else dfc:add(dinputs[2]) end 447 | 448 | dstate = {} -- copy over rest to state grad 449 | for k=3,self.num_state+2 do table.insert(dstate, dinputs[k]) end 450 | 451 | -- continue backprop of xt 452 | local it = self.lookup_tables_inputs[t] 453 | self.lookup_tables[t]:backward(it, dxt) -- backprop into lookup table 454 | end 455 | 456 | -- backprob to the visual features. 457 | local dimgs_cnn, dfc_cnn = unpack(self.img_embedding:backward(self.img_input, {dconv, dconv_embed, dfc})) 458 | 459 | -- we have gradient on image, but for LongTensor gt sequence we only create an empty tensor - can't backprop 460 | self.gradInput = {dimgs_cnn, dfc_cnn} 461 | return self.gradInput 462 | end 463 | 464 | ------------------------------------------------------------------------------- 465 | -- Language Model-aware Criterion 466 | ------------------------------------------------------------------------------- 467 | 468 | local crit, parent = torch.class('nn.LanguageModelCriterion', 'nn.Criterion') 469 | function crit:__init() 470 | parent.__init(self) 471 | end 472 | 473 | function crit:updateOutput(inputs) 474 | local input = inputs[1] 475 | local seq = inputs[2] 476 | --local seq_len = inputs[3] 477 | 478 | local L,N,Mp1 = input:size(1), input:size(2), input:size(3) 479 | local D = seq:size(1) 480 | assert(D == L-1, 'input Tensor should be 1 larger in time') 481 | 482 | --[[ 483 | -- making the seq with the end token. 484 | local eseq = eseq or torch.CudaTensor() 485 | eseq:resize(L, N):zero() 486 | eseq:narrow(1,1,D):copy(seq) -- copy the seq 487 | 488 | eseq:scatter(1,seq_len:add(1):view(1,-1), Mp1) -- insert the END token 489 | eseq = eseq:view(-1,1) 490 | -- making a mask by using the seq 491 | local mask = mask or torch.CudaByteTensor() 492 | mask:resize(eseq:size()):zero() 493 | mask[torch.eq(eseq, 0)] = 1 494 | 495 | local input_reshape = input:view(-1, Mp1):clone() 496 | self.gradInput:resizeAs(input_reshape):zero() -- reset to zeros 497 | 498 | local loss_vec = input_reshape:gather(2, eseq):maskedFill(mask,0) 499 | local n = mask:size(1) - torch.sum(mask) 500 | self.output = -torch.sum(loss_vec) / n 501 | 502 | self.gradInput:scatter(2,eseq,-1):maskedFill(mask:expandAs(self.gradInput),0):div(n) 503 | 504 | self.gradInput = self.gradInput:viewAs(input) 505 | ]]-- 506 | 507 | self.gradInput:resizeAs(input):zero() 508 | local loss = 0 509 | local n = 0 510 | for b=1,N do -- iterate over batches 511 | local first_time = true 512 | for t=1,L do -- iterate over sequence time (ignore t=1, dummy forward for the image) 513 | -- fetch the index of the next token in the sequence 514 | local target_index 515 | if t > D then -- we are out of bounds of the index sequence: pad with null tokens 516 | target_index = 0 517 | else 518 | target_index = seq[{t,b}] 519 | end 520 | -- the first time we see null token as next index, actually want the model to predict the END token 521 | if target_index == 0 and first_time then 522 | target_index = Mp1 523 | first_time = false 524 | end 525 | 526 | -- if there is a non-null next token, enforce loss! 527 | if target_index ~= 0 then 528 | -- accumulate loss 529 | loss = loss - input[{ t,b,target_index }] -- log(p) 530 | self.gradInput[{ t,b,target_index }] = -1 531 | n = n + 1 532 | end 533 | end 534 | end 535 | self.output = loss / n -- normalize by number of predictions that were made 536 | self.gradInput:div(n) 537 | 538 | return self.output 539 | end 540 | 541 | function crit:updateGradInput(inputs) 542 | return self.gradInput 543 | end 544 | -------------------------------------------------------------------------------- /visu/attention_visu.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'nngraph' 3 | 4 | local attention = {} 5 | function attention.attention(input_size, rnn_size, output_size, dropout) 6 | local inputs = {} 7 | local outputs = {} 8 | table.insert(inputs, nn.Identity()()) -- top_h 9 | table.insert(inputs, nn.Identity()()) -- fake_region 10 | table.insert(inputs, nn.Identity()()) -- conv_feat 11 | table.insert(inputs, nn.Identity()()) -- conv_feat_embed 12 | 13 | local h_out = inputs[1] 14 | local fake_region = inputs[2] 15 | local conv_feat = inputs[3] 16 | local conv_feat_embed = inputs[4] 17 | 18 | local fake_region = nn.ReLU()(nn.Linear(rnn_size, input_size)(fake_region)) 19 | -- view neighbor from bach_size * neighbor_num x rnn_size to bach_size x rnn_size * neighbor_num 20 | if dropout > 0 then fake_region = nn.Dropout(dropout)(fake_region) end 21 | 22 | local fake_region_embed = nn.Linear(input_size, input_size)(fake_region) 23 | 24 | local h_out_linear = nn.Tanh()(nn.Linear(rnn_size, input_size)(h_out)) 25 | if dropout > 0 then h_out_linear = nn.Dropout(dropout)(h_out_linear) end 26 | 27 | local h_out_embed = nn.Linear(input_size, input_size)(h_out_linear) 28 | 29 | local txt_replicate = nn.Replicate(50,2)(h_out_embed) 30 | 31 | local img_all = nn.JoinTable(2)({nn.View(-1,1,input_size)(fake_region), conv_feat}) 32 | local img_all_embed = nn.JoinTable(2)({nn.View(-1,1,input_size)(fake_region_embed), conv_feat_embed}) 33 | 34 | local hA = nn.Tanh()(nn.CAddTable()({img_all_embed, txt_replicate})) 35 | if dropout > 0 then hA = nn.Dropout(dropout)(hA) end 36 | local hAflat = nn.Linear(input_size,1)(nn.View(input_size):setNumInputDims(2)(hA)) 37 | local PI = nn.SoftMax()(nn.View(50):setNumInputDims(2)(hAflat)) 38 | 39 | local probs3dim = nn.View(1,-1):setNumInputDims(1)(PI) 40 | local visAtt = nn.MM(false, false)({probs3dim, img_all}) 41 | local visAttdim = nn.View(input_size):setNumInputDims(2)(visAtt) 42 | local atten_out = nn.CAddTable()({visAttdim, h_out_linear}) 43 | 44 | local h = nn.Tanh()(nn.Linear(input_size, input_size)(atten_out)) 45 | if dropout > 0 then h = nn.Dropout(dropout)(h) end 46 | local proj = nn.Linear(input_size, output_size)(h) 47 | 48 | local logsoft = nn.LogSoftMax()(proj) 49 | --local logsoft = nn.SoftMax()(proj) 50 | table.insert(outputs, logsoft) 51 | table.insert(outputs, probs3dim) 52 | 53 | return nn.gModule(inputs, outputs) 54 | end 55 | return attention -------------------------------------------------------------------------------- /visu/visAtten.lua: -------------------------------------------------------------------------------- 1 | require 'image' 2 | require 'hdf5' 3 | 4 | cjson=require('cjson') 5 | 6 | function read_json(path) 7 | local file = io.open(path, 'r') 8 | local text = file:read() 9 | file:close() 10 | local info = cjson.decode(text) 11 | return info 12 | end 13 | 14 | function write_json(path, j) 15 | -- API reference http://www.kyne.com.au/~mark/software/lua-cjson-manual.html#encode 16 | cjson.encode_sparse_array(true, 2, 10) 17 | local text = cjson.encode(j) 18 | local file = io.open(path, 'w') 19 | file:write(text) 20 | file:close() 21 | end 22 | local path = 'flickr30k_gt/' 23 | 24 | local atten = torch.load(path .. 'atten.t7') 25 | local caption = read_json(path ..'visu.json') 26 | 27 | local json_file = read_json('/data/flickr30k/cocotalk.json') 28 | local id2path = {} 29 | for i = 1,5000 do 30 | id = json_file['images'][i]['id'] 31 | path = json_file['images'][i]['file_path'] 32 | id2path[id] = path 33 | end 34 | 35 | 36 | local num = #caption 37 | local cap_all = {} 38 | --local scale_map = torch.FloatTensor(num, 23, 224, 224) 39 | local atten_weight = torch.FloatTensor(num, 23) 40 | local atten_map_original = torch.FloatTensor(num, 23, 7, 7) 41 | 42 | f = io.open('flickr30k.txt', 'w') 43 | for t = 1, num do 44 | local cap = caption[t]['caption'] 45 | local img_id = caption[t]['image_id'] 46 | local atten_map = atten[{{},{t},{}}]:contiguous():view(23,50) 47 | 48 | for i = 1, atten_map:size(1) do 49 | local map = atten_map:sub(i,i, 2, 50):view(7,7) 50 | atten_map_original:sub(t,t,i,i):copy(map) 51 | atten_weight:sub(t,t,i,i):copy(atten_map:sub(i,i, 1, 1)) 52 | --scale_map:sub(t, t, i,i):copy(image.scale(map, 224, 224, 'bicubic')) 53 | end 54 | 55 | f:write(tostring(id2path[img_id])) 56 | f:write('\t') 57 | f:write(cap) 58 | f:write('\n') 59 | table.insert(cap_all, {cap, img_id}) 60 | end 61 | f.close() 62 | 63 | print({atten_weight}) 64 | local myFile = hdf5.open('atten_img.h5', 'w') 65 | --myFile:write('map', scale_map) 66 | myFile:write('atten_weight', atten_weight) 67 | myFile:write('atten_map_original', atten_map_original) 68 | myFile:close() 69 | write_json('cap.json', cap_all) 70 | 71 | 72 | 73 | --------------------------------------------------------------------------------