├── __init__.py ├── artifacts └── bleu_scores.png ├── requirements.txt ├── data ├── coco2014 │ └── .gitignore ├── flickr8k │ └── .gitignore └── glove.6B │ └── .gitignore ├── saved_models └── .gitignore ├── glove.py ├── models └── torch │ ├── layers.py │ ├── vgg16_monolstm.py │ ├── resnet50_monolstm.py │ ├── resnext50_monolstm.py │ ├── incepv3_monolstm.py │ ├── densenet201_monolstm.py │ ├── decoders │ └── monolstm.py │ └── resnet101_attention.py ├── utils_plot.py ├── metrics.py ├── README.md ├── .gitignore ├── utils_torch.py ├── datasets └── flickr8k.py ├── train_torch.py ├── train_attntn.py └── main.ipynb /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /artifacts/bleu_scores.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Subangkar/Image-Captioning-Attention-PyTorch/HEAD/artifacts/bleu_scores.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | tqdm 4 | pillow 5 | ipython 6 | numpy 7 | pandas 8 | nltk 9 | matplotlib 10 | wandb 11 | scikit-image -------------------------------------------------------------------------------- /data/coco2014/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything 2 | * 3 | 4 | # But not these files... 5 | !.gitignore 6 | 7 | # ...even if they are in subdirectories 8 | !*/ 9 | -------------------------------------------------------------------------------- /data/flickr8k/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything 2 | * 3 | 4 | # But not these files... 5 | !.gitignore 6 | 7 | # ...even if they are in subdirectories 8 | !*/ 9 | -------------------------------------------------------------------------------- /data/glove.6B/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything 2 | * 3 | 4 | # But not these files... 5 | !.gitignore 6 | 7 | # ...even if they are in subdirectories 8 | !*/ 9 | -------------------------------------------------------------------------------- /saved_models/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything 2 | * 3 | 4 | # But not these files... 5 | !.gitignore 6 | 7 | # ...even if they are in subdirectories 8 | !*/ 9 | -------------------------------------------------------------------------------- /glove.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | from tqdm.auto import tqdm 5 | 6 | 7 | # GLOVE_DIR = path for glove.6B.100d.txt 8 | def glove_dictionary(GLOVE_DIR, dim=200): 9 | embeddings_index = {} 10 | f = open(os.path.join(GLOVE_DIR, f'glove.6B.{dim}d.txt'), encoding="utf8") 11 | for line in f: 12 | values = line.split() 13 | word = values[0] 14 | coefs = np.asarray(values[1:], dtype='float32') 15 | embeddings_index[word] = coefs 16 | f.close() 17 | return embeddings_index 18 | 19 | 20 | def embedding_matrix_creator(embedding_dim, word2idx, GLOVE_DIR='data/glove.6B/'): 21 | embeddings_index = glove_dictionary(GLOVE_DIR=GLOVE_DIR, dim=embedding_dim) 22 | embedding_matrix = np.zeros((len(word2idx), embedding_dim)) 23 | for word, i in tqdm(word2idx.items()): 24 | embedding_vector = embeddings_index.get(word.lower()) 25 | if embedding_vector is not None: 26 | # words not found in embedding index will be all-zeros. 27 | embedding_matrix[i] = embedding_vector 28 | return embedding_matrix 29 | -------------------------------------------------------------------------------- /models/torch/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class TimeDistributed(nn.Module): 6 | def __init__(self, module, batch_first=False): 7 | super().__init__() 8 | self.module = module 9 | self.batch_first = batch_first 10 | 11 | def forward(self, x): 12 | 13 | if len(x.size()) <= 2: 14 | return self.module(x) 15 | 16 | # Squash samples and timesteps into a single axis 17 | x_reshape = x.contiguous().view(-1, x.size(-1)) # (samples * timesteps, input_size) 18 | 19 | y = self.module(x_reshape) 20 | 21 | # We have to reshape Y 22 | if self.batch_first: 23 | y = y.contiguous().view(x.size(0), -1, y.size(-1)) # (samples, timesteps, output_size) 24 | else: 25 | y = y.view(-1, x.size(1), y.size(-1)) # (timesteps, samples, output_size) 26 | 27 | return y 28 | 29 | 30 | def embedding_layer(trainable=True, embedding_matrix=None, **kwargs): 31 | emb_layer = nn.Embedding(**kwargs) 32 | if embedding_matrix is not None: 33 | emb_layer.weight = nn.Parameter(torch.from_numpy(embedding_matrix).float()) 34 | trainable = (embedding_matrix is None) or trainable 35 | if not trainable: 36 | emb_layer.weight.requires_grad = False 37 | return emb_layer 38 | -------------------------------------------------------------------------------- /utils_plot.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | import numpy as np 4 | from PIL import Image 5 | from matplotlib import pyplot as plt, cm as cm 6 | import skimage 7 | import skimage.transform 8 | 9 | 10 | def visualize_att(image_path, seq, alphas, idx2word, endseq='', smooth=True): 11 | """ 12 | Visualizes caption with weights at every word. 13 | 14 | Adapted from paper authors' repo: https://github.com/kelvinxu/arctic-captions/blob/master/alpha_visualization.ipynb 15 | 16 | :param image_path: path to image that has been captioned 17 | :param seq: caption 18 | :param alphas: weights 19 | :param idx2word: reverse word mapping, i.e. ix2word 20 | :param smooth: smooth weights? 21 | """ 22 | image = Image.open(image_path) 23 | image = image.resize([14 * 24, 14 * 24], Image.LANCZOS) 24 | 25 | # words = [idx2word[ind] for ind in seq] 26 | words = list(itertools.takewhile(lambda word: word != endseq, 27 | map(lambda idx: idx2word[idx], iter(seq)))) 28 | 29 | for t in range(len(words)): 30 | if t > 50: 31 | break 32 | plt.subplot(np.ceil(len(words) / 5.), 5, t + 1) 33 | 34 | plt.text(0, 1, '%s' % (words[t]), color='black', backgroundcolor='white', fontsize=12) 35 | plt.imshow(image) 36 | current_alpha = alphas[t, :] 37 | if smooth: 38 | alpha = skimage.transform.pyramid_expand(current_alpha.numpy(), upscale=24, sigma=8) 39 | else: 40 | alpha = skimage.transform.resize(current_alpha.numpy(), [14 * 24, 14 * 24]) 41 | if t == 0: 42 | plt.imshow(alpha, alpha=0) 43 | else: 44 | plt.imshow(alpha, alpha=0.8) 45 | plt.set_cmap(cm.Greys_r) 46 | plt.axis('off') 47 | plt.show() 48 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from nltk.translate.bleu_score import SmoothingFunction 3 | from nltk.translate.bleu_score import corpus_bleu, sentence_bleu 4 | 5 | 6 | def bleu_score_fn(method_no: int = 4, ref_type='corpus'): 7 | """ 8 | :param method_no: 9 | :param ref_type: 'corpus' or 'sentence' 10 | :return: bleu score 11 | """ 12 | smoothing_method = getattr(SmoothingFunction(), f'method{method_no}') 13 | 14 | def bleu_score_corpus(reference_corpus: list, candidate_corpus: list, n: int = 4): 15 | """ 16 | :param reference_corpus: [b, 5, var_len] 17 | :param candidate_corpus: [b, var_len] 18 | :param n: size of n-gram 19 | """ 20 | weights = [1 / n] * n 21 | return corpus_bleu(reference_corpus, candidate_corpus, 22 | smoothing_function=smoothing_method, weights=weights) 23 | 24 | def bleu_score_sentence(reference_sentences: list, candidate_sentence: list, n: int = 4): 25 | """ 26 | :param reference_sentences: [5, var_len] 27 | :param candidate_sentence: [var_len] 28 | :param n: size of n-gram 29 | """ 30 | weights = [1 / n] * n 31 | return sentence_bleu(reference_sentences, candidate_sentence, 32 | smoothing_function=smoothing_method, weights=weights) 33 | 34 | if ref_type == 'corpus': 35 | return bleu_score_corpus 36 | elif ref_type == 'sentence': 37 | return bleu_score_sentence 38 | 39 | 40 | def accuracy_fn(ignore_value: int = 0): 41 | def accuracy_ignoring_value(source: torch.Tensor, target: torch.Tensor): 42 | mask = target != ignore_value 43 | return (source[mask] == target[mask]).sum().item() / mask.sum().item() 44 | 45 | return accuracy_ignoring_value 46 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Image-Captioning-PyTorch 2 | This repo contains codes to preprocess, train and evaluate sequence models on Flickr8k Image dataset in pytorch. This repo was a part of a Deep Learning Project for the Machine Learning Sessional course of Department of CSE, BUET for the session January-2020. 3 | 4 | **Models Experimented with**: 5 | - Pretrained CNN encoder & LSTM based Decoder 6 | - VGG-16, Inception-v3, Resnet-50, Resnet-101, Resnext-101, Densenet-201 7 | - Pretrained Resnet-101 & LSTM with Attention Mechanism 8 | 9 | Open [`Pretrained Attention Model's Notebook`](demo_attention_flickr8k.ipynb) or [`Pretrained MonoLSTM Model's Notebook`](demo_monolstm_flickr8k.ipynb) in colab and execute from top to bottom. 10 | 11 | **Pre-requisites**: 12 | - Datasets: 13 | - Flickr8k Dataset: [images](https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip) and [annotations](https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip) 14 | - Pre-trained word embeddings: 15 | - [Glove Embeddings of 6B words](http://nlp.stanford.edu/data/glove.6B.zip) 16 | 17 | **Data Folder Structure for training using [`train_torch.py`](train_torch.py) or [`train_attntn.py`](train_attntn.py):** 18 | ``` 19 | data/ 20 | flickr8k/ 21 | Flicker8k_Dataset/ 22 | *.jpg 23 | Flickr8k_text/ 24 | Flickr8k.token.txt 25 | Flickr_8k.devImages.txt 26 | Flickr_8k.testImages.txt 27 | Flickr_8k.trainImages.txt 28 | glove.6B/ 29 | glove.6B.50d.txt 30 | glove.6B.100d.txt 31 | glove.6B.200d.txt 32 | glove.6B.300d.txt 33 | ``` 34 | 35 | **Pretrained Models**: 36 | Some pre-trained weights are provided [here](https://drive.google.com/drive/folders/16e_bNz92M5g3Myp2kKbGZcXIkDTjasP-?usp=sharing) 37 | 38 | **Bleu score comparision of trained models**: 39 | ![alt text](artifacts/bleu_scores.png "Bleu Scores Comparision of some trained models") 40 | -------------------------------------------------------------------------------- /models/torch/vgg16_monolstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from models.torch.decoders.monolstm import Decoder 5 | 6 | 7 | class Encoder(nn.Module): 8 | def __init__(self, embed_size): 9 | """Load the pretrained vgg-16 and replace top fc layer.""" 10 | super(Encoder, self).__init__() 11 | vgg16 = torch.hub.load('pytorch/vision:v0.6.0', 'vgg16', pretrained=True) 12 | vgg16.classifier = vgg16.classifier[:-1] 13 | self.vgg16 = vgg16 14 | self.embed = nn.Linear(vgg16.classifier[-3].out_features, embed_size) # FC-relu-dropout 15 | self.bn = nn.BatchNorm1d(embed_size, momentum=0.01) 16 | 17 | def forward(self, images): 18 | """Extract feature vectors from input images.""" 19 | with torch.no_grad(): 20 | features = self.vgg16(images) 21 | features = self.embed(features) 22 | features = self.bn(features) 23 | return features 24 | 25 | 26 | class Captioner(nn.Module): 27 | def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1, embedding_matrix=None, train_embd=True): 28 | super().__init__() 29 | self.encoder = Encoder(embed_size) 30 | self.decoder = Decoder(embed_size, hidden_size, vocab_size, num_layers, 31 | embedding_matrix=embedding_matrix, train_embd=train_embd) 32 | 33 | def forward(self, images, captions, lengths): 34 | features = self.encoder(images) 35 | outputs = self.decoder(features, captions, lengths) 36 | return outputs 37 | 38 | def sample(self, images, max_len=40, endseq_idx=-1): 39 | features = self.encoder(images) 40 | captions = self.decoder.sample(features=features, max_len=max_len, endseq_idx=endseq_idx) 41 | return captions 42 | 43 | def sample_beam_search(self, images, max_len=40, endseq_idx=-1, beam_width=5): 44 | features = self.encoder(images) 45 | captions = self.decoder.sample_beam_search(features=features, max_len=max_len, beam_width=beam_width) 46 | return captions 47 | -------------------------------------------------------------------------------- /models/torch/resnet50_monolstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from models.torch.decoders.monolstm import Decoder 5 | 6 | 7 | class Encoder(nn.Module): 8 | def __init__(self, embed_size): 9 | """Load the pretrained ResNet-50 and replace top fc layer.""" 10 | super(Encoder, self).__init__() 11 | resnet = torch.hub.load('pytorch/vision:v0.6.0', 'resnet50', pretrained=True) 12 | modules = list(resnet.children())[:-1] 13 | self.resnet = nn.Sequential(*modules) 14 | self.embed = nn.Sequential( 15 | nn.Linear(resnet.fc.in_features, embed_size), 16 | nn.Dropout(p=0.5), 17 | ) 18 | self.bn = nn.BatchNorm1d(embed_size, momentum=0.01) 19 | 20 | def forward(self, images): 21 | """Extract feature vectors from input images.""" 22 | with torch.no_grad(): 23 | features = self.resnet(images) 24 | features = features.view(features.size(0), -1) 25 | features = self.embed(features) 26 | features = self.bn(features) 27 | return features 28 | 29 | 30 | class Captioner(nn.Module): 31 | def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1, embedding_matrix=None, train_embd=True): 32 | super().__init__() 33 | self.encoder = Encoder(embed_size) 34 | self.decoder = Decoder(embed_size, hidden_size, vocab_size, num_layers, 35 | embedding_matrix=embedding_matrix, train_embd=train_embd) 36 | 37 | def forward(self, images, captions, lengths): 38 | features = self.encoder(images) 39 | outputs = self.decoder(features, captions, lengths) 40 | return outputs 41 | 42 | def sample(self, images, max_len=40, endseq_idx=-1): 43 | features = self.encoder(images) 44 | captions = self.decoder.sample(features=features, max_len=max_len, endseq_idx=endseq_idx) 45 | return captions 46 | 47 | def sample_beam_search(self, images, max_len=40, endseq_idx=-1, beam_width=5): 48 | features = self.encoder(images) 49 | captions = self.decoder.sample_beam_search(features=features, max_len=max_len, beam_width=beam_width) 50 | return captions 51 | -------------------------------------------------------------------------------- /models/torch/resnext50_monolstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from models.torch.decoders.monolstm import Decoder 5 | 6 | 7 | class Encoder(nn.Module): 8 | def __init__(self, embed_size): 9 | """Load the pretrained ResNext-50 and replace top fc layer.""" 10 | super(Encoder, self).__init__() 11 | resnext = torch.hub.load('pytorch/vision:v0.6.0', 'resnext50_32x4d', pretrained=True) 12 | modules = list(resnext.children())[:-1] 13 | self.resnext = nn.Sequential(*modules) 14 | self.embed = nn.Sequential( 15 | nn.Linear(resnext.fc.in_features, embed_size), 16 | nn.Dropout(p=0.5), 17 | ) 18 | self.bn = nn.BatchNorm1d(embed_size, momentum=0.01) 19 | 20 | def forward(self, images): 21 | """Extract feature vectors from input images.""" 22 | with torch.no_grad(): 23 | features = self.resnext(images) 24 | features = features.view(features.size(0), -1) 25 | features = self.embed(features) 26 | features = self.bn(features) 27 | return features 28 | 29 | 30 | class Captioner(nn.Module): 31 | def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1, embedding_matrix=None, train_embd=True): 32 | super().__init__() 33 | self.encoder = Encoder(embed_size) 34 | self.decoder = Decoder(embed_size, hidden_size, vocab_size, num_layers, 35 | embedding_matrix=embedding_matrix, train_embd=train_embd) 36 | 37 | def forward(self, images, captions, lengths): 38 | features = self.encoder(images) 39 | outputs = self.decoder(features, captions, lengths) 40 | return outputs 41 | 42 | def sample(self, images, max_len=40, endseq_idx=-1): 43 | features = self.encoder(images) 44 | captions = self.decoder.sample(features=features, max_len=max_len, endseq_idx=endseq_idx) 45 | return captions 46 | 47 | def sample_beam_search(self, images, max_len=40, endseq_idx=-1, beam_width=5): 48 | features = self.encoder(images) 49 | captions = self.decoder.sample_beam_search(features=features, max_len=max_len, beam_width=beam_width) 50 | return captions 51 | -------------------------------------------------------------------------------- /models/torch/incepv3_monolstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | import torchvision 5 | 6 | from models.torch.decoders.monolstm import Decoder 7 | 8 | 9 | class Encoder(nn.Module): 10 | def __init__(self, embed_size): 11 | """Load the pretrained Inception-v3 and replace top fc layer.""" 12 | super(Encoder, self).__init__() 13 | inception_v3 = torchvision.models.inception_v3(pretrained=True, aux_logits=False) 14 | modules = list(inception_v3.children())[:-1] 15 | self.inception_v3 = nn.Sequential(*modules) 16 | self.embed = nn.Sequential( 17 | nn.Linear(inception_v3.fc.in_features, embed_size), 18 | nn.Dropout(p=0.5), 19 | ) 20 | self.bn = nn.BatchNorm1d(embed_size, momentum=0.01) 21 | 22 | def forward(self, images): 23 | """Extract feature vectors from input images.""" 24 | with torch.no_grad(): 25 | features = self.inception_v3(images) 26 | features = F.relu(features, inplace=True).view(features.size(0), -1) 27 | features = self.embed(features) 28 | features = self.bn(features) 29 | return features 30 | 31 | 32 | class Captioner(nn.Module): 33 | def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1, embedding_matrix=None, train_embd=True): 34 | super().__init__() 35 | self.encoder = Encoder(embed_size) 36 | self.decoder = Decoder(embed_size, hidden_size, vocab_size, num_layers, 37 | embedding_matrix=embedding_matrix, train_embd=train_embd) 38 | 39 | def forward(self, images, captions, lengths): 40 | features = self.encoder(images) 41 | outputs = self.decoder(features, captions, lengths) 42 | return outputs 43 | 44 | def sample(self, images, max_len=40, endseq_idx=-1): 45 | features = self.encoder(images) 46 | captions = self.decoder.sample(features=features, max_len=max_len, endseq_idx=endseq_idx) 47 | return captions 48 | 49 | def sample_beam_search(self, images, max_len=40, endseq_idx=-1, beam_width=5): 50 | features = self.encoder(images) 51 | captions = self.decoder.sample_beam_search(features=features, max_len=max_len, beam_width=beam_width) 52 | return captions 53 | -------------------------------------------------------------------------------- /models/torch/densenet201_monolstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | from models.torch.decoders.monolstm import Decoder 6 | 7 | 8 | class Encoder(nn.Module): 9 | def __init__(self, embed_size): 10 | """Load the pretrained Densenet-201 and replace top classifier layer.""" 11 | super(Encoder, self).__init__() 12 | densenet = torch.hub.load('pytorch/vision:v0.6.0', 'densenet201', pretrained=True) 13 | modules = list(densenet.children())[:-1] 14 | self.densenet = nn.Sequential(*modules) 15 | self.embed = nn.Sequential( 16 | nn.Linear(densenet.classifier.in_features, embed_size), 17 | nn.Dropout(p=0.5), 18 | ) 19 | self.bn = nn.BatchNorm1d(embed_size, momentum=0.01) 20 | 21 | def forward(self, images): 22 | """Extract feature vectors from input images.""" 23 | with torch.no_grad(): 24 | features = self.densenet(images) 25 | features = F.relu(features, inplace=True) 26 | features = F.avg_pool2d(features, kernel_size=7).view(features.size(0), -1) 27 | features = self.embed(features) 28 | features = self.bn(features) 29 | return features 30 | 31 | 32 | class Captioner(nn.Module): 33 | def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1, embedding_matrix=None, train_embd=True): 34 | super().__init__() 35 | self.encoder = Encoder(embed_size) 36 | self.decoder = Decoder(embed_size, hidden_size, vocab_size, num_layers, 37 | embedding_matrix=embedding_matrix, train_embd=train_embd) 38 | 39 | def forward(self, images, captions, lengths): 40 | features = self.encoder(images) 41 | outputs = self.decoder(features, captions, lengths) 42 | return outputs 43 | 44 | def sample(self, images, max_len=40, endseq_idx=-1): 45 | features = self.encoder(images) 46 | captions = self.decoder.sample(features=features, max_len=max_len, endseq_idx=endseq_idx) 47 | return captions 48 | 49 | def sample_beam_search(self, images, max_len=40, endseq_idx=-1, beam_width=5): 50 | features = self.encoder(images) 51 | captions = self.decoder.sample_beam_search(features=features, max_len=max_len, beam_width=beam_width) 52 | return captions 53 | -------------------------------------------------------------------------------- /models/torch/decoders/monolstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | from torch.nn import functional as F 4 | from torch.nn.utils.rnn import pack_padded_sequence 5 | 6 | from models.torch.layers import embedding_layer 7 | 8 | 9 | class Decoder(nn.Module): 10 | def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1, embedding_matrix=None, train_embd=True): 11 | """Set the hyper-parameters and build the layers.""" 12 | super(Decoder, self).__init__() 13 | self.embed = embedding_layer(num_embeddings=vocab_size, embedding_dim=embed_size, 14 | embedding_matrix=embedding_matrix, trainable=train_embd) 15 | self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True, dropout=0.5) 16 | self.linear = nn.Linear(hidden_size, vocab_size) 17 | 18 | def forward(self, features, captions, lengths): 19 | """Decode image feature vectors and generates captions. 20 | features = [b, 300] 21 | captions = [b, max_len] 22 | :return [sum_len, vocab_size] 23 | """ 24 | # [b, max_len] -> [b, max_len-1] 25 | captions = captions[:, :-1] 26 | # [b, max_len-1, embed_dim] 27 | embeddings = self.embed(captions) 28 | # [b, max_len, embed_dim] 29 | inputs = torch.cat((features.unsqueeze(1), embeddings), 1) 30 | # (0)[sum_len, embed_dim] 31 | inputs_packed = pack_padded_sequence(inputs, lengths=lengths, batch_first=True, enforce_sorted=True) 32 | # (0)[sum_len, embed_dim] 33 | hiddens, _ = self.lstm(inputs_packed) 34 | # [sum_len, vocab_size] 35 | outputs = self.linear(hiddens[0]) 36 | return outputs 37 | 38 | def sample(self, features, states=None, max_len=40, endseq_idx=-1): 39 | """Samples captions in batch for given image features (Greedy search). 40 | features = [b, embed_dim] 41 | inputs = [b, 1, embed_dim] 42 | :return [b, max_len] 43 | """ 44 | inputs = features.unsqueeze(1) 45 | sampled_ids = [] 46 | for i in range(max_len): 47 | # [b, 1, hidden_size] 48 | hiddens, states = self.lstm(inputs, states) 49 | # [b, 1, hidden_size] -> [b, hidden_size] -> [b, vocab_size] 50 | outputs = self.linear(hiddens.squeeze(1)) # (batch_size, vocab_size) 51 | # [b] 52 | predicted = outputs.argmax(1) 53 | sampled_ids.append(predicted) 54 | # [b] -> [b, embed_dim] -> [b, 1, embed_dim] 55 | inputs = self.embed(predicted).unsqueeze(1) 56 | # [b, max_len] 57 | sampled_ids = torch.stack(sampled_ids, 1) 58 | return sampled_ids 59 | 60 | def sample_beam_search(self, features, states=None, max_len=40, beam_width=5): 61 | """Accept a pre-processed image tensor and return the top predicted 62 | sentences. This is the beam search approach. 63 | features = [b, embed_dim] 64 | """ 65 | # [b, 1, embed_dim] 66 | inputs = features.unsqueeze(1) 67 | # Top word idx sequences and their corresponding inputs and states 68 | idx_sequences = [[[], 0.0, inputs, states]] 69 | for _ in range(max_len): 70 | # Store all the potential candidates at each step 71 | all_candidates = [] 72 | # Predict the next word idx for each of the top sequences 73 | for idx_seq in idx_sequences: 74 | # [b, 1, hidden_size] 75 | hiddens, states = self.lstm(idx_seq[2], idx_seq[3]) 76 | # [b, 1, hidden_size] -> [b, hidden_size] -> [b, vocab_size] 77 | outputs = self.linear(hiddens.squeeze(1)) 78 | # Transform outputs to log probabilities to avoid floating-point 79 | # underflow caused by multiplying very small probabilities 80 | # [b, vocab_size] 81 | log_probs = F.log_softmax(outputs, -1) 82 | # [b, k] 83 | top_log_probs, top_idx = log_probs.topk(beam_width, 1) 84 | # [k] 85 | top_idx = top_idx.squeeze(0) 86 | # create a new set of top sentences for next round 87 | for i in range(beam_width): 88 | # idx_seq = [[i1, i2], 0.5, [b, 1, embed_dim], [b, 1, hidden_size]] 89 | next_idx_seq, log_prob = idx_seq[0][:], idx_seq[1] 90 | # idx_seq = [[top_idx(i)..,], 0.0, [b, 1, embed_dim], [b, 1, hidden_size]] 91 | next_idx_seq.append(top_idx[i].item()) 92 | # idx_seq = [[top_idx(i)..,], top_log_probs(0)(i), [b, 1, embed_dim], [b, 1, hidden_size]] 93 | log_prob += top_log_probs[0][i].item() 94 | # Indexing 1-dimensional top_idx gives 0-dimensional tensors. 95 | # We have to expand dimensions before embedding them 96 | # [1] -> [1, embed_dim]-> [1, 1, embed_dim] 97 | inputs = self.embed(top_idx[i].unsqueeze(0)).unsqueeze(0) 98 | all_candidates.append([next_idx_seq, log_prob, inputs, states]) 99 | # Keep only the top sequences according to their total log probability 100 | ordered = sorted(all_candidates, key=lambda x: x[1], reverse=True) 101 | idx_sequences = ordered[:beam_width] 102 | return [idx_seq[0] for idx_seq in idx_sequences] 103 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.toptal.com/developers/gitignore/api/pycharm,python,jupyternotebooks 3 | # Edit at https://www.toptal.com/developers/gitignore?templates=pycharm,python,jupyternotebooks 4 | 5 | *.zip 6 | wandb/** 7 | 8 | ### JupyterNotebooks ### 9 | # gitignore template for Jupyter Notebooks 10 | # website: http://jupyter.org/ 11 | 12 | .ipynb_checkpoints 13 | */.ipynb_checkpoints/* 14 | 15 | # IPython 16 | profile_default/ 17 | ipython_config.py 18 | 19 | # Remove previous ipynb_checkpoints 20 | # git rm -r .ipynb_checkpoints/ 21 | 22 | ### PyCharm ### 23 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 24 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 25 | 26 | # User-specific stuff 27 | .idea/**/workspace.xml 28 | .idea/**/tasks.xml 29 | .idea/**/usage.statistics.xml 30 | .idea/**/dictionaries 31 | .idea/**/shelf 32 | 33 | # Generated files 34 | .idea/**/contentModel.xml 35 | 36 | # Sensitive or high-churn files 37 | .idea/**/dataSources/ 38 | .idea/**/dataSources.ids 39 | .idea/**/dataSources.local.xml 40 | .idea/**/sqlDataSources.xml 41 | .idea/**/dynamic.xml 42 | .idea/**/uiDesigner.xml 43 | .idea/**/dbnavigator.xml 44 | 45 | # Gradle 46 | .idea/**/gradle.xml 47 | .idea/**/libraries 48 | 49 | # Gradle and Maven with auto-import 50 | # When using Gradle or Maven with auto-import, you should exclude module files, 51 | # since they will be recreated, and may cause churn. Uncomment if using 52 | # auto-import. 53 | # .idea/artifacts 54 | # .idea/compiler.xml 55 | # .idea/jarRepositories.xml 56 | # .idea/modules.xml 57 | # .idea/*.iml 58 | # .idea/modules 59 | # *.iml 60 | # *.ipr 61 | 62 | # CMake 63 | cmake-build-*/ 64 | 65 | # Mongo Explorer plugin 66 | .idea/**/mongoSettings.xml 67 | 68 | # File-based project format 69 | *.iws 70 | 71 | # IntelliJ 72 | out/ 73 | 74 | # mpeltonen/sbt-idea plugin 75 | .idea_modules/ 76 | 77 | # JIRA plugin 78 | atlassian-ide-plugin.xml 79 | 80 | # Cursive Clojure plugin 81 | .idea/replstate.xml 82 | 83 | # Crashlytics plugin (for Android Studio and IntelliJ) 84 | com_crashlytics_export_strings.xml 85 | crashlytics.properties 86 | crashlytics-build.properties 87 | fabric.properties 88 | 89 | # Editor-based Rest Client 90 | .idea/httpRequests 91 | 92 | # Android studio 3.1+ serialized cache file 93 | .idea/caches/build_file_checksums.ser 94 | 95 | ### PyCharm Patch ### 96 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 97 | 98 | # *.iml 99 | # modules.xml 100 | # .idea/misc.xml 101 | # *.ipr 102 | 103 | # Sonarlint plugin 104 | # https://plugins.jetbrains.com/plugin/7973-sonarlint 105 | .idea/**/sonarlint/ 106 | 107 | # SonarQube Plugin 108 | # https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin 109 | .idea/**/sonarIssues.xml 110 | 111 | # Markdown Navigator plugin 112 | # https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced 113 | .idea/**/markdown-navigator.xml 114 | .idea/**/markdown-navigator-enh.xml 115 | .idea/**/markdown-navigator/ 116 | 117 | # Cache file creation bug 118 | # See https://youtrack.jetbrains.com/issue/JBR-2257 119 | .idea/$CACHE_FILE$ 120 | 121 | # CodeStream plugin 122 | # https://plugins.jetbrains.com/plugin/12206-codestream 123 | .idea/codestream.xml 124 | 125 | ### Python ### 126 | # Byte-compiled / optimized / DLL files 127 | __pycache__/ 128 | *.py[cod] 129 | *$py.class 130 | 131 | # C extensions 132 | *.so 133 | 134 | # Distribution / packaging 135 | .Python 136 | build/ 137 | develop-eggs/ 138 | dist/ 139 | downloads/ 140 | eggs/ 141 | .eggs/ 142 | lib/ 143 | lib64/ 144 | parts/ 145 | sdist/ 146 | var/ 147 | wheels/ 148 | pip-wheel-metadata/ 149 | share/python-wheels/ 150 | *.egg-info/ 151 | .installed.cfg 152 | *.egg 153 | MANIFEST 154 | 155 | # PyInstaller 156 | # Usually these files are written by a python script from a template 157 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 158 | *.manifest 159 | *.spec 160 | 161 | # Installer logs 162 | pip-log.txt 163 | pip-delete-this-directory.txt 164 | 165 | # Unit test / coverage reports 166 | htmlcov/ 167 | .tox/ 168 | .nox/ 169 | .coverage 170 | .coverage.* 171 | .cache 172 | nosetests.xml 173 | coverage.xml 174 | *.cover 175 | *.py,cover 176 | .hypothesis/ 177 | .pytest_cache/ 178 | pytestdebug.log 179 | 180 | # Translations 181 | *.mo 182 | *.pot 183 | 184 | # Django stuff: 185 | *.log 186 | local_settings.py 187 | db.sqlite3 188 | db.sqlite3-journal 189 | 190 | # Flask stuff: 191 | instance/ 192 | .webassets-cache 193 | 194 | # Scrapy stuff: 195 | .scrapy 196 | 197 | # Sphinx documentation 198 | docs/_build/ 199 | doc/_build/ 200 | 201 | # PyBuilder 202 | target/ 203 | 204 | # Jupyter Notebook 205 | 206 | # IPython 207 | 208 | # pyenv 209 | .python-version 210 | 211 | # pipenv 212 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 213 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 214 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 215 | # install all needed dependencies. 216 | #Pipfile.lock 217 | 218 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 219 | __pypackages__/ 220 | 221 | # Celery stuff 222 | celerybeat-schedule 223 | celerybeat.pid 224 | 225 | # SageMath parsed files 226 | *.sage.py 227 | 228 | # Environments 229 | .env 230 | .venv 231 | env/ 232 | venv/ 233 | ENV/ 234 | env.bak/ 235 | venv.bak/ 236 | pythonenv* 237 | 238 | # Spyder project settings 239 | .spyderproject 240 | .spyproject 241 | 242 | # Rope project settings 243 | .ropeproject 244 | 245 | # mkdocs documentation 246 | /site 247 | 248 | # mypy 249 | .mypy_cache/ 250 | .dmypy.json 251 | dmypy.json 252 | 253 | # Pyre type checker 254 | .pyre/ 255 | 256 | # pytype static type analyzer 257 | .pytype/ 258 | 259 | # profiling data 260 | .prof 261 | 262 | # End of https://www.toptal.com/developers/gitignore/api/pycharm,python,jupyternotebooks -------------------------------------------------------------------------------- /utils_torch.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | import wandb 7 | from PIL import Image 8 | from nltk.translate.bleu_score import sentence_bleu 9 | from tqdm.auto import tqdm 10 | 11 | 12 | def preprocess_input(x): 13 | x -= 0.5 14 | x *= 2. 15 | return x 16 | 17 | 18 | # returns (3, h, w) 19 | def preprocess(image_path, trans): 20 | img = Image.open(image_path).convert('RGB') 21 | x = trans(img) 22 | x = x.unsqueeze(0) 23 | # x = preprocess_input(x) 24 | return x 25 | 26 | 27 | def greedy_predictions_gen(encoding_dict, model, word2idx, idx2word, images, max_len, 28 | startseq="", endseq="", device=torch.device('cpu')): 29 | def greedy_search_predictions_util(image): 30 | start_word = [startseq] 31 | with torch.no_grad(): 32 | while True: 33 | par_caps = torch.LongTensor([word2idx[i] for i in start_word]) 34 | par_caps = padding_tensor([par_caps], maxlen=max_len).to(device=device) 35 | e = encoding_dict[image[len(images):]].unsqueeze(0) 36 | preds = model(e, par_caps).cpu().numpy() 37 | word_pred = idx2word[np.argmax(preds[0])] # [0] is for first elm of batch 38 | start_word.append(word_pred) 39 | 40 | if word_pred == endseq or len(start_word) > max_len: 41 | break 42 | return ' '.join(start_word[1:-1]) 43 | 44 | return greedy_search_predictions_util 45 | 46 | 47 | def beam_search_predictions_gen(beam_index, encoding_dict, model, word2idx, idx2word, images, max_len, 48 | startseq="", endseq="", device=torch.device('cpu')): 49 | def beam_search_predictions_util(image): 50 | start = [word2idx[startseq]] 51 | 52 | start_word = [[start, 0.0]] 53 | 54 | while len(start_word[0][0]) < max_len: 55 | temp = [] 56 | for s in start_word: 57 | with torch.no_grad(): 58 | par_caps = torch.LongTensor(s[0]) 59 | par_caps = padding_tensor([par_caps], maxlen=max_len).to(device=device) 60 | e = encoding_dict[image[len(images):]].unsqueeze(0) 61 | preds = model(e, par_caps).cpu().numpy() 62 | 63 | word_preds = np.argsort(preds[0])[-beam_index:] 64 | 65 | # Getting the top (n) predictions and creating a 66 | # new list so as to put them via the model again 67 | for w in word_preds: 68 | next_cap, prob = s[0][:], s[1] 69 | next_cap.append(w) 70 | prob += preds[0][w] 71 | temp.append([next_cap, prob]) 72 | 73 | start_word = temp 74 | # Sorting according to the probabilities 75 | start_word = sorted(start_word, reverse=False, key=lambda l: l[1]) 76 | # Getting the top words 77 | start_word = start_word[-beam_index:] 78 | 79 | start_word = start_word[-1][0] 80 | intermediate_caption = [idx2word[i] for i in start_word] 81 | 82 | final_caption = [] 83 | 84 | for i in intermediate_caption: 85 | if i != endseq: 86 | final_caption.append(i) 87 | else: 88 | break 89 | 90 | final_caption = ' '.join(final_caption[1:]) 91 | return final_caption 92 | 93 | return beam_search_predictions_util 94 | 95 | 96 | def split_data(l, img, images): 97 | temp = [] 98 | for i in img: 99 | if i[len(images):] in l: 100 | temp.append(i) 101 | return temp 102 | 103 | 104 | def get_bleu_score(img_to_caplist_dict, caption_gen_func, device=torch.device('cpu')): 105 | bleu_score = 0.0 106 | for k, v in tqdm(img_to_caplist_dict.items()): 107 | candidate = caption_gen_func(k).split() 108 | references = [s.split() for s in v] 109 | bleu_score += sentence_bleu(references, candidate) 110 | return bleu_score / len(img_to_caplist_dict) 111 | 112 | 113 | def print_eval_metrics(img_cap_dict, encoding_dict, model, word2idx, idx2word, images, max_len, 114 | device=torch.device('cpu')): 115 | print('\t\tGreedy: ', 116 | get_bleu_score(img_cap_dict, greedy_predictions_gen(encoding_dict=encoding_dict, model=model, 117 | word2idx=word2idx, idx2word=idx2word, 118 | images=images, max_len=max_len))) 119 | for k in [3, 5, 7]: 120 | print(f'\t\tBeam Search k={k}:', get_bleu_score(img_cap_dict, 121 | beam_search_predictions_gen(beam_index=k, 122 | encoding_dict=encoding_dict, 123 | model=model, 124 | word2idx=word2idx, 125 | idx2word=idx2word, 126 | images=images, max_len=max_len))) 127 | 128 | 129 | def padding_tensor(sequences, maxlen): 130 | """ 131 | :param sequences: list of tensors 132 | :param maxlen: fixed length of output tensors 133 | :return: 134 | """ 135 | num = len(sequences) 136 | # max_len = max([s.size(0) for s in sequences]) 137 | out_dims = (num, maxlen) 138 | out_tensor = sequences[0].data.new(*out_dims).fill_(0) 139 | for i, tensor in enumerate(sequences): 140 | length = tensor.size(0) 141 | out_tensor[i, :length] = tensor 142 | return out_tensor 143 | 144 | 145 | def words_from_tensors_fn(idx2word, max_len=40, startseq='', endseq=''): 146 | def words_from_tensors(captions: np.array) -> list: 147 | """ 148 | :param captions: [b, max_len] 149 | :return: 150 | """ 151 | captoks = [] 152 | for capidx in captions: 153 | # capidx = [1, max_len] 154 | captoks.append(list(itertools.takewhile(lambda word: word != endseq, 155 | map(lambda idx: idx2word[idx], iter(capidx))))[1:]) 156 | return captoks 157 | 158 | return words_from_tensors 159 | 160 | 161 | def sync_files_wandb(file_path_list): 162 | for path in file_path_list: 163 | if os.path.isfile(path) and os.access(path, os.R_OK): 164 | wandb.save(path) 165 | print(f'synced {path}') 166 | else: 167 | print("Either the file is missing or not readable") 168 | -------------------------------------------------------------------------------- /datasets/flickr8k.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import io 3 | import ntpath 4 | import os 5 | 6 | import nltk 7 | import pandas as pd 8 | import torch 9 | from PIL import Image 10 | from torch.utils.data import Dataset 11 | from torchvision import transforms 12 | 13 | from utils_torch import split_data 14 | 15 | 16 | class Flickr8kDataset(Dataset): 17 | """ 18 | imgname: just image file name 19 | imgpath: full path to image file 20 | """ 21 | 22 | def __init__(self, dataset_base_path='data/flickr8k/', 23 | vocab_set=None, dist='val', 24 | startseq="", endseq="", unkseq="", padseq="", 25 | transformations=None, 26 | return_raw=False, 27 | load_img_to_memory=False, 28 | return_type='tensor', 29 | device=torch.device('cpu')): 30 | self.token = dataset_base_path + 'Flickr8k_text/Flickr8k.token.txt' 31 | self.images_path = dataset_base_path + 'Flicker8k_Dataset/' 32 | 33 | self.dist_list = { 34 | 'train': dataset_base_path + 'Flickr8k_text/Flickr_8k.trainImages.txt', 35 | 'val': dataset_base_path + 'Flickr8k_text/Flickr_8k.devImages.txt', 36 | 'test': dataset_base_path + 'Flickr8k_text/Flickr_8k.testImages.txt' 37 | } 38 | 39 | self.load_img_to_memory = load_img_to_memory 40 | self.pil_d = None 41 | 42 | self.device = torch.device(device) 43 | self.torch = torch.cuda if (self.device.type == 'cuda') else torch 44 | 45 | self.return_raw = return_raw 46 | self.return_type = return_type 47 | 48 | self.__get_item__fn = self.__getitem__corpus if return_type == 'corpus' else self.__getitem__tensor 49 | 50 | self.imgpath_list = glob.glob(self.images_path + '*.jpg') 51 | self.all_imgname_to_caplist = self.__all_imgname_to_caplist_dict() 52 | self.imgname_to_caplist = self.__get_imgname_to_caplist_dict(self.__get_imgpath_list(dist=dist)) 53 | 54 | self.transformations = transformations if transformations is not None else transforms.Compose([ 55 | transforms.ToTensor() 56 | ]) 57 | 58 | self.startseq = startseq.strip() 59 | self.endseq = endseq.strip() 60 | self.unkseq = unkseq.strip() 61 | self.padseq = padseq.strip() 62 | 63 | if vocab_set is None: 64 | self.vocab, self.word2idx, self.idx2word, self.max_len = self.__construct_vocab() 65 | else: 66 | self.vocab, self.word2idx, self.idx2word, self.max_len = vocab_set 67 | self.db = self.get_db() 68 | 69 | def __all_imgname_to_caplist_dict(self): 70 | captions = open(self.token, 'r').read().strip().split('\n') 71 | imgname_to_caplist = {} 72 | for i, row in enumerate(captions): 73 | row = row.split('\t') 74 | row[0] = row[0][:len(row[0]) - 2] # filename#0 caption 75 | if row[0] in imgname_to_caplist: 76 | imgname_to_caplist[row[0]].append(row[1]) 77 | else: 78 | imgname_to_caplist[row[0]] = [row[1]] 79 | return imgname_to_caplist 80 | 81 | def __get_imgname_to_caplist_dict(self, img_path_list): 82 | d = {} 83 | for i in img_path_list: 84 | if i[len(self.images_path):] in self.all_imgname_to_caplist: 85 | d[ntpath.basename(i)] = self.all_imgname_to_caplist[i[len(self.images_path):]] 86 | return d 87 | 88 | def __get_imgpath_list(self, dist='val'): 89 | dist_images = set(open(self.dist_list[dist], 'r').read().strip().split('\n')) 90 | dist_imgpathlist = split_data(dist_images, img=self.imgpath_list, images=self.images_path) 91 | return dist_imgpathlist 92 | 93 | def __construct_vocab(self): 94 | words = [self.startseq, self.endseq, self.unkseq, self.padseq] 95 | max_len = 0 96 | for _, caplist in self.imgname_to_caplist.items(): 97 | for cap in caplist: 98 | cap_words = nltk.word_tokenize(cap.lower()) 99 | words.extend(cap_words) 100 | max_len = max(max_len, len(cap_words) + 2) 101 | vocab = sorted(list(set(words))) 102 | 103 | word2idx = {word: index for index, word in enumerate(vocab)} 104 | idx2word = {index: word for index, word in enumerate(vocab)} 105 | 106 | return vocab, word2idx, idx2word, max_len 107 | 108 | def get_vocab(self): 109 | return self.vocab, self.word2idx, self.idx2word, self.max_len 110 | 111 | def get_db(self): 112 | 113 | if self.load_img_to_memory: 114 | self.pil_d = {} 115 | for imgname in self.imgname_to_caplist.keys(): 116 | self.pil_d[imgname] = Image.open(os.path.join(self.images_path, imgname)).convert('RGB') 117 | 118 | if self.return_type == 'corpus': 119 | df = [] 120 | for imgname, caplist in self.imgname_to_caplist.items(): 121 | cap_wordlist = [] 122 | cap_lenlist = [] 123 | for caption in caplist: 124 | toks = nltk.word_tokenize(caption.lower()) 125 | cap_wordlist.append(toks) 126 | cap_lenlist.append(len(toks)) 127 | df.append([imgname, cap_wordlist, cap_lenlist]) 128 | return df 129 | 130 | # ----- Forming a df to sample from ------ 131 | l = ["image_id\tcaption\tcaption_length\n"] 132 | 133 | for imgname, caplist in self.imgname_to_caplist.items(): 134 | for cap in caplist: 135 | l.append( 136 | f"{imgname}\t" 137 | f"{cap.lower()}\t" 138 | f"{len(nltk.word_tokenize(cap.lower()))}\n") 139 | img_id_cap_str = ''.join(l) 140 | 141 | df = pd.read_csv(io.StringIO(img_id_cap_str), delimiter='\t') 142 | return df.to_numpy() 143 | 144 | @property 145 | def pad_value(self): 146 | return 0 147 | 148 | def __getitem__(self, index: int): 149 | return self.__get_item__fn(index) 150 | 151 | def __len__(self): 152 | return len(self.db) 153 | 154 | def get_image_captions(self, index: int): 155 | """ 156 | :param index: [] index 157 | :returns: image_path, list_of_captions 158 | """ 159 | imgname = self.db[index][0] 160 | return os.path.join(self.images_path, imgname), self.imgname_to_caplist[imgname] 161 | 162 | def __getitem__tensor(self, index: int): 163 | imgname = self.db[index][0] 164 | caption = self.db[index][1] 165 | capt_ln = self.db[index][2] 166 | cap_toks = [self.startseq] + nltk.word_tokenize(caption) + [self.endseq] 167 | img_tens = self.pil_d[imgname] if self.load_img_to_memory else Image.open( 168 | os.path.join(self.images_path, imgname)).convert('RGB') 169 | img_tens = self.transformations(img_tens).to(self.device) 170 | cap_tens = self.torch.LongTensor(self.max_len).fill_(self.pad_value) 171 | cap_tens[:len(cap_toks)] = self.torch.LongTensor([self.word2idx[word] for word in cap_toks]) 172 | return img_tens, cap_tens, len(cap_toks) 173 | 174 | def __getitem__corpus(self, index: int): 175 | imgname = self.db[index][0] 176 | cap_wordlist = self.db[index][1] 177 | cap_lenlist = self.db[index][2] 178 | img_tens = self.pil_d[imgname] if self.load_img_to_memory else Image.open( 179 | os.path.join(self.images_path, imgname)).convert('RGB') 180 | img_tens = self.transformations(img_tens).to(self.device) 181 | return img_tens, cap_wordlist, cap_lenlist 182 | -------------------------------------------------------------------------------- /train_torch.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import pickle 3 | import wandb 4 | from matplotlib import pyplot as plt 5 | from torch.nn.utils.rnn import pack_padded_sequence 6 | from torch.utils.data import DataLoader 7 | from torchvision import transforms 8 | 9 | from datasets.flickr8k import Flickr8kDataset 10 | from glove import embedding_matrix_creator 11 | from metrics import * 12 | from utils_torch import * 13 | 14 | # %% 15 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 16 | device 17 | # %% 18 | 19 | DATASET_BASE_PATH = 'data/flickr8k/' 20 | 21 | # %% 22 | 23 | train_set = Flickr8kDataset(dataset_base_path=DATASET_BASE_PATH, dist='train', device=device, 24 | return_type='tensor', 25 | load_img_to_memory=False) 26 | vocab, word2idx, idx2word, max_len = vocab_set = train_set.get_vocab() 27 | val_set = Flickr8kDataset(dataset_base_path=DATASET_BASE_PATH, dist='val', vocab_set=vocab_set, device=device, 28 | return_type='corpus', 29 | load_img_to_memory=False) 30 | test_set = Flickr8kDataset(dataset_base_path=DATASET_BASE_PATH, dist='test', vocab_set=vocab_set, device=device, 31 | return_type='corpus', 32 | load_img_to_memory=False) 33 | train_eval_set = Flickr8kDataset(dataset_base_path=DATASET_BASE_PATH, dist='train', vocab_set=vocab_set, device=device, 34 | return_type='corpus', 35 | load_img_to_memory=False) 36 | with open('vocab_set.pkl', 'wb') as f: 37 | pickle.dump(train_set.get_vocab(), f) 38 | len(train_set), len(val_set), len(test_set) 39 | 40 | # %% 41 | vocab_size = len(vocab) 42 | vocab_size, max_len 43 | 44 | # %% 45 | 46 | MODEL = "resnet50_monolstm" 47 | EMBEDDING_DIM = 50 48 | EMBEDDING = f"GLV{EMBEDDING_DIM}" 49 | HIDDEN_SIZE = 256 50 | BATCH_SIZE = 16 51 | LR = 1e-2 52 | MODEL_NAME = f'saved_models/{MODEL}_b{BATCH_SIZE}_emd{EMBEDDING}' 53 | NUM_EPOCHS = 2 54 | SAVE_FREQ = 2 55 | LOG_INTERVAL = 25 56 | 57 | run = wandb.init(project='image-captioning', 58 | entity='datalab-buet', 59 | name=f"{MODEL}_b{BATCH_SIZE}_emd{EMBEDDING}-{1}", 60 | # tensorboard=True, sync_tensorboard=True, 61 | config={"learning_rate": LR, 62 | "epochs": NUM_EPOCHS, 63 | "batch_size": BATCH_SIZE, 64 | "model": MODEL, 65 | "embedding": EMBEDDING, 66 | "embedding_dim": EMBEDDING_DIM, 67 | "hidden_size": HIDDEN_SIZE, 68 | }, 69 | reinit=True) 70 | 71 | # %% 72 | embedding_matrix = embedding_matrix_creator(embedding_dim=EMBEDDING_DIM, word2idx=word2idx) 73 | embedding_matrix.shape 74 | 75 | 76 | # %% 77 | 78 | def train_model(train_loader, model, loss_fn, optimizer, vocab_size, acc_fn, desc=''): 79 | running_acc = 0.0 80 | running_loss = 0.0 81 | model.train() 82 | t = tqdm(iter(train_loader), desc=f'{desc}') 83 | for batch_idx, batch in enumerate(t): 84 | images, captions, lengths = batch 85 | sort_ind = torch.argsort(lengths, descending=True) 86 | images = images[sort_ind] 87 | captions = captions[sort_ind] 88 | lengths = lengths[sort_ind] 89 | 90 | optimizer.zero_grad() 91 | # [sum_len, vocab_size] 92 | outputs = model(images, captions, lengths) 93 | # [b, max_len] -> [sum_len] 94 | targets = pack_padded_sequence(captions, lengths=lengths, batch_first=True, enforce_sorted=True)[0] 95 | 96 | loss = loss_fn(outputs, targets) 97 | loss.backward() 98 | optimizer.step() 99 | 100 | running_acc += (torch.argmax(outputs, dim=1) == targets).sum().float().item() / targets.size(0) 101 | running_loss += loss.item() 102 | t.set_postfix({'loss': running_loss / (batch_idx + 1), 103 | 'acc': running_acc / (batch_idx + 1), 104 | }, refresh=True) 105 | if (batch_idx + 1) % LOG_INTERVAL == 0: 106 | print(f'{desc} {batch_idx + 1}/{len(train_loader)} ' 107 | f'train_loss: {running_loss / (batch_idx + 1):.4f} ' 108 | f'train_acc: {running_acc / (batch_idx + 1):.4f}') 109 | wandb.log({ 110 | 'train_loss': running_loss / (batch_idx + 1), 111 | 'train_acc': running_acc / (batch_idx + 1), 112 | }) 113 | 114 | return running_loss / len(train_loader) 115 | 116 | 117 | def evaluate_model(data_loader, model, loss_fn, vocab_size, bleu_score_fn, tensor_to_word_fn, desc=''): 118 | running_bleu = [0.0] * 5 119 | model.eval() 120 | t = tqdm(iter(data_loader), desc=f'{desc}') 121 | for batch_idx, batch in enumerate(t): 122 | images, captions, lengths = batch 123 | outputs = tensor_to_word_fn(model.sample(images).cpu().numpy()) 124 | 125 | for i in (1, 2, 3, 4): 126 | running_bleu[i] += bleu_score_fn(reference_corpus=captions, candidate_corpus=outputs, n=i) 127 | t.set_postfix({ 128 | 'bleu1': running_bleu[1] / (batch_idx + 1), 129 | 'bleu4': running_bleu[4] / (batch_idx + 1), 130 | }, refresh=True) 131 | for i in (1, 2, 3, 4): 132 | running_bleu[i] /= len(data_loader) 133 | return running_bleu 134 | 135 | 136 | # %% 137 | 138 | from models.torch.densenet201_monolstm import Captioner 139 | 140 | final_model = Captioner(EMBEDDING_DIM, HIDDEN_SIZE, vocab_size, num_layers=2, 141 | embedding_matrix=embedding_matrix, train_embd=False).to(device) 142 | 143 | loss_fn = torch.nn.CrossEntropyLoss(ignore_index=train_set.pad_value).to(device) 144 | acc_fn = accuracy_fn(ignore_value=train_set.pad_value) 145 | sentence_bleu_score_fn = bleu_score_fn(4, 'sentence') 146 | corpus_bleu_score_fn = bleu_score_fn(4, 'corpus') 147 | tensor_to_word_fn = words_from_tensors_fn(idx2word=idx2word) 148 | 149 | params = list(final_model.decoder.parameters()) + list(final_model.encoder.embed.parameters()) + list( 150 | final_model.encoder.bn.parameters()) 151 | 152 | optimizer = torch.optim.Adam(params=params, lr=LR) 153 | 154 | wandb.watch(final_model, log='all', log_freq=50) 155 | wandb.watch(final_model.encoder, log='all', log_freq=50) 156 | wandb.watch(final_model.decoder, log='all', log_freq=50) 157 | wandb.save('vocab_set.pkl') 158 | 159 | # %% 160 | train_transformations = transforms.Compose([ 161 | transforms.Resize(256), # smaller edge of image resized to 256 162 | transforms.RandomCrop(224), # get 224x224 crop from random location 163 | transforms.RandomHorizontalFlip(p=0.5), 164 | transforms.ToTensor(), # convert the PIL Image to a tensor 165 | transforms.Normalize((0.485, 0.456, 0.406), # normalize image for pre-trained model 166 | (0.229, 0.224, 0.225)) 167 | ]) 168 | eval_transformations = transforms.Compose([ 169 | transforms.Resize(256), # smaller edge of image resized to 256 170 | transforms.CenterCrop(224), # get 224x224 crop from random location 171 | transforms.ToTensor(), # convert the PIL Image to a tensor 172 | transforms.Normalize((0.485, 0.456, 0.406), # normalize image for pre-trained model 173 | (0.229, 0.224, 0.225)) 174 | ]) 175 | 176 | train_set.transformations = train_transformations 177 | val_set.transformations = eval_transformations 178 | test_set.transformations = eval_transformations 179 | train_eval_set.transformations = eval_transformations 180 | 181 | # %% 182 | eval_collate_fn = lambda batch: (torch.stack([x[0] for x in batch]), [x[1] for x in batch], [x[2] for x in batch]) 183 | train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, sampler=None, pin_memory=False) 184 | val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False, sampler=None, pin_memory=False, 185 | collate_fn=eval_collate_fn) 186 | test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False, sampler=None, pin_memory=False, 187 | collate_fn=eval_collate_fn) 188 | train_eval_loader = DataLoader(train_eval_set, batch_size=BATCH_SIZE, shuffle=False, sampler=None, pin_memory=False, 189 | collate_fn=eval_collate_fn) 190 | # %% 191 | train_loss_min = 100 192 | val_bleu4_max = 0.0 193 | for epoch in range(NUM_EPOCHS): 194 | train_loss = train_model(desc=f'Epoch {epoch + 1}/{NUM_EPOCHS}', model=final_model, 195 | optimizer=optimizer, loss_fn=loss_fn, acc_fn=acc_fn, 196 | train_loader=train_loader, vocab_size=vocab_size) 197 | with torch.no_grad(): 198 | train_bleu = evaluate_model(desc=f'\tTrain Bleu Score: ', model=final_model, 199 | loss_fn=loss_fn, bleu_score_fn=corpus_bleu_score_fn, 200 | tensor_to_word_fn=tensor_to_word_fn, 201 | data_loader=train_eval_loader, vocab_size=vocab_size) 202 | val_bleu = evaluate_model(desc=f'\tValidation Bleu Score: ', model=final_model, 203 | loss_fn=loss_fn, bleu_score_fn=corpus_bleu_score_fn, 204 | tensor_to_word_fn=tensor_to_word_fn, 205 | data_loader=val_loader, vocab_size=vocab_size) 206 | print(f'Epoch {epoch + 1}/{NUM_EPOCHS}', 207 | ''.join([f'train_bleu{i}: {train_bleu[i]:.4f} ' for i in (1, 4)]), 208 | ''.join([f'val_bleu{i}: {val_bleu[i]:.4f} ' for i in (1, 4)]), 209 | ) 210 | wandb.log({f'val_bleu{i}': val_bleu[i] for i in (1, 2, 3, 4)}) 211 | wandb.log({'train_bleu': train_bleu[4]}) 212 | wandb.log({'val_bleu': val_bleu[4]}) 213 | state = { 214 | 'epoch': epoch + 1, 215 | 'state_dict': final_model.state_dict(), 216 | 'optimizer': optimizer.state_dict(), 217 | 'train_loss_latest': train_loss, 218 | 'val_bleu4_latest': val_bleu[4], 219 | 'train_loss_min': min(train_loss, train_loss_min), 220 | 'val_bleu4_max': max(val_bleu[4], val_bleu4_max), 221 | 'train_bleus': train_bleu, 222 | 'val_bleus': val_bleu, 223 | } 224 | torch.save(state, f'{MODEL_NAME}_latest.pt') 225 | wandb.save(f'{MODEL_NAME}_latest.pt') 226 | if train_loss < train_loss_min: 227 | train_loss_min = train_loss 228 | torch.save(state, f'{MODEL_NAME}''_best_train.pt') 229 | wandb.save(f'{MODEL_NAME}''_best_train.pt') 230 | if val_bleu[4] > val_bleu4_max: 231 | val_bleu4_max = val_bleu[4] 232 | torch.save(state, f'{MODEL_NAME}''_best_val.pt') 233 | wandb.save(f'{MODEL_NAME}''_best_val.pt') 234 | 235 | torch.save(state, f'{MODEL_NAME}_ep{NUM_EPOCHS:02d}_weights.pt') 236 | wandb.save(f'{MODEL_NAME}_ep{NUM_EPOCHS:02d}_weights.pt') 237 | final_model.eval() 238 | 239 | # %% 240 | model = final_model 241 | 242 | # %% 243 | t_i = 1003 244 | dset = train_set 245 | im, cp, _ = dset[t_i] 246 | print(''.join([idx2word[idx.item()] + ' ' for idx in model.sample(im.unsqueeze(0))[0]])) 247 | print(dset.get_image_captions(t_i)[1]) 248 | 249 | plt.imshow(dset[t_i][0].detach().cpu().permute(1, 2, 0), interpolation="bicubic") 250 | 251 | # %% 252 | t_i = 500 253 | dset = val_set 254 | im, cp, _ = dset[t_i] 255 | print(''.join([idx2word[idx.item()] + ' ' for idx in model.sample(im.unsqueeze(0))[0]])) 256 | print(cp) 257 | 258 | plt.imshow(dset[t_i][0].detach().cpu().permute(1, 2, 0), interpolation="bicubic") 259 | 260 | # %% 261 | t_i = 500 262 | dset = test_set 263 | im, cp, _ = dset[t_i] 264 | print(''.join([idx2word[idx.item()] + ' ' for idx in model.sample(im.unsqueeze(0))[0]])) 265 | print(cp) 266 | 267 | plt.imshow(dset[t_i][0].detach().cpu().permute(1, 2, 0), interpolation="bicubic") 268 | 269 | # %% 270 | with torch.no_grad(): 271 | model.eval() 272 | train_bleu = evaluate_model(desc=f'Train: ', model=final_model, 273 | loss_fn=loss_fn, bleu_score_fn=corpus_bleu_score_fn, 274 | tensor_to_word_fn=tensor_to_word_fn, 275 | data_loader=train_eval_loader, vocab_size=vocab_size) 276 | val_bleu = evaluate_model(desc=f'Val: ', model=final_model, 277 | loss_fn=loss_fn, bleu_score_fn=corpus_bleu_score_fn, 278 | tensor_to_word_fn=tensor_to_word_fn, 279 | data_loader=val_loader, vocab_size=vocab_size) 280 | test_bleu = evaluate_model(desc=f'Test: ', model=final_model, 281 | loss_fn=loss_fn, bleu_score_fn=corpus_bleu_score_fn, 282 | tensor_to_word_fn=tensor_to_word_fn, 283 | data_loader=test_loader, vocab_size=vocab_size) 284 | for setname, result in zip(('train', 'val', 'test'), (train_bleu, val_bleu, test_bleu)): 285 | print(setname, end=' ') 286 | for ngram in (1, 2, 3, 4): 287 | print(f'Bleu-{ngram}: {result[ngram]}', end=' ') 288 | wandb.run.summary[f"{setname}_bleu{ngram}"] = result[ngram] 289 | print() 290 | -------------------------------------------------------------------------------- /train_attntn.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import pickle 3 | import wandb 4 | from matplotlib import pyplot as plt 5 | from torch.nn.utils.rnn import pack_padded_sequence 6 | from torch.utils.data import DataLoader 7 | from torchvision import transforms 8 | 9 | from datasets.flickr8k import Flickr8kDataset 10 | from glove import embedding_matrix_creator 11 | from metrics import * 12 | from utils_torch import * 13 | 14 | # %% 15 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 16 | device 17 | # %% 18 | 19 | DATASET_BASE_PATH = 'data/flickr8k/' 20 | 21 | # %% 22 | 23 | train_set = Flickr8kDataset(dataset_base_path=DATASET_BASE_PATH, dist='train', device=device, 24 | return_type='tensor', 25 | load_img_to_memory=False) 26 | vocab, word2idx, idx2word, max_len = vocab_set = train_set.get_vocab() 27 | val_set = Flickr8kDataset(dataset_base_path=DATASET_BASE_PATH, dist='val', vocab_set=vocab_set, device=device, 28 | return_type='corpus', 29 | load_img_to_memory=False) 30 | test_set = Flickr8kDataset(dataset_base_path=DATASET_BASE_PATH, dist='test', vocab_set=vocab_set, device=device, 31 | return_type='corpus', 32 | load_img_to_memory=False) 33 | train_eval_set = Flickr8kDataset(dataset_base_path=DATASET_BASE_PATH, dist='train', vocab_set=vocab_set, device=device, 34 | return_type='corpus', 35 | load_img_to_memory=False) 36 | with open('vocab_set.pkl', 'wb') as f: 37 | pickle.dump(train_set.get_vocab(), f) 38 | len(train_set), len(val_set), len(test_set) 39 | 40 | # %% 41 | vocab_size = len(vocab) 42 | vocab_size, max_len 43 | 44 | # %% 45 | 46 | MODEL = "resnet101_attention" 47 | EMBEDDING_DIM = 300 48 | EMBEDDING = f"{EMBEDDING_DIM}" 49 | ATTENTION_DIM = 256 50 | DECODER_SIZE = 256 51 | BATCH_SIZE = 128 52 | LR = 1e-3 53 | MODEL_NAME = f'saved_models/{MODEL}_b{BATCH_SIZE}_emd{EMBEDDING}' 54 | NUM_EPOCHS = 2 55 | SAVE_FREQ = 10 56 | LOG_INTERVAL = 25 * (256 // BATCH_SIZE) 57 | 58 | run = wandb.init(project='image-captioning', 59 | entity='datalab-buet', 60 | name=f"{MODEL}_b{BATCH_SIZE}_emd{EMBEDDING}-{1}", 61 | # tensorboard=True, sync_tensorboard=True, 62 | config={"learning_rate": LR, 63 | "epochs": NUM_EPOCHS, 64 | "batch_size": BATCH_SIZE, 65 | "model": MODEL, 66 | "embedding": EMBEDDING, 67 | "embedding_dim": EMBEDDING_DIM, 68 | "attention_dim": ATTENTION_DIM, 69 | "decoder_dim": DECODER_SIZE, 70 | }, 71 | reinit=True) 72 | 73 | # %% 74 | embedding_matrix = embedding_matrix_creator(embedding_dim=EMBEDDING_DIM, word2idx=word2idx) 75 | embedding_matrix.shape 76 | 77 | 78 | # %% 79 | 80 | def train_model(train_loader, model, loss_fn, optimizer, vocab_size, acc_fn, desc=''): 81 | running_acc = 0.0 82 | running_loss = 0.0 83 | model.train() 84 | t = tqdm(iter(train_loader), desc=f'{desc}') 85 | for batch_idx, batch in enumerate(t): 86 | images, captions, lengths = batch 87 | 88 | optimizer.zero_grad() 89 | 90 | scores, caps_sorted, decode_lengths, alphas, sort_ind = model(images, captions, lengths) 91 | 92 | # Since decoding starts with , the targets are all words after , up to 93 | targets = caps_sorted[:, 1:] 94 | 95 | # Remove timesteps that we didn't decode at, or are pads 96 | # pack_padded_sequence is an easy trick to do this 97 | scores = pack_padded_sequence(scores, decode_lengths, batch_first=True)[0] 98 | targets = pack_padded_sequence(targets, decode_lengths, batch_first=True)[0] 99 | 100 | loss = loss_fn(scores, targets) 101 | loss.backward() 102 | optimizer.step() 103 | 104 | running_acc += (torch.argmax(scores, dim=1) == targets).sum().float().item() / targets.size(0) 105 | running_loss += loss.item() 106 | t.set_postfix({'loss': running_loss / (batch_idx + 1), 107 | 'acc': running_acc / (batch_idx + 1), 108 | }, refresh=True) 109 | if (batch_idx + 1) % LOG_INTERVAL == 0: 110 | print(f'{desc} {batch_idx + 1}/{len(train_loader)} ' 111 | f'train_loss: {running_loss / (batch_idx + 1):.4f} ' 112 | f'train_acc: {running_acc / (batch_idx + 1):.4f}') 113 | wandb.log({ 114 | 'train_loss': running_loss / (batch_idx + 1), 115 | 'train_acc': running_acc / (batch_idx + 1), 116 | }) 117 | 118 | return running_loss / len(train_loader) 119 | 120 | 121 | def evaluate_model(data_loader, model, loss_fn, vocab_size, bleu_score_fn, tensor_to_word_fn, desc=''): 122 | running_bleu = [0.0] * 5 123 | model.eval() 124 | t = tqdm(iter(data_loader), desc=f'{desc}') 125 | for batch_idx, batch in enumerate(t): 126 | images, captions, lengths = batch 127 | outputs = tensor_to_word_fn(model.sample(images, startseq_idx=word2idx['']).cpu().numpy()) 128 | 129 | for i in (1, 2, 3, 4): 130 | running_bleu[i] += bleu_score_fn(reference_corpus=captions, candidate_corpus=outputs, n=i) 131 | t.set_postfix({ 132 | 'bleu1': running_bleu[1] / (batch_idx + 1), 133 | 'bleu4': running_bleu[4] / (batch_idx + 1), 134 | }, refresh=True) 135 | for i in (1, 2, 3, 4): 136 | running_bleu[i] /= len(data_loader) 137 | return running_bleu 138 | 139 | 140 | # %% 141 | 142 | from models.torch.resnet101_attention import Captioner 143 | 144 | final_model = Captioner(encoded_image_size=14, encoder_dim=2048, 145 | attention_dim=ATTENTION_DIM, embed_dim=EMBEDDING_DIM, decoder_dim=DECODER_SIZE, 146 | vocab_size=vocab_size, 147 | embedding_matrix=embedding_matrix, train_embd=False).to(device) 148 | 149 | loss_fn = torch.nn.CrossEntropyLoss(ignore_index=train_set.pad_value).to(device) 150 | acc_fn = accuracy_fn(ignore_value=train_set.pad_value) 151 | sentence_bleu_score_fn = bleu_score_fn(4, 'sentence') 152 | corpus_bleu_score_fn = bleu_score_fn(4, 'corpus') 153 | tensor_to_word_fn = words_from_tensors_fn(idx2word=idx2word) 154 | 155 | params = final_model.parameters() 156 | 157 | optimizer = torch.optim.RMSprop(params=params, lr=LR) 158 | 159 | wandb.watch(final_model, log='all', log_freq=50) 160 | # wandb.watch(final_model.encoder, log='all', log_freq=50) 161 | wandb.watch(final_model.decoder, log='all', log_freq=50) 162 | wandb.save('vocab_set.pkl') 163 | 164 | sync_files_wandb(['main.ipynb', 'train_torch.py', 'models/torch/resnet101_attention.py']) 165 | 166 | # %% 167 | train_transformations = transforms.Compose([ 168 | transforms.Resize(256), # smaller edge of image resized to 256 169 | transforms.RandomCrop(256), # get 256x256 crop from random location 170 | transforms.RandomHorizontalFlip(p=0.5), 171 | transforms.ToTensor(), # convert the PIL Image to a tensor 172 | transforms.Normalize((0.485, 0.456, 0.406), # normalize image for pre-trained model 173 | (0.229, 0.224, 0.225)) 174 | ]) 175 | eval_transformations = transforms.Compose([ 176 | transforms.Resize(256), # smaller edge of image resized to 256 177 | transforms.CenterCrop(256), # get 256x256 crop from random location 178 | transforms.ToTensor(), # convert the PIL Image to a tensor 179 | transforms.Normalize((0.485, 0.456, 0.406), # normalize image for pre-trained model 180 | (0.229, 0.224, 0.225)) 181 | ]) 182 | 183 | train_set.transformations = train_transformations 184 | val_set.transformations = eval_transformations 185 | test_set.transformations = eval_transformations 186 | train_eval_set.transformations = eval_transformations 187 | 188 | # %% 189 | eval_collate_fn = lambda batch: (torch.stack([x[0] for x in batch]), [x[1] for x in batch], [x[2] for x in batch]) 190 | train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, sampler=None, pin_memory=False) 191 | val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False, sampler=None, pin_memory=False, 192 | collate_fn=eval_collate_fn) 193 | test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False, sampler=None, pin_memory=False, 194 | collate_fn=eval_collate_fn) 195 | train_eval_loader = DataLoader(train_eval_set, batch_size=BATCH_SIZE, shuffle=False, sampler=None, pin_memory=False, 196 | collate_fn=eval_collate_fn) 197 | 198 | # %% 199 | train_loss_min = 100 200 | val_bleu4_max = 0.0 201 | for epoch in range(NUM_EPOCHS): 202 | train_loss = train_model(desc=f'Epoch {epoch + 1}/{NUM_EPOCHS}', model=final_model, 203 | optimizer=optimizer, loss_fn=loss_fn, acc_fn=acc_fn, 204 | train_loader=train_loader, vocab_size=vocab_size) 205 | with torch.no_grad(): 206 | train_bleu = evaluate_model(desc=f'\tTrain Bleu Score: ', model=final_model, 207 | loss_fn=loss_fn, bleu_score_fn=corpus_bleu_score_fn, 208 | tensor_to_word_fn=tensor_to_word_fn, 209 | data_loader=train_eval_loader, vocab_size=vocab_size) 210 | val_bleu = evaluate_model(desc=f'\tValidation Bleu Score: ', model=final_model, 211 | loss_fn=loss_fn, bleu_score_fn=corpus_bleu_score_fn, 212 | tensor_to_word_fn=tensor_to_word_fn, 213 | data_loader=val_loader, vocab_size=vocab_size) 214 | print(f'Epoch {epoch + 1}/{NUM_EPOCHS}', 215 | ''.join([f'train_bleu{i}: {train_bleu[i]:.4f} ' for i in (1, 4)]), 216 | ''.join([f'val_bleu{i}: {val_bleu[i]:.4f} ' for i in (1, 4)]), 217 | ) 218 | wandb.log({f'val_bleu{i}': val_bleu[i] for i in (1, 2, 3, 4)}) 219 | wandb.log({'train_bleu': train_bleu[4]}) 220 | wandb.log({'val_bleu': val_bleu[4]}) 221 | state = { 222 | 'epoch': epoch + 1, 223 | 'state_dict': final_model.state_dict(), 224 | 'optimizer': optimizer.state_dict(), 225 | 'train_loss_latest': train_loss, 226 | 'val_bleu4_latest': val_bleu[4], 227 | 'train_loss_min': min(train_loss, train_loss_min), 228 | 'val_bleu4_max': max(val_bleu[4], val_bleu4_max), 229 | 'train_bleus': train_bleu, 230 | 'val_bleus': val_bleu, 231 | } 232 | torch.save(state, f'{MODEL_NAME}_latest.pt') 233 | wandb.save(f'{MODEL_NAME}_latest.pt') 234 | if train_loss < train_loss_min: 235 | train_loss_min = train_loss 236 | torch.save(state, f'{MODEL_NAME}''_best_train.pt') 237 | wandb.save(f'{MODEL_NAME}''_best_train.pt') 238 | if val_bleu[4] > val_bleu4_max: 239 | val_bleu4_max = val_bleu[4] 240 | torch.save(state, f'{MODEL_NAME}''_best_val.pt') 241 | wandb.save(f'{MODEL_NAME}''_best_val.pt') 242 | 243 | torch.save(state, f'{MODEL_NAME}_ep{NUM_EPOCHS:02d}_weights.pt') 244 | wandb.save(f'{MODEL_NAME}_ep{NUM_EPOCHS:02d}_weights.pt') 245 | final_model.eval() 246 | 247 | # %% 248 | model = final_model 249 | 250 | # %% 251 | t_i = 1003 252 | dset = train_set 253 | im, cp, _ = dset[t_i] 254 | print(''.join([idx2word[idx.item()] + ' ' for idx in model.sample(im.unsqueeze(0), word2idx[''])[0]])) 255 | print(dset.get_image_captions(t_i)[1]) 256 | 257 | plt.imshow(dset[t_i][0].detach().cpu().permute(1, 2, 0), interpolation="bicubic") 258 | 259 | # %% 260 | t_i = 500 261 | dset = val_set 262 | im, cp, _ = dset[t_i] 263 | print(''.join([idx2word[idx.item()] + ' ' for idx in model.sample(im.unsqueeze(0), word2idx[''])[0]])) 264 | print(cp) 265 | 266 | plt.imshow(dset[t_i][0].detach().cpu().permute(1, 2, 0), interpolation="bicubic") 267 | 268 | # %% 269 | t_i = 500 270 | dset = test_set 271 | im, cp, _ = dset[t_i] 272 | print(''.join([idx2word[idx.item()] + ' ' for idx in model.sample(im.unsqueeze(0), word2idx[''])[0]])) 273 | print(cp) 274 | 275 | plt.imshow(dset[t_i][0].detach().cpu().permute(1, 2, 0), interpolation="bicubic") 276 | 277 | # %% 278 | with torch.no_grad(): 279 | model.eval() 280 | train_bleu = evaluate_model(desc=f'Train: ', model=final_model, 281 | loss_fn=loss_fn, bleu_score_fn=corpus_bleu_score_fn, 282 | tensor_to_word_fn=tensor_to_word_fn, 283 | data_loader=train_eval_loader, vocab_size=vocab_size) 284 | val_bleu = evaluate_model(desc=f'Val: ', model=final_model, 285 | loss_fn=loss_fn, bleu_score_fn=corpus_bleu_score_fn, 286 | tensor_to_word_fn=tensor_to_word_fn, 287 | data_loader=val_loader, vocab_size=vocab_size) 288 | test_bleu = evaluate_model(desc=f'Test: ', model=final_model, 289 | loss_fn=loss_fn, bleu_score_fn=corpus_bleu_score_fn, 290 | tensor_to_word_fn=tensor_to_word_fn, 291 | data_loader=test_loader, vocab_size=vocab_size) 292 | for setname, result in zip(('train', 'val', 'test'), (train_bleu, val_bleu, test_bleu)): 293 | print(setname, end=' ') 294 | for ngram in (1, 2, 3, 4): 295 | print(f'Bleu-{ngram}: {result[ngram]}', end=' ') 296 | wandb.run.summary[f"{setname}_bleu{ngram}"] = result[ngram] 297 | print() 298 | -------------------------------------------------------------------------------- /models/torch/resnet101_attention.py: -------------------------------------------------------------------------------- 1 | """ 2 | This model's basic architecture has been adapted from https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning 3 | """ 4 | 5 | import torch 6 | from torch import nn 7 | import torchvision 8 | 9 | from models.torch.layers import embedding_layer 10 | 11 | 12 | class Encoder(nn.Module): 13 | """ 14 | Encoder. 15 | """ 16 | 17 | def __init__(self, encoded_image_size=14): 18 | super(Encoder, self).__init__() 19 | self.enc_image_size = encoded_image_size 20 | 21 | resnet = torchvision.models.resnet101(pretrained=True) # pretrained ImageNet ResNet-101 22 | 23 | # Remove linear and pool layers (since we're not doing classification) 24 | modules = list(resnet.children())[:-2] 25 | self.resnet = nn.Sequential(*modules) 26 | 27 | # Resize image to fixed size to allow input images of variable size 28 | self.adaptive_pool = nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size)) 29 | 30 | self.fine_tune() 31 | 32 | def forward(self, images): 33 | """ 34 | Forward propagation. 35 | 36 | :param images: images, a tensor of dimensions (batch_size, 3, image_size, image_size) 37 | :return: encoded images 38 | """ 39 | # (batch_size, 2048, image_size/32, image_size/32) 40 | out = self.resnet(images) 41 | # (batch_size, 2048, encoded_image_size, encoded_image_size) 42 | out = self.adaptive_pool(out) 43 | # (batch_size, encoded_image_size, encoded_image_size, 2048) 44 | out = out.permute(0, 2, 3, 1) 45 | return out 46 | 47 | def fine_tune(self, fine_tune=False): 48 | """ 49 | Allow or prevent the computation of gradients for convolutional blocks 2 through 4 of the encoder. 50 | 51 | :param fine_tune: Allow? 52 | """ 53 | for p in self.resnet.parameters(): 54 | p.requires_grad = False 55 | # If fine-tuning, only fine-tune convolutional blocks 2 through 4 56 | for c in list(self.resnet.children())[5:]: 57 | for p in c.parameters(): 58 | p.requires_grad = fine_tune 59 | 60 | 61 | class Attention(nn.Module): 62 | """ 63 | Attention Network. 64 | """ 65 | 66 | def __init__(self, encoder_dim, decoder_dim, attention_dim): 67 | """ 68 | :param encoder_dim: feature size of encoded images 69 | :param decoder_dim: size of decoder's RNN 70 | :param attention_dim: size of the attention network 71 | """ 72 | super(Attention, self).__init__() 73 | self.encoder_att = nn.Linear(encoder_dim, attention_dim) # linear layer to transform encoded image 74 | self.decoder_att = nn.Linear(decoder_dim, attention_dim) # linear layer to transform decoder's output 75 | self.full_att = nn.Linear(attention_dim, 1) # linear layer to calculate values to be softmax-ed 76 | self.relu = nn.ReLU() 77 | self.softmax = nn.Softmax(dim=1) # softmax layer to calculate weights 78 | 79 | def forward(self, encoder_out, decoder_hidden): 80 | """ 81 | Forward propagation. 82 | 83 | :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim) 84 | :param decoder_hidden: previous decoder output, a tensor of dimension (batch_size, decoder_dim) 85 | :return: attention weighted encoding, weights 86 | """ 87 | # [b, num_pixels, attention_dim] 88 | att1 = self.encoder_att(encoder_out) 89 | # [b, attention_dim] 90 | att2 = self.decoder_att(decoder_hidden) 91 | # [b, num_pixels] 92 | att = self.full_att(self.relu(att1 + att2.unsqueeze(1))).squeeze(2) 93 | # [b, num_pixels] 94 | alpha = self.softmax(att) 95 | # [b, encoder_dim] 96 | attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1) 97 | 98 | return attention_weighted_encoding, alpha 99 | 100 | 101 | class DecoderWithAttention(nn.Module): 102 | """ 103 | Decoder. 104 | """ 105 | 106 | def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, encoder_dim=2048, dropout=0.5, 107 | embedding_matrix=None, train_embd=True): 108 | """ 109 | :param attention_dim: size of attention network 110 | :param embed_dim: embedding size 111 | :param decoder_dim: size of decoder's RNN 112 | :param vocab_size: size of vocabulary 113 | :param encoder_dim: feature size of encoded images 114 | :param dropout: dropout 115 | """ 116 | super(DecoderWithAttention, self).__init__() 117 | 118 | self.encoder_dim = encoder_dim 119 | self.attention_dim = attention_dim 120 | self.embed_dim = embed_dim 121 | self.decoder_dim = decoder_dim 122 | self.vocab_size = vocab_size 123 | self.dropout = dropout 124 | 125 | self.attention = Attention(encoder_dim, decoder_dim, attention_dim) # attention network 126 | 127 | self.embedding = embedding_layer(num_embeddings=vocab_size, embedding_dim=embed_dim, 128 | embedding_matrix=embedding_matrix, trainable=train_embd) 129 | self.dropout = nn.Dropout(p=self.dropout) 130 | self.decode_step = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim, bias=True) # decoding LSTMCell 131 | self.init_h = nn.Linear(encoder_dim, decoder_dim) # linear layer to find initial hidden state of LSTMCell 132 | self.init_c = nn.Linear(encoder_dim, decoder_dim) # linear layer to find initial cell state of LSTMCell 133 | self.f_beta = nn.Linear(decoder_dim, encoder_dim) # linear layer to create a sigmoid-activated gate 134 | self.sigmoid = nn.Sigmoid() 135 | self.fc = nn.Linear(decoder_dim, vocab_size) # linear layer to find scores over vocabulary 136 | self.init_weights() # initialize some layers with the uniform distribution 137 | 138 | def init_weights(self): 139 | """ 140 | Initializes some parameters with values from the uniform distribution, for easier convergence. 141 | """ 142 | self.embedding.weight.data.uniform_(-0.1, 0.1) 143 | self.fc.bias.data.fill_(0) 144 | self.fc.weight.data.uniform_(-0.1, 0.1) 145 | 146 | def init_hidden_state(self, encoder_out): 147 | """ 148 | Creates the initial hidden and cell states for the decoder's LSTM based on the encoded images. 149 | 150 | :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim) 151 | :return: hidden state, cell state [b, decoder_dim] 152 | """ 153 | mean_encoder_out = encoder_out.mean(dim=1) 154 | h = self.init_h(mean_encoder_out) 155 | c = self.init_c(mean_encoder_out) 156 | return h, c 157 | 158 | def forward(self, encoder_out, encoded_captions, caption_lengths): 159 | """ 160 | Forward propagation. 161 | 162 | :param encoder_out: encoded images, a tensor of dimension (batch_size, enc_image_size, enc_image_size, encoder_dim) 163 | :param encoded_captions: encoded captions, a tensor of dimension (batch_size, max_caption_length) 164 | :param caption_lengths: caption lengths, a tensor of dimension (batch_size, 1) 165 | :return: scores for vocabulary, sorted encoded captions, decode lengths, weights, sort indices 166 | """ 167 | 168 | batch_size = encoder_out.size(0) 169 | encoder_dim = encoder_out.size(-1) 170 | vocab_size = self.vocab_size 171 | 172 | # Flatten image 173 | # [b, num_pixels, encoder_dim] 174 | encoder_out = encoder_out.view(batch_size, -1, encoder_dim) 175 | num_pixels = encoder_out.size(1) 176 | 177 | # Sort input data by decreasing lengths; why? apparent below 178 | # [b, 1] -> [b], [b] 179 | caption_lengths, sort_ind = caption_lengths.squeeze(1).sort(dim=0, descending=True) 180 | encoder_out = encoder_out[sort_ind] 181 | encoded_captions = encoded_captions[sort_ind] 182 | 183 | # Embedding 184 | # [b, max_len, embed_dim] 185 | embeddings = self.embedding(encoded_captions) 186 | 187 | # Initialize LSTM state 188 | # [b, decoder_dim] 189 | h, c = self.init_hidden_state(encoder_out) 190 | 191 | # We won't decode at the position, since we've finished generating as soon as we generate 192 | # So, decoding lengths are actual lengths - 1 193 | decode_lengths = (caption_lengths - 1).tolist() 194 | 195 | # Create tensors to hold word predicion scores and alphas 196 | # [b, max_len, vocab_size] 197 | predictions = torch.zeros(batch_size, max(decode_lengths), vocab_size).to(encoder_out.device) 198 | # [b, num_pixels, vocab_size] 199 | alphas = torch.zeros(batch_size, max(decode_lengths), num_pixels).to(encoder_out.device) 200 | 201 | # At each time-step, decode by 202 | # attention-weighing the encoder's output based on the decoder's previous hidden state output 203 | # then generate a new word in the decoder with the previous word and the attention weighted encoding 204 | for t in range(max(decode_lengths)): 205 | batch_size_t = sum([l > t for l in decode_lengths]) 206 | # [b, encoder_dim], [b, num_pixels] -> [batch_size_t, encoder_dim], [batch_size_t, num_pixels] 207 | attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t], 208 | h[:batch_size_t]) 209 | # [batch_size_t, encoder_dim] 210 | gate = self.sigmoid(self.f_beta(h[:batch_size_t])) # gating scalar, 211 | attention_weighted_encoding = gate * attention_weighted_encoding 212 | # [batch_size_t, decoder_dim] 213 | h, c = self.decode_step( 214 | torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1), 215 | (h[:batch_size_t], c[:batch_size_t])) 216 | # [batch_size_t, vocab_size] 217 | preds = self.fc(self.dropout(h)) 218 | predictions[:batch_size_t, t, :] = preds 219 | alphas[:batch_size_t, t, :] = alpha 220 | 221 | return predictions, encoded_captions, decode_lengths, alphas, sort_ind 222 | 223 | def sample(self, encoder_out, startseq_idx, endseq_idx=-1, max_len=40, return_alpha=False): 224 | """ 225 | Samples captions in batch for given image features (Greedy search). 226 | :param encoder_out = [b, enc_image_size, enc_image_size, 2048] 227 | :return [b, max_len] 228 | """ 229 | enc_image_size = encoder_out.size(1) 230 | encoder_dim = encoder_out.size(3) 231 | batch_size = encoder_out.size(0) 232 | 233 | # decoder = self 234 | # [b, enc_image_size, enc_image_size, 2048] -> [b, num_pixels, 2048] 235 | encoder_out = encoder_out.view(batch_size, -1, encoder_dim) 236 | # [b, num_pixels, ] 237 | h, c = self.init_hidden_state(encoder_out) 238 | 239 | sampled_ids = [] # list of [b,] 240 | alphas = [] 241 | 242 | # [b, 1] 243 | prev_timestamp_words = torch.LongTensor([[startseq_idx]] * batch_size).to(encoder_out.device) 244 | for i in range(max_len): 245 | # [b, 1] -> [b, embed_dim] 246 | embeddings = self.embedding(prev_timestamp_words).squeeze(1) 247 | # ([b, encoder_dim], [b, num_pixels]) 248 | awe, alpha = self.attention(encoder_out, h) 249 | # [b, enc_image_size, enc_image_size] -> [b, 1, enc_image_size, enc_image_size] 250 | alpha = alpha.view(-1, enc_image_size, enc_image_size).unsqueeze(1) 251 | 252 | # [b, embed_dim] 253 | gate = self.sigmoid(self.f_beta(h)) # gating scalar 254 | # [b, embed_dim] 255 | awe = gate * awe 256 | 257 | # ([b, decoder_dim], ) 258 | h, c = self.decode_step(torch.cat([embeddings, awe], dim=1), (h, c)) 259 | # [b, vocab_size] 260 | predicted_prob = self.fc(h) 261 | # [b] 262 | predicted = predicted_prob.argmax(1) 263 | 264 | sampled_ids.append(predicted) 265 | alphas.append(alpha) 266 | 267 | # [b] -> [b, 1] 268 | prev_timestamp_words = predicted.unsqueeze(1) 269 | # [b, max_len] 270 | sampled_ids = torch.stack(sampled_ids, 1) 271 | return (sampled_ids, torch.cat(alphas, 1)) if return_alpha else sampled_ids 272 | 273 | 274 | class Captioner(nn.Module): 275 | def __init__(self, encoded_image_size, attention_dim, embed_dim, decoder_dim, vocab_size, encoder_dim=2048, 276 | dropout=0.5, **kwargs): 277 | super().__init__() 278 | self.encoder = Encoder(encoded_image_size=encoded_image_size) 279 | self.decoder = DecoderWithAttention(attention_dim, embed_dim, decoder_dim, vocab_size, 280 | encoder_dim, dropout) 281 | 282 | def forward(self, images, encoded_captions, caption_lengths): 283 | """ 284 | :param images: [b, 3, h, w] 285 | :param encoded_captions: [b, max_len] 286 | :param caption_lengths: [b,] 287 | :return: 288 | """ 289 | encoder_out = self.encoder(images) 290 | decoder_out = self.decoder(encoder_out, encoded_captions, caption_lengths.unsqueeze(1)) 291 | return decoder_out 292 | 293 | def sample(self, images, startseq_idx, endseq_idx=-1, max_len=40, return_alpha=False): 294 | encoder_out = self.encoder(images) 295 | return self.decoder.sample(encoder_out=encoder_out, startseq_idx=startseq_idx, max_len=max_len, 296 | return_alpha=return_alpha) 297 | -------------------------------------------------------------------------------- /main.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "outputs": [], 7 | "source": [ 8 | "%load_ext autoreload\n", 9 | "%autoreload 2" 10 | ], 11 | "metadata": { 12 | "collapsed": false, 13 | "pycharm": { 14 | "name": "#%%\n" 15 | } 16 | } 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": null, 21 | "outputs": [], 22 | "source": [ 23 | "import pickle\n", 24 | "import wandb\n", 25 | "from matplotlib import pyplot as plt\n", 26 | "from torch.nn.utils.rnn import pack_padded_sequence\n", 27 | "from torch.utils.data import DataLoader\n", 28 | "from torchvision import transforms\n", 29 | "\n", 30 | "from datasets.flickr8k import Flickr8kDataset\n", 31 | "from glove import embedding_matrix_creator\n", 32 | "from metrics import *\n", 33 | "from utils_torch import *" 34 | ], 35 | "metadata": { 36 | "collapsed": false, 37 | "pycharm": { 38 | "name": "#%%\n" 39 | } 40 | } 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "outputs": [], 46 | "source": [ 47 | "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n", 48 | "device" 49 | ], 50 | "metadata": { 51 | "collapsed": false, 52 | "pycharm": { 53 | "name": "#%%\n" 54 | } 55 | } 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": null, 60 | "outputs": [], 61 | "source": [ 62 | "DATASET_BASE_PATH = 'data/flickr8k/'" 63 | ], 64 | "metadata": { 65 | "collapsed": false, 66 | "pycharm": { 67 | "name": "#%%\n" 68 | } 69 | } 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "outputs": [], 75 | "source": [ 76 | "train_set = Flickr8kDataset(dataset_base_path=DATASET_BASE_PATH, dist='train', device=device,\n", 77 | " return_type='tensor',\n", 78 | " load_img_to_memory=False)\n", 79 | "vocab, word2idx, idx2word, max_len = vocab_set = train_set.get_vocab()\n", 80 | "val_set = Flickr8kDataset(dataset_base_path=DATASET_BASE_PATH, dist='val', vocab_set=vocab_set, device=device,\n", 81 | " return_type='corpus',\n", 82 | " load_img_to_memory=False)\n", 83 | "test_set = Flickr8kDataset(dataset_base_path=DATASET_BASE_PATH, dist='test', vocab_set=vocab_set, device=device,\n", 84 | " return_type='corpus',\n", 85 | " load_img_to_memory=False)\n", 86 | "train_eval_set = Flickr8kDataset(dataset_base_path=DATASET_BASE_PATH, dist='train', vocab_set=vocab_set, device=device,\n", 87 | " return_type='corpus',\n", 88 | " load_img_to_memory=False)\n", 89 | "with open('vocab_set.pkl', 'wb') as f:\n", 90 | " pickle.dump(train_set.get_vocab(), f)\n", 91 | "len(train_set), len(val_set), len(test_set)" 92 | ], 93 | "metadata": { 94 | "collapsed": false, 95 | "pycharm": { 96 | "name": "#%%\n" 97 | } 98 | } 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": null, 103 | "outputs": [], 104 | "source": [ 105 | "vocab_size = len(vocab)\n", 106 | "vocab_size, max_len" 107 | ], 108 | "metadata": { 109 | "collapsed": false, 110 | "pycharm": { 111 | "name": "#%%\n" 112 | } 113 | } 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": null, 118 | "outputs": [], 119 | "source": [ 120 | "MODEL = \"resnet101_attention\"\n", 121 | "EMBEDDING_DIM = 300\n", 122 | "EMBEDDING = f\"{EMBEDDING_DIM}\"\n", 123 | "ATTENTION_DIM = 256\n", 124 | "DECODER_SIZE = 256\n", 125 | "BATCH_SIZE = 128\n", 126 | "LR = 5e-4\n", 127 | "MODEL_NAME = f'saved_models/{MODEL}_b{BATCH_SIZE}_emd{EMBEDDING}'\n", 128 | "NUM_EPOCHS = 2\n", 129 | "SAVE_FREQ = 10\n", 130 | "LOG_INTERVAL = 25 * (256 // BATCH_SIZE)\n", 131 | "\n", 132 | "run = wandb.init(project='image-captioning',\n", 133 | " entity='datalab-buet',\n", 134 | " name=f\"{MODEL}_b{BATCH_SIZE}_emd{EMBEDDING}-{1}\",\n", 135 | " # tensorboard=True, sync_tensorboard=True,\n", 136 | " config={\"learning_rate\": LR,\n", 137 | " \"epochs\": NUM_EPOCHS,\n", 138 | " \"batch_size\": BATCH_SIZE,\n", 139 | " \"model\": MODEL,\n", 140 | " \"embedding\": EMBEDDING,\n", 141 | " \"embedding_dim\": EMBEDDING_DIM,\n", 142 | " \"attention_dim\": ATTENTION_DIM,\n", 143 | " \"decoder_dim\": DECODER_SIZE,\n", 144 | " },\n", 145 | " reinit=True)" 146 | ], 147 | "metadata": { 148 | "collapsed": false, 149 | "pycharm": { 150 | "name": "#%%\n" 151 | } 152 | } 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": null, 157 | "outputs": [], 158 | "source": [ 159 | "embedding_matrix = embedding_matrix_creator(embedding_dim=EMBEDDING_DIM, word2idx=word2idx)\n", 160 | "embedding_matrix.shape\n" 161 | ], 162 | "metadata": { 163 | "collapsed": false, 164 | "pycharm": { 165 | "name": "#%%\n" 166 | } 167 | } 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": null, 172 | "outputs": [], 173 | "source": [ 174 | "def train_model(train_loader, model, loss_fn, optimizer, vocab_size, acc_fn, desc=''):\n", 175 | " running_acc = 0.0\n", 176 | " running_loss = 0.0\n", 177 | " model.train()\n", 178 | " t = tqdm(iter(train_loader), desc=f'{desc}')\n", 179 | " for batch_idx, batch in enumerate(t):\n", 180 | " images, captions, lengths = batch\n", 181 | "\n", 182 | " optimizer.zero_grad()\n", 183 | "\n", 184 | " scores, caps_sorted, decode_lengths, alphas, sort_ind = model(images, captions, lengths)\n", 185 | "\n", 186 | " # Since decoding starts with , the targets are all words after , up to \n", 187 | " targets = caps_sorted[:, 1:]\n", 188 | "\n", 189 | " # Remove timesteps that we didn't decode at, or are pads\n", 190 | " # pack_padded_sequence is an easy trick to do this\n", 191 | " scores, _ = pack_padded_sequence(scores, decode_lengths, batch_first=True)\n", 192 | " targets, _ = pack_padded_sequence(targets, decode_lengths, batch_first=True)\n", 193 | "\n", 194 | " loss = loss_fn(scores, targets)\n", 195 | " loss.backward()\n", 196 | " optimizer.step()\n", 197 | "\n", 198 | " running_acc += (torch.argmax(scores, dim=1) == targets).sum().float().item() / targets.size(0)\n", 199 | " running_loss += loss.item()\n", 200 | " t.set_postfix({'loss': running_loss / (batch_idx + 1),\n", 201 | " 'acc': running_acc / (batch_idx + 1),\n", 202 | " }, refresh=True)\n", 203 | " if (batch_idx + 1) % LOG_INTERVAL == 0:\n", 204 | " print(f'{desc} {batch_idx + 1}/{len(train_loader)} '\n", 205 | " f'train_loss: {running_loss / (batch_idx + 1):.4f} '\n", 206 | " f'train_acc: {running_acc / (batch_idx + 1):.4f}')\n", 207 | " wandb.log({\n", 208 | " 'train_loss': running_loss / (batch_idx + 1),\n", 209 | " 'train_acc': running_acc / (batch_idx + 1),\n", 210 | " })\n", 211 | "\n", 212 | " return running_loss / len(train_loader)\n", 213 | "\n", 214 | "\n", 215 | "def evaluate_model(data_loader, model, loss_fn, vocab_size, bleu_score_fn, tensor_to_word_fn, desc=''):\n", 216 | " running_bleu = [0.0] * 5\n", 217 | " model.eval()\n", 218 | " t = tqdm(iter(data_loader), desc=f'{desc}')\n", 219 | " for batch_idx, batch in enumerate(t):\n", 220 | " images, captions, lengths = batch\n", 221 | " outputs = tensor_to_word_fn(model.sample(images, startseq_idx=word2idx['']).cpu().numpy())\n", 222 | "\n", 223 | " for i in (1, 2, 3, 4):\n", 224 | " running_bleu[i] += bleu_score_fn(reference_corpus=captions, candidate_corpus=outputs, n=i)\n", 225 | " t.set_postfix({\n", 226 | " 'bleu1': running_bleu[1] / (batch_idx + 1),\n", 227 | " 'bleu4': running_bleu[4] / (batch_idx + 1),\n", 228 | " }, refresh=True)\n", 229 | " for i in (1, 2, 3, 4):\n", 230 | " running_bleu[i] /= len(data_loader)\n", 231 | " return running_bleu\n" 232 | ], 233 | "metadata": { 234 | "collapsed": false, 235 | "pycharm": { 236 | "name": "#%%\n" 237 | } 238 | } 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": null, 243 | "outputs": [], 244 | "source": [ 245 | "from models.torch.resnet101_attention import Captioner\n", 246 | "\n", 247 | "final_model = Captioner(encoded_image_size=14, encoder_dim=2048,\n", 248 | " attention_dim=ATTENTION_DIM, embed_dim=EMBEDDING_DIM, decoder_dim=DECODER_SIZE,\n", 249 | " vocab_size=vocab_size,\n", 250 | " embedding_matrix=embedding_matrix, train_embd=True).to(device)\n", 251 | "\n", 252 | "loss_fn = torch.nn.CrossEntropyLoss(ignore_index=train_set.pad_value).to(device)\n", 253 | "acc_fn = accuracy_fn(ignore_value=train_set.pad_value)\n", 254 | "sentence_bleu_score_fn = bleu_score_fn(4, 'sentence')\n", 255 | "corpus_bleu_score_fn = bleu_score_fn(4, 'corpus')\n", 256 | "tensor_to_word_fn = words_from_tensors_fn(idx2word=idx2word)\n", 257 | "\n", 258 | "params = final_model.parameters()\n", 259 | "\n", 260 | "optimizer = torch.optim.Adam(params=params, lr=LR)\n", 261 | "\n", 262 | "wandb.watch(final_model, log='all', log_freq=50)\n", 263 | "# wandb.watch(final_model.encoder, log='all', log_freq=50)\n", 264 | "wandb.watch(final_model.decoder, log='all', log_freq=50)\n", 265 | "wandb.save('vocab_set.pkl')\n", 266 | "\n", 267 | "sync_files_wandb(['main.ipynb', 'train_attntn.py', 'models/torch/resnet101_attention.py'])" 268 | ], 269 | "metadata": { 270 | "collapsed": false, 271 | "pycharm": { 272 | "name": "#%%\n" 273 | } 274 | } 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": null, 279 | "outputs": [], 280 | "source": [ 281 | "train_transformations = transforms.Compose([\n", 282 | " transforms.Resize(256), # smaller edge of image resized to 256\n", 283 | " transforms.RandomCrop(256), # get 256x256 crop from random location\n", 284 | " transforms.RandomHorizontalFlip(p=0.5),\n", 285 | " transforms.ToTensor(), # convert the PIL Image to a tensor\n", 286 | " transforms.Normalize((0.485, 0.456, 0.406), # normalize image for pre-trained model\n", 287 | " (0.229, 0.224, 0.225))\n", 288 | "])\n", 289 | "eval_transformations = transforms.Compose([\n", 290 | " transforms.Resize(256), # smaller edge of image resized to 256\n", 291 | " transforms.CenterCrop(256), # get 256x256 crop from random location\n", 292 | " transforms.ToTensor(), # convert the PIL Image to a tensor\n", 293 | " transforms.Normalize((0.485, 0.456, 0.406), # normalize image for pre-trained model\n", 294 | " (0.229, 0.224, 0.225))\n", 295 | "])\n", 296 | "\n", 297 | "train_set.transformations = train_transformations\n", 298 | "val_set.transformations = eval_transformations\n", 299 | "test_set.transformations = eval_transformations\n", 300 | "train_eval_set.transformations = eval_transformations" 301 | ], 302 | "metadata": { 303 | "collapsed": false, 304 | "pycharm": { 305 | "name": "#%%\n" 306 | } 307 | } 308 | }, 309 | { 310 | "cell_type": "code", 311 | "execution_count": null, 312 | "outputs": [], 313 | "source": [ 314 | "eval_collate_fn = lambda batch: (torch.stack([x[0] for x in batch]), [x[1] for x in batch], [x[2] for x in batch])\n", 315 | "train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, sampler=None, pin_memory=False)\n", 316 | "val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False, sampler=None, pin_memory=False,\n", 317 | " collate_fn=eval_collate_fn)\n", 318 | "test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False, sampler=None, pin_memory=False,\n", 319 | " collate_fn=eval_collate_fn)\n", 320 | "train_eval_loader = DataLoader(train_eval_set, batch_size=BATCH_SIZE, shuffle=False, sampler=None, pin_memory=False,\n", 321 | " collate_fn=eval_collate_fn)" 322 | ], 323 | "metadata": { 324 | "collapsed": false, 325 | "pycharm": { 326 | "name": "#%%\n" 327 | } 328 | } 329 | }, 330 | { 331 | "cell_type": "code", 332 | "execution_count": null, 333 | "outputs": [], 334 | "source": [ 335 | "train_loss_min = 100\n", 336 | "val_bleu4_max = 0.0\n", 337 | "for epoch in range(NUM_EPOCHS):\n", 338 | " train_loss = train_model(desc=f'Epoch {epoch + 1}/{NUM_EPOCHS}', model=final_model,\n", 339 | " optimizer=optimizer, loss_fn=loss_fn, acc_fn=acc_fn,\n", 340 | " train_loader=train_loader, vocab_size=vocab_size)\n", 341 | " with torch.no_grad():\n", 342 | " train_bleu = evaluate_model(desc=f'\\tTrain Bleu Score: ', model=final_model,\n", 343 | " loss_fn=loss_fn, bleu_score_fn=corpus_bleu_score_fn,\n", 344 | " tensor_to_word_fn=tensor_to_word_fn,\n", 345 | " data_loader=train_eval_loader, vocab_size=vocab_size)\n", 346 | " val_bleu = evaluate_model(desc=f'Validation Bleu Score: ', model=final_model,\n", 347 | " loss_fn=loss_fn, bleu_score_fn=corpus_bleu_score_fn,\n", 348 | " tensor_to_word_fn=tensor_to_word_fn,\n", 349 | " data_loader=val_loader, vocab_size=vocab_size)\n", 350 | " print(f'Epoch {epoch + 1}/{NUM_EPOCHS}',\n", 351 | " ''.join([f'train_bleu{i}: {train_bleu[i]:.4f} ' for i in (1, 4)]),\n", 352 | " ''.join([f'val_bleu{i}: {val_bleu[i]:.4f} ' for i in (1, 4)]),\n", 353 | " )\n", 354 | " wandb.log({f'val_bleu{i}': val_bleu[i] for i in (1, 2, 3, 4)})\n", 355 | " wandb.log({'val_bleu': val_bleu[4]})\n", 356 | " state = {\n", 357 | " 'epoch': epoch + 1,\n", 358 | " 'state_dict': final_model.state_dict(),\n", 359 | " 'optimizer': optimizer.state_dict(),\n", 360 | " 'train_loss_latest': train_loss,\n", 361 | " 'val_bleu4_latest': val_bleu[4],\n", 362 | " 'train_loss_min': min(train_loss, train_loss_min),\n", 363 | " 'val_bleu4_max': max(val_bleu[4], val_bleu4_max),\n", 364 | " 'train_bleus': train_bleu,\n", 365 | " 'val_bleus': val_bleu,\n", 366 | " }\n", 367 | " torch.save(state, f'{MODEL_NAME}_latest.pt')\n", 368 | " wandb.save(f'{MODEL_NAME}_latest.pt')\n", 369 | " if train_loss < train_loss_min:\n", 370 | " train_loss_min = train_loss\n", 371 | " torch.save(state, f'{MODEL_NAME}''_best_train.pt')\n", 372 | " wandb.save(f'{MODEL_NAME}''_best_train.pt')\n", 373 | " if val_bleu[4] > val_bleu4_max:\n", 374 | " val_bleu4_max = val_bleu[4]\n", 375 | " torch.save(state, f'{MODEL_NAME}''_best_val.pt')\n", 376 | " wandb.save(f'{MODEL_NAME}''_best_val.pt')\n", 377 | "\n", 378 | "torch.save(state, f'{MODEL_NAME}_ep{NUM_EPOCHS:02d}_weights.pt')\n", 379 | "wandb.save(f'{MODEL_NAME}_ep{NUM_EPOCHS:02d}_weights.pt')\n", 380 | "final_model.eval()" 381 | ], 382 | "metadata": { 383 | "collapsed": false, 384 | "pycharm": { 385 | "name": "#%%\n" 386 | } 387 | } 388 | }, 389 | { 390 | "cell_type": "code", 391 | "execution_count": null, 392 | "outputs": [], 393 | "source": [ 394 | "model = final_model" 395 | ], 396 | "metadata": { 397 | "collapsed": false, 398 | "pycharm": { 399 | "name": "#%%\n" 400 | } 401 | } 402 | }, 403 | { 404 | "cell_type": "code", 405 | "execution_count": null, 406 | "outputs": [], 407 | "source": [ 408 | "t_i = 1003\n", 409 | "dset = train_set\n", 410 | "im, cp, _ = dset[t_i]\n", 411 | "print(''.join([idx2word[idx.item()] + ' ' for idx in model.sample(im.unsqueeze(0), word2idx[''])[0]]))\n", 412 | "print(dset.get_image_captions(t_i)[1])\n", 413 | "\n", 414 | "plt.imshow(dset[t_i][0].detach().cpu().permute(1, 2, 0), interpolation=\"bicubic\")" 415 | ], 416 | "metadata": { 417 | "collapsed": false, 418 | "pycharm": { 419 | "name": "#%%\n" 420 | } 421 | } 422 | }, 423 | { 424 | "cell_type": "code", 425 | "execution_count": null, 426 | "outputs": [], 427 | "source": [ 428 | "t_i = 500\n", 429 | "dset = val_set\n", 430 | "im, cp, _ = dset[t_i]\n", 431 | "print(''.join([idx2word[idx.item()] + ' ' for idx in model.sample(im.unsqueeze(0), word2idx[''])[0]]))\n", 432 | "print(cp)\n", 433 | "\n", 434 | "plt.imshow(dset[t_i][0].detach().cpu().permute(1, 2, 0), interpolation=\"bicubic\")" 435 | ], 436 | "metadata": { 437 | "collapsed": false, 438 | "pycharm": { 439 | "name": "#%%\n" 440 | } 441 | } 442 | }, 443 | { 444 | "cell_type": "code", 445 | "execution_count": null, 446 | "outputs": [], 447 | "source": [ 448 | "t_i = 500\n", 449 | "dset = test_set\n", 450 | "im, cp, _ = dset[t_i]\n", 451 | "print(''.join([idx2word[idx.item()] + ' ' for idx in model.sample(im.unsqueeze(0), word2idx[''])[0]]))\n", 452 | "print(cp)\n", 453 | "\n", 454 | "plt.imshow(dset[t_i][0].detach().cpu().permute(1, 2, 0), interpolation=\"bicubic\")" 455 | ], 456 | "metadata": { 457 | "collapsed": false, 458 | "pycharm": { 459 | "name": "#%%\n" 460 | } 461 | } 462 | }, 463 | { 464 | "cell_type": "code", 465 | "execution_count": null, 466 | "outputs": [], 467 | "source": [ 468 | "with torch.no_grad():\n", 469 | " model.eval()\n", 470 | " train_bleu = evaluate_model(desc=f'Train: ', model=final_model,\n", 471 | " loss_fn=loss_fn, bleu_score_fn=corpus_bleu_score_fn,\n", 472 | " tensor_to_word_fn=tensor_to_word_fn,\n", 473 | " data_loader=train_eval_loader, vocab_size=vocab_size)\n", 474 | " val_bleu = evaluate_model(desc=f'Val: ', model=final_model,\n", 475 | " loss_fn=loss_fn, bleu_score_fn=corpus_bleu_score_fn,\n", 476 | " tensor_to_word_fn=tensor_to_word_fn,\n", 477 | " data_loader=val_loader, vocab_size=vocab_size)\n", 478 | " test_bleu = evaluate_model(desc=f'Test: ', model=final_model,\n", 479 | " loss_fn=loss_fn, bleu_score_fn=corpus_bleu_score_fn,\n", 480 | " tensor_to_word_fn=tensor_to_word_fn,\n", 481 | " data_loader=test_loader, vocab_size=vocab_size)\n", 482 | " for setname, result in zip(('train', 'val', 'test'), (train_bleu, val_bleu, test_bleu)):\n", 483 | " print(setname, end=' ')\n", 484 | " for ngram in (1, 2, 3, 4):\n", 485 | " print(f'Bleu-{ngram}: {result[ngram]}', end=' ')\n", 486 | " wandb.run.summary[f\"{setname}_bleu{ngram}\"] = result[ngram]\n", 487 | " print()" 488 | ], 489 | "metadata": { 490 | "collapsed": false, 491 | "pycharm": { 492 | "name": "#%%\n" 493 | } 494 | } 495 | } 496 | ], 497 | "metadata": { 498 | "kernelspec": { 499 | "display_name": "Python 3", 500 | "language": "python", 501 | "name": "python3" 502 | }, 503 | "language_info": { 504 | "codemirror_mode": { 505 | "name": "ipython", 506 | "version": 3 507 | }, 508 | "file_extension": ".py", 509 | "mimetype": "text/x-python", 510 | "name": "python", 511 | "nbconvert_exporter": "python", 512 | "pygments_lexer": "ipython3", 513 | "version": "3.8.3" 514 | } 515 | }, 516 | "nbformat": 4, 517 | "nbformat_minor": 1 518 | } --------------------------------------------------------------------------------