├── .gitignore
├── .idea
├── Tex2Vis.iml
├── misc.xml
├── modules.xml
└── workspace.xml
├── README.md
├── captions
├── Test_val2014.sentences.txt.bz2
├── Test_val2014.sentences.txt.ngrams.bz2
├── Validation_val2014.sentences.txt.bz2
├── Validation_val2014.sentences.txt.ngrams.bz2
├── train2014.sentences.txt.bz2
└── train2014.sentences.txt.ngrams.bz2
├── pca
├── fc6
│ └── pca.zip
└── fc7
│ └── pca.zip
├── src
├── __init__.py
├── batcher_lstmbased.py
├── batcher_sparse.py
├── batcher_word2visualvec.py
├── caption_search_bow_ngram.py
├── caption_search_rouge.py
├── evaluation_measure.py
├── flags.py
├── helpers.py
├── lstm_text2vis.py
├── mscoco_captions_reader.py
├── paths.py
├── pca_reader.py
├── sparse_text2vis.py
├── visual_embeddings_reader.py
├── word2visualvec.py
└── yfcc100m_extractor.py
├── visualembeddings
└── readme.txt
└── wordembeddings
└── readme.txt
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | ModelParameters/checkpoint
3 |
--------------------------------------------------------------------------------
/.idea/Tex2Vis.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 | true
73 | DEFINITION_ORDER
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 | project
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 |
307 |
308 |
309 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
328 |
329 |
330 |
331 |
332 |
333 |
334 |
335 |
336 |
337 |
338 |
339 |
340 |
341 |
342 |
343 |
344 |
345 |
346 |
347 |
348 |
349 |
350 |
351 |
352 |
353 |
354 |
355 |
356 |
357 |
358 |
359 |
360 | 1463313634211
361 |
362 |
363 | 1463313634211
364 |
365 |
366 |
367 |
368 |
369 |
370 |
371 |
372 |
373 |
374 |
375 |
376 |
377 |
378 |
379 |
380 |
381 |
382 |
383 |
384 |
385 |
386 |
387 |
388 |
389 |
390 |
391 |
392 |
393 |
394 |
395 |
396 |
397 |
398 |
399 |
400 |
401 |
402 |
403 |
404 |
405 |
406 |
407 |
408 |
409 |
410 |
411 |
412 |
413 |
414 |
415 |
416 |
417 |
418 |
419 |
420 |
421 |
422 |
423 |
424 |
425 |
426 |
427 |
428 |
429 |
430 |
431 |
432 |
433 |
434 |
435 |
436 |
437 |
438 |
439 |
440 |
441 |
442 |
443 |
444 |
445 |
446 |
447 |
448 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Text2Vis
2 |
3 | Text2Vis is a family of neural network models aimed at learning a mapping from short textual descriptions to visual features, so that one can search for images by simply providing a short description of it.
4 |
5 | Text2Vis includes: (i) a sparse version, that takes a one-hot vector represeting the textual description as input; (ii) a dense version where words are embedded and given as input to an LSTM conditioning the last memory state to the visual space; and (iii) a Wide & Deep model, that combines both sparse and dense representations. We also included our reimplementation of the Word2VisualVec model with the MSE loss.
6 |
7 | Note: to train the model, you need the visual features associated to the MsCOCO image repository. In our experiments, we considered the fc6 and fc7 layers of the Hybrid CNN. They were however too heavy to add them to the repository, but if you need them we will be very happy to share ours with you!
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/captions/Test_val2014.sentences.txt.bz2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AlexMoreo/tensorflow-Text2Vis/fa613dcd3011c476d38fc667b82230a2fd8b1ba5/captions/Test_val2014.sentences.txt.bz2
--------------------------------------------------------------------------------
/captions/Test_val2014.sentences.txt.ngrams.bz2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AlexMoreo/tensorflow-Text2Vis/fa613dcd3011c476d38fc667b82230a2fd8b1ba5/captions/Test_val2014.sentences.txt.ngrams.bz2
--------------------------------------------------------------------------------
/captions/Validation_val2014.sentences.txt.bz2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AlexMoreo/tensorflow-Text2Vis/fa613dcd3011c476d38fc667b82230a2fd8b1ba5/captions/Validation_val2014.sentences.txt.bz2
--------------------------------------------------------------------------------
/captions/Validation_val2014.sentences.txt.ngrams.bz2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AlexMoreo/tensorflow-Text2Vis/fa613dcd3011c476d38fc667b82230a2fd8b1ba5/captions/Validation_val2014.sentences.txt.ngrams.bz2
--------------------------------------------------------------------------------
/captions/train2014.sentences.txt.bz2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AlexMoreo/tensorflow-Text2Vis/fa613dcd3011c476d38fc667b82230a2fd8b1ba5/captions/train2014.sentences.txt.bz2
--------------------------------------------------------------------------------
/captions/train2014.sentences.txt.ngrams.bz2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AlexMoreo/tensorflow-Text2Vis/fa613dcd3011c476d38fc667b82230a2fd8b1ba5/captions/train2014.sentences.txt.ngrams.bz2
--------------------------------------------------------------------------------
/pca/fc6/pca.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AlexMoreo/tensorflow-Text2Vis/fa613dcd3011c476d38fc667b82230a2fd8b1ba5/pca/fc6/pca.zip
--------------------------------------------------------------------------------
/pca/fc7/pca.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AlexMoreo/tensorflow-Text2Vis/fa613dcd3011c476d38fc667b82230a2fd8b1ba5/pca/fc7/pca.zip
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AlexMoreo/tensorflow-Text2Vis/fa613dcd3011c476d38fc667b82230a2fd8b1ba5/src/__init__.py
--------------------------------------------------------------------------------
/src/batcher_lstmbased.py:
--------------------------------------------------------------------------------
1 | import sys, os
2 | import random
3 | import numpy as np
4 | from mscoco_captions_reader import MSCocoCaptions
5 | from visual_embeddings_reader import VisualEmbeddingsReader
6 |
7 | class Batcher:
8 |
9 | def __init__(self,
10 | captions_file,
11 | visual_file,
12 | buckets_def,
13 | batch_size,
14 | lemma=False,
15 | dowide=True,
16 | word2id=None, id2word=None):
17 | self._captions = MSCocoCaptions(captions_file, word2id=word2id, id2word=id2word, lemma=lemma)
18 | self._visual = VisualEmbeddingsReader(visual_file)
19 | self._batch_size = batch_size
20 | self._build_buckets(buckets_def) # list of bucket sizes
21 | self.epoch = 0 #automatically increases when all examples have been seen
22 | self._is_cached_pads = True
23 | self._dowide=dowide
24 | self._cache_pads = dict()
25 |
26 | def _get_bucket(self, caption, buckets_def):
27 | l = len(caption)
28 | for x in buckets_def:
29 | if x >= l: return x
30 | #if no bucket can contain it, returns the biggest
31 | return max(buckets_def)
32 |
33 | def _build_buckets(self, buckets_def):
34 | print('Building buckets...')
35 | self.buckets = dict([(x, []) for x in buckets_def])
36 | for img_label in self._captions.get_image_ids():
37 | for cap_pos, cap in enumerate(self._captions.get_captions(img_label)):
38 | bucket = self._get_bucket(cap, buckets_def)
39 | self.buckets[bucket].append([img_label, cap_pos])
40 | print('\t' + ' '.join([('[size=%d, %d]'%(x,len(self.buckets[x]))) for x in buckets_def]))
41 | self.buckets_def = []
42 | for bucket_size in buckets_def:
43 | bucket_length = len(self.buckets[bucket_size])
44 | if bucket_length < self._batch_size:
45 | print('Warning: bucket %d contains only %d elements.' % (bucket_size, bucket_length))
46 | # del self.buckets[bucket_size]
47 | # print('Removing bucket %d, it contains only %d elements [%d required]' % (bucket_size, bucket_length, self._batch_size))
48 | #else:
49 | self.buckets_def.append(bucket_size)
50 | self.buckets_def.sort()
51 | self._curr_bucket_pos = 0 # current bucket position
52 | self._offset_bucket = 0 # offset position in the current bucket block
53 |
54 | def current_bucket_elements(self):
55 | return self.buckets[self.current_bucket_size()]
56 |
57 | def current_bucket_size(self):
58 | return self.buckets_def[self._curr_bucket_pos]
59 |
60 | def num_buckets(self):
61 | return len(self.buckets_def)
62 |
63 | def _next_bucket(self):
64 | self._curr_bucket_pos = (self._curr_bucket_pos + 1) % self.num_buckets()
65 | self._offset_bucket=0
66 | random.shuffle(self.current_bucket_elements())
67 | if self._curr_bucket_pos==0: self.epoch += 1
68 |
69 | def _get_pad(self, img_label, cap_pos, bucket_size):
70 | if self._is_cached_pads:
71 | if img_label not in self._cache_pads:
72 | self._cache_pads[img_label]=dict()
73 | if cap_pos not in self._cache_pads[img_label]:
74 | self._cache_pads[img_label][cap_pos]=self._gen_pad(img_label, cap_pos, bucket_size)
75 | return self._cache_pads[img_label][cap_pos]
76 | else:
77 | return self._gen_pad(img_label, cap_pos, bucket_size)
78 |
79 | """
80 | Applies padding to the caption until bucket_size
81 | If the caption length is greater than bucket_size, the endding part of the caption is cut
82 | """
83 | def _gen_pad(self, img_label, cap_pos, bucket_size):
84 | caption = self._captions.get_captions(img_label)[cap_pos]
85 | num_pads = max(bucket_size - len(caption), 0)
86 | return [self._captions.get_pad()] * num_pads + caption[:bucket_size]
87 |
88 | def wide1hot(self, img_label, cap_pos):
89 | if not self._dowide: return None
90 | caption = self._captions.get_captions(img_label)[cap_pos]
91 | wide1hot = np.zeros(self.vocabulary_size())
92 | wide1hot[caption] = 1
93 | return wide1hot
94 |
95 | def next(self):
96 | caps = self.current_bucket_elements()[self._offset_bucket: self._offset_bucket + self._batch_size]
97 | img_labels, caps_pos, pads, visuals = [],[],[],[]
98 | wide1hot = []
99 | for (img_label, cap_pos) in caps:
100 | img_labels.append(img_label)
101 | caps_pos.append(cap_pos)
102 | wide1hot.append(self.wide1hot(img_label, cap_pos))
103 | pads += self._get_pad(img_label, cap_pos, self.buckets_def[self._curr_bucket_pos])
104 | visuals.append(self._visual.get(img_label))
105 | pads = np.array(pads).reshape(len(caps), -1)
106 | current_bucket_size = self.current_bucket_size()
107 | self._offset_bucket += self._batch_size
108 | if self._offset_bucket >= len(self.current_bucket_elements()): # check if the bucket has been processed
109 | self._next_bucket()
110 | return img_labels, caps_pos, pads, wide1hot, visuals, current_bucket_size
111 |
112 | def vocabulary_size(self):
113 | return self._captions.vocabulary_size()
114 |
115 | def get_word2id(self):
116 | return self._captions.word2id
117 |
118 | def get_id2word(self):
119 | return self._captions.id2word
120 |
121 | def get_caption_txt(self, img_id, cap_offset):
122 | return self._captions.captions[img_id][cap_offset]
123 |
124 | def get_captions_txt(self, img_id):
125 | return self._captions.captions[img_id]
126 |
127 | def from_batchlabel2batch_onehot(self, batch_labels):
128 | batch_onehot = np.zeros(shape=(len(batch_labels), self.vocabulary_size()), dtype=np.float)
129 | for i, label in enumerate(batch_labels):
130 | batch_onehot[i, label] = 1.0
131 | return batch_onehot
132 |
133 |
--------------------------------------------------------------------------------
/src/batcher_sparse.py:
--------------------------------------------------------------------------------
1 | import sys, os
2 | import random
3 | import numpy as np
4 | from mscoco_captions_reader import MSCocoCaptions
5 | from visual_embeddings_reader import VisualEmbeddingsReader
6 |
7 | class Batcher:
8 |
9 | def __init__(self,
10 | captions_file,
11 | visual_file,
12 | batch_size,
13 | max_vocabulary_size = 10000,
14 | word2id=None, id2word=None):
15 | self._captions = MSCocoCaptions(captions_file, max_vocabulary_size=max_vocabulary_size, word2id=word2id, id2word=id2word)
16 | self._visual = VisualEmbeddingsReader(visual_file)
17 | self._batch_size = batch_size
18 | self.epoch = 0 #automatically increases when all examples have been seen
19 | self._samples = self._get_samples_coords()
20 | self._offset = 0
21 |
22 | """
23 | Return all samples coordinates,
24 | i.e., [(img_id0,cap0), (img_id0,cap1), ..., (img_id0,cap4), ..., (img_idN,cap4))
25 | """
26 | def _get_samples_coords(self):
27 | all_coords = []
28 | for img_id in self._visual.visual_embeddings.keys():
29 | all_coords += [(img_id, cap_id) for cap_id in self.caption_ids(img_id)]
30 | return all_coords
31 |
32 | def wide1hot(self, img_label, cap_pos):
33 | caption = self._captions.get_captions(img_label)[cap_pos]
34 | wide1hot = np.zeros(self.vocabulary_size())
35 | wide1hot[caption] = 1
36 | return wide1hot
37 |
38 | def rand_cap(self, img_label):
39 | return random.choice(range(len(self._captions.captions[img_label])))
40 |
41 | def next(self):
42 | batch_samples = self._samples[self._offset:self._offset + self._batch_size]
43 | self._offset += self._batch_size
44 | if self._offset > len(self._samples):
45 | self._offset = 0
46 | self.epoch += 1
47 | random.shuffle(self._samples)
48 | if not batch_samples:
49 | return self.next()
50 |
51 | img_labels, caps_pos, wide_in, wide_out, visual_embeddings = [], [], [], [], []
52 | for img_id,cap_pos in batch_samples:
53 | img_labels.append(img_id)
54 | caps_pos.append(cap_pos)
55 | wide_in.append(self.wide1hot(img_id, cap_pos))
56 | wide_out.append(self.wide1hot(img_id, self.rand_cap(img_id)))
57 | visual_embeddings.append(self._visual.get(img_id))
58 | return img_labels, caps_pos, wide_in, wide_out, visual_embeddings
59 |
60 | def vocabulary_size(self):
61 | return self._captions.vocabulary_size()
62 |
63 | def get_word2id(self):
64 | return self._captions.word2id
65 |
66 | def get_id2word(self):
67 | return self._captions.id2word
68 |
69 | def get_caption_txt(self, img_id, cap_offset):
70 | return self._captions.captions[img_id][cap_offset]
71 |
72 | def get_captions_txt(self, img_id):
73 | return self._captions.captions[img_id]
74 |
75 | def caption_ids(self, img_id):
76 | return range(self.num_captions(img_id))
77 |
78 | def num_captions(self, img_id):
79 | return len(self._captions.get_captions(img_id))
80 |
81 |
82 |
83 |
--------------------------------------------------------------------------------
/src/batcher_word2visualvec.py:
--------------------------------------------------------------------------------
1 | import sys, os
2 | import random
3 | import numpy as np
4 | import bz2
5 | sys.path.append(os.getcwd())
6 | from nltk.stem import WordNetLemmatizer
7 | from visual_embeddings_reader import VisualEmbeddingsReader
8 |
9 | class Batcher:
10 |
11 | def __init__(self,
12 | captions_file,
13 | visual_file,
14 | we_dim,
15 | batch_size,
16 | lemmatize,
17 | model):
18 | self._captions = self.read_mscococaptions(captions_file, lemmatize)
19 | self._visual = VisualEmbeddingsReader(visual_file)
20 | self._batch_size = batch_size
21 | self.we_dim = we_dim
22 | self.epoch = 0 #automatically increases when all examples have been seen
23 | self._model=model
24 | self._samples = self._get_samples_coords()
25 | self._offset = 0
26 |
27 | """
28 | Return all samples coordinates,
29 | i.e., [(img_id0,cap0), (img_id0,cap1), ..., (img_id0,cap4), ..., (img_idN,cap4))
30 | """
31 | def _get_samples_coords(self):
32 | all_coords=[]
33 | for img_id in self._visual.visual_embeddings.keys():
34 | all_coords+=[(img_id,cap_id) for cap_id in self.caption_ids(img_id)]
35 | return all_coords
36 |
37 | def next(self):
38 | batch_samples = self._samples[self._offset:self._offset + self._batch_size]
39 | self._offset += self._batch_size
40 | if self._offset > len(self._samples):
41 | self._offset = 0
42 | self.epoch += 1
43 | random.shuffle(self._samples)
44 | if not batch_samples:
45 | return self.next()
46 |
47 | img_labels,caps_pos, pooled_embeddings, visual_embeddings = [], [], [], []
48 | for img_id,cap_pos in batch_samples:
49 | img_labels.append(img_id)
50 | caps_pos.append(cap_pos)
51 | pooled_embeddings.append(self.pool_sentence(self._captions[img_id][cap_pos]))
52 | visual_embeddings.append(self._visual.get(img_id))
53 |
54 | return img_labels, caps_pos, pooled_embeddings, visual_embeddings
55 |
56 | def get_caption_txt(self, img_id, cap_offset):
57 | return self._captions[img_id][cap_offset]
58 |
59 | def get_captions_txt(self, img_id):
60 | return self._captions[img_id]
61 |
62 | def num_captions(self, img_id):
63 | return len(self._captions[img_id])
64 |
65 | def caption_ids(self, img_id):
66 | return range(self.num_captions(img_id))
67 |
68 | def pool_sentence(self, sentence):
69 | pooled = np.zeros(self.we_dim)
70 | items = 0
71 | for w in sentence.split():
72 | if w in self._model:
73 | pooled += self._model[w]
74 | items += 1
75 | if not items:
76 | print('warning: no model found for sentence %s.' %sentence)
77 | return pooled / items if items > 0 else 1
78 |
79 | def read_mscococaptions(self, captions_file, lemmatize=False):
80 | lemmatizer = WordNetLemmatizer() if lemmatize else None
81 | print("Reading captions file <%s>" % captions_file)
82 | captions = dict()
83 | with bz2.BZ2File(captions_file, 'r', buffering=10000000) as fin:
84 | for line in fin:
85 | line = line.decode("utf-8")
86 | fields = line.split("\t")
87 | imageID = int(fields[0])
88 | sentence = fields[2][:-1].lower()
89 | if lemmatize:
90 | sentence = lemmatizer.lemmatize(sentence)
91 | if imageID not in captions:
92 | captions[imageID] = []
93 | captions[imageID].append(sentence)
94 | return captions
95 |
96 |
--------------------------------------------------------------------------------
/src/caption_search_bow_ngram.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import sys
3 | from collections import defaultdict
4 |
5 | import time
6 | from sklearn.preprocessing import normalize
7 |
8 | from evaluation_measure import compute_tdcg
9 | from sklearn.metrics.pairwise import euclidean_distances
10 | from sklearn.feature_extraction.text import TfidfTransformer, CountVectorizer
11 |
12 | stopwords = ['of', 'the', 'a', 'an']
13 |
14 |
15 | class MyBatcher():
16 | def __init__(self, ids, labels, docs):
17 | self._data = defaultdict(list)
18 | for id_, doc in zip(ids, docs):
19 | self._data[id_].append(doc)
20 |
21 | def get_captions_txt(self, id_):
22 | return self._data[id_]
23 |
24 |
25 | def read_data(filename, count=0, filter_ids=None, filter_label=None):
26 | ids = list()
27 | labels = list()
28 | docs = list()
29 | i = 0
30 | with open(filename, 'r') as file:
31 | for line in file:
32 | line = line.strip()
33 | id_, label, text = line.split('\t')[:3]
34 | if (not filter_ids or id_ in filter_ids) and (not filter_label or label in filter_label):
35 | ids.append(id_)
36 | labels.append(label)
37 | docs.append(text)
38 | i += 1
39 | if i == count:
40 | break
41 | return ids, labels, docs
42 |
43 |
44 | if __name__ == '__main__':
45 | parser = argparse.ArgumentParser()
46 | parser.add_argument('-q', '--queries', type=str, default='Test_val2014.sentences.txt')
47 | # parser.add_argument('-c', '--corpus', type=str, default='Test_val2014_generated_captions_neuraltalk_v2.txt')
48 | # parser.add_argument('-c', '--corpus', type=str, default='val2014_generated_captions_show_and_tell_iter_1M.txt')
49 | parser.add_argument('-c', '--corpus', type=str, default='val2014_generated_captions_show_and_tell_iter_2M.txt')
50 | args = parser.parse_args(sys.argv[1:])
51 |
52 | # count = 400
53 | # query_ids, query_labels, queries = read_data(args.queries, count)
54 | # ids, labels, docs = read_data(args.corpus, count, set(query_ids))
55 |
56 | # query_ids, query_labels, queries = read_data(args.queries)
57 | # ids, labels, docs = read_data(args.corpus)
58 |
59 | # query_ids, query_labels, queries = read_data(args.queries, count)
60 | # ids, labels, docs = read_data(args.corpus, count, set(query_ids))
61 |
62 | query_ids, query_labels, queries = read_data(args.queries)
63 | ids, labels, docs = read_data(args.corpus, filter_ids=set(query_ids), filter_label=['0'])
64 |
65 | print('queries', set(query_labels), len(query_ids))
66 | print('corpus', set(labels), len(ids))
67 |
68 | batcher = MyBatcher(query_ids, query_labels, queries)
69 |
70 | # vectorizer = CountVectorizer(min_df=1, stop_words=stopwords)
71 | # vectorizer = CountVectorizer(min_df=5, stop_words=stopwords)
72 | vectorizer = CountVectorizer(min_df=5, stop_words=stopwords, ngram_range=(3, 4), analyzer='char')
73 | tfidfer = TfidfTransformer(norm='l2')
74 |
75 | tfidf_corpus = tfidfer.fit_transform(vectorizer.fit_transform(docs))
76 |
77 | tfidf_queries = tfidfer.transform(vectorizer.transform(queries))
78 |
79 | tfidf_corpus = normalize(tfidf_corpus, norm='l2', axis=1, copy=True)
80 | tfidf_queries = normalize(tfidf_queries, norm='l2', axis=1, copy=True)
81 |
82 | k = min(25, tfidf_corpus.shape[0])
83 |
84 | print(tfidf_corpus.shape)
85 | print(tfidf_queries.shape)
86 |
87 | start = time.time()
88 |
89 | atdcg = 0.0
90 | block_size = 1000
91 | blocks = 0
92 | block = tfidf_queries[block_size * blocks:block_size * (blocks + 1)]
93 | while blocks < 1:
94 | blocks += 1
95 | if block_size * blocks <= tfidf_queries.shape[0]:
96 | block = tfidf_queries[block_size * blocks:min(block_size * (blocks + 1), tfidf_queries.shape[0])]
97 | block_distances = euclidean_distances(block, tfidf_corpus)
98 |
99 | for query_id, query, query_distances in zip(
100 | query_ids[block_size * blocks:min(block_size * (blocks + 1), tfidf_queries.shape[0])],
101 | queries[block_size * blocks:min(block_size * (blocks + 1), tfidf_queries.shape[0])],
102 | block_distances):
103 | ranks = [id_ for (distance, id_) in sorted(zip(query_distances, ids))]
104 | value = compute_tdcg(batcher, query, ranks, sorted(query_distances), k)
105 | atdcg += value
106 | else:
107 | break
108 |
109 | done = time.time()
110 | elapsed = done - start
111 | print(elapsed)
112 |
113 | atdcg /= len(queries)
114 |
115 | print('eucl ATDCG', atdcg)
116 |
--------------------------------------------------------------------------------
/src/caption_search_rouge.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import sys
3 | from collections import defaultdict
4 | from multiprocessing import Pool, Process, Manager
5 | import functools
6 | import time
7 |
8 | from evaluation_measure import compute_tdcg, calc_rouge
9 |
10 | stopwords = ['of', 'the', 'a', 'an']
11 |
12 |
13 | class MyBatcher():
14 | def __init__(self, ids, labels, docs):
15 | self._data = defaultdict(list)
16 | for id_, doc in zip(ids, docs):
17 | self._data[id_].append(doc)
18 |
19 | def get_captions_txt(self, id_):
20 | return self._data[id_]
21 |
22 |
23 | def read_data(filename, count=0, filter_ids=None, filter_label=None):
24 | ids = list()
25 | labels = list()
26 | docs = list()
27 | i = 0
28 | with open(filename, 'r') as file:
29 | for line in file:
30 | line = line.strip()
31 | id_, label, text = line.split('\t')[:3]
32 | if (not filter_ids or id_ in filter_ids) and (not filter_label or label in filter_label):
33 | ids.append(id_)
34 | labels.append(label)
35 | docs.append(text)
36 | i += 1
37 | if i == count:
38 | break
39 | return ids, labels, docs
40 |
41 |
42 | def process_query(query, docs, ids, batcher, k, reverse, q):
43 | rouges = list()
44 | for doc in docs:
45 | rouges.append(calc_rouge([query], [doc]))
46 | ranks = [id_ for (distance, id_) in sorted(zip(rouges, ids), reverse=reverse)]
47 | tdcg = compute_tdcg(batcher, query, ranks, sorted(rouges, reverse=reverse), k)
48 | q.put(tdcg)
49 | return tdcg
50 |
51 |
52 | def incremental_average(q):
53 | count = 0
54 | sum = 0
55 | for value in iter(q.get, None):
56 | count += 1
57 | sum += value
58 | print(sum / count, count, flush=True)
59 |
60 |
61 | if __name__ == '__main__':
62 | parser = argparse.ArgumentParser()
63 | parser.add_argument('-q', '--queries', type=str, default='Test_val2014.sentences.txt')
64 | # parser.add_argument('-c', '--corpus', type=str, default='Test_val2014_generated_captions_neuraltalk_v2.txt')
65 | # parser.add_argument('-c', '--corpus', type=str, default='val2014_generated_captions_show_and_tell_iter_1M.txt')
66 | parser.add_argument('-c', '--corpus', type=str, default='val2014_generated_captions_show_and_tell_iter_2M.txt')
67 | args = parser.parse_args(sys.argv[1:])
68 |
69 | # count = 400
70 | # query_ids, query_labels, queries = read_data(args.queries, count)
71 | # ids, labels, docs = read_data(args.corpus, count, set(query_ids))
72 |
73 | # query_ids, query_labels, queries = read_data(args.queries)
74 | # ids, labels, docs = read_data(args.corpus)
75 |
76 | query_ids, query_labels, queries = read_data(args.queries)
77 | ids, labels, docs = read_data(args.corpus, filter_ids=set(query_ids), filter_label=['0'])
78 |
79 | print('queries', set(query_labels), len(query_ids))
80 | print('corpus', set(labels), len(ids))
81 |
82 | batcher = MyBatcher(query_ids, query_labels, queries)
83 |
84 | k = 25
85 | den = len(queries)
86 | processes = 1
87 | reverse = True
88 |
89 | manager = Manager()
90 | q = manager.Queue()
91 |
92 | p = Process(target=incremental_average, args=(q,))
93 | p.start()
94 |
95 | start = time.time()
96 |
97 | pool = Pool(processes=processes)
98 | ret = pool.map(functools.partial(process_query, docs=docs, ids=ids, batcher=batcher, k=k, reverse=reverse, q=q),
99 | queries[:den])
100 | print('len(ret)', len(ret))
101 | atdcg = sum(ret)
102 |
103 | atdcg /= den
104 |
105 | done = time.time()
106 | elapsed = done - start
107 | print(elapsed)
108 | print('roug ATDCG', atdcg)
109 |
110 | q.put(None)
111 |
--------------------------------------------------------------------------------
/src/evaluation_measure.py:
--------------------------------------------------------------------------------
1 | import sys, math
2 | from io import open
3 | from pca_reader import PCAprojector
4 | from sklearn.neighbors import NearestNeighbors
5 | import sklearn
6 |
7 | #code from https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/rouge/rouge.py
8 | def my_lcs(string, sub):
9 | """
10 | Calculates longest common subsequence for a pair of tokenized strings
11 | :param string : list of str : tokens from a string split using whitespace
12 | :param sub : list of str : shorter string, also split using whitespace
13 | :returns: length (list of int): length of the longest common subsequence between the two strings
14 |
15 | Note: my_lcs only gives length of the longest common subsequence, not the actual LCS
16 | """
17 | if(len(string)< len(sub)):
18 | sub, string = string, sub
19 |
20 | lengths = [[0 for i in range(0,len(sub)+1)] for j in range(0,len(string)+1)]
21 |
22 | for j in range(1,len(sub)+1):
23 | for i in range(1,len(string)+1):
24 | if(string[i-1] == sub[j-1]):
25 | lengths[i][j] = lengths[i-1][j-1] + 1
26 | else:
27 | lengths[i][j] = max(lengths[i-1][j] , lengths[i][j-1])
28 |
29 | return lengths[len(string)][len(sub)]
30 |
31 | #code from https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/rouge/rouge.py
32 | def calc_rouge(candidate, refs, beta = 1.2):
33 | """
34 | Compute ROUGE-L score given one candidate and references for an image
35 | :param candidate: str : candidate sentence to be evaluated
36 | :param refs: list of str : COCO reference sentences for the particular image to be evaluated
37 | :returns score: float (ROUGE-L score for the candidate evaluated against references)
38 | """
39 | assert(len(candidate)==1)
40 | assert(len(refs)>0)
41 | prec = []
42 | rec = []
43 |
44 | # split into tokens
45 | token_c = candidate[0].split(" ")
46 |
47 | for reference in refs:
48 | # split into tokens
49 | token_r = reference.split(" ")
50 | # compute the longest common subsequence
51 | lcs = my_lcs(token_r, token_c)
52 | prec.append(lcs/float(len(token_c)))
53 | rec.append(lcs/float(len(token_r)))
54 |
55 | prec_max = max(prec)
56 | rec_max = max(rec)
57 |
58 | if(prec_max!=0 and rec_max !=0):
59 | score = ((1 + beta**2)*prec_max*rec_max)/float(rec_max + beta**2*prec_max)
60 | else:
61 | score = 0.0
62 | return score
63 |
64 | def compute_dcg(batcher, input_sentente, ranks, k):
65 | dcg = 0.
66 | assert (len(ranks) >= k)
67 | for pos,ranki_id in enumerate(ranks[:k]):
68 | i = pos+1
69 | reference_sentences = batcher.get_captions_txt(ranki_id)
70 | rel_i = calc_rouge([input_sentente], reference_sentences) # compute relevance as the Rogue
71 | dcg += (math.pow(2., rel_i) - 1.) / math.log(i + 1., 2.)
72 | return dcg
73 |
74 |
75 | def compute_tdcg(batcher, input_sentence, ranks, distances, k):
76 | tdcg = 0.0
77 | assert (len(ranks) >= k)
78 | last_distance = distances[0]
79 | tie_size = 0
80 | gain = 0.0
81 | for pos, ranki_id in enumerate(ranks[:k]):
82 |
83 | distance = distances[pos]
84 | if distance == last_distance:
85 | tie_size += 1
86 | else:
87 | i = pos + 1
88 | for j in range(i - tie_size, i):
89 | tdcg += (math.pow(2.0, gain / tie_size) - 1.0) / math.log(j + 1.0, 2.0)
90 | last_distance = distance
91 | tie_size = 1
92 | gain = 0.0
93 | reference_sentences = batcher.get_captions_txt(ranki_id)
94 | gain += calc_rouge([input_sentence], reference_sentences) # compute relevance as the Rogue
95 | for j in range(k - tie_size + 1, k + 1):
96 | tdcg += (math.pow(2.0, gain / tie_size) - 1.0) / math.log(j + 1.0, 2.0)
97 |
98 | return tdcg
99 |
100 | #method should be one among 'auto', 'pca', 'cosine'
101 | def evaluation(test_batches, visual_ids, visual_vectors,
102 | predictions, test_img_ids, test_cap_id,
103 | predictions_file,
104 | method='auto',
105 | mean_file=None, eigen_file=None, pca_num_eig=256, test_loss=None, find_nearest = 25, save_predictions=True):
106 |
107 | if not method in ['pca','cosine']:
108 | print("Error: method should be one among 'pca','cosine' [Abort]")
109 | sys.exit()
110 |
111 | proc_predictions = predictions
112 | nbrs = None
113 |
114 | if method == 'pca':
115 | print('Normalizing visual features...')
116 | sklearn.preprocessing.normalize(visual_vectors, norm='l2', axis=1, copy=False)
117 | proc_predictions = sklearn.preprocessing.normalize(predictions, norm='l2', axis=1, copy=True)
118 |
119 | print('Projecting with PCA')
120 | pca = PCAprojector(mean_file, eigen_file, visual_vectors.shape[1], num_eig=pca_num_eig)
121 | visual_vectors = pca.project(visual_vectors)
122 | proc_predictions = pca.project(proc_predictions)
123 |
124 | nbrs = NearestNeighbors(n_neighbors=find_nearest, n_jobs=-1).fit(visual_vectors)
125 | else:
126 | nbrs = NearestNeighbors(n_neighbors=find_nearest, n_jobs=8, algorithm='brute', metric='cosine').fit(visual_vectors)
127 |
128 | print('Getting nearest neighbors...')
129 | _, indices = nbrs.kneighbors(proc_predictions)
130 |
131 | print('Getting DCG_rouge...')
132 | dcg_rouge_ave = 0
133 | tests_processed = 0
134 | with open(predictions_file, 'w', encoding="utf-8", buffering=100000000) as vis:
135 | for i in xrange(len(predictions)):
136 | pred_str = (' '.join(('%.3f' % x) for x in predictions[i])).replace(" 0.000", " 0") if save_predictions else ''
137 | img_id = test_img_ids[i]
138 | cap_id = test_cap_id[i]
139 | cap_txt = test_batches.get_caption_txt(img_id,cap_id)
140 | nneigbours_ids = visual_ids[indices[i]]
141 | nneigbours_ids_str = ' '.join([("%d"%x) for x in nneigbours_ids]) if save_predictions else ''
142 | dcg_rouge = compute_dcg(batcher=test_batches, input_sentente = cap_txt, ranks = nneigbours_ids, k=find_nearest)
143 | dcg_rouge_ave += dcg_rouge
144 | if save_predictions:
145 | vis.write("%s\t%d\t%s\t%s\t%s\t%0.4f\n" % (img_id, cap_id, cap_txt, pred_str, nneigbours_ids_str, dcg_rouge))
146 | tests_processed += 1
147 | if tests_processed % 1000 == 0:
148 | print('Processed %d predictions. DCGAve=%f' % (tests_processed, dcg_rouge_ave / tests_processed))
149 | dcg_rouge_ave /= tests_processed
150 |
151 | vis.write(u'Test completed: %s DCGrouge=%.4f\n' % (test_loss, dcg_rouge_ave))
152 | print('Test completed: %s DCGrouge=%.4f' % (test_loss, dcg_rouge_ave))
153 |
154 |
--------------------------------------------------------------------------------
/src/flags.py:
--------------------------------------------------------------------------------
1 | from helpers import err_exit
2 |
3 |
4 | def define_commom_flags(flags, num_steps=50001, summary_frequency=10):
5 | flags.DEFINE_boolean('debug', False, 'Activates the debug mode.')
6 |
7 | # training settings
8 | flags.DEFINE_integer('num_steps', num_steps, 'Maximum number of training steps (default '+str(num_steps)+').')
9 | flags.DEFINE_integer('batch_size', 64, 'Batch size (default 64).')
10 | flags.DEFINE_integer('summary_frequency', summary_frequency, 'Determines the number of steps after which to show the summaries (default '+str(summary_frequency)+').')
11 | flags.DEFINE_integer('validation_frequency', 1000, 'Number of steps after which to validate next batch (default 1000).')
12 | flags.DEFINE_integer('validation_batch', 64, 'Batch size for the validation set (default 64).')
13 |
14 | # net settings
15 | flags.DEFINE_integer('fc_layer', 6, 'fc feature layer from the AlexNet. "6" for fc6 (default), "7" for fc7.')
16 | flags.DEFINE_integer('visual_dim', 4096, 'Dimensionality of the visual embeddings space (default 4096).')
17 | flags.DEFINE_boolean('train', True, 'Set the model to be trained (default True).')
18 | flags.DEFINE_boolean('test', True, 'Set the model to be tested (after training, if eventually activated; Default True).')
19 | flags.DEFINE_boolean('boarddata', True, 'Set to False to desactivate the Tensorboard data generation (default True).')
20 | flags.DEFINE_string('run', '', 'Specifies a name for the run (defaults to the date and time when it is run).')
21 | flags.DEFINE_boolean('save_predictions', True, 'Set to True to save all predictions on a file.')
22 | flags.DEFINE_string('checkpoint', None, 'Path for a custom checkpoint file to restore the net parameters.')
23 | flags.DEFINE_string('retrieval', 'pca', "Chooses the retrieval algorithm ['pca'(default), 'cosine']")
24 |
25 | FLAGS = flags.FLAGS
26 | err_exit(FLAGS.fc_layer not in [6, 7], "Error, parameter 'fc_layer' should be either '6' or '7'.")
27 | if not FLAGS.checkpoint is None:
28 | err_exit(FLAGS.train != False or FLAGS.test == False, "Error, a checkpoint for testing was especified but the run is not"
29 | " set for testing. Use --notrain --test when specifying a checkpoint.")
30 |
31 | return flags
--------------------------------------------------------------------------------
/src/helpers.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import tensorflow as tf
3 | import signal
4 | import math, random
5 | import shutil
6 |
7 | #--------------------------------------------------------------
8 | # Model helpers
9 | #--------------------------------------------------------------
10 | """
11 | This class abstracts the model losses into a single instance, to ease its cummulation and average.
12 | """
13 | class model_losses:
14 | def __init__(self, name=''):
15 | self.clear()
16 | self.__name__ = name
17 |
18 | def accum(self, l, llstm, lvisual, items=1):
19 | self.l+=l*items
20 | self.llstm+=llstm*items
21 | self.lvisual+=lvisual*items
22 | self.items+=items
23 |
24 | def mean(self):
25 | l = self.l / self.items
26 | llstm = self.llstm / self.items
27 | lvisual = self.lvisual / self.items
28 | return l, llstm, lvisual
29 |
30 | def mean_flush(self):
31 | means = self.mean()
32 | self.clear()
33 | return means
34 |
35 | def clear(self):
36 | self.l = self.llstm = self.lvisual = self.items = 0
37 |
38 | def __str__(self):
39 | return self.__name__+ ('[L=%.04f Ll=%.04f Lv%.04f]' % self.mean())
40 |
41 | class Model:
42 | def __init__(self, bucket_size,
43 | train_tokens, train_labels, visual_outputs,
44 | visual_prediction,
45 | loss, lstm_loss, visual_loss,
46 | optimizer_loss, optimizer_lstm_loss, optimizer_visual_loss,
47 | summaries):
48 | self.bucket_size = bucket_size
49 | self.train_tokens = train_tokens
50 | self.train_labels = train_labels
51 | self.visual_outputs = visual_outputs
52 | self.visual_prediction = visual_prediction
53 | self.loss = loss
54 | self.lstm_loss = lstm_loss
55 | self.visual_loss = visual_loss
56 | self.optimizer_loss = optimizer_loss
57 | self.optimizer_lstm_loss = optimizer_lstm_loss
58 | self.optimizer_visual_loss = optimizer_visual_loss
59 | self.summaries = summaries
60 |
61 | def variable_summaries(var, scope_name, name):
62 | """Attach summaries to a Tensor."""
63 | embeddings_summaries = []
64 | with tf.name_scope(scope_name):
65 | mean = tf.reduce_mean(var)
66 | embeddings_summaries.append(tf.scalar_summary(scope_name + '/mean', mean))
67 | with tf.name_scope('stddev'):
68 | stddev = tf.sqrt(tf.reduce_sum(tf.square(var - mean)))
69 | embeddings_summaries.append(tf.scalar_summary(scope_name + '/sttdev', stddev))
70 | embeddings_summaries.append(tf.scalar_summary(scope_name + '/max', tf.reduce_max(var)))
71 | embeddings_summaries.append(tf.scalar_summary(scope_name + '/min', tf.reduce_min(var)))
72 | embeddings_summaries.append(tf.histogram_summary(scope_name + '/' + name, var))
73 | return embeddings_summaries
74 |
75 | def savemodel(session, step, saver, checkpoint_dir, run_name, posfix=""):
76 | sys.stdout.write('Saving model...')
77 | sys.stdout.flush()
78 | save_path = saver.save(session, checkpoint_dir + '/' + run_name + posfix, global_step=step + 1)
79 | print('[Done]')
80 | return save_path
81 |
82 | def projection_weights(orig_size, target_size, name):
83 | weight = tf.Variable(tf.truncated_normal([orig_size, target_size], stddev=1.0 / math.sqrt(target_size)), name=name + '/weight')
84 | bias = tf.Variable(tf.zeros([target_size]), name=name + '/bias')
85 | return weight, bias
86 |
87 |
88 | #--------------------------------------------------------------
89 | # Run helpers
90 | #--------------------------------------------------------------
91 | def err_exit(exit_condition=True, err_msg=""):
92 | if exit_condition:
93 | if not err_msg:
94 | err_msg = (sys.argv[0]+ " Error")
95 | print(err_msg)
96 | sys.exit()
97 |
98 | def notexist_exit(path):
99 | if isinstance(path, list):
100 | [notexist_exit(p) for p in path]
101 | elif not os.path.exists(path):
102 | print("Error. Path <%s> does not exist or is not accessible." %path)
103 | sys.exit()
104 |
105 | def create_if_not_exists(dir):
106 | if not os.path.exists(dir): os.makedirs(dir)
107 | return dir
108 |
109 | def restore_checkpoint(saver, session, checkpoint_dir, checkpoint_path=None):
110 | if not checkpoint_path:
111 | print('Restoring last checkpoint in %s' % checkpoint_dir)
112 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
113 | err_exit(not (ckpt and ckpt.model_checkpoint_path),
114 | 'Error: checkpoint directory %s not found or accessible.' % ckpt.model_checkpoint_path)
115 | saver.restore(session, ckpt.model_checkpoint_path)
116 | else:
117 | print('Restoring checkpoint %s' % os.path.join(checkpoint_dir,checkpoint_path))
118 | saver.restore(session, os.path.join(checkpoint_dir,checkpoint_path))
119 |
120 | class TensorboardData:
121 | def __init__(self, generate_tensorboard_data):
122 | self.generate_tensorboard_data=generate_tensorboard_data
123 |
124 | def open(self, summaries_dir, run_name, graph):
125 | if self.generate_tensorboard_data:
126 | train_path = summaries_dir + '/train_' + run_name
127 | valid_path = summaries_dir + '/valid_' + run_name
128 | if os.path.exists(train_path): shutil.rmtree(train_path)
129 | if os.path.exists(valid_path): shutil.rmtree(valid_path)
130 | self.train_writer = tf.train.SummaryWriter(summaries_dir + '/train_' + run_name, graph, flush_secs=30)
131 | self.valid_writer = tf.train.SummaryWriter(summaries_dir + '/valid_' + run_name, graph, flush_secs=120)
132 |
133 | def add_train_summary(self, summary, step):
134 | if self.generate_tensorboard_data:
135 | self.train_writer.add_summary(summary, step)
136 |
137 | def add_valid_summary(self, summary, step):
138 | if self.generate_tensorboard_data:
139 | self.valid_writer.add_summary(summary, step)
140 |
141 | def close(self):
142 | if self.generate_tensorboard_data:
143 | self.train_writer.close()
144 | self.valid_writer.close()
145 |
146 |
147 | def get_optimizer(op_loss, op_vis, op_tex, stochastic_loss = False, visual_prob=0.5):
148 | if not stochastic_loss: return op_loss
149 | return op_vis if random.random() < visual_prob else op_tex
150 |
151 | class InteractiveOptions:
152 | SaveQuit, Quit, Save, Continue, SaveGoTest, GoTest = range(6)
153 | _available_signals = set({SaveQuit, Quit, Save, Continue, SaveGoTest, GoTest})
154 |
155 | def __init__(self):
156 | self.status = self.Continue
157 | signal.signal(signal.SIGINT, self._signal_handler)
158 |
159 | def _signal_handler(self, signal, frame):
160 | self.status = self.get_option_keyboard()
161 |
162 | def get_option_keyboard(self):
163 | print('Process interrupted by user:')
164 | print('\t' + str(self.SaveQuit) + '. Save model and quit.')
165 | print('\t' + str(self.Quit) + '. Quit witout saving.')
166 | print('\t' + str(self.Save) + '. Save model (posfix "_usersaved").')
167 | print('\t' + str(self.Continue) + '. Continue.')
168 | print('\t' + str(self.SaveGoTest) + '. Save and Test the model.')
169 | print('\t' + str(self.GoTest) + '. Test the last saved model.')
170 | option = int(raw_input('Choose option [1-5]:'))
171 | if option not in self._available_signals:
172 | print("Wrong option.")
173 | option = self.get_option_keyboard()
174 | return option
175 |
176 |
177 |
178 |
179 |
180 |
--------------------------------------------------------------------------------
/src/lstm_text2vis.py:
--------------------------------------------------------------------------------
1 | import sys, os, getopt
2 | import random
3 | import time
4 | import numpy as np
5 | import tensorflow as tf
6 | from batcher_lstmbased import Batcher
7 | from evaluation_measure import evaluation
8 | from helpers import *
9 | from paths import PATH
10 | from flags import define_commom_flags
11 |
12 |
13 | def main(argv=None):
14 |
15 | err_exit(argv[1:], "Error in parameters %s (--help for documentation)." % argv[1:])
16 |
17 | Path = PATH(FLAGS.fc_layer, FLAGS.debug)
18 |
19 | do_training = FLAGS.train
20 | do_test = FLAGS.test
21 | run_name = FLAGS.run
22 | if not run_name:
23 | run_name=time.strftime("%d_%b_%Y")+'::'+time.strftime("%H:%M:%Sh")
24 | if FLAGS.debug:
25 | run_name+='-debug'
26 |
27 | sys.stdout = os.fdopen(sys.stdout.fileno(), 'w', 0) # set stdout to unbuffered
28 |
29 | #---------------------------------------------------------------------
30 | # Inits
31 | #---------------------------------------------------------------------
32 |
33 | buckets_def = [15, 20, 40] if not FLAGS.debug else [15]
34 | batches = Batcher(Path.tr_captions_file, Path.tr_visual_embeddings_file, buckets_def=buckets_def, batch_size=FLAGS.batch_size, lemma=FLAGS.lemma, dowide=FLAGS.dowide)
35 |
36 | vocabulary_size=batches.vocabulary_size()
37 | print('Vocabulary size %d' % vocabulary_size)
38 |
39 | graph = tf.Graph()
40 | with graph.as_default():
41 | # Placeholders
42 | train_wide1hot = tf.placeholder(tf.float32, shape=[None, vocabulary_size], name='wide_1hot') if FLAGS.dowide else None
43 |
44 | # Model parameters
45 | embeddings = tf.Variable(tf.random_uniform([vocabulary_size, FLAGS.embedding_size], -1.0, 1.0), name='word_embeddings')
46 | state2text_weight, state2text_bias = projection_weights(FLAGS.num_nodes, vocabulary_size, 'output-text')
47 | state2vis_weight, state2vis_bias = projection_weights(FLAGS.num_nodes*FLAGS.lstm_stacked_layers, FLAGS.visual_dim, 'hidden-visual')
48 | if FLAGS.dowide:
49 | wide_weight, wide_bias = projection_weights(vocabulary_size, FLAGS.visual_dim, 'wide-projection')
50 | else:
51 | wide_weight, wide_bias = None,None
52 |
53 | # Bucket-independent computations
54 | wide_prediction = (tf.matmul(train_wide1hot, wide_weight) + wide_bias) if FLAGS.dowide else None
55 |
56 | # ----------------------------------------------------------------------------------------------
57 | # Generate a bucket-specific unrolled net
58 | def bucket_net(bucket_size, stacked_lstm):
59 | # ----------------------------------------------------------------------------------------------
60 | # Placeholders
61 | train_tokens = list()
62 | train_labels = list()
63 | for i in range(bucket_size-1):
64 | train_tokens.append(tf.placeholder(tf.int64, shape=[None], name='x_'+str(i)))
65 | train_labels.append(tf.placeholder(tf.float32, shape=[None, vocabulary_size], name='x_'+str(i+1)))
66 |
67 | embedding_inputs = list()
68 | for i in range(len(train_tokens)):
69 | embedding_inputs.append(tf.nn.embedding_lookup(embeddings, train_tokens[i]))
70 |
71 | visual_outputs = tf.placeholder(tf.float32, shape=[None, FLAGS.visual_dim], name='visual-embedding')
72 |
73 | # ----------------------------------------------------------------------------------------------
74 | # Unrolled LSTM loop.
75 | outputs, final_state = tf.nn.rnn(stacked_lstm, embedding_inputs, dtype=tf.float32)
76 | final_state = tf.concat(1,[_state.c for _state in final_state]) if FLAGS.lstm_stacked_layers > 1 else final_state.c
77 | tf.get_variable_scope().reuse_variables()
78 | logits = tf.matmul(tf.concat(0, outputs), state2text_weight) + state2text_bias
79 | deep_prediction = tf.matmul(final_state, state2vis_weight) + state2vis_bias
80 | visual_prediction = tf.nn.relu(wide_prediction+deep_prediction) if FLAGS.dowide else tf.nn.relu(deep_prediction)
81 |
82 | # ----------------------------------------------------------------------------------------------
83 | # Losses
84 | lstm_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits, tf.concat(0, train_labels)))
85 | visual_loss = tf.reduce_mean(tf.square(visual_prediction - visual_outputs))
86 | loss = lstm_loss + visual_loss
87 |
88 | # ----------------------------------------------------------------------------------------------
89 | # Tensorboard data: loss summaries
90 | if FLAGS.boarddata:
91 | lstm_loss_summary = tf.scalar_summary('loss/lstm_loss', lstm_loss)
92 | visual_loss_summary = tf.scalar_summary('loss/visual_loss', visual_loss)
93 | loss_summary = tf.scalar_summary('loss/loss', loss)
94 | summaries = tf.merge_summary([loss_summary, lstm_loss_summary, visual_loss_summary])
95 | else: summaries = None
96 |
97 | #----------------------------------------------------------------------------------------------
98 | # Optimizer.
99 | def optimizer(someloss):
100 | global_step = tf.Variable(0)
101 | optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
102 | gradients, v = zip(*optimizer.compute_gradients(someloss))
103 | gradients, _ = tf.clip_by_global_norm(gradients, 1.25)
104 | optimizer = optimizer.apply_gradients(zip(gradients, v), global_step=global_step)
105 | return optimizer
106 |
107 | return Model(bucket_size, train_tokens, train_labels, visual_outputs, visual_prediction, loss, lstm_loss, visual_loss,
108 | optimizer(loss), optimizer(lstm_loss), optimizer(visual_loss), summaries)
109 |
110 | # ----------------------------------------------------------------------------------------------
111 | # Defines the computation graph
112 | print('Creating bucket-specific unrolled nets:')
113 | models = dict()
114 | lstm_cell = tf.nn.rnn_cell.LSTMCell(num_units=FLAGS.num_nodes, use_peepholes=True, state_is_tuple=True)
115 | if FLAGS.lstm_stacked_layers > 1:
116 | lstm_cell = tf.nn.rnn_cell.MultiRNNCell([lstm_cell] * FLAGS.lstm_stacked_layers)
117 | for bucket_i in buckets_def:
118 | with tf.name_scope('bucket-net-'+str(bucket_i)):
119 | models[bucket_i] = bucket_net(bucket_i, lstm_cell)
120 | print('\tcreated model for bucket %d' % bucket_i)
121 |
122 | # ----------------------------------------------------------------------------------------------
123 | # Add ops to save and restore all the variables.
124 | saver = tf.train.Saver(max_to_keep=1) #defaults are: max_to_keep=5, keep_checkpoint_every_n_hours=10000.0
125 |
126 | #---------------------------------------------------------------------
127 | # Model Params
128 | #---------------------------------------------------------------------
129 |
130 | def get_model_params(batches_):
131 | img_labels, caps_pos, caps, wide, viss, bucket_size = batches_.next()
132 | model = models[bucket_size]
133 | params = dict()
134 | if FLAGS.dowide: params[train_wide1hot] = wide
135 | for i in range(bucket_size - 1):
136 | params[model.train_tokens[i]] = caps[:, i]
137 | params[model.train_labels[i]] = batches_.from_batchlabel2batch_onehot(caps[:, i + 1])
138 | params[model.visual_outputs] = viss
139 | return model, params, img_labels, caps_pos
140 |
141 | # ---------------------------------------------------------------------
142 | # Graph run
143 | # ---------------------------------------------------------------------
144 |
145 | with tf.Session(graph=graph) as session:
146 | tf.initialize_all_variables().run()
147 |
148 | # Train the net
149 | if do_training:
150 | #interactive mode: allows the user to save & quit or quit w/o saving
151 | interactive = InteractiveOptions()
152 | tensorboard = TensorboardData(FLAGS.boarddata)
153 |
154 | val_batches = Batcher(Path.val_caption_file, Path.val_visual_embeddings_file,
155 | buckets_def=buckets_def, batch_size=FLAGS.validation_batch,
156 | lemma = FLAGS.lemma,
157 | dowide = FLAGS.dowide,
158 | word2id=batches.get_word2id(), id2word=batches.get_id2word())
159 |
160 | tensorboard.open(Path.summaries_dir, run_name, session.graph)
161 |
162 | best_val, val_lv = None, None
163 | tr_losses = model_losses('tr')
164 | val_losses = model_losses('val')
165 | last_improve = 0
166 | for step in range(1, FLAGS.num_steps):
167 | tr_model,tr_feed_dict,img_labels,caps_pos = get_model_params(batches)
168 | optimizer = get_optimizer(op_loss=tr_model.optimizer_loss, op_vis=tr_model.optimizer_visual_loss, op_tex=tr_model.optimizer_lstm_loss,
169 | stochastic_loss=FLAGS.stochastic_loss, visual_prob=FLAGS.stochastic_visual_prob)
170 |
171 | _,tr_l,tr_ll,tr_lv = session.run([optimizer, tr_model.loss, tr_model.lstm_loss, tr_model.visual_loss], feed_dict=tr_feed_dict)
172 | tr_losses.accum(tr_l,tr_ll,tr_lv, len(img_labels))
173 | tensorboard.add_train_summary(tr_model.summaries.eval(feed_dict=tr_feed_dict), step)
174 |
175 | if step % FLAGS.summary_frequency == 0:
176 | print('[Step %d][Bucket=%d] %s' % (step, tr_model.bucket_size, tr_losses))
177 | tr_losses.clear()
178 |
179 | if step % FLAGS.validation_frequency == 0:
180 | val_losses.clear()
181 | samples_proc = 0
182 | val_epoch = val_batches.epoch
183 | while True:
184 | val_model, val_feed_dict,img_ids,cap_ids = get_model_params(val_batches)
185 | val_l, val_ll, val_lv = session.run([val_model.loss, val_model.lstm_loss, val_model.visual_loss], feed_dict=val_feed_dict)
186 | val_losses.accum(val_l,val_ll,val_lv,len(img_ids))
187 | samples_proc+=len(img_ids)
188 | if val_batches.epoch != val_epoch: break
189 | print('\t[Step %d] %s' % (step, val_losses))
190 | _,_,valid_loss = val_losses.mean_flush()
191 | tensorboard.add_valid_summary(val_model.summaries.eval(feed_dict=val_feed_dict), step)
192 |
193 | if not best_val or valid_loss < best_val:
194 | best_val = valid_loss
195 | last_improve = step
196 | if step > 2500:
197 | savemodel(session, step, saver, Path.checkpoint_dir, run_name)
198 |
199 | if interactive.status != interactive.Continue:
200 | if interactive.status in [interactive.SaveQuit, interactive.Save, interactive.SaveGoTest]:
201 | savemodel(session, step, saver, Path.checkpoint_dir, run_name, posfix='_usersaved')
202 | if interactive.status in [interactive.SaveQuit, interactive.Quit]:
203 | do_test = False
204 | break
205 | if interactive.status in [interactive.GoTest, interactive.SaveGoTest]:
206 | do_test = True
207 | break
208 | interactive.status = interactive.Continue
209 |
210 | # early-stop condition
211 | if step - last_improve >= 10000:
212 | print ("Early stop at step %d" % step)
213 | break
214 |
215 | tensorboard.close()
216 |
217 | # Test the net
218 | if do_test:
219 | restore_checkpoint(saver, session, Path.checkpoint_dir, checkpoint_path=FLAGS.checkpoint)
220 |
221 | print('Starts evaluation...')
222 | test_batches = Batcher(Path.test_caption_file, Path.test_visual_embeddings_file,
223 | buckets_def=buckets_def, batch_size=FLAGS.batch_size,
224 | lemma=FLAGS.lemma,
225 | dowide=FLAGS.dowide,
226 | word2id=batches.get_word2id(), id2word=batches.get_id2word())
227 |
228 | print('Getting predictions...')
229 | test_img_ids, test_cap_id = [], []
230 | buckets_processed, samples_processed = 0, 0
231 | test_losses = model_losses('test')
232 | predictions = []
233 | while test_batches.epoch == 0:
234 | test_model, test_feed_dict, img_ids, cap_ids = get_model_params(test_batches)
235 | batch_predictions, te_l, te_ll, te_lv = session.run([test_model.visual_prediction,
236 | test_model.loss, test_model.lstm_loss, test_model.visual_loss],
237 | feed_dict=test_feed_dict)
238 | test_losses.accum(te_l, te_ll, te_lv, len(img_ids))
239 | predictions.append(batch_predictions)
240 | test_img_ids += img_ids
241 | test_cap_id += cap_ids
242 | buckets_processed += 1
243 | samples_processed += len(img_ids)
244 | if buckets_processed % 100 == 0:
245 | print('Processed %d examples' % (samples_processed))
246 | predictions = np.concatenate((predictions), axis=0)
247 |
248 | predictions_file = Path.predictions_dir + '/' + run_name + '.txt'
249 | visual_ids, visual_vectors = test_batches._visual.get_all_vectors()
250 |
251 | evaluation(test_batches, visual_ids, visual_vectors, predictions, test_img_ids, test_cap_id, predictions_file,
252 | method=FLAGS.retrieval,
253 | mean_file=Path.mean_file, eigen_file=Path.eigen_file,
254 | test_loss=str(test_losses), save_predictions=FLAGS.save_predictions)
255 |
256 | #-------------------------------------
257 | if __name__ == '__main__':
258 | flags = tf.app.flags
259 | FLAGS = flags.FLAGS
260 |
261 | define_commom_flags(flags, num_steps=50001, summary_frequency=10)
262 |
263 | # net settings
264 | flags.DEFINE_integer('num_nodes', 512, 'Number of nodes of the LSTM internal representation (default 512).')
265 | flags.DEFINE_integer('lstm_stacked_layers', 2, 'Number of stacked layers for the LSTM cell (default 2).')
266 | flags.DEFINE_boolean('stochastic_loss', False, 'Determines if the Stochastic-loss heuristic is activated (default False).')
267 | flags.DEFINE_float('stochastic_visual_prob', 0.5, 'If stochastic_loss is active, determines the probability of optimizing the visual loss (default 0.5).')
268 | flags.DEFINE_boolean('lemma', True, 'Lemmatizes the captions before processing them.')
269 | flags.DEFINE_integer('embedding_size', 100, 'Dimensionality of the word embeddings space (default 100).')
270 |
271 | tf.app.run()
272 |
273 |
--------------------------------------------------------------------------------
/src/mscoco_captions_reader.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import collections
3 | import bz2
4 | import numpy as np
5 | from nltk.stem import WordNetLemmatizer
6 |
7 | class MSCocoCaptions(object):
8 |
9 | def __init__(self, captions_file, max_vocabulary_size=10000, word2id=None, id2word=None, lemma=False):
10 | self.sos_char = '[SOS]' #start of sentence
11 | self.eos_char = '[EOS]' #end of sentence
12 | self.unk_word = '[UNK]' #unknow word
13 | self.pad_word = '[PAD]' #padding token
14 | self.captions, text = self._read_captions(captions_file, lemma)
15 | self.word2id=word2id
16 | self.id2word=id2word
17 | if self.word2id is None or self.id2word is None:
18 | self.word2id, self.id2word = self._build_dataset(max_vocabulary_size, text)
19 | self._index_captions()
20 |
21 | def _from_caption2ids(self, caption):
22 | return [self._fromword2id(x) for x in ((self.sos_char + ' ') + caption + (' ' + self.eos_char) * 2).split()]
23 |
24 | def _from_ids2caption(self, ids):
25 | return ' '.join([self._fromid2word(x) for x in ids])
26 |
27 | def _fromword2id(self, word):
28 | return self.word2id[word] if word in self.word2id else self.word2id[self.unk_word]
29 |
30 | def _fromid2word(self, wordid):
31 | return self.id2word[wordid] if wordid in self.id2word else self.unk_word
32 |
33 | def _index_captions(self):
34 | self.indexed_captions=dict()
35 | for imageID in self.captions.keys():
36 | self.indexed_captions[imageID]=[]
37 | for cap_i in self.captions[imageID]:
38 | cap_ids = self._from_caption2ids(cap_i)
39 | self.indexed_captions[imageID].append(cap_ids)
40 |
41 | def _build_dataset(self, max_vocabulary_size, text):
42 | sys.stdout.write('Building indexes...')
43 | sys.stdout.flush()
44 | tokens = text.split()
45 | count = [[self.unk_word, -1], [self.sos_char, -1], [self.eos_char, -1], [self.pad_word, -1]]
46 | count.extend(collections.Counter(tokens).most_common(max_vocabulary_size-len(count)))
47 | word2id = dict()
48 | for word, _ in count:
49 | word2id[word] = len(word2id)
50 | id2word = dict(zip(word2id.values(), word2id.keys()))
51 | print("[Done]")
52 | return word2id, id2word
53 |
54 | def vocabulary_size(self):
55 | return len(self.id2word)
56 |
57 | def _read_captions(self, captions_file, lemmatize=False):
58 | lemmatizer = WordNetLemmatizer() if lemmatize else None
59 | print("Reading captions file <%s>" % captions_file)
60 | captions=dict()
61 | text = []
62 | with bz2.BZ2File(captions_file, 'r', buffering=10000000) as fin:
63 | for line in fin:
64 | line = line.decode("utf-8")
65 | fields = line.split("\t")
66 | imageID = int(fields[0])
67 | sentence = fields[2][:-1].lower()
68 | if lemmatize:
69 | sentence = lemmatizer.lemmatize(sentence)
70 | if imageID not in captions:
71 | captions[imageID]=[]
72 | captions[imageID].append(sentence)
73 | text.append(sentence)
74 | text=' '.join(text)
75 | return captions, text
76 |
77 | def _words(self, probabilities):
78 | """Turn a 1-hot encoding or a probability distribution over the possible
79 | characters back into its (most likely) word representation."""
80 | return [self.fromid2word(c) for c in np.argmax(probabilities, 1)]
81 |
82 |
83 | def num_images(self):
84 | return len(self.indexed_captions)
85 |
86 | def get_image_ids(self):
87 | return self.indexed_captions.keys()
88 |
89 | def get_captions(self, img_id):
90 | return self.indexed_captions[img_id]
91 |
92 | def get_captions_txt(self, img_id):
93 | return self.captions[img_id]
94 |
95 | def get_caption_txt(self, img_id, cap_pos):
96 | return self.captions[img_id][cap_pos]
97 |
98 |
99 | def get_pad(self):
100 | return self._fromword2id(self.pad_word)
--------------------------------------------------------------------------------
/src/paths.py:
--------------------------------------------------------------------------------
1 | from helpers import notexist_exit, create_if_not_exists
2 |
3 | #-------------------------------------
4 | class PATH(object):
5 | def __init__(self, fclayer, debug_mode=False, use_ngrams=False):
6 | debug='.debug' if debug_mode else ''
7 | ngrams='.ngrams' if use_ngrams else ''
8 |
9 | captions_path = '../captions'
10 | self.tr_captions_file = captions_path+'/train2014.sentences.txt'+debug+ngrams+'.bz2'
11 | self.val_caption_file = captions_path+'/Validation_val2014.sentences.txt'+debug+ngrams+'.bz2'
12 | self.test_caption_file= captions_path+'/Test_val2014.sentences.txt'+debug+ngrams+'.bz2'
13 | notexist_exit([self.tr_captions_file, self.val_caption_file, self.test_caption_file])
14 |
15 | fcx = "fc"+str(fclayer)
16 | visual_path = '../visualembeddings/' + fcx
17 | num_examples= '1' if debug_mode else '20'
18 | self.tr_visual_embeddings_file = visual_path+'/COCO_train2014_hybridCNN_'+fcx+'.sparse'+debug+'.txt'
19 | self.val_visual_embeddings_file = visual_path+'/COCO_val2014_hybridCNN_'+fcx+'.sparse.'+num_examples+'K_Validation'+debug+'.txt'
20 | self.test_visual_embeddings_file= visual_path+'/COCO_val2014_hybridCNN_'+fcx+'.sparse.'+num_examples+'K_Test'+debug+'.txt'
21 | notexist_exit([self.tr_visual_embeddings_file, self.val_visual_embeddings_file, self.test_visual_embeddings_file])
22 |
23 | pca_path = '../pca/'+fcx
24 | self.mean_file = pca_path+'/COCO_train2014_hybridCNN_'+fcx+'_ReLu_L2Norm_PC_from65536.mean.dat.txt'
25 | self.eigen_file = pca_path+'/COCO_train2014_hybridCNN_'+fcx+'_ReLu_L2Norm_PC_from65536.eigen.dat.txt'
26 | notexist_exit([self.mean_file, self.eigen_file])
27 |
28 | self.checkpoint_dir = create_if_not_exists('../models')
29 | self.predictions_dir = create_if_not_exists('../predictions')
30 | self.summaries_dir = create_if_not_exists('../summaries')
31 |
--------------------------------------------------------------------------------
/src/pca_reader.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from io import open
3 |
4 | class PCAprojector():
5 | def __init__(self, mean_file, eigen_file, num_dims, num_eig=256):
6 | self.num_dims=num_dims
7 | self.read_mean(mean_file)
8 | self.read_eigen(eigen_file, num_eig)
9 |
10 | def read_mean(self, mean_file):
11 | with open(mean_file, "r") as mean_row:
12 | mean_str = mean_row.readlines()
13 | assert(len(mean_str)==1)
14 | self.mean = self.read_row(mean_str[0])
15 | print("mean read with shape %s" % str(self.mean.shape))
16 |
17 | def read_eigen(self, eigen_file, num_eig):
18 | self.num_eigen=num_eig
19 | self.eigen = np.empty([self.num_eigen, self.num_dims])
20 | current_row = 0
21 | with open(eigen_file, "r") as eigen_rows:
22 | vectors = eigen_rows.readlines()
23 | assert (len(vectors) >= self.num_eigen)
24 | for vec in vectors[:self.num_eigen]:
25 | self.eigen[current_row]=self.read_row(vec)
26 | current_row += 1
27 | print("eigen read with shape %s" % str(self.eigen.shape))
28 |
29 | def read_row(self,row):
30 | vals = row.split()
31 | assert (len(vals) == self.num_dims)
32 | return np.array([float(x) for x in vals])
33 |
34 | def project(self, matrix):
35 | return np.dot(matrix - self.mean, self.eigen.transpose())
36 |
37 |
--------------------------------------------------------------------------------
/src/sparse_text2vis.py:
--------------------------------------------------------------------------------
1 | import sys, getopt, os
2 | import numpy as np
3 | import tensorflow as tf
4 | from batcher_sparse import Batcher
5 | from helpers import *
6 | from paths import PATH
7 | from evaluation_measure import evaluation
8 | from flags import define_commom_flags
9 | from mscoco_captions_reader import MSCocoCaptions
10 |
11 |
12 | def main(argv=None):
13 |
14 | err_exit(argv[1:], "Error in parameters %s (--help for documentation)." % argv[1:])
15 | Path = PATH(FLAGS.fc_layer, debug_mode=FLAGS.debug, use_ngrams=FLAGS.ngrams)
16 | run_name = FLAGS.run+('-N' if FLAGS.ngrams else '-U')
17 | sys.stdout = os.fdopen(sys.stdout.fileno(), 'w', 0) # set stdout to unbuffered
18 |
19 | # Default parameters
20 | l2factor = 0.00000001
21 |
22 | vocabulary_size = 25000 if FLAGS.ngrams else 10000
23 | # The training file (captions and visual embeddings) are used as training data, whereas the validation file (captions and visual embeddings) is split into validation and test
24 | batches = Batcher(captions_file = Path.tr_captions_file,
25 | visual_file = Path.tr_visual_embeddings_file,
26 | batch_size=FLAGS.batch_size,
27 | max_vocabulary_size=vocabulary_size)
28 |
29 | input_size = batches.vocabulary_size()
30 | print("Input size: %d" % input_size )
31 |
32 | # ---------------------------------------------------------------------------------------------------------------------------------------------------------------------
33 | # GRAPH
34 | # ---------------------------------------------------------------------------------------------------------------------------------------------------------------------
35 |
36 | graph = tf.Graph()
37 | with graph.as_default():
38 | # Input/Output data.
39 | # -------------------------------------------------------
40 | caption_input = tf.placeholder(tf.float32, shape=[None, input_size])
41 | caption_output = tf.placeholder(tf.float32, shape=[None, input_size])
42 | visual_embedding_output = tf.placeholder(tf.float32, shape=[None, FLAGS.visual_dim])
43 |
44 | global_step = tf.placeholder(tf.float32) # training iteration
45 |
46 | # Model parameters
47 | # -------------------------------------------------------
48 | # caption-embedding
49 | cap2vec_weights, cap2vec_biases = projection_weights(input_size, FLAGS.hidden_size, 'cap2hid')
50 |
51 | # embedding-caption
52 | vec2cap_weights, vec2cap_biases = projection_weights(FLAGS.hidden_size, input_size, 'hid2cap')
53 |
54 | # embedding-visual
55 | vec2vis_weights, vec2vis_biases = projection_weights(FLAGS.hidden_size, FLAGS.visual_dim, 'hid2vis')
56 |
57 | # NNet
58 | # -------------------------------------------------------
59 | hidden_layer = tf.nn.relu(tf.matmul(caption_input, cap2vec_weights) + cap2vec_biases)
60 | caption_reconstruc = tf.nn.relu(tf.matmul(hidden_layer, vec2cap_weights) + vec2cap_biases)
61 | visual_prediction = tf.nn.relu(tf.matmul(hidden_layer, vec2vis_weights) + vec2vis_biases)
62 |
63 | # Losses functions
64 | # -------------------------------------------------------
65 | l2loss = l2factor * (tf.nn.l2_loss(vec2vis_weights) + tf.nn.l2_loss(vec2vis_biases))
66 | visual_loss = tf.reduce_mean(tf.square(visual_prediction - visual_embedding_output)) + l2loss
67 | caption_loss = tf.reduce_mean(tf.square(caption_output - caption_reconstruc))
68 | loss = visual_loss + caption_loss
69 |
70 | # Optimizers
71 | # -------------------------------------------------------
72 | visual_optimizer = tf.train.AdamOptimizer().minimize(visual_loss)
73 | caption_optimizer = tf.train.AdamOptimizer().minimize(caption_loss)
74 | full_optimizer = tf.train.AdamOptimizer().minimize(loss)
75 |
76 | # Add ops to save and restore all the variables.
77 | saver = tf.train.Saver(max_to_keep=1) # defaults are: max_to_keep=5, keep_checkpoint_every_n_hours=10000.0
78 |
79 | # Tensorboard data
80 | auto_loss_summary = tf.scalar_summary('loss/auto_loss', caption_loss)
81 | visual_loss_summary = tf.scalar_summary('loss/visual_loss', visual_loss)
82 | loss_summary = tf.scalar_summary('loss/loss', loss)
83 | summaries = tf.merge_summary([loss_summary, auto_loss_summary, visual_loss_summary])
84 |
85 | print("Graph built!")
86 |
87 | # ---------------------------------------------------------------------------------------------------------------------------------------------------------------------
88 | # RUN
89 | # ---------------------------------------------------------------------------------------------------------------------------------------------------------------------
90 |
91 | with tf.Session(graph=graph) as session:
92 | tf.initialize_all_variables().run()
93 |
94 | do_training = FLAGS.train
95 | do_test = FLAGS.test
96 |
97 | # Train the net
98 | if do_training:
99 | # interactive mode: allows the user to save & quit or quit w/o saving
100 | interactive = InteractiveOptions()
101 | tensorboard = TensorboardData(FLAGS.boarddata)
102 |
103 | val_batches = Batcher(captions_file=Path.val_caption_file, visual_file=Path.val_visual_embeddings_file,
104 | batch_size=FLAGS.batch_size,
105 | word2id = batches.get_word2id(), id2word = batches.get_id2word())
106 |
107 | tensorboard.open(Path.summaries_dir, run_name, session.graph)
108 |
109 | best_val, val_lv = None, None
110 | tr_losses = model_losses('tr')
111 | val_losses = model_losses('val')
112 | last_improve = 0
113 | for step in range(1, FLAGS.num_steps):
114 | img_labels, caps_pos, wide_in, wide_out, visual_embeddings = batches.next()
115 | optimizer = get_optimizer(op_loss=full_optimizer, op_vis=visual_optimizer, op_tex=caption_optimizer,
116 | stochastic_loss=FLAGS.stochastic_loss,
117 | visual_prob=FLAGS.v_loss_prob)
118 | tr_feed_dict = {caption_input:wide_in, caption_output:wide_out, visual_embedding_output:visual_embeddings}
119 | _, tr_l, tr_ll, tr_lv = session.run([optimizer, loss, caption_loss, visual_loss], feed_dict=tr_feed_dict)
120 | tr_losses.accum(tr_l, tr_ll, tr_lv, len(img_labels))
121 | tensorboard.add_train_summary(summaries.eval(feed_dict=tr_feed_dict), step)
122 |
123 | if step % FLAGS.summary_frequency == 0:
124 | print('[Step %d] %s' % (step, tr_losses))
125 | tr_losses.clear()
126 |
127 | if step % FLAGS.validation_frequency == 0:
128 | val_losses.clear()
129 | samples_proc = 0
130 | val_epoch = val_batches.epoch
131 | while True:
132 | img_labels, caps_pos, wide_in, wide_out, visual_embeddings = val_batches.next()
133 | val_feed_dict = {caption_input:wide_in, caption_output:wide_out, visual_embedding_output:visual_embeddings}
134 | val_l, val_ll, val_lv = session.run([loss, caption_loss, visual_loss], feed_dict=val_feed_dict)
135 | val_losses.accum(val_l, val_ll, val_lv, len(img_labels))
136 | samples_proc += len(img_labels)
137 | if val_batches.epoch != val_epoch: break
138 | print('\t[Step %d] %s' % (step, val_losses))
139 | _, _, valid_loss = val_losses.mean_flush()
140 | tensorboard.add_valid_summary(summaries.eval(feed_dict=val_feed_dict), step)
141 |
142 | if not best_val or valid_loss < best_val:
143 | best_val = valid_loss
144 | last_improve = step
145 | if step > 2500:
146 | savemodel(session, step, saver, Path.checkpoint_dir, run_name)
147 |
148 | if interactive.status != interactive.Continue:
149 | if interactive.status in [interactive.SaveQuit, interactive.Save, interactive.SaveGoTest]:
150 | savemodel(session, step, saver, Path.checkpoint_dir, run_name, posfix='_usersaved')
151 | if interactive.status in [interactive.SaveQuit, interactive.Quit]:
152 | do_test = False
153 | break
154 | if interactive.status in [interactive.GoTest, interactive.SaveGoTest]:
155 | do_test = True
156 | break
157 | interactive.status = interactive.Continue
158 |
159 | # early-stop condition
160 | if step - last_improve >= 20000:
161 | print ("Early stop at step %d" % step)
162 | break
163 |
164 | tensorboard.close()
165 |
166 | # Test the net
167 | if do_test:
168 | restore_checkpoint(saver, session, Path.checkpoint_dir, checkpoint_path=FLAGS.checkpoint)
169 |
170 | print('Starts evaluation...')
171 | test_batches = Batcher(captions_file=Path.test_caption_file, visual_file=Path.test_visual_embeddings_file,
172 | batch_size=FLAGS.batch_size,
173 | word2id = batches.get_word2id(), id2word = batches.get_id2word())
174 |
175 | print('Getting predictions...')
176 | test_img_ids, test_cap_id = [], []
177 | samples_processed = 0
178 | batch_processed = 0
179 | test_losses = model_losses('test')
180 | predictions = []
181 | while test_batches.epoch == 0:
182 | img_ids, cap_ids, wide_in, wide_out, visual_embeddings = test_batches.next()
183 | test_feed_dict = {caption_input: wide_in, caption_output: wide_out, visual_embedding_output: visual_embeddings}
184 | batch_predictions, te_l, te_ll, te_lv = session.run([visual_prediction, loss, caption_loss, visual_loss], feed_dict=test_feed_dict)
185 | test_losses.accum(te_l, te_ll, te_lv, len(img_ids))
186 | predictions.append(batch_predictions)
187 | test_img_ids += img_ids
188 | test_cap_id += cap_ids
189 | batch_processed += 1
190 | samples_processed += len(img_ids)
191 | if batch_processed % 100 == 0:
192 | print('Processed %d examples' % (samples_processed))
193 | predictions = np.concatenate((predictions), axis=0)
194 |
195 | predictions_file = Path.predictions_dir + '/' + run_name + '.txt'
196 | visual_ids, visual_vectors = test_batches._visual.get_all_vectors()
197 |
198 | ref_test_captions = test_batches._captions
199 | if FLAGS.ngrams:
200 | #the reference captions should be taken without ngrams, otherwise the ROUGE has a different bias
201 | uni_captions_file = Path.test_caption_file.replace('.ngrams','')
202 | ref_test_captions = MSCocoCaptions(uni_captions_file)
203 |
204 | evaluation(ref_test_captions, visual_ids, visual_vectors, predictions, test_img_ids, test_cap_id, predictions_file,
205 | method=FLAGS.retrieval,
206 | mean_file=Path.mean_file, eigen_file=Path.eigen_file,
207 | test_loss=str(test_losses), save_predictions=FLAGS.save_predictions)
208 |
209 | #-------------------------------------
210 | if __name__ == '__main__':
211 | flags = tf.app.flags
212 | FLAGS = flags.FLAGS
213 |
214 | flags = define_commom_flags(flags, num_steps=100001, summary_frequency=100)
215 |
216 | flags.DEFINE_boolean('ngrams', True, 'If True uses the (previously extracted) n-grams file (default True).')
217 | flags.DEFINE_boolean('stochastic_loss', True, 'Activates the stochastic loss heuristic (default True).')
218 | flags.DEFINE_float('v_loss_prob', 0.5, 'Visual loss probability (default=0.5).')
219 | flags.DEFINE_integer('hidden_size', 1024, 'Hidden size (default 1024).')
220 |
221 | tf.app.run()
222 |
223 |
224 |
225 |
--------------------------------------------------------------------------------
/src/visual_embeddings_reader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from io import open
4 |
5 | class VisualEmbeddingsReader:
6 | def __init__(self, visualembeddingsFile, visual_dim=4096):
7 | self.visual_dim = visual_dim
8 | if os.path.exists(visualembeddingsFile+'.npy'):
9 | print('Loading binary file from %s ...' % (visualembeddingsFile+'.npy'))
10 | self.visual_embeddings = np.load(visualembeddingsFile+'.npy').item()
11 | return
12 | print('Reading images from %s ...' % visualembeddingsFile)
13 | with open(visualembeddingsFile, "r", buffering=100000000, encoding="utf-8", errors='ignore') as fv:
14 | self.visual_embeddings = dict()
15 | for line in fv:
16 | fields = line.split("\t")
17 | imgID = int(fields[0])
18 | embedding = np.zeros(shape=self.visual_dim, dtype=np.float32)
19 | dim_val = [x.split(':') for x in fields[1].split()]
20 | for dim, val in dim_val:embedding[int(dim)] = float(val)
21 | self.visual_embeddings[imgID] = embedding
22 | if len(self.visual_embeddings) % 1000 == 0: print('\t%d images read' % len(self.visual_embeddings))
23 | print('Saving binary file %s.npy ...' % visualembeddingsFile)
24 | np.save(visualembeddingsFile, self.visual_embeddings)
25 |
26 | def get(self, img_label): return self.visual_embeddings[img_label]
27 |
28 | def get_all_vectors(self):
29 | unzip = zip(*self.visual_embeddings.items())
30 | img_ids = np.asarray(unzip[0])
31 | vectors = np.asarray(unzip[1])
32 | return img_ids, vectors
33 |
34 |
--------------------------------------------------------------------------------
/src/word2visualvec.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | from batcher_word2visualvec import Batcher
4 | from evaluation_measure import evaluation
5 | from helpers import *
6 | import gc
7 | from paths import PATH
8 | from flags import define_commom_flags
9 | import gensim, logging
10 | logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)
11 |
12 |
13 | def load_wordembeddings(sentences_file, we_dim):
14 | saved_model = sentences_file + ".model"
15 | if not os.path.exists(saved_model):
16 | print('Pre-trained model for %s not found. Generating word embeddings...' % sentences_file)
17 | sentences = gensim.models.word2vec.LineSentence(sentences_file)
18 | model = gensim.models.Word2Vec(sentences, size=we_dim, sg=1, workers=8) # default: window=5, min_count=5, iter=5
19 | model.save(saved_model)
20 | print('Done! file saved in %s' % saved_model)
21 | return gensim.models.Word2Vec.load(saved_model)
22 |
23 | class PATH_WE(PATH):
24 | def __init__(self, fclayer, debug_mode=False, lemma=False):
25 | super(PATH_WE, self).__init__(fclayer, debug_mode)
26 | self.wordembeddings_file = '../wordembeddings/YFCC100Muser_tags'+('_lemma'if lemma else '')+'.bz2'
27 | notexist_exit(self.wordembeddings_file)
28 |
29 | def main(argv=None):
30 |
31 | err_exit(argv[1:], "Error in parameters %s (--help for documentation)." % argv[1:])
32 | Path = PATH_WE(FLAGS.fc_layer, debug_mode=FLAGS.debug, lemma=FLAGS.lemma)
33 | sys.stdout = os.fdopen(sys.stdout.fileno(), 'w', 0) # set stdout to unbuffered
34 |
35 | #---------------------------------------------------------------------
36 | # Inits
37 | #---------------------------------------------------------------------
38 | do_training = FLAGS.train
39 | do_test = FLAGS.test
40 | run_name = 'word2visualvec'+FLAGS.run
41 |
42 | we_dim = FLAGS.embedding_size
43 | wordembeddings = load_wordembeddings(Path.wordembeddings_file, we_dim)
44 |
45 | batches = Batcher(captions_file=Path.tr_captions_file,
46 | visual_file=Path.tr_visual_embeddings_file,
47 | we_dim=we_dim,
48 | batch_size = FLAGS.batch_size,
49 | lemmatize=FLAGS.lemma,
50 | model=wordembeddings)
51 |
52 | hidden_sizes = [1000, 2000, 3000] if FLAGS.large else [1000, 2000]
53 |
54 | graph = tf.Graph()
55 | with graph.as_default():
56 | # Placeholders
57 | output = tf.placeholder(tf.float32, shape=[None, FLAGS.visual_dim], name='visual-embedding')
58 | input = tf.placeholder(tf.float32, shape=[None, we_dim], name='pooled_word_embeddings')
59 | keep_p = tf.placeholder(tf.float32)
60 |
61 | def add_layer(layer, hidden_size, drop, name):
62 | weight, bias = projection_weights(layer.get_shape().as_list()[1], hidden_size, name=name)
63 | activation = tf.nn.relu(tf.matmul(layer,weight)+bias)
64 | if drop:
65 | return tf.nn.dropout(activation, keep_prob=keep_p)
66 | else:
67 | return activation
68 |
69 | current_layer = input
70 | for i,hidden_dim in enumerate(hidden_sizes):
71 | current_layer = add_layer(current_layer, hidden_dim, drop=True, name='layer'+str(i))
72 | out_layer = add_layer(current_layer, FLAGS.visual_dim, drop=False, name='output_layer')
73 |
74 |
75 | # losses
76 | loss = tf.reduce_mean(tf.square(out_layer - output))
77 |
78 | # Optimizer.
79 | optimizer = tf.train.RMSPropOptimizer(learning_rate=0.001, decay=0.9, epsilon=1e-6).minimize(loss)
80 |
81 | # Add ops to save and restore all the variables.
82 | saver = tf.train.Saver(max_to_keep=1) #defaults are: max_to_keep=5, keep_checkpoint_every_n_hours=10000.0
83 |
84 | # Tensorboard data
85 | summaries = tf.merge_summary([tf.scalar_summary('loss/loss', loss)]) if FLAGS.boarddata else None
86 |
87 |
88 | with tf.Session(graph=graph) as session:
89 | tf.initialize_all_variables().run()
90 |
91 | # Train the net
92 | if do_training:
93 | # interactive mode: allows the user to save & quit or quit w/o saving
94 | interactive = InteractiveOptions()
95 | tensorboard = TensorboardData(FLAGS.boarddata)
96 |
97 | val_batches = Batcher(captions_file=Path.val_caption_file, visual_file=Path.val_visual_embeddings_file,
98 | we_dim=we_dim,
99 | batch_size=FLAGS.batch_size,
100 | lemmatize=FLAGS.lemma,
101 | model=wordembeddings)
102 |
103 | tensorboard.open(Path.summaries_dir, run_name, session.graph)
104 |
105 | train_loss = 0.0
106 | valid_loss = 0.0
107 | best_val = None
108 | last_improve = 0
109 | for step in range(1, FLAGS.num_steps):
110 | _, _, embedding_pool, vis_embeddings = batches.next()
111 | feed_dict = {input: embedding_pool, output: vis_embeddings, keep_p: FLAGS.drop_keep_p}
112 | _,tr_l = session.run([optimizer, loss], feed_dict=feed_dict)
113 | train_loss += tr_l
114 | tensorboard.add_train_summary(summaries.eval(feed_dict=feed_dict), step)
115 |
116 | if step % FLAGS.summary_frequency == 0:
117 | print('[epoch=%d][Step %d] %0.5f' % (batches.epoch, step, (train_loss / FLAGS.summary_frequency)))
118 | train_loss = 0.0
119 |
120 | if step % FLAGS.validation_frequency == 0:
121 | valid_loss = 0.0
122 | samples_proc = 0
123 | val_epoch = val_batches.epoch
124 | while True:
125 | _, _, pooled_embeddings, visual_embeddings = val_batches.next()
126 | samples = len(pooled_embeddings)
127 | feed_dict = {input: pooled_embeddings, output: visual_embeddings, keep_p: 1.0}
128 | val_l = loss.eval(feed_dict=feed_dict)
129 | valid_loss += val_l*samples
130 | samples_proc += samples
131 | if val_batches.epoch != val_epoch: break
132 | valid_loss /= samples_proc
133 | print('\t[epoch=%d][Step %d] %0.5f' % (val_batches.epoch, step, valid_loss))
134 | tensorboard.add_valid_summary(summaries.eval(feed_dict=feed_dict), step)
135 |
136 | if not best_val or valid_loss < best_val:
137 | best_val = valid_loss
138 | last_improve = step
139 | if step > 2500:
140 | savemodel(session, step, saver, Path.checkpoint_dir, run_name)
141 |
142 | if interactive.status != interactive.Continue:
143 | if interactive.status in [interactive.SaveQuit, interactive.Save, interactive.SaveGoTest]:
144 | savemodel(session, step, saver, Path.checkpoint_dir, run_name, posfix='_usersaved')
145 | if interactive.status in [interactive.SaveQuit, interactive.Quit]:
146 | do_test = False
147 | break
148 | if interactive.status in [interactive.GoTest, interactive.SaveGoTest]:
149 | do_test = True
150 | break
151 | interactive.status = interactive.Continue
152 |
153 | # early-stop condition
154 | if step - last_improve >= 30000:
155 | print ("Early stop at step %d" % step)
156 | break
157 |
158 | tensorboard.close()
159 |
160 | # Test the net
161 | if do_test:
162 | restore_checkpoint(saver, session, Path.checkpoint_dir, checkpoint_path=FLAGS.checkpoint)
163 |
164 | print('Starts evaluation...')
165 | test_batches = Batcher(captions_file=Path.test_caption_file, visual_file=Path.test_visual_embeddings_file,
166 | we_dim=we_dim,
167 | batch_size=FLAGS.batch_size,
168 | lemmatize=FLAGS.lemma,
169 | model=wordembeddings)
170 |
171 | print('Getting predictions...')
172 | test_img_ids,test_cap_id,predictions = [],[],[]
173 | test_loss = 0.0
174 | tests_processed = 0
175 | batch_processed = 0
176 | while True:
177 | img_labels, caps_pos, pooled_embeddings, visual_embeddings = test_batches.next()
178 | samples = len(img_labels)
179 | feed_dict = {input: pooled_embeddings, output: visual_embeddings, keep_p: 1.0}
180 | batch_predictions, test_l = session.run([out_layer, loss], feed_dict=feed_dict)
181 | test_loss += (test_l*samples)
182 | predictions.append(batch_predictions)
183 | test_img_ids += img_labels
184 | test_cap_id += caps_pos
185 | tests_processed += samples
186 | batch_processed += 1
187 | if batch_processed % 100 == 0:
188 | print('Processed %d test examples' % (tests_processed))
189 | if test_batches.epoch != 0: break
190 | predictions = np.concatenate((predictions), axis=0)
191 | test_loss /= tests_processed
192 | wordembeddings = None
193 | gc.collect()
194 |
195 | predictions_file = Path.predictions_dir + '/' + run_name + '.txt'
196 | visual_ids, visual_vectors = test_batches._visual.get_all_vectors()
197 |
198 | evaluation(test_batches, visual_ids, visual_vectors, predictions, test_img_ids, test_cap_id, predictions_file,
199 | method=FLAGS.retrieval,
200 | mean_file=Path.mean_file, eigen_file=Path.eigen_file,
201 | test_loss=str(test_loss), save_predictions=FLAGS.save_predictions)
202 |
203 |
204 |
205 | #-------------------------------------
206 | if __name__ == '__main__':
207 | flags = tf.app.flags
208 | FLAGS = flags.FLAGS
209 |
210 | flags = define_commom_flags(flags, num_steps=500001, summary_frequency=100)
211 |
212 | # net settings
213 | flags.DEFINE_boolean('lemma', True, 'Determines whether to use the word embeddings trained after lemmatizing.')
214 | flags.DEFINE_boolean('large', True, 'If true adds an additional layer to the net (default True).')
215 | flags.DEFINE_integer('embedding_size', 500, 'Dimensionality of the word embeddings space (default 500).')
216 | flags.DEFINE_float('drop_keep_p', 0.85, 'Keep probability for dropout (default 0.85).')
217 |
218 |
219 |
220 | tf.app.run()
221 |
222 |
223 |
224 |
--------------------------------------------------------------------------------
/src/yfcc100m_extractor.py:
--------------------------------------------------------------------------------
1 | import bz2, sys
2 | import urllib
3 | from os import listdir
4 | from os.path import isfile, join
5 | from nltk.stem import WordNetLemmatizer
6 |
7 | def extract_user_tags(yfcc100m_dir, lemmatize=False):
8 | lemmatizer = WordNetLemmatizer() if lemmatize else None
9 | lemma_posfix = '_lemma' if lemmatize else ''
10 | print('Extracting user_tags from %s' %yfcc100m_dir)
11 | yfcc100m_parts = [f for f in listdir(yfcc100m_dir) if isfile(join(yfcc100m_dir, f))]
12 | yfcc100m_parts.sort()
13 |
14 | with bz2.BZ2File(yfcc100m_dir+'user_tags'+lemma_posfix+'.bz2', 'w') as fout:
15 | for part in yfcc100m_parts:
16 | print('\t%s' % join(yfcc100m_dir, part))
17 | for line in bz2.BZ2File(join(yfcc100m_dir, part), 'r'):
18 | user_tags = line.split('\t')[8].strip()
19 | if user_tags:
20 | user_tags = urllib.unquote_plus(user_tags).split(',')
21 | if lemmatize:
22 | lem_tags = []
23 | for tag in user_tags:
24 | try:
25 | lem_tags.append(lemmatizer.lemmatize(tag))
26 | except UnicodeDecodeError:
27 | None
28 | user_tags=lem_tags
29 | if user_tags:
30 | fout.write(' '.join(user_tags) + '\n')
31 |
32 | #extract_user_tags('../wordembeddings/YFCC100M', lemmatize=True)
33 | #extract_user_tags('../wordembeddings/YFCC100M', lemmatize=False)
34 |
--------------------------------------------------------------------------------
/visualembeddings/readme.txt:
--------------------------------------------------------------------------------
1 | Place here your visual features for the MsCOCO images.
2 | If you don't know how to extract them, please do not hesitate to contact us.
3 |
--------------------------------------------------------------------------------
/wordembeddings/readme.txt:
--------------------------------------------------------------------------------
1 | Place here the wordembeddings extracted from the YFCC100M dataset
2 |
--------------------------------------------------------------------------------