├── .gitignore
├── README.md
├── compute_results.py
├── data.py
├── data
└── vocab
│ ├── 10crop_precomp_vocab.pkl
│ ├── coco_precomp_vocab.pkl
│ ├── coco_vocab.pkl
│ ├── f30k_precomp_vocab.pkl
│ ├── f30k_vocab.pkl
│ ├── f8k_precomp_vocab.pkl
│ └── f8k_vocab.pkl
├── evaluation.py
├── model.py
├── requirements.txt
├── train.py
└── vocab.py
/.gitignore:
--------------------------------------------------------------------------------
1 | runs/
2 | __pycache__/
3 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # VSE-HAL
2 | Code release for **HAL: Improved Text-Image Matching by Mitigating Visual Semantic Hubs** [\[arxiv\]](https://arxiv.org/pdf/1911.10097v1.pdf) at AAAI 2020.
3 |
4 | ```bibtex
5 | @inproceedings{liu2020hal,
6 | title={{HAL}: Improved text-image matching by mitigating visual semantic hubs},
7 | author={Liu, Fangyu and Ye, Rongtian and Wang, Xun and Li, Shuaipeng},
8 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
9 | volume={34},
10 | number={07},
11 | pages={11563--11571},
12 | year={2020}
13 | }
14 | ```
15 |
16 | Upgrade your text-image matching model with a few lines of code:
17 | ```python
18 | class ContrastiveLoss(nn.Module):
19 | ...
20 | def forward(self, im, s, ...):
21 | bsize = im.size()[0]
22 | scores = self.sim(im, s)
23 | ...
24 | tmp = torch.eye(bsize).cuda()
25 | s_diag = tmp * scores
26 | scores_ = scores - s_diag
27 | ...
28 | S_ = torch.exp(self.l_alpha * (scores_ - self.l_ep))
29 | loss_diag = - torch.log(1 + F.relu(s_diag.sum(0)))
30 |
31 | loss = torch.sum( \
32 | torch.log(1 + S_.sum(0)) / self.l_alpha \
33 | + torch.log(1 + S_.sum(1)) / self.l_alpha \
34 | + loss_diag \
35 | ) / bsize
36 |
37 | return loss
38 | ```
39 |
40 |
41 | ## Dependencies
42 | ```
43 | nltk==3.4.5
44 | pycocotools==2.0.0
45 | numpy==1.18.1
46 | torch==1.5.1
47 | torchvision==0.6.0
48 | tensorboard_logger==0.1.0
49 | ```
50 |
51 | ## Data
52 | #### MS-COCO
53 | [\[vgg_precomp\]](https://cs.stanford.edu/people/karpathy/deepimagesent/coco.zip)
54 | [\[resnet_precomp\]](https://drive.google.com/uc?id=1vtUijEbXpVzNt6HjC6ph8ZzMHRRNms5j&export=download)
55 |
56 | #### Flickr30k
57 | [\[vgg_precomp\]](http://www.cs.toronto.edu/~faghri/vsepp/data.tar)
58 |
59 | ## Train
60 |
61 | Run `train.py`.
62 |
63 | #### MS-COCO
64 |
65 | ##### w/o global weighting
66 |
67 | ```bash
68 | python3 train.py \
69 | --data_path "data/data/resnet_precomp" \
70 | --vocab_path "data/vocab/" \
71 | --data_name coco_precomp \
72 | --batch_size 512 \
73 | --learning_rate 0.001 \
74 | --lr_update 8 \
75 | --num_epochs 13 \
76 | --img_dim 2048 \
77 | --logger_name runs/COCO \
78 | --local_alpha 30.00 \
79 | --local_ep 0.3
80 | ```
81 |
82 | ##### with global weighting
83 |
84 | ```bash
85 | python3 train.py \
86 | --data_path "data/data/resnet_precomp" \
87 | --vocab_path "data/vocab/" \
88 | --data_name coco_precomp \
89 | --batch_size 512 \
90 | --learning_rate 0.001 \
91 | --lr_update 8 \
92 | --num_epochs 13 \
93 | --img_dim 2048 \
94 | --logger_name runs/COCO_mb \
95 | --local_alpha 30.00 \
96 | --local_ep 0.3 \
97 | --memory_bank \
98 | --global_alpha 40.00 \
99 | --global_beta 40.00 \
100 | --global_ep_posi 0.20 \
101 | --global_ep_nega 0.10 \
102 | --mb_rate 0.05 \
103 | --mb_k 250
104 | ```
105 |
106 | #### Flickr30k
107 |
108 | ```bash
109 | python3 train.py \
110 | --data_path "data/data" \
111 | --vocab_path "data/vocab/" \
112 | --data_name f30k_precomp \
113 | --batch_size 128 \
114 | --learning_rate 0.001 \
115 | --lr_update 8 \
116 | --num_epochs 13 \
117 | --logger_name runs/f30k \
118 | --local_alpha 60.00 \
119 | --local_ep 0.7
120 | ```
121 |
122 | ## Evaluate
123 |
124 | Run `compute_results.py`.
125 |
126 | #### COCO
127 |
128 | ```bash
129 | python3 compute_results.py --data_path data/data/resnet_precomp --fold5 --model_path runs/COCO/model_best.pth.tar
130 | ```
131 |
132 | #### Flickr30k
133 |
134 | ```bash
135 | python3 compute_results.py --data_path data/data --model_path runs/f30k/model_best.pth.tar
136 | ```
137 | #### Trained models
138 | [\[Google Drive\]](https://drive.google.com/drive/folders/1H_EVBFxpYKObNo_CjV0pTaB24A1jWsSF)
139 |
140 | ## Note
141 | Trained models and codes for replicating results on [SCAN](https://github.com/kuanghuei/SCAN) are coming soon.
142 |
143 | ## Acknowledgments
144 | This project would be impossible without the open source implementations of [VSE++](https://github.com/fartashf/vsepp) and [SCAN](https://github.com/kuanghuei/SCAN).
145 |
146 | ## License
147 | [Apache License 2.0](http://www.apache.org/licenses/LICENSE-2.0)
148 |
--------------------------------------------------------------------------------
/compute_results.py:
--------------------------------------------------------------------------------
1 | from vocab import Vocabulary
2 | import evaluation
3 |
4 | import argparse
5 |
6 | parser = argparse.ArgumentParser()
7 | parser.add_argument('--model_path', default='$RUN_PATH/coco_vse/model_best.pth.tar', help='path to model')
8 | parser.add_argument('--data_path', default='data/data', help='path to datasets')
9 | parser.add_argument('--fold5', action='store_true',
10 | help='Use fold5')
11 | parser.add_argument('--save_embeddings', action='store_true',
12 | help='save_embeddings')
13 | parser.add_argument('--save_csv', default='')
14 |
15 | opt_eval = parser.parse_args()
16 |
17 | evaluation.evalrank(opt_eval, split='test')
18 |
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data as data
3 | import torchvision.transforms as transforms
4 | import os
5 | import nltk
6 | from PIL import Image
7 | from pycocotools.coco import COCO
8 | import numpy as np
9 | import json as jsonmod
10 |
11 |
12 | def get_paths(path, name='coco', use_restval=False):
13 | """
14 | Returns paths to images and annotations for the given datasets. For MSCOCO
15 | indices are also returned to control the data split being used.
16 | The indices are extracted from the Karpathy et al. splits using this
17 | snippet:
18 |
19 | >>> import json
20 | >>> dataset=json.load(open('dataset_coco.json','r'))
21 | >>> A=[]
22 | >>> for i in range(len(D['images'])):
23 | ... if D['images'][i]['split'] == 'val':
24 | ... A+=D['images'][i]['sentids'][:5]
25 | ...
26 |
27 | :param name: Dataset names
28 | :param use_restval: If True, the the `restval` data is included in train.
29 | """
30 | roots = {}
31 | ids = {}
32 | if 'coco' == name:
33 | imgdir = os.path.join(path, 'images')
34 | capdir = os.path.join(path, 'annotations')
35 | roots['train'] = {
36 | 'img': os.path.join(imgdir, 'train2014'),
37 | 'cap': os.path.join(capdir, 'captions_train2014.json')
38 | }
39 | roots['val'] = {
40 | 'img': os.path.join(imgdir, 'val2014'),
41 | 'cap': os.path.join(capdir, 'captions_val2014.json')
42 | }
43 | roots['test'] = {
44 | 'img': os.path.join(imgdir, 'val2014'),
45 | 'cap': os.path.join(capdir, 'captions_val2014.json')
46 | }
47 | roots['trainrestval'] = {
48 | 'img': (roots['train']['img'], roots['val']['img']),
49 | 'cap': (roots['train']['cap'], roots['val']['cap'])
50 | }
51 | ids['train'] = np.load(os.path.join(capdir, 'coco_train_ids.npy'))
52 | ids['val'] = np.load(os.path.join(capdir, 'coco_dev_ids.npy'))[:5000]
53 | ids['test'] = np.load(os.path.join(capdir, 'coco_test_ids.npy'))
54 | ids['trainrestval'] = (
55 | ids['train'],
56 | np.load(os.path.join(capdir, 'coco_restval_ids.npy')))
57 | if use_restval:
58 | roots['train'] = roots['trainrestval']
59 | ids['train'] = ids['trainrestval']
60 | elif 'f8k' == name:
61 | imgdir = os.path.join(path, 'images')
62 | cap = os.path.join(path, 'dataset_flickr8k.json')
63 | roots['train'] = {'img': imgdir, 'cap': cap}
64 | roots['val'] = {'img': imgdir, 'cap': cap}
65 | roots['test'] = {'img': imgdir, 'cap': cap}
66 | ids = {'train': None, 'val': None, 'test': None}
67 | elif 'f30k' == name:
68 | imgdir = os.path.join(path, 'images')
69 | cap = os.path.join(path, 'dataset_flickr30k.json')
70 | roots['train'] = {'img': imgdir, 'cap': cap}
71 | roots['val'] = {'img': imgdir, 'cap': cap}
72 | roots['test'] = {'img': imgdir, 'cap': cap}
73 | ids = {'train': None, 'val': None, 'test': None}
74 |
75 | return roots, ids
76 |
77 |
78 | class CocoDataset(data.Dataset):
79 | """COCO Custom Dataset compatible with torch.utils.data.DataLoader."""
80 |
81 | def __init__(self, root, json, vocab, transform=None, ids=None):
82 | """
83 | Args:
84 | root: image directory.
85 | json: coco annotation file path.
86 | vocab: vocabulary wrapper.
87 | transform: transformer for image.
88 | """
89 | self.root = root
90 | # when using `restval`, two json files are needed
91 | if isinstance(json, tuple):
92 | self.coco = (COCO(json[0]), COCO(json[1]))
93 | else:
94 | self.coco = (COCO(json),)
95 | self.root = (root,)
96 | # if ids provided by get_paths, use split-specific ids
97 | if ids is None:
98 | self.ids = list(self.coco.anns.keys())
99 | else:
100 | self.ids = ids
101 |
102 | # if `restval` data is to be used, record the break point for ids
103 | if isinstance(self.ids, tuple):
104 | self.bp = len(self.ids[0])
105 | self.ids = list(self.ids[0]) + list(self.ids[1])
106 | else:
107 | self.bp = len(self.ids)
108 | self.vocab = vocab
109 | self.transform = transform
110 |
111 | def __getitem__(self, index):
112 | """This function returns a tuple that is further passed to collate_fn
113 | """
114 | vocab = self.vocab
115 | root, caption, img_id, path, image = self.get_raw_item(index)
116 |
117 | if self.transform is not None:
118 | image = self.transform(image)
119 |
120 | # Convert caption (string) to word ids.
121 | tokens = nltk.tokenize.word_tokenize(
122 | str(caption).lower().decode('utf-8'))
123 | caption = []
124 | caption.append(vocab(''))
125 | caption.extend([vocab(token) for token in tokens])
126 | caption.append(vocab(''))
127 | target = torch.Tensor(caption)
128 | return image, target, index, img_id, index
129 |
130 | def get_raw_item(self, index):
131 | if index < self.bp:
132 | coco = self.coco[0]
133 | root = self.root[0]
134 | else:
135 | coco = self.coco[1]
136 | root = self.root[1]
137 | ann_id = self.ids[index]
138 | caption = coco.anns[ann_id]['caption']
139 | img_id = coco.anns[ann_id]['image_id']
140 | path = coco.loadImgs(img_id)[0]['file_name']
141 | image = Image.open(os.path.join(root, path)).convert('RGB')
142 |
143 | return root, caption, img_id, path, image
144 |
145 | def __len__(self):
146 | return len(self.ids)
147 |
148 |
149 | class FlickrDataset(data.Dataset):
150 | """
151 | Dataset loader for Flickr30k and Flickr8k full datasets.
152 | """
153 |
154 | def __init__(self, root, json, split, vocab, transform=None):
155 | self.root = root
156 | self.vocab = vocab
157 | self.split = split
158 | self.transform = transform
159 | self.dataset = jsonmod.load(open(json, 'r'))['images']
160 | self.ids = []
161 | for i, d in enumerate(self.dataset):
162 | if d['split'] == split:
163 | self.ids += [(i, x) for x in range(len(d['sentences']))]
164 |
165 | def __getitem__(self, index):
166 | """This function returns a tuple that is further passed to collate_fn
167 | """
168 | vocab = self.vocab
169 | root = self.root
170 | ann_id = self.ids[index]
171 | img_id = ann_id[0]
172 | caption = self.dataset[img_id]['sentences'][ann_id[1]]['raw']
173 | path = self.dataset[img_id]['filename']
174 |
175 | image = Image.open(os.path.join(root, path)).convert('RGB')
176 | if self.transform is not None:
177 | image = self.transform(image)
178 |
179 | # Convert caption (string) to word ids.
180 | tokens = nltk.tokenize.word_tokenize(
181 | str(caption).lower())
182 | caption = []
183 | caption.append(vocab(''))
184 | caption.extend([vocab(token) for token in tokens])
185 | caption.append(vocab(''))
186 | target = torch.Tensor(caption)
187 | return image, target, index, img_id, index
188 |
189 | def __len__(self):
190 | return len(self.ids)
191 |
192 |
193 | class PrecompDataset(data.Dataset):
194 | """
195 | Load precomputed captions and image features
196 | Possible options: f8k, f30k, coco, 10crop
197 | """
198 |
199 | def __init__(self, data_path, data_split, vocab):
200 | self.vocab = vocab
201 | loc = data_path + '/'
202 |
203 | # Captions
204 | self.captions = []
205 | caps_name = loc + '%s_caps.txt' % data_split
206 | with open(caps_name, 'rb') as f:
207 | for line in f:
208 | self.captions.append(line.strip())
209 |
210 | # Image features
211 | self.images = np.load(loc+'%s_ims.npy' % data_split)
212 | self.length = len(self.captions)
213 | # rkiros data has redundancy in images, we divide by 5, 10crop doesn't
214 | if self.images.shape[0] != self.length:
215 | self.im_div = 5
216 | else:
217 | self.im_div = 1
218 | # the development set for coco is large and so validation would be slow
219 | if data_split == 'dev':
220 | self.length = 5000
221 |
222 | def __getitem__(self, index):
223 | # handle the image redundancy
224 | img_id = index/self.im_div
225 | image = torch.Tensor(self.images[int(img_id)])
226 | caption = self.captions[index]
227 | vocab = self.vocab
228 |
229 | # Convert caption (string) to word ids.
230 | tokens = nltk.tokenize.word_tokenize(str(caption).lower())
231 | caption = []
232 | caption.append(vocab(''))
233 | caption.extend([vocab(token) for token in tokens])
234 | caption.append(vocab(''))
235 | target = torch.Tensor(caption)
236 | return image, target, index, img_id, index
237 |
238 | def __len__(self):
239 | return self.length
240 |
241 |
242 | def collate_fn(data):
243 | """Build mini-batch tensors from a list of (image, caption) tuples.
244 | Args:
245 | data: list of (image, caption) tuple.
246 | - image: torch tensor of shape (3, 256, 256).
247 | - caption: torch tensor of shape (?); variable length.
248 |
249 | Returns:
250 | images: torch tensor of shape (batch_size, 3, 256, 256).
251 | targets: torch tensor of shape (batch_size, padded_length).
252 | lengths: list; valid length for each padded caption.
253 | """
254 | # Sort a data list by caption length
255 | data.sort(key=lambda x: len(x[1]), reverse=True)
256 | images, captions, ids, img_ids, indices = zip(*data)
257 |
258 | # Merge images (convert tuple of 3D tensor to 4D tensor)
259 | images = torch.stack(images, 0)
260 |
261 | # Merget captions (convert tuple of 1D tensor to 2D tensor)
262 | lengths = [len(cap) for cap in captions]
263 | targets = torch.zeros(len(captions), max(lengths)).long()
264 | for i, cap in enumerate(captions):
265 | end = lengths[i]
266 | targets[i, :end] = cap[:end]
267 |
268 | return images, targets, lengths, ids, indices
269 |
270 |
271 | def get_loader_single(data_name, split, root, json, vocab, transform,
272 | batch_size=100, shuffle=True,
273 | num_workers=2, ids=None, collate_fn=collate_fn):
274 | """Returns torch.utils.data.DataLoader for custom coco dataset."""
275 | if 'coco' in data_name:
276 | # COCO custom dataset
277 | dataset = CocoDataset(root=root,
278 | json=json,
279 | vocab=vocab,
280 | transform=transform, ids=ids)
281 | elif 'f8k' in data_name or 'f30k' in data_name:
282 | dataset = FlickrDataset(root=root,
283 | split=split,
284 | json=json,
285 | vocab=vocab,
286 | transform=transform)
287 |
288 | # Data loader
289 | data_loader = torch.utils.data.DataLoader(dataset=dataset,
290 | batch_size=batch_size,
291 | shuffle=shuffle,
292 | pin_memory=True,
293 | num_workers=num_workers,
294 | collate_fn=collate_fn)
295 | return data_loader
296 |
297 |
298 | def get_precomp_loader(data_path, data_split, vocab, opt, batch_size=100,
299 | shuffle=True, num_workers=2):
300 | """Returns torch.utils.data.DataLoader for custom coco dataset."""
301 | dset = PrecompDataset(data_path, data_split, vocab)
302 |
303 | data_loader = torch.utils.data.DataLoader(dataset=dset,
304 | batch_size=batch_size,
305 | shuffle=shuffle,
306 | pin_memory=True,
307 | collate_fn=collate_fn)
308 | return data_loader
309 |
310 |
311 | def get_transform(data_name, split_name, opt):
312 | normalizer = transforms.Normalize(mean=[0.485, 0.456, 0.406],
313 | std=[0.229, 0.224, 0.225])
314 | t_list = []
315 | if split_name == 'train':
316 | t_list = [transforms.RandomResizedCrop(opt.crop_size),
317 | transforms.RandomHorizontalFlip()]
318 | elif split_name == 'val':
319 | t_list = [transforms.Resize(256), transforms.CenterCrop(224)]
320 | elif split_name == 'test':
321 | t_list = [transforms.Resize(256), transforms.CenterCrop(224)]
322 |
323 | t_end = [transforms.ToTensor(), normalizer]
324 | transform = transforms.Compose(t_list + t_end)
325 | return transform
326 |
327 |
328 | def get_loaders(data_name, vocab, crop_size, batch_size, workers, opt):
329 | dpath = os.path.join(opt.data_path, data_name)
330 | if opt.data_name.endswith('_precomp'):
331 | train_loader = get_precomp_loader(dpath, 'train', vocab, opt,
332 | batch_size, True, workers)
333 | val_loader = get_precomp_loader(dpath, 'dev', vocab, opt,
334 | batch_size, False, workers)
335 | else:
336 | # Build Dataset Loader
337 | roots, ids = get_paths(dpath, data_name, opt.use_restval)
338 |
339 | transform = get_transform(data_name, 'train', opt)
340 | train_loader = get_loader_single(opt.data_name, 'train',
341 | roots['train']['img'],
342 | roots['train']['cap'],
343 | vocab, transform, ids=ids['train'],
344 | batch_size=batch_size, shuffle=True,
345 | num_workers=workers,
346 | collate_fn=collate_fn)
347 |
348 | transform = get_transform(data_name, 'val', opt)
349 | val_loader = get_loader_single(opt.data_name, 'val',
350 | roots['val']['img'],
351 | roots['val']['cap'],
352 | vocab, transform, ids=ids['val'],
353 | batch_size=batch_size, shuffle=False,
354 | num_workers=workers,
355 | collate_fn=collate_fn)
356 |
357 | return train_loader, val_loader
358 |
359 |
360 | def get_test_loader(split_name, data_name, vocab, crop_size, batch_size,
361 | workers, opt):
362 | dpath = os.path.join(opt.data_path, data_name)
363 | if opt.data_name.endswith('_precomp'):
364 | test_loader = get_precomp_loader(dpath, split_name, vocab, opt,
365 | batch_size, False, workers)
366 | else:
367 | # Build Dataset Loader
368 | roots, ids = get_paths(dpath, data_name, opt.use_restval)
369 |
370 | transform = get_transform(data_name, split_name, opt)
371 | test_loader = get_loader_single(opt.data_name, split_name,
372 | roots[split_name]['img'],
373 | roots[split_name]['cap'],
374 | vocab, transform, ids=ids[split_name],
375 | batch_size=batch_size, shuffle=False,
376 | num_workers=workers,
377 | collate_fn=collate_fn)
378 |
379 | return test_loader
380 |
--------------------------------------------------------------------------------
/data/vocab/10crop_precomp_vocab.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hardyqr/HAL/5a33289255c7807d35c22a54e7cebed8ed2ee77b/data/vocab/10crop_precomp_vocab.pkl
--------------------------------------------------------------------------------
/data/vocab/coco_precomp_vocab.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hardyqr/HAL/5a33289255c7807d35c22a54e7cebed8ed2ee77b/data/vocab/coco_precomp_vocab.pkl
--------------------------------------------------------------------------------
/data/vocab/coco_vocab.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hardyqr/HAL/5a33289255c7807d35c22a54e7cebed8ed2ee77b/data/vocab/coco_vocab.pkl
--------------------------------------------------------------------------------
/data/vocab/f30k_precomp_vocab.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hardyqr/HAL/5a33289255c7807d35c22a54e7cebed8ed2ee77b/data/vocab/f30k_precomp_vocab.pkl
--------------------------------------------------------------------------------
/data/vocab/f30k_vocab.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hardyqr/HAL/5a33289255c7807d35c22a54e7cebed8ed2ee77b/data/vocab/f30k_vocab.pkl
--------------------------------------------------------------------------------
/data/vocab/f8k_precomp_vocab.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hardyqr/HAL/5a33289255c7807d35c22a54e7cebed8ed2ee77b/data/vocab/f8k_precomp_vocab.pkl
--------------------------------------------------------------------------------
/data/vocab/f8k_vocab.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hardyqr/HAL/5a33289255c7807d35c22a54e7cebed8ed2ee77b/data/vocab/f8k_vocab.pkl
--------------------------------------------------------------------------------
/evaluation.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import os
3 | import pickle
4 |
5 | import numpy
6 | from data import get_test_loader
7 | import time
8 | import numpy as np
9 | from vocab import Vocabulary # NOQA
10 | import torch
11 | from model import VSE, order_sim
12 | from collections import OrderedDict
13 |
14 | class AverageMeter(object):
15 | """Computes and stores the average and current value"""
16 |
17 | def __init__(self):
18 | self.reset()
19 |
20 | def reset(self):
21 | self.val = 0
22 | self.avg = 0
23 | self.sum = 0
24 | self.count = 0
25 |
26 | def update(self, val, n=0):
27 | self.val = val
28 | self.sum += val * n
29 | self.count += n
30 | self.avg = self.sum / (.0001 + self.count)
31 |
32 | def __str__(self):
33 | """String representation for logging
34 | """
35 | # for values that should be recorded exactly e.g. iteration number
36 | if self.count == 0:
37 | return str(self.val)
38 | # for stats
39 | return '%.4f (%.4f)' % (self.val, self.avg)
40 |
41 |
42 | class LogCollector(object):
43 | """A collection of logging objects that can change from train to val"""
44 |
45 | def __init__(self):
46 | # to keep the order of logged variables deterministic
47 | self.meters = OrderedDict()
48 |
49 | def update(self, k, v, n=0):
50 | # create a new meter if previously not recorded
51 | if k not in self.meters:
52 | self.meters[k] = AverageMeter()
53 | self.meters[k].update(v, n)
54 |
55 | def __str__(self):
56 | """Concatenate the meters in one log line
57 | """
58 | s = ''
59 | for i, (k, v) in enumerate(self.meters.items()):
60 | if i > 0:
61 | s += ' '
62 | s += k + ' ' + str(v)
63 | return s
64 |
65 | def tb_log(self, tb_logger, prefix='', step=None):
66 | """Log using tensorboard
67 | """
68 | for k, v in self.meters.items():
69 | tb_logger.log_value(prefix + k, v.val, step=step)
70 |
71 |
72 | def encode_data(model, data_loader, log_step=10, logging=print):
73 | """Encode all images and captions loadable by `data_loader`
74 | """
75 | batch_time = AverageMeter()
76 | val_logger = LogCollector()
77 |
78 | # switch to evaluate mode
79 | model.val_start()
80 |
81 | end = time.time()
82 |
83 | # numpy array to keep all the embeddings
84 | img_embs = None
85 | cap_embs = None
86 | for i, (images, captions, lengths, ids, indices) in enumerate(data_loader):
87 | # make sure val logger is used
88 | model.logger = val_logger
89 |
90 | # compute the embeddings
91 | img_emb, cap_emb = model.forward_emb(images, captions, lengths,
92 | volatile=True)
93 |
94 | # initialize the numpy arrays given the size of the embeddings
95 | if img_embs is None:
96 | img_embs = np.zeros((len(data_loader.dataset), img_emb.size(1)))
97 | cap_embs = np.zeros((len(data_loader.dataset), cap_emb.size(1)))
98 |
99 | # preserve the embeddings by copying from gpu and converting to numpy
100 | img_embs[ids] = img_emb.data.cpu().numpy().copy()
101 | cap_embs[ids] = cap_emb.data.cpu().numpy().copy()
102 |
103 | # measure accuracy and record loss
104 | model.forward_loss(img_emb, cap_emb, indices)
105 |
106 | # measure elapsed time
107 | batch_time.update(time.time() - end)
108 | end = time.time()
109 |
110 | if i % log_step == 0:
111 | logging('Test: [{0}/{1}]\t'
112 | '{e_log}\t'
113 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
114 | .format(
115 | i, len(data_loader), batch_time=batch_time,
116 | e_log=str(model.logger)))
117 | del images, captions
118 |
119 | return img_embs, cap_embs
120 |
121 |
122 | def evalrank(opt_eval, split):
123 | """
124 | Evaluate a trained model on either dev or test. If `fold5=True`, 5 fold
125 | cross-validation is done (only for MSCOCO). Otherwise, the full data is
126 | used for evaluation.
127 | """
128 | # load model and options
129 | checkpoint = torch.load(opt_eval.model_path)
130 | opt = checkpoint['opt']
131 | if opt_eval.data_path is not None:
132 | opt.data_path = opt_eval.data_path
133 | print(opt)
134 | # load vocabulary used by the model
135 | with open(os.path.join(opt.vocab_path,
136 | '%s_vocab.pkl' % opt.data_name), 'rb') as f:
137 | vocab = pickle.load(f)
138 | opt.vocab_size = len(vocab)
139 |
140 | # construct model
141 | model = VSE(opt)
142 |
143 | # load model state
144 | model.load_state_dict(checkpoint['model'])
145 |
146 | print('Loading dataset')
147 | data_loader = get_test_loader(split, opt.data_name, vocab, opt.crop_size,
148 | opt.batch_size, opt.workers, opt)
149 |
150 | print('Computing results...')
151 | img_embs, cap_embs = encode_data(model, data_loader)
152 | print('Images: %d, Captions: %d' %
153 | (img_embs.shape[0] / 5, cap_embs.shape[0]))
154 |
155 | if opt_eval.save_embeddings:
156 | save_path = opt_eval.model_path.split('/')[-2]+'_img_and_cap_embeddings.pth.tar'
157 | with open(save_path+'.pkl', 'wb') as handle:
158 | pickle.dump({'img_embs':img_embs, 'cap_embs':cap_embs},
159 | handle, protocol=pickle.HIGHEST_PROTOCOL)
160 | print ("[embeddings saved to {}]".format(save_path))
161 |
162 | if not opt_eval.fold5:
163 | # no cross-validation, full evaluation
164 | r, rt = i2t(img_embs, cap_embs, measure=opt.measure, return_ranks=True)
165 | ri, rti = t2i(img_embs, cap_embs,
166 | measure=opt.measure, return_ranks=True)
167 | ar = (r[0] + r[1] + r[2]) / 3
168 | ari = (ri[0] + ri[1] + ri[2]) / 3
169 | rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2]
170 | print("rsum: %.1f" % rsum)
171 | print("Average i2t Recall: %.1f" % ar)
172 | print("Image to text: %.1f %.1f %.1f %.1f %.1f" % r)
173 | print("Average t2i Recall: %.1f" % ari)
174 | print("Text to image: %.1f %.1f %.1f %.1f %.1f" % ri)
175 |
176 | if len(opt_eval.save_csv) > 0:
177 | with open(opt_eval.save_csv, "a") as f:
178 | i2t_data = ", %.1f, %.1f, %.1f, %.1f, %.1f" % r
179 | t2i_data = ", %.1f, %.1f, %.1f, %.1f, %.1f" % ri
180 | rsum_data = ", %.1f" % rsum
181 | row_data = opt.logger_name + i2t_data + t2i_data + rsum_data + "\n"
182 | f.write(row_data)
183 | else:
184 | # 5fold cross-validation, only for MSCOCO
185 | results = []
186 | for i in range(5):
187 | r, rt0 = i2t(img_embs[i * 5000:(i + 1) * 5000],
188 | cap_embs[i * 5000:(i + 1) *
189 | 5000], measure=opt.measure,
190 | return_ranks=True)
191 | print("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" % r)
192 | ri, rti0 = t2i(img_embs[i * 5000:(i + 1) * 5000],
193 | cap_embs[i * 5000:(i + 1) *
194 | 5000], measure=opt.measure,
195 | return_ranks=True)
196 | if i == 0:
197 | rt, rti = rt0, rti0
198 | print("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" % ri)
199 | ar = (r[0] + r[1] + r[2]) / 3
200 | ari = (ri[0] + ri[1] + ri[2]) / 3
201 | rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2]
202 | print("rsum: %.1f ar: %.1f ari: %.1f" % (rsum, ar, ari))
203 | results += [list(r) + list(ri) + [ar, ari, rsum]]
204 |
205 | if i == 0 and len(opt_eval.save_csv) > 0:
206 | with open(opt_eval.save_csv + "_fold1", "a") as f:
207 | i2t_data = ", %.1f, %.1f, %.1f, %.1f, %.1f" % r
208 | t2i_data = ", %.1f, %.1f, %.1f, %.1f, %.1f" % ri
209 | rsum_data = ", %.1f" % rsum
210 | row_data = opt.logger_name + i2t_data + t2i_data + rsum_data + "\n"
211 | f.write(row_data)
212 |
213 | print("-----------------------------------")
214 | print("Mean metrics: ")
215 | mean_metrics = tuple(np.array(results).mean(axis=0).flatten())
216 | print("rsum: %.1f" % (mean_metrics[12]))
217 | print("Average i2t Recall: %.1f" % mean_metrics[11])
218 | print("Image to text: %.1f %.1f %.1f %.1f %.1f" %
219 | mean_metrics[:5])
220 | print("Average t2i Recall: %.1f" % mean_metrics[12])
221 | print("Text to image: %.1f %.1f %.1f %.1f %.1f" %
222 | mean_metrics[5:10])
223 |
224 | if len(opt_eval.save_csv) > 0:
225 | with open(opt_eval.save_csv + "_fold5", "a") as f:
226 | i2t_data = ", %.1f, %.1f, %.1f, %.1f, %.1f" % mean_metrics[:5]
227 | t2i_data = ", %.1f, %.1f, %.1f, %.1f, %.1f" % mean_metrics[5:10]
228 | rsum_data = ", %.1f" % mean_metrics[12]
229 | row_data = opt.logger_name + i2t_data + t2i_data + rsum_data + "\n"
230 | f.write(row_data)
231 |
232 | torch.save({'rt': rt, 'rti': rti}, 'ranks.pth.tar')
233 |
234 | def i2t(images, captions, npts=None, measure='cosine',
235 | return_ranks=False):
236 | """
237 | Images->Text (Image Annotation)
238 | Images: (5N, K) matrix of images
239 | Captions: (5N, K) matrix of captions
240 | """
241 | if npts is None:
242 | npts = int(images.shape[0] / 5)
243 | #print(npts)
244 | index_list = []
245 |
246 | scores = images.dot(captions.T)
247 |
248 | ranks = numpy.zeros(npts)
249 | top1 = numpy.zeros(npts)
250 | for index in range(npts):
251 |
252 | # Get query image
253 | im = images[5 * index].reshape(1, images.shape[1])
254 |
255 | # Compute scores
256 | if measure == 'order':
257 | bs = 100
258 | if index % bs == 0:
259 | mx = min(images.shape[0], 5 * (index + bs))
260 | im2 = images[5 * index:mx:5]
261 | d2 = order_sim(torch.Tensor(im2),
262 | torch.Tensor(captions))
263 | d2 = d2.cpu().numpy()
264 | d = d2[index % bs]
265 | else:
266 | d = scores[5 * index]
267 | inds = numpy.argsort(d)[::-1]
268 | index_list.append(inds[0])
269 |
270 | # Score
271 | rank = 1e20
272 | for i in range(5 * index, 5 * index + 5, 1):
273 | tmp = numpy.where(inds == i)[0][0]
274 | if tmp < rank:
275 | rank = tmp
276 | ranks[index] = rank
277 | top1[index] = inds[0]
278 |
279 | # Compute metrics
280 | r1 = 100.0 * len(numpy.where(ranks < 1)[0]) / len(ranks)
281 | r5 = 100.0 * len(numpy.where(ranks < 5)[0]) / len(ranks)
282 | r10 = 100.0 * len(numpy.where(ranks < 10)[0]) / len(ranks)
283 | medr = numpy.floor(numpy.median(ranks)) + 1
284 | meanr = ranks.mean() + 1
285 | if return_ranks:
286 | return (r1, r5, r10, medr, meanr), (ranks, top1)
287 | else:
288 | return (r1, r5, r10, medr, meanr)
289 |
290 |
291 | def t2i(images, captions, npts=None, measure='cosine',
292 | return_ranks=False):
293 | """
294 | Text->Images (Image Search)
295 | Images: (5N, K) matrix of images
296 | Captions: (5N, K) matrix of captions
297 | """
298 | if npts is None:
299 | npts = int(images.shape[0] / 5)
300 | #print("# points:", npts)
301 | ims = numpy.array([images[i] for i in range(0, len(images), 5)])
302 |
303 | scores = captions.dot(ims.T)
304 |
305 | ranks = numpy.zeros(5 * npts)
306 | top1 = numpy.zeros(5 * npts)
307 | for index in range(npts):
308 |
309 | # Compute scores
310 | if measure == 'order':
311 | bs = 100
312 | if 5 * index % bs == 0:
313 | mx = min(captions.shape[0], 5 * index + bs)
314 | q2 = captions[5 * index:mx]
315 | d2 = order_sim(torch.Tensor(ims),
316 | torch.Tensor(q2))
317 | d2 = d2.cpu().numpy()
318 |
319 | d = d2[:, (5 * index) % bs:(5 * index) % bs + 5].T
320 | else:
321 | d = scores[5 * index:5 * index + 5]
322 | inds = numpy.zeros(d.shape)
323 | for i in range(len(inds)):
324 | inds[i] = numpy.argsort(d[i])[::-1]
325 | ranks[5 * index + i] = numpy.where(inds[i] == index)[0][0]
326 | top1[5 * index + i] = inds[i][0]
327 |
328 | # Compute metrics
329 | r1 = 100.0 * len(numpy.where(ranks < 1)[0]) / len(ranks)
330 | r5 = 100.0 * len(numpy.where(ranks < 5)[0]) / len(ranks)
331 | r10 = 100.0 * len(numpy.where(ranks < 10)[0]) / len(ranks)
332 | medr = numpy.floor(numpy.median(ranks)) + 1
333 | meanr = ranks.mean() + 1
334 | if return_ranks:
335 | return (r1, r5, r10, medr, meanr), (ranks, top1)
336 | else:
337 | return (r1, r5, r10, medr, meanr)
338 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import functional as F
4 | import torch.nn.init
5 | import torchvision.models as models
6 | from torch.autograd import Variable
7 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
8 | import torch.backends.cudnn as cudnn
9 | from torch.nn.utils.clip_grad import clip_grad_norm_
10 | import numpy as np
11 | from collections import OrderedDict
12 | from random import randint
13 |
14 | def l2norm(X):
15 | """L2-normalize columns of X
16 | """
17 | norm = torch.pow(X, 2).sum(dim=1, keepdim=True).sqrt()
18 | X = torch.div(X, norm)
19 | return X
20 |
21 |
22 | def EncoderImage(data_name, img_dim, embed_size, finetune=False,
23 | cnn_type='vgg19', use_abs=False, no_imgnorm=False):
24 | """A wrapper to image encoders. Chooses between an encoder that uses
25 | precomputed image features, `EncoderImagePrecomp`, or an encoder that
26 | computes image features on the fly `EncoderImageFull`.
27 | """
28 | if data_name.endswith('_precomp'):
29 | img_enc = EncoderImagePrecomp(
30 | img_dim, embed_size, use_abs, no_imgnorm)
31 | else:
32 | img_enc = EncoderImageFull(
33 | embed_size, finetune, cnn_type, use_abs, no_imgnorm)
34 |
35 | return img_enc
36 |
37 |
38 | # tutorials/09 - Image Captioning
39 | class EncoderImageFull(nn.Module):
40 |
41 | def __init__(self, embed_size, finetune=False, cnn_type='vgg19',
42 | use_abs=False, no_imgnorm=False):
43 | """Load pretrained VGG19 and replace top fc layer."""
44 | super(EncoderImageFull, self).__init__()
45 | self.embed_size = embed_size
46 | self.no_imgnorm = no_imgnorm
47 | self.use_abs = use_abs
48 |
49 | # Load a pre-trained model
50 | self.cnn = self.get_cnn(cnn_type, True)
51 |
52 | # For efficient memory usage.
53 | for param in self.cnn.parameters():
54 | param.requires_grad = finetune
55 |
56 | # Replace the last fully connected layer of CNN with a new one
57 | if cnn_type.startswith('vgg'):
58 | self.fc = nn.Linear(self.cnn.classifier._modules['6'].in_features,
59 | embed_size)
60 | self.cnn.classifier = nn.Sequential(
61 | *list(self.cnn.classifier.children())[:-1])
62 | elif cnn_type.startswith('resnet'):
63 | self.fc = nn.Linear(self.cnn.module.fc.in_features, embed_size)
64 | self.cnn.module.fc = nn.Sequential()
65 |
66 | self.init_weights()
67 |
68 | def get_cnn(self, arch, pretrained):
69 | """Load a pretrained CNN and parallelize over GPUs
70 | """
71 | if pretrained:
72 | print("=> using pre-trained model '{}'".format(arch))
73 | model = models.__dict__[arch](pretrained=True)
74 | else:
75 | print("=> creating model '{}'".format(arch))
76 | model = models.__dict__[arch]()
77 |
78 | if arch.startswith('alexnet') or arch.startswith('vgg'):
79 | model.features = nn.DataParallel(model.features)
80 | model.cuda()
81 | else:
82 | model = nn.DataParallel(model).cuda()
83 |
84 | return model
85 |
86 | def load_state_dict(self, state_dict):
87 | """
88 | Handle the models saved before commit pytorch/vision@989d52a
89 | """
90 | if 'cnn.classifier.1.weight' in state_dict:
91 | state_dict['cnn.classifier.0.weight'] = state_dict[
92 | 'cnn.classifier.1.weight']
93 | del state_dict['cnn.classifier.1.weight']
94 | state_dict['cnn.classifier.0.bias'] = state_dict[
95 | 'cnn.classifier.1.bias']
96 | del state_dict['cnn.classifier.1.bias']
97 | state_dict['cnn.classifier.3.weight'] = state_dict[
98 | 'cnn.classifier.4.weight']
99 | del state_dict['cnn.classifier.4.weight']
100 | state_dict['cnn.classifier.3.bias'] = state_dict[
101 | 'cnn.classifier.4.bias']
102 | del state_dict['cnn.classifier.4.bias']
103 |
104 | super(EncoderImageFull, self).load_state_dict(state_dict)
105 |
106 | def init_weights(self):
107 | """Xavier initialization for the fully connected layer
108 | """
109 | r = np.sqrt(6.) / np.sqrt(self.fc.in_features +
110 | self.fc.out_features)
111 | self.fc.weight.data.uniform_(-r, r)
112 | self.fc.bias.data.fill_(0)
113 |
114 | def forward(self, images):
115 | """Extract image feature vectors."""
116 | features = self.cnn(images)
117 |
118 | # normalization in the image embedding space
119 | features = l2norm(features)
120 |
121 | # linear projection to the joint embedding space
122 | features = self.fc(features)
123 |
124 | # normalization in the joint embedding space
125 | if not self.no_imgnorm:
126 | features = l2norm(features)
127 |
128 | # take the absolute value of the embedding (used in order embeddings)
129 | if self.use_abs:
130 | features = torch.abs(features)
131 |
132 | return features
133 |
134 |
135 | class EncoderImagePrecomp(nn.Module):
136 |
137 | def __init__(self, img_dim, embed_size, use_abs=False, no_imgnorm=False):
138 | super(EncoderImagePrecomp, self).__init__()
139 | self.embed_size = embed_size
140 | self.no_imgnorm = no_imgnorm
141 | self.use_abs = use_abs
142 |
143 | self.fc = nn.Linear(img_dim, embed_size)
144 |
145 | self.init_weights()
146 |
147 | def init_weights(self):
148 | """Xavier initialization for the fully connected layer
149 | """
150 | r = np.sqrt(6.) / np.sqrt(self.fc.in_features +
151 | self.fc.out_features)
152 | self.fc.weight.data.uniform_(-r, r)
153 | self.fc.bias.data.fill_(0)
154 |
155 | def forward(self, images):
156 | """Extract image feature vectors."""
157 | # assuming that the precomputed features are already l2-normalized
158 |
159 | features = self.fc(images)
160 |
161 | # normalize in the joint embedding space
162 | if not self.no_imgnorm:
163 | features = l2norm(features)
164 |
165 | # take the absolute value of embedding (used in order embeddings)
166 | if self.use_abs:
167 | features = torch.abs(features)
168 |
169 | return features
170 |
171 | def load_state_dict(self, state_dict):
172 | """Copies parameters. overwritting the default one to
173 | accept state_dict from Full model
174 | """
175 | own_state = self.state_dict()
176 | new_state = OrderedDict()
177 | for name, param in state_dict.items():
178 | if name in own_state:
179 | new_state[name] = param
180 |
181 | super(EncoderImagePrecomp, self).load_state_dict(new_state)
182 |
183 |
184 | # tutorials/08 - Language Model
185 | # RNN Based Language Model
186 | class EncoderText(nn.Module):
187 |
188 | def __init__(self, vocab_size, word_dim, embed_size, num_layers,
189 | use_abs=False):
190 | super(EncoderText, self).__init__()
191 | self.use_abs = use_abs
192 | self.embed_size = embed_size
193 |
194 | # word embedding
195 | self.embed = nn.Embedding(vocab_size, word_dim)
196 |
197 | # caption embedding
198 | self.rnn = nn.GRU(word_dim, embed_size, num_layers, batch_first=True)
199 |
200 | self.init_weights()
201 |
202 | def init_weights(self):
203 | self.embed.weight.data.uniform_(-0.1, 0.1)
204 |
205 | def forward(self, x, lengths):
206 | """Handles variable size captions
207 | """
208 | # Embed word ids to vectors
209 | x = self.embed(x)
210 | packed = pack_padded_sequence(x, lengths, batch_first=True)
211 |
212 | # Forward propagate RNN
213 | out, _ = self.rnn(packed)
214 |
215 | # Reshape *final* output to (batch_size, hidden_size)
216 | padded = pad_packed_sequence(out, batch_first=True)
217 | I = torch.LongTensor(lengths).view(-1, 1, 1)
218 | I = Variable(I.expand(x.size(0), 1, self.embed_size)-1).cuda()
219 | out = torch.gather(padded[0], 1, I).squeeze(1)
220 |
221 | # normalization in the joint embedding space
222 | out = l2norm(out)
223 |
224 | # take absolute value, used by order embeddings
225 | if self.use_abs:
226 | out = torch.abs(out)
227 |
228 | return out
229 |
230 |
231 | def cosine_sim(im, s):
232 | """Cosine similarity between all the image and sentence pairs
233 | """
234 | return im.mm(s.t())
235 |
236 |
237 | def order_sim(im, s):
238 | """Order embeddings similarity measure $max(0, s-im)$
239 | """
240 | YmX = (s.unsqueeze(1).expand(s.size(0), im.size(0), s.size(1))
241 | - im.unsqueeze(0).expand(s.size(0), im.size(0), s.size(1)))
242 | score = -YmX.clamp(min=0).pow(2).sum(2).sqrt().t()
243 | return score
244 |
245 |
246 | class ContrastiveLoss(nn.Module):
247 | """
248 | Compute contrastive loss
249 | """
250 |
251 | def __init__(self, opt):
252 | super(ContrastiveLoss, self).__init__()
253 |
254 | if opt.measure == 'order':
255 | self.sim = order_sim
256 | else:
257 | self.sim = cosine_sim
258 |
259 | self.opt = opt
260 |
261 | # "g" represents "global"
262 | self.g_alpha = self.opt.global_alpha
263 | self.g_beta= self.opt.global_beta # W_it
264 | self.g_ep_posi = self.opt.global_ep_posi # W_ii
265 | self.g_ep_nega = self.opt.global_ep_nega
266 |
267 | # "l" represents "local"
268 | self.l_alpha = self.opt.local_alpha
269 | self.l_ep = self.opt.local_ep
270 |
271 | def forward(self, im, s, mb_img, mb_cap, mb_ind, indices):
272 |
273 | bsize = im.size()[0]
274 |
275 | scores = self.sim(im, s)
276 |
277 | if self.opt.max_violation or self.opt.sum_violation:
278 |
279 | diagonal = scores.diag().view(bsize, 1)
280 | d1 = diagonal.expand_as(scores)
281 | d2 = diagonal.t().expand_as(scores)
282 |
283 | cost_s = (self.opt.margin + scores - d1).clamp(min=0)
284 | cost_im = (self.opt.margin + scores - d2).clamp(min=0)
285 |
286 | mask = torch.eye(bsize) > .5
287 | I = Variable(mask)
288 | if torch.cuda.is_available():
289 | I = I.cuda()
290 | cost_s = cost_s.masked_fill_(I, 0)
291 | cost_im = cost_im.masked_fill_(I, 0)
292 |
293 | if self.opt.max_violation:
294 |
295 | cost_s = cost_s.max(1)[0]
296 | cost_im = cost_im.max(0)[0]
297 |
298 | return cost_s.sum() + cost_im.sum()
299 |
300 | tmp = torch.eye(bsize).cuda()
301 |
302 | s_diag = tmp * scores
303 | scores_ = scores - s_diag
304 |
305 | if mb_img is not None:
306 |
307 | #negative
308 | mb_k = self.opt.mb_k
309 | if im.size()[0] < mb_k: mb_k = bsize
310 |
311 | used_ind = torch.tensor([0 if i in indices else 1 for i in mb_ind]).bool().cuda()
312 |
313 | mb_img = mb_img[used_ind]
314 | mb_cap = mb_cap[used_ind]
315 |
316 | scores_img_glob = self.sim(im, mb_cap)
317 | i2t_k_avg = torch.exp(self.g_beta * torch.topk(scores_img_glob, mb_k)[0] - self.g_ep_nega).sum(1).reshape((bsize,1))
318 | i2t_k_avg_positive = torch.exp(self.g_alpha * (torch.topk(scores_img_glob, mb_k)[0] - self.g_ep_posi)).sum(1)
319 |
320 | scores_cap_glob = self.sim(s, mb_img)
321 | t2i_k_avg = torch.exp(self.g_beta * torch.topk(scores_cap_glob, mb_k)[0] - self.g_ep_nega).sum(1).reshape((1,bsize))
322 | t2i_k_avg_positive = torch.exp(self.g_alpha * (torch.topk(scores_cap_glob, mb_k)[0] - self.g_ep_posi)).sum(1)
323 |
324 | tmp_i2t = i2t_k_avg.repeat(1, bsize)
325 | tmp_t2i = t2i_k_avg.repeat(bsize, 1)
326 |
327 | exp_sii = torch.exp(self.g_beta * s_diag.sum(0))
328 | tmp_expii = exp_sii.reshape((bsize,1)).repeat(1, bsize)
329 | tmp_exptt = exp_sii.reshape((1,bsize)).repeat(bsize, 1)
330 |
331 | wit = (tmp_i2t + tmp_t2i) / (tmp_i2t + tmp_t2i + tmp_expii + tmp_exptt)
332 |
333 | #positive
334 | exp_sii = torch.exp(self.g_alpha * (s_diag.sum(0) - self.g_ep_posi))
335 |
336 | wii = 1 - exp_sii / (exp_sii + i2t_k_avg_positive + t2i_k_avg_positive)
337 |
338 | wit = wit - wit * tmp
339 |
340 | S_ = torch.exp(self.l_alpha * wit.detach() * (scores_ - self.l_ep))
341 |
342 | loss_diag = - torch.log(1 + F.relu((s_diag.sum(0) * wii.detach())))
343 |
344 | else:
345 |
346 | S_ = torch.exp(self.l_alpha * (scores_ - self.l_ep))
347 |
348 | loss_diag = - torch.log(1 + F.relu(s_diag.sum(0)))
349 |
350 | loss = torch.sum(
351 | torch.log(1 + S_.sum(0)) / self.l_alpha \
352 | + torch.log(1 + S_.sum(1)) / self.l_alpha \
353 | + loss_diag
354 | ) / bsize
355 |
356 | return loss
357 |
358 |
359 | class VSE(object):
360 | """
361 | rkiros/uvs model
362 | """
363 |
364 | def __init__(self, opt):
365 | # tutorials/09 - Image Captioning
366 | # Build Models
367 | self.grad_clip = opt.grad_clip
368 | self.img_enc = EncoderImage(opt.data_name, opt.img_dim, opt.embed_size,
369 | opt.finetune, opt.cnn_type,
370 | use_abs=opt.use_abs,
371 | no_imgnorm=opt.no_imgnorm)
372 | self.txt_enc = EncoderText(opt.vocab_size, opt.word_dim,
373 | opt.embed_size, opt.num_layers,
374 | use_abs=opt.use_abs)
375 | if torch.cuda.is_available():
376 | self.img_enc.cuda()
377 | self.txt_enc.cuda()
378 | cudnn.benchmark = True
379 |
380 | # memory bank
381 | self.mb_img = None
382 | self.mb_cap = None
383 | self.mb_ind = None
384 |
385 | # Loss and Optimizer
386 | self.criterion = ContrastiveLoss(opt=opt)
387 | params = list(self.txt_enc.parameters())
388 | params += list(self.img_enc.fc.parameters())
389 | if opt.finetune:
390 | params += list(self.img_enc.cnn.parameters())
391 | self.params = params
392 |
393 | self.optimizer = torch.optim.Adam(params, lr=opt.learning_rate)
394 |
395 | self.Eiters = 0
396 |
397 | def state_dict(self):
398 | state_dict = [self.img_enc.state_dict(), self.txt_enc.state_dict()]
399 | return state_dict
400 |
401 | def load_state_dict(self, state_dict):
402 | self.img_enc.load_state_dict(state_dict[0])
403 | self.txt_enc.load_state_dict(state_dict[1])
404 |
405 | def train_start(self):
406 | """switch to train mode
407 | """
408 | self.img_enc.train()
409 | self.txt_enc.train()
410 |
411 | def val_start(self):
412 | """switch to evaluate mode
413 | """
414 | self.img_enc.eval()
415 | self.txt_enc.eval()
416 |
417 | def forward_emb(self, images, captions, lengths, volatile=False,**kwargs):
418 | """Compute the image and caption embeddings
419 | """
420 | # Set mini-batch dataset
421 | if volatile:
422 | with torch.no_grad():
423 | images = Variable(images)
424 | captions = Variable(captions)
425 | else:
426 | images = Variable(images)
427 | captions = Variable(captions)
428 |
429 | if torch.cuda.is_available():
430 | images = images.cuda()
431 | captions = captions.cuda()
432 |
433 | # Forward
434 | img_emb = self.img_enc(images)
435 | cap_emb = self.txt_enc(captions, lengths)
436 | return img_emb, cap_emb
437 |
438 | def forward_loss(self, img_emb, cap_emb, indices, **kwargs):
439 | """Compute the loss given pairs of image and caption embeddings
440 | """
441 | loss = self.criterion(
442 | img_emb,
443 | cap_emb,
444 | self.mb_img,
445 | self.mb_cap,
446 | self.mb_ind,
447 | indices)
448 | self.logger.update('Loss', loss.item(), img_emb.size(0))
449 | return loss
450 |
451 | def train_emb(self, images, captions, lengths, ids, indices, *args):
452 | """One training step given images and captions.
453 | """
454 | self.Eiters += 1
455 | self.logger.update('Eit', self.Eiters)
456 | self.logger.update('lr', self.optimizer.param_groups[0]['lr'])
457 |
458 | # compute the embeddings
459 | img_emb, cap_emb = self.forward_emb(images, captions, lengths)
460 |
461 | # measure accuracy and record loss
462 | self.optimizer.zero_grad()
463 | loss = self.forward_loss(img_emb, cap_emb, indices)
464 |
465 | # compute gradient and do SGD step
466 | loss.backward()
467 | if self.grad_clip > 0:
468 | clip_grad_norm_(self.params, self.grad_clip)
469 | self.optimizer.step()
470 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # Automatically generated by https://github.com/damnever/pigar.
2 |
3 | # HAL/data.py: 6
4 | Pillow == 9.3.0
5 |
6 | # HAL/data.py: 5
7 | # HAL/vocab.py: 2
8 | nltk == 3.6.6
9 |
10 | # HAL/data.py: 8
11 | # HAL/evaluation.py: 5,8
12 | # HAL/model.py: 10
13 | numpy == 1.22.0
14 |
15 | # HAL/data.py: 7
16 | # HAL/vocab.py: 5
17 | pycocotools-fix == 2.0.0.9
18 |
19 | # HAL/data.py: 7
20 | # HAL/vocab.py: 5
21 | pycocotools-win == 2.0
22 |
23 | # HAL/train.py: 9
24 | tensorboard_logger == 0.1.0
25 |
26 | # HAL/data.py: 1,2
27 | # HAL/evaluation.py: 10
28 | # HAL/model.py: 7,8,9
29 | # HAL/train.py: 7
30 | torch == 1.13.1
31 |
32 | # HAL/data.py: 3
33 | # HAL/model.py: 5
34 | torchvision == 0.6.0a0+35d732a
35 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import os
3 | import time
4 | import shutil
5 | from random import random
6 | import argparse
7 | import torch
8 | import logging
9 | import tensorboard_logger as tb_logger
10 |
11 | import data
12 | from vocab import Vocabulary # NOQA
13 | from model import VSE
14 | from evaluation import i2t, t2i, AverageMeter, LogCollector, encode_data
15 |
16 | def main():
17 | # Hyper Parameters
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument('--data_path', default='/w/31/faghri/vsepp_data/',
20 | help='path to datasets')
21 | parser.add_argument('--data_name', default='precomp',
22 | help='{coco,f8k,f30k,10crop}_precomp|coco|f8k|f30k')
23 | parser.add_argument('--vocab_path', default='./vocab/',
24 | help='Path to saved vocabulary pickle files.')
25 | parser.add_argument('--margin', default=0.2, type=float,
26 | help='Rank loss margin.')
27 | parser.add_argument('--num_epochs', default=15, type=int,
28 | help='Number of training epochs.')
29 | parser.add_argument('--batch_size', default=128, type=int,
30 | help='Size of a training mini-batch.')
31 | parser.add_argument('--word_dim', default=300, type=int,
32 | help='Dimensionality of the word embedding.')
33 | parser.add_argument('--embed_size', default=1024, type=int,
34 | help='Dimensionality of the joint embedding.')
35 | parser.add_argument('--grad_clip', default=2., type=float,
36 | help='Gradient clipping threshold.')
37 | parser.add_argument('--crop_size', default=224, type=int,
38 | help='Size of an image crop as the CNN input.')
39 | parser.add_argument('--num_layers', default=1, type=int,
40 | help='Number of GRU layers.')
41 | parser.add_argument('--learning_rate', default=.0002, type=float,
42 | help='Initial learning rate.')
43 | parser.add_argument('--lr_update', default=8, type=int,
44 | help='Number of epochs to update the learning rate.')
45 | parser.add_argument('--workers', default=10, type=int,
46 | help='Number of data loader workers.')
47 | parser.add_argument('--log_step', default=10, type=int,
48 | help='Number of steps to print and record the log.')
49 | parser.add_argument('--val_step', default=500, type=int,
50 | help='Number of steps to run validation.')
51 | parser.add_argument('--logger_name', default='runs/runX',
52 | help='Path to save the model and Tensorboard log.')
53 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
54 | help='path to latest checkpoint (default: none)')
55 | parser.add_argument('--max_violation', action='store_true',
56 | help='Use max instead of sum in the rank loss.')
57 | parser.add_argument('--sum_violation', action='store_true')
58 | parser.add_argument('--img_dim', default=4096, type=int,
59 | help='Dimensionality of the image embedding.')
60 | parser.add_argument('--finetune', action='store_true',
61 | help='Fine-tune the image encoder.')
62 | parser.add_argument('--cnn_type', default='vgg19',
63 | help="""The CNN used for image encoder
64 | (e.g. vgg19, resnet152)""")
65 | parser.add_argument('--use_restval', action='store_true',
66 | help='Use the restval data for training on MSCOCO.')
67 | parser.add_argument('--measure', default='cosine',
68 | help='Similarity measure used (cosine|order)')
69 | parser.add_argument('--use_abs', action='store_true',
70 | help='Take the absolute value of embedding vectors.')
71 | parser.add_argument('--no_imgnorm', action='store_true',
72 | help='Do not normalize the image embeddings.')
73 | parser.add_argument('--reset_train', action='store_true',
74 | help='Ensure the training is always done in '
75 | 'train mode (Not recommended).')
76 | parser.add_argument('--save_all', action='store_true',
77 | help="Save model after the training of each epoch")
78 | parser.add_argument('--memory_bank', action='store_true',
79 | help="Train model with memory bank")
80 | parser.add_argument('--record_val', action='store_true',
81 | help="Record the rsum values on validation set in file during training")
82 | parser.add_argument('--local_alpha', default=30.0, type=float)
83 | parser.add_argument('--local_ep', default=0.3, type=float)
84 | parser.add_argument('--global_alpha', default=40.0, type=float)
85 | parser.add_argument('--global_beta', default=40.0, type=float)
86 | parser.add_argument('--global_ep_posi', default=0.2, type=float,
87 | help="Global epsilon for positive pairs")
88 | parser.add_argument('--global_ep_nega', default=0.1, type=float,
89 | help="Global epsilon for negative pairs")
90 | parser.add_argument('--mb_k', default=250, type=int,
91 | help="Use top K items in memory bank")
92 | parser.add_argument('--mb_rate', default=0.05, type=float,
93 | help="-")
94 |
95 | opt = parser.parse_args()
96 | print(opt)
97 |
98 | logging.basicConfig(format='%(message)s', level=logging.INFO)
99 | tb_logger.configure(opt.logger_name, flush_secs=5)
100 |
101 | # Load Vocabulary Wrapper
102 | vocab = pickle.load(open(os.path.join(
103 | opt.vocab_path, '%s_vocab.pkl' % opt.data_name), 'rb'))
104 | opt.vocab_size = len(vocab)
105 | print("Vocab Size: %d" % opt.vocab_size)
106 |
107 | # Load data loaders
108 | train_loader, val_loader = data.get_loaders(
109 | opt.data_name, vocab, opt.crop_size, opt.batch_size, opt.workers, opt)
110 |
111 | # Construct the model
112 | model = VSE(opt)
113 |
114 | # optionally resume from a checkpoint
115 | if opt.resume:
116 | if os.path.isfile(opt.resume):
117 | print("=> loading checkpoint '{}'".format(opt.resume))
118 | checkpoint = torch.load(opt.resume)
119 | start_epoch = checkpoint['epoch']
120 | best_rsum = checkpoint['best_rsum']
121 | model.load_state_dict(checkpoint['model'])
122 | # Eiters is used to show logs as the continuation of another
123 | # training
124 | model.Eiters = checkpoint['Eiters']
125 | print("=> loaded checkpoint '{}' (epoch {}, best_rsum {})"
126 | .format(opt.resume, start_epoch, best_rsum))
127 | validate(opt, val_loader, model)
128 | else:
129 | print("=> no checkpoint found at '{}'".format(opt.resume))
130 |
131 | # Train the Model
132 | best_rsum = 0
133 | for epoch in range(opt.num_epochs):
134 | adjust_learning_rate(opt, model.optimizer, epoch)
135 |
136 | memory_bank = opt.memory_bank
137 | if memory_bank and epoch > 0:
138 | load_memory_bank(opt, train_loader, model)
139 | # train for one epoch
140 | train(opt, train_loader, model, epoch, val_loader)
141 |
142 | # evaluate on validation set
143 | rsum = validate(opt, val_loader, model)
144 | print ("rsum: %.1f" % rsum)
145 | if opt.record_val:
146 | with open("rst_val_" + opt.logger_name[5:], "a") as f:
147 | f.write("Epoch: %d ; rsum: %.1f\n" %(epoch, rsum))
148 |
149 | # remember best R@ sum and save checkpoint
150 | is_best = rsum > best_rsum
151 | best_rsum = max(rsum, best_rsum)
152 | save_checkpoint({
153 | 'epoch': epoch + 1,
154 | 'model': model.state_dict(),
155 | 'best_rsum': best_rsum,
156 | 'opt': opt,
157 | 'Eiters': model.Eiters,
158 | }, is_best, prefix=opt.logger_name + '/', save_all=opt.save_all)
159 |
160 | # reset memory bank
161 | model.mb_img = None
162 | model.mb_cap = None
163 |
164 | def load_memory_bank(opt, train_loader, model):
165 | mb_img, mb_cap, ind = None, None, None
166 | for i, train_data in enumerate(train_loader):
167 |
168 | if (i+1) % 50 == 0:
169 | print ('[ %d / %d memory bank loading randomly...]' % (i+1,len(train_loader)))
170 |
171 | if random() > opt.mb_rate: continue
172 |
173 | model.val_start()
174 | with torch.no_grad():
175 | imgs, caps, lengths, _, indices = train_data
176 | img_emb, cap_emb = model.forward_emb(imgs, caps, lengths)
177 |
178 | if mb_img is None:
179 | mb_img = img_emb
180 | mb_cap = cap_emb
181 | ind = indices
182 | else:
183 | mb_img = torch.cat((mb_img, img_emb), 0)
184 | mb_cap = torch.cat((mb_cap, cap_emb), 0)
185 | ind = ind + indices
186 | model.mb_img = mb_img
187 | model.mb_cap = mb_cap
188 | model.mb_ind = ind
189 |
190 | print ('[memory bank fully loaded!]')
191 | print ("MB(Image): ", model.mb_img.size())
192 | print ("MB(Caption): ", model.mb_cap.size())
193 | print ("indices len:",len(model.mb_ind))
194 |
195 | def train(opt, train_loader, model, epoch, val_loader):
196 | # average meters to record the training statistics
197 | batch_time = AverageMeter()
198 | data_time = AverageMeter()
199 | train_logger = LogCollector()
200 |
201 | end = time.time()
202 | for i, train_data in enumerate(train_loader):
203 |
204 | model.train_start()
205 |
206 | # measure data loading time
207 | data_time.update(time.time() - end)
208 |
209 | # make sure train logger is used
210 | model.logger = train_logger
211 |
212 | # Update the model
213 | model.train_emb(*train_data)
214 |
215 | # measure elapsed time
216 | batch_time.update(time.time() - end)
217 | end = time.time()
218 |
219 | # Print log info
220 | if model.Eiters % opt.log_step == 0:
221 | logging.info(
222 | 'Epoch: [{0}][{1}/{2}]\t'
223 | '{e_log}\t'
224 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
225 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
226 | .format(
227 | epoch, i, len(train_loader), batch_time=batch_time,
228 | data_time=data_time, e_log=str(model.logger)))
229 |
230 | # Record logs in tensorboard
231 | tb_logger.log_value('epoch', epoch, step=model.Eiters)
232 | tb_logger.log_value('step', i, step=model.Eiters)
233 | tb_logger.log_value('batch_time', batch_time.val, step=model.Eiters)
234 | tb_logger.log_value('data_time', data_time.val, step=model.Eiters)
235 | model.logger.tb_log(tb_logger, step=model.Eiters)
236 |
237 | # validate at every val_step
238 | if model.Eiters % opt.val_step == 0:
239 | validate(opt, val_loader, model)
240 |
241 |
242 | def validate(opt, val_loader, model):
243 | # compute the encoding for all the validation images and captions
244 | img_embs, cap_embs = encode_data(
245 | model, val_loader, opt.log_step, logging.info)
246 |
247 | # caption retrieval
248 | (r1, r5, r10, medr, meanr) = i2t(img_embs, cap_embs, measure=opt.measure)
249 | logging.info("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" %
250 | (r1, r5, r10, medr, meanr))
251 | # image retrieval
252 | (r1i, r5i, r10i, medri, meanr) = t2i(
253 | img_embs, cap_embs, measure=opt.measure)
254 | logging.info("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" %
255 | (r1i, r5i, r10i, medri, meanr))
256 | # sum of recalls to be used for early stopping
257 | currscore = r1 + r5 + r10 + r1i + r5i + r10i
258 |
259 | # record metrics in tensorboard
260 | tb_logger.log_value('r1', r1, step=model.Eiters)
261 | tb_logger.log_value('r5', r5, step=model.Eiters)
262 | tb_logger.log_value('r10', r10, step=model.Eiters)
263 | tb_logger.log_value('medr', medr, step=model.Eiters)
264 | tb_logger.log_value('meanr', meanr, step=model.Eiters)
265 | tb_logger.log_value('r1i', r1i, step=model.Eiters)
266 | tb_logger.log_value('r5i', r5i, step=model.Eiters)
267 | tb_logger.log_value('r10i', r10i, step=model.Eiters)
268 | tb_logger.log_value('medri', medri, step=model.Eiters)
269 | tb_logger.log_value('meanr', meanr, step=model.Eiters)
270 | tb_logger.log_value('rsum', currscore, step=model.Eiters)
271 |
272 | return currscore
273 |
274 |
275 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', prefix='', save_all=True):
276 | torch.save(state, prefix + filename)
277 | if is_best:
278 | print ("[Best model sofar, saved.]")
279 | shutil.copyfile(prefix + filename, prefix + 'model_best.pth.tar')
280 | if save_all:
281 | shutil.copyfile(prefix + filename, prefix + "Epoch-" + str(state['epoch']) + "-" + 'model.pth.tar')
282 |
283 |
284 |
285 | def adjust_learning_rate(opt, optimizer, epoch):
286 | """Sets the learning rate to the initial LR
287 | decayed by 10 every 30 epochs"""
288 | lr = opt.learning_rate * (0.1 ** (epoch // opt.lr_update))
289 | for param_group in optimizer.param_groups:
290 | param_group['lr'] = lr
291 |
292 |
293 | def accuracy(output, target, topk=(1,)):
294 | """Computes the precision@k for the specified values of k"""
295 | maxk = max(topk)
296 | batch_size = target.size(0)
297 |
298 | _, pred = output.topk(maxk, 1, True, True)
299 | pred = pred.t()
300 | correct = pred.eq(target.view(1, -1).expand_as(pred))
301 |
302 | res = []
303 | for k in topk:
304 | correct_k = correct[:k].view(-1).float().sum(0)
305 | res.append(correct_k.mul_(100.0 / batch_size))
306 | return res
307 |
308 |
309 | if __name__ == '__main__':
310 | main()
311 |
--------------------------------------------------------------------------------
/vocab.py:
--------------------------------------------------------------------------------
1 | # Create a vocabulary wrapper
2 | import nltk
3 | import pickle
4 | from collections import Counter
5 | from pycocotools.coco import COCO
6 | import json
7 | import argparse
8 | import os
9 |
10 | annotations = {
11 | 'coco_precomp': ['train_caps.txt', 'dev_caps.txt'],
12 | 'coco': ['annotations/captions_train2014.json',
13 | 'annotations/captions_val2014.json'],
14 | 'f8k_precomp': ['train_caps.txt', 'dev_caps.txt'],
15 | '10crop_precomp': ['train_caps.txt', 'dev_caps.txt'],
16 | 'f30k_precomp': ['train_caps.txt', 'dev_caps.txt'],
17 | 'f8k': ['dataset_flickr8k.json'],
18 | 'f30k': ['dataset_flickr30k.json'],
19 | }
20 |
21 |
22 | class Vocabulary(object):
23 | """Simple vocabulary wrapper."""
24 |
25 | def __init__(self):
26 | self.word2idx = {}
27 | self.idx2word = {}
28 | self.idx = 0
29 |
30 | def add_word(self, word):
31 | if word not in self.word2idx:
32 | self.word2idx[word] = self.idx
33 | self.idx2word[self.idx] = word
34 | self.idx += 1
35 |
36 | def __call__(self, word):
37 | if word not in self.word2idx:
38 | return self.word2idx['']
39 | return self.word2idx[word]
40 |
41 | def __len__(self):
42 | return len(self.word2idx)
43 |
44 |
45 | def from_coco_json(path):
46 | coco = COCO(path)
47 | ids = coco.anns.keys()
48 | captions = []
49 | for i, idx in enumerate(ids):
50 | captions.append(str(coco.anns[idx]['caption']))
51 |
52 | return captions
53 |
54 |
55 | def from_flickr_json(path):
56 | dataset = json.load(open(path, 'r'))['images']
57 | captions = []
58 | for i, d in enumerate(dataset):
59 | captions += [str(x['raw']) for x in d['sentences']]
60 |
61 | return captions
62 |
63 |
64 | def from_txt(txt):
65 | captions = []
66 | with open(txt, 'rb') as f:
67 | for line in f:
68 | captions.append(line.strip())
69 | return captions
70 |
71 |
72 | def build_vocab(data_path, data_name, jsons, threshold):
73 | """Build a simple vocabulary wrapper."""
74 | counter = Counter()
75 | for path in jsons[data_name]:
76 | full_path = os.path.join(os.path.join(data_path, data_name), path)
77 | if data_name == 'coco':
78 | captions = from_coco_json(full_path)
79 | elif data_name == 'f8k' or data_name == 'f30k':
80 | captions = from_flickr_json(full_path)
81 | else:
82 | captions = from_txt(full_path)
83 | for i, caption in enumerate(captions):
84 | tokens = nltk.tokenize.word_tokenize(
85 | caption.lower().decode('utf-8'))
86 | counter.update(tokens)
87 |
88 | if i % 1000 == 0:
89 | print("[%d/%d] tokenized the captions." % (i, len(captions)))
90 |
91 | # Discard if the occurrence of the word is less than min_word_cnt.
92 | words = [word for word, cnt in counter.items() if cnt >= threshold]
93 |
94 | # Create a vocab wrapper and add some special tokens.
95 | vocab = Vocabulary()
96 | vocab.add_word('')
97 | vocab.add_word('')
98 | vocab.add_word('')
99 | vocab.add_word('')
100 |
101 | # Add words to the vocabulary.
102 | for i, word in enumerate(words):
103 | vocab.add_word(word)
104 | return vocab
105 |
106 |
107 | def main(data_path, data_name):
108 | vocab = build_vocab(data_path, data_name, jsons=annotations, threshold=4)
109 | with open('./vocab/%s_vocab.pkl' % data_name, 'wb') as f:
110 | pickle.dump(vocab, f, pickle.HIGHEST_PROTOCOL)
111 | print("Saved vocabulary file to ", './vocab/%s_vocab.pkl' % data_name)
112 |
113 |
114 | if __name__ == '__main__':
115 | parser = argparse.ArgumentParser()
116 | parser.add_argument('--data_path', default='/w/31/faghri/vsepp_data/')
117 | parser.add_argument('--data_name', default='coco',
118 | help='{coco,f8k,f30k,10crop}_precomp|coco|f8k|f30k')
119 | opt = parser.parse_args()
120 | main(opt.data_path, opt.data_name)
121 |
--------------------------------------------------------------------------------