├── Alexnet.py ├── DataLoader.py ├── Decoder.py ├── DenseNet.py ├── Inception.py ├── Preprocess.py ├── README.md ├── Resnet.py ├── Resnet152.py ├── SqueezeNet.py ├── Validation.py ├── Vgg.py ├── Vocabulary.py ├── check ├── 1.jpg ├── 10.jpg ├── 11.jpg ├── 12.jpg ├── 13.jpg ├── 14.jpg ├── 15.jpg ├── 16.jpg ├── 2.jpg ├── 3.jpg ├── 4.jpg ├── 5.jpg ├── 6.jpg ├── 7.jpg ├── 8.jpg └── 9.jpg ├── test.py ├── train.py ├── train_pic.png ├── train_valid_loss.png └── utils.py /Alexnet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | import torch.nn as nn 6 | from torch.autograd import Variable 7 | import torchvision.models as models 8 | import torchvision.transforms as transforms 9 | 10 | 11 | class Alexnet(nn.Module): 12 | def __init__(self, embedding_dim=512): 13 | super(Alexnet, self).__init__() 14 | self.alexnet = models.alexnet(pretrained=True) 15 | in_features = self.alexnet.classifier[6].in_features 16 | self.linear = nn.Linear(in_features, embedding_dim) 17 | self.alexnet.classifier[6] = self.linear 18 | # self.batch_norm = nn.BatchNorm1d(embedding_dim, momentum=0.01) 19 | self.init_weights() 20 | 21 | def init_weights(self): 22 | self.linear.weight.data.normal_(0.0, 0.02) 23 | self.linear.bias.data.fill_(0) 24 | 25 | def forward(self, images): 26 | embed = self.alexnet(images) 27 | # embed = Variable(embed.data) 28 | # embed = embed.view(embed.size(0), -1) 29 | # embed = self.linear(embed) 30 | # embed = self.batch_norm(embed) 31 | return embed 32 | -------------------------------------------------------------------------------- /DataLoader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import nltk 4 | import time 5 | import torch 6 | from PIL import Image 7 | 8 | class DataLoader(): 9 | def __init__(self, dir_path, vocab, transform): 10 | self.images = None 11 | self.captions_dict = None 12 | # self.data = None 13 | self.vocab = vocab 14 | self.transform = transform 15 | self.load_captions(dir_path) 16 | self.load_images(dir_path) 17 | 18 | def load_captions(self, captions_dir): 19 | caption_file = os.path.join(captions_dir, 'captions.txt') 20 | captions_dict = {} 21 | with open(caption_file) as f: 22 | for line in f: 23 | cur_dict = json.loads(line) 24 | for k, v in cur_dict.items(): 25 | captions_dict[k] = v 26 | self.captions_dict = captions_dict 27 | 28 | def load_images(self, images_dir): 29 | files = os.listdir(images_dir) 30 | images = {} 31 | for cur_file in files: 32 | ext = cur_file.split('.')[1] 33 | if ext == 'jpg': 34 | images[cur_file] = self.transform(Image.open(os.path.join(images_dir, cur_file))) 35 | self.images = images 36 | 37 | def caption2ids(self, caption): 38 | vocab = self.vocab 39 | tokens = nltk.tokenize.word_tokenize(caption.lower()) 40 | vec = [] 41 | vec.append(vocab.get_id('')) 42 | vec.extend([vocab.get_id(word) for word in tokens]) 43 | vec.append(vocab.get_id('')) 44 | return vec 45 | 46 | def gen_data(self): 47 | images = [] 48 | captions = [] 49 | for image_id, cur_captions in self.captions_dict.items(): 50 | num_captions = len(cur_captions) 51 | images.extend([image_id] * num_captions) 52 | for caption in cur_captions: 53 | captions.append(self.caption2ids(caption)) 54 | # self.data = images, captions 55 | data = images, captions 56 | return data 57 | 58 | def get_image(self, image_id): 59 | return self.images[image_id] 60 | 61 | def shuffle_data(data, seed=0): 62 | images, captions = data 63 | shuffled_images = [] 64 | shuffled_captions = [] 65 | num_images = len(images) 66 | torch.manual_seed(seed) 67 | perm = list(torch.randperm(num_images)) 68 | for i in range(num_images): 69 | shuffled_images.append(images[perm[i]]) 70 | shuffled_captions.append(captions[perm[i]]) 71 | return shuffled_images, shuffled_captions 72 | 73 | # def make_minibatches(self, data, minibatch_size=1, seed=0): 74 | 75 | # def get_batch(self,): -------------------------------------------------------------------------------- /Decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | 5 | 6 | class RNN(nn.Module): 7 | def __init__(self, embedding_dim, hidden_dim, vocab_size): 8 | super(RNN, self).__init__() 9 | self.hidden_dim = hidden_dim 10 | self.word_embeddings = nn.Embedding(vocab_size, embedding_dim) 11 | self.lstm = nn.LSTM(embedding_dim, hidden_dim) 12 | self.linear = nn.Linear(hidden_dim, vocab_size) 13 | self.init_weights() 14 | 15 | def init_weights(self): 16 | self.word_embeddings.weight.data.uniform_(-0.1, 0.1) 17 | self.linear.weight.data.uniform_(-0.1, 0.1) 18 | self.linear.bias.data.fill_(0) 19 | 20 | def forward(self, features, caption): 21 | seq_length = len(caption) + 1 22 | embeds = self.word_embeddings(caption) 23 | embeds = torch.cat((features, embeds), 0) 24 | lstm_out, _ = self.lstm(embeds.unsqueeze(1)) 25 | out = self.linear(lstm_out.view(seq_length, -1)) 26 | return out 27 | 28 | def greedy(self, cnn_out, seq_len = 20): 29 | ip = cnn_out 30 | hidden = None 31 | ids_list = [] 32 | for t in range(seq_len): 33 | lstm_out, hidden = self.lstm(ip.unsqueeze(1), hidden) 34 | # generating single word at a time 35 | linear_out = self.linear(lstm_out.squeeze(1)) 36 | word_caption = linear_out.max(dim=1)[1] 37 | ids_list.append(word_caption) 38 | ip = self.word_embeddings(word_caption) 39 | return ids_list -------------------------------------------------------------------------------- /DenseNet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | import torch.nn as nn 6 | from torch.autograd import Variable 7 | import torchvision.models as models 8 | import torchvision.transforms as transforms 9 | 10 | 11 | class DenseNet(nn.Module): 12 | def __init__(self, embedding_dim=300): 13 | super(DenseNet, self).__init__() 14 | self.dense = models.densenet121(pretrained=True) 15 | self.linear = nn.Linear(self.dense.classifier.in_features, embedding_dim) 16 | self.dense.classifier = self.linear 17 | # self.batch_norm = nn.BatchNorm1d(embedding_dim, momentum=0.01) 18 | self.init_weights() 19 | 20 | def init_weights(self): 21 | self.linear.weight.data.normal_(0.0, 0.02) 22 | self.linear.bias.data.fill_(0) 23 | 24 | def forward(self, images): 25 | embed = self.dense(images) 26 | return embed 27 | -------------------------------------------------------------------------------- /Inception.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | import torch.nn as nn 6 | from torch.autograd import Variable 7 | import torchvision.models as models 8 | import torchvision.transforms as transforms 9 | 10 | 11 | class Inception(nn.Module): 12 | def __init__(self, embedding_dim=300): 13 | super(Inception, self).__init__() 14 | self.inception = models.inception_v3(pretrained=True) 15 | in_features = self.inception.fc.in_features 16 | self.linear = nn.Linear(in_features, embedding_dim) 17 | self.inception.fc = self.linear 18 | # self.batch_norm = nn.BatchNorm1d(embedding_dim, momentum=0.01) 19 | self.init_weights() 20 | 21 | def init_weights(self): 22 | self.linear.weight.data.normal_(0.0, 0.02) 23 | self.linear.bias.data.fill_(0) 24 | 25 | def forward(self, images): 26 | embed = self.inception(images) 27 | return embed 28 | -------------------------------------------------------------------------------- /Preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import time 4 | import numpy as np 5 | from PIL import Image 6 | from shutil import copyfile 7 | 8 | 9 | def read_captions(filepath): 10 | captions_dict = {} 11 | with open(filepath) as f: 12 | for line in f: 13 | line_split = line.split(sep='\t', maxsplit=1) 14 | caption = line_split[1][:-1] 15 | id_image = line_split[0].split(sep='#')[0] 16 | if id_image not in captions_dict: 17 | captions_dict[id_image] = [caption] 18 | else: 19 | captions_dict[id_image].append(caption) 20 | return captions_dict 21 | 22 | def get_ids(filepath): 23 | ids = [] 24 | with open(filepath) as f: 25 | for line in f: 26 | ids.append(line[:-1]) 27 | return ids 28 | 29 | def copyfiles(dir_output, dir_input, ids): 30 | if not os.path.exists(dir_output): 31 | os.makedirs(dir_output) 32 | for cur_id in ids: 33 | path_input = os.path.join(dir_input, cur_id) 34 | path_output = os.path.join(dir_output, cur_id) 35 | copyfile(path_input, path_output) 36 | 37 | def write_captions(dir_output, ids, captions_dict): 38 | output_path = os.path.join(dir_output, 'captions.txt') 39 | output = [] 40 | for cur_id in ids: 41 | cur_dict = {cur_id: captions_dict[cur_id]} 42 | output.append(json.dumps(cur_dict)) 43 | 44 | with open(output_path, mode='w') as f: 45 | f.write('\n'.join(output)) 46 | 47 | def segregate(dir_images, filepath_token, captions_path_input): 48 | dir_output = {'train': 'train', 49 | 'dev' : 'dev', 50 | 'test' : 'test' 51 | } 52 | 53 | # id [caption1, caption2, ..] 54 | captions_dict = read_captions(filepath_token) 55 | 56 | # train, dev, test images mixture 57 | images = os.listdir(dir_images) 58 | 59 | # read ids 60 | ids_train = get_ids(captions_path_input['train']) 61 | ids_dev = get_ids(captions_path_input['dev']) 62 | ids_test = get_ids(captions_path_input['test']) 63 | 64 | # copy images to respective dirs 65 | copyfiles(dir_output['train'], dir_images, ids_train) 66 | copyfiles(dir_output['dev'], dir_images, ids_dev) 67 | copyfiles(dir_output['test'], dir_images, ids_test) 68 | 69 | # write id 70 | write_captions(dir_output['train'], ids_train, captions_dict) 71 | write_captions(dir_output['dev'], ids_dev, captions_dict) 72 | write_captions(dir_output['test'], ids_test, captions_dict) 73 | 74 | def load_captions(captions_dir): 75 | caption_file = os.path.join(captions_dir, 'captions.txt') 76 | captions_dict = {} 77 | with open(caption_file) as f: 78 | for line in f: 79 | cur_dict = json.loads(line) 80 | for k, v in cur_dict.items(): 81 | captions_dict[k] = v 82 | return captions_dict 83 | 84 | if __name__ == '__main__': 85 | dir_images = 'images' 86 | dir_text = 'text' 87 | filename_token = 'Flickr8k.token.txt' 88 | filename_train = 'Flickr_8k.trainImages.txt' 89 | filename_dev = 'Flickr_8k.devImages.txt' 90 | filename_test = 'Flickr_8k.testImages.txt' 91 | filepath_token = os.path.join(dir_text, filename_token) 92 | captions_path_input = {'train': os.path.join(dir_text, filename_train), 93 | 'dev': os.path.join(dir_text, filename_dev), 94 | 'test': os.path.join(dir_text, filename_test) 95 | } 96 | 97 | tic = time.time() 98 | segregate(dir_images, filepath_token, captions_path_input) 99 | toc = time.time() 100 | print('time: %.2f mins' %((toc-tic)/60)) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-image-captioning 2 | 3 | ## Abstract 4 | In this project, I have implemented an end-to-end Deep Learning model for Image Captioning. The architecture consists of Encoder and Decoder Networks. Encoder is one of the pre-trained CNN architectures to get image embedding. Decoder is LSTM network with un-intialized word embeddings. 5 | 6 | ## Requirements 7 | 1. python3.6 8 | 2. [pytorch](http://download.pytorch.org/whl/cu90/torch-0.4.0-cp36-cp36m-linux_x86_64.whl) 9 | 3. pytorch-vision 10 | 4. pillow 11 | 5. nltk 12 | 6. pickle 13 | 7. cuda version 9.0/9.1 14 | 8. cuDNN >=7.0 15 | 16 | ```bash 17 | pip install http://download.pytorch.org/whl/cu90/torch-0.4.0-cp36-cp36m-linux_x86_64.whl pytorch-vision pillow nltk pickle 18 | ``` 19 | 20 | ## Dataset 21 | [Flickr8K](http://www.jair.org/papers/paper3994.html)
22 | #train : 6000
23 | #dev : 1000
24 | #test : 1000
25 | 26 | ## Instructions to run the code 27 | 28 | ### 1. Pre-Processing 29 | ```bash 30 | python3 Preprocess.py 31 | ``` 32 | 33 | ### 2. Train 34 | ```bash 35 | python3 train.py -model -dir -save_iter -learning_rate -epoch -gpu_device -hidden_dim -embedding_dim 36 | ``` 37 | ##### args: 38 | 39 | `-model` : one of the cnn architectures - alexnet, resnet18, resnet152, vgg, inception, squeeze, dense
40 | ` -dir` : training directory path
41 | ` -save_iter` : create model checkpoint after some iterations, default = 10
42 | ` -learning_rate`: default = 1e-5
43 | ` -epoch` : re-train the network from saved checkpoint epoch
44 | ` -gpu_device` : gpu device number in case multiple gpus are installed on server
45 | ` -hidden_dim` : number of neurons for lstm's hidden state, default = 512
46 | ` -embedding_dim`: output of cnn encode model, default = 512
47 | 48 | 49 | ### 3. Test 50 | ```bash 51 | python3 test.py -model -i -epoch -gpu_device 52 | ``` 53 | ##### args: 54 | 55 | ` -i` : image path for generating caption
56 | 57 | [Download trained model](https://drive.google.com/open?id=1xF8dfIDsz57ZrX7bKApOakyjm1GoelJm): Trained for ~24 hours (230 iterations) on single NVIDIA 1080 (8GB) GTX GPU. 58 | 59 | ## Results 60 | ### Check whether the model is training or not by overfitting on small dataset. 61 | ![Screen Shot](train_pic.png) 62 | Since training error is decreasing it seems like model is working just fine. 63 | 64 | ### Train vs validation loss 65 | [![Screen Shot](train_valid_loss.png)](https://docs.google.com/spreadsheets/d/1VBz6r91D6P_9rybGmbaVNm-P1PM-xcnIWQ0wxQOWFzM/edit?usp=sharing) 66 | 67 | 68 | Image |Original Captions|Predicted Captions 69 | ----|----|---- 70 | ![Screen Shot](check/1.jpg) | 1. a beagle and a golden retriever wrestling in the grass
2. Two dogs are wrestling in the grass
3. Two puppies are playing in the green grass
4. two puppies playing around in the grass
5. Two puppies play in the grass | 50. a brown and white dog is running through a grassy field .
100. a brown dog in a field .
150. a brown dog is running through a grassy field .
200. **a brown and white dog is laying with its mouth open and people up in the grass .**
230. a brown dog running through grass .
71 | ![Screen Shot](check/2.jpg) | 1. a brightly decorated bicycle with cart with people walking around in the background
2. A street vending machine is parked while people walk by
3. A street vendor on the corner of a busy intersection
4. People on the city street walk past a puppet theater
5. People walk around a mobile puppet theater in a big city . | 50. a man with a green shirt is standing in front of a <unk> at a <unk> .
100. a group of people standing outside a building .
150. a group of people standing around a outside of building .
200. **a group of people are standing around a city street .**
230. a man in a green shirt <unk> a <unk> at a carnival .
72 | ![Screen Shot](check/3.jpg) | 1. A boat is on the water , with mountains in the background .
2. A boat on the water .
3. A lone boat sitting in the water .
4. A white boat on glassy water with mountains in the background .
5. This is a boat on the water with mountains in the background .
| 0. a man is on a <unk> .
30. a person on a surfboard is standing on a beach .
130. a person is standing on a mountain and overlooking the ocean .
230. **a person is standing on a rock and overlooking the ocean .**
73 | ![Screen Shot](check/4.jpg) | 1. A woman climbs up a cliff.
2. A woman rock climber scales a cliff far above pastures .
3. A woman rock-climbing on a cliff .
4. A woman rock-climbs in a rural area .
5. Woman climbing a cliff in a rural area
| 0. a man in a red and a <unk> is on a <unk>
30. **a man in a red shirt is climbing a rock .**
130. a man in a red shirt is rock climbing .
230. a man in a red shirt and green pants climbs a rock cliff .
74 | ![Screen Shot](check/5.jpg) | 1. Hikers cross a bridge over a fast moving stream and rocky scenery .
2. People crossing a long bridge over a canyon with a river .
3. People walk across a rope bridge over a rocky stream .
4. Some hikers are crossing a wood and wire bridge over a river .
5. Three people are looking across a rope and wood bridge over a river .
| 0. a man in a red of a <unk> .
30. a person in a blue jacket is jumping in the snow .
130. a person in the snow .
230. **a person on a snowboard in the air**
75 | ![Screen Shot](check/6.jpg) | 1. Two men in ethnic dress standing in a barren landscape .
2. Two men in keffiyahs stand next to car in the desert and wave at a passing vehicle .
3. Two men in robes wave at an approaching jeep traveling through the sand .
4. Two men in traditional Arab dress standing near a car wave at an SUV in the desert .
5. Two people with head coverings stand in a sandy field .
| 0. a man in a red and a white and a dog is on a <unk> .
30. a man and a woman are standing on a bench in a park .
130. **a man and a woman dressed in <unk> are walking along a dirt road .**
230. a man holding a camera and a woman is walking with her hands on a jumping away from a
76 | ![Screen Shot](check/10.jpg) | 1. A man mountain climbing up an icy mountain .
2. An climber is ascending an ice covered rock face .
3. A person in orange climbs a sheer cliff face covered in snow and ice .
4. Person in a yellow jacket is climbing up snow covered rocks .
5. There is a climber scaling a snowy mountainside .
| 0. a dog is in the water .
30. a man in a yellow shirt is standing in front of a waterfall .
130. a lone climber walks along a rocky path with mountains in the background .
230. **a man climbing a huge mountain .**
77 | ![Screen Shot](check/11.jpg) | 1. A boy with a stick kneeling in front of a goalie net
2. A child in a red jacket playing street hockey guarding a goal .
3. A young kid playing the goalie in a hockey rink .
4. A young male kneeling in front of a hockey goal with a hockey stick in his right hand .
5. Hockey goalie boy in red jacket crouches by goal , with stick .
| 0. a man in a red shirt and a red and a white dog is on a <unk> .
30. aa man and a woman are sitting on a red bench .
130. a man in a red shirt and a white helmet is sitting on a red leash .
230. **a man in a red shirt and blue jeans is sitting on a green wall .**
78 | ![Screen Shot](check/12.jpg) | 1. A group of eight people are gathered around a table at night .
2. A group of people gathered around in the dark .
3. A group of people sit around a table outside on a porch at night .
4.A group of people sit outdoors together at night .
5. A group of people sitting at a table in a darkened room .
| 0. a man in a <unk> .
30. a man is sitting on a bench in front of a crowd .
130. a man in a <unk> room with his closeup of two women .
230. **a group of people are standing in front of a large window .**
79 | ## References 80 | * [Show and Tell: A Neural Image Caption Generator](https://arxiv.org/abs/1411.4555) 81 | 82 | * [Deep Visual-Semantic Alignments for Generating Image Descriptions](https://cs.stanford.edu/people/karpathy/cvpr2015.pdf) 83 | 84 | * [Flickr 8K Dataset](http://www.jair.org/papers/paper3994.html) 85 | 86 | * [yunjey](https://github.com/yunjey/pytorch-tutorial/tree/master/tutorials/03-advanced/image_captioning) 87 | -------------------------------------------------------------------------------- /Resnet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | import torch.nn as nn 6 | from torch.autograd import Variable 7 | import torchvision.models as models 8 | import torchvision.transforms as transforms 9 | 10 | 11 | class Resnet(nn.Module): 12 | def __init__(self, embedding_dim=256): 13 | super(Resnet, self).__init__() 14 | self.resnet18 = models.resnet18(pretrained=True) 15 | in_features = self.resnet18.fc.in_features 16 | modules = list(self.resnet18.children())[:-1] 17 | self.resnet18 = nn.Sequential(*modules) 18 | self.linear = nn.Linear(in_features, embedding_dim) 19 | # self.batch_norm = nn.BatchNorm1d(embedding_dim, momentum=0.01) 20 | self.init_weights() 21 | 22 | def init_weights(self): 23 | self.linear.weight.data.normal_(0.0, 0.02) 24 | self.linear.bias.data.fill_(0) 25 | 26 | def forward(self, images): 27 | embed = self.resnet18(images) 28 | embed = Variable(embed.data) 29 | embed = embed.view(embed.size(0), -1) 30 | embed = self.linear(embed) 31 | # embed = self.batch_norm(embed) 32 | return embed 33 | -------------------------------------------------------------------------------- /Resnet152.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | import torch.nn as nn 6 | from torch.autograd import Variable 7 | import torchvision.models as models 8 | import torchvision.transforms as transforms 9 | 10 | 11 | class Resnet152(nn.Module): 12 | def __init__(self, embedding_dim=512): 13 | super(Resnet152, self).__init__() 14 | self.resnet152 = models.resnet152(pretrained=True) 15 | self.linear = nn.Linear(self.resnet152.fc.in_features, embedding_dim) 16 | self.resnet152.fc = self.linear 17 | # self.batch_norm = nn.BatchNorm1d(embedding_dim, momentum=0.01) 18 | self.init_weights() 19 | 20 | def init_weights(self): 21 | self.linear.weight.data.normal_(0.0, 0.02) 22 | self.linear.bias.data.fill_(0) 23 | 24 | def forward(self, images): 25 | embed = self.resnet152(images) 26 | # embed = Variable(embed.data) 27 | # embed = embed.view(embed.size(0), -1) 28 | # embed = self.linear(embed) 29 | # embed = self.batch_norm(embed) 30 | return embed 31 | -------------------------------------------------------------------------------- /SqueezeNet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | import torch.nn as nn 6 | from torch.autograd import Variable 7 | import torchvision.models as models 8 | import torchvision.transforms as transforms 9 | 10 | 11 | class SqueezeNet(nn.Module): 12 | def __init__(self, embedding_dim=300): 13 | super(SqueezeNet, self).__init__() 14 | self.squeeze = models.squeezenet1_1(pretrained=True) 15 | self.squeeze.num_classes = embedding_dim 16 | final_conv = nn.Conv2d(512, self.squeeze.num_classes, kernel_size=1) 17 | self.squeeze.classifier[1] = final_conv 18 | self.linear = self.squeeze.classifier[1] 19 | # self.batch_norm = nn.BatchNorm1d(embedding_dim, momentum=0.01) 20 | self.init_weights() 21 | 22 | def init_weights(self): 23 | self.linear.weight.data.normal_(0.0, 0.02) 24 | self.linear.bias.data.fill_(0) 25 | 26 | def forward(self, images): 27 | embed = self.squeeze(images) 28 | return embed 29 | -------------------------------------------------------------------------------- /Validation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import time 4 | import pickle 5 | import argparse 6 | from PIL import Image 7 | import torch.nn as nn 8 | from utils import get_cnn 9 | from Decoder import RNN 10 | from Vocabulary import Vocabulary 11 | from torch.autograd import Variable 12 | from torchvision import transforms 13 | from DataLoader import DataLoader, shuffle_data 14 | 15 | if __name__ == '__main__': 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('-dir', type = str, default = 'dev') 18 | parser.add_argument('-model') 19 | parser.add_argument('-epoch', type=int) 20 | parser.add_argument('-gpu_device', type=int) 21 | args = parser.parse_args() 22 | 23 | with open(os.path.join(args.model, 'vocab.pkl'), 'rb') as f: 24 | vocab = pickle.load(f) 25 | 26 | transform = transforms.Compose([transforms.Resize((224, 224)), 27 | transforms.ToTensor(), 28 | transforms.Normalize((0.5, 0.5, 0.5), 29 | (0.5, 0.5, 0.5)) 30 | ]) 31 | dataloader = DataLoader(args.dir, vocab, transform) 32 | data = dataloader.gen_data() 33 | print(args.dir + ' loaded') 34 | 35 | embedding_dim = 512 36 | vocab_size = vocab.index 37 | hidden_dim = 512 38 | model_name = args.model 39 | criterion = nn.CrossEntropyLoss() 40 | cnn = get_cnn(architecture = model_name, embedding_dim = embedding_dim) 41 | lstm = RNN(embedding_dim = embedding_dim, hidden_dim = hidden_dim, 42 | vocab_size = vocab_size) 43 | 44 | if torch.cuda.is_available(): 45 | with torch.cuda.device(args.gpu_device): 46 | cnn.cuda() 47 | lstm.cuda() 48 | 49 | for iteration in range(0, 240, 10): 50 | cnn_file = 'iter_' + str(iteration) + '_cnn.pkl' 51 | lstm_file = 'iter_' + str(iteration) + '_lstm.pkl' 52 | cnn.load_state_dict(torch.load(os.path.join(model_name, cnn_file))) 53 | lstm.load_state_dict(torch.load(os.path.join(model_name, lstm_file))) 54 | 55 | cnn.eval() 56 | lstm.eval() 57 | 58 | images, captions = data 59 | num_captions = len(captions) 60 | loss_list = [] 61 | # tic = time.time() 62 | with torch.no_grad(): 63 | for i in range(num_captions): 64 | image_id = images[i] 65 | image = dataloader.get_image(image_id) 66 | image = image.unsqueeze(0) 67 | 68 | if torch.cuda.is_available(): 69 | with torch.cuda.device(args.gpu_device): 70 | image = Variable(image).cuda() 71 | caption = torch.cuda.LongTensor(captions[i]) 72 | else: 73 | image = Variable(image) 74 | caption = torch.LongTensor(captions[i]) 75 | 76 | caption_train = caption[:-1] # remove 77 | 78 | loss = criterion(lstm(cnn(image), caption_train), caption) 79 | 80 | loss_list.append(loss) 81 | # avg_loss = torch.mean(torch.Tensor(loss_list)) 82 | # print('ex %d / %d avg_loss %f' %(i+1, num_captions, avg_loss), end='\r') 83 | # toc = time.time() 84 | avg_loss = torch.mean(torch.Tensor(loss_list)) 85 | print('%d %f' %(iteration, avg_loss)) -------------------------------------------------------------------------------- /Vgg.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | import torch.nn as nn 6 | from torch.autograd import Variable 7 | import torchvision.models as models 8 | import torchvision.transforms as transforms 9 | 10 | 11 | class Vgg(nn.Module): 12 | def __init__(self, embedding_dim=300): 13 | super(Vgg, self).__init__() 14 | self.vgg = models.vgg11(pretrained=True) 15 | in_features = self.vgg.classifier[6].in_features 16 | self.linear = nn.Linear(in_features, embedding_dim) 17 | self.vgg.classifier[6] = self.linear 18 | # self.batch_norm = nn.BatchNorm1d(embedding_dim, momentum=0.01) 19 | self.init_weights() 20 | 21 | def init_weights(self): 22 | self.linear.weight.data.normal_(0.0, 0.02) 23 | self.linear.bias.data.fill_(0) 24 | 25 | def forward(self, images): 26 | embed = self.vgg(images) 27 | return embed 28 | -------------------------------------------------------------------------------- /Vocabulary.py: -------------------------------------------------------------------------------- 1 | import os 2 | import nltk 3 | import json 4 | from collections import Counter 5 | from Preprocess import load_captions 6 | 7 | 8 | class Vocabulary(): 9 | def __init__(self, captions_dict, threshold): 10 | self.word2id = {} 11 | self.id2word = {} 12 | self.index = 0 13 | self.build(captions_dict, threshold) 14 | 15 | def add_word(self, word): 16 | if word not in self.word2id: 17 | self.word2id[word] = self.index 18 | self.id2word[self.index] = word 19 | self.index += 1 20 | 21 | def get_id(self, word): 22 | if word in self.word2id: 23 | return self.word2id[word] 24 | return self.word2id[''] 25 | 26 | def get_word(self, index): 27 | return self.id2word[index] 28 | 29 | def build(self, captions_dict, threshold): 30 | counter = Counter() 31 | tokens = [] 32 | for k, captions in captions_dict.items(): 33 | for caption in captions: 34 | tokens.extend(nltk.tokenize.word_tokenize(caption.lower())) 35 | 36 | counter.update(tokens) 37 | 38 | words = [word for word, count in counter.items() if count >= threshold] 39 | 40 | self.add_word('') 41 | self.add_word('') 42 | self.add_word('') 43 | self.add_word('') 44 | 45 | for word in words: 46 | self.add_word(word) 47 | 48 | def get_sentence(self, ids_list): 49 | sent = '' 50 | for cur_id in ids_list: 51 | cur_word = self.id2word[cur_id.item()] 52 | sent += ' ' + cur_word 53 | if cur_word == '': 54 | break 55 | return sent 56 | 57 | if __name__ == '__main__': 58 | 59 | captions_dict = load_captions('train') 60 | vocab = Vocabulary(captions_dict, 5) 61 | print(vocab.index) -------------------------------------------------------------------------------- /check/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yudi09/pytorch-image-captioning/bd30bcb196fa9bb686fa652eab77b73133677dae/check/1.jpg -------------------------------------------------------------------------------- /check/10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yudi09/pytorch-image-captioning/bd30bcb196fa9bb686fa652eab77b73133677dae/check/10.jpg -------------------------------------------------------------------------------- /check/11.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yudi09/pytorch-image-captioning/bd30bcb196fa9bb686fa652eab77b73133677dae/check/11.jpg -------------------------------------------------------------------------------- /check/12.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yudi09/pytorch-image-captioning/bd30bcb196fa9bb686fa652eab77b73133677dae/check/12.jpg -------------------------------------------------------------------------------- /check/13.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yudi09/pytorch-image-captioning/bd30bcb196fa9bb686fa652eab77b73133677dae/check/13.jpg -------------------------------------------------------------------------------- /check/14.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yudi09/pytorch-image-captioning/bd30bcb196fa9bb686fa652eab77b73133677dae/check/14.jpg -------------------------------------------------------------------------------- /check/15.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yudi09/pytorch-image-captioning/bd30bcb196fa9bb686fa652eab77b73133677dae/check/15.jpg -------------------------------------------------------------------------------- /check/16.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yudi09/pytorch-image-captioning/bd30bcb196fa9bb686fa652eab77b73133677dae/check/16.jpg -------------------------------------------------------------------------------- /check/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yudi09/pytorch-image-captioning/bd30bcb196fa9bb686fa652eab77b73133677dae/check/2.jpg -------------------------------------------------------------------------------- /check/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yudi09/pytorch-image-captioning/bd30bcb196fa9bb686fa652eab77b73133677dae/check/3.jpg -------------------------------------------------------------------------------- /check/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yudi09/pytorch-image-captioning/bd30bcb196fa9bb686fa652eab77b73133677dae/check/4.jpg -------------------------------------------------------------------------------- /check/5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yudi09/pytorch-image-captioning/bd30bcb196fa9bb686fa652eab77b73133677dae/check/5.jpg -------------------------------------------------------------------------------- /check/6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yudi09/pytorch-image-captioning/bd30bcb196fa9bb686fa652eab77b73133677dae/check/6.jpg -------------------------------------------------------------------------------- /check/7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yudi09/pytorch-image-captioning/bd30bcb196fa9bb686fa652eab77b73133677dae/check/7.jpg -------------------------------------------------------------------------------- /check/8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yudi09/pytorch-image-captioning/bd30bcb196fa9bb686fa652eab77b73133677dae/check/8.jpg -------------------------------------------------------------------------------- /check/9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yudi09/pytorch-image-captioning/bd30bcb196fa9bb686fa652eab77b73133677dae/check/9.jpg -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pickle 4 | import argparse 5 | from PIL import Image 6 | import torch.nn as nn 7 | from utils import get_cnn 8 | from Decoder import RNN 9 | from Vocabulary import Vocabulary 10 | from torch.autograd import Variable 11 | from torchvision import transforms 12 | from DataLoader import DataLoader, shuffle_data 13 | 14 | if __name__ == '__main__': 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('-i') 17 | parser.add_argument('-model') 18 | parser.add_argument('-epoch', type=int) 19 | parser.add_argument('-gpu_device', type=int) 20 | args = parser.parse_args() 21 | 22 | with open(os.path.join(args.model, 'vocab.pkl'), 'rb') as f: 23 | vocab = pickle.load(f) 24 | 25 | transform = transforms.Compose([transforms.Resize((224, 224)), 26 | transforms.ToTensor(), 27 | transforms.Normalize((0.5, 0.5, 0.5), 28 | (0.5, 0.5, 0.5)) 29 | ]) 30 | image = transform(Image.open(args.i)) 31 | 32 | embedding_dim = 512 33 | vocab_size = vocab.index 34 | hidden_dim = 512 35 | model_name = args.model 36 | cnn = get_cnn(architecture = model_name, embedding_dim = embedding_dim) 37 | lstm = RNN(embedding_dim = embedding_dim, hidden_dim = hidden_dim, 38 | vocab_size = vocab_size) 39 | # cnn.eval() 40 | 41 | image = image.unsqueeze(0) 42 | 43 | # image = Variable(image) 44 | if torch.cuda.is_available(): 45 | with torch.cuda.device(args.gpu_device): 46 | cnn.cuda() 47 | lstm.cuda() 48 | image = Variable(image).cuda() 49 | else: 50 | image = Variable(image) 51 | 52 | iteration = args.epoch 53 | cnn_file = 'iter_' + str(iteration) + '_cnn.pkl' 54 | lstm_file = 'iter_' + str(iteration) + '_lstm.pkl' 55 | cnn.load_state_dict(torch.load(os.path.join(model_name, cnn_file))) 56 | lstm.load_state_dict(torch.load(os.path.join(model_name, lstm_file))) 57 | 58 | 59 | cnn_out = cnn(image) 60 | ids_list = lstm.greedy(cnn_out) 61 | print(vocab.get_sentence(ids_list)) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import time 4 | import pickle 5 | import argparse 6 | import torch.nn as nn 7 | from Decoder import RNN 8 | from utils import get_cnn 9 | import matplotlib.pyplot as plt 10 | from Vocabulary import Vocabulary 11 | from torchvision import transforms 12 | from torch.autograd import Variable 13 | from Preprocess import load_captions 14 | from DataLoader import DataLoader, shuffle_data 15 | 16 | 17 | if __name__ == '__main__': 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('-model') 20 | parser.add_argument('-dir', type = str, default = 'train') 21 | parser.add_argument('-save_iter', type = int, default = 10) 22 | parser.add_argument('-learning_rate', type=float, default = 1e-5) 23 | parser.add_argument('-epoch', type=int) 24 | parser.add_argument('-gpu_device', type=int) 25 | parser.add_argument('-hidden_dim', type=int, default = 512) 26 | parser.add_argument('-embedding_dim', type=int, default = 512) 27 | 28 | args = parser.parse_args() 29 | print(args) 30 | train_dir = args.dir 31 | threshold = 5 32 | 33 | captions_dict = load_captions(train_dir) 34 | vocab = Vocabulary(captions_dict, threshold) 35 | with open(os.path.join(args.model, 'vocab.pkl'), 'wb') as f: 36 | pickle.dump(vocab, f) 37 | print('dictionary dump') 38 | transform = transforms.Compose([transforms.Resize((224, 224)), 39 | transforms.ToTensor(), 40 | transforms.Normalize((0.5, 0.5, 0.5), 41 | (0.5, 0.5, 0.5)) 42 | ]) 43 | 44 | dataloader = DataLoader(train_dir, vocab, transform) 45 | data = dataloader.gen_data() 46 | print(train_dir + ' loaded') 47 | 48 | # embedding_dim = 512 49 | vocab_size = vocab.index 50 | hidden_dim = 512 51 | # learning_rate = 1e-3 52 | model_name = args.model 53 | cnn = get_cnn(architecture = model_name, embedding_dim = args.embedding_dim) 54 | lstm = RNN(embedding_dim = args.embedding_dim, hidden_dim = args.hidden_dim, 55 | vocab_size = vocab_size) 56 | 57 | if torch.cuda.is_available(): 58 | with torch.cuda.device(args.gpu_device): 59 | cnn.cuda() 60 | lstm.cuda() 61 | # iteration = args.epoch 62 | # cnn_file = 'iter_' + str(iteration) + '_cnn.pkl' 63 | # lstm_file = 'iter_' + str(iteration) + '_lstm.pkl' 64 | # cnn.load_state_dict(torch.load(os.path.join(model_name, cnn_file))) 65 | # lstm.load_state_dict(torch.load(os.path.join(model_name, lstm_file))) 66 | 67 | criterion = nn.CrossEntropyLoss() 68 | params = list(cnn.linear.parameters()) + list(lstm.parameters()) 69 | optimizer = torch.optim.Adam(params, lr = args.learning_rate) 70 | num_epochs = 100000 71 | 72 | for epoch in range(num_epochs): 73 | shuffled_images, shuffled_captions = shuffle_data(data, seed = epoch) 74 | num_captions = len(shuffled_captions) 75 | loss_list = [] 76 | tic = time.time() 77 | for i in range(num_captions): 78 | image_id = shuffled_images[i] 79 | image = dataloader.get_image(image_id) 80 | image = image.unsqueeze(0) 81 | 82 | if torch.cuda.is_available(): 83 | with torch.cuda.device(args.gpu_device): 84 | image = Variable(image).cuda() 85 | caption = torch.cuda.LongTensor(shuffled_captions[i]) 86 | else: 87 | image = Variable(image) 88 | caption = torch.LongTensor(shuffled_captions[i]) 89 | 90 | caption_train = caption[:-1] # remove 91 | cnn.zero_grad() 92 | lstm.zero_grad() 93 | 94 | cnn_out = cnn(image) 95 | lstm_out = lstm(cnn_out, caption_train) 96 | loss = criterion(lstm_out, caption) 97 | loss.backward() 98 | optimizer.step() 99 | loss_list.append(loss) 100 | toc = time.time() 101 | avg_loss = torch.mean(torch.Tensor(loss_list)) 102 | print('epoch %d avg_loss %f time %.2f mins' 103 | %(epoch, avg_loss, (toc-tic)/60)) 104 | if epoch % args.save_iter == 0: 105 | 106 | torch.save(cnn.state_dict(), os.path.join(model_name, 'iter_%d_cnn.pkl'%(epoch))) 107 | torch.save(lstm.state_dict(), os.path.join(model_name, 'iter_%d_lstm.pkl'%(epoch))) 108 | 109 | -------------------------------------------------------------------------------- /train_pic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yudi09/pytorch-image-captioning/bd30bcb196fa9bb686fa652eab77b73133677dae/train_pic.png -------------------------------------------------------------------------------- /train_valid_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yudi09/pytorch-image-captioning/bd30bcb196fa9bb686fa652eab77b73133677dae/train_valid_loss.png -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from Vgg import Vgg 2 | from Resnet import Resnet 3 | from Alexnet import Alexnet 4 | from DenseNet import DenseNet 5 | from Inception import Inception 6 | from Resnet152 import Resnet152 7 | from SqueezeNet import SqueezeNet 8 | 9 | 10 | def get_cnn(architecture = 'resnet18', embedding_dim = 300): 11 | if architecture == 'resnet18': 12 | cnn = Resnet(embedding_dim = embedding_dim) 13 | elif architecture == 'resnet152': 14 | cnn = Resnet152(embedding_dim = embedding_dim) 15 | elif architecture == 'alexnet': 16 | cnn = Alexnet(embedding_dim = embedding_dim) 17 | elif architecture == 'vgg': 18 | cnn = Vgg(embedding_dim = embedding_dim) 19 | elif architecture == 'inception': 20 | cnn = Inception(embedding_dim = embedding_dim) 21 | elif architecture == 'squeeze': 22 | cnn = SqueezeNet(embedding_dim = embedding_dim) 23 | elif architecture == 'dense': 24 | cnn = DenseNet(embedding_dim = embedding_dim) 25 | return cnn --------------------------------------------------------------------------------