├── LICENSE ├── README.md ├── annotation_data └── readme.m ├── dataloader_unsupervised.py ├── id2idx_data └── readme.m ├── images ├── readme.m └── schematic.png ├── model.py ├── process_dataset.py ├── process_dataset_embeddings.py ├── train_unsupervised_withval.py ├── utils.py └── visualfeatures_data └── readme.m /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Aleix Cambray 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GroundeR PyTorch Implementation 2 | >*Note: This project is functional yet unpolished for release, and so the below description and documentation is also. No more work is expected in the near future. 3 | It is uploaded in the hopes that it is useful for you in its current form.* 4 | 5 | This is a PyTorch implementation of the supervised and unsupervised GroundeR model from [Grounding of Textual Phrases in Images by Reconstruction](https://arxiv.org/pdf/1511.03745.pdf). 6 | 7 | The task is to localize what region of the image, a phrase is referring to. For example, if we have a description "**A man** is jumping over **a fence**" we would like to ground both entities to specific regions of the image. The task of phrase localization or phrase grounding is useful for problems which depend on these unknown mappings between entities and regions such as image captioning and multi-modal neural machine translation. 8 | 9 | ## Requisites 10 | Python 3.6 11 | \- Framework: PyTorch 1.0.1 12 | \- torchvision 13 | \- Numpy 14 | \- PIL 15 | \- pickle 16 | \- matplotlib 17 | 18 | ## Supervised and Unsupervised versions 19 | With reference to the below schematic. 20 | 21 | The supervised version of the code assumes we have ground truth of which region in the image the phrase refers to. Therefore the loss is computed as cross-entropy between the visual attention vector (alpha) and the one-hot ground truth vector. For a fully supervised version there is no need to run the decoder part of the model. 22 | 23 | On the other hand, the unsupervised version does not assume we have ground truth of the correct region and therefore relies on the reconstruction loss given by the cross-entropy between the phrase and the output of the decoder. 24 | 25 | ## Model Design and Architecture 26 | Below is a schematic of the exact implementation (and design decisions) of the code in this repository. 27 | 28 |

29 | 30 |

31 | -------------------------------------------------------------------------------- /annotation_data/readme.m: -------------------------------------------------------------------------------- 1 | processed annotation and co-reference data (from process_dataset.py) 2 | -------------------------------------------------------------------------------- /dataloader_unsupervised.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Aleix Cambray Roma 3 | Work done while at Imperial College London. 4 | """ 5 | import torch 6 | import pickle 7 | import numpy as np 8 | from numpy.random import randint 9 | from torch.utils.data import Dataset 10 | 11 | 12 | class UnsupervisedDataLoader(Dataset): 13 | """ Loads Flickr30k data for the unsupervised, reconstruction case. """ 14 | def __init__(self, data_dir, sample_list_file, seq_length): 15 | def read_sample_list(sample_list_file): 16 | f = open(data_dir + sample_list_file) 17 | return np.array([sample_id.strip() for sample_id in f.readlines()]) 18 | 19 | self.seq_length = seq_length 20 | self.sample_id_list = read_sample_list(sample_list_file) 21 | self.data_dir = data_dir 22 | 23 | with open(self.data_dir + "vocabulary.pkl", "rb") as f: 24 | self.word2idx, self.idx2word = pickle.load(f) 25 | self.vocab = list(self.word2idx.keys()) 26 | 27 | def __len__(self): 28 | return len(self.sample_id_list) 29 | 30 | def __getitem__(self, i): 31 | sample_id = self.sample_id_list[i] 32 | 33 | # Get Visual feature matrix 34 | vis_features = np.zeros((25, 1000), dtype='float32') 35 | real_feat = np.load(self.data_dir + "visualfeatures_data/" + sample_id + ".npy") 36 | vis_features[:real_feat.shape[0], :] = real_feat 37 | real_feat = real_feat.shape[0] 38 | 39 | # Get annotations (all phrases in this image) 40 | with open(self.data_dir + "annotation_data/" + sample_id + ".pkl", "rb") as f: 41 | annotations = pickle.load(f) 42 | phrases = annotations['seqs'] 43 | 44 | # Select random phrase from this image 45 | # print(len(phrases)) 46 | if len(phrases) == 0: 47 | print(sample_id) 48 | print(real_feat) 49 | print("") 50 | 51 | rand_int = randint(len(phrases)) 52 | phrase = phrases[rand_int] 53 | 54 | if len(phrase) > self.seq_length: 55 | phrase[self.seq_length-1] = phrase[-1] # Add end tag at last step 56 | phrase = phrase[:self.seq_length] # Truncate phrase to seq_length 57 | 58 | # Initialise all phrase arrays with the padding symbol 59 | pad_idx = self.word2idx[''] 60 | encoder_input = np.ones(self.seq_length, dtype='int64')*pad_idx 61 | decoder_input = np.ones(self.seq_length, dtype='int64')*pad_idx 62 | decoder_target = np.ones(self.seq_length, dtype='int64')*pad_idx 63 | mask = np.zeros(self.seq_length, dtype='int64') 64 | 65 | # Replace the first padding symbols by the real phrase 66 | encoder_input[0:len(phrase)] = np.array(phrase) # Feed to encoder LSTM: Both or tags 67 | decoder_input[0:len(phrase)-2] = np.array(phrase[1:-1]) # Feed to decoder LSTM: No tags 68 | decoder_target[0:len(phrase)-1] = np.array(phrase[1:]) # Used as reconstruction ground truth: Only tag 69 | mask[0:len(phrase)-1] = 1 70 | 71 | with open('C:/Data/GroundeR/id2idx_data/'+sample_id+'.pkl', 'rb') as f: 72 | id2idx = pickle.load(f) 73 | 74 | phrase_id = annotations['ids'][rand_int] 75 | true_region = id2idx[phrase_id] 76 | # Sentence as list of string-words (visualise) 77 | # phrase_word = np.array([self.idx2word[idx] for i, idx in enumerate(encoder_input) if i < len(phrase)]) 78 | 79 | # print("Shapes:") 80 | # print(" - vis_features: {}".format(vis_features.shape)) 81 | # print(" - encoder_input: {}".format(encoder_input.shape)) 82 | # print(" - decoder_input: {}".format(decoder_input.shape)) 83 | # print(" - decoder_target: {}".format(decoder_target.shape)) 84 | 85 | # device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 86 | 87 | # print('current memory allocated: {}'.format(torch.cuda.memory_allocated() / 1024 ** 2)) 88 | # print('max memory allocated: {}'.format(torch.cuda.max_memory_allocated() / 1024 ** 2)) 89 | # print('cached memory: {}'.format(torch.cuda.memory_cached() / 1024 ** 2)) 90 | 91 | # vis_features = torch.from_numpy(vis_features) 92 | # real_feat = torch.tensor(real_feat).to(device) 93 | # encoder_input = torch.from_numpy(encoder_input) 94 | # decoder_input = torch.from_numpy(decoder_input) 95 | # decoder_target = torch.from_numpy(decoder_target) 96 | # mask = torch.from_numpy(mask) 97 | 98 | return vis_features, real_feat, encoder_input, decoder_input, decoder_target, mask, true_region, len(phrase), len(phrase)-1 99 | -------------------------------------------------------------------------------- /id2idx_data/readme.m: -------------------------------------------------------------------------------- 1 | from process_dataset.py 2 | 3 | Each file in this folder corresponds to an image of the dataset and contains a dictionary which maps an object_ID to its corresponding index in the visual feature matrix. 4 | -------------------------------------------------------------------------------- /images/readme.m: -------------------------------------------------------------------------------- 1 | images for project readme 2 | -------------------------------------------------------------------------------- /images/schematic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acambray/GroundeR-PyTorch/20f241f15ce31aa55ec059eb5be564c7c266ed2b/images/schematic.png -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Aleix Cambray Roma 3 | Work done while at Imperial College London. 4 | """ 5 | import time 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 10 | 11 | 12 | class PhraseEncoder(nn.Module): 13 | """ 14 | Phrase Encoder 15 | Input is sequence of word indices 16 | Output is last LSTM hidden state 17 | """ 18 | def __init__(self, embedding_layer, vocab_size, embed_dim, h_dim): 19 | super(PhraseEncoder, self).__init__() 20 | 21 | # Word embedding layers 22 | # self.embedding = nn.Embedding(vocab_size, embed_dim) 23 | self.embedding = embedding_layer 24 | 25 | # LSTM layer 26 | self.lstm = nn.LSTM(embed_dim, h_dim, batch_first=True) 27 | 28 | def forward(self, phrase_batch, lengths_enc): 29 | """ 30 | Embedding turns sequences of indices (phrase) into sequence of vectors 31 | :param phrase_batch: (BATCH, TIME) i.e. if batch_size=2 and seq_length=5 phrase=[[1,5,4,7,3], [8,2,5,2,4]] 32 | :return: last LSTM output 33 | """ 34 | 35 | # phrase dimensions: (BATCH, TIME) i.e. if batch_size=2 and seq_length=5 phrase=[[1,5,4,7,3], [8,2,5,2,4]] 36 | # embeds dimensions: 37 | 38 | batch_size = phrase_batch.size()[0] 39 | seq_length = phrase_batch.size()[1] 40 | embeds = self.embedding(phrase_batch) 41 | 42 | # Pack 43 | # 1. sort sequences by length 44 | ordered_len, ordered_idx = lengths_enc.sort(0, descending=True) 45 | ordered_embeds = embeds[ordered_idx] 46 | # 2. pack 47 | input_packed = pack_padded_sequence(ordered_embeds, ordered_len, batch_first=True) 48 | 49 | # Feed embeddings to LSTM 50 | # embeds = embeds.view(seq_length, batch_size, -1) 51 | _, (ht, ct) = self.lstm(input_packed) # Hidden is none because we don't initialise the hidden state, could use random noise instead 52 | 53 | # Get final hidden state from LSTM and reverse descending ordering 54 | h = ht[0, :, :] 55 | h[ordered_idx] = h 56 | return h 57 | 58 | 59 | class AttentionModule(nn.Module): 60 | def __init__(self, h_dim, v_dim, embed_dim, out1, regions): 61 | super(AttentionModule, self).__init__() 62 | self.fc1 = nn.Linear(h_dim + v_dim, out1, bias=True) 63 | self.fc2 = nn.Linear(out1, 1, bias=True) 64 | self.fcREC = nn.Linear(v_dim, embed_dim, bias=True) 65 | 66 | def forward(self, h, v): 67 | """ 68 | :param h: encoded phrases dimensions [batch, h_dim] [32, 100] 69 | :param v: visual features matrix dimensions [batch_size, regions, v_dim] [32, 25, 1000] 70 | :return: 71 | """ 72 | batch_size = h.size()[0] 73 | regions = v.size()[1] 74 | h_dim = h.size()[1] 75 | v_dim = v.size()[2] 76 | 77 | # We want to turn h from [batch_size, h_dim] to [batch_size, regions, 1, h_dim] 78 | h1 = h[:, None, None, :] # shape: (batch_size, 1, 1, h_dim) 79 | h_reshaped = h1.repeat(1, regions, 1, 1).type(dtype=torch.float32) # shape: (batch_size, regions, 1, h_dim) 80 | 81 | # Add extra dimension to match 82 | v_reshaped = v[:, :, None, :].type(dtype=torch.float32) # shape: (batch_size, regions, 1, v_dim) 83 | 84 | hv = torch.cat((v_reshaped, h_reshaped), dim=3) 85 | 86 | # View hv (batch_size, regions, 1, hv_dim) to (batch_size*regions, hv_dim) 87 | x = hv.view((hv.size(0)*hv.size(1), hv.size(3))) 88 | x = self.fc1(x) # [ batch*reg, hidden ] 89 | x = F.relu(x) 90 | x = self.fc2(x) # [ batch*reg, 1 ] 91 | x = x.view(hv.size(0), hv.size(1)) # [ batch, reg ] 92 | alpha = F.softmax(x, dim=1) # [ batch, reg ] 93 | att_log = F.log_softmax(x, dim=1) 94 | 95 | # Sum of the elementwise product between alpha and v 96 | alpha_expanded = alpha[:, :, None].expand_as(v) 97 | v_att = torch.sum(torch.mul(v, alpha_expanded), dim=(1,)) 98 | 99 | v_att_dec = self.fcREC(v_att) 100 | v_att_dec = F.relu(v_att_dec) 101 | return v_att, v_att_dec, alpha, att_log 102 | 103 | 104 | class PhraseDecoder(nn.Module): 105 | def __init__(self, embedding_layer, embed_dim, h_dim, vocab_size, v_dim): 106 | super(PhraseDecoder, self).__init__() 107 | 108 | self.embedding = embedding_layer 109 | 110 | self.decoder = nn.LSTM(embed_dim, h_dim, batch_first=True) 111 | self.hidden_to_prob = nn.Linear(h_dim, vocab_size) 112 | 113 | def forward(self, v_att_dec, decoder_input, lengths_dec, teacher_forcing=1): 114 | # TODO: Implement greedy strategy (using 'teacher forcing' at the moment) 115 | 116 | embeds = self.embedding(decoder_input) 117 | 118 | # Concatenate the visual attended context vector and the phrase embedding as the input to the decoder 119 | decoder_input = torch.cat((v_att_dec[:, None, :], embeds[:, :-1, :]), dim=1) # indexes: [batch, timestep, word] 120 | 121 | # Re-order axis to fit PyTorch LSTM convention (seq_length, batch, input_size) 122 | # decoder_input = decoder_input.permute((1, 0, 2)) 123 | 124 | # Pack -------------------------------------------------------------------------------------- 125 | # 1. sort sequences by length 126 | ordered_len, ordered_idx = lengths_dec.sort(0, descending=True) 127 | ordered_embeds = decoder_input[ordered_idx] 128 | # 2. pack 129 | input_packed = pack_padded_sequence(ordered_embeds, ordered_len, batch_first=True) 130 | 131 | # LSTM -------------------------------------------------------------------------------------- 132 | output_packed, hidden = self.decoder(input_packed) 133 | 134 | # Unpack all hidden states------------------------------------------------------------------- 135 | output_sorted, _ = pad_packed_sequence(output_packed, batch_first=True) 136 | 137 | # Fully Connected (hidden state to vocabulary probabilities) -------------------------------- 138 | output = torch.zeros_like(output_sorted) 139 | output[ordered_idx] = output_sorted # shape: batch_size x seq_len x hidden 140 | 141 | input_fc = output.contiguous().view(output.size(0)*output.size(1), output.size(2)) # shape: batch_size*seq_len x hidden 142 | out = self.hidden_to_prob(input_fc) # shape: batch_size*seq_len x vocab_size 143 | out = out.view(output.size(0), output.size(1), out.size(1)) # shape: batch_size x seq_len x vocab_size] 144 | decoder_output = F.log_softmax(out, dim=2) 145 | 146 | return decoder_output 147 | 148 | 149 | class GroundeR(nn.Module): 150 | def __init__(self, vocab, vocab_size, embed_dim, h_dim, v_dim, regions, embeddings_matrix=None, train_embeddings=False): 151 | super(GroundeR, self).__init__() 152 | 153 | if embeddings_matrix is None: 154 | self.embedding = nn.Embedding(vocab_size, embed_dim, sparse=False) 155 | else: 156 | self.embedding = nn.Embedding.from_pretrained(torch.from_numpy(embeddings_matrix).type(torch.float32), freeze=True, sparse=True) 157 | if train_embeddings is False: 158 | self.embedding.weight.requires_grad = False 159 | 160 | self.phrase_encoder = PhraseEncoder(self.embedding, vocab_size, embed_dim, h_dim) 161 | self.attention = AttentionModule(h_dim, v_dim, embed_dim, 100, regions) 162 | self.phrase_decoder = PhraseDecoder(self.embedding, embed_dim, h_dim, vocab_size, v_dim) 163 | 164 | self.vocab = vocab 165 | self.vocab_size = len(vocab) 166 | 167 | def forward(self, encoder_input, lengths_enc, vis_features, decoder_input, lengths_dec): 168 | 169 | # Encode 170 | encoded_batch = self.phrase_encoder(encoder_input, lengths_enc) 171 | 172 | # Attend 173 | v_att, v_att_dec, att, att_log = self.attention(encoded_batch, vis_features) 174 | 175 | # Decode 176 | decoder_output = self.phrase_decoder(v_att_dec, decoder_input, lengths_dec, teacher_forcing=1) 177 | 178 | return decoder_output, att, att_log 179 | 180 | def masked_NLLLoss(self, pred, target, pad_token=0): 181 | max_pred_len = pred.size(1) 182 | Y = target[:, :max_pred_len] 183 | Y = Y.contiguous().view(-1) 184 | 185 | Y_hat = pred.view(-1, pred.size(2)) 186 | 187 | mask = (Y != pad_token).float() 188 | Y_hat_masked = torch.zeros_like(Y, dtype=torch.float32) 189 | 190 | for i, idx in enumerate(Y): 191 | Y_hat_masked[i] = Y_hat[i, idx] 192 | 193 | Y_hat_masked = Y_hat_masked * mask 194 | 195 | n_tokens = torch.sum(mask) 196 | 197 | loss = -torch.sum(Y_hat_masked) / n_tokens 198 | return loss 199 | -------------------------------------------------------------------------------- /process_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Aleix Cambray Roma 3 | Work done while at Imperial College London. 4 | 5 | This code takes in the Flickr30k Entities dataset and generates the data structures necessary for the GroundeR implementation 6 | Output data structures (for each image in dataset) 7 | - Annotation dictionary annotation_data/.pkl - A dictionary containing ids, bboxes and token_index_sequences for each phrase 8 | - Visual features matrix visualfeatures_data/.npy - A matrix in which each column is a visual feature vector for each region 9 | - Link id2idx dict id2idx/.pkl - Matches object id for each phrase to the idx of the correct region (bbox) 10 | """ 11 | 12 | from utils import * 13 | import PIL 14 | from PIL import Image, ImageDraw 15 | import time 16 | import numpy as np 17 | import torchvision.models as models 18 | import torchvision.transforms as transforms 19 | import os 20 | import itertools 21 | import glob 22 | import pickle 23 | 24 | 25 | def crop_and_resize(img, coords, size): 26 | """ 27 | PIL image 'img' is first cropped to a rectangle according to 'coords'. Then the resulting crop is re-scaled according to 'size' 28 | :param img: PIL image 29 | :param coords: rectangle coordinates in the form of a list [x_nw, y_nw, x_se, y_se] 30 | :param size: tuple or list [h_size, v_size] 31 | :return: 32 | """ 33 | img = img.crop(coords) 34 | img = img.resize(size, PIL.Image.LANCZOS) 35 | return img 36 | 37 | 38 | def unify_boxes(boxes): 39 | """ 40 | This function turns a bunch of bounding boxes into one bounding box that encompasses all of them 41 | :param boxes: List of lists, each sub-list is a bounding box [x_nw, y_nw, x_se, y_se] 42 | :return: 1 bounding box 43 | """ 44 | boxes = np.array(boxes) 45 | xmin = np.amin(boxes[:, 0]) 46 | ymin = np.amin(boxes[:, 1]) 47 | xmax = np.amax(boxes[:, 2]) 48 | ymax = np.amax(boxes[:, 3]) 49 | return [xmin, ymin, xmax, ymax] 50 | 51 | 52 | def phrase2seq(phrase, word2idx, vocab): 53 | """ 54 | This function turns a sequence of words into a sequence of indexed tokens according to a vocabulary and its word2idx mapping dictionary. 55 | :param phrase: List of strings 56 | :param word2idx: Dictionary 57 | :param vocab: Set or list containing entire vocabulary as strings 58 | :return: List of integers. Each integer being the token index of the word in the phrase according to the vocabulary. 59 | """ 60 | # and tokens 61 | phrase = [''] + phrase 62 | phrase = phrase + [''] 63 | 64 | phrase_seq = [0]*len(phrase) 65 | for i, word in enumerate(phrase): 66 | if word in vocab: 67 | phrase_seq[i] = word2idx[word] 68 | else: 69 | phrase_seq[i] = word2idx[''] 70 | return phrase_seq 71 | 72 | 73 | if __name__ == "__main__": 74 | build_vocab = True # If True, the vocabulary will be built again and saved. If False, the last vocabulary will be loaded. 75 | generate_annotations = True 76 | generate_visual = False 77 | 78 | data_folder = "C:/Data/GroundeR/data/" 79 | 80 | train_txt = open(data_folder + 'train.txt', 'r') 81 | val_txt = open(data_folder + 'val.txt', 'r') 82 | test_txt = open(data_folder + 'test.txt', 'r') 83 | 84 | # Load RESNET 85 | transform = transforms.Compose([ 86 | transforms.ToTensor(), 87 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 88 | std=[0.229, 0.224, 0.225]), 89 | ]) 90 | 91 | resnet = models.resnet101(pretrained=True) 92 | resnet.eval() 93 | resnet.cuda() 94 | 95 | #################################################################### 96 | # BUILD VOCABULARY # 97 | #################################################################### 98 | if build_vocab: 99 | print("Building Vocabulary") 100 | vocabulary = ['', '', ''] 101 | m = 1 102 | ids = [f[:-4] for f in os.listdir(data_folder + 'annotations/Sentences') if f.endswith('.txt')] 103 | N = len(ids) 104 | for img_n, img_id in enumerate(ids): 105 | print("\rImg {}/{}".format(img_n, N), end="") 106 | sentence_path = data_folder + 'annotations/Sentences/' + img_id + '.txt' 107 | corefData = get_sentence_data(sentence_path) 108 | for description in corefData: 109 | for word in description['sentence'].lower().split(' '): 110 | if word not in vocabulary: 111 | vocabulary.append(word) 112 | m += 1 113 | 114 | vocabulary.append('') 115 | 116 | word2idx = {w: idx for (idx, w) in enumerate(vocabulary)} 117 | idx2word = {idx: w for (idx, w) in enumerate(vocabulary)} 118 | with open('vocabulary.pkl', 'wb') as f: 119 | pickle.dump((word2idx, idx2word), f, protocol=pickle.HIGHEST_PROTOCOL) 120 | 121 | #################################################################### 122 | # BUILD ANNOTATION INPUT DATA STRUCTURE # 123 | #################################################################### 124 | if generate_annotations or generate_visual: 125 | with open(data_folder + "vocabulary.pkl", "rb") as f: 126 | word2idx, idx2word = pickle.load(f) 127 | vocab = set(word2idx.keys()) 128 | 129 | # EXAMPLES LOOP 130 | phrase_count = 0 131 | ids = [f[:-4] for f in os.listdir(data_folder + 'annotations/Sentences') if f.endswith('.txt')] 132 | N = len(ids) 133 | for img_n, img_id in enumerate(ids): 134 | img_id = img_id.replace('\n', '') 135 | img_path = data_folder + 'flickr30k-images' + img_id + '. jpg' 136 | 137 | # Get Sentence and Annotation info for this image 138 | corefData = get_sentence_data(data_folder + 'annotations/Sentences/'+img_id+'.txt') 139 | annotationData = get_annotations(data_folder + 'annotations/Annotations/'+img_id+'.xml') 140 | 141 | ids = [] 142 | bboxes = [] 143 | seqs = [] 144 | num_of_sentences = len(corefData) 145 | if generate_annotations: 146 | # DESCRIPTIONS LOOP 147 | for description in corefData: 148 | 149 | # PHRASES LOOP 150 | for phrase in description['phrases']: 151 | # Get object ID 152 | obj_id = phrase['phrase_id'] 153 | 154 | # Check if this phrase has a box assigned. If not, then skip phrase. 155 | if obj_id not in list(annotationData['boxes'].keys()) or obj_id == '0': 156 | continue 157 | ids.append(obj_id) 158 | 159 | # Obtain box coordinates for this phrase 160 | boxes = annotationData['boxes'][obj_id] 161 | box = unify_boxes(boxes) if len(boxes) > 1 else boxes[0] 162 | bboxes.append(box) 163 | 164 | # Turn phrase from sequence of strings into sequence of indexes 165 | phrase = phrase['phrase'].lower().split() 166 | phrase_seq = phrase2seq(phrase, word2idx, vocab) 167 | seqs.append(phrase_seq) 168 | phrase_count += 1 169 | 170 | image_annotations = {'bboxes': bboxes, 'ids': ids, 'seqs': seqs} 171 | with open('C:/Data/GroundeR/annotation_data/'+img_id+'.pkl', 'wb') as f: 172 | pickle.dump(image_annotations, f, protocol=pickle.HIGHEST_PROTOCOL) 173 | if img_n % 100 == 0: 174 | print(f"\rImage {img_n}/{N} - Phrases: {phrase_count}", end="") 175 | 176 | if generate_visual: 177 | ############################################################# 178 | # EXTRACT IMAGE FEATURES (RESNET101) # 179 | ############################################################# 180 | 181 | img = Image.open(data_folder + 'flickr30k-images/' + img_id + '.jpg') 182 | 183 | # Build id to idx dictionary 184 | id2idx = {id: idx for (idx, id) in enumerate(list(annotationData['boxes'].keys()))} 185 | 186 | # LOOP THROUGH ALL BOXES 187 | objects = list(annotationData['boxes'].keys()) 188 | vis_matrix = np.zeros((len(objects), 1000), dtype=float) 189 | for id in objects: 190 | # For each Object: extract boxes and unify them 191 | boxes = annotationData['boxes'][id] 192 | box = unify_boxes(boxes) if len(boxes) > 1 else boxes[0] 193 | # For each box: crop original img to box, resize crop to 224z224, normalise image 194 | box_img = crop_and_resize(img, box, (224, 224)) 195 | box_img = transform(box_img) 196 | box_img = box_img.unsqueeze(0) 197 | box_img = box_img.cuda() 198 | # Feed image to ResNet-101 and add to visual feature matrix 199 | vis_feature = resnet(box_img) 200 | vis_matrix[id2idx[id], :] = vis_feature.cpu().detach().numpy() 201 | 202 | np.save('C:/Data/GroundeR/visualfeatures_data/'+img_id, vis_matrix) 203 | with open('C:/Data/GroundeR/id2idx_data/'+img_id+'.pkl', 'wb') as f: 204 | pickle.dump(id2idx, f) 205 | -------------------------------------------------------------------------------- /process_dataset_embeddings.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Aleix Cambray Roma 3 | Work done while at Imperial College London. 4 | 5 | This code takes in the Flickr30k Entities dataset and generates the data structures necessary for the GroundeR implementation 6 | Output data structures (for each image in dataset) 7 | - Annotation dictionary annotation_data/.pkl - A dictionary containing ids, bboxes and token_index_sequences for each phrase 8 | - Visual features matrix visualfeatures_data/.npy - A matrix in which each column is a visual feature vector for each region 9 | - Link id2idx dict id2idx/.pkl - Matches object id for each phrase to the idx of the correct region (bbox) 10 | """ 11 | 12 | from utils import * 13 | import PIL 14 | from PIL import Image, ImageDraw 15 | import time 16 | import numpy as np 17 | import torchvision.models as models 18 | import torchvision.transforms as transforms 19 | import os 20 | import itertools 21 | import glob 22 | import pickle 23 | 24 | 25 | def crop_and_resize(img, coords, size): 26 | """ 27 | PIL image 'img' is first cropped to a rectangle according to 'coords'. Then the resulting crop is re-scaled according to 'size' 28 | :param img: PIL image 29 | :param coords: rectangle coordinates in the form of a list [x_nw, y_nw, x_se, y_se] 30 | :param size: tuple or list [h_size, v_size] 31 | :return: 32 | """ 33 | img = img.crop(coords) 34 | img = img.resize(size, PIL.Image.LANCZOS) 35 | return img 36 | 37 | 38 | def unify_boxes(boxes): 39 | """ 40 | This function turns a bunch of bounding boxes into one bounding box that encompasses all of them 41 | :param boxes: List of lists, each sub-list is a bounding box [x_nw, y_nw, x_se, y_se] 42 | :return: 1 bounding box 43 | """ 44 | boxes = np.array(boxes) 45 | xmin = np.amin(boxes[:, 0]) 46 | ymin = np.amin(boxes[:, 1]) 47 | xmax = np.amax(boxes[:, 2]) 48 | ymax = np.amax(boxes[:, 3]) 49 | return [xmin, ymin, xmax, ymax] 50 | 51 | 52 | def phrase2seq(phrase, word2idx, vocab): 53 | """ 54 | This function turns a sequence of words into a sequence of indexed tokens according to a vocabulary and its word2idx mapping dictionary. 55 | :param phrase: List of strings 56 | :param word2idx: Dictionary 57 | :param vocab: Set or list containing entire vocabulary as strings 58 | :return: List of integers. Each integer being the token index of the word in the phrase according to the vocabulary. 59 | """ 60 | # and tokens 61 | phrase = [''] + phrase 62 | phrase = phrase + [''] 63 | 64 | phrase_seq = [0]*len(phrase) 65 | for i, word in enumerate(phrase): 66 | if word in vocab: 67 | phrase_seq[i] = word2idx[word] 68 | else: 69 | phrase_seq[i] = word2idx[''] 70 | return phrase_seq 71 | 72 | 73 | if __name__ == "__main__": 74 | build_vocab = True # If True, the vocabulary will be built again and saved. If False, the last vocabulary will be loaded. 75 | generate_annotations = True 76 | generate_visual = False 77 | 78 | data_folder = "C:/Data/GroundeR/data/" 79 | 80 | train_txt = open(data_folder + 'train.txt', 'r') 81 | val_txt = open(data_folder + 'val.txt', 'r') 82 | test_txt = open(data_folder + 'test.txt', 'r') 83 | 84 | # Load RESNET 85 | transform = transforms.Compose([ 86 | transforms.ToTensor(), 87 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 88 | std=[0.229, 0.224, 0.225]), 89 | ]) 90 | 91 | resnet = models.resnet101(pretrained=True) 92 | resnet.eval() 93 | resnet.cuda() 94 | 95 | #################################################################### 96 | # BUILD VOCABULARY # 97 | #################################################################### 98 | if build_vocab: 99 | embeddings_path = "C:/Data/Embeddings/" + "glove.6B/glove.6B.50d.txt" 100 | print("Loading Glove Model") 101 | 102 | f = open(embeddings_path, 'rb') 103 | vocabulary = ['', '', '', ''] 104 | 105 | # Generate random vectors for start, end, pad and unk tokens 106 | word2vec = {} 107 | word2vec[''] = np.random.normal(loc=0.0, scale=1, size=50) 108 | word2vec[''] = np.random.normal(loc=0.0, scale=1, size=50) 109 | word2vec[''] = np.random.normal(loc=0.0, scale=1, size=50) 110 | word2vec[''] = np.random.normal(loc=0.0, scale=1, size=50) 111 | t0 = time.time() 112 | i = 0 113 | for line in f: 114 | splitLine = line.decode().split() 115 | word = splitLine[0] 116 | vocabulary.append(word) 117 | embedding = np.array([float(val) for val in splitLine[1:]]) 118 | word2vec[word] = embedding 119 | if i % 100 == 0: 120 | print("\r{} - t={:0.10f}s".format(i, time.time() - t0), end="") 121 | t0 = time.time() 122 | i += 1 123 | 124 | word2idx = {w: idx for (idx, w) in enumerate(vocabulary)} 125 | idx2word = {idx: w for (idx, w) in enumerate(vocabulary)} 126 | 127 | idx2vec = {} 128 | for idx in idx2word.keys(): 129 | idx2vec[idx] = word2vec[idx2word[idx]] 130 | 131 | weight_matrix = np.zeros((len(vocabulary), 50)) 132 | for idx in range(weight_matrix.shape[0]): 133 | weight_matrix[idx, :] = idx2vec[idx] 134 | 135 | pickle.dump((vocabulary, word2idx, idx2word, weight_matrix), open("C:/Data/Embeddings/glove.6B/GloVe_6B_50d_vocabulary.pkl", 'wb')) 136 | 137 | print("\nDone.", len(vocabulary), " words loaded!") 138 | 139 | #################################################################### 140 | # BUILD ANNOTATION INPUT DATA STRUCTURE # 141 | #################################################################### 142 | if generate_annotations or generate_visual: 143 | with open("C:/Data/Embeddings/glove.6B/GloVe_6B_50d_vocabulary.pkl", "rb") as f: 144 | vocab, word2idx, idx2word, _ = pickle.load(f) 145 | vocab = set(vocab) 146 | 147 | # EXAMPLES LOOP 148 | phrase_count = 0 149 | img_ids = [f[:-4] for f in os.listdir(data_folder + 'annotations/Sentences') if f.endswith('.txt')] 150 | N = len(img_ids) 151 | t0 = time.time() 152 | for img_n, img_id in enumerate(img_ids): 153 | img_id = img_id.replace('\n', '') 154 | img_path = data_folder + 'flickr30k-images' + img_id + '. jpg' 155 | 156 | # Get Sentence and Annotation info for this image 157 | corefData = get_sentence_data(data_folder + 'annotations/Sentences/'+img_id+'.txt') 158 | annotationData = get_annotations(data_folder + 'annotations/Annotations/'+img_id+'.xml') 159 | 160 | ids = [] 161 | bboxes = [] 162 | seqs = [] 163 | num_of_sentences = len(corefData) 164 | 165 | if generate_annotations: 166 | # DESCRIPTIONS LOOP 167 | for description in corefData: 168 | 169 | # PHRASES LOOP 170 | for phrase in description['phrases']: 171 | # Get object ID 172 | obj_id = phrase['phrase_id'] 173 | 174 | # Check if this phrase has a box assigned. If not, then skip phrase. 175 | if obj_id not in list(annotationData['boxes'].keys()) or obj_id == '0': 176 | continue 177 | ids.append(obj_id) 178 | 179 | # Obtain box coordinates for this phrase 180 | boxes = annotationData['boxes'][obj_id] 181 | box = unify_boxes(boxes) if len(boxes) > 1 else boxes[0] 182 | bboxes.append(box) 183 | 184 | # Turn phrase from sequence of strings into sequence of indexes 185 | phrase = phrase['phrase'].lower().split() 186 | phrase_seq = phrase2seq(phrase, word2idx, vocab) 187 | seqs.append(phrase_seq) 188 | phrase_count += 1 189 | 190 | image_annotations = {'bboxes': bboxes, 'ids': ids, 'seqs': seqs} 191 | with open('C:/Data/GroundeR/annotation_data/'+img_id+'.pkl', 'wb') as f: 192 | pickle.dump(image_annotations, f, protocol=pickle.HIGHEST_PROTOCOL) 193 | if img_n % 100 == 0: 194 | dt = time.time() - t0 195 | print(f"\rImage {img_n}/{N} - dt_100 = {dt} - Phrases: {phrase_count}", end="") 196 | t0 = time.time() 197 | 198 | if generate_visual: 199 | ############################################################# 200 | # EXTRACT IMAGE FEATURES (RESNET101) # 201 | ############################################################# 202 | 203 | img = Image.open(data_folder + 'flickr30k-images/' + img_id + '.jpg') 204 | 205 | # Build id to idx dictionary 206 | id2idx = {id: idx for (idx, id) in enumerate(list(annotationData['boxes'].keys()))} 207 | 208 | # LOOP THROUGH ALL BOXES 209 | objects = list(annotationData['boxes'].keys()) 210 | vis_matrix = np.zeros((len(objects), 1000), dtype=float) 211 | for id in objects: 212 | # For each Object: extract boxes and unify them 213 | boxes = annotationData['boxes'][id] 214 | box = unify_boxes(boxes) if len(boxes) > 1 else boxes[0] 215 | # For each box: crop original img to box, resize crop to 224z224, normalise image 216 | box_img = crop_and_resize(img, box, (224, 224)) 217 | box_img = transform(box_img) 218 | box_img = box_img.unsqueeze(0) 219 | box_img = box_img.cuda() 220 | # Feed image to ResNet-101 and add to visual feature matrix 221 | vis_feature = resnet(box_img) 222 | vis_matrix[id2idx[id], :] = vis_feature.cpu().detach().numpy() 223 | 224 | np.save('C:/Data/GroundeR/visualfeatures_data/'+img_id, vis_matrix) 225 | with open('C:/Data/GroundeR/id2idx_data/'+img_id+'.pkl', 'wb') as f: 226 | pickle.dump(id2idx, f) 227 | -------------------------------------------------------------------------------- /train_unsupervised_withval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Aleix Cambray Roma 3 | Work done while at Imperial College London. 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch import optim 9 | from torch.utils.data import DataLoader 10 | 11 | import time 12 | import copy 13 | import pickle 14 | import numpy as np 15 | from model import GroundeR 16 | import matplotlib.pyplot as plt 17 | from dataloader_unsupervised import UnsupervisedDataLoader 18 | 19 | 20 | def print_gpu_memory(first_string=None): 21 | if first_string is not None: 22 | print(first_string) 23 | print(' current memory allocated: {}'.format(torch.cuda.memory_allocated() / 1024 ** 2)) 24 | print(' max memory allocated: {}'.format(torch.cuda.max_memory_allocated() / 1024 ** 2)) 25 | print(' cached memory: {}'.format(torch.cuda.memory_cached() / 1024 ** 2)) 26 | 27 | 28 | def seq2phrase(seq): 29 | return " ".join([idx2word[idx.item()] for idx in seq]) 30 | 31 | 32 | if __name__ == '__main__': 33 | print("Running on {}".format(torch.cuda.get_device_name(0))) 34 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 35 | # device = 'cpu' 36 | print_gpu_memory() 37 | 38 | # Load vocabulary and embeddings 39 | # vocab_file = "C:/Data/GroundeR/vocabulary.pkl" 40 | vocab_file = "C:/Data/Embeddings/glove.6B/GloVe_6B_50d_vocabulary.pkl" 41 | with open(vocab_file, "rb") as f: 42 | vocab, word2idx, idx2word, weight_matrix = pickle.load(f) 43 | vocab = list(word2idx.keys()) 44 | 45 | # Config 46 | vocab_size = len(vocab) 47 | embed_dim = 50 48 | h_dim = 100 49 | v_dim = 1000 50 | 51 | # Hyper-parameters 52 | regions = 25 53 | learning_rate = 0.00025 54 | epochs = 100 55 | batch_size = 30 56 | max_seq_length = 10 57 | L = 5 58 | 59 | # Build data pipeline providers 60 | dataset = UnsupervisedDataLoader(data_dir="C:/Data/GroundeR/", sample_list_file="flickr30k_train_val.txt", seq_length=max_seq_length) 61 | dataset_length = len(dataset) 62 | 63 | train_length = int(0.8*dataset_length) 64 | val_length = dataset_length - train_length 65 | train_ds, val_ds = torch.utils.data.random_split(dataset, [train_length, val_length]) 66 | 67 | print("Training dataset size: {}".format(len(train_ds))) 68 | print("Val dataset size: {}".format(len(val_ds))) 69 | train_dataloader = DataLoader(train_ds, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=3) 70 | val_dataloader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=3) 71 | 72 | # Instantiate model and send to GPU if available 73 | model = GroundeR(vocab, vocab_size, embed_dim, h_dim, v_dim, regions, weight_matrix, train_embeddings=False) 74 | model = model.to(device) 75 | print_gpu_memory(" After model to GPU:") 76 | criterion_att = nn.NLLLoss() 77 | criterion_rec = nn.NLLLoss(ignore_index=0) 78 | # optimizer = optim.SGD(model.parameters(), lr=learning_rate) 79 | optimizer = optim.Adam(model.parameters(), lr=learning_rate) 80 | 81 | t0 = time.time() 82 | start = time.time() 83 | min_loss = 10000 84 | times = [] 85 | times_total = [] 86 | training_losses = [] 87 | training_accuracies = [] 88 | epoch_training_losses = [] 89 | epoch_training_accuracies = [] 90 | epoch_val_losses = [] 91 | epoch_val_accuracies = [] 92 | iter_num = [] 93 | for epoch in range(epochs): 94 | for phase in ['train', 'val']: 95 | if phase == 'train': 96 | dataloader = train_dataloader 97 | model.train() # Set model to training mode 98 | else: 99 | dataloader = val_dataloader 100 | model.eval() # Set model to evaluation mode 101 | 102 | epoch_samples = 0 103 | running_loss = 0.0 104 | running_corrects = 0.0 105 | batch_i = 0 106 | for vis_features, real_feat, encoder_input, decoder_input, decoder_target, mask, region_true, lengths_enc, lengths_dec in dataloader: 107 | torch.cuda.empty_cache() 108 | dt_load = (time.time() - t0) * 1000 109 | 110 | # Send all data to GPU if available 111 | t1 = time.time() 112 | vis_features = vis_features.to(device) 113 | real_feat = real_feat.to(device) 114 | encoder_input = encoder_input.to(device) 115 | decoder_input = decoder_input.to(device) 116 | decoder_target = decoder_target.to(device) 117 | mask = mask.to(device) 118 | region_true = region_true.to(device) 119 | # lengths_enc = lengths_enc.to(device) 120 | # lengths_dec = lengths_dec.to(device) 121 | dt_togpu = (time.time() - t1) * 1000 122 | print_gpu_memory(" After data to GPU:") 123 | 124 | # Zero the parameter gradients 125 | optimizer.zero_grad() 126 | 127 | # Forward Pass 128 | t2 = time.time() 129 | decoder_output, att, att_log = model(encoder_input, lengths_enc, vis_features, decoder_input, lengths_dec) 130 | dt_forward = (time.time() - t2) * 1000 131 | print_gpu_memory(" After forward run:") 132 | 133 | # Loss 134 | # TODO: [done] Mask loss to ignore pad tokens (using the mask tensor for each sample) 135 | t3 = time.time() 136 | pred = decoder_output.view(decoder_output.size(0)*decoder_output.size(1), decoder_output.size(2)) 137 | target = decoder_target[:, :decoder_output.size(1)] # truncate the decoder_target from length 10 to maximum sequence length in batch 138 | target = target.contiguous().view(-1) 139 | loss_att = criterion_att(att_log, region_true) 140 | # loss_rec = model.masked_NLLLoss(decoder_output, decoder_target) 141 | loss_rec = criterion_rec(pred, target) 142 | loss = L*loss_att + loss_rec 143 | dt_loss = (time.time() - t3) * 1000 144 | print_gpu_memory(" After loss calc:") 145 | 146 | # Accuracy 147 | region_pred = att.max(dim=1)[1] 148 | corrects = torch.sum(region_true == region_pred).item() 149 | running_corrects += corrects 150 | accuracy = corrects / region_true.size(0) 151 | 152 | if phase == 'train': 153 | # Backward pass and parameter update 154 | # TODO: [done] Find out why loss isn't getting updates (initialisation?) 155 | t4 = time.time() 156 | loss.backward() 157 | print_gpu_memory(" After backward run:") 158 | optimizer.step() 159 | print_gpu_memory(" After optimizer step:") 160 | dt_backward = (time.time() - t4) * 1000 161 | print("{:02.0f}.{:03.0f} - Sample {:05.0f}/30781 - Accuracy: {:0.2f}% - Loss: {:02.7f} - GPU: {:0.2f} MB - Load time = {:06.2f}ms - toGPU time = {:06.2f}ms - Forward time = {:06.2f}ms - Loss: {:06.2f}ms - Backward {:06.2f}ms - Time {:05.2f}s".format(epoch, batch_i + 1, (batch_i + 1) * batch_size, accuracy*100, loss.item(), torch.cuda.memory_allocated() / 1024 ** 2, dt_load, dt_togpu, dt_forward, dt_loss, dt_backward, (time.time()-start))) 162 | training_losses.append(loss.item()) # Appends loss over entire batch (reduce=mean) 163 | training_accuracies.append(accuracy) 164 | times.append(dt_load + dt_togpu) 165 | times_total.append(time.time() - t0) 166 | 167 | # statistics & counters 168 | epoch_samples += vis_features.size(0) 169 | running_loss += loss.item() * vis_features.size(0) # Sum of losses over all samples in batch 170 | 171 | batch_i += 1 172 | t0 = time.time() 173 | 174 | # Print real and predicted sentences: 175 | for i_print in range(3): 176 | print(" {} - Real: {}".format(i_print, seq2phrase(decoder_target[i_print]))) 177 | print(" {} - Pred: {}".format(i_print, seq2phrase(decoder_output.max(dim=2)[1][i_print]))) 178 | 179 | # Track epoch losses for both training and validation phases 180 | if phase == 'train': 181 | print("Training Epoch Performance on {} samples".format(epoch_samples)) 182 | epoch_loss = running_loss / epoch_samples 183 | epoch_training_losses.append(epoch_loss) 184 | epoch_training_accuracies.append(running_corrects/epoch_samples) 185 | iter_num.append(len(training_losses)) 186 | elif phase == 'val': 187 | print("Validation Performance on {} samples".format(epoch_samples)) 188 | epoch_loss = running_loss / epoch_samples 189 | epoch_val_losses.append(epoch_loss) 190 | epoch_val_accuracies.append(running_corrects/epoch_samples) 191 | # Save weights of best val-performing model 192 | if epoch_loss < min_loss: 193 | min_loss = epoch_loss 194 | best_model_wts = copy.deepcopy(model.state_dict()) 195 | torch.save(model.state_dict(), "best_model.pt") 196 | 197 | # Print learning profile every 5 epochs 198 | if phase == 'val': 199 | plt.title("LR: {}".format(learning_rate)) 200 | plt.plot(training_losses, label="Training Losses") 201 | plt.plot(iter_num, epoch_training_losses, marker='o', label="Epoch Training losses") 202 | plt.plot(iter_num, epoch_val_losses, marker='o', label="Epoch Val Losses") 203 | plt.legend() 204 | plt.savefig('learning_profile.png') 205 | plt.clf() 206 | plt.plot(training_accuracies, label="Training Accuracies per batch") 207 | plt.plot(iter_num, epoch_training_accuracies, label="Epoch Training Accuracies") 208 | plt.plot(iter_num, epoch_val_accuracies, label="Epoch Val Accuracies") 209 | plt.title("Region Accuracies") 210 | plt.savefig('accuracies.png') 211 | plt.clf() 212 | 213 | plt.plot(times) 214 | plt.savefig('final_times.png') 215 | 216 | plt.plot(times_total) 217 | plt.savefig('final_times_total.png') 218 | 219 | plt.plot(training_losses, label="Training Losses") 220 | plt.plot(iter_num, epoch_training_losses, label="Epoch Training losses") 221 | plt.plot(iter_num, epoch_val_losses, label="Epoch Val Losses") 222 | plt.legend() 223 | plt.savefig('final_learning_profile.png') 224 | 225 | print("\nTime per sample: {}".format((time.time()-t0)/30781)) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import xml.etree.ElementTree as ET 2 | 3 | def get_sentence_data(fn): 4 | """ 5 | Parses a sentence file from the Flickr30K Entities dataset 6 | 7 | input: 8 | fn - full file path to the sentence file to parse 9 | 10 | output: 11 | a list of dictionaries for each sentence with the following fields: 12 | sentence - the original sentence 13 | phrases - a list of dictionaries for each phrase with the 14 | following fields: 15 | phrase - the text of the annotated phrase 16 | first_word_index - the position of the first word of 17 | the phrase in the sentence 18 | phrase_id - an identifier for this phrase 19 | phrase_type - a list of the coarse categories this 20 | phrase belongs to 21 | 22 | """ 23 | with open(fn, 'r', encoding="utf8") as f: 24 | sentences = f.read().split('\n') 25 | 26 | annotations = [] 27 | for sentence in sentences: 28 | if not sentence: 29 | continue 30 | 31 | first_word = [] 32 | phrases = [] 33 | phrase_id = [] 34 | phrase_type = [] 35 | words = [] 36 | current_phrase = [] 37 | add_to_phrase = False 38 | for token in sentence.split(): 39 | if add_to_phrase: 40 | if token[-1] == ']': 41 | add_to_phrase = False 42 | token = token[:-1] 43 | current_phrase.append(token) 44 | phrases.append(' '.join(current_phrase)) 45 | current_phrase = [] 46 | else: 47 | current_phrase.append(token) 48 | 49 | words.append(token) 50 | else: 51 | if token[0] == '[': 52 | add_to_phrase = True 53 | first_word.append(len(words)) 54 | parts = token.split('/') 55 | phrase_id.append(parts[1][3:]) 56 | phrase_type.append(parts[2:]) 57 | else: 58 | words.append(token) 59 | 60 | sentence_data = {'sentence' : ' '.join(words), 'phrases' : []} 61 | for index, phrase, p_id, p_type in zip(first_word, phrases, phrase_id, phrase_type): 62 | sentence_data['phrases'].append({'first_word_index' : index, 63 | 'phrase' : phrase, 64 | 'phrase_id' : p_id, 65 | 'phrase_type' : p_type}) 66 | 67 | annotations.append(sentence_data) 68 | 69 | return annotations 70 | 71 | def get_annotations(fn): 72 | """ 73 | Parses the xml files in the Flickr30K Entities dataset 74 | 75 | input: 76 | fn - full file path to the annotations file to parse 77 | 78 | output: 79 | dictionary with the following fields: 80 | scene - list of identifiers which were annotated as 81 | pertaining to the whole scene 82 | nobox - list of identifiers which were annotated as 83 | not being visible in the image 84 | boxes - a dictionary where the fields are identifiers 85 | and the values are its list of boxes in the 86 | [xmin ymin xmax ymax] format 87 | """ 88 | tree = ET.parse(fn) 89 | root = tree.getroot() 90 | size_container = root.findall('size')[0] 91 | anno_info = {'boxes' : {}, 'scene' : [], 'nobox' : []} 92 | for size_element in size_container: 93 | anno_info[size_element.tag] = int(size_element.text) 94 | 95 | for object_container in root.findall('object'): 96 | for names in object_container.findall('name'): 97 | box_id = names.text 98 | box_container = object_container.findall('bndbox') 99 | if len(box_container) > 0: 100 | if box_id not in anno_info['boxes']: 101 | anno_info['boxes'][box_id] = [] 102 | xmin = int(box_container[0].findall('xmin')[0].text) - 1 103 | ymin = int(box_container[0].findall('ymin')[0].text) - 1 104 | xmax = int(box_container[0].findall('xmax')[0].text) - 1 105 | ymax = int(box_container[0].findall('ymax')[0].text) - 1 106 | anno_info['boxes'][box_id].append([xmin, ymin, xmax, ymax]) 107 | else: 108 | nobndbox = int(object_container.findall('nobndbox')[0].text) 109 | if nobndbox > 0: 110 | anno_info['nobox'].append(box_id) 111 | 112 | scene = int(object_container.findall('scene')[0].text) 113 | if scene > 0: 114 | anno_info['scene'].append(box_id) 115 | 116 | return anno_info 117 | 118 | 119 | -------------------------------------------------------------------------------- /visualfeatures_data/readme.m: -------------------------------------------------------------------------------- 1 | processed visual feature matrices (one per image on Flickr30k) from process_dataset.py 2 | --------------------------------------------------------------------------------