├── misc
├── __init__.py
├── resnet_utils.py
├── loss_wrapper.py
├── resnet.py
├── rewards.py
└── utils.py
├── models
├── beam_search
│ ├── __init__.py
│ └── beam_search.py
├── __init__.py
├── transformer
│ ├── __init__.py
│ ├── utils.py
│ ├── transformer.py
│ ├── encoders.py
│ ├── decoders.py
│ └── attention.py
├── containers.py
└── captioning_model.py
├── evaluation
└── readme.txt
├── vocab_UCM.pkl
├── vocab_RSICD.pkl
├── vocab_Sydney.pkl
├── MG-Transformer.png
├── images
└── MG-Transformer.png
├── utils
├── typing.py
├── __init__.py
└── utils.py
├── data
├── __init__.py
├── example.py
├── utils.py
├── field.py
├── dataset.py
└── vocab.py
├── feature_pro
├── pre_CLIP_feature.py
├── pre_region_feature.py
└── split_group.py
├── LICENSE
├── README.md
├── environment.yml
├── test.py
└── train.py
/misc/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/beam_search/__init__.py:
--------------------------------------------------------------------------------
1 | from .beam_search import BeamSearch
2 |
--------------------------------------------------------------------------------
/evaluation/readme.txt:
--------------------------------------------------------------------------------
1 | 链接: https://pan.baidu.com/s/13ZfH-CMYbW3RsW0-RX7KKQ 提取码: wuiu
--------------------------------------------------------------------------------
/vocab_UCM.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/One-paper-luck/MG-Transformer/HEAD/vocab_UCM.pkl
--------------------------------------------------------------------------------
/vocab_RSICD.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/One-paper-luck/MG-Transformer/HEAD/vocab_RSICD.pkl
--------------------------------------------------------------------------------
/vocab_Sydney.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/One-paper-luck/MG-Transformer/HEAD/vocab_Sydney.pkl
--------------------------------------------------------------------------------
/MG-Transformer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/One-paper-luck/MG-Transformer/HEAD/MG-Transformer.png
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .transformer import Transformer
2 | from .captioning_model import CaptioningModel
3 |
--------------------------------------------------------------------------------
/images/MG-Transformer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/One-paper-luck/MG-Transformer/HEAD/images/MG-Transformer.png
--------------------------------------------------------------------------------
/models/transformer/__init__.py:
--------------------------------------------------------------------------------
1 | from .transformer import *
2 | from .encoders import *
3 | from .decoders import *
4 | from .attention import *
5 |
--------------------------------------------------------------------------------
/utils/typing.py:
--------------------------------------------------------------------------------
1 | from typing import Union, Sequence, Tuple
2 | import torch
3 |
4 | TensorOrSequence = Union[Sequence[torch.Tensor], torch.Tensor]
5 | TensorOrNone = Union[torch.Tensor, None]
6 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils import download_from_url
2 | from .typing import *
3 |
4 | def get_batch_size(x: TensorOrSequence) -> int:
5 | if isinstance(x, torch.Tensor): # x 2,5,2048
6 | b_s = x.size(0)
7 | else:
8 | b_s = x[0].size(0)
9 | return b_s
10 |
11 |
12 | def get_device(x: TensorOrSequence) -> int:
13 | if isinstance(x, torch.Tensor):
14 | b_s = x.device
15 | else:
16 | b_s = x[0].device
17 | return b_s
18 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .field import RawField, Merge, ImageDetectionsField, TextField
2 | from .dataset import Sydney,UCM,RSICD
3 | from torch.utils.data import DataLoader as TorchDataLoader
4 |
5 | class DataLoader(TorchDataLoader):
6 | def __init__(self, dataset, *args, **kwargs):
7 | super(DataLoader, self).__init__(dataset, *args, collate_fn=dataset.collate_fn(), **kwargs)
8 |
9 | def __len__(self):
10 | return len(self.dataset.examples)//self.batch_size
11 |
--------------------------------------------------------------------------------
/misc/resnet_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class myResnet(nn.Module): #
7 | def __init__(self, resnet):
8 | super(myResnet, self).__init__()
9 | self.resnet = resnet
10 |
11 | def forward(self, img, att_size=14):
12 | x = img.unsqueeze(0)
13 | x = self.resnet.conv1(x)
14 | x = self.resnet.bn1(x)
15 | x = self.resnet.relu(x)
16 | x = self.resnet.maxpool(x)
17 | x = self.resnet.layer1(x)
18 | x = self.resnet.layer2(x)
19 | x = self.resnet.layer3(x)
20 | x = self.resnet.layer4(x)
21 |
22 | return x.squeeze()
23 |
--------------------------------------------------------------------------------
/data/example.py:
--------------------------------------------------------------------------------
1 | class Example(object):
2 | """Defines a single training or test example.
3 | Stores each column of the example as an attribute.
4 | 定义单个训练或测试example。
5 | 将example的每一列存储为一个属性。
6 | """
7 |
8 | @classmethod
9 | def fromdict(cls, data):
10 | ex = cls(data)
11 | return ex
12 |
13 | def __init__(self, data):
14 | for key, val in data.items():
15 | super(Example, self).__setattr__(key, val)
16 |
17 | def __setattr__(self, key, value):
18 | raise AttributeError
19 |
20 | def __hash__(self):
21 | return hash(tuple(x for x in self.__dict__.values()))
22 |
23 | def __eq__(self, other):
24 | this = tuple(x for x in self.__dict__.values())
25 | other = tuple(x for x in other.__dict__.values())
26 | return this == other
27 |
28 | def __ne__(self, other):
29 | return not self.__eq__(other)
30 |
--------------------------------------------------------------------------------
/feature_pro/pre_CLIP_feature.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import clip
3 | from PIL import Image
4 | import numpy as np
5 | import os
6 |
7 | device = "cuda" if torch.cuda.is_available() else "cpu"
8 | model, preprocess = clip.load("ViT-B/32", device=device)
9 |
10 | # sydney
11 | path = r"/media/dmd/ours/mlw/rs/Sydney_Captions/imgs"
12 | output_dir='/media/dmd/ours/mlw/rs/clip_feature/sydney_224'
13 |
14 |
15 | for root, dirs, files in os.walk(path):
16 | for i in files:
17 | i = os.path.join(path, i)
18 | image_name = i.split('/')[-1]
19 | image = preprocess(Image.open(i)).unsqueeze(0).to(device) # 1,3,224,224
20 | with torch.no_grad():
21 | image_features = model.encode_image(image) #
22 | """
23 | Note: Delete part of the code in CLIP VisionTransformer as follows:
24 | if self.proj is not None:
25 | x = x @ self.proj
26 | """
27 |
28 | np.save(os.path.join(output_dir, str(image_name)), image_features.data.cpu().float().numpy())
29 |
30 |
31 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import requests
2 |
3 | def download_from_url(url, path):
4 | """Download file, with logic (from tensor2tensor) for Google Drive"""
5 | if 'drive.google.com' not in url:
6 | print('Downloading %s; may take a few minutes' % url)
7 | r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'})
8 | with open(path, "wb") as file:
9 | file.write(r.content)
10 | return
11 | print('Downloading from Google Drive; may take a few minutes')
12 | confirm_token = None
13 | session = requests.Session()
14 | response = session.get(url, stream=True)
15 | for k, v in response.cookies.items():
16 | if k.startswith("download_warning"):
17 | confirm_token = v
18 |
19 | if confirm_token:
20 | url = url + "&confirm=" + confirm_token
21 | response = session.get(url, stream=True)
22 |
23 | chunk_size = 16 * 1024
24 | with open(path, "wb") as f:
25 | for chunk in response.iter_content(chunk_size):
26 | if chunk:
27 | f.write(chunk)
28 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 One-paper-luck
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MG-Transformer
2 |
3 |
4 |
5 |
6 | ## Installation and Dependencies
7 | Create the `m2` conda environment using the `environment.yml` file:
8 | ```
9 | conda env create -f environment.yml
10 | conda activate m2
11 | ```
12 | ## Data preparation
13 | For the evaluation metrics, Please download the [evaluation.zip](https://pan.baidu.com/s/13ZfH-CMYbW3RsW0-RX7KKQ)(BaiduPan code:wuiu) and extract it to `./evaluation`.
14 |
15 |
16 | For Feature Extraction:
17 | Region feature: `./feature_pro/pre_region_feature.py`
18 | CLIP image embedding: `./feature_pro/pre_CLIP_feature.py`
19 | Group mask matrix: `./feature_pro/split_group.py`
20 |
21 |
22 | ## Train
23 | ```
24 | python train.py
25 | ```
26 |
27 | ## Evaluate
28 | ```
29 | python test.py
30 | ```
31 |
32 |
33 | # Citation:
34 | ```
35 | @ARTICLE{10298250,
36 | author={Meng, Lingwu and Wang, Jing and Meng, Ran and Yang, Yang and Xiao, Liang},
37 | journal={IEEE Transactions on Geoscience and Remote Sensing},
38 | title={A Multiscale Grouping Transformer with CLIP Latents for Remote Sensing Image Captioning},
39 | year={2024},
40 | volume={62},
41 | number={4703515},
42 | pages={1-15},
43 | doi={10.1109/TGRS.2024.3385500}}
44 | ```
45 |
46 |
47 |
48 | ## Reference:
49 | 1. https://github.com/tylin/coco-caption
50 | 2. https://github.com/aimagelab/meshed-memory-transformer
51 |
--------------------------------------------------------------------------------
/misc/loss_wrapper.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import misc.utils as utils
3 | from misc.rewards import init_scorer, get_self_critical_reward
4 |
5 |
6 | class LossWrapper(torch.nn.Module):
7 | def __init__(self, model, opt):
8 | super(LossWrapper, self).__init__()
9 | self.opt = opt
10 | self.model = model
11 | if opt.label_smoothing > 0:
12 | self.crit = utils.LabelSmoothing(smoothing=opt.label_smoothing)
13 | else:
14 | self.crit = utils.LanguageModelCriterion()
15 | self.rl_crit = utils.RewardCriterion()
16 |
17 | def forward(self, fc_feats, att_feats, labels, masks, att_masks, gts, gt_indices,
18 | sc_flag):
19 | out = {}
20 | if not sc_flag:
21 | loss = self.crit(self.model(fc_feats, att_feats, labels, att_masks), labels[:, 1:], masks[:, 1:])
22 | else:
23 | self.model.eval()
24 | with torch.no_grad():
25 | greedy_res, _ = self.model(fc_feats, att_feats, att_masks, mode='sample')
26 | self.model.train()
27 | gen_result, sample_logprobs = self.model(fc_feats, att_feats, att_masks, opt={'sample_method': 'sample'},
28 | mode='sample')
29 | gts = [gts[_] for _ in gt_indices.tolist()]
30 | reward = get_self_critical_reward(greedy_res, gts, gen_result, self.opt)
31 | reward = torch.from_numpy(reward).float().to(gen_result.device)
32 | loss = self.rl_crit(sample_logprobs, gen_result.data, reward)
33 | out['reward'] = reward[:, 0].mean()
34 | out['loss'] = loss
35 | return out
36 |
--------------------------------------------------------------------------------
/models/transformer/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 | def l2norm(X, dim=-1, eps=1e-12):
6 | """L2-normalize columns of X
7 | """
8 | norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps
9 | X = torch.div(X, norm)
10 | return X
11 |
12 | def position_embedding(input, d_model):
13 | input = input.view(-1, 1)
14 | dim = torch.arange(d_model // 2, dtype=torch.float32, device=input.device).view(1, -1)
15 | sin = torch.sin(input / 10000 ** (2 * dim / d_model))
16 | cos = torch.cos(input / 10000 ** (2 * dim / d_model))
17 |
18 | out = torch.zeros((input.shape[0], d_model), device=input.device)
19 | out[:, ::2] = sin
20 | out[:, 1::2] = cos
21 | return out
22 |
23 |
24 | def sinusoid_encoding_table(max_len, d_model, padding_idx=None):
25 | pos = torch.arange(max_len, dtype=torch.float32)
26 | out = position_embedding(pos, d_model)
27 |
28 | if padding_idx is not None:
29 | out[padding_idx] = 0
30 | return out
31 |
32 |
33 | class PositionWiseFeedForward(nn.Module):
34 | '''
35 | Position-wise feed forward layer
36 | '''
37 |
38 | def __init__(self, d_model=512, d_ff=2048, dropout=.1, identity_map_reordering=False):
39 | super(PositionWiseFeedForward, self).__init__()
40 | self.identity_map_reordering = identity_map_reordering
41 | self.fc1 = nn.Linear(d_model, d_ff)
42 | self.fc2 = nn.Linear(d_ff, d_model)
43 | self.dropout = nn.Dropout(p=dropout)
44 | self.dropout_2 = nn.Dropout(p=dropout)
45 | self.layer_norm = nn.LayerNorm(d_model)
46 |
47 | def forward(self, input):
48 | if self.identity_map_reordering:
49 | out = self.layer_norm(input)
50 | out = self.fc2(self.dropout_2(F.relu(self.fc1(out))))
51 | out = input + self.dropout(torch.relu(out))
52 | else:
53 | out = self.fc2(self.dropout_2(F.relu(self.fc1(input))))
54 | out = self.dropout(out)
55 | out = self.layer_norm(input + out)
56 | return out
--------------------------------------------------------------------------------
/misc/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.utils import model_zoo
4 | import torchvision.models.resnet
5 | from torchvision.models.resnet import BasicBlock, Bottleneck
6 |
7 | class ResNet(torchvision.models.resnet.ResNet):
8 | def __init__(self, block, layers, num_classes=1000):
9 | super(ResNet, self).__init__(block, layers, num_classes)
10 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) # change
11 | for i in range(2, 5):
12 | getattr(self, 'layer%d'%i)[0].conv1.stride = (2,2)
13 | getattr(self, 'layer%d'%i)[0].conv2.stride = (1,1)
14 |
15 | def resnet18(pretrained=False):
16 | """Constructs a ResNet-18 model.
17 |
18 | Args:
19 | pretrained (bool): If True, returns a model pre-trained on ImageNet
20 | """
21 | model = ResNet(BasicBlock, [2, 2, 2, 2])
22 | if pretrained:
23 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
24 | return model
25 |
26 |
27 | def resnet34(pretrained=False):
28 | """Constructs a ResNet-34 model.
29 |
30 | Args:
31 | pretrained (bool): If True, returns a model pre-trained on ImageNet
32 | """
33 | model = ResNet(BasicBlock, [3, 4, 6, 3])
34 | if pretrained:
35 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
36 | return model
37 |
38 |
39 | def resnet50(pretrained=False):
40 | """Constructs a ResNet-50 model.
41 |
42 | Args:
43 | pretrained (bool): If True, returns a model pre-trained on ImageNet
44 | """
45 | model = ResNet(Bottleneck, [3, 4, 6, 3])
46 | if pretrained:
47 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
48 | return model
49 |
50 |
51 | def resnet101(pretrained=False):
52 | """Constructs a ResNet-101 model.
53 |
54 | Args:
55 | pretrained (bool): If True, returns a model pre-trained on ImageNet
56 | """
57 | model = ResNet(Bottleneck, [3, 4, 23, 3])
58 | if pretrained:
59 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
60 | return model
61 |
62 |
63 | def resnet152(pretrained=False):
64 | """Constructs a ResNet-152 model.
65 |
66 | Args:
67 | pretrained (bool): If True, returns a model pre-trained on ImageNet
68 | """
69 | model = ResNet(Bottleneck, [3, 8, 36, 3])
70 | if pretrained:
71 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
72 | return model
--------------------------------------------------------------------------------
/misc/rewards.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import numpy as np
6 | import time
7 | import misc.utils as utils
8 | from collections import OrderedDict
9 | import torch
10 |
11 |
12 | from cider.pyciderevalcap.ciderD.ciderD import CiderD
13 | from cococaption.pycocoevalcap.bleu.bleu import Bleu
14 |
15 | CiderD_scorer = None
16 | Bleu_scorer = None
17 | #CiderD_scorer = CiderD(df='corpus')
18 |
19 | def init_scorer(cached_tokens):
20 | global CiderD_scorer
21 | CiderD_scorer = CiderD_scorer or CiderD(df=cached_tokens)
22 | global Bleu_scorer
23 | Bleu_scorer = Bleu_scorer or Bleu(4)
24 |
25 | def array_to_str(arr):
26 | out = ''
27 | for i in range(len(arr)):
28 | out += str(arr[i]) + ' '
29 | if arr[i] == 0:
30 | break
31 | return out.strip()
32 |
33 | def get_self_critical_reward(greedy_res, data_gts, gen_result, opt):
34 | batch_size = gen_result.size(0)# batch_size = sample_size * seq_per_img
35 | seq_per_img = batch_size // len(data_gts)
36 |
37 | res = OrderedDict()
38 |
39 | gen_result = gen_result.data.cpu().numpy()
40 | greedy_res = greedy_res.data.cpu().numpy()
41 | for i in range(batch_size):
42 | res[i] = [array_to_str(gen_result[i])]
43 | for i in range(batch_size):
44 | res[batch_size + i] = [array_to_str(greedy_res[i])]
45 |
46 | gts = OrderedDict()
47 | for i in range(len(data_gts)):
48 | gts[i] = [array_to_str(data_gts[i][j]) for j in range(len(data_gts[i]))]
49 |
50 | res_ = [{'image_id':i, 'caption': res[i]} for i in range(2 * batch_size)]
51 | res__ = {i: res[i] for i in range(2 * batch_size)}
52 | gts = {i: gts[i % batch_size // seq_per_img] for i in range(2 * batch_size)}
53 | if opt.cider_reward_weight > 0:
54 | _, cider_scores = CiderD_scorer.compute_score(gts, res_)
55 | print('Cider scores:', _)
56 | else:
57 | cider_scores = 0
58 | if opt.bleu_reward_weight > 0:
59 | _, bleu_scores = Bleu_scorer.compute_score(gts, res__)
60 | bleu_scores = np.array(bleu_scores[3])
61 | print('Bleu scores:', _[3])
62 | else:
63 | bleu_scores = 0
64 | scores = opt.cider_reward_weight * cider_scores + opt.bleu_reward_weight * bleu_scores
65 |
66 | scores = scores[:batch_size] - scores[batch_size:]
67 |
68 | rewards = np.repeat(scores[:, np.newaxis], gen_result.shape[1], 1)
69 |
70 | return rewards
71 |
--------------------------------------------------------------------------------
/models/transformer/transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import copy
4 | from models.containers import ModuleList
5 | from ..captioning_model import CaptioningModel
6 |
7 |
8 | class Transformer(CaptioningModel):
9 | def __init__(self, bos_idx, encoder, decoder):
10 | super(Transformer, self).__init__()
11 | self.bos_idx = bos_idx
12 | self.encoder = encoder
13 | self.decoder = decoder
14 | self.register_state('enc_output', None)
15 | self.register_state('mask_enc', None)
16 | self.init_weights()
17 |
18 | @property
19 | def d_model(self):
20 | return self.decoder.d_model
21 |
22 | def init_weights(self):
23 | for p in self.parameters():
24 | if p.dim() > 1:
25 | nn.init.xavier_uniform_(p)
26 |
27 | def forward(self, images, images_gl, detections_mask, seq, isencoder=None, *args):
28 |
29 | enc_output, mask_enc = self.encoder(images, images_gl,detections_mask,isencoder=isencoder)
30 | dec_output = self.decoder(seq, enc_output, mask_enc)
31 | return dec_output
32 |
33 | def init_state(self, b_s, device):
34 | return [torch.zeros((b_s, 0), dtype=torch.long, device=device),
35 | None, None]
36 |
37 | def step(self, t, prev_output, visual, visual_gl,detections_mask, seq, mode='teacher_forcing', **kwargs):
38 | it = None
39 | if mode == 'teacher_forcing':
40 | raise NotImplementedError
41 | elif mode == 'feedback':
42 | if t == 0:
43 | self.enc_output, self.mask_enc = self.encoder(visual,visual_gl,detections_mask,isencoder=True)
44 | if isinstance(visual, torch.Tensor):
45 | it = visual.data.new_full((visual.shape[0], 1), self.bos_idx).long()
46 | else:
47 | it = visual[0].data.new_full((visual[0].shape[0], 1), self.bos_idx).long()
48 | else:
49 | it = prev_output
50 |
51 | return self.decoder(it, self.enc_output, self.mask_enc)
52 |
53 |
54 | class TransformerEnsemble(CaptioningModel):
55 | def __init__(self, model: Transformer, weight_files):
56 | super(TransformerEnsemble, self).__init__()
57 | self.n = len(weight_files)
58 | self.models = ModuleList([copy.deepcopy(model) for _ in range(self.n)])
59 | for i in range(self.n):
60 | state_dict_i = torch.load(weight_files[i])['state_dict']
61 | self.models[i].load_state_dict(state_dict_i)
62 |
63 | def step(self, t, prev_output, visual, seq, mode='teacher_forcing', **kwargs):
64 | out_ensemble = []
65 | for i in range(self.n):
66 | out_i = self.models[i].step(t, prev_output, visual, seq, mode, **kwargs)
67 | out_ensemble.append(out_i.unsqueeze(0))
68 |
69 | return torch.mean(torch.cat(out_ensemble, 0), dim=0)
70 |
--------------------------------------------------------------------------------
/feature_pro/pre_region_feature.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import os
6 | import json
7 | import argparse
8 | import numpy as np
9 | import torch
10 | import skimage.io
11 | import skimage.transform
12 | from torchvision import transforms as trn
13 | from misc.resnet_utils import myResnet
14 | import misc.resnet as resnet
15 |
16 | preprocess = trn.Compose([
17 | # trn.ToTensor(),
18 | trn.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
19 | ])
20 |
21 |
22 | def main(params):
23 | net = getattr(resnet, params['model'])()
24 | net.load_state_dict(torch.load(os.path.join(params['model_root'], params['model'] + '.pth')))
25 | my_resnet = myResnet(net)
26 | my_resnet.cuda()
27 | my_resnet.eval()
28 |
29 | imgs = json.load(open(params['input_json'], 'r'))
30 | imgs = imgs['images']
31 | N = len(imgs)
32 |
33 | for i, img in enumerate(imgs):
34 | I = skimage.io.imread(os.path.join(params['images_root'], img['filename'])) # sydney 500 500 3
35 | I = skimage.transform.resize(I, (224, 224))
36 |
37 | if len(I.shape) == 2:
38 | I = I[:, :, np.newaxis]
39 | I = np.concatenate((I, I, I), axis=2)
40 |
41 | I = I.astype('float32') / 255.0
42 | I = torch.from_numpy(I.transpose([2, 0, 1])).cuda()
43 | I = preprocess(I)
44 |
45 | with torch.no_grad():
46 | conv5= my_resnet(I,params['att_size'])
47 |
48 | # np.save(os.path.join(params['output_dir'], str(img['filename'])), conv5.data.cpu().float().numpy())
49 |
50 | if i % 1000 == 0:
51 | print('processing %d/%d (%.2f%% done)' % (i, N, i * 100.0 / N))
52 | print('wrote ', params['output_dir'])
53 |
54 |
55 | if __name__ == "__main__":
56 | parser = argparse.ArgumentParser()
57 | parser.add_argument('--input_json', default='/media/dmd/ours/mlw/rs/Sydney_Captions/dataset.json')
58 | # parser.add_argument('--output_dir', default='/media/dmd/ours/mlw/rs/multi_scale_T/sydney_res152_con5_500')
59 | parser.add_argument('--output_dir', default='/media/dmd/ours/mlw/rs/multi_scale_T/sydney_name')
60 | parser.add_argument('--images_root', default='/media/dmd/ours/mlw/rs/Sydney_Captions/imgs',
61 | help='root location in which images are stored, to be prepended to file_path in input json')
62 |
63 | parser.add_argument('--att_size', default=14, type=int, help='14x14 or 7x7')
64 | parser.add_argument('--model', default='resnet152', type=str, help='resnet50,resnet101, resnet152')
65 | parser.add_argument('--model_root', default='/media/dmd/ours/mlw/pre_model', type=str,
66 | help='model root')
67 |
68 | args = parser.parse_args()
69 | params = vars(args) # convert to ordinary dict
70 | print('parsed input parameters:')
71 | print(json.dumps(params, indent=2))
72 | main(params)
73 |
--------------------------------------------------------------------------------
/models/containers.py:
--------------------------------------------------------------------------------
1 | from contextlib import contextmanager
2 | from torch import nn
3 | from utils.typing import *
4 |
5 |
6 | class Module(nn.Module):
7 | def __init__(self):
8 | super(Module, self).__init__()
9 | self._is_stateful = False
10 | self._state_names = []
11 | self._state_defaults = dict()
12 |
13 | def register_state(self, name: str, default: TensorOrNone):
14 | self._state_names.append(name)
15 | if default is None:
16 | self._state_defaults[name] = None
17 | else:
18 | self._state_defaults[name] = default.clone().detach()
19 | self.register_buffer(name, default)
20 |
21 | def states(self):
22 | for name in self._state_names:
23 | yield self._buffers[name]
24 | for m in self.children():
25 | if isinstance(m, Module):
26 | yield from m.states()
27 |
28 | def apply_to_states(self, fn):
29 | for name in self._state_names:
30 | self._buffers[name] = fn(self._buffers[name])
31 | for m in self.children():
32 | if isinstance(m, Module):
33 | m.apply_to_states(fn)
34 |
35 | def _init_states(self, batch_size: int):
36 | for name in self._state_names:
37 | if self._state_defaults[name] is None:
38 | self._buffers[name] = None
39 | else:
40 | self._buffers[name] = self._state_defaults[name].clone().detach().to(self._buffers[name].device)
41 | self._buffers[name] = self._buffers[name].unsqueeze(0)
42 | self._buffers[name] = self._buffers[name].expand([batch_size, ] + list(self._buffers[name].shape[1:]))
43 | self._buffers[name] = self._buffers[name].contiguous()
44 |
45 | def _reset_states(self):
46 | for name in self._state_names:
47 | if self._state_defaults[name] is None:
48 | self._buffers[name] = None
49 | else:
50 | self._buffers[name] = self._state_defaults[name].clone().detach().to(self._buffers[name].device)
51 |
52 | def enable_statefulness(self, batch_size: int):
53 | for m in self.children():
54 | if isinstance(m, Module):
55 | m.enable_statefulness(batch_size)
56 | self._init_states(batch_size)
57 | self._is_stateful = True
58 |
59 | def disable_statefulness(self):
60 | for m in self.children():
61 | if isinstance(m, Module):
62 | m.disable_statefulness()
63 | self._reset_states()
64 | self._is_stateful = False
65 |
66 | @contextmanager
67 | def statefulness(self, batch_size: int):
68 | self.enable_statefulness(batch_size)
69 | try:
70 | yield
71 | finally:
72 | self.disable_statefulness()
73 |
74 |
75 | class ModuleList(nn.ModuleList, Module):
76 | pass
77 |
78 |
79 | class ModuleDict(nn.ModuleDict, Module):
80 | pass
81 |
--------------------------------------------------------------------------------
/models/captioning_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import distributions
3 | import utils
4 | from models.containers import Module
5 | from models.beam_search import *
6 |
7 |
8 | class CaptioningModel(Module):
9 | def __init__(self):
10 | super(CaptioningModel, self).__init__()
11 |
12 | def init_weights(self):
13 | raise NotImplementedError
14 |
15 | def step(self, t, prev_output, visual, seq, mode='teacher_forcing', **kwargs):
16 | raise NotImplementedError
17 |
18 | def forward(self, images, seq, *args):
19 | device = images.device
20 | b_s = images.size(0)
21 | seq_len = seq.size(1)
22 | state = self.init_state(b_s, device)
23 | out = None
24 |
25 | outputs = []
26 | for t in range(seq_len):
27 | out, state = self.step(t, state, out, images, seq, *args, mode='teacher_forcing')
28 | outputs.append(out)
29 |
30 | outputs = torch.cat([o.unsqueeze(1) for o in outputs], 1)
31 | return outputs
32 |
33 | def test(self, visual: utils.TensorOrSequence, max_len: int, eos_idx: int, **kwargs) -> utils.Tuple[torch.Tensor, torch.Tensor]:
34 | b_s = utils.get_batch_size(visual)
35 | device = utils.get_device(visual)
36 | outputs = []
37 | log_probs = []
38 |
39 | mask = torch.ones((b_s,), device=device)
40 | with self.statefulness(b_s):
41 | out = None
42 | for t in range(max_len):
43 | log_probs_t = self.step(t, out, visual, None, mode='feedback', **kwargs)
44 | out = torch.max(log_probs_t, -1)[1]
45 | mask = mask * (out.squeeze(-1) != eos_idx).float()
46 | log_probs.append(log_probs_t * mask.unsqueeze(-1).unsqueeze(-1))
47 | outputs.append(out)
48 |
49 | return torch.cat(outputs, 1), torch.cat(log_probs, 1)
50 |
51 | def sample_rl(self, visual: utils.TensorOrSequence, max_len: int, **kwargs) -> utils.Tuple[torch.Tensor, torch.Tensor]:
52 | b_s = utils.get_batch_size(visual)
53 | outputs = []
54 | log_probs = []
55 |
56 | with self.statefulness(b_s):
57 | out = None
58 | for t in range(max_len):
59 | out = self.step(t, out, visual, None, mode='feedback', **kwargs)
60 | distr = distributions.Categorical(logits=out[:, 0])
61 | out = distr.sample().unsqueeze(1)
62 | outputs.append(out)
63 | log_probs.append(distr.log_prob(out).unsqueeze(1))
64 |
65 | return torch.cat(outputs, 1), torch.cat(log_probs, 1)
66 |
67 | def beam_search(self, visual: utils.TensorOrSequence, visual_gl: utils.TensorOrSequence,detections_mask: utils.TensorOrSequence,
68 | max_len: int, eos_idx: int, beam_size: int, out_size=1,return_probs=False, **kwargs):
69 | bs = BeamSearch(self, max_len, eos_idx, beam_size)
70 | return bs.apply(visual, visual_gl,detections_mask, out_size, return_probs, **kwargs)
71 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: m2
2 | channels:
3 | - defaults
4 | dependencies:
5 | - _libgcc_mutex=0.1=main
6 | - _openmp_mutex=5.1=1_gnu
7 | - blas=1.0=mkl
8 | - ca-certificates=2022.07.19=h06a4308_0
9 | - certifi=2021.5.30=py36h06a4308_0
10 | - intel-openmp=2022.0.1=h06a4308_3633
11 | - joblib=1.0.1=pyhd3eb1b0_0
12 | - ld_impl_linux-64=2.38=h1181459_1
13 | - libffi=3.3=he6710b0_2
14 | - libgcc-ng=11.2.0=h1234567_1
15 | - libgfortran-ng=7.5.0=ha8ba4b0_17
16 | - libgfortran4=7.5.0=ha8ba4b0_17
17 | - libgomp=11.2.0=h1234567_1
18 | - libstdcxx-ng=11.2.0=h1234567_1
19 | - mkl=2020.2=256
20 | - mkl-service=2.3.0=py36he8ac12f_0
21 | - mkl_fft=1.3.0=py36h54f3939_0
22 | - mkl_random=1.1.1=py36h0573a6f_0
23 | - ncurses=6.3=h7f8727e_2
24 | - numpy-base=1.19.2=py36hfa32c7d_0
25 | - openssl=1.1.1q=h7f8727e_0
26 | - pip=21.2.2=py36h06a4308_0
27 | - python=3.6.13=h12debd9_1
28 | - readline=8.1.2=h7f8727e_1
29 | - scikit-learn=0.24.2=py36ha9443f7_0
30 | - setuptools=58.0.4=py36h06a4308_0
31 | - six=1.16.0=pyhd3eb1b0_1
32 | - sqlite=3.38.3=hc218d9a_0
33 | - threadpoolctl=2.2.0=pyh0d69192_0
34 | - tk=8.6.12=h1ccaba5_0
35 | - wheel=0.37.1=pyhd3eb1b0_0
36 | - xz=5.2.5=h7f8727e_1
37 | - zlib=1.2.12=h7f8727e_2
38 | - pip:
39 | - absl-py==1.1.0
40 | - blis==0.2.4
41 | - cached-property==1.5.2
42 | - cachetools==4.2.4
43 | - charset-normalizer==2.0.12
44 | - cycler==0.11.0
45 | - cymem==2.0.6
46 | - cython==0.29.30
47 | - dataclasses==0.8
48 | - decorator==4.4.2
49 | - en-core-web-md==2.1.0
50 | - en-core-web-sm==2.1.0
51 | - google-auth==2.7.0
52 | - google-auth-oauthlib==0.4.6
53 | - grpcio==1.46.3
54 | - h5py==3.1.0
55 | - idna==3.3
56 | - imagecodecs==2020.5.30
57 | - imageio==2.15.0
58 | - importlib-metadata==4.8.3
59 | - importlib-resources==5.4.0
60 | - jsonschema==2.6.0
61 | - kiwisolver==1.3.1
62 | - markdown==3.3.7
63 | - matplotlib==3.3.4
64 | - murmurhash==1.0.7
65 | - networkx==2.5.1
66 | - numpy==1.19.5
67 | - oauthlib==3.2.0
68 | - opencv-contrib-python==4.6.0.66
69 | - opencv-python==4.5.1.48
70 | - param==1.13.0
71 | - pillow==8.4.0
72 | - plac==0.9.6
73 | - preshed==2.0.1
74 | - protobuf==3.19.4
75 | - pyasn1==0.4.8
76 | - pyasn1-modules==0.2.8
77 | - pycocotools==2.0.1
78 | - pyct==0.4.8
79 | - pyheatmap==0.1.12
80 | - pyparsing==3.0.9
81 | - python-dateutil==2.8.2
82 | - pywavelets==1.1.1
83 | - requests==2.27.1
84 | - requests-oauthlib==1.3.1
85 | - rsa==4.8
86 | - scikit-image==0.17.2
87 | - scipy==1.5.4
88 | - spacy==2.1.0
89 | - srsly==1.0.5
90 | - tensorboard==2.9.0
91 | - tensorboard-data-server==0.6.1
92 | - tensorboard-plugin-wit==1.8.1
93 | - thinc==7.0.8
94 | - thop==0.1.1-2209072238
95 | - tifffile==2020.9.3
96 | - torch==1.10.0+cu113
97 | - torchvision==0.11.0+cu113
98 | - tqdm==4.64.0
99 | - typing-extensions==4.1.1
100 | - urllib3==1.26.9
101 | - wasabi==0.9.1
102 | - werkzeug==2.0.3
103 | - zipp==3.6.0
104 |
--------------------------------------------------------------------------------
/data/utils.py:
--------------------------------------------------------------------------------
1 | import contextlib, sys
2 |
3 |
4 | class DummyFile(object):
5 | def write(self, x): pass
6 |
7 |
8 | @contextlib.contextmanager
9 | def nostdout():
10 | save_stdout = sys.stdout
11 | sys.stdout = DummyFile()
12 | yield
13 | sys.stdout = save_stdout
14 |
15 |
16 | def reporthook(t):
17 | """https://github.com/tqdm/tqdm"""
18 | last_b = [0]
19 |
20 | def inner(b=1, bsize=1, tsize=None):
21 | """
22 | b: int, optionala
23 | Number of blocks just transferred [default: 1].
24 | bsize: int, optional
25 | Size of each block (in tqdm units) [default: 1].
26 | tsize: int, optional
27 | Total size (in tqdm units). If [default: None] remains unchanged.
28 | """
29 | if tsize is not None:
30 | t.total = tsize
31 | t.update((b - last_b[0]) * bsize)
32 | last_b[0] = b
33 |
34 | return inner
35 |
36 |
37 | def get_tokenizer(tokenizer):
38 | if callable(tokenizer):
39 | return tokenizer
40 | if tokenizer == "spacy":
41 | try:
42 | import spacy
43 | # spacy_en = spacy.load('en')
44 | spacy_en = spacy.load('en_core_web_sm')
45 | return lambda s: [tok.text for tok in spacy_en.tokenizer(s)]
46 | except ImportError:
47 | print("Please install SpaCy and the SpaCy English tokenizer. "
48 | "See the docs at https://spacy.io for more information.")
49 | raise
50 | except AttributeError:
51 | print("Please install SpaCy and the SpaCy English tokenizer. "
52 | "See the docs at https://spacy.io for more information.")
53 | raise
54 | elif tokenizer == "moses":
55 | try:
56 | from nltk.tokenize.moses import MosesTokenizer
57 | moses_tokenizer = MosesTokenizer()
58 | return moses_tokenizer.tokenize
59 | except ImportError:
60 | print("Please install NLTK. "
61 | "See the docs at http://nltk.org for more information.")
62 | raise
63 | except LookupError:
64 | print("Please install the necessary NLTK corpora. "
65 | "See the docs at http://nltk.org for more information.")
66 | raise
67 | elif tokenizer == 'revtok':
68 | try:
69 | import revtok
70 | return revtok.tokenize
71 | except ImportError:
72 | print("Please install revtok.")
73 | raise
74 | elif tokenizer == 'subword':
75 | try:
76 | import revtok
77 | return lambda x: revtok.tokenize(x, decap=True)
78 | except ImportError:
79 | print("Please install revtok.")
80 | raise
81 | raise ValueError("Requested tokenizer {}, valid choices are a "
82 | "callable that takes a single string as input, "
83 | "\"revtok\" for the revtok reversible tokenizer, "
84 | "\"subword\" for the revtok caps-aware tokenizer, "
85 | "\"spacy\" for the SpaCy English tokenizer, or "
86 | "\"moses\" for the NLTK port of the Moses tokenization "
87 | "script.".format(tokenizer))
88 |
--------------------------------------------------------------------------------
/feature_pro/split_group.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | import torch.nn.functional as F
5 | import random
6 | random.seed(1234)
7 | torch.manual_seed(1234)
8 | np.random.seed(1234)
9 |
10 | def readInfo(filePath):
11 | name = os.listdir(filePath)
12 | return name
13 |
14 | def data_normal(data):
15 | d_min=data.min()
16 | if d_min<0:
17 | data+=torch.abs(d_min)
18 | d_min=data.min()
19 |
20 | d_max=data.max()
21 | dst=d_max-d_min
22 | nor_data=(data-d_min).true_divide(dst)
23 | return nor_data
24 |
25 |
26 | features_path='/media/dmd/ours/mlw/rs/multi_scale_T/sydney_res152_con5_224'
27 | out_path='/media/dmd/ours/mlw/rs/multi_scale_T/sydney_group_mask_8_6'
28 |
29 | num=6
30 |
31 |
32 |
33 | fileList = readInfo(features_path)
34 |
35 | for i, img in enumerate(fileList):
36 | res_feature=np.load(features_path + '/%s' % img)
37 | r_feature = res_feature.reshape(res_feature.shape[0], res_feature.shape[1] * res_feature.shape[2])
38 | r_feature=torch.Tensor(r_feature)
39 |
40 | res_feature=torch.Tensor(res_feature)
41 | avg_feature = F.adaptive_avg_pool2d(res_feature, [1, 1]).squeeze()
42 | avg_feature=avg_feature.unsqueeze(0)
43 | p=torch.matmul(avg_feature, r_feature)
44 |
45 | pp=data_normal(p)
46 |
47 | # divide group 8
48 | group_mask = torch.ones((8, 49, 49)).cuda() # res
49 |
50 | index_1 = torch.nonzero(pp <= 1 / 8, as_tuple=False)[:, 1]
51 | index_2 = torch.nonzero((pp <= 2 / 8) & (pp > 1 / 8), as_tuple=False)[:, 1]
52 | index_3 = torch.nonzero((pp <= 3 / 8) & (pp > 2 / 8), as_tuple=False)[:, 1]
53 | index_4 = torch.nonzero((pp <= 4 / 8) & (pp > 3 / 8), as_tuple=False)[:, 1]
54 | index_5 = torch.nonzero((pp <= 5 / 8) & (pp > 4 / 8), as_tuple=False)[:, 1]
55 | index_6 = torch.nonzero((pp <= 6 / 8) & (pp > 5 / 8), as_tuple=False)[:, 1]
56 | index_7 = torch.nonzero((pp <= 7 / 8) & (pp > 3 / 8), as_tuple=False)[:, 1]
57 | index_8 = torch.nonzero((pp <= 8 / 8) & (pp > 7 / 8), as_tuple=False)[:, 1]
58 |
59 |
60 |
61 | for k in range(8):
62 | list = random.sample(range(1, 9), num)
63 | for j in range(2):
64 | if list[j] == 1:
65 | length = index_1.__len__()
66 | for i in range(length):
67 | group_mask[k][:, index_1[i]] = 0
68 | elif list[j] == 2:
69 | length = index_2.__len__()
70 | for i in range(length):
71 | group_mask[k][:, index_2[i]] = 0
72 | elif list[j] == 3:
73 | length = index_3.__len__()
74 | for i in range(length):
75 | group_mask[k][:, index_3[i]] = 0
76 | elif list[j] == 4:
77 | length = index_4.__len__()
78 | for i in range(length):
79 | group_mask[k][:, index_4[i]] = 0
80 | elif list[j] == 5:
81 | length = index_5.__len__()
82 | for i in range(length):
83 | group_mask[k][:, index_5[i]] = 0
84 | elif list[j] == 6:
85 | length = index_6.__len__()
86 | for i in range(length):
87 | group_mask[k][:, index_6[i]] = 0
88 | elif list[j] == 7:
89 | length = index_7.__len__()
90 | for i in range(length):
91 | group_mask[k][:, index_7[i]] = 0
92 | elif list[j] == 8:
93 | length = index_8.__len__()
94 | for i in range(length):
95 | group_mask[k][:, index_8[i]] = 0
96 |
97 | np.save(os.path.join(out_path, img), group_mask.data.cpu().float().numpy()) # size: 8,49,49
98 |
99 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import random
2 | from data import ImageDetectionsField, TextField, RawField
3 | from data import DataLoader, Sydney, UCM, RSICD
4 | import evaluation
5 | from models.transformer import Transformer, MemoryAugmentedEncoder, MeshedDecoder, GlobalGroupingAttention_with_DC
6 | import torch
7 | from tqdm import tqdm
8 | import argparse, os, pickle
9 | import numpy as np
10 | import warnings
11 | warnings.filterwarnings("ignore")
12 |
13 |
14 |
15 | def seed_torch(seed=1):
16 | random.seed(seed)
17 | os.environ['PYTHONHASHSEED'] = str(seed)
18 | np.random.seed(seed)
19 | torch.manual_seed(seed)
20 | torch.cuda.manual_seed(seed)
21 | torch.backends.cudnn.deterministic = True
22 | torch.backends.cudnn.benchmark = False
23 | seed_torch()
24 |
25 |
26 | def predict_captions(model, dataloader, text_field):
27 | import itertools
28 | model.eval()
29 | gen = {}
30 | gts = {}
31 | with tqdm(desc='Evaluation', unit='it', total=len(dataloader)) as pbar:
32 | for it, (images, caps_gt) in enumerate(dataloader):
33 | detections = images[0].to(device)
34 | detections_gl = images[1].to(device)
35 | detections_mask = images[2].to(device)
36 |
37 | with torch.no_grad():
38 | out, _ = model.beam_search(detections, detections_gl, detections_mask, 20,
39 | text_field.vocab.stoi[''], 5, out_size=1)
40 | caps_gen = text_field.decode(out, join_words=False)
41 | for i, (gts_i, gen_i) in enumerate(zip(caps_gt, caps_gen)):
42 | gen_i = ' '.join([k for k, g in itertools.groupby(gen_i)])
43 | gen['%d_%d' % (it, i)] = [gen_i, ]
44 | gts['%d_%d' % (it, i)] = gts_i
45 | pbar.update()
46 |
47 | gts = evaluation.PTBTokenizer.tokenize(gts)
48 | gen = evaluation.PTBTokenizer.tokenize(gen)
49 | scores, _ = evaluation.compute_scores(gts, gen)
50 | return scores
51 |
52 |
53 |
54 |
55 | if __name__ == '__main__':
56 | device = torch.device('cuda')
57 |
58 | parser = argparse.ArgumentParser(description='MG-Transformer')
59 | parser.add_argument('--exp_name', type=str, default='Sydney') # Sydney,UCM,RSICD
60 | parser.add_argument('--batch_size', type=int, default=50)
61 | parser.add_argument('--workers', type=int, default=0)
62 |
63 | # sydney
64 | parser.add_argument('--annotation_folder', type=str,
65 | default='/media/dmd/ours/mlw/rs/Sydney_Captions')
66 | parser.add_argument('--res_features_path', type=str,
67 | default='/media/dmd/ours/mlw/rs/multi_scale_T/sydney_res152_con5_224')
68 |
69 | parser.add_argument('--clip_features_path', type=str,
70 | default='/media/dmd/ours/mlw/rs/clip_feature/sydney_224')
71 |
72 | parser.add_argument('--mask_features_path', type=str,
73 | default='/media/dmd/ours/mlw/rs/multi_scale_T/sydney_group_mask_8_6')
74 |
75 |
76 | args = parser.parse_args()
77 |
78 | print('Transformer Evaluation')
79 |
80 | # Pipeline for image regions
81 | image_field = ImageDetectionsField(clip_features_path=args.clip_features_path,
82 | res_features_path=args.res_features_path,
83 | mask_features_path=args.mask_features_path)
84 | # Pipeline for text
85 | text_field = TextField(init_token='', eos_token='', lower=True, tokenize='spacy',
86 | remove_punctuation=True, nopoints=False)
87 |
88 |
89 | # Create the dataset Sydney,UCM,RSICD
90 | if args.exp_name == 'Sydney':
91 | dataset = Sydney(image_field, text_field, 'Sydney/images/', args.annotation_folder, args.annotation_folder)
92 | elif args.exp_name == 'UCM':
93 | dataset = UCM(image_field, text_field, 'UCM/images/', args.annotation_folder, args.annotation_folder)
94 | elif args.exp_name == 'RSICD':
95 | dataset = RSICD(image_field, text_field, 'RSICD/images/', args.annotation_folder, args.annotation_folder)
96 |
97 | _, _, test_dataset = dataset.splits
98 |
99 | text_field.vocab = pickle.load(open('vocab_%s.pkl' % args.exp_name, 'rb'))
100 |
101 |
102 |
103 | # Model and dataloaders
104 | encoder = MemoryAugmentedEncoder(3, 0, attention_module=GlobalGroupingAttention_with_DC)
105 | decoder = MeshedDecoder(len(text_field.vocab), 127, 3, text_field.vocab.stoi[''])
106 | model = Transformer(text_field.vocab.stoi[''], encoder, decoder).to(device)
107 |
108 |
109 | data = torch.load('./saved_models/Sydney_best.pth')
110 | # data = torch.load('./saved_models/Sydney_best_test.pth')
111 |
112 | model.load_state_dict(data['state_dict'])
113 | dict_dataset_test = test_dataset.image_dictionary({'image': image_field, 'text': RawField()})
114 | dict_dataloader_test = DataLoader(dict_dataset_test, batch_size=args.batch_size, num_workers=args.workers)
115 |
116 | scores = predict_captions(model, dict_dataloader_test, text_field)
117 |
118 |
119 |
--------------------------------------------------------------------------------
/models/transformer/encoders.py:
--------------------------------------------------------------------------------
1 | from torch.nn import functional as F
2 | from models.transformer.utils import PositionWiseFeedForward
3 | import torch
4 | from torch import nn
5 | from models.transformer.attention import MultiHeadAttention
6 | from models.transformer.attention import ScaledDotProductAttention
7 |
8 |
9 | from models.transformer.utils import *
10 |
11 |
12 |
13 | class EncoderLayer(nn.Module):
14 | def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, identity_map_reordering=False,
15 | attention_module=None, attention_module_kwargs=None, dilation=None):
16 | super(EncoderLayer, self).__init__()
17 | self.identity_map_reordering = identity_map_reordering
18 | self.mhatt = MultiHeadAttention(d_model, d_k, d_v, h, dropout, identity_map_reordering=identity_map_reordering,
19 | attention_module=attention_module,
20 | attention_module_kwargs=attention_module_kwargs, dilation=dilation)
21 | self.pwff = PositionWiseFeedForward(d_model, d_ff, dropout, identity_map_reordering=identity_map_reordering)
22 |
23 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None, group_mask=None,
24 | input_gl=None, memory=None, isencoder=None):
25 | att = self.mhatt(queries, keys, values, attention_mask, attention_weights, group_mask=group_mask,
26 | input_gl=input_gl, memory=memory, isencoder=isencoder)
27 | ff = self.pwff(att)
28 | return ff
29 |
30 |
31 | class MultiLevelEncoder(nn.Module):
32 | def __init__(self, N, padding_idx, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1,
33 | identity_map_reordering=False, attention_module=None, attention_module_kwargs=None):
34 | super(MultiLevelEncoder, self).__init__()
35 | self.d_model = d_model
36 | self.dropout = dropout
37 |
38 | # aoa
39 | self.aoa_layer = nn.Sequential(nn.Linear(2 * self.d_model, 2 * self.d_model), nn.GLU())
40 | self.dropout_aoa = nn.Dropout(p=self.dropout)
41 |
42 |
43 | self.layers = nn.Sequential(EncoderLayer(d_model, d_k, d_v, h, d_ff, dropout,
44 | identity_map_reordering=identity_map_reordering,
45 | attention_module=attention_module,
46 | attention_module_kwargs=attention_module_kwargs, dilation=2),
47 | EncoderLayer(d_model, d_k, d_v, h, d_ff, dropout,
48 | identity_map_reordering=identity_map_reordering,
49 | attention_module=attention_module,
50 | attention_module_kwargs=attention_module_kwargs, dilation=4),
51 | EncoderLayer(d_model, d_k, d_v, h, d_ff, dropout,
52 | identity_map_reordering=identity_map_reordering,
53 | attention_module=attention_module,
54 | attention_module_kwargs=attention_module_kwargs, dilation=8)
55 | )
56 |
57 | self.mhatt = MultiHeadAttention(d_model, d_k, d_v, h, dropout, identity_map_reordering=identity_map_reordering,
58 | attention_module=ScaledDotProductAttention,
59 | attention_module_kwargs=attention_module_kwargs)
60 | self.padding_idx = padding_idx
61 | def forward(self, input, input_gl=None, isencoder=None, detections_mask=None, memory=None, attention_weights=None):
62 | # fine module
63 | attention_mask = (torch.sum(input, -1) == self.padding_idx).unsqueeze(1).unsqueeze(1)
64 | input = self.mhatt(input, input, input, attention_mask)
65 | out = self.aoa_layer(self.dropout_aoa(torch.cat([input, input_gl], -1)))
66 | group_mask = detections_mask
67 | outs = []
68 | for l in self.layers:
69 | out = l(out, out, out, attention_mask, attention_weights, group_mask=group_mask, input_gl=input_gl,
70 | memory=memory,isencoder=isencoder)
71 | outs.append(out.unsqueeze(1))
72 |
73 | outs = torch.cat(outs, 1)
74 | return outs, attention_mask
75 |
76 |
77 | class MemoryAugmentedEncoder(MultiLevelEncoder):
78 | def __init__(self, N, padding_idx, d_in=2048, d_clip=768, **kwargs):
79 | super(MemoryAugmentedEncoder, self).__init__(N, padding_idx, **kwargs)
80 | self.fc = nn.Linear(d_in, self.d_model)
81 | self.fc_clip = nn.Linear(d_clip, self.d_model)
82 |
83 | self.dropout_1 = nn.Dropout(p=self.dropout)
84 | self.layer_norm_1 = nn.LayerNorm(self.d_model)
85 |
86 | self.dropout_2 = nn.Dropout(p=self.dropout)
87 | self.layer_norm_2 = nn.LayerNorm(self.d_model)
88 |
89 |
90 | def forward(self, input, input_gl=None, detections_mask=None, isencoder=None, attention_weights=None):
91 | input_gl = F.relu(self.fc_clip(input_gl))
92 | input_gl = self.dropout_1(input_gl)
93 | input_gl = self.layer_norm_1(input_gl)
94 |
95 | out = F.relu(self.fc(input))
96 | out = self.dropout_2(out)
97 | out = self.layer_norm_2(out)
98 | return super(MemoryAugmentedEncoder, self).forward(out, input_gl=input_gl, detections_mask=detections_mask,
99 | isencoder=isencoder, attention_weights=attention_weights)
100 |
--------------------------------------------------------------------------------
/models/transformer/decoders.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 | import numpy as np
5 |
6 | from models.transformer.attention import MultiHeadAttention
7 | from models.transformer.utils import sinusoid_encoding_table, PositionWiseFeedForward
8 | from models.containers import Module, ModuleList
9 |
10 |
11 | class MeshedDecoderLayer(Module):
12 | def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, self_att_module=None,
13 | enc_att_module=None, self_att_module_kwargs=None, enc_att_module_kwargs=None):
14 | super(MeshedDecoderLayer, self).__init__()
15 | self.self_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=True,
16 | attention_module=self_att_module,
17 | attention_module_kwargs=self_att_module_kwargs)
18 | self.enc_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=False,
19 | attention_module=enc_att_module,
20 | attention_module_kwargs=enc_att_module_kwargs)
21 | self.pwff = PositionWiseFeedForward(d_model, d_ff, dropout)
22 |
23 | self.fc_alpha1 = nn.Linear(d_model + d_model, d_model)
24 | self.fc_alpha2 = nn.Linear(d_model + d_model, d_model)
25 | self.fc_alpha3 = nn.Linear(d_model + d_model, d_model)
26 |
27 | self.init_weights()
28 |
29 | def init_weights(self):
30 | nn.init.xavier_uniform_(self.fc_alpha1.weight)
31 | nn.init.xavier_uniform_(self.fc_alpha2.weight)
32 | nn.init.xavier_uniform_(self.fc_alpha3.weight)
33 | nn.init.constant_(self.fc_alpha1.bias, 0)
34 | nn.init.constant_(self.fc_alpha2.bias, 0)
35 | nn.init.constant_(self.fc_alpha3.bias, 0)
36 |
37 |
38 | def forward(self, input, enc_output, mask_pad, mask_self_att, mask_enc_att): # decoder(seq, enc_output, mask_enc)
39 | self_att = self.self_att(input, input, input, mask_self_att, input_gl=None,
40 | isencoder=False)
41 | self_att = self_att * mask_pad
42 | enc_att1 = self.enc_att(self_att, enc_output[:, 0], enc_output[:, 0], mask_enc_att) * mask_pad # 10 15 512
43 | enc_att2 = self.enc_att(self_att, enc_output[:, 1], enc_output[:, 1], mask_enc_att) * mask_pad # 10 15 512
44 | enc_att3 = self.enc_att(self_att, enc_output[:, 2], enc_output[:, 2], mask_enc_att) * mask_pad # 10 15 512
45 | alpha1 = torch.sigmoid(self.fc_alpha1(torch.cat([self_att, enc_att1], -1))) # 10 15 512
46 | alpha2 = torch.sigmoid(self.fc_alpha2(torch.cat([self_att, enc_att2], -1))) # 10 15 512
47 | alpha3 = torch.sigmoid(self.fc_alpha3(torch.cat([self_att, enc_att3], -1))) # 10 15 512
48 | enc_att = (enc_att1 * alpha1 + enc_att2 * alpha2 + enc_att3 * alpha3) / np.sqrt(3)
49 | enc_att = enc_att * mask_pad # 10 15 512
50 |
51 | ff = self.pwff(enc_att)
52 | ff = ff * mask_pad
53 | return ff
54 |
55 |
56 | class MeshedDecoder(Module):
57 | def __init__(self, vocab_size, max_len, N_dec, padding_idx, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1,
58 | self_att_module=None, enc_att_module=None, self_att_module_kwargs=None, enc_att_module_kwargs=None):
59 | super(MeshedDecoder, self).__init__()
60 | self.d_model = d_model
61 | self.word_emb = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx)
62 | self.pos_emb = nn.Embedding.from_pretrained(sinusoid_encoding_table(max_len+1, d_model, 0),
63 | freeze=True)
64 | self.layers = ModuleList(
65 | [MeshedDecoderLayer(d_model, d_k, d_v, h, d_ff, dropout, self_att_module=self_att_module,
66 | enc_att_module=enc_att_module, self_att_module_kwargs=self_att_module_kwargs,
67 | enc_att_module_kwargs=enc_att_module_kwargs) for _ in range(N_dec)])
68 | self.fc = nn.Linear(d_model, vocab_size, bias=False)
69 | self.max_len = max_len
70 | self.padding_idx = padding_idx
71 | self.N = N_dec
72 |
73 | self.register_state('running_mask_self_attention', torch.zeros((1, 1, 0)).byte())
74 | self.register_state('running_seq', torch.zeros((1,)).long())
75 |
76 | def forward(self, input, encoder_output, mask_encoder):
77 | b_s, seq_len = input.shape[:2] # 10 15
78 | mask_queries = (input != self.padding_idx).unsqueeze(-1).float()
79 | mask_self_attention = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.uint8, device=input.device),diagonal=1)
80 | mask_self_attention = mask_self_attention.unsqueeze(0).unsqueeze(0)
81 | mask_self_attention = mask_self_attention + (input == self.padding_idx).unsqueeze(1).unsqueeze(1).byte()
82 | mask_self_attention = mask_self_attention.gt(0)
83 |
84 | if self._is_stateful: # false
85 | device = self.running_mask_self_attention.device
86 | mask_self_attention = torch.tensor(mask_self_attention, dtype=torch.uint8).to(device)
87 | self.running_mask_self_attention = torch.cat([self.running_mask_self_attention, mask_self_attention],-1)
88 | mask_self_attention = self.running_mask_self_attention
89 |
90 | seq = torch.arange(1, seq_len + 1).view(1, -1).expand(b_s, -1).to(input.device)
91 | seq = seq.masked_fill(mask_queries.squeeze(-1) == 0, 0)
92 |
93 | if self._is_stateful:
94 | self.running_seq.add_(1)
95 | seq = self.running_seq
96 |
97 | out = self.word_emb(input) + self.pos_emb(seq)
98 |
99 | for i, l in enumerate(self.layers):
100 | out = l(out, encoder_output, mask_queries, mask_self_attention, mask_encoder)
101 |
102 |
103 | out = self.fc(out)
104 | return F.log_softmax(out, dim=-1)
105 |
--------------------------------------------------------------------------------
/models/beam_search/beam_search.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import utils
3 |
4 |
5 | class BeamSearch(object):
6 | def __init__(self, model, max_len: int, eos_idx: int, beam_size: int):
7 | self.model = model
8 | self.max_len = max_len
9 | self.eos_idx = eos_idx
10 | self.beam_size = beam_size
11 | self.b_s = None
12 | self.device = None
13 | self.seq_mask = None
14 | self.seq_logprob = None
15 | self.outputs = None
16 | self.log_probs = None
17 | self.selected_words = None
18 | self.all_log_probs = None
19 |
20 | def _expand_state(self, selected_beam, cur_beam_size):
21 | def fn(s):
22 | shape = [int(sh) for sh in s.shape]
23 | beam = selected_beam
24 | for _ in shape[1:]:
25 | beam = beam.unsqueeze(-1)
26 | # import pdb
27 | # pdb.set_trace()
28 | s = torch.gather(s.view(*([self.b_s, cur_beam_size] + shape[1:])), 1,
29 | beam.expand(*([self.b_s, self.beam_size] + shape[1:])))
30 | s = s.view(*([-1, ] + shape[1:]))
31 | return s
32 |
33 | return fn
34 |
35 | def _expand_visual(self, visual: utils.TensorOrSequence, cur_beam_size: int, selected_beam: torch.Tensor):
36 | if isinstance(visual, torch.Tensor):
37 | visual_shape = visual.shape
38 | visual_exp_shape = (self.b_s, cur_beam_size) + visual_shape[1:]
39 | visual_red_shape = (self.b_s * self.beam_size,) + visual_shape[1:]
40 | selected_beam_red_size = (self.b_s, self.beam_size) + tuple(1 for _ in range(len(visual_exp_shape) - 2))
41 | selected_beam_exp_size = (self.b_s, self.beam_size) + visual_exp_shape[2:]
42 | visual_exp = visual.view(visual_exp_shape)
43 | selected_beam_exp = selected_beam.view(selected_beam_red_size).expand(selected_beam_exp_size)
44 | visual = torch.gather(visual_exp, 1, selected_beam_exp).view(visual_red_shape)
45 | else:
46 | new_visual = []
47 | for im in visual:
48 | visual_shape = im.shape
49 | visual_exp_shape = (self.b_s, cur_beam_size) + visual_shape[1:]
50 | visual_red_shape = (self.b_s * self.beam_size,) + visual_shape[1:]
51 | selected_beam_red_size = (self.b_s, self.beam_size) + tuple(1 for _ in range(len(visual_exp_shape) - 2))
52 | selected_beam_exp_size = (self.b_s, self.beam_size) + visual_exp_shape[2:]
53 | visual_exp = im.view(visual_exp_shape)
54 | selected_beam_exp = selected_beam.view(selected_beam_red_size).expand(selected_beam_exp_size)
55 | new_im = torch.gather(visual_exp, 1, selected_beam_exp).view(visual_red_shape)
56 | new_visual.append(new_im)
57 | visual = tuple(new_visual)
58 | return visual
59 |
60 | def apply(self, visual: utils.TensorOrSequence, visual_gl: utils.TensorOrSequence,detections_mask:utils.TensorOrSequence, out_size=1, return_probs=False, **kwargs):
61 | self.b_s = utils.get_batch_size(visual)
62 | self.device = utils.get_device(visual)
63 | self.seq_mask = torch.ones((self.b_s, self.beam_size, 1), device=self.device)
64 | self.seq_logprob = torch.zeros((self.b_s, 1, 1), device=self.device)
65 | self.log_probs = []
66 | self.selected_words = None
67 | if return_probs:
68 | self.all_log_probs = []
69 |
70 | outputs = []
71 | with self.model.statefulness(self.b_s):
72 | for t in range(self.max_len):
73 | visual, outputs = self.iter(t, visual,visual_gl, detections_mask,outputs, return_probs, **kwargs)
74 |
75 |
76 | # Sort result
77 | seq_logprob, sort_idxs = torch.sort(self.seq_logprob, 1, descending=True)
78 | outputs = torch.cat(outputs, -1)
79 | outputs = torch.gather(outputs, 1, sort_idxs.expand(self.b_s, self.beam_size, self.max_len))
80 | log_probs = torch.cat(self.log_probs, -1)
81 | log_probs = torch.gather(log_probs, 1, sort_idxs.expand(self.b_s, self.beam_size, self.max_len))
82 | if return_probs:
83 | all_log_probs = torch.cat(self.all_log_probs, 2)
84 | all_log_probs = torch.gather(all_log_probs, 1, sort_idxs.unsqueeze(-1).expand(self.b_s, self.beam_size,
85 | self.max_len,
86 | all_log_probs.shape[-1]))
87 |
88 | outputs = outputs.contiguous()[:, :out_size]
89 | log_probs = log_probs.contiguous()[:, :out_size]
90 | if out_size == 1:
91 | outputs = outputs.squeeze(1)
92 | log_probs = log_probs.squeeze(1)
93 |
94 | if return_probs:
95 | return outputs, log_probs, all_log_probs
96 | else:
97 | return outputs, log_probs
98 |
99 | def select(self, t, candidate_logprob, **kwargs):
100 | selected_logprob, selected_idx = torch.sort(candidate_logprob.view(self.b_s, -1), -1, descending=True)
101 | selected_logprob, selected_idx = selected_logprob[:, :self.beam_size], selected_idx[:, :self.beam_size]
102 | return selected_idx, selected_logprob
103 |
104 | def iter(self, t: int, visual: utils.TensorOrSequence,visual_gl:utils.TensorOrSequence,detections_mask:utils.TensorOrSequence,outputs, return_probs, **kwargs):
105 | cur_beam_size = 1 if t == 0 else self.beam_size
106 | # word_logprob = self.model.step(t, self.selected_words, visual, visual_gl, None, mode='feedback', **kwargs)
107 |
108 | word_logprob = self.model.step(t, self.selected_words, visual, visual_gl, detections_mask,None, mode='feedback', **kwargs)
109 | word_logprob = word_logprob.view(self.b_s, cur_beam_size, -1)
110 | candidate_logprob = self.seq_logprob + word_logprob
111 |
112 | # Mask sequence if it reaches EOS
113 | if t > 0:
114 | mask = (self.selected_words.view(self.b_s, cur_beam_size) != self.eos_idx).float().unsqueeze(-1)
115 | self.seq_mask = self.seq_mask * mask
116 | word_logprob = word_logprob * self.seq_mask.expand_as(word_logprob)
117 | old_seq_logprob = self.seq_logprob.expand_as(candidate_logprob).contiguous()
118 | old_seq_logprob[:, :, 1:] = -999
119 | candidate_logprob = self.seq_mask * candidate_logprob + old_seq_logprob * (1 - self.seq_mask)
120 |
121 | selected_idx, selected_logprob = self.select(t, candidate_logprob, **kwargs)
122 | selected_beam = selected_idx // candidate_logprob.shape[-1]
123 | selected_words = selected_idx - selected_beam * candidate_logprob.shape[-1]
124 |
125 | self.model.apply_to_states(self._expand_state(selected_beam, cur_beam_size))
126 | visual = self._expand_visual(visual, cur_beam_size, selected_beam)
127 |
128 | self.seq_logprob = selected_logprob.unsqueeze(-1)
129 | self.seq_mask = torch.gather(self.seq_mask, 1, selected_beam.unsqueeze(-1))
130 | outputs = list(torch.gather(o, 1, selected_beam.unsqueeze(-1)) for o in outputs)
131 | outputs.append(selected_words.unsqueeze(-1))
132 |
133 | if return_probs:
134 | if t == 0:
135 | self.all_log_probs.append(word_logprob.expand((self.b_s, self.beam_size, -1)).unsqueeze(2))
136 | else:
137 | self.all_log_probs.append(word_logprob.unsqueeze(2))
138 |
139 | this_word_logprob = torch.gather(word_logprob, 1,
140 | selected_beam.unsqueeze(-1).expand(self.b_s, self.beam_size,
141 | word_logprob.shape[-1]))
142 | this_word_logprob = torch.gather(this_word_logprob, 2, selected_words.unsqueeze(-1))
143 | self.log_probs = list(
144 | torch.gather(o, 1, selected_beam.unsqueeze(-1).expand(self.b_s, self.beam_size, 1)) for o in self.log_probs)
145 | self.log_probs.append(this_word_logprob)
146 | self.selected_words = selected_words.view(-1, 1)
147 |
148 | return visual, outputs
149 |
--------------------------------------------------------------------------------
/models/transformer/attention.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import nn
4 | from models.containers import Module
5 | from torch.nn import functional as F
6 | from models.transformer.utils import *
7 |
8 |
9 | class SELayer(nn.Module):
10 | def __init__(self, channel, reduction=16):
11 | super(SELayer, self).__init__()
12 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
13 | self.fc = nn.Sequential(
14 | nn.Linear(channel, channel // reduction, bias=False),
15 | nn.ReLU(inplace=True),
16 | nn.Linear(channel // reduction, channel, bias=False),
17 | nn.Sigmoid()
18 | )
19 |
20 | def forward(self, x):
21 | b, c, _, _ = x.size()
22 | y = self.avg_pool(x).view(b, c)
23 | y = self.fc(y).view(b, c, 1, 1)
24 | return x * y.expand_as(x)
25 |
26 |
27 | class ScaledDotProductAttention(nn.Module):
28 | """
29 | Scaled dot-product attention
30 | """
31 |
32 | def __init__(self, d_model, d_k, d_v, h, dilation=None):
33 | """
34 | :param d_model: Output dimensionality of the model
35 | :param d_k: Dimensionality of queries and keys
36 | :param d_v: Dimensionality of values
37 | :param h: Number of heads
38 | """
39 | super(ScaledDotProductAttention, self).__init__()
40 | self.fc_q = nn.Linear(d_model, h * d_k)
41 | self.fc_k = nn.Linear(d_model, h * d_k)
42 | self.fc_v = nn.Linear(d_model, h * d_v)
43 | self.fc_o = nn.Linear(h * d_v, d_model)
44 |
45 | self.d_model = d_model
46 | self.d_k = d_k
47 | self.d_v = d_v
48 | self.h = h
49 |
50 | self.init_weights()
51 |
52 | def init_weights(self):
53 | nn.init.xavier_uniform_(self.fc_q.weight)
54 | nn.init.xavier_uniform_(self.fc_k.weight)
55 | nn.init.xavier_uniform_(self.fc_v.weight)
56 | nn.init.xavier_uniform_(self.fc_o.weight)
57 | nn.init.constant_(self.fc_q.bias, 0)
58 | nn.init.constant_(self.fc_k.bias, 0)
59 | nn.init.constant_(self.fc_v.bias, 0)
60 | nn.init.constant_(self.fc_o.bias, 0)
61 |
62 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None, group_mask=None,
63 | input_gl=None, memory=None,
64 | isencoder=None, dilation=None):
65 | """
66 | Computes
67 | :param queries: Queries (b_s, nq, d_model)
68 | :param keys: Keys (b_s, nk, d_model)
69 | :param values: Values (b_s, nk, d_model)
70 | :param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking.
71 | :param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk).
72 | :return:
73 | """
74 | # att[0][0].argmax()
75 | b_s, nq = queries.shape[:2]
76 | nk = keys.shape[1]
77 | q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k) 10 8 50 64
78 | k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk) 10 8 64 50
79 | v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v) 10 8 50 64
80 |
81 | att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk) 10 8 50 50
82 |
83 | if attention_weights is not None:
84 | att = att * attention_weights
85 | if attention_mask is not None:
86 | att = att.masked_fill(attention_mask.bool(), -np.inf)
87 |
88 | att = torch.softmax(att, -1)
89 | out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq,self.h * self.d_v)
90 | out = self.fc_o(out) # (b_s, nq, d_model) 10 50 512
91 | return out
92 |
93 |
94 |
95 | # ScaledDotProductAttention_dilated
96 | class GlobalGroupingAttention_with_DC(nn.Module):
97 | """
98 | Scaled dot-product attention
99 | """
100 |
101 | def __init__(self, d_model, d_k, d_v, h, dilation=None):
102 | """
103 | :param d_model: Output dimensionality of the model
104 | :param d_k: Dimensionality of queries and keys
105 | :param d_v: Dimensionality of values
106 | :param h: Number of heads
107 | """
108 | super(GlobalGroupingAttention_with_DC, self).__init__()
109 | self.fc_c = nn.Linear(49, 1)
110 |
111 | self.fc_q = nn.Linear(d_model, h * d_k)
112 | self.fc_k = nn.Linear(d_model, h * d_k)
113 | self.fc_v = nn.Linear(d_model, h * d_v)
114 | self.fc_o = nn.Linear(h * d_v, d_model)
115 |
116 | self.d_model = d_model
117 | self.d_k = d_k
118 | self.d_v = d_v
119 | self.h = h
120 |
121 | self.net = nn.Sequential(
122 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), stride=1, padding=dilation,
123 | dilation=dilation),
124 | nn.BatchNorm2d(512, affine=False), nn.ReLU(inplace=True))
125 |
126 | # self.SE = SELayer(channel=512)
127 | self.init_weights()
128 |
129 | def init_weights(self):
130 | nn.init.xavier_uniform_(self.fc_c.weight)
131 | nn.init.constant_(self.fc_c.bias, 0)
132 |
133 | nn.init.xavier_uniform_(self.fc_q.weight)
134 | nn.init.xavier_uniform_(self.fc_k.weight)
135 | nn.init.xavier_uniform_(self.fc_v.weight)
136 | nn.init.xavier_uniform_(self.fc_o.weight)
137 | nn.init.constant_(self.fc_q.bias, 0)
138 | nn.init.constant_(self.fc_k.bias, 0)
139 | nn.init.constant_(self.fc_v.bias, 0)
140 | nn.init.constant_(self.fc_o.bias, 0)
141 |
142 |
143 |
144 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None, group_mask=None,
145 | input_gl=None, memory=None, isencoder=None, dilation=None):
146 | """
147 | Computes
148 | :param queries: Queries (b_s, nq, d_model)
149 | :param keys: Keys (b_s, nk, d_model)
150 | :param values: Values (b_s, nk, d_model)
151 | :param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking.
152 | :param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk).
153 | :return:
154 | """
155 | #########################################################################
156 | # dilation
157 | x = queries.permute(0, 2, 1)
158 | x = x.reshape(x.shape[0], x.shape[1], 7, 7) # res
159 | # x = x.reshape(x.shape[0], x.shape[1], 14, 14) # vgg
160 | x = self.net(x)
161 | x_s = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(0, 2, 1)
162 |
163 | # channel
164 | # x_c=x_s.permute(0, 2, 1)
165 | # c=torch.softmax(self.fc_c(x_c),1).permute(0, 2, 1)
166 | # c = c.repeat(1, 49, 1)
167 |
168 | # spatial
169 | # queries = queries*c
170 | # keys = keys*c
171 | # values = values*c
172 |
173 | queries = x_s + queries
174 | keys = x_s + keys
175 | values = x_s + values
176 |
177 | b_s, nq = queries.shape[:2]
178 | nk = keys.shape[1]
179 | q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k)
180 | k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk)
181 | v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v)
182 |
183 | att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk)
184 | if attention_weights is not None:
185 | att = att * attention_weights
186 | if attention_mask is not None:
187 | att = att.masked_fill(attention_mask.bool(), -np.inf)
188 | if group_mask is not None:
189 | # (1)
190 | # att = att.masked_fill(group_mask.bool(), -np.inf)
191 | # (2)
192 | # att = att.masked_fill(group_mask.bool(), torch.tensor(-1e9))
193 | # (3)
194 | group_mask_mat=group_mask.masked_fill(group_mask.bool(), torch.tensor(-1e9))
195 | att=att+group_mask_mat
196 |
197 | att = torch.softmax(att, -1)
198 | out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v)
199 | out = self.fc_o(out) # (b_s, nq, d_model)
200 | return out
201 |
202 |
203 |
204 | class MultiHeadAttention(Module):
205 | """
206 | Multi-head attention layer with Dropout and Layer Normalization.
207 | """
208 |
209 | def __init__(self, d_model, d_k, d_v, h, dropout=.1, identity_map_reordering=False, can_be_stateful=False,
210 | attention_module=None, attention_module_kwargs=None, isenc=None, dilation=None):
211 | super(MultiHeadAttention, self).__init__()
212 | self.identity_map_reordering = identity_map_reordering
213 | if attention_module is not None:
214 | if attention_module_kwargs is not None:
215 | self.attention = attention_module(d_model=d_model, d_k=d_k, d_v=d_v, h=h, **attention_module_kwargs,
216 | dilation=dilation)
217 | else:
218 | self.attention = attention_module(d_model=d_model, d_k=d_k, d_v=d_v, h=h, dilation=dilation)
219 | else:
220 | self.attention = ScaledDotProductAttention(d_model=d_model, d_k=d_k, d_v=d_v, h=h)
221 |
222 |
223 | self.dropout = nn.Dropout(p=dropout)
224 | self.layer_norm = nn.LayerNorm(d_model)
225 |
226 | self.can_be_stateful = can_be_stateful
227 | if self.can_be_stateful:
228 | self.register_state('running_keys', torch.zeros((0, d_model)))
229 | self.register_state('running_values', torch.zeros((0, d_model)))
230 |
231 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None, group_mask=None,
232 | input_gl=None, memory=None,isencoder=None):
233 | if self.can_be_stateful and self._is_stateful:
234 | self.running_keys = torch.cat([self.running_keys, keys], 1)
235 | keys = self.running_keys
236 |
237 | self.running_values = torch.cat([self.running_values, values], 1)
238 | values = self.running_values
239 |
240 | if self.identity_map_reordering:
241 | q_norm = self.layer_norm(queries)
242 | k_norm = self.layer_norm(keys)
243 | v_norm = self.layer_norm(values)
244 | out = self.attention(q_norm, k_norm, v_norm, attention_mask, attention_weights, input_gl=input_gl)
245 | out = queries + self.dropout(torch.relu(out))
246 | else:
247 | if isencoder == True:
248 | out = self.attention(queries, keys, values, attention_mask, attention_weights, group_mask=group_mask,
249 | input_gl=input_gl, memory=memory,isencoder=isencoder)
250 | else:
251 | out = self.attention(queries, keys, values, attention_mask, attention_weights,
252 | input_gl=None, memory=memory, isencoder=isencoder)
253 | out = self.dropout(out)
254 | out = self.layer_norm(queries + out)
255 | return out
256 |
--------------------------------------------------------------------------------
/misc/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import collections
6 | import torch
7 | import torch.nn as nn
8 | import numpy as np
9 | import torch.optim as optim
10 | import os
11 |
12 | import six
13 | from six.moves import cPickle
14 |
15 | bad_endings = ['with', 'in', 'on', 'of', 'a', 'at', 'to', 'for', 'an', 'this', 'his', 'her', 'that']
16 | bad_endings += ['the']
17 |
18 |
19 | def pickle_load(f):
20 | """ Load a pickle.
21 | Parameters
22 | ----------
23 | f: file-like object
24 | """
25 | if six.PY3:
26 | return cPickle.load(f, encoding='latin-1')
27 | else:
28 | return cPickle.load(f)
29 |
30 |
31 | def pickle_dump(obj, f):
32 | """ Dump a pickle.
33 | Parameters
34 | ----------
35 | obj: pickled object
36 | f: file-like object
37 | """
38 | if six.PY3:
39 | return cPickle.dump(obj, f, protocol=2)
40 | else:
41 | return cPickle.dump(obj, f)
42 |
43 |
44 | def if_use_feat(caption_model):
45 | # Decide if load attention feature according to caption model
46 | if caption_model in ['show_tell', 'all_img', 'fc', 'newfc']:
47 | use_att, use_fc = False, True
48 | elif caption_model == 'language_model':
49 | use_att, use_fc = False, False
50 | elif caption_model in ['topdown', 'aoa']:
51 | use_fc, use_att = True, True
52 | else:
53 | use_att, use_fc = True, False
54 | return use_fc, use_att
55 |
56 |
57 | # Input: seq, N*D numpy array, with element 0 .. vocab_size. 0 is END token.
58 | def decode_sequence(ix_to_word, seq):
59 | N, D = seq.size()
60 | out = []
61 | for i in range(N):
62 | txt = ''
63 | for j in range(D):
64 | ix = seq[i, j]
65 | if ix > 0:
66 | if j >= 1:
67 | txt = txt + ' '
68 | txt = txt + ix_to_word[str(ix.item())]
69 | else:
70 | break
71 | if int(os.getenv('REMOVE_BAD_ENDINGS', '0')):
72 | flag = 0
73 | words = txt.split(' ')
74 | for j in range(len(words)):
75 | if words[-j - 1] not in bad_endings:
76 | flag = -j
77 | break
78 | txt = ' '.join(words[0:len(words) + flag])
79 | out.append(txt.replace('@@ ', ''))
80 | return out
81 |
82 |
83 | def to_contiguous(tensor):
84 | if tensor.is_contiguous():
85 | return tensor
86 | else:
87 | return tensor.contiguous()
88 |
89 |
90 | class RewardCriterion(nn.Module):
91 | def __init__(self):
92 | super(RewardCriterion, self).__init__()
93 |
94 | def forward(self, input, seq, reward):
95 | input = to_contiguous(input).view(-1)
96 | reward = to_contiguous(reward).view(-1)
97 | mask = (seq > 0).float()
98 | mask = to_contiguous(torch.cat([mask.new(mask.size(0), 1).fill_(1), mask[:, :-1]], 1)).view(-1)
99 | output = - input * reward * mask
100 | output = torch.sum(output) / torch.sum(mask)
101 |
102 | return output
103 |
104 |
105 | class LanguageModelCriterion(nn.Module):
106 | def __init__(self):
107 | super(LanguageModelCriterion, self).__init__()
108 |
109 | def forward(self, input, target, mask):
110 | # truncate to the same size
111 | target = target[:, :input.size(1)]
112 | mask = mask[:, :input.size(1)]
113 |
114 | output = -input.gather(2, target.unsqueeze(2)).squeeze(2) * mask
115 | output = torch.sum(output) / torch.sum(mask)
116 |
117 | return output
118 |
119 |
120 | class LabelSmoothing(nn.Module):
121 | "Implement label smoothing."
122 |
123 | def __init__(self, size=0, padding_idx=0, smoothing=0.0):
124 | super(LabelSmoothing, self).__init__()
125 | """ nn.KLDivLoss计算KL散度的损失函数,要将模型输出的原始预测值要先进行softmax,然后进行log运算
126 | (torch.nn.functional.log_softmax可以直接实现),得到结果作为input输入到KLDivLoss中。target是二维的,形状与input一样
127 |
128 | torch.nn.CrossEntropyLoss()交叉熵计算,输入的预测值不需要进行softmax,不需要进行log运算!直接用原始的预测输出,标签用整数序列。
129 | """
130 | self.criterion = nn.KLDivLoss(size_average=False, reduce=False)
131 | # self.padding_idx = padding_idx
132 | self.confidence = 1.0 - smoothing # 0.8
133 | self.smoothing = smoothing
134 | # self.size = size
135 | self.true_dist = None
136 |
137 | def forward(self, input, target, mask):
138 | # truncate to the same size
139 | target = target[:, :input.size(1)]
140 | mask = mask[:, :input.size(1)]
141 |
142 | input = to_contiguous(input).view(-1, input.size(-1))
143 | target = to_contiguous(target).view(-1)
144 | mask = to_contiguous(mask).view(-1)
145 |
146 | # assert x.size(1) == self.size
147 | self.size = input.size(1)
148 | # true_dist = x.data.clone()
149 | true_dist = input.data.clone()
150 | # true_dist.fill_(self.smoothing / (self.size - 2))
151 | true_dist.fill_(self.smoothing / (self.size - 1))
152 | true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
153 | # true_dist[:, self.padding_idx] = 0
154 | # mask = torch.nonzero(target.data == self.padding_idx)
155 | # self.true_dist = true_dist
156 | return (self.criterion(input, true_dist).sum(1) * mask).sum() / mask.sum()
157 |
158 |
159 | def set_lr(optimizer, lr):
160 | for group in optimizer.param_groups:
161 | group['lr'] = lr
162 |
163 |
164 | def get_lr(optimizer):
165 | for group in optimizer.param_groups:
166 | return group['lr']
167 |
168 |
169 | def clip_gradient(optimizer, grad_clip): # 梯度截断Clip, 将梯度约束在某一个区间之内. 避免梯度爆炸
170 | for group in optimizer.param_groups:
171 | for param in group['params']:
172 | param.grad.data.clamp_(-grad_clip, grad_clip)
173 |
174 |
175 | def build_optimizer(params, opt):
176 | if opt.optim == 'rmsprop':
177 | return optim.RMSprop(params, opt.learning_rate, opt.optim_alpha, opt.optim_epsilon,
178 | weight_decay=opt.weight_decay)
179 | elif opt.optim == 'adagrad':
180 | return optim.Adagrad(params, opt.learning_rate, weight_decay=opt.weight_decay)
181 | elif opt.optim == 'sgd':
182 | return optim.SGD(params, opt.learning_rate, weight_decay=opt.weight_decay)
183 | elif opt.optim == 'sgdm':
184 | return optim.SGD(params, opt.learning_rate, opt.optim_alpha, weight_decay=opt.weight_decay)
185 | elif opt.optim == 'sgdmom':
186 | return optim.SGD(params, opt.learning_rate, opt.optim_alpha, weight_decay=opt.weight_decay, nesterov=True)
187 | elif opt.optim == 'adam':
188 | return optim.Adam(params, opt.learning_rate, (opt.optim_alpha, opt.optim_beta), opt.optim_epsilon,
189 | weight_decay=opt.weight_decay)
190 | else:
191 | raise Exception("bad option opt.optim: {}".format(opt.optim))
192 |
193 |
194 | def penalty_builder(penalty_config):
195 | if penalty_config == '':
196 | return lambda x, y: y
197 | pen_type, alpha = penalty_config.split('_')
198 | alpha = float(alpha)
199 | if pen_type == 'wu':
200 | return lambda x, y: length_wu(x, y, alpha)
201 | if pen_type == 'avg':
202 | return lambda x, y: length_average(x, y, alpha)
203 |
204 |
205 | def length_wu(length, logprobs, alpha=0.):
206 | """
207 | NMT length re-ranking score from
208 | "Google's Neural Machine Translation System" :cite:`wu2016google`.
209 | """
210 |
211 | modifier = (((5 + length) ** alpha) /
212 | ((5 + 1) ** alpha))
213 | return (logprobs / modifier)
214 |
215 |
216 | def length_average(length, logprobs, alpha=0.):
217 | """
218 | Returns the average probability of tokens in a sequence.
219 | """
220 | return logprobs / length
221 |
222 |
223 | class NoamOpt(object):
224 | "Optim wrapper that implements rate."
225 |
226 | def __init__(self, model_size, factor, warmup, optimizer):
227 | self.optimizer = optimizer
228 | self._step = 0
229 | self.warmup = warmup
230 | self.factor = factor
231 | self.model_size = model_size
232 | self._rate = 0
233 |
234 | def step(self):
235 | "Update parameters and rate"
236 | self._step += 1
237 | rate = self.rate()
238 | for p in self.optimizer.param_groups:
239 | p['lr'] = rate
240 | self._rate = rate
241 | self.optimizer.step()
242 |
243 | def rate(self, step=None):
244 | "Implement `lrate` above"
245 | if step is None:
246 | step = self._step
247 | return self.factor * \
248 | (self.model_size ** (-0.5) *
249 | min(step ** (-0.5), step * self.warmup ** (-1.5)))
250 |
251 | def __getattr__(self, name):
252 | return getattr(self.optimizer, name)
253 |
254 |
255 | class ReduceLROnPlateau(object):
256 | "Optim wrapper that implements rate."
257 |
258 | def __init__(self, optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001,
259 | threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08):
260 | self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode, factor, patience, verbose, threshold,
261 | threshold_mode, cooldown, min_lr, eps)
262 | self.optimizer = optimizer
263 | self.current_lr = get_lr(optimizer)
264 |
265 | def step(self):
266 | "Update parameters and rate"
267 | self.optimizer.step()
268 |
269 | def scheduler_step(self, val):
270 | self.scheduler.step(val)
271 | self.current_lr = get_lr(self.optimizer)
272 |
273 | def state_dict(self):
274 | return {'current_lr': self.current_lr,
275 | 'scheduler_state_dict': self.scheduler.state_dict(),
276 | 'optimizer_state_dict': self.optimizer.state_dict()}
277 |
278 | def load_state_dict(self, state_dict):
279 | if 'current_lr' not in state_dict:
280 | # it's normal optimizer
281 | self.optimizer.load_state_dict(state_dict)
282 | set_lr(self.optimizer, self.current_lr) # use the lr fromt the option
283 | else:
284 | # it's a schduler
285 | self.current_lr = state_dict['current_lr']
286 | self.scheduler.load_state_dict(state_dict['scheduler_state_dict'])
287 | self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
288 | # current_lr is actually useless in this case
289 |
290 | def rate(self, step=None):
291 | "Implement `lrate` above"
292 | if step is None:
293 | step = self._step
294 | return self.factor * \
295 | (self.model_size ** (-0.5) *
296 | min(step ** (-0.5), step * self.warmup ** (-1.5)))
297 |
298 | def __getattr__(self, name):
299 | return getattr(self.optimizer, name)
300 |
301 |
302 | def get_std_opt(model, factor=1, warmup=2000):
303 | # return NoamOpt(model.tgt_embed[0].d_model, 2, 4000,
304 | # torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
305 | return NoamOpt(model.model.tgt_embed[0].d_model, factor, warmup,
306 | torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
307 |
--------------------------------------------------------------------------------
/data/field.py:
--------------------------------------------------------------------------------
1 | # coding: utf8
2 | from collections import Counter, OrderedDict
3 | from torch.utils.data.dataloader import default_collate
4 | from itertools import chain
5 | import six
6 | import torch
7 | import numpy as np
8 | import h5py
9 | import os
10 | import warnings
11 | import shutil
12 |
13 | from .dataset import Dataset
14 | from .vocab import Vocab
15 | from .utils import get_tokenizer
16 |
17 | class RawField(object):
18 | """ Defines a general datatype.
19 |
20 | Every dataset consists of one or more types of data. For instance,
21 | a machine translation dataset contains paired examples of text, while
22 | an image captioning dataset contains images and texts.
23 | Each of these types of data is represented by a RawField object.
24 | An RawField object does not assume any property of the data type and
25 | it holds parameters relating to how a datatype should be processed.
26 |
27 | Attributes:
28 | preprocessing: The Pipeline that will be applied to examples
29 | using this field before creating an example.
30 | Default: None.
31 | postprocessing: A Pipeline that will be applied to a list of examples
32 | using this field before assigning to a batch.
33 | Function signature: (batch(list)) -> object
34 | Default: None.
35 | """
36 |
37 | def __init__(self, preprocessing=None, postprocessing=None):
38 | self.preprocessing = preprocessing
39 | self.postprocessing = postprocessing
40 |
41 | def preprocess(self, x):
42 | """ Preprocess an example if the `preprocessing` Pipeline is provided. """
43 | if self.preprocessing is not None:
44 | return self.preprocessing(x)
45 | else:
46 | return x
47 |
48 | def process(self, batch, *args, **kwargs):
49 | """ Process a list of examples to create a batch.
50 |
51 | Postprocess the batch with user-provided Pipeline.
52 |
53 | Args:
54 | batch (list(object)): A list of object from a batch of examples.
55 | Returns:
56 | object: Processed object given the input and custom
57 | postprocessing Pipeline.
58 | """
59 | if self.postprocessing is not None:
60 | batch = self.postprocessing(batch)
61 | return default_collate(batch)
62 |
63 |
64 | class Merge(RawField):
65 | def __init__(self, *fields):
66 | super(Merge, self).__init__()
67 | self.fields = fields
68 |
69 | def preprocess(self, x):
70 | return tuple(f.preprocess(x) for f in self.fields)
71 |
72 | def process(self, batch, *args, **kwargs):
73 | if len(self.fields) == 1:
74 | batch = [batch, ]
75 | else:
76 | batch = list(zip(*batch))
77 |
78 | out = list(f.process(b, *args, **kwargs) for f, b in zip(self.fields, batch))
79 | return out
80 |
81 |
82 | class ImageDetectionsField(RawField):
83 | def __init__(self, preprocessing=None, postprocessing=None, clip_features_path=None,
84 | res_features_path=None,mask_features_path=None):
85 | self.clip_features_path = clip_features_path
86 | self.res_features_path = res_features_path
87 | self.mask_features_path = mask_features_path
88 | super(ImageDetectionsField, self).__init__(preprocessing, postprocessing)
89 |
90 |
91 | def preprocess(self, x, avoid_precomp=False):
92 | image_name = x.split('/')[-1]
93 |
94 | # resnet feature
95 | precomp_data = np.load(self.res_features_path + '/%s.npy' % image_name) # res 2048, 7, 7
96 | precomp_data = precomp_data.reshape(precomp_data.shape[0], precomp_data.shape[1] * precomp_data.shape[2])
97 | precomp_data=precomp_data.transpose(1, 0)
98 |
99 | # clip feature
100 | precomp_data_clip = np.load(self.clip_features_path + '/%s.npy' % image_name)
101 | precomp_data_clip = precomp_data_clip[0]
102 | precomp_data_clip = np.expand_dims(precomp_data_clip, axis=0)
103 | precomp_data_clip = precomp_data_clip.repeat([49], axis=0)
104 |
105 | # group mask matrix
106 | precomp_data_mask = np.load(self.mask_features_path + '/%s.npy' % image_name)
107 |
108 |
109 | return precomp_data.astype(np.float32), precomp_data_clip.astype(np.float32), precomp_data_mask.astype(np.float32)
110 |
111 |
112 |
113 |
114 | class TextField(RawField):
115 | vocab_cls = Vocab
116 | # Dictionary mapping PyTorch tensor dtypes to the appropriate Python
117 | # numeric type.
118 | dtypes = {
119 | torch.float32: float,
120 | torch.float: float,
121 | torch.float64: float,
122 | torch.double: float,
123 | torch.float16: float,
124 | torch.half: float,
125 |
126 | torch.uint8: int,
127 | torch.int8: int,
128 | torch.int16: int,
129 | torch.short: int,
130 | torch.int32: int,
131 | torch.int: int,
132 | torch.int64: int,
133 | torch.long: int,
134 | }
135 | punctuations = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \
136 | ".", "?", "!", ",", ":", "-", "--", "...", ";"]
137 |
138 | def __init__(self, use_vocab=True, init_token=None, eos_token=None, fix_length=None, dtype=torch.long,
139 | preprocessing=None, postprocessing=None, lower=False, tokenize=(lambda s: s.split()),
140 | remove_punctuation=False, include_lengths=False, batch_first=True, pad_token="",
141 | unk_token="", pad_first=False, truncate_first=False, vectors=None, nopoints=True):
142 | self.use_vocab = use_vocab
143 | self.init_token = init_token
144 | self.eos_token = eos_token
145 | self.fix_length = fix_length
146 | self.dtype = dtype
147 | self.lower = lower
148 | self.tokenize = get_tokenizer(tokenize)
149 | self.remove_punctuation = remove_punctuation
150 | self.include_lengths = include_lengths
151 | self.batch_first = batch_first
152 | self.pad_token = pad_token
153 | self.unk_token = unk_token
154 | self.pad_first = pad_first
155 | self.truncate_first = truncate_first
156 | self.vocab = None
157 | self.vectors = vectors
158 | if nopoints:
159 | self.punctuations.append("..")
160 |
161 | super(TextField, self).__init__(preprocessing, postprocessing)
162 |
163 | def preprocess(self, x):
164 |
165 | if isinstance(x, list):
166 | x = str(x)
167 |
168 | if six.PY2 and isinstance(x, six.string_types) and not isinstance(x, six.text_type):
169 | x = six.text_type(x, encoding='utf-8')
170 | if self.lower:
171 | x = six.text_type.lower(x)
172 | x = self.tokenize(x.rstrip('\n'))
173 | if self.remove_punctuation:
174 | x = [w for w in x if w not in self.punctuations]
175 | if self.preprocessing is not None:
176 | return self.preprocessing(x)
177 | else:
178 | return x
179 |
180 | def process(self, batch, device=None):
181 | padded = self.pad(batch)
182 | tensor = self.numericalize(padded, device=device)
183 | return tensor
184 |
185 | def build_vocab(self, *args, **kwargs):
186 | counter = Counter()
187 | sources = []
188 | for arg in args:
189 | if isinstance(arg, Dataset):
190 | # sources += [getattr(arg, name) for name, field in arg.fields.items() if field is self]
191 | for name, field in arg.fields.items():
192 | if field is self:
193 | sources += [getattr(arg, name)]
194 | else:
195 | sources.append(arg)
196 |
197 | for data in sources:
198 | for x in data:
199 | x = self.preprocess(x) # 返回list
200 | try:
201 | counter.update(x)
202 | except TypeError:
203 | counter.update(chain.from_iterable(x))
204 |
205 | specials = list(OrderedDict.fromkeys([
206 | tok for tok in [self.unk_token, self.pad_token, self.init_token,
207 | self.eos_token]
208 | if tok is not None]))
209 | self.vocab = self.vocab_cls(counter, specials=specials, **kwargs)
210 |
211 | def pad(self, minibatch):
212 | """Pad a batch of examples using this field.
213 | Pads to self.fix_length if provided, otherwise pads to the length of
214 | the longest example in the batch. Prepends self.init_token and appends
215 | self.eos_token if those attributes are not None. Returns a tuple of the
216 | padded list and a list containing lengths of each example if
217 | `self.include_lengths` is `True`, else just
218 | returns the padded list.
219 | """
220 | minibatch = list(minibatch)
221 | if self.fix_length is None:
222 | max_len = max(len(x) for x in minibatch)
223 | else:
224 | max_len = self.fix_length + (
225 | self.init_token, self.eos_token).count(None) - 2
226 | padded, lengths = [], []
227 | for x in minibatch:
228 | if self.pad_first:
229 | padded.append(
230 | [self.pad_token] * max(0, max_len - len(x)) +
231 | ([] if self.init_token is None else [self.init_token]) +
232 | list(x[-max_len:] if self.truncate_first else x[:max_len]) +
233 | ([] if self.eos_token is None else [self.eos_token]))
234 | else:
235 | padded.append(
236 | ([] if self.init_token is None else [self.init_token]) +
237 | list(x[-max_len:] if self.truncate_first else x[:max_len]) +
238 | ([] if self.eos_token is None else [self.eos_token]) +
239 | [self.pad_token] * max(0, max_len - len(x)))
240 | lengths.append(len(padded[-1]) - max(0, max_len - len(x)))
241 | if self.include_lengths:
242 | return padded, lengths
243 | return padded # 返回处理好的 batch text 10, 13
244 |
245 | def numericalize(self, arr, device=None):
246 | """Turn a batch of examples that use this field into a list of Variables.
247 | If the field has include_lengths=True, a tensor of lengths will be
248 | included in the return value.
249 | Arguments:
250 | arr (List[List[str]], or tuple of (List[List[str]], List[int])):
251 | List of tokenized and padded examples, or tuple of List of
252 | tokenized and padded examples and List of lengths of each
253 | example if self.include_lengths is True.
254 | device (str or torch.device): A string or instance of `torch.device`
255 | specifying which device the Variables are going to be created on.
256 | If left as default, the tensors will be created on cpu. Default: None.
257 | """
258 | if self.include_lengths and not isinstance(arr, tuple):
259 | raise ValueError("Field has include_lengths set to True, but "
260 | "input data is not a tuple of "
261 | "(data batch, batch lengths).")
262 | if isinstance(arr, tuple):
263 | arr, lengths = arr
264 | lengths = torch.tensor(lengths, dtype=self.dtype, device=device)
265 |
266 | if self.use_vocab:
267 | arr = [[self.vocab.stoi[x] for x in ex] for ex in arr]
268 |
269 | if self.postprocessing is not None:
270 | arr = self.postprocessing(arr, self.vocab)
271 |
272 | var = torch.tensor(arr, dtype=self.dtype, device=device)
273 | else:
274 | if self.vectors:
275 | arr = [[self.vectors[x] for x in ex] for ex in arr]
276 | if self.dtype not in self.dtypes:
277 | raise ValueError(
278 | "Specified Field dtype {} can not be used with "
279 | "use_vocab=False because we do not know how to numericalize it. "
280 | "Please raise an issue at "
281 | "https://github.com/pytorch/text/issues".format(self.dtype))
282 | numericalization_func = self.dtypes[self.dtype]
283 | # It doesn't make sense to explictly coerce to a numeric type if
284 | # the data is sequential, since it's unclear how to coerce padding tokens
285 | # to a numeric type.
286 | arr = [numericalization_func(x) if isinstance(x, six.string_types)
287 | else x for x in arr]
288 |
289 | if self.postprocessing is not None:
290 | arr = self.postprocessing(arr, None)
291 |
292 | var = torch.cat([torch.cat([a.unsqueeze(0) for a in ar]).unsqueeze(0) for ar in arr])
293 |
294 | # var = torch.tensor(arr, dtype=self.dtype, device=device)
295 | if not self.batch_first:
296 | var.t_() # 10 13的张量
297 | var = var.contiguous()
298 |
299 | if self.include_lengths:
300 | return var, lengths
301 | return var
302 |
303 | def decode(self, word_idxs, join_words=True):
304 | if isinstance(word_idxs, list) and len(word_idxs) == 0:
305 | return self.decode([word_idxs, ], join_words)[0]
306 | if isinstance(word_idxs, list) and isinstance(word_idxs[0], int):
307 | return self.decode([word_idxs, ], join_words)[0]
308 | elif isinstance(word_idxs, np.ndarray) and word_idxs.ndim == 1:
309 | return self.decode(word_idxs.reshape((1, -1)), join_words)[0]
310 | elif isinstance(word_idxs, torch.Tensor) and word_idxs.ndimension() == 1:
311 | return self.decode(word_idxs.unsqueeze(0), join_words)[0]
312 |
313 | captions = []
314 | for wis in word_idxs:
315 | caption = []
316 | for wi in wis:
317 | word = self.vocab.itos[int(wi)]
318 | if word == self.eos_token:
319 | break
320 | caption.append(word)
321 | if join_words:
322 | caption = ' '.join(caption)
323 | captions.append(caption)
324 |
325 | return captions
326 |
--------------------------------------------------------------------------------
/data/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import itertools
4 | import collections
5 | import torch
6 | from .example import Example
7 | from .utils import nostdout
8 | import json
9 |
10 |
11 | class Dataset(object):
12 | def __init__(self, examples, fields):
13 | self.examples = examples
14 | self.fields = dict(fields)
15 |
16 | def collate_fn(self):
17 | def collate(batch):
18 | if len(self.fields) == 1:
19 | batch = [batch, ]
20 | else:
21 | batch = list(zip(*batch))
22 |
23 | tensors = []
24 | for field, data in zip(self.fields.values(), batch):
25 | tensor = field.process(data)
26 | if isinstance(tensor, collections.Sequence) and any(isinstance(t, torch.Tensor) for t in tensor):
27 | tensors.extend(tensor)
28 | else:
29 | tensors.append(tensor)
30 |
31 | if len(tensors) > 1:
32 | return tensors
33 | else:
34 | return tensors[0]
35 |
36 | return collate
37 |
38 | def __getitem__(self, i):
39 | example = self.examples[i]
40 | data = []
41 | for field_name, field in self.fields.items():
42 | tem = getattr(example, field_name)
43 | data.append(field.preprocess(tem))
44 |
45 | if len(data) == 1:
46 | data = data[0]
47 | return data
48 |
49 | def __len__(self):
50 | return len(self.examples)
51 |
52 | def __getattr__(self, attr):
53 | if attr in self.fields:
54 | for x in self.examples:
55 | yield getattr(x, attr)
56 |
57 |
58 | class ValueDataset(Dataset):
59 | def __init__(self, examples, fields, dictionary):
60 | self.dictionary = dictionary
61 | super(ValueDataset, self).__init__(examples, fields)
62 |
63 | def collate_fn(self):
64 | def collate(batch):
65 | value_batch_flattened = list(itertools.chain(*batch))
66 | value_tensors_flattened = super(ValueDataset, self).collate_fn()(value_batch_flattened)
67 |
68 | lengths = [0, ] + list(itertools.accumulate([len(x) for x in batch]))
69 | if isinstance(value_tensors_flattened, collections.Sequence) \
70 | and any(isinstance(t, torch.Tensor) for t in value_tensors_flattened):
71 | value_tensors = [[vt[s:e] for (s, e) in zip(lengths[:-1], lengths[1:])] for vt in
72 | value_tensors_flattened]
73 | else:
74 | value_tensors = [value_tensors_flattened[s:e] for (s, e) in zip(lengths[:-1], lengths[1:])]
75 |
76 | return value_tensors
77 |
78 | return collate
79 |
80 | def __getitem__(self, i):
81 | if i not in self.dictionary:
82 | raise IndexError
83 |
84 | values_data = []
85 | for idx in self.dictionary[i]:
86 | value_data = super(ValueDataset, self).__getitem__(idx)
87 | values_data.append(value_data)
88 | return values_data
89 |
90 | def __len__(self):
91 | return len(self.dictionary)
92 |
93 |
94 | class DictionaryDataset(Dataset):
95 | def __init__(self, examples, fields, key_fields):
96 | if not isinstance(key_fields, (tuple, list)):
97 | key_fields = (key_fields,)
98 | for field in key_fields:
99 | assert (field in fields)
100 |
101 | dictionary = collections.defaultdict(list)
102 | key_fields = {k: fields[k] for k in key_fields}
103 | value_fields = {k: fields[k] for k in fields.keys() if k not in key_fields}
104 | key_examples = []
105 | key_dict = dict()
106 | value_examples = []
107 |
108 | for i, e in enumerate(examples):
109 | key_example = Example.fromdict({k: getattr(e, k) for k in key_fields})
110 | value_example = Example.fromdict({v: getattr(e, v) for v in value_fields})
111 | if key_example not in key_dict:
112 | key_dict[key_example] = len(key_examples)
113 | key_examples.append(key_example)
114 |
115 | value_examples.append(value_example)
116 | dictionary[key_dict[key_example]].append(i)
117 |
118 | self.key_dataset = Dataset(key_examples, key_fields)
119 | self.value_dataset = ValueDataset(value_examples, value_fields, dictionary)
120 | super(DictionaryDataset, self).__init__(examples, fields)
121 |
122 | def collate_fn(self):
123 | def collate(batch):
124 | key_batch, value_batch = list(zip(*batch))
125 | key_tensors = self.key_dataset.collate_fn()(key_batch)
126 | value_tensors = self.value_dataset.collate_fn()(value_batch)
127 | return key_tensors, value_tensors
128 |
129 | return collate
130 |
131 | def __getitem__(self, i):
132 | return self.key_dataset[i], self.value_dataset[i]
133 |
134 | def __len__(self):
135 | return len(self.key_dataset)
136 |
137 |
138 | def unique(sequence):
139 | seen = set()
140 | if isinstance(sequence[0], list):
141 | return [x for x in sequence if not (tuple(x) in seen or seen.add(tuple(x)))]
142 | else:
143 | return [x for x in sequence if not (x in seen or seen.add(x))]
144 |
145 |
146 | class PairedDataset(Dataset):
147 | def __init__(self, examples, fields):
148 | assert ('image' in fields)
149 | assert ('text' in fields)
150 | super(PairedDataset, self).__init__(examples, fields)
151 | self.image_field = self.fields['image']
152 | self.text_field = self.fields['text']
153 |
154 | def image_set(self):
155 | img_list = [e.image for e in self.examples]
156 | image_set = unique(img_list)
157 | examples = [Example.fromdict({'image': i}) for i in image_set]
158 | dataset = Dataset(examples, {'image': self.image_field})
159 | return dataset
160 |
161 | def text_set(self):
162 | text_list = [e.text for e in self.examples]
163 | text_list = unique(text_list)
164 | examples = [Example.fromdict({'text': t}) for t in text_list]
165 | dataset = Dataset(examples, {'text': self.text_field})
166 | return dataset
167 |
168 | def image_dictionary(self, fields=None):
169 | if not fields:
170 | fields = self.fields
171 | dataset = DictionaryDataset(self.examples, fields, key_fields='image')
172 | return dataset
173 |
174 | def text_dictionary(self, fields=None):
175 | if not fields:
176 | fields = self.fields
177 | dataset = DictionaryDataset(self.examples, fields, key_fields='text')
178 | return dataset
179 |
180 | @property
181 | def splits(self):
182 | raise NotImplementedError
183 |
184 |
185 | class Sydney(PairedDataset):
186 | def __init__(self, image_field, text_field, img_root, ann_root, id_root=None, use_restval=True,
187 | cut_validation=False):
188 | roots = {}
189 | img_root = os.path.join(img_root, '')
190 | roots['train'] = {
191 | 'img': os.path.join(img_root, ''),
192 | 'cap': os.path.join(ann_root, 'dataset.json')
193 | }
194 | roots['val'] = {
195 | 'img': os.path.join(img_root, ''),
196 | 'cap': os.path.join(ann_root, 'dataset.json')
197 | }
198 | roots['test'] = {
199 | 'img': os.path.join(img_root, ''),
200 | 'cap': os.path.join(ann_root, 'dataset.json')
201 | }
202 | ids = {'train': None, 'val': None, 'test': None}
203 |
204 | with nostdout():
205 | self.train_examples, self.val_examples, self.test_examples = self.get_samples(img_root, ann_root,
206 | ids)
207 | examples = self.train_examples + self.val_examples + self.test_examples
208 | super(Sydney, self).__init__(examples, {'image': image_field, 'text': text_field})
209 |
210 | @property
211 | def splits(self):
212 | train_split = PairedDataset(self.train_examples, self.fields)
213 | val_split = PairedDataset(self.val_examples, self.fields)
214 | test_split = PairedDataset(self.test_examples, self.fields)
215 | return train_split, val_split, test_split
216 |
217 | @classmethod
218 | def get_samples(cls, img_root, ann_root, ids_dataset=None):
219 | train_samples = []
220 | val_samples = []
221 | test_samples = []
222 |
223 | dataset = json.load(open(os.path.join(ann_root, 'dataset.json'), 'r'))['images']
224 | index = list(range(613))
225 | np.random.shuffle(index)
226 | # 497 58 58
227 | a1 = index[:497]
228 | a2 = index[497:555]
229 | a3 = index[555:613]
230 | for i, d in enumerate(dataset):
231 | if i in a1:
232 | for x in range(len(d['sentences'])):
233 | captions = d['sentences'][x]['raw']
234 | filename = d['filename']
235 | # imgid = d['imgid']
236 | example = Example.fromdict({'image': os.path.join(img_root, str(filename)), 'text': captions})
237 | train_samples.append(example)
238 | elif i in a2:
239 | for x in range(len(d['sentences'])):
240 | captions = d['sentences'][x]['raw']
241 | filename = d['filename']
242 | example = Example.fromdict({'image': os.path.join(img_root, str(filename)), 'text': captions})
243 | val_samples.append(example)
244 | elif i in a3:
245 | for x in range(len(d['sentences'])):
246 | captions = d['sentences'][x]['raw']
247 | filename = d['filename']
248 | example = Example.fromdict({'image': os.path.join(img_root, str(filename)), 'text': captions})
249 | test_samples.append(example)
250 |
251 | return train_samples, val_samples, test_samples
252 |
253 |
254 | class UCM(PairedDataset):
255 | def __init__(self, image_field, text_field, img_root, ann_root, id_root=None, use_restval=True,
256 | cut_validation=False):
257 |
258 | roots = {}
259 | img_root = os.path.join(img_root, '')
260 | roots['train'] = {
261 | 'img': os.path.join(img_root, ''),
262 | 'cap': os.path.join(ann_root, 'dataset.json')
263 | }
264 | roots['val'] = {
265 | 'img': os.path.join(img_root, ''),
266 | 'cap': os.path.join(ann_root, 'dataset.json')
267 | }
268 | roots['test'] = {
269 | 'img': os.path.join(img_root, ''),
270 | 'cap': os.path.join(ann_root, 'dataset.json')
271 | }
272 | ids = {'train': None, 'val': None, 'test': None}
273 |
274 | with nostdout():
275 | self.train_examples, self.val_examples, self.test_examples = self.get_samples(img_root, ann_root, ids)
276 | examples = self.train_examples + self.val_examples + self.test_examples # 三个数据集合在一起
277 | super(UCM, self).__init__(examples, {'image': image_field, 'text': text_field}) # 调用PairedDataset
278 |
279 | @property
280 | def splits(self):
281 | train_split = PairedDataset(self.train_examples, self.fields)
282 | val_split = PairedDataset(self.val_examples, self.fields)
283 | test_split = PairedDataset(self.test_examples, self.fields)
284 | return train_split, val_split, test_split
285 |
286 | @classmethod
287 | def get_samples(cls, img_root, ann_root, ids_dataset=None):
288 | train_samples = []
289 | val_samples = []
290 | test_samples = []
291 |
292 | dataset = json.load(open(os.path.join(ann_root, 'dataset.json'), 'r'))['images']
293 | index = list(range(2100))
294 | np.random.shuffle(index)
295 |
296 | # UCM 1680 210 210
297 | a1 = index[:1680]
298 | a2 = index[1680:1890]
299 | a3 = index[1890:2100]
300 |
301 | for i, d in enumerate(dataset):
302 | if i in a1:
303 | for x in range(len(d['sentences'])):
304 | captions = d['sentences'][x]['raw']
305 | filename = d['filename']
306 | # imgid = d['imgid']
307 | example = Example.fromdict({'image': os.path.join(img_root, str(filename)), 'text': captions})
308 | train_samples.append(example)
309 | elif i in a2:
310 | for x in range(len(d['sentences'])):
311 | captions = d['sentences'][x]['raw']
312 | filename = d['filename']
313 | example = Example.fromdict({'image': os.path.join(img_root, str(filename)), 'text': captions})
314 | val_samples.append(example)
315 | elif i in a3:
316 | for x in range(len(d['sentences'])):
317 | captions = d['sentences'][x]['raw']
318 | filename = d['filename']
319 | example = Example.fromdict({'image': os.path.join(img_root, str(filename)), 'text': captions})
320 | test_samples.append(example)
321 |
322 | return train_samples, val_samples, test_samples
323 |
324 |
325 | class RSICD(PairedDataset):
326 | def __init__(self, image_field, text_field, img_root, ann_root, id_root=None, use_restval=True,
327 | cut_validation=False):
328 |
329 | roots = {}
330 | img_root = os.path.join(img_root, '')
331 | roots['train'] = {
332 | 'img': os.path.join(img_root, ''),
333 | 'cap': os.path.join(ann_root, 'dataset.json')
334 | }
335 | roots['val'] = {
336 | 'img': os.path.join(img_root, ''),
337 | 'cap': os.path.join(ann_root, 'dataset.json')
338 | }
339 | roots['test'] = {
340 | 'img': os.path.join(img_root, ''),
341 | 'cap': os.path.join(ann_root, 'dataset.json')
342 | }
343 | ids = {'train': None, 'val': None, 'test': None}
344 |
345 | with nostdout():
346 | self.train_examples, self.val_examples, self.test_examples = self.get_samples(img_root, ann_root, ids)
347 | examples = self.train_examples + self.val_examples + self.test_examples
348 | super(RSICD, self).__init__(examples, {'image': image_field, 'text': text_field})
349 |
350 | @property
351 | def splits(self):
352 | train_split = PairedDataset(self.train_examples, self.fields)
353 | val_split = PairedDataset(self.val_examples, self.fields)
354 | test_split = PairedDataset(self.test_examples, self.fields)
355 | return train_split, val_split, test_split
356 |
357 | @classmethod
358 | def get_samples(cls, img_root, ann_root, ids_dataset=None):
359 | train_samples = []
360 | val_samples = []
361 | test_samples = []
362 |
363 | dataset = json.load(open(os.path.join(ann_root, 'dataset_rsicd.json'), 'r'))['images']
364 | index = list(range(10921))
365 | np.random.shuffle(index)
366 |
367 | # 10921 8734 1094 1093
368 | a1 = index[:8734]
369 | a2 = index[8734:9828]
370 | a3 = index[9828:10921]
371 | for i, d in enumerate(dataset):
372 | if i in a1:
373 | for x in range(len(d['sentences'])):
374 | captions = d['sentences'][x]['raw']
375 | filename = d['filename']
376 | example = Example.fromdict({'image': os.path.join(img_root, str(filename)), 'text': captions})
377 | train_samples.append(example)
378 | elif i in a2:
379 | for x in range(len(d['sentences'])):
380 | captions = d['sentences'][x]['raw']
381 | filename = d['filename']
382 | example = Example.fromdict({'image': os.path.join(img_root, str(filename)), 'text': captions})
383 | val_samples.append(example)
384 | elif i in a3:
385 | for x in range(len(d['sentences'])):
386 | captions = d['sentences'][x]['raw']
387 | filename = d['filename']
388 | example = Example.fromdict({'image': os.path.join(img_root, str(filename)), 'text': captions})
389 | test_samples.append(example)
390 |
391 | return train_samples, val_samples, test_samples
392 |
393 |
--------------------------------------------------------------------------------
/data/vocab.py:
--------------------------------------------------------------------------------
1 | from __future__ import unicode_literals
2 | import array
3 | from collections import defaultdict
4 | from functools import partial
5 | import io
6 | import logging
7 | import os
8 | import zipfile
9 |
10 | import six
11 | from six.moves.urllib.request import urlretrieve
12 | import torch
13 | from tqdm import tqdm
14 | import tarfile
15 |
16 | from .utils import reporthook
17 |
18 | logger = logging.getLogger(__name__)
19 |
20 |
21 | class Vocab(object):
22 | """Defines a vocabulary object that will be used to numericalize a field.
23 |
24 | Attributes:
25 | freqs: A collections.Counter object holding the frequencies of tokens
26 | in the data used to build the Vocab.
27 | stoi: A collections.defaultdict instance mapping token strings to
28 | numerical identifiers.
29 | itos: A list of token strings indexed by their numerical identifiers.
30 | """
31 | def __init__(self, counter, max_size=None, min_freq=1, specials=[''],
32 | vectors=None, unk_init=None, vectors_cache=None):
33 | """Create a Vocab object from a collections.Counter.
34 |
35 | Arguments:
36 | counter: collections.Counter object holding the frequencies of
37 | each value found in the data.
38 | max_size: The maximum size of the vocabulary, or None for no
39 | maximum. Default: None.
40 | min_freq: The minimum frequency needed to include a token in the
41 | vocabulary. Values less than 1 will be set to 1. Default: 1.
42 | specials: The list of special tokens (e.g., padding or eos) that
43 | will be prepended to the vocabulary in addition to an
44 | token. Default: ['']
45 | vectors: One of either the available pretrained vectors
46 | or custom pretrained vectors (see Vocab.load_vectors);
47 | or a list of aforementioned vectors
48 | unk_init (callback): by default, initialize out-of-vocabulary word vectors
49 | to zero vectors; can be any function that takes in a Tensor and
50 | returns a Tensor of the same size. Default: torch.Tensor.zero_
51 | vectors_cache: directory for cached vectors. Default: '.vector_cache'
52 | """
53 | self.freqs = counter
54 | counter = counter.copy()
55 | min_freq = max(min_freq, 1)
56 |
57 | self.itos = list(specials)
58 | # frequencies of special tokens are not counted when building vocabulary
59 | # in frequency order
60 | for tok in specials:
61 | del counter[tok]
62 |
63 | max_size = None if max_size is None else max_size + len(self.itos)
64 |
65 | # sort by frequency, then alphabetically
66 | words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0])
67 | words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True)
68 |
69 | for word, freq in words_and_frequencies:
70 | if freq < min_freq or len(self.itos) == max_size:
71 | break
72 | self.itos.append(word)
73 |
74 | self.stoi = defaultdict(_default_unk_index)
75 | # stoi is simply a reverse dict for itos
76 | self.stoi.update({tok: i for i, tok in enumerate(self.itos)})
77 |
78 | self.vectors = None
79 | if vectors is not None:
80 | self.load_vectors(vectors, unk_init=unk_init, cache=vectors_cache)
81 | else:
82 | assert unk_init is None and vectors_cache is None
83 |
84 | def __eq__(self, other):
85 | if self.freqs != other.freqs:
86 | return False
87 | if self.stoi != other.stoi:
88 | return False
89 | if self.itos != other.itos:
90 | return False
91 | if self.vectors != other.vectors:
92 | return False
93 | return True
94 |
95 | def __len__(self):
96 | return len(self.itos)
97 |
98 | def extend(self, v, sort=False):
99 | words = sorted(v.itos) if sort else v.itos
100 | for w in words:
101 | if w not in self.stoi:
102 | self.itos.append(w)
103 | self.stoi[w] = len(self.itos) - 1
104 |
105 | def load_vectors(self, vectors, **kwargs):
106 | """
107 | Arguments:
108 | vectors: one of or a list containing instantiations of the
109 | GloVe, CharNGram, or Vectors classes. Alternatively, one
110 | of or a list of available pretrained vectors:
111 | charngram.100d
112 | fasttext.en.300d
113 | fasttext.simple.300d
114 | glove.42B.300d
115 | glove.840B.300d
116 | glove.twitter.27B.25d
117 | glove.twitter.27B.50d
118 | glove.twitter.27B.100d
119 | glove.twitter.27B.200d
120 | glove.6B.50d
121 | glove.6B.100d
122 | glove.6B.200d
123 | glove.6B.300d
124 | Remaining keyword arguments: Passed to the constructor of Vectors classes.
125 | """
126 | if not isinstance(vectors, list):
127 | vectors = [vectors]
128 | for idx, vector in enumerate(vectors):
129 | if six.PY2 and isinstance(vector, str):
130 | vector = six.text_type(vector)
131 | if isinstance(vector, six.string_types):
132 | # Convert the string pretrained vector identifier
133 | # to a Vectors object
134 | if vector not in pretrained_aliases:
135 | raise ValueError(
136 | "Got string input vector {}, but allowed pretrained "
137 | "vectors are {}".format(
138 | vector, list(pretrained_aliases.keys())))
139 | vectors[idx] = pretrained_aliases[vector](**kwargs)
140 | elif not isinstance(vector, Vectors):
141 | raise ValueError(
142 | "Got input vectors of type {}, expected str or "
143 | "Vectors object".format(type(vector)))
144 |
145 | tot_dim = sum(v.dim for v in vectors)
146 | self.vectors = torch.Tensor(len(self), tot_dim)
147 | for i, token in enumerate(self.itos):
148 | start_dim = 0
149 | for v in vectors:
150 | end_dim = start_dim + v.dim
151 | self.vectors[i][start_dim:end_dim] = v[token.strip()]
152 | start_dim = end_dim
153 | assert(start_dim == tot_dim)
154 |
155 | def set_vectors(self, stoi, vectors, dim, unk_init=torch.Tensor.zero_):
156 | """
157 | Set the vectors for the Vocab instance from a collection of Tensors.
158 |
159 | Arguments:
160 | stoi: A dictionary of string to the index of the associated vector
161 | in the `vectors` input argument.
162 | vectors: An indexed iterable (or other structure supporting __getitem__) that
163 | given an input index, returns a FloatTensor representing the vector
164 | for the token associated with the index. For example,
165 | vector[stoi["string"]] should return the vector for "string".
166 | dim: The dimensionality of the vectors.
167 | unk_init (callback): by default, initialize out-of-vocabulary word vectors
168 | to zero vectors; can be any function that takes in a Tensor and
169 | returns a Tensor of the same size. Default: torch.Tensor.zero_
170 | """
171 | self.vectors = torch.Tensor(len(self), dim)
172 | for i, token in enumerate(self.itos):
173 | wv_index = stoi.get(token, None)
174 | if wv_index is not None:
175 | self.vectors[i] = vectors[wv_index]
176 | else:
177 | self.vectors[i] = unk_init(self.vectors[i])
178 |
179 |
180 | class Vectors(object):
181 |
182 | def __init__(self, name, cache=None,
183 | url=None, unk_init=None):
184 | """
185 | Arguments:
186 | name: name of the file that contains the vectors
187 | cache: directory for cached vectors
188 | url: url for download if vectors not found in cache
189 | unk_init (callback): by default, initalize out-of-vocabulary word vectors
190 | to zero vectors; can be any function that takes in a Tensor and
191 | returns a Tensor of the same size
192 | """
193 | cache = '.vector_cache' if cache is None else cache
194 | self.unk_init = torch.Tensor.zero_ if unk_init is None else unk_init
195 | self.cache(name, cache, url=url)
196 |
197 | def __getitem__(self, token):
198 | if token in self.stoi:
199 | return self.vectors[self.stoi[token]]
200 | else:
201 | return self.unk_init(torch.Tensor(self.dim)) # self.unk_init(torch.Tensor(1, self.dim))
202 |
203 | def cache(self, name, cache, url=None):
204 | if os.path.isfile(name):
205 | path = name
206 | path_pt = os.path.join(cache, os.path.basename(name)) + '.pt'
207 | else:
208 | path = os.path.join(cache, name)
209 | path_pt = path + '.pt'
210 |
211 | if not os.path.isfile(path_pt):
212 | if not os.path.isfile(path) and url:
213 | logger.info('Downloading vectors from {}'.format(url))
214 | if not os.path.exists(cache):
215 | os.makedirs(cache)
216 | dest = os.path.join(cache, os.path.basename(url))
217 | if not os.path.isfile(dest):
218 | with tqdm(unit='B', unit_scale=True, miniters=1, desc=dest) as t:
219 | try:
220 | urlretrieve(url, dest, reporthook=reporthook(t))
221 | except KeyboardInterrupt as e: # remove the partial zip file
222 | os.remove(dest)
223 | raise e
224 | logger.info('Extracting vectors into {}'.format(cache))
225 | ext = os.path.splitext(dest)[1][1:]
226 | if ext == 'zip':
227 | with zipfile.ZipFile(dest, "r") as zf:
228 | zf.extractall(cache)
229 | elif ext == 'gz':
230 | with tarfile.open(dest, 'r:gz') as tar:
231 | tar.extractall(path=cache)
232 | if not os.path.isfile(path):
233 | raise RuntimeError('no vectors found at {}'.format(path))
234 |
235 | # str call is necessary for Python 2/3 compatibility, since
236 | # argument must be Python 2 str (Python 3 bytes) or
237 | # Python 3 str (Python 2 unicode)
238 | itos, vectors, dim = [], array.array(str('d')), None
239 |
240 | # Try to read the whole file with utf-8 encoding.
241 | binary_lines = False
242 | try:
243 | with io.open(path, encoding="utf8") as f:
244 | lines = [line for line in f]
245 | # If there are malformed lines, read in binary mode
246 | # and manually decode each word from utf-8
247 | except:
248 | logger.warning("Could not read {} as UTF8 file, "
249 | "reading file as bytes and skipping "
250 | "words with malformed UTF8.".format(path))
251 | with open(path, 'rb') as f:
252 | lines = [line for line in f]
253 | binary_lines = True
254 |
255 | logger.info("Loading vectors from {}".format(path))
256 | for line in tqdm(lines, total=len(lines)):
257 | # Explicitly splitting on " " is important, so we don't
258 | # get rid of Unicode non-breaking spaces in the vectors.
259 | entries = line.rstrip().split(b" " if binary_lines else " ")
260 |
261 | word, entries = entries[0], entries[1:]
262 | if dim is None and len(entries) > 1:
263 | dim = len(entries)
264 | elif len(entries) == 1:
265 | logger.warning("Skipping token {} with 1-dimensional "
266 | "vector {}; likely a header".format(word, entries))
267 | continue
268 | elif dim != len(entries):
269 | raise RuntimeError(
270 | "Vector for token {} has {} dimensions, but previously "
271 | "read vectors have {} dimensions. All vectors must have "
272 | "the same number of dimensions.".format(word, len(entries), dim))
273 |
274 | if binary_lines:
275 | try:
276 | if isinstance(word, six.binary_type):
277 | word = word.decode('utf-8')
278 | except:
279 | logger.info("Skipping non-UTF8 token {}".format(repr(word)))
280 | continue
281 | vectors.extend(float(x) for x in entries)
282 | itos.append(word)
283 |
284 | self.itos = itos
285 | self.stoi = {word: i for i, word in enumerate(itos)}
286 | self.vectors = torch.Tensor(vectors).view(-1, dim)
287 | self.dim = dim
288 | logger.info('Saving vectors to {}'.format(path_pt))
289 | if not os.path.exists(cache):
290 | os.makedirs(cache)
291 | torch.save((self.itos, self.stoi, self.vectors, self.dim), path_pt)
292 | else:
293 | logger.info('Loading vectors from {}'.format(path_pt))
294 | self.itos, self.stoi, self.vectors, self.dim = torch.load(path_pt)
295 |
296 |
297 | class GloVe(Vectors):
298 | url = {
299 | '42B': 'http://nlp.stanford.edu/data/glove.42B.300d.zip',
300 | '840B': 'http://nlp.stanford.edu/data/glove.840B.300d.zip',
301 | 'twitter.27B': 'http://nlp.stanford.edu/data/glove.twitter.27B.zip',
302 | '6B': 'http://nlp.stanford.edu/data/glove.6B.zip',
303 | }
304 |
305 | def __init__(self, name='840B', dim=300, **kwargs):
306 | url = self.url[name]
307 | name = 'glove.{}.{}d.txt'.format(name, str(dim))
308 | super(GloVe, self).__init__(name, url=url, **kwargs)
309 |
310 |
311 | class FastText(Vectors):
312 |
313 | url_base = 'https://s3-us-west-1.amazonaws.com/fasttext-vectors/wiki.{}.vec'
314 |
315 | def __init__(self, language="en", **kwargs):
316 | url = self.url_base.format(language)
317 | name = os.path.basename(url)
318 | super(FastText, self).__init__(name, url=url, **kwargs)
319 |
320 |
321 | class CharNGram(Vectors):
322 |
323 | name = 'charNgram.txt'
324 | url = ('http://www.logos.t.u-tokyo.ac.jp/~hassy/publications/arxiv2016jmt/'
325 | 'jmt_pre-trained_embeddings.tar.gz')
326 |
327 | def __init__(self, **kwargs):
328 | super(CharNGram, self).__init__(self.name, url=self.url, **kwargs)
329 |
330 | def __getitem__(self, token):
331 | vector = torch.Tensor(1, self.dim).zero_()
332 | if token == "":
333 | return self.unk_init(vector)
334 | # These literals need to be coerced to unicode for Python 2 compatibility
335 | # when we try to join them with read ngrams from the files.
336 | chars = ['#BEGIN#'] + list(token) + ['#END#']
337 | num_vectors = 0
338 | for n in [2, 3, 4]:
339 | end = len(chars) - n + 1
340 | grams = [chars[i:(i + n)] for i in range(end)]
341 | for gram in grams:
342 | gram_key = '{}gram-{}'.format(n, ''.join(gram))
343 | if gram_key in self.stoi:
344 | vector += self.vectors[self.stoi[gram_key]]
345 | num_vectors += 1
346 | if num_vectors > 0:
347 | vector /= num_vectors
348 | else:
349 | vector = self.unk_init(vector)
350 | return vector
351 |
352 |
353 | def _default_unk_index():
354 | return 0
355 |
356 |
357 | pretrained_aliases = {
358 | "charngram.100d": partial(CharNGram),
359 | "fasttext.en.300d": partial(FastText, language="en"),
360 | "fasttext.simple.300d": partial(FastText, language="simple"),
361 | "glove.42B.300d": partial(GloVe, name="42B", dim="300"),
362 | "glove.840B.300d": partial(GloVe, name="840B", dim="300"),
363 | "glove.twitter.27B.25d": partial(GloVe, name="twitter.27B", dim="25"),
364 | "glove.twitter.27B.50d": partial(GloVe, name="twitter.27B", dim="50"),
365 | "glove.twitter.27B.100d": partial(GloVe, name="twitter.27B", dim="100"),
366 | "glove.twitter.27B.200d": partial(GloVe, name="twitter.27B", dim="200"),
367 | "glove.6B.50d": partial(GloVe, name="6B", dim="50"),
368 | "glove.6B.100d": partial(GloVe, name="6B", dim="100"),
369 | "glove.6B.200d": partial(GloVe, name="6B", dim="200"),
370 | "glove.6B.300d": partial(GloVe, name="6B", dim="300")
371 | }
372 | """Mapping from string name to factory function"""
373 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | from torch.backends import cudnn
4 |
5 | from data import ImageDetectionsField, TextField, RawField
6 | from data import DataLoader, Sydney, UCM, RSICD
7 | import evaluation
8 | from evaluation import PTBTokenizer, Cider
9 | from models.transformer import Transformer, MemoryAugmentedEncoder, MeshedDecoder, GlobalGroupingAttention_with_DC
10 | import torch
11 | from torch.optim import Adam
12 | from torch.optim.lr_scheduler import LambdaLR
13 | from torch.nn import NLLLoss
14 | from tqdm import tqdm
15 | from torch.utils.tensorboard import SummaryWriter
16 | import argparse, os, pickle
17 | import numpy as np
18 | import itertools
19 | import multiprocessing
20 | from shutil import copyfile
21 |
22 | import warnings
23 |
24 | warnings.filterwarnings("ignore")
25 |
26 |
27 |
28 | def seed_torch(seed=1):
29 | random.seed(seed)
30 | os.environ['PYTHONHASHSEED'] = str(seed) # 为了禁止hash随机化,使得实验可复现
31 | np.random.seed(seed)
32 | torch.manual_seed(seed)
33 | torch.cuda.manual_seed(seed)
34 | torch.backends.cudnn.deterministic = True
35 | torch.backends.cudnn.benchmark = False
36 | seed_torch()
37 |
38 |
39 | def evaluate_loss(model, dataloader, loss_fn, text_field):
40 | # Validation loss
41 | model.eval()
42 | running_loss = .0
43 | with tqdm(desc='Epoch %d - validation' % e, unit='it', total=len(dataloader)) as pbar:
44 | with torch.no_grad():
45 | for it, (detections, detections_gl, detections_mask, captions) in enumerate(dataloader):
46 | detections, detections_gl, detections_mask, captions = detections.to(device), detections_gl.to(device), \
47 | detections_mask.to(device), captions.to(device)
48 | out = model(detections, detections_gl, detections_mask, captions, isencoder=True)
49 | captions = captions[:, 1:].contiguous()
50 | out = out[:, :-1].contiguous()
51 | loss = loss_fn(out.view(-1, len(text_field.vocab)), captions.view(-1))
52 | this_loss = loss.item()
53 | running_loss += this_loss
54 | pbar.set_postfix(loss=running_loss / (it + 1))
55 | pbar.update()
56 | val_loss = running_loss / len(dataloader)
57 | return val_loss
58 |
59 |
60 | def evaluate_metrics(model, dataloader, text_field):
61 | import itertools
62 | model.eval()
63 | gen = {}
64 | gts = {}
65 | with tqdm(desc='Epoch %d - evaluation' % e, unit='it', total=len(dataloader)) as pbar:
66 | for it, (images, caps_gt) in enumerate(dataloader):
67 | detections = images[0].to(device)
68 | detections_gl = images[1].to(device)
69 | detections_mask = images[2].to(device)
70 |
71 | with torch.no_grad():
72 | out, _ = model.beam_search(detections, detections_gl, detections_mask, 20,
73 | text_field.vocab.stoi[''], 5, out_size=1)
74 | caps_gen = text_field.decode(out, join_words=False)
75 | for i, (gts_i, gen_i) in enumerate(zip(caps_gt, caps_gen)):
76 | gen_i = ' '.join([k for k, g in itertools.groupby(gen_i)])
77 | gen['%d_%d' % (it, i)] = [gen_i, ]
78 | gts['%d_%d' % (it, i)] = gts_i
79 | pbar.update()
80 |
81 | gts = evaluation.PTBTokenizer.tokenize(gts)
82 | gen = evaluation.PTBTokenizer.tokenize(gen)
83 | scores, _ = evaluation.compute_scores(gts, gen)
84 | return scores
85 |
86 |
87 | def train_xe(model, dataloader, optim, text_field):
88 | # Training with cross-entropy
89 | model.train()
90 | scheduler.step()
91 | running_loss = .0
92 | with tqdm(desc='Epoch %d - train' % e, unit='it', total=len(dataloader)) as pbar:
93 | for it, (detections, detections_gl, detections_mask, captions) in enumerate(
94 | dataloader):
95 |
96 | detections, detections_gl, detections_mask, captions = detections.to(device), detections_gl.to(device), \
97 | detections_mask.to(device), captions.to(device)
98 | out = model(detections, detections_gl, detections_mask, captions, isencoder=True)
99 | optim.zero_grad()
100 | captions_gt = captions[:, 1:].contiguous()
101 | out = out[:, :-1].contiguous()
102 | loss = loss_fn(out.view(-1, len(text_field.vocab)), captions_gt.view(-1))
103 | loss.backward()
104 |
105 | optim.step()
106 | this_loss = loss.item()
107 | running_loss += this_loss
108 |
109 | pbar.set_postfix(loss=running_loss / (it + 1))
110 | pbar.update()
111 | scheduler.step()
112 |
113 | loss = running_loss / len(dataloader)
114 | return loss
115 |
116 |
117 | def train_scst(model, dataloader, optim, cider, text_field):
118 | # Training with self-critical
119 | tokenizer_pool = multiprocessing.Pool()
120 | running_reward = .0
121 | running_reward_baseline = .0
122 | model.train()
123 | running_loss = .0
124 | seq_len = 20
125 | beam_size = 5
126 |
127 | with tqdm(desc='Epoch %d - train' % e, unit='it', total=len(dataloader)) as pbar:
128 | for it, (images, caps_gt) in enumerate(dataloader):
129 | detections = images[0].to(device)
130 | detections_gl = images[1].to(device)
131 | detections_mask = images[2].to(device)
132 |
133 | outs, log_probs = model.beam_search(detections, detections_gl, detections_mask, 20,
134 | text_field.vocab.stoi[''], beam_size, out_size=beam_size)
135 |
136 | optim.zero_grad()
137 |
138 | # Rewards
139 | caps_gen = text_field.decode(outs.view(-1, seq_len))
140 | caps_gt = list(itertools.chain(*([c, ] * beam_size for c in caps_gt)))
141 | caps_gen, caps_gt = tokenizer_pool.map(evaluation.PTBTokenizer.tokenize, [caps_gen, caps_gt])
142 | reward = cider.compute_score(caps_gt, caps_gen)[1].astype(np.float32)
143 | reward = torch.from_numpy(reward).to(device).view(detections.shape[0], beam_size)
144 | reward_baseline = torch.mean(reward, -1, keepdim=True)
145 | loss = -torch.mean(log_probs, -1) * (reward - reward_baseline)
146 |
147 | loss = loss.mean()
148 | loss.backward()
149 | optim.step()
150 |
151 | running_loss += loss.item()
152 | running_reward += reward.mean().item()
153 | running_reward_baseline += reward_baseline.mean().item()
154 | pbar.set_postfix(loss=running_loss / (it + 1), reward=running_reward / (it + 1),
155 | reward_baseline=running_reward_baseline / (it + 1))
156 | pbar.update()
157 |
158 | loss = running_loss / len(dataloader)
159 | reward = running_reward / len(dataloader)
160 | reward_baseline = running_reward_baseline / len(dataloader)
161 | return loss, reward, reward_baseline
162 |
163 |
164 |
165 | if __name__ == '__main__':
166 | device = torch.device('cuda')
167 | parser = argparse.ArgumentParser(description='MG-Transformer')
168 | parser.add_argument('--exp_name', type=str, default='Sydney') # Sydney,UCM,RSICD
169 | parser.add_argument('--batch_size', type=int, default=50)
170 | parser.add_argument('--workers', type=int, default=0)
171 | parser.add_argument('--head', type=int, default=8)
172 | parser.add_argument('--warmup', type=int, default=10000) # 1000
173 | parser.add_argument('--warm_up_epochs', type=int, default=10) # 10
174 | parser.add_argument('--epochs', type=int, default=100)
175 | parser.add_argument('--resume_last', action='store_true')
176 | parser.add_argument('--resume_best', action='store_true')
177 | ################################################################################################################
178 | parser.add_argument('--annotation_folder', type=str,
179 | default='/media/dmd/ours/mlw/rs/Sydney_Captions')
180 | parser.add_argument('--res_features_path', type=str,
181 | default='/media/dmd/ours/mlw/rs/multi_scale_T/sydney_res152_con5_224')
182 |
183 | parser.add_argument('--clip_features_path', type=str,
184 | default='/media/dmd/ours/mlw/rs/clip_feature/sydney_224')
185 |
186 | parser.add_argument('--mask_features_path', type=str,
187 | default='/media/dmd/ours/mlw/rs/multi_scale_T/sydney_group_mask_8_6')
188 |
189 |
190 | parser.add_argument('--logs_folder', type=str, default='tensorboard_logs')
191 | args = parser.parse_args()
192 | print(args)
193 |
194 | print('Transformer Training')
195 | # 日志
196 | writer = SummaryWriter(log_dir=os.path.join(args.logs_folder, args.exp_name))
197 |
198 | # Pipeline for image regions
199 | image_field = ImageDetectionsField(clip_features_path=args.clip_features_path,
200 | res_features_path=args.res_features_path,
201 | mask_features_path=args.mask_features_path)
202 |
203 | # Pipeline for text
204 | text_field = TextField(init_token='', eos_token='', lower=True, tokenize='spacy',
205 | remove_punctuation=True, nopoints=False)
206 |
207 | # Create the dataset
208 | if args.exp_name == 'Sydney':
209 | dataset = Sydney(image_field, text_field, 'Sydney/images/', args.annotation_folder, args.annotation_folder)
210 | elif args.exp_name == 'UCM':
211 | dataset = UCM(image_field, text_field, 'UCM/images/', args.annotation_folder, args.annotation_folder)
212 | elif args.exp_name == 'RSICD':
213 | dataset = RSICD(image_field, text_field, 'RSICD/images/', args.annotation_folder, args.annotation_folder)
214 |
215 |
216 | train_dataset, val_dataset, test_dataset = dataset.splits
217 |
218 |
219 | if not os.path.isfile('vocab_%s.pkl' % args.exp_name):
220 | print("Building vocabulary")
221 | text_field.build_vocab(train_dataset, val_dataset, min_freq=5)
222 | pickle.dump(text_field.vocab, open('vocab_%s.pkl' % args.exp_name, 'wb'))
223 | else:
224 | text_field.vocab = pickle.load(open('vocab_%s.pkl' % args.exp_name, 'rb'))
225 |
226 |
227 | encoder = MemoryAugmentedEncoder(3, 0, attention_module=GlobalGroupingAttention_with_DC)
228 | decoder = MeshedDecoder(len(text_field.vocab), 127, 3, text_field.vocab.stoi[''])
229 | model = Transformer(text_field.vocab.stoi[''], encoder, decoder).to(device)
230 |
231 | dict_dataset_train = train_dataset.image_dictionary({'image': image_field, 'text': RawField()})
232 | ref_caps_train = list(train_dataset.text)
233 | cider_train = Cider(PTBTokenizer.tokenize(ref_caps_train))
234 | dict_dataset_val = val_dataset.image_dictionary({'image': image_field, 'text': RawField()})
235 | dict_dataset_test = test_dataset.image_dictionary({'image': image_field, 'text': RawField()})
236 |
237 |
238 | def lambda_lr(s):
239 | warm_up = args.warmup
240 | s += 1
241 | return (model.d_model ** -.5) * min(s ** -.5, s * warm_up ** -1.5)
242 | # return (model.d_model ** -.5) * min(s ** -.5, s * warm_up ** -2)
243 |
244 |
245 |
246 | # Initial conditions
247 | optim = Adam(model.parameters(), lr=1, betas=(0.9, 0.98))
248 | scheduler = LambdaLR(optim, lambda_lr)
249 | loss_fn = NLLLoss(ignore_index=text_field.vocab.stoi[''])
250 | use_rl = False
251 | best_cider = .0
252 | best_test_cider = .0
253 | patience = 0
254 | start_epoch = 0
255 |
256 | if args.resume_last or args.resume_best:
257 | if args.resume_last:
258 | fname = 'saved_models/%s_last.pth' % args.exp_name
259 | else:
260 | fname = 'saved_models/%s_best.pth' % args.exp_name
261 |
262 | if os.path.exists(fname):
263 | data = torch.load(fname)
264 | torch.set_rng_state(data['torch_rng_state'])
265 | torch.cuda.set_rng_state(data['cuda_rng_state'])
266 | np.random.set_state(data['numpy_rng_state'])
267 | random.setstate(data['random_rng_state'])
268 | model.load_state_dict(data['state_dict'], strict=False)
269 | optim.load_state_dict(data['optimizer'])
270 | scheduler.load_state_dict(data['scheduler'])
271 | start_epoch = data['epoch'] + 1
272 | best_cider = data['best_cider']
273 | best_test_cider = data['best_test_cider']
274 | patience = data['patience']
275 | use_rl = data['use_rl']
276 | # print('Resuming from epoch %d, validation loss %f, best cider %f, and best_test_cider %f' % (
277 | # data['epoch'], data['val_loss'], data['best_cider'], data['best_test_cider']))
278 |
279 | print("Training starts")
280 |
281 | for e in range(start_epoch, start_epoch + 100):
282 | dataloader_train = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers,
283 | drop_last=True)
284 | dataloader_val = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
285 | dict_dataloader_train = DataLoader(dict_dataset_train, batch_size=args.batch_size // 5, shuffle=True,
286 | num_workers=args.workers)
287 | dict_dataloader_val = DataLoader(dict_dataset_val, batch_size=args.batch_size // 5)
288 | dict_dataloader_test = DataLoader(dict_dataset_test, batch_size=args.batch_size // 5)
289 | if not use_rl:
290 | train_loss = train_xe(model, dataloader_train, optim, text_field)
291 | writer.add_scalar('data/train_loss', train_loss, e)
292 | else:
293 | train_loss, reward, reward_baseline = train_scst(model, dict_dataloader_train, optim, cider_train,
294 | text_field)
295 | writer.add_scalar('data/train_loss', train_loss, e)
296 | writer.add_scalar('data/reward', reward, e)
297 | writer.add_scalar('data/reward_baseline', reward_baseline, e)
298 |
299 |
300 | # Validation loss
301 | val_loss = evaluate_loss(model, dataloader_val, loss_fn, text_field)
302 | writer.add_scalar('data/val_loss', val_loss, e)
303 |
304 | # Validation scores
305 | scores = evaluate_metrics(model, dict_dataloader_val, text_field)
306 | # print("Validation scores", scores)
307 | val_cider = scores['CIDEr']
308 | writer.add_scalar('data/val_cider', val_cider, e)
309 | writer.add_scalar('data/val_bleu1', scores['BLEU'][0], e)
310 | writer.add_scalar('data/val_bleu2', scores['BLEU'][1], e)
311 | writer.add_scalar('data/val_bleu3', scores['BLEU'][2], e)
312 | writer.add_scalar('data/val_bleu4', scores['BLEU'][3], e)
313 | writer.add_scalar('data/val_meteor', scores['METEOR'], e)
314 | writer.add_scalar('data/val_rouge', scores['ROUGE'], e)
315 | writer.add_scalar('data/val_spice', scores['SPICE'], e)
316 | writer.add_scalar('data/val_S*',
317 | (scores['BLEU'][3] + scores['METEOR'] + scores['ROUGE'] + scores['CIDEr']) / 4, e)
318 | writer.add_scalar('data/val_Sm', (
319 | scores['BLEU'][3] + scores['METEOR'] + scores['ROUGE'] + scores['CIDEr'] + scores['SPICE']) / 5, e)
320 |
321 | # Test scores
322 | scores = evaluate_metrics(model, dict_dataloader_test, text_field)
323 | print("Test scores", scores)
324 | test_cider = scores['CIDEr']
325 | writer.add_scalar('data/test_cider', scores['CIDEr'], e)
326 | writer.add_scalar('data/test_bleu1', scores['BLEU'][0], e)
327 | writer.add_scalar('data/test_bleu2', scores['BLEU'][1], e)
328 | writer.add_scalar('data/test_bleu3', scores['BLEU'][2], e)
329 | writer.add_scalar('data/test_bleu4', scores['BLEU'][3], e)
330 | writer.add_scalar('data/test_meteor', scores['METEOR'], e)
331 | writer.add_scalar('data/test_rouge', scores['ROUGE'], e)
332 | writer.add_scalar('data/test_spice', scores['SPICE'], e)
333 | writer.add_scalar('data/test_S*',
334 | (scores['BLEU'][3] + scores['METEOR'] + scores['ROUGE'] + scores['CIDEr']) / 4, e)
335 | writer.add_scalar('data/test_Sm', (
336 | scores['BLEU'][3] + scores['METEOR'] + scores['ROUGE'] + scores['CIDEr'] + scores['SPICE']) / 5, e)
337 |
338 |
339 |
340 | # Prepare for next epoch
341 | best = False
342 | if val_cider >= best_cider:
343 | best_cider = val_cider
344 | patience = 0
345 | best = True
346 | else:
347 | patience += 1
348 |
349 | best_test = False
350 | if test_cider >= best_test_cider:
351 | best_test_cider = test_cider
352 | best_test = True
353 |
354 | switch_to_rl = False
355 | exit_train = False
356 | if patience == 5:
357 | if not use_rl:
358 | use_rl = True
359 | switch_to_rl = True
360 | patience = 0
361 | optim = Adam(model.parameters(), lr=5e-6)
362 | # print("Switching to RL")
363 | else:
364 | print('patience reached.')
365 | exit_train = True
366 |
367 |
368 | if switch_to_rl and not best:
369 | data = torch.load('saved_models/%s_best.pth' % args.exp_name)
370 | torch.set_rng_state(data['torch_rng_state'])
371 | torch.cuda.set_rng_state(data['cuda_rng_state'])
372 | np.random.set_state(data['numpy_rng_state'])
373 | random.setstate(data['random_rng_state'])
374 | model.load_state_dict(data['state_dict'])
375 | # print('Resuming from epoch %d, validation loss %f, best_cider %f, and best test_cider %f' % (
376 | # data['epoch'], data['val_loss'], data['best_cider'], data['best_test_cider']))
377 |
378 |
379 | torch.save({
380 | 'torch_rng_state': torch.get_rng_state(),
381 | 'cuda_rng_state': torch.cuda.get_rng_state(),
382 | 'numpy_rng_state': np.random.get_state(),
383 | 'random_rng_state': random.getstate(),
384 | 'epoch': e,
385 | 'val_loss': val_loss,
386 | 'val_cider': val_cider,
387 | 'state_dict': model.state_dict(),
388 | 'optimizer': optim.state_dict(),
389 | 'scheduler': scheduler.state_dict(),
390 | 'patience': patience,
391 | 'best_cider': best_cider,
392 | 'best_test_cider': best_test_cider,
393 | 'use_rl': use_rl,
394 | }, 'saved_models/%s_last.pth' % args.exp_name)
395 |
396 | if best:
397 | copyfile('saved_models/%s_last.pth' % args.exp_name, 'saved_models/%s_best.pth' % args.exp_name)
398 |
399 |
400 | if best_test:
401 | copyfile('saved_models/%s_last.pth' % args.exp_name, 'saved_models/%s_best_test.pth' % args.exp_name)
402 |
403 | if exit_train:
404 | writer.close()
405 | break
406 |
407 |
--------------------------------------------------------------------------------