├── cc152k.png ├── framework.png ├── flickr_mscoco.png ├── well_annotated.png ├── .gitignore ├── vocab.py ├── README.md ├── data.py ├── LICENSE ├── model └── SGRAF.py └── evaluation.py /cc152k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qxzha/UCPM/HEAD/cc152k.png -------------------------------------------------------------------------------- /framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qxzha/UCPM/HEAD/framework.png -------------------------------------------------------------------------------- /flickr_mscoco.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qxzha/UCPM/HEAD/flickr_mscoco.png -------------------------------------------------------------------------------- /well_annotated.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qxzha/UCPM/HEAD/well_annotated.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /vocab.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------- 2 | # Stacked Cross Attention Network implementation based on 3 | # https://arxiv.org/abs/1803.08024. 4 | # "Stacked Cross Attention for Image-Text Matching" 5 | # Kuang-Huei Lee, Xi Chen, Gang Hua, Houdong Hu, Xiaodong He 6 | # 7 | # Writen by Kuang-Huei Lee, 2018 8 | # --------------------------------------------------------------- 9 | """Vocabulary wrapper""" 10 | 11 | import nltk 12 | from collections import Counter 13 | import argparse 14 | import os 15 | import json 16 | import csv 17 | 18 | annotations = { 19 | 'coco_precomp': ['train_caps.txt', 'dev_caps.txt'], 20 | 'f30k_precomp': ['train_caps.txt', 'dev_caps.txt'], 21 | 'cc152k_precomp': ['train_caps.tsv', 'dev_caps.tsv'], 22 | 'cc510k_precomp': ['train_caps.txt', 'dev_caps.txt'], 23 | } 24 | 25 | 26 | class Vocabulary(object): 27 | """Simple vocabulary wrapper.""" 28 | 29 | def __init__(self): 30 | self.word2idx = {} 31 | self.idx2word = {} 32 | self.idx = 0 33 | 34 | def add_word(self, word): 35 | if word not in self.word2idx: 36 | self.word2idx[word] = self.idx 37 | self.idx2word[self.idx] = word 38 | self.idx += 1 39 | 40 | def __call__(self, word): 41 | if word not in self.word2idx: 42 | return self.word2idx[''] 43 | return self.word2idx[word] 44 | 45 | def __len__(self): 46 | return len(self.word2idx) 47 | 48 | 49 | def serialize_vocab(vocab, dest): 50 | d = {} 51 | d['word2idx'] = vocab.word2idx 52 | d['idx2word'] = vocab.idx2word 53 | d['idx'] = vocab.idx 54 | with open(dest, "w") as f: 55 | json.dump(d, f) 56 | 57 | 58 | def deserialize_vocab(src): 59 | with open(src) as f: 60 | d = json.load(f) 61 | vocab = Vocabulary() 62 | vocab.word2idx = d['word2idx'] 63 | vocab.idx2word = d['idx2word'] 64 | vocab.idx = d['idx'] 65 | vocab.add_word('') 66 | return vocab 67 | 68 | 69 | def deserialize_vocab_glove(src): 70 | with open(src) as f: 71 | d = json.load(f) 72 | vocab = Vocabulary() 73 | vocab.word2idx = d['word2idx'] 74 | vocab.idx2word = {v: k for k, v in vocab.word2idx.items()} 75 | vocab.idx = max(vocab.idx2word) 76 | return vocab 77 | 78 | 79 | def from_txt(txt): 80 | captions = [] 81 | with open(txt, 'r') as f: 82 | for line in f: 83 | captions.append(line.strip()) 84 | return captions 85 | 86 | 87 | def from_tsv(tsv): 88 | captions = [] 89 | img_ids = [] 90 | with open(tsv) as f: 91 | tsvreader = csv.reader(f, delimiter='\t') 92 | for line in tsvreader: 93 | captions.append(line[1]) 94 | img_ids.append(line[0]) 95 | return captions 96 | 97 | 98 | def build_vocab(data_path, data_name, caption_file, threshold): 99 | """Build a simple vocabulary wrapper.""" 100 | counter = Counter() 101 | for path in caption_file[data_name]: 102 | if data_name == 'cc152k_precomp': 103 | full_path = os.path.join(os.path.join(data_path, data_name), path) 104 | captions = from_tsv(full_path) 105 | elif data_name in ['coco_precomp', 'f30k_precomp','cc510k_precomp']: 106 | full_path = os.path.join(os.path.join(data_path, data_name), path) 107 | captions = from_txt(full_path) 108 | else: 109 | raise NotImplementedError('Not support!') 110 | 111 | for i, caption in enumerate(captions): 112 | tokens = nltk.tokenize.word_tokenize(caption.lower()) 113 | counter.update(tokens) 114 | 115 | if i % 1000 == 0: 116 | print("[%d/%d] tokenized the captions." % (i, len(captions))) 117 | 118 | # Discard if the occurrence of the word is less than min_word_cnt. 119 | words = [word for word, cnt in counter.items() if cnt >= threshold] 120 | 121 | # Create a vocab wrapper and add some special tokens. 122 | vocab = Vocabulary() 123 | vocab.add_word('') 124 | vocab.add_word('') 125 | vocab.add_word('') 126 | vocab.add_word('') 127 | 128 | # Add words to the vocabulary. 129 | for i, word in enumerate(words): 130 | vocab.add_word(word) 131 | return vocab 132 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # UCPM: Uncertainty-Guided Cross-Modal Retrieval with Partially Mismatched Pairs 2 | 3 | # 🚀 Seeking a PhD Opportunity 4 | I am currently looking for exciting PhD opportunities. If you know of any openings or can connect me with potential advisors, please feel free to reach out of me. 5 | I would greatly appreciate your support! 6 | 7 | **Email**: [quanxing.zha@gmail.com](quanxing.zha@gmail.com) 8 | 9 | --- 10 | 11 | The official pytorch implementation of [UCPM: Uncertainty-Guided Cross-Modal Retrieval with Partially Mismatched Pairs]() (submitted to IEEE TIP). 12 | 13 | ## Introduction 14 | 15 | ### UCPM framework 16 | 17 | 18 | ## Requirements 19 | - Python 3.8 20 | - PyTorch 1.20.0 21 | - numpy 22 | - scikit-learn 23 | - Punkt Sentence Tokenizer: 24 | 25 | ``` 26 | import nltk 27 | nltk.download() 28 | > d punkt 29 | ``` 30 | (Optional) if the above download failed, you can manually download it from [here](https://drive.google.com/file/d/1eY9FnCm1YbnU5PHwiay7agQjuwSPz3_Z/view?usp=drive_link). 31 | The directory structure is: 32 | ``` 33 | /home/username/ 34 | ├── nltk_data 35 | │ ├── tokenizers 36 | │ ├── punkt 37 | │ ├── czech.pickle 38 | │ ├── french.pickle 39 | │ ├── polish.pickle 40 | │ ├── ...... 41 | ``` 42 | 43 | ## DATASETS 44 | Our directory structure of ```data```. 45 | ``` 46 | data 47 | ├── f30k_precomp # pre-computed BUTD region features for Flickr30K, provided by SCAN 48 | │ ├── train_ids.txt 49 | │ ├── train_caps.txt 50 | │ ├── ...... 51 | │ 52 | ├── coco_precomp # pre-computed BUTD region features for COCO, provided by SCAN 53 | │ ├── train_ids.txt 54 | │ ├── train_caps.txt 55 | │ ├── ...... 56 | │ 57 | ├── cc152k_precomp # pre-computed BUTD region features for cc152k, provided by NCR 58 | │ ├── train_ids.txt 59 | │ ├── train_caps.tsv 60 | │ ├── ...... 61 | │ 62 | └── vocab # vocab files provided by SCAN and NCR 63 | ├── f30k_precomp_vocab.json 64 | ├── coco_precomp_vocab.json 65 | └── cc152k_precomp_vocab.json 66 | ``` 67 | ### Downloads 68 | We follow [NCR](https://github.com/XLearning-SCU/2021-NeurIPS-NCR) to obtain image features and vocabularies. 69 | [Download Dataset](https://ncr-paper.cdn.bcebos.com/data/NCR-data.tar) 70 | 71 | ### Noise Index 72 | If you want to experiment with the same noise index as in this paper, the noise index files can be downloaded from [here](https://drive.google.com/file/d/1JG0-dIS_d8SdaUw-Bbgf00rL8nOkHA5J/view?usp=drive_link). 73 | 74 | ## Training and Evaluation 75 | ### Training 76 | Coming soon 77 | 78 | ### Pre-trained models and Evaluation 79 | The pre-trained models are available here: 80 | 81 | 1. CC152K model [Download](https://drive.google.com/file/d/1pmnNmxZDcO99Jb0li1_vU9kkrz3N7wAO/view?usp=drive_link) 82 | 2. F30k 20% MRate model [Download](https://drive.google.com/file/d/1Ut15QxkkaEjpDVIjU4xZb58HrKWuZcNA/view?usp=drive_link) 83 | 3. F30k 40% MRate model [Download](https://drive.google.com/file/d/1E83kUnr1gwrvPB0Ry4zFJwA0g3dfLsPp/view?usp=drive_link) 84 | 4. F30k 60% MRate model [Download](https://drive.google.com/file/d/10TidbMZ68iO0ERRF9M6_wxSDLXDB0QKv/view?usp=drive_link) 85 | 5. F30k 80% MRate model [Download](https://drive.google.com/file/d/1mnG8Nw9ZhpnCEIVYMBcEPgCgOsSd8SlG/view?usp=drive_link) 86 | 6. COCO 20% MRate model [Download](https://drive.google.com/file/d/1Ck6bReHF0rQjNeVEwmRKnBn_swBy2ej8/view?usp=drive_link) 87 | 7. COCO 40% MRate model [Download](https://drive.google.com/file/d/1nwwR8sHbJlz5fj7yHCrz9Q4dXmX4PLdM/view?usp=drive_link) 88 | 8. COCO 60% MRate model [Download](https://drive.google.com/file/d/1WG-GzfljnwdAoj9DlFIKfzw_YN4kXd7j/view?usp=drive_link) 89 | 9. COCO 80% MRate model [Download](https://drive.google.com/file/d/1_rbC88LOKthc7fmVxm3YPY4trKKoEFtg/view?usp=drive_link) 90 | 91 | Modify the ```data_path```, ```vocab_path```, and ```model_paths``` in the ```eval.py``` file and run it. 92 | ``` 93 | python eval.py 94 | ``` 95 | 96 | ## Experiment Results 97 | 98 | ### Results on Well-Annotated Datasets 99 | 100 | 101 | ### Results on Simulated PMPs 102 | 103 | 104 | ### Results on Real-World PMPs 105 | 106 | 107 | ## License 108 | [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0) 109 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | """Dataloader""" 2 | 3 | import csv 4 | import torch 5 | import torch.utils.data as data 6 | import os 7 | import nltk 8 | import numpy as np 9 | import random 10 | import h5py 11 | 12 | 13 | class PrecompDataset(data.Dataset): 14 | """ 15 | Load precomputed captions and image features 16 | Possible options: f30k_precomp, coco_precomp, cc152k_precomp 17 | """ 18 | 19 | def __init__(self, data_path, data_split, vocab, opt=None, logger=None): 20 | self.vocab = vocab 21 | self.data_split = data_split 22 | loc = data_path + '/' 23 | self.module = opt.module_name 24 | 25 | # load the raw captions 26 | self.captions = [] 27 | 28 | if 'cc152k' in opt.data_name: 29 | with open(loc + '%s_caps.tsv' % data_split) as f: 30 | tsvreader = csv.reader(f, delimiter='\t') 31 | for line in tsvreader: 32 | self.captions.append(line[1].strip()) 33 | else: 34 | with open(loc + '%s_caps.txt' % data_split, 'r', encoding='utf-8') as f: 35 | for line in f.readlines(): 36 | self.captions.append(line.strip()) 37 | 38 | self.images = np.load(loc + '%s_ims.npy' % data_split) 39 | 40 | # rkiros data has redundancy in images, we divide by 5 41 | img_len = self.images.shape[0] 42 | self.length = len(self.captions) 43 | if img_len != self.length: 44 | self.im_div = 5 45 | else: 46 | self.im_div = 1 47 | 48 | if data_split == 'dev': 49 | self.length = 5000 50 | if 'cc152k' in opt.data_name: 51 | self.length = 1000 52 | 53 | self.noise_type = 'NCR' 54 | if self.noise_type == 'DECL': 55 | self.noisy_inx = np.arange(img_len) 56 | if data_split == 'train' and opt.noise_ratio > 0.0: 57 | noise_file = opt.noise_file 58 | if os.path.exists(noise_file): 59 | logger.info('=> load noisy index from {}'.format(noise_file)) 60 | self.noisy_inx = np.load(noise_file) 61 | else: 62 | noisy_ratio = opt.noise_ratio 63 | inx = np.arange(img_len) 64 | np.random.shuffle(inx) 65 | noisy_inx = inx[0: int(noisy_ratio * img_len)] 66 | shuffle_noisy_inx = np.array(noisy_inx) 67 | np.random.shuffle(shuffle_noisy_inx) 68 | self.noisy_inx[noisy_inx] = shuffle_noisy_inx 69 | np.save(noise_file, self.noisy_inx) 70 | logger.info('Noisy rate: %g' % noisy_ratio) 71 | 72 | self.real_correspondence = np.zeros(self.length) 73 | for i in range(self.length): 74 | if self.noisy_inx[i // self.im_div] == i // 5: 75 | self.real_correspondence[i] = 1.0 76 | else: 77 | self.noisy_inx = np.arange(0, self.length) // self.im_div 78 | if data_split == 'train' and opt.noise_ratio > 0.0: 79 | noise_file = opt.noise_file 80 | if os.path.exists(noise_file): 81 | logger.info('=> load noisy index from {}'.format(noise_file)) 82 | self.noisy_inx = np.load(noise_file) 83 | else: 84 | idx = np.arange(self.length) 85 | np.random.shuffle(idx) 86 | noise_length = int(opt.noise_ratio * self.length) 87 | shuffle_index = self.noisy_inx[idx[:noise_length]] 88 | np.random.shuffle(shuffle_index) 89 | self.noisy_inx[idx[:noise_length]] = shuffle_index 90 | np.save(noise_file, self.noisy_inx) 91 | logger.info("=> save noisy index to {}".format(noise_file)) 92 | 93 | self.real_correspondence = np.zeros(self.length) 94 | for i in range(self.length): 95 | if self.noisy_inx[i] == i // 5: 96 | self.real_correspondence[i] = 1.0 97 | print(data_split, self.noisy_inx.shape, self.real_correspondence.sum()) 98 | 99 | def process_caption(self, caption, caption_enhance): 100 | enhance = caption_enhance if self.data_split == 'train' else False 101 | if not enhance: 102 | tokens = nltk.tokenize.word_tokenize(caption.lower()) 103 | caption = list() 104 | caption.append(self.vocab('')) 105 | caption.extend([self.vocab(token) for token in tokens]) 106 | caption.append(self.vocab('')) 107 | target = torch.Tensor(caption) 108 | return target 109 | else: 110 | # Convert caption (string) to word ids. 111 | tokens = ['', ] 112 | tokens.extend(nltk.tokenize.word_tokenize(caption.lower())) 113 | tokens.append('') 114 | deleted_idx = [] 115 | for i, token in enumerate(tokens): 116 | prob = random.random() 117 | if prob < 0.20: 118 | prob /= 0.20 119 | # 50% randomly change token to mask token 120 | if prob < 0.5: 121 | tokens[i] = self.vocab.word2idx[''] 122 | # 10% randomly change token to random token 123 | elif prob < 0.6: 124 | tokens[i] = random.randrange(len(self.vocab)) 125 | # 40% randomly remove the token 126 | else: 127 | tokens[i] = self.vocab(token) 128 | deleted_idx.append(i) 129 | else: 130 | tokens[i] = self.vocab(token) 131 | if len(deleted_idx) != 0: 132 | tokens = [tokens[i] for i in range(len(tokens)) if i not in deleted_idx] 133 | target = torch.Tensor(tokens) 134 | return target 135 | 136 | def process_image(self, image, img_enhance): 137 | enhance = img_enhance if self.data_split == 'train' else False 138 | if enhance: # Size augmentation on region features. 139 | num_features = image.shape[0] 140 | rand_list = np.random.rand(num_features) 141 | tmp = image[np.where(rand_list > 0.20)] 142 | while tmp.size(1) <= 1: 143 | rand_list = np.random.rand(num_features) 144 | tmp = image[np.where(rand_list > 0.20)] 145 | return tmp 146 | else: 147 | return image 148 | 149 | def process_image_2(self, image, img_enhance): 150 | enhance = img_enhance if self.data_split == 'train' else False 151 | if enhance: # Size augmentation on region features. 152 | num_features = image.shape[0] 153 | rand_list = np.random.rand(num_features) 154 | tmp = image[np.where(rand_list > 0.20)] 155 | while tmp.size(1) <= 1: 156 | rand_list = np.random.rand(num_features) 157 | tmp = image[np.where(rand_list > 0.20)] 158 | image[np.where(rand_list < 0.20)] = 1e-10 159 | return image 160 | else: 161 | return image 162 | 163 | def __getitem__(self, index): 164 | # handle the image redundancy 165 | caption = self.captions[index] 166 | if self.data_split == 'train': 167 | if self.noise_type == 'DECL': 168 | img_id = self.noisy_inx[index // self.im_div] 169 | else: 170 | img_id = self.noisy_inx[index] # NCR 171 | else: 172 | img_id = index // self.im_div 173 | image = torch.Tensor(self.images[img_id]) 174 | # Convert caption (string) to word ids. 175 | a = True 176 | if self.module == 'VSEinfty': 177 | target = self.process_caption(caption, caption_enhance=a) 178 | image = self.process_image(image, img_enhance=a) 179 | else: 180 | target = self.process_caption(caption, caption_enhance=a) 181 | image = self.process_image_2(image, img_enhance=a) 182 | 183 | return image, target, index 184 | 185 | def __len__(self): 186 | return self.length 187 | 188 | def get_real_correspondence(self): 189 | return self.real_correspondence 190 | 191 | 192 | def collate_fn(data): 193 | # Sort a data list by caption length 194 | data.sort(key=lambda x: len(x[1]), reverse=True) 195 | images, captions, ids = zip(*data) 196 | 197 | img_lengths = [len(image) for image in images] 198 | all_images = torch.zeros(len(images), max(img_lengths), images[0].size(-1)) 199 | for i, image in enumerate(images): 200 | end = img_lengths[i] 201 | all_images[i, :end] = image[:end] 202 | img_lengths = torch.Tensor(img_lengths) 203 | # Merget captions 204 | lengths = [len(cap) for cap in captions] 205 | targets = torch.zeros(len(captions), max(lengths)).long() 206 | for i, cap in enumerate(captions): 207 | end = lengths[i] 208 | targets[i, :end] = cap[:end] 209 | lengths = torch.Tensor(lengths) 210 | return all_images, img_lengths, targets, lengths, list(ids) 211 | 212 | 213 | def get_precomp_loader(data_path, data_split, vocab, opt, batch_size=100, 214 | shuffle=True, num_workers=2, logger=None): 215 | if shuffle: 216 | dset = PrecompDataset(data_path, data_split, vocab, opt, logger) 217 | data_loader = torch.utils.data.DataLoader(dataset=dset, 218 | batch_size=batch_size, 219 | shuffle=shuffle, 220 | pin_memory=False, 221 | num_workers=num_workers, 222 | collate_fn=collate_fn) 223 | return data_loader 224 | else: 225 | dset = PrecompDataset(data_path, data_split, vocab, opt, logger) 226 | 227 | real_correspondence = dset.get_real_correspondence() 228 | 229 | data_loader = torch.utils.data.DataLoader(dataset=dset, 230 | batch_size=batch_size, 231 | shuffle=shuffle, 232 | pin_memory=False, 233 | num_workers=num_workers, 234 | collate_fn=collate_fn) 235 | # return data_loader, real_correspondence 236 | 237 | return data_loader 238 | 239 | 240 | def get_loaders(data_name, vocab, batch_size, workers, opt, logger): 241 | # get the data path 242 | dpath = os.path.join(opt.data_path, data_name) 243 | 244 | # get the train_loader 245 | train_loader = get_precomp_loader(dpath, 'train', vocab, opt, 246 | batch_size, True, workers, logger) 247 | # get the val_loader 248 | val_loader = get_precomp_loader(dpath, 'dev', vocab, opt, 249 | batch_size, False, workers, logger) 250 | # get the test_loader 251 | test_loader = get_precomp_loader(dpath, 'test', vocab, opt, 252 | batch_size, False, workers, logger) 253 | return train_loader, val_loader, test_loader 254 | 255 | 256 | def get_test_loader(split_name, data_name, vocab, batch_size, workers, opt): 257 | # get the data path 258 | data_path = "/data" 259 | dpath = os.path.join(data_path, data_name) 260 | 261 | # get the test_loader 262 | test_loader = get_precomp_loader(dpath, split_name, vocab, opt, 263 | batch_size, False, workers) 264 | return test_loader 265 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /model/SGRAF.py: -------------------------------------------------------------------------------- 1 | """SGRAF model""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | import torch.nn.functional as F 7 | 8 | import torch.backends.cudnn as cudnn 9 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 10 | from torch.nn.utils.clip_grad import clip_grad_norm_ 11 | 12 | import numpy as np 13 | from collections import OrderedDict 14 | 15 | 16 | def l1norm(X, dim, eps=1e-8): 17 | """L1-normalize columns of X""" 18 | norm = torch.abs(X).sum(dim=dim, keepdim=True) + eps 19 | X = torch.div(X, norm) 20 | return X 21 | 22 | 23 | def l2norm(X, dim=-1, eps=1e-8): 24 | """L2-normalize columns of X""" 25 | norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps 26 | X = torch.div(X, norm) 27 | return X 28 | 29 | 30 | def cosine_sim(x1, x2, dim=-1, eps=1e-8): 31 | """Returns cosine similarity between x1 and x2, computed along dim.""" 32 | w12 = torch.sum(x1 * x2, dim) 33 | w1 = torch.norm(x1, 2, dim) 34 | w2 = torch.norm(x2, 2, dim) 35 | return (w12 / (w1 * w2).clamp(min=eps)).squeeze() 36 | 37 | 38 | class EncoderImage(nn.Module): 39 | """ 40 | Build local region representations by common-used FC-layer. 41 | Args: - images: raw local detected regions, shape: (batch_size, 36, 2048). 42 | Returns: - img_emb: finial local region embeddings, shape: (batch_size, 36, 1024). 43 | """ 44 | def __init__(self, img_dim, embed_size, no_imgnorm=False): 45 | super(EncoderImage, self).__init__() 46 | self.embed_size = embed_size 47 | self.no_imgnorm = no_imgnorm 48 | self.fc = nn.Linear(img_dim, embed_size) 49 | 50 | self.init_weights() 51 | 52 | def init_weights(self): 53 | """Xavier initialization for the fully connected layer""" 54 | r = np.sqrt(6.) / np.sqrt(self.fc.in_features + 55 | self.fc.out_features) 56 | self.fc.weight.data.uniform_(-r, r) 57 | self.fc.bias.data.fill_(0) 58 | 59 | def forward(self, images): 60 | """Extract image feature vectors.""" 61 | # assuming that the precomputed features are already l2-normalized 62 | img_emb = self.fc(images) 63 | 64 | # normalize in the joint embedding space 65 | if not self.no_imgnorm: 66 | img_emb = l2norm(img_emb, dim=-1) 67 | 68 | return img_emb 69 | 70 | def load_state_dict(self, state_dict): 71 | """Overwrite the default one to accept state_dict from Full model""" 72 | own_state = self.state_dict() 73 | new_state = OrderedDict() 74 | for name, param in state_dict.items(): 75 | if name in own_state: 76 | new_state[name] = param 77 | 78 | super(EncoderImage, self).load_state_dict(new_state) 79 | 80 | 81 | class EncoderText(nn.Module): 82 | """ 83 | Build local word representations by common-used Bi-GRU or GRU. 84 | Args: - images: raw local word ids, shape: (batch_size, L). 85 | Returns: - img_emb: final local word embeddings, shape: (batch_size, L, 1024). 86 | """ 87 | def __init__(self, vocab_size, word_dim, embed_size, num_layers, 88 | use_bi_gru=False, no_txtnorm=False): 89 | super(EncoderText, self).__init__() 90 | self.embed_size = embed_size 91 | self.no_txtnorm = no_txtnorm 92 | 93 | # word embedding 94 | self.embed = nn.Embedding(vocab_size, word_dim) 95 | self.dropout = nn.Dropout(0.4) 96 | 97 | # caption embedding 98 | self.use_bi_gru = use_bi_gru 99 | self.cap_rnn = nn.GRU(word_dim, embed_size, num_layers, batch_first=True, bidirectional=use_bi_gru) 100 | 101 | self.init_weights() 102 | 103 | def init_weights(self): 104 | self.embed.weight.data.uniform_(-0.1, 0.1) 105 | 106 | def forward(self, captions, lengths): 107 | """Handles variable size captions""" 108 | # embed word ids to vectors 109 | cap_emb = self.embed(captions) 110 | cap_emb = self.dropout(cap_emb) 111 | 112 | # pack the caption 113 | packed = pack_padded_sequence(cap_emb, lengths, batch_first=True, enforce_sorted = False) 114 | 115 | # forward propagate RNN 116 | out, _ = self.cap_rnn(packed) 117 | 118 | # reshape output to (batch_size, hidden_size) 119 | cap_emb, _ = pad_packed_sequence(out, batch_first=True) 120 | 121 | if self.use_bi_gru: 122 | cap_emb = (cap_emb[:, :, :cap_emb.size(2)//2] + cap_emb[:, :, cap_emb.size(2)//2:])/2 123 | 124 | # normalization in the joint embedding space 125 | if not self.no_txtnorm: 126 | cap_emb = l2norm(cap_emb, dim=-1) 127 | 128 | return cap_emb 129 | 130 | 131 | class VisualSA(nn.Module): 132 | """ 133 | Build global image representations by self-attention. 134 | Args: - local: local region embeddings, shape: (batch_size, 36, 1024) 135 | - raw_global: raw image by averaging regions, shape: (batch_size, 1024) 136 | Returns: - new_global: final image by self-attention, shape: (batch_size, 1024). 137 | """ 138 | def __init__(self, embed_dim, dropout_rate, num_region): 139 | super(VisualSA, self).__init__() 140 | 141 | self.embedding_local = nn.Sequential(nn.Linear(embed_dim, embed_dim), 142 | nn.BatchNorm1d(num_region), 143 | nn.Tanh(), nn.Dropout(dropout_rate)) 144 | self.embedding_global = nn.Sequential(nn.Linear(embed_dim, embed_dim), 145 | nn.BatchNorm1d(embed_dim), 146 | nn.Tanh(), nn.Dropout(dropout_rate)) 147 | self.embedding_common = nn.Sequential(nn.Linear(embed_dim, 1)) 148 | 149 | self.init_weights() 150 | self.softmax = nn.Softmax(dim=1) 151 | 152 | def init_weights(self): 153 | for embeddings in self.children(): 154 | for m in embeddings: 155 | if isinstance(m, nn.Linear): 156 | r = np.sqrt(6.) / np.sqrt(m.in_features + m.out_features) 157 | m.weight.data.uniform_(-r, r) 158 | m.bias.data.fill_(0) 159 | elif isinstance(m, nn.BatchNorm1d): 160 | m.weight.data.fill_(1) 161 | m.bias.data.zero_() 162 | 163 | def forward(self, local, raw_global): 164 | # compute embedding of local regions and raw global image 165 | l_emb = self.embedding_local(local) 166 | g_emb = self.embedding_global(raw_global) 167 | 168 | # compute the normalized weights, shape: (batch_size, 36) 169 | g_emb = g_emb.unsqueeze(1).repeat(1, l_emb.size(1), 1) 170 | common = l_emb.mul(g_emb) 171 | weights = self.embedding_common(common).squeeze(2) 172 | weights = self.softmax(weights) 173 | 174 | # compute final image, shape: (batch_size, 1024) 175 | new_global = (weights.unsqueeze(2) * local).sum(dim=1) 176 | new_global = l2norm(new_global, dim=-1) 177 | 178 | return new_global 179 | 180 | 181 | class TextSA(nn.Module): 182 | """ 183 | Build global text representations by self-attention. 184 | Args: - local: local word embeddings, shape: (batch_size, L, 1024) 185 | - raw_global: raw text by averaging words, shape: (batch_size, 1024) 186 | Returns: - new_global: final text by self-attention, shape: (batch_size, 1024). 187 | """ 188 | 189 | def __init__(self, embed_dim, dropout_rate): 190 | super(TextSA, self).__init__() 191 | 192 | self.embedding_local = nn.Sequential(nn.Linear(embed_dim, embed_dim), 193 | nn.Tanh(), nn.Dropout(dropout_rate)) 194 | self.embedding_global = nn.Sequential(nn.Linear(embed_dim, embed_dim), 195 | nn.Tanh(), nn.Dropout(dropout_rate)) 196 | self.embedding_common = nn.Sequential(nn.Linear(embed_dim, 1)) 197 | 198 | self.init_weights() 199 | self.softmax = nn.Softmax(dim=1) 200 | 201 | def init_weights(self): 202 | for embeddings in self.children(): 203 | for m in embeddings: 204 | if isinstance(m, nn.Linear): 205 | r = np.sqrt(6.) / np.sqrt(m.in_features + m.out_features) 206 | m.weight.data.uniform_(-r, r) 207 | m.bias.data.fill_(0) 208 | elif isinstance(m, nn.BatchNorm1d): 209 | m.weight.data.fill_(1) 210 | m.bias.data.zero_() 211 | 212 | def forward(self, local, raw_global): 213 | # compute embedding of local words and raw global text 214 | l_emb = self.embedding_local(local) 215 | g_emb = self.embedding_global(raw_global) 216 | 217 | # compute the normalized weights, shape: (batch_size, L) 218 | g_emb = g_emb.unsqueeze(1).repeat(1, l_emb.size(1), 1) 219 | common = l_emb.mul(g_emb) 220 | weights = self.embedding_common(common).squeeze(2) 221 | weights = self.softmax(weights) 222 | 223 | # compute final text, shape: (batch_size, 1024) 224 | new_global = (weights.unsqueeze(2) * local).sum(dim=1) 225 | new_global = l2norm(new_global, dim=-1) 226 | 227 | return new_global 228 | 229 | 230 | class GraphReasoning(nn.Module): 231 | """ 232 | Perform the similarity graph reasoning with a full-connected graph 233 | Args: - sim_emb: global and local alignments, shape: (batch_size, L+1, 256) 234 | Returns; - sim_sgr: reasoned graph nodes after several steps, shape: (batch_size, L+1, 256) 235 | """ 236 | def __init__(self, sim_dim): 237 | super(GraphReasoning, self).__init__() 238 | 239 | self.graph_query_w = nn.Linear(sim_dim, sim_dim) 240 | self.graph_key_w = nn.Linear(sim_dim, sim_dim) 241 | self.sim_graph_w = nn.Linear(sim_dim, sim_dim) 242 | self.relu = nn.ReLU() 243 | 244 | self.init_weights() 245 | 246 | def forward(self, sim_emb): 247 | sim_query = self.graph_query_w(sim_emb) 248 | sim_key = self.graph_key_w(sim_emb) 249 | sim_edge = torch.softmax(torch.bmm(sim_query, sim_key.permute(0, 2, 1)), dim=-1) 250 | sim_sgr = torch.bmm(sim_edge, sim_emb) 251 | sim_sgr = self.relu(self.sim_graph_w(sim_sgr)) 252 | return sim_sgr 253 | 254 | def init_weights(self): 255 | for m in self.children(): 256 | if isinstance(m, nn.Linear): 257 | r = np.sqrt(6.) / np.sqrt(m.in_features + m.out_features) 258 | m.weight.data.uniform_(-r, r) 259 | m.bias.data.fill_(0) 260 | elif isinstance(m, nn.BatchNorm1d): 261 | m.weight.data.fill_(1) 262 | m.bias.data.zero_() 263 | 264 | 265 | class AttentionFiltration(nn.Module): 266 | """ 267 | Perform the similarity Attention Filtration with a gate-based attention 268 | Args: - sim_emb: global and local alignments, shape: (batch_size, L+1, 256) 269 | Returns; - sim_saf: aggregated alignment after attention filtration, shape: (batch_size, 256) 270 | """ 271 | def __init__(self, sim_dim): 272 | super(AttentionFiltration, self).__init__() 273 | 274 | self.attn_sim_w = nn.Linear(sim_dim, 1) 275 | self.bn = nn.BatchNorm1d(1) 276 | 277 | self.init_weights() 278 | 279 | def forward(self, sim_emb): 280 | sim_attn = l1norm(torch.sigmoid(self.bn(self.attn_sim_w(sim_emb).permute(0, 2, 1))), dim=-1) 281 | sim_saf = torch.matmul(sim_attn, sim_emb) 282 | sim_saf = l2norm(sim_saf.squeeze(1), dim=-1) 283 | return sim_saf 284 | 285 | def init_weights(self): 286 | for m in self.children(): 287 | if isinstance(m, nn.Linear): 288 | r = np.sqrt(6.) / np.sqrt(m.in_features + m.out_features) 289 | m.weight.data.uniform_(-r, r) 290 | m.bias.data.fill_(0) 291 | elif isinstance(m, nn.BatchNorm1d): 292 | m.weight.data.fill_(1) 293 | m.bias.data.zero_() 294 | 295 | 296 | class EncoderSimilarity(nn.Module): 297 | """ 298 | Compute the image-text similarity by SGR, SAF, AVE 299 | Args: - img_emb: local region embeddings, shape: (batch_size, 36, 1024) 300 | - cap_emb: local word embeddings, shape: (batch_size, L, 1024) 301 | Returns: 302 | - sim_all: final image-text similarities, shape: (batch_size, batch_size). 303 | """ 304 | def __init__(self, embed_size, sim_dim, module_name='AVE', sgr_step=3): 305 | super(EncoderSimilarity, self).__init__() 306 | self.module_name = module_name 307 | 308 | self.v_global_w = VisualSA(embed_size, 0.4, 36) 309 | self.t_global_w = TextSA(embed_size, 0.4) 310 | 311 | self.sim_tranloc_w = nn.Linear(embed_size, sim_dim) 312 | self.sim_tranglo_w = nn.Linear(embed_size, sim_dim) 313 | 314 | self.sim_eval_w = nn.Linear(sim_dim, 1) 315 | self.sigmoid = nn.Sigmoid() 316 | 317 | if module_name == 'SGR': 318 | self.SGR_module = nn.ModuleList([GraphReasoning(sim_dim) for i in range(sgr_step)]) 319 | elif module_name == 'SAF': 320 | self.SAF_module = AttentionFiltration(sim_dim) 321 | else: 322 | raise ValueError('Invalid input of opt.module_name in opts.py') 323 | 324 | self.init_weights() 325 | 326 | def forward(self, img_emb, cap_emb, cap_lens): 327 | sim_all = [] 328 | expand_sim_all = [] 329 | n_image = img_emb.size(0) 330 | n_caption = cap_emb.size(0) 331 | 332 | # get enhanced global images by self-attention 333 | img_ave = torch.mean(img_emb, 1) 334 | img_glo = self.v_global_w(img_emb, img_ave) 335 | 336 | for i in range(n_caption): 337 | # get the i-th sentence 338 | n_word = cap_lens[i] 339 | cap_i = cap_emb[i, :n_word, :].unsqueeze(0) 340 | cap_i_expand = cap_i.repeat(n_image, 1, 1) 341 | 342 | # get enhanced global i-th text by self-attention 343 | cap_ave_i = torch.mean(cap_i, 1) 344 | cap_glo_i = self.t_global_w(cap_i, cap_ave_i) 345 | 346 | # local-global alignment construction 347 | Context_img = SCAN_attention(cap_i_expand, img_emb, smooth=9.0) 348 | sim_loc = torch.pow(torch.sub(Context_img, cap_i_expand), 2) 349 | sim_loc = l2norm(self.sim_tranloc_w(sim_loc), dim=-1) 350 | 351 | sim_glo = torch.pow(torch.sub(img_glo, cap_glo_i), 2) 352 | sim_glo = l2norm(self.sim_tranglo_w(sim_glo), dim=-1) 353 | 354 | # concat the global and local alignments 355 | sim_emb = torch.cat([sim_glo.unsqueeze(1), sim_loc], 1) 356 | 357 | # compute the final similarity vector 358 | if self.module_name == 'SGR': 359 | for module in self.SGR_module: 360 | sim_emb = module(sim_emb) 361 | sim_vec = sim_emb[:, 0, :] 362 | else: 363 | sim_vec = self.SAF_module(sim_emb) 364 | 365 | # compute the final similarity score 366 | sim_i = self.sigmoid(self.sim_eval_w(sim_vec)) 367 | sim_all.append(sim_i) 368 | 369 | expand_sim_all.append(self.sim_eval_w(sim_vec)) 370 | 371 | # (n_image, n_caption) 372 | sim_all = torch.cat(sim_all, 1) 373 | expand_sim_all = torch.cat(expand_sim_all, 1) 374 | 375 | return sim_all, expand_sim_all 376 | 377 | def init_weights(self): 378 | for m in self.children(): 379 | if isinstance(m, nn.Linear): 380 | r = np.sqrt(6.) / np.sqrt(m.in_features + m.out_features) 381 | m.weight.data.uniform_(-r, r) 382 | m.bias.data.fill_(0) 383 | elif isinstance(m, nn.BatchNorm1d): 384 | m.weight.data.fill_(1) 385 | m.bias.data.zero_() 386 | 387 | 388 | def SCAN_attention(query, context, smooth, eps=1e-8): 389 | """ 390 | query: (n_context, queryL, d) 391 | context: (n_context, sourceL, d) 392 | """ 393 | # --> (batch, d, queryL) 394 | queryT = torch.transpose(query, 1, 2) 395 | 396 | # (batch, sourceL, d)(batch, d, queryL) 397 | # --> (batch, sourceL, queryL) 398 | attn = torch.bmm(context, queryT) 399 | 400 | attn = nn.LeakyReLU(0.1)(attn) 401 | attn = l2norm(attn, 2) 402 | 403 | # --> (batch, queryL, sourceL) 404 | attn = torch.transpose(attn, 1, 2).contiguous() 405 | # --> (batch, queryL, sourceL 406 | attn = F.softmax(attn*smooth, dim=2) 407 | 408 | # --> (batch, sourceL, queryL) 409 | attnT = torch.transpose(attn, 1, 2).contiguous() 410 | 411 | # --> (batch, d, sourceL) 412 | contextT = torch.transpose(context, 1, 2) 413 | # (batch x d x sourceL)(batch x sourceL x queryL) 414 | # --> (batch, d, queryL) 415 | weightedContext = torch.bmm(contextT, attnT) 416 | # --> (batch, queryL, d) 417 | weightedContext = torch.transpose(weightedContext, 1, 2) 418 | weightedContext = l2norm(weightedContext, dim=-1) 419 | 420 | return weightedContext 421 | 422 | 423 | class ContrastiveLoss(nn.Module): 424 | """ 425 | Compute contrastive loss 426 | """ 427 | def __init__(self, margin=0, max_violation=False): 428 | super(ContrastiveLoss, self).__init__() 429 | self.margin = margin 430 | self.max_violation = max_violation 431 | 432 | def forward(self, scores, max_violation): 433 | self.max_violation = max_violation 434 | # compute image-sentence score matrix 435 | diagonal = scores.diag().view(scores.size(0), 1) 436 | d1 = diagonal.expand_as(scores) 437 | d2 = diagonal.t().expand_as(scores) 438 | 439 | # compare every diagonal score to scores in its column 440 | # caption retrieval 441 | cost_s = (self.margin + scores - d1).clamp(min=0) 442 | # compare every diagonal score to scores in its row 443 | # image retrieval 444 | cost_im = (self.margin + scores - d2).clamp(min=0) 445 | 446 | # clear diagonals 447 | mask = torch.eye(scores.size(0)) > .5 448 | if torch.cuda.is_available(): 449 | I = mask.cuda() 450 | cost_s = cost_s.masked_fill_(I, 0) 451 | cost_im = cost_im.masked_fill_(I, 0) 452 | 453 | # keep the maximum violating negative for each query 454 | if self.max_violation: 455 | cost_s = cost_s.max(1)[0] 456 | cost_im = cost_im.max(0)[0] 457 | return cost_s.sum() + cost_im.sum() 458 | 459 | 460 | class SGRAF(object): 461 | """ 462 | Similarity Reasoning and Filtration (SGRAF) Network 463 | """ 464 | def __init__(self, opt): 465 | # Build Models 466 | self.grad_clip = opt.grad_clip 467 | self.img_enc = EncoderImage(opt.img_dim, opt.embed_size, 468 | no_imgnorm=opt.no_imgnorm) 469 | self.txt_enc = EncoderText(opt.vocab_size, opt.word_dim, 470 | opt.embed_size, opt.num_layers, 471 | use_bi_gru=opt.bi_gru, 472 | no_txtnorm=opt.no_txtnorm) 473 | self.sim_enc = EncoderSimilarity(opt.embed_size, opt.sim_dim, 474 | opt.module_name, opt.sgr_step) 475 | 476 | if torch.cuda.is_available(): 477 | self.img_enc.cuda() 478 | self.txt_enc.cuda() 479 | self.sim_enc.cuda() 480 | cudnn.benchmark = True 481 | 482 | # Loss and Optimizer 483 | self.criterion = ContrastiveLoss(margin=opt.margin, 484 | max_violation=opt.max_violation) 485 | params = list(self.txt_enc.parameters()) 486 | params += list(self.img_enc.parameters()) 487 | params += list(self.sim_enc.parameters()) 488 | self.params = params 489 | 490 | self.optimizer = torch.optim.Adam(params, lr=opt.learning_rate) 491 | self.Eiters = 0 492 | 493 | def state_dict(self): 494 | state_dict = [self.img_enc.state_dict(), self.txt_enc.state_dict(), self.sim_enc.state_dict()] 495 | return state_dict 496 | 497 | def load_state_dict(self, state_dict): 498 | self.img_enc.load_state_dict(state_dict[0]) 499 | self.txt_enc.load_state_dict(state_dict[1]) 500 | self.sim_enc.load_state_dict(state_dict[2]) 501 | 502 | def train_start(self): 503 | """switch to train mode""" 504 | self.img_enc.train() 505 | self.txt_enc.train() 506 | self.sim_enc.train() 507 | 508 | def val_start(self): 509 | """switch to evaluate mode""" 510 | self.img_enc.eval() 511 | self.txt_enc.eval() 512 | self.sim_enc.eval() 513 | 514 | def forward_emb(self, images, captions, lengths): 515 | """Compute the image and caption embeddings""" 516 | if torch.cuda.is_available(): 517 | images = images.cuda() 518 | captions = captions.cuda() 519 | 520 | # Forward feature encoding 521 | img_embs = self.img_enc(images) 522 | cap_embs = self.txt_enc(captions, lengths) 523 | return img_embs, cap_embs, lengths 524 | 525 | def forward_sim(self, img_embs, cap_embs, cap_lens): 526 | # Forward similarity encoding 527 | sims, expand_sims = self.sim_enc(img_embs, cap_embs, cap_lens) 528 | 529 | expand_sims, evidences, sims_tanh = torch.sigmoid(expand_sims), torch.exp(torch.tanh(expand_sims) / 0.1), torch.tanh(expand_sims) 530 | return sims, expand_sims, evidences, sims_tanh 531 | 532 | def forward_loss(self, sims, epoch = None, **kwargs): 533 | """Compute the loss given pairs of image and caption embeddings 534 | """ 535 | if epoch >= 1: 536 | loss = self.criterion(sims, max_violation=True) 537 | else: 538 | loss = self.criterion(sims, max_violation=False) 539 | self.logger.update('Loss', loss.item(), sims.size(0)) 540 | return loss 541 | 542 | def train_emb(self, images, captions, lengths, ids=None, epoch=None, *args): 543 | """One training step given images and captions. 544 | """ 545 | self.Eiters += 1 546 | self.logger.update('Eit', self.Eiters) 547 | self.logger.update('lr', self.optimizer.param_groups[0]['lr']) 548 | 549 | # compute the embeddings 550 | img_embs, cap_embs, cap_lens = self.forward_emb(images, captions, lengths) 551 | sims = self.forward_sim(img_embs, cap_embs, cap_lens) 552 | 553 | # measure accuracy and record loss 554 | self.optimizer.zero_grad() 555 | loss = self.forward_loss(sims,epoch=epoch) 556 | 557 | # compute gradient and do SGD step 558 | loss.backward() 559 | if self.grad_clip > 0: 560 | clip_grad_norm_(self.params, self.grad_clip) 561 | self.optimizer.step() 562 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | """Evaluation""" 2 | 3 | from __future__ import print_function 4 | import os 5 | import sys 6 | from tkinter import Variable 7 | import torch 8 | import numpy as np 9 | from utils import cosine_similarity_matrix 10 | 11 | from model.CRCL import CRCL 12 | from data import get_test_loader 13 | from vocab import deserialize_vocab 14 | from collections import OrderedDict 15 | 16 | 17 | class AverageMeter(object): 18 | """Computes and stores the average and current value""" 19 | 20 | def __init__(self): 21 | self.reset() 22 | 23 | def reset(self): 24 | self.val = 0 25 | self.avg = 0 26 | self.sum = 0 27 | self.count = 0 28 | 29 | def update(self, val, n=0): 30 | self.val = val 31 | self.sum += val * n 32 | self.count += n 33 | self.avg = self.sum / (.0001 + self.count) 34 | 35 | def __str__(self): 36 | """String representation for logging 37 | """ 38 | # for values that should be recorded exactly e.g. iteration number 39 | if self.count == 0: 40 | return str(self.val) 41 | # for stats 42 | return '%.4f (%.4f)' % (self.val, self.avg) 43 | 44 | 45 | class LogCollector(object): 46 | """A collection of logging objects that can change from train to val""" 47 | 48 | def __init__(self): 49 | # to keep the order of logged variables deterministic 50 | self.meters = OrderedDict() 51 | 52 | def update(self, k, v, n=0): 53 | # create a new meter if previously not recorded 54 | if k not in self.meters: 55 | self.meters[k] = AverageMeter() 56 | self.meters[k].update(v, n) 57 | 58 | def __str__(self): 59 | """Concatenate the meters in one log line 60 | """ 61 | s = '' 62 | for i, (k, v) in enumerate(self.meters.items()): 63 | if i > 0: 64 | s += ' ' 65 | s += k + ' ' + str(v) 66 | return s 67 | 68 | def tb_log(self, tb_logger, prefix='', step=None): 69 | """Log using tensorboard 70 | """ 71 | for k, v in self.meters.items(): 72 | tb_logger.log_value(prefix + k, v.val, step=step) 73 | 74 | 75 | def encode_data(model, data_loader, log_step=10, logging=print, sub=0): 76 | """Encode all images and captions loadable by `data_loader` 77 | """ 78 | val_logger = LogCollector() 79 | 80 | # switch to evaluate mode 81 | model.val_start() 82 | 83 | # np array to keep all the embeddings 84 | img_embs = None 85 | cap_embs = None 86 | 87 | max_n_word = 0 88 | for i, (images, _, captions, lengths, ids) in enumerate(data_loader): 89 | lengths = lengths.numpy().astype(np.int64).tolist() 90 | max_n_word = max(max_n_word, max(lengths)) 91 | ids_ = [] 92 | for i, (images, _, captions, lengths, ids) in enumerate(data_loader): 93 | lengths = lengths.numpy().astype(np.int64).tolist() 94 | # make sure val logger is used 95 | model.logger = val_logger 96 | ids_ += ids 97 | # compute the embeddings 98 | with torch.no_grad(): 99 | img_emb, cap_emb, cap_len = model.forward_emb(images, captions, lengths) 100 | if img_embs is None: 101 | # img_embs = np.zeros((100 * 100, img_emb.size(1), img_emb.size(2))) 102 | # cap_embs = np.zeros((100 * 100, max_n_word, cap_emb.size(2))) 103 | # cap_lens = [0] * (100 * 100) 104 | img_embs = np.zeros((len(data_loader.dataset), img_emb.size(1), img_emb.size(2))) 105 | cap_embs = np.zeros((len(data_loader.dataset), max_n_word, cap_emb.size(2))) 106 | cap_lens = [0] * len(data_loader.dataset) 107 | # cache embeddings 108 | img_embs[ids, :, :] = img_emb.data.cpu().numpy().copy() 109 | cap_embs[ids, :max(lengths), :] = cap_emb.data.cpu().numpy().copy() 110 | 111 | for j, nid in enumerate(ids): 112 | cap_lens[nid] = cap_len[j] 113 | 114 | del images, captions 115 | if sub > 0: 116 | print(f"===>batch {i}") 117 | if sub > 0 and i > sub: 118 | break 119 | if sub > 0: 120 | return np.array(img_embs)[ids_].tolist(), np.array(cap_embs)[ids_].tolist(), np.array(cap_lens)[ 121 | ids_].tolist(), ids_ 122 | else: 123 | return img_embs, cap_embs, cap_lens 124 | 125 | 126 | def shard_attn_scores(model, img_embs, cap_embs, cap_lens, opt, shard_size=100): 127 | n_im_shard = (len(img_embs) - 1) // shard_size + 1 128 | n_cap_shard = (len(cap_embs) - 1) // shard_size + 1 129 | 130 | sims = np.zeros((len(img_embs), len(cap_embs))) 131 | for i in range(n_im_shard): 132 | im_start, im_end = shard_size * i, min(shard_size * (i + 1), len(img_embs)) 133 | for j in range(n_cap_shard): 134 | sys.stdout.write('\r>> shard_attn_scores batch (%d,%d)' % (i, j)) 135 | ca_start, ca_end = shard_size * j, min(shard_size * (j + 1), len(cap_embs)) 136 | 137 | with torch.no_grad(): 138 | im = torch.from_numpy(img_embs[im_start:im_end]).float().cuda() 139 | ca = torch.from_numpy(cap_embs[ca_start:ca_end]).float().cuda() 140 | l = cap_lens[ca_start:ca_end] 141 | sim, _, _, _ = model.forward_sim(im, ca, l) 142 | 143 | sims[im_start:im_end, ca_start:ca_end] = sim.data.cpu().numpy() 144 | sys.stdout.write('\n') 145 | return sims 146 | 147 | 148 | def shard_xattn_t2i(images, captions, caplens, opt, shard_size=128): 149 | """ 150 | Computer pairwise t2i image-caption distance with locality sharding 151 | """ 152 | n_im_shard = (len(images) - 1) // shard_size + 1 153 | n_cap_shard = (len(captions) - 1) // shard_size + 1 154 | 155 | d = np.zeros((len(images), len(captions))) 156 | for i in range(n_im_shard): 157 | im_start, im_end = shard_size * i, min(shard_size * (i + 1), len(images)) 158 | for j in range(n_cap_shard): 159 | sys.stdout.write('\r>> shard_xattn_t2i batch (%d,%d)' % (i, j)) 160 | cap_start, cap_end = shard_size * j, min(shard_size * (j + 1), len(captions)) 161 | im = Variable(torch.from_numpy(images[im_start:im_end]), volatile=True).cuda() 162 | s = Variable(torch.from_numpy(captions[cap_start:cap_end]), volatile=True).cuda() 163 | l = caplens[cap_start:cap_end] 164 | sim = xattn_score_t2i(im, s, l, opt) 165 | d[im_start:im_end, cap_start:cap_end] = sim.data.cpu().numpy() 166 | sys.stdout.write('\n') 167 | return d 168 | 169 | 170 | def shard_xattn_i2t(images, captions, caplens, opt, shard_size=128): 171 | """ 172 | Computer pairwise i2t image-caption distance with locality sharding 173 | """ 174 | n_im_shard = (len(images) - 1) // shard_size + 1 175 | n_cap_shard = (len(captions) - 1) // shard_size + 1 176 | 177 | d = np.zeros((len(images), len(captions))) 178 | for i in range(n_im_shard): 179 | im_start, im_end = shard_size * i, min(shard_size * (i + 1), len(images)) 180 | for j in range(n_cap_shard): 181 | sys.stdout.write('\r>> shard_xattn_i2t batch (%d,%d)' % (i, j)) 182 | cap_start, cap_end = shard_size * j, min(shard_size * (j + 1), len(captions)) 183 | im = Variable(torch.from_numpy(images[im_start:im_end]), volatile=True).cuda() 184 | s = Variable(torch.from_numpy(captions[cap_start:cap_end]), volatile=True).cuda() 185 | l = caplens[cap_start:cap_end] 186 | sim = xattn_score_i2t(im, s, l, opt) 187 | d[im_start:im_end, cap_start:cap_end] = sim.data.cpu().numpy() 188 | sys.stdout.write('\n') 189 | return d 190 | 191 | 192 | def t2i(npts, sims, per_captions=1, return_ranks=False): 193 | """ 194 | Text->Images (Image Search) 195 | Images: (N, n_region, d) matrix of images 196 | Captions: (per_captions * N, max_n_word, d) matrix of captions 197 | CapLens: (per_captions * N) array of caption lengths 198 | sims: (N, per_captions * N) matrix of similarity im-cap 199 | """ 200 | ranks = np.zeros(per_captions * npts) 201 | top1 = np.zeros(per_captions * npts) 202 | top5 = np.zeros((per_captions * npts, 5), dtype=int) 203 | 204 | # --> (per_captions * N(caption), N(image)) 205 | sims = sims.T 206 | retreivaled_index = [] 207 | for index in range(npts): 208 | for i in range(per_captions): 209 | inds = np.argsort(sims[per_captions * index + i])[::-1] 210 | retreivaled_index.append(inds) 211 | ranks[per_captions * index + i] = np.where(inds == index)[0][0] 212 | top1[per_captions * index + i] = inds[0] 213 | top5[per_captions * index + i] = inds[0:5] 214 | 215 | # Compute metrics 216 | r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 217 | r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 218 | r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 219 | medr = np.floor(np.median(ranks)) + 1 220 | meanr = ranks.mean() + 1 221 | if return_ranks: 222 | return (r1, r5, r10, medr, meanr), (ranks, top1, top5, retreivaled_index) 223 | else: 224 | return (r1, r5, r10, medr, meanr) 225 | 226 | 227 | def i2t(npts, sims, per_captions=1, return_ranks=False): 228 | """ 229 | Images->Text (Image Annotation) 230 | Images: (N, n_region, d) matrix of images 231 | Captions: (per_captions * N, max_n_word, d) matrix of captions 232 | CapLens: (per_captions * N) array of caption lengths 233 | sims: (N, per_captions * N) matrix of similarity im-cap 234 | """ 235 | ranks = np.zeros(npts) 236 | top1 = np.zeros(npts) 237 | top5 = np.zeros((npts, 5), dtype=int) 238 | retreivaled_index = [] 239 | for index in range(npts): 240 | inds = np.argsort(sims[index])[::-1] 241 | retreivaled_index.append(inds) 242 | # Score 243 | rank = 1e20 244 | for i in range(per_captions * index, per_captions * index + per_captions, 1): 245 | tmp = np.where(inds == i)[0][0] 246 | if tmp < rank: 247 | rank = tmp 248 | ranks[index] = rank 249 | top1[index] = inds[0] 250 | top5[index] = inds[0:5] 251 | 252 | # Compute metrics 253 | r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 254 | r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 255 | r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 256 | medr = np.floor(np.median(ranks)) + 1 257 | meanr = ranks.mean() + 1 258 | if return_ranks: 259 | return (r1, r5, r10, medr, meanr), (ranks, top1, top5, retreivaled_index) 260 | else: 261 | return (r1, r5, r10, medr, meanr) 262 | 263 | 264 | def encode_data_vse(model, data_loader): 265 | """Encode all images and captions loadable by `data_loader` 266 | """ 267 | model.val_start() 268 | img_embs = None 269 | cap_embs = None 270 | for i, data_i in enumerate(data_loader): 271 | images, image_lengths, captions, caption_lengths, ids = data_i 272 | # print(images.size(),captions.size(),caption_lengths) 273 | img_emb, cap_emb = model.forward_emb(images, captions, caption_lengths, image_lengths) 274 | if img_embs is None: 275 | img_embs = np.zeros((len(data_loader.dataset), img_emb.size(1))) 276 | cap_embs = np.zeros((len(data_loader.dataset), cap_emb.size(1))) 277 | # cache embeddings 278 | img_embs[ids, :] = img_emb.data.cpu().numpy().copy() 279 | cap_embs[ids, :] = cap_emb.data.cpu().numpy().copy() 280 | del images, captions 281 | 282 | return img_embs, cap_embs 283 | 284 | 285 | def i2t_vse(npts, sims, return_ranks=False, mode='coco', per=5): 286 | """ 287 | Images->Text (Image Annotation) 288 | Images: (N, n_region, d) matrix of images 289 | Captions: (5N, max_n_word, d) matrix of captions 290 | CapLens: (5N) array of caption lengths 291 | sims: (N, 5N) matrix of similarity im-cap 292 | """ 293 | ranks = np.zeros(npts) 294 | top1 = np.zeros(npts) 295 | for index in range(npts): 296 | inds = np.argsort(sims[index])[::-1] 297 | if mode == 'coco': 298 | rank = 1e20 299 | for i in range(per * index, per * index + per, 1): 300 | tmp = np.where(inds == i)[0][0] 301 | if tmp < rank: 302 | rank = tmp 303 | ranks[index] = rank 304 | top1[index] = inds[0] 305 | else: 306 | rank = np.where(inds == index)[0][0] 307 | ranks[index] = rank 308 | top1[index] = inds[0] 309 | 310 | # Compute metrics 311 | r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 312 | r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 313 | r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 314 | medr = np.floor(np.median(ranks)) + 1 315 | meanr = ranks.mean() + 1 316 | 317 | if return_ranks: 318 | return (r1, r5, r10, medr, meanr), (ranks, top1) 319 | else: 320 | return (r1, r5, r10, medr, meanr) 321 | 322 | 323 | def t2i_vse(npts, sims, return_ranks=False, mode='coco', per=5): 324 | """ 325 | Text->Images (Image Search) 326 | Images: (N, n_region, d) matrix of images 327 | Captions: (5N, max_n_word, d) matrix of captions 328 | CapLens: (5N) array of caption lengths 329 | sims: (N, 5N) matrix of similarity im-cap 330 | """ 331 | # npts = images.shape[0] 332 | 333 | if mode == 'coco': 334 | ranks = np.zeros(per * npts) 335 | top1 = np.zeros(per * npts) 336 | else: 337 | ranks = np.zeros(npts) 338 | top1 = np.zeros(npts) 339 | 340 | # --> (5N(caption), N(image)) 341 | sims = sims.T 342 | 343 | for index in range(npts): 344 | if mode == 'coco': 345 | for i in range(per): 346 | inds = np.argsort(sims[per * index + i])[::-1] 347 | ranks[per * index + i] = np.where(inds == index)[0][0] 348 | top1[per * index + i] = inds[0] 349 | else: 350 | inds = np.argsort(sims[index])[::-1] 351 | ranks[index] = np.where(inds == index)[0][0] 352 | top1[index] = inds[0] 353 | 354 | # Compute metrics 355 | r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 356 | r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 357 | r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 358 | medr = np.floor(np.median(ranks)) + 1 359 | meanr = ranks.mean() + 1 360 | if return_ranks: 361 | return (r1, r5, r10, medr, meanr), (ranks, top1) 362 | else: 363 | return (r1, r5, r10, medr, meanr) 364 | 365 | 366 | def validation_SGR_or_SAF(opt, val_loader, model, fold=False): 367 | # compute the encoding for all the validation images and captions 368 | if opt.data_name == 'cc152k_precomp': 369 | per_captions = 1 370 | elif opt.data_name in ['coco_precomp', 'f30k_precomp']: 371 | per_captions = 5 372 | else: 373 | print(f"No dataset") 374 | return 0 375 | 376 | model.val_start() 377 | print('Encoding with model') 378 | img_embs, cap_embs, cap_lens = encode_data(model.base_model, val_loader) 379 | # clear duplicate 5*images and keep 1*images FIXME 380 | if not fold: 381 | img_embs = np.array([img_embs[i] for i in range(0, len(img_embs), per_captions)]) 382 | # record computation time of validation 383 | print('Computing similarity from model') 384 | sims_mean = shard_attn_scores(model.base_model, img_embs, cap_embs, cap_lens, opt, shard_size=1000) 385 | print("Calculate similarity time with model") 386 | (r1, r5, r10, medr, meanr) = i2t(img_embs.shape[0], sims_mean, per_captions, return_ranks=False) 387 | print("Average i2t Recall: %.2f" % ((r1 + r5 + r10) / 3)) 388 | print("Image to text: %.2f, %.2f, %.2f, %.2f, %.2f" % (r1, r5, r10, medr, meanr)) 389 | # image retrieval 390 | (r1i, r5i, r10i, medri, meanr) = t2i(img_embs.shape[0], sims_mean, per_captions, return_ranks=False) 391 | print("Average t2i Recall: %.2f" % ((r1i + r5i + r10i) / 3)) 392 | print("Text to image: %.2f, %.2f, %.2f, %.2f, %.2f" % (r1i, r5i, r10i, medri, meanr)) 393 | r_sum = r1 + r5 + r10 + r1i + r5i + r10i 394 | print("Sum of Recall: %.2f" % (r_sum)) 395 | else: 396 | # 5fold cross-validation, only for MSCOCO 397 | results = [] 398 | for i in range(5): 399 | img_embs_shard = img_embs[i * 5000:(i + 1) * 5000:5] 400 | cap_embs_shard = cap_embs[i * 5000:(i + 1) * 5000] 401 | cap_lens_shard = cap_lens[i * 5000:(i + 1) * 5000] 402 | sims = shard_attn_scores(model.base_model, img_embs_shard, cap_embs_shard, cap_lens_shard, opt, 403 | shard_size=1000, 404 | ) 405 | 406 | print('Computing similarity from model') 407 | r, rt = i2t(img_embs_shard.shape[0], sims, per_captions, return_ranks=True) 408 | ri, rti = t2i(img_embs_shard.shape[0], sims, per_captions, return_ranks=True) 409 | 410 | print("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" % r) 411 | print("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" % ri) 412 | ar = (r[0] + r[1] + r[2]) / 3 413 | ari = (ri[0] + ri[1] + ri[2]) / 3 414 | rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2] 415 | print("rsum: %.1f ar: %.1f ari: %.1f" % (rsum, ar, ari)) 416 | results += [list(r) + list(ri) + [ar, ari, rsum]] 417 | print("-----------------------------------") 418 | print("Mean metrics: ") 419 | mean_metrics = tuple(np.array(results).mean(axis=0).flatten()) 420 | a = np.array(mean_metrics) 421 | print("Average i2t Recall: %.1f" % mean_metrics[11]) 422 | print("Image to text: %.1f %.1f %.1f %.1f %.1f" % 423 | mean_metrics[:5]) 424 | print("Average t2i Recall: %.1f" % mean_metrics[12]) 425 | print("Text to image: %.1f %.1f %.1f %.1f %.1f" % 426 | mean_metrics[5:10]) 427 | print("rsum: %.1f" % (a[0:3].sum() + a[5:8].sum())) 428 | 429 | 430 | def validation_SGRAF(opt, val_loader, models, fold=False): 431 | # compute the encoding for all the validation images and captions 432 | if opt.data_name == 'cc152k_precomp': 433 | per_captions = 1 434 | elif opt.data_name in ['coco_precomp', 'f30k_precomp']: 435 | per_captions = 5 436 | else: 437 | print(f"No dataset") 438 | return 0 439 | 440 | models[0].val_start() 441 | models[1].val_start() 442 | print('Encoding with model') 443 | img_embs, cap_embs, cap_lens = encode_data(models[0].base_model, val_loader, opt.log_step) 444 | img_embs1, cap_embs1, cap_lens1 = encode_data(models[1].base_model, val_loader, opt.log_step) 445 | if not fold: 446 | img_embs = np.array([img_embs[i] for i in range(0, len(img_embs), per_captions)]) 447 | img_embs1 = np.array([img_embs1[i] for i in range(0, len(img_embs1), per_captions)]) 448 | # record computation time of validation 449 | print('Computing similarity from model') 450 | sims_mean = shard_attn_scores(models[0].base_model, img_embs, cap_embs, cap_lens, opt, shard_size=1000, 451 | ) 452 | sims_mean += shard_attn_scores(models[1].base_model, img_embs1, cap_embs1, cap_lens1, opt, 453 | shard_size=1000, ) 454 | sims_mean /= 2 455 | 456 | # np.save("./sims/f30k/sim.npy", sims_mean) 457 | 458 | print("Calculate similarity time with model") 459 | # caption retrieval 460 | (r1, r5, r10, medr, meanr) = i2t(img_embs.shape[0], sims_mean, per_captions, return_ranks=False) 461 | print("Average i2t Recall: %.2f" % ((r1 + r5 + r10) / 3)) 462 | print("Image to text: %.2f, %.2f, %.2f, %.2f, %.2f" % (r1, r5, r10, medr, meanr)) 463 | # image retrieval 464 | (r1i, r5i, r10i, medri, meanr) = t2i(img_embs.shape[0], sims_mean, per_captions, return_ranks=False) 465 | print("Average t2i Recall: %.2f" % ((r1i + r5i + r10i) / 3)) 466 | print("Text to image: %.2f, %.2f, %.2f, %.2f, %.2f" % (r1i, r5i, r10i, medri, meanr)) 467 | r_sum = r1 + r5 + r10 + r1i + r5i + r10i 468 | print("Sum of Recall: %.2f" % (r_sum)) 469 | return r_sum 470 | else: 471 | # 5fold cross-validation, only for MSCOCO 472 | results = [] 473 | for i in range(5): 474 | img_embs_shard = img_embs[i * 5000:(i + 1) * 5000:5] 475 | cap_embs_shard = cap_embs[i * 5000:(i + 1) * 5000] 476 | cap_lens_shard = cap_lens[i * 5000:(i + 1) * 5000] 477 | 478 | img_embs_shard1 = img_embs1[i * 5000:(i + 1) * 5000:5] 479 | cap_embs_shard1 = cap_embs1[i * 5000:(i + 1) * 5000] 480 | cap_lens_shard1 = cap_lens1[i * 5000:(i + 1) * 5000] 481 | sims = shard_attn_scores(models[0].base_model, img_embs_shard, cap_embs_shard, cap_lens_shard, opt, 482 | shard_size=1000, 483 | ) 484 | sims += shard_attn_scores(models[1].base_model, img_embs_shard1, cap_embs_shard1, cap_lens_shard1, 485 | opt, 486 | shard_size=1000, 487 | ) 488 | sims /= 2 489 | 490 | print('Computing similarity from model') 491 | r, rt0 = i2t(img_embs_shard.shape[0], sims, per_captions, return_ranks=True) 492 | ri, rti0 = t2i(img_embs_shard.shape[0], sims, per_captions, return_ranks=True) 493 | print("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" % r) 494 | print("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" % ri) 495 | if i == 0: 496 | rt, rti = rt0, rti0 497 | ar = (r[0] + r[1] + r[2]) / 3 498 | ari = (ri[0] + ri[1] + ri[2]) / 3 499 | rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2] 500 | print("rsum: %.1f ar: %.1f ari: %.1f" % (rsum, ar, ari)) 501 | results += [list(r) + list(ri) + [ar, ari, rsum]] 502 | 503 | print("-----------------------------------") 504 | print("Mean metrics: ") 505 | mean_metrics = tuple(np.array(results).mean(axis=0).flatten()) 506 | a = np.array(mean_metrics) 507 | 508 | print("Average i2t Recall: %.1f" % mean_metrics[11]) 509 | print("Image to text: %.1f %.1f %.1f %.1f %.1f" % 510 | mean_metrics[:5]) 511 | print("Average t2i Recall: %.1f" % mean_metrics[12]) 512 | print("Text to image: %.1f %.1f %.1f %.1f %.1f" % 513 | mean_metrics[5:10]) 514 | print("rsum: %.1f" % (a[0:3].sum() + a[5:8].sum())) 515 | 516 | 517 | def validation_VSEinfty(opt, val_loader, model, fold=False): 518 | # compute the encoding for all the validation images and captions 519 | if opt.data_name in ['cc152k_precomp', 'cc510k_precomp']: 520 | per_captions = 1 521 | elif opt.data_name in ['coco_precomp', 'f30k_precomp']: 522 | per_captions = 5 523 | 524 | model.val_start() 525 | if not fold: 526 | with torch.no_grad(): 527 | img_embs, cap_embs = encode_data_vse(model.base_model, val_loader) 528 | img_embs = np.array([img_embs[i] for i in range(0, len(img_embs), per_captions)]) 529 | sims = cosine_similarity_matrix(img_embs, cap_embs) 530 | npts = img_embs.shape[0] 531 | # np.save('./sims/f30K_0.4_RSRL.npy',sims) 532 | (r1, r5, r10, medr, meanr) = i2t_vse(npts, sims, per=per_captions) 533 | print("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" % 534 | (r1, r5, r10, medr, meanr)) 535 | (r1i, r5i, r10i, medri, meanr) = t2i_vse(npts, sims, per=per_captions) 536 | print("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" % 537 | (r1i, r5i, r10i, medri, meanr)) 538 | r_sum = r1 + r5 + r10 + r1i + r5i + r10i 539 | print('Current rsum is {}'.format(r_sum)) 540 | return r1, r5, r10, r1i, r5i, r10i 541 | else: 542 | with torch.no_grad(): 543 | results = [] 544 | img_embs, cap_embs = encode_data_vse(model.base_model, val_loader) 545 | for i in range(5): 546 | print("fold: {}".format(i + 1)) 547 | img_embs_shard = img_embs[i * 5000:(i + 1) * 5000:5] 548 | cap_embs_shard = cap_embs[i * 5000:(i + 1) * 5000] 549 | sims = cosine_similarity_matrix(img_embs_shard, cap_embs_shard) 550 | npts = img_embs_shard.shape[0] 551 | (r1, r5, r10, medr, meanr) = i2t_vse(npts, sims, per=per_captions) 552 | print("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" % 553 | (r1, r5, r10, medr, meanr)) 554 | (r1i, r5i, r10i, medri, meanr) = t2i_vse(npts, sims, per=per_captions) 555 | print("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" % 556 | (r1i, r5i, r10i, medri, meanr)) 557 | r_sum = r1 + r5 + r10 + r1i + r5i + r10i 558 | print('Current rsum is {}'.format(r_sum)) 559 | results.append([r1, r5, r10, medr, meanr, r1i, r5i, r10i, medri, meanr, r_sum]) 560 | 561 | print("-----------------------------------") 562 | print("Mean metrics: ") 563 | mean_metrics = tuple(np.array(results).mean(axis=0).flatten()) 564 | print("rsum: %.1f" % (mean_metrics[10])) 565 | print("Image to text: %.1f %.1f %.1f %.1f %.1f" % 566 | mean_metrics[:5]) 567 | print("Text to image: %.1f %.1f %.1f %.1f %.1f" % 568 | mean_metrics[5:10]) 569 | --------------------------------------------------------------------------------