├── utility ├── __init__.py ├── utility.py ├── language_encoder.py ├── download.py ├── cache.py └── coco.py ├── model.png ├── README.md ├── LICENSE ├── .gitignore ├── model.py ├── generate_features.py ├── eval.py └── train.py /utility/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zimmerrol/show-attend-and-tell-keras/HEAD/model.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Show, Attend and Tell (Keras) [WIP] 2 | Keras implementation of the paper [Show, Attend and Tell: Neural Image Caption Generation with Visual Attention](http://arxiv.org/abs/1502.03044) which introduces an attention based image caption generator. The model changes its attention to the relevant part of the image while it generates each word. 3 | 4 | This figure form the original paper gives a short explanation of the network's structure. 5 | ![Architecture](model.png) 6 | 7 | This project depends on the [Keras Utility & Layer Collection (kulc)](https://github.com/FlashTek/keras-layer-collection), which implements many useful layers and utility functions for attention based models. 8 | 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Roland Zimmermann 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /utility/utility.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import utility.coco as coco 3 | import h5py 4 | 5 | def load_validation_data(maximum_caption_length): 6 | coco.set_data_dir("./data/coco") 7 | coco.maybe_download_and_extract() 8 | 9 | _, _, captions_val_raw = coco.load_records(train=False) 10 | 11 | h5 = h5py.File("image.features.val.VGG19.block5_conv4.h5", "r") 12 | get_data = lambda i: h5[i] 13 | 14 | return captions_val_raw, get_data 15 | 16 | 17 | def load_training_data(maximum_caption_length): 18 | coco.set_data_dir("./data/coco") 19 | coco.maybe_download_and_extract() 20 | 21 | _, _, captions_train_raw = coco.load_records(train=True) 22 | 23 | h5 = h5py.File("image.features.train.VGG19.block5_conv4.h5", "r") 24 | get_data = lambda i: h5[i] 25 | 26 | return captions_train_raw, get_data 27 | 28 | 29 | def create_vocabulary(maximum_size, text_sets): 30 | words = dict() 31 | for texts in text_sets: 32 | for text in texts: 33 | for word in text.lower().split(): 34 | if word in words: 35 | words[word] += 1 36 | else: 37 | words[word] = 1 38 | 39 | words = [item[0] for item in reversed(sorted(words.items(), key=lambda y: y[1]))] 40 | words = ["", "", ""] + words 41 | words = words[:maximum_size] 42 | 43 | word_index_map = {} 44 | index_word_map = {} 45 | for i, word in enumerate(words): 46 | word_index_map[word] = i 47 | index_word_map[i] = word 48 | 49 | return word_index_map, index_word_map 50 | 51 | def encode_text_sets(text_sets, word_index_map): 52 | encoded_text_sets = [] 53 | for i, texts in enumerate(text_sets): 54 | encoded_texts = [] 55 | for j, text in enumerate(texts): 56 | encoded_text = [] 57 | for word in text.split(): 58 | if word.lower() in word_index_map: 59 | encoded_text.append(word_index_map[word.lower()]) 60 | else: 61 | encoded_text.append(word_index_map[""]) 62 | 63 | encoded_texts.append(encoded_text) 64 | encoded_text_sets.append(encoded_texts) 65 | 66 | return encoded_text_sets -------------------------------------------------------------------------------- /utility/language_encoder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from utility.utility import create_vocabulary, encode_text_sets 3 | import pickle 4 | 5 | class LanguageEncoder(object): 6 | def __init__(self, maximum_vocabulary_size=10000, lower=True, forbidden_characters='!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n'): 7 | self._maximum_vocabulary_size = maximum_vocabulary_size 8 | self._forbidden_characters = forbidden_characters 9 | self._lower = lower 10 | 11 | translate_forbidding_characters_dict = dict((c, " ") for c in self._forbidden_characters) 12 | self._translate_forbidding_characters_map = str.maketrans(translate_forbidding_characters_dict) 13 | 14 | def _filter(self, text): 15 | if self._lower: 16 | return text.lower().translate(self._translate_forbidding_characters_map) 17 | else: 18 | return text.translate(self._translate_forbidding_characters_map) 19 | 20 | def _build_vocabulary(self, text_sets): 21 | self._word_index_map, self._index_word_map = create_vocabulary(self._maximum_vocabulary_size, text_sets) 22 | self._vocabulary_size = len(self._word_index_map) 23 | 24 | def fit(self, x): 25 | # x: list of lists containing the texts 26 | 27 | # clean texts 28 | for i, texts in enumerate(x): 29 | for j, text in enumerate(texts): 30 | x[i][j] = self._filter(text) 31 | 32 | self._build_vocabulary(x) 33 | 34 | def transform(self, x, oh_encode=False): 35 | # clean texts 36 | for i, texts in enumerate(x): 37 | for j, text in enumerate(texts): 38 | x[i][j] = self._filter(text) 39 | 40 | return encode_text_sets(x, self._word_index_map) 41 | 42 | def transform_word(self, word): 43 | if word in self._word_index_map: 44 | return self._word_index_map[word] 45 | else: 46 | raise ValueError("word not in the dictionary") 47 | 48 | def fit_transform(self, x): 49 | self.fit(x) 50 | return self.transform(x) 51 | 52 | def save(self, filename): 53 | with open(filename, "wb") as file: 54 | pickle.dump(self, file, protocol=pickle.HIGHEST_PROTOCOL) 55 | 56 | @staticmethod 57 | def load(filename): 58 | with open(filename, "rb") as file: 59 | return pickle.load(file) -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from keras.layers import Input, Dense, LSTM, TimeDistributed, Embedding, Lambda 2 | from kulc.attention import ExternalAttentionRNNWrapper 3 | from keras.models import Model 4 | import keras.backend as K 5 | import tensorflow as tf 6 | 7 | W = 14 8 | H = 14 9 | L = W*H 10 | D = 512 11 | 12 | """ 13 | - use features [BS, H, W, D] 14 | - flatten [BS, H*W, D] 15 | - linear transformation [BS, H*W, D]: 16 | - flatten/reshape -> [BS*H*W, D] 17 | # Dense(D) -> [BS*H*W, D] 18 | - reshape -> [BS, H*W, D] 19 | 20 | """ 21 | def create_model(vocabulary_size, embedding_size, T, L, D): 22 | image_features_input = Input(shape=(L, D), name="image_features_input") 23 | captions_input = Input(shape=(T,), name="captions_input") 24 | captions = Embedding(vocabulary_size, embedding_size, input_length=T)(captions_input) 25 | 26 | averaged_image_features = Lambda(lambda x: K.mean(x, axis=1)) 27 | averaged_image_features = averaged_image_features(image_features_input) 28 | initial_state_h = Dense(embedding_size)(averaged_image_features) 29 | initial_state_c = Dense(embedding_size)(averaged_image_features) 30 | 31 | image_features = TimeDistributed(Dense(D, activation="relu"))(image_features_input) 32 | 33 | encoder = LSTM(embedding_size, return_sequences=True, return_state=True, recurrent_dropout=0.1) 34 | attented_encoder = ExternalAttentionRNNWrapper(encoder, return_attention=True) 35 | 36 | output = TimeDistributed(Dense(vocabulary_size, activation="softmax"), name="output") 37 | 38 | # for training purpose 39 | attented_encoder_training_data, _, _ , _= attented_encoder([captions, image_features], initial_state=[initial_state_h, initial_state_c]) 40 | training_output_data = output(attented_encoder_training_data) 41 | 42 | training_model = Model(inputs=[captions_input, image_features_input], outputs=training_output_data) 43 | 44 | initial_state_inference_model = Model(inputs=[image_features_input], outputs=[initial_state_h, initial_state_c]) 45 | 46 | inference_initial_state_h = Input(shape=(embedding_size,)) 47 | inference_initial_state_c = Input(shape=(embedding_size,)) 48 | attented_encoder_inference_data, inference_encoder_state_h, inference_encoder_state_c, inference_attention = attented_encoder( 49 | [captions, image_features], 50 | initial_state=[inference_initial_state_h, inference_initial_state_c] 51 | ) 52 | 53 | inference_output_data = output(attented_encoder_inference_data) 54 | 55 | inference_model = Model( 56 | inputs=[image_features_input, captions_input, inference_initial_state_h, inference_initial_state_c], 57 | outputs=[inference_output_data, inference_encoder_state_h, inference_encoder_state_c, inference_attention] 58 | ) 59 | 60 | return training_model, inference_model, initial_state_inference_model -------------------------------------------------------------------------------- /utility/download.py: -------------------------------------------------------------------------------- 1 | ######################################################################## 2 | # 3 | # Functions for downloading and extracting data-files from the internet. 4 | # 5 | # Implemented in Python 3.5 6 | # 7 | ######################################################################## 8 | # 9 | # This file is part of the TensorFlow Tutorials available at: 10 | # 11 | # https://github.com/Hvass-Labs/TensorFlow-Tutorials 12 | # 13 | # Published under the MIT License. See the file LICENSE for details. 14 | # 15 | # Copyright 2016 by Magnus Erik Hvass Pedersen 16 | # 17 | ######################################################################## 18 | 19 | import sys 20 | import os 21 | import urllib.request 22 | import tarfile 23 | import zipfile 24 | 25 | ######################################################################## 26 | 27 | 28 | def _print_download_progress(count, block_size, total_size): 29 | """ 30 | Function used for printing the download progress. 31 | Used as a call-back function in maybe_download_and_extract(). 32 | """ 33 | 34 | # Percentage completion. 35 | pct_complete = float(count * block_size) / total_size 36 | 37 | # Status-message. Note the \r which means the line should overwrite itself. 38 | msg = "\r- Download progress: {0:.1%}".format(pct_complete) 39 | 40 | # Print it. 41 | sys.stdout.write(msg) 42 | sys.stdout.flush() 43 | 44 | 45 | ######################################################################## 46 | 47 | 48 | def maybe_download_and_extract(url, download_dir): 49 | """ 50 | Download and extract the data if it doesn't already exist. 51 | Assumes the url is a tar-ball file. 52 | 53 | :param url: 54 | Internet URL for the tar-file to download. 55 | Example: "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" 56 | 57 | :param download_dir: 58 | Directory where the downloaded file is saved. 59 | Example: "data/CIFAR-10/" 60 | 61 | :return: 62 | Nothing. 63 | """ 64 | 65 | # Filename for saving the file downloaded from the internet. 66 | # Use the filename from the URL and add it to the download_dir. 67 | filename = url.split('/')[-1] 68 | file_path = os.path.join(download_dir, filename) 69 | 70 | # Check if the file already exists. 71 | # If it exists then we assume it has also been extracted, 72 | # otherwise we need to download and extract it now. 73 | if not os.path.exists(file_path): 74 | # Check if the download directory exists, otherwise create it. 75 | if not os.path.exists(download_dir): 76 | os.makedirs(download_dir) 77 | 78 | # Download the file from the internet. 79 | file_path, _ = urllib.request.urlretrieve(url=url, 80 | filename=file_path, 81 | reporthook=_print_download_progress) 82 | 83 | print() 84 | print("Download finished. Extracting files.") 85 | 86 | if file_path.endswith(".zip"): 87 | # Unpack the zip-file. 88 | zipfile.ZipFile(file=file_path, mode="r").extractall(download_dir) 89 | elif file_path.endswith((".tar.gz", ".tgz")): 90 | # Unpack the tar-ball. 91 | tarfile.open(name=file_path, mode="r:gz").extractall(download_dir) 92 | 93 | print("Done.") 94 | else: 95 | print("Data has apparently already been downloaded and unpacked.") 96 | 97 | 98 | ######################################################################## 99 | -------------------------------------------------------------------------------- /generate_features.py: -------------------------------------------------------------------------------- 1 | import click 2 | import numpy as np 3 | from utility import coco 4 | import h5py 5 | import os 6 | from PIL import Image 7 | from keras.layers import Input 8 | from keras.applications import ResNet50, VGG16, VGG19 9 | from keras.models import Model 10 | from tqdm import tqdm 11 | 12 | def setup_model(encoder, layer_name): 13 | image_input = Input(shape=(224, 224, 3)) 14 | 15 | base_model = None 16 | if encoder == 'vgg16': 17 | base_model = VGG16(include_top=False, weights='imagenet', input_tensor=image_input, input_shape=(224, 224, 3)) 18 | elif encoder == 'vgg19': 19 | base_model = VGG19(include_top=False, weights='imagenet', input_tensor=image_input, input_shape=(224, 224, 3)) 20 | else: 21 | raise ValueError("not implemented encoder type") 22 | 23 | model = Model(inputs=base_model.input, outputs=base_model.get_layer(layer_name).output) 24 | return model 25 | 26 | def encode_features(model, filenames, image_directory, batch_size=64): 27 | # calculate encoded features 28 | generator = data_generator(filenames, image_directory, batch_size=batch_size) 29 | n_batches = int(np.ceil(len(filenames) / batch_size)) 30 | 31 | for i, batch_data in tqdm(enumerate(generator), total=n_batches): 32 | output = model.predict(batch_data, batch_size=batch_size) 33 | yield output 34 | 35 | def data_generator(filenames, image_directory, batch_size=64): 36 | n_batches = int(np.ceil(len(filenames) / batch_size)) 37 | 38 | for batch_id in range(n_batches): 39 | batch_image_filenames = filenames[batch_id*batch_size:(batch_id+1)*batch_size] 40 | 41 | batch_images = [None] * len(batch_image_filenames) 42 | for i, filename in enumerate(batch_image_filenames): 43 | image = Image.open(os.path.join(image_directory, filename)) 44 | image = image.resize((224, 224)).convert('RGB') 45 | batch_images[i] = np.asarray(image) 46 | batch_images = np.array(batch_images) 47 | 48 | x_data = { 49 | "input_1": batch_images 50 | } 51 | 52 | yield x_data 53 | 54 | @click.command() 55 | @click.option("--data-path", "-d", default="./data/coco/", required=False, type=click.Path(exists=False, file_okay=False, dir_okay=True)) 56 | @click.option("--encoder", "-e", default="VGG19", required=False, type=click.STRING) 57 | @click.option("--layer-name", "-l", default="block5_conv4", required=False, type=click.STRING) 58 | @click.option("--output-folder", "-o", default=".", required=False, type=click.Path(exists=True, file_okay=False, dir_okay=True)) 59 | @click.option("--batch-size", "-b", default=64, required=False, type=click.INT) 60 | def cmd(data_path, encoder, layer_name, output_folder, batch_size): 61 | # create data directory if it does not exist 62 | os.makedirs(data_path, exist_ok=True) 63 | 64 | # download the files now if required 65 | coco.set_data_dir(data_path) 66 | coco.maybe_download_and_extract() 67 | 68 | # load the data now 69 | _, filenames_train, captions_train_raw = coco.load_records(train=True) 70 | _, filenames_val, captions_val_raw = coco.load_records(train=False) 71 | 72 | # encoded the data and save it 73 | model = setup_model(encoder.strip().lower(), layer_name) 74 | 75 | with h5py.File(os.path.join(output_folder, "image.features.train.{0}.{1}.h5".format(encoder, layer_name)), "w") as h5: 76 | index = 0 77 | for batch in encode_features(model, filenames_train, os.path.join(data_path, "train2017"), batch_size=batch_size): 78 | for item in batch: 79 | h5.create_dataset(str(index), data=item, compression="lzf") 80 | index += 1 81 | 82 | with h5py.File(os.path.join(output_folder, "image.features.val.{0}.{1}.h5".format(encoder, layer_name)), "w") as h5: 83 | index = 0 84 | for batch in encode_features(model, filenames_val, os.path.join(data_path, "val2017"), batch_size=batch_size): 85 | for item in batch: 86 | h5.create_dataset(str(index), data=item, compression="lzf") 87 | index += 1 88 | 89 | # pylint: disable=no-value-for-parameter 90 | if __name__ == "__main__": 91 | cmd() -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | from model import create_model 2 | from utility.utility import load_training_data, load_validation_data 3 | from utility.language_encoder import LanguageEncoder 4 | import numpy as np 5 | from matplotlib import pyplot as plt 6 | from keras.callbacks import TensorBoard 7 | import keras.backend as K 8 | 9 | from scipy.misc import imresize 10 | import skimage.transform 11 | import matplotlib.pyplot as plt 12 | from keras.models import load_model 13 | from kulc.attention import ExternalAttentionRNNWrapper 14 | import pathlib 15 | from keras.optimizers import adam 16 | 17 | MAXIMUM_CAPTION_LENGTH = 16 18 | 19 | def generator(batch_size, captions, get_image): 20 | while True: 21 | batch_indices = np.random.randint(0, len(captions), size=batch_size, dtype=np.int) 22 | batch_image_features = np.empty((len(batch_indices), 14*14, 512)) 23 | for i, j in enumerate(batch_indices): 24 | batch_image_features[i] = get_image(str(j)).value.reshape((14*14, 512)) 25 | 26 | batch_captions = [captions[item] for item in batch_indices] 27 | 28 | batch_captions = [x[np.random.randint(0, len(x))][:MAXIMUM_CAPTION_LENGTH-1] for x in batch_captions] 29 | input_captions = [[le.transform_word("")] + x for x in batch_captions] 30 | output_captions = [x + [le.transform_word("")] for x in batch_captions] 31 | 32 | output_captions = one_hot_encode(output_captions, MAXIMUM_CAPTION_LENGTH, MAXIMUM_CAPTION_LENGTH) 33 | input_captions = np.array([x+[le.transform_word("")]*(MAXIMUM_CAPTION_LENGTH-len(x)) for x in input_captions]).astype(np.float32) 34 | 35 | batch_image_features = np.array(batch_image_features, dtype=np.float32) 36 | 37 | x_data = { 38 | "image_features_input": batch_image_features, 39 | "captions_input": input_captions 40 | } 41 | 42 | y_data = { 43 | "output": output_captions 44 | } 45 | 46 | yield (x_data, y_data) 47 | 48 | def one_hot_encode(data, maximum_caption_length, n_classes): 49 | result = np.zeros((len(data), maximum_caption_length, n_classes)) 50 | for i, item in enumerate(data): 51 | for j, word in enumerate(item): 52 | result[i, j, word] = 1.0 53 | for k in range(j+1, maximum_caption_length): 54 | result[i, k, le.transform_word("")] = 1.0 55 | 56 | return result 57 | 58 | def inference(image_features, plot_attention): 59 | image_features = np.array([image_features]) 60 | state_h, state_c = initial_state_inference_model.predict(image_features) 61 | 62 | caption = [le.transform_word("")] 63 | attentions = [] 64 | 65 | current_word = None 66 | for t in range(MAXIMUM_CAPTION_LENGTH): 67 | caption_array = np.array(caption).reshape(1, -1) 68 | output, state_h, state_c, attention = inference_model.predict([image_features, caption_array, state_h, state_c]) 69 | attentions.append(attention[0, -1].reshape((14, 14))) 70 | 71 | current_word = np.argmax(output[0, -1]) 72 | caption.append(current_word) 73 | 74 | if current_word == le.transform_word(""): 75 | break 76 | sentence = [le._index_word_map[i] for i in caption[1:]] 77 | 78 | if plot_attention: 79 | print(len(attentions)) 80 | x = int(np.sqrt(len(attentions))) 81 | y = int(np.ceil(len(attentions) / x)) 82 | _, axes = plt.subplots(y, x, sharex="col", sharey="row") 83 | axes = axes.flatten() 84 | for i in range(len(attentions)): 85 | atn = skimage.transform.pyramid_expand(attentions[i], upscale=16, sigma=20) 86 | axes[i].set_title(sentence[i]) 87 | axes[i].imshow(atn, cmap="gray") 88 | 89 | plt.show() 90 | 91 | return " ".join(sentence) + " ({0})".format(len(caption)-1) 92 | 93 | pathlib.Path('./models').mkdir(exist_ok=True) 94 | 95 | maximum_caption_length = 16 96 | 97 | le = LanguageEncoder.load("./models/language.pkl") 98 | captions_val_raw, get_image_features_val = load_validation_data(maximum_caption_length) 99 | captions_val = le.transform(captions_val_raw) 100 | 101 | model_id = input("Model ID: ") 102 | model_id = int(model_id) 103 | inference_model = load_model(f"./models/sat_inf_{model_id}.h5", custom_objects={"ExternalAttentionRNNWrapper": ExternalAttentionRNNWrapper}) 104 | initial_state_inference_model = load_model(f"./models/sat_inf_init_{model_id}.h5", custom_objects={"ExternalAttentionRNNWrapper": ExternalAttentionRNNWrapper}) 105 | 106 | while True: 107 | max_idx = len(captions_val) 108 | image_idx = input(f"Enter the image index (0-{max_idx}): ") 109 | image_idx = int(image_idx) 110 | 111 | print("output:") 112 | print("\t {0}".format(inference(get_image_features_val(str(image_idx)).value.reshape(14*14, 512), plot_attention=False))) 113 | print("target: ") 114 | for i in range(len(captions_val_raw[image_idx])): 115 | print("\t{0}".format(captions_val_raw[image_idx][i])) 116 | input("done. ") 117 | -------------------------------------------------------------------------------- /utility/cache.py: -------------------------------------------------------------------------------- 1 | ######################################################################## 2 | # 3 | # Cache-wrapper for a function or class. 4 | # 5 | # Save the result of calling a function or creating an object-instance 6 | # to harddisk. This is used to persist the data so it can be reloaded 7 | # very quickly and easily. 8 | # 9 | # Implemented in Python 3.5 10 | # 11 | ######################################################################## 12 | # 13 | # This file is part of the TensorFlow Tutorials available at: 14 | # 15 | # https://github.com/Hvass-Labs/TensorFlow-Tutorials 16 | # 17 | # Published under the MIT License. See the file LICENSE for details. 18 | # 19 | # Copyright 2016 by Magnus Erik Hvass Pedersen 20 | # 21 | ######################################################################## 22 | 23 | import os 24 | import pickle 25 | import numpy as np 26 | 27 | ######################################################################## 28 | 29 | 30 | def cache(cache_path, fn, *args, **kwargs): 31 | """ 32 | Cache-wrapper for a function or class. If the cache-file exists 33 | then the data is reloaded and returned, otherwise the function 34 | is called and the result is saved to cache. The fn-argument can 35 | also be a class instead, in which case an object-instance is 36 | created and saved to the cache-file. 37 | 38 | :param cache_path: 39 | File-path for the cache-file. 40 | 41 | :param fn: 42 | Function or class to be called. 43 | 44 | :param args: 45 | Arguments to the function or class-init. 46 | 47 | :param kwargs: 48 | Keyword arguments to the function or class-init. 49 | 50 | :return: 51 | The result of calling the function or creating the object-instance. 52 | """ 53 | 54 | # If the cache-file exists. 55 | if os.path.exists(cache_path): 56 | # Load the cached data from the file. 57 | with open(cache_path, mode='rb') as file: 58 | obj = pickle.load(file) 59 | 60 | print("- Data loaded from cache-file: " + cache_path) 61 | else: 62 | # The cache-file does not exist. 63 | 64 | # Call the function / class-init with the supplied arguments. 65 | obj = fn(*args, **kwargs) 66 | 67 | # Save the data to a cache-file. 68 | with open(cache_path, mode='wb') as file: 69 | pickle.dump(obj, file) 70 | 71 | print("- Data saved to cache-file: " + cache_path) 72 | 73 | return obj 74 | 75 | 76 | ######################################################################## 77 | 78 | 79 | def convert_numpy2pickle(in_path, out_path): 80 | """ 81 | Convert a numpy-file to pickle-file. 82 | 83 | The first version of the cache-function used numpy for saving the data. 84 | Instead of re-calculating all the data, you can just convert the 85 | cache-file using this function. 86 | 87 | :param in_path: 88 | Input file in numpy-format written using numpy.save(). 89 | 90 | :param out_path: 91 | Output file written as a pickle-file. 92 | 93 | :return: 94 | Nothing. 95 | """ 96 | 97 | # Load the data using numpy. 98 | data = np.load(in_path) 99 | 100 | # Save the data using pickle. 101 | with open(out_path, mode='wb') as file: 102 | pickle.dump(data, file) 103 | 104 | 105 | ######################################################################## 106 | 107 | if __name__ == '__main__': 108 | # This is a short example of using a cache-file. 109 | 110 | # This is the function that will only get called if the result 111 | # is not already saved in the cache-file. This would normally 112 | # be a function that takes a long time to compute, or if you 113 | # need persistent data for some other reason. 114 | def expensive_function(a, b): 115 | return a * b 116 | 117 | print('Computing expensive_function() ...') 118 | 119 | # Either load the result from a cache-file if it already exists, 120 | # otherwise calculate expensive_function(a=123, b=456) and 121 | # save the result to the cache-file for next time. 122 | result = cache(cache_path='cache_expensive_function.pkl', 123 | fn=expensive_function, a=123, b=456) 124 | 125 | print('result =', result) 126 | 127 | # Newline. 128 | print() 129 | 130 | # This is another example which saves an object to a cache-file. 131 | 132 | # We want to cache an object-instance of this class. 133 | # The motivation is to do an expensive computation only once, 134 | # or if we need to persist the data for some other reason. 135 | class ExpensiveClass: 136 | def __init__(self, c, d): 137 | self.c = c 138 | self.d = d 139 | self.result = c * d 140 | 141 | def print_result(self): 142 | print('c =', self.c) 143 | print('d =', self.d) 144 | print('result = c * d =', self.result) 145 | 146 | print('Creating object from ExpensiveClass() ...') 147 | 148 | # Either load the object from a cache-file if it already exists, 149 | # otherwise make an object-instance ExpensiveClass(c=123, d=456) 150 | # and save the object to the cache-file for the next time. 151 | obj = cache(cache_path='cache_ExpensiveClass.pkl', 152 | fn=ExpensiveClass, c=123, d=456) 153 | 154 | obj.print_result() 155 | 156 | ######################################################################## 157 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from model import create_model 2 | from utility.utility import load_training_data, load_validation_data 3 | from utility.language_encoder import LanguageEncoder 4 | import numpy as np 5 | from matplotlib import pyplot as plt 6 | from keras.callbacks import TensorBoard 7 | import keras.backend as K 8 | 9 | from scipy.misc import imresize 10 | import skimage.transform 11 | import matplotlib.pyplot as plt 12 | from keras.models import load_model 13 | from kulc.attention import ExternalAttentionRNNWrapper 14 | import pathlib 15 | from keras.optimizers import adam 16 | 17 | LOAD_MODEL = False 18 | 19 | def generator(batch_size, captions, get_image): 20 | while True: 21 | batch_indices = np.random.randint(0, len(captions), size=batch_size, dtype=np.int) 22 | batch_image_features = np.empty((len(batch_indices), 14*14, 512)) 23 | for i, j in enumerate(batch_indices): 24 | batch_image_features[i] = get_image(str(j)).value.reshape((14*14, 512)) 25 | 26 | batch_captions = [captions[item] for item in batch_indices] 27 | 28 | batch_captions = [x[np.random.randint(0, len(x))][:MAXIMUM_CAPTION_LENGTH-1] for x in batch_captions] 29 | input_captions = [[le.transform_word("")] + x for x in batch_captions] 30 | output_captions = [x + [le.transform_word("")] for x in batch_captions] 31 | 32 | input_captions = np.array([x+[le.transform_word("")]*(MAXIMUM_CAPTION_LENGTH-len(x)) for x in input_captions]).astype(np.float32) 33 | output_captions = one_hot_encode(output_captions, MAXIMUM_CAPTION_LENGTH, MAXIMUM_VOCABULARY_SIZE) 34 | 35 | batch_image_features = np.array(batch_image_features, dtype=np.float32) 36 | 37 | x_data = { 38 | "image_features_input": batch_image_features, 39 | "captions_input": input_captions 40 | } 41 | 42 | y_data = { 43 | "output": output_captions 44 | } 45 | 46 | yield (x_data, y_data) 47 | 48 | def one_hot_encode(data, MAXIMUM_CAPTION_LENGTH, n_classes): 49 | result = np.zeros((len(data), MAXIMUM_CAPTION_LENGTH, n_classes)) 50 | for i, item in enumerate(data): 51 | for j, word in enumerate(item): 52 | result[i, j, word] = 1.0 53 | for k in range(j+1, MAXIMUM_CAPTION_LENGTH): 54 | result[i, k, le.transform_word("")] = 1.0 55 | 56 | return result 57 | 58 | def inference(image_features, plot_attention): 59 | image_features = np.array([image_features]) 60 | state_h, state_c = initial_state_inference_model.predict(image_features) 61 | 62 | caption = [word_index_map[""]] 63 | attentions = [] 64 | 65 | current_word = None 66 | for t in range(MAXIMUM_CAPTION_LENGTH): 67 | caption_array = np.array(caption).reshape(1, -1) 68 | output, state_h, state_c, attention = inference_model.predict([image_features, caption_array, state_h, state_c]) 69 | attentions.append(attention[0, -1].reshape((14, 14))) 70 | 71 | current_word = np.argmax(output[0, -1]) 72 | caption.append(current_word) 73 | 74 | if current_word == word_index_map[""]: 75 | break 76 | sentence = [index_word_map[i] for i in caption[1:]] 77 | 78 | if plot_attention: 79 | print(len(attentions)) 80 | x = int(np.sqrt(len(attentions))) 81 | y = int(np.ceil(len(attentions) / x)) 82 | _, axes = plt.subplots(y, x, sharex="col", sharey="row") 83 | axes = axes.flatten() 84 | for i in range(len(attentions)): 85 | atn = skimage.transform.pyramid_expand(attentions[i], upscale=16, sigma=20) 86 | axes[i].set_title(sentence[i]) 87 | axes[i].imshow(atn, cmap="gray") 88 | 89 | plt.show() 90 | 91 | return " ".join(sentence) + " ({0})".format(len(caption)-1) 92 | 93 | pathlib.Path('./models').mkdir(exist_ok=True) 94 | 95 | MAXIMUM_VOCABULARY_SIZE = 10000 96 | EMBEDDING_SIZE = 512 # 1024 97 | MAXIMUM_CAPTION_LENGTH = 16 98 | 99 | """ 100 | filters='!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n' 101 | translate_dict = dict((c, " ") for c in filters) 102 | translate_map = str.maketrans(translate_dict) 103 | 104 | captions_train_raw, get_image_features_train = load_training_data(MAXIMUM_CAPTION_LENGTH) 105 | captions_val_raw, get_image_features_val = load_validation_data(MAXIMUM_CAPTION_LENGTH) 106 | for i, entry in enumerate(captions_val_raw): 107 | for j, item in enumerate(entry): 108 | captions_val_raw[i][j] = item.translate(translate_map).lower() 109 | for i, entry in enumerate(captions_train_raw): 110 | for j, item in enumerate(entry): 111 | captions_train_raw[i][j] = item.translate(translate_map).lower() 112 | 113 | word_index_map, index_word_map = create_vocabulary(MAXIMUM_VOCABULARY_SIZE, captions_train_raw) 114 | MAXIMUM_VOCABULARY_SIZE = len(word_index_map) 115 | 116 | captions_train = encode_annotations(captions_train_raw, word_index_map, MAXIMUM_CAPTION_LENGTH) 117 | captions_val = encode_annotations(captions_val_raw, word_index_map, MAXIMUM_CAPTION_LENGTH) 118 | """ 119 | 120 | 121 | captions_train_raw, get_image_features_train = load_training_data(MAXIMUM_CAPTION_LENGTH) 122 | captions_val_raw, get_image_features_val = load_validation_data(MAXIMUM_CAPTION_LENGTH) 123 | le = LanguageEncoder(MAXIMUM_VOCABULARY_SIZE) 124 | captions_train = le.fit_transform(captions_train_raw) 125 | captions_val = le.transform(captions_val_raw) 126 | le.save("./models/language.pkl") 127 | 128 | def masked_categorical_crossentropy(y_true, y_pred): 129 | mask_value = le._word_index_map[""] 130 | y_true_id = K.argmax(y_true) 131 | mask = K.cast(K.equal(y_true_id, mask_value), K.floatx()) 132 | mask = 1.0 - mask 133 | loss = K.categorical_crossentropy(y_true, y_pred) * mask 134 | 135 | # take average w.r.t. the number of unmasked entries 136 | return K.sum(loss) / K.sum(mask) 137 | 138 | training_model, inference_model, initial_state_inference_model = create_model(le._vocabulary_size, EMBEDDING_SIZE, None, 14*14, 512) 139 | training_model.compile(adam(0.001), loss=masked_categorical_crossentropy, metrics=["accuracy"]) 140 | 141 | batch_size = 64 142 | 143 | def train(epochs=100): 144 | tbCallback = TensorBoard(log_dir='./logs', histogram_freq=0, batch_size=batch_size, write_graph=True, write_grads=False, write_images=False) 145 | history = training_model.fit_generator(generator(batch_size=batch_size, captions=captions_train, get_image=get_image_features_train), steps_per_epoch=len(captions_train)//batch_size, epochs=epochs, verbose=1, callbacks=[tbCallback]) 146 | 147 | training_model.save("./models/sat_train_{0}.h5".format(epochs)) 148 | inference_model.save("./models/sat_inf_{0}.h5".format(epochs)) 149 | initial_state_inference_model.save("./models/sat_inf_init_{0}.h5".format(epochs)) 150 | 151 | for key in history.history.keys(): 152 | f = plt.figure() 153 | data = history.history[key] 154 | plt.plot(data) 155 | plt.show() 156 | 157 | epochs = input("Number of epochs: ") 158 | epochs = int(epochs) 159 | train(epochs=epochs) 160 | input("done. ") 161 | -------------------------------------------------------------------------------- /utility/coco.py: -------------------------------------------------------------------------------- 1 | ######################################################################## 2 | # 3 | # Functions for downloading the COCO data-set from the internet 4 | # and loading it into memory. This data-set contains images and 5 | # various associated data such as text-captions describing the images. 6 | # 7 | # http://cocodataset.org 8 | # 9 | # Implemented in Python 3.6 10 | # 11 | # Usage: 12 | # 1) Call set_data_dir() to set the desired storage directory. 13 | # 2) Call maybe_download_and_extract() to download the data-set 14 | # if it is not already located in the given data_dir. 15 | # 3) Call load_records(train=True) and load_records(train=False) 16 | # to load the data-records for the training- and validation sets. 17 | # 5) Use the returned data in your own program. 18 | # 19 | # Format: 20 | # The COCO data-set contains a large number of images and various 21 | # data for each image stored in a JSON-file. 22 | # Functionality is provided for getting a list of image-filenames 23 | # (but not actually loading the images) along with their associated 24 | # data such as text-captions describing the contents of the images. 25 | # 26 | ######################################################################## 27 | # 28 | # This file is part of the TensorFlow Tutorials available at: 29 | # 30 | # https://github.com/Hvass-Labs/TensorFlow-Tutorials 31 | # 32 | # Published under the MIT License. See the file LICENSE for details. 33 | # 34 | # Copyright 2018 by Magnus Erik Hvass Pedersen 35 | # 36 | ######################################################################## 37 | 38 | import json 39 | import os 40 | import utility.download as download 41 | from utility.cache import cache 42 | 43 | ######################################################################## 44 | 45 | # Directory where you want to download and save the data-set. 46 | # Set this before you start calling any of the functions below. 47 | # Use the function set_data_dir() to also update train_dir and val_dir. 48 | data_dir = "data/coco/" 49 | 50 | # Sub-directories for the training- and validation-sets. 51 | train_dir = "data/coco/train2017" 52 | val_dir = "data/coco/val2017" 53 | 54 | # Base-URL for the data-sets on the internet. 55 | data_url = "http://images.cocodataset.org/" 56 | 57 | 58 | ######################################################################## 59 | # Private helper-functions. 60 | 61 | def _load_records(train=True): 62 | """ 63 | Load the image-filenames and captions 64 | for either the training-set or the validation-set. 65 | """ 66 | 67 | if train: 68 | # Training-set. 69 | filename = "captions_train2017.json" 70 | else: 71 | # Validation-set. 72 | filename = "captions_val2017.json" 73 | 74 | # Full path for the data-file. 75 | path = os.path.join(data_dir, "annotations", filename) 76 | 77 | # Load the file. 78 | with open(path, "r", encoding="utf-8") as file: 79 | data_raw = json.load(file) 80 | 81 | # Convenience variables. 82 | images = data_raw['images'] 83 | annotations = data_raw['annotations'] 84 | 85 | # Initialize the dict for holding our data. 86 | # The lookup-key is the image-id. 87 | records = dict() 88 | 89 | # Collect all the filenames for the images. 90 | for image in images: 91 | # Get the id and filename for this image. 92 | image_id = image['id'] 93 | filename = image['file_name'] 94 | 95 | # Initialize a new data-record. 96 | record = dict() 97 | 98 | # Set the image-filename in the data-record. 99 | record['filename'] = filename 100 | 101 | # Initialize an empty list of image-captions 102 | # which will be filled further below. 103 | record['captions'] = list() 104 | 105 | # Save the record using the the image-id as the lookup-key. 106 | records[image_id] = record 107 | 108 | # Collect all the captions for the images. 109 | for ann in annotations: 110 | # Get the id and caption for an image. 111 | image_id = ann['image_id'] 112 | caption = ann['caption'] 113 | 114 | # Lookup the data-record for this image-id. 115 | # This data-record should already exist from the loop above. 116 | record = records[image_id] 117 | 118 | # Append the current caption to the list of captions in the 119 | # data-record that was initialized in the loop above. 120 | record['captions'].append(caption) 121 | 122 | # Convert the records-dict to a list of tuples. 123 | records_list = [(key, record['filename'], record['captions']) 124 | for key, record in sorted(records.items())] 125 | 126 | # Convert the list of tuples to separate tuples with the data. 127 | ids, filenames, captions = zip(*records_list) 128 | 129 | return ids, filenames, captions 130 | 131 | 132 | ######################################################################## 133 | # Public functions that you may call to download the data-set from 134 | # the internet and load the data into memory. 135 | 136 | 137 | def set_data_dir(new_data_dir): 138 | """ 139 | Set the base-directory for data-files and then 140 | set the sub-dirs for training and validation data. 141 | """ 142 | 143 | # Ensure we update the global variables. 144 | global data_dir, train_dir, val_dir 145 | 146 | data_dir = new_data_dir 147 | train_dir = os.path.join(new_data_dir, "train2017") 148 | val_dir = os.path.join(new_data_dir, "val2017") 149 | 150 | 151 | def maybe_download_and_extract(): 152 | """ 153 | Download and extract the COCO data-set if the data-files don't 154 | already exist in data_dir. 155 | """ 156 | 157 | # Filenames to download from the internet. 158 | filenames = ["zips/train2017.zip", "zips/val2017.zip", 159 | "annotations/annotations_trainval2017.zip"] 160 | 161 | # Download these files. 162 | for filename in filenames: 163 | # Create the full URL for the given file. 164 | url = data_url + filename 165 | 166 | print("Downloading " + url) 167 | 168 | download.maybe_download_and_extract(url=url, download_dir=data_dir) 169 | 170 | 171 | def load_records(train=True): 172 | """ 173 | Load the data-records for the data-set. This returns the image ids, 174 | filenames and text-captions for either the training-set or validation-set. 175 | 176 | This wraps _load_records() above with a cache, so if the cache-file already 177 | exists then it is loaded instead of processing the original data-file. 178 | 179 | :param train: 180 | Bool whether to load the training-set (True) or validation-set (False). 181 | 182 | :return: 183 | ids, filenames, captions for the images in the data-set. 184 | """ 185 | 186 | if train: 187 | # Cache-file for the training-set data. 188 | cache_filename = "records_train.pkl" 189 | else: 190 | # Cache-file for the validation-set data. 191 | cache_filename = "records_val.pkl" 192 | 193 | # Path for the cache-file. 194 | cache_path = os.path.join(data_dir, cache_filename) 195 | 196 | # If the data-records already exist in a cache-file then load it, 197 | # otherwise call the _load_records() function and save its 198 | # return-values to the cache-file so it can be loaded the next time. 199 | records = cache(cache_path=cache_path, 200 | fn=_load_records, 201 | train=train) 202 | 203 | return records 204 | 205 | ######################################################################## 206 | --------------------------------------------------------------------------------