├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── construct_vocab.py ├── datasets.py ├── get_best.py ├── interactive.py ├── nets.py ├── prepare_scripts ├── download.sh ├── prepare_dataset.sh ├── unzip.sh └── vocab.sh ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | data 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | env/ 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # SageMath parsed files 83 | *.sage.py 84 | 85 | # dotenv 86 | .env 87 | 88 | # virtualenv 89 | .venv 90 | venv/ 91 | ENV/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "coco-caption"] 2 | path = coco-caption 3 | url = https://github.com/tylin/coco-caption 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Sosuke Kobayashi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Image Captioning by Chainer 2 | 3 | A Chainer implementation of [Neural Image Caption](https://www.cv-foundation.org/openaccess/content_cvpr_2015/papers/Vinyals_Show_and_Tell_2015_CVPR_paper.pdf), which generates captions given images. 4 | 5 | This implementation is fast, because it uses cudnn-based LSTM (NStepLSTM) and beam search can deal with batch processing. 6 | 7 | This code uses the [coco-caption](https://github.com/tylin/coco-caption) as a submodule. 8 | So, please clone this repository as follows: 9 | ``` 10 | git clone --recursive https://github.com/soskek/captioning_chainer.git 11 | ``` 12 | 13 | Furthermore, the [coco-caption](https://github.com/tylin/coco-caption) works on python 2.7 only. Thus, this repository also follows it. 14 | 15 | 16 | ## Train an Image Caption Generator 17 | 18 | ``` 19 | sh prepare_scripts/prepare_dataset.sh 20 | ``` 21 | 22 | ``` 23 | # flickr8k, flickr30k, mscoco 24 | python -u train.py -g 0 --vocab data/flickr8k/vocab.txt --dataset flickr8k -b 64 25 | python -u train.py -g 0 --vocab data/flickr30k/vocab.txt --dataset flickr30k -b 64 26 | python -u train.py -g 0 --vocab data/coco/vocab.txt --dataset mscoco -b 64 27 | ``` 28 | 29 | On the mscoco dataset, with beam size of 20, a trained model reached BELU 25.9. 30 | The paper uses ensemble and (unwritten) hyperparameters, which can cause the gap between this and the value reported in the paper. 31 | 32 | ## Use the model 33 | 34 | ``` 35 | python interactive.py --resume result/best_model.npz --vocab data/flickr8k/vocab.txt 36 | ``` 37 | 38 | After launched, enter the path of an image file. 39 | 40 | 41 | ## See Best Result and Plot Curve 42 | 43 | ``` 44 | python get_best.py --log result/log 45 | ``` 46 | 47 | 48 | ## Citation 49 | 50 | ``` 51 | @article{Vinyals2015ShowAT, 52 | title={Show and tell: A neural image caption generator}, 53 | author={Oriol Vinyals and Alexander Toshev and Samy Bengio and Dumitru Erhan}, 54 | journal={2015 IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 55 | year={2015}, 56 | pages={3156-3164} 57 | } 58 | ``` 59 | -------------------------------------------------------------------------------- /construct_vocab.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import sys 4 | 5 | import datasets 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--dataset', '-d', default='mscoco') 9 | parser.add_argument('--threshold', '-t', type=int, default=5) 10 | parser.add_argument('--out', '-o', default='vocab.txt') 11 | args = parser.parse_args() 12 | 13 | directory = datasets.get_default_dataset_path(args.dataset) 14 | vocab, count = datasets.construct_vocab( 15 | directory, max_vocab_size=1e8, min_freq=5, with_count=True) 16 | 17 | json.dump(vocab, open(args.out, 'w')) 18 | json.dump(count, open(args.out + '.count', 'w')) 19 | 20 | sys.stderr.write('# of words: {}\n'.format(len(vocab))) 21 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import collections 3 | import json 4 | import os 5 | 6 | import numpy as np 7 | import scipy.io 8 | 9 | import utils 10 | 11 | 12 | def get_default_dataset_path(dataset): 13 | if 'coco' in dataset: 14 | return 'data/coco' 15 | elif 'flickr8k' in dataset: 16 | return 'data/flickr8k' 17 | elif 'flickr30k' in dataset: 18 | return 'data/flickr30k' 19 | else: 20 | raise NotImplementedError() 21 | 22 | 23 | def construct_vocab(directory, max_vocab_size=50000, min_freq=5, 24 | with_count=False): 25 | counts = collections.defaultdict(int) 26 | caption_path = os.path.join(directory, 'dataset.json') 27 | caption_dataset = json.load(open(caption_path))['images'] 28 | for cap_set in caption_dataset: 29 | this_split = cap_set['split'] 30 | sentences = cap_set['sentences'] 31 | if this_split == 'train': 32 | for sent in sentences: 33 | tokens = sent['tokens'] 34 | for token in tokens: 35 | counts[token] += 1 36 | 37 | vocab = {'': 0, '': 1} 38 | for w, c in sorted(counts.items(), key=lambda x: (-x[1], x[0])): 39 | if len(vocab) >= max_vocab_size or c < min_freq: 40 | break 41 | vocab[w] = len(vocab) 42 | if with_count: 43 | return vocab, dict(counts) 44 | else: 45 | return vocab 46 | 47 | 48 | def load_caption_dataset(vocab, directory, split): 49 | dataset = [] 50 | vec_path = os.path.join(directory, 'vgg_feats.mat') 51 | vecs = scipy.io.loadmat(vec_path)['feats'].T 52 | vecs = [v[0] for v in np.split(vecs, vecs.shape[0], axis=0)] 53 | caption_path = os.path.join(directory, 'dataset.json') 54 | caption_dataset = json.load(open(caption_path))['images'] 55 | for cap_set in caption_dataset: 56 | this_split = cap_set['split'] 57 | img_id = cap_set['imgid'] 58 | sentences = cap_set['sentences'] 59 | other = {'img_id': img_id} 60 | if this_split == split: 61 | img_vec = vecs[img_id] 62 | for sent in sentences: 63 | tokens = sent['tokens'] 64 | sent_array = utils.make_array( 65 | tokens, vocab, add_eos=True, add_bos=True) 66 | dataset.append((img_vec, sent_array, other)) 67 | return dataset 68 | 69 | 70 | def get_caption_dataset(vocab, directory, split): 71 | if isinstance(split, (list, tuple)): 72 | datas = [load_caption_dataset(vocab, directory, sp) 73 | for sp in split] 74 | return datas 75 | else: 76 | return load_caption_dataset(vocab, directory, split) 77 | -------------------------------------------------------------------------------- /get_best.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import json 4 | 5 | import numpy as np 6 | import matplotlib 7 | matplotlib.use('Agg') 8 | import matplotlib.pyplot as plot 9 | 10 | 11 | def plot_result(key, xy, out_path): 12 | f = plot.figure() 13 | a = f.add_subplot(111) 14 | a.set_xlabel(key) 15 | a.grid() 16 | 17 | xy = np.array(xy) 18 | a.plot(xy[:, 0], xy[:, 1], marker='.', label=key) 19 | 20 | l = a.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 21 | f.savefig(out_path, bbox_extra_artists=(l,), bbox_inches='tight') 22 | 23 | plot.close() 24 | 25 | 26 | def main(): 27 | parser = argparse.ArgumentParser(description='Image Comprehension') 28 | parser.add_argument('--key', '-n', default='val/bleu') 29 | parser.add_argument('--log', '-l', required=True) 30 | parser.add_argument('--min', action='store_true') 31 | parser.add_argument('--plot-path', '-p', default='plot.png') 32 | args = parser.parse_args() 33 | print(json.dumps(args.__dict__, indent=2)) 34 | 35 | log = json.load(open(args.log)) 36 | best = None 37 | best_iter = -1 38 | xy = [] 39 | for out in log: 40 | i = out['iteration'] 41 | if args.key in out: 42 | v = out[args.key] 43 | xy.append((i, v)) 44 | else: 45 | continue 46 | update = False 47 | if best is None: 48 | update = True 49 | if args.min: 50 | if v < best: 51 | update = True 52 | else: 53 | if v > best: 54 | update = True 55 | if update: 56 | best = v 57 | best_iter = i 58 | print('iter', best_iter) 59 | print(args.key, best) 60 | 61 | plot_result(args.key, xy, args.plot_path) 62 | 63 | 64 | if __name__ == '__main__': 65 | main() 66 | -------------------------------------------------------------------------------- /interactive.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import json 4 | 5 | import chainer 6 | import numpy as np 7 | from PIL import Image 8 | 9 | import nets 10 | 11 | try: 12 | # p2 13 | input_method = raw_input 14 | except Exception: 15 | # p3 16 | input_method = input 17 | 18 | print('import finished') 19 | 20 | 21 | def main(): 22 | parser = argparse.ArgumentParser(description='Image Comprehension') 23 | parser.add_argument('--gpu', '-g', type=int, default=-1, 24 | help='GPU ID (negative value indicates CPU)') 25 | parser.add_argument('--vocab', '-v', default='vocab.json') 26 | parser.add_argument('--resume', '-r', required=True, 27 | help='Resume the training from snapshot') 28 | args = parser.parse_args() 29 | print(json.dumps(args.__dict__, indent=2)) 30 | 31 | print('read vocab') 32 | vocab = json.load(open(args.vocab)) 33 | rev_vocab = {i: w for w, i in vocab.items()} 34 | 35 | print('setup model') 36 | tmp = np.load(args.resume) 37 | n_vocab, n_units = tmp['embed/W'].shape 38 | n_layer = max(int(key.split('/')[1]) 39 | for key in tmp.keys() if 'rnn' in key) + 1 40 | model = nets.RNNDecoder( 41 | n_layer, n_vocab, n_units, 42 | dropout=0., 43 | eos_id=vocab['']) 44 | cnn = chainer.links.VGG16Layers() 45 | if args.gpu >= 0: 46 | chainer.cuda.get_device_from_id(args.gpu).use() 47 | model.to_gpu() 48 | cnn.to_gpu() 49 | 50 | if args.resume: 51 | print('load model', args.resume) 52 | chainer.serializers.load_npz(args.resume, model) 53 | 54 | while 1: 55 | inp = input_method('image>>').strip() 56 | path = inp.strip().lower() 57 | if not path: 58 | continue 59 | 60 | vecs = cnn.extract([Image.open(path)]).popitem()[1].data 61 | result = model.decode(vecs)[0] 62 | outs = result['outs'] 63 | score = result['score'] 64 | print(' '.join(rev_vocab[wi] for wi in outs)) 65 | print(score) 66 | 67 | 68 | if __name__ == '__main__': 69 | main() 70 | -------------------------------------------------------------------------------- /nets.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | import numpy 4 | 5 | import chainer 6 | from chainer import cuda 7 | import chainer.functions as F 8 | import chainer.links as L 9 | from chainer import reporter 10 | 11 | 12 | def sequence_embed(embed, xs, dropout=0.): 13 | """Efficient embedding function for variable-length sequences 14 | 15 | This output is equally to 16 | "return [F.dropout(embed(x), ratio=dropout) for x in xs]". 17 | However, calling the functions is one-shot and faster. 18 | 19 | Args: 20 | embed (callable): A :func:`~chainer.functions.embed_id` function 21 | or :class:`~chainer.links.EmbedID` link. 22 | xs (list of :class:`~chainer.Variable` or :class:`numpy.ndarray` or \ 23 | :class:`cupy.ndarray`): i-th element in the list is an input variable, 24 | which is a :math:`(L_i, )`-shaped int array. 25 | dropout (float): Dropout ratio. 26 | 27 | Returns: 28 | list of ~chainer.Variable: Output variables. i-th element in the 29 | list is an output variable, which is a :math:`(L_i, N)`-shaped 30 | float array. :math:`(N)` is the number of dimensions of word embedding. 31 | 32 | """ 33 | x_len = [len(x) for x in xs] 34 | x_section = numpy.cumsum(x_len[:-1]) 35 | ex = embed(F.concat(xs, axis=0)) 36 | ex = F.dropout(ex, ratio=dropout) 37 | exs = F.split_axis(ex, x_section, 0) 38 | return exs 39 | 40 | 41 | def get_topk(output, k=20): 42 | batchsize, n_out = output.shape 43 | xp = cuda.get_array_module(output) 44 | argsort = xp.argsort(output, axis=1) 45 | argtopk = argsort[:, ::-1][:, :k] 46 | assert(argtopk.shape == (batchsize, k)), (argtopk.shape, (batchsize, k)) 47 | topk_score = output.take( 48 | argtopk + xp.arange(batchsize)[:, None] * n_out) 49 | return argtopk, topk_score 50 | 51 | 52 | def update_beam_state(outs, total_score, topk, topk_score, h, c, eos_id): 53 | xp = cuda.get_array_module(h) 54 | full = outs.shape[0] 55 | prev_full, k = topk.shape 56 | batch = full // k 57 | prev_k = prev_full // batch 58 | assert(prev_k in [1, k]) 59 | 60 | if total_score is None: 61 | total_score = topk_score 62 | else: 63 | is_end = xp.max(outs == eos_id, axis=1) 64 | is_end = xp.broadcast_to(is_end[:, None], topk_score.shape) 65 | bias = xp.zeros_like(topk_score, numpy.float32) 66 | bias[:, 1:] = -10000. # remove ended cands except for a consequence 67 | total_score = xp.where( 68 | is_end, 69 | total_score[:, None] + bias, 70 | total_score[:, None] + topk_score) 71 | assert(xp.all(total_score < 0.)) 72 | topk = xp.where(is_end, eos_id, topk) # this is not required 73 | total_score = total_score.reshape((prev_full // prev_k, prev_k * k)) 74 | argtopk, total_topk_score = get_topk(total_score, k=k) 75 | assert(argtopk.shape == (prev_full // prev_k, k)) 76 | assert(total_topk_score.shape == (prev_full // prev_k, k)) 77 | total_topk = topk.take( 78 | argtopk + xp.arange(prev_full // prev_k)[:, None] * prev_k * k) 79 | total_topk = total_topk.reshape((full, )) 80 | total_topk_score = total_topk_score.reshape((full, )) 81 | 82 | argtopk = argtopk // k + \ 83 | xp.arange(prev_full // prev_k)[:, None] * prev_k 84 | argtopk = argtopk.reshape((full, )).tolist() 85 | 86 | hs = F.separate(h, axis=1) 87 | cs = F.separate(c, axis=1) 88 | next_h = F.stack([hs[i] for i in argtopk], axis=1) 89 | next_c = F.stack([cs[i] for i in argtopk], axis=1) 90 | 91 | outs = xp.stack([outs[i] for i in argtopk], axis=0) 92 | outs = xp.concatenate([outs, total_topk[:, None]], 93 | axis=1).astype(numpy.int32) 94 | 95 | return outs, total_topk_score, next_h, next_c 96 | 97 | 98 | def finish_beam(outs, total_score, batchsize, eos_id): 99 | k = outs.shape[0] // batchsize 100 | result_batch = collections.defaultdict( 101 | lambda: {'outs': [], 'score': -1e8}) 102 | for i in range(batchsize): 103 | for j in range(k): 104 | score = total_score[i * k + j] 105 | if result_batch[i]['score'] < score: 106 | out = outs[i * k + j].tolist() 107 | if eos_id in out: 108 | out = out[:out.index(eos_id)] 109 | result_batch[i] = {'outs': out, 'score': score} 110 | 111 | result_batch = [ 112 | result for i, result in 113 | sorted(result_batch.items(), key=lambda x: x[0])] 114 | return result_batch 115 | 116 | 117 | class RNNDecoder(chainer.Chain): 118 | 119 | """A LSTM-RNN Decoder with Word Embedding. 120 | 121 | This model decodes a sentence sequentially using LSTM. 122 | 123 | Args: 124 | n_layers (int): The number of LSTM layers. 125 | n_vocab (int): The size of vocabulary. 126 | n_units (int): The number of units of a LSTM layer and word embedding. 127 | dropout (float): The dropout ratio. 128 | 129 | """ 130 | 131 | def __init__(self, n_layers, n_vocab, n_units, 132 | dropout=0.5, eos_id=0, max_decode_length=40): 133 | super(RNNDecoder, self).__init__( 134 | transform=L.Linear(None, n_units), 135 | embed=L.EmbedID(n_vocab, n_units), 136 | rnn=L.NStepLSTM(n_layers, n_units, n_units, dropout), 137 | output=L.Linear(n_units, n_vocab), 138 | ) 139 | self.n_layers = n_layers 140 | self.n_units = n_units 141 | self.dropout = dropout 142 | self.eos_id = eos_id 143 | self.max_decode_length = max_decode_length 144 | 145 | def __call__(self, xs, ys, others): 146 | return self.calculate_loss(xs, ys, others) 147 | 148 | def calculate_loss(self, xs, ys, others): 149 | h, c = self.prepare(xs) 150 | input_ys = [y[:-1] for y in ys] 151 | target_ys = [y[1:] for y in ys] 152 | es = sequence_embed(self.embed, input_ys, self.dropout) 153 | h, c, hs = self.rnn(h, c, es) 154 | concat_h = F.dropout(F.concat(hs, axis=0), self.dropout) 155 | concat_output = self.output(concat_h) 156 | concat_target = F.concat(target_ys, axis=0) 157 | loss = F.softmax_cross_entropy(concat_output, concat_target) 158 | accuracy = F.accuracy(concat_output, concat_target) 159 | reporter.report({'loss': loss.data}, self) 160 | reporter.report({'perp': self.xp.exp(loss.data)}, self) 161 | reporter.report({'acc': accuracy.data}, self) 162 | return loss 163 | 164 | def evaluate(self, *args, **kargs): 165 | return self.calculate_loss(*args, **kargs) 166 | 167 | def decode(self, xs, k=20): 168 | batchsize = len(xs) 169 | h, c = self.prepare(xs) 170 | input_ys = [self.xp.array([self.eos_id], 'i')] * batchsize 171 | outs = self.xp.array([[]] * batchsize * k, 'i') 172 | total_score = None 173 | for i in range(self.max_decode_length): 174 | es = sequence_embed(self.embed, input_ys, 0) 175 | h, c, hs = self.rnn(h, c, es) 176 | 177 | concat_h = F.concat(hs, axis=0) 178 | concat_output = self.output(concat_h) 179 | topk, topk_score = get_topk( 180 | F.log_softmax(concat_output).data, k=k) 181 | assert(self.xp.all(topk_score <= 0.)) 182 | 183 | outs, total_score, h, c = update_beam_state( 184 | outs, total_score, topk, topk_score, h, c, self.eos_id) 185 | assert(self.xp.all(total_score < 0.)), i 186 | input_ys = self.xp.split(outs[:, -1], outs.shape[0], axis=0) 187 | if self.xp.max(outs == self.eos_id, axis=1).sum() == outs.shape[0]: 188 | # all cands meet eos, end 189 | break 190 | result = finish_beam(outs, total_score, batchsize, self.eos_id) 191 | return result 192 | 193 | def prepare(self, xs): 194 | xs = F.split_axis(F.dropout(self.transform( 195 | self.xp.stack(xs, axis=0)), self.dropout), 196 | len(xs), axis=0) 197 | h, c, _ = self.rnn(None, None, xs) 198 | return h, c 199 | -------------------------------------------------------------------------------- /prepare_scripts/download.sh: -------------------------------------------------------------------------------- 1 | mkdir data 2 | cd data 3 | 4 | curl https://cs.stanford.edu/people/karpathy/deepimagesent/flickr8k.zip -o flickr8k.zip -k 5 | curl https://cs.stanford.edu/people/karpathy/deepimagesent/flickr30k.zip -o flickr30k.zip -k 6 | curl https://cs.stanford.edu/people/karpathy/deepimagesent/coco.zip -o coco.zip -k 7 | -------------------------------------------------------------------------------- /prepare_scripts/prepare_dataset.sh: -------------------------------------------------------------------------------- 1 | echo "\n\n\t# DOWNLOAD\n\n" 2 | sh prepare_scripts/download.sh 3 | 4 | echo "\n\n\t# UNZIP\n\n" 5 | sh prepare_scripts/unzip.sh 6 | 7 | echo "\n\n\t# COUNT and MAKE VOCABULARY\n\n" 8 | sh prepare_scripts/vocab.sh 9 | -------------------------------------------------------------------------------- /prepare_scripts/unzip.sh: -------------------------------------------------------------------------------- 1 | cd data 2 | unzip flickr8k.zip 3 | unzip flickr30k.zip 4 | unzip coco.zip 5 | -------------------------------------------------------------------------------- /prepare_scripts/vocab.sh: -------------------------------------------------------------------------------- 1 | python construct_vocab.py -d flickr8k -t 5 -o data/flickr8k/vocab.txt 2 | python construct_vocab.py -d flickr30k -t 5 -o data/flickr30k/vocab.txt 3 | python construct_vocab.py -d mscoco -t 5 -o data/coco/vocab.txt 4 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import json 4 | 5 | import numpy as np 6 | 7 | import chainer 8 | from chainer import training 9 | from chainer.training import extensions 10 | 11 | import datasets 12 | import nets 13 | import utils 14 | 15 | 16 | def main(): 17 | parser = argparse.ArgumentParser(description='Image Comprehension') 18 | parser.add_argument('--batchsize', '-b', type=int, default=64, 19 | help='Number of images in each mini-batch') 20 | parser.add_argument('--learnrate', '-l', type=float, default=1e-4, 21 | help='Learning rate') 22 | parser.add_argument('--epoch', '-e', type=int, default=40, 23 | help='Number of sweeps over the dataset to train') 24 | parser.add_argument('--unit', '-u', type=int, default=512, 25 | help='Number of units') 26 | parser.add_argument('--layer', type=int, default=1) 27 | parser.add_argument('--dropout', '-d', type=float, default=0.3, 28 | help='Dropout rate for MLP') 29 | parser.add_argument('--gpu', '-g', type=int, default=0, 30 | help='GPU ID (negative value indicates CPU)') 31 | parser.add_argument('--vocab', '-v', required=True, 32 | help='Text file of vocab') 33 | parser.add_argument('--out', '-o', default='result', 34 | help='Directory to output the result') 35 | parser.add_argument('--dataset', default='mscoco', 36 | choices=['mscoco', 'flickr8k', 'flickr30k']) 37 | parser.add_argument('--beam', type=int, default=5) 38 | parser.add_argument('--print-sentence-mod', type=int, default=200) 39 | parser.add_argument('--resume') 40 | parser.add_argument('--resume-rnn') 41 | parser.add_argument('--resume-wordemb') 42 | parser.add_argument('--init-output-by-embed', action='store_true') 43 | args = parser.parse_args() 44 | print(json.dumps(args.__dict__, indent=2)) 45 | 46 | print('read vocab') 47 | vocab = json.load(open(args.vocab)) 48 | 49 | print('read dataset') 50 | directory = datasets.get_default_dataset_path(args.dataset) 51 | train, valid = datasets.get_caption_dataset( 52 | vocab, directory, ['train', 'val']) 53 | 54 | print('# train', len(train)) 55 | print('# valid', len(valid)) 56 | # this number is the number of references 57 | # it shuld be the number of images 58 | 59 | print('setup model') 60 | np.random.seed(777) 61 | model = nets.RNNDecoder( 62 | args.layer, len(vocab), args.unit, 63 | dropout=args.dropout, 64 | eos_id=vocab['']) 65 | if args.gpu >= 0: 66 | chainer.cuda.get_device_from_id(args.gpu).use() 67 | model.to_gpu() 68 | 69 | print('setup trainer') 70 | print(' optimizer') 71 | optimizer = chainer.optimizers.Adam(args.learnrate) 72 | optimizer.setup(model) 73 | 74 | if args.resume: 75 | print('load model', args.resume) 76 | chainer.serializers.load_npz(args.resume, model) 77 | if args.resume_rnn: 78 | print('load RNN model', args.resume_rnn) 79 | utils.load_npz_partially(args.resume_rnn, model, 80 | target_words=['rnn/']) 81 | if args.resume_wordemb: 82 | print('load Word Embedding model', args.resume_wordemb) 83 | utils.load_npz_partially(args.resume_wordemb, model, 84 | target_words=['embed/']) 85 | if args.init_output_by_embed: 86 | print('copy Word Embedding to Output Matrix') 87 | model.output.W.data[:] = model.embed.W.data 88 | 89 | print(' iterator') 90 | model.xp.random.seed(777) 91 | train_iter = chainer.iterators.SerialIterator(train, args.batchsize) 92 | valid_iter = chainer.iterators.SerialIterator(valid, args.batchsize, 93 | repeat=False, shuffle=False) 94 | 95 | print(' updater') 96 | updater = training.StandardUpdater( 97 | train_iter, optimizer, 98 | converter=utils.convert, device=args.gpu, 99 | loss_func=model.calculate_loss) 100 | trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out) 101 | 102 | print(' extensions') 103 | iter_per_epoch = len(train) // args.batchsize 104 | log_trigger = (iter_per_epoch // 1, 'iteration') 105 | eval_trigger = (log_trigger[0] * 1, 'iteration') # every 1 epoch 106 | 107 | trainer.extend(extensions.Evaluator( 108 | valid_iter, model, 109 | converter=utils.convert, device=args.gpu, 110 | eval_func=model.evaluate), 111 | trigger=eval_trigger) 112 | 113 | trainer.extend(utils.SentenceEvaluater( 114 | model, valid, vocab, 'val/', 115 | batchsize=args.batchsize, 116 | device=args.gpu, 117 | k=args.beam, 118 | print_sentence_mod=args.print_sentence_mod), 119 | trigger=eval_trigger) 120 | 121 | record_trigger = training.triggers.MaxValueTrigger( 122 | 'val/bleu', 123 | trigger=eval_trigger) 124 | trainer.extend(extensions.snapshot_object( 125 | model, 'best_model.npz'), 126 | trigger=record_trigger) 127 | 128 | trainer.extend(extensions.LogReport(trigger=log_trigger), 129 | trigger=log_trigger) 130 | trainer.extend(extensions.PrintReport( 131 | ['epoch', 'iteration', 132 | 'main/perp', 'validation/main/perp', 133 | # 'main/acc', 'validation/main/acc', 134 | 'val/bleu', 135 | 'val/rouge', 136 | 'val/cider', 137 | 'val/meteor', 138 | 'elapsed_time']), 139 | trigger=log_trigger) 140 | 141 | if eval_trigger[0] % log_trigger[0] != 0: 142 | print('eval_trigger % log_trigger != 0.\n' 143 | 'So, some evaluation results can not be logged and shown.') 144 | 145 | trainer.extend(extensions.ProgressBar()) 146 | 147 | if args.resume: 148 | print(extensions.Evaluator( 149 | valid_iter, model, 150 | converter=utils.convert, device=args.gpu, 151 | eval_func=model.evaluate)()) 152 | 153 | print('log_trigger ', log_trigger) 154 | print('eval_trigger', eval_trigger) 155 | print('START training. # iter/epoch=', iter_per_epoch) 156 | 157 | trainer.run() 158 | 159 | 160 | if __name__ == '__main__': 161 | main() 162 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('./coco-caption/') 3 | sys.path.append('./coco-caption/pycocoevalcap/') 4 | 5 | from bleu.bleu import Bleu 6 | from cider.cider import Cider 7 | from meteor.meteor import Meteor 8 | from rouge.rouge import Rouge 9 | 10 | import collections 11 | import io 12 | import os 13 | 14 | import numpy as np 15 | 16 | import chainer 17 | from chainer import cuda 18 | from chainer import serializer 19 | 20 | 21 | def make_vocab(dataset, max_vocab_size=20000, min_freq=2): 22 | counts = collections.defaultdict(int) 23 | for tokens, _ in dataset: 24 | for token in tokens: 25 | counts[token] += 1 26 | 27 | vocab = {'': 0, '': 1} 28 | for w, c in sorted(counts.items(), key=lambda x: (-x[1], x[0])): 29 | if len(vocab) >= max_vocab_size or c < min_freq: 30 | break 31 | vocab[w] = len(vocab) 32 | return vocab 33 | 34 | 35 | def make_array(tokens, vocab, add_eos=True, add_bos=True): 36 | ids = [vocab.get(token, vocab['']) for token in tokens] 37 | if add_eos: 38 | ids.append(vocab['']) 39 | if add_bos: 40 | ids.insert(0, vocab['']) 41 | return np.array(ids, 'i') 42 | 43 | 44 | def read_vocab(file_name): 45 | vocab = {} 46 | for l in io.open(file_name, encoding='utf-8', errors='ignore'): 47 | sp = l.rstrip().split('\t') 48 | if len(sp) == 2: 49 | vocab[sp[0]] = int(sp[1]) 50 | necessary = ['', '', ' '] 51 | for token in necessary: 52 | if token not in vocab: 53 | vocab[token] = len(vocab) 54 | return vocab 55 | 56 | 57 | def convert(batch, device): 58 | def to_device_batch(batch): 59 | if device is None: 60 | return batch 61 | else: 62 | xp = cuda.cupy.get_array_module(*batch) 63 | concat = xp.concatenate(batch, axis=0) 64 | sections = np.cumsum([len(x) for x in batch[:-1]], dtype='i') 65 | concat_dev = chainer.dataset.to_device(device, concat) 66 | batch_dev = xp.split(concat_dev, sections) 67 | return batch_dev 68 | 69 | vecs, sentences, others = zip(*batch) 70 | return {'xs': to_device_batch(list(vecs)), 71 | 'ys': to_device_batch(list(sentences)), 72 | 'others': list(others)} 73 | 74 | 75 | class PartialNpzDeserializer(serializer.Deserializer): 76 | 77 | """Partial Deserializer for NPZ format. 78 | 79 | This is the standard deserializer in Chainer. This deserializer can be used 80 | to read an object serialized by :func:`save_npz`. 81 | Only params with selected names "targets" are copied. 82 | 83 | Args: 84 | npz: `npz` file object. 85 | path: The base path that the deserialization starts from. 86 | strict (bool): If ``True``, the deserializer raises an error when an 87 | expected value is not found in the given NPZ file. Otherwise, 88 | it ignores the value and skip deserialization. 89 | 90 | """ 91 | 92 | def __init__(self, npz, path='', strict=True, target_words=[]): 93 | assert(len(target_words) >= 1) 94 | self.target_words = target_words 95 | self.npz = npz 96 | self.path = path 97 | self.strict = strict 98 | 99 | def __getitem__(self, key): 100 | key = key.strip('/') 101 | return PartialNpzDeserializer( 102 | self.npz, self.path + key + '/', strict=self.strict, 103 | target_words=self.target_words) 104 | 105 | def __call__(self, key, value): 106 | key = self.path + key.lstrip('/') 107 | if not any(target in key for target in self.target_words): 108 | print('\t{}\tis NOT in targets'.format(key)) 109 | return None 110 | else: 111 | print('{}\tis in targets'.format(key)) 112 | 113 | if not self.strict and key not in self.npz: 114 | return value 115 | 116 | dataset = self.npz[key] 117 | if dataset[()] is None: 118 | return None 119 | 120 | if value is None: 121 | return dataset 122 | elif isinstance(value, np.ndarray): 123 | np.copyto(value, dataset) 124 | elif isinstance(value, cuda.ndarray): 125 | value.set(np.asarray(dataset)) 126 | else: 127 | value = type(value)(np.asarray(dataset)) 128 | return value 129 | 130 | 131 | def load_npz_partially(filename, obj, strict=True, target_words=[]): 132 | """Loads an object from the file in NPZ format. 133 | This is a short-cut function to load from an `.npz` file that contains only 134 | one object. 135 | Args: 136 | filename (str): Name of the file to be loaded. 137 | obj: Object to be deserialized. It must support serialization protocol. 138 | """ 139 | with np.load(filename) as f: 140 | d = PartialNpzDeserializer(f, strict=strict, target_words=target_words) 141 | d.load(obj) 142 | 143 | 144 | class SentenceEvaluaterUnit(object): 145 | def __init__(self, references, 146 | scorers=['bleu', 'rouge', 'cider', 'meteor']): 147 | self.scorers = {} 148 | for scorer in scorers: 149 | if scorer == 'bleu': 150 | self.scorers['bleu'] = Bleu(4) 151 | elif scorer == 'rouge': 152 | self.scorers['rouge'] = Rouge() 153 | elif scorer == 'cider': 154 | self.scorers['cider'] = Cider() 155 | elif scorer == 'meteor': 156 | self.scorers['meteor'] = Meteor() 157 | else: 158 | raise NotImplementedError() 159 | self.references = references 160 | 161 | def __call__(self, predictions): 162 | results = {} 163 | for name, scorer in self.scorers.items(): 164 | score, _ = scorer.compute_score( 165 | self.references, predictions) 166 | if name == 'bleu': 167 | score = score[-1] 168 | results[name] = score 169 | return results 170 | 171 | 172 | def arrange_test_dataset(dataset, vocab): 173 | inputs = {} 174 | references = collections.defaultdict(list) 175 | unk = vocab[''] 176 | for data in dataset: 177 | vec, sentence, other = data 178 | img_id = other['img_id'] 179 | inputs[img_id] = vec 180 | target = sentence[1:-1] # remove bos and eos 181 | # make unk impossible to match 182 | target = [wi if wi != unk else -1 183 | for wi in target] 184 | target = ' '.join(str(wi) for wi in target) 185 | references[img_id].append(target) 186 | assert(set(inputs.keys()) == set(references.keys())) 187 | inputs = [iv for iv in sorted(inputs.items(), key=lambda x:x[0])] 188 | return inputs, references 189 | 190 | 191 | class SentenceEvaluater(chainer.training.Extension): 192 | 193 | priority = chainer.training.PRIORITY_WRITER 194 | 195 | def __init__(self, model, test_data, vocab, base_key, 196 | batchsize=100, device=-1, k=20, 197 | print_sentence_mod=None): 198 | self.model = model 199 | self.inputs, self.references = arrange_test_dataset(test_data, vocab) 200 | self.evaluater = SentenceEvaluaterUnit(self.references) 201 | self.base_key = base_key 202 | self.batchsize = batchsize 203 | self.device = device 204 | self.k = k 205 | self.vocab = vocab 206 | self.rev_vocab = {i: w for w, i in vocab.items()} 207 | self.print_sentence_mod = print_sentence_mod 208 | 209 | def __call__(self, trainer): 210 | with chainer.no_backprop_mode(), chainer.using_config('train', False): 211 | predictions = {} 212 | if self.print_sentence_mod is not None: 213 | print('\n') 214 | for i in range(0, len(self.inputs), self.batchsize): 215 | img_ids, vecs = zip(*self.inputs[i:i + self.batchsize]) 216 | 217 | vecs = [x for x in chainer.dataset.to_device( 218 | self.device, np.stack(vecs))] 219 | 220 | results = self.model.decode(vecs, k=self.k) 221 | for img_id, result in zip(img_ids, results): 222 | outs = result['outs'] 223 | while self.model.eos_id in outs: 224 | outs.remove(self.model.eos_id) 225 | outs_str = ' '.join(str(wi) for wi in outs) 226 | predictions[img_id] = [outs_str] 227 | 228 | if self.print_sentence_mod is not None: 229 | if img_id % self.print_sentence_mod == 0: 230 | print('\t#GEN {}: ({:.3f}) {}'.format( 231 | img_id, float(result['score']), 232 | ' '.join(self.rev_vocab[wi] for wi in outs))) 233 | score_results = self.evaluater(predictions) 234 | for name, score in score_results.items(): 235 | score *= 100. 236 | chainer.report({os.path.join(self.base_key, name): score}) 237 | --------------------------------------------------------------------------------