├── LICENSE
├── README.md
├── annotation_data
└── readme.m
├── dataloader_unsupervised.py
├── id2idx_data
└── readme.m
├── images
├── readme.m
└── schematic.png
├── model.py
├── process_dataset.py
├── process_dataset_embeddings.py
├── train_unsupervised_withval.py
├── utils.py
└── visualfeatures_data
└── readme.m
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Aleix Cambray
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # GroundeR PyTorch Implementation
2 | >*Note: This project is functional yet unpolished for release, and so the below description and documentation is also. No more work is expected in the near future.
3 | It is uploaded in the hopes that it is useful for you in its current form.*
4 |
5 | This is a PyTorch implementation of the supervised and unsupervised GroundeR model from [Grounding of Textual Phrases in Images by Reconstruction](https://arxiv.org/pdf/1511.03745.pdf).
6 |
7 | The task is to localize what region of the image, a phrase is referring to. For example, if we have a description "**A man** is jumping over **a fence**" we would like to ground both entities to specific regions of the image. The task of phrase localization or phrase grounding is useful for problems which depend on these unknown mappings between entities and regions such as image captioning and multi-modal neural machine translation.
8 |
9 | ## Requisites
10 | Python 3.6
11 | \- Framework: PyTorch 1.0.1
12 | \- torchvision
13 | \- Numpy
14 | \- PIL
15 | \- pickle
16 | \- matplotlib
17 |
18 | ## Supervised and Unsupervised versions
19 | With reference to the below schematic.
20 |
21 | The supervised version of the code assumes we have ground truth of which region in the image the phrase refers to. Therefore the loss is computed as cross-entropy between the visual attention vector (alpha) and the one-hot ground truth vector. For a fully supervised version there is no need to run the decoder part of the model.
22 |
23 | On the other hand, the unsupervised version does not assume we have ground truth of the correct region and therefore relies on the reconstruction loss given by the cross-entropy between the phrase and the output of the decoder.
24 |
25 | ## Model Design and Architecture
26 | Below is a schematic of the exact implementation (and design decisions) of the code in this repository.
27 |
28 |
29 |
30 |
31 |
--------------------------------------------------------------------------------
/annotation_data/readme.m:
--------------------------------------------------------------------------------
1 | processed annotation and co-reference data (from process_dataset.py)
2 |
--------------------------------------------------------------------------------
/dataloader_unsupervised.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Aleix Cambray Roma
3 | Work done while at Imperial College London.
4 | """
5 | import torch
6 | import pickle
7 | import numpy as np
8 | from numpy.random import randint
9 | from torch.utils.data import Dataset
10 |
11 |
12 | class UnsupervisedDataLoader(Dataset):
13 | """ Loads Flickr30k data for the unsupervised, reconstruction case. """
14 | def __init__(self, data_dir, sample_list_file, seq_length):
15 | def read_sample_list(sample_list_file):
16 | f = open(data_dir + sample_list_file)
17 | return np.array([sample_id.strip() for sample_id in f.readlines()])
18 |
19 | self.seq_length = seq_length
20 | self.sample_id_list = read_sample_list(sample_list_file)
21 | self.data_dir = data_dir
22 |
23 | with open(self.data_dir + "vocabulary.pkl", "rb") as f:
24 | self.word2idx, self.idx2word = pickle.load(f)
25 | self.vocab = list(self.word2idx.keys())
26 |
27 | def __len__(self):
28 | return len(self.sample_id_list)
29 |
30 | def __getitem__(self, i):
31 | sample_id = self.sample_id_list[i]
32 |
33 | # Get Visual feature matrix
34 | vis_features = np.zeros((25, 1000), dtype='float32')
35 | real_feat = np.load(self.data_dir + "visualfeatures_data/" + sample_id + ".npy")
36 | vis_features[:real_feat.shape[0], :] = real_feat
37 | real_feat = real_feat.shape[0]
38 |
39 | # Get annotations (all phrases in this image)
40 | with open(self.data_dir + "annotation_data/" + sample_id + ".pkl", "rb") as f:
41 | annotations = pickle.load(f)
42 | phrases = annotations['seqs']
43 |
44 | # Select random phrase from this image
45 | # print(len(phrases))
46 | if len(phrases) == 0:
47 | print(sample_id)
48 | print(real_feat)
49 | print("")
50 |
51 | rand_int = randint(len(phrases))
52 | phrase = phrases[rand_int]
53 |
54 | if len(phrase) > self.seq_length:
55 | phrase[self.seq_length-1] = phrase[-1] # Add end tag at last step
56 | phrase = phrase[:self.seq_length] # Truncate phrase to seq_length
57 |
58 | # Initialise all phrase arrays with the padding symbol
59 | pad_idx = self.word2idx['']
60 | encoder_input = np.ones(self.seq_length, dtype='int64')*pad_idx
61 | decoder_input = np.ones(self.seq_length, dtype='int64')*pad_idx
62 | decoder_target = np.ones(self.seq_length, dtype='int64')*pad_idx
63 | mask = np.zeros(self.seq_length, dtype='int64')
64 |
65 | # Replace the first padding symbols by the real phrase
66 | encoder_input[0:len(phrase)] = np.array(phrase) # Feed to encoder LSTM: Both or tags
67 | decoder_input[0:len(phrase)-2] = np.array(phrase[1:-1]) # Feed to decoder LSTM: No tags
68 | decoder_target[0:len(phrase)-1] = np.array(phrase[1:]) # Used as reconstruction ground truth: Only tag
69 | mask[0:len(phrase)-1] = 1
70 |
71 | with open('C:/Data/GroundeR/id2idx_data/'+sample_id+'.pkl', 'rb') as f:
72 | id2idx = pickle.load(f)
73 |
74 | phrase_id = annotations['ids'][rand_int]
75 | true_region = id2idx[phrase_id]
76 | # Sentence as list of string-words (visualise)
77 | # phrase_word = np.array([self.idx2word[idx] for i, idx in enumerate(encoder_input) if i < len(phrase)])
78 |
79 | # print("Shapes:")
80 | # print(" - vis_features: {}".format(vis_features.shape))
81 | # print(" - encoder_input: {}".format(encoder_input.shape))
82 | # print(" - decoder_input: {}".format(decoder_input.shape))
83 | # print(" - decoder_target: {}".format(decoder_target.shape))
84 |
85 | # device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
86 |
87 | # print('current memory allocated: {}'.format(torch.cuda.memory_allocated() / 1024 ** 2))
88 | # print('max memory allocated: {}'.format(torch.cuda.max_memory_allocated() / 1024 ** 2))
89 | # print('cached memory: {}'.format(torch.cuda.memory_cached() / 1024 ** 2))
90 |
91 | # vis_features = torch.from_numpy(vis_features)
92 | # real_feat = torch.tensor(real_feat).to(device)
93 | # encoder_input = torch.from_numpy(encoder_input)
94 | # decoder_input = torch.from_numpy(decoder_input)
95 | # decoder_target = torch.from_numpy(decoder_target)
96 | # mask = torch.from_numpy(mask)
97 |
98 | return vis_features, real_feat, encoder_input, decoder_input, decoder_target, mask, true_region, len(phrase), len(phrase)-1
99 |
--------------------------------------------------------------------------------
/id2idx_data/readme.m:
--------------------------------------------------------------------------------
1 | from process_dataset.py
2 |
3 | Each file in this folder corresponds to an image of the dataset and contains a dictionary which maps an object_ID to its corresponding index in the visual feature matrix.
4 |
--------------------------------------------------------------------------------
/images/readme.m:
--------------------------------------------------------------------------------
1 | images for project readme
2 |
--------------------------------------------------------------------------------
/images/schematic.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acambray/GroundeR-PyTorch/20f241f15ce31aa55ec059eb5be564c7c266ed2b/images/schematic.png
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Aleix Cambray Roma
3 | Work done while at Imperial College London.
4 | """
5 | import time
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
10 |
11 |
12 | class PhraseEncoder(nn.Module):
13 | """
14 | Phrase Encoder
15 | Input is sequence of word indices
16 | Output is last LSTM hidden state
17 | """
18 | def __init__(self, embedding_layer, vocab_size, embed_dim, h_dim):
19 | super(PhraseEncoder, self).__init__()
20 |
21 | # Word embedding layers
22 | # self.embedding = nn.Embedding(vocab_size, embed_dim)
23 | self.embedding = embedding_layer
24 |
25 | # LSTM layer
26 | self.lstm = nn.LSTM(embed_dim, h_dim, batch_first=True)
27 |
28 | def forward(self, phrase_batch, lengths_enc):
29 | """
30 | Embedding turns sequences of indices (phrase) into sequence of vectors
31 | :param phrase_batch: (BATCH, TIME) i.e. if batch_size=2 and seq_length=5 phrase=[[1,5,4,7,3], [8,2,5,2,4]]
32 | :return: last LSTM output
33 | """
34 |
35 | # phrase dimensions: (BATCH, TIME) i.e. if batch_size=2 and seq_length=5 phrase=[[1,5,4,7,3], [8,2,5,2,4]]
36 | # embeds dimensions:
37 |
38 | batch_size = phrase_batch.size()[0]
39 | seq_length = phrase_batch.size()[1]
40 | embeds = self.embedding(phrase_batch)
41 |
42 | # Pack
43 | # 1. sort sequences by length
44 | ordered_len, ordered_idx = lengths_enc.sort(0, descending=True)
45 | ordered_embeds = embeds[ordered_idx]
46 | # 2. pack
47 | input_packed = pack_padded_sequence(ordered_embeds, ordered_len, batch_first=True)
48 |
49 | # Feed embeddings to LSTM
50 | # embeds = embeds.view(seq_length, batch_size, -1)
51 | _, (ht, ct) = self.lstm(input_packed) # Hidden is none because we don't initialise the hidden state, could use random noise instead
52 |
53 | # Get final hidden state from LSTM and reverse descending ordering
54 | h = ht[0, :, :]
55 | h[ordered_idx] = h
56 | return h
57 |
58 |
59 | class AttentionModule(nn.Module):
60 | def __init__(self, h_dim, v_dim, embed_dim, out1, regions):
61 | super(AttentionModule, self).__init__()
62 | self.fc1 = nn.Linear(h_dim + v_dim, out1, bias=True)
63 | self.fc2 = nn.Linear(out1, 1, bias=True)
64 | self.fcREC = nn.Linear(v_dim, embed_dim, bias=True)
65 |
66 | def forward(self, h, v):
67 | """
68 | :param h: encoded phrases dimensions [batch, h_dim] [32, 100]
69 | :param v: visual features matrix dimensions [batch_size, regions, v_dim] [32, 25, 1000]
70 | :return:
71 | """
72 | batch_size = h.size()[0]
73 | regions = v.size()[1]
74 | h_dim = h.size()[1]
75 | v_dim = v.size()[2]
76 |
77 | # We want to turn h from [batch_size, h_dim] to [batch_size, regions, 1, h_dim]
78 | h1 = h[:, None, None, :] # shape: (batch_size, 1, 1, h_dim)
79 | h_reshaped = h1.repeat(1, regions, 1, 1).type(dtype=torch.float32) # shape: (batch_size, regions, 1, h_dim)
80 |
81 | # Add extra dimension to match
82 | v_reshaped = v[:, :, None, :].type(dtype=torch.float32) # shape: (batch_size, regions, 1, v_dim)
83 |
84 | hv = torch.cat((v_reshaped, h_reshaped), dim=3)
85 |
86 | # View hv (batch_size, regions, 1, hv_dim) to (batch_size*regions, hv_dim)
87 | x = hv.view((hv.size(0)*hv.size(1), hv.size(3)))
88 | x = self.fc1(x) # [ batch*reg, hidden ]
89 | x = F.relu(x)
90 | x = self.fc2(x) # [ batch*reg, 1 ]
91 | x = x.view(hv.size(0), hv.size(1)) # [ batch, reg ]
92 | alpha = F.softmax(x, dim=1) # [ batch, reg ]
93 | att_log = F.log_softmax(x, dim=1)
94 |
95 | # Sum of the elementwise product between alpha and v
96 | alpha_expanded = alpha[:, :, None].expand_as(v)
97 | v_att = torch.sum(torch.mul(v, alpha_expanded), dim=(1,))
98 |
99 | v_att_dec = self.fcREC(v_att)
100 | v_att_dec = F.relu(v_att_dec)
101 | return v_att, v_att_dec, alpha, att_log
102 |
103 |
104 | class PhraseDecoder(nn.Module):
105 | def __init__(self, embedding_layer, embed_dim, h_dim, vocab_size, v_dim):
106 | super(PhraseDecoder, self).__init__()
107 |
108 | self.embedding = embedding_layer
109 |
110 | self.decoder = nn.LSTM(embed_dim, h_dim, batch_first=True)
111 | self.hidden_to_prob = nn.Linear(h_dim, vocab_size)
112 |
113 | def forward(self, v_att_dec, decoder_input, lengths_dec, teacher_forcing=1):
114 | # TODO: Implement greedy strategy (using 'teacher forcing' at the moment)
115 |
116 | embeds = self.embedding(decoder_input)
117 |
118 | # Concatenate the visual attended context vector and the phrase embedding as the input to the decoder
119 | decoder_input = torch.cat((v_att_dec[:, None, :], embeds[:, :-1, :]), dim=1) # indexes: [batch, timestep, word]
120 |
121 | # Re-order axis to fit PyTorch LSTM convention (seq_length, batch, input_size)
122 | # decoder_input = decoder_input.permute((1, 0, 2))
123 |
124 | # Pack --------------------------------------------------------------------------------------
125 | # 1. sort sequences by length
126 | ordered_len, ordered_idx = lengths_dec.sort(0, descending=True)
127 | ordered_embeds = decoder_input[ordered_idx]
128 | # 2. pack
129 | input_packed = pack_padded_sequence(ordered_embeds, ordered_len, batch_first=True)
130 |
131 | # LSTM --------------------------------------------------------------------------------------
132 | output_packed, hidden = self.decoder(input_packed)
133 |
134 | # Unpack all hidden states-------------------------------------------------------------------
135 | output_sorted, _ = pad_packed_sequence(output_packed, batch_first=True)
136 |
137 | # Fully Connected (hidden state to vocabulary probabilities) --------------------------------
138 | output = torch.zeros_like(output_sorted)
139 | output[ordered_idx] = output_sorted # shape: batch_size x seq_len x hidden
140 |
141 | input_fc = output.contiguous().view(output.size(0)*output.size(1), output.size(2)) # shape: batch_size*seq_len x hidden
142 | out = self.hidden_to_prob(input_fc) # shape: batch_size*seq_len x vocab_size
143 | out = out.view(output.size(0), output.size(1), out.size(1)) # shape: batch_size x seq_len x vocab_size]
144 | decoder_output = F.log_softmax(out, dim=2)
145 |
146 | return decoder_output
147 |
148 |
149 | class GroundeR(nn.Module):
150 | def __init__(self, vocab, vocab_size, embed_dim, h_dim, v_dim, regions, embeddings_matrix=None, train_embeddings=False):
151 | super(GroundeR, self).__init__()
152 |
153 | if embeddings_matrix is None:
154 | self.embedding = nn.Embedding(vocab_size, embed_dim, sparse=False)
155 | else:
156 | self.embedding = nn.Embedding.from_pretrained(torch.from_numpy(embeddings_matrix).type(torch.float32), freeze=True, sparse=True)
157 | if train_embeddings is False:
158 | self.embedding.weight.requires_grad = False
159 |
160 | self.phrase_encoder = PhraseEncoder(self.embedding, vocab_size, embed_dim, h_dim)
161 | self.attention = AttentionModule(h_dim, v_dim, embed_dim, 100, regions)
162 | self.phrase_decoder = PhraseDecoder(self.embedding, embed_dim, h_dim, vocab_size, v_dim)
163 |
164 | self.vocab = vocab
165 | self.vocab_size = len(vocab)
166 |
167 | def forward(self, encoder_input, lengths_enc, vis_features, decoder_input, lengths_dec):
168 |
169 | # Encode
170 | encoded_batch = self.phrase_encoder(encoder_input, lengths_enc)
171 |
172 | # Attend
173 | v_att, v_att_dec, att, att_log = self.attention(encoded_batch, vis_features)
174 |
175 | # Decode
176 | decoder_output = self.phrase_decoder(v_att_dec, decoder_input, lengths_dec, teacher_forcing=1)
177 |
178 | return decoder_output, att, att_log
179 |
180 | def masked_NLLLoss(self, pred, target, pad_token=0):
181 | max_pred_len = pred.size(1)
182 | Y = target[:, :max_pred_len]
183 | Y = Y.contiguous().view(-1)
184 |
185 | Y_hat = pred.view(-1, pred.size(2))
186 |
187 | mask = (Y != pad_token).float()
188 | Y_hat_masked = torch.zeros_like(Y, dtype=torch.float32)
189 |
190 | for i, idx in enumerate(Y):
191 | Y_hat_masked[i] = Y_hat[i, idx]
192 |
193 | Y_hat_masked = Y_hat_masked * mask
194 |
195 | n_tokens = torch.sum(mask)
196 |
197 | loss = -torch.sum(Y_hat_masked) / n_tokens
198 | return loss
199 |
--------------------------------------------------------------------------------
/process_dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Aleix Cambray Roma
3 | Work done while at Imperial College London.
4 |
5 | This code takes in the Flickr30k Entities dataset and generates the data structures necessary for the GroundeR implementation
6 | Output data structures (for each image in dataset)
7 | - Annotation dictionary annotation_data/.pkl - A dictionary containing ids, bboxes and token_index_sequences for each phrase
8 | - Visual features matrix visualfeatures_data/.npy - A matrix in which each column is a visual feature vector for each region
9 | - Link id2idx dict id2idx/.pkl - Matches object id for each phrase to the idx of the correct region (bbox)
10 | """
11 |
12 | from utils import *
13 | import PIL
14 | from PIL import Image, ImageDraw
15 | import time
16 | import numpy as np
17 | import torchvision.models as models
18 | import torchvision.transforms as transforms
19 | import os
20 | import itertools
21 | import glob
22 | import pickle
23 |
24 |
25 | def crop_and_resize(img, coords, size):
26 | """
27 | PIL image 'img' is first cropped to a rectangle according to 'coords'. Then the resulting crop is re-scaled according to 'size'
28 | :param img: PIL image
29 | :param coords: rectangle coordinates in the form of a list [x_nw, y_nw, x_se, y_se]
30 | :param size: tuple or list [h_size, v_size]
31 | :return:
32 | """
33 | img = img.crop(coords)
34 | img = img.resize(size, PIL.Image.LANCZOS)
35 | return img
36 |
37 |
38 | def unify_boxes(boxes):
39 | """
40 | This function turns a bunch of bounding boxes into one bounding box that encompasses all of them
41 | :param boxes: List of lists, each sub-list is a bounding box [x_nw, y_nw, x_se, y_se]
42 | :return: 1 bounding box
43 | """
44 | boxes = np.array(boxes)
45 | xmin = np.amin(boxes[:, 0])
46 | ymin = np.amin(boxes[:, 1])
47 | xmax = np.amax(boxes[:, 2])
48 | ymax = np.amax(boxes[:, 3])
49 | return [xmin, ymin, xmax, ymax]
50 |
51 |
52 | def phrase2seq(phrase, word2idx, vocab):
53 | """
54 | This function turns a sequence of words into a sequence of indexed tokens according to a vocabulary and its word2idx mapping dictionary.
55 | :param phrase: List of strings
56 | :param word2idx: Dictionary
57 | :param vocab: Set or list containing entire vocabulary as strings
58 | :return: List of integers. Each integer being the token index of the word in the phrase according to the vocabulary.
59 | """
60 | # and tokens
61 | phrase = [''] + phrase
62 | phrase = phrase + ['']
63 |
64 | phrase_seq = [0]*len(phrase)
65 | for i, word in enumerate(phrase):
66 | if word in vocab:
67 | phrase_seq[i] = word2idx[word]
68 | else:
69 | phrase_seq[i] = word2idx['']
70 | return phrase_seq
71 |
72 |
73 | if __name__ == "__main__":
74 | build_vocab = True # If True, the vocabulary will be built again and saved. If False, the last vocabulary will be loaded.
75 | generate_annotations = True
76 | generate_visual = False
77 |
78 | data_folder = "C:/Data/GroundeR/data/"
79 |
80 | train_txt = open(data_folder + 'train.txt', 'r')
81 | val_txt = open(data_folder + 'val.txt', 'r')
82 | test_txt = open(data_folder + 'test.txt', 'r')
83 |
84 | # Load RESNET
85 | transform = transforms.Compose([
86 | transforms.ToTensor(),
87 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
88 | std=[0.229, 0.224, 0.225]),
89 | ])
90 |
91 | resnet = models.resnet101(pretrained=True)
92 | resnet.eval()
93 | resnet.cuda()
94 |
95 | ####################################################################
96 | # BUILD VOCABULARY #
97 | ####################################################################
98 | if build_vocab:
99 | print("Building Vocabulary")
100 | vocabulary = ['', '', '']
101 | m = 1
102 | ids = [f[:-4] for f in os.listdir(data_folder + 'annotations/Sentences') if f.endswith('.txt')]
103 | N = len(ids)
104 | for img_n, img_id in enumerate(ids):
105 | print("\rImg {}/{}".format(img_n, N), end="")
106 | sentence_path = data_folder + 'annotations/Sentences/' + img_id + '.txt'
107 | corefData = get_sentence_data(sentence_path)
108 | for description in corefData:
109 | for word in description['sentence'].lower().split(' '):
110 | if word not in vocabulary:
111 | vocabulary.append(word)
112 | m += 1
113 |
114 | vocabulary.append('')
115 |
116 | word2idx = {w: idx for (idx, w) in enumerate(vocabulary)}
117 | idx2word = {idx: w for (idx, w) in enumerate(vocabulary)}
118 | with open('vocabulary.pkl', 'wb') as f:
119 | pickle.dump((word2idx, idx2word), f, protocol=pickle.HIGHEST_PROTOCOL)
120 |
121 | ####################################################################
122 | # BUILD ANNOTATION INPUT DATA STRUCTURE #
123 | ####################################################################
124 | if generate_annotations or generate_visual:
125 | with open(data_folder + "vocabulary.pkl", "rb") as f:
126 | word2idx, idx2word = pickle.load(f)
127 | vocab = set(word2idx.keys())
128 |
129 | # EXAMPLES LOOP
130 | phrase_count = 0
131 | ids = [f[:-4] for f in os.listdir(data_folder + 'annotations/Sentences') if f.endswith('.txt')]
132 | N = len(ids)
133 | for img_n, img_id in enumerate(ids):
134 | img_id = img_id.replace('\n', '')
135 | img_path = data_folder + 'flickr30k-images' + img_id + '. jpg'
136 |
137 | # Get Sentence and Annotation info for this image
138 | corefData = get_sentence_data(data_folder + 'annotations/Sentences/'+img_id+'.txt')
139 | annotationData = get_annotations(data_folder + 'annotations/Annotations/'+img_id+'.xml')
140 |
141 | ids = []
142 | bboxes = []
143 | seqs = []
144 | num_of_sentences = len(corefData)
145 | if generate_annotations:
146 | # DESCRIPTIONS LOOP
147 | for description in corefData:
148 |
149 | # PHRASES LOOP
150 | for phrase in description['phrases']:
151 | # Get object ID
152 | obj_id = phrase['phrase_id']
153 |
154 | # Check if this phrase has a box assigned. If not, then skip phrase.
155 | if obj_id not in list(annotationData['boxes'].keys()) or obj_id == '0':
156 | continue
157 | ids.append(obj_id)
158 |
159 | # Obtain box coordinates for this phrase
160 | boxes = annotationData['boxes'][obj_id]
161 | box = unify_boxes(boxes) if len(boxes) > 1 else boxes[0]
162 | bboxes.append(box)
163 |
164 | # Turn phrase from sequence of strings into sequence of indexes
165 | phrase = phrase['phrase'].lower().split()
166 | phrase_seq = phrase2seq(phrase, word2idx, vocab)
167 | seqs.append(phrase_seq)
168 | phrase_count += 1
169 |
170 | image_annotations = {'bboxes': bboxes, 'ids': ids, 'seqs': seqs}
171 | with open('C:/Data/GroundeR/annotation_data/'+img_id+'.pkl', 'wb') as f:
172 | pickle.dump(image_annotations, f, protocol=pickle.HIGHEST_PROTOCOL)
173 | if img_n % 100 == 0:
174 | print(f"\rImage {img_n}/{N} - Phrases: {phrase_count}", end="")
175 |
176 | if generate_visual:
177 | #############################################################
178 | # EXTRACT IMAGE FEATURES (RESNET101) #
179 | #############################################################
180 |
181 | img = Image.open(data_folder + 'flickr30k-images/' + img_id + '.jpg')
182 |
183 | # Build id to idx dictionary
184 | id2idx = {id: idx for (idx, id) in enumerate(list(annotationData['boxes'].keys()))}
185 |
186 | # LOOP THROUGH ALL BOXES
187 | objects = list(annotationData['boxes'].keys())
188 | vis_matrix = np.zeros((len(objects), 1000), dtype=float)
189 | for id in objects:
190 | # For each Object: extract boxes and unify them
191 | boxes = annotationData['boxes'][id]
192 | box = unify_boxes(boxes) if len(boxes) > 1 else boxes[0]
193 | # For each box: crop original img to box, resize crop to 224z224, normalise image
194 | box_img = crop_and_resize(img, box, (224, 224))
195 | box_img = transform(box_img)
196 | box_img = box_img.unsqueeze(0)
197 | box_img = box_img.cuda()
198 | # Feed image to ResNet-101 and add to visual feature matrix
199 | vis_feature = resnet(box_img)
200 | vis_matrix[id2idx[id], :] = vis_feature.cpu().detach().numpy()
201 |
202 | np.save('C:/Data/GroundeR/visualfeatures_data/'+img_id, vis_matrix)
203 | with open('C:/Data/GroundeR/id2idx_data/'+img_id+'.pkl', 'wb') as f:
204 | pickle.dump(id2idx, f)
205 |
--------------------------------------------------------------------------------
/process_dataset_embeddings.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Aleix Cambray Roma
3 | Work done while at Imperial College London.
4 |
5 | This code takes in the Flickr30k Entities dataset and generates the data structures necessary for the GroundeR implementation
6 | Output data structures (for each image in dataset)
7 | - Annotation dictionary annotation_data/.pkl - A dictionary containing ids, bboxes and token_index_sequences for each phrase
8 | - Visual features matrix visualfeatures_data/.npy - A matrix in which each column is a visual feature vector for each region
9 | - Link id2idx dict id2idx/.pkl - Matches object id for each phrase to the idx of the correct region (bbox)
10 | """
11 |
12 | from utils import *
13 | import PIL
14 | from PIL import Image, ImageDraw
15 | import time
16 | import numpy as np
17 | import torchvision.models as models
18 | import torchvision.transforms as transforms
19 | import os
20 | import itertools
21 | import glob
22 | import pickle
23 |
24 |
25 | def crop_and_resize(img, coords, size):
26 | """
27 | PIL image 'img' is first cropped to a rectangle according to 'coords'. Then the resulting crop is re-scaled according to 'size'
28 | :param img: PIL image
29 | :param coords: rectangle coordinates in the form of a list [x_nw, y_nw, x_se, y_se]
30 | :param size: tuple or list [h_size, v_size]
31 | :return:
32 | """
33 | img = img.crop(coords)
34 | img = img.resize(size, PIL.Image.LANCZOS)
35 | return img
36 |
37 |
38 | def unify_boxes(boxes):
39 | """
40 | This function turns a bunch of bounding boxes into one bounding box that encompasses all of them
41 | :param boxes: List of lists, each sub-list is a bounding box [x_nw, y_nw, x_se, y_se]
42 | :return: 1 bounding box
43 | """
44 | boxes = np.array(boxes)
45 | xmin = np.amin(boxes[:, 0])
46 | ymin = np.amin(boxes[:, 1])
47 | xmax = np.amax(boxes[:, 2])
48 | ymax = np.amax(boxes[:, 3])
49 | return [xmin, ymin, xmax, ymax]
50 |
51 |
52 | def phrase2seq(phrase, word2idx, vocab):
53 | """
54 | This function turns a sequence of words into a sequence of indexed tokens according to a vocabulary and its word2idx mapping dictionary.
55 | :param phrase: List of strings
56 | :param word2idx: Dictionary
57 | :param vocab: Set or list containing entire vocabulary as strings
58 | :return: List of integers. Each integer being the token index of the word in the phrase according to the vocabulary.
59 | """
60 | # and tokens
61 | phrase = [''] + phrase
62 | phrase = phrase + ['']
63 |
64 | phrase_seq = [0]*len(phrase)
65 | for i, word in enumerate(phrase):
66 | if word in vocab:
67 | phrase_seq[i] = word2idx[word]
68 | else:
69 | phrase_seq[i] = word2idx['']
70 | return phrase_seq
71 |
72 |
73 | if __name__ == "__main__":
74 | build_vocab = True # If True, the vocabulary will be built again and saved. If False, the last vocabulary will be loaded.
75 | generate_annotations = True
76 | generate_visual = False
77 |
78 | data_folder = "C:/Data/GroundeR/data/"
79 |
80 | train_txt = open(data_folder + 'train.txt', 'r')
81 | val_txt = open(data_folder + 'val.txt', 'r')
82 | test_txt = open(data_folder + 'test.txt', 'r')
83 |
84 | # Load RESNET
85 | transform = transforms.Compose([
86 | transforms.ToTensor(),
87 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
88 | std=[0.229, 0.224, 0.225]),
89 | ])
90 |
91 | resnet = models.resnet101(pretrained=True)
92 | resnet.eval()
93 | resnet.cuda()
94 |
95 | ####################################################################
96 | # BUILD VOCABULARY #
97 | ####################################################################
98 | if build_vocab:
99 | embeddings_path = "C:/Data/Embeddings/" + "glove.6B/glove.6B.50d.txt"
100 | print("Loading Glove Model")
101 |
102 | f = open(embeddings_path, 'rb')
103 | vocabulary = ['', '', '', '']
104 |
105 | # Generate random vectors for start, end, pad and unk tokens
106 | word2vec = {}
107 | word2vec[''] = np.random.normal(loc=0.0, scale=1, size=50)
108 | word2vec[''] = np.random.normal(loc=0.0, scale=1, size=50)
109 | word2vec[''] = np.random.normal(loc=0.0, scale=1, size=50)
110 | word2vec[''] = np.random.normal(loc=0.0, scale=1, size=50)
111 | t0 = time.time()
112 | i = 0
113 | for line in f:
114 | splitLine = line.decode().split()
115 | word = splitLine[0]
116 | vocabulary.append(word)
117 | embedding = np.array([float(val) for val in splitLine[1:]])
118 | word2vec[word] = embedding
119 | if i % 100 == 0:
120 | print("\r{} - t={:0.10f}s".format(i, time.time() - t0), end="")
121 | t0 = time.time()
122 | i += 1
123 |
124 | word2idx = {w: idx for (idx, w) in enumerate(vocabulary)}
125 | idx2word = {idx: w for (idx, w) in enumerate(vocabulary)}
126 |
127 | idx2vec = {}
128 | for idx in idx2word.keys():
129 | idx2vec[idx] = word2vec[idx2word[idx]]
130 |
131 | weight_matrix = np.zeros((len(vocabulary), 50))
132 | for idx in range(weight_matrix.shape[0]):
133 | weight_matrix[idx, :] = idx2vec[idx]
134 |
135 | pickle.dump((vocabulary, word2idx, idx2word, weight_matrix), open("C:/Data/Embeddings/glove.6B/GloVe_6B_50d_vocabulary.pkl", 'wb'))
136 |
137 | print("\nDone.", len(vocabulary), " words loaded!")
138 |
139 | ####################################################################
140 | # BUILD ANNOTATION INPUT DATA STRUCTURE #
141 | ####################################################################
142 | if generate_annotations or generate_visual:
143 | with open("C:/Data/Embeddings/glove.6B/GloVe_6B_50d_vocabulary.pkl", "rb") as f:
144 | vocab, word2idx, idx2word, _ = pickle.load(f)
145 | vocab = set(vocab)
146 |
147 | # EXAMPLES LOOP
148 | phrase_count = 0
149 | img_ids = [f[:-4] for f in os.listdir(data_folder + 'annotations/Sentences') if f.endswith('.txt')]
150 | N = len(img_ids)
151 | t0 = time.time()
152 | for img_n, img_id in enumerate(img_ids):
153 | img_id = img_id.replace('\n', '')
154 | img_path = data_folder + 'flickr30k-images' + img_id + '. jpg'
155 |
156 | # Get Sentence and Annotation info for this image
157 | corefData = get_sentence_data(data_folder + 'annotations/Sentences/'+img_id+'.txt')
158 | annotationData = get_annotations(data_folder + 'annotations/Annotations/'+img_id+'.xml')
159 |
160 | ids = []
161 | bboxes = []
162 | seqs = []
163 | num_of_sentences = len(corefData)
164 |
165 | if generate_annotations:
166 | # DESCRIPTIONS LOOP
167 | for description in corefData:
168 |
169 | # PHRASES LOOP
170 | for phrase in description['phrases']:
171 | # Get object ID
172 | obj_id = phrase['phrase_id']
173 |
174 | # Check if this phrase has a box assigned. If not, then skip phrase.
175 | if obj_id not in list(annotationData['boxes'].keys()) or obj_id == '0':
176 | continue
177 | ids.append(obj_id)
178 |
179 | # Obtain box coordinates for this phrase
180 | boxes = annotationData['boxes'][obj_id]
181 | box = unify_boxes(boxes) if len(boxes) > 1 else boxes[0]
182 | bboxes.append(box)
183 |
184 | # Turn phrase from sequence of strings into sequence of indexes
185 | phrase = phrase['phrase'].lower().split()
186 | phrase_seq = phrase2seq(phrase, word2idx, vocab)
187 | seqs.append(phrase_seq)
188 | phrase_count += 1
189 |
190 | image_annotations = {'bboxes': bboxes, 'ids': ids, 'seqs': seqs}
191 | with open('C:/Data/GroundeR/annotation_data/'+img_id+'.pkl', 'wb') as f:
192 | pickle.dump(image_annotations, f, protocol=pickle.HIGHEST_PROTOCOL)
193 | if img_n % 100 == 0:
194 | dt = time.time() - t0
195 | print(f"\rImage {img_n}/{N} - dt_100 = {dt} - Phrases: {phrase_count}", end="")
196 | t0 = time.time()
197 |
198 | if generate_visual:
199 | #############################################################
200 | # EXTRACT IMAGE FEATURES (RESNET101) #
201 | #############################################################
202 |
203 | img = Image.open(data_folder + 'flickr30k-images/' + img_id + '.jpg')
204 |
205 | # Build id to idx dictionary
206 | id2idx = {id: idx for (idx, id) in enumerate(list(annotationData['boxes'].keys()))}
207 |
208 | # LOOP THROUGH ALL BOXES
209 | objects = list(annotationData['boxes'].keys())
210 | vis_matrix = np.zeros((len(objects), 1000), dtype=float)
211 | for id in objects:
212 | # For each Object: extract boxes and unify them
213 | boxes = annotationData['boxes'][id]
214 | box = unify_boxes(boxes) if len(boxes) > 1 else boxes[0]
215 | # For each box: crop original img to box, resize crop to 224z224, normalise image
216 | box_img = crop_and_resize(img, box, (224, 224))
217 | box_img = transform(box_img)
218 | box_img = box_img.unsqueeze(0)
219 | box_img = box_img.cuda()
220 | # Feed image to ResNet-101 and add to visual feature matrix
221 | vis_feature = resnet(box_img)
222 | vis_matrix[id2idx[id], :] = vis_feature.cpu().detach().numpy()
223 |
224 | np.save('C:/Data/GroundeR/visualfeatures_data/'+img_id, vis_matrix)
225 | with open('C:/Data/GroundeR/id2idx_data/'+img_id+'.pkl', 'wb') as f:
226 | pickle.dump(id2idx, f)
227 |
--------------------------------------------------------------------------------
/train_unsupervised_withval.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Aleix Cambray Roma
3 | Work done while at Imperial College London.
4 | """
5 |
6 | import torch
7 | import torch.nn as nn
8 | from torch import optim
9 | from torch.utils.data import DataLoader
10 |
11 | import time
12 | import copy
13 | import pickle
14 | import numpy as np
15 | from model import GroundeR
16 | import matplotlib.pyplot as plt
17 | from dataloader_unsupervised import UnsupervisedDataLoader
18 |
19 |
20 | def print_gpu_memory(first_string=None):
21 | if first_string is not None:
22 | print(first_string)
23 | print(' current memory allocated: {}'.format(torch.cuda.memory_allocated() / 1024 ** 2))
24 | print(' max memory allocated: {}'.format(torch.cuda.max_memory_allocated() / 1024 ** 2))
25 | print(' cached memory: {}'.format(torch.cuda.memory_cached() / 1024 ** 2))
26 |
27 |
28 | def seq2phrase(seq):
29 | return " ".join([idx2word[idx.item()] for idx in seq])
30 |
31 |
32 | if __name__ == '__main__':
33 | print("Running on {}".format(torch.cuda.get_device_name(0)))
34 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
35 | # device = 'cpu'
36 | print_gpu_memory()
37 |
38 | # Load vocabulary and embeddings
39 | # vocab_file = "C:/Data/GroundeR/vocabulary.pkl"
40 | vocab_file = "C:/Data/Embeddings/glove.6B/GloVe_6B_50d_vocabulary.pkl"
41 | with open(vocab_file, "rb") as f:
42 | vocab, word2idx, idx2word, weight_matrix = pickle.load(f)
43 | vocab = list(word2idx.keys())
44 |
45 | # Config
46 | vocab_size = len(vocab)
47 | embed_dim = 50
48 | h_dim = 100
49 | v_dim = 1000
50 |
51 | # Hyper-parameters
52 | regions = 25
53 | learning_rate = 0.00025
54 | epochs = 100
55 | batch_size = 30
56 | max_seq_length = 10
57 | L = 5
58 |
59 | # Build data pipeline providers
60 | dataset = UnsupervisedDataLoader(data_dir="C:/Data/GroundeR/", sample_list_file="flickr30k_train_val.txt", seq_length=max_seq_length)
61 | dataset_length = len(dataset)
62 |
63 | train_length = int(0.8*dataset_length)
64 | val_length = dataset_length - train_length
65 | train_ds, val_ds = torch.utils.data.random_split(dataset, [train_length, val_length])
66 |
67 | print("Training dataset size: {}".format(len(train_ds)))
68 | print("Val dataset size: {}".format(len(val_ds)))
69 | train_dataloader = DataLoader(train_ds, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=3)
70 | val_dataloader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=3)
71 |
72 | # Instantiate model and send to GPU if available
73 | model = GroundeR(vocab, vocab_size, embed_dim, h_dim, v_dim, regions, weight_matrix, train_embeddings=False)
74 | model = model.to(device)
75 | print_gpu_memory(" After model to GPU:")
76 | criterion_att = nn.NLLLoss()
77 | criterion_rec = nn.NLLLoss(ignore_index=0)
78 | # optimizer = optim.SGD(model.parameters(), lr=learning_rate)
79 | optimizer = optim.Adam(model.parameters(), lr=learning_rate)
80 |
81 | t0 = time.time()
82 | start = time.time()
83 | min_loss = 10000
84 | times = []
85 | times_total = []
86 | training_losses = []
87 | training_accuracies = []
88 | epoch_training_losses = []
89 | epoch_training_accuracies = []
90 | epoch_val_losses = []
91 | epoch_val_accuracies = []
92 | iter_num = []
93 | for epoch in range(epochs):
94 | for phase in ['train', 'val']:
95 | if phase == 'train':
96 | dataloader = train_dataloader
97 | model.train() # Set model to training mode
98 | else:
99 | dataloader = val_dataloader
100 | model.eval() # Set model to evaluation mode
101 |
102 | epoch_samples = 0
103 | running_loss = 0.0
104 | running_corrects = 0.0
105 | batch_i = 0
106 | for vis_features, real_feat, encoder_input, decoder_input, decoder_target, mask, region_true, lengths_enc, lengths_dec in dataloader:
107 | torch.cuda.empty_cache()
108 | dt_load = (time.time() - t0) * 1000
109 |
110 | # Send all data to GPU if available
111 | t1 = time.time()
112 | vis_features = vis_features.to(device)
113 | real_feat = real_feat.to(device)
114 | encoder_input = encoder_input.to(device)
115 | decoder_input = decoder_input.to(device)
116 | decoder_target = decoder_target.to(device)
117 | mask = mask.to(device)
118 | region_true = region_true.to(device)
119 | # lengths_enc = lengths_enc.to(device)
120 | # lengths_dec = lengths_dec.to(device)
121 | dt_togpu = (time.time() - t1) * 1000
122 | print_gpu_memory(" After data to GPU:")
123 |
124 | # Zero the parameter gradients
125 | optimizer.zero_grad()
126 |
127 | # Forward Pass
128 | t2 = time.time()
129 | decoder_output, att, att_log = model(encoder_input, lengths_enc, vis_features, decoder_input, lengths_dec)
130 | dt_forward = (time.time() - t2) * 1000
131 | print_gpu_memory(" After forward run:")
132 |
133 | # Loss
134 | # TODO: [done] Mask loss to ignore pad tokens (using the mask tensor for each sample)
135 | t3 = time.time()
136 | pred = decoder_output.view(decoder_output.size(0)*decoder_output.size(1), decoder_output.size(2))
137 | target = decoder_target[:, :decoder_output.size(1)] # truncate the decoder_target from length 10 to maximum sequence length in batch
138 | target = target.contiguous().view(-1)
139 | loss_att = criterion_att(att_log, region_true)
140 | # loss_rec = model.masked_NLLLoss(decoder_output, decoder_target)
141 | loss_rec = criterion_rec(pred, target)
142 | loss = L*loss_att + loss_rec
143 | dt_loss = (time.time() - t3) * 1000
144 | print_gpu_memory(" After loss calc:")
145 |
146 | # Accuracy
147 | region_pred = att.max(dim=1)[1]
148 | corrects = torch.sum(region_true == region_pred).item()
149 | running_corrects += corrects
150 | accuracy = corrects / region_true.size(0)
151 |
152 | if phase == 'train':
153 | # Backward pass and parameter update
154 | # TODO: [done] Find out why loss isn't getting updates (initialisation?)
155 | t4 = time.time()
156 | loss.backward()
157 | print_gpu_memory(" After backward run:")
158 | optimizer.step()
159 | print_gpu_memory(" After optimizer step:")
160 | dt_backward = (time.time() - t4) * 1000
161 | print("{:02.0f}.{:03.0f} - Sample {:05.0f}/30781 - Accuracy: {:0.2f}% - Loss: {:02.7f} - GPU: {:0.2f} MB - Load time = {:06.2f}ms - toGPU time = {:06.2f}ms - Forward time = {:06.2f}ms - Loss: {:06.2f}ms - Backward {:06.2f}ms - Time {:05.2f}s".format(epoch, batch_i + 1, (batch_i + 1) * batch_size, accuracy*100, loss.item(), torch.cuda.memory_allocated() / 1024 ** 2, dt_load, dt_togpu, dt_forward, dt_loss, dt_backward, (time.time()-start)))
162 | training_losses.append(loss.item()) # Appends loss over entire batch (reduce=mean)
163 | training_accuracies.append(accuracy)
164 | times.append(dt_load + dt_togpu)
165 | times_total.append(time.time() - t0)
166 |
167 | # statistics & counters
168 | epoch_samples += vis_features.size(0)
169 | running_loss += loss.item() * vis_features.size(0) # Sum of losses over all samples in batch
170 |
171 | batch_i += 1
172 | t0 = time.time()
173 |
174 | # Print real and predicted sentences:
175 | for i_print in range(3):
176 | print(" {} - Real: {}".format(i_print, seq2phrase(decoder_target[i_print])))
177 | print(" {} - Pred: {}".format(i_print, seq2phrase(decoder_output.max(dim=2)[1][i_print])))
178 |
179 | # Track epoch losses for both training and validation phases
180 | if phase == 'train':
181 | print("Training Epoch Performance on {} samples".format(epoch_samples))
182 | epoch_loss = running_loss / epoch_samples
183 | epoch_training_losses.append(epoch_loss)
184 | epoch_training_accuracies.append(running_corrects/epoch_samples)
185 | iter_num.append(len(training_losses))
186 | elif phase == 'val':
187 | print("Validation Performance on {} samples".format(epoch_samples))
188 | epoch_loss = running_loss / epoch_samples
189 | epoch_val_losses.append(epoch_loss)
190 | epoch_val_accuracies.append(running_corrects/epoch_samples)
191 | # Save weights of best val-performing model
192 | if epoch_loss < min_loss:
193 | min_loss = epoch_loss
194 | best_model_wts = copy.deepcopy(model.state_dict())
195 | torch.save(model.state_dict(), "best_model.pt")
196 |
197 | # Print learning profile every 5 epochs
198 | if phase == 'val':
199 | plt.title("LR: {}".format(learning_rate))
200 | plt.plot(training_losses, label="Training Losses")
201 | plt.plot(iter_num, epoch_training_losses, marker='o', label="Epoch Training losses")
202 | plt.plot(iter_num, epoch_val_losses, marker='o', label="Epoch Val Losses")
203 | plt.legend()
204 | plt.savefig('learning_profile.png')
205 | plt.clf()
206 | plt.plot(training_accuracies, label="Training Accuracies per batch")
207 | plt.plot(iter_num, epoch_training_accuracies, label="Epoch Training Accuracies")
208 | plt.plot(iter_num, epoch_val_accuracies, label="Epoch Val Accuracies")
209 | plt.title("Region Accuracies")
210 | plt.savefig('accuracies.png')
211 | plt.clf()
212 |
213 | plt.plot(times)
214 | plt.savefig('final_times.png')
215 |
216 | plt.plot(times_total)
217 | plt.savefig('final_times_total.png')
218 |
219 | plt.plot(training_losses, label="Training Losses")
220 | plt.plot(iter_num, epoch_training_losses, label="Epoch Training losses")
221 | plt.plot(iter_num, epoch_val_losses, label="Epoch Val Losses")
222 | plt.legend()
223 | plt.savefig('final_learning_profile.png')
224 |
225 | print("\nTime per sample: {}".format((time.time()-t0)/30781))
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import xml.etree.ElementTree as ET
2 |
3 | def get_sentence_data(fn):
4 | """
5 | Parses a sentence file from the Flickr30K Entities dataset
6 |
7 | input:
8 | fn - full file path to the sentence file to parse
9 |
10 | output:
11 | a list of dictionaries for each sentence with the following fields:
12 | sentence - the original sentence
13 | phrases - a list of dictionaries for each phrase with the
14 | following fields:
15 | phrase - the text of the annotated phrase
16 | first_word_index - the position of the first word of
17 | the phrase in the sentence
18 | phrase_id - an identifier for this phrase
19 | phrase_type - a list of the coarse categories this
20 | phrase belongs to
21 |
22 | """
23 | with open(fn, 'r', encoding="utf8") as f:
24 | sentences = f.read().split('\n')
25 |
26 | annotations = []
27 | for sentence in sentences:
28 | if not sentence:
29 | continue
30 |
31 | first_word = []
32 | phrases = []
33 | phrase_id = []
34 | phrase_type = []
35 | words = []
36 | current_phrase = []
37 | add_to_phrase = False
38 | for token in sentence.split():
39 | if add_to_phrase:
40 | if token[-1] == ']':
41 | add_to_phrase = False
42 | token = token[:-1]
43 | current_phrase.append(token)
44 | phrases.append(' '.join(current_phrase))
45 | current_phrase = []
46 | else:
47 | current_phrase.append(token)
48 |
49 | words.append(token)
50 | else:
51 | if token[0] == '[':
52 | add_to_phrase = True
53 | first_word.append(len(words))
54 | parts = token.split('/')
55 | phrase_id.append(parts[1][3:])
56 | phrase_type.append(parts[2:])
57 | else:
58 | words.append(token)
59 |
60 | sentence_data = {'sentence' : ' '.join(words), 'phrases' : []}
61 | for index, phrase, p_id, p_type in zip(first_word, phrases, phrase_id, phrase_type):
62 | sentence_data['phrases'].append({'first_word_index' : index,
63 | 'phrase' : phrase,
64 | 'phrase_id' : p_id,
65 | 'phrase_type' : p_type})
66 |
67 | annotations.append(sentence_data)
68 |
69 | return annotations
70 |
71 | def get_annotations(fn):
72 | """
73 | Parses the xml files in the Flickr30K Entities dataset
74 |
75 | input:
76 | fn - full file path to the annotations file to parse
77 |
78 | output:
79 | dictionary with the following fields:
80 | scene - list of identifiers which were annotated as
81 | pertaining to the whole scene
82 | nobox - list of identifiers which were annotated as
83 | not being visible in the image
84 | boxes - a dictionary where the fields are identifiers
85 | and the values are its list of boxes in the
86 | [xmin ymin xmax ymax] format
87 | """
88 | tree = ET.parse(fn)
89 | root = tree.getroot()
90 | size_container = root.findall('size')[0]
91 | anno_info = {'boxes' : {}, 'scene' : [], 'nobox' : []}
92 | for size_element in size_container:
93 | anno_info[size_element.tag] = int(size_element.text)
94 |
95 | for object_container in root.findall('object'):
96 | for names in object_container.findall('name'):
97 | box_id = names.text
98 | box_container = object_container.findall('bndbox')
99 | if len(box_container) > 0:
100 | if box_id not in anno_info['boxes']:
101 | anno_info['boxes'][box_id] = []
102 | xmin = int(box_container[0].findall('xmin')[0].text) - 1
103 | ymin = int(box_container[0].findall('ymin')[0].text) - 1
104 | xmax = int(box_container[0].findall('xmax')[0].text) - 1
105 | ymax = int(box_container[0].findall('ymax')[0].text) - 1
106 | anno_info['boxes'][box_id].append([xmin, ymin, xmax, ymax])
107 | else:
108 | nobndbox = int(object_container.findall('nobndbox')[0].text)
109 | if nobndbox > 0:
110 | anno_info['nobox'].append(box_id)
111 |
112 | scene = int(object_container.findall('scene')[0].text)
113 | if scene > 0:
114 | anno_info['scene'].append(box_id)
115 |
116 | return anno_info
117 |
118 |
119 |
--------------------------------------------------------------------------------
/visualfeatures_data/readme.m:
--------------------------------------------------------------------------------
1 | processed visual feature matrices (one per image on Flickr30k) from process_dataset.py
2 |
--------------------------------------------------------------------------------