├── build_vocab.py ├── data_loader.py ├── download.sh ├── model.py ├── resize.py ├── sample.py └── train.py /build_vocab.py: -------------------------------------------------------------------------------- 1 | import nltk 2 | import pickle 3 | import argparse 4 | from collections import Counter 5 | from pycocotools.coco import COCO 6 | 7 | 8 | class Vocabulary(object): 9 | """Simple vocabulary wrapper.""" 10 | def __init__(self): 11 | self.word2idx = {} 12 | self.idx2word = {} 13 | self.idx = 0 14 | 15 | def add_word(self, word): 16 | if not word in self.word2idx: 17 | self.word2idx[word] = self.idx 18 | self.idx2word[self.idx] = word 19 | self.idx += 1 20 | 21 | def __call__(self, word): 22 | if not word in self.word2idx: 23 | return self.word2idx[''] 24 | return self.word2idx[word] 25 | 26 | def __len__(self): 27 | return len(self.word2idx) 28 | 29 | def build_vocab(json, threshold): 30 | """Build a simple vocabulary wrapper.""" 31 | coco = COCO(json) 32 | counter = Counter() 33 | ids = coco.anns.keys() 34 | for i, id in enumerate(ids): 35 | caption = str(coco.anns[id]['caption']) 36 | tokens = nltk.tokenize.word_tokenize(caption.lower()) 37 | counter.update(tokens) 38 | 39 | if i % 1000 == 0: 40 | print("[%d/%d] Tokenized the captions." %(i, len(ids))) 41 | 42 | # If the word frequency is less than 'threshold', then the word is discarded. 43 | words = [word for word, cnt in counter.items() if cnt >= threshold] 44 | 45 | # Creates a vocab wrapper and add some special tokens. 46 | vocab = Vocabulary() 47 | vocab.add_word('') 48 | vocab.add_word('') 49 | vocab.add_word('') 50 | vocab.add_word('') 51 | 52 | # Adds the words to the vocabulary. 53 | for i, word in enumerate(words): 54 | vocab.add_word(word) 55 | return vocab 56 | 57 | def main(args): 58 | vocab = build_vocab(json=args.caption_path, 59 | threshold=args.threshold) 60 | vocab_path = args.vocab_path 61 | with open(vocab_path, 'wb') as f: 62 | pickle.dump(vocab, f, pickle.HIGHEST_PROTOCOL) 63 | print("Total vocabulary size: %d" %len(vocab)) 64 | print("Saved the vocabulary wrapper to '%s'" %vocab_path) 65 | 66 | 67 | if __name__ == '__main__': 68 | parser = argparse.ArgumentParser() 69 | parser.add_argument('--caption_path', type=str, 70 | default='./data/annotations/captions_train2014.json', 71 | help='path for train annotation file') 72 | parser.add_argument('--vocab_path', type=str, default='./data/vocab.pkl', 73 | help='path for saving vocabulary wrapper') 74 | parser.add_argument('--threshold', type=int, default=4, 75 | help='minimum word count threshold') 76 | args = parser.parse_args() 77 | main(args) -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | import torch.utils.data as data 4 | import os 5 | import pickle 6 | import numpy as np 7 | import nltk 8 | from PIL import Image 9 | from build_vocab import Vocabulary 10 | from pycocotools.coco import COCO 11 | 12 | 13 | class CocoDataset(data.Dataset): 14 | """COCO Custom Dataset compatible with torch.utils.data.DataLoader.""" 15 | def __init__(self, root, json, vocab, transform=None): 16 | """Set the path for images, captions and vocabulary wrapper. 17 | 18 | Args: 19 | root: image directory. 20 | json: coco annotation file path. 21 | vocab: vocabulary wrapper. 22 | transform: image transformer. 23 | """ 24 | self.root = root 25 | self.coco = COCO(json) 26 | self.ids = list(self.coco.anns.keys()) 27 | self.vocab = vocab 28 | self.transform = transform 29 | 30 | def __getitem__(self, index): 31 | """Returns one data pair (image and caption).""" 32 | coco = self.coco 33 | vocab = self.vocab 34 | ann_id = self.ids[index] 35 | caption = coco.anns[ann_id]['caption'] 36 | img_id = coco.anns[ann_id]['image_id'] 37 | path = coco.loadImgs(img_id)[0]['file_name'] 38 | 39 | image = Image.open(os.path.join(self.root, path)).convert('RGB') 40 | if self.transform is not None: 41 | image = self.transform(image) 42 | 43 | # Convert caption (string) to word ids. 44 | tokens = nltk.tokenize.word_tokenize(str(caption).lower()) 45 | caption = [] 46 | caption.append(vocab('')) 47 | caption.extend([vocab(token) for token in tokens]) 48 | caption.append(vocab('')) 49 | target = torch.Tensor(caption) 50 | return image, target 51 | 52 | def __len__(self): 53 | return len(self.ids) 54 | 55 | 56 | def collate_fn(data): 57 | """Creates mini-batch tensors from the list of tuples (image, caption). 58 | 59 | We should build custom collate_fn rather than using default collate_fn, 60 | because merging caption (including padding) is not supported in default. 61 | 62 | Args: 63 | data: list of tuple (image, caption). 64 | - image: torch tensor of shape (3, 256, 256). 65 | - caption: torch tensor of shape (?); variable length. 66 | 67 | Returns: 68 | images: torch tensor of shape (batch_size, 3, 256, 256). 69 | targets: torch tensor of shape (batch_size, padded_length). 70 | lengths: list; valid length for each padded caption. 71 | """ 72 | # Sort a data list by caption length (descending order). 73 | data.sort(key=lambda x: len(x[1]), reverse=True) 74 | images, captions = zip(*data) 75 | 76 | # Merge images (from tuple of 3D tensor to 4D tensor). 77 | images = torch.stack(images, 0) 78 | 79 | # Merge captions (from tuple of 1D tensor to 2D tensor). 80 | lengths = [len(cap) for cap in captions] 81 | targets = torch.zeros(len(captions), max(lengths)).long() 82 | for i, cap in enumerate(captions): 83 | end = lengths[i] 84 | targets[i, :end] = cap[:end] 85 | return images, targets, lengths 86 | 87 | 88 | def get_loader(root, json, vocab, transform, batch_size, shuffle, num_workers): 89 | """Returns torch.utils.data.DataLoader for custom coco dataset.""" 90 | # COCO caption dataset 91 | coco = CocoDataset(root=root, 92 | json=json, 93 | vocab=vocab, 94 | transform=transform) 95 | 96 | # Data loader for COCO dataset 97 | # This will return (images, captions, lengths) for every iteration. 98 | # images: tensor of shape (batch_size, 3, 224, 224). 99 | # captions: tensor of shape (batch_size, padded_length). 100 | # lengths: list indicating valid length for each caption. length is (batch_size). 101 | data_loader = torch.utils.data.DataLoader(dataset=coco, 102 | batch_size=batch_size, 103 | shuffle=shuffle, 104 | num_workers=num_workers, 105 | collate_fn=collate_fn) 106 | return data_loader -------------------------------------------------------------------------------- /download.sh: -------------------------------------------------------------------------------- 1 | mkdir data 2 | wget http://msvocds.blob.core.windows.net/annotations-1-0-3/captions_train-val2014.zip -P ./data/ 3 | wget http://msvocds.blob.core.windows.net/coco2014/train2014.zip -P ./data/ 4 | wget http://msvocds.blob.core.windows.net/coco2014/val2014.zip -P ./data/ 5 | 6 | unzip ./data/captions_train-val2014.zip -d ./data/ 7 | rm ./data/captions_train-val2014.zip 8 | unzip ./data/train2014.zip -d ./data/ 9 | rm ./data/train2014.zip 10 | unzip ./data/val2014.zip -d ./data/ 11 | rm ./data/val2014.zip 12 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models as models 4 | from torch.nn.utils.rnn import pack_padded_sequence 5 | from torch.autograd import Variable 6 | 7 | 8 | class EncoderCNN(nn.Module): 9 | def __init__(self, embed_size): 10 | """Load the pretrained ResNet-152 and replace top fc layer.""" 11 | super(EncoderCNN, self).__init__() 12 | self.resnet = models.resnet152(pretrained=True) 13 | for param in self.resnet.parameters(): 14 | param.requires_grad = False 15 | self.resnet.fc = nn.Linear(self.resnet.fc.in_features, embed_size) 16 | self.bn = nn.BatchNorm1d(embed_size, momentum=0.01) 17 | self.init_weights() 18 | 19 | def init_weights(self): 20 | """Initialize the weights.""" 21 | self.resnet.fc.weight.data.normal_(0.0, 0.02) 22 | self.resnet.fc.bias.data.fill_(0) 23 | 24 | def forward(self, images): 25 | """Extract the image feature vectors.""" 26 | features = self.resnet(images) 27 | features = self.bn(features) 28 | return features 29 | 30 | 31 | class DecoderRNN(nn.Module): 32 | def __init__(self, embed_size, hidden_size, vocab_size, num_layers): 33 | """Set the hyper-parameters and build the layers.""" 34 | super(DecoderRNN, self).__init__() 35 | self.embed = nn.Embedding(vocab_size, embed_size) 36 | self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True) 37 | self.linear = nn.Linear(hidden_size, vocab_size) 38 | self.init_weights() 39 | 40 | def init_weights(self): 41 | """Initialize weights.""" 42 | self.embed.weight.data.uniform_(-0.1, 0.1) 43 | self.linear.weight.data.uniform_(-0.1, 0.1) 44 | self.linear.bias.data.fill_(0) 45 | 46 | def forward(self, features, captions, lengths): 47 | """Decode image feature vectors and generates captions.""" 48 | embeddings = self.embed(captions) 49 | embeddings = torch.cat((features.unsqueeze(1), embeddings), 1) 50 | packed = pack_padded_sequence(embeddings, lengths, batch_first=True) 51 | hiddens, _ = self.lstm(packed) 52 | outputs = self.linear(hiddens[0]) 53 | return outputs 54 | 55 | def sample(self, features, states): 56 | """Samples captions for given image features (Greedy search).""" 57 | sampled_ids = [] 58 | inputs = features.unsqueeze(1) 59 | for i in range(20): # maximum sampling length 60 | hiddens, states = self.lstm(inputs, states) # (batch_size, 1, hidden_size) 61 | outputs = self.linear(hiddens.squeeze(1)) # (batch_size, vocab_size) 62 | predicted = outputs.max(1)[1] 63 | sampled_ids.append(predicted) 64 | inputs = self.embed(predicted) 65 | sampled_ids = torch.cat(sampled_ids, 1) # (batch_size, 20) 66 | return sampled_ids.squeeze() -------------------------------------------------------------------------------- /resize.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from PIL import Image 4 | 5 | 6 | def resize_image(image, size): 7 | """Resize an image to the given size.""" 8 | return image.resize(size, Image.ANTIALIAS) 9 | 10 | def resize_images(image_dir, output_dir, size): 11 | """Resize the images in 'image_dir' and save into 'output_dir'.""" 12 | if not os.path.exists(output_dir): 13 | os.makedirs(output_dir) 14 | 15 | images = os.listdir(image_dir) 16 | num_images = len(images) 17 | for i, image in enumerate(images): 18 | with open(os.path.join(image_dir, image), 'r+b') as f: 19 | with Image.open(f) as img: 20 | img = resize_image(img, size) 21 | img.save(os.path.join(output_dir, image), img.format) 22 | if i % 100 == 0: 23 | print ("[%d/%d] Resized the images and saved into '%s'." 24 | %(i, num_images, output_dir)) 25 | 26 | def main(args): 27 | splits = ['train', 'val'] 28 | for split in splits: 29 | image_dir = args.image_dir 30 | output_dir = args.output_dir 31 | image_size = [args.image_size, args.image_size] 32 | resize_images(image_dir, output_dir, image_size) 33 | 34 | 35 | if __name__ == '__main__': 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument('--image_dir', type=str, default='./data/train2014/', 38 | help='directory for train images') 39 | parser.add_argument('--output_dir', type=str, default='./data/resized2014/', 40 | help='directory for saving resized images') 41 | parser.add_argument('--image_size', type=int, default=256, 42 | help='size for image after processing') 43 | args = parser.parse_args() 44 | main(args) -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import argparse 5 | import pickle 6 | import os 7 | from torch.autograd import Variable 8 | from torchvision import transforms 9 | from build_vocab import Vocabulary 10 | from model import EncoderCNN, DecoderRNN 11 | from PIL import Image 12 | 13 | 14 | def main(args): 15 | # Image preprocessing 16 | transform = transforms.Compose([ 17 | transforms.Scale(args.crop_size), 18 | transforms.CenterCrop(args.crop_size), 19 | transforms.ToTensor(), 20 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 21 | 22 | # Load vocabulary wrapper 23 | with open(args.vocab_path, 'rb') as f: 24 | vocab = pickle.load(f) 25 | 26 | # Build Models 27 | encoder = EncoderCNN(args.embed_size) 28 | encoder.eval() # evaluation mode (BN uses moving mean/variance) 29 | decoder = DecoderRNN(args.embed_size, args.hidden_size, 30 | len(vocab), args.num_layers) 31 | 32 | 33 | # Load the trained model parameters 34 | encoder.load_state_dict(torch.load(args.encoder_path)) 35 | decoder.load_state_dict(torch.load(args.decoder_path)) 36 | 37 | # Prepare Image 38 | image = Image.open(args.image) 39 | image_tensor = Variable(transform(image).unsqueeze(0)) 40 | 41 | # Set initial states 42 | state = (Variable(torch.zeros(args.num_layers, 1, args.hidden_size)), 43 | Variable(torch.zeros(args.num_layers, 1, args.hidden_size))) 44 | 45 | # If use gpu 46 | if torch.cuda.is_available(): 47 | encoder.cuda() 48 | decoder.cuda() 49 | state = [s.cuda() for s in state] 50 | image_tensor = image_tensor.cuda() 51 | 52 | # Generate caption from image 53 | feature = encoder(image_tensor) 54 | sampled_ids = decoder.sample(feature, state) 55 | sampled_ids = sampled_ids.cpu().data.numpy() 56 | 57 | # Decode word_ids to words 58 | sampled_caption = [] 59 | for word_id in sampled_ids: 60 | word = vocab.idx2word[word_id] 61 | sampled_caption.append(word) 62 | if word == '': 63 | break 64 | sentence = ' '.join(sampled_caption) 65 | 66 | # Print out image and generated caption. 67 | print (sentence) 68 | plt.imshow(np.asarray(image)) 69 | 70 | if __name__ == '__main__': 71 | parser = argparse.ArgumentParser() 72 | parser.add_argument('--image', type=str, required=True, 73 | help='input image for generating caption') 74 | parser.add_argument('--encoder_path', type=str, default='./models/encoder-5-3000.pkl', 75 | help='path for trained encoder') 76 | parser.add_argument('--decoder_path', type=str, default='./models/decoder-5-3000.pkl', 77 | help='path for trained decoder') 78 | parser.add_argument('--vocab_path', type=str, default='./data/vocab.pkl', 79 | help='path for vocabulary wrapper') 80 | parser.add_argument('--crop_size', type=int, default=224, 81 | help='size for center cropping images') 82 | 83 | # Model parameters (should be same as paramters in train.py) 84 | parser.add_argument('--embed_size', type=int , default=256, 85 | help='dimension of word embedding vectors') 86 | parser.add_argument('--hidden_size', type=int , default=512, 87 | help='dimension of lstm hidden states') 88 | parser.add_argument('--num_layers', type=int , default=1 , 89 | help='number of layers in lstm') 90 | args = parser.parse_args() 91 | main(args) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | import os 6 | import pickle 7 | from data_loader import get_loader 8 | from build_vocab import Vocabulary 9 | from model import EncoderCNN, DecoderRNN 10 | from torch.autograd import Variable 11 | from torch.nn.utils.rnn import pack_padded_sequence 12 | from torchvision import transforms 13 | 14 | 15 | def main(args): 16 | # Create model directory 17 | if not os.path.exists(args.model_path): 18 | os.makedirs(args.model_path) 19 | 20 | # Image preprocessing 21 | transform = transforms.Compose([ 22 | transforms.RandomCrop(args.crop_size), 23 | transforms.RandomHorizontalFlip(), 24 | transforms.ToTensor(), 25 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 26 | 27 | # Load vocabulary wrapper. 28 | with open(args.vocab_path, 'rb') as f: 29 | vocab = pickle.load(f) 30 | 31 | # Build data loader 32 | data_loader = get_loader(args.image_dir, args.caption_path, vocab, 33 | transform, args.batch_size, 34 | shuffle=True, num_workers=args.num_workers) 35 | 36 | # Build the models 37 | encoder = EncoderCNN(args.embed_size) 38 | decoder = DecoderRNN(args.embed_size, args.hidden_size, 39 | len(vocab), args.num_layers) 40 | 41 | if torch.cuda.is_available(): 42 | encoder.cuda() 43 | decoder.cuda() 44 | 45 | # Loss and Optimizer 46 | criterion = nn.CrossEntropyLoss() 47 | params = list(decoder.parameters()) + list(encoder.resnet.fc.parameters()) 48 | optimizer = torch.optim.Adam(params, lr=args.learning_rate) 49 | 50 | # Train the Models 51 | total_step = len(data_loader) 52 | for epoch in range(args.num_epochs): 53 | for i, (images, captions, lengths) in enumerate(data_loader): 54 | 55 | # Set mini-batch dataset 56 | images = Variable(images) 57 | captions = Variable(captions) 58 | if torch.cuda.is_available(): 59 | images = images.cuda() 60 | captions = captions.cuda() 61 | targets = pack_padded_sequence(captions, lengths, batch_first=True)[0] 62 | 63 | # Forward, Backward and Optimize 64 | decoder.zero_grad() 65 | encoder.zero_grad() 66 | features = encoder(images) 67 | outputs = decoder(features, captions, lengths) 68 | loss = criterion(outputs, targets) 69 | loss.backward() 70 | optimizer.step() 71 | 72 | # Print log info 73 | if i % args.log_step == 0: 74 | print('Epoch [%d/%d], Step [%d/%d], Loss: %.4f, Perplexity: %5.4f' 75 | %(epoch, args.num_epochs, i, total_step, 76 | loss.data[0], np.exp(loss.data[0]))) 77 | 78 | # Save the models 79 | if (i+1) % args.save_step == 0: 80 | torch.save(decoder.state_dict(), 81 | os.path.join(args.model_path, 82 | 'decoder-%d-%d.pkl' %(epoch+1, i+1))) 83 | torch.save(encoder.state_dict(), 84 | os.path.join(args.model_path, 85 | 'encoder-%d-%d.pkl' %(epoch+1, i+1))) 86 | 87 | if __name__ == '__main__': 88 | parser = argparse.ArgumentParser() 89 | parser.add_argument('--model_path', type=str, default='./models/' , 90 | help='path for saving trained models') 91 | parser.add_argument('--crop_size', type=int, default=224 , 92 | help='size for randomly cropping images') 93 | parser.add_argument('--vocab_path', type=str, default='./data/vocab.pkl', 94 | help='path for vocabulary wrapper') 95 | parser.add_argument('--image_dir', type=str, default='./data/resized2014' , 96 | help='directory for resized images') 97 | parser.add_argument('--caption_path', type=str, 98 | default='./data/annotations/captions_train2014.json', 99 | help='path for train annotation json file') 100 | parser.add_argument('--log_step', type=int , default=10, 101 | help='step size for prining log info') 102 | parser.add_argument('--save_step', type=int , default=1000, 103 | help='step size for saving trained models') 104 | 105 | # Model parameters 106 | parser.add_argument('--embed_size', type=int , default=256 , 107 | help='dimension of word embedding vectors') 108 | parser.add_argument('--hidden_size', type=int , default=512 , 109 | help='dimension of lstm hidden states') 110 | parser.add_argument('--num_layers', type=int , default=1 , 111 | help='number of layers in lstm') 112 | 113 | parser.add_argument('--num_epochs', type=int, default=5) 114 | parser.add_argument('--batch_size', type=int, default=128) 115 | parser.add_argument('--num_workers', type=int, default=2) 116 | parser.add_argument('--learning_rate', type=float, default=0.001) 117 | args = parser.parse_args() 118 | print(args) 119 | main(args) 120 | --------------------------------------------------------------------------------