├── LICENSE ├── README.md ├── data ├── __init__.py ├── dataset.py ├── example.py ├── field.py ├── utils.py └── vocab.py ├── environment.yml ├── evaluation ├── __init__.py ├── bleu │ ├── __init__.py │ ├── bleu.py │ └── bleu_scorer.py ├── cider │ ├── __init__.py │ ├── cider.py │ └── cider_scorer.py ├── meteor │ ├── __init__.py │ └── meteor.py ├── rouge │ ├── __init__.py │ └── rouge.py ├── stanford-corenlp-3.4.1.jar └── tokenizer.py ├── feats_process.py ├── images ├── RSTNet.png ├── results.png ├── train_cider.png └── visualness.png ├── models ├── __init__.py ├── beam_search │ ├── __init__.py │ └── beam_search.py ├── captioning_model.py ├── containers.py ├── m2_transformer │ ├── __init__.py │ ├── attention.py │ ├── decoders.py │ ├── encoders.py │ ├── transformer.py │ └── utils.py ├── rstnet │ ├── __init__.py │ ├── attention.py │ ├── decoders.py │ ├── encoders.py │ ├── grid_aug.py │ ├── language_model.py │ ├── transformer.py │ └── utils.py └── transformer │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── attention.cpython-36.pyc │ ├── decoders.cpython-36.pyc │ ├── encoders.cpython-36.pyc │ ├── transformer.cpython-36.pyc │ └── utils.cpython-36.pyc │ ├── attention.py │ ├── decoders.py │ ├── encoders.py │ ├── transformer.py │ └── utils.py ├── pretrained_models └── README.md ├── switch_datatype.py ├── tensorboard_logs ├── rstnet │ ├── events.out.tfevents.1603849421.MAC-U2S5.25151.0 │ ├── events.out.tfevents.1603849484.MAC-U2S5.28641.0 │ ├── events.out.tfevents.1603849514.MAC-U2S5.30196.0 │ ├── events.out.tfevents.1603850322.MAC-U2S5.11650.0 │ └── events.out.tfevents.1605359164.MAC-U2S5.4519.0 └── transformer │ └── events.out.tfevents.1594888876.socialmedia.34273.0 ├── test_offline.py ├── test_online.py ├── train_language.py ├── train_transformer.py ├── utils ├── __init__.py ├── typing.py └── utils.py └── vocab.pkl /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2021, 张旭迎 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RSTNet: Relationship-Sensitive Transformer Network 2 | This repository contains the reference code for our paper [_RSTNet: Captioning with Adaptive Attention on Visual and Non-Visual Words_ (CVPR 2021)](https://openaccess.thecvf.com/content/CVPR2021/papers/Zhang_RSTNet_Captioning_With_Adaptive_Attention_on_Visual_and_Non-Visual_Words_CVPR_2021_paper.pdf). 3 | 4 |

5 | Relationship-Sensitive Transformer 6 |

7 | 8 | ## Tips 9 | If you have any questions about our work, feel free to post issues on this github project. I will answer your questions and update code monthly. 10 | If you are in hurry, please email me via [zhangxuying1004@gmail.com](zhangxuying1004@gmail.com). 11 | If our work is helpful to you or gives some inspiration to you, please star this project and cite our paper. Thank you! 12 | ``` 13 | @inproceedings{zhang2021rstnet, 14 | title={RSTNet: Captioning with adaptive attention on visual and non-visual words}, 15 | author={Zhang, Xuying and Sun, Xiaoshuai and Luo, Yunpeng and Ji, Jiayi and Zhou, Yiyi and Wu, Yongjian and Huang, Feiyue and Ji, Rongrong}, 16 | booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition}, 17 | pages={15465--15474}, 18 | year={2021} 19 | } 20 | ``` 21 | 22 | ## Environment setup 23 | Clone the repository and create the `m2release` conda environment using the `environment.yml` file: 24 | ``` 25 | conda env create -f environment.yml 26 | conda activate m2release 27 | ``` 28 | 29 | Then, download spacy data by executing the following command: 30 | ```python -m spacy download en``` or ```python -m spacy download en_core_web_sm```. 31 | 32 | You also need to create 5 new folders, namely ```Datasets```, ```save_language_models```, ```language_tensorboard_logs```, ```save_transformer_models``` and ```transformer_tensorboard_logs``` in the root directory of this project. 33 | 34 | ## Data preparation 35 | To run our code, you need to put annotations folder ```m2_annotations```, visual features folder ```X101-features``` for the COCO dataset into ```Datasets```. 36 | 37 | Most annotations have been prepared by [1], please download [m2_annotations](https://drive.google.com/drive/folders/1tJnetunBkQ4Y5A3pq2P53yeJuGa4lX9e?usp=sharing) and put it into this root directory. 38 | 39 | Visual features are computed with the code provided by [2]. To reproduce our result, please download the COCO features file such as ```X-101-features.tgz``` in [grid-feats-vqa 40 | ](https://github.com/facebookresearch/grid-feats-vqa) and rename the extracted folder as ```X101-features```. Considering that this feature file is huge, you can alternatively save the features as float16 for storage space saving by executing the following command: 41 | ``` 42 | python switch_datatype.py 43 | ``` 44 | In order to solve the shape difference and match the feat shape with region feat shape (`50` regions), please execute the following command to reshape the visual to `49(7x7)` and save all visual features as a h5py file. 45 | ``` 46 | python feats_process.py 47 | ``` 48 | Note that you can also access to my processed offline image features [coco_grid_feats](https://pan.baidu.com/s/1myelTYJE8a1HDZHkoccfIA) in Baidu Netdisk with the extraction code ```cvpr``` for convenience. 49 | 50 | 51 | 52 | In addition, if you want to extract the grid-based features of your custom image dataset, you can refer to the codes in project [grid-feats-vqa 53 | ](https://github.com/facebookresearch/grid-feats-vqa). 54 | 55 | ## Training procedure 56 | Run `python train_language.py` and `python train_transformer.py` in sequence using the following arguments: 57 | 58 | | Argument | Possible values | 59 | |------|------| 60 | | `--exp_name` | Experiment name| 61 | | `--batch_size` | Batch size (default: 50) | 62 | | `--workers` | Number of workers, accelerate model training in the xe stage.| 63 | | `--head` | Number of heads (default: 8) | 64 | | `--resume_last` | If used, the training will be resumed from the last checkpoint. | 65 | | `--resume_best` | If used, the training will be resumed from the best checkpoint. | 66 | | `--features_path` | Path to visual features file (h5py)| 67 | | `--annotation_folder` | Path to m2_annotations | 68 | 69 | For example, to train our BERT-based language model with the parameters used in our experiments, use 70 | ``` 71 | python train_language.py --exp_name bert_language --batch_size 50 --features_path /path/to/features --annotation_folder /path/to/annotations 72 | ``` 73 | to train our rstnet model with the parameters used in our experiments, use 74 | ``` 75 | python train_transformer.py --exp_name rstnet --batch_size 50 --m 40 --head 8 --features_path /path/to/features --annotation_folder /path/to/annotations 76 | ``` 77 | The figure below shows the changes of cider value during the training of rstnet. You can also visualize the training details by calling the tensorboard files in ```tensorboard_logs```. 78 |

79 | cider changes 80 |

81 | 82 | ## Evaluation 83 | ### Offline Evaluation 84 | Run `python test_offline.py` to evaluate the performance of rstnet on the Karpathy test split of MS COCO dataset. 85 | 86 | ### Online Evaluation 87 | Run `python test_online.py` to generate required files and evaluate the performance of rstnet on the official test server of MS COCO dataset. 88 | 89 | Note that, to reproduce the our reported results, you can also download our pretrained model files in the ```pretrained_models``` folder and put them into folder ```saved_language_models``` and folder ```saved_language_models``` repectively . The results of offline evaluation (Karpathy test split of MS COCO) are as follows: 90 |

91 | offline evaluation 92 |

93 | 94 | #### References 95 | [1] Cornia, M., Stefanini, M., Baraldi, L., & Cucchiara, R. (2020). Meshed-memory transformer for image captioning. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 96 | [2] Jiang, H., Misra, I., Rohrbach, M., Learned-Miller, E., & Chen, X. (2020). In defense of grid features for visual question answering. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 97 | 98 | 99 | #### Acknowledgements 100 | Thank Cornia _et.al_ for their open source code [meshed-memory-transformer 101 | ](https://github.com/aimagelab/meshed-memory-transformer), on which our implements are based. 102 | Thank Jiang _et.al_ for the significant discovery in visual representation [2], which has given us a lot of inspiration. 103 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .field import RawField, Merge, ImageDetectionsField, TextField 2 | from .dataset import COCO, COCO_TestOnline 3 | from torch.utils.data import DataLoader as TorchDataLoader 4 | 5 | class DataLoader(TorchDataLoader): 6 | def __init__(self, dataset, *args, **kwargs): 7 | super(DataLoader, self).__init__(dataset, *args, collate_fn=dataset.collate_fn(), **kwargs) 8 | -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import itertools 4 | import collections 5 | import torch 6 | from .example import Example 7 | from .utils import nostdout 8 | from pycocotools.coco import COCO as pyCOCO 9 | 10 | 11 | class Dataset(object): 12 | def __init__(self, examples, fields): 13 | self.examples = examples 14 | self.fields = dict(fields) 15 | 16 | def collate_fn(self): 17 | def collate(batch): 18 | if len(self.fields) == 1: 19 | batch = [batch, ] 20 | else: 21 | batch = list(zip(*batch)) 22 | 23 | tensors = [] 24 | for field, data in zip(self.fields.values(), batch): 25 | tensor = field.process(data) 26 | if isinstance(tensor, collections.Sequence) and any(isinstance(t, torch.Tensor) for t in tensor): 27 | tensors.extend(tensor) 28 | else: 29 | tensors.append(tensor) 30 | 31 | if len(tensors) > 1: 32 | return tensors 33 | else: 34 | return tensors[0] 35 | 36 | return collate 37 | 38 | def __getitem__(self, i): 39 | example = self.examples[i] 40 | data = [] 41 | for field_name, field in self.fields.items(): 42 | data.append(field.preprocess(getattr(example, field_name))) 43 | 44 | if len(data) == 1: 45 | data = data[0] 46 | return data 47 | 48 | def __len__(self): 49 | return len(self.examples) 50 | 51 | def __getattr__(self, attr): 52 | if attr in self.fields: 53 | for x in self.examples: 54 | yield getattr(x, attr) 55 | 56 | 57 | class ValueDataset(Dataset): 58 | def __init__(self, examples, fields, dictionary): 59 | self.dictionary = dictionary 60 | super(ValueDataset, self).__init__(examples, fields) 61 | 62 | def collate_fn(self): 63 | def collate(batch): 64 | value_batch_flattened = list(itertools.chain(*batch)) 65 | value_tensors_flattened = super(ValueDataset, self).collate_fn()(value_batch_flattened) 66 | 67 | lengths = [0, ] + list(itertools.accumulate([len(x) for x in batch])) 68 | if isinstance(value_tensors_flattened, collections.Sequence) \ 69 | and any(isinstance(t, torch.Tensor) for t in value_tensors_flattened): 70 | value_tensors = [[vt[s:e] for (s, e) in zip(lengths[:-1], lengths[1:])] for vt in value_tensors_flattened] 71 | else: 72 | value_tensors = [value_tensors_flattened[s:e] for (s, e) in zip(lengths[:-1], lengths[1:])] 73 | 74 | return value_tensors 75 | return collate 76 | 77 | def __getitem__(self, i): 78 | if i not in self.dictionary: 79 | raise IndexError 80 | 81 | values_data = [] 82 | for idx in self.dictionary[i]: 83 | value_data = super(ValueDataset, self).__getitem__(idx) 84 | values_data.append(value_data) 85 | return values_data 86 | 87 | def __len__(self): 88 | return len(self.dictionary) 89 | 90 | 91 | class DictionaryDataset(Dataset): 92 | def __init__(self, examples, fields, key_fields): 93 | if not isinstance(key_fields, (tuple, list)): 94 | key_fields = (key_fields,) 95 | for field in key_fields: 96 | assert (field in fields) 97 | 98 | dictionary = collections.defaultdict(list) 99 | key_fields = {k: fields[k] for k in key_fields} 100 | value_fields = {k: fields[k] for k in fields.keys() if k not in key_fields} 101 | key_examples = [] 102 | key_dict = dict() 103 | value_examples = [] 104 | 105 | for i, e in enumerate(examples): 106 | key_example = Example.fromdict({k: getattr(e, k) for k in key_fields}) 107 | value_example = Example.fromdict({v: getattr(e, v) for v in value_fields}) 108 | if key_example not in key_dict: 109 | key_dict[key_example] = len(key_examples) 110 | key_examples.append(key_example) 111 | 112 | value_examples.append(value_example) 113 | dictionary[key_dict[key_example]].append(i) 114 | 115 | self.key_dataset = Dataset(key_examples, key_fields) 116 | self.value_dataset = ValueDataset(value_examples, value_fields, dictionary) 117 | super(DictionaryDataset, self).__init__(examples, fields) 118 | 119 | def collate_fn(self): 120 | def collate(batch): 121 | key_batch, value_batch = list(zip(*batch)) 122 | key_tensors = self.key_dataset.collate_fn()(key_batch) 123 | value_tensors = self.value_dataset.collate_fn()(value_batch) 124 | return key_tensors, value_tensors 125 | return collate 126 | 127 | def __getitem__(self, i): 128 | return self.key_dataset[i], self.value_dataset[i] 129 | 130 | def __len__(self): 131 | return len(self.key_dataset) 132 | 133 | 134 | def unique(sequence): 135 | seen = set() 136 | if isinstance(sequence[0], list): 137 | return [x for x in sequence if not (tuple(x) in seen or seen.add(tuple(x)))] 138 | else: 139 | return [x for x in sequence if not (x in seen or seen.add(x))] 140 | 141 | 142 | class PairedDataset(Dataset): 143 | def __init__(self, examples, fields): 144 | assert ('image' in fields) 145 | assert ('text' in fields) 146 | super(PairedDataset, self).__init__(examples, fields) 147 | self.image_field = self.fields['image'] 148 | self.text_field = self.fields['text'] 149 | 150 | def image_set(self): 151 | img_list = [e.image for e in self.examples] 152 | image_set = unique(img_list) 153 | examples = [Example.fromdict({'image': i}) for i in image_set] 154 | dataset = Dataset(examples, {'image': self.image_field}) 155 | return dataset 156 | 157 | def text_set(self): 158 | text_list = [e.text for e in self.examples] 159 | text_list = unique(text_list) 160 | examples = [Example.fromdict({'text': t}) for t in text_list] 161 | dataset = Dataset(examples, {'text': self.text_field}) 162 | return dataset 163 | 164 | def image_dictionary(self, fields=None): 165 | if not fields: 166 | fields = self.fields 167 | dataset = DictionaryDataset(self.examples, fields, key_fields='image') 168 | return dataset 169 | 170 | def text_dictionary(self, fields=None): 171 | if not fields: 172 | fields = self.fields 173 | dataset = DictionaryDataset(self.examples, fields, key_fields='text') 174 | return dataset 175 | 176 | @property 177 | def splits(self): 178 | raise NotImplementedError 179 | 180 | 181 | class COCO(PairedDataset): 182 | def __init__(self, image_field, text_field, img_root, ann_root, id_root=None, use_restval=True, 183 | cut_validation=False): 184 | roots = {} 185 | roots['train'] = { 186 | 'img': os.path.join(img_root, 'train2014'), 187 | 'cap': os.path.join(ann_root, 'captions_train2014.json') 188 | } 189 | roots['val'] = { 190 | 'img': os.path.join(img_root, 'val2014'), 191 | 'cap': os.path.join(ann_root, 'captions_val2014.json') 192 | } 193 | roots['test'] = { 194 | 'img': os.path.join(img_root, 'val2014'), 195 | 'cap': os.path.join(ann_root, 'captions_val2014.json') 196 | } 197 | roots['trainrestval'] = { 198 | 'img': (roots['train']['img'], roots['val']['img']), 199 | 'cap': (roots['train']['cap'], roots['val']['cap']) 200 | } 201 | 202 | if id_root is not None: 203 | ids = {} 204 | ids['train'] = np.load(os.path.join(id_root, 'coco_train_ids.npy')) 205 | ids['val'] = np.load(os.path.join(id_root, 'coco_dev_ids.npy')) 206 | if cut_validation: 207 | ids['val'] = ids['val'][:5000] 208 | ids['test'] = np.load(os.path.join(id_root, 'coco_test_ids.npy')) 209 | ids['trainrestval'] = ( 210 | ids['train'], 211 | np.load(os.path.join(id_root, 'coco_restval_ids.npy'))) 212 | 213 | if use_restval: 214 | roots['train'] = roots['trainrestval'] 215 | ids['train'] = ids['trainrestval'] 216 | else: 217 | ids = None 218 | 219 | with nostdout(): 220 | self.train_examples, self.val_examples, self.test_examples = self.get_samples(roots, ids) 221 | examples = self.train_examples + self.val_examples + self.test_examples 222 | super(COCO, self).__init__(examples, {'image': image_field, 'text': text_field}) 223 | 224 | @property 225 | def splits(self): 226 | train_split = PairedDataset(self.train_examples, self.fields) 227 | val_split = PairedDataset(self.val_examples, self.fields) 228 | test_split = PairedDataset(self.test_examples, self.fields) 229 | return train_split, val_split, test_split 230 | 231 | @classmethod 232 | def get_samples(cls, roots, ids_dataset=None): 233 | train_samples = [] 234 | val_samples = [] 235 | test_samples = [] 236 | 237 | for split in ['train', 'val', 'test']: 238 | if isinstance(roots[split]['cap'], tuple): 239 | coco_dataset = (pyCOCO(roots[split]['cap'][0]), pyCOCO(roots[split]['cap'][1])) 240 | root = roots[split]['img'] 241 | else: 242 | coco_dataset = (pyCOCO(roots[split]['cap']),) 243 | root = (roots[split]['img'],) 244 | 245 | if ids_dataset is None: 246 | ids = list(coco_dataset.anns.keys()) 247 | else: 248 | ids = ids_dataset[split] 249 | 250 | if isinstance(ids, tuple): 251 | bp = len(ids[0]) 252 | ids = list(ids[0]) + list(ids[1]) 253 | else: 254 | bp = len(ids) 255 | 256 | for index in range(len(ids)): 257 | if index < bp: 258 | coco = coco_dataset[0] 259 | img_root = root[0] 260 | else: 261 | coco = coco_dataset[1] 262 | img_root = root[1] 263 | 264 | ann_id = ids[index] 265 | caption = coco.anns[ann_id]['caption'] 266 | img_id = coco.anns[ann_id]['image_id'] 267 | filename = coco.loadImgs(img_id)[0]['file_name'] 268 | 269 | example = Example.fromdict({'image': os.path.join(img_root, filename), 'text': caption}) 270 | 271 | if split == 'train': 272 | train_samples.append(example) 273 | elif split == 'val': 274 | val_samples.append(example) 275 | elif split == 'test': 276 | test_samples.append(example) 277 | 278 | return train_samples, val_samples, test_samples 279 | 280 | 281 | class COCO_TestOnline(Dataset): 282 | def __init__(self, feat_path, ann_file, max_detections=49): 283 | """ 284 | feat_path: COCO官方划分的训练集和验证集的特征路径 285 | ann_file: 训练集或验证集的标注信息,用于获取image_id,进而检索出对应特征 286 | """ 287 | super(COCO_TestOnline, self).__init__() 288 | 289 | # 读取图像信息 290 | with open(ann_file, 'r') as f: 291 | self.images_info = json.load(f)['images'] 292 | 293 | # 读取特征文件 294 | self.f = h5py.File(feat_path, 'r') 295 | 296 | # 记录特征数目 297 | self.max_detections = max_detections 298 | 299 | def __len__(self): 300 | return len(self.images_info) 301 | 302 | def __getitem__(self, idx): 303 | image_id = self.images_info[idx]['id'] 304 | precomp_data = self.f['%d_grids' % image_id][()] 305 | 306 | delta = self.max_detections - precomp_data.shape[0] 307 | if delta > 0: 308 | precomp_data = np.concatenate([precomp_data, np.zeros((delta, precomp_data.shape[1]))], axis=0) 309 | elif delta < 0: 310 | precomp_data = precomp_data[:self.max_detections] 311 | 312 | return int(image_id), precomp_data 313 | 314 | -------------------------------------------------------------------------------- /data/example.py: -------------------------------------------------------------------------------- 1 | 2 | class Example(object): 3 | """Defines a single training or test example. 4 | Stores each column of the example as an attribute. 5 | """ 6 | @classmethod 7 | def fromdict(cls, data): 8 | ex = cls(data) 9 | return ex 10 | 11 | def __init__(self, data): 12 | for key, val in data.items(): 13 | super(Example, self).__setattr__(key, val) 14 | 15 | def __setattr__(self, key, value): 16 | raise AttributeError 17 | 18 | def __hash__(self): 19 | return hash(tuple(x for x in self.__dict__.values())) 20 | 21 | def __eq__(self, other): 22 | this = tuple(x for x in self.__dict__.values()) 23 | other = tuple(x for x in other.__dict__.values()) 24 | return this == other 25 | 26 | def __ne__(self, other): 27 | return not self.__eq__(other) 28 | -------------------------------------------------------------------------------- /data/field.py: -------------------------------------------------------------------------------- 1 | # coding: utf8 2 | from collections import Counter, OrderedDict 3 | from torch.utils.data.dataloader import default_collate 4 | from itertools import chain 5 | import six 6 | import torch 7 | import numpy as np 8 | import h5py 9 | import os 10 | import warnings 11 | import shutil 12 | 13 | from .dataset import Dataset 14 | from .vocab import Vocab 15 | from .utils import get_tokenizer 16 | 17 | 18 | class RawField(object): 19 | """ Defines a general datatype. 20 | 21 | Every dataset consists of one or more types of data. For instance, 22 | a machine translation dataset contains paired examples of text, while 23 | an image captioning dataset contains images and texts. 24 | Each of these types of data is represented by a RawField object. 25 | An RawField object does not assume any property of the data type and 26 | it holds parameters relating to how a datatype should be processed. 27 | 28 | Attributes: 29 | preprocessing: The Pipeline that will be applied to examples 30 | using this field before creating an example. 31 | Default: None. 32 | postprocessing: A Pipeline that will be applied to a list of examples 33 | using this field before assigning to a batch. 34 | Function signature: (batch(list)) -> object 35 | Default: None. 36 | """ 37 | 38 | def __init__(self, preprocessing=None, postprocessing=None): 39 | self.preprocessing = preprocessing 40 | self.postprocessing = postprocessing 41 | 42 | def preprocess(self, x): 43 | """ Preprocess an example if the `preprocessing` Pipeline is provided. """ 44 | if self.preprocessing is not None: 45 | return self.preprocessing(x) 46 | else: 47 | return x 48 | 49 | def process(self, batch, *args, **kwargs): 50 | """ Process a list of examples to create a batch. 51 | 52 | Postprocess the batch with user-provided Pipeline. 53 | 54 | Args: 55 | batch (list(object)): A list of object from a batch of examples. 56 | Returns: 57 | object: Processed object given the input and custom 58 | postprocessing Pipeline. 59 | """ 60 | if self.postprocessing is not None: 61 | batch = self.postprocessing(batch) 62 | return default_collate(batch) 63 | 64 | 65 | class Merge(RawField): 66 | def __init__(self, *fields): 67 | super(Merge, self).__init__() 68 | self.fields = fields 69 | 70 | def preprocess(self, x): 71 | return tuple(f.preprocess(x) for f in self.fields) 72 | 73 | def process(self, batch, *args, **kwargs): 74 | if len(self.fields) == 1: 75 | batch = [batch, ] 76 | else: 77 | batch = list(zip(*batch)) 78 | 79 | out = list(f.process(b, *args, **kwargs) for f, b in zip(self.fields, batch)) 80 | return out 81 | 82 | 83 | class ImageDetectionsField(RawField): 84 | def __init__(self, preprocessing=None, postprocessing=None, detections_path=None, max_detections=100, 85 | sort_by_prob=False, load_in_tmp=True): 86 | self.max_detections = max_detections 87 | self.detections_path = detections_path 88 | self.sort_by_prob = sort_by_prob 89 | 90 | tmp_detections_path = os.path.join('/tmp', os.path.basename(detections_path)) 91 | 92 | if load_in_tmp: 93 | if not os.path.isfile(tmp_detections_path): 94 | if shutil.disk_usage("/tmp")[-1] < os.path.getsize(detections_path): 95 | warnings.warn('Loading from %s, because /tmp has no enough space.' % detections_path) 96 | else: 97 | warnings.warn("Copying detection file to /tmp") 98 | shutil.copyfile(detections_path, tmp_detections_path) 99 | warnings.warn("Done.") 100 | self.detections_path = tmp_detections_path 101 | else: 102 | self.detections_path = tmp_detections_path 103 | 104 | super(ImageDetectionsField, self).__init__(preprocessing, postprocessing) 105 | 106 | def preprocess(self, x, avoid_precomp=False): 107 | image_id = int(x.split('_')[-1].split('.')[0]) 108 | try: 109 | f = h5py.File(self.detections_path, 'r') 110 | # precomp_data = f['%d_features' % image_id][()] 111 | precomp_data = f['%d_grids' % image_id][()] 112 | if self.sort_by_prob: 113 | precomp_data = precomp_data[np.argsort(np.max(f['%d_cls_prob' % image_id][()], -1))[::-1]] 114 | except KeyError: 115 | warnings.warn('Could not find detections for %d' % image_id) 116 | precomp_data = np.random.rand(10,2048) 117 | 118 | delta = self.max_detections - precomp_data.shape[0] 119 | if delta > 0: 120 | precomp_data = np.concatenate([precomp_data, np.zeros((delta, precomp_data.shape[1]))], axis=0) 121 | elif delta < 0: 122 | precomp_data = precomp_data[:self.max_detections] 123 | 124 | return precomp_data.astype(np.float32) 125 | 126 | 127 | class TextField(RawField): 128 | vocab_cls = Vocab 129 | # Dictionary mapping PyTorch tensor dtypes to the appropriate Python 130 | # numeric type. 131 | dtypes = { 132 | torch.float32: float, 133 | torch.float: float, 134 | torch.float64: float, 135 | torch.double: float, 136 | torch.float16: float, 137 | torch.half: float, 138 | 139 | torch.uint8: int, 140 | torch.int8: int, 141 | torch.int16: int, 142 | torch.short: int, 143 | torch.int32: int, 144 | torch.int: int, 145 | torch.int64: int, 146 | torch.long: int, 147 | } 148 | punctuations = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \ 149 | ".", "?", "!", ",", ":", "-", "--", "...", ";"] 150 | 151 | def __init__(self, use_vocab=True, init_token=None, eos_token=None, fix_length=None, dtype=torch.long, 152 | preprocessing=None, postprocessing=None, lower=False, tokenize=(lambda s: s.split()), 153 | remove_punctuation=False, include_lengths=False, batch_first=True, pad_token="", 154 | unk_token="", pad_first=False, truncate_first=False, vectors=None, nopoints=True): 155 | self.use_vocab = use_vocab 156 | self.init_token = init_token 157 | self.eos_token = eos_token 158 | self.fix_length = fix_length 159 | self.dtype = dtype 160 | self.lower = lower 161 | self.tokenize = get_tokenizer(tokenize) 162 | self.remove_punctuation = remove_punctuation 163 | self.include_lengths = include_lengths 164 | self.batch_first = batch_first 165 | self.pad_token = pad_token 166 | self.unk_token = unk_token 167 | self.pad_first = pad_first 168 | self.truncate_first = truncate_first 169 | self.vocab = None 170 | self.vectors = vectors 171 | if nopoints: 172 | self.punctuations.append("..") 173 | 174 | super(TextField, self).__init__(preprocessing, postprocessing) 175 | 176 | def preprocess(self, x): 177 | if six.PY2 and isinstance(x, six.string_types) and not isinstance(x, six.text_type): 178 | x = six.text_type(x, encoding='utf-8') 179 | if self.lower: 180 | x = six.text_type.lower(x) 181 | x = self.tokenize(x.rstrip('\n')) 182 | if self.remove_punctuation: 183 | x = [w for w in x if w not in self.punctuations] 184 | if self.preprocessing is not None: 185 | return self.preprocessing(x) 186 | else: 187 | return x 188 | 189 | def process(self, batch, device=None): 190 | padded = self.pad(batch) 191 | tensor = self.numericalize(padded, device=device) 192 | return tensor 193 | 194 | def build_vocab(self, *args, **kwargs): 195 | counter = Counter() 196 | sources = [] 197 | for arg in args: 198 | if isinstance(arg, Dataset): 199 | sources += [getattr(arg, name) for name, field in arg.fields.items() if field is self] 200 | else: 201 | sources.append(arg) 202 | 203 | for data in sources: 204 | for x in data: 205 | x = self.preprocess(x) 206 | try: 207 | counter.update(x) 208 | except TypeError: 209 | counter.update(chain.from_iterable(x)) 210 | 211 | specials = list(OrderedDict.fromkeys([ 212 | tok for tok in [self.unk_token, self.pad_token, self.init_token, 213 | self.eos_token] 214 | if tok is not None])) 215 | self.vocab = self.vocab_cls(counter, specials=specials, **kwargs) 216 | 217 | def pad(self, minibatch): 218 | """Pad a batch of examples using this field. 219 | Pads to self.fix_length if provided, otherwise pads to the length of 220 | the longest example in the batch. Prepends self.init_token and appends 221 | self.eos_token if those attributes are not None. Returns a tuple of the 222 | padded list and a list containing lengths of each example if 223 | `self.include_lengths` is `True`, else just 224 | returns the padded list. 225 | """ 226 | minibatch = list(minibatch) 227 | if self.fix_length is None: 228 | max_len = max(len(x) for x in minibatch) 229 | else: 230 | max_len = self.fix_length + ( 231 | self.init_token, self.eos_token).count(None) - 2 232 | padded, lengths = [], [] 233 | for x in minibatch: 234 | if self.pad_first: 235 | padded.append( 236 | [self.pad_token] * max(0, max_len - len(x)) + 237 | ([] if self.init_token is None else [self.init_token]) + 238 | list(x[-max_len:] if self.truncate_first else x[:max_len]) + 239 | ([] if self.eos_token is None else [self.eos_token])) 240 | else: 241 | padded.append( 242 | ([] if self.init_token is None else [self.init_token]) + 243 | list(x[-max_len:] if self.truncate_first else x[:max_len]) + 244 | ([] if self.eos_token is None else [self.eos_token]) + 245 | [self.pad_token] * max(0, max_len - len(x))) 246 | lengths.append(len(padded[-1]) - max(0, max_len - len(x))) 247 | if self.include_lengths: 248 | return padded, lengths 249 | return padded 250 | 251 | def numericalize(self, arr, device=None): 252 | """Turn a batch of examples that use this field into a list of Variables. 253 | If the field has include_lengths=True, a tensor of lengths will be 254 | included in the return value. 255 | Arguments: 256 | arr (List[List[str]], or tuple of (List[List[str]], List[int])): 257 | List of tokenized and padded examples, or tuple of List of 258 | tokenized and padded examples and List of lengths of each 259 | example if self.include_lengths is True. 260 | device (str or torch.device): A string or instance of `torch.device` 261 | specifying which device the Variables are going to be created on. 262 | If left as default, the tensors will be created on cpu. Default: None. 263 | """ 264 | if self.include_lengths and not isinstance(arr, tuple): 265 | raise ValueError("Field has include_lengths set to True, but " 266 | "input data is not a tuple of " 267 | "(data batch, batch lengths).") 268 | if isinstance(arr, tuple): 269 | arr, lengths = arr 270 | lengths = torch.tensor(lengths, dtype=self.dtype, device=device) 271 | 272 | if self.use_vocab: 273 | arr = [[self.vocab.stoi[x] for x in ex] for ex in arr] 274 | 275 | if self.postprocessing is not None: 276 | arr = self.postprocessing(arr, self.vocab) 277 | 278 | var = torch.tensor(arr, dtype=self.dtype, device=device) 279 | else: 280 | if self.vectors: 281 | arr = [[self.vectors[x] for x in ex] for ex in arr] 282 | if self.dtype not in self.dtypes: 283 | raise ValueError( 284 | "Specified Field dtype {} can not be used with " 285 | "use_vocab=False because we do not know how to numericalize it. " 286 | "Please raise an issue at " 287 | "https://github.com/pytorch/text/issues".format(self.dtype)) 288 | numericalization_func = self.dtypes[self.dtype] 289 | # It doesn't make sense to explictly coerce to a numeric type if 290 | # the data is sequential, since it's unclear how to coerce padding tokens 291 | # to a numeric type. 292 | arr = [numericalization_func(x) if isinstance(x, six.string_types) 293 | else x for x in arr] 294 | 295 | if self.postprocessing is not None: 296 | arr = self.postprocessing(arr, None) 297 | 298 | var = torch.cat([torch.cat([a.unsqueeze(0) for a in ar]).unsqueeze(0) for ar in arr]) 299 | 300 | # var = torch.tensor(arr, dtype=self.dtype, device=device) 301 | if not self.batch_first: 302 | var.t_() 303 | var = var.contiguous() 304 | 305 | if self.include_lengths: 306 | return var, lengths 307 | return var 308 | 309 | def decode(self, word_idxs, join_words=True): 310 | if isinstance(word_idxs, list) and len(word_idxs) == 0: 311 | return self.decode([word_idxs, ], join_words)[0] 312 | if isinstance(word_idxs, list) and isinstance(word_idxs[0], int): 313 | return self.decode([word_idxs, ], join_words)[0] 314 | elif isinstance(word_idxs, np.ndarray) and word_idxs.ndim == 1: 315 | return self.decode(word_idxs.reshape((1, -1)), join_words)[0] 316 | elif isinstance(word_idxs, torch.Tensor) and word_idxs.ndimension() == 1: 317 | return self.decode(word_idxs.unsqueeze(0), join_words)[0] 318 | 319 | captions = [] 320 | for wis in word_idxs: 321 | caption = [] 322 | for wi in wis: 323 | word = self.vocab.itos[int(wi)] 324 | if word == self.eos_token: 325 | break 326 | caption.append(word) 327 | if join_words: 328 | caption = ' '.join(caption) 329 | captions.append(caption) 330 | return captions 331 | -------------------------------------------------------------------------------- /data/utils.py: -------------------------------------------------------------------------------- 1 | import contextlib, sys 2 | 3 | class DummyFile(object): 4 | def write(self, x): pass 5 | 6 | @contextlib.contextmanager 7 | def nostdout(): 8 | save_stdout = sys.stdout 9 | sys.stdout = DummyFile() 10 | yield 11 | sys.stdout = save_stdout 12 | 13 | def reporthook(t): 14 | """https://github.com/tqdm/tqdm""" 15 | last_b = [0] 16 | 17 | def inner(b=1, bsize=1, tsize=None): 18 | """ 19 | b: int, optionala 20 | Number of blocks just transferred [default: 1]. 21 | bsize: int, optional 22 | Size of each block (in tqdm units) [default: 1]. 23 | tsize: int, optional 24 | Total size (in tqdm units). If [default: None] remains unchanged. 25 | """ 26 | if tsize is not None: 27 | t.total = tsize 28 | t.update((b - last_b[0]) * bsize) 29 | last_b[0] = b 30 | return inner 31 | 32 | def get_tokenizer(tokenizer): 33 | if callable(tokenizer): 34 | return tokenizer 35 | if tokenizer == "spacy": 36 | try: 37 | import spacy 38 | spacy_en = spacy.load('en') 39 | return lambda s: [tok.text for tok in spacy_en.tokenizer(s)] 40 | except ImportError: 41 | print("Please install SpaCy and the SpaCy English tokenizer. " 42 | "See the docs at https://spacy.io for more information.") 43 | raise 44 | except AttributeError: 45 | print("Please install SpaCy and the SpaCy English tokenizer. " 46 | "See the docs at https://spacy.io for more information.") 47 | raise 48 | elif tokenizer == "moses": 49 | try: 50 | from nltk.tokenize.moses import MosesTokenizer 51 | moses_tokenizer = MosesTokenizer() 52 | return moses_tokenizer.tokenize 53 | except ImportError: 54 | print("Please install NLTK. " 55 | "See the docs at http://nltk.org for more information.") 56 | raise 57 | except LookupError: 58 | print("Please install the necessary NLTK corpora. " 59 | "See the docs at http://nltk.org for more information.") 60 | raise 61 | elif tokenizer == 'revtok': 62 | try: 63 | import revtok 64 | return revtok.tokenize 65 | except ImportError: 66 | print("Please install revtok.") 67 | raise 68 | elif tokenizer == 'subword': 69 | try: 70 | import revtok 71 | return lambda x: revtok.tokenize(x, decap=True) 72 | except ImportError: 73 | print("Please install revtok.") 74 | raise 75 | raise ValueError("Requested tokenizer {}, valid choices are a " 76 | "callable that takes a single string as input, " 77 | "\"revtok\" for the revtok reversible tokenizer, " 78 | "\"subword\" for the revtok caps-aware tokenizer, " 79 | "\"spacy\" for the SpaCy English tokenizer, or " 80 | "\"moses\" for the NLTK port of the Moses tokenization " 81 | "script.".format(tokenizer)) 82 | -------------------------------------------------------------------------------- /data/vocab.py: -------------------------------------------------------------------------------- 1 | from __future__ import unicode_literals 2 | import array 3 | from collections import defaultdict 4 | from functools import partial 5 | import io 6 | import logging 7 | import os 8 | import zipfile 9 | 10 | import six 11 | from six.moves.urllib.request import urlretrieve 12 | import torch 13 | from tqdm import tqdm 14 | import tarfile 15 | 16 | from .utils import reporthook 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | class Vocab(object): 22 | """Defines a vocabulary object that will be used to numericalize a field. 23 | 24 | Attributes: 25 | freqs: A collections.Counter object holding the frequencies of tokens 26 | in the data used to build the Vocab. 27 | stoi: A collections.defaultdict instance mapping token strings to 28 | numerical identifiers. 29 | itos: A list of token strings indexed by their numerical identifiers. 30 | """ 31 | def __init__(self, counter, max_size=None, min_freq=1, specials=[''], 32 | vectors=None, unk_init=None, vectors_cache=None): 33 | """Create a Vocab object from a collections.Counter. 34 | 35 | Arguments: 36 | counter: collections.Counter object holding the frequencies of 37 | each value found in the data. 38 | max_size: The maximum size of the vocabulary, or None for no 39 | maximum. Default: None. 40 | min_freq: The minimum frequency needed to include a token in the 41 | vocabulary. Values less than 1 will be set to 1. Default: 1. 42 | specials: The list of special tokens (e.g., padding or eos) that 43 | will be prepended to the vocabulary in addition to an 44 | token. Default: [''] 45 | vectors: One of either the available pretrained vectors 46 | or custom pretrained vectors (see Vocab.load_vectors); 47 | or a list of aforementioned vectors 48 | unk_init (callback): by default, initialize out-of-vocabulary word vectors 49 | to zero vectors; can be any function that takes in a Tensor and 50 | returns a Tensor of the same size. Default: torch.Tensor.zero_ 51 | vectors_cache: directory for cached vectors. Default: '.vector_cache' 52 | """ 53 | self.freqs = counter 54 | counter = counter.copy() 55 | min_freq = max(min_freq, 1) 56 | 57 | self.itos = list(specials) 58 | # frequencies of special tokens are not counted when building vocabulary 59 | # in frequency order 60 | for tok in specials: 61 | del counter[tok] 62 | 63 | max_size = None if max_size is None else max_size + len(self.itos) 64 | 65 | # sort by frequency, then alphabetically 66 | words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0]) 67 | words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True) 68 | 69 | for word, freq in words_and_frequencies: 70 | if freq < min_freq or len(self.itos) == max_size: 71 | break 72 | self.itos.append(word) 73 | 74 | self.stoi = defaultdict(_default_unk_index) 75 | # stoi is simply a reverse dict for itos 76 | self.stoi.update({tok: i for i, tok in enumerate(self.itos)}) 77 | 78 | self.vectors = None 79 | if vectors is not None: 80 | self.load_vectors(vectors, unk_init=unk_init, cache=vectors_cache) 81 | else: 82 | assert unk_init is None and vectors_cache is None 83 | 84 | def __eq__(self, other): 85 | if self.freqs != other.freqs: 86 | return False 87 | if self.stoi != other.stoi: 88 | return False 89 | if self.itos != other.itos: 90 | return False 91 | if self.vectors != other.vectors: 92 | return False 93 | return True 94 | 95 | def __len__(self): 96 | return len(self.itos) 97 | 98 | def extend(self, v, sort=False): 99 | words = sorted(v.itos) if sort else v.itos 100 | for w in words: 101 | if w not in self.stoi: 102 | self.itos.append(w) 103 | self.stoi[w] = len(self.itos) - 1 104 | 105 | def load_vectors(self, vectors, **kwargs): 106 | """ 107 | Arguments: 108 | vectors: one of or a list containing instantiations of the 109 | GloVe, CharNGram, or Vectors classes. Alternatively, one 110 | of or a list of available pretrained vectors: 111 | charngram.100d 112 | fasttext.en.300d 113 | fasttext.simple.300d 114 | glove.42B.300d 115 | glove.840B.300d 116 | glove.twitter.27B.25d 117 | glove.twitter.27B.50d 118 | glove.twitter.27B.100d 119 | glove.twitter.27B.200d 120 | glove.6B.50d 121 | glove.6B.100d 122 | glove.6B.200d 123 | glove.6B.300d 124 | Remaining keyword arguments: Passed to the constructor of Vectors classes. 125 | """ 126 | if not isinstance(vectors, list): 127 | vectors = [vectors] 128 | for idx, vector in enumerate(vectors): 129 | if six.PY2 and isinstance(vector, str): 130 | vector = six.text_type(vector) 131 | if isinstance(vector, six.string_types): 132 | # Convert the string pretrained vector identifier 133 | # to a Vectors object 134 | if vector not in pretrained_aliases: 135 | raise ValueError( 136 | "Got string input vector {}, but allowed pretrained " 137 | "vectors are {}".format( 138 | vector, list(pretrained_aliases.keys()))) 139 | vectors[idx] = pretrained_aliases[vector](**kwargs) 140 | elif not isinstance(vector, Vectors): 141 | raise ValueError( 142 | "Got input vectors of type {}, expected str or " 143 | "Vectors object".format(type(vector))) 144 | 145 | tot_dim = sum(v.dim for v in vectors) 146 | self.vectors = torch.Tensor(len(self), tot_dim) 147 | for i, token in enumerate(self.itos): 148 | start_dim = 0 149 | for v in vectors: 150 | end_dim = start_dim + v.dim 151 | self.vectors[i][start_dim:end_dim] = v[token.strip()] 152 | start_dim = end_dim 153 | assert(start_dim == tot_dim) 154 | 155 | def set_vectors(self, stoi, vectors, dim, unk_init=torch.Tensor.zero_): 156 | """ 157 | Set the vectors for the Vocab instance from a collection of Tensors. 158 | 159 | Arguments: 160 | stoi: A dictionary of string to the index of the associated vector 161 | in the `vectors` input argument. 162 | vectors: An indexed iterable (or other structure supporting __getitem__) that 163 | given an input index, returns a FloatTensor representing the vector 164 | for the token associated with the index. For example, 165 | vector[stoi["string"]] should return the vector for "string". 166 | dim: The dimensionality of the vectors. 167 | unk_init (callback): by default, initialize out-of-vocabulary word vectors 168 | to zero vectors; can be any function that takes in a Tensor and 169 | returns a Tensor of the same size. Default: torch.Tensor.zero_ 170 | """ 171 | self.vectors = torch.Tensor(len(self), dim) 172 | for i, token in enumerate(self.itos): 173 | wv_index = stoi.get(token, None) 174 | if wv_index is not None: 175 | self.vectors[i] = vectors[wv_index] 176 | else: 177 | self.vectors[i] = unk_init(self.vectors[i]) 178 | 179 | 180 | class Vectors(object): 181 | 182 | def __init__(self, name, cache=None, 183 | url=None, unk_init=None): 184 | """ 185 | Arguments: 186 | name: name of the file that contains the vectors 187 | cache: directory for cached vectors 188 | url: url for download if vectors not found in cache 189 | unk_init (callback): by default, initalize out-of-vocabulary word vectors 190 | to zero vectors; can be any function that takes in a Tensor and 191 | returns a Tensor of the same size 192 | """ 193 | cache = '.vector_cache' if cache is None else cache 194 | self.unk_init = torch.Tensor.zero_ if unk_init is None else unk_init 195 | self.cache(name, cache, url=url) 196 | 197 | def __getitem__(self, token): 198 | if token in self.stoi: 199 | return self.vectors[self.stoi[token]] 200 | else: 201 | return self.unk_init(torch.Tensor(self.dim)) # self.unk_init(torch.Tensor(1, self.dim)) 202 | 203 | def cache(self, name, cache, url=None): 204 | if os.path.isfile(name): 205 | path = name 206 | path_pt = os.path.join(cache, os.path.basename(name)) + '.pt' 207 | else: 208 | path = os.path.join(cache, name) 209 | path_pt = path + '.pt' 210 | 211 | if not os.path.isfile(path_pt): 212 | if not os.path.isfile(path) and url: 213 | logger.info('Downloading vectors from {}'.format(url)) 214 | if not os.path.exists(cache): 215 | os.makedirs(cache) 216 | dest = os.path.join(cache, os.path.basename(url)) 217 | if not os.path.isfile(dest): 218 | with tqdm(unit='B', unit_scale=True, miniters=1, desc=dest) as t: 219 | try: 220 | urlretrieve(url, dest, reporthook=reporthook(t)) 221 | except KeyboardInterrupt as e: # remove the partial zip file 222 | os.remove(dest) 223 | raise e 224 | logger.info('Extracting vectors into {}'.format(cache)) 225 | ext = os.path.splitext(dest)[1][1:] 226 | if ext == 'zip': 227 | with zipfile.ZipFile(dest, "r") as zf: 228 | zf.extractall(cache) 229 | elif ext == 'gz': 230 | with tarfile.open(dest, 'r:gz') as tar: 231 | tar.extractall(path=cache) 232 | if not os.path.isfile(path): 233 | raise RuntimeError('no vectors found at {}'.format(path)) 234 | 235 | # str call is necessary for Python 2/3 compatibility, since 236 | # argument must be Python 2 str (Python 3 bytes) or 237 | # Python 3 str (Python 2 unicode) 238 | itos, vectors, dim = [], array.array(str('d')), None 239 | 240 | # Try to read the whole file with utf-8 encoding. 241 | binary_lines = False 242 | try: 243 | with io.open(path, encoding="utf8") as f: 244 | lines = [line for line in f] 245 | # If there are malformed lines, read in binary mode 246 | # and manually decode each word from utf-8 247 | except: 248 | logger.warning("Could not read {} as UTF8 file, " 249 | "reading file as bytes and skipping " 250 | "words with malformed UTF8.".format(path)) 251 | with open(path, 'rb') as f: 252 | lines = [line for line in f] 253 | binary_lines = True 254 | 255 | logger.info("Loading vectors from {}".format(path)) 256 | for line in tqdm(lines, total=len(lines)): 257 | # Explicitly splitting on " " is important, so we don't 258 | # get rid of Unicode non-breaking spaces in the vectors. 259 | entries = line.rstrip().split(b" " if binary_lines else " ") 260 | 261 | word, entries = entries[0], entries[1:] 262 | if dim is None and len(entries) > 1: 263 | dim = len(entries) 264 | elif len(entries) == 1: 265 | logger.warning("Skipping token {} with 1-dimensional " 266 | "vector {}; likely a header".format(word, entries)) 267 | continue 268 | elif dim != len(entries): 269 | raise RuntimeError( 270 | "Vector for token {} has {} dimensions, but previously " 271 | "read vectors have {} dimensions. All vectors must have " 272 | "the same number of dimensions.".format(word, len(entries), dim)) 273 | 274 | if binary_lines: 275 | try: 276 | if isinstance(word, six.binary_type): 277 | word = word.decode('utf-8') 278 | except: 279 | logger.info("Skipping non-UTF8 token {}".format(repr(word))) 280 | continue 281 | vectors.extend(float(x) for x in entries) 282 | itos.append(word) 283 | 284 | self.itos = itos 285 | self.stoi = {word: i for i, word in enumerate(itos)} 286 | self.vectors = torch.Tensor(vectors).view(-1, dim) 287 | self.dim = dim 288 | logger.info('Saving vectors to {}'.format(path_pt)) 289 | if not os.path.exists(cache): 290 | os.makedirs(cache) 291 | torch.save((self.itos, self.stoi, self.vectors, self.dim), path_pt) 292 | else: 293 | logger.info('Loading vectors from {}'.format(path_pt)) 294 | self.itos, self.stoi, self.vectors, self.dim = torch.load(path_pt) 295 | 296 | 297 | class GloVe(Vectors): 298 | url = { 299 | '42B': 'http://nlp.stanford.edu/data/glove.42B.300d.zip', 300 | '840B': 'http://nlp.stanford.edu/data/glove.840B.300d.zip', 301 | 'twitter.27B': 'http://nlp.stanford.edu/data/glove.twitter.27B.zip', 302 | '6B': 'http://nlp.stanford.edu/data/glove.6B.zip', 303 | } 304 | 305 | def __init__(self, name='840B', dim=300, **kwargs): 306 | url = self.url[name] 307 | name = 'glove.{}.{}d.txt'.format(name, str(dim)) 308 | super(GloVe, self).__init__(name, url=url, **kwargs) 309 | 310 | 311 | class FastText(Vectors): 312 | 313 | url_base = 'https://s3-us-west-1.amazonaws.com/fasttext-vectors/wiki.{}.vec' 314 | 315 | def __init__(self, language="en", **kwargs): 316 | url = self.url_base.format(language) 317 | name = os.path.basename(url) 318 | super(FastText, self).__init__(name, url=url, **kwargs) 319 | 320 | 321 | class CharNGram(Vectors): 322 | 323 | name = 'charNgram.txt' 324 | url = ('http://www.logos.t.u-tokyo.ac.jp/~hassy/publications/arxiv2016jmt/' 325 | 'jmt_pre-trained_embeddings.tar.gz') 326 | 327 | def __init__(self, **kwargs): 328 | super(CharNGram, self).__init__(self.name, url=self.url, **kwargs) 329 | 330 | def __getitem__(self, token): 331 | vector = torch.Tensor(1, self.dim).zero_() 332 | if token == "": 333 | return self.unk_init(vector) 334 | # These literals need to be coerced to unicode for Python 2 compatibility 335 | # when we try to join them with read ngrams from the files. 336 | chars = ['#BEGIN#'] + list(token) + ['#END#'] 337 | num_vectors = 0 338 | for n in [2, 3, 4]: 339 | end = len(chars) - n + 1 340 | grams = [chars[i:(i + n)] for i in range(end)] 341 | for gram in grams: 342 | gram_key = '{}gram-{}'.format(n, ''.join(gram)) 343 | if gram_key in self.stoi: 344 | vector += self.vectors[self.stoi[gram_key]] 345 | num_vectors += 1 346 | if num_vectors > 0: 347 | vector /= num_vectors 348 | else: 349 | vector = self.unk_init(vector) 350 | return vector 351 | 352 | 353 | def _default_unk_index(): 354 | return 0 355 | 356 | 357 | pretrained_aliases = { 358 | "charngram.100d": partial(CharNGram), 359 | "fasttext.en.300d": partial(FastText, language="en"), 360 | "fasttext.simple.300d": partial(FastText, language="simple"), 361 | "glove.42B.300d": partial(GloVe, name="42B", dim="300"), 362 | "glove.840B.300d": partial(GloVe, name="840B", dim="300"), 363 | "glove.twitter.27B.25d": partial(GloVe, name="twitter.27B", dim="25"), 364 | "glove.twitter.27B.50d": partial(GloVe, name="twitter.27B", dim="50"), 365 | "glove.twitter.27B.100d": partial(GloVe, name="twitter.27B", dim="100"), 366 | "glove.twitter.27B.200d": partial(GloVe, name="twitter.27B", dim="200"), 367 | "glove.6B.50d": partial(GloVe, name="6B", dim="50"), 368 | "glove.6B.100d": partial(GloVe, name="6B", dim="100"), 369 | "glove.6B.200d": partial(GloVe, name="6B", dim="200"), 370 | "glove.6B.300d": partial(GloVe, name="6B", dim="300") 371 | } 372 | """Mapping from string name to factory function""" 373 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: m2release 2 | channels: 3 | - anaconda 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - asn1crypto=1.2.0=py36_0 8 | - blas=1.0=mkl 9 | - ca-certificates=2019.10.16=0 10 | - certifi=2019.9.11=py36_0 11 | - cffi=1.13.2=py36h2e261b9_0 12 | - chardet=3.0.4=py36_1003 13 | - cryptography=2.8=py36h1ba5d50_0 14 | - cython=0.29.14=py36he6710b0_0 15 | - dill=0.2.9=py36_0 16 | - idna=2.8=py36_0 17 | - intel-openmp=2019.5=281 18 | - libedit=3.1.20181209=hc058e9b_0 19 | - libffi=3.2.1=hd88cf55_4 20 | - libgcc-ng=9.1.0=hdf63c60_0 21 | - libgfortran-ng=7.3.0=hdf63c60_0 22 | - libstdcxx-ng=9.1.0=hdf63c60_0 23 | - mkl=2019.5=281 24 | - mkl-service=2.3.0=py36he904b0f_0 25 | - mkl_fft=1.0.15=py36ha843d7b_0 26 | - mkl_random=1.1.0=py36hd6b4f25_0 27 | - msgpack-numpy=0.4.4.3=py_0 28 | - msgpack-python=0.5.6=py36h6bb024c_1 29 | - ncurses=6.1=he6710b0_1 30 | - openjdk=8.0.152=h46b5887_1 31 | - openssl=1.1.1=h7b6447c_0 32 | - pip=19.3.1=py36_0 33 | - pycparser=2.19=py_0 34 | - pyopenssl=19.1.0=py36_0 35 | - pysocks=1.7.1=py36_0 36 | - python=3.6.9=h265db76_0 37 | - readline=7.0=h7b6447c_5 38 | - requests=2.22.0=py36_0 39 | - setuptools=41.6.0=py36_0 40 | - six=1.13.0=py36_0 41 | - spacy=2.0.11=py36h04863e7_2 42 | - sqlite=3.30.1=h7b6447c_0 43 | - termcolor=1.1.0=py36_1 44 | - thinc=6.11.2=py36hedc7406_1 45 | - tk=8.6.8=hbc83047_0 46 | - toolz=0.10.0=py_0 47 | - urllib3=1.24.2=py36_0 48 | - wheel=0.33.6=py36_0 49 | - xz=5.2.4=h14c3975_4 50 | - zlib=1.2.11=h7b6447c_3 51 | - pip: 52 | - absl-py==0.8.1 53 | - cycler==0.10.0 54 | - cymem==1.31.2 55 | - cytoolz==0.9.0.1 56 | - future==0.17.1 57 | - grpcio==1.25.0 58 | - h5py==2.8.0 59 | - kiwisolver==1.1.0 60 | - markdown==3.1.1 61 | - matplotlib==2.2.3 62 | - msgpack==0.6.2 63 | - multiprocess==0.70.9 64 | - murmurhash==0.28.0 65 | - numpy==1.16.4 66 | - pathlib==1.0.1 67 | - pathos==0.2.3 68 | - pillow==6.2.1 69 | - plac==0.9.6 70 | - pox==0.2.7 71 | - ppft==1.6.6.1 72 | - preshed==1.0.1 73 | - protobuf==3.10.0 74 | - pycocotools==2.0.0 75 | - pyparsing==2.4.5 76 | - python-dateutil==2.8.1 77 | - pytz==2019.3 78 | - regex==2017.4.5 79 | - tensorboard==1.14.0 80 | - torch==1.1.0 81 | - torchvision==0.3.0 82 | - tqdm==4.32.2 83 | - ujson==1.35 84 | - werkzeug==0.16.0 85 | - wrapt==1.10.11 86 | - transformers 87 | 88 | -------------------------------------------------------------------------------- /evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .bleu import Bleu 2 | from .meteor import Meteor 3 | from .rouge import Rouge 4 | from .cider import Cider 5 | from .tokenizer import PTBTokenizer 6 | 7 | def compute_scores(gts, gen): 8 | metrics = (Bleu(), Meteor(), Rouge(), Cider()) 9 | all_score = {} 10 | all_scores = {} 11 | for metric in metrics: 12 | score, scores = metric.compute_score(gts, gen) 13 | all_score[str(metric)] = score 14 | all_scores[str(metric)] = scores 15 | 16 | return all_score, all_scores 17 | -------------------------------------------------------------------------------- /evaluation/bleu/__init__.py: -------------------------------------------------------------------------------- 1 | from .bleu import Bleu -------------------------------------------------------------------------------- /evaluation/bleu/bleu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : bleu.py 4 | # 5 | # Description : Wrapper for BLEU scorer. 6 | # 7 | # Creation Date : 06-01-2015 8 | # Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT 9 | # Authors : Hao Fang and Tsung-Yi Lin 10 | 11 | from .bleu_scorer import BleuScorer 12 | 13 | 14 | class Bleu: 15 | def __init__(self, n=4): 16 | # default compute Blue score up to 4 17 | self._n = n 18 | self._hypo_for_image = {} 19 | self.ref_for_image = {} 20 | 21 | def compute_score(self, gts, res): 22 | 23 | assert(gts.keys() == res.keys()) 24 | imgIds = gts.keys() 25 | 26 | bleu_scorer = BleuScorer(n=self._n) 27 | for id in imgIds: 28 | hypo = res[id] 29 | ref = gts[id] 30 | 31 | # Sanity check. 32 | assert(type(hypo) is list) 33 | assert(len(hypo) == 1) 34 | assert(type(ref) is list) 35 | assert(len(ref) >= 1) 36 | 37 | bleu_scorer += (hypo[0], ref) 38 | 39 | # score, scores = bleu_scorer.compute_score(option='shortest') 40 | score, scores = bleu_scorer.compute_score(option='closest', verbose=0) 41 | # score, scores = bleu_scorer.compute_score(option='average', verbose=1) 42 | 43 | return score, scores 44 | 45 | def __str__(self): 46 | return 'BLEU' 47 | -------------------------------------------------------------------------------- /evaluation/bleu/bleu_scorer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # bleu_scorer.py 4 | # David Chiang 5 | 6 | # Copyright (c) 2004-2006 University of Maryland. All rights 7 | # reserved. Do not redistribute without permission from the 8 | # author. Not for commercial use. 9 | 10 | # Modified by: 11 | # Hao Fang 12 | # Tsung-Yi Lin 13 | 14 | ''' Provides: 15 | cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test(). 16 | cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked(). 17 | ''' 18 | 19 | import copy 20 | import sys, math, re 21 | from collections import defaultdict 22 | 23 | 24 | def precook(s, n=4, out=False): 25 | """Takes a string as input and returns an object that can be given to 26 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 27 | can take string arguments as well.""" 28 | words = s.split() 29 | counts = defaultdict(int) 30 | for k in range(1, n + 1): 31 | for i in range(len(words) - k + 1): 32 | ngram = tuple(words[i:i + k]) 33 | counts[ngram] += 1 34 | return (len(words), counts) 35 | 36 | 37 | def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "average" 38 | '''Takes a list of reference sentences for a single segment 39 | and returns an object that encapsulates everything that BLEU 40 | needs to know about them.''' 41 | 42 | reflen = [] 43 | maxcounts = {} 44 | for ref in refs: 45 | rl, counts = precook(ref, n) 46 | reflen.append(rl) 47 | for (ngram, count) in counts.items(): 48 | maxcounts[ngram] = max(maxcounts.get(ngram, 0), count) 49 | 50 | # Calculate effective reference sentence length. 51 | if eff == "shortest": 52 | reflen = min(reflen) 53 | elif eff == "average": 54 | reflen = float(sum(reflen)) / len(reflen) 55 | 56 | ## lhuang: N.B.: leave reflen computaiton to the very end!! 57 | 58 | ## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design) 59 | 60 | return (reflen, maxcounts) 61 | 62 | 63 | def cook_test(test, ref_tuple, eff=None, n=4): 64 | '''Takes a test sentence and returns an object that 65 | encapsulates everything that BLEU needs to know about it.''' 66 | 67 | testlen, counts = precook(test, n, True) 68 | reflen, refmaxcounts = ref_tuple 69 | 70 | result = {} 71 | 72 | # Calculate effective reference sentence length. 73 | 74 | if eff == "closest": 75 | result["reflen"] = min((abs(l - testlen), l) for l in reflen)[1] 76 | else: ## i.e., "average" or "shortest" or None 77 | result["reflen"] = reflen 78 | 79 | result["testlen"] = testlen 80 | 81 | result["guess"] = [max(0, testlen - k + 1) for k in range(1, n + 1)] 82 | 83 | result['correct'] = [0] * n 84 | for (ngram, count) in counts.items(): 85 | result["correct"][len(ngram) - 1] += min(refmaxcounts.get(ngram, 0), count) 86 | 87 | return result 88 | 89 | 90 | class BleuScorer(object): 91 | """Bleu scorer. 92 | """ 93 | 94 | __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen" 95 | 96 | # special_reflen is used in oracle (proportional effective ref len for a node). 97 | 98 | def copy(self): 99 | ''' copy the refs.''' 100 | new = BleuScorer(n=self.n) 101 | new.ctest = copy.copy(self.ctest) 102 | new.crefs = copy.copy(self.crefs) 103 | new._score = None 104 | return new 105 | 106 | def __init__(self, test=None, refs=None, n=4, special_reflen=None): 107 | ''' singular instance ''' 108 | 109 | self.n = n 110 | self.crefs = [] 111 | self.ctest = [] 112 | self.cook_append(test, refs) 113 | self.special_reflen = special_reflen 114 | 115 | def cook_append(self, test, refs): 116 | '''called by constructor and __iadd__ to avoid creating new instances.''' 117 | 118 | if refs is not None: 119 | self.crefs.append(cook_refs(refs)) 120 | if test is not None: 121 | cooked_test = cook_test(test, self.crefs[-1]) 122 | self.ctest.append(cooked_test) ## N.B.: -1 123 | else: 124 | self.ctest.append(None) # lens of crefs and ctest have to match 125 | 126 | self._score = None ## need to recompute 127 | 128 | def ratio(self, option=None): 129 | self.compute_score(option=option) 130 | return self._ratio 131 | 132 | def score_ratio(self, option=None): 133 | ''' 134 | return (bleu, len_ratio) pair 135 | ''' 136 | 137 | return self.fscore(option=option), self.ratio(option=option) 138 | 139 | def score_ratio_str(self, option=None): 140 | return "%.4f (%.2f)" % self.score_ratio(option) 141 | 142 | def reflen(self, option=None): 143 | self.compute_score(option=option) 144 | return self._reflen 145 | 146 | def testlen(self, option=None): 147 | self.compute_score(option=option) 148 | return self._testlen 149 | 150 | def retest(self, new_test): 151 | if type(new_test) is str: 152 | new_test = [new_test] 153 | assert len(new_test) == len(self.crefs), new_test 154 | self.ctest = [] 155 | for t, rs in zip(new_test, self.crefs): 156 | self.ctest.append(cook_test(t, rs)) 157 | self._score = None 158 | 159 | return self 160 | 161 | def rescore(self, new_test): 162 | ''' replace test(s) with new test(s), and returns the new score.''' 163 | 164 | return self.retest(new_test).compute_score() 165 | 166 | def size(self): 167 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) 168 | return len(self.crefs) 169 | 170 | def __iadd__(self, other): 171 | '''add an instance (e.g., from another sentence).''' 172 | 173 | if type(other) is tuple: 174 | ## avoid creating new BleuScorer instances 175 | self.cook_append(other[0], other[1]) 176 | else: 177 | assert self.compatible(other), "incompatible BLEUs." 178 | self.ctest.extend(other.ctest) 179 | self.crefs.extend(other.crefs) 180 | self._score = None ## need to recompute 181 | 182 | return self 183 | 184 | def compatible(self, other): 185 | return isinstance(other, BleuScorer) and self.n == other.n 186 | 187 | def single_reflen(self, option="average"): 188 | return self._single_reflen(self.crefs[0][0], option) 189 | 190 | def _single_reflen(self, reflens, option=None, testlen=None): 191 | 192 | if option == "shortest": 193 | reflen = min(reflens) 194 | elif option == "average": 195 | reflen = float(sum(reflens)) / len(reflens) 196 | elif option == "closest": 197 | reflen = min((abs(l - testlen), l) for l in reflens)[1] 198 | else: 199 | assert False, "unsupported reflen option %s" % option 200 | 201 | return reflen 202 | 203 | def recompute_score(self, option=None, verbose=0): 204 | self._score = None 205 | return self.compute_score(option, verbose) 206 | 207 | def compute_score(self, option=None, verbose=0): 208 | n = self.n 209 | small = 1e-9 210 | tiny = 1e-15 ## so that if guess is 0 still return 0 211 | bleu_list = [[] for _ in range(n)] 212 | 213 | if self._score is not None: 214 | return self._score 215 | 216 | if option is None: 217 | option = "average" if len(self.crefs) == 1 else "closest" 218 | 219 | self._testlen = 0 220 | self._reflen = 0 221 | totalcomps = {'testlen': 0, 'reflen': 0, 'guess': [0] * n, 'correct': [0] * n} 222 | 223 | # for each sentence 224 | for comps in self.ctest: 225 | testlen = comps['testlen'] 226 | self._testlen += testlen 227 | 228 | if self.special_reflen is None: ## need computation 229 | reflen = self._single_reflen(comps['reflen'], option, testlen) 230 | else: 231 | reflen = self.special_reflen 232 | 233 | self._reflen += reflen 234 | 235 | for key in ['guess', 'correct']: 236 | for k in range(n): 237 | totalcomps[key][k] += comps[key][k] 238 | 239 | # append per image bleu score 240 | bleu = 1. 241 | for k in range(n): 242 | bleu *= (float(comps['correct'][k]) + tiny) \ 243 | / (float(comps['guess'][k]) + small) 244 | bleu_list[k].append(bleu ** (1. / (k + 1))) 245 | ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division 246 | if ratio < 1: 247 | for k in range(n): 248 | bleu_list[k][-1] *= math.exp(1 - 1 / ratio) 249 | 250 | if verbose > 1: 251 | print(comps, reflen) 252 | 253 | totalcomps['reflen'] = self._reflen 254 | totalcomps['testlen'] = self._testlen 255 | 256 | bleus = [] 257 | bleu = 1. 258 | for k in range(n): 259 | bleu *= float(totalcomps['correct'][k] + tiny) \ 260 | / (totalcomps['guess'][k] + small) 261 | bleus.append(bleu ** (1. / (k + 1))) 262 | ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division 263 | if ratio < 1: 264 | for k in range(n): 265 | bleus[k] *= math.exp(1 - 1 / ratio) 266 | 267 | if verbose > 0: 268 | print(totalcomps) 269 | print("ratio:", ratio) 270 | 271 | self._score = bleus 272 | return self._score, bleu_list 273 | -------------------------------------------------------------------------------- /evaluation/cider/__init__.py: -------------------------------------------------------------------------------- 1 | from .cider import Cider -------------------------------------------------------------------------------- /evaluation/cider/cider.py: -------------------------------------------------------------------------------- 1 | # Filename: cider.py 2 | # 3 | # Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric 4 | # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726) 5 | # 6 | # Creation Date: Sun Feb 8 14:16:54 2015 7 | # 8 | # Authors: Ramakrishna Vedantam and Tsung-Yi Lin 9 | 10 | from .cider_scorer import CiderScorer 11 | 12 | class Cider: 13 | """ 14 | Main Class to compute the CIDEr metric 15 | 16 | """ 17 | def __init__(self, gts=None, n=4, sigma=6.0): 18 | # set cider to sum over 1 to 4-grams 19 | self._n = n 20 | # set the standard deviation parameter for gaussian penalty 21 | self._sigma = sigma 22 | self.doc_frequency = None 23 | self.ref_len = None 24 | if gts is not None: 25 | tmp_cider = CiderScorer(gts, n=self._n, sigma=self._sigma) 26 | self.doc_frequency = tmp_cider.doc_frequency 27 | self.ref_len = tmp_cider.ref_len 28 | 29 | def compute_score(self, gts, res): 30 | """ 31 | Main function to compute CIDEr score 32 | :param gts (dict) : dictionary with key and value 33 | res (dict) : dictionary with key and value 34 | :return: cider (float) : computed CIDEr score for the corpus 35 | """ 36 | assert(gts.keys() == res.keys()) 37 | cider_scorer = CiderScorer(gts, test=res, n=self._n, sigma=self._sigma, doc_frequency=self.doc_frequency, 38 | ref_len=self.ref_len) 39 | return cider_scorer.compute_score() 40 | 41 | def __str__(self): 42 | return 'CIDEr' 43 | -------------------------------------------------------------------------------- /evaluation/cider/cider_scorer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Tsung-Yi Lin 3 | # Ramakrishna Vedantam 4 | 5 | import copy 6 | from collections import defaultdict 7 | import numpy as np 8 | import math 9 | 10 | def precook(s, n=4): 11 | """ 12 | Takes a string as input and returns an object that can be given to 13 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 14 | can take string arguments as well. 15 | :param s: string : sentence to be converted into ngrams 16 | :param n: int : number of ngrams for which representation is calculated 17 | :return: term frequency vector for occuring ngrams 18 | """ 19 | words = s.split() 20 | counts = defaultdict(int) 21 | for k in range(1,n+1): 22 | for i in range(len(words)-k+1): 23 | ngram = tuple(words[i:i+k]) 24 | counts[ngram] += 1 25 | return counts 26 | 27 | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average" 28 | '''Takes a list of reference sentences for a single segment 29 | and returns an object that encapsulates everything that BLEU 30 | needs to know about them. 31 | :param refs: list of string : reference sentences for some image 32 | :param n: int : number of ngrams for which (ngram) representation is calculated 33 | :return: result (list of dict) 34 | ''' 35 | return [precook(ref, n) for ref in refs] 36 | 37 | def cook_test(test, n=4): 38 | '''Takes a test sentence and returns an object that 39 | encapsulates everything that BLEU needs to know about it. 40 | :param test: list of string : hypothesis sentence for some image 41 | :param n: int : number of ngrams for which (ngram) representation is calculated 42 | :return: result (dict) 43 | ''' 44 | return precook(test, n) 45 | 46 | class CiderScorer(object): 47 | """CIDEr scorer. 48 | """ 49 | 50 | def __init__(self, refs, test=None, n=4, sigma=6.0, doc_frequency=None, ref_len=None): 51 | ''' singular instance ''' 52 | self.n = n 53 | self.sigma = sigma 54 | self.crefs = [] 55 | self.ctest = [] 56 | self.doc_frequency = defaultdict(float) 57 | self.ref_len = None 58 | 59 | for k in refs.keys(): 60 | self.crefs.append(cook_refs(refs[k])) 61 | if test is not None: 62 | self.ctest.append(cook_test(test[k][0])) ## N.B.: -1 63 | else: 64 | self.ctest.append(None) # lens of crefs and ctest have to match 65 | 66 | if doc_frequency is None and ref_len is None: 67 | # compute idf 68 | self.compute_doc_freq() 69 | # compute log reference length 70 | self.ref_len = np.log(float(len(self.crefs))) 71 | else: 72 | self.doc_frequency = doc_frequency 73 | self.ref_len = ref_len 74 | 75 | def compute_doc_freq(self): 76 | ''' 77 | Compute term frequency for reference data. 78 | This will be used to compute idf (inverse document frequency later) 79 | The term frequency is stored in the object 80 | :return: None 81 | ''' 82 | for refs in self.crefs: 83 | # refs, k ref captions of one image 84 | for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]): 85 | self.doc_frequency[ngram] += 1 86 | # maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 87 | 88 | def compute_cider(self): 89 | def counts2vec(cnts): 90 | """ 91 | Function maps counts of ngram to vector of tfidf weights. 92 | The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights. 93 | The n-th entry of array denotes length of n-grams. 94 | :param cnts: 95 | :return: vec (array of dict), norm (array of float), length (int) 96 | """ 97 | vec = [defaultdict(float) for _ in range(self.n)] 98 | length = 0 99 | norm = [0.0 for _ in range(self.n)] 100 | for (ngram,term_freq) in cnts.items(): 101 | # give word count 1 if it doesn't appear in reference corpus 102 | df = np.log(max(1.0, self.doc_frequency[ngram])) 103 | # ngram index 104 | n = len(ngram)-1 105 | # tf (term_freq) * idf (precomputed idf) for n-grams 106 | vec[n][ngram] = float(term_freq)*(self.ref_len - df) 107 | # compute norm for the vector. the norm will be used for computing similarity 108 | norm[n] += pow(vec[n][ngram], 2) 109 | 110 | if n == 1: 111 | length += term_freq 112 | norm = [np.sqrt(n) for n in norm] 113 | return vec, norm, length 114 | 115 | def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref): 116 | ''' 117 | Compute the cosine similarity of two vectors. 118 | :param vec_hyp: array of dictionary for vector corresponding to hypothesis 119 | :param vec_ref: array of dictionary for vector corresponding to reference 120 | :param norm_hyp: array of float for vector corresponding to hypothesis 121 | :param norm_ref: array of float for vector corresponding to reference 122 | :param length_hyp: int containing length of hypothesis 123 | :param length_ref: int containing length of reference 124 | :return: array of score for each n-grams cosine similarity 125 | ''' 126 | delta = float(length_hyp - length_ref) 127 | # measure consine similarity 128 | val = np.array([0.0 for _ in range(self.n)]) 129 | for n in range(self.n): 130 | # ngram 131 | for (ngram,count) in vec_hyp[n].items(): 132 | # vrama91 : added clipping 133 | val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram] 134 | 135 | if (norm_hyp[n] != 0) and (norm_ref[n] != 0): 136 | val[n] /= (norm_hyp[n]*norm_ref[n]) 137 | 138 | assert(not math.isnan(val[n])) 139 | # vrama91: added a length based gaussian penalty 140 | val[n] *= np.e**(-(delta**2)/(2*self.sigma**2)) 141 | return val 142 | 143 | scores = [] 144 | for test, refs in zip(self.ctest, self.crefs): 145 | # compute vector for test captions 146 | vec, norm, length = counts2vec(test) 147 | # compute vector for ref captions 148 | score = np.array([0.0 for _ in range(self.n)]) 149 | for ref in refs: 150 | vec_ref, norm_ref, length_ref = counts2vec(ref) 151 | score += sim(vec, vec_ref, norm, norm_ref, length, length_ref) 152 | # change by vrama91 - mean of ngram scores, instead of sum 153 | score_avg = np.mean(score) 154 | # divide by number of references 155 | score_avg /= len(refs) 156 | # multiply score by 10 157 | score_avg *= 10.0 158 | # append score of an image to the score list 159 | scores.append(score_avg) 160 | return scores 161 | 162 | def compute_score(self): 163 | # compute cider score 164 | score = self.compute_cider() 165 | # debug 166 | # print score 167 | return np.mean(np.array(score)), np.array(score) -------------------------------------------------------------------------------- /evaluation/meteor/__init__.py: -------------------------------------------------------------------------------- 1 | from .meteor import Meteor -------------------------------------------------------------------------------- /evaluation/meteor/meteor.py: -------------------------------------------------------------------------------- 1 | # Python wrapper for METEOR implementation, by Xinlei Chen 2 | # Acknowledge Michael Denkowski for the generous discussion and help 3 | 4 | import os 5 | import subprocess 6 | import threading 7 | import tarfile 8 | from utils import download_from_url 9 | 10 | METEOR_GZ_URL = 'http://aimagelab.ing.unimore.it/speaksee/data/meteor.tgz' 11 | METEOR_JAR = 'meteor-1.5.jar' 12 | 13 | class Meteor: 14 | def __init__(self): 15 | base_path = os.path.dirname(os.path.abspath(__file__)) 16 | jar_path = os.path.join(base_path, METEOR_JAR) 17 | gz_path = os.path.join(base_path, os.path.basename(METEOR_GZ_URL)) 18 | if not os.path.isfile(jar_path): 19 | if not os.path.isfile(gz_path): 20 | download_from_url(METEOR_GZ_URL, gz_path) 21 | tar = tarfile.open(gz_path, "r") 22 | tar.extractall(path=os.path.dirname(os.path.abspath(__file__))) 23 | tar.close() 24 | os.remove(gz_path) 25 | 26 | self.meteor_cmd = ['java', '-jar', '-Xmx2G', METEOR_JAR, \ 27 | '-', '-', '-stdio', '-l', 'en', '-norm'] 28 | self.meteor_p = subprocess.Popen(self.meteor_cmd, \ 29 | cwd=os.path.dirname(os.path.abspath(__file__)), \ 30 | stdin=subprocess.PIPE, \ 31 | stdout=subprocess.PIPE, \ 32 | stderr=subprocess.PIPE) 33 | # Used to guarantee thread safety 34 | self.lock = threading.Lock() 35 | 36 | def compute_score(self, gts, res): 37 | assert(gts.keys() == res.keys()) 38 | imgIds = gts.keys() 39 | scores = [] 40 | 41 | eval_line = 'EVAL' 42 | self.lock.acquire() 43 | for i in imgIds: 44 | assert(len(res[i]) == 1) 45 | stat = self._stat(res[i][0], gts[i]) 46 | eval_line += ' ||| {}'.format(stat) 47 | 48 | self.meteor_p.stdin.write('{}\n'.format(eval_line).encode()) 49 | self.meteor_p.stdin.flush() 50 | for i in range(0,len(imgIds)): 51 | scores.append(float(self.meteor_p.stdout.readline().strip())) 52 | score = float(self.meteor_p.stdout.readline().strip()) 53 | self.lock.release() 54 | 55 | return score, scores 56 | 57 | def _stat(self, hypothesis_str, reference_list): 58 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words 59 | hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ') 60 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)) 61 | self.meteor_p.stdin.write('{}\n'.format(score_line).encode()) 62 | self.meteor_p.stdin.flush() 63 | raw = self.meteor_p.stdout.readline().decode().strip() 64 | numbers = [str(int(float(n))) for n in raw.split()] 65 | return ' '.join(numbers) 66 | 67 | def __del__(self): 68 | self.lock.acquire() 69 | self.meteor_p.stdin.close() 70 | self.meteor_p.kill() 71 | self.meteor_p.wait() 72 | self.lock.release() 73 | 74 | def __str__(self): 75 | return 'METEOR' 76 | -------------------------------------------------------------------------------- /evaluation/rouge/__init__.py: -------------------------------------------------------------------------------- 1 | from .rouge import Rouge -------------------------------------------------------------------------------- /evaluation/rouge/rouge.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : rouge.py 4 | # 5 | # Description : Computes ROUGE-L metric as described by Lin and Hovey (2004) 6 | # 7 | # Creation Date : 2015-01-07 06:03 8 | # Author : Ramakrishna Vedantam 9 | 10 | import numpy as np 11 | import pdb 12 | 13 | 14 | def my_lcs(string, sub): 15 | """ 16 | Calculates longest common subsequence for a pair of tokenized strings 17 | :param string : list of str : tokens from a string split using whitespace 18 | :param sub : list of str : shorter string, also split using whitespace 19 | :returns: length (list of int): length of the longest common subsequence between the two strings 20 | 21 | Note: my_lcs only gives length of the longest common subsequence, not the actual LCS 22 | """ 23 | if (len(string) < len(sub)): 24 | sub, string = string, sub 25 | 26 | lengths = [[0 for i in range(0, len(sub) + 1)] for j in range(0, len(string) + 1)] 27 | 28 | for j in range(1, len(sub) + 1): 29 | for i in range(1, len(string) + 1): 30 | if (string[i - 1] == sub[j - 1]): 31 | lengths[i][j] = lengths[i - 1][j - 1] + 1 32 | else: 33 | lengths[i][j] = max(lengths[i - 1][j], lengths[i][j - 1]) 34 | 35 | return lengths[len(string)][len(sub)] 36 | 37 | 38 | class Rouge(): 39 | ''' 40 | Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set 41 | 42 | ''' 43 | 44 | def __init__(self): 45 | # vrama91: updated the value below based on discussion with Hovey 46 | self.beta = 1.2 47 | 48 | def calc_score(self, candidate, refs): 49 | """ 50 | Compute ROUGE-L score given one candidate and references for an image 51 | :param candidate: str : candidate sentence to be evaluated 52 | :param refs: list of str : COCO reference sentences for the particular image to be evaluated 53 | :returns score: int (ROUGE-L score for the candidate evaluated against references) 54 | """ 55 | assert (len(candidate) == 1) 56 | assert (len(refs) > 0) 57 | prec = [] 58 | rec = [] 59 | 60 | # split into tokens 61 | token_c = candidate[0].split(" ") 62 | 63 | for reference in refs: 64 | # split into tokens 65 | token_r = reference.split(" ") 66 | # compute the longest common subsequence 67 | lcs = my_lcs(token_r, token_c) 68 | prec.append(lcs / float(len(token_c))) 69 | rec.append(lcs / float(len(token_r))) 70 | 71 | prec_max = max(prec) 72 | rec_max = max(rec) 73 | 74 | if (prec_max != 0 and rec_max != 0): 75 | score = ((1 + self.beta ** 2) * prec_max * rec_max) / float(rec_max + self.beta ** 2 * prec_max) 76 | else: 77 | score = 0.0 78 | return score 79 | 80 | def compute_score(self, gts, res): 81 | """ 82 | Computes Rouge-L score given a set of reference and candidate sentences for the dataset 83 | Invoked by evaluate_captions.py 84 | :param hypo_for_image: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values 85 | :param ref_for_image: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values 86 | :returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images) 87 | """ 88 | assert (gts.keys() == res.keys()) 89 | imgIds = gts.keys() 90 | 91 | score = [] 92 | for id in imgIds: 93 | hypo = res[id] 94 | ref = gts[id] 95 | 96 | score.append(self.calc_score(hypo, ref)) 97 | 98 | # Sanity check. 99 | assert (type(hypo) is list) 100 | assert (len(hypo) == 1) 101 | assert (type(ref) is list) 102 | assert (len(ref) > 0) 103 | 104 | average_score = np.mean(np.array(score)) 105 | return average_score, np.array(score) 106 | 107 | def __str__(self): 108 | return 'ROUGE' 109 | -------------------------------------------------------------------------------- /evaluation/stanford-corenlp-3.4.1.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangxuying1004/RSTNet/d308459ad1c28e0f4c2c3e4a9986ae76870e3c72/evaluation/stanford-corenlp-3.4.1.jar -------------------------------------------------------------------------------- /evaluation/tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : ptbtokenizer.py 4 | # 5 | # Description : Do the PTB Tokenization and remove punctuations. 6 | # 7 | # Creation Date : 29-12-2014 8 | # Last Modified : Thu Mar 19 09:53:35 2015 9 | # Authors : Hao Fang and Tsung-Yi Lin 10 | 11 | import os 12 | import subprocess 13 | import tempfile 14 | 15 | class PTBTokenizer(object): 16 | """Python wrapper of Stanford PTBTokenizer""" 17 | 18 | corenlp_jar = 'stanford-corenlp-3.4.1.jar' 19 | punctuations = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \ 20 | ".", "?", "!", ",", ":", "-", "--", "...", ";"] 21 | 22 | @classmethod 23 | def tokenize(cls, corpus): 24 | cmd = ['java', '-cp', cls.corenlp_jar, \ 25 | 'edu.stanford.nlp.process.PTBTokenizer', \ 26 | '-preserveLines', '-lowerCase'] 27 | 28 | if isinstance(corpus, list) or isinstance(corpus, tuple): 29 | if isinstance(corpus[0], list) or isinstance(corpus[0], tuple): 30 | corpus = {i:c for i, c in enumerate(corpus)} 31 | else: 32 | corpus = {i: [c, ] for i, c in enumerate(corpus)} 33 | 34 | # prepare data for PTB Tokenizer 35 | tokenized_corpus = {} 36 | image_id = [k for k, v in list(corpus.items()) for _ in range(len(v))] 37 | sentences = '\n'.join([c.replace('\n', ' ') for k, v in corpus.items() for c in v]) 38 | 39 | # save sentences to temporary file 40 | path_to_jar_dirname=os.path.dirname(os.path.abspath(__file__)) 41 | tmp_file = tempfile.NamedTemporaryFile(delete=False, dir=path_to_jar_dirname) 42 | tmp_file.write(sentences.encode()) 43 | tmp_file.close() 44 | 45 | # tokenize sentence 46 | cmd.append(os.path.basename(tmp_file.name)) 47 | p_tokenizer = subprocess.Popen(cmd, cwd=path_to_jar_dirname, \ 48 | stdout=subprocess.PIPE, stderr=open(os.devnull, 'w')) 49 | token_lines = p_tokenizer.communicate(input=sentences.rstrip())[0] 50 | token_lines = token_lines.decode() 51 | lines = token_lines.split('\n') 52 | # remove temp file 53 | os.remove(tmp_file.name) 54 | 55 | # create dictionary for tokenized captions 56 | for k, line in zip(image_id, lines): 57 | if not k in tokenized_corpus: 58 | tokenized_corpus[k] = [] 59 | tokenized_caption = ' '.join([w for w in line.rstrip().split(' ') \ 60 | if w not in cls.punctuations]) 61 | tokenized_corpus[k].append(tokenized_caption) 62 | 63 | return tokenized_corpus -------------------------------------------------------------------------------- /feats_process.py: -------------------------------------------------------------------------------- 1 | import os 2 | import h5py 3 | import argparse 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | class DataProcessor(nn.Module): 10 | def __init__(self): 11 | super(DataProcessor, self).__init__() 12 | self.pool = nn.AdaptiveAvgPool2d((7, 7)) 13 | 14 | def forward(self, x): 15 | x = self.pool(x) 16 | x = torch.squeeze(x) # [1, d, h, w] => [d, h, w] 17 | x = x.permute(1, 2, 0) # [d, h, w] => [h, w, d] 18 | return x.view(-1, x.size(-1)) # [h*w, d] 19 | 20 | 21 | def process_dataset(file_path, feat_paths): 22 | print('save the ori grid features to the features with specified size') 23 | # 加载特征处理器 24 | processor = DataProcessor() 25 | with h5py.File(file_path, 'w') as f: 26 | for i in tqdm(range(len(feat_paths))): 27 | # 加载特征 28 | feat_path = feat_paths[i] 29 | img_feat = torch.load(feat_path) 30 | # 处理特征 31 | img_feat = processor(img_feat) 32 | # 保存特征 33 | img_name = feat_path.split('/')[-1] 34 | img_id = int(img_name.split('.')[0]) 35 | f.create_dataset('%d_grids' % img_id, data=img_feat.numpy()) 36 | f.close() 37 | 38 | 39 | def get_feat_paths(dir_to_save_feats, data_split='trainval', test2014_info_path=None): 40 | print('get the paths of raw grid features') 41 | ans = [] 42 | # 线下训练和测试 43 | if data_split == 'trainval': 44 | filenames_train = os.listdir(os.path.join(dir_to_save_feats, 'train2014')) 45 | ans_train = [os.path.join(dir_to_save_feats, 'train2014', filename) for filename in filenames_train] 46 | filenames_val = os.listdir(os.path.join(dir_to_save_feats, 'val2014')) 47 | ans_val = [os.path.join(dir_to_save_feats, 'val2014', filename) for filename in filenames_val] 48 | ans = ans_train + ans_val 49 | # 线上测试 50 | elif data_split == 'test': 51 | assert test2014_info_path is not None 52 | with open(test2014_info_path, 'r') as f: 53 | test2014_info = json.load(f) 54 | 55 | for image in test2014_info['images']: 56 | img_id = image['id'] 57 | feat_path = os.path.join(dir_to_save_feats, 'test2015', img_id+'.pth') 58 | assert os.path.exists(feat_path) 59 | ans.append(feat_path) 60 | assert len(ans) == 40775 61 | assert not ans # make sure ans list is not empty 62 | return ans 63 | 64 | 65 | def main(args): 66 | # 加载原始特征的绝对路径 67 | feat_paths = get_feat_paths(args.dir_to_raw_feats, args.data_split, args.test2014_info_path) 68 | # 构建处理后特征的文件名和保存路径 69 | file_path = os.path.join(args.dir_to_processed_feats, 'X101_grid_feats_coco_'+args.data_split+'.hdf5') 70 | # 处理特征并保存 71 | process_dataset(file_path, feat_paths) 72 | print('finished!') 73 | 74 | 75 | if __name__ == '__main__': 76 | 77 | parser = argparse.ArgumentParser(description='data process') 78 | parser.add_argument('--dir_to_raw_feats', type=str, default='./datasets/X101-features/') 79 | parser.add_argument('--dir_to_processed_feats', type=str, default='./datasets/X101-features/') 80 | # trainval = train2014 + val2014,用于训练和线下测试,test = test2014,用于线上测试 81 | parser.add_argument('--data_split', type=str, default='trainval') # trainval, test 82 | # test2015包含test2014,获取test2014时,先加载test2014索引再加载特征,image_info_test2014.json是保存test2014信息的文件 83 | parser.add_argument('--test2014_info_path', type=str, default='./datasets/m2_annotations/image_info_test2014.json') 84 | args = parser.parse_args() 85 | 86 | main(args) 87 | 88 | -------------------------------------------------------------------------------- /images/RSTNet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangxuying1004/RSTNet/d308459ad1c28e0f4c2c3e4a9986ae76870e3c72/images/RSTNet.png -------------------------------------------------------------------------------- /images/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangxuying1004/RSTNet/d308459ad1c28e0f4c2c3e4a9986ae76870e3c72/images/results.png -------------------------------------------------------------------------------- /images/train_cider.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangxuying1004/RSTNet/d308459ad1c28e0f4c2c3e4a9986ae76870e3c72/images/train_cider.png -------------------------------------------------------------------------------- /images/visualness.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangxuying1004/RSTNet/d308459ad1c28e0f4c2c3e4a9986ae76870e3c72/images/visualness.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .transformer import Transformer 2 | from .captioning_model import CaptioningModel 3 | -------------------------------------------------------------------------------- /models/beam_search/__init__.py: -------------------------------------------------------------------------------- 1 | from .beam_search import BeamSearch 2 | -------------------------------------------------------------------------------- /models/beam_search/beam_search.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import utils 3 | 4 | 5 | class BeamSearch(object): 6 | def __init__(self, model, max_len: int, eos_idx: int, beam_size: int): 7 | self.model = model 8 | self.max_len = max_len 9 | self.eos_idx = eos_idx 10 | self.beam_size = beam_size 11 | self.b_s = None 12 | self.device = None 13 | self.seq_mask = None 14 | self.seq_logprob = None 15 | self.outputs = None 16 | self.log_probs = None 17 | self.selected_words = None 18 | self.all_log_probs = None 19 | 20 | def _expand_state(self, selected_beam, cur_beam_size): 21 | def fn(s): 22 | shape = [int(sh) for sh in s.shape] 23 | beam = selected_beam 24 | for _ in shape[1:]: 25 | beam = beam.unsqueeze(-1) 26 | s = torch.gather(s.view(*([self.b_s, cur_beam_size] + shape[1:])), 1, 27 | beam.expand(*([self.b_s, self.beam_size] + shape[1:]))) 28 | s = s.view(*([-1, ] + shape[1:])) 29 | return s 30 | 31 | return fn 32 | 33 | def _expand_visual(self, visual: utils.TensorOrSequence, cur_beam_size: int, selected_beam: torch.Tensor): 34 | if isinstance(visual, torch.Tensor): 35 | visual_shape = visual.shape 36 | visual_exp_shape = (self.b_s, cur_beam_size) + visual_shape[1:] 37 | visual_red_shape = (self.b_s * self.beam_size,) + visual_shape[1:] 38 | selected_beam_red_size = (self.b_s, self.beam_size) + tuple(1 for _ in range(len(visual_exp_shape) - 2)) 39 | selected_beam_exp_size = (self.b_s, self.beam_size) + visual_exp_shape[2:] 40 | visual_exp = visual.view(visual_exp_shape) 41 | selected_beam_exp = selected_beam.view(selected_beam_red_size).expand(selected_beam_exp_size) 42 | visual = torch.gather(visual_exp, 1, selected_beam_exp).view(visual_red_shape) 43 | else: 44 | new_visual = [] 45 | for im in visual: 46 | visual_shape = im.shape 47 | visual_exp_shape = (self.b_s, cur_beam_size) + visual_shape[1:] 48 | visual_red_shape = (self.b_s * self.beam_size,) + visual_shape[1:] 49 | selected_beam_red_size = (self.b_s, self.beam_size) + tuple(1 for _ in range(len(visual_exp_shape) - 2)) 50 | selected_beam_exp_size = (self.b_s, self.beam_size) + visual_exp_shape[2:] 51 | visual_exp = im.view(visual_exp_shape) 52 | selected_beam_exp = selected_beam.view(selected_beam_red_size).expand(selected_beam_exp_size) 53 | new_im = torch.gather(visual_exp, 1, selected_beam_exp).view(visual_red_shape) 54 | new_visual.append(new_im) 55 | visual = tuple(new_visual) 56 | return visual 57 | 58 | def apply(self, visual: utils.TensorOrSequence, out_size=1, return_probs=False, **kwargs): 59 | self.b_s = utils.get_batch_size(visual) 60 | self.device = utils.get_device(visual) 61 | self.seq_mask = torch.ones((self.b_s, self.beam_size, 1), device=self.device) 62 | self.seq_logprob = torch.zeros((self.b_s, 1, 1), device=self.device) 63 | self.log_probs = [] 64 | self.selected_words = None 65 | if return_probs: 66 | self.all_log_probs = [] 67 | 68 | outputs = [] 69 | with self.model.statefulness(self.b_s): 70 | for t in range(self.max_len): 71 | visual, outputs = self.iter(t, visual, outputs, return_probs, **kwargs) 72 | 73 | # Sort result 74 | seq_logprob, sort_idxs = torch.sort(self.seq_logprob, 1, descending=True) 75 | outputs = torch.cat(outputs, -1) 76 | outputs = torch.gather(outputs, 1, sort_idxs.expand(self.b_s, self.beam_size, self.max_len)) 77 | log_probs = torch.cat(self.log_probs, -1) 78 | log_probs = torch.gather(log_probs, 1, sort_idxs.expand(self.b_s, self.beam_size, self.max_len)) 79 | if return_probs: 80 | all_log_probs = torch.cat(self.all_log_probs, 2) 81 | all_log_probs = torch.gather(all_log_probs, 1, sort_idxs.unsqueeze(-1).expand(self.b_s, self.beam_size, 82 | self.max_len, 83 | all_log_probs.shape[-1])) 84 | 85 | outputs = outputs.contiguous()[:, :out_size] 86 | log_probs = log_probs.contiguous()[:, :out_size] 87 | if out_size == 1: 88 | outputs = outputs.squeeze(1) 89 | log_probs = log_probs.squeeze(1) 90 | 91 | if return_probs: 92 | return outputs, log_probs, all_log_probs 93 | else: 94 | return outputs, log_probs 95 | 96 | def select(self, t, candidate_logprob, **kwargs): 97 | selected_logprob, selected_idx = torch.sort(candidate_logprob.view(self.b_s, -1), -1, descending=True) 98 | selected_logprob, selected_idx = selected_logprob[:, :self.beam_size], selected_idx[:, :self.beam_size] 99 | return selected_idx, selected_logprob 100 | 101 | def iter(self, t: int, visual: utils.TensorOrSequence, outputs, return_probs, **kwargs): 102 | cur_beam_size = 1 if t == 0 else self.beam_size 103 | 104 | word_logprob = self.model.step(t, self.selected_words, visual, None, mode='feedback', **kwargs) 105 | word_logprob = word_logprob.view(self.b_s, cur_beam_size, -1) 106 | candidate_logprob = self.seq_logprob + word_logprob 107 | 108 | # Mask sequence if it reaches EOS 109 | if t > 0: 110 | mask = (self.selected_words.view(self.b_s, cur_beam_size) != self.eos_idx).float().unsqueeze(-1) 111 | self.seq_mask = self.seq_mask * mask 112 | word_logprob = word_logprob * self.seq_mask.expand_as(word_logprob) 113 | old_seq_logprob = self.seq_logprob.expand_as(candidate_logprob).contiguous() 114 | old_seq_logprob[:, :, 1:] = -999 115 | candidate_logprob = self.seq_mask * candidate_logprob + old_seq_logprob * (1 - self.seq_mask) 116 | 117 | selected_idx, selected_logprob = self.select(t, candidate_logprob, **kwargs) 118 | selected_beam = selected_idx / candidate_logprob.shape[-1] 119 | selected_words = selected_idx - selected_beam * candidate_logprob.shape[-1] 120 | 121 | self.model.apply_to_states(self._expand_state(selected_beam, cur_beam_size)) 122 | visual = self._expand_visual(visual, cur_beam_size, selected_beam) 123 | 124 | self.seq_logprob = selected_logprob.unsqueeze(-1) 125 | self.seq_mask = torch.gather(self.seq_mask, 1, selected_beam.unsqueeze(-1)) 126 | outputs = list(torch.gather(o, 1, selected_beam.unsqueeze(-1)) for o in outputs) 127 | outputs.append(selected_words.unsqueeze(-1)) 128 | 129 | if return_probs: 130 | if t == 0: 131 | self.all_log_probs.append(word_logprob.expand((self.b_s, self.beam_size, -1)).unsqueeze(2)) 132 | else: 133 | self.all_log_probs.append(word_logprob.unsqueeze(2)) 134 | 135 | this_word_logprob = torch.gather(word_logprob, 1, 136 | selected_beam.unsqueeze(-1).expand(self.b_s, self.beam_size, 137 | word_logprob.shape[-1])) 138 | this_word_logprob = torch.gather(this_word_logprob, 2, selected_words.unsqueeze(-1)) 139 | self.log_probs = list( 140 | torch.gather(o, 1, selected_beam.unsqueeze(-1).expand(self.b_s, self.beam_size, 1)) for o in self.log_probs) 141 | self.log_probs.append(this_word_logprob) 142 | self.selected_words = selected_words.view(-1, 1) 143 | 144 | return visual, outputs 145 | -------------------------------------------------------------------------------- /models/captioning_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import distributions 3 | import utils 4 | from models.containers import Module 5 | from models.beam_search import * 6 | 7 | 8 | class CaptioningModel(Module): 9 | def __init__(self): 10 | super(CaptioningModel, self).__init__() 11 | 12 | def init_weights(self): 13 | raise NotImplementedError 14 | 15 | def step(self, t, prev_output, visual, seq, mode='teacher_forcing', **kwargs): 16 | raise NotImplementedError 17 | 18 | def forward(self, images, seq, *args): 19 | device = images.device 20 | b_s = images.size(0) 21 | seq_len = seq.size(1) 22 | state = self.init_state(b_s, device) 23 | out = None 24 | 25 | outputs = [] 26 | for t in range(seq_len): 27 | out, state = self.step(t, state, out, images, seq, *args, mode='teacher_forcing') 28 | outputs.append(out) 29 | 30 | outputs = torch.cat([o.unsqueeze(1) for o in outputs], 1) 31 | return outputs 32 | 33 | def test(self, visual: utils.TensorOrSequence, max_len: int, eos_idx: int, **kwargs) -> utils.Tuple[torch.Tensor, torch.Tensor]: 34 | b_s = utils.get_batch_size(visual) 35 | device = utils.get_device(visual) 36 | outputs = [] 37 | log_probs = [] 38 | 39 | mask = torch.ones((b_s,), device=device) 40 | with self.statefulness(b_s): 41 | out = None 42 | for t in range(max_len): 43 | log_probs_t = self.step(t, out, visual, None, mode='feedback', **kwargs) 44 | out = torch.max(log_probs_t, -1)[1] 45 | mask = mask * (out.squeeze(-1) != eos_idx).float() 46 | log_probs.append(log_probs_t * mask.unsqueeze(-1).unsqueeze(-1)) 47 | outputs.append(out) 48 | 49 | return torch.cat(outputs, 1), torch.cat(log_probs, 1) 50 | 51 | def sample_rl(self, visual: utils.TensorOrSequence, max_len: int, **kwargs) -> utils.Tuple[torch.Tensor, torch.Tensor]: 52 | b_s = utils.get_batch_size(visual) 53 | outputs = [] 54 | log_probs = [] 55 | 56 | with self.statefulness(b_s): 57 | out = None 58 | for t in range(max_len): 59 | out = self.step(t, out, visual, None, mode='feedback', **kwargs) 60 | distr = distributions.Categorical(logits=out[:, 0]) 61 | out = distr.sample().unsqueeze(1) 62 | outputs.append(out) 63 | log_probs.append(distr.log_prob(out).unsqueeze(1)) 64 | 65 | return torch.cat(outputs, 1), torch.cat(log_probs, 1) 66 | 67 | def beam_search(self, visual: utils.TensorOrSequence, max_len: int, eos_idx: int, beam_size: int, out_size=1, 68 | return_probs=False, **kwargs): 69 | bs = BeamSearch(self, max_len, eos_idx, beam_size) 70 | return bs.apply(visual, out_size, return_probs, **kwargs) 71 | -------------------------------------------------------------------------------- /models/containers.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from torch import nn 3 | from utils.typing import * 4 | 5 | 6 | class Module(nn.Module): 7 | def __init__(self): 8 | super(Module, self).__init__() 9 | self._is_stateful = False 10 | self._state_names = [] 11 | self._state_defaults = dict() 12 | 13 | def register_state(self, name: str, default: TensorOrNone): 14 | self._state_names.append(name) 15 | if default is None: 16 | self._state_defaults[name] = None 17 | else: 18 | self._state_defaults[name] = default.clone().detach() 19 | self.register_buffer(name, default) 20 | 21 | def states(self): 22 | for name in self._state_names: 23 | yield self._buffers[name] 24 | for m in self.children(): 25 | if isinstance(m, Module): 26 | yield from m.states() 27 | 28 | def apply_to_states(self, fn): 29 | for name in self._state_names: 30 | self._buffers[name] = fn(self._buffers[name]) 31 | for m in self.children(): 32 | if isinstance(m, Module): 33 | m.apply_to_states(fn) 34 | 35 | def _init_states(self, batch_size: int): 36 | for name in self._state_names: 37 | if self._state_defaults[name] is None: 38 | self._buffers[name] = None 39 | else: 40 | self._buffers[name] = self._state_defaults[name].clone().detach().to(self._buffers[name].device) 41 | self._buffers[name] = self._buffers[name].unsqueeze(0) 42 | self._buffers[name] = self._buffers[name].expand([batch_size, ] + list(self._buffers[name].shape[1:])) 43 | self._buffers[name] = self._buffers[name].contiguous() 44 | 45 | def _reset_states(self): 46 | for name in self._state_names: 47 | if self._state_defaults[name] is None: 48 | self._buffers[name] = None 49 | else: 50 | self._buffers[name] = self._state_defaults[name].clone().detach().to(self._buffers[name].device) 51 | 52 | def enable_statefulness(self, batch_size: int): 53 | for m in self.children(): 54 | if isinstance(m, Module): 55 | m.enable_statefulness(batch_size) 56 | self._init_states(batch_size) 57 | self._is_stateful = True 58 | 59 | def disable_statefulness(self): 60 | for m in self.children(): 61 | if isinstance(m, Module): 62 | m.disable_statefulness() 63 | self._reset_states() 64 | self._is_stateful = False 65 | 66 | @contextmanager 67 | def statefulness(self, batch_size: int): 68 | self.enable_statefulness(batch_size) 69 | try: 70 | yield 71 | finally: 72 | self.disable_statefulness() 73 | 74 | 75 | class ModuleList(nn.ModuleList, Module): 76 | pass 77 | 78 | 79 | class ModuleDict(nn.ModuleDict, Module): 80 | pass 81 | -------------------------------------------------------------------------------- /models/m2_transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .transformer import * 2 | from .encoders import * 3 | from .decoders import * 4 | from .attention import * 5 | -------------------------------------------------------------------------------- /models/m2_transformer/attention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from models.containers import Module 5 | 6 | 7 | class ScaledDotProductAttention(nn.Module): 8 | ''' 9 | Scaled dot-product attention 10 | ''' 11 | 12 | def __init__(self, d_model, d_k, d_v, h): 13 | ''' 14 | :param d_model: Output dimensionality of the model 15 | :param d_k: Dimensionality of queries and keys 16 | :param d_v: Dimensionality of values 17 | :param h: Number of heads 18 | ''' 19 | super(ScaledDotProductAttention, self).__init__() 20 | self.fc_q = nn.Linear(d_model, h * d_k) 21 | self.fc_k = nn.Linear(d_model, h * d_k) 22 | self.fc_v = nn.Linear(d_model, h * d_v) 23 | self.fc_o = nn.Linear(h * d_v, d_model) 24 | 25 | self.d_model = d_model 26 | self.d_k = d_k 27 | self.d_v = d_v 28 | self.h = h 29 | 30 | self.init_weights() 31 | 32 | def init_weights(self): 33 | nn.init.xavier_uniform_(self.fc_q.weight) 34 | nn.init.xavier_uniform_(self.fc_k.weight) 35 | nn.init.xavier_uniform_(self.fc_v.weight) 36 | nn.init.xavier_uniform_(self.fc_o.weight) 37 | nn.init.constant_(self.fc_q.bias, 0) 38 | nn.init.constant_(self.fc_k.bias, 0) 39 | nn.init.constant_(self.fc_v.bias, 0) 40 | nn.init.constant_(self.fc_o.bias, 0) 41 | 42 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None): 43 | ''' 44 | Computes 45 | :param queries: Queries (b_s, nq, d_model) 46 | :param keys: Keys (b_s, nk, d_model) 47 | :param values: Values (b_s, nk, d_model) 48 | :param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking. 49 | :param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk). 50 | :return: 51 | ''' 52 | b_s, nq = queries.shape[:2] 53 | nk = keys.shape[1] 54 | q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k) 55 | k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk) 56 | v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v) 57 | 58 | att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk) 59 | if attention_weights is not None: 60 | att = att * attention_weights 61 | if attention_mask is not None: 62 | att = att.masked_fill(attention_mask, -np.inf) 63 | att = torch.softmax(att, -1) 64 | out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v) 65 | out = self.fc_o(out) # (b_s, nq, d_model) 66 | return out 67 | 68 | 69 | class ScaledDotProductAttentionMemory(nn.Module): 70 | ''' 71 | Scaled dot-product attention with memory 72 | ''' 73 | 74 | def __init__(self, d_model, d_k, d_v, h, m): 75 | ''' 76 | :param d_model: Output dimensionality of the model 77 | :param d_k: Dimensionality of queries and keys 78 | :param d_v: Dimensionality of values 79 | :param h: Number of heads 80 | :param m: Number of memory slots 81 | ''' 82 | super(ScaledDotProductAttentionMemory, self).__init__() 83 | self.fc_q = nn.Linear(d_model, h * d_k) 84 | self.fc_k = nn.Linear(d_model, h * d_k) 85 | self.fc_v = nn.Linear(d_model, h * d_v) 86 | self.fc_o = nn.Linear(h * d_v, d_model) 87 | self.m_k = nn.Parameter(torch.FloatTensor(1, m, h * d_k)) 88 | self.m_v = nn.Parameter(torch.FloatTensor(1, m, h * d_v)) 89 | 90 | self.d_model = d_model 91 | self.d_k = d_k 92 | self.d_v = d_v 93 | self.h = h 94 | self.m = m 95 | 96 | self.init_weights() 97 | 98 | def init_weights(self): 99 | nn.init.xavier_uniform_(self.fc_q.weight) 100 | nn.init.xavier_uniform_(self.fc_k.weight) 101 | nn.init.xavier_uniform_(self.fc_v.weight) 102 | nn.init.xavier_uniform_(self.fc_o.weight) 103 | nn.init.normal_(self.m_k, 0, 1 / self.d_k) 104 | nn.init.normal_(self.m_v, 0, 1 / self.m) 105 | nn.init.constant_(self.fc_q.bias, 0) 106 | nn.init.constant_(self.fc_k.bias, 0) 107 | nn.init.constant_(self.fc_v.bias, 0) 108 | nn.init.constant_(self.fc_o.bias, 0) 109 | 110 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None): 111 | ''' 112 | Computes 113 | :param queries: Queries (b_s, nq, d_model) 114 | :param keys: Keys (b_s, nk, d_model) 115 | :param values: Values (b_s, nk, d_model) 116 | :param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking. 117 | :param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk). 118 | :return: 119 | ''' 120 | b_s, nq = queries.shape[:2] 121 | nk = keys.shape[1] 122 | 123 | m_k = np.sqrt(self.d_k) * self.m_k.expand(b_s, self.m, self.h * self.d_k) 124 | m_v = np.sqrt(self.m) * self.m_v.expand(b_s, self.m, self.h * self.d_v) 125 | 126 | q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k) 127 | k = torch.cat([self.fc_k(keys), m_k], 1).view(b_s, nk + self.m, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk) 128 | v = torch.cat([self.fc_v(values), m_v], 1).view(b_s, nk + self.m, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v) 129 | 130 | att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk) 131 | if attention_weights is not None: 132 | att = torch.cat([att[:, :, :, :nk] * attention_weights, att[:, :, :, nk:]], -1) 133 | if attention_mask is not None: 134 | att[:, :, :, :nk] = att[:, :, :, :nk].masked_fill(attention_mask, -np.inf) 135 | att = torch.softmax(att, -1) 136 | out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v) 137 | out = self.fc_o(out) # (b_s, nq, d_model) 138 | return out 139 | 140 | 141 | class MultiHeadAttention(Module): 142 | ''' 143 | Multi-head attention layer with Dropout and Layer Normalization. 144 | ''' 145 | 146 | def __init__(self, d_model, d_k, d_v, h, dropout=.1, identity_map_reordering=False, can_be_stateful=False, 147 | attention_module=None, attention_module_kwargs=None): 148 | super(MultiHeadAttention, self).__init__() 149 | self.identity_map_reordering = identity_map_reordering 150 | if attention_module is not None: 151 | if attention_module_kwargs is not None: 152 | self.attention = attention_module(d_model=d_model, d_k=d_k, d_v=d_v, h=h, **attention_module_kwargs) 153 | else: 154 | self.attention = attention_module(d_model=d_model, d_k=d_k, d_v=d_v, h=h) 155 | else: 156 | self.attention = ScaledDotProductAttention(d_model=d_model, d_k=d_k, d_v=d_v, h=h) 157 | self.dropout = nn.Dropout(p=dropout) 158 | self.layer_norm = nn.LayerNorm(d_model) 159 | 160 | self.can_be_stateful = can_be_stateful 161 | if self.can_be_stateful: 162 | self.register_state('running_keys', torch.zeros((0, d_model))) 163 | self.register_state('running_values', torch.zeros((0, d_model))) 164 | 165 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None): 166 | if self.can_be_stateful and self._is_stateful: 167 | self.running_keys = torch.cat([self.running_keys, keys], 1) 168 | keys = self.running_keys 169 | 170 | self.running_values = torch.cat([self.running_values, values], 1) 171 | values = self.running_values 172 | 173 | if self.identity_map_reordering: 174 | q_norm = self.layer_norm(queries) 175 | k_norm = self.layer_norm(keys) 176 | v_norm = self.layer_norm(values) 177 | out = self.attention(q_norm, k_norm, v_norm, attention_mask, attention_weights) 178 | out = queries + self.dropout(torch.relu(out)) 179 | else: 180 | out = self.attention(queries, keys, values, attention_mask, attention_weights) 181 | out = self.dropout(out) 182 | out = self.layer_norm(queries + out) 183 | return out 184 | -------------------------------------------------------------------------------- /models/m2_transformer/decoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import numpy as np 5 | 6 | from models.m2_transformer.attention import MultiHeadAttention 7 | from models.m2_transformer.utils import sinusoid_encoding_table, PositionWiseFeedForward 8 | from models.containers import Module, ModuleList 9 | 10 | 11 | class MeshedDecoderLayer(Module): 12 | def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, self_att_module=None, 13 | enc_att_module=None, self_att_module_kwargs=None, enc_att_module_kwargs=None): 14 | super(MeshedDecoderLayer, self).__init__() 15 | self.self_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=True, 16 | attention_module=self_att_module, 17 | attention_module_kwargs=self_att_module_kwargs) 18 | self.enc_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=False, 19 | attention_module=enc_att_module, 20 | attention_module_kwargs=enc_att_module_kwargs) 21 | self.pwff = PositionWiseFeedForward(d_model, d_ff, dropout) 22 | 23 | self.fc_alpha1 = nn.Linear(d_model + d_model, d_model) 24 | self.fc_alpha2 = nn.Linear(d_model + d_model, d_model) 25 | self.fc_alpha3 = nn.Linear(d_model + d_model, d_model) 26 | 27 | self.init_weights() 28 | 29 | def init_weights(self): 30 | nn.init.xavier_uniform_(self.fc_alpha1.weight) 31 | nn.init.xavier_uniform_(self.fc_alpha2.weight) 32 | nn.init.xavier_uniform_(self.fc_alpha3.weight) 33 | nn.init.constant_(self.fc_alpha1.bias, 0) 34 | nn.init.constant_(self.fc_alpha2.bias, 0) 35 | nn.init.constant_(self.fc_alpha3.bias, 0) 36 | 37 | def forward(self, input, enc_output, mask_pad, mask_self_att, mask_enc_att): 38 | self_att = self.self_att(input, input, input, mask_self_att) 39 | self_att = self_att * mask_pad 40 | 41 | enc_att1 = self.enc_att(self_att, enc_output[:, 0], enc_output[:, 0], mask_enc_att) * mask_pad 42 | enc_att2 = self.enc_att(self_att, enc_output[:, 1], enc_output[:, 1], mask_enc_att) * mask_pad 43 | enc_att3 = self.enc_att(self_att, enc_output[:, 2], enc_output[:, 2], mask_enc_att) * mask_pad 44 | 45 | alpha1 = torch.sigmoid(self.fc_alpha1(torch.cat([self_att, enc_att1], -1))) 46 | alpha2 = torch.sigmoid(self.fc_alpha2(torch.cat([self_att, enc_att2], -1))) 47 | alpha3 = torch.sigmoid(self.fc_alpha3(torch.cat([self_att, enc_att3], -1))) 48 | 49 | enc_att = (enc_att1 * alpha1 + enc_att2 * alpha2 + enc_att3 * alpha3) / np.sqrt(3) 50 | enc_att = enc_att * mask_pad 51 | 52 | ff = self.pwff(enc_att) 53 | ff = ff * mask_pad 54 | return ff 55 | 56 | 57 | class MeshedDecoder(Module): 58 | def __init__(self, vocab_size, max_len, N_dec, padding_idx, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, 59 | self_att_module=None, enc_att_module=None, self_att_module_kwargs=None, enc_att_module_kwargs=None): 60 | super(MeshedDecoder, self).__init__() 61 | self.d_model = d_model 62 | self.word_emb = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx) 63 | self.pos_emb = nn.Embedding.from_pretrained(sinusoid_encoding_table(max_len + 1, d_model, 0), freeze=True) 64 | self.layers = ModuleList( 65 | [MeshedDecoderLayer(d_model, d_k, d_v, h, d_ff, dropout, self_att_module=self_att_module, 66 | enc_att_module=enc_att_module, self_att_module_kwargs=self_att_module_kwargs, 67 | enc_att_module_kwargs=enc_att_module_kwargs) for _ in range(N_dec)]) 68 | self.fc = nn.Linear(d_model, vocab_size, bias=False) 69 | self.max_len = max_len 70 | self.padding_idx = padding_idx 71 | self.N = N_dec 72 | 73 | self.register_state('running_mask_self_attention', torch.zeros((1, 1, 0)).byte()) 74 | self.register_state('running_seq', torch.zeros((1,)).long()) 75 | 76 | def forward(self, input, encoder_output, mask_encoder): 77 | # input (b_s, seq_len) 78 | b_s, seq_len = input.shape[:2] 79 | mask_queries = (input != self.padding_idx).unsqueeze(-1).float() # (b_s, seq_len, 1) 80 | mask_self_attention = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.uint8, device=input.device), 81 | diagonal=1) 82 | mask_self_attention = mask_self_attention.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len) 83 | mask_self_attention = mask_self_attention + (input == self.padding_idx).unsqueeze(1).unsqueeze(1).byte() 84 | mask_self_attention = mask_self_attention.gt(0) # (b_s, 1, seq_len, seq_len) 85 | if self._is_stateful: 86 | self.running_mask_self_attention = torch.cat([self.running_mask_self_attention, mask_self_attention], -1) 87 | mask_self_attention = self.running_mask_self_attention 88 | 89 | seq = torch.arange(1, seq_len + 1).view(1, -1).expand(b_s, -1).to(input.device) # (b_s, seq_len) 90 | seq = seq.masked_fill(mask_queries.squeeze(-1) == 0, 0) 91 | if self._is_stateful: 92 | self.running_seq.add_(1) 93 | seq = self.running_seq 94 | 95 | out = self.word_emb(input) + self.pos_emb(seq) 96 | for i, l in enumerate(self.layers): 97 | out = l(out, encoder_output, mask_queries, mask_self_attention, mask_encoder) 98 | 99 | out = self.fc(out) 100 | return F.log_softmax(out, dim=-1) 101 | -------------------------------------------------------------------------------- /models/m2_transformer/encoders.py: -------------------------------------------------------------------------------- 1 | from torch.nn import functional as F 2 | from models.m2_transformer.utils import PositionWiseFeedForward 3 | import torch 4 | from torch import nn 5 | from models.m2_transformer.attention import MultiHeadAttention 6 | 7 | 8 | class EncoderLayer(nn.Module): 9 | def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, identity_map_reordering=False, 10 | attention_module=None, attention_module_kwargs=None): 11 | super(EncoderLayer, self).__init__() 12 | self.identity_map_reordering = identity_map_reordering 13 | self.mhatt = MultiHeadAttention(d_model, d_k, d_v, h, dropout, identity_map_reordering=identity_map_reordering, 14 | attention_module=attention_module, 15 | attention_module_kwargs=attention_module_kwargs) 16 | self.pwff = PositionWiseFeedForward(d_model, d_ff, dropout, identity_map_reordering=identity_map_reordering) 17 | 18 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None): 19 | att = self.mhatt(queries, keys, values, attention_mask, attention_weights) 20 | ff = self.pwff(att) 21 | return ff 22 | 23 | 24 | class MultiLevelEncoder(nn.Module): 25 | def __init__(self, N, padding_idx, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, 26 | identity_map_reordering=False, attention_module=None, attention_module_kwargs=None): 27 | super(MultiLevelEncoder, self).__init__() 28 | self.d_model = d_model 29 | self.dropout = dropout 30 | self.layers = nn.ModuleList([EncoderLayer(d_model, d_k, d_v, h, d_ff, dropout, 31 | identity_map_reordering=identity_map_reordering, 32 | attention_module=attention_module, 33 | attention_module_kwargs=attention_module_kwargs) 34 | for _ in range(N)]) 35 | self.padding_idx = padding_idx 36 | 37 | def forward(self, input, attention_weights=None): 38 | # input (b_s, seq_len, d_in) 39 | attention_mask = (torch.sum(input, -1) == self.padding_idx).unsqueeze(1).unsqueeze(1) # (b_s, 1, 1, seq_len) 40 | 41 | outs = [] 42 | out = input 43 | for l in self.layers: 44 | out = l(out, out, out, attention_mask, attention_weights) 45 | outs.append(out.unsqueeze(1)) 46 | 47 | outs = torch.cat(outs, 1) 48 | return outs, attention_mask 49 | 50 | 51 | class MemoryAugmentedEncoder(MultiLevelEncoder): 52 | def __init__(self, N, padding_idx, d_in=2048, **kwargs): 53 | super(MemoryAugmentedEncoder, self).__init__(N, padding_idx, **kwargs) 54 | self.fc = nn.Linear(d_in, self.d_model) 55 | self.dropout = nn.Dropout(p=self.dropout) 56 | self.layer_norm = nn.LayerNorm(self.d_model) 57 | 58 | def forward(self, input, attention_weights=None): 59 | out = F.relu(self.fc(input)) 60 | out = self.dropout(out) 61 | out = self.layer_norm(out) 62 | return super(MemoryAugmentedEncoder, self).forward(out, attention_weights=attention_weights) 63 | -------------------------------------------------------------------------------- /models/m2_transformer/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import copy 4 | from models.containers import ModuleList 5 | from ..captioning_model import CaptioningModel 6 | 7 | 8 | class Transformer(CaptioningModel): 9 | def __init__(self, bos_idx, encoder, decoder): 10 | super(Transformer, self).__init__() 11 | self.bos_idx = bos_idx 12 | self.encoder = encoder 13 | self.decoder = decoder 14 | self.register_state('enc_output', None) 15 | self.register_state('mask_enc', None) 16 | self.init_weights() 17 | 18 | @property 19 | def d_model(self): 20 | return self.decoder.d_model 21 | 22 | def init_weights(self): 23 | for p in self.parameters(): 24 | if p.dim() > 1: 25 | nn.init.xavier_uniform_(p) 26 | 27 | def forward(self, images, seq, *args): 28 | enc_output, mask_enc = self.encoder(images) 29 | dec_output = self.decoder(seq, enc_output, mask_enc) 30 | return dec_output 31 | 32 | def init_state(self, b_s, device): 33 | return [torch.zeros((b_s, 0), dtype=torch.long, device=device), 34 | None, None] 35 | 36 | def step(self, t, prev_output, visual, seq, mode='teacher_forcing', **kwargs): 37 | it = None 38 | if mode == 'teacher_forcing': 39 | raise NotImplementedError 40 | elif mode == 'feedback': 41 | if t == 0: 42 | self.enc_output, self.mask_enc = self.encoder(visual) 43 | if isinstance(visual, torch.Tensor): 44 | it = visual.data.new_full((visual.shape[0], 1), self.bos_idx).long() 45 | else: 46 | it = visual[0].data.new_full((visual[0].shape[0], 1), self.bos_idx).long() 47 | else: 48 | it = prev_output 49 | 50 | return self.decoder(it, self.enc_output, self.mask_enc) 51 | 52 | 53 | class TransformerEnsemble(CaptioningModel): 54 | def __init__(self, model: Transformer, weight_files): 55 | super(TransformerEnsemble, self).__init__() 56 | self.n = len(weight_files) 57 | self.models = ModuleList([copy.deepcopy(model) for _ in range(self.n)]) 58 | for i in range(self.n): 59 | state_dict_i = torch.load(weight_files[i])['state_dict'] 60 | self.models[i].load_state_dict(state_dict_i) 61 | 62 | def step(self, t, prev_output, visual, seq, mode='teacher_forcing', **kwargs): 63 | out_ensemble = [] 64 | for i in range(self.n): 65 | out_i = self.models[i].step(t, prev_output, visual, seq, mode, **kwargs) 66 | out_ensemble.append(out_i.unsqueeze(0)) 67 | 68 | return torch.mean(torch.cat(out_ensemble, 0), dim=0) 69 | -------------------------------------------------------------------------------- /models/m2_transformer/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | def position_embedding(input, d_model): 7 | input = input.view(-1, 1) 8 | dim = torch.arange(d_model // 2, dtype=torch.float32, device=input.device).view(1, -1) 9 | sin = torch.sin(input / 10000 ** (2 * dim / d_model)) 10 | cos = torch.cos(input / 10000 ** (2 * dim / d_model)) 11 | 12 | out = torch.zeros((input.shape[0], d_model), device=input.device) 13 | out[:, ::2] = sin 14 | out[:, 1::2] = cos 15 | return out 16 | 17 | 18 | def sinusoid_encoding_table(max_len, d_model, padding_idx=None): 19 | pos = torch.arange(max_len, dtype=torch.float32) 20 | out = position_embedding(pos, d_model) 21 | 22 | if padding_idx is not None: 23 | out[padding_idx] = 0 24 | return out 25 | 26 | 27 | class PositionWiseFeedForward(nn.Module): 28 | ''' 29 | Position-wise feed forward layer 30 | ''' 31 | 32 | def __init__(self, d_model=512, d_ff=2048, dropout=.1, identity_map_reordering=False): 33 | super(PositionWiseFeedForward, self).__init__() 34 | self.identity_map_reordering = identity_map_reordering 35 | self.fc1 = nn.Linear(d_model, d_ff) 36 | self.fc2 = nn.Linear(d_ff, d_model) 37 | self.dropout = nn.Dropout(p=dropout) 38 | self.dropout_2 = nn.Dropout(p=dropout) 39 | self.layer_norm = nn.LayerNorm(d_model) 40 | 41 | def forward(self, input): 42 | if self.identity_map_reordering: 43 | out = self.layer_norm(input) 44 | out = self.fc2(self.dropout_2(F.relu(self.fc1(out)))) 45 | out = input + self.dropout(torch.relu(out)) 46 | else: 47 | out = self.fc2(self.dropout_2(F.relu(self.fc1(input)))) 48 | out = self.dropout(out) 49 | out = self.layer_norm(input + out) 50 | return out -------------------------------------------------------------------------------- /models/rstnet/__init__.py: -------------------------------------------------------------------------------- 1 | from .transformer import * 2 | from .encoders import * 3 | from .decoders import * 4 | from .attention import * 5 | -------------------------------------------------------------------------------- /models/rstnet/decoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | from models.transformer.attention import MultiHeadAttention 6 | from models.rstnet.attention import MultiHeadAdaptiveAttention 7 | from models.transformer.utils import sinusoid_encoding_table, PositionWiseFeedForward 8 | from models.containers import Module, ModuleList 9 | from models.rstnet.language_model import LanguageModel 10 | 11 | 12 | class DecoderLayer(Module): 13 | def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, self_att_module=None, 14 | enc_att_module=None, self_att_module_kwargs=None, enc_att_module_kwargs=None): 15 | super(DecoderLayer, self).__init__() 16 | self.self_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=True, 17 | attention_module=self_att_module, 18 | attention_module_kwargs=self_att_module_kwargs) 19 | self.enc_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=False, 20 | attention_module=enc_att_module, 21 | attention_module_kwargs=enc_att_module_kwargs) 22 | 23 | self.dropout1 = nn.Dropout(dropout) 24 | self.lnorm1 = nn.LayerNorm(d_model) 25 | 26 | self.dropout2 = nn.Dropout(dropout) 27 | self.lnorm2 = nn.LayerNorm(d_model) 28 | self.pwff = PositionWiseFeedForward(d_model, d_ff, dropout) 29 | 30 | def forward(self, input, enc_output, mask_pad, mask_self_att, mask_enc_att, pos): 31 | # MHA+AddNorm 32 | self_att = self.self_att(input, input, input, mask_self_att) 33 | self_att = self.lnorm1(input + self.dropout1(self_att)) 34 | self_att = self_att * mask_pad 35 | # MHA+AddNorm 36 | key = enc_output + pos 37 | enc_att = self.enc_att(self_att, key, enc_output, mask_enc_att) 38 | enc_att = self.lnorm2(self_att + self.dropout2(enc_att)) 39 | enc_att = enc_att * mask_pad 40 | # FFN+AddNorm 41 | ff = self.pwff(enc_att) 42 | ff = ff * mask_pad 43 | return ff 44 | 45 | 46 | class DecoderAdaptiveLayer(Module): 47 | def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, self_att_module=None, 48 | enc_att_module=None, self_att_module_kwargs=None, enc_att_module_kwargs=None): 49 | super(DecoderAdaptiveLayer, self).__init__() 50 | self.self_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=True, 51 | attention_module=self_att_module, 52 | attention_module_kwargs=self_att_module_kwargs) 53 | 54 | self.enc_att = MultiHeadAdaptiveAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=False, attention_module=enc_att_module, attention_module_kwargs=enc_att_module_kwargs) 55 | 56 | self.dropout1 = nn.Dropout(dropout) 57 | self.lnorm1 = nn.LayerNorm(d_model) 58 | 59 | self.dropout2 = nn.Dropout(dropout) 60 | self.lnorm2 = nn.LayerNorm(d_model) 61 | self.pwff = PositionWiseFeedForward(d_model, d_ff, dropout) 62 | 63 | def forward(self, input, enc_output, mask_pad, mask_self_att, mask_enc_att, language_feature=None, pos=None): 64 | # MHA+AddNorm 65 | self_att = self.self_att(input, input, input, mask_self_att) 66 | self_att = self.lnorm1(input + self.dropout1(self_att)) 67 | self_att = self_att * mask_pad 68 | # MHA+AddNorm 69 | key = enc_output + pos 70 | enc_att = self.enc_att(self_att, key, enc_output, mask_enc_att, language_feature=language_feature) 71 | enc_att = self.lnorm2(self_att + self.dropout2(enc_att)) 72 | enc_att = enc_att * mask_pad 73 | # FFN+AddNorm 74 | ff = self.pwff(enc_att) 75 | ff = ff * mask_pad 76 | return ff 77 | 78 | 79 | class TransformerDecoderLayer(Module): 80 | def __init__(self, vocab_size, max_len, N_dec, padding_idx, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, bert_hidden_size=768, dropout=.1, 81 | self_att_module=None, enc_att_module=None, self_att_module_kwargs=None, enc_att_module_kwargs=None, language_model_path=None): 82 | super(TransformerDecoderLayer, self).__init__() 83 | self.d_model = d_model 84 | self.word_emb = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx) 85 | self.pos_emb = nn.Embedding.from_pretrained(sinusoid_encoding_table(max_len + 1, d_model, 0), freeze=True) 86 | self.layers = ModuleList( 87 | [DecoderLayer(d_model, d_k, d_v, h, d_ff, dropout, self_att_module=self_att_module, enc_att_module=enc_att_module, self_att_module_kwargs=self_att_module_kwargs, enc_att_module_kwargs=enc_att_module_kwargs) if i < N_dec else DecoderAdaptiveLayer(d_model, d_k, d_v, h, d_ff, dropout, self_att_module=self_att_module, enc_att_module=enc_att_module, self_att_module_kwargs=self_att_module_kwargs, enc_att_module_kwargs=enc_att_module_kwargs) for i in range(N_dec + 1)]) 88 | self.fc = nn.Linear(d_model, vocab_size, bias=False) 89 | 90 | # 加载语言模型 91 | self.language_model = LanguageModel(padding_idx=padding_idx, bert_hidden_size=bert_hidden_size, vocab_size=vocab_size, max_len=max_len) 92 | assert language_model_path is not None 93 | language_model_file = torch.load(language_model_path) 94 | self.language_model.load_state_dict(language_model_file['state_dict'], strict=False) 95 | for p in self.language_model.parameters(): 96 | p.requires_grad = False 97 | 98 | self.max_len = max_len 99 | self.padding_idx = padding_idx 100 | self.N = N_dec 101 | 102 | self.register_state('running_mask_self_attention', torch.zeros((1, 1, 0)).byte()) 103 | self.register_state('running_seq', torch.zeros((1,)).long()) 104 | 105 | def forward(self, input, encoder_output, mask_encoder, pos): 106 | # input (b_s, seq_len) 107 | b_s, seq_len = input.shape[:2] 108 | mask_queries = (input != self.padding_idx).unsqueeze(-1).float() # (b_s, seq_len, 1) 109 | mask_self_attention = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.uint8, device=input.device), 110 | diagonal=1) 111 | mask_self_attention = mask_self_attention.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len) 112 | mask_self_attention = mask_self_attention + (input == self.padding_idx).unsqueeze(1).unsqueeze(1).byte() 113 | mask_self_attention = mask_self_attention.gt(0) # (b_s, 1, seq_len, seq_len) 114 | if self._is_stateful: 115 | self.running_mask_self_attention = torch.cat([self.running_mask_self_attention.type_as(mask_self_attention), mask_self_attention], -1) 116 | mask_self_attention = self.running_mask_self_attention 117 | 118 | seq = torch.arange(1, seq_len + 1).view(1, -1).expand(b_s, -1).to(input.device) # (b_s, seq_len) 119 | seq = seq.masked_fill(mask_queries.squeeze(-1) == 0, 0) 120 | if self._is_stateful: 121 | self.running_seq.add_(1) 122 | seq = self.running_seq 123 | 124 | out = self.word_emb(input) + self.pos_emb(seq) 125 | _, language_feature = self.language_model(input) 126 | 127 | if encoder_output.shape[0] != pos.shape[0]: 128 | assert encoder_output.shape[0] % pos.shape[0] == 0 129 | beam_size = int(encoder_output.shape[0] / pos.shape[0]) 130 | shape = (pos.shape[0], beam_size, pos.shape[1], pos.shape[2]) 131 | pos = pos.unsqueeze(1) # bs * 1 * 50 * 512 132 | pos = pos.expand(shape) # bs * 5 * 50 * 512 133 | pos = pos.contiguous().flatten(0, 1) # (bs*5) * 50 * 512 134 | 135 | for i, l in enumerate(self.layers): 136 | if i < self.N: 137 | out = l(out, encoder_output, mask_queries, mask_self_attention, mask_encoder, pos=pos) 138 | else: 139 | out = l(out, encoder_output, mask_queries, mask_self_attention, mask_encoder, language_feature, pos=pos) 140 | 141 | out = self.fc(out) 142 | return F.log_softmax(out, dim=-1) 143 | -------------------------------------------------------------------------------- /models/rstnet/encoders.py: -------------------------------------------------------------------------------- 1 | from torch.nn import functional as F 2 | from models.transformer.utils import PositionWiseFeedForward 3 | import torch 4 | from torch import nn 5 | from models.rstnet.attention import MultiHeadGeometryAttention 6 | from models.rstnet.grid_aug import BoxRelationalEmbedding 7 | 8 | 9 | class EncoderLayer(nn.Module): 10 | def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, identity_map_reordering=False, 11 | attention_module=None, attention_module_kwargs=None): 12 | super(EncoderLayer, self).__init__() 13 | self.identity_map_reordering = identity_map_reordering 14 | self.mhatt = MultiHeadGeometryAttention(d_model, d_k, d_v, h, dropout, identity_map_reordering=identity_map_reordering, 15 | attention_module=attention_module, 16 | attention_module_kwargs=attention_module_kwargs) 17 | self.dropout = nn.Dropout(dropout) 18 | self.lnorm = nn.LayerNorm(d_model) 19 | self.pwff = PositionWiseFeedForward(d_model, d_ff, dropout, identity_map_reordering=identity_map_reordering) 20 | 21 | def forward(self, queries, keys, values, relative_geometry_weights, attention_mask=None, attention_weights=None, pos=None): 22 | 23 | # q, k = (queries + pos, keys + pos) if pos is not None else (queries, keys) 24 | q = queries + pos 25 | k = keys + pos 26 | att = self.mhatt(q, k, values, relative_geometry_weights, attention_mask, attention_weights) 27 | att = self.lnorm(queries + self.dropout(att)) 28 | ff = self.pwff(att) 29 | return ff 30 | 31 | 32 | class MultiLevelEncoder(nn.Module): 33 | def __init__(self, N, padding_idx, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, 34 | identity_map_reordering=False, attention_module=None, attention_module_kwargs=None): 35 | super(MultiLevelEncoder, self).__init__() 36 | self.d_model = d_model 37 | self.dropout = dropout 38 | self.layers = nn.ModuleList([EncoderLayer(d_model, d_k, d_v, h, d_ff, dropout, 39 | identity_map_reordering=identity_map_reordering, 40 | attention_module=attention_module, 41 | attention_module_kwargs=attention_module_kwargs) 42 | for _ in range(N)]) 43 | self.padding_idx = padding_idx 44 | 45 | self.WGs = nn.ModuleList([nn.Linear(64, 1, bias=True) for _ in range(h)]) 46 | 47 | def forward(self, input, attention_weights=None, pos=None): 48 | # input (b_s, seq_len, d_in) 49 | attention_mask = (torch.sum(input, -1) == self.padding_idx).unsqueeze(1).unsqueeze(1) # (b_s, 1, 1, seq_len) 50 | 51 | # grid geometry embedding 52 | relative_geometry_embeddings = BoxRelationalEmbedding(input) 53 | flatten_relative_geometry_embeddings = relative_geometry_embeddings.view(-1, 64) 54 | box_size_per_head = list(relative_geometry_embeddings.shape[:3]) 55 | box_size_per_head.insert(1, 1) 56 | relative_geometry_weights_per_head = [layer(flatten_relative_geometry_embeddings).view(box_size_per_head) for layer in self.WGs] 57 | relative_geometry_weights = torch.cat((relative_geometry_weights_per_head), 1) 58 | relative_geometry_weights = F.relu(relative_geometry_weights) 59 | 60 | 61 | relative_geometry_embeddings = BoxRelationalEmbedding(input) 62 | flatten_relative_geometry_embeddings = relative_geometry_embeddings.view(-1, 64) 63 | box_size_per_head = list(relative_geometry_embeddings.shape[:3]) 64 | box_size_per_head.insert(1, 1) 65 | relative_geometry_weights_per_head = [layer(flatten_relative_geometry_embeddings).view(box_size_per_head) for layer in self.WGs] 66 | relative_geometry_weights = torch.cat((relative_geometry_weights_per_head), 1) 67 | relative_geometry_weights = F.relu(relative_geometry_weights) 68 | 69 | out = input 70 | for layer in self.layers: 71 | out = layer(out, out, out, relative_geometry_weights, attention_mask, attention_weights, pos=pos) 72 | 73 | return out, attention_mask 74 | 75 | 76 | class TransformerEncoder(MultiLevelEncoder): 77 | def __init__(self, N, padding_idx, d_in=2048, **kwargs): 78 | super(TransformerEncoder, self).__init__(N, padding_idx, **kwargs) 79 | self.fc = nn.Linear(d_in, self.d_model) 80 | self.dropout = nn.Dropout(p=self.dropout) 81 | self.layer_norm = nn.LayerNorm(self.d_model) 82 | 83 | def forward(self, input, attention_weights=None, pos=None): 84 | mask = (torch.sum(input, dim=-1) == 0).unsqueeze(-1) 85 | out = F.relu(self.fc(input)) 86 | out = self.dropout(out) 87 | out = self.layer_norm(out) 88 | out = out.masked_fill(mask, 0) 89 | return super(TransformerEncoder, self).forward(out, attention_weights=attention_weights, pos=pos) 90 | -------------------------------------------------------------------------------- /models/rstnet/grid_aug.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import math 4 | 5 | class PositionEmbeddingSine(nn.Module): 6 | """ 7 | This is a more standard version of the position embedding, very similar to the one 8 | used by the Attention is all you need paper, generalized to work on images. 9 | """ 10 | 11 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 12 | super().__init__() 13 | self.num_pos_feats = num_pos_feats 14 | self.temperature = temperature 15 | self.normalize = normalize 16 | if scale is not None and normalize is False: 17 | raise ValueError("normalize should be True if scale is passed") 18 | if scale is None: 19 | scale = 2 * math.pi 20 | self.scale = scale 21 | 22 | def forward(self, x, mask=None): 23 | if mask is None: 24 | mask = torch.zeros(x.shape[:-1], dtype=torch.bool, device=x.device) 25 | not_mask = (mask == False) 26 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 27 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 28 | if self.normalize: 29 | eps = 1e-6 30 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 31 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 32 | 33 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 34 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 35 | 36 | pos_x = x_embed[:, :, :, None] / dim_t 37 | pos_y = y_embed[:, :, :, None] / dim_t 38 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 39 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 40 | pos = torch.cat((pos_y, pos_x), dim=3) # .permute(0, 3, 1, 2) 41 | pos = pos.flatten(1, 2) 42 | return pos 43 | 44 | 45 | def get_relative_pos(x, bs, grid_size): 46 | x = x.view(1, -1, 1).expand(bs, -1, -1) 47 | return x / grid_size 48 | 49 | def get_grids_pos(bs, grid_size=7): 50 | x = torch.arange(0, grid_size).float().cuda() 51 | y = torch.arange(0, grid_size).float().cuda() 52 | px_min = x.view(-1, 1).expand(-1, grid_size).contiguous().view(-1) 53 | py_min = y.view(1, -1).expand(grid_size, -1).contiguous().view(-1) 54 | px_max = px_min + 1 55 | py_max = py_min + 1 56 | 57 | x_min = get_relative_pos(px_min, bs, grid_size) 58 | y_min = get_relative_pos(py_min, bs, grid_size) 59 | x_max = get_relative_pos(px_max, bs, grid_size) 60 | y_max = get_relative_pos(py_max, bs, grid_size) 61 | return y_min, x_min, y_max, x_max 62 | 63 | 64 | def BoxRelationalEmbedding(f_g, dim_g=64, wave_len=1000, trignometric_embedding=True): 65 | """ 66 | Given a tensor with bbox coordinates for detected objects on each batch image, 67 | this function computes a matrix for each image 68 | 69 | with entry (i,j) given by a vector representation of the 70 | displacement between the coordinates of bbox_i, and bbox_j 71 | 72 | input: np.array of shape=(batch_size, max_nr_bounding_boxes, 4) 73 | output: np.array of shape=(batch_size, max_nr_bounding_boxes, max_nr_bounding_boxes, 64) 74 | """ 75 | # returns a relational embedding for each pair of bboxes, with dimension = dim_g 76 | # follow implementation of https://github.com/heefe92/Relation_Networks-pytorch/blob/master/model.py#L1014-L1055 77 | 78 | batch_size = f_g.size(0) 79 | x_min, y_min, x_max, y_max = get_grids_pos(batch_size) 80 | 81 | cx = (x_min + x_max) * 0.5 82 | cy = (y_min + y_max) * 0.5 83 | w = (x_max - x_min) + 1. 84 | h = (y_max - y_min) + 1. 85 | 86 | # cx.view(1,-1) transposes the vector cx, and so dim(delta_x) = (dim(cx), dim(cx)) 87 | delta_x = cx - cx.view(batch_size, 1, -1) 88 | delta_x = torch.clamp(torch.abs(delta_x / w), min=1e-3) 89 | delta_x = torch.log(delta_x) 90 | 91 | delta_y = cy - cy.view(batch_size, 1, -1) 92 | delta_y = torch.clamp(torch.abs(delta_y / h), min=1e-3) 93 | delta_y = torch.log(delta_y) 94 | 95 | delta_w = torch.log(w / w.view(batch_size, 1, -1)) 96 | delta_h = torch.log(h / h.view(batch_size, 1, -1)) 97 | 98 | matrix_size = delta_h.size() 99 | delta_x = delta_x.view(batch_size, matrix_size[1], matrix_size[2], 1) 100 | delta_y = delta_y.view(batch_size, matrix_size[1], matrix_size[2], 1) 101 | delta_w = delta_w.view(batch_size, matrix_size[1], matrix_size[2], 1) 102 | delta_h = delta_h.view(batch_size, matrix_size[1], matrix_size[2], 1) 103 | 104 | position_mat = torch.cat((delta_x, delta_y, delta_w, delta_h), -1) # bs * r * r *4 105 | 106 | if trignometric_embedding == True: 107 | feat_range = torch.arange(dim_g / 8).cuda() 108 | dim_mat = feat_range / (dim_g / 8) 109 | dim_mat = 1. / (torch.pow(wave_len, dim_mat)) 110 | 111 | dim_mat = dim_mat.view(1, 1, 1, -1) 112 | position_mat = position_mat.view(batch_size, matrix_size[1], matrix_size[2], 4, -1) 113 | position_mat = 100. * position_mat 114 | 115 | mul_mat = position_mat * dim_mat 116 | mul_mat = mul_mat.view(batch_size, matrix_size[1], matrix_size[2], -1) 117 | sin_mat = torch.sin(mul_mat) 118 | cos_mat = torch.cos(mul_mat) 119 | embedding = torch.cat((sin_mat, cos_mat), -1) 120 | else: 121 | embedding = position_mat 122 | return (embedding) 123 | -------------------------------------------------------------------------------- /models/rstnet/language_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | from transformers import BertModel 5 | 6 | from models.transformer.attention import MultiHeadAttention 7 | from models.transformer.utils import sinusoid_encoding_table, PositionWiseFeedForward 8 | from models.containers import Module 9 | 10 | 11 | class EncoderLayer(Module): 12 | def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, self_att_module=None, 13 | enc_att_module=None, self_att_module_kwargs=None, enc_att_module_kwargs=None): 14 | super(EncoderLayer, self).__init__() 15 | self.self_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=True, 16 | attention_module=self_att_module, 17 | attention_module_kwargs=self_att_module_kwargs) 18 | 19 | self.dropout1 = nn.Dropout(dropout) 20 | self.lnorm1 = nn.LayerNorm(d_model) 21 | 22 | self.pwff = PositionWiseFeedForward(d_model, d_ff, dropout) 23 | 24 | def forward(self, input, mask_pad, mask_self_att): 25 | # MHA+AddNorm 26 | self_att = self.self_att(input, input, input, mask_self_att) 27 | self_att = self.lnorm1(input + self.dropout1(self_att)) 28 | self_att = self_att * mask_pad 29 | 30 | # FFN+AddNorm 31 | ff = self.pwff(self_att) 32 | ff = ff * mask_pad 33 | return ff 34 | 35 | 36 | class LanguageModel(Module): 37 | def __init__(self, padding_idx=0, bert_hidden_size=768, vocab_size=10201, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, max_len=54, dropout=.1): 38 | super(LanguageModel, self).__init__() 39 | self.padding_idx = padding_idx 40 | self.d_model = d_model 41 | 42 | self.language_model = BertModel.from_pretrained('bert-base-uncased', return_dict=True) 43 | self.language_model.config.vocab_size = vocab_size 44 | self.proj_to_caption_model = nn.Linear(bert_hidden_size, d_model) 45 | 46 | self.pos_emb = nn.Embedding.from_pretrained(sinusoid_encoding_table(max_len + 1, d_model, 0), freeze=True) 47 | self.encoder_layer = EncoderLayer(d_model, d_k, d_v, h, d_ff, dropout) 48 | self.proj_to_vocab = nn.Linear(d_model, vocab_size) 49 | 50 | self.register_state('running_mask_self_attention', torch.zeros((1, 1, 0)).byte()) 51 | self.register_state('running_seq', torch.zeros((1,)).long()) 52 | 53 | def forward( 54 | self, input_ids, attention_mask=None, token_type_ids=None, 55 | position_ids=None, head_mask=None, inputs_embeds=None, output_attentions=False, 56 | output_hidden_states=False, return_dict=False, encoder_hidden_states=None, 57 | encoder_attention_mask=None 58 | ): 59 | # input (b_s, seq_len) 60 | b_s, seq_len = input_ids.shape[:2] 61 | mask_queries = (input_ids != self.padding_idx).unsqueeze(-1).float() # (b_s, seq_len, 1) 62 | mask_self_attention = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.uint8, device=input_ids.device), diagonal=1) 63 | mask_self_attention = mask_self_attention.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len) 64 | mask_self_attention = mask_self_attention + (input_ids == self.padding_idx).unsqueeze(1).unsqueeze(1).byte() 65 | mask_self_attention = mask_self_attention.gt(0) # (b_s, 1, seq_len, seq_len) 66 | if self._is_stateful: 67 | self.running_mask_self_attention = torch.cat([self.running_mask_self_attention.type_as(mask_self_attention), mask_self_attention], -1) 68 | mask_self_attention = self.running_mask_self_attention 69 | seq = torch.arange(1, seq_len + 1).view(1, -1).expand(b_s, -1).to(input_ids.device) # (b_s, seq_len) 70 | seq = seq.masked_fill(mask_queries.squeeze(-1) == 0, 0) 71 | if self._is_stateful: 72 | self.running_seq.add_(1) 73 | seq = self.running_seq 74 | 75 | if attention_mask is None: 76 | attention_mask = torch.ones_like(input_ids) 77 | if token_type_ids is None: 78 | token_type_ids = torch.zeros_like(input_ids).long() 79 | 80 | bert_output = self.language_model( 81 | input_ids=input_ids, 82 | token_type_ids=token_type_ids, 83 | attention_mask=attention_mask 84 | ) 85 | language_feature = self.proj_to_caption_model(bert_output.last_hidden_state) 86 | language_feature = language_feature + self.pos_emb(seq) 87 | 88 | language_feature = self.encoder_layer(language_feature, mask_queries, mask_self_attention) 89 | 90 | logits = self.proj_to_vocab(language_feature) 91 | out = F.log_softmax(logits, dim=-1) 92 | return out, language_feature 93 | -------------------------------------------------------------------------------- /models/rstnet/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import copy 4 | from models.containers import ModuleList 5 | from ..captioning_model import CaptioningModel 6 | from .grid_aug import PositionEmbeddingSine 7 | 8 | 9 | class Transformer(CaptioningModel): 10 | def __init__(self, bos_idx, encoder, decoder): 11 | super(Transformer, self).__init__() 12 | self.bos_idx = bos_idx 13 | self.encoder = encoder 14 | self.decoder = decoder 15 | self.grid_embedding = PositionEmbeddingSine(self.decoder.d_model // 2, normalize=True) 16 | 17 | self.register_state('enc_output', None) 18 | self.register_state('mask_enc', None) 19 | self.init_weights() 20 | 21 | @property 22 | def d_model(self): 23 | return self.decoder.d_model 24 | 25 | def init_weights(self): 26 | for p in self.parameters(): 27 | if p.dim() > 1: 28 | nn.init.xavier_uniform_(p) 29 | 30 | def get_pos_embedding(self, grids): 31 | bs = grids.shape[0] 32 | grid_embed = self.grid_embedding(grids.view(bs, 7, 7, -1)) 33 | return grid_embed 34 | 35 | def forward(self, images, seq, *args): 36 | grid_embed = self.get_pos_embedding(images) 37 | enc_output, mask_enc = self.encoder(images, pos=grid_embed) 38 | dec_output = self.decoder(seq, enc_output, mask_enc, pos=grid_embed) 39 | return dec_output 40 | 41 | def init_state(self, b_s, device): 42 | return [torch.zeros((b_s, 0), dtype=torch.long, device=device), 43 | None, None] 44 | 45 | def step(self, t, prev_output, visual, seq, mode='teacher_forcing', **kwargs): 46 | it = None 47 | if mode == 'teacher_forcing': 48 | raise NotImplementedError 49 | elif mode == 'feedback': 50 | if t == 0: 51 | self.grid_embed = self.get_pos_embedding(visual) 52 | self.enc_output, self.mask_enc = self.encoder(visual, pos=self.grid_embed) 53 | 54 | if isinstance(visual, torch.Tensor): 55 | it = visual.data.new_full((visual.shape[0], 1), self.bos_idx).long() 56 | else: 57 | it = visual[0].data.new_full((visual[0].shape[0], 1), self.bos_idx).long() 58 | else: 59 | it = prev_output 60 | 61 | return self.decoder(it, self.enc_output, self.mask_enc, pos=self.grid_embed) 62 | 63 | 64 | class TransformerEnsemble(CaptioningModel): 65 | def __init__(self, model: Transformer, weight_files): 66 | super(TransformerEnsemble, self).__init__() 67 | self.n = len(weight_files) 68 | self.models = ModuleList([copy.deepcopy(model) for _ in range(self.n)]) 69 | for i in range(self.n): 70 | state_dict_i = torch.load(weight_files[i])['state_dict'] 71 | self.models[i].load_state_dict(state_dict_i) 72 | 73 | def step(self, t, prev_output, visual, seq, mode='teacher_forcing', **kwargs): 74 | out_ensemble = [] 75 | for i in range(self.n): 76 | out_i = self.models[i].step(t, prev_output, visual, seq, mode, **kwargs) 77 | out_ensemble.append(out_i.unsqueeze(0)) 78 | 79 | return torch.mean(torch.cat(out_ensemble, 0), dim=0) 80 | -------------------------------------------------------------------------------- /models/rstnet/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | def position_embedding(input, d_model): 7 | input = input.view(-1, 1) 8 | dim = torch.arange(d_model // 2, dtype=torch.float32, device=input.device).view(1, -1) 9 | sin = torch.sin(input / 10000 ** (2 * dim / d_model)) 10 | cos = torch.cos(input / 10000 ** (2 * dim / d_model)) 11 | 12 | out = torch.zeros((input.shape[0], d_model), device=input.device) 13 | out[:, ::2] = sin 14 | out[:, 1::2] = cos 15 | return out 16 | 17 | 18 | def sinusoid_encoding_table(max_len, d_model, padding_idx=None): 19 | pos = torch.arange(max_len, dtype=torch.float32) 20 | out = position_embedding(pos, d_model) 21 | 22 | if padding_idx is not None: 23 | out[padding_idx] = 0 24 | return out 25 | 26 | 27 | class PositionWiseFeedForward(nn.Module): 28 | ''' 29 | Position-wise feed forward layer 30 | ''' 31 | 32 | def __init__(self, d_model=512, d_ff=2048, dropout=.1, identity_map_reordering=False): 33 | super(PositionWiseFeedForward, self).__init__() 34 | self.identity_map_reordering = identity_map_reordering 35 | self.fc1 = nn.Linear(d_model, d_ff) 36 | self.fc2 = nn.Linear(d_ff, d_model) 37 | self.dropout = nn.Dropout(p=dropout) 38 | self.dropout_2 = nn.Dropout(p=dropout) 39 | self.layer_norm = nn.LayerNorm(d_model) 40 | 41 | def forward(self, input): 42 | if self.identity_map_reordering: 43 | out = self.layer_norm(input) 44 | out = self.fc2(self.dropout_2(F.relu(self.fc1(out)))) 45 | out = input + self.dropout(torch.relu(out)) 46 | else: 47 | out = self.fc2(self.dropout_2(F.relu(self.fc1(input)))) 48 | out = self.dropout(out) 49 | out = self.layer_norm(input + out) 50 | return out 51 | 52 | 53 | def in_norm(x_btc): 54 | 55 | b, t, c = x_btc.size() 56 | eps = 1e-6 57 | 58 | mu = torch.sum(x_btc, dim=1) / t 59 | sigma_square = torch.sum((x_btc - mu) ** 2, dim=1) / t 60 | x_btc = (x_btc - mu) / torch.sqrt(sigma_square + eps) 61 | 62 | return x_btc 63 | -------------------------------------------------------------------------------- /models/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .transformer import * 2 | from .encoders import * 3 | from .decoders import * 4 | from .attention import * 5 | -------------------------------------------------------------------------------- /models/transformer/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangxuying1004/RSTNet/d308459ad1c28e0f4c2c3e4a9986ae76870e3c72/models/transformer/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/transformer/__pycache__/attention.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangxuying1004/RSTNet/d308459ad1c28e0f4c2c3e4a9986ae76870e3c72/models/transformer/__pycache__/attention.cpython-36.pyc -------------------------------------------------------------------------------- /models/transformer/__pycache__/decoders.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangxuying1004/RSTNet/d308459ad1c28e0f4c2c3e4a9986ae76870e3c72/models/transformer/__pycache__/decoders.cpython-36.pyc -------------------------------------------------------------------------------- /models/transformer/__pycache__/encoders.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangxuying1004/RSTNet/d308459ad1c28e0f4c2c3e4a9986ae76870e3c72/models/transformer/__pycache__/encoders.cpython-36.pyc -------------------------------------------------------------------------------- /models/transformer/__pycache__/transformer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangxuying1004/RSTNet/d308459ad1c28e0f4c2c3e4a9986ae76870e3c72/models/transformer/__pycache__/transformer.cpython-36.pyc -------------------------------------------------------------------------------- /models/transformer/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangxuying1004/RSTNet/d308459ad1c28e0f4c2c3e4a9986ae76870e3c72/models/transformer/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /models/transformer/attention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from models.containers import Module 5 | 6 | 7 | class ScaledDotProductAttention(nn.Module): 8 | ''' 9 | Scaled dot-product attention 10 | ''' 11 | 12 | def __init__(self, d_model, d_k, d_v, h, dropout=.1, comment=None): 13 | ''' 14 | :param d_model: Output dimensionality of the model 15 | :param d_k: Dimensionality of queries and keys 16 | :param d_v: Dimensionality of values 17 | :param h: Number of heads 18 | ''' 19 | super(ScaledDotProductAttention, self).__init__() 20 | self.fc_q = nn.Linear(d_model, h * d_k) 21 | self.fc_k = nn.Linear(d_model, h * d_k) 22 | self.fc_v = nn.Linear(d_model, h * d_v) 23 | self.fc_o = nn.Linear(h * d_v, d_model) 24 | self.dropout = nn.Dropout(dropout) 25 | 26 | self.d_model = d_model 27 | self.d_k = d_k 28 | self.d_v = d_v 29 | self.h = h 30 | 31 | self.init_weights() 32 | 33 | self.comment = comment 34 | 35 | def init_weights(self): 36 | nn.init.xavier_uniform_(self.fc_q.weight) 37 | nn.init.xavier_uniform_(self.fc_k.weight) 38 | nn.init.xavier_uniform_(self.fc_v.weight) 39 | nn.init.xavier_uniform_(self.fc_o.weight) 40 | nn.init.constant_(self.fc_q.bias, 0) 41 | nn.init.constant_(self.fc_k.bias, 0) 42 | nn.init.constant_(self.fc_v.bias, 0) 43 | nn.init.constant_(self.fc_o.bias, 0) 44 | 45 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None): 46 | ''' 47 | Computes 48 | :param queries: Queries (b_s, nq, d_model) 49 | :param keys: Keys (b_s, nk, d_model) 50 | :param values: Values (b_s, nk, d_model) 51 | :param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking. 52 | :param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk). 53 | :return: 54 | ''' 55 | 56 | b_s, nq = queries.shape[:2] 57 | nk = keys.shape[1] 58 | 59 | q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k) 60 | k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk) 61 | v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v) 62 | 63 | att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk) 64 | if attention_weights is not None: 65 | att = att * attention_weights 66 | if attention_mask is not None: 67 | att = att.masked_fill(attention_mask, -np.inf) 68 | att = torch.softmax(att, -1) 69 | att = self.dropout(att) 70 | 71 | out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v) 72 | out = self.fc_o(out) # (b_s, nq, d_model) 73 | return out 74 | 75 | 76 | class ScaledDotProductAttentionMemory(nn.Module): 77 | ''' 78 | Scaled dot-product attention with memory 79 | ''' 80 | 81 | def __init__(self, d_model, d_k, d_v, h, m): 82 | ''' 83 | :param d_model: Output dimensionality of the model 84 | :param d_k: Dimensionality of queries and keys 85 | :param d_v: Dimensionality of values 86 | :param h: Number of heads 87 | :param m: Number of memory slots 88 | ''' 89 | super(ScaledDotProductAttentionMemory, self).__init__() 90 | self.fc_q = nn.Linear(d_model, h * d_k) 91 | self.fc_k = nn.Linear(d_model, h * d_k) 92 | self.fc_v = nn.Linear(d_model, h * d_v) 93 | self.fc_o = nn.Linear(h * d_v, d_model) 94 | self.m_k = nn.Parameter(torch.FloatTensor(1, m, h * d_k)) 95 | self.m_v = nn.Parameter(torch.FloatTensor(1, m, h * d_v)) 96 | 97 | self.d_model = d_model 98 | self.d_k = d_k 99 | self.d_v = d_v 100 | self.h = h 101 | self.m = m 102 | 103 | self.init_weights() 104 | 105 | def init_weights(self): 106 | nn.init.xavier_uniform_(self.fc_q.weight) 107 | nn.init.xavier_uniform_(self.fc_k.weight) 108 | nn.init.xavier_uniform_(self.fc_v.weight) 109 | nn.init.xavier_uniform_(self.fc_o.weight) 110 | nn.init.normal_(self.m_k, 0, 1 / self.d_k) 111 | nn.init.normal_(self.m_v, 0, 1 / self.m) 112 | nn.init.constant_(self.fc_q.bias, 0) 113 | nn.init.constant_(self.fc_k.bias, 0) 114 | nn.init.constant_(self.fc_v.bias, 0) 115 | nn.init.constant_(self.fc_o.bias, 0) 116 | 117 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None): 118 | ''' 119 | Computes 120 | :param queries: Queries (b_s, nq, d_model) 121 | :param keys: Keys (b_s, nk, d_model) 122 | :param values: Values (b_s, nk, d_model) 123 | :param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking. 124 | :param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk). 125 | :return: 126 | ''' 127 | b_s, nq = queries.shape[:2] 128 | nk = keys.shape[1] 129 | 130 | m_k = np.sqrt(self.d_k) * self.m_k.expand(b_s, self.m, self.h * self.d_k) 131 | m_v = np.sqrt(self.m) * self.m_v.expand(b_s, self.m, self.h * self.d_v) 132 | 133 | q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k) 134 | k = torch.cat([self.fc_k(keys), m_k], 1).view(b_s, nk + self.m, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk) 135 | v = torch.cat([self.fc_v(values), m_v], 1).view(b_s, nk + self.m, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v) 136 | 137 | att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk) 138 | if attention_weights is not None: 139 | att = torch.cat([att[:, :, :, :nk] * attention_weights, att[:, :, :, nk:]], -1) 140 | if attention_mask is not None: 141 | att[:, :, :, :nk] = att[:, :, :, :nk].masked_fill(attention_mask, -np.inf) 142 | att = torch.softmax(att, -1) 143 | out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v) 144 | out = self.fc_o(out) # (b_s, nq, d_model) 145 | return out 146 | 147 | 148 | class MultiHeadAttention(Module): 149 | ''' 150 | Multi-head attention layer with Dropout and Layer Normalization. 151 | ''' 152 | 153 | def __init__(self, d_model, d_k, d_v, h, dropout=.1, identity_map_reordering=False, can_be_stateful=False, 154 | attention_module=None, attention_module_kwargs=None, comment=None): 155 | super(MultiHeadAttention, self).__init__() 156 | self.identity_map_reordering = identity_map_reordering 157 | self.attention = ScaledDotProductAttention(d_model=d_model, d_k=d_k, d_v=d_v, h=h, comment=comment) 158 | self.dropout = nn.Dropout(p=dropout) 159 | self.layer_norm = nn.LayerNorm(d_model) 160 | 161 | self.can_be_stateful = can_be_stateful 162 | if self.can_be_stateful: 163 | self.register_state('running_keys', torch.zeros((0, d_model))) 164 | self.register_state('running_values', torch.zeros((0, d_model))) 165 | 166 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None): 167 | if self.can_be_stateful and self._is_stateful: 168 | self.running_keys = torch.cat([self.running_keys, keys], 1) 169 | keys = self.running_keys 170 | 171 | self.running_values = torch.cat([self.running_values, values], 1) 172 | values = self.running_values 173 | 174 | if self.identity_map_reordering: 175 | q_norm = self.layer_norm(queries) 176 | k_norm = self.layer_norm(keys) 177 | v_norm = self.layer_norm(values) 178 | out = self.attention(q_norm, k_norm, v_norm, attention_mask, attention_weights) 179 | out = queries + self.dropout(torch.relu(out)) 180 | else: 181 | out = self.attention(queries, keys, values, attention_mask, attention_weights) 182 | out = self.dropout(out) 183 | out = self.layer_norm(queries + out) 184 | return out 185 | -------------------------------------------------------------------------------- /models/transformer/decoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | from models.transformer.attention import MultiHeadAttention 6 | from models.transformer.utils import sinusoid_encoding_table, PositionWiseFeedForward 7 | from models.containers import Module, ModuleList 8 | 9 | 10 | class DecoderLayer(Module): 11 | def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, self_att_module=None, 12 | enc_att_module=None, self_att_module_kwargs=None, enc_att_module_kwargs=None): 13 | super(DecoderLayer, self).__init__() 14 | self.self_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=True, 15 | attention_module=self_att_module, 16 | attention_module_kwargs=self_att_module_kwargs) 17 | self.enc_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=False, 18 | attention_module=enc_att_module, 19 | attention_module_kwargs=enc_att_module_kwargs) 20 | 21 | self.dropout1 = nn.Dropout(dropout) 22 | self.lnorm1 = nn.LayerNorm(d_model) 23 | 24 | self.dropout2 = nn.Dropout(dropout) 25 | self.lnorm2 = nn.LayerNorm(d_model) 26 | self.pwff = PositionWiseFeedForward(d_model, d_ff, dropout) 27 | 28 | def forward(self, input, enc_output, mask_pad, mask_self_att, mask_enc_att): 29 | # MHA+AddNorm 30 | self_att = self.self_att(input, input, input, mask_self_att) 31 | self_att = self.lnorm1(input + self.dropout1(self_att)) 32 | self_att = self_att * mask_pad 33 | # MHA+AddNorm 34 | enc_att = self.enc_att(self_att, enc_output, enc_output, mask_enc_att) 35 | enc_att = self.lnorm2(self_att + self.dropout2(enc_att)) 36 | enc_att = enc_att * mask_pad 37 | # FFN+AddNorm 38 | ff = self.pwff(enc_att) 39 | ff = ff * mask_pad 40 | return ff 41 | 42 | 43 | class TransformerDecoderLayer(Module): 44 | def __init__(self, vocab_size, max_len, N_dec, padding_idx, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, 45 | self_att_module=None, enc_att_module=None, self_att_module_kwargs=None, enc_att_module_kwargs=None): 46 | super(TransformerDecoderLayer, self).__init__() 47 | self.d_model = d_model 48 | self.word_emb = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx) 49 | self.pos_emb = nn.Embedding.from_pretrained(sinusoid_encoding_table(max_len + 1, d_model, 0), freeze=True) 50 | self.layers = ModuleList( 51 | [DecoderLayer(d_model, d_k, d_v, h, d_ff, dropout, self_att_module=self_att_module, enc_att_module=enc_att_module, self_att_module_kwargs=self_att_module_kwargs, enc_att_module_kwargs=enc_att_module_kwargs) for _ in range(N_dec)]) 52 | self.fc = nn.Linear(d_model, vocab_size, bias=False) 53 | self.max_len = max_len 54 | self.padding_idx = padding_idx 55 | self.N = N_dec 56 | 57 | self.register_state('running_mask_self_attention', torch.zeros((1, 1, 0)).byte()) 58 | self.register_state('running_seq', torch.zeros((1,)).long()) 59 | 60 | def forward(self, input, encoder_output, mask_encoder): 61 | # input (b_s, seq_len) 62 | b_s, seq_len = input.shape[:2] 63 | mask_queries = (input != self.padding_idx).unsqueeze(-1).float() # (b_s, seq_len, 1) 64 | mask_self_attention = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.uint8, device=input.device), 65 | diagonal=1) 66 | mask_self_attention = mask_self_attention.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len) 67 | mask_self_attention = mask_self_attention + (input == self.padding_idx).unsqueeze(1).unsqueeze(1).byte() 68 | mask_self_attention = mask_self_attention.gt(0) # (b_s, 1, seq_len, seq_len) 69 | if self._is_stateful: 70 | self.running_mask_self_attention = torch.cat([self.running_mask_self_attention.type_as(mask_self_attention), mask_self_attention], -1) 71 | mask_self_attention = self.running_mask_self_attention 72 | 73 | seq = torch.arange(1, seq_len + 1).view(1, -1).expand(b_s, -1).to(input.device) # (b_s, seq_len) 74 | seq = seq.masked_fill(mask_queries.squeeze(-1) == 0, 0) 75 | if self._is_stateful: 76 | self.running_seq.add_(1) 77 | seq = self.running_seq 78 | 79 | out = self.word_emb(input) + self.pos_emb(seq) 80 | 81 | for i, l in enumerate(self.layers): 82 | out = l(out, encoder_output, mask_queries, mask_self_attention, mask_encoder) 83 | 84 | out = self.fc(out) 85 | return F.log_softmax(out, dim=-1) 86 | -------------------------------------------------------------------------------- /models/transformer/encoders.py: -------------------------------------------------------------------------------- 1 | from torch.nn import functional as F 2 | from models.transformer.utils import PositionWiseFeedForward 3 | import torch 4 | from torch import nn 5 | from models.transformer.attention import MultiHeadAttention 6 | 7 | 8 | class EncoderLayer(nn.Module): 9 | def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, identity_map_reordering=False, 10 | attention_module=None, attention_module_kwargs=None): 11 | super(EncoderLayer, self).__init__() 12 | self.identity_map_reordering = identity_map_reordering 13 | self.mhatt = MultiHeadAttention(d_model, d_k, d_v, h, dropout, identity_map_reordering=identity_map_reordering, 14 | attention_module=attention_module, 15 | attention_module_kwargs=attention_module_kwargs) 16 | self.dropout = nn.Dropout(dropout) 17 | self.lnorm = nn.LayerNorm(d_model) 18 | self.pwff = PositionWiseFeedForward(d_model, d_ff, dropout, identity_map_reordering=identity_map_reordering) 19 | 20 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None): 21 | 22 | att = self.mhatt(queries, keys, values, attention_mask, attention_weights) 23 | att = self.lnorm(queries + self.dropout(att)) 24 | ff = self.pwff(att) 25 | return ff 26 | 27 | 28 | class MultiLevelEncoder(nn.Module): 29 | def __init__(self, N, padding_idx, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, 30 | identity_map_reordering=False, attention_module=None, attention_module_kwargs=None): 31 | super(MultiLevelEncoder, self).__init__() 32 | self.d_model = d_model 33 | self.dropout = dropout 34 | self.layers = nn.ModuleList([EncoderLayer(d_model, d_k, d_v, h, d_ff, dropout, 35 | identity_map_reordering=identity_map_reordering, 36 | attention_module=attention_module, 37 | attention_module_kwargs=attention_module_kwargs) 38 | for _ in range(N)]) 39 | self.padding_idx = padding_idx 40 | 41 | def forward(self, input, attention_weights=None): 42 | # input (b_s, seq_len, d_in) 43 | attention_mask = (torch.sum(input, -1) == self.padding_idx).unsqueeze(1).unsqueeze(1) # (b_s, 1, 1, seq_len) 44 | 45 | out = input 46 | for l in self.layers: 47 | out = l(out, out, out, attention_mask, attention_weights) 48 | # outs.append(out.unsqueeze(1)) 49 | 50 | # outs = torch.cat(outs, 1) 51 | return out, attention_mask 52 | 53 | 54 | class TransformerEncoder(MultiLevelEncoder): 55 | def __init__(self, N, padding_idx, d_in=2048, **kwargs): 56 | super(TransformerEncoder, self).__init__(N, padding_idx, **kwargs) 57 | self.fc = nn.Linear(d_in, self.d_model) 58 | self.dropout = nn.Dropout(p=self.dropout) 59 | self.layer_norm = nn.LayerNorm(self.d_model) 60 | 61 | def forward(self, input, attention_weights=None): 62 | mask = (torch.sum(input, dim=-1) == 0).unsqueeze(-1) 63 | out = F.relu(self.fc(input)) 64 | out = self.dropout(out) 65 | out = self.layer_norm(out) 66 | out = out.masked_fill(mask, 0) 67 | return super(TransformerEncoder, self).forward(out, attention_weights=attention_weights) 68 | -------------------------------------------------------------------------------- /models/transformer/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import copy 4 | from models.containers import ModuleList 5 | from ..captioning_model import CaptioningModel 6 | 7 | 8 | class Transformer(CaptioningModel): 9 | def __init__(self, bos_idx, encoder, decoder): 10 | super(Transformer, self).__init__() 11 | self.bos_idx = bos_idx 12 | self.encoder = encoder 13 | self.decoder = decoder 14 | self.register_state('enc_output', None) 15 | self.register_state('mask_enc', None) 16 | self.init_weights() 17 | 18 | @property 19 | def d_model(self): 20 | return self.decoder.d_model 21 | 22 | def init_weights(self): 23 | for p in self.parameters(): 24 | if p.dim() > 1: 25 | nn.init.xavier_uniform_(p) 26 | 27 | def forward(self, images, seq, *args): 28 | enc_output, mask_enc = self.encoder(images) 29 | dec_output = self.decoder(seq, enc_output, mask_enc) 30 | return dec_output 31 | 32 | def init_state(self, b_s, device): 33 | return [torch.zeros((b_s, 0), dtype=torch.long, device=device), 34 | None, None] 35 | 36 | def step(self, t, prev_output, visual, seq, mode='teacher_forcing', **kwargs): 37 | it = None 38 | if mode == 'teacher_forcing': 39 | raise NotImplementedError 40 | elif mode == 'feedback': 41 | if t == 0: 42 | self.enc_output, self.mask_enc = self.encoder(visual) 43 | if isinstance(visual, torch.Tensor): 44 | it = visual.data.new_full((visual.shape[0], 1), self.bos_idx).long() 45 | else: 46 | it = visual[0].data.new_full((visual[0].shape[0], 1), self.bos_idx).long() 47 | else: 48 | it = prev_output 49 | 50 | return self.decoder(it, self.enc_output, self.mask_enc) 51 | 52 | 53 | class TransformerEnsemble(CaptioningModel): 54 | def __init__(self, model: Transformer, weight_files): 55 | super(TransformerEnsemble, self).__init__() 56 | self.n = len(weight_files) 57 | self.models = ModuleList([copy.deepcopy(model) for _ in range(self.n)]) 58 | for i in range(self.n): 59 | state_dict_i = torch.load(weight_files[i])['state_dict'] 60 | self.models[i].load_state_dict(state_dict_i) 61 | 62 | def step(self, t, prev_output, visual, seq, mode='teacher_forcing', **kwargs): 63 | out_ensemble = [] 64 | for i in range(self.n): 65 | out_i = self.models[i].step(t, prev_output, visual, seq, mode, **kwargs) 66 | out_ensemble.append(out_i.unsqueeze(0)) 67 | 68 | return torch.mean(torch.cat(out_ensemble, 0), dim=0) 69 | -------------------------------------------------------------------------------- /models/transformer/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | def position_embedding(input, d_model): 7 | input = input.view(-1, 1) 8 | dim = torch.arange(d_model // 2, dtype=torch.float32, device=input.device).view(1, -1) 9 | sin = torch.sin(input / 10000 ** (2 * dim / d_model)) 10 | cos = torch.cos(input / 10000 ** (2 * dim / d_model)) 11 | 12 | out = torch.zeros((input.shape[0], d_model), device=input.device) 13 | out[:, ::2] = sin 14 | out[:, 1::2] = cos 15 | return out 16 | 17 | 18 | def sinusoid_encoding_table(max_len, d_model, padding_idx=None): 19 | pos = torch.arange(max_len, dtype=torch.float32) 20 | out = position_embedding(pos, d_model) 21 | 22 | if padding_idx is not None: 23 | out[padding_idx] = 0 24 | return out 25 | 26 | 27 | class PositionWiseFeedForward(nn.Module): 28 | ''' 29 | Position-wise feed forward layer 30 | ''' 31 | 32 | def __init__(self, d_model=512, d_ff=2048, dropout=.1, identity_map_reordering=False): 33 | super(PositionWiseFeedForward, self).__init__() 34 | self.identity_map_reordering = identity_map_reordering 35 | self.fc1 = nn.Linear(d_model, d_ff) 36 | self.fc2 = nn.Linear(d_ff, d_model) 37 | self.dropout = nn.Dropout(p=dropout) 38 | self.dropout_2 = nn.Dropout(p=dropout) 39 | self.layer_norm = nn.LayerNorm(d_model) 40 | 41 | def forward(self, input): 42 | if self.identity_map_reordering: 43 | out = self.layer_norm(input) 44 | out = self.fc2(self.dropout_2(F.relu(self.fc1(out)))) 45 | out = input + self.dropout(torch.relu(out)) 46 | else: 47 | out = self.fc2(self.dropout_2(F.relu(self.fc1(input)))) 48 | out = self.dropout(out) 49 | out = self.layer_norm(input + out) 50 | return out 51 | -------------------------------------------------------------------------------- /pretrained_models/README.md: -------------------------------------------------------------------------------- 1 | you can download our pretrained models here. 2 | 3 | #### 1. pretrained languge model 4 | [language_context.pth](https://drive.google.com/file/d/1sayx7qwOd79XE4RFdvSXG3zyQH4FpyYJ/view?usp=sharing) 5 | #### 2. pretrained rstnet model 6 | [rstnet.pth](https://drive.google.com/file/d/13Faz1NIjvwSqBofDo7yO94bgbEdFtKgE/view?usp=sharing) 7 | -------------------------------------------------------------------------------- /switch_datatype.py: -------------------------------------------------------------------------------- 1 | # switch data from float32 to float16 2 | import os 3 | import torch 4 | from tqdm import tqdm 5 | import numpy as np 6 | import argparse 7 | 8 | 9 | def main(args): 10 | 11 | data_splits = os.listdir(args.dir_to_save_feats) 12 | 13 | for data_split in data_splits: 14 | print('processing {} ...'.format(data_split)) 15 | if not os.path.exists(os.path.join(args.dir_to_save_float16_feats, data_split)): 16 | os.mkdir(os.path.join(args.dir_to_save_float16_feats, data_split)) 17 | 18 | feat_dir = os.path.join(args.dir_to_save_feats, data_split) 19 | file_names = os.listdir(feat_dir) 20 | print(len(file_names)) 21 | 22 | for i in tqdm(range(len(file_names))): 23 | file_name = file_names[i] 24 | file_path = os.path.join(args.dir_to_save_feats, data_split, file_name) 25 | data32 = torch.load(file_path).numpy().squeeze() 26 | data16 = data32.astype('float16') 27 | 28 | image_id = int(file_name.split('.')[0]) 29 | saved_file_path = os.path.join(args.dir_to_save_float16_feats, data_split, str(image_id)+'.npy') 30 | np.save(saved_file_path, data16) 31 | 32 | 33 | if __name__ == '__main__': 34 | 35 | parser = argparse.ArgumentParser(description='swith the data type of features') 36 | parser.add_argument('--dir_to_raw_feats', type=str, default='./Datasets/X101-features/', help='big data') 37 | parser.add_argument('--dir_to_float16_feats', type=str, default='./Datasets/X101-features-float16', help='little data') 38 | args = parser.parse_args() 39 | 40 | main(args) 41 | 42 | -------------------------------------------------------------------------------- /tensorboard_logs/rstnet/events.out.tfevents.1603849421.MAC-U2S5.25151.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangxuying1004/RSTNet/d308459ad1c28e0f4c2c3e4a9986ae76870e3c72/tensorboard_logs/rstnet/events.out.tfevents.1603849421.MAC-U2S5.25151.0 -------------------------------------------------------------------------------- /tensorboard_logs/rstnet/events.out.tfevents.1603849484.MAC-U2S5.28641.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangxuying1004/RSTNet/d308459ad1c28e0f4c2c3e4a9986ae76870e3c72/tensorboard_logs/rstnet/events.out.tfevents.1603849484.MAC-U2S5.28641.0 -------------------------------------------------------------------------------- /tensorboard_logs/rstnet/events.out.tfevents.1603849514.MAC-U2S5.30196.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangxuying1004/RSTNet/d308459ad1c28e0f4c2c3e4a9986ae76870e3c72/tensorboard_logs/rstnet/events.out.tfevents.1603849514.MAC-U2S5.30196.0 -------------------------------------------------------------------------------- /tensorboard_logs/rstnet/events.out.tfevents.1603850322.MAC-U2S5.11650.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangxuying1004/RSTNet/d308459ad1c28e0f4c2c3e4a9986ae76870e3c72/tensorboard_logs/rstnet/events.out.tfevents.1603850322.MAC-U2S5.11650.0 -------------------------------------------------------------------------------- /tensorboard_logs/rstnet/events.out.tfevents.1605359164.MAC-U2S5.4519.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangxuying1004/RSTNet/d308459ad1c28e0f4c2c3e4a9986ae76870e3c72/tensorboard_logs/rstnet/events.out.tfevents.1605359164.MAC-U2S5.4519.0 -------------------------------------------------------------------------------- /tensorboard_logs/transformer/events.out.tfevents.1594888876.socialmedia.34273.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangxuying1004/RSTNet/d308459ad1c28e0f4c2c3e4a9986ae76870e3c72/tensorboard_logs/transformer/events.out.tfevents.1594888876.socialmedia.34273.0 -------------------------------------------------------------------------------- /test_offline.py: -------------------------------------------------------------------------------- 1 | # 线下测试:evaluating the performance of captioning model on the Karpathy test split of MS-COCO. 2 | 3 | import random 4 | from data import ImageDetectionsField, TextField, RawField 5 | from data import COCO, DataLoader 6 | import evaluation 7 | from models.rstnet import Transformer, TransformerEncoder, TransformerDecoderLayer, ScaledDotProductAttention 8 | 9 | import torch 10 | from tqdm import tqdm 11 | import argparse 12 | import pickle 13 | import numpy as np 14 | import time 15 | 16 | random.seed(1234) 17 | torch.manual_seed(1234) 18 | np.random.seed(1234) 19 | 20 | 21 | def predict_captions(model, dataloader, text_field): 22 | import itertools 23 | model.eval() 24 | gen = {} 25 | gts = {} 26 | with tqdm(desc='Evaluation', unit='it', total=len(dataloader)) as pbar: 27 | for it, (images, caps_gt) in enumerate(iter(dataloader)): 28 | images = images.to(device) 29 | with torch.no_grad(): 30 | out, _ = model.beam_search(images, 20, text_field.vocab.stoi[''], 5, out_size=1) 31 | caps_gen = text_field.decode(out, join_words=False) 32 | for i, (gts_i, gen_i) in enumerate(zip(caps_gt, caps_gen)): 33 | gen_i = ' '.join([k for k, g in itertools.groupby(gen_i)]) 34 | gen['%d_%d' % (it, i)] = [gen_i.strip(), ] 35 | gts['%d_%d' % (it, i)] = gts_i 36 | pbar.update() 37 | 38 | gts = evaluation.PTBTokenizer.tokenize(gts) 39 | gen = evaluation.PTBTokenizer.tokenize(gen) 40 | scores, _ = evaluation.compute_scores(gts, gen) 41 | 42 | return scores 43 | 44 | 45 | if __name__ == '__main__': 46 | start_time = time.time() 47 | device = torch.device('cuda') 48 | 49 | parser = argparse.ArgumentParser(description='Relationship-Sensitive Transformer Network') 50 | parser.add_argument('--batch_size', type=int, default=10) 51 | parser.add_argument('--workers', type=int, default=4) 52 | parser.add_argument('--m', type=int, default=40) 53 | 54 | parser.add_argument('--features_path', type=str, default='./datasets/X101-features/X101-grid-coco_trainval.hdf5') 55 | parser.add_argument('--annotation_folder', type=str, default='./datasets/m2_annotations') 56 | 57 | # the path of tested model and vocabulary 58 | parser.add_argument('--language_model_path', type=str, default='./saved_language_models/language_context.pth') 59 | parser.add_argument('--model_path', type=str, default='./saved_transformer_models/rstnet_best.pth') 60 | parser.add_argument('--vocab_path', type=str, default='./vocab.pkl') 61 | args = parser.parse_args() 62 | 63 | print('The Offline Evaluation of RSTNet') 64 | 65 | # Pipeline for image regions 66 | image_field = ImageDetectionsField(detections_path=args.features_path, max_detections=49, load_in_tmp=False) 67 | 68 | # Pipeline for text 69 | text_field = TextField(init_token='', eos_token='', lower=True, tokenize='spacy', 70 | remove_punctuation=True, nopoints=False) 71 | 72 | # 加载数据集 73 | dataset = COCO(image_field, text_field, 'coco/images/', args.annotation_folder, args.annotation_folder) 74 | _, _, test_dataset = dataset.splits 75 | text_field.vocab = pickle.load(open(args.vocab_path, 'rb')) 76 | 77 | dict_dataset_test = test_dataset.image_dictionary({'image': image_field, 'text': RawField()}) 78 | dict_dataloader_test = DataLoader(dict_dataset_test, batch_size=args.batch_size, num_workers=args.workers) 79 | 80 | # 加载模型 81 | encoder = TransformerEncoder(3, 0, attention_module=ScaledDotProductAttention, attention_module_kwargs={'m': args.m}) 82 | decoder = TransformerDecoderLayer(len(text_field.vocab), 54, 3, text_field.vocab.stoi[''], language_model_path=args.language_model_path) 83 | model = Transformer(text_field.vocab.stoi[''], encoder, decoder).to(device) 84 | 85 | data = torch.load(args.model_path) 86 | model.load_state_dict(data['state_dict']) 87 | 88 | # 计算得分 89 | scores = predict_captions(model, dict_dataloader_test, text_field) 90 | print(scores) 91 | print('it costs {} s to test.'.format(time.time() - start_time)) 92 | 93 | 94 | -------------------------------------------------------------------------------- /test_online.py: -------------------------------------------------------------------------------- 1 | # 线上测试:evaluating the performance of captioning model on official MS-COCO test server. 2 | # 此处生成指定格式的官方测试集和验证集对应的captions文件,将这两个文件压缩后上传到[CodaLab](https://competitions.codalab.org/competitions/3221#participate)即可得到 3 | # 线上测试的结果和排名 4 | 5 | import torch 6 | import argparse 7 | import pickle 8 | import numpy as np 9 | import itertools 10 | import json 11 | import os 12 | 13 | from tqdm import tqdm 14 | 15 | from data import TextField, COCO_TestOnline 16 | from models.rstnet import Transformer, TransformerEncoder, TransformerDecoderLayer, ScaledDotProductAttention, TransformerEnsemble 17 | 18 | import random 19 | random.seed(1234) 20 | torch.manual_seed(1234) 21 | np.random.seed(1234) 22 | 23 | from torch.utils.data import DataLoader 24 | 25 | 26 | def gen_caps(captioning_model, dataset, batch_size=10, workers=0): 27 | dataloader = DataLoader( 28 | dataset, 29 | batch_size=batch_size, 30 | num_workers=workers 31 | ) 32 | 33 | outputs = [] 34 | with tqdm(len(dataloader)) as pbar: 35 | for it, (image_ids, images) in enumerate(iter(dataloader)): 36 | images = images.to(device) 37 | with torch.no_grad(): 38 | out, _ = captioning_model.beam_search(images, 20, text_field.vocab.stoi[''], 5, out_size=1) 39 | caps_gen = text_field.decode(out, join_words=False) 40 | caps_gen = [' '.join([k for k, g in itertools.groupby(gen_i)]).strip() for gen_i in caps_gen] 41 | for i in range(image_ids.size(0)): 42 | item = {} 43 | item['image_id'] = int(image_ids[i]) 44 | item['caption'] = caps_gen[i] 45 | outputs.append(item) 46 | pbar.update() 47 | return outputs 48 | 49 | 50 | def save_results(outputs, datasplit, dir_to_save_caps): 51 | if not os.path.exists(dir_to_save_caps): 52 | os.makedirs(dir_to_save_caps) 53 | # 命名规范:captions_test2014_XXX_results.json 和 captions_val2014_XXX_results.json 54 | output_path = os.path.join(dir_to_save_caps, 'captions_' + datasplit + '2014_RSTNet_results.json') 55 | with open(output_path, 'w') as f: 56 | json.dump(outputs, f) 57 | 58 | 59 | if __name__ == '__main__': 60 | 61 | device = torch.device('cuda') 62 | 63 | parser = argparse.ArgumentParser(description='Relationship-Sensitive Transformer Network') 64 | parser.add_argument('--batch_size', type=int, default=10) 65 | parser.add_argument('--workers', type=int, default=0) 66 | 67 | parser.add_argument('--datasplit', type=str, default='test') # test, val 68 | 69 | # 测试集 70 | parser.add_argument('--test_features_path', type=str, default='./datasets/X101-features/X101_grid_feats_coco_test.hdf5') 71 | parser.add_argument('--test_annotation_folder', type=str, default='./datasets/m2_annotations/image_info_test2014.json') 72 | 73 | # 验证集 74 | parser.add_argument('--val_features_path', type=str, default='./datasets/X101-features/X101_grid_feats_coco_trainval.hdf5') 75 | parser.add_argument('--val_annotation_folder', type=str, default='/home/DATA/m2_annotations/captions_val2014.json') 76 | 77 | # 模型参数 78 | parser.add_argument('--models_path', type=list, default=[ 79 | './test_online/models/rstnet_x101_1.pth', 80 | './test_online/models/rstnet_x101_2.pth', 81 | './test_online/models/rstnet_x101_3.pth', 82 | './test_online/models/rstnet_x101_4.pth' 83 | ]) 84 | 85 | parser.add_argument('--dir_to_save_caps', type=str, default='./test_online/results/') # 文件保存路径 86 | 87 | args = parser.parse_args() 88 | 89 | print('The Online Evaluation of RSTNet') 90 | 91 | # 加载数据集 92 | text_field = TextField(init_token='', eos_token='', lower=True, tokenize='spacy', 93 | remove_punctuation=True, nopoints=False) 94 | text_field.vocab = pickle.load(open('./vocab.pkl', 'rb')) 95 | 96 | if args.datasplit == 'test': 97 | dataset = COCO_TestOnline(feat_path=args.test_features_path, ann_file=args.test_annotation_folder) 98 | else: 99 | dataset = COCO_TestOnline(feat_path=args.val_features_path, ann_file=args.val_annotation_folder) 100 | 101 | # 加载模型参数 102 | # 模型结构 103 | encoder = TransformerEncoder(3, 0, attention_module=ScaledDotProductAttention, attention_module_kwargs={'m': 40}) 104 | decoder = TransformerDecoderLayer(len(text_field.vocab), 54, 3, text_field.vocab.stoi['']) 105 | model = Transformer(text_field.vocab.stoi[''], encoder, decoder).to(device) 106 | # 集成模型 107 | ensemble_model = TransformerEnsemble(model=model, weight_files=args.models_path) 108 | 109 | # 生成结果 110 | outputs = gen_caps(ensemble_model, dataset, batch_size=args.batch_size, workers=args.workers) 111 | 112 | # 保存结果 113 | save_results(outputs, args.datasplit, args.dir_to_save_caps) 114 | 115 | print('finished!') 116 | 117 | 118 | -------------------------------------------------------------------------------- /train_language.py: -------------------------------------------------------------------------------- 1 | import random 2 | from data import ImageDetectionsField, TextField, RawField 3 | from data import COCO, DataLoader 4 | 5 | from models.rstnet.language_model import LanguageModel 6 | 7 | import torch 8 | from torch.optim import Adam 9 | from torch.optim.lr_scheduler import LambdaLR 10 | from torch.nn import NLLLoss 11 | from tqdm import tqdm 12 | from torch.utils.tensorboard import SummaryWriter 13 | 14 | import argparse 15 | import os 16 | import pickle 17 | import numpy as np 18 | from shutil import copyfile 19 | 20 | random.seed(1234) 21 | torch.manual_seed(1234) 22 | np.random.seed(1234) 23 | 24 | 25 | def evaluate_loss(model, dataloader, loss_fn, text_field): 26 | # Validation loss 27 | model.eval() 28 | running_loss = .0 29 | with tqdm(desc='Epoch %d - validation' % e, unit='it', total=len(dataloader)) as pbar: 30 | with torch.no_grad(): 31 | for it, (detections, captions) in enumerate(dataloader): 32 | detections, captions = detections.to(device), captions.to(device) 33 | # out = model(detections, captions) 34 | out, _ = model(captions) 35 | captions = captions[:, 1:].contiguous() 36 | out = out[:, :-1].contiguous() 37 | loss = loss_fn(out.view(-1, len(text_field.vocab)), captions.view(-1)) 38 | this_loss = loss.item() 39 | running_loss += this_loss 40 | 41 | pbar.set_postfix(loss=running_loss / (it + 1)) 42 | pbar.update() 43 | 44 | val_loss = running_loss / len(dataloader) 45 | return val_loss 46 | 47 | 48 | def evaluate_metrics(model, dataloader, text_field): 49 | model.eval() 50 | scores = {} 51 | total_num = 0. 52 | correct_num = 0. 53 | with tqdm(desc='Epoch %d - evaluation' % e, unit='ite', total=len(dataloader)) as pbar: 54 | with torch.no_grad(): 55 | for it, (detections, captions) in enumerate(dataloader): 56 | detections, captions = detections.to(device), captions.to(device) 57 | # out = model(detections, captions) 58 | out, _ = model(captions) 59 | captions = captions[:, 1:].contiguous() 60 | out = torch.argmax(out[:, :-1], dim=-1).contiguous() 61 | b_s, seq_len = out.size() 62 | total_num += float(b_s * seq_len) 63 | correct_num += float((out == captions).sum()) 64 | pbar.update() 65 | 66 | scores['correct_num'] = correct_num 67 | scores['total_num'] = total_num 68 | scores['accuracy'] = correct_num / total_num 69 | return scores 70 | 71 | 72 | def train_xe(model, dataloader_train, optim, text_field): 73 | # Training with cross-entropy 74 | model.train() 75 | scheduler.step() 76 | running_loss = .0 77 | # print('lr = {}'.format(scheduler.get_lr()[0])) 78 | with tqdm(desc='Epoch %d - train' % e, unit='it', total=len(dataloader_train)) as pbar: 79 | for it, (detections, captions) in enumerate(dataloader_train): 80 | detections, captions = detections.to(device), captions.to(device) 81 | # out = model(detections, captions) 82 | out, _ = model(captions) 83 | optim.zero_grad() 84 | captions_gt = captions[:, 1:].contiguous() 85 | out = out[:, :-1].contiguous() 86 | loss = loss_fn(out.view(-1, len(text_field.vocab)), captions_gt.view(-1)) 87 | loss.backward() 88 | 89 | optim.step() 90 | this_loss = loss.item() 91 | running_loss += this_loss 92 | 93 | pbar.set_postfix(loss=running_loss / (it + 1)) 94 | pbar.update() 95 | scheduler.step() 96 | 97 | loss = running_loss / len(dataloader_train) 98 | return loss 99 | 100 | 101 | if __name__ == '__main__': 102 | device = torch.device('cuda') 103 | parser = argparse.ArgumentParser(description='Bert Language Model') 104 | parser.add_argument('--exp_name', type=str, default='bert_language') 105 | parser.add_argument('--batch_size', type=int, default=50) 106 | parser.add_argument('--workers', type=int, default=4) 107 | parser.add_argument('--m', type=int, default=40) 108 | parser.add_argument('--head', type=int, default=8) 109 | parser.add_argument('--warmup', type=int, default=11328) 110 | parser.add_argument('--resume_last', action='store_true') 111 | parser.add_argument('--resume_best', action='store_true') 112 | 113 | parser.add_argument('--features_path', type=str, default='./Datasets/X101-features/X101-grid-coco_trainval.hdf5') 114 | parser.add_argument('--annotation_folder', type=str, default='./Datasets/m2_annotations') 115 | 116 | parser.add_argument('--dir_to_save_model', type=str, default='./saved_language_models') 117 | parser.add_argument('--logs_folder', type=str, default='./language_tensorboard_logs') 118 | 119 | args = parser.parse_args() 120 | print(args) 121 | 122 | print('Bert Language Model Training') 123 | # preparation 124 | if not os.path.exists(args.dir_to_save_model): 125 | os.makedirs(args.dir_to_save_model) 126 | if not os.path.exists(args.logs_folder): 127 | os.makedirs(args.logs_folder) 128 | writer = SummaryWriter(log_dir=os.path.join(args.logs_folder, args.exp_name)) 129 | 130 | # Pipeline for image regions 131 | image_field = ImageDetectionsField(detections_path=args.features_path, max_detections=50, load_in_tmp=False) 132 | 133 | # Pipeline for text 134 | text_field = TextField(init_token='', eos_token='', lower=True, tokenize='spacy', 135 | remove_punctuation=True, nopoints=False) 136 | 137 | # Create the dataset 138 | dataset = COCO(image_field, text_field, 'coco/images/', args.annotation_folder, args.annotation_folder) 139 | train_dataset, val_dataset, test_dataset = dataset.splits 140 | 141 | if not os.path.isfile('vocab.pkl'): 142 | print("Building vocabulary") 143 | text_field.build_vocab(train_dataset, val_dataset, min_freq=5) 144 | pickle.dump(text_field.vocab, open('vocab.pkl', 'wb')) 145 | else: 146 | print('Loading from vocabulary') 147 | text_field.vocab = pickle.load(open('vocab.pkl', 'rb')) 148 | 149 | model = LanguageModel(padding_idx=text_field.vocab.stoi[''], bert_hidden_size=768, vocab_size=len(text_field.vocab)).to(device) 150 | 151 | dict_dataset_train = train_dataset.image_dictionary({'image': image_field, 'text': RawField()}) 152 | dict_dataset_val = val_dataset.image_dictionary({'image': image_field, 'text': RawField()}) 153 | dict_dataset_test = test_dataset.image_dictionary({'image': image_field, 'text': RawField()}) 154 | 155 | def lambda_lr(s): 156 | warm_up = args.warmup 157 | s += 1 158 | if s % 11331 == 0: 159 | s = 1 160 | else: 161 | s = s % 11331 162 | 163 | lr = (model.d_model ** -.5) * min(s ** -.5, s * warm_up ** -1.5) 164 | if lr > 1e-6: 165 | lr = 1e-6 166 | 167 | print('s = {}, lr = {}'.format(s, lr)) 168 | return lr 169 | 170 | # Initial conditions 171 | optim = Adam(model.parameters(), lr=1, betas=(0.9, 0.98)) 172 | scheduler = LambdaLR(optim, lambda_lr) 173 | # scheduler = StepLR(optim, step_size=2, gamma=0.5) 174 | loss_fn = NLLLoss(ignore_index=text_field.vocab.stoi['']) 175 | use_rl = False 176 | best_score = .0 177 | best_test_score = .0 178 | patience = 0 179 | start_epoch = 0 180 | 181 | if args.resume_last or args.resume_best: 182 | if args.resume_last: 183 | fname = os.path.join(args.dir_to_save_model, '%s_last.pth' % args.exp_name) 184 | else: 185 | fname = os.path.join(args.dir_to_save_model, '%s_best.pth' % args.exp_name) 186 | 187 | if os.path.exists(fname): 188 | data = torch.load(fname) 189 | torch.set_rng_state(data['torch_rng_state']) 190 | torch.cuda.set_rng_state(data['cuda_rng_state']) 191 | np.random.set_state(data['numpy_rng_state']) 192 | random.setstate(data['random_rng_state']) 193 | model.load_state_dict(data['state_dict'], strict=False) 194 | optim.load_state_dict(data['optimizer']) 195 | scheduler.load_state_dict(data['scheduler']) 196 | start_epoch = data['epoch'] + 1 197 | best_score = data['best_score'] 198 | patience = data['patience'] 199 | use_rl = data['use_rl'] 200 | print('Resuming from epoch %d, validation loss %f, and best score %f' % ( 201 | data['epoch'], data['val_loss'], data['best_score'])) 202 | 203 | print("Training starts") 204 | for e in range(start_epoch, start_epoch + 100): 205 | dataloader_train = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, 206 | drop_last=True) 207 | dataloader_val = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) 208 | dataloader_test = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) 209 | 210 | dict_dataloader_train = DataLoader(dict_dataset_train, batch_size=args.batch_size // 5, shuffle=True, 211 | num_workers=args.workers) 212 | dict_dataloader_val = DataLoader(dict_dataset_val, batch_size=args.batch_size // 5) 213 | dict_dataloader_test = DataLoader(dict_dataset_test, batch_size=args.batch_size // 5) 214 | 215 | if not use_rl: 216 | train_loss = train_xe(model, dataloader_train, optim, text_field) 217 | writer.add_scalar('data/train_loss', train_loss, e) 218 | else: 219 | break 220 | 221 | # Validation loss 222 | val_loss = evaluate_loss(model, dataloader_val, loss_fn, text_field) 223 | writer.add_scalar('data/val_loss', val_loss, e) 224 | 225 | # Validation scores 226 | val_scores = evaluate_metrics(model, dataloader_val, text_field) 227 | print("epoch {}: Validation scores", val_scores) 228 | val_score = val_scores['accuracy'] 229 | writer.add_scalar('data/val_score', val_score, e) 230 | 231 | # Test scores 232 | test_scores = evaluate_metrics(model, dataloader_test, text_field) 233 | print("epoch {}: Test scores", test_scores) 234 | test_score = test_scores['accuracy'] 235 | writer.add_scalar('data/test_score', test_score, e) 236 | 237 | # Prepare for next epoch 238 | best = False 239 | if val_score >= best_score: 240 | best_score = val_score 241 | patience = 0 242 | best = True 243 | else: 244 | patience += 1 245 | 246 | best_test = False 247 | if test_score >= best_test_score: 248 | best_test_score = test_score 249 | best_test = True 250 | 251 | switch_to_rl = False 252 | exit_train = False 253 | if patience == 5: 254 | if not use_rl: 255 | use_rl = True 256 | switch_to_rl = True 257 | patience = 0 258 | break 259 | else: 260 | print('patience reached.') 261 | exit_train = True 262 | 263 | if switch_to_rl and not best: 264 | data = torch.load(os.path.join(args.dir_to_save_model, '%s_best.pth' % args.exp_name)) 265 | torch.set_rng_state(data['torch_rng_state']) 266 | torch.cuda.set_rng_state(data['cuda_rng_state']) 267 | np.random.set_state(data['numpy_rng_state']) 268 | random.setstate(data['random_rng_state']) 269 | model.load_state_dict(data['state_dict']) 270 | print('Resuming from epoch %d, validation loss %f, and best score %f' % ( 271 | data['epoch'], data['val_loss'], data['best_score'])) 272 | 273 | torch.save({ 274 | 'torch_rng_state': torch.get_rng_state(), 275 | 'cuda_rng_state': torch.cuda.get_rng_state(), 276 | 'numpy_rng_state': np.random.get_state(), 277 | 'random_rng_state': random.getstate(), 278 | 'epoch': e, 279 | 'val_loss': val_loss, 280 | 'val_score': val_score, 281 | 'state_dict': model.state_dict(), 282 | 'optimizer': optim.state_dict(), 283 | 'scheduler': scheduler.state_dict(), 284 | 'patience': patience, 285 | 'best_score': best_score, 286 | 'use_rl': use_rl, 287 | }, os.path.join(args.dir_to_save_model, '%s_last.pth' % args.exp_name)) 288 | 289 | if best: 290 | copyfile(os.path.join(args.dir_to_save_model, '%s_last.pth' % args.exp_name), os.path.join(args.dir_to_save_model, '%s_best.pth' % args.exp_name)) 291 | if best_test: 292 | copyfile(os.path.join(args.dir_to_save_model, '%s_last.pth' % args.exp_name), os.path.join(args.dir_to_save_model, '%s_best_test.pth' % args.exp_name)) 293 | 294 | if exit_train: 295 | writer.close() 296 | break 297 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import download_from_url 2 | from .typing import * 3 | 4 | def get_batch_size(x: TensorOrSequence) -> int: 5 | if isinstance(x, torch.Tensor): 6 | b_s = x.size(0) 7 | else: 8 | b_s = x[0].size(0) 9 | return b_s 10 | 11 | 12 | def get_device(x: TensorOrSequence) -> int: 13 | if isinstance(x, torch.Tensor): 14 | b_s = x.device 15 | else: 16 | b_s = x[0].device 17 | return b_s 18 | -------------------------------------------------------------------------------- /utils/typing.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Sequence, Tuple 2 | import torch 3 | 4 | TensorOrSequence = Union[Sequence[torch.Tensor], torch.Tensor] 5 | TensorOrNone = Union[torch.Tensor, None] 6 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import requests 2 | 3 | def download_from_url(url, path): 4 | """Download file, with logic (from tensor2tensor) for Google Drive""" 5 | if 'drive.google.com' not in url: 6 | print('Downloading %s; may take a few minutes' % url) 7 | r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}) 8 | with open(path, "wb") as file: 9 | file.write(r.content) 10 | return 11 | print('Downloading from Google Drive; may take a few minutes') 12 | confirm_token = None 13 | session = requests.Session() 14 | response = session.get(url, stream=True) 15 | for k, v in response.cookies.items(): 16 | if k.startswith("download_warning"): 17 | confirm_token = v 18 | 19 | if confirm_token: 20 | url = url + "&confirm=" + confirm_token 21 | response = session.get(url, stream=True) 22 | 23 | chunk_size = 16 * 1024 24 | with open(path, "wb") as f: 25 | for chunk in response.iter_content(chunk_size): 26 | if chunk: 27 | f.write(chunk) 28 | -------------------------------------------------------------------------------- /vocab.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangxuying1004/RSTNet/d308459ad1c28e0f4c2c3e4a9986ae76870e3c72/vocab.pkl --------------------------------------------------------------------------------