├── README.md ├── dataloader.py ├── dataloaderraw.py ├── eval.py ├── eval_utils.py ├── misc ├── __init__.py ├── resnet.py ├── resnet_utils.py ├── rewards.py └── utils.py ├── models ├── Att2inModel.py ├── AttModel.py ├── CaptionModel.py ├── FCModel.py ├── OldModel.py ├── ShowTellModel.py └── __init__.py ├── opts.py ├── scripts ├── copy_model.sh ├── prepro_feats.py ├── prepro_labels.py └── prepro_ngrams.py ├── train.py └── vis ├── imgs └── dummy ├── index.html └── jquery-1.8.3.min.js /README.md: -------------------------------------------------------------------------------- 1 | # Self-critical Sequence Training for Image Captioning 2 | 3 | This is an unofficial implementation for [Self-critical Sequence Training for Image Captioning](https://arxiv.org/abs/1612.00563). The result of FC model can be replicated. (Not able to replicate Att2in result.) 4 | 5 | The author helped me a lot when I tried to replicate the result. Great thanks. The latest topdown and att2in2 model can achieve 1.12 Cider score on Karpathy's test split after self-critical training. 6 | 7 | This is based on my [neuraltalk2.pytorch](https://github.com/ruotianluo/neuraltalk2.pytorch) repository. The modifications is: 8 | - Add self critical training. 9 | 10 | ## Requirements 11 | Python 2.7 (because there is no [coco-caption](https://github.com/tylin/coco-caption) version for python 3) 12 | PyTorch 0.2 (along with torchvision) 13 | 14 | You need to download pretrained resnet model for both training and evaluation. The models can be downloaded from [here](https://drive.google.com/open?id=0B7fNdx_jAqhtbVYzOURMdDNHSGM), and should be placed in `data/imagenet_weights`. 15 | 16 | ## Pretrained models 17 | Pretrained models are provided [here](https://drive.google.com/open?id=0B7fNdx_jAqhtdE1JRXpmeGJudTg). And the performances of each model will be maintained in this [issue](https://github.com/ruotianluo/neuraltalk2.pytorch/issues/10). 18 | 19 | If you want to do evaluation only, then you can follow [this section](#generate-image-captions) after downloading the pretrained models. 20 | 21 | ## Train your own network on ImageCLEF dataset 22 | ImageCLEF 2018 caption: http://www.imageclef.org/2018/caption 23 | 24 | ImageCLEF 2017 caption review: http://ceur-ws.org/Vol-1866/ 25 | 26 | ImageCLEF 2017leaderboard: http://www.imageclef.org/2017/caption 27 | 28 | 29 | ## Train your own network on Eric Xing's Medical image Caption paper https://arxiv.org/abs/1711.08195 dataset. 30 | 31 | ### Download Chest X-ray dataset and preprocessing. 32 | Download:https://www.dropbox.com/s/uanpmhfvpe7gbhk/chest_xray.zip?dl=0 33 | 34 | jason preprocessed: https://www.dropbox.com/sh/zar9xooe0557wqp/AABagbPxDB3SAUBZoR1umguEa?dl=0 35 | 36 | ## Train your own network on IU X-Ray 37 | 38 | ### Download dataset and preprocessing 39 | 40 | First, download the coco images from [link](http://mscoco.org/dataset/#download). We need 2014 training images and 2014 val. images. You should put the `train2014/` and `val2014/` in the same directory, denoted as `$IMAGE_ROOT`. 41 | 42 | Download preprocessed coco captions from [link](http://cs.stanford.edu/people/karpathy/deepimagesent/caption_datasets.zip) from Karpathy's homepage. Extract `dataset_coco.json` from the zip file and copy it in to `data/`. This file provides preprocessed captions and also standard train-val-test splits. 43 | 44 | Once we have these, we can now invoke the `prepro_*.py` script, which will read all of this in and create a dataset (two feature folders, a hdf5 label file and a json file). 45 | 46 | ```bash 47 | $ python scripts/prepro_labels.py --input_json Chest_Xray/data_labels.json --output_json Chest_Xray/cocotalk.json --output_h5 Chest_Xray/cocotalk 48 | $ python scripts/prepro_feats.py --input_json Chest_Xray/data_labels.json --output_dir Chest_Xray/cocotalk --images_root $IMAGE_ROOT 49 | ``` 50 | 51 | `prepro_labels.py` will map all words that occur <= 5 times to a special `UNK` token, and create a vocabulary for all the remaining words. The image information and vocabulary are dumped into `data/cocotalk.json` and discretized caption data are dumped into `data/cocotalk_label.h5`. 52 | 53 | `prepro_feats.py` extract the resnet101 features (both fc feature and last conv feature) of each image. The features are saved in `data/cocotalk_fc` and `data/cocotalk_att`, and resulting files are about 200GB. 54 | 55 | (Check the prepro scripts for more options, like other resnet models or other attention sizes.) 56 | 57 | **Warning**: the prepro script will fail with the default MSCOCO data because one of their images is corrupted. See [this issue](https://github.com/karpathy/neuraltalk2/issues/4) for the fix, it involves manually replacing one image in the dataset. 58 | 59 | ### Start training 60 | 61 | ```bash 62 | $ python train.py --id fc --caption_model fc --input_json data/cocotalk.json --input_fc_dir data/cocotalk_fc --input_att_dir data/cocotalk_att --input_label_h5 data/cocotalk_label.h5 --batch_size 10 --learning_rate 5e-4 --learning_rate_decay_start 0 --scheduled_sampling_start 0 --checkpoint_path log_fc --save_checkpoint_every 6000 --val_images_use 5000 --max_epochs 30 63 | ``` 64 | 65 | The train script will dump checkpoints into the folder specified by `--checkpoint_path` (default = `save/`). We only save the best-performing checkpoint on validation and the latest checkpoint to save disk space. 66 | 67 | To resume training, you can specify `--start_from` option to be the path saving `infos.pkl` and `model.pth` (usually you could just set `--start_from` and `--checkpoint_path` to be the same). 68 | 69 | If you have tensorflow, the loss histories are automatically dumped into `--checkpoint_path`, and can be visualized using tensorboard. 70 | 71 | The current command use scheduled sampling, you can also set scheduled_sampling_start to -1 to turn off scheduled sampling. 72 | 73 | If you'd like to evaluate BLEU/METEOR/CIDEr scores during training in addition to validation cross entropy loss, use `--language_eval 1` option, but don't forget to download the [coco-caption code](https://github.com/tylin/coco-caption) into `coco-caption` directory. 74 | 75 | For more options, see `opts.py`. 76 | 77 | **A few notes on training.** To give you an idea, with the default settings one epoch of MS COCO images is about 11000 iterations. After 1 epoch of training results in validation loss ~2.5 and CIDEr score of ~0.68. By iteration 60,000 CIDEr climbs up to about ~0.84 (validation loss at about 2.4 (under scheduled sampling)). 78 | 79 | ### Train using self critical 80 | 81 | First you should preprocess the dataset and get the cache for calculating cider score: 82 | ``` 83 | $ python scripts/prepro_ngrams.py --input_json .../dataset_coco.json --dict_json data/cocotalk.json --output_pkl data/coco-train --split train 84 | ``` 85 | 86 | And also you need to clone my forked [cider](https://github.com/ruotianluo/cider) repository. 87 | 88 | Then, copy the model from the pretrained model using cross entropy. (It's not mandatory to copy the model, just for back-up) 89 | ``` 90 | $ bash scripts/copy_model.sh fc fc_rl 91 | ``` 92 | 93 | Then 94 | ```bash 95 | $ python train.py --id fc_rl --caption_model fc --input_json data/cocotalk.json --input_fc_dir data/cocotalk_fc --input_att_dir data/cocotalk_att --input_label_h5 data/cocotalk_label.h5 --batch_size 10 --learning_rate 5e-5 --start_from log_fc_rl --checkpoint_path log_fc_rl --save_checkpoint_every 6000 --language_eval 1 --val_images_use 5000 --self_critical_after 30 96 | ``` 97 | 98 | You will see a huge boost on Cider score, : ). 99 | 100 | **A few notes on training.** Starting self-critical training after 30 epochs, the CIDEr score goes up to 1.05 after 600k iterations (including the 30 epochs pertraining). 101 | 102 | ### Caption images after training 103 | 104 | ## Generate image captions 105 | 106 | ### Evaluate on raw images 107 | Now place all your images of interest into a folder, e.g. `blah`, and run 108 | the eval script: 109 | 110 | ```bash 111 | $ python eval.py --model model.pth --infos_path infos.pkl --image_folder blah --num_images 10 112 | ``` 113 | 114 | This tells the `eval` script to run up to 10 images from the given folder. If you have a big GPU you can speed up the evaluation by increasing `batch_size`. Use `--num_images -1` to process all images. The eval script will create an `vis.json` file inside the `vis` folder, which can then be visualized with the provided HTML interface: 115 | 116 | ```bash 117 | $ cd vis 118 | $ python -m SimpleHTTPServer 119 | ``` 120 | 121 | Now visit `localhost:8000` in your browser and you should see your predicted captions. 122 | 123 | ### Evaluate on Karpathy's test split 124 | 125 | ```bash 126 | $ python eval.py --dump_images 0 --num_images 5000 --model model.pth --infos_path infos.pkl --language_eval 1 127 | ``` 128 | 129 | The defualt split to evaluate is test. The default inference method is greedy decoding (`--sample_max 1`), to sample from the posterior, set `--sample_max 0`. 130 | 131 | **Beam Search**. Beam search can increase the performance of the search for greedy decoding sequence by ~5%. However, this is a little more expensive. To turn on the beam search, use `--beam_size N`, N should be greater than 1. 132 | 133 | 134 | 135 | ## Train your own network on COCO 136 | 137 | ### Download COCO dataset and preprocessing 138 | 139 | First, download the coco images from [link](http://mscoco.org/dataset/#download). We need 2014 training images and 2014 val. images. You should put the `train2014/` and `val2014/` in the same directory, denoted as `$IMAGE_ROOT`. 140 | 141 | Download preprocessed coco captions from [link](http://cs.stanford.edu/people/karpathy/deepimagesent/caption_datasets.zip) from Karpathy's homepage. Extract `dataset_coco.json` from the zip file and copy it in to `data/`. This file provides preprocessed captions and also standard train-val-test splits. 142 | 143 | Once we have these, we can now invoke the `prepro_*.py` script, which will read all of this in and create a dataset (two feature folders, a hdf5 label file and a json file). 144 | 145 | ```bash 146 | $ python scripts/prepro_labels.py --input_json data/dataset_coco.json --output_json data/cocotalk.json --output_h5 data/cocotalk 147 | $ python scripts/prepro_feats.py --input_json data/dataset_coco.json --output_dir data/cocotalk --images_root $IMAGE_ROOT 148 | ``` 149 | 150 | `prepro_labels.py` will map all words that occur <= 5 times to a special `UNK` token, and create a vocabulary for all the remaining words. The image information and vocabulary are dumped into `data/cocotalk.json` and discretized caption data are dumped into `data/cocotalk_label.h5`. 151 | 152 | `prepro_feats.py` extract the resnet101 features (both fc feature and last conv feature) of each image. The features are saved in `data/cocotalk_fc` and `data/cocotalk_att`, and resulting files are about 200GB. 153 | 154 | (Check the prepro scripts for more options, like other resnet models or other attention sizes.) 155 | 156 | **Warning**: the prepro script will fail with the default MSCOCO data because one of their images is corrupted. See [this issue](https://github.com/karpathy/neuraltalk2/issues/4) for the fix, it involves manually replacing one image in the dataset. 157 | 158 | ### Start training 159 | 160 | ```bash 161 | $ python train.py --id fc --caption_model fc --input_json data/cocotalk.json --input_fc_dir data/cocotalk_fc --input_att_dir data/cocotalk_att --input_label_h5 data/cocotalk_label.h5 --batch_size 10 --learning_rate 5e-4 --learning_rate_decay_start 0 --scheduled_sampling_start 0 --checkpoint_path log_fc --save_checkpoint_every 6000 --val_images_use 5000 --max_epochs 30 162 | ``` 163 | 164 | The train script will dump checkpoints into the folder specified by `--checkpoint_path` (default = `save/`). We only save the best-performing checkpoint on validation and the latest checkpoint to save disk space. 165 | 166 | To resume training, you can specify `--start_from` option to be the path saving `infos.pkl` and `model.pth` (usually you could just set `--start_from` and `--checkpoint_path` to be the same). 167 | 168 | If you have tensorflow, the loss histories are automatically dumped into `--checkpoint_path`, and can be visualized using tensorboard. 169 | 170 | The current command use scheduled sampling, you can also set scheduled_sampling_start to -1 to turn off scheduled sampling. 171 | 172 | If you'd like to evaluate BLEU/METEOR/CIDEr scores during training in addition to validation cross entropy loss, use `--language_eval 1` option, but don't forget to download the [coco-caption code](https://github.com/tylin/coco-caption) into `coco-caption` directory. 173 | 174 | For more options, see `opts.py`. 175 | 176 | **A few notes on training.** To give you an idea, with the default settings one epoch of MS COCO images is about 11000 iterations. After 1 epoch of training results in validation loss ~2.5 and CIDEr score of ~0.68. By iteration 60,000 CIDEr climbs up to about ~0.84 (validation loss at about 2.4 (under scheduled sampling)). 177 | 178 | ### Train using self critical 179 | 180 | First you should preprocess the dataset and get the cache for calculating cider score: 181 | ``` 182 | $ python scripts/prepro_ngrams.py --input_json .../dataset_coco.json --dict_json data/cocotalk.json --output_pkl data/coco-train --split train 183 | ``` 184 | 185 | And also you need to clone my forked [cider](https://github.com/ruotianluo/cider) repository. 186 | 187 | Then, copy the model from the pretrained model using cross entropy. (It's not mandatory to copy the model, just for back-up) 188 | ``` 189 | $ bash scripts/copy_model.sh fc fc_rl 190 | ``` 191 | 192 | Then 193 | ```bash 194 | $ python train.py --id fc_rl --caption_model fc --input_json data/cocotalk.json --input_fc_dir data/cocotalk_fc --input_att_dir data/cocotalk_att --input_label_h5 data/cocotalk_label.h5 --batch_size 10 --learning_rate 5e-5 --start_from log_fc_rl --checkpoint_path log_fc_rl --save_checkpoint_every 6000 --language_eval 1 --val_images_use 5000 --self_critical_after 30 195 | ``` 196 | 197 | You will see a huge boost on Cider score, : ). 198 | 199 | **A few notes on training.** Starting self-critical training after 30 epochs, the CIDEr score goes up to 1.05 after 600k iterations (including the 30 epochs pertraining). 200 | 201 | ### Caption images after training 202 | 203 | ## Generate image captions 204 | 205 | ### Evaluate on raw images 206 | Now place all your images of interest into a folder, e.g. `blah`, and run 207 | the eval script: 208 | 209 | ```bash 210 | $ python eval.py --model model.pth --infos_path infos.pkl --image_folder blah --num_images 10 211 | ``` 212 | 213 | This tells the `eval` script to run up to 10 images from the given folder. If you have a big GPU you can speed up the evaluation by increasing `batch_size`. Use `--num_images -1` to process all images. The eval script will create an `vis.json` file inside the `vis` folder, which can then be visualized with the provided HTML interface: 214 | 215 | ```bash 216 | $ cd vis 217 | $ python -m SimpleHTTPServer 218 | ``` 219 | 220 | Now visit `localhost:8000` in your browser and you should see your predicted captions. 221 | 222 | ### Evaluate on Karpathy's test split 223 | 224 | ```bash 225 | $ python eval.py --dump_images 0 --num_images 5000 --model model.pth --infos_path infos.pkl --language_eval 1 226 | ``` 227 | 228 | The defualt split to evaluate is test. The default inference method is greedy decoding (`--sample_max 1`), to sample from the posterior, set `--sample_max 0`. 229 | 230 | **Beam Search**. Beam search can increase the performance of the search for greedy decoding sequence by ~5%. However, this is a little more expensive. To turn on the beam search, use `--beam_size N`, N should be greater than 1. 231 | 232 | ## Miscellanea 233 | **Using cpu**. The code is currently defaultly using gpu; there is even no option for switching. If someone highly needs a cpu model, please open an issue; I can potentially create a cpu checkpoint and modify the eval.py to run the model on cpu. However, there's no point using cpu to train the model. 234 | 235 | **Train on other dataset**. It should be trivial to port if you can create a file like `dataset_coco.json` for your own dataset. 236 | 237 | **Live demo**. Not supported now. Welcome pull request. 238 | 239 | ## Acknowledgements 240 | 241 | Thanks the original [neuraltalk2](https://github.com/karpathy/neuraltalk2) and awesome PyTorch team. 242 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import json 6 | import h5py 7 | import os 8 | import numpy as np 9 | import random 10 | 11 | import torch 12 | import torch.utils.data as data 13 | 14 | import multiprocessing 15 | 16 | def get_npy_data(ix, fc_file, att_file, use_att): 17 | if use_att == True: 18 | return (np.load(fc_file), np.load(att_file)['feat'], ix) 19 | else: 20 | return (np.load(fc_file), np.zeros((1,1,1)), ix) 21 | 22 | class DataLoader(data.Dataset): 23 | 24 | def reset_iterator(self, split): 25 | del self._prefetch_process[split] 26 | self._prefetch_process[split] = BlobFetcher(split, self, split=='train') 27 | self.iterators[split] = 0 28 | 29 | def get_vocab_size(self): 30 | return self.vocab_size 31 | 32 | def get_vocab(self): 33 | return self.ix_to_word 34 | 35 | def get_seq_length(self): 36 | return self.seq_length 37 | 38 | def __init__(self, opt): 39 | self.opt = opt 40 | self.batch_size = self.opt.batch_size 41 | self.seq_per_img = opt.seq_per_img 42 | self.use_att = getattr(opt, 'use_att', True) 43 | 44 | # load the json file which contains additional information about the dataset 45 | print('DataLoader loading json file: ', opt.input_json) 46 | self.info = json.load(open(self.opt.input_json)) 47 | self.ix_to_word = self.info['ix_to_word'] 48 | self.vocab_size = len(self.ix_to_word) 49 | print('vocab size is ', self.vocab_size) 50 | 51 | # open the hdf5 file 52 | print('DataLoader loading h5 file: ', opt.input_fc_dir, opt.input_att_dir, opt.input_label_h5) 53 | self.h5_label_file = h5py.File(self.opt.input_label_h5, 'r', driver='core') 54 | 55 | self.input_fc_dir = self.opt.input_fc_dir 56 | self.input_att_dir = self.opt.input_att_dir 57 | 58 | # load in the sequence data 59 | seq_size = self.h5_label_file['labels'].shape 60 | self.seq_length = seq_size[1] 61 | print('max sequence length in data is', self.seq_length) 62 | # load the pointers in full to RAM (should be small enough) 63 | self.label_start_ix = self.h5_label_file['label_start_ix'][:] 64 | self.label_end_ix = self.h5_label_file['label_end_ix'][:] 65 | 66 | self.num_images = self.label_start_ix.shape[0] 67 | print('read %d image features' %(self.num_images)) 68 | 69 | # separate out indexes for each of the provided splits 70 | self.split_ix = {'train': [], 'val': [], 'test': []} 71 | for ix in range(len(self.info['images'])): 72 | img = self.info['images'][ix] 73 | if img['split'] == 'train': 74 | self.split_ix['train'].append(ix) 75 | elif img['split'] == 'val': 76 | self.split_ix['val'].append(ix) 77 | elif img['split'] == 'test': 78 | self.split_ix['test'].append(ix) 79 | elif opt.train_only == 0: # restval 80 | self.split_ix['train'].append(ix) 81 | 82 | print('assigned %d images to split train' %len(self.split_ix['train'])) 83 | print('assigned %d images to split val' %len(self.split_ix['val'])) 84 | print('assigned %d images to split test' %len(self.split_ix['test'])) 85 | 86 | self.iterators = {'train': 0, 'val': 0, 'test': 0} 87 | 88 | self._prefetch_process = {} # The three prefetch process 89 | for split in self.iterators.keys(): 90 | self._prefetch_process[split] = BlobFetcher(split, self, split=='train') 91 | # Terminate the child process when the parent exists 92 | def cleanup(): 93 | print('Terminating BlobFetcher') 94 | for split in self.iterators.keys(): 95 | del self._prefetch_process[split] 96 | import atexit 97 | atexit.register(cleanup) 98 | 99 | def get_batch(self, split, batch_size=None, seq_per_img=None): 100 | batch_size = batch_size or self.batch_size 101 | seq_per_img = seq_per_img or self.seq_per_img 102 | 103 | fc_batch = [] # np.ndarray((batch_size * seq_per_img, self.opt.fc_feat_size), dtype = 'float32') 104 | att_batch = [] # np.ndarray((batch_size * seq_per_img, 14, 14, self.opt.att_feat_size), dtype = 'float32') 105 | label_batch = np.zeros([batch_size * seq_per_img, self.seq_length + 2], dtype = 'int') 106 | mask_batch = np.zeros([batch_size * seq_per_img, self.seq_length + 2], dtype = 'float32') 107 | 108 | wrapped = False 109 | 110 | infos = [] 111 | gts = [] 112 | 113 | for i in range(batch_size): 114 | import time 115 | t_start = time.time() 116 | # fetch image 117 | tmp_fc, tmp_att,\ 118 | ix, tmp_wrapped = self._prefetch_process[split].get() 119 | fc_batch += [tmp_fc] * seq_per_img 120 | att_batch += [tmp_att] * seq_per_img 121 | 122 | # fetch the sequence labels 123 | ix1 = self.label_start_ix[ix] - 1 #label_start_ix starts from 1 124 | ix2 = self.label_end_ix[ix] - 1 125 | ncap = ix2 - ix1 + 1 # number of captions available for this image 126 | assert ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t' 127 | 128 | if ncap < seq_per_img: 129 | # we need to subsample (with replacement) 130 | seq = np.zeros([seq_per_img, self.seq_length], dtype = 'int') 131 | for q in range(seq_per_img): 132 | ixl = random.randint(ix1,ix2) 133 | seq[q, :] = self.h5_label_file['labels'][ixl, :self.seq_length] 134 | else: 135 | ixl = random.randint(ix1, ix2 - seq_per_img + 1) 136 | seq = self.h5_label_file['labels'][ixl: ixl + seq_per_img, :self.seq_length] 137 | 138 | label_batch[i * seq_per_img : (i + 1) * seq_per_img, 1 : self.seq_length + 1] = seq 139 | 140 | if tmp_wrapped: 141 | wrapped = True 142 | 143 | # Used for reward evaluation 144 | gts.append(self.h5_label_file['labels'][self.label_start_ix[ix] - 1: self.label_end_ix[ix]]) 145 | 146 | # record associated info as well 147 | info_dict = {} 148 | info_dict['ix'] = ix 149 | info_dict['id'] = self.info['images'][ix]['id'] 150 | info_dict['file_path'] = self.info['images'][ix]['file_path'] 151 | infos.append(info_dict) 152 | #print(i, time.time() - t_start) 153 | 154 | # generate mask 155 | t_start = time.time() 156 | nonzeros = np.array(list(map(lambda x: (x != 0).sum()+2, label_batch))) 157 | for ix, row in enumerate(mask_batch): 158 | row[:nonzeros[ix]] = 1 159 | #print('mask', time.time() - t_start) 160 | 161 | data = {} 162 | data['fc_feats'] = np.stack(fc_batch) 163 | data['att_feats'] = np.stack(att_batch) 164 | data['labels'] = label_batch 165 | data['gts'] = gts 166 | data['masks'] = mask_batch 167 | data['bounds'] = {'it_pos_now': self.iterators[split], 'it_max': len(self.split_ix[split]), 'wrapped': wrapped} 168 | data['infos'] = infos 169 | 170 | return data 171 | 172 | # It's not coherent to make DataLoader a subclass of Dataset, but essentially, we only need to implement the following to functions, 173 | # so that the torch.utils.data.DataLoader can load the data according the index. 174 | # However, it's minimum change to switch to pytorch data loading. 175 | def __getitem__(self, index): 176 | """This function returns a tuple that is further passed to collate_fn 177 | """ 178 | ix = index #self.split_ix[index] 179 | return get_npy_data(ix, \ 180 | os.path.join(self.input_fc_dir, str(self.info['images'][ix]['id']) + '.npy'), 181 | os.path.join(self.input_att_dir, str(self.info['images'][ix]['id']) + '.npz'), 182 | self.use_att 183 | ) 184 | 185 | def __len__(self): 186 | return len(self.info['images']) 187 | 188 | class BlobFetcher(): 189 | """Experimental class for prefetching blobs in a separate process.""" 190 | def __init__(self, split, dataloader, if_shuffle=False): 191 | """ 192 | db is a list of tuples containing: imcrop_name, caption, bbox_feat of gt box, imname 193 | """ 194 | self.split = split 195 | self.dataloader = dataloader 196 | self.if_shuffle = if_shuffle 197 | 198 | # Add more in the queue 199 | def reset(self): 200 | """ 201 | Two cases: 202 | 1. not hasattr(self, 'split_loader'): Resume from previous training. Create the dataset given the saved split_ix and iterator 203 | 2. wrapped: a new epoch, the split_ix and iterator have been updated in the get_minibatch_inds already. 204 | """ 205 | # batch_size is 0, the merge is done in DataLoader class 206 | self.split_loader = iter(data.DataLoader(dataset=self.dataloader, 207 | batch_size=1, 208 | sampler=self.dataloader.split_ix[self.split][self.dataloader.iterators[self.split]:], 209 | shuffle=False, 210 | pin_memory=True, 211 | num_workers=multiprocessing.cpu_count(), 212 | collate_fn=lambda x: x[0])) 213 | 214 | def _get_next_minibatch_inds(self): 215 | max_index = len(self.dataloader.split_ix[self.split]) 216 | wrapped = False 217 | 218 | ri = self.dataloader.iterators[self.split] 219 | ix = self.dataloader.split_ix[self.split][ri] 220 | 221 | ri_next = ri + 1 222 | if ri_next >= max_index: 223 | ri_next = 0 224 | if self.if_shuffle: 225 | random.shuffle(self.dataloader.split_ix[self.split]) 226 | wrapped = True 227 | self.dataloader.iterators[self.split] = ri_next 228 | 229 | return ix, wrapped 230 | 231 | def get(self): 232 | if not hasattr(self, 'split_loader'): 233 | self.reset() 234 | 235 | ix, wrapped = self._get_next_minibatch_inds() 236 | tmp = self.split_loader.next() 237 | if wrapped: 238 | self.reset() 239 | 240 | assert tmp[2] == ix, "ix not equal" 241 | 242 | return tmp + [wrapped] -------------------------------------------------------------------------------- /dataloaderraw.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import json 6 | import h5py 7 | import os 8 | import numpy as np 9 | import random 10 | import torch 11 | from torch.autograd import Variable 12 | import skimage 13 | import skimage.io 14 | import scipy.misc 15 | 16 | from torchvision import transforms as trn 17 | preprocess = trn.Compose([ 18 | #trn.ToTensor(), 19 | trn.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 20 | ]) 21 | 22 | from misc.resnet_utils import myResnet 23 | import misc.resnet 24 | 25 | class DataLoaderRaw(): 26 | 27 | def __init__(self, opt): 28 | self.opt = opt 29 | self.coco_json = opt.get('coco_json', '') 30 | self.folder_path = opt.get('folder_path', '') 31 | 32 | self.batch_size = opt.get('batch_size', 1) 33 | self.seq_per_img = 1 34 | 35 | # Load resnet 36 | self.cnn_model = opt.get('cnn_model', 'resnet101') 37 | self.my_resnet = getattr(misc.resnet, self.cnn_model)() 38 | self.my_resnet.load_state_dict(torch.load('./data/imagenet_weights/'+self.cnn_model+'.pth')) 39 | self.my_resnet = myResnet(self.my_resnet) 40 | self.my_resnet.cuda() 41 | self.my_resnet.eval() 42 | 43 | 44 | 45 | # load the json file which contains additional information about the dataset 46 | print('DataLoaderRaw loading images from folder: ', self.folder_path) 47 | 48 | self.files = [] 49 | self.ids = [] 50 | 51 | print(len(self.coco_json)) 52 | if len(self.coco_json) > 0: 53 | print('reading from ' + opt.coco_json) 54 | # read in filenames from the coco-style json file 55 | self.coco_annotation = json.load(open(self.coco_json)) 56 | for k,v in enumerate(self.coco_annotation['images']): 57 | fullpath = os.path.join(self.folder_path, v['file_name']) 58 | self.files.append(fullpath) 59 | self.ids.append(v['id']) 60 | else: 61 | # read in all the filenames from the folder 62 | print('listing all images in directory ' + self.folder_path) 63 | def isImage(f): 64 | supportedExt = ['.jpg','.JPG','.jpeg','.JPEG','.png','.PNG','.ppm','.PPM'] 65 | for ext in supportedExt: 66 | start_idx = f.rfind(ext) 67 | if start_idx >= 0 and start_idx + len(ext) == len(f): 68 | return True 69 | return False 70 | 71 | n = 1 72 | for root, dirs, files in os.walk(self.folder_path, topdown=False): 73 | for file in files: 74 | fullpath = os.path.join(self.folder_path, file) 75 | if isImage(fullpath): 76 | self.files.append(fullpath) 77 | self.ids.append(str(n)) # just order them sequentially 78 | n = n + 1 79 | 80 | self.N = len(self.files) 81 | print('DataLoaderRaw found ', self.N, ' images') 82 | 83 | self.iterator = 0 84 | 85 | def get_batch(self, split, batch_size=None): 86 | batch_size = batch_size or self.batch_size 87 | 88 | # pick an index of the datapoint to load next 89 | fc_batch = np.ndarray((batch_size, 2048), dtype = 'float32') 90 | att_batch = np.ndarray((batch_size, 14, 14, 2048), dtype = 'float32') 91 | max_index = self.N 92 | wrapped = False 93 | infos = [] 94 | 95 | for i in range(batch_size): 96 | ri = self.iterator 97 | ri_next = ri + 1 98 | if ri_next >= max_index: 99 | ri_next = 0 100 | wrapped = True 101 | # wrap back around 102 | self.iterator = ri_next 103 | 104 | img = skimage.io.imread(self.files[ri]) 105 | 106 | if len(img.shape) == 2: 107 | img = img[:,:,np.newaxis] 108 | img = np.concatenate((img, img, img), axis=2) 109 | 110 | img = img.astype('float32')/255.0 111 | img = torch.from_numpy(img.transpose([2,0,1])).cuda() 112 | img = Variable(preprocess(img), volatile=True) 113 | tmp_fc, tmp_att = self.my_resnet(img) 114 | 115 | fc_batch[i] = tmp_fc.data.cpu().float().numpy() 116 | att_batch[i] = tmp_att.data.cpu().float().numpy() 117 | 118 | info_struct = {} 119 | info_struct['id'] = self.ids[ri] 120 | info_struct['file_path'] = self.files[ri] 121 | infos.append(info_struct) 122 | 123 | data = {} 124 | data['fc_feats'] = fc_batch 125 | data['att_feats'] = att_batch 126 | data['bounds'] = {'it_pos_now': self.iterator, 'it_max': self.N, 'wrapped': wrapped} 127 | data['infos'] = infos 128 | 129 | return data 130 | 131 | def reset_iterator(self, split): 132 | self.iterator = 0 133 | 134 | def get_vocab_size(self): 135 | return len(self.ix_to_word) 136 | 137 | def get_vocab(self): 138 | return self.ix_to_word 139 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import json 6 | import numpy as np 7 | 8 | import time 9 | import os 10 | from six.moves import cPickle 11 | 12 | import opts 13 | import models 14 | from dataloader import * 15 | from dataloaderraw import * 16 | import eval_utils 17 | import argparse 18 | import misc.utils as utils 19 | import torch 20 | 21 | # Input arguments and options 22 | parser = argparse.ArgumentParser() 23 | # Input paths 24 | parser.add_argument('--model', type=str, default='', 25 | help='path to model to evaluate') 26 | parser.add_argument('--cnn_model', type=str, default='resnet101', 27 | help='resnet101, resnet152') 28 | parser.add_argument('--infos_path', type=str, default='', 29 | help='path to infos to evaluate') 30 | # Basic options 31 | parser.add_argument('--batch_size', type=int, default=0, 32 | help='if > 0 then overrule, otherwise load from checkpoint.') 33 | parser.add_argument('--num_images', type=int, default=-1, 34 | help='how many images to use when periodically evaluating the loss? (-1 = all)') 35 | parser.add_argument('--language_eval', type=int, default=0, 36 | help='Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.') 37 | parser.add_argument('--dump_images', type=int, default=1, 38 | help='Dump images into vis/imgs folder for vis? (1=yes,0=no)') 39 | parser.add_argument('--dump_json', type=int, default=1, 40 | help='Dump json with predictions into vis folder? (1=yes,0=no)') 41 | parser.add_argument('--dump_path', type=int, default=0, 42 | help='Write image paths along with predictions into vis json? (1=yes,0=no)') 43 | 44 | # Sampling options 45 | parser.add_argument('--sample_max', type=int, default=1, 46 | help='1 = sample argmax words. 0 = sample from distributions.') 47 | parser.add_argument('--beam_size', type=int, default=2, 48 | help='used when sample_max = 1, indicates number of beams in beam search. Usually 2 or 3 works well. More is not better. Set this to 1 for faster runtime but a bit worse performance.') 49 | parser.add_argument('--temperature', type=float, default=1.0, 50 | help='temperature when sampling from distributions (i.e. when sample_max = 0). Lower = "safer" predictions.') 51 | # For evaluation on a folder of images: 52 | parser.add_argument('--image_folder', type=str, default='', 53 | help='If this is nonempty then will predict on the images in this folder path') 54 | parser.add_argument('--image_root', type=str, default='', 55 | help='In case the image paths have to be preprended with a root path to an image folder') 56 | # For evaluation on MSCOCO images from some split: 57 | parser.add_argument('--input_fc_dir', type=str, default='', 58 | help='path to the h5file containing the preprocessed dataset') 59 | parser.add_argument('--input_att_dir', type=str, default='', 60 | help='path to the h5file containing the preprocessed dataset') 61 | parser.add_argument('--input_label_h5', type=str, default='', 62 | help='path to the h5file containing the preprocessed dataset') 63 | parser.add_argument('--input_json', type=str, default='', 64 | help='path to the json file containing additional info and vocab. empty = fetch from model checkpoint.') 65 | parser.add_argument('--split', type=str, default='test', 66 | help='if running on MSCOCO images, which split to use: val|test|train') 67 | parser.add_argument('--coco_json', type=str, default='', 68 | help='if nonempty then use this file in DataLoaderRaw (see docs there). Used only in MSCOCO test evaluation, where we have a specific json file of only test set images.') 69 | # misc 70 | parser.add_argument('--id', type=str, default='', 71 | help='an id identifying this run/job. used only if language_eval = 1 for appending to intermediate files') 72 | 73 | opt = parser.parse_args() 74 | 75 | # Load infos 76 | with open(opt.infos_path) as f: 77 | infos = cPickle.load(f) 78 | 79 | # override and collect parameters 80 | if len(opt.input_fc_dir) == 0: 81 | opt.input_fc_dir = infos['opt'].input_fc_dir 82 | opt.input_att_dir = infos['opt'].input_att_dir 83 | opt.input_label_h5 = infos['opt'].input_label_h5 84 | if len(opt.input_json) == 0: 85 | opt.input_json = infos['opt'].input_json 86 | if opt.batch_size == 0: 87 | opt.batch_size = infos['opt'].batch_size 88 | if len(opt.id) == 0: 89 | opt.id = infos['opt'].id 90 | ignore = ["id", "batch_size", "beam_size", "start_from", "language_eval"] 91 | for k in vars(infos['opt']).keys(): 92 | if k not in ignore: 93 | if k in vars(opt): 94 | assert vars(opt)[k] == vars(infos['opt'])[k], k + ' option not consistent' 95 | else: 96 | vars(opt).update({k: vars(infos['opt'])[k]}) # copy over options from model 97 | 98 | vocab = infos['vocab'] # ix -> word mapping 99 | 100 | # Setup the model 101 | model = models.setup(opt) 102 | model.load_state_dict(torch.load(opt.model)) 103 | model.cuda() 104 | model.eval() 105 | crit = utils.LanguageModelCriterion() 106 | 107 | # Create the Data Loader instance 108 | if len(opt.image_folder) == 0: 109 | loader = DataLoader(opt) 110 | else: 111 | loader = DataLoaderRaw({'folder_path': opt.image_folder, 112 | 'coco_json': opt.coco_json, 113 | 'batch_size': opt.batch_size, 114 | 'cnn_model': opt.cnn_model}) 115 | # When eval using provided pretrained model, the vocab may be different from what you have in your cocotalk.json 116 | # So make sure to use the vocab in infos file. 117 | loader.ix_to_word = infos['vocab'] 118 | 119 | 120 | # Set sample options 121 | loss, split_predictions, lang_stats = eval_utils.eval_split(model, crit, loader, 122 | vars(opt)) 123 | 124 | print('loss: ', loss) 125 | if lang_stats: 126 | print(lang_stats) 127 | 128 | if opt.dump_json == 1: 129 | # dump the json 130 | json.dump(split_predictions, open('vis/vis.json', 'w')) 131 | -------------------------------------------------------------------------------- /eval_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.autograd import Variable 8 | 9 | import numpy as np 10 | import json 11 | from json import encoder 12 | import random 13 | import string 14 | import time 15 | import os 16 | import sys 17 | import misc.utils as utils 18 | 19 | def language_eval(dataset, preds, model_id, split): 20 | import sys 21 | sys.path.append("coco-caption") 22 | annFile = 'coco-caption/annotations/captions_val2014.json' 23 | from pycocotools.coco import COCO 24 | from pycocoevalcap.eval import COCOEvalCap 25 | 26 | encoder.FLOAT_REPR = lambda o: format(o, '.3f') 27 | 28 | if not os.path.isdir('eval_results'): 29 | os.mkdir('eval_results') 30 | cache_path = os.path.join('eval_results/', model_id + '_' + split + '.json') 31 | 32 | coco = COCO(annFile) 33 | valids = coco.getImgIds() 34 | 35 | # filter results to only those in MSCOCO validation set (will be about a third) 36 | preds_filt = [p for p in preds if p['image_id'] in valids] 37 | print('using %d/%d predictions' % (len(preds_filt), len(preds))) 38 | json.dump(preds_filt, open(cache_path, 'w')) # serialize to temporary json file. Sigh, COCO API... 39 | 40 | cocoRes = coco.loadRes(cache_path) 41 | cocoEval = COCOEvalCap(coco, cocoRes) 42 | cocoEval.params['image_id'] = cocoRes.getImgIds() 43 | cocoEval.evaluate() 44 | 45 | # create output dictionary 46 | out = {} 47 | for metric, score in cocoEval.eval.items(): 48 | out[metric] = score 49 | 50 | imgToEval = cocoEval.imgToEval 51 | for p in preds_filt: 52 | image_id, caption = p['image_id'], p['caption'] 53 | imgToEval[image_id]['caption'] = caption 54 | with open(cache_path, 'w') as outfile: 55 | json.dump({'overall': out, 'imgToEval': imgToEval}, outfile) 56 | 57 | return out 58 | 59 | def eval_split(model, crit, loader, eval_kwargs={}): 60 | verbose = eval_kwargs.get('verbose', True) 61 | num_images = eval_kwargs.get('num_images', eval_kwargs.get('val_images_use', -1)) 62 | split = eval_kwargs.get('split', 'val') 63 | lang_eval = eval_kwargs.get('language_eval', 0) 64 | dataset = eval_kwargs.get('dataset', 'coco') 65 | beam_size = eval_kwargs.get('beam_size', 1) 66 | 67 | # Make sure in the evaluation mode 68 | model.eval() 69 | 70 | loader.reset_iterator(split) 71 | 72 | n = 0 73 | loss = 0 74 | loss_sum = 0 75 | loss_evals = 1e-8 76 | predictions = [] 77 | while True: 78 | data = loader.get_batch(split) 79 | n = n + loader.batch_size 80 | 81 | if data.get('labels', None) is not None: 82 | # forward the model to get loss 83 | tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks']] 84 | tmp = [Variable(torch.from_numpy(_), volatile=True).cuda() for _ in tmp] 85 | fc_feats, att_feats, labels, masks = tmp 86 | 87 | loss = crit(model(fc_feats, att_feats, labels), labels[:,1:], masks[:,1:]).data[0] 88 | loss_sum = loss_sum + loss 89 | loss_evals = loss_evals + 1 90 | 91 | # forward the model to also get generated samples for each image 92 | # Only leave one feature for each image, in case duplicate sample 93 | tmp = [data['fc_feats'][np.arange(loader.batch_size) * loader.seq_per_img], 94 | data['att_feats'][np.arange(loader.batch_size) * loader.seq_per_img]] 95 | tmp = [Variable(torch.from_numpy(_), volatile=True).cuda() for _ in tmp] 96 | fc_feats, att_feats = tmp 97 | # forward the model to also get generated samples for each image 98 | seq, _ = model.sample(fc_feats, att_feats, eval_kwargs) 99 | 100 | #set_trace() 101 | sents = utils.decode_sequence(loader.get_vocab(), seq) 102 | 103 | for k, sent in enumerate(sents): 104 | entry = {'image_id': data['infos'][k]['id'], 'caption': sent} 105 | if eval_kwargs.get('dump_path', 0) == 1: 106 | entry['file_name'] = data['infos'][k]['file_path'] 107 | predictions.append(entry) 108 | if eval_kwargs.get('dump_images', 0) == 1: 109 | # dump the raw image to vis/ folder 110 | cmd = 'cp "' + os.path.join(eval_kwargs['image_root'], data['infos'][k]['file_path']) + '" vis/imgs/img' + str(len(predictions)) + '.jpg' # bit gross 111 | print(cmd) 112 | os.system(cmd) 113 | 114 | if verbose: 115 | print('image %s: %s' %(entry['image_id'], entry['caption'])) 116 | 117 | # if we wrapped around the split or used up val imgs budget then bail 118 | ix0 = data['bounds']['it_pos_now'] 119 | ix1 = data['bounds']['it_max'] 120 | if num_images != -1: 121 | ix1 = min(ix1, num_images) 122 | for i in range(n - ix1): 123 | predictions.pop() 124 | 125 | if verbose: 126 | print('evaluating validation preformance... %d/%d (%f)' %(ix0 - 1, ix1, loss)) 127 | 128 | if data['bounds']['wrapped']: 129 | break 130 | if num_images >= 0 and n >= num_images: 131 | break 132 | 133 | lang_stats = None 134 | if lang_eval == 1: 135 | lang_stats = language_eval(dataset, predictions, eval_kwargs['id'], split) 136 | 137 | # Switch back to training mode 138 | model.train() 139 | return loss_sum/loss_evals, predictions, lang_stats 140 | -------------------------------------------------------------------------------- /misc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALISCIFP/Medical-Image-Caption/e94a3a88834219ff7bdb05914e25c868e33fa4e6/misc/__init__.py -------------------------------------------------------------------------------- /misc/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | 6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 7 | 'resnet152'] 8 | 9 | 10 | model_urls = { 11 | 'resnet18': 'https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', 12 | 'resnet34': 'https://s3.amazonaws.com/pytorch/models/resnet34-333f7ec4.pth', 13 | 'resnet50': 'https://s3.amazonaws.com/pytorch/models/resnet50-19c8e357.pth', 14 | 'resnet101': 'https://s3.amazonaws.com/pytorch/models/resnet101-5d3b4d8f.pth', 15 | 'resnet152': 'https://s3.amazonaws.com/pytorch/models/resnet152-b121ed2d.pth', 16 | } 17 | 18 | 19 | def conv3x3(in_planes, out_planes, stride=1): 20 | "3x3 convolution with padding" 21 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 22 | padding=1, bias=False) 23 | 24 | 25 | class BasicBlock(nn.Module): 26 | expansion = 1 27 | 28 | def __init__(self, inplanes, planes, stride=1, downsample=None): 29 | super(BasicBlock, self).__init__() 30 | self.conv1 = conv3x3(inplanes, planes, stride) 31 | self.bn1 = nn.BatchNorm2d(planes) 32 | self.relu = nn.ReLU(inplace=True) 33 | self.conv2 = conv3x3(planes, planes) 34 | self.bn2 = nn.BatchNorm2d(planes) 35 | self.downsample = downsample 36 | self.stride = stride 37 | 38 | def forward(self, x): 39 | residual = x 40 | 41 | out = self.conv1(x) 42 | out = self.bn1(out) 43 | out = self.relu(out) 44 | 45 | out = self.conv2(out) 46 | out = self.bn2(out) 47 | 48 | if self.downsample is not None: 49 | residual = self.downsample(x) 50 | 51 | out += residual 52 | out = self.relu(out) 53 | 54 | return out 55 | 56 | 57 | class Bottleneck(nn.Module): 58 | expansion = 4 59 | 60 | def __init__(self, inplanes, planes, stride=1, downsample=None): 61 | super(Bottleneck, self).__init__() 62 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False) # change 63 | self.bn1 = nn.BatchNorm2d(planes) 64 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, # change 65 | padding=1, bias=False) 66 | self.bn2 = nn.BatchNorm2d(planes) 67 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 68 | self.bn3 = nn.BatchNorm2d(planes * 4) 69 | self.relu = nn.ReLU(inplace=True) 70 | self.downsample = downsample 71 | self.stride = stride 72 | 73 | def forward(self, x): 74 | residual = x 75 | 76 | out = self.conv1(x) 77 | out = self.bn1(out) 78 | out = self.relu(out) 79 | 80 | out = self.conv2(out) 81 | out = self.bn2(out) 82 | out = self.relu(out) 83 | 84 | out = self.conv3(out) 85 | out = self.bn3(out) 86 | 87 | if self.downsample is not None: 88 | residual = self.downsample(x) 89 | 90 | out += residual 91 | out = self.relu(out) 92 | 93 | return out 94 | 95 | 96 | class ResNet(nn.Module): 97 | def __init__(self, block, layers, num_classes=1000): 98 | self.inplanes = 64 99 | super(ResNet, self).__init__() 100 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 101 | bias=False) 102 | self.bn1 = nn.BatchNorm2d(64) 103 | self.relu = nn.ReLU(inplace=True) 104 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) # change 105 | self.layer1 = self._make_layer(block, 64, layers[0]) 106 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 107 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 108 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 109 | self.avgpool = nn.AvgPool2d(7) 110 | self.fc = nn.Linear(512 * block.expansion, num_classes) 111 | 112 | for m in self.modules(): 113 | if isinstance(m, nn.Conv2d): 114 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 115 | m.weight.data.normal_(0, math.sqrt(2. / n)) 116 | elif isinstance(m, nn.BatchNorm2d): 117 | m.weight.data.fill_(1) 118 | m.bias.data.zero_() 119 | 120 | def _make_layer(self, block, planes, blocks, stride=1): 121 | downsample = None 122 | if stride != 1 or self.inplanes != planes * block.expansion: 123 | downsample = nn.Sequential( 124 | nn.Conv2d(self.inplanes, planes * block.expansion, 125 | kernel_size=1, stride=stride, bias=False), 126 | nn.BatchNorm2d(planes * block.expansion), 127 | ) 128 | 129 | layers = [] 130 | layers.append(block(self.inplanes, planes, stride, downsample)) 131 | self.inplanes = planes * block.expansion 132 | for i in range(1, blocks): 133 | layers.append(block(self.inplanes, planes)) 134 | 135 | return nn.Sequential(*layers) 136 | 137 | def forward(self, x): 138 | x = self.conv1(x) 139 | x = self.bn1(x) 140 | x = self.relu(x) 141 | x = self.maxpool(x) 142 | 143 | x = self.layer1(x) 144 | x = self.layer2(x) 145 | x = self.layer3(x) 146 | x = self.layer4(x) 147 | 148 | x = self.avgpool(x) 149 | x = x.view(x.size(0), -1) 150 | x = self.fc(x) 151 | 152 | return x 153 | 154 | 155 | def resnet18(pretrained=False): 156 | """Constructs a ResNet-18 model. 157 | 158 | Args: 159 | pretrained (bool): If True, returns a model pre-trained on ImageNet 160 | """ 161 | model = ResNet(BasicBlock, [2, 2, 2, 2]) 162 | if pretrained: 163 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 164 | return model 165 | 166 | 167 | def resnet34(pretrained=False): 168 | """Constructs a ResNet-34 model. 169 | 170 | Args: 171 | pretrained (bool): If True, returns a model pre-trained on ImageNet 172 | """ 173 | model = ResNet(BasicBlock, [3, 4, 6, 3]) 174 | if pretrained: 175 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 176 | return model 177 | 178 | 179 | def resnet50(pretrained=False): 180 | """Constructs a ResNet-50 model. 181 | 182 | Args: 183 | pretrained (bool): If True, returns a model pre-trained on ImageNet 184 | """ 185 | model = ResNet(Bottleneck, [3, 4, 6, 3]) 186 | if pretrained: 187 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 188 | return model 189 | 190 | 191 | def resnet101(pretrained=False): 192 | """Constructs a ResNet-101 model. 193 | 194 | Args: 195 | pretrained (bool): If True, returns a model pre-trained on ImageNet 196 | """ 197 | model = ResNet(Bottleneck, [3, 4, 23, 3]) 198 | if pretrained: 199 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 200 | return model 201 | 202 | 203 | def resnet152(pretrained=False): 204 | """Constructs a ResNet-152 model. 205 | 206 | Args: 207 | pretrained (bool): If True, returns a model pre-trained on ImageNet 208 | """ 209 | model = ResNet(Bottleneck, [3, 8, 36, 3]) 210 | if pretrained: 211 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 212 | return model -------------------------------------------------------------------------------- /misc/resnet_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import torch.nn.functional as F 5 | 6 | class myResnet(nn.Module): 7 | def __init__(self, resnet): 8 | super(myResnet, self).__init__() 9 | self.resnet = resnet 10 | 11 | def forward(self, img, att_size=14): 12 | x = img.unsqueeze(0) 13 | 14 | x = self.resnet.conv1(x) 15 | x = self.resnet.bn1(x) 16 | x = self.resnet.relu(x) 17 | x = self.resnet.maxpool(x) 18 | 19 | x = self.resnet.layer1(x) 20 | x = self.resnet.layer2(x) 21 | x = self.resnet.layer3(x) 22 | x = self.resnet.layer4(x) 23 | 24 | fc = x.mean(3).mean(2).squeeze() 25 | att = F.adaptive_avg_pool2d(x,[att_size,att_size]).squeeze().permute(1, 2, 0) 26 | 27 | return fc, att 28 | 29 | -------------------------------------------------------------------------------- /misc/rewards.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | import time 7 | import misc.utils as utils 8 | from collections import OrderedDict 9 | import torch 10 | from torch.autograd import Variable 11 | 12 | import sys 13 | sys.path.append("cider") 14 | from pyciderevalcap.ciderD.ciderD import CiderD 15 | 16 | CiderD_scorer = None 17 | #CiderD_scorer = CiderD(df='corpus') 18 | 19 | def init_cider_scorer(cached_tokens): 20 | global CiderD_scorer 21 | CiderD_scorer = CiderD_scorer or CiderD(df=cached_tokens) 22 | 23 | def array_to_str(arr): 24 | out = '' 25 | for i in range(len(arr)): 26 | out += str(arr[i]) + ' ' 27 | if arr[i] == 0: 28 | break 29 | return out.strip() 30 | 31 | def get_self_critical_reward(model, fc_feats, att_feats, data, gen_result): 32 | batch_size = gen_result.size(0)# batch_size = sample_size * seq_per_img 33 | seq_per_img = batch_size // len(data['gts']) 34 | 35 | # get greedy decoding baseline 36 | greedy_res, _ = model.sample(Variable(fc_feats.data, volatile=True), Variable(att_feats.data, volatile=True)) 37 | 38 | res = OrderedDict() 39 | 40 | gen_result = gen_result.cpu().numpy() 41 | greedy_res = greedy_res.cpu().numpy() 42 | for i in range(batch_size): 43 | res[i] = [array_to_str(gen_result[i])] 44 | for i in range(batch_size): 45 | res[batch_size + i] = [array_to_str(greedy_res[i])] 46 | 47 | gts = OrderedDict() 48 | for i in range(len(data['gts'])): 49 | gts[i] = [array_to_str(data['gts'][i][j]) for j in range(len(data['gts'][i]))] 50 | 51 | #_, scores = Bleu(4).compute_score(gts, res) 52 | #scores = np.array(scores[3]) 53 | res = [{'image_id':i, 'caption': res[i]} for i in range(2 * batch_size)] 54 | gts = {i: gts[i % batch_size // seq_per_img] for i in range(2 * batch_size)} 55 | _, scores = CiderD_scorer.compute_score(gts, res) 56 | print('Cider scores:', _) 57 | 58 | scores = scores[:batch_size] - scores[batch_size:] 59 | 60 | rewards = np.repeat(scores[:, np.newaxis], gen_result.shape[1], 1) 61 | 62 | return rewards 63 | -------------------------------------------------------------------------------- /misc/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import collections 6 | import torch 7 | import torch.nn as nn 8 | from torch.autograd import Variable 9 | import numpy as np 10 | 11 | def if_use_att(caption_model): 12 | # Decide if load attention feature according to caption model 13 | if caption_model in ['show_tell', 'all_img', 'fc']: 14 | return False 15 | return True 16 | 17 | # Input: seq, N*D numpy array, with element 0 .. vocab_size. 0 is END token. 18 | def decode_sequence(ix_to_word, seq): 19 | N, D = seq.size() 20 | out = [] 21 | for i in range(N): 22 | txt = '' 23 | for j in range(D): 24 | ix = seq[i,j] 25 | if ix > 0 : 26 | if j >= 1: 27 | txt = txt + ' ' 28 | txt = txt + ix_to_word[str(ix)] 29 | else: 30 | break 31 | out.append(txt) 32 | return out 33 | 34 | def to_contiguous(tensor): 35 | if tensor.is_contiguous(): 36 | return tensor 37 | else: 38 | return tensor.contiguous() 39 | 40 | class RewardCriterion(nn.Module): 41 | def __init__(self): 42 | super(RewardCriterion, self).__init__() 43 | 44 | def forward(self, input, seq, reward): 45 | input = to_contiguous(input).view(-1) 46 | reward = to_contiguous(reward).view(-1) 47 | mask = (seq>0).float() 48 | mask = to_contiguous(torch.cat([mask.new(mask.size(0), 1).fill_(1), mask[:, :-1]], 1)).view(-1) 49 | output = - input * reward * Variable(mask) 50 | output = torch.sum(output) / torch.sum(mask) 51 | 52 | return output 53 | class LanguageModelCriterion(nn.Module): 54 | def __init__(self): 55 | super(LanguageModelCriterion, self).__init__() 56 | 57 | def forward(self, input, target, mask): 58 | # truncate to the same size 59 | target = target[:, :input.size(1)] 60 | mask = mask[:, :input.size(1)] 61 | input = to_contiguous(input).view(-1, input.size(2)) 62 | target = to_contiguous(target).view(-1, 1) 63 | mask = to_contiguous(mask).view(-1, 1) 64 | output = - input.gather(1, target) * mask 65 | output = torch.sum(output) / torch.sum(mask) 66 | 67 | return output 68 | 69 | def set_lr(optimizer, lr): 70 | for group in optimizer.param_groups: 71 | group['lr'] = lr 72 | 73 | def clip_gradient(optimizer, grad_clip): 74 | for group in optimizer.param_groups: 75 | for param in group['params']: 76 | param.grad.data.clamp_(-grad_clip, grad_clip) -------------------------------------------------------------------------------- /models/Att2inModel.py: -------------------------------------------------------------------------------- 1 | # This file contains att2in model 2 | # Att2in is from Self-critical Sequence Training for Image Captioning 3 | # https://arxiv.org/abs/1612.00563 4 | # In this file we only have Att2in2, which is a slightly different version of att2in, 5 | # in which the img feature embedding and word embedding is the same as what in adaatt. 6 | 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | from torch.autograd import * 16 | import misc.utils as utils 17 | 18 | from .CaptionModel import CaptionModel 19 | 20 | class Att2inCore(nn.Module): 21 | def __init__(self, opt): 22 | super(Att2inCore, self).__init__() 23 | self.input_encoding_size = opt.input_encoding_size 24 | #self.rnn_type = opt.rnn_type 25 | self.rnn_size = opt.rnn_size 26 | #self.num_layers = opt.num_layers 27 | self.drop_prob_lm = opt.drop_prob_lm 28 | self.fc_feat_size = opt.fc_feat_size 29 | self.att_feat_size = opt.att_feat_size 30 | self.att_hid_size = opt.att_hid_size 31 | 32 | # Build a LSTM 33 | self.a2c = nn.Linear(self.att_feat_size, 2 * self.rnn_size) 34 | self.i2h = nn.Linear(self.input_encoding_size, 5 * self.rnn_size) 35 | self.h2h = nn.Linear(self.rnn_size, 5 * self.rnn_size) 36 | self.dropout = nn.Dropout(self.drop_prob_lm) 37 | 38 | self.h2att = nn.Linear(self.rnn_size, self.att_hid_size) 39 | self.alpha_net = nn.Linear(self.att_hid_size, 1) 40 | 41 | def forward(self, xt, fc_feats, att_feats, p_att_feats, state): 42 | # The p_att_feats here is already projected 43 | att_size = att_feats.numel() // att_feats.size(0) // self.att_feat_size 44 | att = p_att_feats.view(-1, att_size, self.att_hid_size) 45 | 46 | att_h = self.h2att(state[0][-1]) # batch * att_hid_size 47 | att_h = att_h.unsqueeze(1).expand_as(att) # batch * att_size * att_hid_size 48 | dot = att + att_h # batch * att_size * att_hid_size 49 | dot = F.tanh(dot) # batch * att_size * att_hid_size 50 | dot = dot.view(-1, self.att_hid_size) # (batch * att_size) * att_hid_size 51 | dot = self.alpha_net(dot) # (batch * att_size) * 1 52 | dot = dot.view(-1, att_size) # batch * att_size 53 | 54 | weight = F.softmax(dot) # batch * att_size 55 | att_feats_ = att_feats.view(-1, att_size, self.att_feat_size) # batch * att_size * att_feat_size 56 | att_res = torch.bmm(weight.unsqueeze(1), att_feats_).squeeze(1) # batch * att_feat_size 57 | 58 | all_input_sums = self.i2h(xt) + self.h2h(state[0][-1]) 59 | sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size) 60 | sigmoid_chunk = F.sigmoid(sigmoid_chunk) 61 | in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size) 62 | forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size) 63 | out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size) 64 | 65 | in_transform = all_input_sums.narrow(1, 3 * self.rnn_size, 2 * self.rnn_size) + \ 66 | self.a2c(att_res) 67 | in_transform = torch.max(\ 68 | in_transform.narrow(1, 0, self.rnn_size), 69 | in_transform.narrow(1, self.rnn_size, self.rnn_size)) 70 | next_c = forget_gate * state[1][-1] + in_gate * in_transform 71 | next_h = out_gate * F.tanh(next_c) 72 | 73 | output = self.dropout(next_h) 74 | state = (next_h.unsqueeze(0), next_c.unsqueeze(0)) 75 | return output, state 76 | 77 | class Att2inModel(CaptionModel): 78 | def __init__(self, opt): 79 | super(Att2inModel, self).__init__() 80 | self.vocab_size = opt.vocab_size 81 | self.input_encoding_size = opt.input_encoding_size 82 | #self.rnn_type = opt.rnn_type 83 | self.rnn_size = opt.rnn_size 84 | self.num_layers = 1 85 | self.drop_prob_lm = opt.drop_prob_lm 86 | self.seq_length = opt.seq_length 87 | self.fc_feat_size = opt.fc_feat_size 88 | self.att_feat_size = opt.att_feat_size 89 | self.att_hid_size = opt.att_hid_size 90 | 91 | self.ss_prob = 0.0 # Schedule sampling probability 92 | 93 | self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size) 94 | self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1) 95 | self.ctx2att = nn.Linear(self.att_feat_size, self.att_hid_size) 96 | self.core = Att2inCore(opt) 97 | 98 | self.init_weights() 99 | 100 | def init_weights(self): 101 | initrange = 0.1 102 | self.embed.weight.data.uniform_(-initrange, initrange) 103 | self.logit.bias.data.fill_(0) 104 | self.logit.weight.data.uniform_(-initrange, initrange) 105 | 106 | def init_hidden(self, bsz): 107 | weight = next(self.parameters()).data 108 | return (Variable(weight.new(self.num_layers, bsz, self.rnn_size).zero_()), 109 | Variable(weight.new(self.num_layers, bsz, self.rnn_size).zero_())) 110 | 111 | def forward(self, fc_feats, att_feats, seq): 112 | batch_size = fc_feats.size(0) 113 | state = self.init_hidden(batch_size) 114 | 115 | outputs = [] 116 | 117 | # Project the attention feats first to reduce memory and computation comsumptions. 118 | p_att_feats = self.ctx2att(att_feats.view(-1, self.att_feat_size)) 119 | p_att_feats = p_att_feats.view(*(att_feats.size()[:-1] + (self.att_hid_size,))) 120 | 121 | for i in range(seq.size(1) - 1): 122 | if self.training and i >= 1 and self.ss_prob > 0.0: # otherwiste no need to sample 123 | sample_prob = fc_feats.data.new(batch_size).uniform_(0, 1) 124 | sample_mask = sample_prob < self.ss_prob 125 | if sample_mask.sum() == 0: 126 | it = seq[:, i].clone() 127 | else: 128 | sample_ind = sample_mask.nonzero().view(-1) 129 | it = seq[:, i].data.clone() 130 | #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1) 131 | #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1)) 132 | prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1) 133 | it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind)) 134 | it = Variable(it, requires_grad=False) 135 | else: 136 | it = seq[:, i].clone() 137 | # break if all the sequences end 138 | if i >= 1 and seq[:, i].data.sum() == 0: 139 | break 140 | 141 | xt = self.embed(it) 142 | 143 | output, state = self.core(xt, fc_feats, att_feats, p_att_feats, state) 144 | output = F.log_softmax(self.logit(output)) 145 | outputs.append(output) 146 | 147 | return torch.cat([_.unsqueeze(1) for _ in outputs], 1) 148 | 149 | def get_logprobs_state(self, it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, state): 150 | # 'it' is Variable contraining a word index 151 | xt = self.embed(it) 152 | 153 | output, state = self.core(xt, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, state) 154 | logprobs = F.log_softmax(self.logit(output)) 155 | 156 | return logprobs, state 157 | 158 | def sample_beam(self, fc_feats, att_feats, opt={}): 159 | beam_size = opt.get('beam_size', 10) 160 | batch_size = fc_feats.size(0) 161 | 162 | # Project the attention feats first to reduce memory and computation comsumptions. 163 | p_att_feats = self.ctx2att(att_feats.view(-1, self.att_feat_size)) 164 | p_att_feats = p_att_feats.view(*(att_feats.size()[:-1] + (self.att_hid_size,))) 165 | 166 | assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' 167 | seq = torch.LongTensor(self.seq_length, batch_size).zero_() 168 | seqLogprobs = torch.FloatTensor(self.seq_length, batch_size) 169 | # lets process every image independently for now, for simplicity 170 | 171 | self.done_beams = [[] for _ in range(batch_size)] 172 | for k in range(batch_size): 173 | state = self.init_hidden(beam_size) 174 | tmp_fc_feats = fc_feats[k:k+1].expand(beam_size, self.fc_feat_size) 175 | tmp_att_feats = att_feats[k:k+1].expand(*((beam_size,)+att_feats.size()[1:])).contiguous() 176 | tmp_p_att_feats = p_att_feats[k:k+1].expand(*((beam_size,)+p_att_feats.size()[1:])).contiguous() 177 | 178 | for t in range(1): 179 | if t == 0: # input 180 | it = fc_feats.data.new(beam_size).long().zero_() 181 | xt = self.embed(Variable(it, requires_grad=False)) 182 | 183 | output, state = self.core(xt, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, state) 184 | logprobs = F.log_softmax(self.logit(output)) 185 | 186 | self.done_beams[k] = self.beam_search(state, logprobs, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, opt=opt) 187 | seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score 188 | seqLogprobs[:, k] = self.done_beams[k][0]['logps'] 189 | # return the samples and their log likelihoods 190 | return seq.transpose(0, 1), seqLogprobs.transpose(0, 1) 191 | 192 | def sample(self, fc_feats, att_feats, opt={}): 193 | sample_max = opt.get('sample_max', 1) 194 | beam_size = opt.get('beam_size', 1) 195 | temperature = opt.get('temperature', 1.0) 196 | if beam_size > 1: 197 | return self.sample_beam(fc_feats, att_feats, opt) 198 | 199 | batch_size = fc_feats.size(0) 200 | state = self.init_hidden(batch_size) 201 | 202 | # Project the attention feats first to reduce memory and computation comsumptions. 203 | p_att_feats = self.ctx2att(att_feats.view(-1, self.att_feat_size)) 204 | p_att_feats = p_att_feats.view(*(att_feats.size()[:-1] + (self.att_hid_size,))) 205 | 206 | seq = [] 207 | seqLogprobs = [] 208 | for t in range(self.seq_length + 1): 209 | if t == 0: # input 210 | it = fc_feats.data.new(batch_size).long().zero_() 211 | elif sample_max: 212 | sampleLogprobs, it = torch.max(logprobs.data, 1) 213 | it = it.view(-1).long() 214 | else: 215 | if temperature == 1.0: 216 | prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1) 217 | else: 218 | # scale logprobs by temperature 219 | prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu() 220 | it = torch.multinomial(prob_prev, 1).cuda() 221 | sampleLogprobs = logprobs.gather(1, Variable(it, requires_grad=False)) # gather the logprobs at sampled positions 222 | it = it.view(-1).long() # and flatten indices for downstream processing 223 | 224 | xt = self.embed(Variable(it, requires_grad=False)) 225 | 226 | if t >= 1: 227 | # stop when all finished 228 | if t == 1: 229 | unfinished = it > 0 230 | else: 231 | unfinished = unfinished * (it > 0) 232 | if unfinished.sum() == 0: 233 | break 234 | it = it * unfinished.type_as(it) 235 | seq.append(it) #seq[t] the input of t+2 time step 236 | 237 | seqLogprobs.append(sampleLogprobs.view(-1)) 238 | 239 | output, state = self.core(xt, fc_feats, att_feats, p_att_feats, state) 240 | logprobs = F.log_softmax(self.logit(output)) 241 | 242 | return torch.cat([_.unsqueeze(1) for _ in seq], 1), torch.cat([_.unsqueeze(1) for _ in seqLogprobs], 1) -------------------------------------------------------------------------------- /models/AttModel.py: -------------------------------------------------------------------------------- 1 | # This file contains Att2in2, AdaAtt, AdaAttMO, TopDown model 2 | 3 | # AdaAtt is from Knowing When to Look: Adaptive Attention via A Visual Sentinel for Image Captioning 4 | # https://arxiv.org/abs/1612.01887 5 | # AdaAttMO is a modified version with maxout lstm 6 | 7 | # Att2in is from Self-critical Sequence Training for Image Captioning 8 | # https://arxiv.org/abs/1612.00563 9 | # In this file we only have Att2in2, which is a slightly different version of att2in, 10 | # in which the img feature embedding and word embedding is the same as what in adaatt. 11 | 12 | # TopDown is from Bottom-Up and Top-Down Attention for Image Captioning and VQA 13 | # https://arxiv.org/abs/1707.07998 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import torch 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | from torch.autograd import * 23 | import misc.utils as utils 24 | 25 | from .CaptionModel import CaptionModel 26 | 27 | class AttModel(CaptionModel): 28 | def __init__(self, opt): 29 | super(AttModel, self).__init__() 30 | self.vocab_size = opt.vocab_size 31 | self.input_encoding_size = opt.input_encoding_size 32 | #self.rnn_type = opt.rnn_type 33 | self.rnn_size = opt.rnn_size 34 | self.num_layers = opt.num_layers 35 | self.drop_prob_lm = opt.drop_prob_lm 36 | self.seq_length = opt.seq_length 37 | self.fc_feat_size = opt.fc_feat_size 38 | self.att_feat_size = opt.att_feat_size 39 | self.att_hid_size = opt.att_hid_size 40 | 41 | self.ss_prob = 0.0 # Schedule sampling probability 42 | 43 | self.embed = nn.Sequential(nn.Embedding(self.vocab_size + 1, self.input_encoding_size), 44 | nn.ReLU(), 45 | nn.Dropout(self.drop_prob_lm)) 46 | self.fc_embed = nn.Sequential(nn.Linear(self.fc_feat_size, self.rnn_size), 47 | nn.ReLU(), 48 | nn.Dropout(self.drop_prob_lm)) 49 | self.att_embed = nn.Sequential(nn.Linear(self.att_feat_size, self.rnn_size), 50 | nn.ReLU(), 51 | nn.Dropout(self.drop_prob_lm)) 52 | self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1) 53 | self.ctx2att = nn.Linear(self.rnn_size, self.att_hid_size) 54 | 55 | def init_hidden(self, bsz): 56 | weight = next(self.parameters()).data 57 | return (Variable(weight.new(self.num_layers, bsz, self.rnn_size).zero_()), 58 | Variable(weight.new(self.num_layers, bsz, self.rnn_size).zero_())) 59 | 60 | def forward(self, fc_feats, att_feats, seq): 61 | batch_size = fc_feats.size(0) 62 | state = self.init_hidden(batch_size) 63 | 64 | outputs = [] 65 | 66 | # embed fc and att feats 67 | fc_feats = self.fc_embed(fc_feats) 68 | _att_feats = self.att_embed(att_feats.view(-1, self.att_feat_size)) 69 | att_feats = _att_feats.view(*(att_feats.size()[:-1] + (self.rnn_size,))) 70 | 71 | # Project the attention feats first to reduce memory and computation comsumptions. 72 | p_att_feats = self.ctx2att(att_feats.view(-1, self.rnn_size)) 73 | p_att_feats = p_att_feats.view(*(att_feats.size()[:-1] + (self.att_hid_size,))) 74 | 75 | for i in range(seq.size(1) - 1): 76 | if self.training and i >= 1 and self.ss_prob > 0.0: # otherwiste no need to sample 77 | sample_prob = fc_feats.data.new(batch_size).uniform_(0, 1) 78 | sample_mask = sample_prob < self.ss_prob 79 | if sample_mask.sum() == 0: 80 | it = seq[:, i].clone() 81 | else: 82 | sample_ind = sample_mask.nonzero().view(-1) 83 | it = seq[:, i].data.clone() 84 | #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1) 85 | #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1)) 86 | prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1) 87 | it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind)) 88 | it = Variable(it, requires_grad=False) 89 | else: 90 | it = seq[:, i].clone() 91 | # break if all the sequences end 92 | if i >= 1 and seq[:, i].data.sum() == 0: 93 | break 94 | 95 | xt = self.embed(it) 96 | 97 | output, state = self.core(xt, fc_feats, att_feats, p_att_feats, state) 98 | output = F.log_softmax(self.logit(output)) 99 | outputs.append(output) 100 | 101 | return torch.cat([_.unsqueeze(1) for _ in outputs], 1) 102 | 103 | def get_logprobs_state(self, it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, state): 104 | # 'it' is Variable contraining a word index 105 | xt = self.embed(it) 106 | 107 | output, state = self.core(xt, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, state) 108 | logprobs = F.log_softmax(self.logit(output)) 109 | 110 | return logprobs, state 111 | 112 | def sample_beam(self, fc_feats, att_feats, opt={}): 113 | beam_size = opt.get('beam_size', 10) 114 | batch_size = fc_feats.size(0) 115 | 116 | # embed fc and att feats 117 | fc_feats = self.fc_embed(fc_feats) 118 | _att_feats = self.att_embed(att_feats.view(-1, self.att_feat_size)) 119 | att_feats = _att_feats.view(*(att_feats.size()[:-1] + (self.rnn_size,))) 120 | 121 | # Project the attention feats first to reduce memory and computation comsumptions. 122 | p_att_feats = self.ctx2att(att_feats.view(-1, self.rnn_size)) 123 | p_att_feats = p_att_feats.view(*(att_feats.size()[:-1] + (self.att_hid_size,))) 124 | 125 | assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' 126 | seq = torch.LongTensor(self.seq_length, batch_size).zero_() 127 | seqLogprobs = torch.FloatTensor(self.seq_length, batch_size) 128 | # lets process every image independently for now, for simplicity 129 | 130 | self.done_beams = [[] for _ in range(batch_size)] 131 | for k in range(batch_size): 132 | state = self.init_hidden(beam_size) 133 | tmp_fc_feats = fc_feats[k:k+1].expand(beam_size, fc_feats.size(1)) 134 | tmp_att_feats = att_feats[k:k+1].expand(*((beam_size,)+att_feats.size()[1:])).contiguous() 135 | tmp_p_att_feats = p_att_feats[k:k+1].expand(*((beam_size,)+p_att_feats.size()[1:])).contiguous() 136 | 137 | for t in range(1): 138 | if t == 0: # input 139 | it = fc_feats.data.new(beam_size).long().zero_() 140 | xt = self.embed(Variable(it, requires_grad=False)) 141 | 142 | output, state = self.core(xt, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, state) 143 | logprobs = F.log_softmax(self.logit(output)) 144 | 145 | self.done_beams[k] = self.beam_search(state, logprobs, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, opt=opt) 146 | seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score 147 | seqLogprobs[:, k] = self.done_beams[k][0]['logps'] 148 | # return the samples and their log likelihoods 149 | return seq.transpose(0, 1), seqLogprobs.transpose(0, 1) 150 | 151 | def sample(self, fc_feats, att_feats, opt={}): 152 | sample_max = opt.get('sample_max', 1) 153 | beam_size = opt.get('beam_size', 1) 154 | temperature = opt.get('temperature', 1.0) 155 | if beam_size > 1: 156 | return self.sample_beam(fc_feats, att_feats, opt) 157 | 158 | batch_size = fc_feats.size(0) 159 | state = self.init_hidden(batch_size) 160 | 161 | # embed fc and att feats 162 | fc_feats = self.fc_embed(fc_feats) 163 | _att_feats = self.att_embed(att_feats.view(-1, self.att_feat_size)) 164 | att_feats = _att_feats.view(*(att_feats.size()[:-1] + (self.rnn_size,))) 165 | 166 | # Project the attention feats first to reduce memory and computation comsumptions. 167 | p_att_feats = self.ctx2att(att_feats.view(-1, self.rnn_size)) 168 | p_att_feats = p_att_feats.view(*(att_feats.size()[:-1] + (self.att_hid_size,))) 169 | 170 | seq = [] 171 | seqLogprobs = [] 172 | for t in range(self.seq_length + 1): 173 | if t == 0: # input 174 | it = fc_feats.data.new(batch_size).long().zero_() 175 | elif sample_max: 176 | sampleLogprobs, it = torch.max(logprobs.data, 1) 177 | it = it.view(-1).long() 178 | else: 179 | if temperature == 1.0: 180 | prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1) 181 | else: 182 | # scale logprobs by temperature 183 | prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu() 184 | it = torch.multinomial(prob_prev, 1).cuda() 185 | sampleLogprobs = logprobs.gather(1, Variable(it, requires_grad=False)) # gather the logprobs at sampled positions 186 | it = it.view(-1).long() # and flatten indices for downstream processing 187 | 188 | xt = self.embed(Variable(it, requires_grad=False)) 189 | 190 | if t >= 1: 191 | # stop when all finished 192 | if t == 1: 193 | unfinished = it > 0 194 | else: 195 | unfinished = unfinished * (it > 0) 196 | if unfinished.sum() == 0: 197 | break 198 | it = it * unfinished.type_as(it) 199 | seq.append(it) #seq[t] the input of t+2 time step 200 | 201 | seqLogprobs.append(sampleLogprobs.view(-1)) 202 | 203 | output, state = self.core(xt, fc_feats, att_feats, p_att_feats, state) 204 | logprobs = F.log_softmax(self.logit(output)) 205 | 206 | return torch.cat([_.unsqueeze(1) for _ in seq], 1), torch.cat([_.unsqueeze(1) for _ in seqLogprobs], 1) 207 | 208 | class AdaAtt_lstm(nn.Module): 209 | def __init__(self, opt, use_maxout=True): 210 | super(AdaAtt_lstm, self).__init__() 211 | self.input_encoding_size = opt.input_encoding_size 212 | #self.rnn_type = opt.rnn_type 213 | self.rnn_size = opt.rnn_size 214 | self.num_layers = opt.num_layers 215 | self.drop_prob_lm = opt.drop_prob_lm 216 | self.fc_feat_size = opt.fc_feat_size 217 | self.att_feat_size = opt.att_feat_size 218 | self.att_hid_size = opt.att_hid_size 219 | 220 | self.use_maxout = use_maxout 221 | 222 | # Build a LSTM 223 | self.w2h = nn.Linear(self.input_encoding_size, (4+(use_maxout==True)) * self.rnn_size) 224 | self.v2h = nn.Linear(self.rnn_size, (4+(use_maxout==True)) * self.rnn_size) 225 | 226 | self.i2h = nn.ModuleList([nn.Linear(self.rnn_size, (4+(use_maxout==True)) * self.rnn_size) for _ in range(self.num_layers - 1)]) 227 | self.h2h = nn.ModuleList([nn.Linear(self.rnn_size, (4+(use_maxout==True)) * self.rnn_size) for _ in range(self.num_layers)]) 228 | 229 | # Layers for getting the fake region 230 | if self.num_layers == 1: 231 | self.r_w2h = nn.Linear(self.input_encoding_size, self.rnn_size) 232 | self.r_v2h = nn.Linear(self.rnn_size, self.rnn_size) 233 | else: 234 | self.r_i2h = nn.Linear(self.rnn_size, self.rnn_size) 235 | self.r_h2h = nn.Linear(self.rnn_size, self.rnn_size) 236 | 237 | 238 | def forward(self, xt, img_fc, state): 239 | 240 | hs = [] 241 | cs = [] 242 | for L in range(self.num_layers): 243 | # c,h from previous timesteps 244 | prev_h = state[0][L] 245 | prev_c = state[1][L] 246 | # the input to this layer 247 | if L == 0: 248 | x = xt 249 | i2h = self.w2h(x) + self.v2h(img_fc) 250 | else: 251 | x = hs[-1] 252 | x = F.dropout(x, self.drop_prob_lm, self.training) 253 | i2h = self.i2h[L-1](x) 254 | 255 | all_input_sums = i2h+self.h2h[L](prev_h) 256 | 257 | sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size) 258 | sigmoid_chunk = F.sigmoid(sigmoid_chunk) 259 | # decode the gates 260 | in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size) 261 | forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size) 262 | out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size) 263 | # decode the write inputs 264 | if not self.use_maxout: 265 | in_transform = F.tanh(all_input_sums.narrow(1, 3 * self.rnn_size, self.rnn_size)) 266 | else: 267 | in_transform = all_input_sums.narrow(1, 3 * self.rnn_size, 2 * self.rnn_size) 268 | in_transform = torch.max(\ 269 | in_transform.narrow(1, 0, self.rnn_size), 270 | in_transform.narrow(1, self.rnn_size, self.rnn_size)) 271 | # perform the LSTM update 272 | next_c = forget_gate * prev_c + in_gate * in_transform 273 | # gated cells form the output 274 | tanh_nex_c = F.tanh(next_c) 275 | next_h = out_gate * tanh_nex_c 276 | if L == self.num_layers-1: 277 | if L == 0: 278 | i2h = self.r_w2h(x) + self.r_v2h(img_fc) 279 | else: 280 | i2h = self.r_i2h(x) 281 | n5 = i2h+self.r_h2h(prev_h) 282 | fake_region = F.sigmoid(n5) * tanh_nex_c 283 | 284 | cs.append(next_c) 285 | hs.append(next_h) 286 | 287 | # set up the decoder 288 | top_h = hs[-1] 289 | top_h = F.dropout(top_h, self.drop_prob_lm, self.training) 290 | fake_region = F.dropout(fake_region, self.drop_prob_lm, self.training) 291 | 292 | state = (torch.cat([_.unsqueeze(0) for _ in hs], 0), 293 | torch.cat([_.unsqueeze(0) for _ in cs], 0)) 294 | return top_h, fake_region, state 295 | 296 | class AdaAtt_attention(nn.Module): 297 | def __init__(self, opt): 298 | super(AdaAtt_attention, self).__init__() 299 | self.input_encoding_size = opt.input_encoding_size 300 | #self.rnn_type = opt.rnn_type 301 | self.rnn_size = opt.rnn_size 302 | self.drop_prob_lm = opt.drop_prob_lm 303 | self.att_hid_size = opt.att_hid_size 304 | 305 | # fake region embed 306 | self.fr_linear = nn.Sequential( 307 | nn.Linear(self.rnn_size, self.input_encoding_size), 308 | nn.ReLU(), 309 | nn.Dropout(self.drop_prob_lm)) 310 | self.fr_embed = nn.Linear(self.input_encoding_size, self.att_hid_size) 311 | 312 | # h out embed 313 | self.ho_linear = nn.Sequential( 314 | nn.Linear(self.rnn_size, self.input_encoding_size), 315 | nn.Tanh(), 316 | nn.Dropout(self.drop_prob_lm)) 317 | self.ho_embed = nn.Linear(self.input_encoding_size, self.att_hid_size) 318 | 319 | self.alpha_net = nn.Linear(self.att_hid_size, 1) 320 | self.att2h = nn.Linear(self.rnn_size, self.rnn_size) 321 | 322 | def forward(self, h_out, fake_region, conv_feat, conv_feat_embed): 323 | 324 | # View into three dimensions 325 | att_size = conv_feat.numel() // conv_feat.size(0) // self.rnn_size 326 | conv_feat = conv_feat.view(-1, att_size, self.rnn_size) 327 | conv_feat_embed = conv_feat_embed.view(-1, att_size, self.att_hid_size) 328 | 329 | # view neighbor from bach_size * neighbor_num x rnn_size to bach_size x rnn_size * neighbor_num 330 | fake_region = self.fr_linear(fake_region) 331 | fake_region_embed = self.fr_embed(fake_region) 332 | 333 | h_out_linear = self.ho_linear(h_out) 334 | h_out_embed = self.ho_embed(h_out_linear) 335 | 336 | txt_replicate = h_out_embed.unsqueeze(1).expand(h_out_embed.size(0), att_size + 1, h_out_embed.size(1)) 337 | 338 | img_all = torch.cat([fake_region.view(-1,1,self.input_encoding_size), conv_feat], 1) 339 | img_all_embed = torch.cat([fake_region_embed.view(-1,1,self.input_encoding_size), conv_feat_embed], 1) 340 | 341 | hA = F.tanh(img_all_embed + txt_replicate) 342 | hA = F.dropout(hA,self.drop_prob_lm, self.training) 343 | 344 | hAflat = self.alpha_net(hA.view(-1, self.att_hid_size)) 345 | PI = F.softmax(hAflat.view(-1, att_size + 1)) 346 | 347 | visAtt = torch.bmm(PI.unsqueeze(1), img_all) 348 | visAttdim = visAtt.squeeze(1) 349 | 350 | atten_out = visAttdim + h_out_linear 351 | 352 | h = F.tanh(self.att2h(atten_out)) 353 | h = F.dropout(h, self.drop_prob_lm, self.training) 354 | return h 355 | 356 | class AdaAttCore(nn.Module): 357 | def __init__(self, opt, use_maxout=False): 358 | super(AdaAttCore, self).__init__() 359 | self.lstm = AdaAtt_lstm(opt, use_maxout) 360 | self.attention = AdaAtt_attention(opt) 361 | 362 | def forward(self, xt, fc_feats, att_feats, p_att_feats, state): 363 | h_out, p_out, state = self.lstm(xt, fc_feats, state) 364 | atten_out = self.attention(h_out, p_out, att_feats, p_att_feats) 365 | return atten_out, state 366 | 367 | class TopDownCore(nn.Module): 368 | def __init__(self, opt, use_maxout=False): 369 | super(TopDownCore, self).__init__() 370 | self.drop_prob_lm = opt.drop_prob_lm 371 | 372 | self.att_lstm = nn.LSTMCell(opt.input_encoding_size + opt.rnn_size * 2, opt.rnn_size) # we, fc, h^2_t-1 373 | self.lang_lstm = nn.LSTMCell(opt.rnn_size * 2, opt.rnn_size) # h^1_t, \hat v 374 | self.attention = Attention(opt) 375 | 376 | def forward(self, xt, fc_feats, att_feats, p_att_feats, state): 377 | prev_h = state[0][-1] 378 | att_lstm_input = torch.cat([prev_h, fc_feats, xt], 1) 379 | 380 | h_att, c_att = self.att_lstm(att_lstm_input, (state[0][0], state[1][0])) 381 | 382 | att = self.attention(h_att, att_feats, p_att_feats) 383 | 384 | lang_lstm_input = torch.cat([att, h_att], 1) 385 | # lang_lstm_input = torch.cat([att, F.dropout(h_att, self.drop_prob_lm, self.training)], 1) ????? 386 | 387 | h_lang, c_lang = self.lang_lstm(lang_lstm_input, (state[0][1], state[1][1])) 388 | 389 | output = F.dropout(h_lang, self.drop_prob_lm, self.training) 390 | state = (torch.stack([h_att, h_lang]), torch.stack([c_att, c_lang])) 391 | 392 | return output, state 393 | 394 | class Attention(nn.Module): 395 | def __init__(self, opt): 396 | super(Attention, self).__init__() 397 | self.rnn_size = opt.rnn_size 398 | self.att_hid_size = opt.att_hid_size 399 | 400 | self.h2att = nn.Linear(self.rnn_size, self.att_hid_size) 401 | self.alpha_net = nn.Linear(self.att_hid_size, 1) 402 | 403 | def forward(self, h, att_feats, p_att_feats): 404 | # The p_att_feats here is already projected 405 | att_size = att_feats.numel() // att_feats.size(0) // self.rnn_size 406 | att = p_att_feats.view(-1, att_size, self.att_hid_size) 407 | 408 | att_h = self.h2att(h) # batch * att_hid_size 409 | att_h = att_h.unsqueeze(1).expand_as(att) # batch * att_size * att_hid_size 410 | dot = att + att_h # batch * att_size * att_hid_size 411 | dot = F.tanh(dot) # batch * att_size * att_hid_size 412 | dot = dot.view(-1, self.att_hid_size) # (batch * att_size) * att_hid_size 413 | dot = self.alpha_net(dot) # (batch * att_size) * 1 414 | dot = dot.view(-1, att_size) # batch * att_size 415 | 416 | weight = F.softmax(dot) # batch * att_size 417 | att_feats_ = att_feats.view(-1, att_size, self.rnn_size) # batch * att_size * att_feat_size 418 | att_res = torch.bmm(weight.unsqueeze(1), att_feats_).squeeze(1) # batch * att_feat_size 419 | 420 | return att_res 421 | 422 | 423 | class Att2in2Core(nn.Module): 424 | def __init__(self, opt): 425 | super(Att2in2Core, self).__init__() 426 | self.input_encoding_size = opt.input_encoding_size 427 | #self.rnn_type = opt.rnn_type 428 | self.rnn_size = opt.rnn_size 429 | #self.num_layers = opt.num_layers 430 | self.drop_prob_lm = opt.drop_prob_lm 431 | self.fc_feat_size = opt.fc_feat_size 432 | self.att_feat_size = opt.att_feat_size 433 | self.att_hid_size = opt.att_hid_size 434 | 435 | # Build a LSTM 436 | self.a2c = nn.Linear(self.rnn_size, 2 * self.rnn_size) 437 | self.i2h = nn.Linear(self.input_encoding_size, 5 * self.rnn_size) 438 | self.h2h = nn.Linear(self.rnn_size, 5 * self.rnn_size) 439 | self.dropout = nn.Dropout(self.drop_prob_lm) 440 | 441 | self.attention = Attention(opt) 442 | 443 | def forward(self, xt, fc_feats, att_feats, p_att_feats, state): 444 | att_res = self.attention(state[0][-1], att_feats, p_att_feats) 445 | 446 | all_input_sums = self.i2h(xt) + self.h2h(state[0][-1]) 447 | sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size) 448 | sigmoid_chunk = F.sigmoid(sigmoid_chunk) 449 | in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size) 450 | forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size) 451 | out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size) 452 | 453 | in_transform = all_input_sums.narrow(1, 3 * self.rnn_size, 2 * self.rnn_size) + \ 454 | self.a2c(att_res) 455 | in_transform = torch.max(\ 456 | in_transform.narrow(1, 0, self.rnn_size), 457 | in_transform.narrow(1, self.rnn_size, self.rnn_size)) 458 | next_c = forget_gate * state[1][-1] + in_gate * in_transform 459 | next_h = out_gate * F.tanh(next_c) 460 | 461 | output = self.dropout(next_h) 462 | state = (next_h.unsqueeze(0), next_c.unsqueeze(0)) 463 | return output, state 464 | 465 | class AdaAttModel(AttModel): 466 | def __init__(self, opt): 467 | super(AdaAttModel, self).__init__(opt) 468 | self.core = AdaAttCore(opt) 469 | 470 | # AdaAtt with maxout lstm 471 | class AdaAttMOModel(AttModel): 472 | def __init__(self, opt): 473 | super(AdaAttMOModel, self).__init__(opt) 474 | self.core = AdaAttCore(opt, True) 475 | 476 | class Att2in2Model(AttModel): 477 | def __init__(self, opt): 478 | super(Att2in2Model, self).__init__(opt) 479 | self.core = Att2in2Core(opt) 480 | delattr(self, 'fc_embed') 481 | self.fc_embed = lambda x : x 482 | 483 | class TopDownModel(AttModel): 484 | def __init__(self, opt): 485 | super(TopDownModel, self).__init__(opt) 486 | self.num_layers = 2 487 | self.core = TopDownCore(opt) 488 | -------------------------------------------------------------------------------- /models/CaptionModel.py: -------------------------------------------------------------------------------- 1 | # This file contains ShowAttendTell and AllImg model 2 | 3 | # ShowAttendTell is from Show, Attend and Tell: Neural Image Caption Generation with Visual Attention 4 | # https://arxiv.org/abs/1502.03044 5 | 6 | # AllImg is a model where 7 | # img feature is concatenated with word embedding at every time step as the input of lstm 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | from torch.autograd import * 16 | import misc.utils as utils 17 | 18 | 19 | class CaptionModel(nn.Module): 20 | def __init__(self): 21 | super(CaptionModel, self).__init__() 22 | 23 | def beam_search(self, state, logprobs, *args, **kwargs): 24 | # args are the miscelleous inputs to the core in addition to embedded word and state 25 | # kwargs only accept opt 26 | 27 | def beam_step(logprobsf, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state): 28 | #INPUTS: 29 | #logprobsf: probabilities augmented after diversity 30 | #beam_size: obvious 31 | #t : time instant 32 | #beam_seq : tensor contanining the beams 33 | #beam_seq_logprobs: tensor contanining the beam logprobs 34 | #beam_logprobs_sum: tensor contanining joint logprobs 35 | #OUPUTS: 36 | #beam_seq : tensor containing the word indices of the decoded captions 37 | #beam_seq_logprobs : log-probability of each decision made, same size as beam_seq 38 | #beam_logprobs_sum : joint log-probability of each beam 39 | 40 | ys,ix = torch.sort(logprobsf,1,True) 41 | candidates = [] 42 | cols = min(beam_size, ys.size(1)) 43 | rows = beam_size 44 | if t == 0: 45 | rows = 1 46 | for c in range(cols): # for each column (word, essentially) 47 | for q in range(rows): # for each beam expansion 48 | #compute logprob of expanding beam q with word in (sorted) position c 49 | local_logprob = ys[q,c] 50 | candidate_logprob = beam_logprobs_sum[q] + local_logprob 51 | candidates.append({'c':ix[q,c], 'q':q, 'p':candidate_logprob, 'r':local_logprob}) 52 | candidates = sorted(candidates, key=lambda x: -x['p']) 53 | 54 | new_state = [_.clone() for _ in state] 55 | #beam_seq_prev, beam_seq_logprobs_prev 56 | if t >= 1: 57 | #we''ll need these as reference when we fork beams around 58 | beam_seq_prev = beam_seq[:t].clone() 59 | beam_seq_logprobs_prev = beam_seq_logprobs[:t].clone() 60 | for vix in range(beam_size): 61 | v = candidates[vix] 62 | #fork beam index q into index vix 63 | if t >= 1: 64 | beam_seq[:t, vix] = beam_seq_prev[:, v['q']] 65 | beam_seq_logprobs[:t, vix] = beam_seq_logprobs_prev[:, v['q']] 66 | #rearrange recurrent states 67 | for state_ix in range(len(new_state)): 68 | # copy over state in previous beam q to new beam at vix 69 | new_state[state_ix][:, vix] = state[state_ix][:, v['q']] # dimension one is time step 70 | #append new end terminal at the end of this beam 71 | beam_seq[t, vix] = v['c'] # c'th word is the continuation 72 | beam_seq_logprobs[t, vix] = v['r'] # the raw logprob here 73 | beam_logprobs_sum[vix] = v['p'] # the new (sum) logprob along this beam 74 | state = new_state 75 | return beam_seq, beam_seq_logprobs, beam_logprobs_sum, state, candidates 76 | 77 | # start beam search 78 | opt = kwargs['opt'] 79 | beam_size = opt.get('beam_size', 10) 80 | 81 | beam_seq = torch.LongTensor(self.seq_length, beam_size).zero_() 82 | beam_seq_logprobs = torch.FloatTensor(self.seq_length, beam_size).zero_() 83 | beam_logprobs_sum = torch.zeros(beam_size) # running sum of logprobs for each beam 84 | done_beams = [] 85 | 86 | for t in range(self.seq_length): 87 | """pem a beam merge. that is, 88 | for every previous beam we now many new possibilities to branch out 89 | we need to resort our beams to maintain the loop invariant of keeping 90 | the top beam_size most likely sequences.""" 91 | logprobsf = logprobs.data.float() # lets go to CPU for more efficiency in indexing operations 92 | # suppress UNK tokens in the decoding 93 | logprobsf[:,logprobsf.size(1)-1] = logprobsf[:, logprobsf.size(1)-1] - 1000 94 | 95 | beam_seq,\ 96 | beam_seq_logprobs,\ 97 | beam_logprobs_sum,\ 98 | state,\ 99 | candidates_divm = beam_step(logprobsf, 100 | beam_size, 101 | t, 102 | beam_seq, 103 | beam_seq_logprobs, 104 | beam_logprobs_sum, 105 | state) 106 | 107 | for vix in range(beam_size): 108 | # if time's up... or if end token is reached then copy beams 109 | if beam_seq[t, vix] == 0 or t == self.seq_length - 1: 110 | final_beam = { 111 | 'seq': beam_seq[:, vix].clone(), 112 | 'logps': beam_seq_logprobs[:, vix].clone(), 113 | 'p': beam_logprobs_sum[vix] 114 | } 115 | done_beams.append(final_beam) 116 | # don't continue beams from finished sequences 117 | beam_logprobs_sum[vix] = -1000 118 | 119 | # encode as vectors 120 | it = beam_seq[t] 121 | logprobs, state = self.get_logprobs_state(Variable(it.cuda()), *(args + (state,))) 122 | 123 | done_beams = sorted(done_beams, key=lambda x: -x['p'])[:beam_size] 124 | return done_beams 125 | -------------------------------------------------------------------------------- /models/FCModel.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.autograd import * 9 | import misc.utils as utils 10 | 11 | from .CaptionModel import CaptionModel 12 | 13 | class LSTMCore(nn.Module): 14 | def __init__(self, opt): 15 | super(LSTMCore, self).__init__() 16 | self.input_encoding_size = opt.input_encoding_size 17 | self.rnn_size = opt.rnn_size 18 | self.drop_prob_lm = opt.drop_prob_lm 19 | 20 | # Build a LSTM 21 | self.i2h = nn.Linear(self.input_encoding_size, 5 * self.rnn_size) 22 | self.h2h = nn.Linear(self.rnn_size, 5 * self.rnn_size) 23 | self.dropout = nn.Dropout(self.drop_prob_lm) 24 | 25 | def forward(self, xt, state): 26 | 27 | all_input_sums = self.i2h(xt) + self.h2h(state[0][-1]) 28 | sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size) 29 | sigmoid_chunk = F.sigmoid(sigmoid_chunk) 30 | in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size) 31 | forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size) 32 | out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size) 33 | 34 | in_transform = torch.max(\ 35 | all_input_sums.narrow(1, 3 * self.rnn_size, self.rnn_size), 36 | all_input_sums.narrow(1, 4 * self.rnn_size, self.rnn_size)) 37 | next_c = forget_gate * state[1][-1] + in_gate * in_transform 38 | next_h = out_gate * F.tanh(next_c) 39 | 40 | next_h = self.dropout(next_h) 41 | 42 | output = next_h 43 | state = (next_h.unsqueeze(0), next_c.unsqueeze(0)) 44 | return output, state 45 | 46 | class FCModel(CaptionModel): 47 | def __init__(self, opt): 48 | super(FCModel, self).__init__() 49 | self.vocab_size = opt.vocab_size 50 | self.input_encoding_size = opt.input_encoding_size 51 | self.rnn_type = opt.rnn_type 52 | self.rnn_size = opt.rnn_size 53 | self.num_layers = opt.num_layers 54 | self.drop_prob_lm = opt.drop_prob_lm 55 | self.seq_length = opt.seq_length 56 | self.fc_feat_size = opt.fc_feat_size 57 | 58 | self.ss_prob = 0.0 # Schedule sampling probability 59 | 60 | self.img_embed = nn.Linear(self.fc_feat_size, self.input_encoding_size) 61 | self.core = LSTMCore(opt) 62 | self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size) 63 | self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1) 64 | 65 | self.init_weights() 66 | 67 | def init_weights(self): 68 | initrange = 0.1 69 | self.embed.weight.data.uniform_(-initrange, initrange) 70 | self.logit.bias.data.fill_(0) 71 | self.logit.weight.data.uniform_(-initrange, initrange) 72 | 73 | def init_hidden(self, bsz): 74 | weight = next(self.parameters()).data 75 | if self.rnn_type == 'lstm': 76 | return (Variable(weight.new(self.num_layers, bsz, self.rnn_size).zero_()), 77 | Variable(weight.new(self.num_layers, bsz, self.rnn_size).zero_())) 78 | else: 79 | return Variable(weight.new(self.num_layers, bsz, self.rnn_size).zero_()) 80 | 81 | def forward(self, fc_feats, att_feats, seq): 82 | batch_size = fc_feats.size(0) 83 | state = self.init_hidden(batch_size) 84 | outputs = [] 85 | 86 | for i in range(seq.size(1)): 87 | if i == 0: 88 | xt = self.img_embed(fc_feats) 89 | else: 90 | if self.training and i >= 2 and self.ss_prob > 0.0: # otherwiste no need to sample 91 | sample_prob = fc_feats.data.new(batch_size).uniform_(0, 1) 92 | sample_mask = sample_prob < self.ss_prob 93 | if sample_mask.sum() == 0: 94 | it = seq[:, i-1].clone() 95 | else: 96 | sample_ind = sample_mask.nonzero().view(-1) 97 | it = seq[:, i-1].data.clone() 98 | #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1) 99 | #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1)) 100 | prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1) 101 | it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind)) 102 | it = Variable(it, requires_grad=False) 103 | else: 104 | it = seq[:, i-1].clone() 105 | # break if all the sequences end 106 | if i >= 2 and seq[:, i-1].data.sum() == 0: 107 | break 108 | xt = self.embed(it) 109 | 110 | output, state = self.core(xt, state) 111 | output = F.log_softmax(self.logit(output)) 112 | outputs.append(output) 113 | 114 | return torch.cat([_.unsqueeze(1) for _ in outputs[1:]], 1).contiguous() 115 | 116 | def get_logprobs_state(self, it, state): 117 | # 'it' is Variable contraining a word index 118 | xt = self.embed(it) 119 | 120 | output, state = self.core(xt, state) 121 | logprobs = F.log_softmax(self.logit(output)) 122 | 123 | return logprobs, state 124 | 125 | def sample_beam(self, fc_feats, att_feats, opt={}): 126 | beam_size = opt.get('beam_size', 10) 127 | batch_size = fc_feats.size(0) 128 | 129 | assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' 130 | seq = torch.LongTensor(self.seq_length, batch_size).zero_() 131 | seqLogprobs = torch.FloatTensor(self.seq_length, batch_size) 132 | # lets process every image independently for now, for simplicity 133 | 134 | self.done_beams = [[] for _ in range(batch_size)] 135 | for k in range(batch_size): 136 | state = self.init_hidden(beam_size) 137 | for t in range(2): 138 | if t == 0: 139 | xt = self.img_embed(fc_feats[k:k+1]).expand(beam_size, self.input_encoding_size) 140 | elif t == 1: # input 141 | it = fc_feats.data.new(beam_size).long().zero_() 142 | xt = self.embed(Variable(it, requires_grad=False)) 143 | 144 | output, state = self.core(xt, state) 145 | logprobs = F.log_softmax(self.logit(output)) 146 | 147 | self.done_beams[k] = self.beam_search(state, logprobs, opt=opt) 148 | seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score 149 | seqLogprobs[:, k] = self.done_beams[k][0]['logps'] 150 | # return the samples and their log likelihoods 151 | return seq.transpose(0, 1), seqLogprobs.transpose(0, 1) 152 | 153 | def sample(self, fc_feats, att_feats, opt={}): 154 | sample_max = opt.get('sample_max', 1) 155 | beam_size = opt.get('beam_size', 1) 156 | temperature = opt.get('temperature', 1.0) 157 | if beam_size > 1: 158 | return self.sample_beam(fc_feats, att_feats, opt) 159 | 160 | batch_size = fc_feats.size(0) 161 | state = self.init_hidden(batch_size) 162 | seq = [] 163 | seqLogprobs = [] 164 | for t in range(self.seq_length + 2): 165 | if t == 0: 166 | xt = self.img_embed(fc_feats) 167 | else: 168 | if t == 1: # input 169 | it = fc_feats.data.new(batch_size).long().zero_() 170 | elif sample_max: 171 | sampleLogprobs, it = torch.max(logprobs.data, 1) 172 | it = it.view(-1).long() 173 | else: 174 | if temperature == 1.0: 175 | prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1) 176 | else: 177 | # scale logprobs by temperature 178 | prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu() 179 | it = torch.multinomial(prob_prev, 1).cuda() 180 | sampleLogprobs = logprobs.gather(1, Variable(it, requires_grad=False)) # gather the logprobs at sampled positions 181 | it = it.view(-1).long() # and flatten indices for downstream processing 182 | 183 | xt = self.embed(Variable(it, requires_grad=False)) 184 | 185 | if t >= 2: 186 | # stop when all finished 187 | if t == 2: 188 | unfinished = it > 0 189 | else: 190 | unfinished = unfinished * (it > 0) 191 | if unfinished.sum() == 0: 192 | break 193 | it = it * unfinished.type_as(it) 194 | seq.append(it) #seq[t] the input of t+2 time step 195 | seqLogprobs.append(sampleLogprobs.view(-1)) 196 | 197 | output, state = self.core(xt, state) 198 | logprobs = F.log_softmax(self.logit(output)) 199 | 200 | return torch.cat([_.unsqueeze(1) for _ in seq], 1), torch.cat([_.unsqueeze(1) for _ in seqLogprobs], 1) 201 | 202 | 203 | -------------------------------------------------------------------------------- /models/OldModel.py: -------------------------------------------------------------------------------- 1 | # This file contains ShowAttendTell and AllImg model 2 | 3 | # ShowAttendTell is from Show, Attend and Tell: Neural Image Caption Generation with Visual Attention 4 | # https://arxiv.org/abs/1502.03044 5 | 6 | # AllImg is a model where 7 | # img feature is concatenated with word embedding at every time step as the input of lstm 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | from torch.autograd import * 16 | import misc.utils as utils 17 | 18 | from .CaptionModel import CaptionModel 19 | 20 | class OldModel(CaptionModel): 21 | def __init__(self, opt): 22 | super(OldModel, self).__init__() 23 | self.vocab_size = opt.vocab_size 24 | self.input_encoding_size = opt.input_encoding_size 25 | self.rnn_type = opt.rnn_type 26 | self.rnn_size = opt.rnn_size 27 | self.num_layers = opt.num_layers 28 | self.drop_prob_lm = opt.drop_prob_lm 29 | self.seq_length = opt.seq_length 30 | self.fc_feat_size = opt.fc_feat_size 31 | self.att_feat_size = opt.att_feat_size 32 | 33 | self.ss_prob = 0.0 # Schedule sampling probability 34 | 35 | self.linear = nn.Linear(self.fc_feat_size, self.num_layers * self.rnn_size) # feature to rnn_size 36 | self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size) 37 | self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1) 38 | self.dropout = nn.Dropout(self.drop_prob_lm) 39 | 40 | self.init_weights() 41 | 42 | def init_weights(self): 43 | initrange = 0.1 44 | self.embed.weight.data.uniform_(-initrange, initrange) 45 | self.logit.bias.data.fill_(0) 46 | self.logit.weight.data.uniform_(-initrange, initrange) 47 | 48 | def init_hidden(self, fc_feats): 49 | image_map = self.linear(fc_feats).view(-1, self.num_layers, self.rnn_size).transpose(0, 1) 50 | if self.rnn_type == 'lstm': 51 | return (image_map, image_map) 52 | else: 53 | return image_map 54 | 55 | def forward(self, fc_feats, att_feats, seq): 56 | batch_size = fc_feats.size(0) 57 | state = self.init_hidden(fc_feats) 58 | 59 | outputs = [] 60 | 61 | for i in range(seq.size(1) - 1): 62 | if self.training and i >= 1 and self.ss_prob > 0.0: # otherwiste no need to sample 63 | sample_prob = fc_feats.data.new(batch_size).uniform_(0, 1) 64 | sample_mask = sample_prob < self.ss_prob 65 | if sample_mask.sum() == 0: 66 | it = seq[:, i].clone() 67 | else: 68 | sample_ind = sample_mask.nonzero().view(-1) 69 | it = seq[:, i].data.clone() 70 | #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1) 71 | #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1)) 72 | prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1) 73 | it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind)) 74 | it = Variable(it, requires_grad=False) 75 | else: 76 | it = seq[:, i].clone() 77 | # break if all the sequences end 78 | if i >= 1 and seq[:, i].data.sum() == 0: 79 | break 80 | 81 | xt = self.embed(it) 82 | 83 | output, state = self.core(xt, fc_feats, att_feats, state) 84 | output = F.log_softmax(self.logit(self.dropout(output))) 85 | outputs.append(output) 86 | 87 | return torch.cat([_.unsqueeze(1) for _ in outputs], 1) 88 | 89 | def get_logprobs_state(self, it, tmp_fc_feats, tmp_att_feats, state): 90 | # 'it' is Variable contraining a word index 91 | xt = self.embed(it) 92 | 93 | output, state = self.core(xt, tmp_fc_feats, tmp_att_feats, state) 94 | logprobs = F.log_softmax(self.logit(self.dropout(output))) 95 | 96 | return logprobs, state 97 | 98 | def sample_beam(self, fc_feats, att_feats, opt={}): 99 | beam_size = opt.get('beam_size', 10) 100 | batch_size = fc_feats.size(0) 101 | 102 | assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' 103 | seq = torch.LongTensor(self.seq_length, batch_size).zero_() 104 | seqLogprobs = torch.FloatTensor(self.seq_length, batch_size) 105 | # lets process every image independently for now, for simplicity 106 | 107 | self.done_beams = [[] for _ in range(batch_size)] 108 | for k in range(batch_size): 109 | tmp_fc_feats = fc_feats[k:k+1].expand(beam_size, self.fc_feat_size) 110 | tmp_att_feats = att_feats[k:k+1].expand(*((beam_size,)+att_feats.size()[1:])).contiguous() 111 | 112 | state = self.init_hidden(tmp_fc_feats) 113 | 114 | beam_seq = torch.LongTensor(self.seq_length, beam_size).zero_() 115 | beam_seq_logprobs = torch.FloatTensor(self.seq_length, beam_size).zero_() 116 | beam_logprobs_sum = torch.zeros(beam_size) # running sum of logprobs for each beam 117 | done_beams = [] 118 | for t in range(1): 119 | if t == 0: # input 120 | it = fc_feats.data.new(beam_size).long().zero_() 121 | xt = self.embed(Variable(it, requires_grad=False)) 122 | 123 | output, state = self.core(xt, tmp_fc_feats, tmp_att_feats, state) 124 | logprobs = F.log_softmax(self.logit(self.dropout(output))) 125 | 126 | self.done_beams[k] = self.beam_search(state, logprobs, tmp_fc_feats, tmp_att_feats, opt=opt) 127 | seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score 128 | seqLogprobs[:, k] = self.done_beams[k][0]['logps'] 129 | # return the samples and their log likelihoods 130 | return seq.transpose(0, 1), seqLogprobs.transpose(0, 1) 131 | 132 | def sample(self, fc_feats, att_feats, opt={}): 133 | sample_max = opt.get('sample_max', 1) 134 | beam_size = opt.get('beam_size', 1) 135 | temperature = opt.get('temperature', 1.0) 136 | if beam_size > 1: 137 | return self.sample_beam(fc_feats, att_feats, opt) 138 | 139 | batch_size = fc_feats.size(0) 140 | state = self.init_hidden(fc_feats) 141 | 142 | seq = [] 143 | seqLogprobs = [] 144 | for t in range(self.seq_length + 1): 145 | if t == 0: # input 146 | it = fc_feats.data.new(batch_size).long().zero_() 147 | elif sample_max: 148 | sampleLogprobs, it = torch.max(logprobs.data, 1) 149 | it = it.view(-1).long() 150 | else: 151 | if temperature == 1.0: 152 | prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1) 153 | else: 154 | # scale logprobs by temperature 155 | prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu() 156 | it = torch.multinomial(prob_prev, 1).cuda() 157 | sampleLogprobs = logprobs.gather(1, Variable(it, requires_grad=False)) # gather the logprobs at sampled positions 158 | it = it.view(-1).long() # and flatten indices for downstream processing 159 | 160 | xt = self.embed(Variable(it, requires_grad=False)) 161 | 162 | if t >= 1: 163 | # stop when all finished 164 | if t == 1: 165 | unfinished = it > 0 166 | else: 167 | unfinished = unfinished * (it > 0) 168 | if unfinished.sum() == 0: 169 | break 170 | it = it * unfinished.type_as(it) 171 | seq.append(it) #seq[t] the input of t+2 time step 172 | seqLogprobs.append(sampleLogprobs.view(-1)) 173 | 174 | output, state = self.core(xt, fc_feats, att_feats, state) 175 | logprobs = F.log_softmax(self.logit(self.dropout(output))) 176 | 177 | return torch.cat([_.unsqueeze(1) for _ in seq], 1), torch.cat([_.unsqueeze(1) for _ in seqLogprobs], 1) 178 | 179 | 180 | class ShowAttendTellCore(nn.Module): 181 | def __init__(self, opt): 182 | super(ShowAttendTellCore, self).__init__() 183 | self.input_encoding_size = opt.input_encoding_size 184 | self.rnn_type = opt.rnn_type 185 | self.rnn_size = opt.rnn_size 186 | self.num_layers = opt.num_layers 187 | self.drop_prob_lm = opt.drop_prob_lm 188 | self.fc_feat_size = opt.fc_feat_size 189 | self.att_feat_size = opt.att_feat_size 190 | self.att_hid_size = opt.att_hid_size 191 | 192 | self.rnn = getattr(nn, self.rnn_type.upper())(self.input_encoding_size + self.att_feat_size, 193 | self.rnn_size, self.num_layers, bias=False, dropout=self.drop_prob_lm) 194 | 195 | if self.att_hid_size > 0: 196 | self.ctx2att = nn.Linear(self.att_feat_size, self.att_hid_size) 197 | self.h2att = nn.Linear(self.rnn_size, self.att_hid_size) 198 | self.alpha_net = nn.Linear(self.att_hid_size, 1) 199 | else: 200 | self.ctx2att = nn.Linear(self.att_feat_size, 1) 201 | self.h2att = nn.Linear(self.rnn_size, 1) 202 | 203 | def forward(self, xt, fc_feats, att_feats, state): 204 | att_size = att_feats.numel() // att_feats.size(0) // self.att_feat_size 205 | att = att_feats.view(-1, self.att_feat_size) 206 | if self.att_hid_size > 0: 207 | att = self.ctx2att(att) # (batch * att_size) * att_hid_size 208 | att = att.view(-1, att_size, self.att_hid_size) # batch * att_size * att_hid_size 209 | att_h = self.h2att(state[0][-1]) # batch * att_hid_size 210 | att_h = att_h.unsqueeze(1).expand_as(att) # batch * att_size * att_hid_size 211 | dot = att + att_h # batch * att_size * att_hid_size 212 | dot = F.tanh(dot) # batch * att_size * att_hid_size 213 | dot = dot.view(-1, self.att_hid_size) # (batch * att_size) * att_hid_size 214 | dot = self.alpha_net(dot) # (batch * att_size) * 1 215 | dot = dot.view(-1, att_size) # batch * att_size 216 | else: 217 | att = self.ctx2att(att)(att) # (batch * att_size) * 1 218 | att = att.view(-1, att_size) # batch * att_size 219 | att_h = self.h2att(state[0][-1]) # batch * 1 220 | att_h = att_h.expand_as(att) # batch * att_size 221 | dot = att_h + att # batch * att_size 222 | 223 | weight = F.softmax(dot) 224 | att_feats_ = att_feats.view(-1, att_size, self.att_feat_size) # batch * att_size * att_feat_size 225 | att_res = torch.bmm(weight.unsqueeze(1), att_feats_).squeeze(1) # batch * att_feat_size 226 | 227 | output, state = self.rnn(torch.cat([xt, att_res], 1).unsqueeze(0), state) 228 | return output.squeeze(0), state 229 | 230 | class AllImgCore(nn.Module): 231 | def __init__(self, opt): 232 | super(AllImgCore, self).__init__() 233 | self.input_encoding_size = opt.input_encoding_size 234 | self.rnn_type = opt.rnn_type 235 | self.rnn_size = opt.rnn_size 236 | self.num_layers = opt.num_layers 237 | self.drop_prob_lm = opt.drop_prob_lm 238 | self.fc_feat_size = opt.fc_feat_size 239 | 240 | self.rnn = getattr(nn, self.rnn_type.upper())(self.input_encoding_size + self.fc_feat_size, 241 | self.rnn_size, self.num_layers, bias=False, dropout=self.drop_prob_lm) 242 | 243 | def forward(self, xt, fc_feats, att_feats, state): 244 | output, state = self.rnn(torch.cat([xt, fc_feats], 1).unsqueeze(0), state) 245 | return output.squeeze(0), state 246 | 247 | class ShowAttendTellModel(OldModel): 248 | def __init__(self, opt): 249 | super(ShowAttendTellModel, self).__init__(opt) 250 | self.core = ShowAttendTellCore(opt) 251 | 252 | class AllImgModel(OldModel): 253 | def __init__(self, opt): 254 | super(AllImgModel, self).__init__(opt) 255 | self.core = AllImgCore(opt) 256 | 257 | -------------------------------------------------------------------------------- /models/ShowTellModel.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.autograd import * 9 | import misc.utils as utils 10 | 11 | from .CaptionModel import CaptionModel 12 | 13 | class ShowTellModel(CaptionModel): 14 | def __init__(self, opt): 15 | super(ShowTellModel, self).__init__() 16 | self.vocab_size = opt.vocab_size 17 | self.input_encoding_size = opt.input_encoding_size 18 | self.rnn_type = opt.rnn_type 19 | self.rnn_size = opt.rnn_size 20 | self.num_layers = opt.num_layers 21 | self.drop_prob_lm = opt.drop_prob_lm 22 | self.seq_length = opt.seq_length 23 | self.fc_feat_size = opt.fc_feat_size 24 | 25 | self.ss_prob = 0.0 # Schedule sampling probability 26 | 27 | self.img_embed = nn.Linear(self.fc_feat_size, self.input_encoding_size) 28 | self.core = getattr(nn, self.rnn_type.upper())(self.input_encoding_size, self.rnn_size, self.num_layers, bias=False, dropout=self.drop_prob_lm) 29 | self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size) 30 | self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1) 31 | self.dropout = nn.Dropout(self.drop_prob_lm) 32 | 33 | self.init_weights() 34 | 35 | def init_weights(self): 36 | initrange = 0.1 37 | self.embed.weight.data.uniform_(-initrange, initrange) 38 | self.logit.bias.data.fill_(0) 39 | self.logit.weight.data.uniform_(-initrange, initrange) 40 | 41 | def init_hidden(self, bsz): 42 | weight = next(self.parameters()).data 43 | if self.rnn_type == 'lstm': 44 | return (Variable(weight.new(self.num_layers, bsz, self.rnn_size).zero_()), 45 | Variable(weight.new(self.num_layers, bsz, self.rnn_size).zero_())) 46 | else: 47 | return Variable(weight.new(self.num_layers, bsz, self.rnn_size).zero_()) 48 | 49 | def forward(self, fc_feats, att_feats, seq): 50 | batch_size = fc_feats.size(0) 51 | state = self.init_hidden(batch_size) 52 | outputs = [] 53 | 54 | for i in range(seq.size(1)): 55 | if i == 0: 56 | xt = self.img_embed(fc_feats) 57 | else: 58 | if self.training and i >= 2 and self.ss_prob > 0.0: # otherwiste no need to sample 59 | sample_prob = fc_feats.data.new(batch_size).uniform_(0, 1) 60 | sample_mask = sample_prob < self.ss_prob 61 | if sample_mask.sum() == 0: 62 | it = seq[:, i-1].clone() 63 | else: 64 | sample_ind = sample_mask.nonzero().view(-1) 65 | it = seq[:, i-1].data.clone() 66 | #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1) 67 | #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1)) 68 | prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1) 69 | it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind)) 70 | it = Variable(it, requires_grad=False) 71 | else: 72 | it = seq[:, i-1].clone() 73 | # break if all the sequences end 74 | if i >= 2 and seq[:, i-1].data.sum() == 0: 75 | break 76 | xt = self.embed(it) 77 | 78 | output, state = self.core(xt.unsqueeze(0), state) 79 | output = F.log_softmax(self.logit(self.dropout(output.squeeze(0)))) 80 | outputs.append(output) 81 | 82 | return torch.cat([_.unsqueeze(1) for _ in outputs[1:]], 1).contiguous() 83 | 84 | def get_logprobs_state(self, it, state): 85 | # 'it' is Variable contraining a word index 86 | xt = self.embed(it) 87 | 88 | output, state = self.core(xt.unsqueeze(0), state) 89 | logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0)))) 90 | 91 | return logprobs, state 92 | 93 | def sample_beam(self, fc_feats, att_feats, opt={}): 94 | beam_size = opt.get('beam_size', 10) 95 | batch_size = fc_feats.size(0) 96 | 97 | assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' 98 | seq = torch.LongTensor(self.seq_length, batch_size).zero_() 99 | seqLogprobs = torch.FloatTensor(self.seq_length, batch_size) 100 | # lets process every image independently for now, for simplicity 101 | 102 | self.done_beams = [[] for _ in range(batch_size)] 103 | for k in range(batch_size): 104 | state = self.init_hidden(beam_size) 105 | for t in range(2): 106 | if t == 0: 107 | xt = self.img_embed(fc_feats[k:k+1]).expand(beam_size, self.input_encoding_size) 108 | elif t == 1: # input 109 | it = fc_feats.data.new(beam_size).long().zero_() 110 | xt = self.embed(Variable(it, requires_grad=False)) 111 | 112 | output, state = self.core(xt.unsqueeze(0), state) 113 | logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0)))) 114 | 115 | self.done_beams[k] = self.beam_search(state, logprobs, opt=opt) 116 | seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score 117 | seqLogprobs[:, k] = self.done_beams[k][0]['logps'] 118 | # return the samples and their log likelihoods 119 | return seq.transpose(0, 1), seqLogprobs.transpose(0, 1) 120 | 121 | def sample(self, fc_feats, att_feats, opt={}): 122 | sample_max = opt.get('sample_max', 1) 123 | beam_size = opt.get('beam_size', 1) 124 | temperature = opt.get('temperature', 1.0) 125 | if beam_size > 1: 126 | return self.sample_beam(fc_feats, att_feats, opt) 127 | 128 | batch_size = fc_feats.size(0) 129 | state = self.init_hidden(batch_size) 130 | seq = [] 131 | seqLogprobs = [] 132 | for t in range(self.seq_length + 2): 133 | if t == 0: 134 | xt = self.img_embed(fc_feats) 135 | else: 136 | if t == 1: # input 137 | it = fc_feats.data.new(batch_size).long().zero_() 138 | elif sample_max: 139 | sampleLogprobs, it = torch.max(logprobs.data, 1) 140 | it = it.view(-1).long() 141 | else: 142 | if temperature == 1.0: 143 | prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1) 144 | else: 145 | # scale logprobs by temperature 146 | prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu() 147 | it = torch.multinomial(prob_prev, 1).cuda() 148 | sampleLogprobs = logprobs.gather(1, Variable(it, requires_grad=False)) # gather the logprobs at sampled positions 149 | it = it.view(-1).long() # and flatten indices for downstream processing 150 | 151 | xt = self.embed(Variable(it, requires_grad=False)) 152 | 153 | if t >= 2: 154 | # stop when all finished 155 | if t == 2: 156 | unfinished = it > 0 157 | else: 158 | unfinished = unfinished * (it > 0) 159 | if unfinished.sum() == 0: 160 | break 161 | it = it * unfinished.type_as(it) 162 | seq.append(it) #seq[t] the input of t+2 time step 163 | seqLogprobs.append(sampleLogprobs.view(-1)) 164 | 165 | output, state = self.core(xt.unsqueeze(0), state) 166 | logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0)))) 167 | 168 | return torch.cat([_.unsqueeze(1) for _ in seq], 1), torch.cat([_.unsqueeze(1) for _ in seqLogprobs], 1) -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import copy 7 | 8 | import numpy as np 9 | import misc.utils as utils 10 | import torch 11 | 12 | from .ShowTellModel import ShowTellModel 13 | from .FCModel import FCModel 14 | from .OldModel import ShowAttendTellModel, AllImgModel 15 | from .Att2inModel import Att2inModel 16 | from .AttModel import * 17 | 18 | def setup(opt): 19 | 20 | if opt.caption_model == 'fc': 21 | model = FCModel(opt) 22 | # Att2in model in self-critical 23 | elif opt.caption_model == 'att2in': 24 | model = Att2inModel(opt) 25 | # Att2in model with two-layer MLP img embedding and word embedding 26 | elif opt.caption_model == 'att2in2': 27 | model = Att2in2Model(opt) 28 | # Adaptive Attention model from Knowing when to look 29 | elif opt.caption_model == 'adaatt': 30 | model = AdaAttModel(opt) 31 | # Adaptive Attention with maxout lstm 32 | elif opt.caption_model == 'adaattmo': 33 | model = AdaAttMOModel(opt) 34 | # Top-down attention model 35 | elif opt.caption_model == 'topdown': 36 | model = TopDownModel(opt) 37 | else: 38 | raise Exception("Caption model not supported: {}".format(opt.caption_model)) 39 | 40 | # check compatibility if training is continued from previously saved model 41 | if vars(opt).get('start_from', None) is not None: 42 | # check if all necessary files exist 43 | assert os.path.isdir(opt.start_from)," %s must be a a path" % opt.start_from 44 | assert os.path.isfile(os.path.join(opt.start_from,"infos_"+opt.id+".pkl")),"infos.pkl file does not exist in path %s"%opt.start_from 45 | model.load_state_dict(torch.load(os.path.join(opt.start_from, 'model.pth'))) 46 | 47 | return model -------------------------------------------------------------------------------- /opts.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def parse_opt(): 4 | parser = argparse.ArgumentParser() 5 | # Data input settings 6 | parser.add_argument('--input_json', type=str, default='data/coco.json', 7 | help='path to the json file containing additional info and vocab') 8 | parser.add_argument('--input_fc_dir', type=str, default='data/cocotalk_fc', 9 | help='path to the directory containing the preprocessed fc feats') 10 | parser.add_argument('--input_att_dir', type=str, default='data/cocotalk_att', 11 | help='path to the directory containing the preprocessed att feats') 12 | parser.add_argument('--input_label_h5', type=str, default='data/coco_label.h5', 13 | help='path to the h5file containing the preprocessed dataset') 14 | parser.add_argument('--start_from', type=str, default=None, 15 | help="""continue training from saved model at this path. Path must contain files saved by previous training process: 16 | 'infos.pkl' : configuration; 17 | 'checkpoint' : paths to model file(s) (created by tf). 18 | Note: this file contains absolute paths, be careful when moving files around; 19 | 'model.ckpt-*' : file(s) with model definition (created by tf) 20 | """) 21 | parser.add_argument('--cached_tokens', type=str, default='coco-train-idxs', 22 | help='Cached token file for calculating cider score during self critical training.') 23 | 24 | # Model settings 25 | parser.add_argument('--caption_model', type=str, default="show_tell", 26 | help='show_tell, show_attend_tell, all_img, fc, att2in, att2in2, adaatt, adaattmo, topdown') 27 | parser.add_argument('--rnn_size', type=int, default=512, 28 | help='size of the rnn in number of hidden nodes in each layer') 29 | parser.add_argument('--num_layers', type=int, default=1, 30 | help='number of layers in the RNN') 31 | parser.add_argument('--rnn_type', type=str, default='lstm', 32 | help='rnn, gru, or lstm') 33 | parser.add_argument('--input_encoding_size', type=int, default=512, 34 | help='the encoding size of each token in the vocabulary, and the image.') 35 | parser.add_argument('--att_hid_size', type=int, default=512, 36 | help='the hidden size of the attention MLP; only useful in show_attend_tell; 0 if not using hidden layer') 37 | parser.add_argument('--fc_feat_size', type=int, default=2048, 38 | help='2048 for resnet, 4096 for vgg') 39 | parser.add_argument('--att_feat_size', type=int, default=2048, 40 | help='2048 for resnet, 512 for vgg') 41 | 42 | # Optimization: General 43 | parser.add_argument('--max_epochs', type=int, default=-1, 44 | help='number of epochs') 45 | parser.add_argument('--batch_size', type=int, default=16, 46 | help='minibatch size') 47 | parser.add_argument('--grad_clip', type=float, default=0.1, #5., 48 | help='clip gradients at this value') 49 | parser.add_argument('--drop_prob_lm', type=float, default=0.5, 50 | help='strength of dropout in the Language Model RNN') 51 | parser.add_argument('--self_critical_after', type=int, default=-1, 52 | help='After what epoch do we start finetuning the CNN? (-1 = disable; never finetune, 0 = finetune from start)') 53 | parser.add_argument('--seq_per_img', type=int, default=5, 54 | help='number of captions to sample for each image during training. Done for efficiency since CNN forward pass is expensive. E.g. coco has 5 sents/image') 55 | parser.add_argument('--beam_size', type=int, default=1, 56 | help='used when sample_max = 1, indicates number of beams in beam search. Usually 2 or 3 works well. More is not better. Set this to 1 for faster runtime but a bit worse performance.') 57 | 58 | #Optimization: for the Language Model 59 | parser.add_argument('--optim', type=str, default='adam', 60 | help='what update to use? rmsprop|sgd|sgdmom|adagrad|adam') 61 | parser.add_argument('--learning_rate', type=float, default=4e-4, 62 | help='learning rate') 63 | parser.add_argument('--learning_rate_decay_start', type=int, default=-1, 64 | help='at what iteration to start decaying learning rate? (-1 = dont) (in epoch)') 65 | parser.add_argument('--learning_rate_decay_every', type=int, default=3, 66 | help='every how many iterations thereafter to drop LR?(in epoch)') 67 | parser.add_argument('--learning_rate_decay_rate', type=float, default=0.8, 68 | help='every how many iterations thereafter to drop LR?(in epoch)') 69 | parser.add_argument('--optim_alpha', type=float, default=0.9, 70 | help='alpha for adam') 71 | parser.add_argument('--optim_beta', type=float, default=0.999, 72 | help='beta used for adam') 73 | parser.add_argument('--optim_epsilon', type=float, default=1e-8, 74 | help='epsilon that goes into denominator for smoothing') 75 | parser.add_argument('--weight_decay', type=float, default=0, 76 | help='weight_decay') 77 | 78 | parser.add_argument('--scheduled_sampling_start', type=int, default=-1, 79 | help='at what iteration to start decay gt probability') 80 | parser.add_argument('--scheduled_sampling_increase_every', type=int, default=5, 81 | help='every how many iterations thereafter to gt probability') 82 | parser.add_argument('--scheduled_sampling_increase_prob', type=float, default=0.05, 83 | help='How much to update the prob') 84 | parser.add_argument('--scheduled_sampling_max_prob', type=float, default=0.25, 85 | help='Maximum scheduled sampling prob.') 86 | 87 | 88 | # Evaluation/Checkpointing 89 | parser.add_argument('--val_images_use', type=int, default=3200, 90 | help='how many images to use when periodically evaluating the validation loss? (-1 = all)') 91 | parser.add_argument('--save_checkpoint_every', type=int, default=2500, 92 | help='how often to save a model checkpoint (in iterations)?') 93 | parser.add_argument('--checkpoint_path', type=str, default='save', 94 | help='directory to store checkpointed models') 95 | parser.add_argument('--language_eval', type=int, default=0, 96 | help='Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.') 97 | parser.add_argument('--losses_log_every', type=int, default=25, 98 | help='How often do we snapshot losses, for inclusion in the progress dump? (0 = disable)') 99 | parser.add_argument('--load_best_score', type=int, default=1, 100 | help='Do we load previous best score when resuming training.') 101 | 102 | # misc 103 | parser.add_argument('--id', type=str, default='', 104 | help='an id identifying this run/job. used in cross-val and appended when writing progress files') 105 | parser.add_argument('--train_only', type=int, default=0, 106 | help='if true then use 80k, else use 110k') 107 | 108 | args = parser.parse_args() 109 | 110 | # Check if args are valid 111 | assert args.rnn_size > 0, "rnn_size should be greater than 0" 112 | assert args.num_layers > 0, "num_layers should be greater than 0" 113 | assert args.input_encoding_size > 0, "input_encoding_size should be greater than 0" 114 | assert args.batch_size > 0, "batch_size should be greater than 0" 115 | assert args.drop_prob_lm >= 0 and args.drop_prob_lm < 1, "drop_prob_lm should be between 0 and 1" 116 | assert args.seq_per_img > 0, "seq_per_img should be greater than 0" 117 | assert args.beam_size > 0, "beam_size should be greater than 0" 118 | assert args.save_checkpoint_every > 0, "save_checkpoint_every should be greater than 0" 119 | assert args.losses_log_every > 0, "losses_log_every should be greater than 0" 120 | assert args.language_eval == 0 or args.language_eval == 1, "language_eval should be 0 or 1" 121 | assert args.load_best_score == 0 or args.load_best_score == 1, "language_eval should be 0 or 1" 122 | assert args.train_only == 0 or args.train_only == 1, "language_eval should be 0 or 1" 123 | 124 | return args -------------------------------------------------------------------------------- /scripts/copy_model.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | cp -r log_$1 log_$2 4 | cd log_$2 5 | mv infos_$1-best.pkl infos_$2-best.pkl 6 | mv infos_$1.pkl infos_$2.pkl 7 | cd ../ -------------------------------------------------------------------------------- /scripts/prepro_feats.py: -------------------------------------------------------------------------------- 1 | """ 2 | Preprocess a raw json dataset into hdf5/json files for use in data_loader.lua 3 | 4 | Input: json file that has the form 5 | [{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...] 6 | example element in this list would look like 7 | {'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895} 8 | 9 | This script reads this json, does some basic preprocessing on the captions 10 | (e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays 11 | 12 | Output: a json file and an hdf5 file 13 | The hdf5 file contains several fields: 14 | /images is (N,3,256,256) uint8 array of raw image data in RGB format 15 | /labels is (M,max_length) uint32 array of encoded labels, zero padded 16 | /label_start_ix and /label_end_ix are (N,) uint32 arrays of pointers to the 17 | first and last indices (in range 1..M) of labels for each image 18 | /label_length stores the length of the sequence for each of the M sequences 19 | 20 | The json file has a dict that contains: 21 | - an 'ix_to_word' field storing the vocab in form {ix:'word'}, where ix is 1-indexed 22 | - an 'images' field that is a list holding auxiliary information for each image, 23 | such as in particular the 'split' it was assigned to. 24 | """ 25 | 26 | from __future__ import absolute_import 27 | from __future__ import division 28 | from __future__ import print_function 29 | 30 | import os 31 | import json 32 | import argparse 33 | from random import shuffle, seed 34 | import string 35 | # non-standard dependencies: 36 | import h5py 37 | from six.moves import cPickle 38 | import numpy as np 39 | import torch 40 | import torchvision.models as models 41 | from torch.autograd import Variable 42 | import skimage.io 43 | 44 | from torchvision import transforms as trn 45 | preprocess = trn.Compose([ 46 | #trn.ToTensor(), 47 | trn.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 48 | ]) 49 | 50 | from misc.resnet_utils import myResnet 51 | import misc.resnet as resnet 52 | 53 | def main(params): 54 | net = getattr(resnet, params['model'])() 55 | net.load_state_dict(torch.load(os.path.join(params['model_root'],params['model']+'.pth'))) 56 | my_resnet = myResnet(net) 57 | my_resnet.cuda() 58 | my_resnet.eval() 59 | 60 | imgs = json.load(open(params['input_json'], 'r')) 61 | imgs = imgs['images'] 62 | N = len(imgs) 63 | 64 | seed(123) # make reproducible 65 | 66 | dir_fc = params['output_dir']+'_fc' 67 | dir_att = params['output_dir']+'_att' 68 | if not os.path.isdir(dir_fc): 69 | os.mkdir(dir_fc) 70 | if not os.path.isdir(dir_att): 71 | os.mkdir(dir_att) 72 | 73 | for i,img in enumerate(imgs): 74 | # load the image 75 | I = skimage.io.imread(os.path.join(params['images_root'], img['filepath'], img['filename'])) 76 | # handle grayscale input images 77 | if len(I.shape) == 2: 78 | I = I[:,:,np.newaxis] 79 | I = np.concatenate((I,I,I), axis=2) 80 | 81 | I = I.astype('float32')/255.0 82 | I = torch.from_numpy(I.transpose([2,0,1])).cuda() 83 | I = Variable(preprocess(I), volatile=True) 84 | tmp_fc, tmp_att = my_resnet(I, params['att_size']) 85 | # write to pkl 86 | np.save(os.path.join(dir_fc, str(img['cocoid'])), tmp_fc.data.cpu().float().numpy()) 87 | np.savez_compressed(os.path.join(dir_att, str(img['cocoid'])), feat=tmp_att.data.cpu().float().numpy()) 88 | 89 | if i % 1000 == 0: 90 | print('processing %d/%d (%.2f%% done)' % (i, N, i*100.0/N)) 91 | print('wrote ', params['output_dir']) 92 | 93 | if __name__ == "__main__": 94 | 95 | parser = argparse.ArgumentParser() 96 | 97 | # input json 98 | parser.add_argument('--input_json', required=True, help='input json file to process into hdf5') 99 | parser.add_argument('--output_dir', default='data', help='output h5 file') 100 | 101 | # options 102 | parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json') 103 | parser.add_argument('--att_size', default=14, type=int, help='14x14 or 7x7') 104 | parser.add_argument('--model', default='resnet101', type=str, help='resnet101, resnet152') 105 | parser.add_argument('--model_root', default='./data/imagenet_weights', type=str, help='model root') 106 | 107 | args = parser.parse_args() 108 | params = vars(args) # convert to ordinary dict 109 | print('parsed input parameters:') 110 | print(json.dumps(params, indent = 2)) 111 | main(params) 112 | -------------------------------------------------------------------------------- /scripts/prepro_labels.py: -------------------------------------------------------------------------------- 1 | """ 2 | Preprocess a raw json dataset into hdf5/json files for use in data_loader.lua 3 | 4 | Input: json file that has the form 5 | [{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...] 6 | example element in this list would look like 7 | {'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895} 8 | 9 | This script reads this json, does some basic preprocessing on the captions 10 | (e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays 11 | 12 | Output: a json file and an hdf5 file 13 | The hdf5 file contains several fields: 14 | /images is (N,3,256,256) uint8 array of raw image data in RGB format 15 | /labels is (M,max_length) uint32 array of encoded labels, zero padded 16 | /label_start_ix and /label_end_ix are (N,) uint32 arrays of pointers to the 17 | first and last indices (in range 1..M) of labels for each image 18 | /label_length stores the length of the sequence for each of the M sequences 19 | 20 | The json file has a dict that contains: 21 | - an 'ix_to_word' field storing the vocab in form {ix:'word'}, where ix is 1-indexed 22 | - an 'images' field that is a list holding auxiliary information for each image, 23 | such as in particular the 'split' it was assigned to. 24 | """ 25 | 26 | from __future__ import absolute_import 27 | from __future__ import division 28 | from __future__ import print_function 29 | 30 | import os 31 | import json 32 | import argparse 33 | from random import shuffle, seed 34 | import string 35 | # non-standard dependencies: 36 | import h5py 37 | import numpy as np 38 | import torch 39 | import torchvision.models as models 40 | from torch.autograd import Variable 41 | import skimage.io 42 | 43 | def build_vocab(imgs, params): 44 | count_thr = params['word_count_threshold'] 45 | 46 | # count up the number of words 47 | counts = {} 48 | for img in imgs: 49 | for sent in img['sentences']: 50 | for w in sent['tokens']: 51 | counts[w] = counts.get(w, 0) + 1 52 | cw = sorted([(count,w) for w,count in counts.items()], reverse=True) 53 | print('top words and their counts:') 54 | print('\n'.join(map(str,cw[:20]))) 55 | 56 | # print some stats 57 | total_words = sum(counts.values()) 58 | print('total words:', total_words) 59 | bad_words = [w for w,n in counts.items() if n <= count_thr] 60 | vocab = [w for w,n in counts.items() if n > count_thr] 61 | bad_count = sum(counts[w] for w in bad_words) 62 | print('number of bad words: %d/%d = %.2f%%' % (len(bad_words), len(counts), len(bad_words)*100.0/len(counts))) 63 | print('number of words in vocab would be %d' % (len(vocab), )) 64 | print('number of UNKs: %d/%d = %.2f%%' % (bad_count, total_words, bad_count*100.0/total_words)) 65 | 66 | # lets look at the distribution of lengths as well 67 | sent_lengths = {} 68 | for img in imgs: 69 | for sent in img['sentences']: 70 | txt = sent['tokens'] 71 | nw = len(txt) 72 | sent_lengths[nw] = sent_lengths.get(nw, 0) + 1 73 | max_len = max(sent_lengths.keys()) 74 | print('max length sentence in raw data: ', max_len) 75 | print('sentence length distribution (count, number of words):') 76 | sum_len = sum(sent_lengths.values()) 77 | for i in range(max_len+1): 78 | print('%2d: %10d %f%%' % (i, sent_lengths.get(i,0), sent_lengths.get(i,0)*100.0/sum_len)) 79 | 80 | # lets now produce the final annotations 81 | if bad_count > 0: 82 | # additional special UNK token we will use below to map infrequent words to 83 | print('inserting the special UNK token') 84 | vocab.append('UNK') 85 | 86 | for img in imgs: 87 | img['final_captions'] = [] 88 | for sent in img['sentences']: 89 | txt = sent['tokens'] 90 | caption = [w if counts.get(w,0) > count_thr else 'UNK' for w in txt] 91 | img['final_captions'].append(caption) 92 | 93 | return vocab 94 | 95 | def encode_captions(imgs, params, wtoi): 96 | """ 97 | encode all captions into one large array, which will be 1-indexed. 98 | also produces label_start_ix and label_end_ix which store 1-indexed 99 | and inclusive (Lua-style) pointers to the first and last caption for 100 | each image in the dataset. 101 | """ 102 | 103 | max_length = params['max_length'] 104 | N = len(imgs) 105 | M = sum(len(img['final_captions']) for img in imgs) # total number of captions 106 | 107 | label_arrays = [] 108 | label_start_ix = np.zeros(N, dtype='uint32') # note: these will be one-indexed 109 | label_end_ix = np.zeros(N, dtype='uint32') 110 | label_length = np.zeros(M, dtype='uint32') 111 | caption_counter = 0 112 | counter = 1 113 | for i,img in enumerate(imgs): 114 | n = len(img['final_captions']) 115 | assert n > 0, 'error: some image has no captions' 116 | 117 | Li = np.zeros((n, max_length), dtype='uint32') 118 | for j,s in enumerate(img['final_captions']): 119 | label_length[caption_counter] = min(max_length, len(s)) # record the length of this sequence 120 | caption_counter += 1 121 | for k,w in enumerate(s): 122 | if k < max_length: 123 | Li[j,k] = wtoi[w] 124 | 125 | # note: word indices are 1-indexed, and captions are padded with zeros 126 | label_arrays.append(Li) 127 | label_start_ix[i] = counter 128 | label_end_ix[i] = counter + n - 1 129 | 130 | counter += n 131 | 132 | L = np.concatenate(label_arrays, axis=0) # put all the labels together 133 | assert L.shape[0] == M, 'lengths don\'t match? that\'s weird' 134 | assert np.all(label_length > 0), 'error: some caption had no words?' 135 | 136 | print('encoded captions to array of size ', L.shape) 137 | return L, label_start_ix, label_end_ix, label_length 138 | 139 | def main(params): 140 | 141 | imgs = json.load(open(params['input_json'], 'r')) 142 | imgs = imgs['images'] 143 | 144 | seed(123) # make reproducible 145 | 146 | # create the vocab 147 | vocab = build_vocab(imgs, params) 148 | itow = {i+1:w for i,w in enumerate(vocab)} # a 1-indexed vocab translation table 149 | wtoi = {w:i+1 for i,w in enumerate(vocab)} # inverse table 150 | 151 | # encode captions in large arrays, ready to ship to hdf5 file 152 | L, label_start_ix, label_end_ix, label_length = encode_captions(imgs, params, wtoi) 153 | 154 | # create output h5 file 155 | N = len(imgs) 156 | f_lb = h5py.File(params['output_h5']+'_label.h5', "w") 157 | f_lb.create_dataset("labels", dtype='uint32', data=L) 158 | f_lb.create_dataset("label_start_ix", dtype='uint32', data=label_start_ix) 159 | f_lb.create_dataset("label_end_ix", dtype='uint32', data=label_end_ix) 160 | f_lb.create_dataset("label_length", dtype='uint32', data=label_length) 161 | f_lb.close() 162 | 163 | # create output json file 164 | out = {} 165 | out['ix_to_word'] = itow # encode the (1-indexed) vocab 166 | out['images'] = [] 167 | for i,img in enumerate(imgs): 168 | 169 | jimg = {} 170 | jimg['split'] = img['split'] 171 | if 'filename' in img: jimg['file_path'] = os.path.join(img['filepath'], img['filename']) # copy it over, might need 172 | if 'cocoid' in img: jimg['id'] = img['cocoid'] # copy over & mantain an id, if present (e.g. coco ids, useful) 173 | 174 | out['images'].append(jimg) 175 | 176 | json.dump(out, open(params['output_json'], 'w')) 177 | print('wrote ', params['output_json']) 178 | 179 | if __name__ == "__main__": 180 | 181 | parser = argparse.ArgumentParser() 182 | 183 | # input json 184 | parser.add_argument('--input_json', required=True, help='input json file to process into hdf5') 185 | parser.add_argument('--output_json', default='data.json', help='output json file') 186 | parser.add_argument('--output_h5', default='data', help='output h5 file') 187 | 188 | # options 189 | parser.add_argument('--max_length', default=16, type=int, help='max length of a caption, in number of words. captions longer than this get clipped.') 190 | parser.add_argument('--word_count_threshold', default=5, type=int, help='only words that occur more than this number of times will be put in vocab') 191 | 192 | args = parser.parse_args() 193 | params = vars(args) # convert to ordinary dict 194 | print('parsed input parameters:') 195 | print(json.dumps(params, indent = 2)) 196 | main(params) 197 | -------------------------------------------------------------------------------- /scripts/prepro_ngrams.py: -------------------------------------------------------------------------------- 1 | """ 2 | Preprocess a raw json dataset into hdf5/json files for use in data_loader.lua 3 | 4 | Input: json file that has the form 5 | [{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...] 6 | example element in this list would look like 7 | {'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895} 8 | 9 | This script reads this json, does some basic preprocessing on the captions 10 | (e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays 11 | 12 | Output: a json file and an hdf5 file 13 | The hdf5 file contains several fields: 14 | /images is (N,3,256,256) uint8 array of raw image data in RGB format 15 | /labels is (M,max_length) uint32 array of encoded labels, zero padded 16 | /label_start_ix and /label_end_ix are (N,) uint32 arrays of pointers to the 17 | first and last indices (in range 1..M) of labels for each image 18 | /label_length stores the length of the sequence for each of the M sequences 19 | 20 | The json file has a dict that contains: 21 | - an 'ix_to_word' field storing the vocab in form {ix:'word'}, where ix is 1-indexed 22 | - an 'images' field that is a list holding auxiliary information for each image, 23 | such as in particular the 'split' it was assigned to. 24 | """ 25 | 26 | import os 27 | import json 28 | import argparse 29 | from six.moves import cPickle 30 | from collections import defaultdict 31 | 32 | def precook(s, n=4, out=False): 33 | """ 34 | Takes a string as input and returns an object that can be given to 35 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 36 | can take string arguments as well. 37 | :param s: string : sentence to be converted into ngrams 38 | :param n: int : number of ngrams for which representation is calculated 39 | :return: term frequency vector for occuring ngrams 40 | """ 41 | words = s.split() 42 | counts = defaultdict(int) 43 | for k in xrange(1,n+1): 44 | for i in xrange(len(words)-k+1): 45 | ngram = tuple(words[i:i+k]) 46 | counts[ngram] += 1 47 | return counts 48 | 49 | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average" 50 | '''Takes a list of reference sentences for a single segment 51 | and returns an object that encapsulates everything that BLEU 52 | needs to know about them. 53 | :param refs: list of string : reference sentences for some image 54 | :param n: int : number of ngrams for which (ngram) representation is calculated 55 | :return: result (list of dict) 56 | ''' 57 | return [precook(ref, n) for ref in refs] 58 | 59 | def create_crefs(refs): 60 | crefs = [] 61 | for ref in refs: 62 | # ref is a list of 5 captions 63 | crefs.append(cook_refs(ref)) 64 | return crefs 65 | 66 | def compute_doc_freq(crefs): 67 | ''' 68 | Compute term frequency for reference data. 69 | This will be used to compute idf (inverse document frequency later) 70 | The term frequency is stored in the object 71 | :return: None 72 | ''' 73 | document_frequency = defaultdict(float) 74 | for refs in crefs: 75 | # refs, k ref captions of one image 76 | for ngram in set([ngram for ref in refs for (ngram,count) in ref.iteritems()]): 77 | document_frequency[ngram] += 1 78 | # maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 79 | return document_frequency 80 | 81 | def build_dict(imgs, wtoi, params): 82 | wtoi[''] = 0 83 | 84 | count_imgs = 0 85 | 86 | refs_words = [] 87 | refs_idxs = [] 88 | for img in imgs: 89 | if (params['split'] == img['split']) or \ 90 | (params['split'] == 'train' and img['split'] == 'restval') or \ 91 | (params['split'] == 'all'): 92 | #(params['split'] == 'val' and img['split'] == 'restval') or \ 93 | ref_words = [] 94 | ref_idxs = [] 95 | for sent in img['sentences']: 96 | tmp_tokens = sent['tokens'] + [''] 97 | tmp_tokens = [_ if _ in wtoi else 'UNK' for _ in tmp_tokens] 98 | ref_words.append(' '.join(tmp_tokens)) 99 | ref_idxs.append(' '.join([str(wtoi[_]) for _ in tmp_tokens])) 100 | refs_words.append(ref_words) 101 | refs_idxs.append(ref_idxs) 102 | count_imgs += 1 103 | print('total imgs:', count_imgs) 104 | 105 | ngram_words = compute_doc_freq(create_crefs(refs_words)) 106 | ngram_idxs = compute_doc_freq(create_crefs(refs_idxs)) 107 | return ngram_words, ngram_idxs, count_imgs 108 | 109 | def main(params): 110 | 111 | imgs = json.load(open(params['input_json'], 'r')) 112 | itow = json.load(open(params['dict_json'], 'r'))['ix_to_word'] 113 | wtoi = {w:i for i,w in itow.items()} 114 | 115 | imgs = imgs['images'] 116 | 117 | ngram_words, ngram_idxs, ref_len = build_dict(imgs, wtoi, params) 118 | 119 | cPickle.dump({'document_frequency': ngram_words, 'ref_len': ref_len}, open(params['output_pkl']+'-words.p','w'), protocol=cPickle.HIGHEST_PROTOCOL) 120 | cPickle.dump({'document_frequency': ngram_idxs, 'ref_len': ref_len}, open(params['output_pkl']+'-idxs.p','w'), protocol=cPickle.HIGHEST_PROTOCOL) 121 | 122 | if __name__ == "__main__": 123 | 124 | parser = argparse.ArgumentParser() 125 | 126 | # input json 127 | parser.add_argument('--input_json', default='/home-nfs/rluo/rluo/nips/code/prepro/dataset_coco.json', help='input json file to process into hdf5') 128 | parser.add_argument('--dict_json', default='data/cocotalk.json', help='output json file') 129 | parser.add_argument('--output_pkl', default='data/coco-all', help='output pickle file') 130 | parser.add_argument('--split', default='all', help='test, val, train, all') 131 | args = parser.parse_args() 132 | params = vars(args) # convert to ordinary dict 133 | 134 | main(params) 135 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.autograd import Variable 8 | import torch.optim as optim 9 | 10 | import numpy as np 11 | 12 | import time 13 | import os 14 | from six.moves import cPickle 15 | 16 | import opts 17 | import models 18 | from dataloader import * 19 | import eval_utils 20 | import misc.utils as utils 21 | from misc.rewards import init_cider_scorer, get_self_critical_reward 22 | 23 | try: 24 | import tensorflow as tf 25 | except ImportError: 26 | print("Tensorflow not installed; No tensorboard logging.") 27 | tf = None 28 | 29 | def add_summary_value(writer, key, value, iteration): 30 | summary = tf.Summary(value=[tf.Summary.Value(tag=key, simple_value=value)]) 31 | writer.add_summary(summary, iteration) 32 | 33 | def train(opt): 34 | opt.use_att = utils.if_use_att(opt.caption_model) 35 | loader = DataLoader(opt) 36 | opt.vocab_size = loader.vocab_size 37 | opt.seq_length = loader.seq_length 38 | 39 | tf_summary_writer = tf and tf.summary.FileWriter(opt.checkpoint_path) 40 | 41 | infos = {} 42 | histories = {} 43 | if opt.start_from is not None: 44 | # open old infos and check if models are compatible 45 | with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl')) as f: 46 | infos = cPickle.load(f) 47 | saved_model_opt = infos['opt'] 48 | need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"] 49 | for checkme in need_be_same: 50 | assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme 51 | 52 | if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')): 53 | with open(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')) as f: 54 | histories = cPickle.load(f) 55 | 56 | iteration = infos.get('iter', 0) 57 | epoch = infos.get('epoch', 0) 58 | 59 | val_result_history = histories.get('val_result_history', {}) 60 | loss_history = histories.get('loss_history', {}) 61 | lr_history = histories.get('lr_history', {}) 62 | ss_prob_history = histories.get('ss_prob_history', {}) 63 | 64 | loader.iterators = infos.get('iterators', loader.iterators) 65 | loader.split_ix = infos.get('split_ix', loader.split_ix) 66 | if opt.load_best_score == 1: 67 | best_val_score = infos.get('best_val_score', None) 68 | 69 | model = models.setup(opt) 70 | model.cuda() 71 | 72 | update_lr_flag = True 73 | # Assure in training mode 74 | model.train() 75 | 76 | crit = utils.LanguageModelCriterion() 77 | rl_crit = utils.RewardCriterion() 78 | 79 | optimizer = optim.Adam(model.parameters(), lr=opt.learning_rate, weight_decay=opt.weight_decay) 80 | 81 | # Load the optimizer 82 | if vars(opt).get('start_from', None) is not None and os.path.isfile(os.path.join(opt.start_from,"optimizer.pth")): 83 | optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer.pth'))) 84 | 85 | while True: 86 | if update_lr_flag: 87 | # Assign the learning rate 88 | if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: 89 | frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every 90 | decay_factor = opt.learning_rate_decay_rate ** frac 91 | opt.current_lr = opt.learning_rate * decay_factor 92 | else: 93 | opt.current_lr = opt.learning_rate 94 | utils.set_lr(optimizer, opt.current_lr) 95 | # Assign the scheduled sampling prob 96 | if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: 97 | frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every 98 | opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob) 99 | model.ss_prob = opt.ss_prob 100 | 101 | # If start self critical training 102 | if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: 103 | sc_flag = True 104 | init_cider_scorer(opt.cached_tokens) 105 | else: 106 | sc_flag = False 107 | 108 | update_lr_flag = False 109 | 110 | start = time.time() 111 | # Load data from train split (0) 112 | data = loader.get_batch('train') 113 | print('Read data:', time.time() - start) 114 | 115 | torch.cuda.synchronize() 116 | start = time.time() 117 | 118 | tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks']] 119 | tmp = [Variable(torch.from_numpy(_), requires_grad=False).cuda() for _ in tmp] 120 | fc_feats, att_feats, labels, masks = tmp 121 | 122 | optimizer.zero_grad() 123 | if not sc_flag: 124 | loss = crit(model(fc_feats, att_feats, labels), labels[:,1:], masks[:,1:]) 125 | else: 126 | gen_result, sample_logprobs = model.sample(fc_feats, att_feats, {'sample_max':0}) 127 | reward = get_self_critical_reward(model, fc_feats, att_feats, data, gen_result) 128 | loss = rl_crit(sample_logprobs, gen_result, Variable(torch.from_numpy(reward).float().cuda(), requires_grad=False)) 129 | 130 | loss.backward() 131 | utils.clip_gradient(optimizer, opt.grad_clip) 132 | optimizer.step() 133 | train_loss = loss.data[0] 134 | torch.cuda.synchronize() 135 | end = time.time() 136 | if not sc_flag: 137 | print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \ 138 | .format(iteration, epoch, train_loss, end - start)) 139 | else: 140 | print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \ 141 | .format(iteration, epoch, np.mean(reward[:,0]), end - start)) 142 | 143 | # Update the iteration and epoch 144 | iteration += 1 145 | if data['bounds']['wrapped']: 146 | epoch += 1 147 | update_lr_flag = True 148 | 149 | # Write the training loss summary 150 | if (iteration % opt.losses_log_every == 0): 151 | if tf is not None: 152 | add_summary_value(tf_summary_writer, 'train_loss', train_loss, iteration) 153 | add_summary_value(tf_summary_writer, 'learning_rate', opt.current_lr, iteration) 154 | add_summary_value(tf_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration) 155 | if sc_flag: 156 | add_summary_value(tf_summary_writer, 'avg_reward', np.mean(reward[:,0]), iteration) 157 | tf_summary_writer.flush() 158 | 159 | loss_history[iteration] = train_loss if not sc_flag else np.mean(reward[:,0]) 160 | lr_history[iteration] = opt.current_lr 161 | ss_prob_history[iteration] = model.ss_prob 162 | 163 | # make evaluation on validation set, and save model 164 | if (iteration % opt.save_checkpoint_every == 0): 165 | # eval model 166 | eval_kwargs = {'split': 'val', 167 | 'dataset': opt.input_json} 168 | eval_kwargs.update(vars(opt)) 169 | val_loss, predictions, lang_stats = eval_utils.eval_split(model, crit, loader, eval_kwargs) 170 | 171 | # Write validation result into summary 172 | if tf is not None: 173 | add_summary_value(tf_summary_writer, 'validation loss', val_loss, iteration) 174 | if lang_stats is not None: 175 | for k,v in lang_stats.items(): 176 | add_summary_value(tf_summary_writer, k, v, iteration) 177 | tf_summary_writer.flush() 178 | val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions} 179 | 180 | # Save model if is improving on validation result 181 | if opt.language_eval == 1: 182 | current_score = lang_stats['CIDEr'] 183 | else: 184 | current_score = - val_loss 185 | 186 | best_flag = False 187 | if True: # if true 188 | if best_val_score is None or current_score > best_val_score: 189 | best_val_score = current_score 190 | best_flag = True 191 | checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth') 192 | torch.save(model.state_dict(), checkpoint_path) 193 | print("model saved to {}".format(checkpoint_path)) 194 | optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth') 195 | torch.save(optimizer.state_dict(), optimizer_path) 196 | 197 | # Dump miscalleous informations 198 | infos['iter'] = iteration 199 | infos['epoch'] = epoch 200 | infos['iterators'] = loader.iterators 201 | infos['split_ix'] = loader.split_ix 202 | infos['best_val_score'] = best_val_score 203 | infos['opt'] = opt 204 | infos['vocab'] = loader.get_vocab() 205 | 206 | histories['val_result_history'] = val_result_history 207 | histories['loss_history'] = loss_history 208 | histories['lr_history'] = lr_history 209 | histories['ss_prob_history'] = ss_prob_history 210 | with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'.pkl'), 'wb') as f: 211 | cPickle.dump(infos, f) 212 | with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+'.pkl'), 'wb') as f: 213 | cPickle.dump(histories, f) 214 | 215 | if best_flag: 216 | checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best.pth') 217 | torch.save(model.state_dict(), checkpoint_path) 218 | print("model saved to {}".format(checkpoint_path)) 219 | with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'-best.pkl'), 'wb') as f: 220 | cPickle.dump(infos, f) 221 | 222 | # Stop if reaching max epochs 223 | if epoch >= opt.max_epochs and opt.max_epochs != -1: 224 | break 225 | 226 | opt = opts.parse_opt() 227 | train(opt) 228 | -------------------------------------------------------------------------------- /vis/imgs/dummy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALISCIFP/Medical-Image-Caption/e94a3a88834219ff7bdb05914e25c868e33fa4e6/vis/imgs/dummy -------------------------------------------------------------------------------- /vis/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | neuraltalk2 results visualization 7 | 8 | 42 | 43 | 44 |
45 | 72 | 73 | 74 | --------------------------------------------------------------------------------