├── .gitignore
├── CapGenerator
├── __init__.py
├── attention.py
├── def_attention.ipynb
├── eval_model.py
├── generate_model.py
├── load_data.py
├── prepare_data.py
└── train_model.py
├── Devel.ipynb
├── Flickr8k_Dataset
└── README.md
├── Flickr8k_text
└── README.md
├── LICENSE
├── README.md
├── imgs
├── dog.jpg
├── people.jpg
├── ski.jpg
└── worker.jpg
├── model.png
├── models
└── README.md
├── requirements.txt
└── train_attention.ipynb
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Ignore OSX Default files
10 | .DS_Store
11 |
12 | # Distribution / packaging
13 | .Python
14 | env/
15 | build/
16 | develop-eggs/
17 | dist/
18 | downloads/
19 | eggs/
20 | .eggs/
21 | lib/
22 | lib64/
23 | parts/
24 | sdist/
25 | var/
26 | wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 |
31 | # PyInstaller
32 | # Usually these files are written by a python script from a template
33 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
34 | *.manifest
35 | *.spec
36 |
37 | # Installer logs
38 | pip-log.txt
39 | pip-delete-this-directory.txt
40 |
41 | # Unit test / coverage reports
42 | htmlcov/
43 | .tox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | .hypothesis/
51 |
52 | # Translations
53 | *.mo
54 | *.pot
55 |
56 | # Django stuff:
57 | *.log
58 | local_settings.py
59 |
60 | # Flask stuff:
61 | instance/
62 | .webassets-cache
63 |
64 | # Scrapy stuff:
65 | .scrapy
66 |
67 | # Sphinx documentation
68 | docs/_build/
69 |
70 | # PyBuilder
71 | target/
72 |
73 | # Jupyter Notebook
74 | .ipynb_checkpoints
75 |
76 | # pyenv
77 | .python-version
78 |
79 | # celery beat schedule file
80 | celerybeat-schedule
81 |
82 | # SageMath parsed files
83 | *.sage.py
84 |
85 | # dotenv
86 | .env
87 |
88 | # virtualenv
89 | .venv
90 | venv/
91 | ENV/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
106 | # Datasets
107 | Flickr*/*
108 | models/*
109 |
110 | # Don't ignore README
111 | !Flickr*/README.md
112 | !models/README.mdd
--------------------------------------------------------------------------------
/CapGenerator/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Div99/Image-Captioning/d7465d54e02ace69abc395e250176a1936f705b5/CapGenerator/__init__.py
--------------------------------------------------------------------------------
/CapGenerator/attention.py:
--------------------------------------------------------------------------------
1 | from os import listdir
2 | from pickle import dump
3 | from keras.applications.vgg16 import VGG16
4 | from keras.preprocessing.image import load_img
5 | from keras.preprocessing.image import img_to_array
6 | from keras.applications.vgg16 import preprocess_input
7 | from keras.layers import Input, Reshape, Concatenate, Activation, Dense
8 | import keras
9 | import numpy as np
10 | import string
11 | # from progressbar import progressbar
12 | from keras.models import Model
13 | import tensorflow as tf
14 |
15 | class AttentionModel:
16 |
17 | def __init__(self):
18 |
19 | # load model
20 | model = VGG16()
21 | model.layers.pop()
22 | # extract final 49x512 conv layer for context vectors
23 | final_conv = Reshape([49,512])(model.layers[-4].output)
24 | self.model = Model(inputs=model.inputs, outputs=final_conv)
25 | print(self.model.summary())
26 |
27 | # model parameters
28 | self.dim_ctx = 512
29 | self.n_ctx = 49
30 | self.lstm_cell_dim = 128
31 | self.lstm_hidden_dim = 128
32 |
33 | # cell state MLP
34 | self.c_mlp_hidden = 256
35 |
36 | self.inputs_c = Input(shape=(self.dim_ctx,))
37 | f_c = Dense(self.c_mlp_hidden,activation="relu")(self.inputs_c)
38 | self.f_c = Dense(self.lstm_cell_dim,activation=None)(f_c)
39 |
40 | # hidden state MLP
41 | self.h_mlp_hidden = 256
42 |
43 | self.inputs_h = Input(shape=(self.dim_ctx,))
44 | f_h = Dense(self.h_mlp_hidden,activation="relu")(self.inputs_h)
45 | self.f_h = Dense(self.lstm_hidden_dim,activation=None)(f_h)
46 |
47 | # attention/alphas MLP
48 | self.att_mlp_hidden = 256
49 |
50 | self.inputs_att = Input(shape=(self.dim_ctx+self.lstm_hidden_dim,))
51 | x = Dense(self.att_mlp_hidden,activation="relu")(self.inputs_att)
52 | x = Dense(1,activation=None)(x)
53 | self.alphas = Activation("softmax")(x)
54 |
55 | self.sess = tf.Session()
56 |
57 | # Returns tensors for the initial cell_state and hidden_states
58 | def init_lstm_states(self,contexts):
59 | cell_state = self.sess.run(self.f_c,feed_dict={self.inputs_c:contexts})
60 | hidden_state = self.sess.run(self.f_h,feed_dict={self.inputs_h:contexts})
61 | return cell_state,hidden_state
62 |
63 | # Computes alpha values (attention weights) by passing context vectors + hidden state through MLP
64 | # Includes hidden state by concatenating to end of alpha values
65 | def generate_alphas(self,contexts,hidden_state):
66 | batch_size = contexts.shape[0]
67 | tiled_hidden_state = tf.tile([[hidden_state]],[batch_size,n_ctx,1])
68 | concat_input = Concatenate(axis=-1)((contexts,tiled_hidden_state))
69 | return self.sess.run(self.alphas,feed_dict={self.inputs_att:concat_input})
70 |
71 | # Generates a soft-attention attention vector from alphas & context vectors
72 | def get_soft_attention_vec(contexts,alphas):
73 | return contexts*tf.reshape(alphas,[1,-1,1])
74 |
75 | # Generates VGG16 features from a batch of images
76 | def get_features(images):
77 | return self.sess.run(self.model.output,feed_dict={})
78 |
79 |
80 |
--------------------------------------------------------------------------------
/CapGenerator/def_attention.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 3,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "from os import listdir\n",
10 | "from pickle import dump\n",
11 | "from keras.applications.vgg16 import VGG16\n",
12 | "from keras.preprocessing.image import load_img\n",
13 | "from keras.preprocessing.image import img_to_array\n",
14 | "from keras.applications.vgg16 import preprocess_input\n",
15 | "from keras.layers import Input, Reshape, Concatenate, Activation, Dense\n",
16 | "import keras\n",
17 | "import numpy as np\n",
18 | "import string\n",
19 | "# from progressbar import progressbar\n",
20 | "from keras.models import Model\n",
21 | "import tensorflow as tf"
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": 7,
27 | "metadata": {},
28 | "outputs": [
29 | {
30 | "name": "stdout",
31 | "output_type": "stream",
32 | "text": [
33 | "_________________________________________________________________\n",
34 | "Layer (type) Output Shape Param # \n",
35 | "=================================================================\n",
36 | "input_4 (InputLayer) (None, 224, 224, 3) 0 \n",
37 | "_________________________________________________________________\n",
38 | "block1_conv1 (Conv2D) (None, 224, 224, 64) 1792 \n",
39 | "_________________________________________________________________\n",
40 | "block1_conv2 (Conv2D) (None, 224, 224, 64) 36928 \n",
41 | "_________________________________________________________________\n",
42 | "block1_pool (MaxPooling2D) (None, 112, 112, 64) 0 \n",
43 | "_________________________________________________________________\n",
44 | "block2_conv1 (Conv2D) (None, 112, 112, 128) 73856 \n",
45 | "_________________________________________________________________\n",
46 | "block2_conv2 (Conv2D) (None, 112, 112, 128) 147584 \n",
47 | "_________________________________________________________________\n",
48 | "block2_pool (MaxPooling2D) (None, 56, 56, 128) 0 \n",
49 | "_________________________________________________________________\n",
50 | "block3_conv1 (Conv2D) (None, 56, 56, 256) 295168 \n",
51 | "_________________________________________________________________\n",
52 | "block3_conv2 (Conv2D) (None, 56, 56, 256) 590080 \n",
53 | "_________________________________________________________________\n",
54 | "block3_conv3 (Conv2D) (None, 56, 56, 256) 590080 \n",
55 | "_________________________________________________________________\n",
56 | "block3_pool (MaxPooling2D) (None, 28, 28, 256) 0 \n",
57 | "_________________________________________________________________\n",
58 | "block4_conv1 (Conv2D) (None, 28, 28, 512) 1180160 \n",
59 | "_________________________________________________________________\n",
60 | "block4_conv2 (Conv2D) (None, 28, 28, 512) 2359808 \n",
61 | "_________________________________________________________________\n",
62 | "block4_conv3 (Conv2D) (None, 28, 28, 512) 2359808 \n",
63 | "_________________________________________________________________\n",
64 | "block4_pool (MaxPooling2D) (None, 14, 14, 512) 0 \n",
65 | "_________________________________________________________________\n",
66 | "block5_conv1 (Conv2D) (None, 14, 14, 512) 2359808 \n",
67 | "_________________________________________________________________\n",
68 | "block5_conv2 (Conv2D) (None, 14, 14, 512) 2359808 \n",
69 | "_________________________________________________________________\n",
70 | "block5_conv3 (Conv2D) (None, 14, 14, 512) 2359808 \n",
71 | "_________________________________________________________________\n",
72 | "block5_pool (MaxPooling2D) (None, 7, 7, 512) 0 \n",
73 | "_________________________________________________________________\n",
74 | "flatten (Flatten) (None, 25088) 0 \n",
75 | "_________________________________________________________________\n",
76 | "fc1 (Dense) (None, 4096) 102764544 \n",
77 | "_________________________________________________________________\n",
78 | "fc2 (Dense) (None, 4096) 16781312 \n",
79 | "=================================================================\n",
80 | "Total params: 134,260,544\n",
81 | "Trainable params: 134,260,544\n",
82 | "Non-trainable params: 0\n",
83 | "_________________________________________________________________\n",
84 | "None\n"
85 | ]
86 | }
87 | ],
88 | "source": [
89 | "from IPython.display import SVG\n",
90 | "from keras.utils.vis_utils import model_to_dot\n",
91 | "\n",
92 | "# load model\n",
93 | "model = VGG16()\n",
94 | "model.layers.pop()\n",
95 | "# extract final 49x512 conv layer for context vectors\n",
96 | "final_conv = Reshape([49,512])(model.layers[-4].output)\n",
97 | "print(model.summary())"
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "execution_count": 8,
103 | "metadata": {},
104 | "outputs": [
105 | {
106 | "data": {
107 | "image/svg+xml": [
108 | ""
505 | ],
506 | "text/plain": [
507 | ""
508 | ]
509 | },
510 | "execution_count": 8,
511 | "metadata": {},
512 | "output_type": "execute_result"
513 | }
514 | ],
515 | "source": [
516 | "SVG(model_to_dot(model_to_dot, show_shapes=True, show_layer_names=True).create(prog='dot', format='svg'))"
517 | ]
518 | },
519 | {
520 | "cell_type": "code",
521 | "execution_count": 9,
522 | "metadata": {},
523 | "outputs": [
524 | {
525 | "name": "stdout",
526 | "output_type": "stream",
527 | "text": [
528 | "_________________________________________________________________\n",
529 | "Layer (type) Output Shape Param # \n",
530 | "=================================================================\n",
531 | "input_5 (InputLayer) (None, 224, 224, 3) 0 \n",
532 | "_________________________________________________________________\n",
533 | "block1_conv1 (Conv2D) (None, 224, 224, 64) 1792 \n",
534 | "_________________________________________________________________\n",
535 | "block1_conv2 (Conv2D) (None, 224, 224, 64) 36928 \n",
536 | "_________________________________________________________________\n",
537 | "block1_pool (MaxPooling2D) (None, 112, 112, 64) 0 \n",
538 | "_________________________________________________________________\n",
539 | "block2_conv1 (Conv2D) (None, 112, 112, 128) 73856 \n",
540 | "_________________________________________________________________\n",
541 | "block2_conv2 (Conv2D) (None, 112, 112, 128) 147584 \n",
542 | "_________________________________________________________________\n",
543 | "block2_pool (MaxPooling2D) (None, 56, 56, 128) 0 \n",
544 | "_________________________________________________________________\n",
545 | "block3_conv1 (Conv2D) (None, 56, 56, 256) 295168 \n",
546 | "_________________________________________________________________\n",
547 | "block3_conv2 (Conv2D) (None, 56, 56, 256) 590080 \n",
548 | "_________________________________________________________________\n",
549 | "block3_conv3 (Conv2D) (None, 56, 56, 256) 590080 \n",
550 | "_________________________________________________________________\n",
551 | "block3_pool (MaxPooling2D) (None, 28, 28, 256) 0 \n",
552 | "_________________________________________________________________\n",
553 | "block4_conv1 (Conv2D) (None, 28, 28, 512) 1180160 \n",
554 | "_________________________________________________________________\n",
555 | "block4_conv2 (Conv2D) (None, 28, 28, 512) 2359808 \n",
556 | "_________________________________________________________________\n",
557 | "block4_conv3 (Conv2D) (None, 28, 28, 512) 2359808 \n",
558 | "_________________________________________________________________\n",
559 | "block4_pool (MaxPooling2D) (None, 14, 14, 512) 0 \n",
560 | "_________________________________________________________________\n",
561 | "block5_conv1 (Conv2D) (None, 14, 14, 512) 2359808 \n",
562 | "_________________________________________________________________\n",
563 | "block5_conv2 (Conv2D) (None, 14, 14, 512) 2359808 \n",
564 | "_________________________________________________________________\n",
565 | "block5_conv3 (Conv2D) (None, 14, 14, 512) 2359808 \n",
566 | "_________________________________________________________________\n",
567 | "block5_pool (MaxPooling2D) (None, 7, 7, 512) 0 \n",
568 | "_________________________________________________________________\n",
569 | "flatten (Flatten) (None, 25088) 0 \n",
570 | "_________________________________________________________________\n",
571 | "fc1 (Dense) (None, 4096) 102764544 \n",
572 | "_________________________________________________________________\n",
573 | "fc2 (Dense) (None, 4096) 16781312 \n",
574 | "=================================================================\n",
575 | "Total params: 134,260,544\n",
576 | "Trainable params: 134,260,544\n",
577 | "Non-trainable params: 0\n",
578 | "_________________________________________________________________\n",
579 | "None\n"
580 | ]
581 | }
582 | ],
583 | "source": [
584 | "model = VGG16()\n",
585 | "# re-structure the model\n",
586 | "model.layers.pop()\n",
587 | "model = Model(inputs=model.inputs, outputs=model.layers[-1].output)\n",
588 | "print(model.summary())"
589 | ]
590 | },
591 | {
592 | "cell_type": "code",
593 | "execution_count": null,
594 | "metadata": {},
595 | "outputs": [],
596 | "source": [
597 | "SVG(model_to_dot(model, show_shapes=True, show_layer_names=True).create(prog='dot', format='svg'))"
598 | ]
599 | },
600 | {
601 | "cell_type": "code",
602 | "execution_count": null,
603 | "metadata": {},
604 | "outputs": [],
605 | "source": [
606 | "class AttentionModel:\n",
607 | " \n",
608 | " def __init__(self):\n",
609 | " \n",
610 | " # load model\n",
611 | " model = VGG16()\n",
612 | " model.layers.pop()\n",
613 | " # extract final 49x512 conv layer for context vectors\n",
614 | " final_conv = Reshape([49,512])(model.layers[-4].output)\n",
615 | " self.model = Model(inputs=model.inputs, outputs=final_conv)\n",
616 | " print(self.model.summary())\n",
617 | " \n",
618 | " # model parameters\n",
619 | " self.dim_ctx = 512\n",
620 | " self.n_ctx = 49\n",
621 | " self.lstm_cell_dim = 128\n",
622 | " self.lstm_hidden_dim = 128\n",
623 | " \n",
624 | " # cell state MLP\n",
625 | " self.c_mlp_hidden = 256\n",
626 | " \n",
627 | " self.inputs_c = Input(shape=(self.dim_ctx,))\n",
628 | " f_c = Dense(self.c_mlp_hidden,activation=\"relu\")(self.inputs_c)\n",
629 | " self.f_c = Dense(self.lstm_cell_dim,activation=None)(f_c)\n",
630 | " \n",
631 | " # hidden state MLP\n",
632 | " self.h_mlp_hidden = 256\n",
633 | " \n",
634 | " self.inputs_h = Input(shape=(self.dim_ctx,))\n",
635 | " f_h = Dense(self.h_mlp_hidden,activation=\"relu\")(self.inputs_h)\n",
636 | " self.f_h = Dense(self.lstm_hidden_dim,activation=None)(f_h)\n",
637 | " \n",
638 | " # attention/alphas MLP\n",
639 | " self.att_mlp_hidden = 256\n",
640 | " \n",
641 | " self.inputs_att = Input(shape=(self.dim_ctx+self.lstm_hidden_dim,))\n",
642 | " x = Dense(self.att_mlp_hidden,activation=\"relu\")(self.inputs_att)\n",
643 | " x = Dense(1,activation=None)(x)\n",
644 | " self.alphas = Activation(\"softmax\")(x)\n",
645 | " \n",
646 | " self.sess = tf.Session()\n",
647 | " \n",
648 | " # Returns tensors for the initial cell_state and hidden_states\n",
649 | " def init_lstm_states(self,contexts):\n",
650 | " cell_state = self.sess.run(self.f_c,feed_dict={self.inputs_c:contexts})\n",
651 | " hidden_state = self.sess.run(self.f_h,feed_dict={self.inputs_h:contexts})\n",
652 | " return cell_state,hidden_state\n",
653 | " \n",
654 | " # Computes alpha values (attention weights) by passing context vectors + hidden state through MLP\n",
655 | " # Includes hidden state by concatenating to end of alpha values\n",
656 | " def generate_alphas(self,contexts,hidden_state):\n",
657 | " batch_size = contexts.shape[0]\n",
658 | " tiled_hidden_state = tf.tile([[hidden_state]],[batch_size,n_ctx,1])\n",
659 | " concat_input = Concatenate(axis=-1)((contexts,tiled_hidden_state))\n",
660 | " return self.sess.run(self.alphas,feed_dict={self.inputs_att:concat_input})\n",
661 | "\n",
662 | " # Generates a soft-attention attention vector from alphas & context vectors\n",
663 | " def get_soft_attention_vec(contexts,alphas):\n",
664 | " return contexts*tf.reshape(alphas,[1,-1,1])\n",
665 | " \n",
666 | " # Generates VGG16 features from a batch of images\n",
667 | " def get_features(images):\n",
668 | " return self.sess.run(self.model.output,feed_dict={})\n",
669 | " \n",
670 | " "
671 | ]
672 | }
673 | ],
674 | "metadata": {
675 | "kernelspec": {
676 | "display_name": "Anaconda",
677 | "language": "python",
678 | "name": "anaconda"
679 | },
680 | "language_info": {
681 | "codemirror_mode": {
682 | "name": "ipython",
683 | "version": 3
684 | },
685 | "file_extension": ".py",
686 | "mimetype": "text/x-python",
687 | "name": "python",
688 | "nbconvert_exporter": "python",
689 | "pygments_lexer": "ipython3",
690 | "version": "3.6.2"
691 | }
692 | },
693 | "nbformat": 4,
694 | "nbformat_minor": 2
695 | }
696 |
--------------------------------------------------------------------------------
/CapGenerator/eval_model.py:
--------------------------------------------------------------------------------
1 | from pickle import load
2 | import numpy as np
3 | from keras.preprocessing.sequence import pad_sequences
4 | from keras.applications.vgg16 import VGG16
5 | from keras.preprocessing.image import load_img
6 | from keras.preprocessing.image import img_to_array
7 | from keras.applications.vgg16 import preprocess_input
8 | from keras.models import Model
9 | from keras.models import load_model
10 | from nltk.translate.bleu_score import corpus_bleu
11 |
12 | import load_data as ld
13 | import generate_model as gen
14 | import argparse
15 |
16 | # extract features from each photo in the directory
17 | def extract_features(filename):
18 | # load the model
19 | model = VGG16()
20 | # re-structure the model
21 | model.layers.pop()
22 | model = Model(inputs=model.inputs, outputs=model.layers[-1].output)
23 | # load the photo
24 | image = load_img(filename, target_size=(224, 224))
25 | # convert the image pixels to a numpy array
26 | image = img_to_array(image)
27 | # reshape data for the model
28 | image = image.reshape((1, image.shape[0], image.shape[1], image.shape[2]))
29 | # prepare the image for the VGG model
30 | image = preprocess_input(image)
31 | # get features
32 | feature = model.predict(image, verbose=0)
33 | return feature
34 |
35 | # generate a description for an image
36 | def generate_desc(model, tokenizer, photo, index_word, max_length, beam_size=5):
37 |
38 | captions = [['startseq', 0.0]]
39 | # seed the generation process
40 | in_text = 'startseq'
41 | # iterate over the whole length of the sequence
42 | for i in range(max_length):
43 | all_caps = []
44 | # expand each current candidate
45 | for cap in captions:
46 | sentence, score = cap
47 | # if final word is 'end' token, just add the current caption
48 | if sentence.split()[-1] == 'endseq':
49 | all_caps.append(cap)
50 | continue
51 | # integer encode input sequence
52 | sequence = tokenizer.texts_to_sequences([sentence])[0]
53 | # pad input
54 | sequence = pad_sequences([sequence], maxlen=max_length)
55 | # predict next words
56 | y_pred = model.predict([photo,sequence], verbose=0)[0]
57 | # convert probability to integer
58 | yhats = np.argsort(y_pred)[-beam_size:]
59 |
60 | for j in yhats:
61 | # map integer to word
62 | word = index_word.get(j)
63 | # stop if we cannot map the word
64 | if word is None:
65 | continue
66 | # Add word to caption, and generate log prob
67 | caption = [sentence + ' ' + word, score + np.log(y_pred[j])]
68 | all_caps.append(caption)
69 |
70 | # order all candidates by score
71 | ordered = sorted(all_caps, key=lambda tup:tup[1], reverse=True)
72 | captions = ordered[:beam_size]
73 |
74 | return captions
75 |
76 | # evaluate the skill of the model
77 | def evaluate_model(model, descriptions, photos, tokenizer, index_word, max_length):
78 | actual, predicted = list(), list()
79 | # step over the whole set
80 | for key, desc_list in descriptions.items():
81 | # generate description
82 | yhat = generate_desc(model, tokenizer, photos[key], index_word, max_length)[0]
83 | # store actual and predicted
84 | references = [d.split() for d in desc_list]
85 | actual.append(references)
86 | # Use best caption
87 | predicted.append(yhat[0].split())
88 | # calculate BLEU score
89 | print('BLEU-1: %f' % corpus_bleu(actual, predicted, weights=(1.0, 0, 0, 0)))
90 | print('BLEU-2: %f' % corpus_bleu(actual, predicted, weights=(0.5, 0.5, 0, 0)))
91 | print('BLEU-3: %f' % corpus_bleu(actual, predicted, weights=(0.3, 0.3, 0.3, 0)))
92 | print('BLEU-4: %f' % corpus_bleu(actual, predicted, weights=(0.25, 0.25, 0.25, 0.25)))
93 |
94 | def eval_test_set(model, descriptions, photos, tokenizer, index_word, max_length):
95 | actual, predicted = list(), list()
96 | # step over the whole set
97 | for key, desc_list in descriptions.items():
98 | # generate description
99 | yhat = generate_desc(model, tokenizer, photos[key], index_word, max_length)[0]
100 | # store actual and predicted
101 | references = [d.split() for d in desc_list]
102 | actual.append(references)
103 | # Use best caption
104 | predicted.append(yhat[0].split())
105 | predicted = sorted(predicted)
106 | actual = [x for _,x in sorted(zip(actual,predicted))]
107 |
108 | if __name__ == '__main__':
109 |
110 | parser = argparse.ArgumentParser(description='Generate image captions')
111 | parser.add_argument("-i", "--image", help="Input image path")
112 | parser.add_argument("-m", "--model", help="model checkpoint")
113 | args = parser.parse_args()
114 |
115 |
116 | # load the tokenizer
117 | tokenizer = load(open('models/tokenizer.pkl', 'rb'))
118 | index_word = load(open('models/index_word.pkl', 'rb'))
119 | # pre-define the max sequence length (from training)
120 | max_length = 34
121 |
122 | # load the model
123 | if args.model:
124 | filename = args.model
125 | else:
126 | filename = 'models/model_weight.h5'
127 | model = load_model(filename)
128 |
129 | if args.image:
130 | # load and prepare the photograph
131 | photo = extract_features(args.image)
132 | # generate description
133 | captions = generate_desc(model, tokenizer, photo, index_word, max_length)
134 | for cap in captions:
135 | # remove start and end tokens
136 | seq = cap[0].split()[1:-1]
137 | desc = ' '.join(seq)
138 | print('{} [log prob: {:1.2f}]'.format(desc,cap[1]))
139 | else:
140 | # load test set
141 | test_features, test_descriptions = ld.prepare_dataset('test')[1]
142 |
143 | # evaluate model
144 | evaluate_model(model, test_descriptions, test_features, tokenizer, index_word, max_length)
145 |
--------------------------------------------------------------------------------
/CapGenerator/generate_model.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | from pickle import load
4 | from keras.preprocessing.text import Tokenizer
5 | from keras.preprocessing.sequence import pad_sequences
6 | from keras.utils import to_categorical
7 | from keras.utils import plot_model
8 | from keras.models import Model, Sequential
9 | from keras.layers import Input
10 | from keras.layers import Dense
11 | from keras.layers import LSTM
12 | from keras.layers import Embedding
13 | from keras.layers import Dropout
14 | from keras.layers import RepeatVector
15 | from keras.layers import TimeDistributed
16 | from keras.layers import concatenate
17 | from keras.callbacks import ModelCheckpoint
18 | from keras.optimizers import Adam
19 |
20 | EMBEDDING_DIM = 256
21 |
22 | lstm_layers = 2
23 | dropout_rate = 0.2
24 | learning_rate = 0.001
25 |
26 | # convert a dictionary of clean descriptions to a list of descriptions
27 | def to_lines(descriptions):
28 | all_desc = list()
29 | for key in descriptions.keys():
30 | [all_desc.append(d) for d in descriptions[key]]
31 | return all_desc
32 |
33 | # fit a tokenizer given caption descriptions
34 | def create_tokenizer(descriptions):
35 | lines = to_lines(descriptions)
36 | tokenizer = Tokenizer()
37 | tokenizer.fit_on_texts(lines)
38 | return tokenizer
39 |
40 |
41 | # calculate the length of the description with the most words
42 | def max_length(descriptions):
43 | lines = to_lines(descriptions)
44 | return max(len(d.split()) for d in lines)
45 |
46 | # create sequences of images, input sequences and output words for an image
47 | def create_sequences(tokenizer, max_length, desc_list, photo):
48 | vocab_size = len(tokenizer.word_index) + 1
49 |
50 | X1, X2, y = [], [], []
51 | # walk through each description for the image
52 | for desc in desc_list:
53 | # encode the sequence
54 | seq = tokenizer.texts_to_sequences([desc])[0]
55 | # split one sequence into multiple X,y pairs
56 | for i in range(1, len(seq)):
57 | # split into input and output pair
58 | in_seq, out_seq = seq[:i], seq[i]
59 | # pad input sequence
60 | in_seq = pad_sequences([in_seq], maxlen=max_length)[0]
61 | # encode output sequence
62 | out_seq = to_categorical([out_seq], num_classes=vocab_size)[0]
63 | # store
64 | X1.append(photo)
65 | X2.append(in_seq)
66 | y.append(out_seq)
67 | return np.array(X1), np.array(X2), np.array(y)
68 |
69 | # data generator, intended to be used in a call to model.fit_generator()
70 | def data_generator(descriptions, photos, tokenizer, max_length, n_step = 1):
71 | # loop for ever over images
72 | while 1:
73 | # loop over photo identifiers in the dataset
74 | keys = list(descriptions.keys())
75 | for i in range(0, len(keys), n_step):
76 | Ximages, XSeq, y = list(), list(),list()
77 | for j in range(i, min(len(keys), i+n_step)):
78 | image_id = keys[j]
79 | # retrieve the photo feature
80 | photo = photos[image_id][0]
81 | desc_list = descriptions[image_id]
82 | in_img, in_seq, out_word = create_sequences(tokenizer, max_length, desc_list, photo)
83 | for k in range(len(in_img)):
84 | Ximages.append(in_img[k])
85 | XSeq.append(in_seq[k])
86 | y.append(out_word[k])
87 | yield [[np.array(Ximages), np.array(XSeq)], np.array(y)]
88 |
89 | def categorical_crossentropy_from_logits(y_true, y_pred):
90 | y_true = y_true[:, :-1, :] # Discard the last timestep
91 | y_pred = y_pred[:, :-1, :] # Discard the last timestep
92 | loss = tf.nn.softmax_cross_entropy_with_logits(labels=y_true,
93 | logits=y_pred)
94 | return loss
95 |
96 | def categorical_accuracy_with_variable_timestep(y_true, y_pred):
97 | y_true = y_true[:, :-1, :] # Discard the last timestep
98 | y_pred = y_pred[:, :-1, :] # Discard the last timestep
99 |
100 | # Flatten the timestep dimension
101 | shape = tf.shape(y_true)
102 | y_true = tf.reshape(y_true, [-1, shape[-1]])
103 | y_pred = tf.reshape(y_pred, [-1, shape[-1]])
104 |
105 | # Discard rows that are all zeros as they represent padding words.
106 | is_zero_y_true = tf.equal(y_true, 0)
107 | is_zero_row_y_true = tf.reduce_all(is_zero_y_true, axis=-1)
108 | y_true = tf.boolean_mask(y_true, ~is_zero_row_y_true)
109 | y_pred = tf.boolean_mask(y_pred, ~is_zero_row_y_true)
110 |
111 | accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(y_true, axis=1),
112 | tf.argmax(y_pred, axis=1)),
113 | dtype=tf.float32))
114 | return accuracy
115 |
116 | # define the captioning model
117 | def define_model(vocab_size, max_length):
118 | # feature extractor (encoder)
119 | inputs1 = Input(shape=(4096,))
120 | fe1 = Dropout(0.5)(inputs1)
121 | fe2 = Dense(EMBEDDING_DIM, activation='relu')(fe1)
122 | fe3 = RepeatVector(max_length)(fe2)
123 |
124 | # embedding
125 | inputs2 = Input(shape=(max_length,))
126 | emb2 = Embedding(vocab_size, EMBEDDING_DIM, mask_zero=True)(inputs2)
127 |
128 | # merge inputs
129 | merged = concatenate([fe3, emb2])
130 | # language model (decoder)
131 | lm2 = LSTM(500, return_sequences=False)(merged)
132 | #lm3 = Dense(500, activation='relu')(lm2)
133 | outputs = Dense(vocab_size, activation='softmax')(lm2)
134 |
135 | # tie it together [image, seq] [word]
136 | model = Model(inputs=[inputs1, inputs2], outputs=outputs)
137 | model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
138 | print(model.summary())
139 | plot_model(model, show_shapes=True, to_file='model.png')
140 | return model
141 |
--------------------------------------------------------------------------------
/CapGenerator/load_data.py:
--------------------------------------------------------------------------------
1 | from pickle import load
2 | import argparse
3 |
4 | # load doc into memory
5 | def load_doc(filename):
6 | # open the file as read only
7 | file = open(filename, 'r')
8 | # read all text
9 | text = file.read()
10 | # close the file
11 | file.close()
12 | return text
13 |
14 | # load a pre-defined list of photo identifiers
15 | def load_set(filename):
16 | doc = load_doc(filename)
17 | dataset = list()
18 | # process line by line
19 | for line in doc.split('\n'):
20 | # skip empty lines
21 | if len(line) < 1:
22 | continue
23 | # get the image identifier
24 | identifier = line.split('.')[0]
25 | dataset.append(identifier)
26 | return set(dataset)
27 |
28 | # split a dataset into train/test elements
29 | def train_test_split(dataset):
30 | # order keys so the split is consistent
31 | ordered = sorted(dataset)
32 | # return split dataset as two new sets
33 | return set(ordered[:100]), set(ordered[100:200])
34 |
35 | # load clean descriptions into memory
36 | def load_clean_descriptions(filename, dataset):
37 | # load document
38 | doc = load_doc(filename)
39 | descriptions = dict()
40 | for line in doc.split('\n'):
41 | # split line by white space
42 | tokens = line.split()
43 | # split id from description
44 | image_id, image_desc = tokens[0], tokens[1:]
45 | # skip images not in the set
46 | if image_id in dataset:
47 | # create list
48 | if image_id not in descriptions:
49 | descriptions[image_id] = list()
50 | # wrap description in tokens
51 | desc = 'startseq ' + ' '.join(image_desc) + ' endseq'
52 | # store
53 | descriptions[image_id].append(desc)
54 | return descriptions
55 |
56 | # load photo features
57 | def load_photo_features(filename, dataset):
58 | # load all features
59 | all_features = load(open(filename, 'rb'))
60 | # filter features
61 | features = {k: all_features[k] for k in dataset}
62 | return features
63 |
64 | def prepare_dataset(data='dev'):
65 |
66 | assert data in ['dev', 'train', 'test']
67 |
68 | train_features = None
69 | train_descriptions = None
70 |
71 | if data == 'dev':
72 | # load dev set (1K)
73 | filename = 'Flickr8k_text/Flickr_8k.devImages.txt'
74 | dataset = load_set(filename)
75 | print('Dataset: %d' % len(dataset))
76 |
77 | # train-test split
78 | train, test = train_test_split(dataset)
79 | #print('Train=%d, Test=%d' % (len(train), len(test)))
80 |
81 | # descriptions
82 | train_descriptions = load_clean_descriptions('models/descriptions.txt', train)
83 | test_descriptions = load_clean_descriptions('models/descriptions.txt', test)
84 | print('Descriptions: train=%d, test=%d' % (len(train_descriptions), len(test_descriptions)))
85 |
86 | # photo features
87 | train_features = load_photo_features('models/features.pkl', train)
88 | test_features = load_photo_features('models/features.pkl', test)
89 | print('Photos: train=%d, test=%d' % (len(train_features), len(test_features)))
90 |
91 | elif data == 'train':
92 | # load training dataset (6K)
93 | filename = 'Flickr8k_text/Flickr_8k.trainImages.txt'
94 | train = load_set(filename)
95 |
96 | filename = 'Flickr8k_text/Flickr_8k.devImages.txt'
97 | test = load_set(filename)
98 | print('Dataset: %d' % len(train))
99 |
100 | # descriptions
101 | train_descriptions = load_clean_descriptions('models/descriptions.txt', train)
102 | test_descriptions = load_clean_descriptions('models/descriptions.txt', test)
103 | print('Descriptions: train=%d, test=%d' % (len(train_descriptions), len(test_descriptions)))
104 |
105 | # photo features
106 | train_features = load_photo_features('models/features.pkl', train)
107 | test_features = load_photo_features('models/features.pkl', test)
108 | print('Photos: train=%d, test=%d' % (len(train_features), len(test_features)))
109 |
110 | elif data == 'test':
111 | # load test set
112 | filename = 'Flickr8k_text/Flickr_8k.testImages.txt'
113 | test = load_set(filename)
114 | print('Dataset: %d' % len(test))
115 | # descriptions
116 | test_descriptions = load_clean_descriptions('models/descriptions.txt', test)
117 | print('Descriptions: test=%d' % len(test_descriptions))
118 | # photo features
119 | test_features = load_photo_features('models/features.pkl', test)
120 | print('Photos: test=%d' % len(test_features))
121 |
122 | return (train_features, train_descriptions), (test_features, test_descriptions)
123 |
124 | if __name__ == '__main__':
125 | parser = argparse.ArgumentParser(description='Generate dataset features')
126 | parser.add_argument("-t", "--train", action='store_const', const='train',
127 | default = 'dev', help="Use large 6K training set")
128 | args = parser.parse_args()
129 | prepare_dataset(args.train)
130 |
--------------------------------------------------------------------------------
/CapGenerator/prepare_data.py:
--------------------------------------------------------------------------------
1 | from os import listdir
2 | from pickle import dump
3 | from keras.applications.vgg16 import VGG16
4 | from keras.preprocessing.image import load_img
5 | from keras.preprocessing.image import img_to_array
6 | from keras.applications.vgg16 import preprocess_input
7 | from keras.layers import Input, Reshape, Concatenate
8 | import numpy as np
9 | import string
10 | from progressbar import progressbar
11 | from keras.models import Model
12 |
13 | # load an image from filepath
14 | def load_image(path):
15 | img = load_img(path, target_size=(224,224))
16 | img = img_to_array(img)
17 | img = np.expand_dims(img, axis=0)
18 | img = preprocess_input(img)
19 | return np.asarray(img)
20 |
21 | # extract features from each photo in the directory
22 | def extract_features(directory,is_attention=False):
23 | # load the model
24 | if is_attention:
25 | model = VGG16()
26 | model.layers.pop()
27 | # extract final 49x512 conv layer for context vectors
28 | final_conv = Reshape([49,512])(model.layers[-4].output)
29 | model = Model(inputs=model.inputs, outputs=final_conv)
30 | print(model.summary())
31 | features = dict()
32 | else:
33 | model = VGG16()
34 | # re-structure the model
35 | model.layers.pop()
36 | model = Model(inputs=model.inputs, outputs=model.layers[-1].output)
37 | print(model.summary())
38 | # extract features from each photo
39 | features = dict()
40 |
41 | for name in progressbar(listdir(directory)):
42 | # ignore README
43 | if name == 'README.md':
44 | continue
45 | filename = directory + '/' + name
46 | image = load_image(filename)
47 | # extract features
48 | feature = model.predict(image, verbose=0)
49 | # get image id
50 | image_id = name.split('.')[0]
51 | # store feature
52 | features[image_id] = feature
53 | print('>%s' % name)
54 | return features
55 |
56 | # load doc into memory
57 | def load_doc(filename):
58 | # open the file as read only
59 | file = open(filename, 'r')
60 | # read all text
61 | text = file.read()
62 | # close the file
63 | file.close()
64 | return text
65 |
66 | # extract descriptions for images
67 | def load_descriptions(doc):
68 | mapping = dict()
69 | # process lines
70 | for line in doc.split('\n'):
71 | # split line by white space
72 | tokens = line.split()
73 | if len(line) < 2:
74 | continue
75 | # take the first token as the image id, the rest as the description
76 | image_id, image_desc = tokens[0], tokens[1:]
77 | # remove filename from image id
78 | image_id = image_id.split('.')[0]
79 | # convert description tokens back to string
80 | image_desc = ' '.join(image_desc)
81 | # create the list if needed
82 | if image_id not in mapping:
83 | mapping[image_id] = list()
84 | # store description
85 | mapping[image_id].append(image_desc)
86 | return mapping
87 |
88 | def clean_descriptions(descriptions):
89 | # prepare translation table for removing punctuation
90 | table = str.maketrans('', '', string.punctuation)
91 | for key, desc_list in descriptions.items():
92 | for i in range(len(desc_list)):
93 | desc = desc_list[i]
94 | # tokenize
95 | desc = desc.split()
96 | # convert to lower case
97 | desc = [word.lower() for word in desc]
98 | # remove punctuation from each token
99 | desc = [w.translate(table) for w in desc]
100 | # remove hanging 's' and 'a'
101 | desc = [word for word in desc if len(word)>1]
102 | # remove tokens with numbers in them
103 | desc = [word for word in desc if word.isalpha()]
104 | # store as string
105 | desc_list[i] = ' '.join(desc)
106 |
107 | # convert the loaded descriptions into a vocabulary of words
108 | def to_vocabulary(descriptions):
109 | # build a list of all description strings
110 | all_desc = set()
111 | for key in descriptions.keys():
112 | [all_desc.update(d.split()) for d in descriptions[key]]
113 | return all_desc
114 |
115 | # save descriptions to file, one per line
116 | def save_descriptions(descriptions, filename):
117 | lines = list()
118 | for key, desc_list in descriptions.items():
119 | for desc in desc_list:
120 | lines.append(key + ' ' + desc)
121 | data = '\n'.join(lines)
122 | file = open(filename, 'w')
123 | file.write(data)
124 | file.close()
125 |
126 | # extract features from all images
127 |
128 | directory = 'Flickr8k_Dataset'
129 | features = extract_features(directory)
130 | print('Extracted Features: %d' % len(features))
131 | # save to file
132 | dump(features, open('models/features.pkl', 'wb'))
133 |
134 | # prepare descriptions
135 |
136 | filename = 'Flickr8k_text/Flickr8k.token.txt'
137 | # load descriptions
138 | doc = load_doc(filename)
139 | # parse descriptions
140 | descriptions = load_descriptions(doc)
141 | print('Loaded: %d ' % len(descriptions))
142 | # clean descriptions
143 | clean_descriptions(descriptions)
144 | # summarize vocabulary
145 | vocabulary = to_vocabulary(descriptions)
146 | print('Vocabulary Size: %d' % len(vocabulary))
147 | # save to file
148 | save_descriptions(descriptions, 'models/descriptions.txt')
149 |
--------------------------------------------------------------------------------
/CapGenerator/train_model.py:
--------------------------------------------------------------------------------
1 | import load_data as ld
2 | import generate_model as gen
3 | from keras.callbacks import ModelCheckpoint
4 | from pickle import dump
5 |
6 | def train_model(weight = None, epochs = 10):
7 | # load dataset
8 | data = ld.prepare_dataset('train')
9 | train_features, train_descriptions = data[0]
10 | test_features, test_descriptions = data[1]
11 |
12 | # prepare tokenizer
13 | tokenizer = gen.create_tokenizer(train_descriptions)
14 | # save the tokenizer
15 | dump(tokenizer, open('models/tokenizer.pkl', 'wb'))
16 | # index_word dict
17 | index_word = {v: k for k, v in tokenizer.word_index.items()}
18 | # save dict
19 | dump(index_word, open('models/index_word.pkl', 'wb'))
20 |
21 | vocab_size = len(tokenizer.word_index) + 1
22 | print('Vocabulary Size: %d' % vocab_size)
23 |
24 | # determine the maximum sequence length
25 | max_length = gen.max_length(train_descriptions)
26 | print('Description Length: %d' % max_length)
27 |
28 | # generate model
29 | model = gen.define_model(vocab_size, max_length)
30 |
31 | # Check if pre-trained weights to be used
32 | if weight != None:
33 | model.load_weights(weight)
34 |
35 | # define checkpoint callback
36 | filepath = 'models/model-ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}.h5'
37 | checkpoint = ModelCheckpoint(filepath, monitor='val_loss', verbose=1,
38 | save_best_only=True, mode='min')
39 |
40 | steps = len(train_descriptions)
41 | val_steps = len(test_descriptions)
42 | # create the data generator
43 | train_generator = gen.data_generator(train_descriptions, train_features, tokenizer, max_length)
44 | val_generator = gen.data_generator(test_descriptions, test_features, tokenizer, max_length)
45 |
46 | # fit model
47 | model.fit_generator(train_generator, epochs=epochs, steps_per_epoch=steps, verbose=1,
48 | callbacks=[checkpoint], validation_data=val_generator, validation_steps=val_steps)
49 |
50 | try:
51 | model.save('models/wholeModel.h5', overwrite=True)
52 | model.save_weights('models/weights.h5',overwrite=True)
53 | except:
54 | print("Error in saving model.")
55 | print("Training complete...\n")
56 |
57 | if __name__ == '__main__':
58 | train_model(epochs=20)
59 |
--------------------------------------------------------------------------------
/Devel.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 2,
6 | "metadata": {},
7 | "outputs": [
8 | {
9 | "name": "stdout",
10 | "output_type": "stream",
11 | "text": [
12 | "Dataset: 6000\n",
13 | "Descriptions: train=6000, test=1000\n",
14 | "Photos: train=6000, test=1000\n"
15 | ]
16 | }
17 | ],
18 | "source": [
19 | "from CapGenerator import load_data as ld\n",
20 | "\n",
21 | "data = ld.prepare_dataset('train')"
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": 3,
27 | "metadata": {},
28 | "outputs": [],
29 | "source": [
30 | "train_features, train_descriptions = data[0]"
31 | ]
32 | },
33 | {
34 | "cell_type": "code",
35 | "execution_count": 4,
36 | "metadata": {},
37 | "outputs": [
38 | {
39 | "ename": "SyntaxError",
40 | "evalue": "invalid syntax (, line 11)",
41 | "output_type": "error",
42 | "traceback": [
43 | "\u001b[0;36m File \u001b[0;32m\"\"\u001b[0;36m, line \u001b[0;32m11\u001b[0m\n\u001b[0;31m for key Pin descriptions.keys():\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m invalid syntax\n"
44 | ]
45 | }
46 | ],
47 | "source": [
48 | "from keras.preprocessing.text import Tokenizer\n",
49 | "from keras.preprocessing.sequence import pad_sequences\n",
50 | "from keras.utils import to_categorical\n",
51 | "import numpy as np\n",
52 | "from CapGenerator import generate_model as gen\n",
53 | "\n",
54 | "\n",
55 | "# convert a dictionary of clean descriptions to a list of descriptions\n",
56 | "def to_lines(descriptions):\n",
57 | " all_desc = list()\n",
58 | " for key Pin descriptions.keys():\n",
59 | " [all_desc.append(d) for d in descriptions[key]]\n",
60 | " return all_desc\n",
61 | "\n",
62 | " # fit a tokenizer given caption descriptions\n",
63 | " def create_tokenizer(descriptions):\n",
64 | " lines = to_lines(descriptions)\n",
65 | " tokenizer = Tokenizer()\n",
66 | " tokenizer.fit_on_texts(lines)\n",
67 | " return tokenizer\n",
68 | "\n",
69 | " # calculate the length of the description with the most words\n",
70 | " def max_length(descriptions):\n",
71 | " lines = to_lines(descriptions)\n",
72 | " return max(len(d.split()) for d in lines)\n",
73 | "\n",
74 | " # create sequences of images, input sequences and output words for an image\n",
75 | " def create_sequences(tokenizer, max_length, desc_list, photo):\n",
76 | " X1, X2, y = list(), list(), list()\n",
77 | " # walk through each description for the image\n",
78 | " for desc in desc_list:\n",
79 | " # encode the sequence\n",
80 | " seq = tokenizer.texts_to_sequences([desc])[0]\n",
81 | " # split one sequence into multiple X,y pairs\n",
82 | " for i in range(1, len(seq)):\n",
83 | " s\n",
84 | " # split into input and output pair\n",
85 | " in_seq, out_seq = seq[:i], seq[i]\n",
86 | " # pad input sequence\n",
87 | " in_seq = pad_sequences([in_seq], maxlen=max_length)[0]\n",
88 | " # encode output sequence\n",
89 | " out_seq = to_categorical([out_seq], num_classes=vocab_size)[0]\n",
90 | " # store\n",
91 | " X1.append(photo)\n",
92 | " X2.append(in_seq)\n",
93 | " y.append(out_seq)\n",
94 | " return np.array(X1), np.array(X2), np.array(y)\n",
95 | "\n",
96 | " # data generator, intended to be used in a call to model.fit_generator()\n",
97 | " def data_generator(descriptions, photos, tokenizer, max_length):\n",
98 | " # loop for ever over images\n",
99 | " while 1:\n",
100 | " for key, desc_list in descriptions.items():\n",
101 | " # retrieve the photo feature\n",
102 | " photo = photos[key][0]\n",
103 | " print('Photo:')\n",
104 | " print(photo.shape)\n",
105 | " in_img, in_seq, out_word = create_sequences(tokenizer, max_length, desc_list, photo)\n",
106 | " yield [[in_img, in_seq], out_word]\n",
107 | "\n",
108 | " # prepare tokenizer\n",
109 | " tokenizer = create_tokenizer(train_descriptions)\n",
110 | " vocab_size = len(tokenizer.word_index) + 1\n",
111 | " print('Vocabulary Size: %d' % vocab_size)\n",
112 | "\n",
113 | " # determine the maximum sequence length\n",
114 | " max_length = max_length(train_descriptions)\n",
115 | " print('Description Length: %d' % max_length)\n",
116 | "\n",
117 | " # test the data generator\n",
118 | " generator = gen.data_generator(train_descriptions, train_features, tokenizer, max_length)\n"
119 | ]
120 | },
121 | {
122 | "cell_type": "code",
123 | "execution_count": null,
124 | "metadata": {},
125 | "outputs": [],
126 | "source": [
127 | "inputs, outputs = next(generator)\n",
128 | "print(inputs[0].shape)\n",
129 | "print(inputs[1].shape)\n",
130 | "print(outputs.shape)"
131 | ]
132 | },
133 | {
134 | "cell_type": "code",
135 | "execution_count": null,
136 | "metadata": {},
137 | "outputs": [],
138 | "source": [
139 | "from keras.utils import to_categorical\n",
140 | "from keras.utils import plot_model\n",
141 | "from keras.models import Model, Sequential\n",
142 | "from keras.layers import Input\n",
143 | "from keras.layers import Dense\n",
144 | "from keras.layers import LSTM\n",
145 | "from keras.layers import Embedding\n",
146 | "from keras.layers import Dropout\n",
147 | "from keras.layers import RepeatVector\n",
148 | "from keras.layers import TimeDistributed\n",
149 | "from keras.layers import concatenate\n",
150 | "from keras.layers import Reshape\n",
151 | "from keras.layers import merge\n",
152 | "from keras.layers import GRU\n",
153 | "from keras.layers import BatchNormalization\n",
154 | "\n",
155 | "from IPython.display import SVG\n",
156 | "from keras.utils.vis_utils import model_to_dot\n",
157 | "\n",
158 | "from keras.optimizers import Adam\n",
159 | "\n",
160 | "\n",
161 | "EMBEDDING_DIM = 128\n",
162 | "lstm_layers = 3\n",
163 | "dropout_rate = 0.22\n",
164 | "learning_rate = 0.001\n",
165 | "\n",
166 | "# define the captioning model\n",
167 | "def define_model(vocab_size, max_length):\n",
168 | " # feature extractor (encoder)\n",
169 | " inputs1 = Input(shape=(4096, ))\n",
170 | " fe1 = Dropout(0.5)(inputs1)\n",
171 | " fe2 = Dense(EMBEDDING_DIM, activation='relu')(fe1)\n",
172 | " fe3 = RepeatVector(max_length)(fe2)\n",
173 | "\n",
174 | " # embedding\n",
175 | " inputs2 = Input(shape=(max_length, ))\n",
176 | " emb2 = Embedding(vocab_size, 256, mask_zero=True)(inputs2)\n",
177 | "\n",
178 | " # merge inputs\n",
179 | " merged = concatenate([fe3, emb2])\n",
180 | " # language model (decoder)\n",
181 | " input_ = merged\n",
182 | " for _ in range(lstm_layers):\n",
183 | " input_ = BatchNormalization()(input_)\n",
184 | " lstm_out = LSTM(\n",
185 | " 300,\n",
186 | " return_sequences=True,\n",
187 | " dropout=dropout_rate,\n",
188 | " recurrent_dropout=dropout_rate)(input_)\n",
189 | " input_ = lstm_out\n",
190 | " outputs = Dense(vocab_size, activation='softmax')(lstm_out)\n",
191 | "\n",
192 | " # tie it together [image, seq] [word]\n",
193 | " model = Model(inputs=[inputs1, inputs2], outputs=outputs)\n",
194 | " model.compile(\n",
195 | " loss='categorical_crossentropy',\n",
196 | " optimizer=Adam(lr=learning_rate),\n",
197 | " metrics=['accuracy'])\n",
198 | " print(model.summary())\n",
199 | " plot_model(model, show_shapes=True, to_file='model.png')\n",
200 | " return model"
201 | ]
202 | },
203 | {
204 | "cell_type": "code",
205 | "execution_count": null,
206 | "metadata": {},
207 | "outputs": [],
208 | "source": [
209 | "model = define_model(200, 34)"
210 | ]
211 | },
212 | {
213 | "cell_type": "code",
214 | "execution_count": null,
215 | "metadata": {},
216 | "outputs": [],
217 | "source": [
218 | "display(SVG(model_to_dot(model, show_shapes=True, show_layer_names=True).create(prog='dot', format='svg')))"
219 | ]
220 | },
221 | {
222 | "cell_type": "code",
223 | "execution_count": null,
224 | "metadata": {},
225 | "outputs": [],
226 | "source": [
227 | "\n",
228 | "def image_caption_model(vocab_size=2187, embedding_matrix=None, lang_dim=100,\n",
229 | " max_caplen=53, img_dim=2048, clipnorm=1):\n",
230 | " print('generating vocab_history model v5')\n",
231 | " # text: current word\n",
232 | " lang_input = Input(shape=(1,))\n",
233 | " img_input = Input(shape=(img_dim,))\n",
234 | " seq_input = Input(shape=(max_caplen,))\n",
235 | " vhist_input = Input(shape=(vocab_size,))\n",
236 | "\n",
237 | " if embedding_matrix is not None:\n",
238 | " x = Embedding(output_dim=lang_dim, input_dim=vocab_size, init='glorot_uniform', input_length=1, weights=[embedding_matrix])(lang_input)\n",
239 | " else:\n",
240 | " x = Embedding(output_dim=lang_dim, input_dim=vocab_size, init='glorot_uniform', input_length=1)(lang_input)\n",
241 | "\n",
242 | " lang_embed = Reshape((lang_dim,))(x)\n",
243 | " lang_embed = merge([lang_embed, seq_input], mode='concat', concat_axis=-1)\n",
244 | " lang_embed = Dense(lang_dim)(lang_embed)\n",
245 | " lang_embed = Dropout(0.25)(lang_embed)\n",
246 | "\n",
247 | " merge_layer = merge([img_input, lang_embed, vhist_input], mode='concat', concat_axis=-1)\n",
248 | " merge_layer = Reshape((1, lang_dim+img_dim+vocab_size))(merge_layer)\n",
249 | "\n",
250 | " gru_1 = GRU(img_dim)(merge_layer)\n",
251 | " gru_1 = Dropout(0.25)(gru_1)\n",
252 | " gru_1 = Dense(img_dim)(gru_1)\n",
253 | " gru_1 = BatchNormalization()(gru_1)\n",
254 | " gru_1 = Activation('softmax')(gru_1)\n",
255 | "\n",
256 | " attention_1 = merge([img_input, gru_1], mode='mul', concat_axis=-1)\n",
257 | " attention_1 = merge([attention_1, lang_embed, vhist_input], mode='concat', concat_axis=-1)\n",
258 | " attention_1 = Reshape((1, lang_dim + img_dim + vocab_size))(attention_1)\n",
259 | " gru_2 = GRU(1024)(attention_1)\n",
260 | " gru_2 = Dropout(0.25)(gru_2)\n",
261 | " gru_2 = Dense(vocab_size)(gru_2)\n",
262 | " gru_2 = BatchNormalization()(gru_2)\n",
263 | " out = Activation('softmax')(gru_2)\n",
264 | " \n",
265 | " model = Model(input=[img_input, lang_input, seq_input, vhist_input], output=out)\n",
266 | " model.compile(loss='categorical_crossentropy', optimizer=RMSprop(lr=0.0001, clipnorm=1.))\n",
267 | " plot_model(model, show_shapes=True, to_file='plot.png')\n",
268 | " return model"
269 | ]
270 | },
271 | {
272 | "cell_type": "code",
273 | "execution_count": null,
274 | "metadata": {},
275 | "outputs": [],
276 | "source": [
277 | "from keras.layers import Activation\n",
278 | "from keras.optimizers import RMSprop\n",
279 | "\n",
280 | "model = image_caption_model()"
281 | ]
282 | },
283 | {
284 | "cell_type": "code",
285 | "execution_count": null,
286 | "metadata": {},
287 | "outputs": [],
288 | "source": [
289 | "SVG(model_to_dot(model, show_shapes=True, show_layer_names=True).create(prog='dot', format='svg'))"
290 | ]
291 | },
292 | {
293 | "cell_type": "code",
294 | "execution_count": null,
295 | "metadata": {},
296 | "outputs": [],
297 | "source": [
298 | "from keras.models import Model\n",
299 | "from keras.layers import Input, Dropout, TimeDistributed, Masking, Dense, Lambda, Permute\n",
300 | "from keras.layers import BatchNormalization, Embedding, Activation, Reshape, Multiply\n",
301 | "from keras.layers.merge import Add, Concatenate, Average\n",
302 | "from keras.layers.recurrent import LSTM, GRU, SimpleRNN\n",
303 | "from keras.regularizers import l2\n",
304 | "from keras import backend as K\n",
305 | "\n",
306 | "def NIC(max_token_length, vocabulary_size, rnn='lstm' ,num_image_features=2048,\n",
307 | " hidden_size=512, embedding_size=512, regularizer=1e-8, batch_size= 20):\n",
308 | "\n",
309 | " # word embedding\n",
310 | " text_input = Input(shape=(max_token_length, vocabulary_size), name='text')#batch_shape=batch_size,\n",
311 | " text_mask = Masking(mask_value=0.0, name='text_mask')(text_input)\n",
312 | " text_to_embedding = TimeDistributed(Dense(units=embedding_size,\n",
313 | " kernel_regularizer=l2(regularizer),\n",
314 | " name='text_embedding'))(text_mask)\n",
315 | "\n",
316 | " text_dropout = Dropout(.5, name='text_dropout')(text_to_embedding)\n",
317 | "\n",
318 | " # image embedding\n",
319 | " image_input = Input(shape=(max_token_length, 14, 14, 512), #batch_shape=batch_size,\n",
320 | " name='image')\n",
321 | " image_input2 = Reshape((max_token_length, 196, 512), name='reshape1')(image_input)\n",
322 | " # denselastlayer\n",
323 | " image_embedding = TimeDistributed(Dense(units=embedding_size,\n",
324 | " kernel_regularizer=l2(regularizer),\n",
325 | " name='image_embedding'))(image_input2)\n",
326 | " image_dropout = Dropout(.5, name='image_dropout')(image_embedding)\n",
327 | "\n",
328 | " # language model\n",
329 | " #recurrent_inputs = [text_dropout, text_dropout]\n",
330 | " # merged_input = Add()(recurrent_inputs)\n",
331 | " if rnn == 'lstm':\n",
332 | " for i in range(max_token_length):\n",
333 | " if i == 0:\n",
334 | " # first_input = AVE(units=1,\n",
335 | " ## name='initial_zero')(image_dropout[:, i, :, :])\n",
336 | " getlistlayer = Lambda(lambda x: [x[:, j, :] for j in range(196)])\n",
337 | " getim0layer = Lambda(lambda x: x[:, 0, :, :])\n",
338 | " temp_layer = getim0layer(image_dropout)\n",
339 | " #order_input = getlistlayer(temp_layer)\n",
340 | " #order_input = [image_dropout[:, i, j, :] for j in range(196)]\n",
341 | " avelayer = Lambda(lambda x:K.mean(x, axis=1))\n",
342 | " first_input = avelayer(temp_layer)\n",
343 | " #first_input = Average()(order_input)\n",
344 | " #first_input = Dense(units=100,\n",
345 | " # kernel_regularizer=l2(regularizer),\n",
346 | " # name='initial_zero_embed')(first_input)\n",
347 | " gettx0layer = Lambda(lambda x: x[:, 0, :])\n",
348 | " first_input2 = gettx0layer(text_dropout)\n",
349 | " recurrent_inputs = [first_input2, first_input]\n",
350 | " merged_input_temp = Add()(recurrent_inputs)\n",
351 | " merged_input = Reshape((1, embedding_size), name='reshape2')(merged_input_temp)\n",
352 | " else:\n",
353 | " getim1layer = Lambda(lambda x: x[:, 1, :, :])\n",
354 | " getim2layer = Lambda(lambda x: x[:, 2, :, :])\n",
355 | " getim3layer = Lambda(lambda x: x[:, 3, :, :])\n",
356 | " getim4layer = Lambda(lambda x: x[:, 4, :, :])\n",
357 | " getim5layer = Lambda(lambda x: x[:, 5, :, :])\n",
358 | " getim6layer = Lambda(lambda x: x[:, 6, :, :])\n",
359 | " getim7layer = Lambda(lambda x: x[:, 7, :, :])\n",
360 | " getim8layer = Lambda(lambda x: x[:, 8, :, :])\n",
361 | " getim9layer = Lambda(lambda x: x[:, 9, :, :])\n",
362 | " getim10layer = Lambda(lambda x: x[:, 10, :, :])\n",
363 | " getim11layer = Lambda(lambda x: x[:, 11, :, :])\n",
364 | " getim12layer = Lambda(lambda x: x[:, 12, :, :])\n",
365 | " getim13layer = Lambda(lambda x: x[:, 13, :, :])\n",
366 | " getim14layer = Lambda(lambda x: x[:, 14, :, :])\n",
367 | " getim15layer = Lambda(lambda x: x[:, 15, :, :])\n",
368 | " getim16layer = Lambda(lambda x: x[:, 16, :, :])\n",
369 | " getim17layer = Lambda(lambda x: x[:, 17, :, :])\n",
370 | " getim18layer = Lambda(lambda x: x[:, 18, :, :])\n",
371 | " getim19layer = Lambda(lambda x: x[:, 19, :, :])\n",
372 | " getim20layer = Lambda(lambda x: x[:, 20, :, :])\n",
373 | " getim21layer = Lambda(lambda x: x[:, 21, :, :])\n",
374 | " getim22layer = Lambda(lambda x: x[:, 22, :, :])\n",
375 | " getim23layer = Lambda(lambda x: x[:, 23, :, :])\n",
376 | " getim24layer = Lambda(lambda x: x[:, 24, :, :])\n",
377 | " getim25layer = Lambda(lambda x: x[:, 25, :, :])\n",
378 | " getim26layer = Lambda(lambda x: x[:, 26, :, :])\n",
379 | " getim27layer = Lambda(lambda x: x[:, 27, :, :])\n",
380 | " getim28layer = Lambda(lambda x: x[:, 28, :, :])\n",
381 | " getim29layer = Lambda(lambda x: x[:, 29, :, :])\n",
382 | " getim30layer = Lambda(lambda x: x[:, 30, :, :])\n",
383 | " getim31layer = Lambda(lambda x: x[:, 31, :, :])\n",
384 | " if i == 1:\n",
385 | " outputimsplit = getim1layer(image_dropout)\n",
386 | " elif i == 2:\n",
387 | " outputimsplit = getim2layer(image_dropout)\n",
388 | " elif i==3:\n",
389 | " outputimsplit = getim3layer(image_dropout)\n",
390 | " elif i == 4:\n",
391 | " outputimsplit = getim4layer(image_dropout)\n",
392 | " elif i==5:\n",
393 | " outputimsplit = getim5layer(image_dropout)\n",
394 | " elif i == 6:\n",
395 | " outputimsplit = getim6layer(image_dropout)\n",
396 | " elif i==7:\n",
397 | " outputimsplit = getim7layer(image_dropout)\n",
398 | " elif i == 8:\n",
399 | " outputimsplit = getim8layer(image_dropout)\n",
400 | " elif i==9:\n",
401 | " outputimsplit = getim9layer(image_dropout)\n",
402 | " elif i == 10:\n",
403 | " outputimsplit = getim10layer(image_dropout)\n",
404 | " elif i==11:\n",
405 | " outputimsplit = getim11layer(image_dropout)\n",
406 | " elif i == 12:\n",
407 | " outputimsplit = getim12layer(image_dropout)\n",
408 | " elif i==13:\n",
409 | " outputimsplit = getim13layer(image_dropout)\n",
410 | " elif i == 14:\n",
411 | " outputimsplit = getim14layer(image_dropout)\n",
412 | " elif i==15:\n",
413 | " outputimsplit = getim15layer(image_dropout)\n",
414 | " elif i == 16:\n",
415 | " outputimsplit = getim16layer(image_dropout)\n",
416 | " elif i==17:\n",
417 | " outputimsplit = getim17layer(image_dropout)\n",
418 | " elif i == 18:\n",
419 | " outputimsplit = getim18layer(image_dropout)\n",
420 | " elif i==19:\n",
421 | " outputimsplit = getim19layer(image_dropout)\n",
422 | " elif i == 20:\n",
423 | " outputimsplit = getim20layer(image_dropout)\n",
424 | " elif i==21:\n",
425 | " outputimsplit = getim21layer(image_dropout)\n",
426 | " elif i == 22:\n",
427 | " outputimsplit = getim22layer(image_dropout)\n",
428 | " elif i==23:\n",
429 | " outputimsplit = getim23layer(image_dropout)\n",
430 | " elif i == 24:\n",
431 | " outputimsplit = getim24layer(image_dropout)\n",
432 | " elif i==25:\n",
433 | " outputimsplit = getim25layer(image_dropout)\n",
434 | " elif i == 26:\n",
435 | " outputimsplit = getim26layer(image_dropout)\n",
436 | " elif i==27:\n",
437 | " outputimsplit = getim27layer(image_dropout)\n",
438 | " elif i == 28:\n",
439 | " outputimsplit = getim28layer(image_dropout)\n",
440 | " elif i==29:\n",
441 | " outputimsplit = getim29layer(image_dropout)\n",
442 | " elif i == 30:\n",
443 | " outputimsplit = getim30layer(image_dropout)\n",
444 | " else:\n",
445 | " outputimsplit = getim31layer(image_dropout)\n",
446 | " per_out = Permute((2, 1))(outputimsplit)\n",
447 | " per_out1 = Dense(units=1,\n",
448 | " kernel_regularizer=l2(regularizer))(per_out)\n",
449 | " dim_change = Dense(units=128,\n",
450 | " kernel_regularizer=l2(regularizer))(lstm_out)\n",
451 | " dim_change2 = Permute((2, 1))(dim_change)\n",
452 | " #per_out2 = Reshape((196, 1))(dim_change)\n",
453 | " attendout3 = Multiply()([per_out1, dim_change2])\n",
454 | " pre_merge = Reshape((1, embedding_size))(attendout3)\n",
455 | " gettx1layer = Lambda(lambda x: x[:, 1, :])\n",
456 | " gettx2layer = Lambda(lambda x: x[:, 2, :])\n",
457 | " gettx3layer = Lambda(lambda x: x[:, 3, :])\n",
458 | " gettx4layer = Lambda(lambda x: x[:, 4, :])\n",
459 | " gettx5layer = Lambda(lambda x: x[:, 5, :])\n",
460 | " gettx6layer = Lambda(lambda x: x[:, 6, :])\n",
461 | " gettx7layer = Lambda(lambda x: x[:, 7, :])\n",
462 | " gettx8layer = Lambda(lambda x: x[:, 8, :])\n",
463 | " gettx9layer = Lambda(lambda x: x[:, 9, :])\n",
464 | " gettx10layer = Lambda(lambda x: x[:, 10, :])\n",
465 | " gettx11layer = Lambda(lambda x: x[:, 11, :])\n",
466 | " gettx12layer = Lambda(lambda x: x[:, 12, :])\n",
467 | " gettx13layer = Lambda(lambda x: x[:, 13, :])\n",
468 | " gettx14layer = Lambda(lambda x: x[:, 14, :])\n",
469 | " gettx15layer = Lambda(lambda x: x[:, 15, :])\n",
470 | " gettx16layer = Lambda(lambda x: x[:, 16, :])\n",
471 | " gettx17layer = Lambda(lambda x: x[:, 17, :])\n",
472 | " gettx18layer = Lambda(lambda x: x[:, 18, :])\n",
473 | " gettx19layer = Lambda(lambda x: x[:, 19, :])\n",
474 | " gettx20layer = Lambda(lambda x: x[:, 20, :])\n",
475 | " gettx21layer = Lambda(lambda x: x[:, 21, :])\n",
476 | " gettx22layer = Lambda(lambda x: x[:, 22, :])\n",
477 | " gettx23layer = Lambda(lambda x: x[:, 23, :])\n",
478 | " gettx24layer = Lambda(lambda x: x[:, 24, :])\n",
479 | " gettx25layer = Lambda(lambda x: x[:, 25, :])\n",
480 | " gettx26layer = Lambda(lambda x: x[:, 26, :])\n",
481 | " gettx27layer = Lambda(lambda x: x[:, 27, :])\n",
482 | " gettx28layer = Lambda(lambda x: x[:, 28, :])\n",
483 | " gettx29layer = Lambda(lambda x: x[:, 29, :])\n",
484 | " gettx30layer = Lambda(lambda x: x[:, 30, :])\n",
485 | " gettx31layer = Lambda(lambda x: x[:, 31, :])\n",
486 | " if i == 1:\n",
487 | " outputtxsplit = gettx1layer(image_dropout)\n",
488 | " elif i == 2:\n",
489 | " outputtxsplit = gettx2layer(image_dropout)\n",
490 | " elif i==3:\n",
491 | " outputtxsplit = gettx3layer(image_dropout)\n",
492 | " elif i == 4:\n",
493 | " outputtxsplit = gettx4layer(image_dropout)\n",
494 | " elif i==5:\n",
495 | " outputtxsplit = gettx5layer(image_dropout)\n",
496 | " elif i == 6:\n",
497 | " outputtxsplit = gettx6layer(image_dropout)\n",
498 | " elif i==7:\n",
499 | " outputtxsplit = gettx7layer(image_dropout)\n",
500 | " elif i == 8:\n",
501 | " outputtxsplit = gettx8layer(image_dropout)\n",
502 | " elif i==9:\n",
503 | " outputtxsplit = gettx9layer(image_dropout)\n",
504 | " elif i == 10:\n",
505 | " outputtxsplit = gettx10layer(image_dropout)\n",
506 | " elif i==11:\n",
507 | " outputtxsplit = gettx11layer(image_dropout)\n",
508 | " elif i == 12:\n",
509 | " outputtxsplit = gettx12layer(image_dropout)\n",
510 | " elif i==13:\n",
511 | " outputtxsplit = gettx13layer(image_dropout)\n",
512 | " elif i == 14:\n",
513 | " outputtxsplit = gettx14layer(image_dropout)\n",
514 | " elif i==15:\n",
515 | " outputtxsplit = gettx15layer(image_dropout)\n",
516 | " elif i == 16:\n",
517 | " outputtxsplit = gettx16layer(image_dropout)\n",
518 | " elif i==17:\n",
519 | " outputtxsplit = gettx17layer(image_dropout)\n",
520 | " elif i == 18:\n",
521 | " outputtxsplit = gettx18layer(image_dropout)\n",
522 | " elif i==19:\n",
523 | " outputtxsplit = gettx19layer(image_dropout)\n",
524 | " elif i == 20:\n",
525 | " outputtxsplit = gettx20layer(image_dropout)\n",
526 | " elif i==21:\n",
527 | " outputtxsplit = gettx21layer(image_dropout)\n",
528 | " elif i == 22:\n",
529 | " outputtxsplit = gettx22layer(image_dropout)\n",
530 | " elif i==23:\n",
531 | " outputtxsplit = gettx23layer(image_dropout)\n",
532 | " elif i == 24:\n",
533 | " outputtxsplit = gettx24layer(image_dropout)\n",
534 | " elif i==25:\n",
535 | " outputtxsplit = gettx25layer(image_dropout)\n",
536 | " elif i == 26:\n",
537 | " outputtxsplit = gettx26layer(image_dropout)\n",
538 | " elif i==27:\n",
539 | " outputtxsplit = gettx27layer(image_dropout)\n",
540 | " elif i == 28:\n",
541 | " outputtxsplit = gettx28layer(image_dropout)\n",
542 | " elif i==29:\n",
543 | " outputtxsplit = gettx29layer(image_dropout)\n",
544 | " elif i == 30:\n",
545 | " outputtxsplit = gettx30layer(image_dropout)\n",
546 | " else:\n",
547 | " outputtxsplit = gettx31layer(image_dropout)\n",
548 | " shape_im = Permute((2, 1))(outputtxsplit)\n",
549 | " dim_change_im = Dense(units=1,\n",
550 | " kernel_regularizer=l2(regularizer))(shape_im)\n",
551 | " pre_merge_txt = Permute((2, 1))(dim_change_im)\n",
552 | " #pre_merge_txt = Reshape((1, embedding_size))(outputtxsplit)\n",
553 | " recurrent_inputs = [pre_merge_txt, pre_merge]\n",
554 | " merged_input = Add()(recurrent_inputs)\n",
555 | " #merged_input = Reshape((1, embedding_size), name='reshape3')(merged_input_temp)\n",
556 | " lstm_out = LSTM(units=hidden_size,#[:, i, :]\n",
557 | " recurrent_regularizer=l2(regularizer),\n",
558 | " kernel_regularizer=l2(regularizer),\n",
559 | " bias_regularizer=l2(regularizer),\n",
560 | " return_sequences=True,\n",
561 | " name='recurrent_network' + str(i))(merged_input)\n",
562 | " if i == 0:\n",
563 | " lstm_out_final = Concatenate(axis=1)([lstm_out, lstm_out])\n",
564 | " else:\n",
565 | " lstm_out_final = Concatenate(axis=1)([lstm_out_final, lstm_out])\n",
566 | " else:\n",
567 | " raise Exception('Invalid rnn name')\n",
568 | " getoutbelayer = Lambda(lambda x: x[:, 1:, :])\n",
569 | " output_be = Reshape((max_token_length+1, embedding_size))(lstm_out_final)\n",
570 | " output_bee = getoutbelayer(output_be)\n",
571 | " output = TimeDistributed(Dense(units=vocabulary_size,\n",
572 | " kernel_regularizer=l2(regularizer),\n",
573 | " activation='softmax'),\n",
574 | " name='output')(output_bee)\n",
575 | "\n",
576 | " inputs = [text_input, image_input]\n",
577 | " model = Model(inputs=inputs, outputs=output)\n",
578 | " return model"
579 | ]
580 | },
581 | {
582 | "cell_type": "code",
583 | "execution_count": null,
584 | "metadata": {},
585 | "outputs": [],
586 | "source": [
587 | "model = NIC(16, 1024)\n",
588 | "plot_model(model, '../images/NIC.png')"
589 | ]
590 | },
591 | {
592 | "cell_type": "code",
593 | "execution_count": null,
594 | "metadata": {},
595 | "outputs": [],
596 | "source": [
597 | "from keras.applications.vgg16 import VGG16\n",
598 | "\n",
599 | "model = VGG16(weights='imagenet', include_top=True, input_shape = (224, 224, 3))"
600 | ]
601 | },
602 | {
603 | "cell_type": "code",
604 | "execution_count": null,
605 | "metadata": {},
606 | "outputs": [],
607 | "source": [
608 | "plot_model(model, to_file='model.png', show_shapes=True)"
609 | ]
610 | },
611 | {
612 | "cell_type": "code",
613 | "execution_count": null,
614 | "metadata": {},
615 | "outputs": [],
616 | "source": [
617 | "for k, v in train_descriptions.items():\n",
618 | " "
619 | ]
620 | },
621 | {
622 | "cell_type": "code",
623 | "execution_count": null,
624 | "metadata": {},
625 | "outputs": [],
626 | "source": [
627 | "from keras.preprocessing.text import Tokenizer\n",
628 | "\n",
629 | " t = Tokenizer() # all without .\n",
630 | " text = \" Tomorrow will be cold. \"\n",
631 | " text = text.replace(\".\", \" .\")\n",
632 | " t.fit_on_texts([text])\n",
633 | " print(t.word_index)"
634 | ]
635 | },
636 | {
637 | "cell_type": "code",
638 | "execution_count": null,
639 | "metadata": {},
640 | "outputs": [],
641 | "source": [
642 | "# extract features from each photo in the directory\n",
643 | "def extract_features(filename):\n",
644 | " # load the model\n",
645 | " model = VGG16()\n",
646 | " # re-structure the model\n",
647 | " model.layers.pop()\n",
648 | " model = Model(inputs=model.inputs, outputs=model.layers[-1].output)\n",
649 | " # load the photo\n",
650 | " image = load_img(filename, target_size=(224, 224))\n",
651 | " # convert the image pixels to a numpy array\n",
652 | " image = img_to_array(image)\n",
653 | " # reshape data for the model\n",
654 | " image = image.reshape((1, image.shape[0], image.shape[1], image.shape[2]))\n",
655 | " # prepare the image for the VGG model\n",
656 | " image = preprocess_input(image)\n",
657 | " # get features\n",
658 | " feature = model.predict(image, verbose=0)\n",
659 | " return feature"
660 | ]
661 | },
662 | {
663 | "cell_type": "code",
664 | "execution_count": null,
665 | "metadata": {},
666 | "outputs": [],
667 | "source": [
668 | "from keras.layers import Embedding, Input\n",
669 | "from keras.layers import BatchNormalization, Dense, RepeatVector\n",
670 | "from keras.applications.vgg16 import VGG16\n",
671 | "\n",
672 | "# The top layer is the last layer\n",
673 | "image_model = VGG16(weights='imagenet')\n",
674 | "# Fix the weights\n",
675 | "for layer in image_model.layers:\n",
676 | " layer.trainable = False\n",
677 | "\n",
678 | "embedding_size = 300\n",
679 | "dense_input = BatchNormalization(axis=-1)(image_model.output)\n",
680 | "image_dense = Dense(units=embedding_size)(dense_input) # FC layer\n",
681 | "# Add a timestep dimension to match LSTM's input size\n",
682 | "image_embedding = RepeatVector(1)(image_dense)\n",
683 | "image_input = image_model.input\n",
684 | "\n",
685 | "vocab_size = 2536\n",
686 | "embedding_size = 300\n",
687 | "\n",
688 | "sentence_input = Input(shape=[34])\n",
689 | "word_embedding = Embedding(\n",
690 | " input_dim=vocab_size, output_dim=embedding_size)(sentence_input)\n",
691 | "\n",
692 | "from keras.layers import (BatchNormalization, Concatenate, Dense, LSTM,\n",
693 | " TimeDistributed)\n",
694 | "from keras.models import Model\n",
695 | "from keras.optimizers import Adam\n",
696 | "\n",
697 | "sequence_input = Concatenate(axis=1)([image_embedding, word_embedding])\n",
698 | "\n",
699 | "learning_rate = 0.00051\n",
700 | "lstm_output_size = 300\n",
701 | "vocab_size = 2536\n",
702 | "lstm_layers = 3\n",
703 | "dropout_rate = 0.22\n",
704 | "input_ = sequence_input\n",
705 | "\n",
706 | "for _ in range(lstm_layers):\n",
707 | " input_ = BatchNormalization(axis=-1)(input_)\n",
708 | " lstm_out = LSTM(\n",
709 | " units=lstm_output_size,\n",
710 | " return_sequences=True,\n",
711 | " dropout=dropout_rate,\n",
712 | " recurrent_dropout=dropout_rate)(input_)\n",
713 | " input_ = lstm_out\n",
714 | "sequence_output = TimeDistributed(Dense(units=vocab_size))(lstm_out)\n",
715 | "\n",
716 | "model = Model(inputs=[image_input, sentence_input], outputs=sequence_output)\n"
717 | ]
718 | },
719 | {
720 | "cell_type": "code",
721 | "execution_count": null,
722 | "metadata": {},
723 | "outputs": [],
724 | "source": [
725 | "SVG(model_to_dot(model, show_shapes=True, show_layer_names=True).create(prog='dot', format='svg'))"
726 | ]
727 | },
728 | {
729 | "cell_type": "code",
730 | "execution_count": 6,
731 | "metadata": {},
732 | "outputs": [
733 | {
734 | "name": "stderr",
735 | "output_type": "stream",
736 | "text": [
737 | "usage: ipykernel_launcher.py [-h] [-i IMAGE] [-m MODEL]\n",
738 | "ipykernel_launcher.py: error: unrecognized arguments: -f /Users/Divyansh/Library/Jupyter/runtime/kernel-421658b4-8964-4bee-84bf-e58e86f851f8.json\n"
739 | ]
740 | },
741 | {
742 | "ename": "SystemExit",
743 | "evalue": "2",
744 | "output_type": "error",
745 | "traceback": [
746 | "An exception has occurred, use %tb to see the full traceback.\n",
747 | "\u001b[0;31mSystemExit\u001b[0m\u001b[0;31m:\u001b[0m 2\n"
748 | ]
749 | },
750 | {
751 | "name": "stderr",
752 | "output_type": "stream",
753 | "text": [
754 | "/Users/Divyansh/anaconda3/lib/python3.6/site-packages/IPython/core/interactiveshell.py:2870: UserWarning: To exit: use 'exit', 'quit', or Ctrl-D.\n",
755 | " warn(\"To exit: use 'exit', 'quit', or Ctrl-D.\", stacklevel=1)\n"
756 | ]
757 | }
758 | ],
759 | "source": [
760 | "from pickle import load\n",
761 | "import numpy as np\n",
762 | "from keras.preprocessing.sequence import pad_sequences\n",
763 | "from keras.applications.vgg16 import VGG16\n",
764 | "from keras.preprocessing.image import load_img\n",
765 | "from keras.preprocessing.image import img_to_array\n",
766 | "from keras.applications.vgg16 import preprocess_input\n",
767 | "from keras.models import Model\n",
768 | "from keras.models import load_model\n",
769 | "from nltk.translate.bleu_score import corpus_bleu\n",
770 | "\n",
771 | "from CapGenerator import load_data as ld\n",
772 | "from CapGenerator import generate_model as gen\n",
773 | "import argparse\n",
774 | "\n",
775 | "# extract features from each photo in the directory\n",
776 | "def extract_features(filename):\n",
777 | " # load the model\n",
778 | " model = VGG16()\n",
779 | " # re-structure the model\n",
780 | " model.layers.pop()\n",
781 | " model = Model(inputs=model.inputs, outputs=model.layers[-1].output)\n",
782 | " # load the photo\n",
783 | " image = load_img(filename, target_size=(224, 224))\n",
784 | " # convert the image pixels to a numpy array\n",
785 | " image = img_to_array(image)\n",
786 | " # reshape data for the model\n",
787 | " image = image.reshape((1, image.shape[0], image.shape[1], image.shape[2]))\n",
788 | " # prepare the image for the VGG model\n",
789 | " image = preprocess_input(image)\n",
790 | " # get features\n",
791 | " feature = model.predict(image, verbose=0)\n",
792 | " return feature\n",
793 | "\n",
794 | "# generate a description for an image\n",
795 | "def generate_desc(model, tokenizer, photo, index_word, max_length, beam_size=5):\n",
796 | "\n",
797 | " captions = [['startseq', 0.0]]\n",
798 | " # seed the generation process\n",
799 | " in_text = 'startseq'\n",
800 | " # iterate over the whole length of the sequence\n",
801 | " for i in range(max_length):\n",
802 | " all_caps = []\n",
803 | " # expand each current candidate\n",
804 | " for cap in captions:\n",
805 | " sentence, score = cap\n",
806 | " # if final word is 'end' token, just add the current caption\n",
807 | " if sentence.split()[-1] == 'endseq':\n",
808 | " all_caps.append(cap)\n",
809 | " continue\n",
810 | " # integer encode input sequence\n",
811 | " sequence = tokenizer.texts_to_sequences([sentence])[0]\n",
812 | " # pad input\n",
813 | " sequence = pad_sequences([sequence], maxlen=max_length)\n",
814 | " # predict next words\n",
815 | " y_pred = model.predict([photo,sequence], verbose=0)[0]\n",
816 | " # convert probability to integer\n",
817 | " yhats = np.argsort(y_pred)[-beam_size:]\n",
818 | "\n",
819 | " for j in yhats:\n",
820 | " # map integer to word\n",
821 | " word = index_word.get(j)\n",
822 | " # stop if we cannot map the word\n",
823 | " if word is None:\n",
824 | " continue\n",
825 | " # Add word to caption, and generate log prob\n",
826 | " caption = [sentence + ' ' + word, score + np.log(y_pred[j])]\n",
827 | " all_caps.append(caption)\n",
828 | "\n",
829 | " # order all candidates by score\n",
830 | " ordered = sorted(all_caps, key=lambda tup:tup[1], reverse=True)\n",
831 | " captions = ordered[:beam_size]\n",
832 | "\n",
833 | " return captions\n",
834 | "\n",
835 | "# evaluate the skill of the model\n",
836 | "def evaluate_model(model, descriptions, photos, tokenizer, index_word, max_length):\n",
837 | " actual, predicted = list(), list()\n",
838 | " # step over the whole set\n",
839 | " for key, desc_list in descriptions.items():\n",
840 | " # generate description\n",
841 | " yhat = generate_desc(model, tokenizer, photos[key], index_word, max_length)[0]\n",
842 | " # store actual and predicted\n",
843 | " references = [d.split() for d in desc_list]\n",
844 | " actual.append(references)\n",
845 | " # Use best caption\n",
846 | " predicted.append(yhat[0].split())\n",
847 | " # calculate BLEU score\n",
848 | " print('BLEU-1: %f' % corpus_bleu(actual, predicted, weights=(1.0, 0, 0, 0)))\n",
849 | " print('BLEU-2: %f' % corpus_bleu(actual, predicted, weights=(0.5, 0.5, 0, 0)))\n",
850 | " print('BLEU-3: %f' % corpus_bleu(actual, predicted, weights=(0.3, 0.3, 0.3, 0)))\n",
851 | " print('BLEU-4: %f' % corpus_bleu(actual, predicted, weights=(0.25, 0.25, 0.25, 0.25)))\n",
852 | "\n",
853 | "def eval_test_set(model, descriptions, photos, tokenizer, index_word, max_length):\n",
854 | " actual, predicted = list(), list()\n",
855 | " # step over the whole set\n",
856 | " for key, desc_list in descriptions.items():\n",
857 | " # generate description\n",
858 | " yhat = generate_desc(model, tokenizer, photos[key], index_word, max_length)[0]\n",
859 | " # store actual and predicted\n",
860 | " references = [d.split() for d in desc_list]\n",
861 | " actual.append(references)\n",
862 | " # Use best caption\n",
863 | " predicted.append(yhat[0].split())\n",
864 | " predicted = sorted(predicted)\n",
865 | " actual = [x for _,x in sorted(zip(actual,predicted))]\n",
866 | "\n",
867 | "if __name__ == '__main__':\n",
868 | "\n",
869 | " parser = argparse.ArgumentParser(description='Generate image captions')\n",
870 | " parser.add_argument(\"-i\", \"--image\", help=\"Input image path\")\n",
871 | " parser.add_argument(\"-m\", \"--model\", help=\"model checkpoint\")\n",
872 | " args = parser.parse_args()\n",
873 | "\n",
874 | "\n",
875 | " # load the tokenizer\n",
876 | " tokenizer = load(open('models/tokenizer.pkl', 'rb'))\n",
877 | " index_word = load(open('models/index_word.pkl', 'rb'))\n",
878 | " # pre-define the max sequence length (from training)\n",
879 | " max_length = 34\n",
880 | "\n",
881 | " # load the model\n",
882 | " if args.model:\n",
883 | " filename = args.model\n",
884 | " else:\n",
885 | " filename = 'models/model-ep005-loss3.504-val_loss3.893.h5'\n",
886 | " model = load_model(filename)\n",
887 | "\n",
888 | " if args.image:\n",
889 | " # load and prepare the photograph\n",
890 | " photo = extract_features(args.image)\n",
891 | " # generate description\n",
892 | " captions = generate_desc(model, tokenizer, photo, index_word, max_length)\n",
893 | " for cap in captions:\n",
894 | " # remove start and end tokens\n",
895 | " seq = cap[0].split()[1:-1]\n",
896 | " desc = ' '.join(seq)\n",
897 | " print('{} [log prob: {:1.2f}]'.format(desc,cap[1]))\n",
898 | " else:\n",
899 | " # load test set\n",
900 | " test_features, test_descriptions = ld.prepare_dataset('test')[1]\n",
901 | "\n",
902 | " # evaluate model\n",
903 | " evaluate_model(model, test_descriptions, test_features, tokenizer, index_word, max_length)\n"
904 | ]
905 | },
906 | {
907 | "cell_type": "code",
908 | "execution_count": null,
909 | "metadata": {},
910 | "outputs": [],
911 | "source": []
912 | }
913 | ],
914 | "metadata": {
915 | "kernelspec": {
916 | "display_name": "Anaconda",
917 | "language": "python",
918 | "name": "anaconda"
919 | }
920 | },
921 | "nbformat": 4,
922 | "nbformat_minor": 2
923 | }
924 |
--------------------------------------------------------------------------------
/Flickr8k_Dataset/README.md:
--------------------------------------------------------------------------------
1 | All the images are to be stored in this directory.
2 |
3 |
--------------------------------------------------------------------------------
/Flickr8k_text/README.md:
--------------------------------------------------------------------------------
1 | The annotations are placed in this directory.
2 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Divyansh Garg
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # Image Captioning (Keras)
3 |
4 | Image Captioning System that generates natural language captions for any image.
5 |
6 | The architecture for the model is inspired from "Show and Tell" [1] by Vinyals et al. The model is built using [Keras](https://keras.io/) library.
7 |
8 | The project also contains code for Attention LSTM layer, although not integrated in the model.
9 |
10 | ## Dataset
11 | The model is trained on [Flickr8k Dataset](https://illinois.edu/fb/sec/1713398)
12 |
13 | Although it can be trained on others like Flickr30k or MS COCO
14 |
15 | ## Model
16 |
17 |

18 |
19 |
20 | ## Performance
21 | The model has been trained for 20 epoches on 6000 training samples of Flickr8k Dataset. It acheives a `BLEU-1 = ~0.59` with 1000 testing samples.
22 |
23 | ----------------------------------
24 |
25 | ## Requirements
26 | - tensorflow
27 | - keras
28 | - numpy
29 | - h5py
30 | - progressbar2
31 |
32 | These requirements can be easily installed by:
33 | `pip install -r requirements.txt`
34 |
35 |
36 | ## Scripts
37 |
38 | - __caption_generator.py__: The base script that contains functions for model creation, batch data generator etc.
39 | - __prepare_data.py__: Extracts features from images using VGG16 imagenet model. Also prepares annotation for training. Changes have to be done to this script if new dataset is to be used.
40 | - __train_model.py__: Module for training the caption generator.
41 | - __eval_model.py__: Contains module for evaluating and testing the performance of the caption generator, currently, it contains the [BLEU](https://en.wikipedia.org/wiki/BLEU) metric.
42 |
43 | ## Usage
44 |
45 | ### Pre-trained model
46 | 1. Download pre-trained weights from [releases](https://github.com/Div99/Image-Captioning/releases)
47 | 2. Move `model_weight.h5` to `models` directory
48 | 3. Prepare data using `python prepare_data.py`
49 | 4. For inference on example image, run: `python eval_model.py -i [img-path]`
50 |
51 | ### From scratch
52 | After the requirements have been installed, the process from training to testing is fairly easy. The commands to run:
53 | 1. `python prepare_data.py`
54 | 2. `python train_model.py`
55 | 3. `python eval_model.py`
56 |
57 | After training, evaluation on an example image can be done by running:
58 | `python eval_model.py -m [model-checkpoint] -i [img-path]`
59 |
60 | ## Results
61 |
62 | Image | Caption
63 | --- | ---
64 |
| **Generated Caption:** A white and black dog is running through the water
65 |
| **Generated Caption:** man is skiing on snowy hill
66 |
| **Generated Caption:** man in red shirt is walking down the street
67 |
68 | ----------------------------------
69 |
70 | ## References
71 | [1] Oriol Vinyals, Alexander Toshev, Samy Bengio, Dumitru Erhan. [Show and Tell: A Neural Image Caption Generator](https://arxiv.org/pdf/1411.4555.pdf)
72 |
73 | [2] Kelvin Xu, Jimmy Ba, Ryan Kiros, Kyunghyun Cho, Aaron Courville, Ruslan Salakhutdinov, Richard Zemel, Yoshua Bengio. [Show, Attend and Tell: Neural Image Caption Generation with Visual Attention](https://arxiv.org/pdf/1502.03044.pdf)
74 |
75 | ----------------------------------
76 |
77 | ## License
78 | MIT License. See LICENSE file for details.
79 |
--------------------------------------------------------------------------------
/imgs/dog.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Div99/Image-Captioning/d7465d54e02ace69abc395e250176a1936f705b5/imgs/dog.jpg
--------------------------------------------------------------------------------
/imgs/people.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Div99/Image-Captioning/d7465d54e02ace69abc395e250176a1936f705b5/imgs/people.jpg
--------------------------------------------------------------------------------
/imgs/ski.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Div99/Image-Captioning/d7465d54e02ace69abc395e250176a1936f705b5/imgs/ski.jpg
--------------------------------------------------------------------------------
/imgs/worker.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Div99/Image-Captioning/d7465d54e02ace69abc395e250176a1936f705b5/imgs/worker.jpg
--------------------------------------------------------------------------------
/model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Div99/Image-Captioning/d7465d54e02ace69abc395e250176a1936f705b5/model.png
--------------------------------------------------------------------------------
/models/README.md:
--------------------------------------------------------------------------------
1 | Save trained models and processed datasets here.
2 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | h5py
2 | Keras
3 | numpy
4 | pydot
5 | Pillow
6 | nltk
7 | tensorflow
8 | progressbar2
9 |
--------------------------------------------------------------------------------
/train_attention.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [
8 | {
9 | "name": "stderr",
10 | "output_type": "stream",
11 | "text": [
12 | "Using TensorFlow backend.\n"
13 | ]
14 | }
15 | ],
16 | "source": [
17 | "from os import listdir\n",
18 | "from pickle import dump\n",
19 | "from keras.applications.vgg16 import VGG16\n",
20 | "from keras.preprocessing.image import load_img\n",
21 | "from keras.preprocessing.image import img_to_array\n",
22 | "from keras.applications.vgg16 import preprocess_input\n",
23 | "from keras.layers import Input, Reshape, Concatenate, Activation, Dense\n",
24 | "import keras\n",
25 | "import numpy as np\n",
26 | "import string\n",
27 | "# from progressbar import progressbar\n",
28 | "from keras.models import Model"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": 8,
34 | "metadata": {},
35 | "outputs": [],
36 | "source": [
37 | "def extract_features():\n",
38 | " # load model\n",
39 | " model = VGG16()\n",
40 | " model.layers.pop()\n",
41 | " # extract final 49x512 conv layer for context vectors\n",
42 | " final_conv = Reshape([49,512])(model.layers[-4].output)\n",
43 | " model = Model(inputs=model.inputs, outputs=final_conv)\n",
44 | " print(model.summary())\n",
45 | " features = dict()"
46 | ]
47 | },
48 | {
49 | "cell_type": "code",
50 | "execution_count": 9,
51 | "metadata": {},
52 | "outputs": [],
53 | "source": [
54 | "dim_ctx = -1\n",
55 | "n_ctx = -1\n",
56 | "lstm_cell_dim = -1\n",
57 | "lstm_hidden_dim = -1\n",
58 | "\n",
59 | "# Initializes the cell & hidden states of the LSTM by passing the mean context vector through MLPs\n",
60 | "# [contexts] : n_ctx x dim_ctx matrix\n",
61 | "def init_lstm_states(contexts):\n",
62 | " mean_context = tf.reduce_mean(contexts,)\n",
63 | " \n",
64 | " # num hidden units in MLP\n",
65 | " mlp_hidden = 256\n",
66 | " \n",
67 | " # cell_state MLP\n",
68 | " inputs_c = Input(shape=(dim_ctx,))\n",
69 | " f_c = Dense(mlp_hidden,activation=\"relu\")(inputs_c)\n",
70 | " f_c = Dense(lstm_cell_dim,activation=None)(f_c)\n",
71 | " \n",
72 | " # hidden_state MLP\n",
73 | " inputs_h = Input(shape=(dim_ctx,))\n",
74 | " f_h = Dense(mlp_hidden,activation=\"relu\")(inputs_h)\n",
75 | " f_h = Dense(lstm_hidden_dim,activation=None)(f_h)\n",
76 | " \n",
77 | " return f_c,f_h\n",
78 | "\n",
79 | "# Computes alpha values (attention weights) by passing context vectors + hidden state through MLP\n",
80 | "# Includes hidden state by concatenating to end of alpha values\n",
81 | "def generate_alphas(contexts,hidden_state):\n",
82 | " mlp_hidden = 256\n",
83 | " \n",
84 | " # tile and concatenate inputs\n",
85 | " batch_size = contexts.shape[0]\n",
86 | " tiled_hidden_state = tf.tile([[hidden_state]],[batch_size,n_ctx,1])\n",
87 | " concat_input = Concatenate(axis=-1)((contexts,tiled_hidden_state))\n",
88 | " \n",
89 | " # feed into MLP\n",
90 | " inputs = Input(shape=(dim_ctx+lstm_hidden_dim,))\n",
91 | " x = Dense(mlp_hidden,activation=\"relu\")(inputs)\n",
92 | " x = Dense(1,activation=None)(x)\n",
93 | " x = Activation(\"softmax\")(x)\n",
94 | " \n",
95 | " return x\n",
96 | "\n",
97 | "# Generates a soft-attention attention vector from alphas & context vectors\n",
98 | "def get_soft_attention_vec(contexts,alphas):\n",
99 | " return contexts*tf.reshape(alphas,[1,-1,1])"
100 | ]
101 | },
102 | {
103 | "cell_type": "code",
104 | "execution_count": 4,
105 | "metadata": {},
106 | "outputs": [],
107 | "source": [
108 | "alphas = np.random.randint(10,size=[49])\n",
109 | "contexts = np.ones([10,49,512])"
110 | ]
111 | },
112 | {
113 | "cell_type": "code",
114 | "execution_count": 5,
115 | "metadata": {},
116 | "outputs": [
117 | {
118 | "data": {
119 | "text/plain": [
120 | "(10, 49, 512)"
121 | ]
122 | },
123 | "execution_count": 5,
124 | "metadata": {},
125 | "output_type": "execute_result"
126 | }
127 | ],
128 | "source": [
129 | "(contexts*np.reshape(alphas,[1,-1,1])).shape"
130 | ]
131 | },
132 | {
133 | "cell_type": "code",
134 | "execution_count": null,
135 | "metadata": {},
136 | "outputs": [],
137 | "source": []
138 | },
139 | {
140 | "cell_type": "code",
141 | "execution_count": 6,
142 | "metadata": {},
143 | "outputs": [],
144 | "source": [
145 | "dim_ctx = 10\n",
146 | "lstm_hidden_dim = 10\n",
147 | "mlp_hidden = 1\n",
148 | "inputs = Input(shape=(dim_ctx+lstm_hidden_dim,))\n",
149 | "x = Dense(mlp_hidden,activation=\"relu\")(inputs)\n",
150 | "x = Dense(1,activation=None)(x)\n",
151 | "x = Activation(\"softmax\")(x)"
152 | ]
153 | },
154 | {
155 | "cell_type": "code",
156 | "execution_count": 7,
157 | "metadata": {},
158 | "outputs": [
159 | {
160 | "data": {
161 | "text/plain": [
162 | ""
163 | ]
164 | },
165 | "execution_count": 7,
166 | "metadata": {},
167 | "output_type": "execute_result"
168 | }
169 | ],
170 | "source": [
171 | "x"
172 | ]
173 | },
174 | {
175 | "cell_type": "code",
176 | "execution_count": null,
177 | "metadata": {},
178 | "outputs": [],
179 | "source": []
180 | }
181 | ],
182 | "metadata": {
183 | "kernelspec": {
184 | "display_name": "Anaconda",
185 | "language": "python",
186 | "name": "anaconda"
187 | },
188 | "language_info": {
189 | "codemirror_mode": {
190 | "name": "ipython",
191 | "version": 3
192 | },
193 | "file_extension": ".py",
194 | "mimetype": "text/x-python",
195 | "name": "python",
196 | "nbconvert_exporter": "python",
197 | "pygments_lexer": "ipython3",
198 | "version": "3.6.2"
199 | }
200 | },
201 | "nbformat": 4,
202 | "nbformat_minor": 2
203 | }
204 |
--------------------------------------------------------------------------------