├── .gitmodules ├── ADVANCED.md ├── README.md ├── dataloader.py ├── dataloaderraw.py ├── eval.py ├── eval_ensemble.py ├── eval_utils.py ├── misc ├── __init__.py ├── resnet.py ├── resnet_utils.py ├── rewards.py └── utils.py ├── models ├── AttEnsemble.py ├── AttModel.py ├── CaptionModel.py ├── FCModel.py ├── OldModel.py ├── ShowTellModel.py ├── TransformerModel.py └── __init__.py ├── opts.py ├── scripts ├── copy_model.sh ├── make_bu_data.py ├── prepro_feats.py ├── prepro_labels.py └── prepro_ngrams.py ├── train.py └── vis ├── imgs └── dummy ├── index.html └── jquery-1.8.3.min.js /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "cider"] 2 | path = cider 3 | url = https://github.com/ruotianluo/cider.git 4 | -------------------------------------------------------------------------------- /ADVANCED.md: -------------------------------------------------------------------------------- 1 | # Advanced 2 | 3 | ## Ensemble 4 | 5 | ## Batch normalization 6 | 7 | ## Box feature -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Transformer for captioning 2 | 3 | # Note: This repository is deprecated, and the code has been merged to [self-critical.pytorch](https://github.com/ruotianluo/self-critical.pytorch). The same training script should work for self-critical too. 4 | 5 | This is an experiment to use transformer model to do captioning. Most of the code is copy from [Harvard detailed tutorial for transformer(http://nlp.seas.harvard.edu/2018/04/03/attention.html). 6 | 7 | Also, notice, this repository is a fork of my [self-critical.pytorch](https://github.com/ruotianluo/self-critical.pytorch) repository. Most of the code are shared. 8 | 9 | The addition to self-critical.pytorch is following: 10 | - transformer model 11 | - Add warmup adam for training transformer (important) 12 | - Add reduce_on_paltaeu (not really useful) 13 | 14 | A training script that could achieve 1.25 on validation set without beam search. 15 | 16 | ```bash 17 | id="transformer" 18 | ckpt_path="log_"$id 19 | if [ ! -d $ckpt_path ]; then 20 | mkdir $ckpt_path 21 | fi 22 | if [ ! -f $ckpt_path"/infos_"$id".pkl" ]; then 23 | start_from="" 24 | else 25 | start_from="--start_from "$ckpt_path 26 | fi 27 | 28 | python train.py --id $id --caption_model transformer --noamopt --noamopt_warmup 20000 --label_smoothing 0.0 --input_json data/cocotalk.json --input_label_h5 data/cocotalk_label.h5 --input_fc_dir data/cocobu_fc --input_att_dir data/cocobu_att --seq_per_img 5 --batch_size 10 --beam_size 1 --learning_rate 5e-4 --num_layers 6 --input_encoding_size 512 --rnn_size 2048 --learning_rate_decay_start 0 --scheduled_sampling_start 0 --checkpoint_path $ckpt_path $start_from --save_checkpoint_every 3000 --language_eval 1 --val_images_use 5000 --max_epochs 15 29 | 30 | python train.py --id $id --caption_model transformer --reduce_on_plateau --input_json data/cocotalk.json --input_label_h5 data/cocotalk_label.h5 --input_fc_dir data/cocobu_fc --input_att_dir data/cocobu_att --input_box_dir data/cocobu_box --seq_per_img 5 --batch_size 10 --beam_size 1 --learning_rate 1e-5 --num_layers 6 --input_encoding_size 512 --rnn_size 2048 --checkpoint_path $ckpt_path $start_from --save_checkpoint_every 3000 --language_eval 1 --val_images_use 5000 --self_critical_after 10 31 | ``` 32 | 33 | **Notice**: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: 34 | ``` 35 | N=num_layers 36 | d_model=input_encoding_size 37 | d_ff=rnn_size 38 | h is always 8 39 | ``` 40 | 41 | 42 | # Self-critical Sequence Training for Image Captioning (+ misc.) 43 | 44 | This repository includes the unofficial implementation [Self-critical Sequence Training for Image Captioning](https://arxiv.org/abs/1612.00563) and [Bottom-Up and Top-Down Attention for Image Captioning and Visual Question Answering](https://arxiv.org/abs/1707.07998). 45 | 46 | The author of SCST helped me a lot when I tried to replicate the result. Great thanks. The att2in2 model can achieve more than 1.20 Cider score on Karpathy's test split (with self-critical training, bottom-up feature, large rnn hidden size, without ensemble) 47 | 48 | This is based on my [ImageCaptioning.pytorch](https://github.com/ruotianluo/ImageCaptioning.pytorch) repository. The modifications is: 49 | - Self critical training. 50 | - Bottom up feature support from [ref](https://arxiv.org/abs/1707.07998). (Evaluation on arbitrary images is not supported.) 51 | - Ensemble 52 | - Multi-GPU training 53 | 54 | ## Requirements 55 | Python 2.7 (because there is no [coco-caption](https://github.com/tylin/coco-caption) version for python 3) 56 | PyTorch 0.4 (along with torchvision) 57 | cider (already been added as a submodule) 58 | 59 | (**Skip if you are using bottom-up feature**): If you want to use resnet to extract image features, you need to download pretrained resnet model for both training and evaluation. The models can be downloaded from [here](https://drive.google.com/open?id=0B7fNdx_jAqhtbVYzOURMdDNHSGM), and should be placed in `data/imagenet_weights`. 60 | 61 | ## Pretrained models (using resnet101 feature) 62 | Pretrained models are provided [here](https://drive.google.com/open?id=0B7fNdx_jAqhtdE1JRXpmeGJudTg). And the performances of each model will be maintained in this [issue](https://github.com/ruotianluo/neuraltalk2.pytorch/issues/10). 63 | 64 | If you want to do evaluation only, you can then follow [this section](#generate-image-captions) after downloading the pretrained models (and also the pretrained resnet101). 65 | 66 | ## Train your own network on COCO 67 | 68 | ### Download COCO captions and preprocess them 69 | 70 | Download preprocessed coco captions from [link](http://cs.stanford.edu/people/karpathy/deepimagesent/caption_datasets.zip) from Karpathy's homepage. Extract `dataset_coco.json` from the zip file and copy it in to `data/`. This file provides preprocessed captions and also standard train-val-test splits. 71 | 72 | Then do: 73 | 74 | ```bash 75 | $ python scripts/prepro_labels.py --input_json data/dataset_coco.json --output_json data/cocotalk.json --output_h5 data/cocotalk 76 | ``` 77 | 78 | `prepro_labels.py` will map all words that occur <= 5 times to a special `UNK` token, and create a vocabulary for all the remaining words. The image information and vocabulary are dumped into `data/cocotalk.json` and discretized caption data are dumped into `data/cocotalk_label.h5`. 79 | 80 | ### Download COCO dataset and pre-extract the image features (Skip if you are using bottom-up feature) 81 | 82 | Download the coco images from [link](http://mscoco.org/dataset/#download). We need 2014 training images and 2014 val. images. You should put the `train2014/` and `val2014/` in the same directory, denoted as `$IMAGE_ROOT`. 83 | 84 | Then: 85 | 86 | ``` 87 | $ python scripts/prepro_feats.py --input_json data/dataset_coco.json --output_dir data/cocotalk --images_root $IMAGE_ROOT 88 | ``` 89 | 90 | 91 | `prepro_feats.py` extract the resnet101 features (both fc feature and last conv feature) of each image. The features are saved in `data/cocotalk_fc` and `data/cocotalk_att`, and resulting files are about 200GB. 92 | 93 | (Check the prepro scripts for more options, like other resnet models or other attention sizes.) 94 | 95 | **Warning**: the prepro script will fail with the default MSCOCO data because one of their images is corrupted. See [this issue](https://github.com/karpathy/neuraltalk2/issues/4) for the fix, it involves manually replacing one image in the dataset. 96 | 97 | ### Download Bottom-up features (Skip if you are using resnet features) 98 | 99 | Download pre-extracted feature from [link](https://github.com/peteanderson80/bottom-up-attention). You can either download adaptive one or fixed one. 100 | 101 | For example: 102 | ``` 103 | mkdir data/bu_data; cd data/bu_data 104 | wget https://storage.googleapis.com/bottom-up-attention/trainval.zip 105 | unzip trainval.zip 106 | 107 | ``` 108 | 109 | Then: 110 | 111 | ```bash 112 | python script/make_bu_data.py --output_dir data/cocobu 113 | ``` 114 | 115 | This will create `data/cocobu_fc`, `data/cocobu_att` and `data/cocobu_box`. If you want to use bottom-up feature, you can just follow the following steps and replace all cocotalk with cocobu. 116 | 117 | ### Start training 118 | 119 | ```bash 120 | $ python train.py --id fc --caption_model fc --input_json data/cocotalk.json --input_fc_dir data/cocotalk_fc --input_att_dir data/cocotalk_att --input_label_h5 data/cocotalk_label.h5 --batch_size 10 --learning_rate 5e-4 --learning_rate_decay_start 0 --scheduled_sampling_start 0 --checkpoint_path log_fc --save_checkpoint_every 6000 --val_images_use 5000 --max_epochs 30 121 | ``` 122 | 123 | The train script will dump checkpoints into the folder specified by `--checkpoint_path` (default = `save/`). We only save the best-performing checkpoint on validation and the latest checkpoint to save disk space. 124 | 125 | To resume training, you can specify `--start_from` option to be the path saving `infos.pkl` and `model.pth` (usually you could just set `--start_from` and `--checkpoint_path` to be the same). 126 | 127 | If you have tensorflow, the loss histories are automatically dumped into `--checkpoint_path`, and can be visualized using tensorboard. 128 | 129 | The current command use scheduled sampling, you can also set scheduled_sampling_start to -1 to turn off scheduled sampling. 130 | 131 | If you'd like to evaluate BLEU/METEOR/CIDEr scores during training in addition to validation cross entropy loss, use `--language_eval 1` option, but don't forget to download the [coco-caption code](https://github.com/tylin/coco-caption) into `coco-caption` directory. 132 | 133 | For more options, see `opts.py`. 134 | 135 | **A few notes on training.** To give you an idea, with the default settings one epoch of MS COCO images is about 11000 iterations. After 1 epoch of training results in validation loss ~2.5 and CIDEr score of ~0.68. By iteration 60,000 CIDEr climbs up to about ~0.84 (validation loss at about 2.4 (under scheduled sampling)). 136 | 137 | ### Train using self critical 138 | 139 | First you should preprocess the dataset and get the cache for calculating cider score: 140 | ``` 141 | $ python scripts/prepro_ngrams.py --input_json .../dataset_coco.json --dict_json data/cocotalk.json --output_pkl data/coco-train --split train 142 | ``` 143 | 144 | Then, copy the model from the pretrained model using cross entropy. (It's not mandatory to copy the model, just for back-up) 145 | ``` 146 | $ bash scripts/copy_model.sh fc fc_rl 147 | ``` 148 | 149 | Then 150 | ```bash 151 | $ python train.py --id fc_rl --caption_model fc --input_json data/cocotalk.json --input_fc_dir data/cocotalk_fc --input_att_dir data/cocotalk_att --input_label_h5 data/cocotalk_label.h5 --batch_size 10 --learning_rate 5e-5 --start_from log_fc_rl --checkpoint_path log_fc_rl --save_checkpoint_every 6000 --language_eval 1 --val_images_use 5000 --self_critical_after 30 152 | ``` 153 | 154 | You will see a huge boost on Cider score, : ). 155 | 156 | **A few notes on training.** Starting self-critical training after 30 epochs, the CIDEr score goes up to 1.05 after 600k iterations (including the 30 epochs pertraining). 157 | 158 | ### Caption images after training 159 | 160 | ## Generate image captions 161 | 162 | ### Evaluate on raw images 163 | Now place all your images of interest into a folder, e.g. `blah`, and run 164 | the eval script: 165 | 166 | ```bash 167 | $ python eval.py --model model.pth --infos_path infos.pkl --image_folder blah --num_images 10 168 | ``` 169 | 170 | This tells the `eval` script to run up to 10 images from the given folder. If you have a big GPU you can speed up the evaluation by increasing `batch_size`. Use `--num_images -1` to process all images. The eval script will create an `vis.json` file inside the `vis` folder, which can then be visualized with the provided HTML interface: 171 | 172 | ```bash 173 | $ cd vis 174 | $ python -m SimpleHTTPServer 175 | ``` 176 | 177 | Now visit `localhost:8000` in your browser and you should see your predicted captions. 178 | 179 | ### Evaluate on Karpathy's test split 180 | 181 | ```bash 182 | $ python eval.py --dump_images 0 --num_images 5000 --model model.pth --infos_path infos.pkl --language_eval 1 183 | ``` 184 | 185 | The defualt split to evaluate is test. The default inference method is greedy decoding (`--sample_max 1`), to sample from the posterior, set `--sample_max 0`. 186 | 187 | **Beam Search**. Beam search can increase the performance of the search for greedy decoding sequence by ~5%. However, this is a little more expensive. To turn on the beam search, use `--beam_size N`, N should be greater than 1. 188 | 189 | ## Miscellanea 190 | **Using cpu**. The code is currently defaultly using gpu; there is even no option for switching. If someone highly needs a cpu model, please open an issue; I can potentially create a cpu checkpoint and modify the eval.py to run the model on cpu. However, there's no point using cpu to train the model. 191 | 192 | **Train on other dataset**. It should be trivial to port if you can create a file like `dataset_coco.json` for your own dataset. 193 | 194 | **Live demo**. Not supported now. Welcome pull request. 195 | 196 | ## For more advanced features: 197 | 198 | Checkout `ADVANCED.md`. 199 | 200 | ## Reference 201 | 202 | If you find this repo useful, please consider citing (no obligation at all): 203 | 204 | ``` 205 | @article{luo2018discriminability, 206 | title={Discriminability objective for training descriptive captions}, 207 | author={Luo, Ruotian and Price, Brian and Cohen, Scott and Shakhnarovich, Gregory}, 208 | journal={arXiv preprint arXiv:1803.04376}, 209 | year={2018} 210 | } 211 | ``` 212 | 213 | Of course, please cite the original paper of models you are using (You can find references in the model files). 214 | 215 | ## Acknowledgements 216 | 217 | Thanks the original [neuraltalk2](https://github.com/karpathy/neuraltalk2) and awesome PyTorch team. 218 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import json 6 | import h5py 7 | import os 8 | import numpy as np 9 | import random 10 | 11 | import torch 12 | import torch.utils.data as data 13 | 14 | import multiprocessing 15 | 16 | class DataLoader(data.Dataset): 17 | 18 | def reset_iterator(self, split): 19 | del self._prefetch_process[split] 20 | self._prefetch_process[split] = BlobFetcher(split, self, split=='train') 21 | self.iterators[split] = 0 22 | 23 | def get_vocab_size(self): 24 | return self.vocab_size 25 | 26 | def get_vocab(self): 27 | return self.ix_to_word 28 | 29 | def get_seq_length(self): 30 | return self.seq_length 31 | 32 | def __init__(self, opt): 33 | self.opt = opt 34 | self.batch_size = self.opt.batch_size 35 | self.seq_per_img = opt.seq_per_img 36 | 37 | # feature related options 38 | self.use_att = getattr(opt, 'use_att', True) 39 | self.use_box = getattr(opt, 'use_box', 0) 40 | self.norm_att_feat = getattr(opt, 'norm_att_feat', 0) 41 | self.norm_box_feat = getattr(opt, 'norm_box_feat', 0) 42 | 43 | # load the json file which contains additional information about the dataset 44 | print('DataLoader loading json file: ', opt.input_json) 45 | self.info = json.load(open(self.opt.input_json)) 46 | self.ix_to_word = self.info['ix_to_word'] 47 | self.vocab_size = len(self.ix_to_word) 48 | print('vocab size is ', self.vocab_size) 49 | 50 | # open the hdf5 file 51 | print('DataLoader loading h5 file: ', opt.input_fc_dir, opt.input_att_dir, opt.input_box_dir, opt.input_label_h5) 52 | self.h5_label_file = h5py.File(self.opt.input_label_h5, 'r', driver='core') 53 | 54 | self.input_fc_dir = self.opt.input_fc_dir 55 | self.input_att_dir = self.opt.input_att_dir 56 | self.input_box_dir = self.opt.input_box_dir 57 | 58 | # load in the sequence data 59 | seq_size = self.h5_label_file['labels'].shape 60 | self.seq_length = seq_size[1] 61 | print('max sequence length in data is', self.seq_length) 62 | # load the pointers in full to RAM (should be small enough) 63 | self.label_start_ix = self.h5_label_file['label_start_ix'][:] 64 | self.label_end_ix = self.h5_label_file['label_end_ix'][:] 65 | 66 | self.num_images = self.label_start_ix.shape[0] 67 | print('read %d image features' %(self.num_images)) 68 | 69 | # separate out indexes for each of the provided splits 70 | self.split_ix = {'train': [], 'val': [], 'test': []} 71 | for ix in range(len(self.info['images'])): 72 | img = self.info['images'][ix] 73 | if img['split'] == 'train': 74 | self.split_ix['train'].append(ix) 75 | elif img['split'] == 'val': 76 | self.split_ix['val'].append(ix) 77 | elif img['split'] == 'test': 78 | self.split_ix['test'].append(ix) 79 | elif opt.train_only == 0: # restval 80 | self.split_ix['train'].append(ix) 81 | 82 | print('assigned %d images to split train' %len(self.split_ix['train'])) 83 | print('assigned %d images to split val' %len(self.split_ix['val'])) 84 | print('assigned %d images to split test' %len(self.split_ix['test'])) 85 | 86 | self.iterators = {'train': 0, 'val': 0, 'test': 0} 87 | 88 | self._prefetch_process = {} # The three prefetch process 89 | for split in self.iterators.keys(): 90 | self._prefetch_process[split] = BlobFetcher(split, self, split=='train') 91 | # Terminate the child process when the parent exists 92 | def cleanup(): 93 | print('Terminating BlobFetcher') 94 | for split in self.iterators.keys(): 95 | del self._prefetch_process[split] 96 | import atexit 97 | atexit.register(cleanup) 98 | 99 | def get_captions(self, ix, seq_per_img): 100 | # fetch the sequence labels 101 | ix1 = self.label_start_ix[ix] - 1 #label_start_ix starts from 1 102 | ix2 = self.label_end_ix[ix] - 1 103 | ncap = ix2 - ix1 + 1 # number of captions available for this image 104 | assert ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t' 105 | 106 | if ncap < seq_per_img: 107 | # we need to subsample (with replacement) 108 | seq = np.zeros([seq_per_img, self.seq_length], dtype = 'int') 109 | for q in range(seq_per_img): 110 | ixl = random.randint(ix1,ix2) 111 | seq[q, :] = self.h5_label_file['labels'][ixl, :self.seq_length] 112 | else: 113 | ixl = random.randint(ix1, ix2 - seq_per_img + 1) 114 | seq = self.h5_label_file['labels'][ixl: ixl + seq_per_img, :self.seq_length] 115 | 116 | return seq 117 | 118 | def get_batch(self, split, batch_size=None, seq_per_img=None): 119 | batch_size = batch_size or self.batch_size 120 | seq_per_img = seq_per_img or self.seq_per_img 121 | 122 | fc_batch = [] # np.ndarray((batch_size * seq_per_img, self.opt.fc_feat_size), dtype = 'float32') 123 | att_batch = [] # np.ndarray((batch_size * seq_per_img, 14, 14, self.opt.att_feat_size), dtype = 'float32') 124 | label_batch = np.zeros([batch_size * seq_per_img, self.seq_length + 2], dtype = 'int') 125 | mask_batch = np.zeros([batch_size * seq_per_img, self.seq_length + 2], dtype = 'float32') 126 | 127 | wrapped = False 128 | 129 | infos = [] 130 | gts = [] 131 | 132 | for i in range(batch_size): 133 | # fetch image 134 | tmp_fc, tmp_att,\ 135 | ix, tmp_wrapped = self._prefetch_process[split].get() 136 | fc_batch.append(tmp_fc) 137 | att_batch.append(tmp_att) 138 | 139 | label_batch[i * seq_per_img : (i + 1) * seq_per_img, 1 : self.seq_length + 1] = self.get_captions(ix, seq_per_img) 140 | 141 | if tmp_wrapped: 142 | wrapped = True 143 | 144 | # Used for reward evaluation 145 | gts.append(self.h5_label_file['labels'][self.label_start_ix[ix] - 1: self.label_end_ix[ix]]) 146 | 147 | # record associated info as well 148 | info_dict = {} 149 | info_dict['ix'] = ix 150 | info_dict['id'] = self.info['images'][ix]['id'] 151 | info_dict['file_path'] = self.info['images'][ix]['file_path'] 152 | infos.append(info_dict) 153 | 154 | # #sort by att_feat length 155 | # fc_batch, att_batch, label_batch, gts, infos = \ 156 | # zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: len(x[1]), reverse=True)) 157 | fc_batch, att_batch, label_batch, gts, infos = \ 158 | zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: 0, reverse=True)) 159 | data = {} 160 | data['fc_feats'] = np.stack(reduce(lambda x,y:x+y, [[_]*seq_per_img for _ in fc_batch])) 161 | # merge att_feats 162 | max_att_len = max([_.shape[0] for _ in att_batch]) 163 | data['att_feats'] = np.zeros([len(att_batch)*seq_per_img, max_att_len, att_batch[0].shape[1]], dtype = 'float32') 164 | for i in range(len(att_batch)): 165 | data['att_feats'][i*seq_per_img:(i+1)*seq_per_img, :att_batch[i].shape[0]] = att_batch[i] 166 | data['att_masks'] = np.zeros(data['att_feats'].shape[:2], dtype='float32') 167 | for i in range(len(att_batch)): 168 | data['att_masks'][i*seq_per_img:(i+1)*seq_per_img, :att_batch[i].shape[0]] = 1 169 | # set att_masks to None if attention features have same length 170 | if data['att_masks'].sum() == data['att_masks'].size: 171 | data['att_masks'] = None 172 | 173 | data['labels'] = np.vstack(label_batch) 174 | # generate mask 175 | nonzeros = np.array(list(map(lambda x: (x != 0).sum()+2, data['labels']))) 176 | for ix, row in enumerate(mask_batch): 177 | row[:nonzeros[ix]] = 1 178 | data['masks'] = mask_batch 179 | 180 | data['gts'] = gts # all ground truth captions of each images 181 | data['bounds'] = {'it_pos_now': self.iterators[split], 'it_max': len(self.split_ix[split]), 'wrapped': wrapped} 182 | data['infos'] = infos 183 | 184 | return data 185 | 186 | # It's not coherent to make DataLoader a subclass of Dataset, but essentially, we only need to implement the following to functions, 187 | # so that the torch.utils.data.DataLoader can load the data according the index. 188 | # However, it's minimum change to switch to pytorch data loading. 189 | def __getitem__(self, index): 190 | """This function returns a tuple that is further passed to collate_fn 191 | """ 192 | ix = index #self.split_ix[index] 193 | if self.use_att: 194 | att_feat = np.load(os.path.join(self.input_att_dir, str(self.info['images'][ix]['id']) + '.npz'))['feat'] 195 | # Reshape to K x C 196 | att_feat = att_feat.reshape(-1, att_feat.shape[-1]) 197 | if self.norm_att_feat: 198 | att_feat = att_feat / np.linalg.norm(att_feat, 2, 1, keepdims=True) 199 | if self.use_box: 200 | box_feat = np.load(os.path.join(self.input_box_dir, str(self.info['images'][ix]['id']) + '.npy')) 201 | # devided by image width and height 202 | x1,y1,x2,y2 = np.hsplit(box_feat, 4) 203 | h,w = self.info['images'][ix]['height'], self.info['images'][ix]['width'] 204 | box_feat = np.hstack((x1/w, y1/h, x2/w, y2/h, (x2-x1)*(y2-y1)/(w*h))) # question? x2-x1+1?? 205 | if self.norm_box_feat: 206 | box_feat = box_feat / np.linalg.norm(box_feat, 2, 1, keepdims=True) 207 | att_feat = np.hstack([att_feat, box_feat]) 208 | # sort the features by the size of boxes 209 | att_feat = np.stack(sorted(att_feat, key=lambda x:x[-1], reverse=True)) 210 | else: 211 | att_feat = np.zeros((1,1,1)) 212 | return (np.load(os.path.join(self.input_fc_dir, str(self.info['images'][ix]['id']) + '.npy')), 213 | att_feat, 214 | ix) 215 | 216 | def __len__(self): 217 | return len(self.info['images']) 218 | 219 | class SubsetSampler(torch.utils.data.sampler.Sampler): 220 | r"""Samples elements randomly from a given list of indices, without replacement. 221 | Arguments: 222 | indices (list): a list of indices 223 | """ 224 | 225 | def __init__(self, indices): 226 | self.indices = indices 227 | 228 | def __iter__(self): 229 | return (self.indices[i] for i in range(len(self.indices))) 230 | 231 | def __len__(self): 232 | return len(self.indices) 233 | 234 | class BlobFetcher(): 235 | """Experimental class for prefetching blobs in a separate process.""" 236 | def __init__(self, split, dataloader, if_shuffle=False): 237 | """ 238 | db is a list of tuples containing: imcrop_name, caption, bbox_feat of gt box, imname 239 | """ 240 | self.split = split 241 | self.dataloader = dataloader 242 | self.if_shuffle = if_shuffle 243 | 244 | # Add more in the queue 245 | def reset(self): 246 | """ 247 | Two cases for this function to be triggered: 248 | 1. not hasattr(self, 'split_loader'): Resume from previous training. Create the dataset given the saved split_ix and iterator 249 | 2. wrapped: a new epoch, the split_ix and iterator have been updated in the get_minibatch_inds already. 250 | """ 251 | # batch_size is 1, the merge is done in DataLoader class 252 | self.split_loader = iter(data.DataLoader(dataset=self.dataloader, 253 | batch_size=1, 254 | sampler=SubsetSampler(self.dataloader.split_ix[self.split][self.dataloader.iterators[self.split]:]), 255 | shuffle=False, 256 | pin_memory=True, 257 | num_workers=4, # 4 is usually enough 258 | collate_fn=lambda x: x[0])) 259 | 260 | def _get_next_minibatch_inds(self): 261 | max_index = len(self.dataloader.split_ix[self.split]) 262 | wrapped = False 263 | 264 | ri = self.dataloader.iterators[self.split] 265 | ix = self.dataloader.split_ix[self.split][ri] 266 | 267 | ri_next = ri + 1 268 | if ri_next >= max_index: 269 | ri_next = 0 270 | if self.if_shuffle: 271 | random.shuffle(self.dataloader.split_ix[self.split]) 272 | wrapped = True 273 | self.dataloader.iterators[self.split] = ri_next 274 | 275 | return ix, wrapped 276 | 277 | def get(self): 278 | if not hasattr(self, 'split_loader'): 279 | self.reset() 280 | 281 | ix, wrapped = self._get_next_minibatch_inds() 282 | tmp = self.split_loader.next() 283 | if wrapped: 284 | self.reset() 285 | 286 | assert tmp[2] == ix, "ix not equal" 287 | 288 | return tmp + [wrapped] -------------------------------------------------------------------------------- /dataloaderraw.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import json 6 | import h5py 7 | import os 8 | import numpy as np 9 | import random 10 | import torch 11 | import skimage 12 | import skimage.io 13 | import scipy.misc 14 | 15 | from torchvision import transforms as trn 16 | preprocess = trn.Compose([ 17 | #trn.ToTensor(), 18 | trn.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 19 | ]) 20 | 21 | from misc.resnet_utils import myResnet 22 | import misc.resnet 23 | 24 | class DataLoaderRaw(): 25 | 26 | def __init__(self, opt): 27 | self.opt = opt 28 | self.coco_json = opt.get('coco_json', '') 29 | self.folder_path = opt.get('folder_path', '') 30 | 31 | self.batch_size = opt.get('batch_size', 1) 32 | self.seq_per_img = 1 33 | 34 | # Load resnet 35 | self.cnn_model = opt.get('cnn_model', 'resnet101') 36 | self.my_resnet = getattr(misc.resnet, self.cnn_model)() 37 | self.my_resnet.load_state_dict(torch.load('./data/imagenet_weights/'+self.cnn_model+'.pth')) 38 | self.my_resnet = myResnet(self.my_resnet) 39 | self.my_resnet.cuda() 40 | self.my_resnet.eval() 41 | 42 | 43 | 44 | # load the json file which contains additional information about the dataset 45 | print('DataLoaderRaw loading images from folder: ', self.folder_path) 46 | 47 | self.files = [] 48 | self.ids = [] 49 | 50 | print(len(self.coco_json)) 51 | if len(self.coco_json) > 0: 52 | print('reading from ' + opt.coco_json) 53 | # read in filenames from the coco-style json file 54 | self.coco_annotation = json.load(open(self.coco_json)) 55 | for k,v in enumerate(self.coco_annotation['images']): 56 | fullpath = os.path.join(self.folder_path, v['file_name']) 57 | self.files.append(fullpath) 58 | self.ids.append(v['id']) 59 | else: 60 | # read in all the filenames from the folder 61 | print('listing all images in directory ' + self.folder_path) 62 | def isImage(f): 63 | supportedExt = ['.jpg','.JPG','.jpeg','.JPEG','.png','.PNG','.ppm','.PPM'] 64 | for ext in supportedExt: 65 | start_idx = f.rfind(ext) 66 | if start_idx >= 0 and start_idx + len(ext) == len(f): 67 | return True 68 | return False 69 | 70 | n = 1 71 | for root, dirs, files in os.walk(self.folder_path, topdown=False): 72 | for file in files: 73 | fullpath = os.path.join(self.folder_path, file) 74 | if isImage(fullpath): 75 | self.files.append(fullpath) 76 | self.ids.append(str(n)) # just order them sequentially 77 | n = n + 1 78 | 79 | self.N = len(self.files) 80 | print('DataLoaderRaw found ', self.N, ' images') 81 | 82 | self.iterator = 0 83 | 84 | def get_batch(self, split, batch_size=None): 85 | batch_size = batch_size or self.batch_size 86 | 87 | # pick an index of the datapoint to load next 88 | fc_batch = np.ndarray((batch_size, 2048), dtype = 'float32') 89 | att_batch = np.ndarray((batch_size, 14, 14, 2048), dtype = 'float32') 90 | max_index = self.N 91 | wrapped = False 92 | infos = [] 93 | 94 | for i in range(batch_size): 95 | ri = self.iterator 96 | ri_next = ri + 1 97 | if ri_next >= max_index: 98 | ri_next = 0 99 | wrapped = True 100 | # wrap back around 101 | self.iterator = ri_next 102 | 103 | img = skimage.io.imread(self.files[ri]) 104 | 105 | if len(img.shape) == 2: 106 | img = img[:,:,np.newaxis] 107 | img = np.concatenate((img, img, img), axis=2) 108 | 109 | img = img.astype('float32')/255.0 110 | img = torch.from_numpy(img.transpose([2,0,1])).cuda() 111 | img = preprocess(img) 112 | with torch.no_grad(): 113 | tmp_fc, tmp_att = self.my_resnet(img) 114 | 115 | fc_batch[i] = tmp_fc.data.cpu().float().numpy() 116 | att_batch[i] = tmp_att.data.cpu().float().numpy() 117 | 118 | info_struct = {} 119 | info_struct['id'] = self.ids[ri] 120 | info_struct['file_path'] = self.files[ri] 121 | infos.append(info_struct) 122 | 123 | data = {} 124 | data['fc_feats'] = fc_batch 125 | data['att_feats'] = att_batch 126 | data['bounds'] = {'it_pos_now': self.iterator, 'it_max': self.N, 'wrapped': wrapped} 127 | data['infos'] = infos 128 | 129 | return data 130 | 131 | def reset_iterator(self, split): 132 | self.iterator = 0 133 | 134 | def get_vocab_size(self): 135 | return len(self.ix_to_word) 136 | 137 | def get_vocab(self): 138 | return self.ix_to_word 139 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import json 6 | import numpy as np 7 | 8 | import time 9 | import os 10 | from six.moves import cPickle 11 | 12 | import opts 13 | import models 14 | from dataloader import * 15 | from dataloaderraw import * 16 | import eval_utils 17 | import argparse 18 | import misc.utils as utils 19 | import torch 20 | 21 | # Input arguments and options 22 | parser = argparse.ArgumentParser() 23 | # Input paths 24 | parser.add_argument('--model', type=str, default='', 25 | help='path to model to evaluate') 26 | parser.add_argument('--cnn_model', type=str, default='resnet101', 27 | help='resnet101, resnet152') 28 | parser.add_argument('--infos_path', type=str, default='', 29 | help='path to infos to evaluate') 30 | # Basic options 31 | parser.add_argument('--batch_size', type=int, default=0, 32 | help='if > 0 then overrule, otherwise load from checkpoint.') 33 | parser.add_argument('--num_images', type=int, default=-1, 34 | help='how many images to use when periodically evaluating the loss? (-1 = all)') 35 | parser.add_argument('--language_eval', type=int, default=0, 36 | help='Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.') 37 | parser.add_argument('--dump_images', type=int, default=1, 38 | help='Dump images into vis/imgs folder for vis? (1=yes,0=no)') 39 | parser.add_argument('--dump_json', type=int, default=1, 40 | help='Dump json with predictions into vis folder? (1=yes,0=no)') 41 | parser.add_argument('--dump_path', type=int, default=0, 42 | help='Write image paths along with predictions into vis json? (1=yes,0=no)') 43 | 44 | # Sampling options 45 | parser.add_argument('--sample_max', type=int, default=1, 46 | help='1 = sample argmax words. 0 = sample from distributions.') 47 | parser.add_argument('--max_ppl', type=int, default=0, 48 | help='beam search by max perplexity or max probability.') 49 | parser.add_argument('--beam_size', type=int, default=2, 50 | help='used when sample_max = 1, indicates number of beams in beam search. Usually 2 or 3 works well. More is not better. Set this to 1 for faster runtime but a bit worse performance.') 51 | parser.add_argument('--group_size', type=int, default=1, 52 | help='used for diverse beam search. if group_size is 1, then it\'s normal beam search') 53 | parser.add_argument('--diversity_lambda', type=float, default=0.5, 54 | help='used for diverse beam search. Usually from 0.2 to 0.8. Higher value of lambda produces a more diverse list') 55 | parser.add_argument('--temperature', type=float, default=1.0, 56 | help='temperature when sampling from distributions (i.e. when sample_max = 0). Lower = "safer" predictions.') 57 | parser.add_argument('--decoding_constraint', type=int, default=0, 58 | help='If 1, not allowing same word in a row') 59 | # For evaluation on a folder of images: 60 | parser.add_argument('--image_folder', type=str, default='', 61 | help='If this is nonempty then will predict on the images in this folder path') 62 | parser.add_argument('--image_root', type=str, default='', 63 | help='In case the image paths have to be preprended with a root path to an image folder') 64 | # For evaluation on MSCOCO images from some split: 65 | parser.add_argument('--input_fc_dir', type=str, default='', 66 | help='path to the h5file containing the preprocessed dataset') 67 | parser.add_argument('--input_att_dir', type=str, default='', 68 | help='path to the h5file containing the preprocessed dataset') 69 | parser.add_argument('--input_box_dir', type=str, default='', 70 | help='path to the h5file containing the preprocessed dataset') 71 | parser.add_argument('--input_label_h5', type=str, default='', 72 | help='path to the h5file containing the preprocessed dataset') 73 | parser.add_argument('--input_json', type=str, default='', 74 | help='path to the json file containing additional info and vocab. empty = fetch from model checkpoint.') 75 | parser.add_argument('--split', type=str, default='test', 76 | help='if running on MSCOCO images, which split to use: val|test|train') 77 | parser.add_argument('--coco_json', type=str, default='', 78 | help='if nonempty then use this file in DataLoaderRaw (see docs there). Used only in MSCOCO test evaluation, where we have a specific json file of only test set images.') 79 | # misc 80 | parser.add_argument('--id', type=str, default='', 81 | help='an id identifying this run/job. used only if language_eval = 1 for appending to intermediate files') 82 | parser.add_argument('--verbose_beam', type=int, default=1, 83 | help='if we need to print out all beam search beams.') 84 | parser.add_argument('--verbose_loss', type=int, default=0, 85 | help='if we need to calculate loss.') 86 | 87 | opt = parser.parse_args() 88 | 89 | # Load infos 90 | with open(opt.infos_path) as f: 91 | infos = cPickle.load(f) 92 | 93 | # override and collect parameters 94 | if len(opt.input_fc_dir) == 0: 95 | opt.input_fc_dir = infos['opt'].input_fc_dir 96 | opt.input_att_dir = infos['opt'].input_att_dir 97 | opt.input_box_dir = infos['opt'].input_box_dir 98 | opt.input_label_h5 = infos['opt'].input_label_h5 99 | if len(opt.input_json) == 0: 100 | opt.input_json = infos['opt'].input_json 101 | if opt.batch_size == 0: 102 | opt.batch_size = infos['opt'].batch_size 103 | if len(opt.id) == 0: 104 | opt.id = infos['opt'].id 105 | ignore = ["id", "batch_size", "beam_size", "start_from", "language_eval"] 106 | for k in vars(infos['opt']).keys(): 107 | if k not in ignore: 108 | if k in vars(opt): 109 | assert vars(opt)[k] == vars(infos['opt'])[k], k + ' option not consistent' 110 | else: 111 | vars(opt).update({k: vars(infos['opt'])[k]}) # copy over options from model 112 | 113 | vocab = infos['vocab'] # ix -> word mapping 114 | 115 | # Setup the model 116 | model = models.setup(opt) 117 | model.load_state_dict(torch.load(opt.model)) 118 | model.cuda() 119 | model.eval() 120 | crit = utils.LanguageModelCriterion() 121 | 122 | # Create the Data Loader instance 123 | if len(opt.image_folder) == 0: 124 | loader = DataLoader(opt) 125 | else: 126 | loader = DataLoaderRaw({'folder_path': opt.image_folder, 127 | 'coco_json': opt.coco_json, 128 | 'batch_size': opt.batch_size, 129 | 'cnn_model': opt.cnn_model}) 130 | # When eval using provided pretrained model, the vocab may be different from what you have in your cocotalk.json 131 | # So make sure to use the vocab in infos file. 132 | loader.ix_to_word = infos['vocab'] 133 | 134 | 135 | # Set sample options 136 | loss, split_predictions, lang_stats = eval_utils.eval_split(model, crit, loader, 137 | vars(opt)) 138 | 139 | print('loss: ', loss) 140 | if lang_stats: 141 | print(lang_stats) 142 | 143 | if opt.dump_json == 1: 144 | # dump the json 145 | json.dump(split_predictions, open('vis/vis.json', 'w')) 146 | -------------------------------------------------------------------------------- /eval_ensemble.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import json 6 | import numpy as np 7 | 8 | import time 9 | import os 10 | from six.moves import cPickle 11 | 12 | import opts 13 | import models 14 | from dataloader import * 15 | from dataloaderraw import * 16 | import eval_utils 17 | import argparse 18 | import misc.utils as utils 19 | import torch 20 | 21 | # Input arguments and options 22 | parser = argparse.ArgumentParser() 23 | # Input paths 24 | parser.add_argument('--ids', nargs='+', required=True, help='id of the models to ensemble') 25 | # parser.add_argument('--models', nargs='+', required=True 26 | # help='path to model to evaluate') 27 | # parser.add_argument('--infos_paths', nargs='+', required=True, help='path to infos to evaluate') 28 | # Basic options 29 | parser.add_argument('--batch_size', type=int, default=0, 30 | help='if > 0 then overrule, otherwise load from checkpoint.') 31 | parser.add_argument('--num_images', type=int, default=-1, 32 | help='how many images to use when periodically evaluating the loss? (-1 = all)') 33 | parser.add_argument('--language_eval', type=int, default=0, 34 | help='Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.') 35 | parser.add_argument('--dump_images', type=int, default=1, 36 | help='Dump images into vis/imgs folder for vis? (1=yes,0=no)') 37 | parser.add_argument('--dump_json', type=int, default=1, 38 | help='Dump json with predictions into vis folder? (1=yes,0=no)') 39 | parser.add_argument('--dump_path', type=int, default=0, 40 | help='Write image paths along with predictions into vis json? (1=yes,0=no)') 41 | 42 | # Sampling options 43 | parser.add_argument('--sample_max', type=int, default=1, 44 | help='1 = sample argmax words. 0 = sample from distributions.') 45 | parser.add_argument('--max_ppl', type=int, default=0, 46 | help='beam search by max perplexity or max probability.') 47 | parser.add_argument('--beam_size', type=int, default=2, 48 | help='used when sample_max = 1, indicates number of beams in beam search. Usually 2 or 3 works well. More is not better. Set this to 1 for faster runtime but a bit worse performance.') 49 | parser.add_argument('--group_size', type=int, default=1, 50 | help='used for diverse beam search. if group_size is 1, then it\'s normal beam search') 51 | parser.add_argument('--diversity_lambda', type=float, default=0.5, 52 | help='used for diverse beam search. Usually from 0.2 to 0.8. Higher value of lambda produces a more diverse list') 53 | parser.add_argument('--temperature', type=float, default=1.0, 54 | help='temperature when sampling from distributions (i.e. when sample_max = 0). Lower = "safer" predictions.') 55 | parser.add_argument('--decoding_constraint', type=int, default=0, 56 | help='If 1, not allowing same word in a row') 57 | # For evaluation on a folder of images: 58 | parser.add_argument('--image_folder', type=str, default='', 59 | help='If this is nonempty then will predict on the images in this folder path') 60 | parser.add_argument('--image_root', type=str, default='', 61 | help='In case the image paths have to be preprended with a root path to an image folder') 62 | # For evaluation on MSCOCO images from some split: 63 | parser.add_argument('--input_fc_dir', type=str, default='', 64 | help='path to the h5file containing the preprocessed dataset') 65 | parser.add_argument('--input_att_dir', type=str, default='', 66 | help='path to the h5file containing the preprocessed dataset') 67 | parser.add_argument('--input_box_dir', type=str, default='', 68 | help='path to the h5file containing the preprocessed dataset') 69 | parser.add_argument('--input_label_h5', type=str, default='', 70 | help='path to the h5file containing the preprocessed dataset') 71 | parser.add_argument('--input_json', type=str, default='', 72 | help='path to the json file containing additional info and vocab. empty = fetch from model checkpoint.') 73 | parser.add_argument('--split', type=str, default='test', 74 | help='if running on MSCOCO images, which split to use: val|test|train') 75 | parser.add_argument('--coco_json', type=str, default='', 76 | help='if nonempty then use this file in DataLoaderRaw (see docs there). Used only in MSCOCO test evaluation, where we have a specific json file of only test set images.') 77 | parser.add_argument('--seq_length', type=int, default=40, 78 | help='maximum sequence length during sampling') 79 | # misc 80 | parser.add_argument('--id', type=str, default='', 81 | help='an id identifying this run/job. used only if language_eval = 1 for appending to intermediate files') 82 | parser.add_argument('--verbose_beam', type=int, default=1, 83 | help='if we need to print out all beam search beams.') 84 | parser.add_argument('--verbose_loss', type=int, default=0, 85 | help='If calculate loss using ground truth during evaluation') 86 | 87 | opt = parser.parse_args() 88 | 89 | model_infos = [cPickle.load(open('log_%s/infos_%s-best.pkl' %(id, id))) for id in opt.ids] 90 | model_paths = ['log_%s/model-best.pth' %(id) for id in opt.ids] 91 | 92 | # Load one infos 93 | infos = model_infos[0] 94 | 95 | # override and collect parameters 96 | if len(opt.input_fc_dir) == 0: 97 | opt.input_fc_dir = infos['opt'].input_fc_dir 98 | opt.input_att_dir = infos['opt'].input_att_dir 99 | opt.input_box_dir = infos['opt'].input_box_dir 100 | opt.input_label_h5 = infos['opt'].input_label_h5 101 | if len(opt.input_json) == 0: 102 | opt.input_json = infos['opt'].input_json 103 | if opt.batch_size == 0: 104 | opt.batch_size = infos['opt'].batch_size 105 | if len(opt.id) == 0: 106 | opt.id = infos['opt'].id 107 | 108 | vars(opt).update({k: vars(infos['opt'])[k] for k in vars(infos['opt']).keys() if k not in vars(opt)}) # copy over options from model 109 | 110 | 111 | opt.use_box = max([getattr(infos['opt'], 'use_box', 0) for infos in model_infos]) 112 | assert max([getattr(infos['opt'], 'norm_att_feat', 0) for infos in model_infos]) == max([getattr(infos['opt'], 'norm_att_feat', 0) for infos in model_infos]), 'Not support different norm_att_feat' 113 | assert max([getattr(infos['opt'], 'norm_box_feat', 0) for infos in model_infos]) == max([getattr(infos['opt'], 'norm_box_feat', 0) for infos in model_infos]), 'Not support different norm_box_feat' 114 | 115 | vocab = infos['vocab'] # ix -> word mapping 116 | 117 | # Setup the model 118 | from models.AttEnsemble import AttEnsemble 119 | 120 | _models = [] 121 | for i in range(len(model_infos)): 122 | model_infos[i]['opt'].start_from = None 123 | tmp = models.setup(model_infos[i]['opt']) 124 | tmp.load_state_dict(torch.load(model_paths[i])) 125 | tmp.cuda() 126 | tmp.eval() 127 | _models.append(tmp) 128 | 129 | model = AttEnsemble(_models) 130 | model.seq_length = opt.seq_length 131 | model.eval() 132 | crit = utils.LanguageModelCriterion() 133 | 134 | # Create the Data Loader instance 135 | if len(opt.image_folder) == 0: 136 | loader = DataLoader(opt) 137 | else: 138 | loader = DataLoaderRaw({'folder_path': opt.image_folder, 139 | 'coco_json': opt.coco_json, 140 | 'batch_size': opt.batch_size, 141 | 'cnn_model': opt.cnn_model}) 142 | # When eval using provided pretrained model, the vocab may be different from what you have in your cocotalk.json 143 | # So make sure to use the vocab in infos file. 144 | loader.ix_to_word = infos['vocab'] 145 | 146 | 147 | # Set sample options 148 | loss, split_predictions, lang_stats = eval_utils.eval_split(model, crit, loader, 149 | vars(opt)) 150 | 151 | print('loss: ', loss) 152 | if lang_stats: 153 | print(lang_stats) 154 | 155 | if opt.dump_json == 1: 156 | # dump the json 157 | json.dump(split_predictions, open('vis/vis.json', 'w')) 158 | -------------------------------------------------------------------------------- /eval_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | import numpy as np 9 | import json 10 | from json import encoder 11 | import random 12 | import string 13 | import time 14 | import os 15 | import sys 16 | import misc.utils as utils 17 | 18 | def language_eval(dataset, preds, model_id, split): 19 | import sys 20 | sys.path.append("coco-caption") 21 | annFile = 'coco-caption/annotations/captions_val2014.json' 22 | from pycocotools.coco import COCO 23 | from pycocoevalcap.eval import COCOEvalCap 24 | 25 | # encoder.FLOAT_REPR = lambda o: format(o, '.3f') 26 | 27 | if not os.path.isdir('eval_results'): 28 | os.mkdir('eval_results') 29 | cache_path = os.path.join('eval_results/', model_id + '_' + split + '.json') 30 | 31 | coco = COCO(annFile) 32 | valids = coco.getImgIds() 33 | 34 | # filter results to only those in MSCOCO validation set (will be about a third) 35 | preds_filt = [p for p in preds if p['image_id'] in valids] 36 | print('using %d/%d predictions' % (len(preds_filt), len(preds))) 37 | json.dump(preds_filt, open(cache_path, 'w')) # serialize to temporary json file. Sigh, COCO API... 38 | 39 | cocoRes = coco.loadRes(cache_path) 40 | cocoEval = COCOEvalCap(coco, cocoRes) 41 | cocoEval.params['image_id'] = cocoRes.getImgIds() 42 | cocoEval.evaluate() 43 | 44 | # create output dictionary 45 | out = {} 46 | for metric, score in cocoEval.eval.items(): 47 | out[metric] = score 48 | 49 | imgToEval = cocoEval.imgToEval 50 | for p in preds_filt: 51 | image_id, caption = p['image_id'], p['caption'] 52 | imgToEval[image_id]['caption'] = caption 53 | with open(cache_path, 'w') as outfile: 54 | json.dump({'overall': out, 'imgToEval': imgToEval}, outfile) 55 | 56 | return out 57 | 58 | def eval_split(model, crit, loader, eval_kwargs={}): 59 | verbose = eval_kwargs.get('verbose', True) 60 | verbose_beam = eval_kwargs.get('verbose_beam', 1) 61 | verbose_loss = eval_kwargs.get('verbose_loss', 1) 62 | num_images = eval_kwargs.get('num_images', eval_kwargs.get('val_images_use', -1)) 63 | split = eval_kwargs.get('split', 'val') 64 | lang_eval = eval_kwargs.get('language_eval', 0) 65 | dataset = eval_kwargs.get('dataset', 'coco') 66 | beam_size = eval_kwargs.get('beam_size', 1) 67 | 68 | # Make sure in the evaluation mode 69 | model.eval() 70 | 71 | loader.reset_iterator(split) 72 | 73 | n = 0 74 | loss = 0 75 | loss_sum = 0 76 | loss_evals = 1e-8 77 | predictions = [] 78 | while True: 79 | data = loader.get_batch(split) 80 | n = n + loader.batch_size 81 | 82 | if data.get('labels', None) is not None and verbose_loss: 83 | # forward the model to get loss 84 | tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks']] 85 | tmp = [torch.from_numpy(_).cuda() if _ is not None else _ for _ in tmp] 86 | fc_feats, att_feats, labels, masks, att_masks = tmp 87 | 88 | with torch.no_grad(): 89 | loss = crit(model(fc_feats, att_feats, labels, att_masks), labels[:,1:], masks[:,1:]).item() 90 | loss_sum = loss_sum + loss 91 | loss_evals = loss_evals + 1 92 | 93 | # forward the model to also get generated samples for each image 94 | # Only leave one feature for each image, in case duplicate sample 95 | tmp = [data['fc_feats'][np.arange(loader.batch_size) * loader.seq_per_img], 96 | data['att_feats'][np.arange(loader.batch_size) * loader.seq_per_img], 97 | data['att_masks'][np.arange(loader.batch_size) * loader.seq_per_img] if data['att_masks'] is not None else None] 98 | tmp = [torch.from_numpy(_).cuda() if _ is not None else _ for _ in tmp] 99 | fc_feats, att_feats, att_masks = tmp 100 | # forward the model to also get generated samples for each image 101 | with torch.no_grad(): 102 | seq = model(fc_feats, att_feats, att_masks, opt=eval_kwargs, mode='sample')[0].data 103 | 104 | # Print beam search 105 | if beam_size > 1 and verbose_beam: 106 | for i in range(loader.batch_size): 107 | print('\n'.join([utils.decode_sequence(loader.get_vocab(), _['seq'].unsqueeze(0))[0] for _ in model.done_beams[i]])) 108 | print('--' * 10) 109 | sents = utils.decode_sequence(loader.get_vocab(), seq) 110 | 111 | for k, sent in enumerate(sents): 112 | entry = {'image_id': data['infos'][k]['id'], 'caption': sent} 113 | if eval_kwargs.get('dump_path', 0) == 1: 114 | entry['file_name'] = data['infos'][k]['file_path'] 115 | predictions.append(entry) 116 | if eval_kwargs.get('dump_images', 0) == 1: 117 | # dump the raw image to vis/ folder 118 | cmd = 'cp "' + os.path.join(eval_kwargs['image_root'], data['infos'][k]['file_path']) + '" vis/imgs/img' + str(len(predictions)) + '.jpg' # bit gross 119 | print(cmd) 120 | os.system(cmd) 121 | 122 | if verbose: 123 | print('image %s: %s' %(entry['image_id'], entry['caption'])) 124 | 125 | # if we wrapped around the split or used up val imgs budget then bail 126 | ix0 = data['bounds']['it_pos_now'] 127 | ix1 = data['bounds']['it_max'] 128 | if num_images != -1: 129 | ix1 = min(ix1, num_images) 130 | for i in range(n - ix1): 131 | predictions.pop() 132 | 133 | if verbose: 134 | print('evaluating validation preformance... %d/%d (%f)' %(ix0 - 1, ix1, loss)) 135 | 136 | if data['bounds']['wrapped']: 137 | break 138 | if num_images >= 0 and n >= num_images: 139 | break 140 | 141 | lang_stats = None 142 | if lang_eval == 1: 143 | lang_stats = language_eval(dataset, predictions, eval_kwargs['id'], split) 144 | 145 | # Switch back to training mode 146 | model.train() 147 | return loss_sum/loss_evals, predictions, lang_stats 148 | -------------------------------------------------------------------------------- /misc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruotianluo/Transformer_Captioning/717fe48e4d808df771c2de37610b200c3eb78fe1/misc/__init__.py -------------------------------------------------------------------------------- /misc/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models.resnet 4 | from torchvision.models.resnet import BasicBlock, Bottleneck 5 | 6 | class ResNet(torchvision.models.resnet.ResNet): 7 | def __init__(self, block, layers, num_classes=1000): 8 | super(ResNet, self).__init__(block, layers, num_classes) 9 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) # change 10 | for i in range(2, 5): 11 | getattr(self, 'layer%d'%i)[0].conv1.stride = (2,2) 12 | getattr(self, 'layer%d'%i)[0].conv2.stride = (1,1) 13 | 14 | def resnet18(pretrained=False): 15 | """Constructs a ResNet-18 model. 16 | 17 | Args: 18 | pretrained (bool): If True, returns a model pre-trained on ImageNet 19 | """ 20 | model = ResNet(BasicBlock, [2, 2, 2, 2]) 21 | if pretrained: 22 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 23 | return model 24 | 25 | 26 | def resnet34(pretrained=False): 27 | """Constructs a ResNet-34 model. 28 | 29 | Args: 30 | pretrained (bool): If True, returns a model pre-trained on ImageNet 31 | """ 32 | model = ResNet(BasicBlock, [3, 4, 6, 3]) 33 | if pretrained: 34 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 35 | return model 36 | 37 | 38 | def resnet50(pretrained=False): 39 | """Constructs a ResNet-50 model. 40 | 41 | Args: 42 | pretrained (bool): If True, returns a model pre-trained on ImageNet 43 | """ 44 | model = ResNet(Bottleneck, [3, 4, 6, 3]) 45 | if pretrained: 46 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 47 | return model 48 | 49 | 50 | def resnet101(pretrained=False): 51 | """Constructs a ResNet-101 model. 52 | 53 | Args: 54 | pretrained (bool): If True, returns a model pre-trained on ImageNet 55 | """ 56 | model = ResNet(Bottleneck, [3, 4, 23, 3]) 57 | if pretrained: 58 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 59 | return model 60 | 61 | 62 | def resnet152(pretrained=False): 63 | """Constructs a ResNet-152 model. 64 | 65 | Args: 66 | pretrained (bool): If True, returns a model pre-trained on ImageNet 67 | """ 68 | model = ResNet(Bottleneck, [3, 8, 36, 3]) 69 | if pretrained: 70 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 71 | return model -------------------------------------------------------------------------------- /misc/resnet_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class myResnet(nn.Module): 6 | def __init__(self, resnet): 7 | super(myResnet, self).__init__() 8 | self.resnet = resnet 9 | 10 | def forward(self, img, att_size=14): 11 | x = img.unsqueeze(0) 12 | 13 | x = self.resnet.conv1(x) 14 | x = self.resnet.bn1(x) 15 | x = self.resnet.relu(x) 16 | x = self.resnet.maxpool(x) 17 | 18 | x = self.resnet.layer1(x) 19 | x = self.resnet.layer2(x) 20 | x = self.resnet.layer3(x) 21 | x = self.resnet.layer4(x) 22 | 23 | fc = x.mean(3).mean(2).squeeze() 24 | att = F.adaptive_avg_pool2d(x,[att_size,att_size]).squeeze().permute(1, 2, 0) 25 | 26 | return fc, att 27 | 28 | -------------------------------------------------------------------------------- /misc/rewards.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | import time 7 | import misc.utils as utils 8 | from collections import OrderedDict 9 | import torch 10 | 11 | import sys 12 | sys.path.append("cider") 13 | from pyciderevalcap.ciderD.ciderD import CiderD 14 | sys.path.append("coco-caption") 15 | from pycocoevalcap.bleu.bleu import Bleu 16 | 17 | CiderD_scorer = None 18 | Bleu_scorer = None 19 | #CiderD_scorer = CiderD(df='corpus') 20 | 21 | def init_scorer(cached_tokens): 22 | global CiderD_scorer 23 | CiderD_scorer = CiderD_scorer or CiderD(df=cached_tokens) 24 | global Bleu_scorer 25 | Bleu_scorer = Bleu_scorer or Bleu(4) 26 | 27 | def array_to_str(arr): 28 | out = '' 29 | for i in range(len(arr)): 30 | out += str(arr[i]) + ' ' 31 | if arr[i] == 0: 32 | break 33 | return out.strip() 34 | 35 | def get_self_critical_reward(model, fc_feats, att_feats, att_masks, data, gen_result, opt): 36 | batch_size = gen_result.size(0)# batch_size = sample_size * seq_per_img 37 | seq_per_img = batch_size // len(data['gts']) 38 | 39 | # get greedy decoding baseline 40 | model.eval() 41 | with torch.no_grad(): 42 | greedy_res, _ = model(fc_feats, att_feats, att_masks=att_masks, mode='sample') 43 | model.train() 44 | 45 | res = OrderedDict() 46 | 47 | gen_result = gen_result.data.cpu().numpy() 48 | greedy_res = greedy_res.data.cpu().numpy() 49 | for i in range(batch_size): 50 | res[i] = [array_to_str(gen_result[i])] 51 | for i in range(batch_size): 52 | res[batch_size + i] = [array_to_str(greedy_res[i])] 53 | 54 | gts = OrderedDict() 55 | for i in range(len(data['gts'])): 56 | gts[i] = [array_to_str(data['gts'][i][j]) for j in range(len(data['gts'][i]))] 57 | 58 | res_ = [{'image_id':i, 'caption': res[i]} for i in range(2 * batch_size)] 59 | res__ = {i: res[i] for i in range(2 * batch_size)} 60 | gts = {i: gts[i % batch_size // seq_per_img] for i in range(2 * batch_size)} 61 | if opt.cider_reward_weight > 0: 62 | _, cider_scores = CiderD_scorer.compute_score(gts, res_) 63 | print('Cider scores:', _) 64 | else: 65 | cider_scores = 0 66 | if opt.bleu_reward_weight > 0: 67 | _, bleu_scores = Bleu_scorer.compute_score(gts, res__) 68 | bleu_scores = np.array(bleu_scores[3]) 69 | print('Bleu scores:', _[3]) 70 | else: 71 | bleu_scores = 0 72 | scores = opt.cider_reward_weight * cider_scores + opt.bleu_reward_weight * bleu_scores 73 | 74 | scores = scores[:batch_size] - scores[batch_size:] 75 | 76 | rewards = np.repeat(scores[:, np.newaxis], gen_result.shape[1], 1) 77 | 78 | return rewards -------------------------------------------------------------------------------- /misc/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import collections 6 | import torch 7 | import torch.nn as nn 8 | import numpy as np 9 | import torch.optim as optim 10 | 11 | def if_use_att(caption_model): 12 | # Decide if load attention feature according to caption model 13 | if caption_model in ['show_tell', 'all_img', 'fc']: 14 | return False 15 | return True 16 | 17 | # Input: seq, N*D numpy array, with element 0 .. vocab_size. 0 is END token. 18 | def decode_sequence(ix_to_word, seq): 19 | N, D = seq.size() 20 | out = [] 21 | for i in range(N): 22 | txt = '' 23 | for j in range(D): 24 | ix = seq[i,j] 25 | if ix > 0 : 26 | if j >= 1: 27 | txt = txt + ' ' 28 | txt = txt + ix_to_word[str(ix.item())] 29 | else: 30 | break 31 | out.append(txt) 32 | return out 33 | 34 | def to_contiguous(tensor): 35 | if tensor.is_contiguous(): 36 | return tensor 37 | else: 38 | return tensor.contiguous() 39 | 40 | class RewardCriterion(nn.Module): 41 | def __init__(self): 42 | super(RewardCriterion, self).__init__() 43 | 44 | def forward(self, input, seq, reward): 45 | input = to_contiguous(input).view(-1) 46 | reward = to_contiguous(reward).view(-1) 47 | mask = (seq>0).float() 48 | mask = to_contiguous(torch.cat([mask.new(mask.size(0), 1).fill_(1), mask[:, :-1]], 1)).view(-1) 49 | output = - input * reward * mask 50 | output = torch.sum(output) / torch.sum(mask) 51 | 52 | return output 53 | 54 | class LanguageModelCriterion(nn.Module): 55 | def __init__(self): 56 | super(LanguageModelCriterion, self).__init__() 57 | 58 | def forward(self, input, target, mask): 59 | # truncate to the same size 60 | target = target[:, :input.size(1)] 61 | mask = mask[:, :input.size(1)] 62 | 63 | output = -input.gather(2, target.unsqueeze(2)).squeeze(2) * mask 64 | output = torch.sum(output) / torch.sum(mask) 65 | 66 | return output 67 | 68 | class LabelSmoothing(nn.Module): 69 | "Implement label smoothing." 70 | def __init__(self, size=0, padding_idx=0, smoothing=0.0): 71 | super(LabelSmoothing, self).__init__() 72 | self.criterion = nn.KLDivLoss(size_average=False, reduce=False) 73 | # self.padding_idx = padding_idx 74 | self.confidence = 1.0 - smoothing 75 | self.smoothing = smoothing 76 | # self.size = size 77 | self.true_dist = None 78 | 79 | def forward(self, input, target, mask): 80 | # truncate to the same size 81 | target = target[:, :input.size(1)] 82 | mask = mask[:, :input.size(1)] 83 | 84 | input = to_contiguous(input).view(-1, input.size(-1)) 85 | target = to_contiguous(target).view(-1) 86 | mask = to_contiguous(mask).view(-1) 87 | 88 | # assert x.size(1) == self.size 89 | self.size = input.size(1) 90 | # true_dist = x.data.clone() 91 | true_dist = input.data.clone() 92 | # true_dist.fill_(self.smoothing / (self.size - 2)) 93 | true_dist.fill_(self.smoothing / (self.size - 1)) 94 | true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) 95 | # true_dist[:, self.padding_idx] = 0 96 | # mask = torch.nonzero(target.data == self.padding_idx) 97 | # self.true_dist = true_dist 98 | return (self.criterion(input, true_dist).sum(1) * mask).sum() / mask.sum() 99 | 100 | def set_lr(optimizer, lr): 101 | for group in optimizer.param_groups: 102 | group['lr'] = lr 103 | 104 | def get_lr(optimizer): 105 | for group in optimizer.param_groups: 106 | return group['lr'] 107 | 108 | def clip_gradient(optimizer, grad_clip): 109 | for group in optimizer.param_groups: 110 | for param in group['params']: 111 | param.grad.data.clamp_(-grad_clip, grad_clip) 112 | 113 | def build_optimizer(params, opt): 114 | if opt.optim == 'rmsprop': 115 | return optim.RMSprop(params, opt.learning_rate, opt.optim_alpha, opt.optim_epsilon, weight_decay=opt.weight_decay) 116 | elif opt.optim == 'adagrad': 117 | return optim.Adagrad(params, opt.learning_rate, weight_decay=opt.weight_decay) 118 | elif opt.optim == 'sgd': 119 | return optim.SGD(params, opt.learning_rate, weight_decay=opt.weight_decay) 120 | elif opt.optim == 'sgdm': 121 | return optim.SGD(params, opt.learning_rate, opt.optim_alpha, weight_decay=opt.weight_decay) 122 | elif opt.optim == 'sgdmom': 123 | return optim.SGD(params, opt.learning_rate, opt.optim_alpha, weight_decay=opt.weight_decay, nesterov=True) 124 | elif opt.optim == 'adam': 125 | return optim.Adam(params, opt.learning_rate, (opt.optim_alpha, opt.optim_beta), opt.optim_epsilon, weight_decay=opt.weight_decay) 126 | else: 127 | raise Exception("bad option opt.optim: {}".format(opt.optim)) 128 | 129 | 130 | class NoamOpt(object): 131 | "Optim wrapper that implements rate." 132 | def __init__(self, model_size, factor, warmup, optimizer): 133 | self.optimizer = optimizer 134 | self._step = 0 135 | self.warmup = warmup 136 | self.factor = factor 137 | self.model_size = model_size 138 | self._rate = 0 139 | 140 | def step(self): 141 | "Update parameters and rate" 142 | self._step += 1 143 | rate = self.rate() 144 | for p in self.optimizer.param_groups: 145 | p['lr'] = rate 146 | self._rate = rate 147 | self.optimizer.step() 148 | 149 | def rate(self, step = None): 150 | "Implement `lrate` above" 151 | if step is None: 152 | step = self._step 153 | return self.factor * \ 154 | (self.model_size ** (-0.5) * 155 | min(step ** (-0.5), step * self.warmup ** (-1.5))) 156 | 157 | def __getattr__(self, name): 158 | return getattr(self.optimizer, name) 159 | 160 | class ReduceLROnPlateau(object): 161 | "Optim wrapper that implements rate." 162 | def __init__(self, optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08): 163 | self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode, factor, patience, verbose, threshold, threshold_mode, cooldown, min_lr, eps) 164 | self.optimizer = optimizer 165 | self.current_lr = get_lr(optimizer) 166 | 167 | def step(self): 168 | "Update parameters and rate" 169 | self.optimizer.step() 170 | 171 | def scheduler_step(self, val): 172 | self.scheduler.step(val) 173 | self.current_lr = get_lr(self.optimizer) 174 | 175 | def state_dict(self): 176 | return {'current_lr':self.current_lr, 177 | 'scheduler_state_dict': {key: value for key, value in self.scheduler.__dict__.items() if key not in {'optimizer', 'is_better'}}, 178 | 'optimizer_state_dict': self.optimizer.state_dict()} 179 | 180 | def load_state_dict(self, state_dict): 181 | if 'current_lr' not in state_dict: 182 | # it's normal optimizer 183 | self.optimizer.load_state_dict(state_dict) 184 | set_lr(self.optimizer, self.current_lr) # use the lr fromt the option 185 | else: 186 | # it's a schduler 187 | self.current_lr = state_dict['current_lr'] 188 | self.scheduler.__dict__.update(state_dict['scheduler_state_dict']) 189 | self.scheduler._init_is_better(mode=self.scheduler.mode, threshold=self.scheduler.threshold, threshold_mode=self.scheduler.threshold_mode) 190 | self.optimizer.load_state_dict(state_dict['optimizer_state_dict']) 191 | # current_lr is actually useless in this case 192 | 193 | def rate(self, step = None): 194 | "Implement `lrate` above" 195 | if step is None: 196 | step = self._step 197 | return self.factor * \ 198 | (self.model_size ** (-0.5) * 199 | min(step ** (-0.5), step * self.warmup ** (-1.5))) 200 | 201 | def __getattr__(self, name): 202 | return getattr(self.optimizer, name) 203 | 204 | def get_std_opt(model, factor=1, warmup=2000): 205 | # return NoamOpt(model.tgt_embed[0].d_model, 2, 4000, 206 | # torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) 207 | return NoamOpt(model.model.tgt_embed[0].d_model, factor, warmup, 208 | torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) 209 | -------------------------------------------------------------------------------- /models/AttEnsemble.py: -------------------------------------------------------------------------------- 1 | # This file contains Att2in2, AdaAtt, AdaAttMO, TopDown model 2 | 3 | # AdaAtt is from Knowing When to Look: Adaptive Attention via A Visual Sentinel for Image Captioning 4 | # https://arxiv.org/abs/1612.01887 5 | # AdaAttMO is a modified version with maxout lstm 6 | 7 | # Att2in is from Self-critical Sequence Training for Image Captioning 8 | # https://arxiv.org/abs/1612.00563 9 | # In this file we only have Att2in2, which is a slightly different version of att2in, 10 | # in which the img feature embedding and word embedding is the same as what in adaatt. 11 | 12 | # TopDown is from Bottom-Up and Top-Down Attention for Image Captioning and VQA 13 | # https://arxiv.org/abs/1707.07998 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import torch 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | from torch.autograd import * 23 | import misc.utils as utils 24 | 25 | from .CaptionModel import CaptionModel 26 | from .AttModel import pack_wrapper, AttModel 27 | 28 | class AttEnsemble(AttModel): 29 | def __init__(self, models): 30 | CaptionModel.__init__(self) 31 | # super(AttEnsemble, self).__init__() 32 | 33 | self.models = nn.ModuleList(models) 34 | self.vocab_size = models[0].vocab_size 35 | self.seq_length = models[0].seq_length 36 | self.ss_prob = 0 37 | 38 | def init_hidden(self, batch_size): 39 | return [m.init_hidden(batch_size) for m in self.models] 40 | 41 | def embed(self, it): 42 | return [m.embed(it) for m in self.models] 43 | 44 | def core(self, *args): 45 | return zip(*[m.core(*_) for m, _ in zip(self.models, zip(*args))]) 46 | 47 | def get_logprobs_state(self, it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, state): 48 | # 'it' contains a word index 49 | xt = self.embed(it) 50 | 51 | output, state = self.core(xt, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, state, tmp_att_masks) 52 | logprobs = torch.stack([F.softmax(m.logit(output[i]), dim=1) for i,m in enumerate(self.models)], 2).mean(2).log() 53 | 54 | return logprobs, state 55 | 56 | def _prepare_feature(self, fc_feats, att_feats, att_masks): 57 | att_feats, att_masks = self.clip_att(att_feats, att_masks) 58 | 59 | # embed fc and att feats 60 | fc_feats = [m.fc_embed(fc_feats) for m in self.models] 61 | att_feats = [pack_wrapper(m.att_embed, att_feats[...,:m.att_feat_size], att_masks) for m in self.models] 62 | 63 | # Project the attention feats first to reduce memory and computation comsumptions. 64 | p_att_feats = [m.ctx2att(att_feats[i]) for i,m in enumerate(self.models)] 65 | 66 | return fc_feats, att_feats, p_att_feats, [att_masks] * len(self.models) 67 | 68 | def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}): 69 | beam_size = opt.get('beam_size', 10) 70 | batch_size = fc_feats.size(0) 71 | 72 | fc_feats, att_feats, p_att_feats, att_masks = self._prepare_feature(fc_feats, att_feats, att_masks) 73 | 74 | assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' 75 | seq = torch.LongTensor(self.seq_length, batch_size).zero_() 76 | seqLogprobs = torch.FloatTensor(self.seq_length, batch_size) 77 | # lets process every image independently for now, for simplicity 78 | 79 | self.done_beams = [[] for _ in range(batch_size)] 80 | for k in range(batch_size): 81 | state = self.init_hidden(beam_size) 82 | tmp_fc_feats = [fc_feats[i][k:k+1].expand(beam_size, fc_feats[i].size(1)) for i,m in enumerate(self.models)] 83 | tmp_att_feats = [att_feats[i][k:k+1].expand(*((beam_size,)+att_feats[i].size()[1:])).contiguous() for i,m in enumerate(self.models)] 84 | tmp_p_att_feats = [p_att_feats[i][k:k+1].expand(*((beam_size,)+p_att_feats[i].size()[1:])).contiguous() for i,m in enumerate(self.models)] 85 | tmp_att_masks = [att_masks[k:k+1].expand(*((beam_size,)+att_masks.size()[1:])).contiguous() for i,m in enumerate(self.models)] if att_masks[0] is not None else att_masks 86 | 87 | it = fc_feats[0].data.new(beam_size).long().zero_() 88 | logprobs, state = self.get_logprobs_state(it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, state) 89 | 90 | self.done_beams[k] = self.beam_search(state, logprobs, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, opt=opt) 91 | seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score 92 | seqLogprobs[:, k] = self.done_beams[k][0]['logps'] 93 | # return the samples and their log likelihoods 94 | return seq.transpose(0, 1), seqLogprobs.transpose(0, 1) 95 | 96 | def beam_search(self, init_state, init_logprobs, *args, **kwargs): 97 | 98 | # function computes the similarity score to be augmented 99 | def add_diversity(beam_seq_table, logprobsf, t, divm, diversity_lambda, bdash): 100 | local_time = t - divm 101 | unaug_logprobsf = logprobsf.clone() 102 | for prev_choice in range(divm): 103 | prev_decisions = beam_seq_table[prev_choice][local_time] 104 | for sub_beam in range(bdash): 105 | for prev_labels in range(bdash): 106 | logprobsf[sub_beam][prev_decisions[prev_labels]] = logprobsf[sub_beam][prev_decisions[prev_labels]] - diversity_lambda 107 | return unaug_logprobsf 108 | 109 | # does one step of classical beam search 110 | 111 | def beam_step(logprobsf, unaug_logprobsf, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state): 112 | #INPUTS: 113 | #logprobsf: probabilities augmented after diversity 114 | #beam_size: obvious 115 | #t : time instant 116 | #beam_seq : tensor contanining the beams 117 | #beam_seq_logprobs: tensor contanining the beam logprobs 118 | #beam_logprobs_sum: tensor contanining joint logprobs 119 | #OUPUTS: 120 | #beam_seq : tensor containing the word indices of the decoded captions 121 | #beam_seq_logprobs : log-probability of each decision made, same size as beam_seq 122 | #beam_logprobs_sum : joint log-probability of each beam 123 | 124 | ys,ix = torch.sort(logprobsf,1,True) 125 | candidates = [] 126 | cols = min(beam_size, ys.size(1)) 127 | rows = beam_size 128 | if t == 0: 129 | rows = 1 130 | for c in range(cols): # for each column (word, essentially) 131 | for q in range(rows): # for each beam expansion 132 | #compute logprob of expanding beam q with word in (sorted) position c 133 | local_logprob = ys[q,c].item() 134 | candidate_logprob = beam_logprobs_sum[q] + local_logprob 135 | local_unaug_logprob = unaug_logprobsf[q,ix[q,c]] 136 | candidates.append({'c':ix[q,c], 'q':q, 'p':candidate_logprob, 'r':local_unaug_logprob}) 137 | candidates = sorted(candidates, key=lambda x: -x['p']) 138 | 139 | new_state = [[_.clone() for _ in state_] for state_ in state] 140 | #beam_seq_prev, beam_seq_logprobs_prev 141 | if t >= 1: 142 | #we''ll need these as reference when we fork beams around 143 | beam_seq_prev = beam_seq[:t].clone() 144 | beam_seq_logprobs_prev = beam_seq_logprobs[:t].clone() 145 | for vix in range(beam_size): 146 | v = candidates[vix] 147 | #fork beam index q into index vix 148 | if t >= 1: 149 | beam_seq[:t, vix] = beam_seq_prev[:, v['q']] 150 | beam_seq_logprobs[:t, vix] = beam_seq_logprobs_prev[:, v['q']] 151 | #rearrange recurrent states 152 | for ii in range(len(new_state)): 153 | for state_ix in range(len(new_state[ii])): 154 | # copy over state in previous beam q to new beam at vix 155 | new_state[ii][state_ix][:, vix] = state[ii][state_ix][:, v['q']] # dimension one is time step 156 | #append new end terminal at the end of this beam 157 | beam_seq[t, vix] = v['c'] # c'th word is the continuation 158 | beam_seq_logprobs[t, vix] = v['r'] # the raw logprob here 159 | beam_logprobs_sum[vix] = v['p'] # the new (sum) logprob along this beam 160 | state = new_state 161 | return beam_seq,beam_seq_logprobs,beam_logprobs_sum,state,candidates 162 | 163 | # Start diverse_beam_search 164 | opt = kwargs['opt'] 165 | beam_size = opt.get('beam_size', 10) 166 | group_size = opt.get('group_size', 1) 167 | diversity_lambda = opt.get('diversity_lambda', 0.5) 168 | decoding_constraint = opt.get('decoding_constraint', 0) 169 | max_ppl = opt.get('max_ppl', 0) 170 | bdash = beam_size // group_size # beam per group 171 | 172 | # INITIALIZATIONS 173 | beam_seq_table = [torch.LongTensor(self.seq_length, bdash).zero_() for _ in range(group_size)] 174 | beam_seq_logprobs_table = [torch.FloatTensor(self.seq_length, bdash).zero_() for _ in range(group_size)] 175 | beam_logprobs_sum_table = [torch.zeros(bdash) for _ in range(group_size)] 176 | 177 | # logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1) 178 | done_beams_table = [[] for _ in range(group_size)] 179 | state_table = zip(*[[list(torch.unbind(_)) for _ in torch.stack(init_state_).chunk(group_size, 2)] for init_state_ in init_state]) 180 | logprobs_table = list(init_logprobs.chunk(group_size, 0)) 181 | # END INIT 182 | 183 | # Chunk elements in the args 184 | args = [[_.chunk(group_size) if _ is not None else [None]*group_size for _ in args_] for args_ in args] # arg_name, model_name, group_name 185 | args = [[[args[j][i][k] for i in range(len(self.models))] for j in range(len(args))] for k in range(group_size)] # group_name, arg_name, model_name 186 | 187 | for t in range(self.seq_length + group_size - 1): 188 | for divm in range(group_size): 189 | if t >= divm and t <= self.seq_length + divm - 1: 190 | # add diversity 191 | logprobsf = logprobs_table[divm].data.float() 192 | # suppress previous word 193 | if decoding_constraint and t-divm > 0: 194 | logprobsf.scatter_(1, beam_seq_table[divm][t-divm-1].unsqueeze(1).cuda(), float('-inf')) 195 | # suppress UNK tokens in the decoding 196 | logprobsf[:,logprobsf.size(1)-1] = logprobsf[:, logprobsf.size(1)-1] - 1000 197 | # diversity is added here 198 | # the function directly modifies the logprobsf values and hence, we need to return 199 | # the unaugmented ones for sorting the candidates in the end. # for historical 200 | # reasons :-) 201 | unaug_logprobsf = add_diversity(beam_seq_table,logprobsf,t,divm,diversity_lambda,bdash) 202 | 203 | # infer new beams 204 | beam_seq_table[divm],\ 205 | beam_seq_logprobs_table[divm],\ 206 | beam_logprobs_sum_table[divm],\ 207 | state_table[divm],\ 208 | candidates_divm = beam_step(logprobsf, 209 | unaug_logprobsf, 210 | bdash, 211 | t-divm, 212 | beam_seq_table[divm], 213 | beam_seq_logprobs_table[divm], 214 | beam_logprobs_sum_table[divm], 215 | state_table[divm]) 216 | 217 | # if time's up... or if end token is reached then copy beams 218 | for vix in range(bdash): 219 | if beam_seq_table[divm][t-divm,vix] == 0 or t == self.seq_length + divm - 1: 220 | final_beam = { 221 | 'seq': beam_seq_table[divm][:, vix].clone(), 222 | 'logps': beam_seq_logprobs_table[divm][:, vix].clone(), 223 | 'unaug_p': beam_seq_logprobs_table[divm][:, vix].sum().item(), 224 | 'p': beam_logprobs_sum_table[divm][vix].item() 225 | } 226 | if max_ppl: 227 | final_beam['p'] = final_beam['p'] / (t-divm+1) 228 | done_beams_table[divm].append(final_beam) 229 | # don't continue beams from finished sequences 230 | beam_logprobs_sum_table[divm][vix] = -1000 231 | 232 | # move the current group one step forward in time 233 | 234 | it = beam_seq_table[divm][t-divm] 235 | logprobs_table[divm], state_table[divm] = self.get_logprobs_state(it.cuda(), *(args[divm] + [state_table[divm]])) 236 | 237 | # all beams are sorted by their log-probabilities 238 | done_beams_table = [sorted(done_beams_table[i], key=lambda x: -x['p'])[:bdash] for i in range(group_size)] 239 | done_beams = reduce(lambda a,b:a+b, done_beams_table) 240 | return done_beams 241 | -------------------------------------------------------------------------------- /models/AttModel.py: -------------------------------------------------------------------------------- 1 | # This file contains Att2in2, AdaAtt, AdaAttMO, TopDown model 2 | 3 | # AdaAtt is from Knowing When to Look: Adaptive Attention via A Visual Sentinel for Image Captioning 4 | # https://arxiv.org/abs/1612.01887 5 | # AdaAttMO is a modified version with maxout lstm 6 | 7 | # Att2in is from Self-critical Sequence Training for Image Captioning 8 | # https://arxiv.org/abs/1612.00563 9 | # In this file we only have Att2in2, which is a slightly different version of att2in, 10 | # in which the img feature embedding and word embedding is the same as what in adaatt. 11 | 12 | # TopDown is from Bottom-Up and Top-Down Attention for Image Captioning and VQA 13 | # https://arxiv.org/abs/1707.07998 14 | # However, it may not be identical to the author's architecture. 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import torch 21 | import torch.nn as nn 22 | import torch.nn.functional as F 23 | import misc.utils as utils 24 | from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence 25 | 26 | from .CaptionModel import CaptionModel 27 | 28 | def sort_pack_padded_sequence(input, lengths): 29 | sorted_lengths, indices = torch.sort(lengths, descending=True) 30 | tmp = pack_padded_sequence(input[indices], sorted_lengths, batch_first=True) 31 | inv_ix = indices.clone() 32 | inv_ix[indices] = torch.arange(0,len(indices)).type_as(inv_ix) 33 | return tmp, inv_ix 34 | 35 | def pad_unsort_packed_sequence(input, inv_ix): 36 | tmp, _ = pad_packed_sequence(input, batch_first=True) 37 | tmp = tmp[inv_ix] 38 | return tmp 39 | 40 | def pack_wrapper(module, att_feats, att_masks): 41 | if att_masks is not None: 42 | packed, inv_ix = sort_pack_padded_sequence(att_feats, att_masks.data.long().sum(1)) 43 | return pad_unsort_packed_sequence(PackedSequence(module(packed[0]), packed[1]), inv_ix) 44 | else: 45 | return module(att_feats) 46 | 47 | class AttModel(CaptionModel): 48 | def __init__(self, opt): 49 | super(AttModel, self).__init__() 50 | self.vocab_size = opt.vocab_size 51 | self.input_encoding_size = opt.input_encoding_size 52 | #self.rnn_type = opt.rnn_type 53 | self.rnn_size = opt.rnn_size 54 | self.num_layers = opt.num_layers 55 | self.drop_prob_lm = opt.drop_prob_lm 56 | self.seq_length = opt.seq_length 57 | self.fc_feat_size = opt.fc_feat_size 58 | self.att_feat_size = opt.att_feat_size 59 | self.att_hid_size = opt.att_hid_size 60 | 61 | self.use_bn = getattr(opt, 'use_bn', 0) 62 | 63 | self.ss_prob = 0.0 # Schedule sampling probability 64 | 65 | self.embed = nn.Sequential(nn.Embedding(self.vocab_size + 1, self.input_encoding_size), 66 | nn.ReLU(), 67 | nn.Dropout(self.drop_prob_lm)) 68 | self.fc_embed = nn.Sequential(nn.Linear(self.fc_feat_size, self.rnn_size), 69 | nn.ReLU(), 70 | nn.Dropout(self.drop_prob_lm)) 71 | self.att_embed = nn.Sequential(*( 72 | ((nn.BatchNorm1d(self.att_feat_size),) if self.use_bn else ())+ 73 | (nn.Linear(self.att_feat_size, self.rnn_size), 74 | nn.ReLU(), 75 | nn.Dropout(self.drop_prob_lm))+ 76 | ((nn.BatchNorm1d(self.rnn_size),) if self.use_bn==2 else ()))) 77 | 78 | self.logit_layers = getattr(opt, 'logit_layers', 1) 79 | if self.logit_layers == 1: 80 | self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1) 81 | else: 82 | self.logit = [[nn.Linear(self.rnn_size, self.rnn_size), nn.ReLU(), nn.Dropout(0.5)] for _ in range(opt.logit_layers - 1)] 83 | self.logit = nn.Sequential(*(reduce(lambda x,y:x+y, self.logit) + [nn.Linear(self.rnn_size, self.vocab_size + 1)])) 84 | self.ctx2att = nn.Linear(self.rnn_size, self.att_hid_size) 85 | 86 | def init_hidden(self, bsz): 87 | weight = next(self.parameters()) 88 | return (weight.new_zeros(self.num_layers, bsz, self.rnn_size), 89 | weight.new_zeros(self.num_layers, bsz, self.rnn_size)) 90 | 91 | def clip_att(self, att_feats, att_masks): 92 | # Clip the length of att_masks and att_feats to the maximum length 93 | if att_masks is not None: 94 | max_len = att_masks.data.long().sum(1).max() 95 | att_feats = att_feats[:, :max_len].contiguous() 96 | att_masks = att_masks[:, :max_len].contiguous() 97 | return att_feats, att_masks 98 | 99 | def _prepare_feature(self, fc_feats, att_feats, att_masks): 100 | att_feats, att_masks = self.clip_att(att_feats, att_masks) 101 | 102 | # embed fc and att feats 103 | fc_feats = self.fc_embed(fc_feats) 104 | att_feats = pack_wrapper(self.att_embed, att_feats, att_masks) 105 | 106 | # Project the attention feats first to reduce memory and computation comsumptions. 107 | p_att_feats = self.ctx2att(att_feats) 108 | 109 | return fc_feats, att_feats, p_att_feats, att_masks 110 | 111 | def _forward(self, fc_feats, att_feats, seq, att_masks=None): 112 | batch_size = fc_feats.size(0) 113 | state = self.init_hidden(batch_size) 114 | 115 | outputs = fc_feats.new_zeros(batch_size, seq.size(1) - 1, self.vocab_size+1) 116 | 117 | # Prepare the features 118 | p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks) 119 | # pp_att_feats is used for attention, we cache it in advance to reduce computation cost 120 | 121 | for i in range(seq.size(1) - 1): 122 | if self.training and i >= 1 and self.ss_prob > 0.0: # otherwiste no need to sample 123 | sample_prob = fc_feats.new(batch_size).uniform_(0, 1) 124 | sample_mask = sample_prob < self.ss_prob 125 | if sample_mask.sum() == 0: 126 | it = seq[:, i].clone() 127 | else: 128 | sample_ind = sample_mask.nonzero().view(-1) 129 | it = seq[:, i].data.clone() 130 | #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1) 131 | #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1)) 132 | # prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1) 133 | prob_prev = torch.exp(outputs[:, i-1].detach()) # fetch prev distribution: shape Nx(M+1) 134 | it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind)) 135 | else: 136 | it = seq[:, i].clone() 137 | # break if all the sequences end 138 | if i >= 1 and seq[:, i].sum() == 0: 139 | break 140 | 141 | output, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state) 142 | outputs[:, i] = output 143 | 144 | return outputs 145 | 146 | def get_logprobs_state(self, it, fc_feats, att_feats, p_att_feats, att_masks, state): 147 | # 'it' contains a word index 148 | xt = self.embed(it) 149 | 150 | output, state = self.core(xt, fc_feats, att_feats, p_att_feats, state, att_masks) 151 | logprobs = F.log_softmax(self.logit(output), dim=1) 152 | 153 | return logprobs, state 154 | 155 | def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}): 156 | beam_size = opt.get('beam_size', 10) 157 | batch_size = fc_feats.size(0) 158 | 159 | p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks) 160 | 161 | assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' 162 | seq = torch.LongTensor(self.seq_length, batch_size).zero_() 163 | seqLogprobs = torch.FloatTensor(self.seq_length, batch_size) 164 | # lets process every image independently for now, for simplicity 165 | 166 | self.done_beams = [[] for _ in range(batch_size)] 167 | for k in range(batch_size): 168 | state = self.init_hidden(beam_size) 169 | tmp_fc_feats = p_fc_feats[k:k+1].expand(beam_size, p_fc_feats.size(1)) 170 | tmp_att_feats = p_att_feats[k:k+1].expand(*((beam_size,)+p_att_feats.size()[1:])).contiguous() 171 | tmp_p_att_feats = pp_att_feats[k:k+1].expand(*((beam_size,)+pp_att_feats.size()[1:])).contiguous() 172 | tmp_att_masks = p_att_masks[k:k+1].expand(*((beam_size,)+p_att_masks.size()[1:])).contiguous() if att_masks is not None else None 173 | 174 | for t in range(1): 175 | if t == 0: # input 176 | it = fc_feats.new_zeros([beam_size], dtype=torch.long) 177 | 178 | logprobs, state = self.get_logprobs_state(it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, state) 179 | 180 | self.done_beams[k] = self.beam_search(state, logprobs, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, opt=opt) 181 | seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score 182 | seqLogprobs[:, k] = self.done_beams[k][0]['logps'] 183 | # return the samples and their log likelihoods 184 | return seq.transpose(0, 1), seqLogprobs.transpose(0, 1) 185 | 186 | def _sample(self, fc_feats, att_feats, att_masks=None, opt={}): 187 | 188 | sample_max = opt.get('sample_max', 1) 189 | beam_size = opt.get('beam_size', 1) 190 | temperature = opt.get('temperature', 1.0) 191 | decoding_constraint = opt.get('decoding_constraint', 0) 192 | if beam_size > 1: 193 | return self._sample_beam(fc_feats, att_feats, att_masks, opt) 194 | 195 | batch_size = fc_feats.size(0) 196 | state = self.init_hidden(batch_size) 197 | 198 | p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks) 199 | 200 | seq = fc_feats.new_zeros((batch_size, self.seq_length), dtype=torch.long) 201 | seqLogprobs = fc_feats.new_zeros(batch_size, self.seq_length) 202 | for t in range(self.seq_length + 1): 203 | if t == 0: # input 204 | it = fc_feats.new_zeros(batch_size, dtype=torch.long) 205 | 206 | logprobs, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state) 207 | 208 | if decoding_constraint and t > 0: 209 | tmp = logprobs.new_zeros(logprobs.size()) 210 | tmp.scatter_(1, seq[:,t-1].data.unsqueeze(1), float('-inf')) 211 | logprobs = logprobs + tmp 212 | 213 | # sample the next word 214 | if t == self.seq_length: # skip if we achieve maximum length 215 | break 216 | if sample_max: 217 | sampleLogprobs, it = torch.max(logprobs.data, 1) 218 | it = it.view(-1).long() 219 | else: 220 | if temperature == 1.0: 221 | prob_prev = torch.exp(logprobs.data) # fetch prev distribution: shape Nx(M+1) 222 | else: 223 | # scale logprobs by temperature 224 | prob_prev = torch.exp(torch.div(logprobs.data, temperature)) 225 | it = torch.multinomial(prob_prev, 1) 226 | sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions 227 | it = it.view(-1).long() # and flatten indices for downstream processing 228 | 229 | # stop when all finished 230 | if t == 0: 231 | unfinished = it > 0 232 | else: 233 | unfinished = unfinished * (it > 0) 234 | it = it * unfinished.type_as(it) 235 | seq[:,t] = it 236 | seqLogprobs[:,t] = sampleLogprobs.view(-1) 237 | # quit loop if all sequences have finished 238 | if unfinished.sum() == 0: 239 | break 240 | 241 | return seq, seqLogprobs 242 | 243 | class AdaAtt_lstm(nn.Module): 244 | def __init__(self, opt, use_maxout=True): 245 | super(AdaAtt_lstm, self).__init__() 246 | self.input_encoding_size = opt.input_encoding_size 247 | #self.rnn_type = opt.rnn_type 248 | self.rnn_size = opt.rnn_size 249 | self.num_layers = opt.num_layers 250 | self.drop_prob_lm = opt.drop_prob_lm 251 | self.fc_feat_size = opt.fc_feat_size 252 | self.att_feat_size = opt.att_feat_size 253 | self.att_hid_size = opt.att_hid_size 254 | 255 | self.use_maxout = use_maxout 256 | 257 | # Build a LSTM 258 | self.w2h = nn.Linear(self.input_encoding_size, (4+(use_maxout==True)) * self.rnn_size) 259 | self.v2h = nn.Linear(self.rnn_size, (4+(use_maxout==True)) * self.rnn_size) 260 | 261 | self.i2h = nn.ModuleList([nn.Linear(self.rnn_size, (4+(use_maxout==True)) * self.rnn_size) for _ in range(self.num_layers - 1)]) 262 | self.h2h = nn.ModuleList([nn.Linear(self.rnn_size, (4+(use_maxout==True)) * self.rnn_size) for _ in range(self.num_layers)]) 263 | 264 | # Layers for getting the fake region 265 | if self.num_layers == 1: 266 | self.r_w2h = nn.Linear(self.input_encoding_size, self.rnn_size) 267 | self.r_v2h = nn.Linear(self.rnn_size, self.rnn_size) 268 | else: 269 | self.r_i2h = nn.Linear(self.rnn_size, self.rnn_size) 270 | self.r_h2h = nn.Linear(self.rnn_size, self.rnn_size) 271 | 272 | 273 | def forward(self, xt, img_fc, state): 274 | 275 | hs = [] 276 | cs = [] 277 | for L in range(self.num_layers): 278 | # c,h from previous timesteps 279 | prev_h = state[0][L] 280 | prev_c = state[1][L] 281 | # the input to this layer 282 | if L == 0: 283 | x = xt 284 | i2h = self.w2h(x) + self.v2h(img_fc) 285 | else: 286 | x = hs[-1] 287 | x = F.dropout(x, self.drop_prob_lm, self.training) 288 | i2h = self.i2h[L-1](x) 289 | 290 | all_input_sums = i2h+self.h2h[L](prev_h) 291 | 292 | sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size) 293 | sigmoid_chunk = F.sigmoid(sigmoid_chunk) 294 | # decode the gates 295 | in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size) 296 | forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size) 297 | out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size) 298 | # decode the write inputs 299 | if not self.use_maxout: 300 | in_transform = F.tanh(all_input_sums.narrow(1, 3 * self.rnn_size, self.rnn_size)) 301 | else: 302 | in_transform = all_input_sums.narrow(1, 3 * self.rnn_size, 2 * self.rnn_size) 303 | in_transform = torch.max(\ 304 | in_transform.narrow(1, 0, self.rnn_size), 305 | in_transform.narrow(1, self.rnn_size, self.rnn_size)) 306 | # perform the LSTM update 307 | next_c = forget_gate * prev_c + in_gate * in_transform 308 | # gated cells form the output 309 | tanh_nex_c = F.tanh(next_c) 310 | next_h = out_gate * tanh_nex_c 311 | if L == self.num_layers-1: 312 | if L == 0: 313 | i2h = self.r_w2h(x) + self.r_v2h(img_fc) 314 | else: 315 | i2h = self.r_i2h(x) 316 | n5 = i2h+self.r_h2h(prev_h) 317 | fake_region = F.sigmoid(n5) * tanh_nex_c 318 | 319 | cs.append(next_c) 320 | hs.append(next_h) 321 | 322 | # set up the decoder 323 | top_h = hs[-1] 324 | top_h = F.dropout(top_h, self.drop_prob_lm, self.training) 325 | fake_region = F.dropout(fake_region, self.drop_prob_lm, self.training) 326 | 327 | state = (torch.cat([_.unsqueeze(0) for _ in hs], 0), 328 | torch.cat([_.unsqueeze(0) for _ in cs], 0)) 329 | return top_h, fake_region, state 330 | 331 | class AdaAtt_attention(nn.Module): 332 | def __init__(self, opt): 333 | super(AdaAtt_attention, self).__init__() 334 | self.input_encoding_size = opt.input_encoding_size 335 | #self.rnn_type = opt.rnn_type 336 | self.rnn_size = opt.rnn_size 337 | self.drop_prob_lm = opt.drop_prob_lm 338 | self.att_hid_size = opt.att_hid_size 339 | 340 | # fake region embed 341 | self.fr_linear = nn.Sequential( 342 | nn.Linear(self.rnn_size, self.input_encoding_size), 343 | nn.ReLU(), 344 | nn.Dropout(self.drop_prob_lm)) 345 | self.fr_embed = nn.Linear(self.input_encoding_size, self.att_hid_size) 346 | 347 | # h out embed 348 | self.ho_linear = nn.Sequential( 349 | nn.Linear(self.rnn_size, self.input_encoding_size), 350 | nn.Tanh(), 351 | nn.Dropout(self.drop_prob_lm)) 352 | self.ho_embed = nn.Linear(self.input_encoding_size, self.att_hid_size) 353 | 354 | self.alpha_net = nn.Linear(self.att_hid_size, 1) 355 | self.att2h = nn.Linear(self.rnn_size, self.rnn_size) 356 | 357 | def forward(self, h_out, fake_region, conv_feat, conv_feat_embed, att_masks=None): 358 | 359 | # View into three dimensions 360 | att_size = conv_feat.numel() // conv_feat.size(0) // self.rnn_size 361 | conv_feat = conv_feat.view(-1, att_size, self.rnn_size) 362 | conv_feat_embed = conv_feat_embed.view(-1, att_size, self.att_hid_size) 363 | 364 | # view neighbor from bach_size * neighbor_num x rnn_size to bach_size x rnn_size * neighbor_num 365 | fake_region = self.fr_linear(fake_region) 366 | fake_region_embed = self.fr_embed(fake_region) 367 | 368 | h_out_linear = self.ho_linear(h_out) 369 | h_out_embed = self.ho_embed(h_out_linear) 370 | 371 | txt_replicate = h_out_embed.unsqueeze(1).expand(h_out_embed.size(0), att_size + 1, h_out_embed.size(1)) 372 | 373 | img_all = torch.cat([fake_region.view(-1,1,self.input_encoding_size), conv_feat], 1) 374 | img_all_embed = torch.cat([fake_region_embed.view(-1,1,self.input_encoding_size), conv_feat_embed], 1) 375 | 376 | hA = F.tanh(img_all_embed + txt_replicate) 377 | hA = F.dropout(hA,self.drop_prob_lm, self.training) 378 | 379 | hAflat = self.alpha_net(hA.view(-1, self.att_hid_size)) 380 | PI = F.softmax(hAflat.view(-1, att_size + 1), dim=1) 381 | 382 | if att_masks is not None: 383 | att_masks = att_masks.view(-1, att_size) 384 | PI = PI * torch.cat([att_masks[:,:1], att_masks], 1) # assume one one at the first time step. 385 | PI = PI / PI.sum(1, keepdim=True) 386 | 387 | visAtt = torch.bmm(PI.unsqueeze(1), img_all) 388 | visAttdim = visAtt.squeeze(1) 389 | 390 | atten_out = visAttdim + h_out_linear 391 | 392 | h = F.tanh(self.att2h(atten_out)) 393 | h = F.dropout(h, self.drop_prob_lm, self.training) 394 | return h 395 | 396 | class AdaAttCore(nn.Module): 397 | def __init__(self, opt, use_maxout=False): 398 | super(AdaAttCore, self).__init__() 399 | self.lstm = AdaAtt_lstm(opt, use_maxout) 400 | self.attention = AdaAtt_attention(opt) 401 | 402 | def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None): 403 | h_out, p_out, state = self.lstm(xt, fc_feats, state) 404 | atten_out = self.attention(h_out, p_out, att_feats, p_att_feats, att_masks) 405 | return atten_out, state 406 | 407 | class TopDownCore(nn.Module): 408 | def __init__(self, opt, use_maxout=False): 409 | super(TopDownCore, self).__init__() 410 | self.drop_prob_lm = opt.drop_prob_lm 411 | 412 | self.att_lstm = nn.LSTMCell(opt.input_encoding_size + opt.rnn_size * 2, opt.rnn_size) # we, fc, h^2_t-1 413 | self.lang_lstm = nn.LSTMCell(opt.rnn_size * 2, opt.rnn_size) # h^1_t, \hat v 414 | self.attention = Attention(opt) 415 | 416 | def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None): 417 | prev_h = state[0][-1] 418 | att_lstm_input = torch.cat([prev_h, fc_feats, xt], 1) 419 | 420 | h_att, c_att = self.att_lstm(att_lstm_input, (state[0][0], state[1][0])) 421 | 422 | att = self.attention(h_att, att_feats, p_att_feats, att_masks) 423 | 424 | lang_lstm_input = torch.cat([att, h_att], 1) 425 | # lang_lstm_input = torch.cat([att, F.dropout(h_att, self.drop_prob_lm, self.training)], 1) ????? 426 | 427 | h_lang, c_lang = self.lang_lstm(lang_lstm_input, (state[0][1], state[1][1])) 428 | 429 | output = F.dropout(h_lang, self.drop_prob_lm, self.training) 430 | state = (torch.stack([h_att, h_lang]), torch.stack([c_att, c_lang])) 431 | 432 | return output, state 433 | 434 | 435 | ############################################################################ 436 | # Notice: 437 | # StackAtt and DenseAtt are models that I randomly designed. 438 | # They are not related to any paper. 439 | ############################################################################ 440 | 441 | from .FCModel import LSTMCore 442 | class StackAttCore(nn.Module): 443 | def __init__(self, opt, use_maxout=False): 444 | super(StackAttCore, self).__init__() 445 | self.drop_prob_lm = opt.drop_prob_lm 446 | 447 | # self.att0 = Attention(opt) 448 | self.att1 = Attention(opt) 449 | self.att2 = Attention(opt) 450 | 451 | opt_input_encoding_size = opt.input_encoding_size 452 | opt.input_encoding_size = opt.input_encoding_size + opt.rnn_size 453 | self.lstm0 = LSTMCore(opt) # att_feat + word_embedding 454 | opt.input_encoding_size = opt.rnn_size * 2 455 | self.lstm1 = LSTMCore(opt) 456 | self.lstm2 = LSTMCore(opt) 457 | opt.input_encoding_size = opt_input_encoding_size 458 | 459 | # self.emb1 = nn.Linear(opt.rnn_size, opt.rnn_size) 460 | self.emb2 = nn.Linear(opt.rnn_size, opt.rnn_size) 461 | 462 | def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None): 463 | # att_res_0 = self.att0(state[0][-1], att_feats, p_att_feats, att_masks) 464 | h_0, state_0 = self.lstm0(torch.cat([xt,fc_feats],1), [state[0][0:1], state[1][0:1]]) 465 | att_res_1 = self.att1(h_0, att_feats, p_att_feats, att_masks) 466 | h_1, state_1 = self.lstm1(torch.cat([h_0,att_res_1],1), [state[0][1:2], state[1][1:2]]) 467 | att_res_2 = self.att2(h_1 + self.emb2(att_res_1), att_feats, p_att_feats, att_masks) 468 | h_2, state_2 = self.lstm2(torch.cat([h_1,att_res_2],1), [state[0][2:3], state[1][2:3]]) 469 | 470 | return h_2, [torch.cat(_, 0) for _ in zip(state_0, state_1, state_2)] 471 | 472 | class DenseAttCore(nn.Module): 473 | def __init__(self, opt, use_maxout=False): 474 | super(DenseAttCore, self).__init__() 475 | self.drop_prob_lm = opt.drop_prob_lm 476 | 477 | # self.att0 = Attention(opt) 478 | self.att1 = Attention(opt) 479 | self.att2 = Attention(opt) 480 | 481 | opt_input_encoding_size = opt.input_encoding_size 482 | opt.input_encoding_size = opt.input_encoding_size + opt.rnn_size 483 | self.lstm0 = LSTMCore(opt) # att_feat + word_embedding 484 | opt.input_encoding_size = opt.rnn_size * 2 485 | self.lstm1 = LSTMCore(opt) 486 | self.lstm2 = LSTMCore(opt) 487 | opt.input_encoding_size = opt_input_encoding_size 488 | 489 | # self.emb1 = nn.Linear(opt.rnn_size, opt.rnn_size) 490 | self.emb2 = nn.Linear(opt.rnn_size, opt.rnn_size) 491 | 492 | # fuse h_0 and h_1 493 | self.fusion1 = nn.Sequential(nn.Linear(opt.rnn_size*2, opt.rnn_size), 494 | nn.ReLU(), 495 | nn.Dropout(opt.drop_prob_lm)) 496 | # fuse h_0, h_1 and h_2 497 | self.fusion2 = nn.Sequential(nn.Linear(opt.rnn_size*3, opt.rnn_size), 498 | nn.ReLU(), 499 | nn.Dropout(opt.drop_prob_lm)) 500 | 501 | def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None): 502 | # att_res_0 = self.att0(state[0][-1], att_feats, p_att_feats, att_masks) 503 | h_0, state_0 = self.lstm0(torch.cat([xt,fc_feats],1), [state[0][0:1], state[1][0:1]]) 504 | att_res_1 = self.att1(h_0, att_feats, p_att_feats, att_masks) 505 | h_1, state_1 = self.lstm1(torch.cat([h_0,att_res_1],1), [state[0][1:2], state[1][1:2]]) 506 | att_res_2 = self.att2(h_1 + self.emb2(att_res_1), att_feats, p_att_feats, att_masks) 507 | h_2, state_2 = self.lstm2(torch.cat([self.fusion1(torch.cat([h_0, h_1], 1)),att_res_2],1), [state[0][2:3], state[1][2:3]]) 508 | 509 | return self.fusion2(torch.cat([h_0, h_1, h_2], 1)), [torch.cat(_, 0) for _ in zip(state_0, state_1, state_2)] 510 | 511 | class Attention(nn.Module): 512 | def __init__(self, opt): 513 | super(Attention, self).__init__() 514 | self.rnn_size = opt.rnn_size 515 | self.att_hid_size = opt.att_hid_size 516 | 517 | self.h2att = nn.Linear(self.rnn_size, self.att_hid_size) 518 | self.alpha_net = nn.Linear(self.att_hid_size, 1) 519 | 520 | def forward(self, h, att_feats, p_att_feats, att_masks=None): 521 | # The p_att_feats here is already projected 522 | att_size = att_feats.numel() // att_feats.size(0) // att_feats.size(-1) 523 | att = p_att_feats.view(-1, att_size, self.att_hid_size) 524 | 525 | att_h = self.h2att(h) # batch * att_hid_size 526 | att_h = att_h.unsqueeze(1).expand_as(att) # batch * att_size * att_hid_size 527 | dot = att + att_h # batch * att_size * att_hid_size 528 | dot = F.tanh(dot) # batch * att_size * att_hid_size 529 | dot = dot.view(-1, self.att_hid_size) # (batch * att_size) * att_hid_size 530 | dot = self.alpha_net(dot) # (batch * att_size) * 1 531 | dot = dot.view(-1, att_size) # batch * att_size 532 | 533 | weight = F.softmax(dot, dim=1) # batch * att_size 534 | if att_masks is not None: 535 | weight = weight * att_masks.view(-1, att_size).float() 536 | weight = weight / weight.sum(1, keepdim=True) # normalize to 1 537 | att_feats_ = att_feats.view(-1, att_size, att_feats.size(-1)) # batch * att_size * att_feat_size 538 | att_res = torch.bmm(weight.unsqueeze(1), att_feats_).squeeze(1) # batch * att_feat_size 539 | 540 | return att_res 541 | 542 | class Att2in2Core(nn.Module): 543 | def __init__(self, opt): 544 | super(Att2in2Core, self).__init__() 545 | self.input_encoding_size = opt.input_encoding_size 546 | #self.rnn_type = opt.rnn_type 547 | self.rnn_size = opt.rnn_size 548 | #self.num_layers = opt.num_layers 549 | self.drop_prob_lm = opt.drop_prob_lm 550 | self.fc_feat_size = opt.fc_feat_size 551 | self.att_feat_size = opt.att_feat_size 552 | self.att_hid_size = opt.att_hid_size 553 | 554 | # Build a LSTM 555 | self.a2c = nn.Linear(self.rnn_size, 2 * self.rnn_size) 556 | self.i2h = nn.Linear(self.input_encoding_size, 5 * self.rnn_size) 557 | self.h2h = nn.Linear(self.rnn_size, 5 * self.rnn_size) 558 | self.dropout = nn.Dropout(self.drop_prob_lm) 559 | 560 | self.attention = Attention(opt) 561 | 562 | def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None): 563 | att_res = self.attention(state[0][-1], att_feats, p_att_feats, att_masks) 564 | 565 | all_input_sums = self.i2h(xt) + self.h2h(state[0][-1]) 566 | sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size) 567 | sigmoid_chunk = F.sigmoid(sigmoid_chunk) 568 | in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size) 569 | forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size) 570 | out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size) 571 | 572 | in_transform = all_input_sums.narrow(1, 3 * self.rnn_size, 2 * self.rnn_size) + \ 573 | self.a2c(att_res) 574 | in_transform = torch.max(\ 575 | in_transform.narrow(1, 0, self.rnn_size), 576 | in_transform.narrow(1, self.rnn_size, self.rnn_size)) 577 | next_c = forget_gate * state[1][-1] + in_gate * in_transform 578 | next_h = out_gate * F.tanh(next_c) 579 | 580 | output = self.dropout(next_h) 581 | state = (next_h.unsqueeze(0), next_c.unsqueeze(0)) 582 | return output, state 583 | 584 | class Att2inCore(Att2in2Core): 585 | def __init__(self, opt): 586 | super(Att2inCore, self).__init__(opt) 587 | del self.a2c 588 | self.a2c = nn.Linear(self.att_feat_size, 2 * self.rnn_size) 589 | 590 | """ 591 | Note this is my attempt to replicate att2all model in self-critical paper. 592 | However, this is not a correct replication actually. Will fix it. 593 | """ 594 | class Att2all2Core(nn.Module): 595 | def __init__(self, opt): 596 | super(Att2all2Core, self).__init__() 597 | self.input_encoding_size = opt.input_encoding_size 598 | #self.rnn_type = opt.rnn_type 599 | self.rnn_size = opt.rnn_size 600 | #self.num_layers = opt.num_layers 601 | self.drop_prob_lm = opt.drop_prob_lm 602 | self.fc_feat_size = opt.fc_feat_size 603 | self.att_feat_size = opt.att_feat_size 604 | self.att_hid_size = opt.att_hid_size 605 | 606 | # Build a LSTM 607 | self.a2h = nn.Linear(self.rnn_size, 5 * self.rnn_size) 608 | self.i2h = nn.Linear(self.input_encoding_size, 5 * self.rnn_size) 609 | self.h2h = nn.Linear(self.rnn_size, 5 * self.rnn_size) 610 | self.dropout = nn.Dropout(self.drop_prob_lm) 611 | 612 | self.attention = Attention(opt) 613 | 614 | def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None): 615 | att_res = self.attention(state[0][-1], att_feats, p_att_feats, att_masks) 616 | 617 | all_input_sums = self.i2h(xt) + self.h2h(state[0][-1]) + self.a2h(att_res) 618 | sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size) 619 | sigmoid_chunk = F.sigmoid(sigmoid_chunk) 620 | in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size) 621 | forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size) 622 | out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size) 623 | 624 | in_transform = all_input_sums.narrow(1, 3 * self.rnn_size, 2 * self.rnn_size) 625 | in_transform = torch.max(\ 626 | in_transform.narrow(1, 0, self.rnn_size), 627 | in_transform.narrow(1, self.rnn_size, self.rnn_size)) 628 | next_c = forget_gate * state[1][-1] + in_gate * in_transform 629 | next_h = out_gate * F.tanh(next_c) 630 | 631 | output = self.dropout(next_h) 632 | state = (next_h.unsqueeze(0), next_c.unsqueeze(0)) 633 | return output, state 634 | 635 | class AdaAttModel(AttModel): 636 | def __init__(self, opt): 637 | super(AdaAttModel, self).__init__(opt) 638 | self.core = AdaAttCore(opt) 639 | 640 | # AdaAtt with maxout lstm 641 | class AdaAttMOModel(AttModel): 642 | def __init__(self, opt): 643 | super(AdaAttMOModel, self).__init__(opt) 644 | self.core = AdaAttCore(opt, True) 645 | 646 | class Att2in2Model(AttModel): 647 | def __init__(self, opt): 648 | super(Att2in2Model, self).__init__(opt) 649 | self.core = Att2in2Core(opt) 650 | delattr(self, 'fc_embed') 651 | self.fc_embed = lambda x : x 652 | 653 | class Att2all2Model(AttModel): 654 | def __init__(self, opt): 655 | super(Att2all2Model, self).__init__(opt) 656 | self.core = Att2all2Core(opt) 657 | delattr(self, 'fc_embed') 658 | self.fc_embed = lambda x : x 659 | 660 | class TopDownModel(AttModel): 661 | def __init__(self, opt): 662 | super(TopDownModel, self).__init__(opt) 663 | self.num_layers = 2 664 | self.core = TopDownCore(opt) 665 | 666 | class StackAttModel(AttModel): 667 | def __init__(self, opt): 668 | super(StackAttModel, self).__init__(opt) 669 | self.num_layers = 3 670 | self.core = StackAttCore(opt) 671 | 672 | class DenseAttModel(AttModel): 673 | def __init__(self, opt): 674 | super(DenseAttModel, self).__init__(opt) 675 | self.num_layers = 3 676 | self.core = DenseAttCore(opt) 677 | 678 | class Att2inModel(AttModel): 679 | def __init__(self, opt): 680 | super(Att2inModel, self).__init__(opt) 681 | del self.embed, self.fc_embed, self.att_embed 682 | self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size) 683 | self.fc_embed = self.att_embed = lambda x: x 684 | del self.ctx2att 685 | self.ctx2att = nn.Linear(self.att_feat_size, self.att_hid_size) 686 | self.core = Att2inCore(opt) 687 | self.init_weights() 688 | 689 | def init_weights(self): 690 | initrange = 0.1 691 | self.embed.weight.data.uniform_(-initrange, initrange) 692 | self.logit.bias.data.fill_(0) 693 | self.logit.weight.data.uniform_(-initrange, initrange) 694 | -------------------------------------------------------------------------------- /models/CaptionModel.py: -------------------------------------------------------------------------------- 1 | # This file contains ShowAttendTell and AllImg model 2 | 3 | # ShowAttendTell is from Show, Attend and Tell: Neural Image Caption Generation with Visual Attention 4 | # https://arxiv.org/abs/1502.03044 5 | 6 | # AllImg is a model where 7 | # img feature is concatenated with word embedding at every time step as the input of lstm 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | from torch.autograd import * 16 | import misc.utils as utils 17 | 18 | 19 | class CaptionModel(nn.Module): 20 | def __init__(self): 21 | super(CaptionModel, self).__init__() 22 | 23 | # implements beam search 24 | # calls beam_step and returns the final set of beams 25 | # augments log-probabilities with diversity terms when number of groups > 1 26 | 27 | def forward(self, *args, **kwargs): 28 | mode = kwargs.get('mode', 'forward') 29 | if 'mode' in kwargs: 30 | del kwargs['mode'] 31 | return getattr(self, '_'+mode)(*args, **kwargs) 32 | 33 | def beam_search(self, init_state, init_logprobs, *args, **kwargs): 34 | 35 | # function computes the similarity score to be augmented 36 | def add_diversity(beam_seq_table, logprobsf, t, divm, diversity_lambda, bdash): 37 | local_time = t - divm 38 | unaug_logprobsf = logprobsf.clone() 39 | for prev_choice in range(divm): 40 | prev_decisions = beam_seq_table[prev_choice][local_time] 41 | for sub_beam in range(bdash): 42 | for prev_labels in range(bdash): 43 | logprobsf[sub_beam][prev_decisions[prev_labels]] = logprobsf[sub_beam][prev_decisions[prev_labels]] - diversity_lambda 44 | return unaug_logprobsf 45 | 46 | # does one step of classical beam search 47 | 48 | def beam_step(logprobsf, unaug_logprobsf, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state): 49 | #INPUTS: 50 | #logprobsf: probabilities augmented after diversity 51 | #beam_size: obvious 52 | #t : time instant 53 | #beam_seq : tensor contanining the beams 54 | #beam_seq_logprobs: tensor contanining the beam logprobs 55 | #beam_logprobs_sum: tensor contanining joint logprobs 56 | #OUPUTS: 57 | #beam_seq : tensor containing the word indices of the decoded captions 58 | #beam_seq_logprobs : log-probability of each decision made, same size as beam_seq 59 | #beam_logprobs_sum : joint log-probability of each beam 60 | 61 | ys,ix = torch.sort(logprobsf,1,True) 62 | candidates = [] 63 | cols = min(beam_size, ys.size(1)) 64 | rows = beam_size 65 | if t == 0: 66 | rows = 1 67 | for c in range(cols): # for each column (word, essentially) 68 | for q in range(rows): # for each beam expansion 69 | #compute logprob of expanding beam q with word in (sorted) position c 70 | local_logprob = ys[q,c].item() 71 | candidate_logprob = beam_logprobs_sum[q] + local_logprob 72 | local_unaug_logprob = unaug_logprobsf[q,ix[q,c]] 73 | candidates.append({'c':ix[q,c], 'q':q, 'p':candidate_logprob, 'r':local_unaug_logprob}) 74 | candidates = sorted(candidates, key=lambda x: -x['p']) 75 | 76 | new_state = [_.clone() for _ in state] 77 | #beam_seq_prev, beam_seq_logprobs_prev 78 | if t >= 1: 79 | #we''ll need these as reference when we fork beams around 80 | beam_seq_prev = beam_seq[:t].clone() 81 | beam_seq_logprobs_prev = beam_seq_logprobs[:t].clone() 82 | for vix in range(beam_size): 83 | v = candidates[vix] 84 | #fork beam index q into index vix 85 | if t >= 1: 86 | beam_seq[:t, vix] = beam_seq_prev[:, v['q']] 87 | beam_seq_logprobs[:t, vix] = beam_seq_logprobs_prev[:, v['q']] 88 | #rearrange recurrent states 89 | for state_ix in range(len(new_state)): 90 | # copy over state in previous beam q to new beam at vix 91 | new_state[state_ix][:, vix] = state[state_ix][:, v['q']] # dimension one is time step 92 | #append new end terminal at the end of this beam 93 | beam_seq[t, vix] = v['c'] # c'th word is the continuation 94 | beam_seq_logprobs[t, vix] = v['r'] # the raw logprob here 95 | beam_logprobs_sum[vix] = v['p'] # the new (sum) logprob along this beam 96 | state = new_state 97 | return beam_seq,beam_seq_logprobs,beam_logprobs_sum,state,candidates 98 | 99 | # Start diverse_beam_search 100 | opt = kwargs['opt'] 101 | beam_size = opt.get('beam_size', 10) 102 | group_size = opt.get('group_size', 1) 103 | diversity_lambda = opt.get('diversity_lambda', 0.5) 104 | decoding_constraint = opt.get('decoding_constraint', 0) 105 | max_ppl = opt.get('max_ppl', 0) 106 | bdash = beam_size // group_size # beam per group 107 | 108 | # INITIALIZATIONS 109 | beam_seq_table = [torch.LongTensor(self.seq_length, bdash).zero_() for _ in range(group_size)] 110 | beam_seq_logprobs_table = [torch.FloatTensor(self.seq_length, bdash).zero_() for _ in range(group_size)] 111 | beam_logprobs_sum_table = [torch.zeros(bdash) for _ in range(group_size)] 112 | 113 | # logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1) 114 | done_beams_table = [[] for _ in range(group_size)] 115 | state_table = [list(torch.unbind(_)) for _ in torch.stack(init_state).chunk(group_size, 2)] 116 | logprobs_table = list(init_logprobs.chunk(group_size, 0)) 117 | # END INIT 118 | 119 | # Chunk elements in the args 120 | args = list(args) 121 | args = [_.chunk(group_size) if _ is not None else [None]*group_size for _ in args] 122 | args = [[args[i][j] for i in range(len(args))] for j in range(group_size)] 123 | 124 | for t in range(self.seq_length + group_size - 1): 125 | for divm in range(group_size): 126 | if t >= divm and t <= self.seq_length + divm - 1: 127 | # add diversity 128 | logprobsf = logprobs_table[divm].data.float() 129 | # suppress previous word 130 | if decoding_constraint and t-divm > 0: 131 | logprobsf.scatter_(1, beam_seq_table[divm][t-divm-1].unsqueeze(1).cuda(), float('-inf')) 132 | # suppress UNK tokens in the decoding 133 | logprobsf[:,logprobsf.size(1)-1] = logprobsf[:, logprobsf.size(1)-1] - 1000 134 | # diversity is added here 135 | # the function directly modifies the logprobsf values and hence, we need to return 136 | # the unaugmented ones for sorting the candidates in the end. # for historical 137 | # reasons :-) 138 | unaug_logprobsf = add_diversity(beam_seq_table,logprobsf,t,divm,diversity_lambda,bdash) 139 | 140 | # infer new beams 141 | beam_seq_table[divm],\ 142 | beam_seq_logprobs_table[divm],\ 143 | beam_logprobs_sum_table[divm],\ 144 | state_table[divm],\ 145 | candidates_divm = beam_step(logprobsf, 146 | unaug_logprobsf, 147 | bdash, 148 | t-divm, 149 | beam_seq_table[divm], 150 | beam_seq_logprobs_table[divm], 151 | beam_logprobs_sum_table[divm], 152 | state_table[divm]) 153 | 154 | # if time's up... or if end token is reached then copy beams 155 | for vix in range(bdash): 156 | if beam_seq_table[divm][t-divm,vix] == 0 or t == self.seq_length + divm - 1: 157 | final_beam = { 158 | 'seq': beam_seq_table[divm][:, vix].clone(), 159 | 'logps': beam_seq_logprobs_table[divm][:, vix].clone(), 160 | 'unaug_p': beam_seq_logprobs_table[divm][:, vix].sum().item(), 161 | 'p': beam_logprobs_sum_table[divm][vix].item() 162 | } 163 | if max_ppl: 164 | final_beam['p'] = final_beam['p'] / (t-divm+1) 165 | done_beams_table[divm].append(final_beam) 166 | # don't continue beams from finished sequences 167 | beam_logprobs_sum_table[divm][vix] = -1000 168 | 169 | # move the current group one step forward in time 170 | 171 | it = beam_seq_table[divm][t-divm] 172 | logprobs_table[divm], state_table[divm] = self.get_logprobs_state(it.cuda(), *(args[divm] + [state_table[divm]])) 173 | 174 | # all beams are sorted by their log-probabilities 175 | done_beams_table = [sorted(done_beams_table[i], key=lambda x: -x['p'])[:bdash] for i in range(group_size)] 176 | done_beams = reduce(lambda a,b:a+b, done_beams_table) 177 | return done_beams -------------------------------------------------------------------------------- /models/FCModel.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.autograd import * 9 | import misc.utils as utils 10 | 11 | from .CaptionModel import CaptionModel 12 | 13 | class LSTMCore(nn.Module): 14 | def __init__(self, opt): 15 | super(LSTMCore, self).__init__() 16 | self.input_encoding_size = opt.input_encoding_size 17 | self.rnn_size = opt.rnn_size 18 | self.drop_prob_lm = opt.drop_prob_lm 19 | 20 | # Build a LSTM 21 | self.i2h = nn.Linear(self.input_encoding_size, 5 * self.rnn_size) 22 | self.h2h = nn.Linear(self.rnn_size, 5 * self.rnn_size) 23 | self.dropout = nn.Dropout(self.drop_prob_lm) 24 | 25 | def forward(self, xt, state): 26 | 27 | all_input_sums = self.i2h(xt) + self.h2h(state[0][-1]) 28 | sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size) 29 | sigmoid_chunk = F.sigmoid(sigmoid_chunk) 30 | in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size) 31 | forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size) 32 | out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size) 33 | 34 | in_transform = torch.max(\ 35 | all_input_sums.narrow(1, 3 * self.rnn_size, self.rnn_size), 36 | all_input_sums.narrow(1, 4 * self.rnn_size, self.rnn_size)) 37 | next_c = forget_gate * state[1][-1] + in_gate * in_transform 38 | next_h = out_gate * F.tanh(next_c) 39 | 40 | output = self.dropout(next_h) 41 | state = (next_h.unsqueeze(0), next_c.unsqueeze(0)) 42 | return output, state 43 | 44 | class FCModel(CaptionModel): 45 | def __init__(self, opt): 46 | super(FCModel, self).__init__() 47 | self.vocab_size = opt.vocab_size 48 | self.input_encoding_size = opt.input_encoding_size 49 | self.rnn_type = opt.rnn_type 50 | self.rnn_size = opt.rnn_size 51 | self.num_layers = opt.num_layers 52 | self.drop_prob_lm = opt.drop_prob_lm 53 | self.seq_length = opt.seq_length 54 | self.fc_feat_size = opt.fc_feat_size 55 | 56 | self.ss_prob = 0.0 # Schedule sampling probability 57 | 58 | self.img_embed = nn.Linear(self.fc_feat_size, self.input_encoding_size) 59 | self.core = LSTMCore(opt) 60 | self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size) 61 | self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1) 62 | 63 | self.init_weights() 64 | 65 | def init_weights(self): 66 | initrange = 0.1 67 | self.embed.weight.data.uniform_(-initrange, initrange) 68 | self.logit.bias.data.fill_(0) 69 | self.logit.weight.data.uniform_(-initrange, initrange) 70 | 71 | def init_hidden(self, bsz): 72 | weight = next(self.parameters()) 73 | if self.rnn_type == 'lstm': 74 | return (weight.new_zeros(self.num_layers, bsz, self.rnn_size), 75 | weight.new_zeros(self.num_layers, bsz, self.rnn_size)) 76 | else: 77 | return weight.new_zeros(self.num_layers, bsz, self.rnn_size) 78 | 79 | def _forward(self, fc_feats, att_feats, seq, att_masks=None): 80 | batch_size = fc_feats.size(0) 81 | state = self.init_hidden(batch_size) 82 | outputs = [] 83 | 84 | for i in range(seq.size(1)): 85 | if i == 0: 86 | xt = self.img_embed(fc_feats) 87 | else: 88 | if self.training and i >= 2 and self.ss_prob > 0.0: # otherwiste no need to sample 89 | sample_prob = fc_feats.data.new(batch_size).uniform_(0, 1) 90 | sample_mask = sample_prob < self.ss_prob 91 | if sample_mask.sum() == 0: 92 | it = seq[:, i-1].clone() 93 | else: 94 | sample_ind = sample_mask.nonzero().view(-1) 95 | it = seq[:, i-1].data.clone() 96 | #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1) 97 | #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1)) 98 | prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1) 99 | it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind)) 100 | else: 101 | it = seq[:, i-1].clone() 102 | # break if all the sequences end 103 | if i >= 2 and seq[:, i-1].sum() == 0: 104 | break 105 | xt = self.embed(it) 106 | 107 | output, state = self.core(xt, state) 108 | output = F.log_softmax(self.logit(output), dim=1) 109 | outputs.append(output) 110 | 111 | return torch.cat([_.unsqueeze(1) for _ in outputs[1:]], 1).contiguous() 112 | 113 | def get_logprobs_state(self, it, state): 114 | # 'it' is contains a word index 115 | xt = self.embed(it) 116 | 117 | output, state = self.core(xt, state) 118 | logprobs = F.log_softmax(self.logit(output), dim=1) 119 | 120 | return logprobs, state 121 | 122 | def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}): 123 | beam_size = opt.get('beam_size', 10) 124 | batch_size = fc_feats.size(0) 125 | 126 | assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' 127 | seq = torch.LongTensor(self.seq_length, batch_size).zero_() 128 | seqLogprobs = torch.FloatTensor(self.seq_length, batch_size) 129 | # lets process every image independently for now, for simplicity 130 | 131 | self.done_beams = [[] for _ in range(batch_size)] 132 | for k in range(batch_size): 133 | state = self.init_hidden(beam_size) 134 | for t in range(2): 135 | if t == 0: 136 | xt = self.img_embed(fc_feats[k:k+1]).expand(beam_size, self.input_encoding_size) 137 | elif t == 1: # input 138 | it = fc_feats.data.new(beam_size).long().zero_() 139 | xt = self.embed(it) 140 | 141 | output, state = self.core(xt, state) 142 | logprobs = F.log_softmax(self.logit(output), dim=1) 143 | 144 | self.done_beams[k] = self.beam_search(state, logprobs, opt=opt) 145 | seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score 146 | seqLogprobs[:, k] = self.done_beams[k][0]['logps'] 147 | # return the samples and their log likelihoods 148 | return seq.transpose(0, 1), seqLogprobs.transpose(0, 1) 149 | 150 | def _sample(self, fc_feats, att_feats, att_masks=None, opt={}): 151 | sample_max = opt.get('sample_max', 1) 152 | beam_size = opt.get('beam_size', 1) 153 | temperature = opt.get('temperature', 1.0) 154 | if beam_size > 1: 155 | return self.sample_beam(fc_feats, att_feats, opt) 156 | 157 | batch_size = fc_feats.size(0) 158 | state = self.init_hidden(batch_size) 159 | seq = fc_feats.new_zeros(batch_size, self.seq_length, dtype=torch.long) 160 | seqLogprobs = fc_feats.new_zeros(batch_size, self.seq_length) 161 | for t in range(self.seq_length + 2): 162 | if t == 0: 163 | xt = self.img_embed(fc_feats) 164 | else: 165 | if t == 1: # input 166 | it = fc_feats.data.new(batch_size).long().zero_() 167 | xt = self.embed(it) 168 | 169 | output, state = self.core(xt, state) 170 | logprobs = F.log_softmax(self.logit(output), dim=1) 171 | 172 | # sample the next_word 173 | if t == self.seq_length + 1: # skip if we achieve maximum length 174 | break 175 | if sample_max: 176 | sampleLogprobs, it = torch.max(logprobs.data, 1) 177 | it = it.view(-1).long() 178 | else: 179 | if temperature == 1.0: 180 | prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1) 181 | else: 182 | # scale logprobs by temperature 183 | prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu() 184 | it = torch.multinomial(prob_prev, 1).cuda() 185 | sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions 186 | it = it.view(-1).long() # and flatten indices for downstream processing 187 | 188 | if t >= 1: 189 | # stop when all finished 190 | if t == 1: 191 | unfinished = it > 0 192 | else: 193 | unfinished = unfinished * (it > 0) 194 | it = it * unfinished.type_as(it) 195 | seq[:,t-1] = it #seq[t] the input of t+2 time step 196 | seqLogprobs[:,t-1] = sampleLogprobs.view(-1) 197 | if unfinished.sum() == 0: 198 | break 199 | 200 | return seq, seqLogprobs 201 | -------------------------------------------------------------------------------- /models/OldModel.py: -------------------------------------------------------------------------------- 1 | # This file contains ShowAttendTell and AllImg model 2 | 3 | # ShowAttendTell is from Show, Attend and Tell: Neural Image Caption Generation with Visual Attention 4 | # https://arxiv.org/abs/1502.03044 5 | 6 | # AllImg is a model where 7 | # img feature is concatenated with word embedding at every time step as the input of lstm 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | from torch.autograd import * 16 | import misc.utils as utils 17 | 18 | from .CaptionModel import CaptionModel 19 | 20 | class OldModel(CaptionModel): 21 | def __init__(self, opt): 22 | super(OldModel, self).__init__() 23 | self.vocab_size = opt.vocab_size 24 | self.input_encoding_size = opt.input_encoding_size 25 | self.rnn_type = opt.rnn_type 26 | self.rnn_size = opt.rnn_size 27 | self.num_layers = opt.num_layers 28 | self.drop_prob_lm = opt.drop_prob_lm 29 | self.seq_length = opt.seq_length 30 | self.fc_feat_size = opt.fc_feat_size 31 | self.att_feat_size = opt.att_feat_size 32 | 33 | self.ss_prob = 0.0 # Schedule sampling probability 34 | 35 | self.linear = nn.Linear(self.fc_feat_size, self.num_layers * self.rnn_size) # feature to rnn_size 36 | self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size) 37 | self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1) 38 | self.dropout = nn.Dropout(self.drop_prob_lm) 39 | 40 | self.init_weights() 41 | 42 | def init_weights(self): 43 | initrange = 0.1 44 | self.embed.weight.data.uniform_(-initrange, initrange) 45 | self.logit.bias.data.fill_(0) 46 | self.logit.weight.data.uniform_(-initrange, initrange) 47 | 48 | def init_hidden(self, fc_feats): 49 | image_map = self.linear(fc_feats).view(-1, self.num_layers, self.rnn_size).transpose(0, 1) 50 | if self.rnn_type == 'lstm': 51 | return (image_map, image_map) 52 | else: 53 | return image_map 54 | 55 | def forward(self, fc_feats, att_feats, seq): 56 | batch_size = fc_feats.size(0) 57 | state = self.init_hidden(fc_feats) 58 | 59 | outputs = [] 60 | 61 | for i in range(seq.size(1) - 1): 62 | if self.training and i >= 1 and self.ss_prob > 0.0: # otherwiste no need to sample 63 | sample_prob = fc_feats.data.new(batch_size).uniform_(0, 1) 64 | sample_mask = sample_prob < self.ss_prob 65 | if sample_mask.sum() == 0: 66 | it = seq[:, i].clone() 67 | else: 68 | sample_ind = sample_mask.nonzero().view(-1) 69 | it = seq[:, i].data.clone() 70 | #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1) 71 | #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1)) 72 | prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1) 73 | it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind)) 74 | else: 75 | it = seq[:, i].clone() 76 | # break if all the sequences end 77 | if i >= 1 and seq[:, i].sum() == 0: 78 | break 79 | 80 | xt = self.embed(it) 81 | 82 | output, state = self.core(xt, fc_feats, att_feats, state) 83 | output = F.log_softmax(self.logit(self.dropout(output)), dim=1) 84 | outputs.append(output) 85 | 86 | return torch.cat([_.unsqueeze(1) for _ in outputs], 1) 87 | 88 | def get_logprobs_state(self, it, tmp_fc_feats, tmp_att_feats, state): 89 | # 'it' contains a word index 90 | xt = self.embed(it) 91 | 92 | output, state = self.core(xt, tmp_fc_feats, tmp_att_feats, state) 93 | logprobs = F.log_softmax(self.logit(self.dropout(output)), dim=1) 94 | 95 | return logprobs, state 96 | 97 | def sample_beam(self, fc_feats, att_feats, opt={}): 98 | beam_size = opt.get('beam_size', 10) 99 | batch_size = fc_feats.size(0) 100 | 101 | assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' 102 | seq = torch.LongTensor(self.seq_length, batch_size).zero_() 103 | seqLogprobs = torch.FloatTensor(self.seq_length, batch_size) 104 | # lets process every image independently for now, for simplicity 105 | 106 | self.done_beams = [[] for _ in range(batch_size)] 107 | for k in range(batch_size): 108 | tmp_fc_feats = fc_feats[k:k+1].expand(beam_size, self.fc_feat_size) 109 | tmp_att_feats = att_feats[k:k+1].expand(*((beam_size,)+att_feats.size()[1:])).contiguous() 110 | 111 | state = self.init_hidden(tmp_fc_feats) 112 | 113 | beam_seq = torch.LongTensor(self.seq_length, beam_size).zero_() 114 | beam_seq_logprobs = torch.FloatTensor(self.seq_length, beam_size).zero_() 115 | beam_logprobs_sum = torch.zeros(beam_size) # running sum of logprobs for each beam 116 | done_beams = [] 117 | for t in range(1): 118 | if t == 0: # input 119 | it = fc_feats.data.new(beam_size).long().zero_() 120 | xt = self.embed(it) 121 | 122 | output, state = self.core(xt, tmp_fc_feats, tmp_att_feats, state) 123 | logprobs = F.log_softmax(self.logit(self.dropout(output)), dim=1) 124 | 125 | self.done_beams[k] = self.beam_search(state, logprobs, tmp_fc_feats, tmp_att_feats, opt=opt) 126 | seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score 127 | seqLogprobs[:, k] = self.done_beams[k][0]['logps'] 128 | # return the samples and their log likelihoods 129 | return seq.transpose(0, 1), seqLogprobs.transpose(0, 1) 130 | 131 | def sample(self, fc_feats, att_feats, opt={}): 132 | sample_max = opt.get('sample_max', 1) 133 | beam_size = opt.get('beam_size', 1) 134 | temperature = opt.get('temperature', 1.0) 135 | if beam_size > 1: 136 | return self.sample_beam(fc_feats, att_feats, opt) 137 | 138 | batch_size = fc_feats.size(0) 139 | state = self.init_hidden(fc_feats) 140 | 141 | seq = [] 142 | seqLogprobs = [] 143 | for t in range(self.seq_length + 1): 144 | if t == 0: # input 145 | it = fc_feats.data.new(batch_size).long().zero_() 146 | elif sample_max: 147 | sampleLogprobs, it = torch.max(logprobs.data, 1) 148 | it = it.view(-1).long() 149 | else: 150 | if temperature == 1.0: 151 | prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1) 152 | else: 153 | # scale logprobs by temperature 154 | prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu() 155 | it = torch.multinomial(prob_prev, 1).cuda() 156 | sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions 157 | it = it.view(-1).long() # and flatten indices for downstream processing 158 | 159 | xt = self.embed(it) 160 | 161 | if t >= 1: 162 | # stop when all finished 163 | if t == 1: 164 | unfinished = it > 0 165 | else: 166 | unfinished = unfinished * (it > 0) 167 | if unfinished.sum() == 0: 168 | break 169 | it = it * unfinished.type_as(it) 170 | seq.append(it) #seq[t] the input of t+2 time step 171 | seqLogprobs.append(sampleLogprobs.view(-1)) 172 | 173 | output, state = self.core(xt, fc_feats, att_feats, state) 174 | logprobs = F.log_softmax(self.logit(self.dropout(output)), dim=1) 175 | 176 | return torch.cat([_.unsqueeze(1) for _ in seq], 1), torch.cat([_.unsqueeze(1) for _ in seqLogprobs], 1) 177 | 178 | 179 | class ShowAttendTellCore(nn.Module): 180 | def __init__(self, opt): 181 | super(ShowAttendTellCore, self).__init__() 182 | self.input_encoding_size = opt.input_encoding_size 183 | self.rnn_type = opt.rnn_type 184 | self.rnn_size = opt.rnn_size 185 | self.num_layers = opt.num_layers 186 | self.drop_prob_lm = opt.drop_prob_lm 187 | self.fc_feat_size = opt.fc_feat_size 188 | self.att_feat_size = opt.att_feat_size 189 | self.att_hid_size = opt.att_hid_size 190 | 191 | self.rnn = getattr(nn, self.rnn_type.upper())(self.input_encoding_size + self.att_feat_size, 192 | self.rnn_size, self.num_layers, bias=False, dropout=self.drop_prob_lm) 193 | 194 | if self.att_hid_size > 0: 195 | self.ctx2att = nn.Linear(self.att_feat_size, self.att_hid_size) 196 | self.h2att = nn.Linear(self.rnn_size, self.att_hid_size) 197 | self.alpha_net = nn.Linear(self.att_hid_size, 1) 198 | else: 199 | self.ctx2att = nn.Linear(self.att_feat_size, 1) 200 | self.h2att = nn.Linear(self.rnn_size, 1) 201 | 202 | def forward(self, xt, fc_feats, att_feats, state): 203 | att_size = att_feats.numel() // att_feats.size(0) // self.att_feat_size 204 | att = att_feats.view(-1, self.att_feat_size) 205 | if self.att_hid_size > 0: 206 | att = self.ctx2att(att) # (batch * att_size) * att_hid_size 207 | att = att.view(-1, att_size, self.att_hid_size) # batch * att_size * att_hid_size 208 | att_h = self.h2att(state[0][-1]) # batch * att_hid_size 209 | att_h = att_h.unsqueeze(1).expand_as(att) # batch * att_size * att_hid_size 210 | dot = att + att_h # batch * att_size * att_hid_size 211 | dot = F.tanh(dot) # batch * att_size * att_hid_size 212 | dot = dot.view(-1, self.att_hid_size) # (batch * att_size) * att_hid_size 213 | dot = self.alpha_net(dot) # (batch * att_size) * 1 214 | dot = dot.view(-1, att_size) # batch * att_size 215 | else: 216 | att = self.ctx2att(att)(att) # (batch * att_size) * 1 217 | att = att.view(-1, att_size) # batch * att_size 218 | att_h = self.h2att(state[0][-1]) # batch * 1 219 | att_h = att_h.expand_as(att) # batch * att_size 220 | dot = att_h + att # batch * att_size 221 | 222 | weight = F.softmax(dot, dim=1) 223 | att_feats_ = att_feats.view(-1, att_size, self.att_feat_size) # batch * att_size * att_feat_size 224 | att_res = torch.bmm(weight.unsqueeze(1), att_feats_).squeeze(1) # batch * att_feat_size 225 | 226 | output, state = self.rnn(torch.cat([xt, att_res], 1).unsqueeze(0), state) 227 | return output.squeeze(0), state 228 | 229 | class AllImgCore(nn.Module): 230 | def __init__(self, opt): 231 | super(AllImgCore, self).__init__() 232 | self.input_encoding_size = opt.input_encoding_size 233 | self.rnn_type = opt.rnn_type 234 | self.rnn_size = opt.rnn_size 235 | self.num_layers = opt.num_layers 236 | self.drop_prob_lm = opt.drop_prob_lm 237 | self.fc_feat_size = opt.fc_feat_size 238 | 239 | self.rnn = getattr(nn, self.rnn_type.upper())(self.input_encoding_size + self.fc_feat_size, 240 | self.rnn_size, self.num_layers, bias=False, dropout=self.drop_prob_lm) 241 | 242 | def forward(self, xt, fc_feats, att_feats, state): 243 | output, state = self.rnn(torch.cat([xt, fc_feats], 1).unsqueeze(0), state) 244 | return output.squeeze(0), state 245 | 246 | class ShowAttendTellModel(OldModel): 247 | def __init__(self, opt): 248 | super(ShowAttendTellModel, self).__init__(opt) 249 | self.core = ShowAttendTellCore(opt) 250 | 251 | class AllImgModel(OldModel): 252 | def __init__(self, opt): 253 | super(AllImgModel, self).__init__(opt) 254 | self.core = AllImgCore(opt) 255 | 256 | -------------------------------------------------------------------------------- /models/ShowTellModel.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.autograd import * 9 | import misc.utils as utils 10 | 11 | from .CaptionModel import CaptionModel 12 | 13 | class ShowTellModel(CaptionModel): 14 | def __init__(self, opt): 15 | super(ShowTellModel, self).__init__() 16 | self.vocab_size = opt.vocab_size 17 | self.input_encoding_size = opt.input_encoding_size 18 | self.rnn_type = opt.rnn_type 19 | self.rnn_size = opt.rnn_size 20 | self.num_layers = opt.num_layers 21 | self.drop_prob_lm = opt.drop_prob_lm 22 | self.seq_length = opt.seq_length 23 | self.fc_feat_size = opt.fc_feat_size 24 | 25 | self.ss_prob = 0.0 # Schedule sampling probability 26 | 27 | self.img_embed = nn.Linear(self.fc_feat_size, self.input_encoding_size) 28 | self.core = getattr(nn, self.rnn_type.upper())(self.input_encoding_size, self.rnn_size, self.num_layers, bias=False, dropout=self.drop_prob_lm) 29 | self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size) 30 | self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1) 31 | self.dropout = nn.Dropout(self.drop_prob_lm) 32 | 33 | self.init_weights() 34 | 35 | def init_weights(self): 36 | initrange = 0.1 37 | self.embed.weight.data.uniform_(-initrange, initrange) 38 | self.logit.bias.data.fill_(0) 39 | self.logit.weight.data.uniform_(-initrange, initrange) 40 | 41 | def init_hidden(self, bsz): 42 | weight = next(self.parameters()).data 43 | if self.rnn_type == 'lstm': 44 | return (weight.new_zeros(self.num_layers, bsz, self.rnn_size), 45 | weight.new_zeros(self.num_layers, bsz, self.rnn_size)) 46 | else: 47 | return weight.new_zeros(self.num_layers, bsz, self.rnn_size) 48 | 49 | def _forward(self, fc_feats, att_feats, seq, att_masks=None): 50 | batch_size = fc_feats.size(0) 51 | state = self.init_hidden(batch_size) 52 | outputs = [] 53 | 54 | for i in range(seq.size(1)): 55 | if i == 0: 56 | xt = self.img_embed(fc_feats) 57 | else: 58 | if self.training and i >= 2 and self.ss_prob > 0.0: # otherwiste no need to sample 59 | sample_prob = fc_feats.data.new(batch_size).uniform_(0, 1) 60 | sample_mask = sample_prob < self.ss_prob 61 | if sample_mask.sum() == 0: 62 | it = seq[:, i-1].clone() 63 | else: 64 | sample_ind = sample_mask.nonzero().view(-1) 65 | it = seq[:, i-1].data.clone() 66 | #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1) 67 | #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1)) 68 | prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1) 69 | it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind)) 70 | else: 71 | it = seq[:, i-1].clone() 72 | # break if all the sequences end 73 | if i >= 2 and seq[:, i-1].data.sum() == 0: 74 | break 75 | xt = self.embed(it) 76 | 77 | output, state = self.core(xt.unsqueeze(0), state) 78 | output = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1) 79 | outputs.append(output) 80 | 81 | return torch.cat([_.unsqueeze(1) for _ in outputs[1:]], 1).contiguous() 82 | 83 | def get_logprobs_state(self, it, state): 84 | # 'it' contains a word index 85 | xt = self.embed(it) 86 | 87 | output, state = self.core(xt.unsqueeze(0), state) 88 | logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1) 89 | 90 | return logprobs, state 91 | 92 | def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}): 93 | beam_size = opt.get('beam_size', 10) 94 | batch_size = fc_feats.size(0) 95 | 96 | assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' 97 | seq = torch.LongTensor(self.seq_length, batch_size).zero_() 98 | seqLogprobs = torch.FloatTensor(self.seq_length, batch_size) 99 | # lets process every image independently for now, for simplicity 100 | 101 | self.done_beams = [[] for _ in range(batch_size)] 102 | for k in range(batch_size): 103 | state = self.init_hidden(beam_size) 104 | for t in range(2): 105 | if t == 0: 106 | xt = self.img_embed(fc_feats[k:k+1]).expand(beam_size, self.input_encoding_size) 107 | elif t == 1: # input 108 | it = fc_feats.data.new(beam_size).long().zero_() 109 | xt = self.embed(it) 110 | 111 | output, state = self.core(xt.unsqueeze(0), state) 112 | logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1) 113 | 114 | self.done_beams[k] = self.beam_search(state, logprobs, opt=opt) 115 | seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score 116 | seqLogprobs[:, k] = self.done_beams[k][0]['logps'] 117 | # return the samples and their log likelihoods 118 | return seq.transpose(0, 1), seqLogprobs.transpose(0, 1) 119 | 120 | def _sample(self, fc_feats, att_feats, att_masks=None, opt={}): 121 | sample_max = opt.get('sample_max', 1) 122 | beam_size = opt.get('beam_size', 1) 123 | temperature = opt.get('temperature', 1.0) 124 | if beam_size > 1: 125 | return self.sample_beam(fc_feats, att_feats, opt) 126 | 127 | batch_size = fc_feats.size(0) 128 | state = self.init_hidden(batch_size) 129 | seq = fc_feats.new_zeros(batch_size, self.seq_length, dtype=torch.long) 130 | seqLogprobs = fc_feats.new_zeros(batch_size, self.seq_length) 131 | for t in range(self.seq_length + 2): 132 | if t == 0: 133 | xt = self.img_embed(fc_feats) 134 | else: 135 | if t == 1: # input 136 | it = fc_feats.data.new(batch_size).long().zero_() 137 | xt = self.embed(it) 138 | 139 | output, state = self.core(xt.unsqueeze(0), state) 140 | logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1) 141 | 142 | # sample the next word 143 | if t == self.seq_length + 1: # skip if we achieve maximum length 144 | break 145 | if sample_max: 146 | sampleLogprobs, it = torch.max(logprobs.data, 1) 147 | it = it.view(-1).long() 148 | else: 149 | if temperature == 1.0: 150 | prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1) 151 | else: 152 | # scale logprobs by temperature 153 | prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu() 154 | it = torch.multinomial(prob_prev, 1).cuda() 155 | sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions 156 | it = it.view(-1).long() # and flatten indices for downstream processing 157 | 158 | if t >= 1: 159 | # stop when all finished 160 | if t == 1: 161 | unfinished = it > 0 162 | else: 163 | unfinished = unfinished * (it > 0) 164 | it = it * unfinished.type_as(it) 165 | seq[:,t-1] = it #seq[t] the input of t+2 time step 166 | seqLogprobs[:,t-1] = sampleLogprobs.view(-1) 167 | if unfinished.sum() == 0: 168 | break 169 | 170 | return seq, seqLogprobs -------------------------------------------------------------------------------- /models/TransformerModel.py: -------------------------------------------------------------------------------- 1 | # This file contains Transformer network 2 | # Most of the code is copied from http://nlp.seas.harvard.edu/2018/04/03/attention.html 3 | 4 | # The cfg name correspondance: 5 | # N=num_layers 6 | # d_model=input_encoding_size 7 | # d_ff=rnn_size 8 | # h is always 8 9 | 10 | from __future__ import absolute_import 11 | from __future__ import division 12 | from __future__ import print_function 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | import misc.utils as utils 18 | 19 | import copy 20 | import math 21 | import numpy as np 22 | 23 | from .CaptionModel import CaptionModel 24 | from .AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper, AttModel 25 | 26 | class EncoderDecoder(nn.Module): 27 | """ 28 | A standard Encoder-Decoder architecture. Base for this and many 29 | other models. 30 | """ 31 | def __init__(self, encoder, decoder, src_embed, tgt_embed, generator): 32 | super(EncoderDecoder, self).__init__() 33 | self.encoder = encoder 34 | self.decoder = decoder 35 | self.src_embed = src_embed 36 | self.tgt_embed = tgt_embed 37 | self.generator = generator 38 | 39 | def forward(self, src, tgt, src_mask, tgt_mask): 40 | "Take in and process masked src and target sequences." 41 | return self.decode(self.encode(src, src_mask), src_mask, 42 | tgt, tgt_mask) 43 | 44 | def encode(self, src, src_mask): 45 | return self.encoder(self.src_embed(src), src_mask) 46 | 47 | def decode(self, memory, src_mask, tgt, tgt_mask): 48 | return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask) 49 | 50 | class Generator(nn.Module): 51 | "Define standard linear + softmax generation step." 52 | def __init__(self, d_model, vocab): 53 | super(Generator, self).__init__() 54 | self.proj = nn.Linear(d_model, vocab) 55 | 56 | def forward(self, x): 57 | return F.log_softmax(self.proj(x), dim=-1) 58 | 59 | def clones(module, N): 60 | "Produce N identical layers." 61 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) 62 | 63 | class Encoder(nn.Module): 64 | "Core encoder is a stack of N layers" 65 | def __init__(self, layer, N): 66 | super(Encoder, self).__init__() 67 | self.layers = clones(layer, N) 68 | self.norm = LayerNorm(layer.size) 69 | 70 | def forward(self, x, mask): 71 | "Pass the input (and mask) through each layer in turn." 72 | for layer in self.layers: 73 | x = layer(x, mask) 74 | return self.norm(x) 75 | 76 | class LayerNorm(nn.Module): 77 | "Construct a layernorm module (See citation for details)." 78 | def __init__(self, features, eps=1e-6): 79 | super(LayerNorm, self).__init__() 80 | self.a_2 = nn.Parameter(torch.ones(features)) 81 | self.b_2 = nn.Parameter(torch.zeros(features)) 82 | self.eps = eps 83 | 84 | def forward(self, x): 85 | mean = x.mean(-1, keepdim=True) 86 | std = x.std(-1, keepdim=True) 87 | return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 88 | 89 | class SublayerConnection(nn.Module): 90 | """ 91 | A residual connection followed by a layer norm. 92 | Note for code simplicity the norm is first as opposed to last. 93 | """ 94 | def __init__(self, size, dropout): 95 | super(SublayerConnection, self).__init__() 96 | self.norm = LayerNorm(size) 97 | self.dropout = nn.Dropout(dropout) 98 | 99 | def forward(self, x, sublayer): 100 | "Apply residual connection to any sublayer with the same size." 101 | return x + self.dropout(sublayer(self.norm(x))) 102 | 103 | class EncoderLayer(nn.Module): 104 | "Encoder is made up of self-attn and feed forward (defined below)" 105 | def __init__(self, size, self_attn, feed_forward, dropout): 106 | super(EncoderLayer, self).__init__() 107 | self.self_attn = self_attn 108 | self.feed_forward = feed_forward 109 | self.sublayer = clones(SublayerConnection(size, dropout), 2) 110 | self.size = size 111 | 112 | def forward(self, x, mask): 113 | "Follow Figure 1 (left) for connections." 114 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) 115 | return self.sublayer[1](x, self.feed_forward) 116 | 117 | class Decoder(nn.Module): 118 | "Generic N layer decoder with masking." 119 | def __init__(self, layer, N): 120 | super(Decoder, self).__init__() 121 | self.layers = clones(layer, N) 122 | self.norm = LayerNorm(layer.size) 123 | 124 | def forward(self, x, memory, src_mask, tgt_mask): 125 | for layer in self.layers: 126 | x = layer(x, memory, src_mask, tgt_mask) 127 | return self.norm(x) 128 | 129 | class DecoderLayer(nn.Module): 130 | "Decoder is made of self-attn, src-attn, and feed forward (defined below)" 131 | def __init__(self, size, self_attn, src_attn, feed_forward, dropout): 132 | super(DecoderLayer, self).__init__() 133 | self.size = size 134 | self.self_attn = self_attn 135 | self.src_attn = src_attn 136 | self.feed_forward = feed_forward 137 | self.sublayer = clones(SublayerConnection(size, dropout), 3) 138 | 139 | def forward(self, x, memory, src_mask, tgt_mask): 140 | "Follow Figure 1 (right) for connections." 141 | m = memory 142 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask)) 143 | x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask)) 144 | return self.sublayer[2](x, self.feed_forward) 145 | 146 | def subsequent_mask(size): 147 | "Mask out subsequent positions." 148 | attn_shape = (1, size, size) 149 | subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8') 150 | return torch.from_numpy(subsequent_mask) == 0 151 | 152 | def attention(query, key, value, mask=None, dropout=None): 153 | "Compute 'Scaled Dot Product Attention'" 154 | d_k = query.size(-1) 155 | scores = torch.matmul(query, key.transpose(-2, -1)) \ 156 | / math.sqrt(d_k) 157 | if mask is not None: 158 | scores = scores.masked_fill(mask == 0, -1e9) 159 | p_attn = F.softmax(scores, dim = -1) 160 | if dropout is not None: 161 | p_attn = dropout(p_attn) 162 | return torch.matmul(p_attn, value), p_attn 163 | 164 | class MultiHeadedAttention(nn.Module): 165 | def __init__(self, h, d_model, dropout=0.1): 166 | "Take in model size and number of heads." 167 | super(MultiHeadedAttention, self).__init__() 168 | assert d_model % h == 0 169 | # We assume d_v always equals d_k 170 | self.d_k = d_model // h 171 | self.h = h 172 | self.linears = clones(nn.Linear(d_model, d_model), 4) 173 | self.attn = None 174 | self.dropout = nn.Dropout(p=dropout) 175 | 176 | def forward(self, query, key, value, mask=None): 177 | "Implements Figure 2" 178 | if mask is not None: 179 | # Same mask applied to all h heads. 180 | mask = mask.unsqueeze(1) 181 | nbatches = query.size(0) 182 | 183 | # 1) Do all the linear projections in batch from d_model => h x d_k 184 | query, key, value = \ 185 | [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) 186 | for l, x in zip(self.linears, (query, key, value))] 187 | 188 | # 2) Apply attention on all the projected vectors in batch. 189 | x, self.attn = attention(query, key, value, mask=mask, 190 | dropout=self.dropout) 191 | 192 | # 3) "Concat" using a view and apply a final linear. 193 | x = x.transpose(1, 2).contiguous() \ 194 | .view(nbatches, -1, self.h * self.d_k) 195 | return self.linears[-1](x) 196 | 197 | class PositionwiseFeedForward(nn.Module): 198 | "Implements FFN equation." 199 | def __init__(self, d_model, d_ff, dropout=0.1): 200 | super(PositionwiseFeedForward, self).__init__() 201 | self.w_1 = nn.Linear(d_model, d_ff) 202 | self.w_2 = nn.Linear(d_ff, d_model) 203 | self.dropout = nn.Dropout(dropout) 204 | 205 | def forward(self, x): 206 | return self.w_2(self.dropout(F.relu(self.w_1(x)))) 207 | 208 | class Embeddings(nn.Module): 209 | def __init__(self, d_model, vocab): 210 | super(Embeddings, self).__init__() 211 | self.lut = nn.Embedding(vocab, d_model) 212 | self.d_model = d_model 213 | 214 | def forward(self, x): 215 | return self.lut(x) * math.sqrt(self.d_model) 216 | 217 | class PositionalEncoding(nn.Module): 218 | "Implement the PE function." 219 | def __init__(self, d_model, dropout, max_len=5000): 220 | super(PositionalEncoding, self).__init__() 221 | self.dropout = nn.Dropout(p=dropout) 222 | 223 | # Compute the positional encodings once in log space. 224 | pe = torch.zeros(max_len, d_model) 225 | position = torch.arange(0, max_len).unsqueeze(1).float() 226 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * 227 | -(math.log(10000.0) / d_model)) 228 | pe[:, 0::2] = torch.sin(position * div_term) 229 | pe[:, 1::2] = torch.cos(position * div_term) 230 | pe = pe.unsqueeze(0) 231 | self.register_buffer('pe', pe) 232 | 233 | def forward(self, x): 234 | x = x + self.pe[:, :x.size(1)] 235 | return self.dropout(x) 236 | 237 | class TransformerModel(AttModel): 238 | 239 | def make_model(self, src_vocab, tgt_vocab, N=6, 240 | d_model=512, d_ff=2048, h=8, dropout=0.1): 241 | "Helper: Construct a model from hyperparameters." 242 | c = copy.deepcopy 243 | attn = MultiHeadedAttention(h, d_model) 244 | ff = PositionwiseFeedForward(d_model, d_ff, dropout) 245 | position = PositionalEncoding(d_model, dropout) 246 | model = EncoderDecoder( 247 | Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N), 248 | Decoder(DecoderLayer(d_model, c(attn), c(attn), 249 | c(ff), dropout), N), 250 | lambda x:x, # nn.Sequential(Embeddings(d_model, src_vocab), c(position)), 251 | nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)), 252 | Generator(d_model, tgt_vocab)) 253 | 254 | # This was important from their code. 255 | # Initialize parameters with Glorot / fan_avg. 256 | for p in model.parameters(): 257 | if p.dim() > 1: 258 | nn.init.xavier_uniform_(p) 259 | return model 260 | 261 | def __init__(self, opt): 262 | super(TransformerModel, self).__init__(opt) 263 | self.opt = opt 264 | # self.config = yaml.load(open(opt.config_file)) 265 | # d_model = self.input_encoding_size # 512 266 | 267 | delattr(self, 'att_embed') 268 | self.att_embed = nn.Sequential(*( 269 | ((nn.BatchNorm1d(self.att_feat_size),) if self.use_bn else ())+ 270 | (nn.Linear(self.att_feat_size, self.input_encoding_size), 271 | nn.ReLU(), 272 | nn.Dropout(self.drop_prob_lm))+ 273 | ((nn.BatchNorm1d(self.input_encoding_size),) if self.use_bn==2 else ()))) 274 | 275 | delattr(self, 'embed') 276 | self.embed = lambda x : x 277 | delattr(self, 'fc_embed') 278 | self.fc_embed = lambda x : x 279 | delattr(self, 'logit') 280 | del self.ctx2att 281 | 282 | tgt_vocab = self.vocab_size + 1 283 | self.model = self.make_model(0, tgt_vocab, 284 | N=opt.num_layers, 285 | d_model=opt.input_encoding_size, 286 | d_ff=opt.rnn_size) 287 | 288 | def logit(self, x): # unsafe way 289 | return self.model.generator.proj(x) 290 | 291 | def init_hidden(self, bsz): 292 | return None 293 | 294 | def _prepare_feature(self, fc_feats, att_feats, att_masks): 295 | 296 | att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks) 297 | memory = self.model.encode(att_feats, att_masks) 298 | 299 | return fc_feats[...,:1], att_feats[...,:1], memory, att_masks 300 | 301 | def _prepare_feature_forward(self, att_feats, att_masks=None, seq=None): 302 | att_feats, att_masks = self.clip_att(att_feats, att_masks) 303 | 304 | att_feats = pack_wrapper(self.att_embed, att_feats, att_masks) 305 | 306 | if att_masks is None: 307 | att_masks = att_feats.new_ones(att_feats.shape[:2], dtype=torch.long) 308 | att_masks = att_masks.unsqueeze(-2) 309 | 310 | if seq is not None: 311 | # crop the last one 312 | seq = seq[:,:-1] 313 | seq_mask = (seq.data > 0) 314 | seq_mask[:,0] += 1 315 | 316 | seq_mask = seq_mask.unsqueeze(-2) 317 | seq_mask = seq_mask & subsequent_mask(seq.size(-1)).to(seq_mask) 318 | else: 319 | seq_mask = None 320 | 321 | return att_feats, seq, att_masks, seq_mask 322 | 323 | def _forward(self, fc_feats, att_feats, seq, att_masks=None): 324 | att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks, seq) 325 | 326 | out = self.model(att_feats, seq, att_masks, seq_mask) 327 | 328 | outputs = self.model.generator(out) 329 | return outputs 330 | # return torch.cat([_.unsqueeze(1) for _ in outputs], 1) 331 | 332 | def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask): 333 | """ 334 | state = [ys.unsqueeze(0)] 335 | """ 336 | if state is None: 337 | ys = it.unsqueeze(1) 338 | else: 339 | ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1) 340 | out = self.model.decode(memory, mask, 341 | ys, 342 | subsequent_mask(ys.size(1)) 343 | .to(memory.device)) 344 | return out[:, -1], [ys.unsqueeze(0)] -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import copy 7 | 8 | import numpy as np 9 | import misc.utils as utils 10 | import torch 11 | 12 | from .ShowTellModel import ShowTellModel 13 | from .FCModel import FCModel 14 | from .OldModel import ShowAttendTellModel, AllImgModel 15 | from .AttModel import * 16 | from .TransformerModel import TransformerModel 17 | 18 | def setup(opt): 19 | 20 | if opt.caption_model == 'fc': 21 | model = FCModel(opt) 22 | if opt.caption_model == 'show_tell': 23 | model = ShowTellModel(opt) 24 | # Att2in model in self-critical 25 | elif opt.caption_model == 'att2in': 26 | model = Att2inModel(opt) 27 | # Att2in model with two-layer MLP img embedding and word embedding 28 | elif opt.caption_model == 'att2in2': 29 | model = Att2in2Model(opt) 30 | elif opt.caption_model == 'att2all2': 31 | model = Att2all2Model(opt) 32 | # Adaptive Attention model from Knowing when to look 33 | elif opt.caption_model == 'adaatt': 34 | model = AdaAttModel(opt) 35 | # Adaptive Attention with maxout lstm 36 | elif opt.caption_model == 'adaattmo': 37 | model = AdaAttMOModel(opt) 38 | # Top-down attention model 39 | elif opt.caption_model == 'topdown': 40 | model = TopDownModel(opt) 41 | # StackAtt 42 | elif opt.caption_model == 'stackatt': 43 | model = StackAttModel(opt) 44 | # DenseAtt 45 | elif opt.caption_model == 'denseatt': 46 | model = DenseAttModel(opt) 47 | # Transformer 48 | elif opt.caption_model == 'transformer': 49 | model = TransformerModel(opt) 50 | else: 51 | raise Exception("Caption model not supported: {}".format(opt.caption_model)) 52 | 53 | # check compatibility if training is continued from previously saved model 54 | if vars(opt).get('start_from', None) is not None: 55 | # check if all necessary files exist 56 | assert os.path.isdir(opt.start_from)," %s must be a a path" % opt.start_from 57 | assert os.path.isfile(os.path.join(opt.start_from,"infos_"+opt.id+".pkl")),"infos.pkl file does not exist in path %s"%opt.start_from 58 | model.load_state_dict(torch.load(os.path.join(opt.start_from, 'model.pth'))) 59 | 60 | return model 61 | -------------------------------------------------------------------------------- /opts.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def parse_opt(): 4 | parser = argparse.ArgumentParser() 5 | # Data input settings 6 | parser.add_argument('--input_json', type=str, default='data/coco.json', 7 | help='path to the json file containing additional info and vocab') 8 | parser.add_argument('--input_fc_dir', type=str, default='data/cocotalk_fc', 9 | help='path to the directory containing the preprocessed fc feats') 10 | parser.add_argument('--input_att_dir', type=str, default='data/cocotalk_att', 11 | help='path to the directory containing the preprocessed att feats') 12 | parser.add_argument('--input_box_dir', type=str, default='data/cocotalk_box', 13 | help='path to the directory containing the boxes of att feats') 14 | parser.add_argument('--input_label_h5', type=str, default='data/coco_label.h5', 15 | help='path to the h5file containing the preprocessed dataset') 16 | parser.add_argument('--start_from', type=str, default=None, 17 | help="""continue training from saved model at this path. Path must contain files saved by previous training process: 18 | 'infos.pkl' : configuration; 19 | 'checkpoint' : paths to model file(s) (created by tf). 20 | Note: this file contains absolute paths, be careful when moving files around; 21 | 'model.ckpt-*' : file(s) with model definition (created by tf) 22 | """) 23 | parser.add_argument('--cached_tokens', type=str, default='coco-train-idxs', 24 | help='Cached token file for calculating cider score during self critical training.') 25 | 26 | # Model settings 27 | parser.add_argument('--caption_model', type=str, default="show_tell", 28 | help='show_tell, show_attend_tell, all_img, fc, att2in, att2in2, att2all2, adaatt, adaattmo, topdown, stackatt, denseatt, transformer') 29 | parser.add_argument('--rnn_size', type=int, default=512, 30 | help='size of the rnn in number of hidden nodes in each layer') 31 | parser.add_argument('--num_layers', type=int, default=1, 32 | help='number of layers in the RNN') 33 | parser.add_argument('--rnn_type', type=str, default='lstm', 34 | help='rnn, gru, or lstm') 35 | parser.add_argument('--input_encoding_size', type=int, default=512, 36 | help='the encoding size of each token in the vocabulary, and the image.') 37 | parser.add_argument('--att_hid_size', type=int, default=512, 38 | help='the hidden size of the attention MLP; only useful in show_attend_tell; 0 if not using hidden layer') 39 | parser.add_argument('--fc_feat_size', type=int, default=2048, 40 | help='2048 for resnet, 4096 for vgg') 41 | parser.add_argument('--att_feat_size', type=int, default=2048, 42 | help='2048 for resnet, 512 for vgg') 43 | parser.add_argument('--logit_layers', type=int, default=1, 44 | help='number of layers in the RNN') 45 | 46 | 47 | parser.add_argument('--use_bn', type=int, default=0, 48 | help='If 1, then do batch_normalization first in att_embed, if 2 then do bn both in the beginning and the end of att_embed') 49 | 50 | # feature manipulation 51 | parser.add_argument('--norm_att_feat', type=int, default=0, 52 | help='If normalize attention features') 53 | parser.add_argument('--use_box', type=int, default=0, 54 | help='If use box features') 55 | parser.add_argument('--norm_box_feat', type=int, default=0, 56 | help='If use box, do we normalize box feature') 57 | 58 | # Optimization: General 59 | parser.add_argument('--max_epochs', type=int, default=-1, 60 | help='number of epochs') 61 | parser.add_argument('--batch_size', type=int, default=16, 62 | help='minibatch size') 63 | parser.add_argument('--grad_clip', type=float, default=0.1, #5., 64 | help='clip gradients at this value') 65 | parser.add_argument('--drop_prob_lm', type=float, default=0.5, 66 | help='strength of dropout in the Language Model RNN') 67 | parser.add_argument('--self_critical_after', type=int, default=-1, 68 | help='After what epoch do we start finetuning the CNN? (-1 = disable; never finetune, 0 = finetune from start)') 69 | parser.add_argument('--seq_per_img', type=int, default=5, 70 | help='number of captions to sample for each image during training. Done for efficiency since CNN forward pass is expensive. E.g. coco has 5 sents/image') 71 | parser.add_argument('--beam_size', type=int, default=1, 72 | help='used when sample_max = 1, indicates number of beams in beam search. Usually 2 or 3 works well. More is not better. Set this to 1 for faster runtime but a bit worse performance.') 73 | 74 | #Optimization: for the Language Model 75 | parser.add_argument('--optim', type=str, default='adam', 76 | help='what update to use? rmsprop|sgd|sgdmom|adagrad|adam') 77 | parser.add_argument('--learning_rate', type=float, default=4e-4, 78 | help='learning rate') 79 | parser.add_argument('--learning_rate_decay_start', type=int, default=-1, 80 | help='at what iteration to start decaying learning rate? (-1 = dont) (in epoch)') 81 | parser.add_argument('--learning_rate_decay_every', type=int, default=3, 82 | help='every how many iterations thereafter to drop LR?(in epoch)') 83 | parser.add_argument('--learning_rate_decay_rate', type=float, default=0.8, 84 | help='every how many iterations thereafter to drop LR?(in epoch)') 85 | parser.add_argument('--optim_alpha', type=float, default=0.9, 86 | help='alpha for adam') 87 | parser.add_argument('--optim_beta', type=float, default=0.999, 88 | help='beta used for adam') 89 | parser.add_argument('--optim_epsilon', type=float, default=1e-8, 90 | help='epsilon that goes into denominator for smoothing') 91 | parser.add_argument('--weight_decay', type=float, default=0, 92 | help='weight_decay') 93 | 94 | parser.add_argument('--scheduled_sampling_start', type=int, default=-1, 95 | help='at what iteration to start decay gt probability') 96 | parser.add_argument('--scheduled_sampling_increase_every', type=int, default=5, 97 | help='every how many iterations thereafter to gt probability') 98 | parser.add_argument('--scheduled_sampling_increase_prob', type=float, default=0.05, 99 | help='How much to update the prob') 100 | parser.add_argument('--scheduled_sampling_max_prob', type=float, default=0.25, 101 | help='Maximum scheduled sampling prob.') 102 | 103 | 104 | # Evaluation/Checkpointing 105 | parser.add_argument('--val_images_use', type=int, default=3200, 106 | help='how many images to use when periodically evaluating the validation loss? (-1 = all)') 107 | parser.add_argument('--save_checkpoint_every', type=int, default=2500, 108 | help='how often to save a model checkpoint (in iterations)?') 109 | parser.add_argument('--checkpoint_path', type=str, default='save', 110 | help='directory to store checkpointed models') 111 | parser.add_argument('--language_eval', type=int, default=0, 112 | help='Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.') 113 | parser.add_argument('--losses_log_every', type=int, default=25, 114 | help='How often do we snapshot losses, for inclusion in the progress dump? (0 = disable)') 115 | parser.add_argument('--load_best_score', type=int, default=1, 116 | help='Do we load previous best score when resuming training.') 117 | 118 | # misc 119 | parser.add_argument('--id', type=str, default='', 120 | help='an id identifying this run/job. used in cross-val and appended when writing progress files') 121 | parser.add_argument('--train_only', type=int, default=0, 122 | help='if true then use 80k, else use 110k') 123 | 124 | 125 | # Reward 126 | parser.add_argument('--cider_reward_weight', type=float, default=1, 127 | help='The reward weight from cider') 128 | parser.add_argument('--bleu_reward_weight', type=float, default=0, 129 | help='The reward weight from bleu4') 130 | 131 | # Transformer 132 | parser.add_argument('--label_smoothing', type=float, default=0, 133 | help='') 134 | parser.add_argument('--noamopt', action='store_true', 135 | help='') 136 | parser.add_argument('--noamopt_warmup', type=int, default=2000, 137 | help='') 138 | parser.add_argument('--noamopt_factor', type=float, default=1, 139 | help='') 140 | 141 | parser.add_argument('--reduce_on_plateau', action='store_true', 142 | help='') 143 | 144 | args = parser.parse_args() 145 | 146 | # Check if args are valid 147 | assert args.rnn_size > 0, "rnn_size should be greater than 0" 148 | assert args.num_layers > 0, "num_layers should be greater than 0" 149 | assert args.input_encoding_size > 0, "input_encoding_size should be greater than 0" 150 | assert args.batch_size > 0, "batch_size should be greater than 0" 151 | assert args.drop_prob_lm >= 0 and args.drop_prob_lm < 1, "drop_prob_lm should be between 0 and 1" 152 | assert args.seq_per_img > 0, "seq_per_img should be greater than 0" 153 | assert args.beam_size > 0, "beam_size should be greater than 0" 154 | assert args.save_checkpoint_every > 0, "save_checkpoint_every should be greater than 0" 155 | assert args.losses_log_every > 0, "losses_log_every should be greater than 0" 156 | assert args.language_eval == 0 or args.language_eval == 1, "language_eval should be 0 or 1" 157 | assert args.load_best_score == 0 or args.load_best_score == 1, "language_eval should be 0 or 1" 158 | assert args.train_only == 0 or args.train_only == 1, "language_eval should be 0 or 1" 159 | 160 | return args -------------------------------------------------------------------------------- /scripts/copy_model.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | cp -r log_$1 log_$2 4 | cd log_$2 5 | mv infos_$1-best.pkl infos_$2-best.pkl 6 | mv infos_$1.pkl infos_$2.pkl 7 | cd ../ -------------------------------------------------------------------------------- /scripts/make_bu_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import base64 7 | import numpy as np 8 | import csv 9 | import sys 10 | import zlib 11 | import time 12 | import mmap 13 | import argparse 14 | 15 | parser = argparse.ArgumentParser() 16 | 17 | # output_dir 18 | parser.add_argument('--downloaded_feats', default='data/bu_data', help='downloaded feature directory') 19 | parser.add_argument('--output_dir', default='data/cocobu', help='output feature files') 20 | 21 | args = parser.parse_args() 22 | 23 | csv.field_size_limit(sys.maxsize) 24 | 25 | 26 | FIELDNAMES = ['image_id', 'image_w','image_h','num_boxes', 'boxes', 'features'] 27 | infiles = ['trainval/karpathy_test_resnet101_faster_rcnn_genome.tsv', 28 | 'trainval/karpathy_val_resnet101_faster_rcnn_genome.tsv',\ 29 | 'trainval/karpathy_train_resnet101_faster_rcnn_genome.tsv.0', \ 30 | 'trainval/karpathy_train_resnet101_faster_rcnn_genome.tsv.1'] 31 | 32 | os.makedirs(args.output_dir+'_att') 33 | os.makedirs(args.output_dir+'_fc') 34 | os.makedirs(args.output_dir+'_box') 35 | 36 | for infile in infiles: 37 | print('Reading ' + infile) 38 | with open(os.path.join(args.downloaded_feats, infile), "r+b") as tsv_in_file: 39 | reader = csv.DictReader(tsv_in_file, delimiter='\t', fieldnames = FIELDNAMES) 40 | for item in reader: 41 | item['image_id'] = int(item['image_id']) 42 | item['num_boxes'] = int(item['num_boxes']) 43 | for field in ['boxes', 'features']: 44 | item[field] = np.frombuffer(base64.decodestring(item[field]), 45 | dtype=np.float32).reshape((item['num_boxes'],-1)) 46 | np.savez_compressed(os.path.join(args.output_dir+'_att', str(item['image_id'])), feat=item['features']) 47 | np.save(os.path.join(args.output_dir+'_fc', str(item['image_id'])), item['features'].mean(0)) 48 | np.save(os.path.join(args.output_dir+'_box', str(item['image_id'])), item['boxes']) 49 | 50 | 51 | 52 | 53 | -------------------------------------------------------------------------------- /scripts/prepro_feats.py: -------------------------------------------------------------------------------- 1 | """ 2 | Preprocess a raw json dataset into hdf5/json files for use in data_loader.lua 3 | 4 | Input: json file that has the form 5 | [{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...] 6 | example element in this list would look like 7 | {'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895} 8 | 9 | This script reads this json, does some basic preprocessing on the captions 10 | (e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays 11 | 12 | Output: a json file and an hdf5 file 13 | The hdf5 file contains several fields: 14 | /images is (N,3,256,256) uint8 array of raw image data in RGB format 15 | /labels is (M,max_length) uint32 array of encoded labels, zero padded 16 | /label_start_ix and /label_end_ix are (N,) uint32 arrays of pointers to the 17 | first and last indices (in range 1..M) of labels for each image 18 | /label_length stores the length of the sequence for each of the M sequences 19 | 20 | The json file has a dict that contains: 21 | - an 'ix_to_word' field storing the vocab in form {ix:'word'}, where ix is 1-indexed 22 | - an 'images' field that is a list holding auxiliary information for each image, 23 | such as in particular the 'split' it was assigned to. 24 | """ 25 | 26 | from __future__ import absolute_import 27 | from __future__ import division 28 | from __future__ import print_function 29 | 30 | import os 31 | import json 32 | import argparse 33 | from random import shuffle, seed 34 | import string 35 | # non-standard dependencies: 36 | import h5py 37 | from six.moves import cPickle 38 | import numpy as np 39 | import torch 40 | import torchvision.models as models 41 | import skimage.io 42 | 43 | from torchvision import transforms as trn 44 | preprocess = trn.Compose([ 45 | #trn.ToTensor(), 46 | trn.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 47 | ]) 48 | 49 | from misc.resnet_utils import myResnet 50 | import misc.resnet as resnet 51 | 52 | def main(params): 53 | net = getattr(resnet, params['model'])() 54 | net.load_state_dict(torch.load(os.path.join(params['model_root'],params['model']+'.pth'))) 55 | my_resnet = myResnet(net) 56 | my_resnet.cuda() 57 | my_resnet.eval() 58 | 59 | imgs = json.load(open(params['input_json'], 'r')) 60 | imgs = imgs['images'] 61 | N = len(imgs) 62 | 63 | seed(123) # make reproducible 64 | 65 | dir_fc = params['output_dir']+'_fc' 66 | dir_att = params['output_dir']+'_att' 67 | if not os.path.isdir(dir_fc): 68 | os.mkdir(dir_fc) 69 | if not os.path.isdir(dir_att): 70 | os.mkdir(dir_att) 71 | 72 | for i,img in enumerate(imgs): 73 | # load the image 74 | I = skimage.io.imread(os.path.join(params['images_root'], img['filepath'], img['filename'])) 75 | # handle grayscale input images 76 | if len(I.shape) == 2: 77 | I = I[:,:,np.newaxis] 78 | I = np.concatenate((I,I,I), axis=2) 79 | 80 | I = I.astype('float32')/255.0 81 | I = torch.from_numpy(I.transpose([2,0,1])).cuda() 82 | I = preprocess(I) 83 | with torch.no_grad(): 84 | tmp_fc, tmp_att = my_resnet(I, params['att_size']) 85 | # write to pkl 86 | np.save(os.path.join(dir_fc, str(img['cocoid'])), tmp_fc.data.cpu().float().numpy()) 87 | np.savez_compressed(os.path.join(dir_att, str(img['cocoid'])), feat=tmp_att.data.cpu().float().numpy()) 88 | 89 | if i % 1000 == 0: 90 | print('processing %d/%d (%.2f%% done)' % (i, N, i*100.0/N)) 91 | print('wrote ', params['output_dir']) 92 | 93 | if __name__ == "__main__": 94 | 95 | parser = argparse.ArgumentParser() 96 | 97 | # input json 98 | parser.add_argument('--input_json', required=True, help='input json file to process into hdf5') 99 | parser.add_argument('--output_dir', default='data', help='output h5 file') 100 | 101 | # options 102 | parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json') 103 | parser.add_argument('--att_size', default=14, type=int, help='14x14 or 7x7') 104 | parser.add_argument('--model', default='resnet101', type=str, help='resnet101, resnet152') 105 | parser.add_argument('--model_root', default='./data/imagenet_weights', type=str, help='model root') 106 | 107 | args = parser.parse_args() 108 | params = vars(args) # convert to ordinary dict 109 | print('parsed input parameters:') 110 | print(json.dumps(params, indent = 2)) 111 | main(params) 112 | -------------------------------------------------------------------------------- /scripts/prepro_labels.py: -------------------------------------------------------------------------------- 1 | """ 2 | Preprocess a raw json dataset into hdf5/json files for use in data_loader.lua 3 | 4 | Input: json file that has the form 5 | [{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...] 6 | example element in this list would look like 7 | {'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895} 8 | 9 | This script reads this json, does some basic preprocessing on the captions 10 | (e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays 11 | 12 | Output: a json file and an hdf5 file 13 | The hdf5 file contains several fields: 14 | /images is (N,3,256,256) uint8 array of raw image data in RGB format 15 | /labels is (M,max_length) uint32 array of encoded labels, zero padded 16 | /label_start_ix and /label_end_ix are (N,) uint32 arrays of pointers to the 17 | first and last indices (in range 1..M) of labels for each image 18 | /label_length stores the length of the sequence for each of the M sequences 19 | 20 | The json file has a dict that contains: 21 | - an 'ix_to_word' field storing the vocab in form {ix:'word'}, where ix is 1-indexed 22 | - an 'images' field that is a list holding auxiliary information for each image, 23 | such as in particular the 'split' it was assigned to. 24 | """ 25 | 26 | from __future__ import absolute_import 27 | from __future__ import division 28 | from __future__ import print_function 29 | 30 | import os 31 | import json 32 | import argparse 33 | from random import shuffle, seed 34 | import string 35 | # non-standard dependencies: 36 | import h5py 37 | import numpy as np 38 | import torch 39 | import torchvision.models as models 40 | import skimage.io 41 | from PIL import Image 42 | 43 | def build_vocab(imgs, params): 44 | count_thr = params['word_count_threshold'] 45 | 46 | # count up the number of words 47 | counts = {} 48 | for img in imgs: 49 | for sent in img['sentences']: 50 | for w in sent['tokens']: 51 | counts[w] = counts.get(w, 0) + 1 52 | cw = sorted([(count,w) for w,count in counts.items()], reverse=True) 53 | print('top words and their counts:') 54 | print('\n'.join(map(str,cw[:20]))) 55 | 56 | # print some stats 57 | total_words = sum(counts.values()) 58 | print('total words:', total_words) 59 | bad_words = [w for w,n in counts.items() if n <= count_thr] 60 | vocab = [w for w,n in counts.items() if n > count_thr] 61 | bad_count = sum(counts[w] for w in bad_words) 62 | print('number of bad words: %d/%d = %.2f%%' % (len(bad_words), len(counts), len(bad_words)*100.0/len(counts))) 63 | print('number of words in vocab would be %d' % (len(vocab), )) 64 | print('number of UNKs: %d/%d = %.2f%%' % (bad_count, total_words, bad_count*100.0/total_words)) 65 | 66 | # lets look at the distribution of lengths as well 67 | sent_lengths = {} 68 | for img in imgs: 69 | for sent in img['sentences']: 70 | txt = sent['tokens'] 71 | nw = len(txt) 72 | sent_lengths[nw] = sent_lengths.get(nw, 0) + 1 73 | max_len = max(sent_lengths.keys()) 74 | print('max length sentence in raw data: ', max_len) 75 | print('sentence length distribution (count, number of words):') 76 | sum_len = sum(sent_lengths.values()) 77 | for i in range(max_len+1): 78 | print('%2d: %10d %f%%' % (i, sent_lengths.get(i,0), sent_lengths.get(i,0)*100.0/sum_len)) 79 | 80 | # lets now produce the final annotations 81 | if bad_count > 0: 82 | # additional special UNK token we will use below to map infrequent words to 83 | print('inserting the special UNK token') 84 | vocab.append('UNK') 85 | 86 | for img in imgs: 87 | img['final_captions'] = [] 88 | for sent in img['sentences']: 89 | txt = sent['tokens'] 90 | caption = [w if counts.get(w,0) > count_thr else 'UNK' for w in txt] 91 | img['final_captions'].append(caption) 92 | 93 | return vocab 94 | 95 | def encode_captions(imgs, params, wtoi): 96 | """ 97 | encode all captions into one large array, which will be 1-indexed. 98 | also produces label_start_ix and label_end_ix which store 1-indexed 99 | and inclusive (Lua-style) pointers to the first and last caption for 100 | each image in the dataset. 101 | """ 102 | 103 | max_length = params['max_length'] 104 | N = len(imgs) 105 | M = sum(len(img['final_captions']) for img in imgs) # total number of captions 106 | 107 | label_arrays = [] 108 | label_start_ix = np.zeros(N, dtype='uint32') # note: these will be one-indexed 109 | label_end_ix = np.zeros(N, dtype='uint32') 110 | label_length = np.zeros(M, dtype='uint32') 111 | caption_counter = 0 112 | counter = 1 113 | for i,img in enumerate(imgs): 114 | n = len(img['final_captions']) 115 | assert n > 0, 'error: some image has no captions' 116 | 117 | Li = np.zeros((n, max_length), dtype='uint32') 118 | for j,s in enumerate(img['final_captions']): 119 | label_length[caption_counter] = min(max_length, len(s)) # record the length of this sequence 120 | caption_counter += 1 121 | for k,w in enumerate(s): 122 | if k < max_length: 123 | Li[j,k] = wtoi[w] 124 | 125 | # note: word indices are 1-indexed, and captions are padded with zeros 126 | label_arrays.append(Li) 127 | label_start_ix[i] = counter 128 | label_end_ix[i] = counter + n - 1 129 | 130 | counter += n 131 | 132 | L = np.concatenate(label_arrays, axis=0) # put all the labels together 133 | assert L.shape[0] == M, 'lengths don\'t match? that\'s weird' 134 | assert np.all(label_length > 0), 'error: some caption had no words?' 135 | 136 | print('encoded captions to array of size ', L.shape) 137 | return L, label_start_ix, label_end_ix, label_length 138 | 139 | def main(params): 140 | 141 | imgs = json.load(open(params['input_json'], 'r')) 142 | imgs = imgs['images'] 143 | 144 | seed(123) # make reproducible 145 | 146 | # create the vocab 147 | vocab = build_vocab(imgs, params) 148 | itow = {i+1:w for i,w in enumerate(vocab)} # a 1-indexed vocab translation table 149 | wtoi = {w:i+1 for i,w in enumerate(vocab)} # inverse table 150 | 151 | # encode captions in large arrays, ready to ship to hdf5 file 152 | L, label_start_ix, label_end_ix, label_length = encode_captions(imgs, params, wtoi) 153 | 154 | # create output h5 file 155 | N = len(imgs) 156 | f_lb = h5py.File(params['output_h5']+'_label.h5', "w") 157 | f_lb.create_dataset("labels", dtype='uint32', data=L) 158 | f_lb.create_dataset("label_start_ix", dtype='uint32', data=label_start_ix) 159 | f_lb.create_dataset("label_end_ix", dtype='uint32', data=label_end_ix) 160 | f_lb.create_dataset("label_length", dtype='uint32', data=label_length) 161 | f_lb.close() 162 | 163 | # create output json file 164 | out = {} 165 | out['ix_to_word'] = itow # encode the (1-indexed) vocab 166 | out['images'] = [] 167 | for i,img in enumerate(imgs): 168 | 169 | jimg = {} 170 | jimg['split'] = img['split'] 171 | if 'filename' in img: jimg['file_path'] = os.path.join(img['filepath'], img['filename']) # copy it over, might need 172 | if 'cocoid' in img: jimg['id'] = img['cocoid'] # copy over & mantain an id, if present (e.g. coco ids, useful) 173 | 174 | if params['images_root'] != '': 175 | with Image.open(os.path.join(params['images_root'], img['filepath'], img['filename'])) as _img: 176 | jimg['width'], jimg['height'] = _img.size 177 | 178 | out['images'].append(jimg) 179 | 180 | json.dump(out, open(params['output_json'], 'w')) 181 | print('wrote ', params['output_json']) 182 | 183 | if __name__ == "__main__": 184 | 185 | parser = argparse.ArgumentParser() 186 | 187 | # input json 188 | parser.add_argument('--input_json', required=True, help='input json file to process into hdf5') 189 | parser.add_argument('--output_json', default='data.json', help='output json file') 190 | parser.add_argument('--output_h5', default='data', help='output h5 file') 191 | parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json') 192 | 193 | # options 194 | parser.add_argument('--max_length', default=16, type=int, help='max length of a caption, in number of words. captions longer than this get clipped.') 195 | parser.add_argument('--word_count_threshold', default=5, type=int, help='only words that occur more than this number of times will be put in vocab') 196 | 197 | args = parser.parse_args() 198 | params = vars(args) # convert to ordinary dict 199 | print('parsed input parameters:') 200 | print(json.dumps(params, indent = 2)) 201 | main(params) 202 | -------------------------------------------------------------------------------- /scripts/prepro_ngrams.py: -------------------------------------------------------------------------------- 1 | """ 2 | Preprocess a raw json dataset into hdf5/json files for use in data_loader.lua 3 | 4 | Input: json file that has the form 5 | [{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...] 6 | example element in this list would look like 7 | {'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895} 8 | 9 | This script reads this json, does some basic preprocessing on the captions 10 | (e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays 11 | 12 | Output: a json file and an hdf5 file 13 | The hdf5 file contains several fields: 14 | /images is (N,3,256,256) uint8 array of raw image data in RGB format 15 | /labels is (M,max_length) uint32 array of encoded labels, zero padded 16 | /label_start_ix and /label_end_ix are (N,) uint32 arrays of pointers to the 17 | first and last indices (in range 1..M) of labels for each image 18 | /label_length stores the length of the sequence for each of the M sequences 19 | 20 | The json file has a dict that contains: 21 | - an 'ix_to_word' field storing the vocab in form {ix:'word'}, where ix is 1-indexed 22 | - an 'images' field that is a list holding auxiliary information for each image, 23 | such as in particular the 'split' it was assigned to. 24 | """ 25 | 26 | import os 27 | import json 28 | import argparse 29 | from six.moves import cPickle 30 | from collections import defaultdict 31 | 32 | def precook(s, n=4, out=False): 33 | """ 34 | Takes a string as input and returns an object that can be given to 35 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 36 | can take string arguments as well. 37 | :param s: string : sentence to be converted into ngrams 38 | :param n: int : number of ngrams for which representation is calculated 39 | :return: term frequency vector for occuring ngrams 40 | """ 41 | words = s.split() 42 | counts = defaultdict(int) 43 | for k in xrange(1,n+1): 44 | for i in xrange(len(words)-k+1): 45 | ngram = tuple(words[i:i+k]) 46 | counts[ngram] += 1 47 | return counts 48 | 49 | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average" 50 | '''Takes a list of reference sentences for a single segment 51 | and returns an object that encapsulates everything that BLEU 52 | needs to know about them. 53 | :param refs: list of string : reference sentences for some image 54 | :param n: int : number of ngrams for which (ngram) representation is calculated 55 | :return: result (list of dict) 56 | ''' 57 | return [precook(ref, n) for ref in refs] 58 | 59 | def create_crefs(refs): 60 | crefs = [] 61 | for ref in refs: 62 | # ref is a list of 5 captions 63 | crefs.append(cook_refs(ref)) 64 | return crefs 65 | 66 | def compute_doc_freq(crefs): 67 | ''' 68 | Compute term frequency for reference data. 69 | This will be used to compute idf (inverse document frequency later) 70 | The term frequency is stored in the object 71 | :return: None 72 | ''' 73 | document_frequency = defaultdict(float) 74 | for refs in crefs: 75 | # refs, k ref captions of one image 76 | for ngram in set([ngram for ref in refs for (ngram,count) in ref.iteritems()]): 77 | document_frequency[ngram] += 1 78 | # maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 79 | return document_frequency 80 | 81 | def build_dict(imgs, wtoi, params): 82 | wtoi[''] = 0 83 | 84 | count_imgs = 0 85 | 86 | refs_words = [] 87 | refs_idxs = [] 88 | for img in imgs: 89 | if (params['split'] == img['split']) or \ 90 | (params['split'] == 'train' and img['split'] == 'restval') or \ 91 | (params['split'] == 'all'): 92 | #(params['split'] == 'val' and img['split'] == 'restval') or \ 93 | ref_words = [] 94 | ref_idxs = [] 95 | for sent in img['sentences']: 96 | tmp_tokens = sent['tokens'] + [''] 97 | tmp_tokens = [_ if _ in wtoi else 'UNK' for _ in tmp_tokens] 98 | ref_words.append(' '.join(tmp_tokens)) 99 | ref_idxs.append(' '.join([str(wtoi[_]) for _ in tmp_tokens])) 100 | refs_words.append(ref_words) 101 | refs_idxs.append(ref_idxs) 102 | count_imgs += 1 103 | print('total imgs:', count_imgs) 104 | 105 | ngram_words = compute_doc_freq(create_crefs(refs_words)) 106 | ngram_idxs = compute_doc_freq(create_crefs(refs_idxs)) 107 | return ngram_words, ngram_idxs, count_imgs 108 | 109 | def main(params): 110 | 111 | imgs = json.load(open(params['input_json'], 'r')) 112 | itow = json.load(open(params['dict_json'], 'r'))['ix_to_word'] 113 | wtoi = {w:i for i,w in itow.items()} 114 | 115 | imgs = imgs['images'] 116 | 117 | ngram_words, ngram_idxs, ref_len = build_dict(imgs, wtoi, params) 118 | 119 | cPickle.dump({'document_frequency': ngram_words, 'ref_len': ref_len}, open(params['output_pkl']+'-words.p','w'), protocol=cPickle.HIGHEST_PROTOCOL) 120 | cPickle.dump({'document_frequency': ngram_idxs, 'ref_len': ref_len}, open(params['output_pkl']+'-idxs.p','w'), protocol=cPickle.HIGHEST_PROTOCOL) 121 | 122 | if __name__ == "__main__": 123 | 124 | parser = argparse.ArgumentParser() 125 | 126 | # input json 127 | parser.add_argument('--input_json', default='/home-nfs/rluo/rluo/nips/code/prepro/dataset_coco.json', help='input json file to process into hdf5') 128 | parser.add_argument('--dict_json', default='data/cocotalk.json', help='output json file') 129 | parser.add_argument('--output_pkl', default='data/coco-all', help='output pickle file') 130 | parser.add_argument('--split', default='all', help='test, val, train, all') 131 | args = parser.parse_args() 132 | params = vars(args) # convert to ordinary dict 133 | 134 | main(params) 135 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | 9 | import numpy as np 10 | 11 | import time 12 | import os 13 | from six.moves import cPickle 14 | 15 | import opts 16 | import models 17 | from dataloader import * 18 | import eval_utils 19 | import misc.utils as utils 20 | from misc.rewards import init_scorer, get_self_critical_reward 21 | 22 | try: 23 | import tensorboardX as tb 24 | except ImportError: 25 | print("tensorboardX is not installed") 26 | tb = None 27 | 28 | def add_summary_value(writer, key, value, iteration): 29 | if writer: 30 | writer.add_scalar(key, value, iteration) 31 | 32 | def train(opt): 33 | # Deal with feature things before anything 34 | opt.use_att = utils.if_use_att(opt.caption_model) 35 | if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5 36 | 37 | loader = DataLoader(opt) 38 | opt.vocab_size = loader.vocab_size 39 | opt.seq_length = loader.seq_length 40 | 41 | tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path) 42 | 43 | infos = {} 44 | histories = {} 45 | if opt.start_from is not None: 46 | # open old infos and check if models are compatible 47 | with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl')) as f: 48 | infos = cPickle.load(f) 49 | saved_model_opt = infos['opt'] 50 | need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"] 51 | for checkme in need_be_same: 52 | assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme 53 | 54 | if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')): 55 | with open(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')) as f: 56 | histories = cPickle.load(f) 57 | 58 | iteration = infos.get('iter', 0) 59 | epoch = infos.get('epoch', 0) 60 | 61 | val_result_history = histories.get('val_result_history', {}) 62 | loss_history = histories.get('loss_history', {}) 63 | lr_history = histories.get('lr_history', {}) 64 | ss_prob_history = histories.get('ss_prob_history', {}) 65 | 66 | loader.iterators = infos.get('iterators', loader.iterators) 67 | loader.split_ix = infos.get('split_ix', loader.split_ix) 68 | if opt.load_best_score == 1: 69 | best_val_score = infos.get('best_val_score', None) 70 | 71 | model = models.setup(opt).cuda() 72 | dp_model = torch.nn.DataParallel(model) 73 | 74 | epoch_done = True 75 | # Assure in training mode 76 | dp_model.train() 77 | 78 | if opt.label_smoothing > 0: 79 | crit = utils.LabelSmoothing(smoothing=opt.label_smoothing) 80 | else: 81 | crit = utils.LanguageModelCriterion() 82 | rl_crit = utils.RewardCriterion() 83 | 84 | if opt.noamopt: 85 | assert opt.caption_model == 'transformer', 'noamopt can only work with transformer' 86 | optimizer = utils.get_std_opt(model, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup) 87 | optimizer._step = iteration 88 | elif opt.reduce_on_plateau: 89 | optimizer = utils.build_optimizer(model.parameters(), opt) 90 | optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3) 91 | else: 92 | optimizer = utils.build_optimizer(model.parameters(), opt) 93 | # Load the optimizer 94 | if vars(opt).get('start_from', None) is not None and os.path.isfile(os.path.join(opt.start_from,"optimizer.pth")): 95 | optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer.pth'))) 96 | 97 | while True: 98 | if epoch_done: 99 | if not opt.noamopt and not opt.reduce_on_plateau: 100 | # Assign the learning rate 101 | if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: 102 | frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every 103 | decay_factor = opt.learning_rate_decay_rate ** frac 104 | opt.current_lr = opt.learning_rate * decay_factor 105 | else: 106 | opt.current_lr = opt.learning_rate 107 | utils.set_lr(optimizer, opt.current_lr) # set the decayed rate 108 | # Assign the scheduled sampling prob 109 | if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: 110 | frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every 111 | opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob) 112 | model.ss_prob = opt.ss_prob 113 | 114 | # If start self critical training 115 | if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: 116 | sc_flag = True 117 | init_scorer(opt.cached_tokens) 118 | else: 119 | sc_flag = False 120 | 121 | epoch_done = False 122 | 123 | start = time.time() 124 | # Load data from train split (0) 125 | data = loader.get_batch('train') 126 | print('Read data:', time.time() - start) 127 | 128 | torch.cuda.synchronize() 129 | start = time.time() 130 | 131 | tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks']] 132 | tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp] 133 | fc_feats, att_feats, labels, masks, att_masks = tmp 134 | 135 | optimizer.zero_grad() 136 | if not sc_flag: 137 | loss = crit(dp_model(fc_feats, att_feats, labels, att_masks), labels[:,1:], masks[:,1:]) 138 | else: 139 | gen_result, sample_logprobs = dp_model(fc_feats, att_feats, att_masks, opt={'sample_max':0}, mode='sample') 140 | reward = get_self_critical_reward(dp_model, fc_feats, att_feats, att_masks, data, gen_result, opt) 141 | loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda()) 142 | 143 | loss.backward() 144 | utils.clip_gradient(optimizer, opt.grad_clip) 145 | optimizer.step() 146 | train_loss = loss.item() 147 | torch.cuda.synchronize() 148 | end = time.time() 149 | if not sc_flag: 150 | print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \ 151 | .format(iteration, epoch, train_loss, end - start)) 152 | else: 153 | print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \ 154 | .format(iteration, epoch, np.mean(reward[:,0]), end - start)) 155 | 156 | # Update the iteration and epoch 157 | iteration += 1 158 | if data['bounds']['wrapped']: 159 | epoch += 1 160 | epoch_done = True 161 | 162 | # Write the training loss summary 163 | if (iteration % opt.losses_log_every == 0): 164 | add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration) 165 | if opt.noamopt: 166 | opt.current_lr = optimizer.rate() 167 | elif opt.reduce_on_plateau: 168 | opt.current_lr = optimizer.current_lr 169 | add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration) 170 | add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration) 171 | if sc_flag: 172 | add_summary_value(tb_summary_writer, 'avg_reward', np.mean(reward[:,0]), iteration) 173 | 174 | loss_history[iteration] = train_loss if not sc_flag else np.mean(reward[:,0]) 175 | lr_history[iteration] = opt.current_lr 176 | ss_prob_history[iteration] = model.ss_prob 177 | 178 | # make evaluation on validation set, and save model 179 | if (iteration % opt.save_checkpoint_every == 0): 180 | # eval model 181 | eval_kwargs = {'split': 'val', 182 | 'dataset': opt.input_json} 183 | eval_kwargs.update(vars(opt)) 184 | val_loss, predictions, lang_stats = eval_utils.eval_split(dp_model, crit, loader, eval_kwargs) 185 | 186 | if opt.reduce_on_plateau: 187 | if 'CIDEr' in lang_stats: 188 | optimizer.scheduler_step(-lang_stats['CIDEr']) 189 | else: 190 | optimizer.scheduler_step(val_loss) 191 | 192 | # Write validation result into summary 193 | add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration) 194 | for k,v in lang_stats.items(): 195 | add_summary_value(tb_summary_writer, k, v, iteration) 196 | val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions} 197 | 198 | # Save model if is improving on validation result 199 | if opt.language_eval == 1: 200 | current_score = lang_stats['CIDEr'] 201 | else: 202 | current_score = - val_loss 203 | 204 | best_flag = False 205 | if True: # if true 206 | if best_val_score is None or current_score > best_val_score: 207 | best_val_score = current_score 208 | best_flag = True 209 | checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth') 210 | torch.save(model.state_dict(), checkpoint_path) 211 | print("model saved to {}".format(checkpoint_path)) 212 | optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth') 213 | torch.save(optimizer.state_dict(), optimizer_path) 214 | 215 | # Dump miscalleous informations 216 | infos['iter'] = iteration 217 | infos['epoch'] = epoch 218 | infos['iterators'] = loader.iterators 219 | infos['split_ix'] = loader.split_ix 220 | infos['best_val_score'] = best_val_score 221 | infos['opt'] = opt 222 | infos['vocab'] = loader.get_vocab() 223 | 224 | histories['val_result_history'] = val_result_history 225 | histories['loss_history'] = loss_history 226 | histories['lr_history'] = lr_history 227 | histories['ss_prob_history'] = ss_prob_history 228 | with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'.pkl'), 'wb') as f: 229 | cPickle.dump(infos, f) 230 | with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+'.pkl'), 'wb') as f: 231 | cPickle.dump(histories, f) 232 | 233 | if best_flag: 234 | checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best.pth') 235 | torch.save(model.state_dict(), checkpoint_path) 236 | print("model saved to {}".format(checkpoint_path)) 237 | with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'-best.pkl'), 'wb') as f: 238 | cPickle.dump(infos, f) 239 | 240 | # Stop if reaching max epochs 241 | if epoch >= opt.max_epochs and opt.max_epochs != -1: 242 | break 243 | 244 | opt = opts.parse_opt() 245 | train(opt) 246 | -------------------------------------------------------------------------------- /vis/imgs/dummy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruotianluo/Transformer_Captioning/717fe48e4d808df771c2de37610b200c3eb78fe1/vis/imgs/dummy -------------------------------------------------------------------------------- /vis/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | neuraltalk2 results visualization 7 | 8 | 42 | 43 | 44 |
45 | 72 | 73 | 74 | --------------------------------------------------------------------------------