├── .gitattributes ├── .gitignore ├── README.md ├── app.py ├── model.py └── saved_models ├── image_captioning_transformer_weights.h5 ├── model.h5 └── vocab.file /.gitattributes: -------------------------------------------------------------------------------- 1 | *.h5 filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | dataset/* 2 | __pycache__/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Image-Captioning 2 | 3 | HuggingFace Space: https://huggingface.co/spaces/pritish/Image-Captioning 4 | 5 | ![image](https://user-images.githubusercontent.com/55872694/218245040-fea19824-c082-47c1-bc21-0b814e61cc9a.png) 6 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import requests 3 | import numpy as np 4 | from PIL import Image 5 | from model import get_caption_model, generate_caption 6 | 7 | 8 | @st.cache(allow_output_mutation=True) 9 | def get_model(): 10 | return get_caption_model() 11 | 12 | caption_model = get_model() 13 | 14 | img_url = st.text_input(label='Enter Image URL') 15 | 16 | if (img_url != "") or (img_url != None): 17 | img = Image.open(requests.get(img_url, stream=True).raw) 18 | st.image(img) 19 | 20 | img = np.array(img) 21 | pred_caption = generate_caption(img, caption_model) 22 | st.write(pred_caption) 23 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import tensorflow as tf 3 | import pandas as pd 4 | import numpy as np 5 | 6 | 7 | # CONTANTS 8 | MAX_LENGTH = 40 9 | VOCABULARY_SIZE = 10000 10 | BATCH_SIZE = 32 11 | BUFFER_SIZE = 1000 12 | EMBEDDING_DIM = 512 13 | UNITS = 512 14 | 15 | 16 | # LOADING DATA 17 | vocab = pickle.load(open('saved_models/vocab.file', 'rb')) 18 | 19 | tokenizer = tf.keras.layers.TextVectorization( 20 | max_tokens=VOCABULARY_SIZE, 21 | standardize=None, 22 | output_sequence_length=MAX_LENGTH, 23 | vocabulary=vocab 24 | ) 25 | 26 | idx2word = tf.keras.layers.StringLookup( 27 | mask_token="", 28 | vocabulary=tokenizer.get_vocabulary(), 29 | invert=True) 30 | 31 | 32 | # MODEL 33 | def CNN_Encoder(): 34 | inception_v3 = tf.keras.applications.InceptionV3( 35 | include_top=False, 36 | weights='imagenet' 37 | ) 38 | inception_v3.trainable = False 39 | 40 | output = inception_v3.output 41 | output = tf.keras.layers.Reshape( 42 | (-1, output.shape[-1]))(output) 43 | 44 | cnn_model = tf.keras.models.Model(inception_v3.input, output) 45 | return cnn_model 46 | 47 | 48 | class TransformerEncoderLayer(tf.keras.layers.Layer): 49 | 50 | def __init__(self, embed_dim, num_heads): 51 | super().__init__() 52 | self.layer_norm_1 = tf.keras.layers.LayerNormalization() 53 | self.layer_norm_2 = tf.keras.layers.LayerNormalization() 54 | self.attention = tf.keras.layers.MultiHeadAttention( 55 | num_heads=num_heads, key_dim=embed_dim) 56 | self.dense = tf.keras.layers.Dense(embed_dim, activation="relu") 57 | 58 | 59 | def call(self, x, training): 60 | x = self.layer_norm_1(x) 61 | x = self.dense(x) 62 | 63 | attn_output = self.attention( 64 | query=x, 65 | value=x, 66 | key=x, 67 | attention_mask=None, 68 | training=training 69 | ) 70 | 71 | x = self.layer_norm_2(x + attn_output) 72 | return x 73 | 74 | 75 | class Embeddings(tf.keras.layers.Layer): 76 | 77 | def __init__(self, vocab_size, embed_dim, max_len): 78 | super().__init__() 79 | self.token_embeddings = tf.keras.layers.Embedding( 80 | vocab_size, embed_dim) 81 | self.position_embeddings = tf.keras.layers.Embedding( 82 | max_len, embed_dim, input_shape=(None, max_len)) 83 | 84 | 85 | def call(self, input_ids): 86 | length = tf.shape(input_ids)[-1] 87 | position_ids = tf.range(start=0, limit=length, delta=1) 88 | position_ids = tf.expand_dims(position_ids, axis=0) 89 | 90 | token_embeddings = self.token_embeddings(input_ids) 91 | position_embeddings = self.position_embeddings(position_ids) 92 | 93 | return token_embeddings + position_embeddings 94 | 95 | 96 | class TransformerDecoderLayer(tf.keras.layers.Layer): 97 | 98 | def __init__(self, embed_dim, units, num_heads): 99 | super().__init__() 100 | self.embedding = Embeddings( 101 | tokenizer.vocabulary_size(), embed_dim, MAX_LENGTH) 102 | 103 | self.attention_1 = tf.keras.layers.MultiHeadAttention( 104 | num_heads=num_heads, key_dim=embed_dim, dropout=0.1 105 | ) 106 | self.attention_2 = tf.keras.layers.MultiHeadAttention( 107 | num_heads=num_heads, key_dim=embed_dim, dropout=0.1 108 | ) 109 | 110 | self.layernorm_1 = tf.keras.layers.LayerNormalization() 111 | self.layernorm_2 = tf.keras.layers.LayerNormalization() 112 | self.layernorm_3 = tf.keras.layers.LayerNormalization() 113 | 114 | self.ffn_layer_1 = tf.keras.layers.Dense(units, activation="relu") 115 | self.ffn_layer_2 = tf.keras.layers.Dense(embed_dim) 116 | 117 | self.out = tf.keras.layers.Dense(tokenizer.vocabulary_size(), activation="softmax") 118 | 119 | self.dropout_1 = tf.keras.layers.Dropout(0.3) 120 | self.dropout_2 = tf.keras.layers.Dropout(0.5) 121 | 122 | 123 | def call(self, input_ids, encoder_output, training, mask=None): 124 | embeddings = self.embedding(input_ids) 125 | 126 | combined_mask = None 127 | padding_mask = None 128 | 129 | if mask is not None: 130 | causal_mask = self.get_causal_attention_mask(embeddings) 131 | padding_mask = tf.cast(mask[:, :, tf.newaxis], dtype=tf.int32) 132 | combined_mask = tf.cast(mask[:, tf.newaxis, :], dtype=tf.int32) 133 | combined_mask = tf.minimum(combined_mask, causal_mask) 134 | 135 | attn_output_1 = self.attention_1( 136 | query=embeddings, 137 | value=embeddings, 138 | key=embeddings, 139 | attention_mask=combined_mask, 140 | training=training 141 | ) 142 | 143 | out_1 = self.layernorm_1(embeddings + attn_output_1) 144 | 145 | attn_output_2 = self.attention_2( 146 | query=out_1, 147 | value=encoder_output, 148 | key=encoder_output, 149 | attention_mask=padding_mask, 150 | training=training 151 | ) 152 | 153 | out_2 = self.layernorm_2(out_1 + attn_output_2) 154 | 155 | ffn_out = self.ffn_layer_1(out_2) 156 | ffn_out = self.dropout_1(ffn_out, training=training) 157 | ffn_out = self.ffn_layer_2(ffn_out) 158 | 159 | ffn_out = self.layernorm_3(ffn_out + out_2) 160 | ffn_out = self.dropout_2(ffn_out, training=training) 161 | preds = self.out(ffn_out) 162 | return preds 163 | 164 | 165 | def get_causal_attention_mask(self, inputs): 166 | input_shape = tf.shape(inputs) 167 | batch_size, sequence_length = input_shape[0], input_shape[1] 168 | i = tf.range(sequence_length)[:, tf.newaxis] 169 | j = tf.range(sequence_length) 170 | mask = tf.cast(i >= j, dtype="int32") 171 | mask = tf.reshape(mask, (1, input_shape[1], input_shape[1])) 172 | mult = tf.concat( 173 | [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)], 174 | axis=0 175 | ) 176 | return tf.tile(mask, mult) 177 | 178 | 179 | class ImageCaptioningModel(tf.keras.Model): 180 | 181 | def __init__(self, cnn_model, encoder, decoder, image_aug=None): 182 | super().__init__() 183 | self.cnn_model = cnn_model 184 | self.encoder = encoder 185 | self.decoder = decoder 186 | self.image_aug = image_aug 187 | self.loss_tracker = tf.keras.metrics.Mean(name="loss") 188 | self.acc_tracker = tf.keras.metrics.Mean(name="accuracy") 189 | 190 | 191 | def calculate_loss(self, y_true, y_pred, mask): 192 | loss = self.loss(y_true, y_pred) 193 | mask = tf.cast(mask, dtype=loss.dtype) 194 | loss *= mask 195 | return tf.reduce_sum(loss) / tf.reduce_sum(mask) 196 | 197 | 198 | def calculate_accuracy(self, y_true, y_pred, mask): 199 | accuracy = tf.equal(y_true, tf.argmax(y_pred, axis=2)) 200 | accuracy = tf.math.logical_and(mask, accuracy) 201 | accuracy = tf.cast(accuracy, dtype=tf.float32) 202 | mask = tf.cast(mask, dtype=tf.float32) 203 | return tf.reduce_sum(accuracy) / tf.reduce_sum(mask) 204 | 205 | 206 | def compute_loss_and_acc(self, img_embed, captions, training=True): 207 | encoder_output = self.encoder(img_embed, training=True) 208 | y_input = captions[:, :-1] 209 | y_true = captions[:, 1:] 210 | mask = (y_true != 0) 211 | y_pred = self.decoder( 212 | y_input, encoder_output, training=True, mask=mask 213 | ) 214 | loss = self.calculate_loss(y_true, y_pred, mask) 215 | acc = self.calculate_accuracy(y_true, y_pred, mask) 216 | return loss, acc 217 | 218 | 219 | def train_step(self, batch): 220 | imgs, captions = batch 221 | 222 | if self.image_aug: 223 | imgs = self.image_aug(imgs) 224 | 225 | img_embed = self.cnn_model(imgs) 226 | 227 | with tf.GradientTape() as tape: 228 | loss, acc = self.compute_loss_and_acc( 229 | img_embed, captions 230 | ) 231 | 232 | train_vars = ( 233 | self.encoder.trainable_variables + self.decoder.trainable_variables 234 | ) 235 | grads = tape.gradient(loss, train_vars) 236 | self.optimizer.apply_gradients(zip(grads, train_vars)) 237 | self.loss_tracker.update_state(loss) 238 | self.acc_tracker.update_state(acc) 239 | 240 | return {"loss": self.loss_tracker.result(), "acc": self.acc_tracker.result()} 241 | 242 | 243 | def test_step(self, batch): 244 | imgs, captions = batch 245 | 246 | img_embed = self.cnn_model(imgs) 247 | 248 | loss, acc = self.compute_loss_and_acc( 249 | img_embed, captions, training=False 250 | ) 251 | 252 | self.loss_tracker.update_state(loss) 253 | self.acc_tracker.update_state(acc) 254 | 255 | return {"loss": self.loss_tracker.result(), "acc": self.acc_tracker.result()} 256 | 257 | @property 258 | def metrics(self): 259 | return [self.loss_tracker, self.acc_tracker] 260 | 261 | 262 | def load_image_from_path(img_path): 263 | img = tf.io.read_file(img_path) 264 | img = tf.io.decode_jpeg(img, channels=3) 265 | img = tf.keras.layers.Resizing(299, 299)(img) 266 | img = img / 255. 267 | return img 268 | 269 | 270 | def generate_caption(img, caption_model): 271 | if isinstance(img, str): 272 | img = load_image_from_path(img) 273 | 274 | if isinstance(img, np.ndarray): 275 | img = tf.convert_to_tensor(img) 276 | 277 | img = tf.expand_dims(img, axis=0) 278 | img_embed = caption_model.cnn_model(img) 279 | img_encoded = caption_model.encoder(img_embed, training=False) 280 | 281 | y_inp = '[start]' 282 | for i in range(MAX_LENGTH-1): 283 | tokenized = tokenizer([y_inp])[:, :-1] 284 | mask = tf.cast(tokenized != 0, tf.int32) 285 | pred = caption_model.decoder( 286 | tokenized, img_encoded, training=False, mask=mask) 287 | 288 | pred_idx = np.argmax(pred[0, i, :]) 289 | pred_word = idx2word(pred_idx).numpy().decode('utf-8') 290 | if pred_word == '[end]': 291 | break 292 | 293 | y_inp += ' ' + pred_word 294 | 295 | y_inp = y_inp.replace('[start] ', '') 296 | return y_inp 297 | 298 | 299 | def get_caption_model(): 300 | encoder = TransformerEncoderLayer(EMBEDDING_DIM, 1) 301 | decoder = TransformerDecoderLayer(EMBEDDING_DIM, UNITS, 8) 302 | 303 | cnn_model = CNN_Encoder() 304 | 305 | caption_model = ImageCaptioningModel( 306 | cnn_model=cnn_model, encoder=encoder, decoder=decoder, image_aug=None, 307 | ) 308 | 309 | def call_fn(batch, training): 310 | return batch 311 | 312 | caption_model.call = call_fn 313 | sample_x, sample_y = tf.random.normal((1, 299, 299, 3)), tf.zeros((1, 40)) 314 | 315 | caption_model((sample_x, sample_y)) 316 | 317 | sample_img_embed = caption_model.cnn_model(sample_x) 318 | sample_enc_out = caption_model.encoder(sample_img_embed, training=False) 319 | caption_model.decoder(sample_y, sample_enc_out, training=False) 320 | 321 | caption_model.load_weights('saved_models\image_captioning_transformer_weights.h5') 322 | 323 | return caption_model 324 | -------------------------------------------------------------------------------- /saved_models/image_captioning_transformer_weights.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:4feab5df7dc83396210b152594e0abb31ef7a9a584a9146461aa585752a37ffb 3 | size 201652392 4 | -------------------------------------------------------------------------------- /saved_models/model.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:e927884d1ad5adc141cdfffb429ba2aba1ad0a1e42e7d9d999972eaf3e5e81e8 3 | size 201651096 4 | -------------------------------------------------------------------------------- /saved_models/vocab.file: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pritishmishra703/Image-Captioning/d75538d958d538a2f3f150302a98c0d5e1d1c069/saved_models/vocab.file --------------------------------------------------------------------------------