├── .gitignore ├── README.md ├── compute_results.py ├── data.py ├── data └── vocab │ ├── 10crop_precomp_vocab.pkl │ ├── coco_precomp_vocab.pkl │ ├── coco_vocab.pkl │ ├── f30k_precomp_vocab.pkl │ ├── f30k_vocab.pkl │ ├── f8k_precomp_vocab.pkl │ └── f8k_vocab.pkl ├── evaluation.py ├── model.py ├── requirements.txt ├── train.py └── vocab.py /.gitignore: -------------------------------------------------------------------------------- 1 | runs/ 2 | __pycache__/ 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VSE-HAL 2 | Code release for **HAL: Improved Text-Image Matching by Mitigating Visual Semantic Hubs** [\[arxiv\]](https://arxiv.org/pdf/1911.10097v1.pdf) at AAAI 2020. 3 | 4 | ```bibtex 5 | @inproceedings{liu2020hal, 6 | title={{HAL}: Improved text-image matching by mitigating visual semantic hubs}, 7 | author={Liu, Fangyu and Ye, Rongtian and Wang, Xun and Li, Shuaipeng}, 8 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, 9 | volume={34}, 10 | number={07}, 11 | pages={11563--11571}, 12 | year={2020} 13 | } 14 | ``` 15 | 16 | Upgrade your text-image matching model with a few lines of code: 17 | ```python 18 | class ContrastiveLoss(nn.Module): 19 | ... 20 | def forward(self, im, s, ...): 21 | bsize = im.size()[0] 22 | scores = self.sim(im, s) 23 | ... 24 | tmp = torch.eye(bsize).cuda() 25 | s_diag = tmp * scores 26 | scores_ = scores - s_diag 27 | ... 28 | S_ = torch.exp(self.l_alpha * (scores_ - self.l_ep)) 29 | loss_diag = - torch.log(1 + F.relu(s_diag.sum(0))) 30 | 31 | loss = torch.sum( \ 32 | torch.log(1 + S_.sum(0)) / self.l_alpha \ 33 | + torch.log(1 + S_.sum(1)) / self.l_alpha \ 34 | + loss_diag \ 35 | ) / bsize 36 | 37 | return loss 38 | ``` 39 | 40 | 41 | ## Dependencies 42 | ``` 43 | nltk==3.4.5 44 | pycocotools==2.0.0 45 | numpy==1.18.1 46 | torch==1.5.1 47 | torchvision==0.6.0 48 | tensorboard_logger==0.1.0 49 | ``` 50 | 51 | ## Data 52 | #### MS-COCO 53 | [\[vgg_precomp\]](https://cs.stanford.edu/people/karpathy/deepimagesent/coco.zip)
54 | [\[resnet_precomp\]](https://drive.google.com/uc?id=1vtUijEbXpVzNt6HjC6ph8ZzMHRRNms5j&export=download) 55 | 56 | #### Flickr30k 57 | [\[vgg_precomp\]](http://www.cs.toronto.edu/~faghri/vsepp/data.tar) 58 | 59 | ## Train 60 | 61 | Run `train.py`. 62 | 63 | #### MS-COCO 64 | 65 | ##### w/o global weighting 66 | 67 | ```bash 68 | python3 train.py \ 69 | --data_path "data/data/resnet_precomp" \ 70 | --vocab_path "data/vocab/" \ 71 | --data_name coco_precomp \ 72 | --batch_size 512 \ 73 | --learning_rate 0.001 \ 74 | --lr_update 8 \ 75 | --num_epochs 13 \ 76 | --img_dim 2048 \ 77 | --logger_name runs/COCO \ 78 | --local_alpha 30.00 \ 79 | --local_ep 0.3 80 | ``` 81 | 82 | ##### with global weighting 83 | 84 | ```bash 85 | python3 train.py \ 86 | --data_path "data/data/resnet_precomp" \ 87 | --vocab_path "data/vocab/" \ 88 | --data_name coco_precomp \ 89 | --batch_size 512 \ 90 | --learning_rate 0.001 \ 91 | --lr_update 8 \ 92 | --num_epochs 13 \ 93 | --img_dim 2048 \ 94 | --logger_name runs/COCO_mb \ 95 | --local_alpha 30.00 \ 96 | --local_ep 0.3 \ 97 | --memory_bank \ 98 | --global_alpha 40.00 \ 99 | --global_beta 40.00 \ 100 | --global_ep_posi 0.20 \ 101 | --global_ep_nega 0.10 \ 102 | --mb_rate 0.05 \ 103 | --mb_k 250 104 | ``` 105 | 106 | #### Flickr30k 107 | 108 | ```bash 109 | python3 train.py \ 110 | --data_path "data/data" \ 111 | --vocab_path "data/vocab/" \ 112 | --data_name f30k_precomp \ 113 | --batch_size 128 \ 114 | --learning_rate 0.001 \ 115 | --lr_update 8 \ 116 | --num_epochs 13 \ 117 | --logger_name runs/f30k \ 118 | --local_alpha 60.00 \ 119 | --local_ep 0.7 120 | ``` 121 | 122 | ## Evaluate 123 | 124 | Run `compute_results.py`. 125 | 126 | #### COCO 127 | 128 | ```bash 129 | python3 compute_results.py --data_path data/data/resnet_precomp --fold5 --model_path runs/COCO/model_best.pth.tar 130 | ``` 131 | 132 | #### Flickr30k 133 | 134 | ```bash 135 | python3 compute_results.py --data_path data/data --model_path runs/f30k/model_best.pth.tar 136 | ``` 137 | #### Trained models 138 | [\[Google Drive\]](https://drive.google.com/drive/folders/1H_EVBFxpYKObNo_CjV0pTaB24A1jWsSF) 139 | 140 | ## Note 141 | Trained models and codes for replicating results on [SCAN](https://github.com/kuanghuei/SCAN) are coming soon. 142 | 143 | ## Acknowledgments 144 | This project would be impossible without the open source implementations of [VSE++](https://github.com/fartashf/vsepp) and [SCAN](https://github.com/kuanghuei/SCAN). 145 | 146 | ## License 147 | [Apache License 2.0](http://www.apache.org/licenses/LICENSE-2.0) 148 | -------------------------------------------------------------------------------- /compute_results.py: -------------------------------------------------------------------------------- 1 | from vocab import Vocabulary 2 | import evaluation 3 | 4 | import argparse 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('--model_path', default='$RUN_PATH/coco_vse/model_best.pth.tar', help='path to model') 8 | parser.add_argument('--data_path', default='data/data', help='path to datasets') 9 | parser.add_argument('--fold5', action='store_true', 10 | help='Use fold5') 11 | parser.add_argument('--save_embeddings', action='store_true', 12 | help='save_embeddings') 13 | parser.add_argument('--save_csv', default='') 14 | 15 | opt_eval = parser.parse_args() 16 | 17 | evaluation.evalrank(opt_eval, split='test') 18 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | import torchvision.transforms as transforms 4 | import os 5 | import nltk 6 | from PIL import Image 7 | from pycocotools.coco import COCO 8 | import numpy as np 9 | import json as jsonmod 10 | 11 | 12 | def get_paths(path, name='coco', use_restval=False): 13 | """ 14 | Returns paths to images and annotations for the given datasets. For MSCOCO 15 | indices are also returned to control the data split being used. 16 | The indices are extracted from the Karpathy et al. splits using this 17 | snippet: 18 | 19 | >>> import json 20 | >>> dataset=json.load(open('dataset_coco.json','r')) 21 | >>> A=[] 22 | >>> for i in range(len(D['images'])): 23 | ... if D['images'][i]['split'] == 'val': 24 | ... A+=D['images'][i]['sentids'][:5] 25 | ... 26 | 27 | :param name: Dataset names 28 | :param use_restval: If True, the the `restval` data is included in train. 29 | """ 30 | roots = {} 31 | ids = {} 32 | if 'coco' == name: 33 | imgdir = os.path.join(path, 'images') 34 | capdir = os.path.join(path, 'annotations') 35 | roots['train'] = { 36 | 'img': os.path.join(imgdir, 'train2014'), 37 | 'cap': os.path.join(capdir, 'captions_train2014.json') 38 | } 39 | roots['val'] = { 40 | 'img': os.path.join(imgdir, 'val2014'), 41 | 'cap': os.path.join(capdir, 'captions_val2014.json') 42 | } 43 | roots['test'] = { 44 | 'img': os.path.join(imgdir, 'val2014'), 45 | 'cap': os.path.join(capdir, 'captions_val2014.json') 46 | } 47 | roots['trainrestval'] = { 48 | 'img': (roots['train']['img'], roots['val']['img']), 49 | 'cap': (roots['train']['cap'], roots['val']['cap']) 50 | } 51 | ids['train'] = np.load(os.path.join(capdir, 'coco_train_ids.npy')) 52 | ids['val'] = np.load(os.path.join(capdir, 'coco_dev_ids.npy'))[:5000] 53 | ids['test'] = np.load(os.path.join(capdir, 'coco_test_ids.npy')) 54 | ids['trainrestval'] = ( 55 | ids['train'], 56 | np.load(os.path.join(capdir, 'coco_restval_ids.npy'))) 57 | if use_restval: 58 | roots['train'] = roots['trainrestval'] 59 | ids['train'] = ids['trainrestval'] 60 | elif 'f8k' == name: 61 | imgdir = os.path.join(path, 'images') 62 | cap = os.path.join(path, 'dataset_flickr8k.json') 63 | roots['train'] = {'img': imgdir, 'cap': cap} 64 | roots['val'] = {'img': imgdir, 'cap': cap} 65 | roots['test'] = {'img': imgdir, 'cap': cap} 66 | ids = {'train': None, 'val': None, 'test': None} 67 | elif 'f30k' == name: 68 | imgdir = os.path.join(path, 'images') 69 | cap = os.path.join(path, 'dataset_flickr30k.json') 70 | roots['train'] = {'img': imgdir, 'cap': cap} 71 | roots['val'] = {'img': imgdir, 'cap': cap} 72 | roots['test'] = {'img': imgdir, 'cap': cap} 73 | ids = {'train': None, 'val': None, 'test': None} 74 | 75 | return roots, ids 76 | 77 | 78 | class CocoDataset(data.Dataset): 79 | """COCO Custom Dataset compatible with torch.utils.data.DataLoader.""" 80 | 81 | def __init__(self, root, json, vocab, transform=None, ids=None): 82 | """ 83 | Args: 84 | root: image directory. 85 | json: coco annotation file path. 86 | vocab: vocabulary wrapper. 87 | transform: transformer for image. 88 | """ 89 | self.root = root 90 | # when using `restval`, two json files are needed 91 | if isinstance(json, tuple): 92 | self.coco = (COCO(json[0]), COCO(json[1])) 93 | else: 94 | self.coco = (COCO(json),) 95 | self.root = (root,) 96 | # if ids provided by get_paths, use split-specific ids 97 | if ids is None: 98 | self.ids = list(self.coco.anns.keys()) 99 | else: 100 | self.ids = ids 101 | 102 | # if `restval` data is to be used, record the break point for ids 103 | if isinstance(self.ids, tuple): 104 | self.bp = len(self.ids[0]) 105 | self.ids = list(self.ids[0]) + list(self.ids[1]) 106 | else: 107 | self.bp = len(self.ids) 108 | self.vocab = vocab 109 | self.transform = transform 110 | 111 | def __getitem__(self, index): 112 | """This function returns a tuple that is further passed to collate_fn 113 | """ 114 | vocab = self.vocab 115 | root, caption, img_id, path, image = self.get_raw_item(index) 116 | 117 | if self.transform is not None: 118 | image = self.transform(image) 119 | 120 | # Convert caption (string) to word ids. 121 | tokens = nltk.tokenize.word_tokenize( 122 | str(caption).lower().decode('utf-8')) 123 | caption = [] 124 | caption.append(vocab('')) 125 | caption.extend([vocab(token) for token in tokens]) 126 | caption.append(vocab('')) 127 | target = torch.Tensor(caption) 128 | return image, target, index, img_id, index 129 | 130 | def get_raw_item(self, index): 131 | if index < self.bp: 132 | coco = self.coco[0] 133 | root = self.root[0] 134 | else: 135 | coco = self.coco[1] 136 | root = self.root[1] 137 | ann_id = self.ids[index] 138 | caption = coco.anns[ann_id]['caption'] 139 | img_id = coco.anns[ann_id]['image_id'] 140 | path = coco.loadImgs(img_id)[0]['file_name'] 141 | image = Image.open(os.path.join(root, path)).convert('RGB') 142 | 143 | return root, caption, img_id, path, image 144 | 145 | def __len__(self): 146 | return len(self.ids) 147 | 148 | 149 | class FlickrDataset(data.Dataset): 150 | """ 151 | Dataset loader for Flickr30k and Flickr8k full datasets. 152 | """ 153 | 154 | def __init__(self, root, json, split, vocab, transform=None): 155 | self.root = root 156 | self.vocab = vocab 157 | self.split = split 158 | self.transform = transform 159 | self.dataset = jsonmod.load(open(json, 'r'))['images'] 160 | self.ids = [] 161 | for i, d in enumerate(self.dataset): 162 | if d['split'] == split: 163 | self.ids += [(i, x) for x in range(len(d['sentences']))] 164 | 165 | def __getitem__(self, index): 166 | """This function returns a tuple that is further passed to collate_fn 167 | """ 168 | vocab = self.vocab 169 | root = self.root 170 | ann_id = self.ids[index] 171 | img_id = ann_id[0] 172 | caption = self.dataset[img_id]['sentences'][ann_id[1]]['raw'] 173 | path = self.dataset[img_id]['filename'] 174 | 175 | image = Image.open(os.path.join(root, path)).convert('RGB') 176 | if self.transform is not None: 177 | image = self.transform(image) 178 | 179 | # Convert caption (string) to word ids. 180 | tokens = nltk.tokenize.word_tokenize( 181 | str(caption).lower()) 182 | caption = [] 183 | caption.append(vocab('')) 184 | caption.extend([vocab(token) for token in tokens]) 185 | caption.append(vocab('')) 186 | target = torch.Tensor(caption) 187 | return image, target, index, img_id, index 188 | 189 | def __len__(self): 190 | return len(self.ids) 191 | 192 | 193 | class PrecompDataset(data.Dataset): 194 | """ 195 | Load precomputed captions and image features 196 | Possible options: f8k, f30k, coco, 10crop 197 | """ 198 | 199 | def __init__(self, data_path, data_split, vocab): 200 | self.vocab = vocab 201 | loc = data_path + '/' 202 | 203 | # Captions 204 | self.captions = [] 205 | caps_name = loc + '%s_caps.txt' % data_split 206 | with open(caps_name, 'rb') as f: 207 | for line in f: 208 | self.captions.append(line.strip()) 209 | 210 | # Image features 211 | self.images = np.load(loc+'%s_ims.npy' % data_split) 212 | self.length = len(self.captions) 213 | # rkiros data has redundancy in images, we divide by 5, 10crop doesn't 214 | if self.images.shape[0] != self.length: 215 | self.im_div = 5 216 | else: 217 | self.im_div = 1 218 | # the development set for coco is large and so validation would be slow 219 | if data_split == 'dev': 220 | self.length = 5000 221 | 222 | def __getitem__(self, index): 223 | # handle the image redundancy 224 | img_id = index/self.im_div 225 | image = torch.Tensor(self.images[int(img_id)]) 226 | caption = self.captions[index] 227 | vocab = self.vocab 228 | 229 | # Convert caption (string) to word ids. 230 | tokens = nltk.tokenize.word_tokenize(str(caption).lower()) 231 | caption = [] 232 | caption.append(vocab('')) 233 | caption.extend([vocab(token) for token in tokens]) 234 | caption.append(vocab('')) 235 | target = torch.Tensor(caption) 236 | return image, target, index, img_id, index 237 | 238 | def __len__(self): 239 | return self.length 240 | 241 | 242 | def collate_fn(data): 243 | """Build mini-batch tensors from a list of (image, caption) tuples. 244 | Args: 245 | data: list of (image, caption) tuple. 246 | - image: torch tensor of shape (3, 256, 256). 247 | - caption: torch tensor of shape (?); variable length. 248 | 249 | Returns: 250 | images: torch tensor of shape (batch_size, 3, 256, 256). 251 | targets: torch tensor of shape (batch_size, padded_length). 252 | lengths: list; valid length for each padded caption. 253 | """ 254 | # Sort a data list by caption length 255 | data.sort(key=lambda x: len(x[1]), reverse=True) 256 | images, captions, ids, img_ids, indices = zip(*data) 257 | 258 | # Merge images (convert tuple of 3D tensor to 4D tensor) 259 | images = torch.stack(images, 0) 260 | 261 | # Merget captions (convert tuple of 1D tensor to 2D tensor) 262 | lengths = [len(cap) for cap in captions] 263 | targets = torch.zeros(len(captions), max(lengths)).long() 264 | for i, cap in enumerate(captions): 265 | end = lengths[i] 266 | targets[i, :end] = cap[:end] 267 | 268 | return images, targets, lengths, ids, indices 269 | 270 | 271 | def get_loader_single(data_name, split, root, json, vocab, transform, 272 | batch_size=100, shuffle=True, 273 | num_workers=2, ids=None, collate_fn=collate_fn): 274 | """Returns torch.utils.data.DataLoader for custom coco dataset.""" 275 | if 'coco' in data_name: 276 | # COCO custom dataset 277 | dataset = CocoDataset(root=root, 278 | json=json, 279 | vocab=vocab, 280 | transform=transform, ids=ids) 281 | elif 'f8k' in data_name or 'f30k' in data_name: 282 | dataset = FlickrDataset(root=root, 283 | split=split, 284 | json=json, 285 | vocab=vocab, 286 | transform=transform) 287 | 288 | # Data loader 289 | data_loader = torch.utils.data.DataLoader(dataset=dataset, 290 | batch_size=batch_size, 291 | shuffle=shuffle, 292 | pin_memory=True, 293 | num_workers=num_workers, 294 | collate_fn=collate_fn) 295 | return data_loader 296 | 297 | 298 | def get_precomp_loader(data_path, data_split, vocab, opt, batch_size=100, 299 | shuffle=True, num_workers=2): 300 | """Returns torch.utils.data.DataLoader for custom coco dataset.""" 301 | dset = PrecompDataset(data_path, data_split, vocab) 302 | 303 | data_loader = torch.utils.data.DataLoader(dataset=dset, 304 | batch_size=batch_size, 305 | shuffle=shuffle, 306 | pin_memory=True, 307 | collate_fn=collate_fn) 308 | return data_loader 309 | 310 | 311 | def get_transform(data_name, split_name, opt): 312 | normalizer = transforms.Normalize(mean=[0.485, 0.456, 0.406], 313 | std=[0.229, 0.224, 0.225]) 314 | t_list = [] 315 | if split_name == 'train': 316 | t_list = [transforms.RandomResizedCrop(opt.crop_size), 317 | transforms.RandomHorizontalFlip()] 318 | elif split_name == 'val': 319 | t_list = [transforms.Resize(256), transforms.CenterCrop(224)] 320 | elif split_name == 'test': 321 | t_list = [transforms.Resize(256), transforms.CenterCrop(224)] 322 | 323 | t_end = [transforms.ToTensor(), normalizer] 324 | transform = transforms.Compose(t_list + t_end) 325 | return transform 326 | 327 | 328 | def get_loaders(data_name, vocab, crop_size, batch_size, workers, opt): 329 | dpath = os.path.join(opt.data_path, data_name) 330 | if opt.data_name.endswith('_precomp'): 331 | train_loader = get_precomp_loader(dpath, 'train', vocab, opt, 332 | batch_size, True, workers) 333 | val_loader = get_precomp_loader(dpath, 'dev', vocab, opt, 334 | batch_size, False, workers) 335 | else: 336 | # Build Dataset Loader 337 | roots, ids = get_paths(dpath, data_name, opt.use_restval) 338 | 339 | transform = get_transform(data_name, 'train', opt) 340 | train_loader = get_loader_single(opt.data_name, 'train', 341 | roots['train']['img'], 342 | roots['train']['cap'], 343 | vocab, transform, ids=ids['train'], 344 | batch_size=batch_size, shuffle=True, 345 | num_workers=workers, 346 | collate_fn=collate_fn) 347 | 348 | transform = get_transform(data_name, 'val', opt) 349 | val_loader = get_loader_single(opt.data_name, 'val', 350 | roots['val']['img'], 351 | roots['val']['cap'], 352 | vocab, transform, ids=ids['val'], 353 | batch_size=batch_size, shuffle=False, 354 | num_workers=workers, 355 | collate_fn=collate_fn) 356 | 357 | return train_loader, val_loader 358 | 359 | 360 | def get_test_loader(split_name, data_name, vocab, crop_size, batch_size, 361 | workers, opt): 362 | dpath = os.path.join(opt.data_path, data_name) 363 | if opt.data_name.endswith('_precomp'): 364 | test_loader = get_precomp_loader(dpath, split_name, vocab, opt, 365 | batch_size, False, workers) 366 | else: 367 | # Build Dataset Loader 368 | roots, ids = get_paths(dpath, data_name, opt.use_restval) 369 | 370 | transform = get_transform(data_name, split_name, opt) 371 | test_loader = get_loader_single(opt.data_name, split_name, 372 | roots[split_name]['img'], 373 | roots[split_name]['cap'], 374 | vocab, transform, ids=ids[split_name], 375 | batch_size=batch_size, shuffle=False, 376 | num_workers=workers, 377 | collate_fn=collate_fn) 378 | 379 | return test_loader 380 | -------------------------------------------------------------------------------- /data/vocab/10crop_precomp_vocab.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hardyqr/HAL/5a33289255c7807d35c22a54e7cebed8ed2ee77b/data/vocab/10crop_precomp_vocab.pkl -------------------------------------------------------------------------------- /data/vocab/coco_precomp_vocab.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hardyqr/HAL/5a33289255c7807d35c22a54e7cebed8ed2ee77b/data/vocab/coco_precomp_vocab.pkl -------------------------------------------------------------------------------- /data/vocab/coco_vocab.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hardyqr/HAL/5a33289255c7807d35c22a54e7cebed8ed2ee77b/data/vocab/coco_vocab.pkl -------------------------------------------------------------------------------- /data/vocab/f30k_precomp_vocab.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hardyqr/HAL/5a33289255c7807d35c22a54e7cebed8ed2ee77b/data/vocab/f30k_precomp_vocab.pkl -------------------------------------------------------------------------------- /data/vocab/f30k_vocab.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hardyqr/HAL/5a33289255c7807d35c22a54e7cebed8ed2ee77b/data/vocab/f30k_vocab.pkl -------------------------------------------------------------------------------- /data/vocab/f8k_precomp_vocab.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hardyqr/HAL/5a33289255c7807d35c22a54e7cebed8ed2ee77b/data/vocab/f8k_precomp_vocab.pkl -------------------------------------------------------------------------------- /data/vocab/f8k_vocab.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hardyqr/HAL/5a33289255c7807d35c22a54e7cebed8ed2ee77b/data/vocab/f8k_vocab.pkl -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import pickle 4 | 5 | import numpy 6 | from data import get_test_loader 7 | import time 8 | import numpy as np 9 | from vocab import Vocabulary # NOQA 10 | import torch 11 | from model import VSE, order_sim 12 | from collections import OrderedDict 13 | 14 | class AverageMeter(object): 15 | """Computes and stores the average and current value""" 16 | 17 | def __init__(self): 18 | self.reset() 19 | 20 | def reset(self): 21 | self.val = 0 22 | self.avg = 0 23 | self.sum = 0 24 | self.count = 0 25 | 26 | def update(self, val, n=0): 27 | self.val = val 28 | self.sum += val * n 29 | self.count += n 30 | self.avg = self.sum / (.0001 + self.count) 31 | 32 | def __str__(self): 33 | """String representation for logging 34 | """ 35 | # for values that should be recorded exactly e.g. iteration number 36 | if self.count == 0: 37 | return str(self.val) 38 | # for stats 39 | return '%.4f (%.4f)' % (self.val, self.avg) 40 | 41 | 42 | class LogCollector(object): 43 | """A collection of logging objects that can change from train to val""" 44 | 45 | def __init__(self): 46 | # to keep the order of logged variables deterministic 47 | self.meters = OrderedDict() 48 | 49 | def update(self, k, v, n=0): 50 | # create a new meter if previously not recorded 51 | if k not in self.meters: 52 | self.meters[k] = AverageMeter() 53 | self.meters[k].update(v, n) 54 | 55 | def __str__(self): 56 | """Concatenate the meters in one log line 57 | """ 58 | s = '' 59 | for i, (k, v) in enumerate(self.meters.items()): 60 | if i > 0: 61 | s += ' ' 62 | s += k + ' ' + str(v) 63 | return s 64 | 65 | def tb_log(self, tb_logger, prefix='', step=None): 66 | """Log using tensorboard 67 | """ 68 | for k, v in self.meters.items(): 69 | tb_logger.log_value(prefix + k, v.val, step=step) 70 | 71 | 72 | def encode_data(model, data_loader, log_step=10, logging=print): 73 | """Encode all images and captions loadable by `data_loader` 74 | """ 75 | batch_time = AverageMeter() 76 | val_logger = LogCollector() 77 | 78 | # switch to evaluate mode 79 | model.val_start() 80 | 81 | end = time.time() 82 | 83 | # numpy array to keep all the embeddings 84 | img_embs = None 85 | cap_embs = None 86 | for i, (images, captions, lengths, ids, indices) in enumerate(data_loader): 87 | # make sure val logger is used 88 | model.logger = val_logger 89 | 90 | # compute the embeddings 91 | img_emb, cap_emb = model.forward_emb(images, captions, lengths, 92 | volatile=True) 93 | 94 | # initialize the numpy arrays given the size of the embeddings 95 | if img_embs is None: 96 | img_embs = np.zeros((len(data_loader.dataset), img_emb.size(1))) 97 | cap_embs = np.zeros((len(data_loader.dataset), cap_emb.size(1))) 98 | 99 | # preserve the embeddings by copying from gpu and converting to numpy 100 | img_embs[ids] = img_emb.data.cpu().numpy().copy() 101 | cap_embs[ids] = cap_emb.data.cpu().numpy().copy() 102 | 103 | # measure accuracy and record loss 104 | model.forward_loss(img_emb, cap_emb, indices) 105 | 106 | # measure elapsed time 107 | batch_time.update(time.time() - end) 108 | end = time.time() 109 | 110 | if i % log_step == 0: 111 | logging('Test: [{0}/{1}]\t' 112 | '{e_log}\t' 113 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 114 | .format( 115 | i, len(data_loader), batch_time=batch_time, 116 | e_log=str(model.logger))) 117 | del images, captions 118 | 119 | return img_embs, cap_embs 120 | 121 | 122 | def evalrank(opt_eval, split): 123 | """ 124 | Evaluate a trained model on either dev or test. If `fold5=True`, 5 fold 125 | cross-validation is done (only for MSCOCO). Otherwise, the full data is 126 | used for evaluation. 127 | """ 128 | # load model and options 129 | checkpoint = torch.load(opt_eval.model_path) 130 | opt = checkpoint['opt'] 131 | if opt_eval.data_path is not None: 132 | opt.data_path = opt_eval.data_path 133 | print(opt) 134 | # load vocabulary used by the model 135 | with open(os.path.join(opt.vocab_path, 136 | '%s_vocab.pkl' % opt.data_name), 'rb') as f: 137 | vocab = pickle.load(f) 138 | opt.vocab_size = len(vocab) 139 | 140 | # construct model 141 | model = VSE(opt) 142 | 143 | # load model state 144 | model.load_state_dict(checkpoint['model']) 145 | 146 | print('Loading dataset') 147 | data_loader = get_test_loader(split, opt.data_name, vocab, opt.crop_size, 148 | opt.batch_size, opt.workers, opt) 149 | 150 | print('Computing results...') 151 | img_embs, cap_embs = encode_data(model, data_loader) 152 | print('Images: %d, Captions: %d' % 153 | (img_embs.shape[0] / 5, cap_embs.shape[0])) 154 | 155 | if opt_eval.save_embeddings: 156 | save_path = opt_eval.model_path.split('/')[-2]+'_img_and_cap_embeddings.pth.tar' 157 | with open(save_path+'.pkl', 'wb') as handle: 158 | pickle.dump({'img_embs':img_embs, 'cap_embs':cap_embs}, 159 | handle, protocol=pickle.HIGHEST_PROTOCOL) 160 | print ("[embeddings saved to {}]".format(save_path)) 161 | 162 | if not opt_eval.fold5: 163 | # no cross-validation, full evaluation 164 | r, rt = i2t(img_embs, cap_embs, measure=opt.measure, return_ranks=True) 165 | ri, rti = t2i(img_embs, cap_embs, 166 | measure=opt.measure, return_ranks=True) 167 | ar = (r[0] + r[1] + r[2]) / 3 168 | ari = (ri[0] + ri[1] + ri[2]) / 3 169 | rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2] 170 | print("rsum: %.1f" % rsum) 171 | print("Average i2t Recall: %.1f" % ar) 172 | print("Image to text: %.1f %.1f %.1f %.1f %.1f" % r) 173 | print("Average t2i Recall: %.1f" % ari) 174 | print("Text to image: %.1f %.1f %.1f %.1f %.1f" % ri) 175 | 176 | if len(opt_eval.save_csv) > 0: 177 | with open(opt_eval.save_csv, "a") as f: 178 | i2t_data = ", %.1f, %.1f, %.1f, %.1f, %.1f" % r 179 | t2i_data = ", %.1f, %.1f, %.1f, %.1f, %.1f" % ri 180 | rsum_data = ", %.1f" % rsum 181 | row_data = opt.logger_name + i2t_data + t2i_data + rsum_data + "\n" 182 | f.write(row_data) 183 | else: 184 | # 5fold cross-validation, only for MSCOCO 185 | results = [] 186 | for i in range(5): 187 | r, rt0 = i2t(img_embs[i * 5000:(i + 1) * 5000], 188 | cap_embs[i * 5000:(i + 1) * 189 | 5000], measure=opt.measure, 190 | return_ranks=True) 191 | print("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" % r) 192 | ri, rti0 = t2i(img_embs[i * 5000:(i + 1) * 5000], 193 | cap_embs[i * 5000:(i + 1) * 194 | 5000], measure=opt.measure, 195 | return_ranks=True) 196 | if i == 0: 197 | rt, rti = rt0, rti0 198 | print("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" % ri) 199 | ar = (r[0] + r[1] + r[2]) / 3 200 | ari = (ri[0] + ri[1] + ri[2]) / 3 201 | rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2] 202 | print("rsum: %.1f ar: %.1f ari: %.1f" % (rsum, ar, ari)) 203 | results += [list(r) + list(ri) + [ar, ari, rsum]] 204 | 205 | if i == 0 and len(opt_eval.save_csv) > 0: 206 | with open(opt_eval.save_csv + "_fold1", "a") as f: 207 | i2t_data = ", %.1f, %.1f, %.1f, %.1f, %.1f" % r 208 | t2i_data = ", %.1f, %.1f, %.1f, %.1f, %.1f" % ri 209 | rsum_data = ", %.1f" % rsum 210 | row_data = opt.logger_name + i2t_data + t2i_data + rsum_data + "\n" 211 | f.write(row_data) 212 | 213 | print("-----------------------------------") 214 | print("Mean metrics: ") 215 | mean_metrics = tuple(np.array(results).mean(axis=0).flatten()) 216 | print("rsum: %.1f" % (mean_metrics[12])) 217 | print("Average i2t Recall: %.1f" % mean_metrics[11]) 218 | print("Image to text: %.1f %.1f %.1f %.1f %.1f" % 219 | mean_metrics[:5]) 220 | print("Average t2i Recall: %.1f" % mean_metrics[12]) 221 | print("Text to image: %.1f %.1f %.1f %.1f %.1f" % 222 | mean_metrics[5:10]) 223 | 224 | if len(opt_eval.save_csv) > 0: 225 | with open(opt_eval.save_csv + "_fold5", "a") as f: 226 | i2t_data = ", %.1f, %.1f, %.1f, %.1f, %.1f" % mean_metrics[:5] 227 | t2i_data = ", %.1f, %.1f, %.1f, %.1f, %.1f" % mean_metrics[5:10] 228 | rsum_data = ", %.1f" % mean_metrics[12] 229 | row_data = opt.logger_name + i2t_data + t2i_data + rsum_data + "\n" 230 | f.write(row_data) 231 | 232 | torch.save({'rt': rt, 'rti': rti}, 'ranks.pth.tar') 233 | 234 | def i2t(images, captions, npts=None, measure='cosine', 235 | return_ranks=False): 236 | """ 237 | Images->Text (Image Annotation) 238 | Images: (5N, K) matrix of images 239 | Captions: (5N, K) matrix of captions 240 | """ 241 | if npts is None: 242 | npts = int(images.shape[0] / 5) 243 | #print(npts) 244 | index_list = [] 245 | 246 | scores = images.dot(captions.T) 247 | 248 | ranks = numpy.zeros(npts) 249 | top1 = numpy.zeros(npts) 250 | for index in range(npts): 251 | 252 | # Get query image 253 | im = images[5 * index].reshape(1, images.shape[1]) 254 | 255 | # Compute scores 256 | if measure == 'order': 257 | bs = 100 258 | if index % bs == 0: 259 | mx = min(images.shape[0], 5 * (index + bs)) 260 | im2 = images[5 * index:mx:5] 261 | d2 = order_sim(torch.Tensor(im2), 262 | torch.Tensor(captions)) 263 | d2 = d2.cpu().numpy() 264 | d = d2[index % bs] 265 | else: 266 | d = scores[5 * index] 267 | inds = numpy.argsort(d)[::-1] 268 | index_list.append(inds[0]) 269 | 270 | # Score 271 | rank = 1e20 272 | for i in range(5 * index, 5 * index + 5, 1): 273 | tmp = numpy.where(inds == i)[0][0] 274 | if tmp < rank: 275 | rank = tmp 276 | ranks[index] = rank 277 | top1[index] = inds[0] 278 | 279 | # Compute metrics 280 | r1 = 100.0 * len(numpy.where(ranks < 1)[0]) / len(ranks) 281 | r5 = 100.0 * len(numpy.where(ranks < 5)[0]) / len(ranks) 282 | r10 = 100.0 * len(numpy.where(ranks < 10)[0]) / len(ranks) 283 | medr = numpy.floor(numpy.median(ranks)) + 1 284 | meanr = ranks.mean() + 1 285 | if return_ranks: 286 | return (r1, r5, r10, medr, meanr), (ranks, top1) 287 | else: 288 | return (r1, r5, r10, medr, meanr) 289 | 290 | 291 | def t2i(images, captions, npts=None, measure='cosine', 292 | return_ranks=False): 293 | """ 294 | Text->Images (Image Search) 295 | Images: (5N, K) matrix of images 296 | Captions: (5N, K) matrix of captions 297 | """ 298 | if npts is None: 299 | npts = int(images.shape[0] / 5) 300 | #print("# points:", npts) 301 | ims = numpy.array([images[i] for i in range(0, len(images), 5)]) 302 | 303 | scores = captions.dot(ims.T) 304 | 305 | ranks = numpy.zeros(5 * npts) 306 | top1 = numpy.zeros(5 * npts) 307 | for index in range(npts): 308 | 309 | # Compute scores 310 | if measure == 'order': 311 | bs = 100 312 | if 5 * index % bs == 0: 313 | mx = min(captions.shape[0], 5 * index + bs) 314 | q2 = captions[5 * index:mx] 315 | d2 = order_sim(torch.Tensor(ims), 316 | torch.Tensor(q2)) 317 | d2 = d2.cpu().numpy() 318 | 319 | d = d2[:, (5 * index) % bs:(5 * index) % bs + 5].T 320 | else: 321 | d = scores[5 * index:5 * index + 5] 322 | inds = numpy.zeros(d.shape) 323 | for i in range(len(inds)): 324 | inds[i] = numpy.argsort(d[i])[::-1] 325 | ranks[5 * index + i] = numpy.where(inds[i] == index)[0][0] 326 | top1[5 * index + i] = inds[i][0] 327 | 328 | # Compute metrics 329 | r1 = 100.0 * len(numpy.where(ranks < 1)[0]) / len(ranks) 330 | r5 = 100.0 * len(numpy.where(ranks < 5)[0]) / len(ranks) 331 | r10 = 100.0 * len(numpy.where(ranks < 10)[0]) / len(ranks) 332 | medr = numpy.floor(numpy.median(ranks)) + 1 333 | meanr = ranks.mean() + 1 334 | if return_ranks: 335 | return (r1, r5, r10, medr, meanr), (ranks, top1) 336 | else: 337 | return (r1, r5, r10, medr, meanr) 338 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | import torch.nn.init 5 | import torchvision.models as models 6 | from torch.autograd import Variable 7 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 8 | import torch.backends.cudnn as cudnn 9 | from torch.nn.utils.clip_grad import clip_grad_norm_ 10 | import numpy as np 11 | from collections import OrderedDict 12 | from random import randint 13 | 14 | def l2norm(X): 15 | """L2-normalize columns of X 16 | """ 17 | norm = torch.pow(X, 2).sum(dim=1, keepdim=True).sqrt() 18 | X = torch.div(X, norm) 19 | return X 20 | 21 | 22 | def EncoderImage(data_name, img_dim, embed_size, finetune=False, 23 | cnn_type='vgg19', use_abs=False, no_imgnorm=False): 24 | """A wrapper to image encoders. Chooses between an encoder that uses 25 | precomputed image features, `EncoderImagePrecomp`, or an encoder that 26 | computes image features on the fly `EncoderImageFull`. 27 | """ 28 | if data_name.endswith('_precomp'): 29 | img_enc = EncoderImagePrecomp( 30 | img_dim, embed_size, use_abs, no_imgnorm) 31 | else: 32 | img_enc = EncoderImageFull( 33 | embed_size, finetune, cnn_type, use_abs, no_imgnorm) 34 | 35 | return img_enc 36 | 37 | 38 | # tutorials/09 - Image Captioning 39 | class EncoderImageFull(nn.Module): 40 | 41 | def __init__(self, embed_size, finetune=False, cnn_type='vgg19', 42 | use_abs=False, no_imgnorm=False): 43 | """Load pretrained VGG19 and replace top fc layer.""" 44 | super(EncoderImageFull, self).__init__() 45 | self.embed_size = embed_size 46 | self.no_imgnorm = no_imgnorm 47 | self.use_abs = use_abs 48 | 49 | # Load a pre-trained model 50 | self.cnn = self.get_cnn(cnn_type, True) 51 | 52 | # For efficient memory usage. 53 | for param in self.cnn.parameters(): 54 | param.requires_grad = finetune 55 | 56 | # Replace the last fully connected layer of CNN with a new one 57 | if cnn_type.startswith('vgg'): 58 | self.fc = nn.Linear(self.cnn.classifier._modules['6'].in_features, 59 | embed_size) 60 | self.cnn.classifier = nn.Sequential( 61 | *list(self.cnn.classifier.children())[:-1]) 62 | elif cnn_type.startswith('resnet'): 63 | self.fc = nn.Linear(self.cnn.module.fc.in_features, embed_size) 64 | self.cnn.module.fc = nn.Sequential() 65 | 66 | self.init_weights() 67 | 68 | def get_cnn(self, arch, pretrained): 69 | """Load a pretrained CNN and parallelize over GPUs 70 | """ 71 | if pretrained: 72 | print("=> using pre-trained model '{}'".format(arch)) 73 | model = models.__dict__[arch](pretrained=True) 74 | else: 75 | print("=> creating model '{}'".format(arch)) 76 | model = models.__dict__[arch]() 77 | 78 | if arch.startswith('alexnet') or arch.startswith('vgg'): 79 | model.features = nn.DataParallel(model.features) 80 | model.cuda() 81 | else: 82 | model = nn.DataParallel(model).cuda() 83 | 84 | return model 85 | 86 | def load_state_dict(self, state_dict): 87 | """ 88 | Handle the models saved before commit pytorch/vision@989d52a 89 | """ 90 | if 'cnn.classifier.1.weight' in state_dict: 91 | state_dict['cnn.classifier.0.weight'] = state_dict[ 92 | 'cnn.classifier.1.weight'] 93 | del state_dict['cnn.classifier.1.weight'] 94 | state_dict['cnn.classifier.0.bias'] = state_dict[ 95 | 'cnn.classifier.1.bias'] 96 | del state_dict['cnn.classifier.1.bias'] 97 | state_dict['cnn.classifier.3.weight'] = state_dict[ 98 | 'cnn.classifier.4.weight'] 99 | del state_dict['cnn.classifier.4.weight'] 100 | state_dict['cnn.classifier.3.bias'] = state_dict[ 101 | 'cnn.classifier.4.bias'] 102 | del state_dict['cnn.classifier.4.bias'] 103 | 104 | super(EncoderImageFull, self).load_state_dict(state_dict) 105 | 106 | def init_weights(self): 107 | """Xavier initialization for the fully connected layer 108 | """ 109 | r = np.sqrt(6.) / np.sqrt(self.fc.in_features + 110 | self.fc.out_features) 111 | self.fc.weight.data.uniform_(-r, r) 112 | self.fc.bias.data.fill_(0) 113 | 114 | def forward(self, images): 115 | """Extract image feature vectors.""" 116 | features = self.cnn(images) 117 | 118 | # normalization in the image embedding space 119 | features = l2norm(features) 120 | 121 | # linear projection to the joint embedding space 122 | features = self.fc(features) 123 | 124 | # normalization in the joint embedding space 125 | if not self.no_imgnorm: 126 | features = l2norm(features) 127 | 128 | # take the absolute value of the embedding (used in order embeddings) 129 | if self.use_abs: 130 | features = torch.abs(features) 131 | 132 | return features 133 | 134 | 135 | class EncoderImagePrecomp(nn.Module): 136 | 137 | def __init__(self, img_dim, embed_size, use_abs=False, no_imgnorm=False): 138 | super(EncoderImagePrecomp, self).__init__() 139 | self.embed_size = embed_size 140 | self.no_imgnorm = no_imgnorm 141 | self.use_abs = use_abs 142 | 143 | self.fc = nn.Linear(img_dim, embed_size) 144 | 145 | self.init_weights() 146 | 147 | def init_weights(self): 148 | """Xavier initialization for the fully connected layer 149 | """ 150 | r = np.sqrt(6.) / np.sqrt(self.fc.in_features + 151 | self.fc.out_features) 152 | self.fc.weight.data.uniform_(-r, r) 153 | self.fc.bias.data.fill_(0) 154 | 155 | def forward(self, images): 156 | """Extract image feature vectors.""" 157 | # assuming that the precomputed features are already l2-normalized 158 | 159 | features = self.fc(images) 160 | 161 | # normalize in the joint embedding space 162 | if not self.no_imgnorm: 163 | features = l2norm(features) 164 | 165 | # take the absolute value of embedding (used in order embeddings) 166 | if self.use_abs: 167 | features = torch.abs(features) 168 | 169 | return features 170 | 171 | def load_state_dict(self, state_dict): 172 | """Copies parameters. overwritting the default one to 173 | accept state_dict from Full model 174 | """ 175 | own_state = self.state_dict() 176 | new_state = OrderedDict() 177 | for name, param in state_dict.items(): 178 | if name in own_state: 179 | new_state[name] = param 180 | 181 | super(EncoderImagePrecomp, self).load_state_dict(new_state) 182 | 183 | 184 | # tutorials/08 - Language Model 185 | # RNN Based Language Model 186 | class EncoderText(nn.Module): 187 | 188 | def __init__(self, vocab_size, word_dim, embed_size, num_layers, 189 | use_abs=False): 190 | super(EncoderText, self).__init__() 191 | self.use_abs = use_abs 192 | self.embed_size = embed_size 193 | 194 | # word embedding 195 | self.embed = nn.Embedding(vocab_size, word_dim) 196 | 197 | # caption embedding 198 | self.rnn = nn.GRU(word_dim, embed_size, num_layers, batch_first=True) 199 | 200 | self.init_weights() 201 | 202 | def init_weights(self): 203 | self.embed.weight.data.uniform_(-0.1, 0.1) 204 | 205 | def forward(self, x, lengths): 206 | """Handles variable size captions 207 | """ 208 | # Embed word ids to vectors 209 | x = self.embed(x) 210 | packed = pack_padded_sequence(x, lengths, batch_first=True) 211 | 212 | # Forward propagate RNN 213 | out, _ = self.rnn(packed) 214 | 215 | # Reshape *final* output to (batch_size, hidden_size) 216 | padded = pad_packed_sequence(out, batch_first=True) 217 | I = torch.LongTensor(lengths).view(-1, 1, 1) 218 | I = Variable(I.expand(x.size(0), 1, self.embed_size)-1).cuda() 219 | out = torch.gather(padded[0], 1, I).squeeze(1) 220 | 221 | # normalization in the joint embedding space 222 | out = l2norm(out) 223 | 224 | # take absolute value, used by order embeddings 225 | if self.use_abs: 226 | out = torch.abs(out) 227 | 228 | return out 229 | 230 | 231 | def cosine_sim(im, s): 232 | """Cosine similarity between all the image and sentence pairs 233 | """ 234 | return im.mm(s.t()) 235 | 236 | 237 | def order_sim(im, s): 238 | """Order embeddings similarity measure $max(0, s-im)$ 239 | """ 240 | YmX = (s.unsqueeze(1).expand(s.size(0), im.size(0), s.size(1)) 241 | - im.unsqueeze(0).expand(s.size(0), im.size(0), s.size(1))) 242 | score = -YmX.clamp(min=0).pow(2).sum(2).sqrt().t() 243 | return score 244 | 245 | 246 | class ContrastiveLoss(nn.Module): 247 | """ 248 | Compute contrastive loss 249 | """ 250 | 251 | def __init__(self, opt): 252 | super(ContrastiveLoss, self).__init__() 253 | 254 | if opt.measure == 'order': 255 | self.sim = order_sim 256 | else: 257 | self.sim = cosine_sim 258 | 259 | self.opt = opt 260 | 261 | # "g" represents "global" 262 | self.g_alpha = self.opt.global_alpha 263 | self.g_beta= self.opt.global_beta # W_it 264 | self.g_ep_posi = self.opt.global_ep_posi # W_ii 265 | self.g_ep_nega = self.opt.global_ep_nega 266 | 267 | # "l" represents "local" 268 | self.l_alpha = self.opt.local_alpha 269 | self.l_ep = self.opt.local_ep 270 | 271 | def forward(self, im, s, mb_img, mb_cap, mb_ind, indices): 272 | 273 | bsize = im.size()[0] 274 | 275 | scores = self.sim(im, s) 276 | 277 | if self.opt.max_violation or self.opt.sum_violation: 278 | 279 | diagonal = scores.diag().view(bsize, 1) 280 | d1 = diagonal.expand_as(scores) 281 | d2 = diagonal.t().expand_as(scores) 282 | 283 | cost_s = (self.opt.margin + scores - d1).clamp(min=0) 284 | cost_im = (self.opt.margin + scores - d2).clamp(min=0) 285 | 286 | mask = torch.eye(bsize) > .5 287 | I = Variable(mask) 288 | if torch.cuda.is_available(): 289 | I = I.cuda() 290 | cost_s = cost_s.masked_fill_(I, 0) 291 | cost_im = cost_im.masked_fill_(I, 0) 292 | 293 | if self.opt.max_violation: 294 | 295 | cost_s = cost_s.max(1)[0] 296 | cost_im = cost_im.max(0)[0] 297 | 298 | return cost_s.sum() + cost_im.sum() 299 | 300 | tmp = torch.eye(bsize).cuda() 301 | 302 | s_diag = tmp * scores 303 | scores_ = scores - s_diag 304 | 305 | if mb_img is not None: 306 | 307 | #negative 308 | mb_k = self.opt.mb_k 309 | if im.size()[0] < mb_k: mb_k = bsize 310 | 311 | used_ind = torch.tensor([0 if i in indices else 1 for i in mb_ind]).bool().cuda() 312 | 313 | mb_img = mb_img[used_ind] 314 | mb_cap = mb_cap[used_ind] 315 | 316 | scores_img_glob = self.sim(im, mb_cap) 317 | i2t_k_avg = torch.exp(self.g_beta * torch.topk(scores_img_glob, mb_k)[0] - self.g_ep_nega).sum(1).reshape((bsize,1)) 318 | i2t_k_avg_positive = torch.exp(self.g_alpha * (torch.topk(scores_img_glob, mb_k)[0] - self.g_ep_posi)).sum(1) 319 | 320 | scores_cap_glob = self.sim(s, mb_img) 321 | t2i_k_avg = torch.exp(self.g_beta * torch.topk(scores_cap_glob, mb_k)[0] - self.g_ep_nega).sum(1).reshape((1,bsize)) 322 | t2i_k_avg_positive = torch.exp(self.g_alpha * (torch.topk(scores_cap_glob, mb_k)[0] - self.g_ep_posi)).sum(1) 323 | 324 | tmp_i2t = i2t_k_avg.repeat(1, bsize) 325 | tmp_t2i = t2i_k_avg.repeat(bsize, 1) 326 | 327 | exp_sii = torch.exp(self.g_beta * s_diag.sum(0)) 328 | tmp_expii = exp_sii.reshape((bsize,1)).repeat(1, bsize) 329 | tmp_exptt = exp_sii.reshape((1,bsize)).repeat(bsize, 1) 330 | 331 | wit = (tmp_i2t + tmp_t2i) / (tmp_i2t + tmp_t2i + tmp_expii + tmp_exptt) 332 | 333 | #positive 334 | exp_sii = torch.exp(self.g_alpha * (s_diag.sum(0) - self.g_ep_posi)) 335 | 336 | wii = 1 - exp_sii / (exp_sii + i2t_k_avg_positive + t2i_k_avg_positive) 337 | 338 | wit = wit - wit * tmp 339 | 340 | S_ = torch.exp(self.l_alpha * wit.detach() * (scores_ - self.l_ep)) 341 | 342 | loss_diag = - torch.log(1 + F.relu((s_diag.sum(0) * wii.detach()))) 343 | 344 | else: 345 | 346 | S_ = torch.exp(self.l_alpha * (scores_ - self.l_ep)) 347 | 348 | loss_diag = - torch.log(1 + F.relu(s_diag.sum(0))) 349 | 350 | loss = torch.sum( 351 | torch.log(1 + S_.sum(0)) / self.l_alpha \ 352 | + torch.log(1 + S_.sum(1)) / self.l_alpha \ 353 | + loss_diag 354 | ) / bsize 355 | 356 | return loss 357 | 358 | 359 | class VSE(object): 360 | """ 361 | rkiros/uvs model 362 | """ 363 | 364 | def __init__(self, opt): 365 | # tutorials/09 - Image Captioning 366 | # Build Models 367 | self.grad_clip = opt.grad_clip 368 | self.img_enc = EncoderImage(opt.data_name, opt.img_dim, opt.embed_size, 369 | opt.finetune, opt.cnn_type, 370 | use_abs=opt.use_abs, 371 | no_imgnorm=opt.no_imgnorm) 372 | self.txt_enc = EncoderText(opt.vocab_size, opt.word_dim, 373 | opt.embed_size, opt.num_layers, 374 | use_abs=opt.use_abs) 375 | if torch.cuda.is_available(): 376 | self.img_enc.cuda() 377 | self.txt_enc.cuda() 378 | cudnn.benchmark = True 379 | 380 | # memory bank 381 | self.mb_img = None 382 | self.mb_cap = None 383 | self.mb_ind = None 384 | 385 | # Loss and Optimizer 386 | self.criterion = ContrastiveLoss(opt=opt) 387 | params = list(self.txt_enc.parameters()) 388 | params += list(self.img_enc.fc.parameters()) 389 | if opt.finetune: 390 | params += list(self.img_enc.cnn.parameters()) 391 | self.params = params 392 | 393 | self.optimizer = torch.optim.Adam(params, lr=opt.learning_rate) 394 | 395 | self.Eiters = 0 396 | 397 | def state_dict(self): 398 | state_dict = [self.img_enc.state_dict(), self.txt_enc.state_dict()] 399 | return state_dict 400 | 401 | def load_state_dict(self, state_dict): 402 | self.img_enc.load_state_dict(state_dict[0]) 403 | self.txt_enc.load_state_dict(state_dict[1]) 404 | 405 | def train_start(self): 406 | """switch to train mode 407 | """ 408 | self.img_enc.train() 409 | self.txt_enc.train() 410 | 411 | def val_start(self): 412 | """switch to evaluate mode 413 | """ 414 | self.img_enc.eval() 415 | self.txt_enc.eval() 416 | 417 | def forward_emb(self, images, captions, lengths, volatile=False,**kwargs): 418 | """Compute the image and caption embeddings 419 | """ 420 | # Set mini-batch dataset 421 | if volatile: 422 | with torch.no_grad(): 423 | images = Variable(images) 424 | captions = Variable(captions) 425 | else: 426 | images = Variable(images) 427 | captions = Variable(captions) 428 | 429 | if torch.cuda.is_available(): 430 | images = images.cuda() 431 | captions = captions.cuda() 432 | 433 | # Forward 434 | img_emb = self.img_enc(images) 435 | cap_emb = self.txt_enc(captions, lengths) 436 | return img_emb, cap_emb 437 | 438 | def forward_loss(self, img_emb, cap_emb, indices, **kwargs): 439 | """Compute the loss given pairs of image and caption embeddings 440 | """ 441 | loss = self.criterion( 442 | img_emb, 443 | cap_emb, 444 | self.mb_img, 445 | self.mb_cap, 446 | self.mb_ind, 447 | indices) 448 | self.logger.update('Loss', loss.item(), img_emb.size(0)) 449 | return loss 450 | 451 | def train_emb(self, images, captions, lengths, ids, indices, *args): 452 | """One training step given images and captions. 453 | """ 454 | self.Eiters += 1 455 | self.logger.update('Eit', self.Eiters) 456 | self.logger.update('lr', self.optimizer.param_groups[0]['lr']) 457 | 458 | # compute the embeddings 459 | img_emb, cap_emb = self.forward_emb(images, captions, lengths) 460 | 461 | # measure accuracy and record loss 462 | self.optimizer.zero_grad() 463 | loss = self.forward_loss(img_emb, cap_emb, indices) 464 | 465 | # compute gradient and do SGD step 466 | loss.backward() 467 | if self.grad_clip > 0: 468 | clip_grad_norm_(self.params, self.grad_clip) 469 | self.optimizer.step() 470 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Automatically generated by https://github.com/damnever/pigar. 2 | 3 | # HAL/data.py: 6 4 | Pillow == 9.3.0 5 | 6 | # HAL/data.py: 5 7 | # HAL/vocab.py: 2 8 | nltk == 3.6.6 9 | 10 | # HAL/data.py: 8 11 | # HAL/evaluation.py: 5,8 12 | # HAL/model.py: 10 13 | numpy == 1.22.0 14 | 15 | # HAL/data.py: 7 16 | # HAL/vocab.py: 5 17 | pycocotools-fix == 2.0.0.9 18 | 19 | # HAL/data.py: 7 20 | # HAL/vocab.py: 5 21 | pycocotools-win == 2.0 22 | 23 | # HAL/train.py: 9 24 | tensorboard_logger == 0.1.0 25 | 26 | # HAL/data.py: 1,2 27 | # HAL/evaluation.py: 10 28 | # HAL/model.py: 7,8,9 29 | # HAL/train.py: 7 30 | torch == 1.13.1 31 | 32 | # HAL/data.py: 3 33 | # HAL/model.py: 5 34 | torchvision == 0.6.0a0+35d732a 35 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | import time 4 | import shutil 5 | from random import random 6 | import argparse 7 | import torch 8 | import logging 9 | import tensorboard_logger as tb_logger 10 | 11 | import data 12 | from vocab import Vocabulary # NOQA 13 | from model import VSE 14 | from evaluation import i2t, t2i, AverageMeter, LogCollector, encode_data 15 | 16 | def main(): 17 | # Hyper Parameters 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--data_path', default='/w/31/faghri/vsepp_data/', 20 | help='path to datasets') 21 | parser.add_argument('--data_name', default='precomp', 22 | help='{coco,f8k,f30k,10crop}_precomp|coco|f8k|f30k') 23 | parser.add_argument('--vocab_path', default='./vocab/', 24 | help='Path to saved vocabulary pickle files.') 25 | parser.add_argument('--margin', default=0.2, type=float, 26 | help='Rank loss margin.') 27 | parser.add_argument('--num_epochs', default=15, type=int, 28 | help='Number of training epochs.') 29 | parser.add_argument('--batch_size', default=128, type=int, 30 | help='Size of a training mini-batch.') 31 | parser.add_argument('--word_dim', default=300, type=int, 32 | help='Dimensionality of the word embedding.') 33 | parser.add_argument('--embed_size', default=1024, type=int, 34 | help='Dimensionality of the joint embedding.') 35 | parser.add_argument('--grad_clip', default=2., type=float, 36 | help='Gradient clipping threshold.') 37 | parser.add_argument('--crop_size', default=224, type=int, 38 | help='Size of an image crop as the CNN input.') 39 | parser.add_argument('--num_layers', default=1, type=int, 40 | help='Number of GRU layers.') 41 | parser.add_argument('--learning_rate', default=.0002, type=float, 42 | help='Initial learning rate.') 43 | parser.add_argument('--lr_update', default=8, type=int, 44 | help='Number of epochs to update the learning rate.') 45 | parser.add_argument('--workers', default=10, type=int, 46 | help='Number of data loader workers.') 47 | parser.add_argument('--log_step', default=10, type=int, 48 | help='Number of steps to print and record the log.') 49 | parser.add_argument('--val_step', default=500, type=int, 50 | help='Number of steps to run validation.') 51 | parser.add_argument('--logger_name', default='runs/runX', 52 | help='Path to save the model and Tensorboard log.') 53 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 54 | help='path to latest checkpoint (default: none)') 55 | parser.add_argument('--max_violation', action='store_true', 56 | help='Use max instead of sum in the rank loss.') 57 | parser.add_argument('--sum_violation', action='store_true') 58 | parser.add_argument('--img_dim', default=4096, type=int, 59 | help='Dimensionality of the image embedding.') 60 | parser.add_argument('--finetune', action='store_true', 61 | help='Fine-tune the image encoder.') 62 | parser.add_argument('--cnn_type', default='vgg19', 63 | help="""The CNN used for image encoder 64 | (e.g. vgg19, resnet152)""") 65 | parser.add_argument('--use_restval', action='store_true', 66 | help='Use the restval data for training on MSCOCO.') 67 | parser.add_argument('--measure', default='cosine', 68 | help='Similarity measure used (cosine|order)') 69 | parser.add_argument('--use_abs', action='store_true', 70 | help='Take the absolute value of embedding vectors.') 71 | parser.add_argument('--no_imgnorm', action='store_true', 72 | help='Do not normalize the image embeddings.') 73 | parser.add_argument('--reset_train', action='store_true', 74 | help='Ensure the training is always done in ' 75 | 'train mode (Not recommended).') 76 | parser.add_argument('--save_all', action='store_true', 77 | help="Save model after the training of each epoch") 78 | parser.add_argument('--memory_bank', action='store_true', 79 | help="Train model with memory bank") 80 | parser.add_argument('--record_val', action='store_true', 81 | help="Record the rsum values on validation set in file during training") 82 | parser.add_argument('--local_alpha', default=30.0, type=float) 83 | parser.add_argument('--local_ep', default=0.3, type=float) 84 | parser.add_argument('--global_alpha', default=40.0, type=float) 85 | parser.add_argument('--global_beta', default=40.0, type=float) 86 | parser.add_argument('--global_ep_posi', default=0.2, type=float, 87 | help="Global epsilon for positive pairs") 88 | parser.add_argument('--global_ep_nega', default=0.1, type=float, 89 | help="Global epsilon for negative pairs") 90 | parser.add_argument('--mb_k', default=250, type=int, 91 | help="Use top K items in memory bank") 92 | parser.add_argument('--mb_rate', default=0.05, type=float, 93 | help="-") 94 | 95 | opt = parser.parse_args() 96 | print(opt) 97 | 98 | logging.basicConfig(format='%(message)s', level=logging.INFO) 99 | tb_logger.configure(opt.logger_name, flush_secs=5) 100 | 101 | # Load Vocabulary Wrapper 102 | vocab = pickle.load(open(os.path.join( 103 | opt.vocab_path, '%s_vocab.pkl' % opt.data_name), 'rb')) 104 | opt.vocab_size = len(vocab) 105 | print("Vocab Size: %d" % opt.vocab_size) 106 | 107 | # Load data loaders 108 | train_loader, val_loader = data.get_loaders( 109 | opt.data_name, vocab, opt.crop_size, opt.batch_size, opt.workers, opt) 110 | 111 | # Construct the model 112 | model = VSE(opt) 113 | 114 | # optionally resume from a checkpoint 115 | if opt.resume: 116 | if os.path.isfile(opt.resume): 117 | print("=> loading checkpoint '{}'".format(opt.resume)) 118 | checkpoint = torch.load(opt.resume) 119 | start_epoch = checkpoint['epoch'] 120 | best_rsum = checkpoint['best_rsum'] 121 | model.load_state_dict(checkpoint['model']) 122 | # Eiters is used to show logs as the continuation of another 123 | # training 124 | model.Eiters = checkpoint['Eiters'] 125 | print("=> loaded checkpoint '{}' (epoch {}, best_rsum {})" 126 | .format(opt.resume, start_epoch, best_rsum)) 127 | validate(opt, val_loader, model) 128 | else: 129 | print("=> no checkpoint found at '{}'".format(opt.resume)) 130 | 131 | # Train the Model 132 | best_rsum = 0 133 | for epoch in range(opt.num_epochs): 134 | adjust_learning_rate(opt, model.optimizer, epoch) 135 | 136 | memory_bank = opt.memory_bank 137 | if memory_bank and epoch > 0: 138 | load_memory_bank(opt, train_loader, model) 139 | # train for one epoch 140 | train(opt, train_loader, model, epoch, val_loader) 141 | 142 | # evaluate on validation set 143 | rsum = validate(opt, val_loader, model) 144 | print ("rsum: %.1f" % rsum) 145 | if opt.record_val: 146 | with open("rst_val_" + opt.logger_name[5:], "a") as f: 147 | f.write("Epoch: %d ; rsum: %.1f\n" %(epoch, rsum)) 148 | 149 | # remember best R@ sum and save checkpoint 150 | is_best = rsum > best_rsum 151 | best_rsum = max(rsum, best_rsum) 152 | save_checkpoint({ 153 | 'epoch': epoch + 1, 154 | 'model': model.state_dict(), 155 | 'best_rsum': best_rsum, 156 | 'opt': opt, 157 | 'Eiters': model.Eiters, 158 | }, is_best, prefix=opt.logger_name + '/', save_all=opt.save_all) 159 | 160 | # reset memory bank 161 | model.mb_img = None 162 | model.mb_cap = None 163 | 164 | def load_memory_bank(opt, train_loader, model): 165 | mb_img, mb_cap, ind = None, None, None 166 | for i, train_data in enumerate(train_loader): 167 | 168 | if (i+1) % 50 == 0: 169 | print ('[ %d / %d memory bank loading randomly...]' % (i+1,len(train_loader))) 170 | 171 | if random() > opt.mb_rate: continue 172 | 173 | model.val_start() 174 | with torch.no_grad(): 175 | imgs, caps, lengths, _, indices = train_data 176 | img_emb, cap_emb = model.forward_emb(imgs, caps, lengths) 177 | 178 | if mb_img is None: 179 | mb_img = img_emb 180 | mb_cap = cap_emb 181 | ind = indices 182 | else: 183 | mb_img = torch.cat((mb_img, img_emb), 0) 184 | mb_cap = torch.cat((mb_cap, cap_emb), 0) 185 | ind = ind + indices 186 | model.mb_img = mb_img 187 | model.mb_cap = mb_cap 188 | model.mb_ind = ind 189 | 190 | print ('[memory bank fully loaded!]') 191 | print ("MB(Image): ", model.mb_img.size()) 192 | print ("MB(Caption): ", model.mb_cap.size()) 193 | print ("indices len:",len(model.mb_ind)) 194 | 195 | def train(opt, train_loader, model, epoch, val_loader): 196 | # average meters to record the training statistics 197 | batch_time = AverageMeter() 198 | data_time = AverageMeter() 199 | train_logger = LogCollector() 200 | 201 | end = time.time() 202 | for i, train_data in enumerate(train_loader): 203 | 204 | model.train_start() 205 | 206 | # measure data loading time 207 | data_time.update(time.time() - end) 208 | 209 | # make sure train logger is used 210 | model.logger = train_logger 211 | 212 | # Update the model 213 | model.train_emb(*train_data) 214 | 215 | # measure elapsed time 216 | batch_time.update(time.time() - end) 217 | end = time.time() 218 | 219 | # Print log info 220 | if model.Eiters % opt.log_step == 0: 221 | logging.info( 222 | 'Epoch: [{0}][{1}/{2}]\t' 223 | '{e_log}\t' 224 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 225 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 226 | .format( 227 | epoch, i, len(train_loader), batch_time=batch_time, 228 | data_time=data_time, e_log=str(model.logger))) 229 | 230 | # Record logs in tensorboard 231 | tb_logger.log_value('epoch', epoch, step=model.Eiters) 232 | tb_logger.log_value('step', i, step=model.Eiters) 233 | tb_logger.log_value('batch_time', batch_time.val, step=model.Eiters) 234 | tb_logger.log_value('data_time', data_time.val, step=model.Eiters) 235 | model.logger.tb_log(tb_logger, step=model.Eiters) 236 | 237 | # validate at every val_step 238 | if model.Eiters % opt.val_step == 0: 239 | validate(opt, val_loader, model) 240 | 241 | 242 | def validate(opt, val_loader, model): 243 | # compute the encoding for all the validation images and captions 244 | img_embs, cap_embs = encode_data( 245 | model, val_loader, opt.log_step, logging.info) 246 | 247 | # caption retrieval 248 | (r1, r5, r10, medr, meanr) = i2t(img_embs, cap_embs, measure=opt.measure) 249 | logging.info("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" % 250 | (r1, r5, r10, medr, meanr)) 251 | # image retrieval 252 | (r1i, r5i, r10i, medri, meanr) = t2i( 253 | img_embs, cap_embs, measure=opt.measure) 254 | logging.info("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" % 255 | (r1i, r5i, r10i, medri, meanr)) 256 | # sum of recalls to be used for early stopping 257 | currscore = r1 + r5 + r10 + r1i + r5i + r10i 258 | 259 | # record metrics in tensorboard 260 | tb_logger.log_value('r1', r1, step=model.Eiters) 261 | tb_logger.log_value('r5', r5, step=model.Eiters) 262 | tb_logger.log_value('r10', r10, step=model.Eiters) 263 | tb_logger.log_value('medr', medr, step=model.Eiters) 264 | tb_logger.log_value('meanr', meanr, step=model.Eiters) 265 | tb_logger.log_value('r1i', r1i, step=model.Eiters) 266 | tb_logger.log_value('r5i', r5i, step=model.Eiters) 267 | tb_logger.log_value('r10i', r10i, step=model.Eiters) 268 | tb_logger.log_value('medri', medri, step=model.Eiters) 269 | tb_logger.log_value('meanr', meanr, step=model.Eiters) 270 | tb_logger.log_value('rsum', currscore, step=model.Eiters) 271 | 272 | return currscore 273 | 274 | 275 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', prefix='', save_all=True): 276 | torch.save(state, prefix + filename) 277 | if is_best: 278 | print ("[Best model sofar, saved.]") 279 | shutil.copyfile(prefix + filename, prefix + 'model_best.pth.tar') 280 | if save_all: 281 | shutil.copyfile(prefix + filename, prefix + "Epoch-" + str(state['epoch']) + "-" + 'model.pth.tar') 282 | 283 | 284 | 285 | def adjust_learning_rate(opt, optimizer, epoch): 286 | """Sets the learning rate to the initial LR 287 | decayed by 10 every 30 epochs""" 288 | lr = opt.learning_rate * (0.1 ** (epoch // opt.lr_update)) 289 | for param_group in optimizer.param_groups: 290 | param_group['lr'] = lr 291 | 292 | 293 | def accuracy(output, target, topk=(1,)): 294 | """Computes the precision@k for the specified values of k""" 295 | maxk = max(topk) 296 | batch_size = target.size(0) 297 | 298 | _, pred = output.topk(maxk, 1, True, True) 299 | pred = pred.t() 300 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 301 | 302 | res = [] 303 | for k in topk: 304 | correct_k = correct[:k].view(-1).float().sum(0) 305 | res.append(correct_k.mul_(100.0 / batch_size)) 306 | return res 307 | 308 | 309 | if __name__ == '__main__': 310 | main() 311 | -------------------------------------------------------------------------------- /vocab.py: -------------------------------------------------------------------------------- 1 | # Create a vocabulary wrapper 2 | import nltk 3 | import pickle 4 | from collections import Counter 5 | from pycocotools.coco import COCO 6 | import json 7 | import argparse 8 | import os 9 | 10 | annotations = { 11 | 'coco_precomp': ['train_caps.txt', 'dev_caps.txt'], 12 | 'coco': ['annotations/captions_train2014.json', 13 | 'annotations/captions_val2014.json'], 14 | 'f8k_precomp': ['train_caps.txt', 'dev_caps.txt'], 15 | '10crop_precomp': ['train_caps.txt', 'dev_caps.txt'], 16 | 'f30k_precomp': ['train_caps.txt', 'dev_caps.txt'], 17 | 'f8k': ['dataset_flickr8k.json'], 18 | 'f30k': ['dataset_flickr30k.json'], 19 | } 20 | 21 | 22 | class Vocabulary(object): 23 | """Simple vocabulary wrapper.""" 24 | 25 | def __init__(self): 26 | self.word2idx = {} 27 | self.idx2word = {} 28 | self.idx = 0 29 | 30 | def add_word(self, word): 31 | if word not in self.word2idx: 32 | self.word2idx[word] = self.idx 33 | self.idx2word[self.idx] = word 34 | self.idx += 1 35 | 36 | def __call__(self, word): 37 | if word not in self.word2idx: 38 | return self.word2idx[''] 39 | return self.word2idx[word] 40 | 41 | def __len__(self): 42 | return len(self.word2idx) 43 | 44 | 45 | def from_coco_json(path): 46 | coco = COCO(path) 47 | ids = coco.anns.keys() 48 | captions = [] 49 | for i, idx in enumerate(ids): 50 | captions.append(str(coco.anns[idx]['caption'])) 51 | 52 | return captions 53 | 54 | 55 | def from_flickr_json(path): 56 | dataset = json.load(open(path, 'r'))['images'] 57 | captions = [] 58 | for i, d in enumerate(dataset): 59 | captions += [str(x['raw']) for x in d['sentences']] 60 | 61 | return captions 62 | 63 | 64 | def from_txt(txt): 65 | captions = [] 66 | with open(txt, 'rb') as f: 67 | for line in f: 68 | captions.append(line.strip()) 69 | return captions 70 | 71 | 72 | def build_vocab(data_path, data_name, jsons, threshold): 73 | """Build a simple vocabulary wrapper.""" 74 | counter = Counter() 75 | for path in jsons[data_name]: 76 | full_path = os.path.join(os.path.join(data_path, data_name), path) 77 | if data_name == 'coco': 78 | captions = from_coco_json(full_path) 79 | elif data_name == 'f8k' or data_name == 'f30k': 80 | captions = from_flickr_json(full_path) 81 | else: 82 | captions = from_txt(full_path) 83 | for i, caption in enumerate(captions): 84 | tokens = nltk.tokenize.word_tokenize( 85 | caption.lower().decode('utf-8')) 86 | counter.update(tokens) 87 | 88 | if i % 1000 == 0: 89 | print("[%d/%d] tokenized the captions." % (i, len(captions))) 90 | 91 | # Discard if the occurrence of the word is less than min_word_cnt. 92 | words = [word for word, cnt in counter.items() if cnt >= threshold] 93 | 94 | # Create a vocab wrapper and add some special tokens. 95 | vocab = Vocabulary() 96 | vocab.add_word('') 97 | vocab.add_word('') 98 | vocab.add_word('') 99 | vocab.add_word('') 100 | 101 | # Add words to the vocabulary. 102 | for i, word in enumerate(words): 103 | vocab.add_word(word) 104 | return vocab 105 | 106 | 107 | def main(data_path, data_name): 108 | vocab = build_vocab(data_path, data_name, jsons=annotations, threshold=4) 109 | with open('./vocab/%s_vocab.pkl' % data_name, 'wb') as f: 110 | pickle.dump(vocab, f, pickle.HIGHEST_PROTOCOL) 111 | print("Saved vocabulary file to ", './vocab/%s_vocab.pkl' % data_name) 112 | 113 | 114 | if __name__ == '__main__': 115 | parser = argparse.ArgumentParser() 116 | parser.add_argument('--data_path', default='/w/31/faghri/vsepp_data/') 117 | parser.add_argument('--data_name', default='coco', 118 | help='{coco,f8k,f30k,10crop}_precomp|coco|f8k|f30k') 119 | opt = parser.parse_args() 120 | main(opt.data_path, opt.data_name) 121 | --------------------------------------------------------------------------------