├── .gitignore ├── .gitmodules ├── util └── convert_data.py ├── vocab.py ├── data.py ├── README.md ├── LICENSE ├── train.py ├── evaluation.py └── model.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.swp 3 | *.ipynb_checkpoints 4 | *.json 5 | *.pth.tar 6 | .DS_Store 7 | /data 8 | /pycocotools 9 | /runs 10 | /vocab 11 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "bottom-up-attention"] 2 | path = bottom-up-attention 3 | url = https://github.com/peteanderson80/bottom-up-attention.git 4 | -------------------------------------------------------------------------------- /util/convert_data.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------- 2 | # Stacked Cross Attention Network implementation based on 3 | # https://arxiv.org/abs/1803.08024. 4 | # "Stacked Cross Attention for Image-Text Matching" 5 | # Kuang-Huei Lee, Xi Chen, Gang Hua, Houdong Hu, Xiaodong He 6 | # 7 | # Writen by Kuang-Huei Lee, 2018 8 | # --------------------------------------------------------------- 9 | """Convert image features from bottom up attention to numpy array""" 10 | import os 11 | import base64 12 | import csv 13 | import sys 14 | import zlib 15 | import json 16 | import argparse 17 | 18 | import numpy as np 19 | 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--imgid_list', default='data/coco_precomp/train_ids.txt', 23 | help='Path to list of image id') 24 | parser.add_argument('--input_file', default='/media/data/kualee/coco_bottom_up_feature/trainval_36/trainval_resnet101_faster_rcnn_genome_36.tsv', 25 | help='tsv of all image data (output of bottom-up-attention/tools/generate_tsv.py), \ 26 | where each columns are: [image_id, image_w, image_h, num_boxes, boxes, features].') 27 | parser.add_argument('--output_dir', default='data/coco_precomp/', 28 | help='Output directory.') 29 | parser.add_argument('--split', default='train', 30 | help='train|dev|test') 31 | opt = parser.parse_args() 32 | print(opt) 33 | 34 | 35 | meta = [] 36 | feature = {} 37 | for line in open(opt.imgid_list): 38 | sid = int(line.strip()) 39 | meta.append(sid) 40 | feature[sid] = None 41 | 42 | csv.field_size_limit(sys.maxsize) 43 | FIELDNAMES = ['image_id', 'image_w', 'image_h', 'num_boxes', 'boxes', 'features'] 44 | 45 | if __name__ == '__main__': 46 | with open(opt.input_file, "r+b") as tsv_in_file: 47 | reader = csv.DictReader(tsv_in_file, delimiter='\t', fieldnames = FIELDNAMES) 48 | for item in reader: 49 | item['image_id'] = int(item['image_id']) 50 | item['image_h'] = int(item['image_h']) 51 | item['image_w'] = int(item['image_w']) 52 | item['num_boxes'] = int(item['num_boxes']) 53 | for field in ['boxes', 'features']: 54 | data = item[field] 55 | buf = base64.decodestring(data) 56 | temp = np.frombuffer(buf, dtype=np.float32) 57 | item[field] = temp.reshape((item['num_boxes'],-1)) 58 | if item['image_id'] in feature: 59 | feature[item['image_id']] = item['features'] 60 | data_out = np.stack([feature[sid] for sid in meta], axis=0) 61 | print("Final numpy array shape:", data_out.shape) 62 | np.save(os.path.join(opt.output_dir, '{}_ims.npy'.format(opt.split)), data_out) 63 | -------------------------------------------------------------------------------- /vocab.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------- 2 | # Stacked Cross Attention Network implementation based on 3 | # https://arxiv.org/abs/1803.08024. 4 | # "Stacked Cross Attention for Image-Text Matching" 5 | # Kuang-Huei Lee, Xi Chen, Gang Hua, Houdong Hu, Xiaodong He 6 | # 7 | # Writen by Kuang-Huei Lee, 2018 8 | # --------------------------------------------------------------- 9 | """Vocabulary wrapper""" 10 | 11 | import nltk 12 | from collections import Counter 13 | import argparse 14 | import os 15 | import json 16 | 17 | annotations = { 18 | 'coco_precomp': ['train_caps.txt', 'dev_caps.txt'], 19 | 'f30k_precomp': ['train_caps.txt', 'dev_caps.txt'], 20 | } 21 | 22 | 23 | class Vocabulary(object): 24 | """Simple vocabulary wrapper.""" 25 | 26 | def __init__(self): 27 | self.word2idx = {} 28 | self.idx2word = {} 29 | self.idx = 0 30 | 31 | def add_word(self, word): 32 | if word not in self.word2idx: 33 | self.word2idx[word] = self.idx 34 | self.idx2word[self.idx] = word 35 | self.idx += 1 36 | 37 | def __call__(self, word): 38 | if word not in self.word2idx: 39 | return self.word2idx[''] 40 | return self.word2idx[word] 41 | 42 | def __len__(self): 43 | return len(self.word2idx) 44 | 45 | 46 | def serialize_vocab(vocab, dest): 47 | d = {} 48 | d['word2idx'] = vocab.word2idx 49 | d['idx2word'] = vocab.idx2word 50 | d['idx'] = vocab.idx 51 | with open(dest, "w") as f: 52 | json.dump(d, f) 53 | 54 | 55 | def deserialize_vocab(src): 56 | with open(src) as f: 57 | d = json.load(f) 58 | vocab = Vocabulary() 59 | vocab.word2idx = d['word2idx'] 60 | vocab.idx2word = d['idx2word'] 61 | vocab.idx = d['idx'] 62 | return vocab 63 | 64 | 65 | def from_txt(txt): 66 | captions = [] 67 | with open(txt, 'rb') as f: 68 | for line in f: 69 | captions.append(line.strip()) 70 | return captions 71 | 72 | 73 | def build_vocab(data_path, data_name, caption_file, threshold): 74 | """Build a simple vocabulary wrapper.""" 75 | counter = Counter() 76 | for path in caption_file[data_name]: 77 | full_path = os.path.join(os.path.join(data_path, data_name), path) 78 | captions = from_txt(full_path) 79 | for i, caption in enumerate(captions): 80 | tokens = nltk.tokenize.word_tokenize( 81 | caption.lower().decode('utf-8')) 82 | counter.update(tokens) 83 | 84 | if i % 1000 == 0: 85 | print("[%d/%d] tokenized the captions." % (i, len(captions))) 86 | 87 | # Discard if the occurrence of the word is less than min_word_cnt. 88 | words = [word for word, cnt in counter.items() if cnt >= threshold] 89 | 90 | # Create a vocab wrapper and add some special tokens. 91 | vocab = Vocabulary() 92 | vocab.add_word('') 93 | vocab.add_word('') 94 | vocab.add_word('') 95 | vocab.add_word('') 96 | 97 | # Add words to the vocabulary. 98 | for i, word in enumerate(words): 99 | vocab.add_word(word) 100 | return vocab 101 | 102 | 103 | def main(data_path, data_name): 104 | vocab = build_vocab(data_path, data_name, caption_file=annotations, threshold=4) 105 | serialize_vocab(vocab, './vocab/%s_vocab.json' % data_name) 106 | print("Saved vocabulary file to ", './vocab/%s_vocab.json' % data_name) 107 | 108 | 109 | 110 | if __name__ == '__main__': 111 | parser = argparse.ArgumentParser() 112 | parser.add_argument('--data_path', default='data') 113 | parser.add_argument('--data_name', default='f30k_precomp', 114 | help='{coco,f30k}_precomp') 115 | opt = parser.parse_args() 116 | main(opt.data_path, opt.data_name) 117 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------- 2 | # Stacked Cross Attention Network implementation based on 3 | # https://arxiv.org/abs/1803.08024. 4 | # "Stacked Cross Attention for Image-Text Matching" 5 | # Kuang-Huei Lee, Xi Chen, Gang Hua, Houdong Hu, Xiaodong He 6 | # 7 | # Writen by Kuang-Huei Lee, 2018 8 | # --------------------------------------------------------------- 9 | """Data provider""" 10 | 11 | import torch 12 | import torch.utils.data as data 13 | import torchvision.transforms as transforms 14 | import os 15 | import nltk 16 | from PIL import Image 17 | import numpy as np 18 | import json as jsonmod 19 | 20 | 21 | class PrecompDataset(data.Dataset): 22 | """ 23 | Load precomputed captions and image features 24 | Possible options: f30k_precomp, coco_precomp 25 | """ 26 | 27 | def __init__(self, data_path, data_split, vocab): 28 | self.vocab = vocab 29 | loc = data_path + '/' 30 | 31 | # Captions 32 | self.captions = [] 33 | with open(loc+'%s_caps.txt' % data_split, 'rb') as f: 34 | for line in f: 35 | self.captions.append(line.strip()) 36 | 37 | # Image features 38 | self.images = np.load(loc+'%s_ims.npy' % data_split) 39 | self.length = len(self.captions) 40 | # rkiros data has redundancy in images, we divide by 5, 10crop doesn't 41 | if self.images.shape[0] != self.length: 42 | self.im_div = 5 43 | else: 44 | self.im_div = 1 45 | # the development set for coco is large and so validation would be slow 46 | if data_split == 'dev': 47 | self.length = 5000 48 | 49 | def __getitem__(self, index): 50 | # handle the image redundancy 51 | img_id = index/self.im_div 52 | image = torch.Tensor(self.images[img_id]) 53 | caption = self.captions[index] 54 | vocab = self.vocab 55 | 56 | # Convert caption (string) to word ids. 57 | tokens = nltk.tokenize.word_tokenize( 58 | str(caption).lower().decode('utf-8')) 59 | caption = [] 60 | caption.append(vocab('')) 61 | caption.extend([vocab(token) for token in tokens]) 62 | caption.append(vocab('')) 63 | target = torch.Tensor(caption) 64 | return image, target, index, img_id 65 | 66 | def __len__(self): 67 | return self.length 68 | 69 | 70 | def collate_fn(data): 71 | """Build mini-batch tensors from a list of (image, caption) tuples. 72 | Args: 73 | data: list of (image, caption) tuple. 74 | - image: torch tensor of shape (3, 256, 256). 75 | - caption: torch tensor of shape (?); variable length. 76 | 77 | Returns: 78 | images: torch tensor of shape (batch_size, 3, 256, 256). 79 | targets: torch tensor of shape (batch_size, padded_length). 80 | lengths: list; valid length for each padded caption. 81 | """ 82 | # Sort a data list by caption length 83 | data.sort(key=lambda x: len(x[1]), reverse=True) 84 | images, captions, ids, img_ids = zip(*data) 85 | 86 | # Merge images (convert tuple of 3D tensor to 4D tensor) 87 | images = torch.stack(images, 0) 88 | 89 | # Merget captions (convert tuple of 1D tensor to 2D tensor) 90 | lengths = [len(cap) for cap in captions] 91 | targets = torch.zeros(len(captions), max(lengths)).long() 92 | for i, cap in enumerate(captions): 93 | end = lengths[i] 94 | targets[i, :end] = cap[:end] 95 | 96 | return images, targets, lengths, ids 97 | 98 | 99 | def get_precomp_loader(data_path, data_split, vocab, opt, batch_size=100, 100 | shuffle=True, num_workers=2): 101 | """Returns torch.utils.data.DataLoader for custom coco dataset.""" 102 | dset = PrecompDataset(data_path, data_split, vocab) 103 | 104 | data_loader = torch.utils.data.DataLoader(dataset=dset, 105 | batch_size=batch_size, 106 | shuffle=shuffle, 107 | pin_memory=True, 108 | collate_fn=collate_fn) 109 | return data_loader 110 | 111 | 112 | def get_loaders(data_name, vocab, batch_size, workers, opt): 113 | dpath = os.path.join(opt.data_path, data_name) 114 | train_loader = get_precomp_loader(dpath, 'train', vocab, opt, 115 | batch_size, True, workers) 116 | val_loader = get_precomp_loader(dpath, 'dev', vocab, opt, 117 | batch_size, False, workers) 118 | return train_loader, val_loader 119 | 120 | 121 | def get_test_loader(split_name, data_name, vocab, batch_size, 122 | workers, opt): 123 | dpath = os.path.join(opt.data_path, data_name) 124 | test_loader = get_precomp_loader(dpath, split_name, vocab, opt, 125 | batch_size, False, workers) 126 | return test_loader 127 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | 3 | This is Stacked Cross Attention Network, source code of [Stacked Cross Attention for Image-Text Matching](https://arxiv.org/abs/1803.08024) ([project page](https://kuanghuei.github.io/SCANProject/)) from Microsoft AI and Research. The paper will appear in ECCV 2018. It is built on top of the [VSE++](https://github.com/fartashf/vsepp) in PyTorch. 4 | 5 | 6 | ## Requirements and Installation 7 | We recommended the following dependencies. 8 | 9 | * Python 2.7 10 | * [PyTorch](http://pytorch.org/) 0.3 11 | * [NumPy](http://www.numpy.org/) (>1.12.1) 12 | * [TensorBoard](https://github.com/TeamHG-Memex/tensorboard_logger) 13 | 14 | * Punkt Sentence Tokenizer: 15 | ```python 16 | import nltk 17 | nltk.download() 18 | > d punkt 19 | ``` 20 | 21 | ## Download data 22 | 23 | Download the dataset files and pre-trained models. We use splits produced by [Andrej Karpathy](http://cs.stanford.edu/people/karpathy/deepimagesent/). The raw images can be downloaded from from their original sources [here](http://nlp.cs.illinois.edu/HockenmaierGroup/Framing_Image_Description/KCCA.html), [here](http://shannon.cs.illinois.edu/DenotationGraph/) and [here](http://mscoco.org/). 24 | 25 | The precomputed image features of MS-COCO are from [here](https://github.com/peteanderson80/bottom-up-attention). The precomputed image features of Flickr30K are extracted from the raw Flickr30K images using the bottom-up attention model from [here](https://github.com/peteanderson80/bottom-up-attention). All the data needed for reproducing the experiments in the paper, including image features and vocabularies, can be downloaded from: 26 | 27 | https://www.kaggle.com/datasets/kuanghueilee/scan-features 28 | 29 | We refer to the path of extracted files for `data.zip` as `$DATA_PATH` and files for `vocab.zip` to `./vocab` directory. Alternatively, you can also run vocab.py to produce vocabulary files. For example, 30 | 31 | ```bash 32 | python vocab.py --data_path data --data_name f30k_precomp 33 | python vocab.py --data_path data --data_name coco_precomp 34 | ``` 35 | 36 | ## Data pre-processing (Optional) 37 | 38 | The image features of Flickr30K and MS-COCO are available in numpy array format, which can be used for training directly. However, if you wish to test on another dataset, you will need to start from scratch: 39 | 40 | 1. Use the `bottom-up-attention/tools/generate_tsv.py` and the bottom-up attention model to extract features of image regions. The output file format will be a tsv, where the columns are ['image_id', 'image_w', 'image_h', 'num_boxes', 'boxes', 'features']. 41 | 2. Use `util/convert_data.py` to convert the above output to a numpy array. 42 | 43 | 44 | If downloading the whole data package containing bottom-up image features for Flickr30K and MS-COCO is too slow for you, you can download everything but image features from https://www.kaggle.com/datasets/kuanghueilee/scan-features and compute image features locally from raw images. 45 | 46 | 47 | ## Training new models 48 | Run `train.py`: 49 | 50 | ```bash 51 | python train.py --data_path "$DATA_PATH" --data_name coco_precomp --vocab_path "$VOCAB_PATH" --logger_name runs/coco_scan/log --model_name runs/coco_scan/log --max_violation --bi_gru 52 | ``` 53 | 54 | Arguments used to train Flickr30K models: 55 | 56 | | Method | Arguments | 57 | | :-------: | :-------: | 58 | | SCAN t-i LSE | `--max_violation --bi_gru --agg_func=LogSumExp --cross_attn=t2i --lambda_lse=6 --lambda_softmax=9` | 59 | | SCAN t-i AVG | `--max_violation --bi_gru --agg_func=Mean --cross_attn=t2i --lambda_softmax=9` | 60 | | SCAN i-t LSE | `--max_violation --bi_gru --agg_func=LogSumExp --cross_attn=i2t --lambda_lse=5 --lambda_softmax=4` | 61 | | SCAN i-t AVG | `--max_violation --bi_gru --agg_func=Mean --cross_attn=i2t --lambda_softmax=4` | 62 | 63 | 64 | Arguments used to train MS-COCO models: 65 | 66 | | Method | Arguments | 67 | | :-------: | :-------: | 68 | | SCAN t-i LSE | `--max_violation --bi_gru --agg_func=LogSumExp --cross_attn=t2i --lambda_lse=6 --lambda_softmax=9 --num_epochs=20 --lr_update=10 --learning_rate=.0005` | 69 | | SCAN t-i AVG | `--max_violation --bi_gru --agg_func=Mean --cross_attn=t2i --lambda_softmax=9 --num_epochs=20 --lr_update=10 --learning_rate=.0005` | 70 | | SCAN i-t LSE | `--max_violation --bi_gru --agg_func=LogSumExp --cross_attn=i2t --lambda_lse=20 --lambda_softmax=4 --num_epochs=20 --lr_update=10 --learning_rate=.0005` | 71 | | SCAN i-t AVG | `--max_violation --bi_gru --agg_func=Mean --cross_attn=i2t --lambda_softmax=4 --num_epochs=20 --lr_update=10 --learning_rate=.0005` | 72 | 73 | ## Evaluate trained models 74 | 75 | ```python 76 | from vocab import Vocabulary 77 | import evaluation 78 | evaluation.evalrank("$RUN_PATH/coco_scan/model_best.pth.tar", data_path="$DATA_PATH", split="test") 79 | ``` 80 | 81 | To do cross-validation on MSCOCO, pass `fold5=True` with a model trained using 82 | `--data_name coco_precomp`. 83 | 84 | ## Reference 85 | 86 | If you found this code useful, please cite the following paper: 87 | 88 | ``` 89 | @inproceedings{lee2018stacked, 90 | title={Stacked cross attention for image-text matching}, 91 | author={Lee, Kuang-Huei and Chen, Xi and Hua, Gang and Hu, Houdong and He, Xiaodong}, 92 | booktitle={Proceedings of the European conference on computer vision (ECCV)}, 93 | pages={201--216}, 94 | year={2018} 95 | } 96 | ``` 97 | 98 | ## License 99 | 100 | [Apache License 2.0](http://www.apache.org/licenses/LICENSE-2.0) 101 | 102 | 103 | ## Acknowledgments 104 | 105 | The authors would like to thank [Po-Sen Huang](https://posenhuang.github.io/) and Yokesh Kumar for helping the manuscript. We also thank Li Huang, Arun Sacheti, and Bing Multimedia team for supporting this work. 106 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------- 2 | # Stacked Cross Attention Network implementation based on 3 | # https://arxiv.org/abs/1803.08024. 4 | # "Stacked Cross Attention for Image-Text Matching" 5 | # Kuang-Huei Lee, Xi Chen, Gang Hua, Houdong Hu, Xiaodong He 6 | # 7 | # Writen by Kuang-Huei Lee, 2018 8 | # --------------------------------------------------------------- 9 | """Training script""" 10 | 11 | import os 12 | import time 13 | import shutil 14 | 15 | import torch 16 | import numpy 17 | 18 | import data 19 | from vocab import Vocabulary, deserialize_vocab 20 | from model import SCAN 21 | from evaluation import i2t, t2i, AverageMeter, LogCollector, encode_data, shard_xattn_t2i, shard_xattn_i2t 22 | from torch.autograd import Variable 23 | 24 | import logging 25 | import tensorboard_logger as tb_logger 26 | 27 | import argparse 28 | 29 | 30 | def main(): 31 | # Hyper Parameters 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('--data_path', default='./data/', 34 | help='path to datasets') 35 | parser.add_argument('--data_name', default='precomp', 36 | help='{coco,f30k}_precomp') 37 | parser.add_argument('--vocab_path', default='./vocab/', 38 | help='Path to saved vocabulary json files.') 39 | parser.add_argument('--margin', default=0.2, type=float, 40 | help='Rank loss margin.') 41 | parser.add_argument('--num_epochs', default=30, type=int, 42 | help='Number of training epochs.') 43 | parser.add_argument('--batch_size', default=128, type=int, 44 | help='Size of a training mini-batch.') 45 | parser.add_argument('--word_dim', default=300, type=int, 46 | help='Dimensionality of the word embedding.') 47 | parser.add_argument('--embed_size', default=1024, type=int, 48 | help='Dimensionality of the joint embedding.') 49 | parser.add_argument('--grad_clip', default=2., type=float, 50 | help='Gradient clipping threshold.') 51 | parser.add_argument('--num_layers', default=1, type=int, 52 | help='Number of GRU layers.') 53 | parser.add_argument('--learning_rate', default=.0002, type=float, 54 | help='Initial learning rate.') 55 | parser.add_argument('--lr_update', default=15, type=int, 56 | help='Number of epochs to update the learning rate.') 57 | parser.add_argument('--workers', default=10, type=int, 58 | help='Number of data loader workers.') 59 | parser.add_argument('--log_step', default=10, type=int, 60 | help='Number of steps to print and record the log.') 61 | parser.add_argument('--val_step', default=500, type=int, 62 | help='Number of steps to run validation.') 63 | parser.add_argument('--logger_name', default='./runs/runX/log', 64 | help='Path to save Tensorboard log.') 65 | parser.add_argument('--model_name', default='./runs/runX/checkpoint', 66 | help='Path to save the model.') 67 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 68 | help='path to latest checkpoint (default: none)') 69 | parser.add_argument('--max_violation', action='store_true', 70 | help='Use max instead of sum in the rank loss.') 71 | parser.add_argument('--img_dim', default=2048, type=int, 72 | help='Dimensionality of the image embedding.') 73 | parser.add_argument('--no_imgnorm', action='store_true', 74 | help='Do not normalize the image embeddings.') 75 | parser.add_argument('--no_txtnorm', action='store_true', 76 | help='Do not normalize the text embeddings.') 77 | parser.add_argument('--raw_feature_norm', default="clipped_l2norm", 78 | help='clipped_l2norm|l2norm|clipped_l1norm|l1norm|no_norm|softmax') 79 | parser.add_argument('--agg_func', default="LogSumExp", 80 | help='LogSumExp|Mean|Max|Sum') 81 | parser.add_argument('--cross_attn', default="t2i", 82 | help='t2i|i2t') 83 | parser.add_argument('--precomp_enc_type', default="basic", 84 | help='basic|weight_norm') 85 | parser.add_argument('--bi_gru', action='store_true', 86 | help='Use bidirectional GRU.') 87 | parser.add_argument('--lambda_lse', default=6., type=float, 88 | help='LogSumExp temp.') 89 | parser.add_argument('--lambda_softmax', default=9., type=float, 90 | help='Attention softmax temperature.') 91 | opt = parser.parse_args() 92 | print(opt) 93 | 94 | logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO) 95 | tb_logger.configure(opt.logger_name, flush_secs=5) 96 | 97 | # Load Vocabulary Wrapper 98 | vocab = deserialize_vocab(os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name)) 99 | opt.vocab_size = len(vocab) 100 | 101 | # Load data loaders 102 | train_loader, val_loader = data.get_loaders( 103 | opt.data_name, vocab, opt.batch_size, opt.workers, opt) 104 | 105 | # Construct the model 106 | model = SCAN(opt) 107 | 108 | best_rsum = 0 109 | start_epoch = 0 110 | # optionally resume from a checkpoint 111 | if opt.resume: 112 | if os.path.isfile(opt.resume): 113 | print("=> loading checkpoint '{}'".format(opt.resume)) 114 | checkpoint = torch.load(opt.resume) 115 | start_epoch = checkpoint['epoch'] + 1 116 | best_rsum = checkpoint['best_rsum'] 117 | model.load_state_dict(checkpoint['model']) 118 | # Eiters is used to show logs as the continuation of another 119 | # training 120 | model.Eiters = checkpoint['Eiters'] 121 | print("=> loaded checkpoint '{}' (epoch {}, best_rsum {})" 122 | .format(opt.resume, start_epoch, best_rsum)) 123 | validate(opt, val_loader, model) 124 | else: 125 | print("=> no checkpoint found at '{}'".format(opt.resume)) 126 | 127 | # Train the Model 128 | for epoch in range(start_epoch, opt.num_epochs): 129 | print(opt.logger_name) 130 | print(opt.model_name) 131 | 132 | adjust_learning_rate(opt, model.optimizer, epoch) 133 | 134 | # train for one epoch 135 | train(opt, train_loader, model, epoch, val_loader) 136 | 137 | # evaluate on validation set 138 | rsum = validate(opt, val_loader, model) 139 | 140 | # remember best R@ sum and save checkpoint 141 | is_best = rsum > best_rsum 142 | best_rsum = max(rsum, best_rsum) 143 | if not os.path.exists(opt.model_name): 144 | os.mkdir(opt.model_name) 145 | save_checkpoint({ 146 | 'epoch': epoch, 147 | 'model': model.state_dict(), 148 | 'best_rsum': best_rsum, 149 | 'opt': opt, 150 | 'Eiters': model.Eiters, 151 | }, is_best, filename='checkpoint_{}.pth.tar'.format(epoch), prefix=opt.model_name + '/') 152 | 153 | 154 | def train(opt, train_loader, model, epoch, val_loader): 155 | # average meters to record the training statistics 156 | batch_time = AverageMeter() 157 | data_time = AverageMeter() 158 | train_logger = LogCollector() 159 | 160 | end = time.time() 161 | for i, train_data in enumerate(train_loader): 162 | # switch to train mode 163 | model.train_start() 164 | 165 | # measure data loading time 166 | data_time.update(time.time() - end) 167 | 168 | # make sure train logger is used 169 | model.logger = train_logger 170 | 171 | # Update the model 172 | model.train_emb(*train_data) 173 | 174 | # measure elapsed time 175 | batch_time.update(time.time() - end) 176 | end = time.time() 177 | 178 | # Print log info 179 | if model.Eiters % opt.log_step == 0: 180 | logging.info( 181 | 'Epoch: [{0}][{1}/{2}]\t' 182 | '{e_log}\t' 183 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 184 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 185 | .format( 186 | epoch, i, len(train_loader), batch_time=batch_time, 187 | data_time=data_time, e_log=str(model.logger))) 188 | 189 | # Record logs in tensorboard 190 | tb_logger.log_value('epoch', epoch, step=model.Eiters) 191 | tb_logger.log_value('step', i, step=model.Eiters) 192 | tb_logger.log_value('batch_time', batch_time.val, step=model.Eiters) 193 | tb_logger.log_value('data_time', data_time.val, step=model.Eiters) 194 | model.logger.tb_log(tb_logger, step=model.Eiters) 195 | 196 | # validate at every val_step 197 | if model.Eiters % opt.val_step == 0: 198 | validate(opt, val_loader, model) 199 | 200 | 201 | def validate(opt, val_loader, model): 202 | # compute the encoding for all the validation images and captions 203 | img_embs, cap_embs, cap_lens = encode_data( 204 | model, val_loader, opt.log_step, logging.info) 205 | 206 | img_embs = numpy.array([img_embs[i] for i in range(0, len(img_embs), 5)]) 207 | 208 | start = time.time() 209 | if opt.cross_attn == 't2i': 210 | sims = shard_xattn_t2i(img_embs, cap_embs, cap_lens, opt, shard_size=128) 211 | elif opt.cross_attn == 'i2t': 212 | sims = shard_xattn_i2t(img_embs, cap_embs, cap_lens, opt, shard_size=128) 213 | else: 214 | raise NotImplementedError 215 | end = time.time() 216 | print("calculate similarity time:", end-start) 217 | 218 | # caption retrieval 219 | (r1, r5, r10, medr, meanr) = i2t(img_embs, cap_embs, cap_lens, sims) 220 | logging.info("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" % 221 | (r1, r5, r10, medr, meanr)) 222 | # image retrieval 223 | (r1i, r5i, r10i, medri, meanr) = t2i( 224 | img_embs, cap_embs, cap_lens, sims) 225 | logging.info("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" % 226 | (r1i, r5i, r10i, medri, meanr)) 227 | # sum of recalls to be used for early stopping 228 | currscore = r1 + r5 + r10 + r1i + r5i + r10i 229 | 230 | # record metrics in tensorboard 231 | tb_logger.log_value('r1', r1, step=model.Eiters) 232 | tb_logger.log_value('r5', r5, step=model.Eiters) 233 | tb_logger.log_value('r10', r10, step=model.Eiters) 234 | tb_logger.log_value('medr', medr, step=model.Eiters) 235 | tb_logger.log_value('meanr', meanr, step=model.Eiters) 236 | tb_logger.log_value('r1i', r1i, step=model.Eiters) 237 | tb_logger.log_value('r5i', r5i, step=model.Eiters) 238 | tb_logger.log_value('r10i', r10i, step=model.Eiters) 239 | tb_logger.log_value('medri', medri, step=model.Eiters) 240 | tb_logger.log_value('meanr', meanr, step=model.Eiters) 241 | tb_logger.log_value('rsum', currscore, step=model.Eiters) 242 | 243 | return currscore 244 | 245 | 246 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', prefix=''): 247 | tries = 15 248 | error = None 249 | 250 | # deal with unstable I/O. Usually not necessary. 251 | while tries: 252 | try: 253 | torch.save(state, prefix + filename) 254 | if is_best: 255 | shutil.copyfile(prefix + filename, prefix + 'model_best.pth.tar') 256 | except IOError as e: 257 | error = e 258 | tries -= 1 259 | else: 260 | break 261 | print('model save {} failed, remaining {} trials'.format(filename, tries)) 262 | if not tries: 263 | raise error 264 | 265 | 266 | def adjust_learning_rate(opt, optimizer, epoch): 267 | """Sets the learning rate to the initial LR 268 | decayed by 10 every 30 epochs""" 269 | lr = opt.learning_rate * (0.1 ** (epoch // opt.lr_update)) 270 | for param_group in optimizer.param_groups: 271 | param_group['lr'] = lr 272 | 273 | 274 | def accuracy(output, target, topk=(1,)): 275 | """Computes the precision@k for the specified values of k""" 276 | maxk = max(topk) 277 | batch_size = target.size(0) 278 | 279 | _, pred = output.topk(maxk, 1, True, True) 280 | pred = pred.t() 281 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 282 | 283 | res = [] 284 | for k in topk: 285 | correct_k = correct[:k].view(-1).float().sum(0) 286 | res.append(correct_k.mul_(100.0 / batch_size)) 287 | return res 288 | 289 | 290 | if __name__ == '__main__': 291 | main() 292 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------- 2 | # Stacked Cross Attention Network implementation based on 3 | # https://arxiv.org/abs/1803.08024. 4 | # "Stacked Cross Attention for Image-Text Matching" 5 | # Kuang-Huei Lee, Xi Chen, Gang Hua, Houdong Hu, Xiaodong He 6 | # 7 | # Writen by Kuang-Huei Lee, 2018 8 | # --------------------------------------------------------------- 9 | """Evaluation""" 10 | 11 | from __future__ import print_function 12 | import os 13 | 14 | import sys 15 | from data import get_test_loader 16 | import time 17 | import numpy as np 18 | from vocab import Vocabulary, deserialize_vocab # NOQA 19 | import torch 20 | from model import SCAN, xattn_score_t2i, xattn_score_i2t 21 | from collections import OrderedDict 22 | import time 23 | from torch.autograd import Variable 24 | 25 | class AverageMeter(object): 26 | """Computes and stores the average and current value""" 27 | 28 | def __init__(self): 29 | self.reset() 30 | 31 | def reset(self): 32 | self.val = 0 33 | self.avg = 0 34 | self.sum = 0 35 | self.count = 0 36 | 37 | def update(self, val, n=0): 38 | self.val = val 39 | self.sum += val * n 40 | self.count += n 41 | self.avg = self.sum / (.0001 + self.count) 42 | 43 | def __str__(self): 44 | """String representation for logging 45 | """ 46 | # for values that should be recorded exactly e.g. iteration number 47 | if self.count == 0: 48 | return str(self.val) 49 | # for stats 50 | return '%.4f (%.4f)' % (self.val, self.avg) 51 | 52 | 53 | class LogCollector(object): 54 | """A collection of logging objects that can change from train to val""" 55 | 56 | def __init__(self): 57 | # to keep the order of logged variables deterministic 58 | self.meters = OrderedDict() 59 | 60 | def update(self, k, v, n=0): 61 | # create a new meter if previously not recorded 62 | if k not in self.meters: 63 | self.meters[k] = AverageMeter() 64 | self.meters[k].update(v, n) 65 | 66 | def __str__(self): 67 | """Concatenate the meters in one log line 68 | """ 69 | s = '' 70 | for i, (k, v) in enumerate(self.meters.iteritems()): 71 | if i > 0: 72 | s += ' ' 73 | s += k + ' ' + str(v) 74 | return s 75 | 76 | def tb_log(self, tb_logger, prefix='', step=None): 77 | """Log using tensorboard 78 | """ 79 | for k, v in self.meters.iteritems(): 80 | tb_logger.log_value(prefix + k, v.val, step=step) 81 | 82 | 83 | def encode_data(model, data_loader, log_step=10, logging=print): 84 | """Encode all images and captions loadable by `data_loader` 85 | """ 86 | batch_time = AverageMeter() 87 | val_logger = LogCollector() 88 | 89 | # switch to evaluate mode 90 | model.val_start() 91 | 92 | end = time.time() 93 | 94 | # np array to keep all the embeddings 95 | img_embs = None 96 | cap_embs = None 97 | cap_lens = None 98 | 99 | max_n_word = 0 100 | for i, (images, captions, lengths, ids) in enumerate(data_loader): 101 | max_n_word = max(max_n_word, max(lengths)) 102 | 103 | for i, (images, captions, lengths, ids) in enumerate(data_loader): 104 | # make sure val logger is used 105 | model.logger = val_logger 106 | 107 | # compute the embeddings 108 | img_emb, cap_emb, cap_len = model.forward_emb(images, captions, lengths, volatile=True) 109 | #print(img_emb) 110 | if img_embs is None: 111 | if img_emb.dim() == 3: 112 | img_embs = np.zeros((len(data_loader.dataset), img_emb.size(1), img_emb.size(2))) 113 | else: 114 | img_embs = np.zeros((len(data_loader.dataset), img_emb.size(1))) 115 | cap_embs = np.zeros((len(data_loader.dataset), max_n_word, cap_emb.size(2))) 116 | cap_lens = [0] * len(data_loader.dataset) 117 | # cache embeddings 118 | img_embs[ids] = img_emb.data.cpu().numpy().copy() 119 | cap_embs[ids,:max(lengths),:] = cap_emb.data.cpu().numpy().copy() 120 | for j, nid in enumerate(ids): 121 | cap_lens[nid] = cap_len[j] 122 | 123 | # measure accuracy and record loss 124 | model.forward_loss(img_emb, cap_emb, cap_len) 125 | 126 | # measure elapsed time 127 | batch_time.update(time.time() - end) 128 | end = time.time() 129 | 130 | if i % log_step == 0: 131 | logging('Test: [{0}/{1}]\t' 132 | '{e_log}\t' 133 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 134 | .format( 135 | i, len(data_loader), batch_time=batch_time, 136 | e_log=str(model.logger))) 137 | del images, captions 138 | return img_embs, cap_embs, cap_lens 139 | 140 | 141 | def evalrank(model_path, data_path=None, split='dev', fold5=False): 142 | """ 143 | Evaluate a trained model on either dev or test. If `fold5=True`, 5 fold 144 | cross-validation is done (only for MSCOCO). Otherwise, the full data is 145 | used for evaluation. 146 | """ 147 | # load model and options 148 | checkpoint = torch.load(model_path) 149 | opt = checkpoint['opt'] 150 | print(opt) 151 | if data_path is not None: 152 | opt.data_path = data_path 153 | 154 | # load vocabulary used by the model 155 | vocab = deserialize_vocab(os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name)) 156 | opt.vocab_size = len(vocab) 157 | 158 | # construct model 159 | model = SCAN(opt) 160 | 161 | # load model state 162 | model.load_state_dict(checkpoint['model']) 163 | 164 | print('Loading dataset') 165 | data_loader = get_test_loader(split, opt.data_name, vocab, 166 | opt.batch_size, opt.workers, opt) 167 | 168 | print('Computing results...') 169 | img_embs, cap_embs, cap_lens = encode_data(model, data_loader) 170 | print('Images: %d, Captions: %d' % 171 | (img_embs.shape[0] / 5, cap_embs.shape[0])) 172 | 173 | 174 | if not fold5: 175 | # no cross-validation, full evaluation 176 | img_embs = np.array([img_embs[i] for i in range(0, len(img_embs), 5)]) 177 | start = time.time() 178 | if opt.cross_attn == 't2i': 179 | sims = shard_xattn_t2i(img_embs, cap_embs, cap_lens, opt, shard_size=128) 180 | elif opt.cross_attn == 'i2t': 181 | sims = shard_xattn_i2t(img_embs, cap_embs, cap_lens, opt, shard_size=128) 182 | else: 183 | raise NotImplementedError 184 | end = time.time() 185 | print("calculate similarity time:", end-start) 186 | 187 | r, rt = i2t(img_embs, cap_embs, cap_lens, sims, return_ranks=True) 188 | ri, rti = t2i(img_embs, cap_embs, cap_lens, sims, return_ranks=True) 189 | ar = (r[0] + r[1] + r[2]) / 3 190 | ari = (ri[0] + ri[1] + ri[2]) / 3 191 | rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2] 192 | print("rsum: %.1f" % rsum) 193 | print("Average i2t Recall: %.1f" % ar) 194 | print("Image to text: %.1f %.1f %.1f %.1f %.1f" % r) 195 | print("Average t2i Recall: %.1f" % ari) 196 | print("Text to image: %.1f %.1f %.1f %.1f %.1f" % ri) 197 | else: 198 | # 5fold cross-validation, only for MSCOCO 199 | results = [] 200 | for i in range(5): 201 | img_embs_shard = img_embs[i * 5000:(i + 1) * 5000:5] 202 | cap_embs_shard = cap_embs[i * 5000:(i + 1) * 5000] 203 | cap_lens_shard = cap_lens[i * 5000:(i + 1) * 5000] 204 | start = time.time() 205 | if opt.cross_attn == 't2i': 206 | sims = shard_xattn_t2i(img_embs_shard, cap_embs_shard, cap_lens_shard, opt, shard_size=128) 207 | elif opt.cross_attn == 'i2t': 208 | sims = shard_xattn_i2t(img_embs_shard, cap_embs_shard, cap_lens_shard, opt, shard_size=128) 209 | else: 210 | raise NotImplementedError 211 | end = time.time() 212 | print("calculate similarity time:", end-start) 213 | 214 | r, rt0 = i2t(img_embs_shard, cap_embs_shard, cap_lens_shard, sims, return_ranks=True) 215 | print("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" % r) 216 | ri, rti0 = t2i(img_embs_shard, cap_embs_shard, cap_lens_shard, sims, return_ranks=True) 217 | print("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" % ri) 218 | 219 | if i == 0: 220 | rt, rti = rt0, rti0 221 | ar = (r[0] + r[1] + r[2]) / 3 222 | ari = (ri[0] + ri[1] + ri[2]) / 3 223 | rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2] 224 | print("rsum: %.1f ar: %.1f ari: %.1f" % (rsum, ar, ari)) 225 | results += [list(r) + list(ri) + [ar, ari, rsum]] 226 | 227 | print("-----------------------------------") 228 | print("Mean metrics: ") 229 | mean_metrics = tuple(np.array(results).mean(axis=0).flatten()) 230 | print("rsum: %.1f" % (mean_metrics[10] * 6)) 231 | print("Average i2t Recall: %.1f" % mean_metrics[11]) 232 | print("Image to text: %.1f %.1f %.1f %.1f %.1f" % 233 | mean_metrics[:5]) 234 | print("Average t2i Recall: %.1f" % mean_metrics[12]) 235 | print("Text to image: %.1f %.1f %.1f %.1f %.1f" % 236 | mean_metrics[5:10]) 237 | 238 | torch.save({'rt': rt, 'rti': rti}, 'ranks.pth.tar') 239 | 240 | 241 | def softmax(X, axis): 242 | """ 243 | Compute the softmax of each element along an axis of X. 244 | """ 245 | y = np.atleast_2d(X) 246 | # subtract the max for numerical stability 247 | y = y - np.expand_dims(np.max(y, axis = axis), axis) 248 | # exponentiate y 249 | y = np.exp(y) 250 | # take the sum along the specified axis 251 | ax_sum = np.expand_dims(np.sum(y, axis = axis), axis) 252 | # finally: divide elementwise 253 | p = y / ax_sum 254 | return p 255 | 256 | 257 | def shard_xattn_t2i(images, captions, caplens, opt, shard_size=128): 258 | """ 259 | Computer pairwise t2i image-caption distance with locality sharding 260 | """ 261 | n_im_shard = (len(images)-1)/shard_size + 1 262 | n_cap_shard = (len(captions)-1)/shard_size + 1 263 | 264 | d = np.zeros((len(images), len(captions))) 265 | for i in range(n_im_shard): 266 | im_start, im_end = shard_size*i, min(shard_size*(i+1), len(images)) 267 | for j in range(n_cap_shard): 268 | sys.stdout.write('\r>> shard_xattn_t2i batch (%d,%d)' % (i,j)) 269 | cap_start, cap_end = shard_size*j, min(shard_size*(j+1), len(captions)) 270 | im = Variable(torch.from_numpy(images[im_start:im_end]), volatile=True).cuda() 271 | s = Variable(torch.from_numpy(captions[cap_start:cap_end]), volatile=True).cuda() 272 | l = caplens[cap_start:cap_end] 273 | sim = xattn_score_t2i(im, s, l, opt) 274 | d[im_start:im_end, cap_start:cap_end] = sim.data.cpu().numpy() 275 | sys.stdout.write('\n') 276 | return d 277 | 278 | 279 | def shard_xattn_i2t(images, captions, caplens, opt, shard_size=128): 280 | """ 281 | Computer pairwise i2t image-caption distance with locality sharding 282 | """ 283 | n_im_shard = (len(images)-1)/shard_size + 1 284 | n_cap_shard = (len(captions)-1)/shard_size + 1 285 | 286 | d = np.zeros((len(images), len(captions))) 287 | for i in range(n_im_shard): 288 | im_start, im_end = shard_size*i, min(shard_size*(i+1), len(images)) 289 | for j in range(n_cap_shard): 290 | sys.stdout.write('\r>> shard_xattn_i2t batch (%d,%d)' % (i,j)) 291 | cap_start, cap_end = shard_size*j, min(shard_size*(j+1), len(captions)) 292 | im = Variable(torch.from_numpy(images[im_start:im_end]), volatile=True).cuda() 293 | s = Variable(torch.from_numpy(captions[cap_start:cap_end]), volatile=True).cuda() 294 | l = caplens[cap_start:cap_end] 295 | sim = xattn_score_i2t(im, s, l, opt) 296 | d[im_start:im_end, cap_start:cap_end] = sim.data.cpu().numpy() 297 | sys.stdout.write('\n') 298 | return d 299 | 300 | 301 | def i2t(images, captions, caplens, sims, npts=None, return_ranks=False): 302 | """ 303 | Images->Text (Image Annotation) 304 | Images: (N, n_region, d) matrix of images 305 | Captions: (5N, max_n_word, d) matrix of captions 306 | CapLens: (5N) array of caption lengths 307 | sims: (N, 5N) matrix of similarity im-cap 308 | """ 309 | npts = images.shape[0] 310 | ranks = np.zeros(npts) 311 | top1 = np.zeros(npts) 312 | for index in range(npts): 313 | inds = np.argsort(sims[index])[::-1] 314 | # Score 315 | rank = 1e20 316 | for i in range(5 * index, 5 * index + 5, 1): 317 | tmp = np.where(inds == i)[0][0] 318 | if tmp < rank: 319 | rank = tmp 320 | ranks[index] = rank 321 | top1[index] = inds[0] 322 | 323 | # Compute metrics 324 | r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 325 | r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 326 | r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 327 | medr = np.floor(np.median(ranks)) + 1 328 | meanr = ranks.mean() + 1 329 | if return_ranks: 330 | return (r1, r5, r10, medr, meanr), (ranks, top1) 331 | else: 332 | return (r1, r5, r10, medr, meanr) 333 | 334 | 335 | def t2i(images, captions, caplens, sims, npts=None, return_ranks=False): 336 | """ 337 | Text->Images (Image Search) 338 | Images: (N, n_region, d) matrix of images 339 | Captions: (5N, max_n_word, d) matrix of captions 340 | CapLens: (5N) array of caption lengths 341 | sims: (N, 5N) matrix of similarity im-cap 342 | """ 343 | npts = images.shape[0] 344 | ranks = np.zeros(5 * npts) 345 | top1 = np.zeros(5 * npts) 346 | 347 | # --> (5N(caption), N(image)) 348 | sims = sims.T 349 | 350 | for index in range(npts): 351 | for i in range(5): 352 | inds = np.argsort(sims[5 * index + i])[::-1] 353 | ranks[5 * index + i] = np.where(inds == index)[0][0] 354 | top1[5 * index + i] = inds[0] 355 | 356 | # Compute metrics 357 | r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 358 | r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 359 | r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 360 | medr = np.floor(np.median(ranks)) + 1 361 | meanr = ranks.mean() + 1 362 | if return_ranks: 363 | return (r1, r5, r10, medr, meanr), (ranks, top1) 364 | else: 365 | return (r1, r5, r10, medr, meanr) 366 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------- 2 | # Stacked Cross Attention Network implementation based on 3 | # https://arxiv.org/abs/1803.08024. 4 | # "Stacked Cross Attention for Image-Text Matching" 5 | # Kuang-Huei Lee, Xi Chen, Gang Hua, Houdong Hu, Xiaodong He 6 | # 7 | # Writen by Kuang-Huei Lee, 2018 8 | # --------------------------------------------------------------- 9 | """SCAN model""" 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.init 14 | import torchvision.models as models 15 | from torch.autograd import Variable 16 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 17 | from torch.nn.utils.weight_norm import weight_norm 18 | import torch.backends.cudnn as cudnn 19 | from torch.nn.utils.clip_grad import clip_grad_norm 20 | import numpy as np 21 | from collections import OrderedDict 22 | 23 | 24 | def l1norm(X, dim, eps=1e-8): 25 | """L1-normalize columns of X 26 | """ 27 | norm = torch.abs(X).sum(dim=dim, keepdim=True) + eps 28 | X = torch.div(X, norm) 29 | return X 30 | 31 | 32 | def l2norm(X, dim, eps=1e-8): 33 | """L2-normalize columns of X 34 | """ 35 | norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps 36 | X = torch.div(X, norm) 37 | return X 38 | 39 | 40 | def EncoderImage(data_name, img_dim, embed_size, precomp_enc_type='basic', 41 | no_imgnorm=False): 42 | """A wrapper to image encoders. Chooses between an different encoders 43 | that uses precomputed image features. 44 | """ 45 | if precomp_enc_type == 'basic': 46 | img_enc = EncoderImagePrecomp( 47 | img_dim, embed_size, no_imgnorm) 48 | elif precomp_enc_type == 'weight_norm': 49 | img_enc = EncoderImageWeightNormPrecomp( 50 | img_dim, embed_size, no_imgnorm) 51 | else: 52 | raise ValueError("Unknown precomp_enc_type: {}".format(precomp_enc_type)) 53 | 54 | return img_enc 55 | 56 | 57 | class EncoderImagePrecomp(nn.Module): 58 | 59 | def __init__(self, img_dim, embed_size, no_imgnorm=False): 60 | super(EncoderImagePrecomp, self).__init__() 61 | self.embed_size = embed_size 62 | self.no_imgnorm = no_imgnorm 63 | self.fc = nn.Linear(img_dim, embed_size) 64 | 65 | self.init_weights() 66 | 67 | def init_weights(self): 68 | """Xavier initialization for the fully connected layer 69 | """ 70 | r = np.sqrt(6.) / np.sqrt(self.fc.in_features + 71 | self.fc.out_features) 72 | self.fc.weight.data.uniform_(-r, r) 73 | self.fc.bias.data.fill_(0) 74 | 75 | def forward(self, images): 76 | """Extract image feature vectors.""" 77 | # assuming that the precomputed features are already l2-normalized 78 | 79 | features = self.fc(images) 80 | 81 | # normalize in the joint embedding space 82 | if not self.no_imgnorm: 83 | features = l2norm(features, dim=-1) 84 | 85 | return features 86 | 87 | def load_state_dict(self, state_dict): 88 | """Copies parameters. overwritting the default one to 89 | accept state_dict from Full model 90 | """ 91 | own_state = self.state_dict() 92 | new_state = OrderedDict() 93 | for name, param in state_dict.items(): 94 | if name in own_state: 95 | new_state[name] = param 96 | 97 | super(EncoderImagePrecomp, self).load_state_dict(new_state) 98 | 99 | 100 | class EncoderImageWeightNormPrecomp(nn.Module): 101 | 102 | def __init__(self, img_dim, embed_size, no_imgnorm=False): 103 | super(EncoderImageWeightNormPrecomp, self).__init__() 104 | self.embed_size = embed_size 105 | self.no_imgnorm = no_imgnorm 106 | self.fc = weight_norm(nn.Linear(img_dim, embed_size), dim=None) 107 | 108 | def forward(self, images): 109 | """Extract image feature vectors.""" 110 | # assuming that the precomputed features are already l2-normalized 111 | 112 | features = self.fc(images) 113 | 114 | # normalize in the joint embedding space 115 | if not self.no_imgnorm: 116 | features = l2norm(features, dim=-1) 117 | 118 | return features 119 | 120 | def load_state_dict(self, state_dict): 121 | """Copies parameters. overwritting the default one to 122 | accept state_dict from Full model 123 | """ 124 | own_state = self.state_dict() 125 | new_state = OrderedDict() 126 | for name, param in state_dict.items(): 127 | if name in own_state: 128 | new_state[name] = param 129 | 130 | super(EncoderImageWeightNormPrecomp, self).load_state_dict(new_state) 131 | 132 | 133 | # RNN Based Language Model 134 | class EncoderText(nn.Module): 135 | 136 | def __init__(self, vocab_size, word_dim, embed_size, num_layers, 137 | use_bi_gru=False, no_txtnorm=False): 138 | super(EncoderText, self).__init__() 139 | self.embed_size = embed_size 140 | self.no_txtnorm = no_txtnorm 141 | 142 | # word embedding 143 | self.embed = nn.Embedding(vocab_size, word_dim) 144 | 145 | # caption embedding 146 | self.use_bi_gru = use_bi_gru 147 | self.rnn = nn.GRU(word_dim, embed_size, num_layers, batch_first=True, bidirectional=use_bi_gru) 148 | 149 | self.init_weights() 150 | 151 | def init_weights(self): 152 | self.embed.weight.data.uniform_(-0.1, 0.1) 153 | 154 | def forward(self, x, lengths): 155 | """Handles variable size captions 156 | """ 157 | # Embed word ids to vectors 158 | x = self.embed(x) 159 | packed = pack_padded_sequence(x, lengths, batch_first=True) 160 | 161 | # Forward propagate RNN 162 | out, _ = self.rnn(packed) 163 | 164 | # Reshape *final* output to (batch_size, hidden_size) 165 | padded = pad_packed_sequence(out, batch_first=True) 166 | cap_emb, cap_len = padded 167 | 168 | if self.use_bi_gru: 169 | cap_emb = (cap_emb[:,:,:cap_emb.size(2)/2] + cap_emb[:,:,cap_emb.size(2)/2:])/2 170 | 171 | # normalization in the joint embedding space 172 | if not self.no_txtnorm: 173 | cap_emb = l2norm(cap_emb, dim=-1) 174 | 175 | return cap_emb, cap_len 176 | 177 | 178 | def func_attention(query, context, opt, smooth, eps=1e-8): 179 | """ 180 | query: (n_context, queryL, d) 181 | context: (n_context, sourceL, d) 182 | """ 183 | batch_size_q, queryL = query.size(0), query.size(1) 184 | batch_size, sourceL = context.size(0), context.size(1) 185 | 186 | 187 | # Get attention 188 | # --> (batch, d, queryL) 189 | queryT = torch.transpose(query, 1, 2) 190 | 191 | # (batch, sourceL, d)(batch, d, queryL) 192 | # --> (batch, sourceL, queryL) 193 | attn = torch.bmm(context, queryT) 194 | if opt.raw_feature_norm == "softmax": 195 | # --> (batch*sourceL, queryL) 196 | attn = attn.view(batch_size*sourceL, queryL) 197 | attn = nn.Softmax()(attn) 198 | # --> (batch, sourceL, queryL) 199 | attn = attn.view(batch_size, sourceL, queryL) 200 | elif opt.raw_feature_norm == "l2norm": 201 | attn = l2norm(attn, 2) 202 | elif opt.raw_feature_norm == "clipped_l2norm": 203 | attn = nn.LeakyReLU(0.1)(attn) 204 | attn = l2norm(attn, 2) 205 | elif opt.raw_feature_norm == "l1norm": 206 | attn = l1norm_d(attn, 2) 207 | elif opt.raw_feature_norm == "clipped_l1norm": 208 | attn = nn.LeakyReLU(0.1)(attn) 209 | attn = l1norm_d(attn, 2) 210 | elif opt.raw_feature_norm == "clipped": 211 | attn = nn.LeakyReLU(0.1)(attn) 212 | elif opt.raw_feature_norm == "no_norm": 213 | pass 214 | else: 215 | raise ValueError("unknown first norm type:", opt.raw_feature_norm) 216 | # --> (batch, queryL, sourceL) 217 | attn = torch.transpose(attn, 1, 2).contiguous() 218 | # --> (batch*queryL, sourceL) 219 | attn = attn.view(batch_size*queryL, sourceL) 220 | attn = nn.Softmax()(attn*smooth) 221 | # --> (batch, queryL, sourceL) 222 | attn = attn.view(batch_size, queryL, sourceL) 223 | # --> (batch, sourceL, queryL) 224 | attnT = torch.transpose(attn, 1, 2).contiguous() 225 | 226 | # --> (batch, d, sourceL) 227 | contextT = torch.transpose(context, 1, 2) 228 | # (batch x d x sourceL)(batch x sourceL x queryL) 229 | # --> (batch, d, queryL) 230 | weightedContext = torch.bmm(contextT, attnT) 231 | # --> (batch, queryL, d) 232 | weightedContext = torch.transpose(weightedContext, 1, 2) 233 | 234 | return weightedContext, attnT 235 | 236 | 237 | def cosine_similarity(x1, x2, dim=1, eps=1e-8): 238 | """Returns cosine similarity between x1 and x2, computed along dim.""" 239 | w12 = torch.sum(x1 * x2, dim) 240 | w1 = torch.norm(x1, 2, dim) 241 | w2 = torch.norm(x2, 2, dim) 242 | return (w12 / (w1 * w2).clamp(min=eps)).squeeze() 243 | 244 | 245 | def xattn_score_t2i(images, captions, cap_lens, opt): 246 | """ 247 | Images: (n_image, n_regions, d) matrix of images 248 | Captions: (n_caption, max_n_word, d) matrix of captions 249 | CapLens: (n_caption) array of caption lengths 250 | """ 251 | similarities = [] 252 | n_image = images.size(0) 253 | n_caption = captions.size(0) 254 | for i in range(n_caption): 255 | # Get the i-th text description 256 | n_word = cap_lens[i] 257 | cap_i = captions[i, :n_word, :].unsqueeze(0).contiguous() 258 | # --> (n_image, n_word, d) 259 | cap_i_expand = cap_i.repeat(n_image, 1, 1) 260 | """ 261 | word(query): (n_image, n_word, d) 262 | image(context): (n_image, n_regions, d) 263 | weiContext: (n_image, n_word, d) 264 | attn: (n_image, n_region, n_word) 265 | """ 266 | weiContext, attn = func_attention(cap_i_expand, images, opt, smooth=opt.lambda_softmax) 267 | cap_i_expand = cap_i_expand.contiguous() 268 | weiContext = weiContext.contiguous() 269 | # (n_image, n_word) 270 | row_sim = cosine_similarity(cap_i_expand, weiContext, dim=2) 271 | if opt.agg_func == 'LogSumExp': 272 | row_sim.mul_(opt.lambda_lse).exp_() 273 | row_sim = row_sim.sum(dim=1, keepdim=True) 274 | row_sim = torch.log(row_sim)/opt.lambda_lse 275 | elif opt.agg_func == 'Max': 276 | row_sim = row_sim.max(dim=1, keepdim=True)[0] 277 | elif opt.agg_func == 'Sum': 278 | row_sim = row_sim.sum(dim=1, keepdim=True) 279 | elif opt.agg_func == 'Mean': 280 | row_sim = row_sim.mean(dim=1, keepdim=True) 281 | else: 282 | raise ValueError("unknown aggfunc: {}".format(opt.agg_func)) 283 | similarities.append(row_sim) 284 | 285 | # (n_image, n_caption) 286 | similarities = torch.cat(similarities, 1) 287 | 288 | return similarities 289 | 290 | 291 | def xattn_score_i2t(images, captions, cap_lens, opt): 292 | """ 293 | Images: (batch_size, n_regions, d) matrix of images 294 | Captions: (batch_size, max_n_words, d) matrix of captions 295 | CapLens: (batch_size) array of caption lengths 296 | """ 297 | similarities = [] 298 | n_image = images.size(0) 299 | n_caption = captions.size(0) 300 | n_region = images.size(1) 301 | for i in range(n_caption): 302 | # Get the i-th text description 303 | n_word = cap_lens[i] 304 | cap_i = captions[i, :n_word, :].unsqueeze(0).contiguous() 305 | # (n_image, n_word, d) 306 | cap_i_expand = cap_i.repeat(n_image, 1, 1) 307 | """ 308 | word(query): (n_image, n_word, d) 309 | image(context): (n_image, n_region, d) 310 | weiContext: (n_image, n_region, d) 311 | attn: (n_image, n_word, n_region) 312 | """ 313 | weiContext, attn = func_attention(images, cap_i_expand, opt, smooth=opt.lambda_softmax) 314 | # (n_image, n_region) 315 | row_sim = cosine_similarity(images, weiContext, dim=2) 316 | if opt.agg_func == 'LogSumExp': 317 | row_sim.mul_(opt.lambda_lse).exp_() 318 | row_sim = row_sim.sum(dim=1, keepdim=True) 319 | row_sim = torch.log(row_sim)/opt.lambda_lse 320 | elif opt.agg_func == 'Max': 321 | row_sim = row_sim.max(dim=1, keepdim=True)[0] 322 | elif opt.agg_func == 'Sum': 323 | row_sim = row_sim.sum(dim=1, keepdim=True) 324 | elif opt.agg_func == 'Mean': 325 | row_sim = row_sim.mean(dim=1, keepdim=True) 326 | else: 327 | raise ValueError("unknown aggfunc: {}".format(opt.agg_func)) 328 | similarities.append(row_sim) 329 | 330 | # (n_image, n_caption) 331 | similarities = torch.cat(similarities, 1) 332 | return similarities 333 | 334 | 335 | class ContrastiveLoss(nn.Module): 336 | """ 337 | Compute contrastive loss 338 | """ 339 | def __init__(self, opt, margin=0, max_violation=False): 340 | super(ContrastiveLoss, self).__init__() 341 | self.opt = opt 342 | self.margin = margin 343 | self.max_violation = max_violation 344 | 345 | def forward(self, im, s, s_l): 346 | # compute image-sentence score matrix 347 | if self.opt.cross_attn == 't2i': 348 | scores = xattn_score_t2i(im, s, s_l, self.opt) 349 | elif self.opt.cross_attn == 'i2t': 350 | scores = xattn_score_i2t(im, s, s_l, self.opt) 351 | else: 352 | raise ValueError("unknown first norm type:", opt.raw_feature_norm) 353 | diagonal = scores.diag().view(im.size(0), 1) 354 | d1 = diagonal.expand_as(scores) 355 | d2 = diagonal.t().expand_as(scores) 356 | 357 | # compare every diagonal score to scores in its column 358 | # caption retrieval 359 | cost_s = (self.margin + scores - d1).clamp(min=0) 360 | # compare every diagonal score to scores in its row 361 | # image retrieval 362 | cost_im = (self.margin + scores - d2).clamp(min=0) 363 | 364 | # clear diagonals 365 | mask = torch.eye(scores.size(0)) > .5 366 | I = Variable(mask) 367 | if torch.cuda.is_available(): 368 | I = I.cuda() 369 | cost_s = cost_s.masked_fill_(I, 0) 370 | cost_im = cost_im.masked_fill_(I, 0) 371 | 372 | # keep the maximum violating negative for each query 373 | if self.max_violation: 374 | cost_s = cost_s.max(1)[0] 375 | cost_im = cost_im.max(0)[0] 376 | return cost_s.sum() + cost_im.sum() 377 | 378 | 379 | class SCAN(object): 380 | """ 381 | Stacked Cross Attention Network (SCAN) model 382 | """ 383 | def __init__(self, opt): 384 | # Build Models 385 | self.grad_clip = opt.grad_clip 386 | self.img_enc = EncoderImage(opt.data_name, opt.img_dim, opt.embed_size, 387 | precomp_enc_type=opt.precomp_enc_type, 388 | no_imgnorm=opt.no_imgnorm) 389 | self.txt_enc = EncoderText(opt.vocab_size, opt.word_dim, 390 | opt.embed_size, opt.num_layers, 391 | use_bi_gru=opt.bi_gru, 392 | no_txtnorm=opt.no_txtnorm) 393 | if torch.cuda.is_available(): 394 | self.img_enc.cuda() 395 | self.txt_enc.cuda() 396 | cudnn.benchmark = True 397 | 398 | # Loss and Optimizer 399 | self.criterion = ContrastiveLoss(opt=opt, 400 | margin=opt.margin, 401 | max_violation=opt.max_violation) 402 | params = list(self.txt_enc.parameters()) 403 | params += list(self.img_enc.fc.parameters()) 404 | 405 | self.params = params 406 | 407 | self.optimizer = torch.optim.Adam(params, lr=opt.learning_rate) 408 | 409 | self.Eiters = 0 410 | 411 | def state_dict(self): 412 | state_dict = [self.img_enc.state_dict(), self.txt_enc.state_dict()] 413 | return state_dict 414 | 415 | def load_state_dict(self, state_dict): 416 | self.img_enc.load_state_dict(state_dict[0]) 417 | self.txt_enc.load_state_dict(state_dict[1]) 418 | 419 | def train_start(self): 420 | """switch to train mode 421 | """ 422 | self.img_enc.train() 423 | self.txt_enc.train() 424 | 425 | def val_start(self): 426 | """switch to evaluate mode 427 | """ 428 | self.img_enc.eval() 429 | self.txt_enc.eval() 430 | 431 | def forward_emb(self, images, captions, lengths, volatile=False): 432 | """Compute the image and caption embeddings 433 | """ 434 | # Set mini-batch dataset 435 | images = Variable(images, volatile=volatile) 436 | captions = Variable(captions, volatile=volatile) 437 | if torch.cuda.is_available(): 438 | images = images.cuda() 439 | captions = captions.cuda() 440 | 441 | # Forward 442 | img_emb = self.img_enc(images) 443 | 444 | # cap_emb (tensor), cap_lens (list) 445 | cap_emb, cap_lens = self.txt_enc(captions, lengths) 446 | return img_emb, cap_emb, cap_lens 447 | 448 | def forward_loss(self, img_emb, cap_emb, cap_len, **kwargs): 449 | """Compute the loss given pairs of image and caption embeddings 450 | """ 451 | loss = self.criterion(img_emb, cap_emb, cap_len) 452 | self.logger.update('Le', loss.data[0], img_emb.size(0)) 453 | return loss 454 | 455 | def train_emb(self, images, captions, lengths, ids=None, *args): 456 | """One training step given images and captions. 457 | """ 458 | self.Eiters += 1 459 | self.logger.update('Eit', self.Eiters) 460 | self.logger.update('lr', self.optimizer.param_groups[0]['lr']) 461 | 462 | # compute the embeddings 463 | img_emb, cap_emb, cap_lens = self.forward_emb(images, captions, lengths) 464 | 465 | # measure accuracy and record loss 466 | self.optimizer.zero_grad() 467 | loss = self.forward_loss(img_emb, cap_emb, cap_lens) 468 | 469 | # compute gradient and do SGD step 470 | loss.backward() 471 | if self.grad_clip > 0: 472 | clip_grad_norm(self.params, self.grad_clip) 473 | self.optimizer.step() 474 | --------------------------------------------------------------------------------