├── LICENSE
├── README.md
├── data
├── __init__.py
├── dataset.py
├── example.py
├── field.py
├── utils.py
└── vocab.py
├── environment.yml
├── evaluation
├── __init__.py
├── bleu
│ ├── __init__.py
│ ├── bleu.py
│ └── bleu_scorer.py
├── cider
│ ├── __init__.py
│ ├── cider.py
│ └── cider_scorer.py
├── meteor
│ ├── __init__.py
│ └── meteor.py
├── rouge
│ ├── __init__.py
│ └── rouge.py
├── stanford-corenlp-3.4.1.jar
└── tokenizer.py
├── feats_process.py
├── images
├── RSTNet.png
├── results.png
├── train_cider.png
└── visualness.png
├── models
├── __init__.py
├── beam_search
│ ├── __init__.py
│ └── beam_search.py
├── captioning_model.py
├── containers.py
├── m2_transformer
│ ├── __init__.py
│ ├── attention.py
│ ├── decoders.py
│ ├── encoders.py
│ ├── transformer.py
│ └── utils.py
├── rstnet
│ ├── __init__.py
│ ├── attention.py
│ ├── decoders.py
│ ├── encoders.py
│ ├── grid_aug.py
│ ├── language_model.py
│ ├── transformer.py
│ └── utils.py
└── transformer
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── attention.cpython-36.pyc
│ ├── decoders.cpython-36.pyc
│ ├── encoders.cpython-36.pyc
│ ├── transformer.cpython-36.pyc
│ └── utils.cpython-36.pyc
│ ├── attention.py
│ ├── decoders.py
│ ├── encoders.py
│ ├── transformer.py
│ └── utils.py
├── pretrained_models
└── README.md
├── switch_datatype.py
├── tensorboard_logs
├── rstnet
│ ├── events.out.tfevents.1603849421.MAC-U2S5.25151.0
│ ├── events.out.tfevents.1603849484.MAC-U2S5.28641.0
│ ├── events.out.tfevents.1603849514.MAC-U2S5.30196.0
│ ├── events.out.tfevents.1603850322.MAC-U2S5.11650.0
│ └── events.out.tfevents.1605359164.MAC-U2S5.4519.0
└── transformer
│ └── events.out.tfevents.1594888876.socialmedia.34273.0
├── test_offline.py
├── test_online.py
├── train_language.py
├── train_transformer.py
├── utils
├── __init__.py
├── typing.py
└── utils.py
└── vocab.pkl
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2021, 张旭迎
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | 3. Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # RSTNet: Relationship-Sensitive Transformer Network
2 | This repository contains the reference code for our paper [_RSTNet: Captioning with Adaptive Attention on Visual and Non-Visual Words_ (CVPR 2021)](https://openaccess.thecvf.com/content/CVPR2021/papers/Zhang_RSTNet_Captioning_With_Adaptive_Attention_on_Visual_and_Non-Visual_Words_CVPR_2021_paper.pdf).
3 |
4 |
5 |
6 |
7 |
8 | ## Tips
9 | If you have any questions about our work, feel free to post issues on this github project. I will answer your questions and update code monthly.
10 | If you are in hurry, please email me via [zhangxuying1004@gmail.com](zhangxuying1004@gmail.com).
11 | If our work is helpful to you or gives some inspiration to you, please star this project and cite our paper. Thank you!
12 | ```
13 | @inproceedings{zhang2021rstnet,
14 | title={RSTNet: Captioning with adaptive attention on visual and non-visual words},
15 | author={Zhang, Xuying and Sun, Xiaoshuai and Luo, Yunpeng and Ji, Jiayi and Zhou, Yiyi and Wu, Yongjian and Huang, Feiyue and Ji, Rongrong},
16 | booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition},
17 | pages={15465--15474},
18 | year={2021}
19 | }
20 | ```
21 |
22 | ## Environment setup
23 | Clone the repository and create the `m2release` conda environment using the `environment.yml` file:
24 | ```
25 | conda env create -f environment.yml
26 | conda activate m2release
27 | ```
28 |
29 | Then, download spacy data by executing the following command:
30 | ```python -m spacy download en``` or ```python -m spacy download en_core_web_sm```.
31 |
32 | You also need to create 5 new folders, namely ```Datasets```, ```save_language_models```, ```language_tensorboard_logs```, ```save_transformer_models``` and ```transformer_tensorboard_logs``` in the root directory of this project.
33 |
34 | ## Data preparation
35 | To run our code, you need to put annotations folder ```m2_annotations```, visual features folder ```X101-features``` for the COCO dataset into ```Datasets```.
36 |
37 | Most annotations have been prepared by [1], please download [m2_annotations](https://drive.google.com/drive/folders/1tJnetunBkQ4Y5A3pq2P53yeJuGa4lX9e?usp=sharing) and put it into this root directory.
38 |
39 | Visual features are computed with the code provided by [2]. To reproduce our result, please download the COCO features file such as ```X-101-features.tgz``` in [grid-feats-vqa
40 | ](https://github.com/facebookresearch/grid-feats-vqa) and rename the extracted folder as ```X101-features```. Considering that this feature file is huge, you can alternatively save the features as float16 for storage space saving by executing the following command:
41 | ```
42 | python switch_datatype.py
43 | ```
44 | In order to solve the shape difference and match the feat shape with region feat shape (`50` regions), please execute the following command to reshape the visual to `49(7x7)` and save all visual features as a h5py file.
45 | ```
46 | python feats_process.py
47 | ```
48 | Note that you can also access to my processed offline image features [coco_grid_feats](https://pan.baidu.com/s/1myelTYJE8a1HDZHkoccfIA) in Baidu Netdisk with the extraction code ```cvpr``` for convenience.
49 |
50 |
51 |
52 | In addition, if you want to extract the grid-based features of your custom image dataset, you can refer to the codes in project [grid-feats-vqa
53 | ](https://github.com/facebookresearch/grid-feats-vqa).
54 |
55 | ## Training procedure
56 | Run `python train_language.py` and `python train_transformer.py` in sequence using the following arguments:
57 |
58 | | Argument | Possible values |
59 | |------|------|
60 | | `--exp_name` | Experiment name|
61 | | `--batch_size` | Batch size (default: 50) |
62 | | `--workers` | Number of workers, accelerate model training in the xe stage.|
63 | | `--head` | Number of heads (default: 8) |
64 | | `--resume_last` | If used, the training will be resumed from the last checkpoint. |
65 | | `--resume_best` | If used, the training will be resumed from the best checkpoint. |
66 | | `--features_path` | Path to visual features file (h5py)|
67 | | `--annotation_folder` | Path to m2_annotations |
68 |
69 | For example, to train our BERT-based language model with the parameters used in our experiments, use
70 | ```
71 | python train_language.py --exp_name bert_language --batch_size 50 --features_path /path/to/features --annotation_folder /path/to/annotations
72 | ```
73 | to train our rstnet model with the parameters used in our experiments, use
74 | ```
75 | python train_transformer.py --exp_name rstnet --batch_size 50 --m 40 --head 8 --features_path /path/to/features --annotation_folder /path/to/annotations
76 | ```
77 | The figure below shows the changes of cider value during the training of rstnet. You can also visualize the training details by calling the tensorboard files in ```tensorboard_logs```.
78 |
79 |
80 |
81 |
82 | ## Evaluation
83 | ### Offline Evaluation
84 | Run `python test_offline.py` to evaluate the performance of rstnet on the Karpathy test split of MS COCO dataset.
85 |
86 | ### Online Evaluation
87 | Run `python test_online.py` to generate required files and evaluate the performance of rstnet on the official test server of MS COCO dataset.
88 |
89 | Note that, to reproduce the our reported results, you can also download our pretrained model files in the ```pretrained_models``` folder and put them into folder ```saved_language_models``` and folder ```saved_language_models``` repectively . The results of offline evaluation (Karpathy test split of MS COCO) are as follows:
90 |
91 |
92 |
93 |
94 | #### References
95 | [1] Cornia, M., Stefanini, M., Baraldi, L., & Cucchiara, R. (2020). Meshed-memory transformer for image captioning. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition.
96 | [2] Jiang, H., Misra, I., Rohrbach, M., Learned-Miller, E., & Chen, X. (2020). In defense of grid features for visual question answering. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition.
97 |
98 |
99 | #### Acknowledgements
100 | Thank Cornia _et.al_ for their open source code [meshed-memory-transformer
101 | ](https://github.com/aimagelab/meshed-memory-transformer), on which our implements are based.
102 | Thank Jiang _et.al_ for the significant discovery in visual representation [2], which has given us a lot of inspiration.
103 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .field import RawField, Merge, ImageDetectionsField, TextField
2 | from .dataset import COCO, COCO_TestOnline
3 | from torch.utils.data import DataLoader as TorchDataLoader
4 |
5 | class DataLoader(TorchDataLoader):
6 | def __init__(self, dataset, *args, **kwargs):
7 | super(DataLoader, self).__init__(dataset, *args, collate_fn=dataset.collate_fn(), **kwargs)
8 |
--------------------------------------------------------------------------------
/data/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import itertools
4 | import collections
5 | import torch
6 | from .example import Example
7 | from .utils import nostdout
8 | from pycocotools.coco import COCO as pyCOCO
9 |
10 |
11 | class Dataset(object):
12 | def __init__(self, examples, fields):
13 | self.examples = examples
14 | self.fields = dict(fields)
15 |
16 | def collate_fn(self):
17 | def collate(batch):
18 | if len(self.fields) == 1:
19 | batch = [batch, ]
20 | else:
21 | batch = list(zip(*batch))
22 |
23 | tensors = []
24 | for field, data in zip(self.fields.values(), batch):
25 | tensor = field.process(data)
26 | if isinstance(tensor, collections.Sequence) and any(isinstance(t, torch.Tensor) for t in tensor):
27 | tensors.extend(tensor)
28 | else:
29 | tensors.append(tensor)
30 |
31 | if len(tensors) > 1:
32 | return tensors
33 | else:
34 | return tensors[0]
35 |
36 | return collate
37 |
38 | def __getitem__(self, i):
39 | example = self.examples[i]
40 | data = []
41 | for field_name, field in self.fields.items():
42 | data.append(field.preprocess(getattr(example, field_name)))
43 |
44 | if len(data) == 1:
45 | data = data[0]
46 | return data
47 |
48 | def __len__(self):
49 | return len(self.examples)
50 |
51 | def __getattr__(self, attr):
52 | if attr in self.fields:
53 | for x in self.examples:
54 | yield getattr(x, attr)
55 |
56 |
57 | class ValueDataset(Dataset):
58 | def __init__(self, examples, fields, dictionary):
59 | self.dictionary = dictionary
60 | super(ValueDataset, self).__init__(examples, fields)
61 |
62 | def collate_fn(self):
63 | def collate(batch):
64 | value_batch_flattened = list(itertools.chain(*batch))
65 | value_tensors_flattened = super(ValueDataset, self).collate_fn()(value_batch_flattened)
66 |
67 | lengths = [0, ] + list(itertools.accumulate([len(x) for x in batch]))
68 | if isinstance(value_tensors_flattened, collections.Sequence) \
69 | and any(isinstance(t, torch.Tensor) for t in value_tensors_flattened):
70 | value_tensors = [[vt[s:e] for (s, e) in zip(lengths[:-1], lengths[1:])] for vt in value_tensors_flattened]
71 | else:
72 | value_tensors = [value_tensors_flattened[s:e] for (s, e) in zip(lengths[:-1], lengths[1:])]
73 |
74 | return value_tensors
75 | return collate
76 |
77 | def __getitem__(self, i):
78 | if i not in self.dictionary:
79 | raise IndexError
80 |
81 | values_data = []
82 | for idx in self.dictionary[i]:
83 | value_data = super(ValueDataset, self).__getitem__(idx)
84 | values_data.append(value_data)
85 | return values_data
86 |
87 | def __len__(self):
88 | return len(self.dictionary)
89 |
90 |
91 | class DictionaryDataset(Dataset):
92 | def __init__(self, examples, fields, key_fields):
93 | if not isinstance(key_fields, (tuple, list)):
94 | key_fields = (key_fields,)
95 | for field in key_fields:
96 | assert (field in fields)
97 |
98 | dictionary = collections.defaultdict(list)
99 | key_fields = {k: fields[k] for k in key_fields}
100 | value_fields = {k: fields[k] for k in fields.keys() if k not in key_fields}
101 | key_examples = []
102 | key_dict = dict()
103 | value_examples = []
104 |
105 | for i, e in enumerate(examples):
106 | key_example = Example.fromdict({k: getattr(e, k) for k in key_fields})
107 | value_example = Example.fromdict({v: getattr(e, v) for v in value_fields})
108 | if key_example not in key_dict:
109 | key_dict[key_example] = len(key_examples)
110 | key_examples.append(key_example)
111 |
112 | value_examples.append(value_example)
113 | dictionary[key_dict[key_example]].append(i)
114 |
115 | self.key_dataset = Dataset(key_examples, key_fields)
116 | self.value_dataset = ValueDataset(value_examples, value_fields, dictionary)
117 | super(DictionaryDataset, self).__init__(examples, fields)
118 |
119 | def collate_fn(self):
120 | def collate(batch):
121 | key_batch, value_batch = list(zip(*batch))
122 | key_tensors = self.key_dataset.collate_fn()(key_batch)
123 | value_tensors = self.value_dataset.collate_fn()(value_batch)
124 | return key_tensors, value_tensors
125 | return collate
126 |
127 | def __getitem__(self, i):
128 | return self.key_dataset[i], self.value_dataset[i]
129 |
130 | def __len__(self):
131 | return len(self.key_dataset)
132 |
133 |
134 | def unique(sequence):
135 | seen = set()
136 | if isinstance(sequence[0], list):
137 | return [x for x in sequence if not (tuple(x) in seen or seen.add(tuple(x)))]
138 | else:
139 | return [x for x in sequence if not (x in seen or seen.add(x))]
140 |
141 |
142 | class PairedDataset(Dataset):
143 | def __init__(self, examples, fields):
144 | assert ('image' in fields)
145 | assert ('text' in fields)
146 | super(PairedDataset, self).__init__(examples, fields)
147 | self.image_field = self.fields['image']
148 | self.text_field = self.fields['text']
149 |
150 | def image_set(self):
151 | img_list = [e.image for e in self.examples]
152 | image_set = unique(img_list)
153 | examples = [Example.fromdict({'image': i}) for i in image_set]
154 | dataset = Dataset(examples, {'image': self.image_field})
155 | return dataset
156 |
157 | def text_set(self):
158 | text_list = [e.text for e in self.examples]
159 | text_list = unique(text_list)
160 | examples = [Example.fromdict({'text': t}) for t in text_list]
161 | dataset = Dataset(examples, {'text': self.text_field})
162 | return dataset
163 |
164 | def image_dictionary(self, fields=None):
165 | if not fields:
166 | fields = self.fields
167 | dataset = DictionaryDataset(self.examples, fields, key_fields='image')
168 | return dataset
169 |
170 | def text_dictionary(self, fields=None):
171 | if not fields:
172 | fields = self.fields
173 | dataset = DictionaryDataset(self.examples, fields, key_fields='text')
174 | return dataset
175 |
176 | @property
177 | def splits(self):
178 | raise NotImplementedError
179 |
180 |
181 | class COCO(PairedDataset):
182 | def __init__(self, image_field, text_field, img_root, ann_root, id_root=None, use_restval=True,
183 | cut_validation=False):
184 | roots = {}
185 | roots['train'] = {
186 | 'img': os.path.join(img_root, 'train2014'),
187 | 'cap': os.path.join(ann_root, 'captions_train2014.json')
188 | }
189 | roots['val'] = {
190 | 'img': os.path.join(img_root, 'val2014'),
191 | 'cap': os.path.join(ann_root, 'captions_val2014.json')
192 | }
193 | roots['test'] = {
194 | 'img': os.path.join(img_root, 'val2014'),
195 | 'cap': os.path.join(ann_root, 'captions_val2014.json')
196 | }
197 | roots['trainrestval'] = {
198 | 'img': (roots['train']['img'], roots['val']['img']),
199 | 'cap': (roots['train']['cap'], roots['val']['cap'])
200 | }
201 |
202 | if id_root is not None:
203 | ids = {}
204 | ids['train'] = np.load(os.path.join(id_root, 'coco_train_ids.npy'))
205 | ids['val'] = np.load(os.path.join(id_root, 'coco_dev_ids.npy'))
206 | if cut_validation:
207 | ids['val'] = ids['val'][:5000]
208 | ids['test'] = np.load(os.path.join(id_root, 'coco_test_ids.npy'))
209 | ids['trainrestval'] = (
210 | ids['train'],
211 | np.load(os.path.join(id_root, 'coco_restval_ids.npy')))
212 |
213 | if use_restval:
214 | roots['train'] = roots['trainrestval']
215 | ids['train'] = ids['trainrestval']
216 | else:
217 | ids = None
218 |
219 | with nostdout():
220 | self.train_examples, self.val_examples, self.test_examples = self.get_samples(roots, ids)
221 | examples = self.train_examples + self.val_examples + self.test_examples
222 | super(COCO, self).__init__(examples, {'image': image_field, 'text': text_field})
223 |
224 | @property
225 | def splits(self):
226 | train_split = PairedDataset(self.train_examples, self.fields)
227 | val_split = PairedDataset(self.val_examples, self.fields)
228 | test_split = PairedDataset(self.test_examples, self.fields)
229 | return train_split, val_split, test_split
230 |
231 | @classmethod
232 | def get_samples(cls, roots, ids_dataset=None):
233 | train_samples = []
234 | val_samples = []
235 | test_samples = []
236 |
237 | for split in ['train', 'val', 'test']:
238 | if isinstance(roots[split]['cap'], tuple):
239 | coco_dataset = (pyCOCO(roots[split]['cap'][0]), pyCOCO(roots[split]['cap'][1]))
240 | root = roots[split]['img']
241 | else:
242 | coco_dataset = (pyCOCO(roots[split]['cap']),)
243 | root = (roots[split]['img'],)
244 |
245 | if ids_dataset is None:
246 | ids = list(coco_dataset.anns.keys())
247 | else:
248 | ids = ids_dataset[split]
249 |
250 | if isinstance(ids, tuple):
251 | bp = len(ids[0])
252 | ids = list(ids[0]) + list(ids[1])
253 | else:
254 | bp = len(ids)
255 |
256 | for index in range(len(ids)):
257 | if index < bp:
258 | coco = coco_dataset[0]
259 | img_root = root[0]
260 | else:
261 | coco = coco_dataset[1]
262 | img_root = root[1]
263 |
264 | ann_id = ids[index]
265 | caption = coco.anns[ann_id]['caption']
266 | img_id = coco.anns[ann_id]['image_id']
267 | filename = coco.loadImgs(img_id)[0]['file_name']
268 |
269 | example = Example.fromdict({'image': os.path.join(img_root, filename), 'text': caption})
270 |
271 | if split == 'train':
272 | train_samples.append(example)
273 | elif split == 'val':
274 | val_samples.append(example)
275 | elif split == 'test':
276 | test_samples.append(example)
277 |
278 | return train_samples, val_samples, test_samples
279 |
280 |
281 | class COCO_TestOnline(Dataset):
282 | def __init__(self, feat_path, ann_file, max_detections=49):
283 | """
284 | feat_path: COCO官方划分的训练集和验证集的特征路径
285 | ann_file: 训练集或验证集的标注信息,用于获取image_id,进而检索出对应特征
286 | """
287 | super(COCO_TestOnline, self).__init__()
288 |
289 | # 读取图像信息
290 | with open(ann_file, 'r') as f:
291 | self.images_info = json.load(f)['images']
292 |
293 | # 读取特征文件
294 | self.f = h5py.File(feat_path, 'r')
295 |
296 | # 记录特征数目
297 | self.max_detections = max_detections
298 |
299 | def __len__(self):
300 | return len(self.images_info)
301 |
302 | def __getitem__(self, idx):
303 | image_id = self.images_info[idx]['id']
304 | precomp_data = self.f['%d_grids' % image_id][()]
305 |
306 | delta = self.max_detections - precomp_data.shape[0]
307 | if delta > 0:
308 | precomp_data = np.concatenate([precomp_data, np.zeros((delta, precomp_data.shape[1]))], axis=0)
309 | elif delta < 0:
310 | precomp_data = precomp_data[:self.max_detections]
311 |
312 | return int(image_id), precomp_data
313 |
314 |
--------------------------------------------------------------------------------
/data/example.py:
--------------------------------------------------------------------------------
1 |
2 | class Example(object):
3 | """Defines a single training or test example.
4 | Stores each column of the example as an attribute.
5 | """
6 | @classmethod
7 | def fromdict(cls, data):
8 | ex = cls(data)
9 | return ex
10 |
11 | def __init__(self, data):
12 | for key, val in data.items():
13 | super(Example, self).__setattr__(key, val)
14 |
15 | def __setattr__(self, key, value):
16 | raise AttributeError
17 |
18 | def __hash__(self):
19 | return hash(tuple(x for x in self.__dict__.values()))
20 |
21 | def __eq__(self, other):
22 | this = tuple(x for x in self.__dict__.values())
23 | other = tuple(x for x in other.__dict__.values())
24 | return this == other
25 |
26 | def __ne__(self, other):
27 | return not self.__eq__(other)
28 |
--------------------------------------------------------------------------------
/data/field.py:
--------------------------------------------------------------------------------
1 | # coding: utf8
2 | from collections import Counter, OrderedDict
3 | from torch.utils.data.dataloader import default_collate
4 | from itertools import chain
5 | import six
6 | import torch
7 | import numpy as np
8 | import h5py
9 | import os
10 | import warnings
11 | import shutil
12 |
13 | from .dataset import Dataset
14 | from .vocab import Vocab
15 | from .utils import get_tokenizer
16 |
17 |
18 | class RawField(object):
19 | """ Defines a general datatype.
20 |
21 | Every dataset consists of one or more types of data. For instance,
22 | a machine translation dataset contains paired examples of text, while
23 | an image captioning dataset contains images and texts.
24 | Each of these types of data is represented by a RawField object.
25 | An RawField object does not assume any property of the data type and
26 | it holds parameters relating to how a datatype should be processed.
27 |
28 | Attributes:
29 | preprocessing: The Pipeline that will be applied to examples
30 | using this field before creating an example.
31 | Default: None.
32 | postprocessing: A Pipeline that will be applied to a list of examples
33 | using this field before assigning to a batch.
34 | Function signature: (batch(list)) -> object
35 | Default: None.
36 | """
37 |
38 | def __init__(self, preprocessing=None, postprocessing=None):
39 | self.preprocessing = preprocessing
40 | self.postprocessing = postprocessing
41 |
42 | def preprocess(self, x):
43 | """ Preprocess an example if the `preprocessing` Pipeline is provided. """
44 | if self.preprocessing is not None:
45 | return self.preprocessing(x)
46 | else:
47 | return x
48 |
49 | def process(self, batch, *args, **kwargs):
50 | """ Process a list of examples to create a batch.
51 |
52 | Postprocess the batch with user-provided Pipeline.
53 |
54 | Args:
55 | batch (list(object)): A list of object from a batch of examples.
56 | Returns:
57 | object: Processed object given the input and custom
58 | postprocessing Pipeline.
59 | """
60 | if self.postprocessing is not None:
61 | batch = self.postprocessing(batch)
62 | return default_collate(batch)
63 |
64 |
65 | class Merge(RawField):
66 | def __init__(self, *fields):
67 | super(Merge, self).__init__()
68 | self.fields = fields
69 |
70 | def preprocess(self, x):
71 | return tuple(f.preprocess(x) for f in self.fields)
72 |
73 | def process(self, batch, *args, **kwargs):
74 | if len(self.fields) == 1:
75 | batch = [batch, ]
76 | else:
77 | batch = list(zip(*batch))
78 |
79 | out = list(f.process(b, *args, **kwargs) for f, b in zip(self.fields, batch))
80 | return out
81 |
82 |
83 | class ImageDetectionsField(RawField):
84 | def __init__(self, preprocessing=None, postprocessing=None, detections_path=None, max_detections=100,
85 | sort_by_prob=False, load_in_tmp=True):
86 | self.max_detections = max_detections
87 | self.detections_path = detections_path
88 | self.sort_by_prob = sort_by_prob
89 |
90 | tmp_detections_path = os.path.join('/tmp', os.path.basename(detections_path))
91 |
92 | if load_in_tmp:
93 | if not os.path.isfile(tmp_detections_path):
94 | if shutil.disk_usage("/tmp")[-1] < os.path.getsize(detections_path):
95 | warnings.warn('Loading from %s, because /tmp has no enough space.' % detections_path)
96 | else:
97 | warnings.warn("Copying detection file to /tmp")
98 | shutil.copyfile(detections_path, tmp_detections_path)
99 | warnings.warn("Done.")
100 | self.detections_path = tmp_detections_path
101 | else:
102 | self.detections_path = tmp_detections_path
103 |
104 | super(ImageDetectionsField, self).__init__(preprocessing, postprocessing)
105 |
106 | def preprocess(self, x, avoid_precomp=False):
107 | image_id = int(x.split('_')[-1].split('.')[0])
108 | try:
109 | f = h5py.File(self.detections_path, 'r')
110 | # precomp_data = f['%d_features' % image_id][()]
111 | precomp_data = f['%d_grids' % image_id][()]
112 | if self.sort_by_prob:
113 | precomp_data = precomp_data[np.argsort(np.max(f['%d_cls_prob' % image_id][()], -1))[::-1]]
114 | except KeyError:
115 | warnings.warn('Could not find detections for %d' % image_id)
116 | precomp_data = np.random.rand(10,2048)
117 |
118 | delta = self.max_detections - precomp_data.shape[0]
119 | if delta > 0:
120 | precomp_data = np.concatenate([precomp_data, np.zeros((delta, precomp_data.shape[1]))], axis=0)
121 | elif delta < 0:
122 | precomp_data = precomp_data[:self.max_detections]
123 |
124 | return precomp_data.astype(np.float32)
125 |
126 |
127 | class TextField(RawField):
128 | vocab_cls = Vocab
129 | # Dictionary mapping PyTorch tensor dtypes to the appropriate Python
130 | # numeric type.
131 | dtypes = {
132 | torch.float32: float,
133 | torch.float: float,
134 | torch.float64: float,
135 | torch.double: float,
136 | torch.float16: float,
137 | torch.half: float,
138 |
139 | torch.uint8: int,
140 | torch.int8: int,
141 | torch.int16: int,
142 | torch.short: int,
143 | torch.int32: int,
144 | torch.int: int,
145 | torch.int64: int,
146 | torch.long: int,
147 | }
148 | punctuations = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \
149 | ".", "?", "!", ",", ":", "-", "--", "...", ";"]
150 |
151 | def __init__(self, use_vocab=True, init_token=None, eos_token=None, fix_length=None, dtype=torch.long,
152 | preprocessing=None, postprocessing=None, lower=False, tokenize=(lambda s: s.split()),
153 | remove_punctuation=False, include_lengths=False, batch_first=True, pad_token="",
154 | unk_token="", pad_first=False, truncate_first=False, vectors=None, nopoints=True):
155 | self.use_vocab = use_vocab
156 | self.init_token = init_token
157 | self.eos_token = eos_token
158 | self.fix_length = fix_length
159 | self.dtype = dtype
160 | self.lower = lower
161 | self.tokenize = get_tokenizer(tokenize)
162 | self.remove_punctuation = remove_punctuation
163 | self.include_lengths = include_lengths
164 | self.batch_first = batch_first
165 | self.pad_token = pad_token
166 | self.unk_token = unk_token
167 | self.pad_first = pad_first
168 | self.truncate_first = truncate_first
169 | self.vocab = None
170 | self.vectors = vectors
171 | if nopoints:
172 | self.punctuations.append("..")
173 |
174 | super(TextField, self).__init__(preprocessing, postprocessing)
175 |
176 | def preprocess(self, x):
177 | if six.PY2 and isinstance(x, six.string_types) and not isinstance(x, six.text_type):
178 | x = six.text_type(x, encoding='utf-8')
179 | if self.lower:
180 | x = six.text_type.lower(x)
181 | x = self.tokenize(x.rstrip('\n'))
182 | if self.remove_punctuation:
183 | x = [w for w in x if w not in self.punctuations]
184 | if self.preprocessing is not None:
185 | return self.preprocessing(x)
186 | else:
187 | return x
188 |
189 | def process(self, batch, device=None):
190 | padded = self.pad(batch)
191 | tensor = self.numericalize(padded, device=device)
192 | return tensor
193 |
194 | def build_vocab(self, *args, **kwargs):
195 | counter = Counter()
196 | sources = []
197 | for arg in args:
198 | if isinstance(arg, Dataset):
199 | sources += [getattr(arg, name) for name, field in arg.fields.items() if field is self]
200 | else:
201 | sources.append(arg)
202 |
203 | for data in sources:
204 | for x in data:
205 | x = self.preprocess(x)
206 | try:
207 | counter.update(x)
208 | except TypeError:
209 | counter.update(chain.from_iterable(x))
210 |
211 | specials = list(OrderedDict.fromkeys([
212 | tok for tok in [self.unk_token, self.pad_token, self.init_token,
213 | self.eos_token]
214 | if tok is not None]))
215 | self.vocab = self.vocab_cls(counter, specials=specials, **kwargs)
216 |
217 | def pad(self, minibatch):
218 | """Pad a batch of examples using this field.
219 | Pads to self.fix_length if provided, otherwise pads to the length of
220 | the longest example in the batch. Prepends self.init_token and appends
221 | self.eos_token if those attributes are not None. Returns a tuple of the
222 | padded list and a list containing lengths of each example if
223 | `self.include_lengths` is `True`, else just
224 | returns the padded list.
225 | """
226 | minibatch = list(minibatch)
227 | if self.fix_length is None:
228 | max_len = max(len(x) for x in minibatch)
229 | else:
230 | max_len = self.fix_length + (
231 | self.init_token, self.eos_token).count(None) - 2
232 | padded, lengths = [], []
233 | for x in minibatch:
234 | if self.pad_first:
235 | padded.append(
236 | [self.pad_token] * max(0, max_len - len(x)) +
237 | ([] if self.init_token is None else [self.init_token]) +
238 | list(x[-max_len:] if self.truncate_first else x[:max_len]) +
239 | ([] if self.eos_token is None else [self.eos_token]))
240 | else:
241 | padded.append(
242 | ([] if self.init_token is None else [self.init_token]) +
243 | list(x[-max_len:] if self.truncate_first else x[:max_len]) +
244 | ([] if self.eos_token is None else [self.eos_token]) +
245 | [self.pad_token] * max(0, max_len - len(x)))
246 | lengths.append(len(padded[-1]) - max(0, max_len - len(x)))
247 | if self.include_lengths:
248 | return padded, lengths
249 | return padded
250 |
251 | def numericalize(self, arr, device=None):
252 | """Turn a batch of examples that use this field into a list of Variables.
253 | If the field has include_lengths=True, a tensor of lengths will be
254 | included in the return value.
255 | Arguments:
256 | arr (List[List[str]], or tuple of (List[List[str]], List[int])):
257 | List of tokenized and padded examples, or tuple of List of
258 | tokenized and padded examples and List of lengths of each
259 | example if self.include_lengths is True.
260 | device (str or torch.device): A string or instance of `torch.device`
261 | specifying which device the Variables are going to be created on.
262 | If left as default, the tensors will be created on cpu. Default: None.
263 | """
264 | if self.include_lengths and not isinstance(arr, tuple):
265 | raise ValueError("Field has include_lengths set to True, but "
266 | "input data is not a tuple of "
267 | "(data batch, batch lengths).")
268 | if isinstance(arr, tuple):
269 | arr, lengths = arr
270 | lengths = torch.tensor(lengths, dtype=self.dtype, device=device)
271 |
272 | if self.use_vocab:
273 | arr = [[self.vocab.stoi[x] for x in ex] for ex in arr]
274 |
275 | if self.postprocessing is not None:
276 | arr = self.postprocessing(arr, self.vocab)
277 |
278 | var = torch.tensor(arr, dtype=self.dtype, device=device)
279 | else:
280 | if self.vectors:
281 | arr = [[self.vectors[x] for x in ex] for ex in arr]
282 | if self.dtype not in self.dtypes:
283 | raise ValueError(
284 | "Specified Field dtype {} can not be used with "
285 | "use_vocab=False because we do not know how to numericalize it. "
286 | "Please raise an issue at "
287 | "https://github.com/pytorch/text/issues".format(self.dtype))
288 | numericalization_func = self.dtypes[self.dtype]
289 | # It doesn't make sense to explictly coerce to a numeric type if
290 | # the data is sequential, since it's unclear how to coerce padding tokens
291 | # to a numeric type.
292 | arr = [numericalization_func(x) if isinstance(x, six.string_types)
293 | else x for x in arr]
294 |
295 | if self.postprocessing is not None:
296 | arr = self.postprocessing(arr, None)
297 |
298 | var = torch.cat([torch.cat([a.unsqueeze(0) for a in ar]).unsqueeze(0) for ar in arr])
299 |
300 | # var = torch.tensor(arr, dtype=self.dtype, device=device)
301 | if not self.batch_first:
302 | var.t_()
303 | var = var.contiguous()
304 |
305 | if self.include_lengths:
306 | return var, lengths
307 | return var
308 |
309 | def decode(self, word_idxs, join_words=True):
310 | if isinstance(word_idxs, list) and len(word_idxs) == 0:
311 | return self.decode([word_idxs, ], join_words)[0]
312 | if isinstance(word_idxs, list) and isinstance(word_idxs[0], int):
313 | return self.decode([word_idxs, ], join_words)[0]
314 | elif isinstance(word_idxs, np.ndarray) and word_idxs.ndim == 1:
315 | return self.decode(word_idxs.reshape((1, -1)), join_words)[0]
316 | elif isinstance(word_idxs, torch.Tensor) and word_idxs.ndimension() == 1:
317 | return self.decode(word_idxs.unsqueeze(0), join_words)[0]
318 |
319 | captions = []
320 | for wis in word_idxs:
321 | caption = []
322 | for wi in wis:
323 | word = self.vocab.itos[int(wi)]
324 | if word == self.eos_token:
325 | break
326 | caption.append(word)
327 | if join_words:
328 | caption = ' '.join(caption)
329 | captions.append(caption)
330 | return captions
331 |
--------------------------------------------------------------------------------
/data/utils.py:
--------------------------------------------------------------------------------
1 | import contextlib, sys
2 |
3 | class DummyFile(object):
4 | def write(self, x): pass
5 |
6 | @contextlib.contextmanager
7 | def nostdout():
8 | save_stdout = sys.stdout
9 | sys.stdout = DummyFile()
10 | yield
11 | sys.stdout = save_stdout
12 |
13 | def reporthook(t):
14 | """https://github.com/tqdm/tqdm"""
15 | last_b = [0]
16 |
17 | def inner(b=1, bsize=1, tsize=None):
18 | """
19 | b: int, optionala
20 | Number of blocks just transferred [default: 1].
21 | bsize: int, optional
22 | Size of each block (in tqdm units) [default: 1].
23 | tsize: int, optional
24 | Total size (in tqdm units). If [default: None] remains unchanged.
25 | """
26 | if tsize is not None:
27 | t.total = tsize
28 | t.update((b - last_b[0]) * bsize)
29 | last_b[0] = b
30 | return inner
31 |
32 | def get_tokenizer(tokenizer):
33 | if callable(tokenizer):
34 | return tokenizer
35 | if tokenizer == "spacy":
36 | try:
37 | import spacy
38 | spacy_en = spacy.load('en')
39 | return lambda s: [tok.text for tok in spacy_en.tokenizer(s)]
40 | except ImportError:
41 | print("Please install SpaCy and the SpaCy English tokenizer. "
42 | "See the docs at https://spacy.io for more information.")
43 | raise
44 | except AttributeError:
45 | print("Please install SpaCy and the SpaCy English tokenizer. "
46 | "See the docs at https://spacy.io for more information.")
47 | raise
48 | elif tokenizer == "moses":
49 | try:
50 | from nltk.tokenize.moses import MosesTokenizer
51 | moses_tokenizer = MosesTokenizer()
52 | return moses_tokenizer.tokenize
53 | except ImportError:
54 | print("Please install NLTK. "
55 | "See the docs at http://nltk.org for more information.")
56 | raise
57 | except LookupError:
58 | print("Please install the necessary NLTK corpora. "
59 | "See the docs at http://nltk.org for more information.")
60 | raise
61 | elif tokenizer == 'revtok':
62 | try:
63 | import revtok
64 | return revtok.tokenize
65 | except ImportError:
66 | print("Please install revtok.")
67 | raise
68 | elif tokenizer == 'subword':
69 | try:
70 | import revtok
71 | return lambda x: revtok.tokenize(x, decap=True)
72 | except ImportError:
73 | print("Please install revtok.")
74 | raise
75 | raise ValueError("Requested tokenizer {}, valid choices are a "
76 | "callable that takes a single string as input, "
77 | "\"revtok\" for the revtok reversible tokenizer, "
78 | "\"subword\" for the revtok caps-aware tokenizer, "
79 | "\"spacy\" for the SpaCy English tokenizer, or "
80 | "\"moses\" for the NLTK port of the Moses tokenization "
81 | "script.".format(tokenizer))
82 |
--------------------------------------------------------------------------------
/data/vocab.py:
--------------------------------------------------------------------------------
1 | from __future__ import unicode_literals
2 | import array
3 | from collections import defaultdict
4 | from functools import partial
5 | import io
6 | import logging
7 | import os
8 | import zipfile
9 |
10 | import six
11 | from six.moves.urllib.request import urlretrieve
12 | import torch
13 | from tqdm import tqdm
14 | import tarfile
15 |
16 | from .utils import reporthook
17 |
18 | logger = logging.getLogger(__name__)
19 |
20 |
21 | class Vocab(object):
22 | """Defines a vocabulary object that will be used to numericalize a field.
23 |
24 | Attributes:
25 | freqs: A collections.Counter object holding the frequencies of tokens
26 | in the data used to build the Vocab.
27 | stoi: A collections.defaultdict instance mapping token strings to
28 | numerical identifiers.
29 | itos: A list of token strings indexed by their numerical identifiers.
30 | """
31 | def __init__(self, counter, max_size=None, min_freq=1, specials=[''],
32 | vectors=None, unk_init=None, vectors_cache=None):
33 | """Create a Vocab object from a collections.Counter.
34 |
35 | Arguments:
36 | counter: collections.Counter object holding the frequencies of
37 | each value found in the data.
38 | max_size: The maximum size of the vocabulary, or None for no
39 | maximum. Default: None.
40 | min_freq: The minimum frequency needed to include a token in the
41 | vocabulary. Values less than 1 will be set to 1. Default: 1.
42 | specials: The list of special tokens (e.g., padding or eos) that
43 | will be prepended to the vocabulary in addition to an
44 | token. Default: ['']
45 | vectors: One of either the available pretrained vectors
46 | or custom pretrained vectors (see Vocab.load_vectors);
47 | or a list of aforementioned vectors
48 | unk_init (callback): by default, initialize out-of-vocabulary word vectors
49 | to zero vectors; can be any function that takes in a Tensor and
50 | returns a Tensor of the same size. Default: torch.Tensor.zero_
51 | vectors_cache: directory for cached vectors. Default: '.vector_cache'
52 | """
53 | self.freqs = counter
54 | counter = counter.copy()
55 | min_freq = max(min_freq, 1)
56 |
57 | self.itos = list(specials)
58 | # frequencies of special tokens are not counted when building vocabulary
59 | # in frequency order
60 | for tok in specials:
61 | del counter[tok]
62 |
63 | max_size = None if max_size is None else max_size + len(self.itos)
64 |
65 | # sort by frequency, then alphabetically
66 | words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0])
67 | words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True)
68 |
69 | for word, freq in words_and_frequencies:
70 | if freq < min_freq or len(self.itos) == max_size:
71 | break
72 | self.itos.append(word)
73 |
74 | self.stoi = defaultdict(_default_unk_index)
75 | # stoi is simply a reverse dict for itos
76 | self.stoi.update({tok: i for i, tok in enumerate(self.itos)})
77 |
78 | self.vectors = None
79 | if vectors is not None:
80 | self.load_vectors(vectors, unk_init=unk_init, cache=vectors_cache)
81 | else:
82 | assert unk_init is None and vectors_cache is None
83 |
84 | def __eq__(self, other):
85 | if self.freqs != other.freqs:
86 | return False
87 | if self.stoi != other.stoi:
88 | return False
89 | if self.itos != other.itos:
90 | return False
91 | if self.vectors != other.vectors:
92 | return False
93 | return True
94 |
95 | def __len__(self):
96 | return len(self.itos)
97 |
98 | def extend(self, v, sort=False):
99 | words = sorted(v.itos) if sort else v.itos
100 | for w in words:
101 | if w not in self.stoi:
102 | self.itos.append(w)
103 | self.stoi[w] = len(self.itos) - 1
104 |
105 | def load_vectors(self, vectors, **kwargs):
106 | """
107 | Arguments:
108 | vectors: one of or a list containing instantiations of the
109 | GloVe, CharNGram, or Vectors classes. Alternatively, one
110 | of or a list of available pretrained vectors:
111 | charngram.100d
112 | fasttext.en.300d
113 | fasttext.simple.300d
114 | glove.42B.300d
115 | glove.840B.300d
116 | glove.twitter.27B.25d
117 | glove.twitter.27B.50d
118 | glove.twitter.27B.100d
119 | glove.twitter.27B.200d
120 | glove.6B.50d
121 | glove.6B.100d
122 | glove.6B.200d
123 | glove.6B.300d
124 | Remaining keyword arguments: Passed to the constructor of Vectors classes.
125 | """
126 | if not isinstance(vectors, list):
127 | vectors = [vectors]
128 | for idx, vector in enumerate(vectors):
129 | if six.PY2 and isinstance(vector, str):
130 | vector = six.text_type(vector)
131 | if isinstance(vector, six.string_types):
132 | # Convert the string pretrained vector identifier
133 | # to a Vectors object
134 | if vector not in pretrained_aliases:
135 | raise ValueError(
136 | "Got string input vector {}, but allowed pretrained "
137 | "vectors are {}".format(
138 | vector, list(pretrained_aliases.keys())))
139 | vectors[idx] = pretrained_aliases[vector](**kwargs)
140 | elif not isinstance(vector, Vectors):
141 | raise ValueError(
142 | "Got input vectors of type {}, expected str or "
143 | "Vectors object".format(type(vector)))
144 |
145 | tot_dim = sum(v.dim for v in vectors)
146 | self.vectors = torch.Tensor(len(self), tot_dim)
147 | for i, token in enumerate(self.itos):
148 | start_dim = 0
149 | for v in vectors:
150 | end_dim = start_dim + v.dim
151 | self.vectors[i][start_dim:end_dim] = v[token.strip()]
152 | start_dim = end_dim
153 | assert(start_dim == tot_dim)
154 |
155 | def set_vectors(self, stoi, vectors, dim, unk_init=torch.Tensor.zero_):
156 | """
157 | Set the vectors for the Vocab instance from a collection of Tensors.
158 |
159 | Arguments:
160 | stoi: A dictionary of string to the index of the associated vector
161 | in the `vectors` input argument.
162 | vectors: An indexed iterable (or other structure supporting __getitem__) that
163 | given an input index, returns a FloatTensor representing the vector
164 | for the token associated with the index. For example,
165 | vector[stoi["string"]] should return the vector for "string".
166 | dim: The dimensionality of the vectors.
167 | unk_init (callback): by default, initialize out-of-vocabulary word vectors
168 | to zero vectors; can be any function that takes in a Tensor and
169 | returns a Tensor of the same size. Default: torch.Tensor.zero_
170 | """
171 | self.vectors = torch.Tensor(len(self), dim)
172 | for i, token in enumerate(self.itos):
173 | wv_index = stoi.get(token, None)
174 | if wv_index is not None:
175 | self.vectors[i] = vectors[wv_index]
176 | else:
177 | self.vectors[i] = unk_init(self.vectors[i])
178 |
179 |
180 | class Vectors(object):
181 |
182 | def __init__(self, name, cache=None,
183 | url=None, unk_init=None):
184 | """
185 | Arguments:
186 | name: name of the file that contains the vectors
187 | cache: directory for cached vectors
188 | url: url for download if vectors not found in cache
189 | unk_init (callback): by default, initalize out-of-vocabulary word vectors
190 | to zero vectors; can be any function that takes in a Tensor and
191 | returns a Tensor of the same size
192 | """
193 | cache = '.vector_cache' if cache is None else cache
194 | self.unk_init = torch.Tensor.zero_ if unk_init is None else unk_init
195 | self.cache(name, cache, url=url)
196 |
197 | def __getitem__(self, token):
198 | if token in self.stoi:
199 | return self.vectors[self.stoi[token]]
200 | else:
201 | return self.unk_init(torch.Tensor(self.dim)) # self.unk_init(torch.Tensor(1, self.dim))
202 |
203 | def cache(self, name, cache, url=None):
204 | if os.path.isfile(name):
205 | path = name
206 | path_pt = os.path.join(cache, os.path.basename(name)) + '.pt'
207 | else:
208 | path = os.path.join(cache, name)
209 | path_pt = path + '.pt'
210 |
211 | if not os.path.isfile(path_pt):
212 | if not os.path.isfile(path) and url:
213 | logger.info('Downloading vectors from {}'.format(url))
214 | if not os.path.exists(cache):
215 | os.makedirs(cache)
216 | dest = os.path.join(cache, os.path.basename(url))
217 | if not os.path.isfile(dest):
218 | with tqdm(unit='B', unit_scale=True, miniters=1, desc=dest) as t:
219 | try:
220 | urlretrieve(url, dest, reporthook=reporthook(t))
221 | except KeyboardInterrupt as e: # remove the partial zip file
222 | os.remove(dest)
223 | raise e
224 | logger.info('Extracting vectors into {}'.format(cache))
225 | ext = os.path.splitext(dest)[1][1:]
226 | if ext == 'zip':
227 | with zipfile.ZipFile(dest, "r") as zf:
228 | zf.extractall(cache)
229 | elif ext == 'gz':
230 | with tarfile.open(dest, 'r:gz') as tar:
231 | tar.extractall(path=cache)
232 | if not os.path.isfile(path):
233 | raise RuntimeError('no vectors found at {}'.format(path))
234 |
235 | # str call is necessary for Python 2/3 compatibility, since
236 | # argument must be Python 2 str (Python 3 bytes) or
237 | # Python 3 str (Python 2 unicode)
238 | itos, vectors, dim = [], array.array(str('d')), None
239 |
240 | # Try to read the whole file with utf-8 encoding.
241 | binary_lines = False
242 | try:
243 | with io.open(path, encoding="utf8") as f:
244 | lines = [line for line in f]
245 | # If there are malformed lines, read in binary mode
246 | # and manually decode each word from utf-8
247 | except:
248 | logger.warning("Could not read {} as UTF8 file, "
249 | "reading file as bytes and skipping "
250 | "words with malformed UTF8.".format(path))
251 | with open(path, 'rb') as f:
252 | lines = [line for line in f]
253 | binary_lines = True
254 |
255 | logger.info("Loading vectors from {}".format(path))
256 | for line in tqdm(lines, total=len(lines)):
257 | # Explicitly splitting on " " is important, so we don't
258 | # get rid of Unicode non-breaking spaces in the vectors.
259 | entries = line.rstrip().split(b" " if binary_lines else " ")
260 |
261 | word, entries = entries[0], entries[1:]
262 | if dim is None and len(entries) > 1:
263 | dim = len(entries)
264 | elif len(entries) == 1:
265 | logger.warning("Skipping token {} with 1-dimensional "
266 | "vector {}; likely a header".format(word, entries))
267 | continue
268 | elif dim != len(entries):
269 | raise RuntimeError(
270 | "Vector for token {} has {} dimensions, but previously "
271 | "read vectors have {} dimensions. All vectors must have "
272 | "the same number of dimensions.".format(word, len(entries), dim))
273 |
274 | if binary_lines:
275 | try:
276 | if isinstance(word, six.binary_type):
277 | word = word.decode('utf-8')
278 | except:
279 | logger.info("Skipping non-UTF8 token {}".format(repr(word)))
280 | continue
281 | vectors.extend(float(x) for x in entries)
282 | itos.append(word)
283 |
284 | self.itos = itos
285 | self.stoi = {word: i for i, word in enumerate(itos)}
286 | self.vectors = torch.Tensor(vectors).view(-1, dim)
287 | self.dim = dim
288 | logger.info('Saving vectors to {}'.format(path_pt))
289 | if not os.path.exists(cache):
290 | os.makedirs(cache)
291 | torch.save((self.itos, self.stoi, self.vectors, self.dim), path_pt)
292 | else:
293 | logger.info('Loading vectors from {}'.format(path_pt))
294 | self.itos, self.stoi, self.vectors, self.dim = torch.load(path_pt)
295 |
296 |
297 | class GloVe(Vectors):
298 | url = {
299 | '42B': 'http://nlp.stanford.edu/data/glove.42B.300d.zip',
300 | '840B': 'http://nlp.stanford.edu/data/glove.840B.300d.zip',
301 | 'twitter.27B': 'http://nlp.stanford.edu/data/glove.twitter.27B.zip',
302 | '6B': 'http://nlp.stanford.edu/data/glove.6B.zip',
303 | }
304 |
305 | def __init__(self, name='840B', dim=300, **kwargs):
306 | url = self.url[name]
307 | name = 'glove.{}.{}d.txt'.format(name, str(dim))
308 | super(GloVe, self).__init__(name, url=url, **kwargs)
309 |
310 |
311 | class FastText(Vectors):
312 |
313 | url_base = 'https://s3-us-west-1.amazonaws.com/fasttext-vectors/wiki.{}.vec'
314 |
315 | def __init__(self, language="en", **kwargs):
316 | url = self.url_base.format(language)
317 | name = os.path.basename(url)
318 | super(FastText, self).__init__(name, url=url, **kwargs)
319 |
320 |
321 | class CharNGram(Vectors):
322 |
323 | name = 'charNgram.txt'
324 | url = ('http://www.logos.t.u-tokyo.ac.jp/~hassy/publications/arxiv2016jmt/'
325 | 'jmt_pre-trained_embeddings.tar.gz')
326 |
327 | def __init__(self, **kwargs):
328 | super(CharNGram, self).__init__(self.name, url=self.url, **kwargs)
329 |
330 | def __getitem__(self, token):
331 | vector = torch.Tensor(1, self.dim).zero_()
332 | if token == "":
333 | return self.unk_init(vector)
334 | # These literals need to be coerced to unicode for Python 2 compatibility
335 | # when we try to join them with read ngrams from the files.
336 | chars = ['#BEGIN#'] + list(token) + ['#END#']
337 | num_vectors = 0
338 | for n in [2, 3, 4]:
339 | end = len(chars) - n + 1
340 | grams = [chars[i:(i + n)] for i in range(end)]
341 | for gram in grams:
342 | gram_key = '{}gram-{}'.format(n, ''.join(gram))
343 | if gram_key in self.stoi:
344 | vector += self.vectors[self.stoi[gram_key]]
345 | num_vectors += 1
346 | if num_vectors > 0:
347 | vector /= num_vectors
348 | else:
349 | vector = self.unk_init(vector)
350 | return vector
351 |
352 |
353 | def _default_unk_index():
354 | return 0
355 |
356 |
357 | pretrained_aliases = {
358 | "charngram.100d": partial(CharNGram),
359 | "fasttext.en.300d": partial(FastText, language="en"),
360 | "fasttext.simple.300d": partial(FastText, language="simple"),
361 | "glove.42B.300d": partial(GloVe, name="42B", dim="300"),
362 | "glove.840B.300d": partial(GloVe, name="840B", dim="300"),
363 | "glove.twitter.27B.25d": partial(GloVe, name="twitter.27B", dim="25"),
364 | "glove.twitter.27B.50d": partial(GloVe, name="twitter.27B", dim="50"),
365 | "glove.twitter.27B.100d": partial(GloVe, name="twitter.27B", dim="100"),
366 | "glove.twitter.27B.200d": partial(GloVe, name="twitter.27B", dim="200"),
367 | "glove.6B.50d": partial(GloVe, name="6B", dim="50"),
368 | "glove.6B.100d": partial(GloVe, name="6B", dim="100"),
369 | "glove.6B.200d": partial(GloVe, name="6B", dim="200"),
370 | "glove.6B.300d": partial(GloVe, name="6B", dim="300")
371 | }
372 | """Mapping from string name to factory function"""
373 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: m2release
2 | channels:
3 | - anaconda
4 | - defaults
5 | dependencies:
6 | - _libgcc_mutex=0.1=main
7 | - asn1crypto=1.2.0=py36_0
8 | - blas=1.0=mkl
9 | - ca-certificates=2019.10.16=0
10 | - certifi=2019.9.11=py36_0
11 | - cffi=1.13.2=py36h2e261b9_0
12 | - chardet=3.0.4=py36_1003
13 | - cryptography=2.8=py36h1ba5d50_0
14 | - cython=0.29.14=py36he6710b0_0
15 | - dill=0.2.9=py36_0
16 | - idna=2.8=py36_0
17 | - intel-openmp=2019.5=281
18 | - libedit=3.1.20181209=hc058e9b_0
19 | - libffi=3.2.1=hd88cf55_4
20 | - libgcc-ng=9.1.0=hdf63c60_0
21 | - libgfortran-ng=7.3.0=hdf63c60_0
22 | - libstdcxx-ng=9.1.0=hdf63c60_0
23 | - mkl=2019.5=281
24 | - mkl-service=2.3.0=py36he904b0f_0
25 | - mkl_fft=1.0.15=py36ha843d7b_0
26 | - mkl_random=1.1.0=py36hd6b4f25_0
27 | - msgpack-numpy=0.4.4.3=py_0
28 | - msgpack-python=0.5.6=py36h6bb024c_1
29 | - ncurses=6.1=he6710b0_1
30 | - openjdk=8.0.152=h46b5887_1
31 | - openssl=1.1.1=h7b6447c_0
32 | - pip=19.3.1=py36_0
33 | - pycparser=2.19=py_0
34 | - pyopenssl=19.1.0=py36_0
35 | - pysocks=1.7.1=py36_0
36 | - python=3.6.9=h265db76_0
37 | - readline=7.0=h7b6447c_5
38 | - requests=2.22.0=py36_0
39 | - setuptools=41.6.0=py36_0
40 | - six=1.13.0=py36_0
41 | - spacy=2.0.11=py36h04863e7_2
42 | - sqlite=3.30.1=h7b6447c_0
43 | - termcolor=1.1.0=py36_1
44 | - thinc=6.11.2=py36hedc7406_1
45 | - tk=8.6.8=hbc83047_0
46 | - toolz=0.10.0=py_0
47 | - urllib3=1.24.2=py36_0
48 | - wheel=0.33.6=py36_0
49 | - xz=5.2.4=h14c3975_4
50 | - zlib=1.2.11=h7b6447c_3
51 | - pip:
52 | - absl-py==0.8.1
53 | - cycler==0.10.0
54 | - cymem==1.31.2
55 | - cytoolz==0.9.0.1
56 | - future==0.17.1
57 | - grpcio==1.25.0
58 | - h5py==2.8.0
59 | - kiwisolver==1.1.0
60 | - markdown==3.1.1
61 | - matplotlib==2.2.3
62 | - msgpack==0.6.2
63 | - multiprocess==0.70.9
64 | - murmurhash==0.28.0
65 | - numpy==1.16.4
66 | - pathlib==1.0.1
67 | - pathos==0.2.3
68 | - pillow==6.2.1
69 | - plac==0.9.6
70 | - pox==0.2.7
71 | - ppft==1.6.6.1
72 | - preshed==1.0.1
73 | - protobuf==3.10.0
74 | - pycocotools==2.0.0
75 | - pyparsing==2.4.5
76 | - python-dateutil==2.8.1
77 | - pytz==2019.3
78 | - regex==2017.4.5
79 | - tensorboard==1.14.0
80 | - torch==1.1.0
81 | - torchvision==0.3.0
82 | - tqdm==4.32.2
83 | - ujson==1.35
84 | - werkzeug==0.16.0
85 | - wrapt==1.10.11
86 | - transformers
87 |
88 |
--------------------------------------------------------------------------------
/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | from .bleu import Bleu
2 | from .meteor import Meteor
3 | from .rouge import Rouge
4 | from .cider import Cider
5 | from .tokenizer import PTBTokenizer
6 |
7 | def compute_scores(gts, gen):
8 | metrics = (Bleu(), Meteor(), Rouge(), Cider())
9 | all_score = {}
10 | all_scores = {}
11 | for metric in metrics:
12 | score, scores = metric.compute_score(gts, gen)
13 | all_score[str(metric)] = score
14 | all_scores[str(metric)] = scores
15 |
16 | return all_score, all_scores
17 |
--------------------------------------------------------------------------------
/evaluation/bleu/__init__.py:
--------------------------------------------------------------------------------
1 | from .bleu import Bleu
--------------------------------------------------------------------------------
/evaluation/bleu/bleu.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #
3 | # File Name : bleu.py
4 | #
5 | # Description : Wrapper for BLEU scorer.
6 | #
7 | # Creation Date : 06-01-2015
8 | # Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT
9 | # Authors : Hao Fang and Tsung-Yi Lin
10 |
11 | from .bleu_scorer import BleuScorer
12 |
13 |
14 | class Bleu:
15 | def __init__(self, n=4):
16 | # default compute Blue score up to 4
17 | self._n = n
18 | self._hypo_for_image = {}
19 | self.ref_for_image = {}
20 |
21 | def compute_score(self, gts, res):
22 |
23 | assert(gts.keys() == res.keys())
24 | imgIds = gts.keys()
25 |
26 | bleu_scorer = BleuScorer(n=self._n)
27 | for id in imgIds:
28 | hypo = res[id]
29 | ref = gts[id]
30 |
31 | # Sanity check.
32 | assert(type(hypo) is list)
33 | assert(len(hypo) == 1)
34 | assert(type(ref) is list)
35 | assert(len(ref) >= 1)
36 |
37 | bleu_scorer += (hypo[0], ref)
38 |
39 | # score, scores = bleu_scorer.compute_score(option='shortest')
40 | score, scores = bleu_scorer.compute_score(option='closest', verbose=0)
41 | # score, scores = bleu_scorer.compute_score(option='average', verbose=1)
42 |
43 | return score, scores
44 |
45 | def __str__(self):
46 | return 'BLEU'
47 |
--------------------------------------------------------------------------------
/evaluation/bleu/bleu_scorer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | # bleu_scorer.py
4 | # David Chiang
5 |
6 | # Copyright (c) 2004-2006 University of Maryland. All rights
7 | # reserved. Do not redistribute without permission from the
8 | # author. Not for commercial use.
9 |
10 | # Modified by:
11 | # Hao Fang
12 | # Tsung-Yi Lin
13 |
14 | ''' Provides:
15 | cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test().
16 | cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked().
17 | '''
18 |
19 | import copy
20 | import sys, math, re
21 | from collections import defaultdict
22 |
23 |
24 | def precook(s, n=4, out=False):
25 | """Takes a string as input and returns an object that can be given to
26 | either cook_refs or cook_test. This is optional: cook_refs and cook_test
27 | can take string arguments as well."""
28 | words = s.split()
29 | counts = defaultdict(int)
30 | for k in range(1, n + 1):
31 | for i in range(len(words) - k + 1):
32 | ngram = tuple(words[i:i + k])
33 | counts[ngram] += 1
34 | return (len(words), counts)
35 |
36 |
37 | def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "average"
38 | '''Takes a list of reference sentences for a single segment
39 | and returns an object that encapsulates everything that BLEU
40 | needs to know about them.'''
41 |
42 | reflen = []
43 | maxcounts = {}
44 | for ref in refs:
45 | rl, counts = precook(ref, n)
46 | reflen.append(rl)
47 | for (ngram, count) in counts.items():
48 | maxcounts[ngram] = max(maxcounts.get(ngram, 0), count)
49 |
50 | # Calculate effective reference sentence length.
51 | if eff == "shortest":
52 | reflen = min(reflen)
53 | elif eff == "average":
54 | reflen = float(sum(reflen)) / len(reflen)
55 |
56 | ## lhuang: N.B.: leave reflen computaiton to the very end!!
57 |
58 | ## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design)
59 |
60 | return (reflen, maxcounts)
61 |
62 |
63 | def cook_test(test, ref_tuple, eff=None, n=4):
64 | '''Takes a test sentence and returns an object that
65 | encapsulates everything that BLEU needs to know about it.'''
66 |
67 | testlen, counts = precook(test, n, True)
68 | reflen, refmaxcounts = ref_tuple
69 |
70 | result = {}
71 |
72 | # Calculate effective reference sentence length.
73 |
74 | if eff == "closest":
75 | result["reflen"] = min((abs(l - testlen), l) for l in reflen)[1]
76 | else: ## i.e., "average" or "shortest" or None
77 | result["reflen"] = reflen
78 |
79 | result["testlen"] = testlen
80 |
81 | result["guess"] = [max(0, testlen - k + 1) for k in range(1, n + 1)]
82 |
83 | result['correct'] = [0] * n
84 | for (ngram, count) in counts.items():
85 | result["correct"][len(ngram) - 1] += min(refmaxcounts.get(ngram, 0), count)
86 |
87 | return result
88 |
89 |
90 | class BleuScorer(object):
91 | """Bleu scorer.
92 | """
93 |
94 | __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen"
95 |
96 | # special_reflen is used in oracle (proportional effective ref len for a node).
97 |
98 | def copy(self):
99 | ''' copy the refs.'''
100 | new = BleuScorer(n=self.n)
101 | new.ctest = copy.copy(self.ctest)
102 | new.crefs = copy.copy(self.crefs)
103 | new._score = None
104 | return new
105 |
106 | def __init__(self, test=None, refs=None, n=4, special_reflen=None):
107 | ''' singular instance '''
108 |
109 | self.n = n
110 | self.crefs = []
111 | self.ctest = []
112 | self.cook_append(test, refs)
113 | self.special_reflen = special_reflen
114 |
115 | def cook_append(self, test, refs):
116 | '''called by constructor and __iadd__ to avoid creating new instances.'''
117 |
118 | if refs is not None:
119 | self.crefs.append(cook_refs(refs))
120 | if test is not None:
121 | cooked_test = cook_test(test, self.crefs[-1])
122 | self.ctest.append(cooked_test) ## N.B.: -1
123 | else:
124 | self.ctest.append(None) # lens of crefs and ctest have to match
125 |
126 | self._score = None ## need to recompute
127 |
128 | def ratio(self, option=None):
129 | self.compute_score(option=option)
130 | return self._ratio
131 |
132 | def score_ratio(self, option=None):
133 | '''
134 | return (bleu, len_ratio) pair
135 | '''
136 |
137 | return self.fscore(option=option), self.ratio(option=option)
138 |
139 | def score_ratio_str(self, option=None):
140 | return "%.4f (%.2f)" % self.score_ratio(option)
141 |
142 | def reflen(self, option=None):
143 | self.compute_score(option=option)
144 | return self._reflen
145 |
146 | def testlen(self, option=None):
147 | self.compute_score(option=option)
148 | return self._testlen
149 |
150 | def retest(self, new_test):
151 | if type(new_test) is str:
152 | new_test = [new_test]
153 | assert len(new_test) == len(self.crefs), new_test
154 | self.ctest = []
155 | for t, rs in zip(new_test, self.crefs):
156 | self.ctest.append(cook_test(t, rs))
157 | self._score = None
158 |
159 | return self
160 |
161 | def rescore(self, new_test):
162 | ''' replace test(s) with new test(s), and returns the new score.'''
163 |
164 | return self.retest(new_test).compute_score()
165 |
166 | def size(self):
167 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
168 | return len(self.crefs)
169 |
170 | def __iadd__(self, other):
171 | '''add an instance (e.g., from another sentence).'''
172 |
173 | if type(other) is tuple:
174 | ## avoid creating new BleuScorer instances
175 | self.cook_append(other[0], other[1])
176 | else:
177 | assert self.compatible(other), "incompatible BLEUs."
178 | self.ctest.extend(other.ctest)
179 | self.crefs.extend(other.crefs)
180 | self._score = None ## need to recompute
181 |
182 | return self
183 |
184 | def compatible(self, other):
185 | return isinstance(other, BleuScorer) and self.n == other.n
186 |
187 | def single_reflen(self, option="average"):
188 | return self._single_reflen(self.crefs[0][0], option)
189 |
190 | def _single_reflen(self, reflens, option=None, testlen=None):
191 |
192 | if option == "shortest":
193 | reflen = min(reflens)
194 | elif option == "average":
195 | reflen = float(sum(reflens)) / len(reflens)
196 | elif option == "closest":
197 | reflen = min((abs(l - testlen), l) for l in reflens)[1]
198 | else:
199 | assert False, "unsupported reflen option %s" % option
200 |
201 | return reflen
202 |
203 | def recompute_score(self, option=None, verbose=0):
204 | self._score = None
205 | return self.compute_score(option, verbose)
206 |
207 | def compute_score(self, option=None, verbose=0):
208 | n = self.n
209 | small = 1e-9
210 | tiny = 1e-15 ## so that if guess is 0 still return 0
211 | bleu_list = [[] for _ in range(n)]
212 |
213 | if self._score is not None:
214 | return self._score
215 |
216 | if option is None:
217 | option = "average" if len(self.crefs) == 1 else "closest"
218 |
219 | self._testlen = 0
220 | self._reflen = 0
221 | totalcomps = {'testlen': 0, 'reflen': 0, 'guess': [0] * n, 'correct': [0] * n}
222 |
223 | # for each sentence
224 | for comps in self.ctest:
225 | testlen = comps['testlen']
226 | self._testlen += testlen
227 |
228 | if self.special_reflen is None: ## need computation
229 | reflen = self._single_reflen(comps['reflen'], option, testlen)
230 | else:
231 | reflen = self.special_reflen
232 |
233 | self._reflen += reflen
234 |
235 | for key in ['guess', 'correct']:
236 | for k in range(n):
237 | totalcomps[key][k] += comps[key][k]
238 |
239 | # append per image bleu score
240 | bleu = 1.
241 | for k in range(n):
242 | bleu *= (float(comps['correct'][k]) + tiny) \
243 | / (float(comps['guess'][k]) + small)
244 | bleu_list[k].append(bleu ** (1. / (k + 1)))
245 | ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division
246 | if ratio < 1:
247 | for k in range(n):
248 | bleu_list[k][-1] *= math.exp(1 - 1 / ratio)
249 |
250 | if verbose > 1:
251 | print(comps, reflen)
252 |
253 | totalcomps['reflen'] = self._reflen
254 | totalcomps['testlen'] = self._testlen
255 |
256 | bleus = []
257 | bleu = 1.
258 | for k in range(n):
259 | bleu *= float(totalcomps['correct'][k] + tiny) \
260 | / (totalcomps['guess'][k] + small)
261 | bleus.append(bleu ** (1. / (k + 1)))
262 | ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division
263 | if ratio < 1:
264 | for k in range(n):
265 | bleus[k] *= math.exp(1 - 1 / ratio)
266 |
267 | if verbose > 0:
268 | print(totalcomps)
269 | print("ratio:", ratio)
270 |
271 | self._score = bleus
272 | return self._score, bleu_list
273 |
--------------------------------------------------------------------------------
/evaluation/cider/__init__.py:
--------------------------------------------------------------------------------
1 | from .cider import Cider
--------------------------------------------------------------------------------
/evaluation/cider/cider.py:
--------------------------------------------------------------------------------
1 | # Filename: cider.py
2 | #
3 | # Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric
4 | # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726)
5 | #
6 | # Creation Date: Sun Feb 8 14:16:54 2015
7 | #
8 | # Authors: Ramakrishna Vedantam and Tsung-Yi Lin
9 |
10 | from .cider_scorer import CiderScorer
11 |
12 | class Cider:
13 | """
14 | Main Class to compute the CIDEr metric
15 |
16 | """
17 | def __init__(self, gts=None, n=4, sigma=6.0):
18 | # set cider to sum over 1 to 4-grams
19 | self._n = n
20 | # set the standard deviation parameter for gaussian penalty
21 | self._sigma = sigma
22 | self.doc_frequency = None
23 | self.ref_len = None
24 | if gts is not None:
25 | tmp_cider = CiderScorer(gts, n=self._n, sigma=self._sigma)
26 | self.doc_frequency = tmp_cider.doc_frequency
27 | self.ref_len = tmp_cider.ref_len
28 |
29 | def compute_score(self, gts, res):
30 | """
31 | Main function to compute CIDEr score
32 | :param gts (dict) : dictionary with key and value
33 | res (dict) : dictionary with key and value
34 | :return: cider (float) : computed CIDEr score for the corpus
35 | """
36 | assert(gts.keys() == res.keys())
37 | cider_scorer = CiderScorer(gts, test=res, n=self._n, sigma=self._sigma, doc_frequency=self.doc_frequency,
38 | ref_len=self.ref_len)
39 | return cider_scorer.compute_score()
40 |
41 | def __str__(self):
42 | return 'CIDEr'
43 |
--------------------------------------------------------------------------------
/evaluation/cider/cider_scorer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Tsung-Yi Lin
3 | # Ramakrishna Vedantam
4 |
5 | import copy
6 | from collections import defaultdict
7 | import numpy as np
8 | import math
9 |
10 | def precook(s, n=4):
11 | """
12 | Takes a string as input and returns an object that can be given to
13 | either cook_refs or cook_test. This is optional: cook_refs and cook_test
14 | can take string arguments as well.
15 | :param s: string : sentence to be converted into ngrams
16 | :param n: int : number of ngrams for which representation is calculated
17 | :return: term frequency vector for occuring ngrams
18 | """
19 | words = s.split()
20 | counts = defaultdict(int)
21 | for k in range(1,n+1):
22 | for i in range(len(words)-k+1):
23 | ngram = tuple(words[i:i+k])
24 | counts[ngram] += 1
25 | return counts
26 |
27 | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average"
28 | '''Takes a list of reference sentences for a single segment
29 | and returns an object that encapsulates everything that BLEU
30 | needs to know about them.
31 | :param refs: list of string : reference sentences for some image
32 | :param n: int : number of ngrams for which (ngram) representation is calculated
33 | :return: result (list of dict)
34 | '''
35 | return [precook(ref, n) for ref in refs]
36 |
37 | def cook_test(test, n=4):
38 | '''Takes a test sentence and returns an object that
39 | encapsulates everything that BLEU needs to know about it.
40 | :param test: list of string : hypothesis sentence for some image
41 | :param n: int : number of ngrams for which (ngram) representation is calculated
42 | :return: result (dict)
43 | '''
44 | return precook(test, n)
45 |
46 | class CiderScorer(object):
47 | """CIDEr scorer.
48 | """
49 |
50 | def __init__(self, refs, test=None, n=4, sigma=6.0, doc_frequency=None, ref_len=None):
51 | ''' singular instance '''
52 | self.n = n
53 | self.sigma = sigma
54 | self.crefs = []
55 | self.ctest = []
56 | self.doc_frequency = defaultdict(float)
57 | self.ref_len = None
58 |
59 | for k in refs.keys():
60 | self.crefs.append(cook_refs(refs[k]))
61 | if test is not None:
62 | self.ctest.append(cook_test(test[k][0])) ## N.B.: -1
63 | else:
64 | self.ctest.append(None) # lens of crefs and ctest have to match
65 |
66 | if doc_frequency is None and ref_len is None:
67 | # compute idf
68 | self.compute_doc_freq()
69 | # compute log reference length
70 | self.ref_len = np.log(float(len(self.crefs)))
71 | else:
72 | self.doc_frequency = doc_frequency
73 | self.ref_len = ref_len
74 |
75 | def compute_doc_freq(self):
76 | '''
77 | Compute term frequency for reference data.
78 | This will be used to compute idf (inverse document frequency later)
79 | The term frequency is stored in the object
80 | :return: None
81 | '''
82 | for refs in self.crefs:
83 | # refs, k ref captions of one image
84 | for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]):
85 | self.doc_frequency[ngram] += 1
86 | # maxcounts[ngram] = max(maxcounts.get(ngram,0), count)
87 |
88 | def compute_cider(self):
89 | def counts2vec(cnts):
90 | """
91 | Function maps counts of ngram to vector of tfidf weights.
92 | The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights.
93 | The n-th entry of array denotes length of n-grams.
94 | :param cnts:
95 | :return: vec (array of dict), norm (array of float), length (int)
96 | """
97 | vec = [defaultdict(float) for _ in range(self.n)]
98 | length = 0
99 | norm = [0.0 for _ in range(self.n)]
100 | for (ngram,term_freq) in cnts.items():
101 | # give word count 1 if it doesn't appear in reference corpus
102 | df = np.log(max(1.0, self.doc_frequency[ngram]))
103 | # ngram index
104 | n = len(ngram)-1
105 | # tf (term_freq) * idf (precomputed idf) for n-grams
106 | vec[n][ngram] = float(term_freq)*(self.ref_len - df)
107 | # compute norm for the vector. the norm will be used for computing similarity
108 | norm[n] += pow(vec[n][ngram], 2)
109 |
110 | if n == 1:
111 | length += term_freq
112 | norm = [np.sqrt(n) for n in norm]
113 | return vec, norm, length
114 |
115 | def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref):
116 | '''
117 | Compute the cosine similarity of two vectors.
118 | :param vec_hyp: array of dictionary for vector corresponding to hypothesis
119 | :param vec_ref: array of dictionary for vector corresponding to reference
120 | :param norm_hyp: array of float for vector corresponding to hypothesis
121 | :param norm_ref: array of float for vector corresponding to reference
122 | :param length_hyp: int containing length of hypothesis
123 | :param length_ref: int containing length of reference
124 | :return: array of score for each n-grams cosine similarity
125 | '''
126 | delta = float(length_hyp - length_ref)
127 | # measure consine similarity
128 | val = np.array([0.0 for _ in range(self.n)])
129 | for n in range(self.n):
130 | # ngram
131 | for (ngram,count) in vec_hyp[n].items():
132 | # vrama91 : added clipping
133 | val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram]
134 |
135 | if (norm_hyp[n] != 0) and (norm_ref[n] != 0):
136 | val[n] /= (norm_hyp[n]*norm_ref[n])
137 |
138 | assert(not math.isnan(val[n]))
139 | # vrama91: added a length based gaussian penalty
140 | val[n] *= np.e**(-(delta**2)/(2*self.sigma**2))
141 | return val
142 |
143 | scores = []
144 | for test, refs in zip(self.ctest, self.crefs):
145 | # compute vector for test captions
146 | vec, norm, length = counts2vec(test)
147 | # compute vector for ref captions
148 | score = np.array([0.0 for _ in range(self.n)])
149 | for ref in refs:
150 | vec_ref, norm_ref, length_ref = counts2vec(ref)
151 | score += sim(vec, vec_ref, norm, norm_ref, length, length_ref)
152 | # change by vrama91 - mean of ngram scores, instead of sum
153 | score_avg = np.mean(score)
154 | # divide by number of references
155 | score_avg /= len(refs)
156 | # multiply score by 10
157 | score_avg *= 10.0
158 | # append score of an image to the score list
159 | scores.append(score_avg)
160 | return scores
161 |
162 | def compute_score(self):
163 | # compute cider score
164 | score = self.compute_cider()
165 | # debug
166 | # print score
167 | return np.mean(np.array(score)), np.array(score)
--------------------------------------------------------------------------------
/evaluation/meteor/__init__.py:
--------------------------------------------------------------------------------
1 | from .meteor import Meteor
--------------------------------------------------------------------------------
/evaluation/meteor/meteor.py:
--------------------------------------------------------------------------------
1 | # Python wrapper for METEOR implementation, by Xinlei Chen
2 | # Acknowledge Michael Denkowski for the generous discussion and help
3 |
4 | import os
5 | import subprocess
6 | import threading
7 | import tarfile
8 | from utils import download_from_url
9 |
10 | METEOR_GZ_URL = 'http://aimagelab.ing.unimore.it/speaksee/data/meteor.tgz'
11 | METEOR_JAR = 'meteor-1.5.jar'
12 |
13 | class Meteor:
14 | def __init__(self):
15 | base_path = os.path.dirname(os.path.abspath(__file__))
16 | jar_path = os.path.join(base_path, METEOR_JAR)
17 | gz_path = os.path.join(base_path, os.path.basename(METEOR_GZ_URL))
18 | if not os.path.isfile(jar_path):
19 | if not os.path.isfile(gz_path):
20 | download_from_url(METEOR_GZ_URL, gz_path)
21 | tar = tarfile.open(gz_path, "r")
22 | tar.extractall(path=os.path.dirname(os.path.abspath(__file__)))
23 | tar.close()
24 | os.remove(gz_path)
25 |
26 | self.meteor_cmd = ['java', '-jar', '-Xmx2G', METEOR_JAR, \
27 | '-', '-', '-stdio', '-l', 'en', '-norm']
28 | self.meteor_p = subprocess.Popen(self.meteor_cmd, \
29 | cwd=os.path.dirname(os.path.abspath(__file__)), \
30 | stdin=subprocess.PIPE, \
31 | stdout=subprocess.PIPE, \
32 | stderr=subprocess.PIPE)
33 | # Used to guarantee thread safety
34 | self.lock = threading.Lock()
35 |
36 | def compute_score(self, gts, res):
37 | assert(gts.keys() == res.keys())
38 | imgIds = gts.keys()
39 | scores = []
40 |
41 | eval_line = 'EVAL'
42 | self.lock.acquire()
43 | for i in imgIds:
44 | assert(len(res[i]) == 1)
45 | stat = self._stat(res[i][0], gts[i])
46 | eval_line += ' ||| {}'.format(stat)
47 |
48 | self.meteor_p.stdin.write('{}\n'.format(eval_line).encode())
49 | self.meteor_p.stdin.flush()
50 | for i in range(0,len(imgIds)):
51 | scores.append(float(self.meteor_p.stdout.readline().strip()))
52 | score = float(self.meteor_p.stdout.readline().strip())
53 | self.lock.release()
54 |
55 | return score, scores
56 |
57 | def _stat(self, hypothesis_str, reference_list):
58 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words
59 | hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ')
60 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str))
61 | self.meteor_p.stdin.write('{}\n'.format(score_line).encode())
62 | self.meteor_p.stdin.flush()
63 | raw = self.meteor_p.stdout.readline().decode().strip()
64 | numbers = [str(int(float(n))) for n in raw.split()]
65 | return ' '.join(numbers)
66 |
67 | def __del__(self):
68 | self.lock.acquire()
69 | self.meteor_p.stdin.close()
70 | self.meteor_p.kill()
71 | self.meteor_p.wait()
72 | self.lock.release()
73 |
74 | def __str__(self):
75 | return 'METEOR'
76 |
--------------------------------------------------------------------------------
/evaluation/rouge/__init__.py:
--------------------------------------------------------------------------------
1 | from .rouge import Rouge
--------------------------------------------------------------------------------
/evaluation/rouge/rouge.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #
3 | # File Name : rouge.py
4 | #
5 | # Description : Computes ROUGE-L metric as described by Lin and Hovey (2004)
6 | #
7 | # Creation Date : 2015-01-07 06:03
8 | # Author : Ramakrishna Vedantam
9 |
10 | import numpy as np
11 | import pdb
12 |
13 |
14 | def my_lcs(string, sub):
15 | """
16 | Calculates longest common subsequence for a pair of tokenized strings
17 | :param string : list of str : tokens from a string split using whitespace
18 | :param sub : list of str : shorter string, also split using whitespace
19 | :returns: length (list of int): length of the longest common subsequence between the two strings
20 |
21 | Note: my_lcs only gives length of the longest common subsequence, not the actual LCS
22 | """
23 | if (len(string) < len(sub)):
24 | sub, string = string, sub
25 |
26 | lengths = [[0 for i in range(0, len(sub) + 1)] for j in range(0, len(string) + 1)]
27 |
28 | for j in range(1, len(sub) + 1):
29 | for i in range(1, len(string) + 1):
30 | if (string[i - 1] == sub[j - 1]):
31 | lengths[i][j] = lengths[i - 1][j - 1] + 1
32 | else:
33 | lengths[i][j] = max(lengths[i - 1][j], lengths[i][j - 1])
34 |
35 | return lengths[len(string)][len(sub)]
36 |
37 |
38 | class Rouge():
39 | '''
40 | Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set
41 |
42 | '''
43 |
44 | def __init__(self):
45 | # vrama91: updated the value below based on discussion with Hovey
46 | self.beta = 1.2
47 |
48 | def calc_score(self, candidate, refs):
49 | """
50 | Compute ROUGE-L score given one candidate and references for an image
51 | :param candidate: str : candidate sentence to be evaluated
52 | :param refs: list of str : COCO reference sentences for the particular image to be evaluated
53 | :returns score: int (ROUGE-L score for the candidate evaluated against references)
54 | """
55 | assert (len(candidate) == 1)
56 | assert (len(refs) > 0)
57 | prec = []
58 | rec = []
59 |
60 | # split into tokens
61 | token_c = candidate[0].split(" ")
62 |
63 | for reference in refs:
64 | # split into tokens
65 | token_r = reference.split(" ")
66 | # compute the longest common subsequence
67 | lcs = my_lcs(token_r, token_c)
68 | prec.append(lcs / float(len(token_c)))
69 | rec.append(lcs / float(len(token_r)))
70 |
71 | prec_max = max(prec)
72 | rec_max = max(rec)
73 |
74 | if (prec_max != 0 and rec_max != 0):
75 | score = ((1 + self.beta ** 2) * prec_max * rec_max) / float(rec_max + self.beta ** 2 * prec_max)
76 | else:
77 | score = 0.0
78 | return score
79 |
80 | def compute_score(self, gts, res):
81 | """
82 | Computes Rouge-L score given a set of reference and candidate sentences for the dataset
83 | Invoked by evaluate_captions.py
84 | :param hypo_for_image: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values
85 | :param ref_for_image: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values
86 | :returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images)
87 | """
88 | assert (gts.keys() == res.keys())
89 | imgIds = gts.keys()
90 |
91 | score = []
92 | for id in imgIds:
93 | hypo = res[id]
94 | ref = gts[id]
95 |
96 | score.append(self.calc_score(hypo, ref))
97 |
98 | # Sanity check.
99 | assert (type(hypo) is list)
100 | assert (len(hypo) == 1)
101 | assert (type(ref) is list)
102 | assert (len(ref) > 0)
103 |
104 | average_score = np.mean(np.array(score))
105 | return average_score, np.array(score)
106 |
107 | def __str__(self):
108 | return 'ROUGE'
109 |
--------------------------------------------------------------------------------
/evaluation/stanford-corenlp-3.4.1.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangxuying1004/RSTNet/d308459ad1c28e0f4c2c3e4a9986ae76870e3c72/evaluation/stanford-corenlp-3.4.1.jar
--------------------------------------------------------------------------------
/evaluation/tokenizer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #
3 | # File Name : ptbtokenizer.py
4 | #
5 | # Description : Do the PTB Tokenization and remove punctuations.
6 | #
7 | # Creation Date : 29-12-2014
8 | # Last Modified : Thu Mar 19 09:53:35 2015
9 | # Authors : Hao Fang and Tsung-Yi Lin
10 |
11 | import os
12 | import subprocess
13 | import tempfile
14 |
15 | class PTBTokenizer(object):
16 | """Python wrapper of Stanford PTBTokenizer"""
17 |
18 | corenlp_jar = 'stanford-corenlp-3.4.1.jar'
19 | punctuations = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \
20 | ".", "?", "!", ",", ":", "-", "--", "...", ";"]
21 |
22 | @classmethod
23 | def tokenize(cls, corpus):
24 | cmd = ['java', '-cp', cls.corenlp_jar, \
25 | 'edu.stanford.nlp.process.PTBTokenizer', \
26 | '-preserveLines', '-lowerCase']
27 |
28 | if isinstance(corpus, list) or isinstance(corpus, tuple):
29 | if isinstance(corpus[0], list) or isinstance(corpus[0], tuple):
30 | corpus = {i:c for i, c in enumerate(corpus)}
31 | else:
32 | corpus = {i: [c, ] for i, c in enumerate(corpus)}
33 |
34 | # prepare data for PTB Tokenizer
35 | tokenized_corpus = {}
36 | image_id = [k for k, v in list(corpus.items()) for _ in range(len(v))]
37 | sentences = '\n'.join([c.replace('\n', ' ') for k, v in corpus.items() for c in v])
38 |
39 | # save sentences to temporary file
40 | path_to_jar_dirname=os.path.dirname(os.path.abspath(__file__))
41 | tmp_file = tempfile.NamedTemporaryFile(delete=False, dir=path_to_jar_dirname)
42 | tmp_file.write(sentences.encode())
43 | tmp_file.close()
44 |
45 | # tokenize sentence
46 | cmd.append(os.path.basename(tmp_file.name))
47 | p_tokenizer = subprocess.Popen(cmd, cwd=path_to_jar_dirname, \
48 | stdout=subprocess.PIPE, stderr=open(os.devnull, 'w'))
49 | token_lines = p_tokenizer.communicate(input=sentences.rstrip())[0]
50 | token_lines = token_lines.decode()
51 | lines = token_lines.split('\n')
52 | # remove temp file
53 | os.remove(tmp_file.name)
54 |
55 | # create dictionary for tokenized captions
56 | for k, line in zip(image_id, lines):
57 | if not k in tokenized_corpus:
58 | tokenized_corpus[k] = []
59 | tokenized_caption = ' '.join([w for w in line.rstrip().split(' ') \
60 | if w not in cls.punctuations])
61 | tokenized_corpus[k].append(tokenized_caption)
62 |
63 | return tokenized_corpus
--------------------------------------------------------------------------------
/feats_process.py:
--------------------------------------------------------------------------------
1 | import os
2 | import h5py
3 | import argparse
4 |
5 | import torch
6 | import torch.nn as nn
7 |
8 |
9 | class DataProcessor(nn.Module):
10 | def __init__(self):
11 | super(DataProcessor, self).__init__()
12 | self.pool = nn.AdaptiveAvgPool2d((7, 7))
13 |
14 | def forward(self, x):
15 | x = self.pool(x)
16 | x = torch.squeeze(x) # [1, d, h, w] => [d, h, w]
17 | x = x.permute(1, 2, 0) # [d, h, w] => [h, w, d]
18 | return x.view(-1, x.size(-1)) # [h*w, d]
19 |
20 |
21 | def process_dataset(file_path, feat_paths):
22 | print('save the ori grid features to the features with specified size')
23 | # 加载特征处理器
24 | processor = DataProcessor()
25 | with h5py.File(file_path, 'w') as f:
26 | for i in tqdm(range(len(feat_paths))):
27 | # 加载特征
28 | feat_path = feat_paths[i]
29 | img_feat = torch.load(feat_path)
30 | # 处理特征
31 | img_feat = processor(img_feat)
32 | # 保存特征
33 | img_name = feat_path.split('/')[-1]
34 | img_id = int(img_name.split('.')[0])
35 | f.create_dataset('%d_grids' % img_id, data=img_feat.numpy())
36 | f.close()
37 |
38 |
39 | def get_feat_paths(dir_to_save_feats, data_split='trainval', test2014_info_path=None):
40 | print('get the paths of raw grid features')
41 | ans = []
42 | # 线下训练和测试
43 | if data_split == 'trainval':
44 | filenames_train = os.listdir(os.path.join(dir_to_save_feats, 'train2014'))
45 | ans_train = [os.path.join(dir_to_save_feats, 'train2014', filename) for filename in filenames_train]
46 | filenames_val = os.listdir(os.path.join(dir_to_save_feats, 'val2014'))
47 | ans_val = [os.path.join(dir_to_save_feats, 'val2014', filename) for filename in filenames_val]
48 | ans = ans_train + ans_val
49 | # 线上测试
50 | elif data_split == 'test':
51 | assert test2014_info_path is not None
52 | with open(test2014_info_path, 'r') as f:
53 | test2014_info = json.load(f)
54 |
55 | for image in test2014_info['images']:
56 | img_id = image['id']
57 | feat_path = os.path.join(dir_to_save_feats, 'test2015', img_id+'.pth')
58 | assert os.path.exists(feat_path)
59 | ans.append(feat_path)
60 | assert len(ans) == 40775
61 | assert not ans # make sure ans list is not empty
62 | return ans
63 |
64 |
65 | def main(args):
66 | # 加载原始特征的绝对路径
67 | feat_paths = get_feat_paths(args.dir_to_raw_feats, args.data_split, args.test2014_info_path)
68 | # 构建处理后特征的文件名和保存路径
69 | file_path = os.path.join(args.dir_to_processed_feats, 'X101_grid_feats_coco_'+args.data_split+'.hdf5')
70 | # 处理特征并保存
71 | process_dataset(file_path, feat_paths)
72 | print('finished!')
73 |
74 |
75 | if __name__ == '__main__':
76 |
77 | parser = argparse.ArgumentParser(description='data process')
78 | parser.add_argument('--dir_to_raw_feats', type=str, default='./datasets/X101-features/')
79 | parser.add_argument('--dir_to_processed_feats', type=str, default='./datasets/X101-features/')
80 | # trainval = train2014 + val2014,用于训练和线下测试,test = test2014,用于线上测试
81 | parser.add_argument('--data_split', type=str, default='trainval') # trainval, test
82 | # test2015包含test2014,获取test2014时,先加载test2014索引再加载特征,image_info_test2014.json是保存test2014信息的文件
83 | parser.add_argument('--test2014_info_path', type=str, default='./datasets/m2_annotations/image_info_test2014.json')
84 | args = parser.parse_args()
85 |
86 | main(args)
87 |
88 |
--------------------------------------------------------------------------------
/images/RSTNet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangxuying1004/RSTNet/d308459ad1c28e0f4c2c3e4a9986ae76870e3c72/images/RSTNet.png
--------------------------------------------------------------------------------
/images/results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangxuying1004/RSTNet/d308459ad1c28e0f4c2c3e4a9986ae76870e3c72/images/results.png
--------------------------------------------------------------------------------
/images/train_cider.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangxuying1004/RSTNet/d308459ad1c28e0f4c2c3e4a9986ae76870e3c72/images/train_cider.png
--------------------------------------------------------------------------------
/images/visualness.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangxuying1004/RSTNet/d308459ad1c28e0f4c2c3e4a9986ae76870e3c72/images/visualness.png
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .transformer import Transformer
2 | from .captioning_model import CaptioningModel
3 |
--------------------------------------------------------------------------------
/models/beam_search/__init__.py:
--------------------------------------------------------------------------------
1 | from .beam_search import BeamSearch
2 |
--------------------------------------------------------------------------------
/models/beam_search/beam_search.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import utils
3 |
4 |
5 | class BeamSearch(object):
6 | def __init__(self, model, max_len: int, eos_idx: int, beam_size: int):
7 | self.model = model
8 | self.max_len = max_len
9 | self.eos_idx = eos_idx
10 | self.beam_size = beam_size
11 | self.b_s = None
12 | self.device = None
13 | self.seq_mask = None
14 | self.seq_logprob = None
15 | self.outputs = None
16 | self.log_probs = None
17 | self.selected_words = None
18 | self.all_log_probs = None
19 |
20 | def _expand_state(self, selected_beam, cur_beam_size):
21 | def fn(s):
22 | shape = [int(sh) for sh in s.shape]
23 | beam = selected_beam
24 | for _ in shape[1:]:
25 | beam = beam.unsqueeze(-1)
26 | s = torch.gather(s.view(*([self.b_s, cur_beam_size] + shape[1:])), 1,
27 | beam.expand(*([self.b_s, self.beam_size] + shape[1:])))
28 | s = s.view(*([-1, ] + shape[1:]))
29 | return s
30 |
31 | return fn
32 |
33 | def _expand_visual(self, visual: utils.TensorOrSequence, cur_beam_size: int, selected_beam: torch.Tensor):
34 | if isinstance(visual, torch.Tensor):
35 | visual_shape = visual.shape
36 | visual_exp_shape = (self.b_s, cur_beam_size) + visual_shape[1:]
37 | visual_red_shape = (self.b_s * self.beam_size,) + visual_shape[1:]
38 | selected_beam_red_size = (self.b_s, self.beam_size) + tuple(1 for _ in range(len(visual_exp_shape) - 2))
39 | selected_beam_exp_size = (self.b_s, self.beam_size) + visual_exp_shape[2:]
40 | visual_exp = visual.view(visual_exp_shape)
41 | selected_beam_exp = selected_beam.view(selected_beam_red_size).expand(selected_beam_exp_size)
42 | visual = torch.gather(visual_exp, 1, selected_beam_exp).view(visual_red_shape)
43 | else:
44 | new_visual = []
45 | for im in visual:
46 | visual_shape = im.shape
47 | visual_exp_shape = (self.b_s, cur_beam_size) + visual_shape[1:]
48 | visual_red_shape = (self.b_s * self.beam_size,) + visual_shape[1:]
49 | selected_beam_red_size = (self.b_s, self.beam_size) + tuple(1 for _ in range(len(visual_exp_shape) - 2))
50 | selected_beam_exp_size = (self.b_s, self.beam_size) + visual_exp_shape[2:]
51 | visual_exp = im.view(visual_exp_shape)
52 | selected_beam_exp = selected_beam.view(selected_beam_red_size).expand(selected_beam_exp_size)
53 | new_im = torch.gather(visual_exp, 1, selected_beam_exp).view(visual_red_shape)
54 | new_visual.append(new_im)
55 | visual = tuple(new_visual)
56 | return visual
57 |
58 | def apply(self, visual: utils.TensorOrSequence, out_size=1, return_probs=False, **kwargs):
59 | self.b_s = utils.get_batch_size(visual)
60 | self.device = utils.get_device(visual)
61 | self.seq_mask = torch.ones((self.b_s, self.beam_size, 1), device=self.device)
62 | self.seq_logprob = torch.zeros((self.b_s, 1, 1), device=self.device)
63 | self.log_probs = []
64 | self.selected_words = None
65 | if return_probs:
66 | self.all_log_probs = []
67 |
68 | outputs = []
69 | with self.model.statefulness(self.b_s):
70 | for t in range(self.max_len):
71 | visual, outputs = self.iter(t, visual, outputs, return_probs, **kwargs)
72 |
73 | # Sort result
74 | seq_logprob, sort_idxs = torch.sort(self.seq_logprob, 1, descending=True)
75 | outputs = torch.cat(outputs, -1)
76 | outputs = torch.gather(outputs, 1, sort_idxs.expand(self.b_s, self.beam_size, self.max_len))
77 | log_probs = torch.cat(self.log_probs, -1)
78 | log_probs = torch.gather(log_probs, 1, sort_idxs.expand(self.b_s, self.beam_size, self.max_len))
79 | if return_probs:
80 | all_log_probs = torch.cat(self.all_log_probs, 2)
81 | all_log_probs = torch.gather(all_log_probs, 1, sort_idxs.unsqueeze(-1).expand(self.b_s, self.beam_size,
82 | self.max_len,
83 | all_log_probs.shape[-1]))
84 |
85 | outputs = outputs.contiguous()[:, :out_size]
86 | log_probs = log_probs.contiguous()[:, :out_size]
87 | if out_size == 1:
88 | outputs = outputs.squeeze(1)
89 | log_probs = log_probs.squeeze(1)
90 |
91 | if return_probs:
92 | return outputs, log_probs, all_log_probs
93 | else:
94 | return outputs, log_probs
95 |
96 | def select(self, t, candidate_logprob, **kwargs):
97 | selected_logprob, selected_idx = torch.sort(candidate_logprob.view(self.b_s, -1), -1, descending=True)
98 | selected_logprob, selected_idx = selected_logprob[:, :self.beam_size], selected_idx[:, :self.beam_size]
99 | return selected_idx, selected_logprob
100 |
101 | def iter(self, t: int, visual: utils.TensorOrSequence, outputs, return_probs, **kwargs):
102 | cur_beam_size = 1 if t == 0 else self.beam_size
103 |
104 | word_logprob = self.model.step(t, self.selected_words, visual, None, mode='feedback', **kwargs)
105 | word_logprob = word_logprob.view(self.b_s, cur_beam_size, -1)
106 | candidate_logprob = self.seq_logprob + word_logprob
107 |
108 | # Mask sequence if it reaches EOS
109 | if t > 0:
110 | mask = (self.selected_words.view(self.b_s, cur_beam_size) != self.eos_idx).float().unsqueeze(-1)
111 | self.seq_mask = self.seq_mask * mask
112 | word_logprob = word_logprob * self.seq_mask.expand_as(word_logprob)
113 | old_seq_logprob = self.seq_logprob.expand_as(candidate_logprob).contiguous()
114 | old_seq_logprob[:, :, 1:] = -999
115 | candidate_logprob = self.seq_mask * candidate_logprob + old_seq_logprob * (1 - self.seq_mask)
116 |
117 | selected_idx, selected_logprob = self.select(t, candidate_logprob, **kwargs)
118 | selected_beam = selected_idx / candidate_logprob.shape[-1]
119 | selected_words = selected_idx - selected_beam * candidate_logprob.shape[-1]
120 |
121 | self.model.apply_to_states(self._expand_state(selected_beam, cur_beam_size))
122 | visual = self._expand_visual(visual, cur_beam_size, selected_beam)
123 |
124 | self.seq_logprob = selected_logprob.unsqueeze(-1)
125 | self.seq_mask = torch.gather(self.seq_mask, 1, selected_beam.unsqueeze(-1))
126 | outputs = list(torch.gather(o, 1, selected_beam.unsqueeze(-1)) for o in outputs)
127 | outputs.append(selected_words.unsqueeze(-1))
128 |
129 | if return_probs:
130 | if t == 0:
131 | self.all_log_probs.append(word_logprob.expand((self.b_s, self.beam_size, -1)).unsqueeze(2))
132 | else:
133 | self.all_log_probs.append(word_logprob.unsqueeze(2))
134 |
135 | this_word_logprob = torch.gather(word_logprob, 1,
136 | selected_beam.unsqueeze(-1).expand(self.b_s, self.beam_size,
137 | word_logprob.shape[-1]))
138 | this_word_logprob = torch.gather(this_word_logprob, 2, selected_words.unsqueeze(-1))
139 | self.log_probs = list(
140 | torch.gather(o, 1, selected_beam.unsqueeze(-1).expand(self.b_s, self.beam_size, 1)) for o in self.log_probs)
141 | self.log_probs.append(this_word_logprob)
142 | self.selected_words = selected_words.view(-1, 1)
143 |
144 | return visual, outputs
145 |
--------------------------------------------------------------------------------
/models/captioning_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import distributions
3 | import utils
4 | from models.containers import Module
5 | from models.beam_search import *
6 |
7 |
8 | class CaptioningModel(Module):
9 | def __init__(self):
10 | super(CaptioningModel, self).__init__()
11 |
12 | def init_weights(self):
13 | raise NotImplementedError
14 |
15 | def step(self, t, prev_output, visual, seq, mode='teacher_forcing', **kwargs):
16 | raise NotImplementedError
17 |
18 | def forward(self, images, seq, *args):
19 | device = images.device
20 | b_s = images.size(0)
21 | seq_len = seq.size(1)
22 | state = self.init_state(b_s, device)
23 | out = None
24 |
25 | outputs = []
26 | for t in range(seq_len):
27 | out, state = self.step(t, state, out, images, seq, *args, mode='teacher_forcing')
28 | outputs.append(out)
29 |
30 | outputs = torch.cat([o.unsqueeze(1) for o in outputs], 1)
31 | return outputs
32 |
33 | def test(self, visual: utils.TensorOrSequence, max_len: int, eos_idx: int, **kwargs) -> utils.Tuple[torch.Tensor, torch.Tensor]:
34 | b_s = utils.get_batch_size(visual)
35 | device = utils.get_device(visual)
36 | outputs = []
37 | log_probs = []
38 |
39 | mask = torch.ones((b_s,), device=device)
40 | with self.statefulness(b_s):
41 | out = None
42 | for t in range(max_len):
43 | log_probs_t = self.step(t, out, visual, None, mode='feedback', **kwargs)
44 | out = torch.max(log_probs_t, -1)[1]
45 | mask = mask * (out.squeeze(-1) != eos_idx).float()
46 | log_probs.append(log_probs_t * mask.unsqueeze(-1).unsqueeze(-1))
47 | outputs.append(out)
48 |
49 | return torch.cat(outputs, 1), torch.cat(log_probs, 1)
50 |
51 | def sample_rl(self, visual: utils.TensorOrSequence, max_len: int, **kwargs) -> utils.Tuple[torch.Tensor, torch.Tensor]:
52 | b_s = utils.get_batch_size(visual)
53 | outputs = []
54 | log_probs = []
55 |
56 | with self.statefulness(b_s):
57 | out = None
58 | for t in range(max_len):
59 | out = self.step(t, out, visual, None, mode='feedback', **kwargs)
60 | distr = distributions.Categorical(logits=out[:, 0])
61 | out = distr.sample().unsqueeze(1)
62 | outputs.append(out)
63 | log_probs.append(distr.log_prob(out).unsqueeze(1))
64 |
65 | return torch.cat(outputs, 1), torch.cat(log_probs, 1)
66 |
67 | def beam_search(self, visual: utils.TensorOrSequence, max_len: int, eos_idx: int, beam_size: int, out_size=1,
68 | return_probs=False, **kwargs):
69 | bs = BeamSearch(self, max_len, eos_idx, beam_size)
70 | return bs.apply(visual, out_size, return_probs, **kwargs)
71 |
--------------------------------------------------------------------------------
/models/containers.py:
--------------------------------------------------------------------------------
1 | from contextlib import contextmanager
2 | from torch import nn
3 | from utils.typing import *
4 |
5 |
6 | class Module(nn.Module):
7 | def __init__(self):
8 | super(Module, self).__init__()
9 | self._is_stateful = False
10 | self._state_names = []
11 | self._state_defaults = dict()
12 |
13 | def register_state(self, name: str, default: TensorOrNone):
14 | self._state_names.append(name)
15 | if default is None:
16 | self._state_defaults[name] = None
17 | else:
18 | self._state_defaults[name] = default.clone().detach()
19 | self.register_buffer(name, default)
20 |
21 | def states(self):
22 | for name in self._state_names:
23 | yield self._buffers[name]
24 | for m in self.children():
25 | if isinstance(m, Module):
26 | yield from m.states()
27 |
28 | def apply_to_states(self, fn):
29 | for name in self._state_names:
30 | self._buffers[name] = fn(self._buffers[name])
31 | for m in self.children():
32 | if isinstance(m, Module):
33 | m.apply_to_states(fn)
34 |
35 | def _init_states(self, batch_size: int):
36 | for name in self._state_names:
37 | if self._state_defaults[name] is None:
38 | self._buffers[name] = None
39 | else:
40 | self._buffers[name] = self._state_defaults[name].clone().detach().to(self._buffers[name].device)
41 | self._buffers[name] = self._buffers[name].unsqueeze(0)
42 | self._buffers[name] = self._buffers[name].expand([batch_size, ] + list(self._buffers[name].shape[1:]))
43 | self._buffers[name] = self._buffers[name].contiguous()
44 |
45 | def _reset_states(self):
46 | for name in self._state_names:
47 | if self._state_defaults[name] is None:
48 | self._buffers[name] = None
49 | else:
50 | self._buffers[name] = self._state_defaults[name].clone().detach().to(self._buffers[name].device)
51 |
52 | def enable_statefulness(self, batch_size: int):
53 | for m in self.children():
54 | if isinstance(m, Module):
55 | m.enable_statefulness(batch_size)
56 | self._init_states(batch_size)
57 | self._is_stateful = True
58 |
59 | def disable_statefulness(self):
60 | for m in self.children():
61 | if isinstance(m, Module):
62 | m.disable_statefulness()
63 | self._reset_states()
64 | self._is_stateful = False
65 |
66 | @contextmanager
67 | def statefulness(self, batch_size: int):
68 | self.enable_statefulness(batch_size)
69 | try:
70 | yield
71 | finally:
72 | self.disable_statefulness()
73 |
74 |
75 | class ModuleList(nn.ModuleList, Module):
76 | pass
77 |
78 |
79 | class ModuleDict(nn.ModuleDict, Module):
80 | pass
81 |
--------------------------------------------------------------------------------
/models/m2_transformer/__init__.py:
--------------------------------------------------------------------------------
1 | from .transformer import *
2 | from .encoders import *
3 | from .decoders import *
4 | from .attention import *
5 |
--------------------------------------------------------------------------------
/models/m2_transformer/attention.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import nn
4 | from models.containers import Module
5 |
6 |
7 | class ScaledDotProductAttention(nn.Module):
8 | '''
9 | Scaled dot-product attention
10 | '''
11 |
12 | def __init__(self, d_model, d_k, d_v, h):
13 | '''
14 | :param d_model: Output dimensionality of the model
15 | :param d_k: Dimensionality of queries and keys
16 | :param d_v: Dimensionality of values
17 | :param h: Number of heads
18 | '''
19 | super(ScaledDotProductAttention, self).__init__()
20 | self.fc_q = nn.Linear(d_model, h * d_k)
21 | self.fc_k = nn.Linear(d_model, h * d_k)
22 | self.fc_v = nn.Linear(d_model, h * d_v)
23 | self.fc_o = nn.Linear(h * d_v, d_model)
24 |
25 | self.d_model = d_model
26 | self.d_k = d_k
27 | self.d_v = d_v
28 | self.h = h
29 |
30 | self.init_weights()
31 |
32 | def init_weights(self):
33 | nn.init.xavier_uniform_(self.fc_q.weight)
34 | nn.init.xavier_uniform_(self.fc_k.weight)
35 | nn.init.xavier_uniform_(self.fc_v.weight)
36 | nn.init.xavier_uniform_(self.fc_o.weight)
37 | nn.init.constant_(self.fc_q.bias, 0)
38 | nn.init.constant_(self.fc_k.bias, 0)
39 | nn.init.constant_(self.fc_v.bias, 0)
40 | nn.init.constant_(self.fc_o.bias, 0)
41 |
42 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):
43 | '''
44 | Computes
45 | :param queries: Queries (b_s, nq, d_model)
46 | :param keys: Keys (b_s, nk, d_model)
47 | :param values: Values (b_s, nk, d_model)
48 | :param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking.
49 | :param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk).
50 | :return:
51 | '''
52 | b_s, nq = queries.shape[:2]
53 | nk = keys.shape[1]
54 | q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k)
55 | k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk)
56 | v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v)
57 |
58 | att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk)
59 | if attention_weights is not None:
60 | att = att * attention_weights
61 | if attention_mask is not None:
62 | att = att.masked_fill(attention_mask, -np.inf)
63 | att = torch.softmax(att, -1)
64 | out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v)
65 | out = self.fc_o(out) # (b_s, nq, d_model)
66 | return out
67 |
68 |
69 | class ScaledDotProductAttentionMemory(nn.Module):
70 | '''
71 | Scaled dot-product attention with memory
72 | '''
73 |
74 | def __init__(self, d_model, d_k, d_v, h, m):
75 | '''
76 | :param d_model: Output dimensionality of the model
77 | :param d_k: Dimensionality of queries and keys
78 | :param d_v: Dimensionality of values
79 | :param h: Number of heads
80 | :param m: Number of memory slots
81 | '''
82 | super(ScaledDotProductAttentionMemory, self).__init__()
83 | self.fc_q = nn.Linear(d_model, h * d_k)
84 | self.fc_k = nn.Linear(d_model, h * d_k)
85 | self.fc_v = nn.Linear(d_model, h * d_v)
86 | self.fc_o = nn.Linear(h * d_v, d_model)
87 | self.m_k = nn.Parameter(torch.FloatTensor(1, m, h * d_k))
88 | self.m_v = nn.Parameter(torch.FloatTensor(1, m, h * d_v))
89 |
90 | self.d_model = d_model
91 | self.d_k = d_k
92 | self.d_v = d_v
93 | self.h = h
94 | self.m = m
95 |
96 | self.init_weights()
97 |
98 | def init_weights(self):
99 | nn.init.xavier_uniform_(self.fc_q.weight)
100 | nn.init.xavier_uniform_(self.fc_k.weight)
101 | nn.init.xavier_uniform_(self.fc_v.weight)
102 | nn.init.xavier_uniform_(self.fc_o.weight)
103 | nn.init.normal_(self.m_k, 0, 1 / self.d_k)
104 | nn.init.normal_(self.m_v, 0, 1 / self.m)
105 | nn.init.constant_(self.fc_q.bias, 0)
106 | nn.init.constant_(self.fc_k.bias, 0)
107 | nn.init.constant_(self.fc_v.bias, 0)
108 | nn.init.constant_(self.fc_o.bias, 0)
109 |
110 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):
111 | '''
112 | Computes
113 | :param queries: Queries (b_s, nq, d_model)
114 | :param keys: Keys (b_s, nk, d_model)
115 | :param values: Values (b_s, nk, d_model)
116 | :param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking.
117 | :param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk).
118 | :return:
119 | '''
120 | b_s, nq = queries.shape[:2]
121 | nk = keys.shape[1]
122 |
123 | m_k = np.sqrt(self.d_k) * self.m_k.expand(b_s, self.m, self.h * self.d_k)
124 | m_v = np.sqrt(self.m) * self.m_v.expand(b_s, self.m, self.h * self.d_v)
125 |
126 | q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k)
127 | k = torch.cat([self.fc_k(keys), m_k], 1).view(b_s, nk + self.m, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk)
128 | v = torch.cat([self.fc_v(values), m_v], 1).view(b_s, nk + self.m, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v)
129 |
130 | att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk)
131 | if attention_weights is not None:
132 | att = torch.cat([att[:, :, :, :nk] * attention_weights, att[:, :, :, nk:]], -1)
133 | if attention_mask is not None:
134 | att[:, :, :, :nk] = att[:, :, :, :nk].masked_fill(attention_mask, -np.inf)
135 | att = torch.softmax(att, -1)
136 | out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v)
137 | out = self.fc_o(out) # (b_s, nq, d_model)
138 | return out
139 |
140 |
141 | class MultiHeadAttention(Module):
142 | '''
143 | Multi-head attention layer with Dropout and Layer Normalization.
144 | '''
145 |
146 | def __init__(self, d_model, d_k, d_v, h, dropout=.1, identity_map_reordering=False, can_be_stateful=False,
147 | attention_module=None, attention_module_kwargs=None):
148 | super(MultiHeadAttention, self).__init__()
149 | self.identity_map_reordering = identity_map_reordering
150 | if attention_module is not None:
151 | if attention_module_kwargs is not None:
152 | self.attention = attention_module(d_model=d_model, d_k=d_k, d_v=d_v, h=h, **attention_module_kwargs)
153 | else:
154 | self.attention = attention_module(d_model=d_model, d_k=d_k, d_v=d_v, h=h)
155 | else:
156 | self.attention = ScaledDotProductAttention(d_model=d_model, d_k=d_k, d_v=d_v, h=h)
157 | self.dropout = nn.Dropout(p=dropout)
158 | self.layer_norm = nn.LayerNorm(d_model)
159 |
160 | self.can_be_stateful = can_be_stateful
161 | if self.can_be_stateful:
162 | self.register_state('running_keys', torch.zeros((0, d_model)))
163 | self.register_state('running_values', torch.zeros((0, d_model)))
164 |
165 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):
166 | if self.can_be_stateful and self._is_stateful:
167 | self.running_keys = torch.cat([self.running_keys, keys], 1)
168 | keys = self.running_keys
169 |
170 | self.running_values = torch.cat([self.running_values, values], 1)
171 | values = self.running_values
172 |
173 | if self.identity_map_reordering:
174 | q_norm = self.layer_norm(queries)
175 | k_norm = self.layer_norm(keys)
176 | v_norm = self.layer_norm(values)
177 | out = self.attention(q_norm, k_norm, v_norm, attention_mask, attention_weights)
178 | out = queries + self.dropout(torch.relu(out))
179 | else:
180 | out = self.attention(queries, keys, values, attention_mask, attention_weights)
181 | out = self.dropout(out)
182 | out = self.layer_norm(queries + out)
183 | return out
184 |
--------------------------------------------------------------------------------
/models/m2_transformer/decoders.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 | import numpy as np
5 |
6 | from models.m2_transformer.attention import MultiHeadAttention
7 | from models.m2_transformer.utils import sinusoid_encoding_table, PositionWiseFeedForward
8 | from models.containers import Module, ModuleList
9 |
10 |
11 | class MeshedDecoderLayer(Module):
12 | def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, self_att_module=None,
13 | enc_att_module=None, self_att_module_kwargs=None, enc_att_module_kwargs=None):
14 | super(MeshedDecoderLayer, self).__init__()
15 | self.self_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=True,
16 | attention_module=self_att_module,
17 | attention_module_kwargs=self_att_module_kwargs)
18 | self.enc_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=False,
19 | attention_module=enc_att_module,
20 | attention_module_kwargs=enc_att_module_kwargs)
21 | self.pwff = PositionWiseFeedForward(d_model, d_ff, dropout)
22 |
23 | self.fc_alpha1 = nn.Linear(d_model + d_model, d_model)
24 | self.fc_alpha2 = nn.Linear(d_model + d_model, d_model)
25 | self.fc_alpha3 = nn.Linear(d_model + d_model, d_model)
26 |
27 | self.init_weights()
28 |
29 | def init_weights(self):
30 | nn.init.xavier_uniform_(self.fc_alpha1.weight)
31 | nn.init.xavier_uniform_(self.fc_alpha2.weight)
32 | nn.init.xavier_uniform_(self.fc_alpha3.weight)
33 | nn.init.constant_(self.fc_alpha1.bias, 0)
34 | nn.init.constant_(self.fc_alpha2.bias, 0)
35 | nn.init.constant_(self.fc_alpha3.bias, 0)
36 |
37 | def forward(self, input, enc_output, mask_pad, mask_self_att, mask_enc_att):
38 | self_att = self.self_att(input, input, input, mask_self_att)
39 | self_att = self_att * mask_pad
40 |
41 | enc_att1 = self.enc_att(self_att, enc_output[:, 0], enc_output[:, 0], mask_enc_att) * mask_pad
42 | enc_att2 = self.enc_att(self_att, enc_output[:, 1], enc_output[:, 1], mask_enc_att) * mask_pad
43 | enc_att3 = self.enc_att(self_att, enc_output[:, 2], enc_output[:, 2], mask_enc_att) * mask_pad
44 |
45 | alpha1 = torch.sigmoid(self.fc_alpha1(torch.cat([self_att, enc_att1], -1)))
46 | alpha2 = torch.sigmoid(self.fc_alpha2(torch.cat([self_att, enc_att2], -1)))
47 | alpha3 = torch.sigmoid(self.fc_alpha3(torch.cat([self_att, enc_att3], -1)))
48 |
49 | enc_att = (enc_att1 * alpha1 + enc_att2 * alpha2 + enc_att3 * alpha3) / np.sqrt(3)
50 | enc_att = enc_att * mask_pad
51 |
52 | ff = self.pwff(enc_att)
53 | ff = ff * mask_pad
54 | return ff
55 |
56 |
57 | class MeshedDecoder(Module):
58 | def __init__(self, vocab_size, max_len, N_dec, padding_idx, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1,
59 | self_att_module=None, enc_att_module=None, self_att_module_kwargs=None, enc_att_module_kwargs=None):
60 | super(MeshedDecoder, self).__init__()
61 | self.d_model = d_model
62 | self.word_emb = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx)
63 | self.pos_emb = nn.Embedding.from_pretrained(sinusoid_encoding_table(max_len + 1, d_model, 0), freeze=True)
64 | self.layers = ModuleList(
65 | [MeshedDecoderLayer(d_model, d_k, d_v, h, d_ff, dropout, self_att_module=self_att_module,
66 | enc_att_module=enc_att_module, self_att_module_kwargs=self_att_module_kwargs,
67 | enc_att_module_kwargs=enc_att_module_kwargs) for _ in range(N_dec)])
68 | self.fc = nn.Linear(d_model, vocab_size, bias=False)
69 | self.max_len = max_len
70 | self.padding_idx = padding_idx
71 | self.N = N_dec
72 |
73 | self.register_state('running_mask_self_attention', torch.zeros((1, 1, 0)).byte())
74 | self.register_state('running_seq', torch.zeros((1,)).long())
75 |
76 | def forward(self, input, encoder_output, mask_encoder):
77 | # input (b_s, seq_len)
78 | b_s, seq_len = input.shape[:2]
79 | mask_queries = (input != self.padding_idx).unsqueeze(-1).float() # (b_s, seq_len, 1)
80 | mask_self_attention = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.uint8, device=input.device),
81 | diagonal=1)
82 | mask_self_attention = mask_self_attention.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len)
83 | mask_self_attention = mask_self_attention + (input == self.padding_idx).unsqueeze(1).unsqueeze(1).byte()
84 | mask_self_attention = mask_self_attention.gt(0) # (b_s, 1, seq_len, seq_len)
85 | if self._is_stateful:
86 | self.running_mask_self_attention = torch.cat([self.running_mask_self_attention, mask_self_attention], -1)
87 | mask_self_attention = self.running_mask_self_attention
88 |
89 | seq = torch.arange(1, seq_len + 1).view(1, -1).expand(b_s, -1).to(input.device) # (b_s, seq_len)
90 | seq = seq.masked_fill(mask_queries.squeeze(-1) == 0, 0)
91 | if self._is_stateful:
92 | self.running_seq.add_(1)
93 | seq = self.running_seq
94 |
95 | out = self.word_emb(input) + self.pos_emb(seq)
96 | for i, l in enumerate(self.layers):
97 | out = l(out, encoder_output, mask_queries, mask_self_attention, mask_encoder)
98 |
99 | out = self.fc(out)
100 | return F.log_softmax(out, dim=-1)
101 |
--------------------------------------------------------------------------------
/models/m2_transformer/encoders.py:
--------------------------------------------------------------------------------
1 | from torch.nn import functional as F
2 | from models.m2_transformer.utils import PositionWiseFeedForward
3 | import torch
4 | from torch import nn
5 | from models.m2_transformer.attention import MultiHeadAttention
6 |
7 |
8 | class EncoderLayer(nn.Module):
9 | def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, identity_map_reordering=False,
10 | attention_module=None, attention_module_kwargs=None):
11 | super(EncoderLayer, self).__init__()
12 | self.identity_map_reordering = identity_map_reordering
13 | self.mhatt = MultiHeadAttention(d_model, d_k, d_v, h, dropout, identity_map_reordering=identity_map_reordering,
14 | attention_module=attention_module,
15 | attention_module_kwargs=attention_module_kwargs)
16 | self.pwff = PositionWiseFeedForward(d_model, d_ff, dropout, identity_map_reordering=identity_map_reordering)
17 |
18 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):
19 | att = self.mhatt(queries, keys, values, attention_mask, attention_weights)
20 | ff = self.pwff(att)
21 | return ff
22 |
23 |
24 | class MultiLevelEncoder(nn.Module):
25 | def __init__(self, N, padding_idx, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1,
26 | identity_map_reordering=False, attention_module=None, attention_module_kwargs=None):
27 | super(MultiLevelEncoder, self).__init__()
28 | self.d_model = d_model
29 | self.dropout = dropout
30 | self.layers = nn.ModuleList([EncoderLayer(d_model, d_k, d_v, h, d_ff, dropout,
31 | identity_map_reordering=identity_map_reordering,
32 | attention_module=attention_module,
33 | attention_module_kwargs=attention_module_kwargs)
34 | for _ in range(N)])
35 | self.padding_idx = padding_idx
36 |
37 | def forward(self, input, attention_weights=None):
38 | # input (b_s, seq_len, d_in)
39 | attention_mask = (torch.sum(input, -1) == self.padding_idx).unsqueeze(1).unsqueeze(1) # (b_s, 1, 1, seq_len)
40 |
41 | outs = []
42 | out = input
43 | for l in self.layers:
44 | out = l(out, out, out, attention_mask, attention_weights)
45 | outs.append(out.unsqueeze(1))
46 |
47 | outs = torch.cat(outs, 1)
48 | return outs, attention_mask
49 |
50 |
51 | class MemoryAugmentedEncoder(MultiLevelEncoder):
52 | def __init__(self, N, padding_idx, d_in=2048, **kwargs):
53 | super(MemoryAugmentedEncoder, self).__init__(N, padding_idx, **kwargs)
54 | self.fc = nn.Linear(d_in, self.d_model)
55 | self.dropout = nn.Dropout(p=self.dropout)
56 | self.layer_norm = nn.LayerNorm(self.d_model)
57 |
58 | def forward(self, input, attention_weights=None):
59 | out = F.relu(self.fc(input))
60 | out = self.dropout(out)
61 | out = self.layer_norm(out)
62 | return super(MemoryAugmentedEncoder, self).forward(out, attention_weights=attention_weights)
63 |
--------------------------------------------------------------------------------
/models/m2_transformer/transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import copy
4 | from models.containers import ModuleList
5 | from ..captioning_model import CaptioningModel
6 |
7 |
8 | class Transformer(CaptioningModel):
9 | def __init__(self, bos_idx, encoder, decoder):
10 | super(Transformer, self).__init__()
11 | self.bos_idx = bos_idx
12 | self.encoder = encoder
13 | self.decoder = decoder
14 | self.register_state('enc_output', None)
15 | self.register_state('mask_enc', None)
16 | self.init_weights()
17 |
18 | @property
19 | def d_model(self):
20 | return self.decoder.d_model
21 |
22 | def init_weights(self):
23 | for p in self.parameters():
24 | if p.dim() > 1:
25 | nn.init.xavier_uniform_(p)
26 |
27 | def forward(self, images, seq, *args):
28 | enc_output, mask_enc = self.encoder(images)
29 | dec_output = self.decoder(seq, enc_output, mask_enc)
30 | return dec_output
31 |
32 | def init_state(self, b_s, device):
33 | return [torch.zeros((b_s, 0), dtype=torch.long, device=device),
34 | None, None]
35 |
36 | def step(self, t, prev_output, visual, seq, mode='teacher_forcing', **kwargs):
37 | it = None
38 | if mode == 'teacher_forcing':
39 | raise NotImplementedError
40 | elif mode == 'feedback':
41 | if t == 0:
42 | self.enc_output, self.mask_enc = self.encoder(visual)
43 | if isinstance(visual, torch.Tensor):
44 | it = visual.data.new_full((visual.shape[0], 1), self.bos_idx).long()
45 | else:
46 | it = visual[0].data.new_full((visual[0].shape[0], 1), self.bos_idx).long()
47 | else:
48 | it = prev_output
49 |
50 | return self.decoder(it, self.enc_output, self.mask_enc)
51 |
52 |
53 | class TransformerEnsemble(CaptioningModel):
54 | def __init__(self, model: Transformer, weight_files):
55 | super(TransformerEnsemble, self).__init__()
56 | self.n = len(weight_files)
57 | self.models = ModuleList([copy.deepcopy(model) for _ in range(self.n)])
58 | for i in range(self.n):
59 | state_dict_i = torch.load(weight_files[i])['state_dict']
60 | self.models[i].load_state_dict(state_dict_i)
61 |
62 | def step(self, t, prev_output, visual, seq, mode='teacher_forcing', **kwargs):
63 | out_ensemble = []
64 | for i in range(self.n):
65 | out_i = self.models[i].step(t, prev_output, visual, seq, mode, **kwargs)
66 | out_ensemble.append(out_i.unsqueeze(0))
67 |
68 | return torch.mean(torch.cat(out_ensemble, 0), dim=0)
69 |
--------------------------------------------------------------------------------
/models/m2_transformer/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 |
6 | def position_embedding(input, d_model):
7 | input = input.view(-1, 1)
8 | dim = torch.arange(d_model // 2, dtype=torch.float32, device=input.device).view(1, -1)
9 | sin = torch.sin(input / 10000 ** (2 * dim / d_model))
10 | cos = torch.cos(input / 10000 ** (2 * dim / d_model))
11 |
12 | out = torch.zeros((input.shape[0], d_model), device=input.device)
13 | out[:, ::2] = sin
14 | out[:, 1::2] = cos
15 | return out
16 |
17 |
18 | def sinusoid_encoding_table(max_len, d_model, padding_idx=None):
19 | pos = torch.arange(max_len, dtype=torch.float32)
20 | out = position_embedding(pos, d_model)
21 |
22 | if padding_idx is not None:
23 | out[padding_idx] = 0
24 | return out
25 |
26 |
27 | class PositionWiseFeedForward(nn.Module):
28 | '''
29 | Position-wise feed forward layer
30 | '''
31 |
32 | def __init__(self, d_model=512, d_ff=2048, dropout=.1, identity_map_reordering=False):
33 | super(PositionWiseFeedForward, self).__init__()
34 | self.identity_map_reordering = identity_map_reordering
35 | self.fc1 = nn.Linear(d_model, d_ff)
36 | self.fc2 = nn.Linear(d_ff, d_model)
37 | self.dropout = nn.Dropout(p=dropout)
38 | self.dropout_2 = nn.Dropout(p=dropout)
39 | self.layer_norm = nn.LayerNorm(d_model)
40 |
41 | def forward(self, input):
42 | if self.identity_map_reordering:
43 | out = self.layer_norm(input)
44 | out = self.fc2(self.dropout_2(F.relu(self.fc1(out))))
45 | out = input + self.dropout(torch.relu(out))
46 | else:
47 | out = self.fc2(self.dropout_2(F.relu(self.fc1(input))))
48 | out = self.dropout(out)
49 | out = self.layer_norm(input + out)
50 | return out
--------------------------------------------------------------------------------
/models/rstnet/__init__.py:
--------------------------------------------------------------------------------
1 | from .transformer import *
2 | from .encoders import *
3 | from .decoders import *
4 | from .attention import *
5 |
--------------------------------------------------------------------------------
/models/rstnet/decoders.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 | from models.transformer.attention import MultiHeadAttention
6 | from models.rstnet.attention import MultiHeadAdaptiveAttention
7 | from models.transformer.utils import sinusoid_encoding_table, PositionWiseFeedForward
8 | from models.containers import Module, ModuleList
9 | from models.rstnet.language_model import LanguageModel
10 |
11 |
12 | class DecoderLayer(Module):
13 | def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, self_att_module=None,
14 | enc_att_module=None, self_att_module_kwargs=None, enc_att_module_kwargs=None):
15 | super(DecoderLayer, self).__init__()
16 | self.self_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=True,
17 | attention_module=self_att_module,
18 | attention_module_kwargs=self_att_module_kwargs)
19 | self.enc_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=False,
20 | attention_module=enc_att_module,
21 | attention_module_kwargs=enc_att_module_kwargs)
22 |
23 | self.dropout1 = nn.Dropout(dropout)
24 | self.lnorm1 = nn.LayerNorm(d_model)
25 |
26 | self.dropout2 = nn.Dropout(dropout)
27 | self.lnorm2 = nn.LayerNorm(d_model)
28 | self.pwff = PositionWiseFeedForward(d_model, d_ff, dropout)
29 |
30 | def forward(self, input, enc_output, mask_pad, mask_self_att, mask_enc_att, pos):
31 | # MHA+AddNorm
32 | self_att = self.self_att(input, input, input, mask_self_att)
33 | self_att = self.lnorm1(input + self.dropout1(self_att))
34 | self_att = self_att * mask_pad
35 | # MHA+AddNorm
36 | key = enc_output + pos
37 | enc_att = self.enc_att(self_att, key, enc_output, mask_enc_att)
38 | enc_att = self.lnorm2(self_att + self.dropout2(enc_att))
39 | enc_att = enc_att * mask_pad
40 | # FFN+AddNorm
41 | ff = self.pwff(enc_att)
42 | ff = ff * mask_pad
43 | return ff
44 |
45 |
46 | class DecoderAdaptiveLayer(Module):
47 | def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, self_att_module=None,
48 | enc_att_module=None, self_att_module_kwargs=None, enc_att_module_kwargs=None):
49 | super(DecoderAdaptiveLayer, self).__init__()
50 | self.self_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=True,
51 | attention_module=self_att_module,
52 | attention_module_kwargs=self_att_module_kwargs)
53 |
54 | self.enc_att = MultiHeadAdaptiveAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=False, attention_module=enc_att_module, attention_module_kwargs=enc_att_module_kwargs)
55 |
56 | self.dropout1 = nn.Dropout(dropout)
57 | self.lnorm1 = nn.LayerNorm(d_model)
58 |
59 | self.dropout2 = nn.Dropout(dropout)
60 | self.lnorm2 = nn.LayerNorm(d_model)
61 | self.pwff = PositionWiseFeedForward(d_model, d_ff, dropout)
62 |
63 | def forward(self, input, enc_output, mask_pad, mask_self_att, mask_enc_att, language_feature=None, pos=None):
64 | # MHA+AddNorm
65 | self_att = self.self_att(input, input, input, mask_self_att)
66 | self_att = self.lnorm1(input + self.dropout1(self_att))
67 | self_att = self_att * mask_pad
68 | # MHA+AddNorm
69 | key = enc_output + pos
70 | enc_att = self.enc_att(self_att, key, enc_output, mask_enc_att, language_feature=language_feature)
71 | enc_att = self.lnorm2(self_att + self.dropout2(enc_att))
72 | enc_att = enc_att * mask_pad
73 | # FFN+AddNorm
74 | ff = self.pwff(enc_att)
75 | ff = ff * mask_pad
76 | return ff
77 |
78 |
79 | class TransformerDecoderLayer(Module):
80 | def __init__(self, vocab_size, max_len, N_dec, padding_idx, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, bert_hidden_size=768, dropout=.1,
81 | self_att_module=None, enc_att_module=None, self_att_module_kwargs=None, enc_att_module_kwargs=None, language_model_path=None):
82 | super(TransformerDecoderLayer, self).__init__()
83 | self.d_model = d_model
84 | self.word_emb = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx)
85 | self.pos_emb = nn.Embedding.from_pretrained(sinusoid_encoding_table(max_len + 1, d_model, 0), freeze=True)
86 | self.layers = ModuleList(
87 | [DecoderLayer(d_model, d_k, d_v, h, d_ff, dropout, self_att_module=self_att_module, enc_att_module=enc_att_module, self_att_module_kwargs=self_att_module_kwargs, enc_att_module_kwargs=enc_att_module_kwargs) if i < N_dec else DecoderAdaptiveLayer(d_model, d_k, d_v, h, d_ff, dropout, self_att_module=self_att_module, enc_att_module=enc_att_module, self_att_module_kwargs=self_att_module_kwargs, enc_att_module_kwargs=enc_att_module_kwargs) for i in range(N_dec + 1)])
88 | self.fc = nn.Linear(d_model, vocab_size, bias=False)
89 |
90 | # 加载语言模型
91 | self.language_model = LanguageModel(padding_idx=padding_idx, bert_hidden_size=bert_hidden_size, vocab_size=vocab_size, max_len=max_len)
92 | assert language_model_path is not None
93 | language_model_file = torch.load(language_model_path)
94 | self.language_model.load_state_dict(language_model_file['state_dict'], strict=False)
95 | for p in self.language_model.parameters():
96 | p.requires_grad = False
97 |
98 | self.max_len = max_len
99 | self.padding_idx = padding_idx
100 | self.N = N_dec
101 |
102 | self.register_state('running_mask_self_attention', torch.zeros((1, 1, 0)).byte())
103 | self.register_state('running_seq', torch.zeros((1,)).long())
104 |
105 | def forward(self, input, encoder_output, mask_encoder, pos):
106 | # input (b_s, seq_len)
107 | b_s, seq_len = input.shape[:2]
108 | mask_queries = (input != self.padding_idx).unsqueeze(-1).float() # (b_s, seq_len, 1)
109 | mask_self_attention = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.uint8, device=input.device),
110 | diagonal=1)
111 | mask_self_attention = mask_self_attention.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len)
112 | mask_self_attention = mask_self_attention + (input == self.padding_idx).unsqueeze(1).unsqueeze(1).byte()
113 | mask_self_attention = mask_self_attention.gt(0) # (b_s, 1, seq_len, seq_len)
114 | if self._is_stateful:
115 | self.running_mask_self_attention = torch.cat([self.running_mask_self_attention.type_as(mask_self_attention), mask_self_attention], -1)
116 | mask_self_attention = self.running_mask_self_attention
117 |
118 | seq = torch.arange(1, seq_len + 1).view(1, -1).expand(b_s, -1).to(input.device) # (b_s, seq_len)
119 | seq = seq.masked_fill(mask_queries.squeeze(-1) == 0, 0)
120 | if self._is_stateful:
121 | self.running_seq.add_(1)
122 | seq = self.running_seq
123 |
124 | out = self.word_emb(input) + self.pos_emb(seq)
125 | _, language_feature = self.language_model(input)
126 |
127 | if encoder_output.shape[0] != pos.shape[0]:
128 | assert encoder_output.shape[0] % pos.shape[0] == 0
129 | beam_size = int(encoder_output.shape[0] / pos.shape[0])
130 | shape = (pos.shape[0], beam_size, pos.shape[1], pos.shape[2])
131 | pos = pos.unsqueeze(1) # bs * 1 * 50 * 512
132 | pos = pos.expand(shape) # bs * 5 * 50 * 512
133 | pos = pos.contiguous().flatten(0, 1) # (bs*5) * 50 * 512
134 |
135 | for i, l in enumerate(self.layers):
136 | if i < self.N:
137 | out = l(out, encoder_output, mask_queries, mask_self_attention, mask_encoder, pos=pos)
138 | else:
139 | out = l(out, encoder_output, mask_queries, mask_self_attention, mask_encoder, language_feature, pos=pos)
140 |
141 | out = self.fc(out)
142 | return F.log_softmax(out, dim=-1)
143 |
--------------------------------------------------------------------------------
/models/rstnet/encoders.py:
--------------------------------------------------------------------------------
1 | from torch.nn import functional as F
2 | from models.transformer.utils import PositionWiseFeedForward
3 | import torch
4 | from torch import nn
5 | from models.rstnet.attention import MultiHeadGeometryAttention
6 | from models.rstnet.grid_aug import BoxRelationalEmbedding
7 |
8 |
9 | class EncoderLayer(nn.Module):
10 | def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, identity_map_reordering=False,
11 | attention_module=None, attention_module_kwargs=None):
12 | super(EncoderLayer, self).__init__()
13 | self.identity_map_reordering = identity_map_reordering
14 | self.mhatt = MultiHeadGeometryAttention(d_model, d_k, d_v, h, dropout, identity_map_reordering=identity_map_reordering,
15 | attention_module=attention_module,
16 | attention_module_kwargs=attention_module_kwargs)
17 | self.dropout = nn.Dropout(dropout)
18 | self.lnorm = nn.LayerNorm(d_model)
19 | self.pwff = PositionWiseFeedForward(d_model, d_ff, dropout, identity_map_reordering=identity_map_reordering)
20 |
21 | def forward(self, queries, keys, values, relative_geometry_weights, attention_mask=None, attention_weights=None, pos=None):
22 |
23 | # q, k = (queries + pos, keys + pos) if pos is not None else (queries, keys)
24 | q = queries + pos
25 | k = keys + pos
26 | att = self.mhatt(q, k, values, relative_geometry_weights, attention_mask, attention_weights)
27 | att = self.lnorm(queries + self.dropout(att))
28 | ff = self.pwff(att)
29 | return ff
30 |
31 |
32 | class MultiLevelEncoder(nn.Module):
33 | def __init__(self, N, padding_idx, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1,
34 | identity_map_reordering=False, attention_module=None, attention_module_kwargs=None):
35 | super(MultiLevelEncoder, self).__init__()
36 | self.d_model = d_model
37 | self.dropout = dropout
38 | self.layers = nn.ModuleList([EncoderLayer(d_model, d_k, d_v, h, d_ff, dropout,
39 | identity_map_reordering=identity_map_reordering,
40 | attention_module=attention_module,
41 | attention_module_kwargs=attention_module_kwargs)
42 | for _ in range(N)])
43 | self.padding_idx = padding_idx
44 |
45 | self.WGs = nn.ModuleList([nn.Linear(64, 1, bias=True) for _ in range(h)])
46 |
47 | def forward(self, input, attention_weights=None, pos=None):
48 | # input (b_s, seq_len, d_in)
49 | attention_mask = (torch.sum(input, -1) == self.padding_idx).unsqueeze(1).unsqueeze(1) # (b_s, 1, 1, seq_len)
50 |
51 | # grid geometry embedding
52 | relative_geometry_embeddings = BoxRelationalEmbedding(input)
53 | flatten_relative_geometry_embeddings = relative_geometry_embeddings.view(-1, 64)
54 | box_size_per_head = list(relative_geometry_embeddings.shape[:3])
55 | box_size_per_head.insert(1, 1)
56 | relative_geometry_weights_per_head = [layer(flatten_relative_geometry_embeddings).view(box_size_per_head) for layer in self.WGs]
57 | relative_geometry_weights = torch.cat((relative_geometry_weights_per_head), 1)
58 | relative_geometry_weights = F.relu(relative_geometry_weights)
59 |
60 |
61 | relative_geometry_embeddings = BoxRelationalEmbedding(input)
62 | flatten_relative_geometry_embeddings = relative_geometry_embeddings.view(-1, 64)
63 | box_size_per_head = list(relative_geometry_embeddings.shape[:3])
64 | box_size_per_head.insert(1, 1)
65 | relative_geometry_weights_per_head = [layer(flatten_relative_geometry_embeddings).view(box_size_per_head) for layer in self.WGs]
66 | relative_geometry_weights = torch.cat((relative_geometry_weights_per_head), 1)
67 | relative_geometry_weights = F.relu(relative_geometry_weights)
68 |
69 | out = input
70 | for layer in self.layers:
71 | out = layer(out, out, out, relative_geometry_weights, attention_mask, attention_weights, pos=pos)
72 |
73 | return out, attention_mask
74 |
75 |
76 | class TransformerEncoder(MultiLevelEncoder):
77 | def __init__(self, N, padding_idx, d_in=2048, **kwargs):
78 | super(TransformerEncoder, self).__init__(N, padding_idx, **kwargs)
79 | self.fc = nn.Linear(d_in, self.d_model)
80 | self.dropout = nn.Dropout(p=self.dropout)
81 | self.layer_norm = nn.LayerNorm(self.d_model)
82 |
83 | def forward(self, input, attention_weights=None, pos=None):
84 | mask = (torch.sum(input, dim=-1) == 0).unsqueeze(-1)
85 | out = F.relu(self.fc(input))
86 | out = self.dropout(out)
87 | out = self.layer_norm(out)
88 | out = out.masked_fill(mask, 0)
89 | return super(TransformerEncoder, self).forward(out, attention_weights=attention_weights, pos=pos)
90 |
--------------------------------------------------------------------------------
/models/rstnet/grid_aug.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import math
4 |
5 | class PositionEmbeddingSine(nn.Module):
6 | """
7 | This is a more standard version of the position embedding, very similar to the one
8 | used by the Attention is all you need paper, generalized to work on images.
9 | """
10 |
11 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
12 | super().__init__()
13 | self.num_pos_feats = num_pos_feats
14 | self.temperature = temperature
15 | self.normalize = normalize
16 | if scale is not None and normalize is False:
17 | raise ValueError("normalize should be True if scale is passed")
18 | if scale is None:
19 | scale = 2 * math.pi
20 | self.scale = scale
21 |
22 | def forward(self, x, mask=None):
23 | if mask is None:
24 | mask = torch.zeros(x.shape[:-1], dtype=torch.bool, device=x.device)
25 | not_mask = (mask == False)
26 | y_embed = not_mask.cumsum(1, dtype=torch.float32)
27 | x_embed = not_mask.cumsum(2, dtype=torch.float32)
28 | if self.normalize:
29 | eps = 1e-6
30 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
31 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
32 |
33 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
34 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
35 |
36 | pos_x = x_embed[:, :, :, None] / dim_t
37 | pos_y = y_embed[:, :, :, None] / dim_t
38 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
39 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
40 | pos = torch.cat((pos_y, pos_x), dim=3) # .permute(0, 3, 1, 2)
41 | pos = pos.flatten(1, 2)
42 | return pos
43 |
44 |
45 | def get_relative_pos(x, bs, grid_size):
46 | x = x.view(1, -1, 1).expand(bs, -1, -1)
47 | return x / grid_size
48 |
49 | def get_grids_pos(bs, grid_size=7):
50 | x = torch.arange(0, grid_size).float().cuda()
51 | y = torch.arange(0, grid_size).float().cuda()
52 | px_min = x.view(-1, 1).expand(-1, grid_size).contiguous().view(-1)
53 | py_min = y.view(1, -1).expand(grid_size, -1).contiguous().view(-1)
54 | px_max = px_min + 1
55 | py_max = py_min + 1
56 |
57 | x_min = get_relative_pos(px_min, bs, grid_size)
58 | y_min = get_relative_pos(py_min, bs, grid_size)
59 | x_max = get_relative_pos(px_max, bs, grid_size)
60 | y_max = get_relative_pos(py_max, bs, grid_size)
61 | return y_min, x_min, y_max, x_max
62 |
63 |
64 | def BoxRelationalEmbedding(f_g, dim_g=64, wave_len=1000, trignometric_embedding=True):
65 | """
66 | Given a tensor with bbox coordinates for detected objects on each batch image,
67 | this function computes a matrix for each image
68 |
69 | with entry (i,j) given by a vector representation of the
70 | displacement between the coordinates of bbox_i, and bbox_j
71 |
72 | input: np.array of shape=(batch_size, max_nr_bounding_boxes, 4)
73 | output: np.array of shape=(batch_size, max_nr_bounding_boxes, max_nr_bounding_boxes, 64)
74 | """
75 | # returns a relational embedding for each pair of bboxes, with dimension = dim_g
76 | # follow implementation of https://github.com/heefe92/Relation_Networks-pytorch/blob/master/model.py#L1014-L1055
77 |
78 | batch_size = f_g.size(0)
79 | x_min, y_min, x_max, y_max = get_grids_pos(batch_size)
80 |
81 | cx = (x_min + x_max) * 0.5
82 | cy = (y_min + y_max) * 0.5
83 | w = (x_max - x_min) + 1.
84 | h = (y_max - y_min) + 1.
85 |
86 | # cx.view(1,-1) transposes the vector cx, and so dim(delta_x) = (dim(cx), dim(cx))
87 | delta_x = cx - cx.view(batch_size, 1, -1)
88 | delta_x = torch.clamp(torch.abs(delta_x / w), min=1e-3)
89 | delta_x = torch.log(delta_x)
90 |
91 | delta_y = cy - cy.view(batch_size, 1, -1)
92 | delta_y = torch.clamp(torch.abs(delta_y / h), min=1e-3)
93 | delta_y = torch.log(delta_y)
94 |
95 | delta_w = torch.log(w / w.view(batch_size, 1, -1))
96 | delta_h = torch.log(h / h.view(batch_size, 1, -1))
97 |
98 | matrix_size = delta_h.size()
99 | delta_x = delta_x.view(batch_size, matrix_size[1], matrix_size[2], 1)
100 | delta_y = delta_y.view(batch_size, matrix_size[1], matrix_size[2], 1)
101 | delta_w = delta_w.view(batch_size, matrix_size[1], matrix_size[2], 1)
102 | delta_h = delta_h.view(batch_size, matrix_size[1], matrix_size[2], 1)
103 |
104 | position_mat = torch.cat((delta_x, delta_y, delta_w, delta_h), -1) # bs * r * r *4
105 |
106 | if trignometric_embedding == True:
107 | feat_range = torch.arange(dim_g / 8).cuda()
108 | dim_mat = feat_range / (dim_g / 8)
109 | dim_mat = 1. / (torch.pow(wave_len, dim_mat))
110 |
111 | dim_mat = dim_mat.view(1, 1, 1, -1)
112 | position_mat = position_mat.view(batch_size, matrix_size[1], matrix_size[2], 4, -1)
113 | position_mat = 100. * position_mat
114 |
115 | mul_mat = position_mat * dim_mat
116 | mul_mat = mul_mat.view(batch_size, matrix_size[1], matrix_size[2], -1)
117 | sin_mat = torch.sin(mul_mat)
118 | cos_mat = torch.cos(mul_mat)
119 | embedding = torch.cat((sin_mat, cos_mat), -1)
120 | else:
121 | embedding = position_mat
122 | return (embedding)
123 |
--------------------------------------------------------------------------------
/models/rstnet/language_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import functional as F
4 | from transformers import BertModel
5 |
6 | from models.transformer.attention import MultiHeadAttention
7 | from models.transformer.utils import sinusoid_encoding_table, PositionWiseFeedForward
8 | from models.containers import Module
9 |
10 |
11 | class EncoderLayer(Module):
12 | def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, self_att_module=None,
13 | enc_att_module=None, self_att_module_kwargs=None, enc_att_module_kwargs=None):
14 | super(EncoderLayer, self).__init__()
15 | self.self_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=True,
16 | attention_module=self_att_module,
17 | attention_module_kwargs=self_att_module_kwargs)
18 |
19 | self.dropout1 = nn.Dropout(dropout)
20 | self.lnorm1 = nn.LayerNorm(d_model)
21 |
22 | self.pwff = PositionWiseFeedForward(d_model, d_ff, dropout)
23 |
24 | def forward(self, input, mask_pad, mask_self_att):
25 | # MHA+AddNorm
26 | self_att = self.self_att(input, input, input, mask_self_att)
27 | self_att = self.lnorm1(input + self.dropout1(self_att))
28 | self_att = self_att * mask_pad
29 |
30 | # FFN+AddNorm
31 | ff = self.pwff(self_att)
32 | ff = ff * mask_pad
33 | return ff
34 |
35 |
36 | class LanguageModel(Module):
37 | def __init__(self, padding_idx=0, bert_hidden_size=768, vocab_size=10201, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, max_len=54, dropout=.1):
38 | super(LanguageModel, self).__init__()
39 | self.padding_idx = padding_idx
40 | self.d_model = d_model
41 |
42 | self.language_model = BertModel.from_pretrained('bert-base-uncased', return_dict=True)
43 | self.language_model.config.vocab_size = vocab_size
44 | self.proj_to_caption_model = nn.Linear(bert_hidden_size, d_model)
45 |
46 | self.pos_emb = nn.Embedding.from_pretrained(sinusoid_encoding_table(max_len + 1, d_model, 0), freeze=True)
47 | self.encoder_layer = EncoderLayer(d_model, d_k, d_v, h, d_ff, dropout)
48 | self.proj_to_vocab = nn.Linear(d_model, vocab_size)
49 |
50 | self.register_state('running_mask_self_attention', torch.zeros((1, 1, 0)).byte())
51 | self.register_state('running_seq', torch.zeros((1,)).long())
52 |
53 | def forward(
54 | self, input_ids, attention_mask=None, token_type_ids=None,
55 | position_ids=None, head_mask=None, inputs_embeds=None, output_attentions=False,
56 | output_hidden_states=False, return_dict=False, encoder_hidden_states=None,
57 | encoder_attention_mask=None
58 | ):
59 | # input (b_s, seq_len)
60 | b_s, seq_len = input_ids.shape[:2]
61 | mask_queries = (input_ids != self.padding_idx).unsqueeze(-1).float() # (b_s, seq_len, 1)
62 | mask_self_attention = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.uint8, device=input_ids.device), diagonal=1)
63 | mask_self_attention = mask_self_attention.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len)
64 | mask_self_attention = mask_self_attention + (input_ids == self.padding_idx).unsqueeze(1).unsqueeze(1).byte()
65 | mask_self_attention = mask_self_attention.gt(0) # (b_s, 1, seq_len, seq_len)
66 | if self._is_stateful:
67 | self.running_mask_self_attention = torch.cat([self.running_mask_self_attention.type_as(mask_self_attention), mask_self_attention], -1)
68 | mask_self_attention = self.running_mask_self_attention
69 | seq = torch.arange(1, seq_len + 1).view(1, -1).expand(b_s, -1).to(input_ids.device) # (b_s, seq_len)
70 | seq = seq.masked_fill(mask_queries.squeeze(-1) == 0, 0)
71 | if self._is_stateful:
72 | self.running_seq.add_(1)
73 | seq = self.running_seq
74 |
75 | if attention_mask is None:
76 | attention_mask = torch.ones_like(input_ids)
77 | if token_type_ids is None:
78 | token_type_ids = torch.zeros_like(input_ids).long()
79 |
80 | bert_output = self.language_model(
81 | input_ids=input_ids,
82 | token_type_ids=token_type_ids,
83 | attention_mask=attention_mask
84 | )
85 | language_feature = self.proj_to_caption_model(bert_output.last_hidden_state)
86 | language_feature = language_feature + self.pos_emb(seq)
87 |
88 | language_feature = self.encoder_layer(language_feature, mask_queries, mask_self_attention)
89 |
90 | logits = self.proj_to_vocab(language_feature)
91 | out = F.log_softmax(logits, dim=-1)
92 | return out, language_feature
93 |
--------------------------------------------------------------------------------
/models/rstnet/transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import copy
4 | from models.containers import ModuleList
5 | from ..captioning_model import CaptioningModel
6 | from .grid_aug import PositionEmbeddingSine
7 |
8 |
9 | class Transformer(CaptioningModel):
10 | def __init__(self, bos_idx, encoder, decoder):
11 | super(Transformer, self).__init__()
12 | self.bos_idx = bos_idx
13 | self.encoder = encoder
14 | self.decoder = decoder
15 | self.grid_embedding = PositionEmbeddingSine(self.decoder.d_model // 2, normalize=True)
16 |
17 | self.register_state('enc_output', None)
18 | self.register_state('mask_enc', None)
19 | self.init_weights()
20 |
21 | @property
22 | def d_model(self):
23 | return self.decoder.d_model
24 |
25 | def init_weights(self):
26 | for p in self.parameters():
27 | if p.dim() > 1:
28 | nn.init.xavier_uniform_(p)
29 |
30 | def get_pos_embedding(self, grids):
31 | bs = grids.shape[0]
32 | grid_embed = self.grid_embedding(grids.view(bs, 7, 7, -1))
33 | return grid_embed
34 |
35 | def forward(self, images, seq, *args):
36 | grid_embed = self.get_pos_embedding(images)
37 | enc_output, mask_enc = self.encoder(images, pos=grid_embed)
38 | dec_output = self.decoder(seq, enc_output, mask_enc, pos=grid_embed)
39 | return dec_output
40 |
41 | def init_state(self, b_s, device):
42 | return [torch.zeros((b_s, 0), dtype=torch.long, device=device),
43 | None, None]
44 |
45 | def step(self, t, prev_output, visual, seq, mode='teacher_forcing', **kwargs):
46 | it = None
47 | if mode == 'teacher_forcing':
48 | raise NotImplementedError
49 | elif mode == 'feedback':
50 | if t == 0:
51 | self.grid_embed = self.get_pos_embedding(visual)
52 | self.enc_output, self.mask_enc = self.encoder(visual, pos=self.grid_embed)
53 |
54 | if isinstance(visual, torch.Tensor):
55 | it = visual.data.new_full((visual.shape[0], 1), self.bos_idx).long()
56 | else:
57 | it = visual[0].data.new_full((visual[0].shape[0], 1), self.bos_idx).long()
58 | else:
59 | it = prev_output
60 |
61 | return self.decoder(it, self.enc_output, self.mask_enc, pos=self.grid_embed)
62 |
63 |
64 | class TransformerEnsemble(CaptioningModel):
65 | def __init__(self, model: Transformer, weight_files):
66 | super(TransformerEnsemble, self).__init__()
67 | self.n = len(weight_files)
68 | self.models = ModuleList([copy.deepcopy(model) for _ in range(self.n)])
69 | for i in range(self.n):
70 | state_dict_i = torch.load(weight_files[i])['state_dict']
71 | self.models[i].load_state_dict(state_dict_i)
72 |
73 | def step(self, t, prev_output, visual, seq, mode='teacher_forcing', **kwargs):
74 | out_ensemble = []
75 | for i in range(self.n):
76 | out_i = self.models[i].step(t, prev_output, visual, seq, mode, **kwargs)
77 | out_ensemble.append(out_i.unsqueeze(0))
78 |
79 | return torch.mean(torch.cat(out_ensemble, 0), dim=0)
80 |
--------------------------------------------------------------------------------
/models/rstnet/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 |
6 | def position_embedding(input, d_model):
7 | input = input.view(-1, 1)
8 | dim = torch.arange(d_model // 2, dtype=torch.float32, device=input.device).view(1, -1)
9 | sin = torch.sin(input / 10000 ** (2 * dim / d_model))
10 | cos = torch.cos(input / 10000 ** (2 * dim / d_model))
11 |
12 | out = torch.zeros((input.shape[0], d_model), device=input.device)
13 | out[:, ::2] = sin
14 | out[:, 1::2] = cos
15 | return out
16 |
17 |
18 | def sinusoid_encoding_table(max_len, d_model, padding_idx=None):
19 | pos = torch.arange(max_len, dtype=torch.float32)
20 | out = position_embedding(pos, d_model)
21 |
22 | if padding_idx is not None:
23 | out[padding_idx] = 0
24 | return out
25 |
26 |
27 | class PositionWiseFeedForward(nn.Module):
28 | '''
29 | Position-wise feed forward layer
30 | '''
31 |
32 | def __init__(self, d_model=512, d_ff=2048, dropout=.1, identity_map_reordering=False):
33 | super(PositionWiseFeedForward, self).__init__()
34 | self.identity_map_reordering = identity_map_reordering
35 | self.fc1 = nn.Linear(d_model, d_ff)
36 | self.fc2 = nn.Linear(d_ff, d_model)
37 | self.dropout = nn.Dropout(p=dropout)
38 | self.dropout_2 = nn.Dropout(p=dropout)
39 | self.layer_norm = nn.LayerNorm(d_model)
40 |
41 | def forward(self, input):
42 | if self.identity_map_reordering:
43 | out = self.layer_norm(input)
44 | out = self.fc2(self.dropout_2(F.relu(self.fc1(out))))
45 | out = input + self.dropout(torch.relu(out))
46 | else:
47 | out = self.fc2(self.dropout_2(F.relu(self.fc1(input))))
48 | out = self.dropout(out)
49 | out = self.layer_norm(input + out)
50 | return out
51 |
52 |
53 | def in_norm(x_btc):
54 |
55 | b, t, c = x_btc.size()
56 | eps = 1e-6
57 |
58 | mu = torch.sum(x_btc, dim=1) / t
59 | sigma_square = torch.sum((x_btc - mu) ** 2, dim=1) / t
60 | x_btc = (x_btc - mu) / torch.sqrt(sigma_square + eps)
61 |
62 | return x_btc
63 |
--------------------------------------------------------------------------------
/models/transformer/__init__.py:
--------------------------------------------------------------------------------
1 | from .transformer import *
2 | from .encoders import *
3 | from .decoders import *
4 | from .attention import *
5 |
--------------------------------------------------------------------------------
/models/transformer/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangxuying1004/RSTNet/d308459ad1c28e0f4c2c3e4a9986ae76870e3c72/models/transformer/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/models/transformer/__pycache__/attention.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangxuying1004/RSTNet/d308459ad1c28e0f4c2c3e4a9986ae76870e3c72/models/transformer/__pycache__/attention.cpython-36.pyc
--------------------------------------------------------------------------------
/models/transformer/__pycache__/decoders.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangxuying1004/RSTNet/d308459ad1c28e0f4c2c3e4a9986ae76870e3c72/models/transformer/__pycache__/decoders.cpython-36.pyc
--------------------------------------------------------------------------------
/models/transformer/__pycache__/encoders.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangxuying1004/RSTNet/d308459ad1c28e0f4c2c3e4a9986ae76870e3c72/models/transformer/__pycache__/encoders.cpython-36.pyc
--------------------------------------------------------------------------------
/models/transformer/__pycache__/transformer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangxuying1004/RSTNet/d308459ad1c28e0f4c2c3e4a9986ae76870e3c72/models/transformer/__pycache__/transformer.cpython-36.pyc
--------------------------------------------------------------------------------
/models/transformer/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangxuying1004/RSTNet/d308459ad1c28e0f4c2c3e4a9986ae76870e3c72/models/transformer/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/models/transformer/attention.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import nn
4 | from models.containers import Module
5 |
6 |
7 | class ScaledDotProductAttention(nn.Module):
8 | '''
9 | Scaled dot-product attention
10 | '''
11 |
12 | def __init__(self, d_model, d_k, d_v, h, dropout=.1, comment=None):
13 | '''
14 | :param d_model: Output dimensionality of the model
15 | :param d_k: Dimensionality of queries and keys
16 | :param d_v: Dimensionality of values
17 | :param h: Number of heads
18 | '''
19 | super(ScaledDotProductAttention, self).__init__()
20 | self.fc_q = nn.Linear(d_model, h * d_k)
21 | self.fc_k = nn.Linear(d_model, h * d_k)
22 | self.fc_v = nn.Linear(d_model, h * d_v)
23 | self.fc_o = nn.Linear(h * d_v, d_model)
24 | self.dropout = nn.Dropout(dropout)
25 |
26 | self.d_model = d_model
27 | self.d_k = d_k
28 | self.d_v = d_v
29 | self.h = h
30 |
31 | self.init_weights()
32 |
33 | self.comment = comment
34 |
35 | def init_weights(self):
36 | nn.init.xavier_uniform_(self.fc_q.weight)
37 | nn.init.xavier_uniform_(self.fc_k.weight)
38 | nn.init.xavier_uniform_(self.fc_v.weight)
39 | nn.init.xavier_uniform_(self.fc_o.weight)
40 | nn.init.constant_(self.fc_q.bias, 0)
41 | nn.init.constant_(self.fc_k.bias, 0)
42 | nn.init.constant_(self.fc_v.bias, 0)
43 | nn.init.constant_(self.fc_o.bias, 0)
44 |
45 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):
46 | '''
47 | Computes
48 | :param queries: Queries (b_s, nq, d_model)
49 | :param keys: Keys (b_s, nk, d_model)
50 | :param values: Values (b_s, nk, d_model)
51 | :param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking.
52 | :param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk).
53 | :return:
54 | '''
55 |
56 | b_s, nq = queries.shape[:2]
57 | nk = keys.shape[1]
58 |
59 | q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k)
60 | k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk)
61 | v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v)
62 |
63 | att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk)
64 | if attention_weights is not None:
65 | att = att * attention_weights
66 | if attention_mask is not None:
67 | att = att.masked_fill(attention_mask, -np.inf)
68 | att = torch.softmax(att, -1)
69 | att = self.dropout(att)
70 |
71 | out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v)
72 | out = self.fc_o(out) # (b_s, nq, d_model)
73 | return out
74 |
75 |
76 | class ScaledDotProductAttentionMemory(nn.Module):
77 | '''
78 | Scaled dot-product attention with memory
79 | '''
80 |
81 | def __init__(self, d_model, d_k, d_v, h, m):
82 | '''
83 | :param d_model: Output dimensionality of the model
84 | :param d_k: Dimensionality of queries and keys
85 | :param d_v: Dimensionality of values
86 | :param h: Number of heads
87 | :param m: Number of memory slots
88 | '''
89 | super(ScaledDotProductAttentionMemory, self).__init__()
90 | self.fc_q = nn.Linear(d_model, h * d_k)
91 | self.fc_k = nn.Linear(d_model, h * d_k)
92 | self.fc_v = nn.Linear(d_model, h * d_v)
93 | self.fc_o = nn.Linear(h * d_v, d_model)
94 | self.m_k = nn.Parameter(torch.FloatTensor(1, m, h * d_k))
95 | self.m_v = nn.Parameter(torch.FloatTensor(1, m, h * d_v))
96 |
97 | self.d_model = d_model
98 | self.d_k = d_k
99 | self.d_v = d_v
100 | self.h = h
101 | self.m = m
102 |
103 | self.init_weights()
104 |
105 | def init_weights(self):
106 | nn.init.xavier_uniform_(self.fc_q.weight)
107 | nn.init.xavier_uniform_(self.fc_k.weight)
108 | nn.init.xavier_uniform_(self.fc_v.weight)
109 | nn.init.xavier_uniform_(self.fc_o.weight)
110 | nn.init.normal_(self.m_k, 0, 1 / self.d_k)
111 | nn.init.normal_(self.m_v, 0, 1 / self.m)
112 | nn.init.constant_(self.fc_q.bias, 0)
113 | nn.init.constant_(self.fc_k.bias, 0)
114 | nn.init.constant_(self.fc_v.bias, 0)
115 | nn.init.constant_(self.fc_o.bias, 0)
116 |
117 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):
118 | '''
119 | Computes
120 | :param queries: Queries (b_s, nq, d_model)
121 | :param keys: Keys (b_s, nk, d_model)
122 | :param values: Values (b_s, nk, d_model)
123 | :param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking.
124 | :param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk).
125 | :return:
126 | '''
127 | b_s, nq = queries.shape[:2]
128 | nk = keys.shape[1]
129 |
130 | m_k = np.sqrt(self.d_k) * self.m_k.expand(b_s, self.m, self.h * self.d_k)
131 | m_v = np.sqrt(self.m) * self.m_v.expand(b_s, self.m, self.h * self.d_v)
132 |
133 | q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k)
134 | k = torch.cat([self.fc_k(keys), m_k], 1).view(b_s, nk + self.m, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk)
135 | v = torch.cat([self.fc_v(values), m_v], 1).view(b_s, nk + self.m, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v)
136 |
137 | att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk)
138 | if attention_weights is not None:
139 | att = torch.cat([att[:, :, :, :nk] * attention_weights, att[:, :, :, nk:]], -1)
140 | if attention_mask is not None:
141 | att[:, :, :, :nk] = att[:, :, :, :nk].masked_fill(attention_mask, -np.inf)
142 | att = torch.softmax(att, -1)
143 | out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v)
144 | out = self.fc_o(out) # (b_s, nq, d_model)
145 | return out
146 |
147 |
148 | class MultiHeadAttention(Module):
149 | '''
150 | Multi-head attention layer with Dropout and Layer Normalization.
151 | '''
152 |
153 | def __init__(self, d_model, d_k, d_v, h, dropout=.1, identity_map_reordering=False, can_be_stateful=False,
154 | attention_module=None, attention_module_kwargs=None, comment=None):
155 | super(MultiHeadAttention, self).__init__()
156 | self.identity_map_reordering = identity_map_reordering
157 | self.attention = ScaledDotProductAttention(d_model=d_model, d_k=d_k, d_v=d_v, h=h, comment=comment)
158 | self.dropout = nn.Dropout(p=dropout)
159 | self.layer_norm = nn.LayerNorm(d_model)
160 |
161 | self.can_be_stateful = can_be_stateful
162 | if self.can_be_stateful:
163 | self.register_state('running_keys', torch.zeros((0, d_model)))
164 | self.register_state('running_values', torch.zeros((0, d_model)))
165 |
166 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):
167 | if self.can_be_stateful and self._is_stateful:
168 | self.running_keys = torch.cat([self.running_keys, keys], 1)
169 | keys = self.running_keys
170 |
171 | self.running_values = torch.cat([self.running_values, values], 1)
172 | values = self.running_values
173 |
174 | if self.identity_map_reordering:
175 | q_norm = self.layer_norm(queries)
176 | k_norm = self.layer_norm(keys)
177 | v_norm = self.layer_norm(values)
178 | out = self.attention(q_norm, k_norm, v_norm, attention_mask, attention_weights)
179 | out = queries + self.dropout(torch.relu(out))
180 | else:
181 | out = self.attention(queries, keys, values, attention_mask, attention_weights)
182 | out = self.dropout(out)
183 | out = self.layer_norm(queries + out)
184 | return out
185 |
--------------------------------------------------------------------------------
/models/transformer/decoders.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 | from models.transformer.attention import MultiHeadAttention
6 | from models.transformer.utils import sinusoid_encoding_table, PositionWiseFeedForward
7 | from models.containers import Module, ModuleList
8 |
9 |
10 | class DecoderLayer(Module):
11 | def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, self_att_module=None,
12 | enc_att_module=None, self_att_module_kwargs=None, enc_att_module_kwargs=None):
13 | super(DecoderLayer, self).__init__()
14 | self.self_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=True,
15 | attention_module=self_att_module,
16 | attention_module_kwargs=self_att_module_kwargs)
17 | self.enc_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=False,
18 | attention_module=enc_att_module,
19 | attention_module_kwargs=enc_att_module_kwargs)
20 |
21 | self.dropout1 = nn.Dropout(dropout)
22 | self.lnorm1 = nn.LayerNorm(d_model)
23 |
24 | self.dropout2 = nn.Dropout(dropout)
25 | self.lnorm2 = nn.LayerNorm(d_model)
26 | self.pwff = PositionWiseFeedForward(d_model, d_ff, dropout)
27 |
28 | def forward(self, input, enc_output, mask_pad, mask_self_att, mask_enc_att):
29 | # MHA+AddNorm
30 | self_att = self.self_att(input, input, input, mask_self_att)
31 | self_att = self.lnorm1(input + self.dropout1(self_att))
32 | self_att = self_att * mask_pad
33 | # MHA+AddNorm
34 | enc_att = self.enc_att(self_att, enc_output, enc_output, mask_enc_att)
35 | enc_att = self.lnorm2(self_att + self.dropout2(enc_att))
36 | enc_att = enc_att * mask_pad
37 | # FFN+AddNorm
38 | ff = self.pwff(enc_att)
39 | ff = ff * mask_pad
40 | return ff
41 |
42 |
43 | class TransformerDecoderLayer(Module):
44 | def __init__(self, vocab_size, max_len, N_dec, padding_idx, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1,
45 | self_att_module=None, enc_att_module=None, self_att_module_kwargs=None, enc_att_module_kwargs=None):
46 | super(TransformerDecoderLayer, self).__init__()
47 | self.d_model = d_model
48 | self.word_emb = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx)
49 | self.pos_emb = nn.Embedding.from_pretrained(sinusoid_encoding_table(max_len + 1, d_model, 0), freeze=True)
50 | self.layers = ModuleList(
51 | [DecoderLayer(d_model, d_k, d_v, h, d_ff, dropout, self_att_module=self_att_module, enc_att_module=enc_att_module, self_att_module_kwargs=self_att_module_kwargs, enc_att_module_kwargs=enc_att_module_kwargs) for _ in range(N_dec)])
52 | self.fc = nn.Linear(d_model, vocab_size, bias=False)
53 | self.max_len = max_len
54 | self.padding_idx = padding_idx
55 | self.N = N_dec
56 |
57 | self.register_state('running_mask_self_attention', torch.zeros((1, 1, 0)).byte())
58 | self.register_state('running_seq', torch.zeros((1,)).long())
59 |
60 | def forward(self, input, encoder_output, mask_encoder):
61 | # input (b_s, seq_len)
62 | b_s, seq_len = input.shape[:2]
63 | mask_queries = (input != self.padding_idx).unsqueeze(-1).float() # (b_s, seq_len, 1)
64 | mask_self_attention = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.uint8, device=input.device),
65 | diagonal=1)
66 | mask_self_attention = mask_self_attention.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len)
67 | mask_self_attention = mask_self_attention + (input == self.padding_idx).unsqueeze(1).unsqueeze(1).byte()
68 | mask_self_attention = mask_self_attention.gt(0) # (b_s, 1, seq_len, seq_len)
69 | if self._is_stateful:
70 | self.running_mask_self_attention = torch.cat([self.running_mask_self_attention.type_as(mask_self_attention), mask_self_attention], -1)
71 | mask_self_attention = self.running_mask_self_attention
72 |
73 | seq = torch.arange(1, seq_len + 1).view(1, -1).expand(b_s, -1).to(input.device) # (b_s, seq_len)
74 | seq = seq.masked_fill(mask_queries.squeeze(-1) == 0, 0)
75 | if self._is_stateful:
76 | self.running_seq.add_(1)
77 | seq = self.running_seq
78 |
79 | out = self.word_emb(input) + self.pos_emb(seq)
80 |
81 | for i, l in enumerate(self.layers):
82 | out = l(out, encoder_output, mask_queries, mask_self_attention, mask_encoder)
83 |
84 | out = self.fc(out)
85 | return F.log_softmax(out, dim=-1)
86 |
--------------------------------------------------------------------------------
/models/transformer/encoders.py:
--------------------------------------------------------------------------------
1 | from torch.nn import functional as F
2 | from models.transformer.utils import PositionWiseFeedForward
3 | import torch
4 | from torch import nn
5 | from models.transformer.attention import MultiHeadAttention
6 |
7 |
8 | class EncoderLayer(nn.Module):
9 | def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, identity_map_reordering=False,
10 | attention_module=None, attention_module_kwargs=None):
11 | super(EncoderLayer, self).__init__()
12 | self.identity_map_reordering = identity_map_reordering
13 | self.mhatt = MultiHeadAttention(d_model, d_k, d_v, h, dropout, identity_map_reordering=identity_map_reordering,
14 | attention_module=attention_module,
15 | attention_module_kwargs=attention_module_kwargs)
16 | self.dropout = nn.Dropout(dropout)
17 | self.lnorm = nn.LayerNorm(d_model)
18 | self.pwff = PositionWiseFeedForward(d_model, d_ff, dropout, identity_map_reordering=identity_map_reordering)
19 |
20 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):
21 |
22 | att = self.mhatt(queries, keys, values, attention_mask, attention_weights)
23 | att = self.lnorm(queries + self.dropout(att))
24 | ff = self.pwff(att)
25 | return ff
26 |
27 |
28 | class MultiLevelEncoder(nn.Module):
29 | def __init__(self, N, padding_idx, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1,
30 | identity_map_reordering=False, attention_module=None, attention_module_kwargs=None):
31 | super(MultiLevelEncoder, self).__init__()
32 | self.d_model = d_model
33 | self.dropout = dropout
34 | self.layers = nn.ModuleList([EncoderLayer(d_model, d_k, d_v, h, d_ff, dropout,
35 | identity_map_reordering=identity_map_reordering,
36 | attention_module=attention_module,
37 | attention_module_kwargs=attention_module_kwargs)
38 | for _ in range(N)])
39 | self.padding_idx = padding_idx
40 |
41 | def forward(self, input, attention_weights=None):
42 | # input (b_s, seq_len, d_in)
43 | attention_mask = (torch.sum(input, -1) == self.padding_idx).unsqueeze(1).unsqueeze(1) # (b_s, 1, 1, seq_len)
44 |
45 | out = input
46 | for l in self.layers:
47 | out = l(out, out, out, attention_mask, attention_weights)
48 | # outs.append(out.unsqueeze(1))
49 |
50 | # outs = torch.cat(outs, 1)
51 | return out, attention_mask
52 |
53 |
54 | class TransformerEncoder(MultiLevelEncoder):
55 | def __init__(self, N, padding_idx, d_in=2048, **kwargs):
56 | super(TransformerEncoder, self).__init__(N, padding_idx, **kwargs)
57 | self.fc = nn.Linear(d_in, self.d_model)
58 | self.dropout = nn.Dropout(p=self.dropout)
59 | self.layer_norm = nn.LayerNorm(self.d_model)
60 |
61 | def forward(self, input, attention_weights=None):
62 | mask = (torch.sum(input, dim=-1) == 0).unsqueeze(-1)
63 | out = F.relu(self.fc(input))
64 | out = self.dropout(out)
65 | out = self.layer_norm(out)
66 | out = out.masked_fill(mask, 0)
67 | return super(TransformerEncoder, self).forward(out, attention_weights=attention_weights)
68 |
--------------------------------------------------------------------------------
/models/transformer/transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import copy
4 | from models.containers import ModuleList
5 | from ..captioning_model import CaptioningModel
6 |
7 |
8 | class Transformer(CaptioningModel):
9 | def __init__(self, bos_idx, encoder, decoder):
10 | super(Transformer, self).__init__()
11 | self.bos_idx = bos_idx
12 | self.encoder = encoder
13 | self.decoder = decoder
14 | self.register_state('enc_output', None)
15 | self.register_state('mask_enc', None)
16 | self.init_weights()
17 |
18 | @property
19 | def d_model(self):
20 | return self.decoder.d_model
21 |
22 | def init_weights(self):
23 | for p in self.parameters():
24 | if p.dim() > 1:
25 | nn.init.xavier_uniform_(p)
26 |
27 | def forward(self, images, seq, *args):
28 | enc_output, mask_enc = self.encoder(images)
29 | dec_output = self.decoder(seq, enc_output, mask_enc)
30 | return dec_output
31 |
32 | def init_state(self, b_s, device):
33 | return [torch.zeros((b_s, 0), dtype=torch.long, device=device),
34 | None, None]
35 |
36 | def step(self, t, prev_output, visual, seq, mode='teacher_forcing', **kwargs):
37 | it = None
38 | if mode == 'teacher_forcing':
39 | raise NotImplementedError
40 | elif mode == 'feedback':
41 | if t == 0:
42 | self.enc_output, self.mask_enc = self.encoder(visual)
43 | if isinstance(visual, torch.Tensor):
44 | it = visual.data.new_full((visual.shape[0], 1), self.bos_idx).long()
45 | else:
46 | it = visual[0].data.new_full((visual[0].shape[0], 1), self.bos_idx).long()
47 | else:
48 | it = prev_output
49 |
50 | return self.decoder(it, self.enc_output, self.mask_enc)
51 |
52 |
53 | class TransformerEnsemble(CaptioningModel):
54 | def __init__(self, model: Transformer, weight_files):
55 | super(TransformerEnsemble, self).__init__()
56 | self.n = len(weight_files)
57 | self.models = ModuleList([copy.deepcopy(model) for _ in range(self.n)])
58 | for i in range(self.n):
59 | state_dict_i = torch.load(weight_files[i])['state_dict']
60 | self.models[i].load_state_dict(state_dict_i)
61 |
62 | def step(self, t, prev_output, visual, seq, mode='teacher_forcing', **kwargs):
63 | out_ensemble = []
64 | for i in range(self.n):
65 | out_i = self.models[i].step(t, prev_output, visual, seq, mode, **kwargs)
66 | out_ensemble.append(out_i.unsqueeze(0))
67 |
68 | return torch.mean(torch.cat(out_ensemble, 0), dim=0)
69 |
--------------------------------------------------------------------------------
/models/transformer/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 |
6 | def position_embedding(input, d_model):
7 | input = input.view(-1, 1)
8 | dim = torch.arange(d_model // 2, dtype=torch.float32, device=input.device).view(1, -1)
9 | sin = torch.sin(input / 10000 ** (2 * dim / d_model))
10 | cos = torch.cos(input / 10000 ** (2 * dim / d_model))
11 |
12 | out = torch.zeros((input.shape[0], d_model), device=input.device)
13 | out[:, ::2] = sin
14 | out[:, 1::2] = cos
15 | return out
16 |
17 |
18 | def sinusoid_encoding_table(max_len, d_model, padding_idx=None):
19 | pos = torch.arange(max_len, dtype=torch.float32)
20 | out = position_embedding(pos, d_model)
21 |
22 | if padding_idx is not None:
23 | out[padding_idx] = 0
24 | return out
25 |
26 |
27 | class PositionWiseFeedForward(nn.Module):
28 | '''
29 | Position-wise feed forward layer
30 | '''
31 |
32 | def __init__(self, d_model=512, d_ff=2048, dropout=.1, identity_map_reordering=False):
33 | super(PositionWiseFeedForward, self).__init__()
34 | self.identity_map_reordering = identity_map_reordering
35 | self.fc1 = nn.Linear(d_model, d_ff)
36 | self.fc2 = nn.Linear(d_ff, d_model)
37 | self.dropout = nn.Dropout(p=dropout)
38 | self.dropout_2 = nn.Dropout(p=dropout)
39 | self.layer_norm = nn.LayerNorm(d_model)
40 |
41 | def forward(self, input):
42 | if self.identity_map_reordering:
43 | out = self.layer_norm(input)
44 | out = self.fc2(self.dropout_2(F.relu(self.fc1(out))))
45 | out = input + self.dropout(torch.relu(out))
46 | else:
47 | out = self.fc2(self.dropout_2(F.relu(self.fc1(input))))
48 | out = self.dropout(out)
49 | out = self.layer_norm(input + out)
50 | return out
51 |
--------------------------------------------------------------------------------
/pretrained_models/README.md:
--------------------------------------------------------------------------------
1 | you can download our pretrained models here.
2 |
3 | #### 1. pretrained languge model
4 | [language_context.pth](https://drive.google.com/file/d/1sayx7qwOd79XE4RFdvSXG3zyQH4FpyYJ/view?usp=sharing)
5 | #### 2. pretrained rstnet model
6 | [rstnet.pth](https://drive.google.com/file/d/13Faz1NIjvwSqBofDo7yO94bgbEdFtKgE/view?usp=sharing)
7 |
--------------------------------------------------------------------------------
/switch_datatype.py:
--------------------------------------------------------------------------------
1 | # switch data from float32 to float16
2 | import os
3 | import torch
4 | from tqdm import tqdm
5 | import numpy as np
6 | import argparse
7 |
8 |
9 | def main(args):
10 |
11 | data_splits = os.listdir(args.dir_to_save_feats)
12 |
13 | for data_split in data_splits:
14 | print('processing {} ...'.format(data_split))
15 | if not os.path.exists(os.path.join(args.dir_to_save_float16_feats, data_split)):
16 | os.mkdir(os.path.join(args.dir_to_save_float16_feats, data_split))
17 |
18 | feat_dir = os.path.join(args.dir_to_save_feats, data_split)
19 | file_names = os.listdir(feat_dir)
20 | print(len(file_names))
21 |
22 | for i in tqdm(range(len(file_names))):
23 | file_name = file_names[i]
24 | file_path = os.path.join(args.dir_to_save_feats, data_split, file_name)
25 | data32 = torch.load(file_path).numpy().squeeze()
26 | data16 = data32.astype('float16')
27 |
28 | image_id = int(file_name.split('.')[0])
29 | saved_file_path = os.path.join(args.dir_to_save_float16_feats, data_split, str(image_id)+'.npy')
30 | np.save(saved_file_path, data16)
31 |
32 |
33 | if __name__ == '__main__':
34 |
35 | parser = argparse.ArgumentParser(description='swith the data type of features')
36 | parser.add_argument('--dir_to_raw_feats', type=str, default='./Datasets/X101-features/', help='big data')
37 | parser.add_argument('--dir_to_float16_feats', type=str, default='./Datasets/X101-features-float16', help='little data')
38 | args = parser.parse_args()
39 |
40 | main(args)
41 |
42 |
--------------------------------------------------------------------------------
/tensorboard_logs/rstnet/events.out.tfevents.1603849421.MAC-U2S5.25151.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangxuying1004/RSTNet/d308459ad1c28e0f4c2c3e4a9986ae76870e3c72/tensorboard_logs/rstnet/events.out.tfevents.1603849421.MAC-U2S5.25151.0
--------------------------------------------------------------------------------
/tensorboard_logs/rstnet/events.out.tfevents.1603849484.MAC-U2S5.28641.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangxuying1004/RSTNet/d308459ad1c28e0f4c2c3e4a9986ae76870e3c72/tensorboard_logs/rstnet/events.out.tfevents.1603849484.MAC-U2S5.28641.0
--------------------------------------------------------------------------------
/tensorboard_logs/rstnet/events.out.tfevents.1603849514.MAC-U2S5.30196.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangxuying1004/RSTNet/d308459ad1c28e0f4c2c3e4a9986ae76870e3c72/tensorboard_logs/rstnet/events.out.tfevents.1603849514.MAC-U2S5.30196.0
--------------------------------------------------------------------------------
/tensorboard_logs/rstnet/events.out.tfevents.1603850322.MAC-U2S5.11650.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangxuying1004/RSTNet/d308459ad1c28e0f4c2c3e4a9986ae76870e3c72/tensorboard_logs/rstnet/events.out.tfevents.1603850322.MAC-U2S5.11650.0
--------------------------------------------------------------------------------
/tensorboard_logs/rstnet/events.out.tfevents.1605359164.MAC-U2S5.4519.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangxuying1004/RSTNet/d308459ad1c28e0f4c2c3e4a9986ae76870e3c72/tensorboard_logs/rstnet/events.out.tfevents.1605359164.MAC-U2S5.4519.0
--------------------------------------------------------------------------------
/tensorboard_logs/transformer/events.out.tfevents.1594888876.socialmedia.34273.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangxuying1004/RSTNet/d308459ad1c28e0f4c2c3e4a9986ae76870e3c72/tensorboard_logs/transformer/events.out.tfevents.1594888876.socialmedia.34273.0
--------------------------------------------------------------------------------
/test_offline.py:
--------------------------------------------------------------------------------
1 | # 线下测试:evaluating the performance of captioning model on the Karpathy test split of MS-COCO.
2 |
3 | import random
4 | from data import ImageDetectionsField, TextField, RawField
5 | from data import COCO, DataLoader
6 | import evaluation
7 | from models.rstnet import Transformer, TransformerEncoder, TransformerDecoderLayer, ScaledDotProductAttention
8 |
9 | import torch
10 | from tqdm import tqdm
11 | import argparse
12 | import pickle
13 | import numpy as np
14 | import time
15 |
16 | random.seed(1234)
17 | torch.manual_seed(1234)
18 | np.random.seed(1234)
19 |
20 |
21 | def predict_captions(model, dataloader, text_field):
22 | import itertools
23 | model.eval()
24 | gen = {}
25 | gts = {}
26 | with tqdm(desc='Evaluation', unit='it', total=len(dataloader)) as pbar:
27 | for it, (images, caps_gt) in enumerate(iter(dataloader)):
28 | images = images.to(device)
29 | with torch.no_grad():
30 | out, _ = model.beam_search(images, 20, text_field.vocab.stoi[''], 5, out_size=1)
31 | caps_gen = text_field.decode(out, join_words=False)
32 | for i, (gts_i, gen_i) in enumerate(zip(caps_gt, caps_gen)):
33 | gen_i = ' '.join([k for k, g in itertools.groupby(gen_i)])
34 | gen['%d_%d' % (it, i)] = [gen_i.strip(), ]
35 | gts['%d_%d' % (it, i)] = gts_i
36 | pbar.update()
37 |
38 | gts = evaluation.PTBTokenizer.tokenize(gts)
39 | gen = evaluation.PTBTokenizer.tokenize(gen)
40 | scores, _ = evaluation.compute_scores(gts, gen)
41 |
42 | return scores
43 |
44 |
45 | if __name__ == '__main__':
46 | start_time = time.time()
47 | device = torch.device('cuda')
48 |
49 | parser = argparse.ArgumentParser(description='Relationship-Sensitive Transformer Network')
50 | parser.add_argument('--batch_size', type=int, default=10)
51 | parser.add_argument('--workers', type=int, default=4)
52 | parser.add_argument('--m', type=int, default=40)
53 |
54 | parser.add_argument('--features_path', type=str, default='./datasets/X101-features/X101-grid-coco_trainval.hdf5')
55 | parser.add_argument('--annotation_folder', type=str, default='./datasets/m2_annotations')
56 |
57 | # the path of tested model and vocabulary
58 | parser.add_argument('--language_model_path', type=str, default='./saved_language_models/language_context.pth')
59 | parser.add_argument('--model_path', type=str, default='./saved_transformer_models/rstnet_best.pth')
60 | parser.add_argument('--vocab_path', type=str, default='./vocab.pkl')
61 | args = parser.parse_args()
62 |
63 | print('The Offline Evaluation of RSTNet')
64 |
65 | # Pipeline for image regions
66 | image_field = ImageDetectionsField(detections_path=args.features_path, max_detections=49, load_in_tmp=False)
67 |
68 | # Pipeline for text
69 | text_field = TextField(init_token='', eos_token='', lower=True, tokenize='spacy',
70 | remove_punctuation=True, nopoints=False)
71 |
72 | # 加载数据集
73 | dataset = COCO(image_field, text_field, 'coco/images/', args.annotation_folder, args.annotation_folder)
74 | _, _, test_dataset = dataset.splits
75 | text_field.vocab = pickle.load(open(args.vocab_path, 'rb'))
76 |
77 | dict_dataset_test = test_dataset.image_dictionary({'image': image_field, 'text': RawField()})
78 | dict_dataloader_test = DataLoader(dict_dataset_test, batch_size=args.batch_size, num_workers=args.workers)
79 |
80 | # 加载模型
81 | encoder = TransformerEncoder(3, 0, attention_module=ScaledDotProductAttention, attention_module_kwargs={'m': args.m})
82 | decoder = TransformerDecoderLayer(len(text_field.vocab), 54, 3, text_field.vocab.stoi[''], language_model_path=args.language_model_path)
83 | model = Transformer(text_field.vocab.stoi[''], encoder, decoder).to(device)
84 |
85 | data = torch.load(args.model_path)
86 | model.load_state_dict(data['state_dict'])
87 |
88 | # 计算得分
89 | scores = predict_captions(model, dict_dataloader_test, text_field)
90 | print(scores)
91 | print('it costs {} s to test.'.format(time.time() - start_time))
92 |
93 |
94 |
--------------------------------------------------------------------------------
/test_online.py:
--------------------------------------------------------------------------------
1 | # 线上测试:evaluating the performance of captioning model on official MS-COCO test server.
2 | # 此处生成指定格式的官方测试集和验证集对应的captions文件,将这两个文件压缩后上传到[CodaLab](https://competitions.codalab.org/competitions/3221#participate)即可得到
3 | # 线上测试的结果和排名
4 |
5 | import torch
6 | import argparse
7 | import pickle
8 | import numpy as np
9 | import itertools
10 | import json
11 | import os
12 |
13 | from tqdm import tqdm
14 |
15 | from data import TextField, COCO_TestOnline
16 | from models.rstnet import Transformer, TransformerEncoder, TransformerDecoderLayer, ScaledDotProductAttention, TransformerEnsemble
17 |
18 | import random
19 | random.seed(1234)
20 | torch.manual_seed(1234)
21 | np.random.seed(1234)
22 |
23 | from torch.utils.data import DataLoader
24 |
25 |
26 | def gen_caps(captioning_model, dataset, batch_size=10, workers=0):
27 | dataloader = DataLoader(
28 | dataset,
29 | batch_size=batch_size,
30 | num_workers=workers
31 | )
32 |
33 | outputs = []
34 | with tqdm(len(dataloader)) as pbar:
35 | for it, (image_ids, images) in enumerate(iter(dataloader)):
36 | images = images.to(device)
37 | with torch.no_grad():
38 | out, _ = captioning_model.beam_search(images, 20, text_field.vocab.stoi[''], 5, out_size=1)
39 | caps_gen = text_field.decode(out, join_words=False)
40 | caps_gen = [' '.join([k for k, g in itertools.groupby(gen_i)]).strip() for gen_i in caps_gen]
41 | for i in range(image_ids.size(0)):
42 | item = {}
43 | item['image_id'] = int(image_ids[i])
44 | item['caption'] = caps_gen[i]
45 | outputs.append(item)
46 | pbar.update()
47 | return outputs
48 |
49 |
50 | def save_results(outputs, datasplit, dir_to_save_caps):
51 | if not os.path.exists(dir_to_save_caps):
52 | os.makedirs(dir_to_save_caps)
53 | # 命名规范:captions_test2014_XXX_results.json 和 captions_val2014_XXX_results.json
54 | output_path = os.path.join(dir_to_save_caps, 'captions_' + datasplit + '2014_RSTNet_results.json')
55 | with open(output_path, 'w') as f:
56 | json.dump(outputs, f)
57 |
58 |
59 | if __name__ == '__main__':
60 |
61 | device = torch.device('cuda')
62 |
63 | parser = argparse.ArgumentParser(description='Relationship-Sensitive Transformer Network')
64 | parser.add_argument('--batch_size', type=int, default=10)
65 | parser.add_argument('--workers', type=int, default=0)
66 |
67 | parser.add_argument('--datasplit', type=str, default='test') # test, val
68 |
69 | # 测试集
70 | parser.add_argument('--test_features_path', type=str, default='./datasets/X101-features/X101_grid_feats_coco_test.hdf5')
71 | parser.add_argument('--test_annotation_folder', type=str, default='./datasets/m2_annotations/image_info_test2014.json')
72 |
73 | # 验证集
74 | parser.add_argument('--val_features_path', type=str, default='./datasets/X101-features/X101_grid_feats_coco_trainval.hdf5')
75 | parser.add_argument('--val_annotation_folder', type=str, default='/home/DATA/m2_annotations/captions_val2014.json')
76 |
77 | # 模型参数
78 | parser.add_argument('--models_path', type=list, default=[
79 | './test_online/models/rstnet_x101_1.pth',
80 | './test_online/models/rstnet_x101_2.pth',
81 | './test_online/models/rstnet_x101_3.pth',
82 | './test_online/models/rstnet_x101_4.pth'
83 | ])
84 |
85 | parser.add_argument('--dir_to_save_caps', type=str, default='./test_online/results/') # 文件保存路径
86 |
87 | args = parser.parse_args()
88 |
89 | print('The Online Evaluation of RSTNet')
90 |
91 | # 加载数据集
92 | text_field = TextField(init_token='', eos_token='', lower=True, tokenize='spacy',
93 | remove_punctuation=True, nopoints=False)
94 | text_field.vocab = pickle.load(open('./vocab.pkl', 'rb'))
95 |
96 | if args.datasplit == 'test':
97 | dataset = COCO_TestOnline(feat_path=args.test_features_path, ann_file=args.test_annotation_folder)
98 | else:
99 | dataset = COCO_TestOnline(feat_path=args.val_features_path, ann_file=args.val_annotation_folder)
100 |
101 | # 加载模型参数
102 | # 模型结构
103 | encoder = TransformerEncoder(3, 0, attention_module=ScaledDotProductAttention, attention_module_kwargs={'m': 40})
104 | decoder = TransformerDecoderLayer(len(text_field.vocab), 54, 3, text_field.vocab.stoi[''])
105 | model = Transformer(text_field.vocab.stoi[''], encoder, decoder).to(device)
106 | # 集成模型
107 | ensemble_model = TransformerEnsemble(model=model, weight_files=args.models_path)
108 |
109 | # 生成结果
110 | outputs = gen_caps(ensemble_model, dataset, batch_size=args.batch_size, workers=args.workers)
111 |
112 | # 保存结果
113 | save_results(outputs, args.datasplit, args.dir_to_save_caps)
114 |
115 | print('finished!')
116 |
117 |
118 |
--------------------------------------------------------------------------------
/train_language.py:
--------------------------------------------------------------------------------
1 | import random
2 | from data import ImageDetectionsField, TextField, RawField
3 | from data import COCO, DataLoader
4 |
5 | from models.rstnet.language_model import LanguageModel
6 |
7 | import torch
8 | from torch.optim import Adam
9 | from torch.optim.lr_scheduler import LambdaLR
10 | from torch.nn import NLLLoss
11 | from tqdm import tqdm
12 | from torch.utils.tensorboard import SummaryWriter
13 |
14 | import argparse
15 | import os
16 | import pickle
17 | import numpy as np
18 | from shutil import copyfile
19 |
20 | random.seed(1234)
21 | torch.manual_seed(1234)
22 | np.random.seed(1234)
23 |
24 |
25 | def evaluate_loss(model, dataloader, loss_fn, text_field):
26 | # Validation loss
27 | model.eval()
28 | running_loss = .0
29 | with tqdm(desc='Epoch %d - validation' % e, unit='it', total=len(dataloader)) as pbar:
30 | with torch.no_grad():
31 | for it, (detections, captions) in enumerate(dataloader):
32 | detections, captions = detections.to(device), captions.to(device)
33 | # out = model(detections, captions)
34 | out, _ = model(captions)
35 | captions = captions[:, 1:].contiguous()
36 | out = out[:, :-1].contiguous()
37 | loss = loss_fn(out.view(-1, len(text_field.vocab)), captions.view(-1))
38 | this_loss = loss.item()
39 | running_loss += this_loss
40 |
41 | pbar.set_postfix(loss=running_loss / (it + 1))
42 | pbar.update()
43 |
44 | val_loss = running_loss / len(dataloader)
45 | return val_loss
46 |
47 |
48 | def evaluate_metrics(model, dataloader, text_field):
49 | model.eval()
50 | scores = {}
51 | total_num = 0.
52 | correct_num = 0.
53 | with tqdm(desc='Epoch %d - evaluation' % e, unit='ite', total=len(dataloader)) as pbar:
54 | with torch.no_grad():
55 | for it, (detections, captions) in enumerate(dataloader):
56 | detections, captions = detections.to(device), captions.to(device)
57 | # out = model(detections, captions)
58 | out, _ = model(captions)
59 | captions = captions[:, 1:].contiguous()
60 | out = torch.argmax(out[:, :-1], dim=-1).contiguous()
61 | b_s, seq_len = out.size()
62 | total_num += float(b_s * seq_len)
63 | correct_num += float((out == captions).sum())
64 | pbar.update()
65 |
66 | scores['correct_num'] = correct_num
67 | scores['total_num'] = total_num
68 | scores['accuracy'] = correct_num / total_num
69 | return scores
70 |
71 |
72 | def train_xe(model, dataloader_train, optim, text_field):
73 | # Training with cross-entropy
74 | model.train()
75 | scheduler.step()
76 | running_loss = .0
77 | # print('lr = {}'.format(scheduler.get_lr()[0]))
78 | with tqdm(desc='Epoch %d - train' % e, unit='it', total=len(dataloader_train)) as pbar:
79 | for it, (detections, captions) in enumerate(dataloader_train):
80 | detections, captions = detections.to(device), captions.to(device)
81 | # out = model(detections, captions)
82 | out, _ = model(captions)
83 | optim.zero_grad()
84 | captions_gt = captions[:, 1:].contiguous()
85 | out = out[:, :-1].contiguous()
86 | loss = loss_fn(out.view(-1, len(text_field.vocab)), captions_gt.view(-1))
87 | loss.backward()
88 |
89 | optim.step()
90 | this_loss = loss.item()
91 | running_loss += this_loss
92 |
93 | pbar.set_postfix(loss=running_loss / (it + 1))
94 | pbar.update()
95 | scheduler.step()
96 |
97 | loss = running_loss / len(dataloader_train)
98 | return loss
99 |
100 |
101 | if __name__ == '__main__':
102 | device = torch.device('cuda')
103 | parser = argparse.ArgumentParser(description='Bert Language Model')
104 | parser.add_argument('--exp_name', type=str, default='bert_language')
105 | parser.add_argument('--batch_size', type=int, default=50)
106 | parser.add_argument('--workers', type=int, default=4)
107 | parser.add_argument('--m', type=int, default=40)
108 | parser.add_argument('--head', type=int, default=8)
109 | parser.add_argument('--warmup', type=int, default=11328)
110 | parser.add_argument('--resume_last', action='store_true')
111 | parser.add_argument('--resume_best', action='store_true')
112 |
113 | parser.add_argument('--features_path', type=str, default='./Datasets/X101-features/X101-grid-coco_trainval.hdf5')
114 | parser.add_argument('--annotation_folder', type=str, default='./Datasets/m2_annotations')
115 |
116 | parser.add_argument('--dir_to_save_model', type=str, default='./saved_language_models')
117 | parser.add_argument('--logs_folder', type=str, default='./language_tensorboard_logs')
118 |
119 | args = parser.parse_args()
120 | print(args)
121 |
122 | print('Bert Language Model Training')
123 | # preparation
124 | if not os.path.exists(args.dir_to_save_model):
125 | os.makedirs(args.dir_to_save_model)
126 | if not os.path.exists(args.logs_folder):
127 | os.makedirs(args.logs_folder)
128 | writer = SummaryWriter(log_dir=os.path.join(args.logs_folder, args.exp_name))
129 |
130 | # Pipeline for image regions
131 | image_field = ImageDetectionsField(detections_path=args.features_path, max_detections=50, load_in_tmp=False)
132 |
133 | # Pipeline for text
134 | text_field = TextField(init_token='', eos_token='', lower=True, tokenize='spacy',
135 | remove_punctuation=True, nopoints=False)
136 |
137 | # Create the dataset
138 | dataset = COCO(image_field, text_field, 'coco/images/', args.annotation_folder, args.annotation_folder)
139 | train_dataset, val_dataset, test_dataset = dataset.splits
140 |
141 | if not os.path.isfile('vocab.pkl'):
142 | print("Building vocabulary")
143 | text_field.build_vocab(train_dataset, val_dataset, min_freq=5)
144 | pickle.dump(text_field.vocab, open('vocab.pkl', 'wb'))
145 | else:
146 | print('Loading from vocabulary')
147 | text_field.vocab = pickle.load(open('vocab.pkl', 'rb'))
148 |
149 | model = LanguageModel(padding_idx=text_field.vocab.stoi[''], bert_hidden_size=768, vocab_size=len(text_field.vocab)).to(device)
150 |
151 | dict_dataset_train = train_dataset.image_dictionary({'image': image_field, 'text': RawField()})
152 | dict_dataset_val = val_dataset.image_dictionary({'image': image_field, 'text': RawField()})
153 | dict_dataset_test = test_dataset.image_dictionary({'image': image_field, 'text': RawField()})
154 |
155 | def lambda_lr(s):
156 | warm_up = args.warmup
157 | s += 1
158 | if s % 11331 == 0:
159 | s = 1
160 | else:
161 | s = s % 11331
162 |
163 | lr = (model.d_model ** -.5) * min(s ** -.5, s * warm_up ** -1.5)
164 | if lr > 1e-6:
165 | lr = 1e-6
166 |
167 | print('s = {}, lr = {}'.format(s, lr))
168 | return lr
169 |
170 | # Initial conditions
171 | optim = Adam(model.parameters(), lr=1, betas=(0.9, 0.98))
172 | scheduler = LambdaLR(optim, lambda_lr)
173 | # scheduler = StepLR(optim, step_size=2, gamma=0.5)
174 | loss_fn = NLLLoss(ignore_index=text_field.vocab.stoi[''])
175 | use_rl = False
176 | best_score = .0
177 | best_test_score = .0
178 | patience = 0
179 | start_epoch = 0
180 |
181 | if args.resume_last or args.resume_best:
182 | if args.resume_last:
183 | fname = os.path.join(args.dir_to_save_model, '%s_last.pth' % args.exp_name)
184 | else:
185 | fname = os.path.join(args.dir_to_save_model, '%s_best.pth' % args.exp_name)
186 |
187 | if os.path.exists(fname):
188 | data = torch.load(fname)
189 | torch.set_rng_state(data['torch_rng_state'])
190 | torch.cuda.set_rng_state(data['cuda_rng_state'])
191 | np.random.set_state(data['numpy_rng_state'])
192 | random.setstate(data['random_rng_state'])
193 | model.load_state_dict(data['state_dict'], strict=False)
194 | optim.load_state_dict(data['optimizer'])
195 | scheduler.load_state_dict(data['scheduler'])
196 | start_epoch = data['epoch'] + 1
197 | best_score = data['best_score']
198 | patience = data['patience']
199 | use_rl = data['use_rl']
200 | print('Resuming from epoch %d, validation loss %f, and best score %f' % (
201 | data['epoch'], data['val_loss'], data['best_score']))
202 |
203 | print("Training starts")
204 | for e in range(start_epoch, start_epoch + 100):
205 | dataloader_train = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers,
206 | drop_last=True)
207 | dataloader_val = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
208 | dataloader_test = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
209 |
210 | dict_dataloader_train = DataLoader(dict_dataset_train, batch_size=args.batch_size // 5, shuffle=True,
211 | num_workers=args.workers)
212 | dict_dataloader_val = DataLoader(dict_dataset_val, batch_size=args.batch_size // 5)
213 | dict_dataloader_test = DataLoader(dict_dataset_test, batch_size=args.batch_size // 5)
214 |
215 | if not use_rl:
216 | train_loss = train_xe(model, dataloader_train, optim, text_field)
217 | writer.add_scalar('data/train_loss', train_loss, e)
218 | else:
219 | break
220 |
221 | # Validation loss
222 | val_loss = evaluate_loss(model, dataloader_val, loss_fn, text_field)
223 | writer.add_scalar('data/val_loss', val_loss, e)
224 |
225 | # Validation scores
226 | val_scores = evaluate_metrics(model, dataloader_val, text_field)
227 | print("epoch {}: Validation scores", val_scores)
228 | val_score = val_scores['accuracy']
229 | writer.add_scalar('data/val_score', val_score, e)
230 |
231 | # Test scores
232 | test_scores = evaluate_metrics(model, dataloader_test, text_field)
233 | print("epoch {}: Test scores", test_scores)
234 | test_score = test_scores['accuracy']
235 | writer.add_scalar('data/test_score', test_score, e)
236 |
237 | # Prepare for next epoch
238 | best = False
239 | if val_score >= best_score:
240 | best_score = val_score
241 | patience = 0
242 | best = True
243 | else:
244 | patience += 1
245 |
246 | best_test = False
247 | if test_score >= best_test_score:
248 | best_test_score = test_score
249 | best_test = True
250 |
251 | switch_to_rl = False
252 | exit_train = False
253 | if patience == 5:
254 | if not use_rl:
255 | use_rl = True
256 | switch_to_rl = True
257 | patience = 0
258 | break
259 | else:
260 | print('patience reached.')
261 | exit_train = True
262 |
263 | if switch_to_rl and not best:
264 | data = torch.load(os.path.join(args.dir_to_save_model, '%s_best.pth' % args.exp_name))
265 | torch.set_rng_state(data['torch_rng_state'])
266 | torch.cuda.set_rng_state(data['cuda_rng_state'])
267 | np.random.set_state(data['numpy_rng_state'])
268 | random.setstate(data['random_rng_state'])
269 | model.load_state_dict(data['state_dict'])
270 | print('Resuming from epoch %d, validation loss %f, and best score %f' % (
271 | data['epoch'], data['val_loss'], data['best_score']))
272 |
273 | torch.save({
274 | 'torch_rng_state': torch.get_rng_state(),
275 | 'cuda_rng_state': torch.cuda.get_rng_state(),
276 | 'numpy_rng_state': np.random.get_state(),
277 | 'random_rng_state': random.getstate(),
278 | 'epoch': e,
279 | 'val_loss': val_loss,
280 | 'val_score': val_score,
281 | 'state_dict': model.state_dict(),
282 | 'optimizer': optim.state_dict(),
283 | 'scheduler': scheduler.state_dict(),
284 | 'patience': patience,
285 | 'best_score': best_score,
286 | 'use_rl': use_rl,
287 | }, os.path.join(args.dir_to_save_model, '%s_last.pth' % args.exp_name))
288 |
289 | if best:
290 | copyfile(os.path.join(args.dir_to_save_model, '%s_last.pth' % args.exp_name), os.path.join(args.dir_to_save_model, '%s_best.pth' % args.exp_name))
291 | if best_test:
292 | copyfile(os.path.join(args.dir_to_save_model, '%s_last.pth' % args.exp_name), os.path.join(args.dir_to_save_model, '%s_best_test.pth' % args.exp_name))
293 |
294 | if exit_train:
295 | writer.close()
296 | break
297 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils import download_from_url
2 | from .typing import *
3 |
4 | def get_batch_size(x: TensorOrSequence) -> int:
5 | if isinstance(x, torch.Tensor):
6 | b_s = x.size(0)
7 | else:
8 | b_s = x[0].size(0)
9 | return b_s
10 |
11 |
12 | def get_device(x: TensorOrSequence) -> int:
13 | if isinstance(x, torch.Tensor):
14 | b_s = x.device
15 | else:
16 | b_s = x[0].device
17 | return b_s
18 |
--------------------------------------------------------------------------------
/utils/typing.py:
--------------------------------------------------------------------------------
1 | from typing import Union, Sequence, Tuple
2 | import torch
3 |
4 | TensorOrSequence = Union[Sequence[torch.Tensor], torch.Tensor]
5 | TensorOrNone = Union[torch.Tensor, None]
6 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import requests
2 |
3 | def download_from_url(url, path):
4 | """Download file, with logic (from tensor2tensor) for Google Drive"""
5 | if 'drive.google.com' not in url:
6 | print('Downloading %s; may take a few minutes' % url)
7 | r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'})
8 | with open(path, "wb") as file:
9 | file.write(r.content)
10 | return
11 | print('Downloading from Google Drive; may take a few minutes')
12 | confirm_token = None
13 | session = requests.Session()
14 | response = session.get(url, stream=True)
15 | for k, v in response.cookies.items():
16 | if k.startswith("download_warning"):
17 | confirm_token = v
18 |
19 | if confirm_token:
20 | url = url + "&confirm=" + confirm_token
21 | response = session.get(url, stream=True)
22 |
23 | chunk_size = 16 * 1024
24 | with open(path, "wb") as f:
25 | for chunk in response.iter_content(chunk_size):
26 | if chunk:
27 | f.write(chunk)
28 |
--------------------------------------------------------------------------------
/vocab.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangxuying1004/RSTNet/d308459ad1c28e0f4c2c3e4a9986ae76870e3c72/vocab.pkl
--------------------------------------------------------------------------------