├── misc ├── __init__.py ├── resnet_utils.py ├── loss_wrapper.py ├── resnet.py ├── rewards.py └── utils.py ├── models ├── beam_search │ ├── __init__.py │ └── beam_search.py ├── __init__.py ├── transformer │ ├── __init__.py │ ├── utils.py │ ├── transformer.py │ ├── encoders.py │ ├── decoders.py │ └── attention.py ├── containers.py └── captioning_model.py ├── evaluation └── readme.txt ├── vocab_UCM.pkl ├── vocab_RSICD.pkl ├── vocab_Sydney.pkl ├── MG-Transformer.png ├── images └── MG-Transformer.png ├── utils ├── typing.py ├── __init__.py └── utils.py ├── data ├── __init__.py ├── example.py ├── utils.py ├── field.py ├── dataset.py └── vocab.py ├── feature_pro ├── pre_CLIP_feature.py ├── pre_region_feature.py └── split_group.py ├── LICENSE ├── README.md ├── environment.yml ├── test.py └── train.py /misc/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/beam_search/__init__.py: -------------------------------------------------------------------------------- 1 | from .beam_search import BeamSearch 2 | -------------------------------------------------------------------------------- /evaluation/readme.txt: -------------------------------------------------------------------------------- 1 | 链接: https://pan.baidu.com/s/13ZfH-CMYbW3RsW0-RX7KKQ 提取码: wuiu -------------------------------------------------------------------------------- /vocab_UCM.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/One-paper-luck/MG-Transformer/HEAD/vocab_UCM.pkl -------------------------------------------------------------------------------- /vocab_RSICD.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/One-paper-luck/MG-Transformer/HEAD/vocab_RSICD.pkl -------------------------------------------------------------------------------- /vocab_Sydney.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/One-paper-luck/MG-Transformer/HEAD/vocab_Sydney.pkl -------------------------------------------------------------------------------- /MG-Transformer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/One-paper-luck/MG-Transformer/HEAD/MG-Transformer.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .transformer import Transformer 2 | from .captioning_model import CaptioningModel 3 | -------------------------------------------------------------------------------- /images/MG-Transformer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/One-paper-luck/MG-Transformer/HEAD/images/MG-Transformer.png -------------------------------------------------------------------------------- /models/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .transformer import * 2 | from .encoders import * 3 | from .decoders import * 4 | from .attention import * 5 | -------------------------------------------------------------------------------- /utils/typing.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Sequence, Tuple 2 | import torch 3 | 4 | TensorOrSequence = Union[Sequence[torch.Tensor], torch.Tensor] 5 | TensorOrNone = Union[torch.Tensor, None] 6 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import download_from_url 2 | from .typing import * 3 | 4 | def get_batch_size(x: TensorOrSequence) -> int: 5 | if isinstance(x, torch.Tensor): # x 2,5,2048 6 | b_s = x.size(0) 7 | else: 8 | b_s = x[0].size(0) 9 | return b_s 10 | 11 | 12 | def get_device(x: TensorOrSequence) -> int: 13 | if isinstance(x, torch.Tensor): 14 | b_s = x.device 15 | else: 16 | b_s = x[0].device 17 | return b_s 18 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .field import RawField, Merge, ImageDetectionsField, TextField 2 | from .dataset import Sydney,UCM,RSICD 3 | from torch.utils.data import DataLoader as TorchDataLoader 4 | 5 | class DataLoader(TorchDataLoader): 6 | def __init__(self, dataset, *args, **kwargs): 7 | super(DataLoader, self).__init__(dataset, *args, collate_fn=dataset.collate_fn(), **kwargs) 8 | 9 | def __len__(self): 10 | return len(self.dataset.examples)//self.batch_size 11 | -------------------------------------------------------------------------------- /misc/resnet_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class myResnet(nn.Module): # 7 | def __init__(self, resnet): 8 | super(myResnet, self).__init__() 9 | self.resnet = resnet 10 | 11 | def forward(self, img, att_size=14): 12 | x = img.unsqueeze(0) 13 | x = self.resnet.conv1(x) 14 | x = self.resnet.bn1(x) 15 | x = self.resnet.relu(x) 16 | x = self.resnet.maxpool(x) 17 | x = self.resnet.layer1(x) 18 | x = self.resnet.layer2(x) 19 | x = self.resnet.layer3(x) 20 | x = self.resnet.layer4(x) 21 | 22 | return x.squeeze() 23 | -------------------------------------------------------------------------------- /data/example.py: -------------------------------------------------------------------------------- 1 | class Example(object): 2 | """Defines a single training or test example. 3 | Stores each column of the example as an attribute. 4 | 定义单个训练或测试example。 5 | 将example的每一列存储为一个属性。 6 | """ 7 | 8 | @classmethod 9 | def fromdict(cls, data): 10 | ex = cls(data) 11 | return ex 12 | 13 | def __init__(self, data): 14 | for key, val in data.items(): 15 | super(Example, self).__setattr__(key, val) 16 | 17 | def __setattr__(self, key, value): 18 | raise AttributeError 19 | 20 | def __hash__(self): 21 | return hash(tuple(x for x in self.__dict__.values())) 22 | 23 | def __eq__(self, other): 24 | this = tuple(x for x in self.__dict__.values()) 25 | other = tuple(x for x in other.__dict__.values()) 26 | return this == other 27 | 28 | def __ne__(self, other): 29 | return not self.__eq__(other) 30 | -------------------------------------------------------------------------------- /feature_pro/pre_CLIP_feature.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import clip 3 | from PIL import Image 4 | import numpy as np 5 | import os 6 | 7 | device = "cuda" if torch.cuda.is_available() else "cpu" 8 | model, preprocess = clip.load("ViT-B/32", device=device) 9 | 10 | # sydney 11 | path = r"/media/dmd/ours/mlw/rs/Sydney_Captions/imgs" 12 | output_dir='/media/dmd/ours/mlw/rs/clip_feature/sydney_224' 13 | 14 | 15 | for root, dirs, files in os.walk(path): 16 | for i in files: 17 | i = os.path.join(path, i) 18 | image_name = i.split('/')[-1] 19 | image = preprocess(Image.open(i)).unsqueeze(0).to(device) # 1,3,224,224 20 | with torch.no_grad(): 21 | image_features = model.encode_image(image) # 22 | """ 23 | Note: Delete part of the code in CLIP VisionTransformer as follows: 24 | if self.proj is not None: 25 | x = x @ self.proj 26 | """ 27 | 28 | np.save(os.path.join(output_dir, str(image_name)), image_features.data.cpu().float().numpy()) 29 | 30 | 31 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import requests 2 | 3 | def download_from_url(url, path): 4 | """Download file, with logic (from tensor2tensor) for Google Drive""" 5 | if 'drive.google.com' not in url: 6 | print('Downloading %s; may take a few minutes' % url) 7 | r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}) 8 | with open(path, "wb") as file: 9 | file.write(r.content) 10 | return 11 | print('Downloading from Google Drive; may take a few minutes') 12 | confirm_token = None 13 | session = requests.Session() 14 | response = session.get(url, stream=True) 15 | for k, v in response.cookies.items(): 16 | if k.startswith("download_warning"): 17 | confirm_token = v 18 | 19 | if confirm_token: 20 | url = url + "&confirm=" + confirm_token 21 | response = session.get(url, stream=True) 22 | 23 | chunk_size = 16 * 1024 24 | with open(path, "wb") as f: 25 | for chunk in response.iter_content(chunk_size): 26 | if chunk: 27 | f.write(chunk) 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 One-paper-luck 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MG-Transformer 2 |

3 | MG-Transformer 4 |

5 | 6 | ## Installation and Dependencies 7 | Create the `m2` conda environment using the `environment.yml` file: 8 | ``` 9 | conda env create -f environment.yml 10 | conda activate m2 11 | ``` 12 | ## Data preparation 13 | For the evaluation metrics, Please download the [evaluation.zip](https://pan.baidu.com/s/13ZfH-CMYbW3RsW0-RX7KKQ)(BaiduPan code:wuiu) and extract it to `./evaluation`. 14 | 15 | 16 | For Feature Extraction: 17 | Region feature: `./feature_pro/pre_region_feature.py` 18 | CLIP image embedding: `./feature_pro/pre_CLIP_feature.py` 19 | Group mask matrix: `./feature_pro/split_group.py` 20 | 21 | 22 | ## Train 23 | ``` 24 | python train.py 25 | ``` 26 | 27 | ## Evaluate 28 | ``` 29 | python test.py 30 | ``` 31 | 32 | 33 | # Citation: 34 | ``` 35 | @ARTICLE{10298250, 36 | author={Meng, Lingwu and Wang, Jing and Meng, Ran and Yang, Yang and Xiao, Liang}, 37 | journal={IEEE Transactions on Geoscience and Remote Sensing}, 38 | title={A Multiscale Grouping Transformer with CLIP Latents for Remote Sensing Image Captioning}, 39 | year={2024}, 40 | volume={62}, 41 | number={4703515}, 42 | pages={1-15}, 43 | doi={10.1109/TGRS.2024.3385500}} 44 | ``` 45 | 46 | 47 | 48 | ## Reference: 49 | 1. https://github.com/tylin/coco-caption 50 | 2. https://github.com/aimagelab/meshed-memory-transformer 51 | -------------------------------------------------------------------------------- /misc/loss_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import misc.utils as utils 3 | from misc.rewards import init_scorer, get_self_critical_reward 4 | 5 | 6 | class LossWrapper(torch.nn.Module): 7 | def __init__(self, model, opt): 8 | super(LossWrapper, self).__init__() 9 | self.opt = opt 10 | self.model = model 11 | if opt.label_smoothing > 0: 12 | self.crit = utils.LabelSmoothing(smoothing=opt.label_smoothing) 13 | else: 14 | self.crit = utils.LanguageModelCriterion() 15 | self.rl_crit = utils.RewardCriterion() 16 | 17 | def forward(self, fc_feats, att_feats, labels, masks, att_masks, gts, gt_indices, 18 | sc_flag): 19 | out = {} 20 | if not sc_flag: 21 | loss = self.crit(self.model(fc_feats, att_feats, labels, att_masks), labels[:, 1:], masks[:, 1:]) 22 | else: 23 | self.model.eval() 24 | with torch.no_grad(): 25 | greedy_res, _ = self.model(fc_feats, att_feats, att_masks, mode='sample') 26 | self.model.train() 27 | gen_result, sample_logprobs = self.model(fc_feats, att_feats, att_masks, opt={'sample_method': 'sample'}, 28 | mode='sample') 29 | gts = [gts[_] for _ in gt_indices.tolist()] 30 | reward = get_self_critical_reward(greedy_res, gts, gen_result, self.opt) 31 | reward = torch.from_numpy(reward).float().to(gen_result.device) 32 | loss = self.rl_crit(sample_logprobs, gen_result.data, reward) 33 | out['reward'] = reward[:, 0].mean() 34 | out['loss'] = loss 35 | return out 36 | -------------------------------------------------------------------------------- /models/transformer/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | def l2norm(X, dim=-1, eps=1e-12): 6 | """L2-normalize columns of X 7 | """ 8 | norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps 9 | X = torch.div(X, norm) 10 | return X 11 | 12 | def position_embedding(input, d_model): 13 | input = input.view(-1, 1) 14 | dim = torch.arange(d_model // 2, dtype=torch.float32, device=input.device).view(1, -1) 15 | sin = torch.sin(input / 10000 ** (2 * dim / d_model)) 16 | cos = torch.cos(input / 10000 ** (2 * dim / d_model)) 17 | 18 | out = torch.zeros((input.shape[0], d_model), device=input.device) 19 | out[:, ::2] = sin 20 | out[:, 1::2] = cos 21 | return out 22 | 23 | 24 | def sinusoid_encoding_table(max_len, d_model, padding_idx=None): 25 | pos = torch.arange(max_len, dtype=torch.float32) 26 | out = position_embedding(pos, d_model) 27 | 28 | if padding_idx is not None: 29 | out[padding_idx] = 0 30 | return out 31 | 32 | 33 | class PositionWiseFeedForward(nn.Module): 34 | ''' 35 | Position-wise feed forward layer 36 | ''' 37 | 38 | def __init__(self, d_model=512, d_ff=2048, dropout=.1, identity_map_reordering=False): 39 | super(PositionWiseFeedForward, self).__init__() 40 | self.identity_map_reordering = identity_map_reordering 41 | self.fc1 = nn.Linear(d_model, d_ff) 42 | self.fc2 = nn.Linear(d_ff, d_model) 43 | self.dropout = nn.Dropout(p=dropout) 44 | self.dropout_2 = nn.Dropout(p=dropout) 45 | self.layer_norm = nn.LayerNorm(d_model) 46 | 47 | def forward(self, input): 48 | if self.identity_map_reordering: 49 | out = self.layer_norm(input) 50 | out = self.fc2(self.dropout_2(F.relu(self.fc1(out)))) 51 | out = input + self.dropout(torch.relu(out)) 52 | else: 53 | out = self.fc2(self.dropout_2(F.relu(self.fc1(input)))) 54 | out = self.dropout(out) 55 | out = self.layer_norm(input + out) 56 | return out -------------------------------------------------------------------------------- /misc/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils import model_zoo 4 | import torchvision.models.resnet 5 | from torchvision.models.resnet import BasicBlock, Bottleneck 6 | 7 | class ResNet(torchvision.models.resnet.ResNet): 8 | def __init__(self, block, layers, num_classes=1000): 9 | super(ResNet, self).__init__(block, layers, num_classes) 10 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) # change 11 | for i in range(2, 5): 12 | getattr(self, 'layer%d'%i)[0].conv1.stride = (2,2) 13 | getattr(self, 'layer%d'%i)[0].conv2.stride = (1,1) 14 | 15 | def resnet18(pretrained=False): 16 | """Constructs a ResNet-18 model. 17 | 18 | Args: 19 | pretrained (bool): If True, returns a model pre-trained on ImageNet 20 | """ 21 | model = ResNet(BasicBlock, [2, 2, 2, 2]) 22 | if pretrained: 23 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 24 | return model 25 | 26 | 27 | def resnet34(pretrained=False): 28 | """Constructs a ResNet-34 model. 29 | 30 | Args: 31 | pretrained (bool): If True, returns a model pre-trained on ImageNet 32 | """ 33 | model = ResNet(BasicBlock, [3, 4, 6, 3]) 34 | if pretrained: 35 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 36 | return model 37 | 38 | 39 | def resnet50(pretrained=False): 40 | """Constructs a ResNet-50 model. 41 | 42 | Args: 43 | pretrained (bool): If True, returns a model pre-trained on ImageNet 44 | """ 45 | model = ResNet(Bottleneck, [3, 4, 6, 3]) 46 | if pretrained: 47 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 48 | return model 49 | 50 | 51 | def resnet101(pretrained=False): 52 | """Constructs a ResNet-101 model. 53 | 54 | Args: 55 | pretrained (bool): If True, returns a model pre-trained on ImageNet 56 | """ 57 | model = ResNet(Bottleneck, [3, 4, 23, 3]) 58 | if pretrained: 59 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 60 | return model 61 | 62 | 63 | def resnet152(pretrained=False): 64 | """Constructs a ResNet-152 model. 65 | 66 | Args: 67 | pretrained (bool): If True, returns a model pre-trained on ImageNet 68 | """ 69 | model = ResNet(Bottleneck, [3, 8, 36, 3]) 70 | if pretrained: 71 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 72 | return model -------------------------------------------------------------------------------- /misc/rewards.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | import time 7 | import misc.utils as utils 8 | from collections import OrderedDict 9 | import torch 10 | 11 | 12 | from cider.pyciderevalcap.ciderD.ciderD import CiderD 13 | from cococaption.pycocoevalcap.bleu.bleu import Bleu 14 | 15 | CiderD_scorer = None 16 | Bleu_scorer = None 17 | #CiderD_scorer = CiderD(df='corpus') 18 | 19 | def init_scorer(cached_tokens): 20 | global CiderD_scorer 21 | CiderD_scorer = CiderD_scorer or CiderD(df=cached_tokens) 22 | global Bleu_scorer 23 | Bleu_scorer = Bleu_scorer or Bleu(4) 24 | 25 | def array_to_str(arr): 26 | out = '' 27 | for i in range(len(arr)): 28 | out += str(arr[i]) + ' ' 29 | if arr[i] == 0: 30 | break 31 | return out.strip() 32 | 33 | def get_self_critical_reward(greedy_res, data_gts, gen_result, opt): 34 | batch_size = gen_result.size(0)# batch_size = sample_size * seq_per_img 35 | seq_per_img = batch_size // len(data_gts) 36 | 37 | res = OrderedDict() 38 | 39 | gen_result = gen_result.data.cpu().numpy() 40 | greedy_res = greedy_res.data.cpu().numpy() 41 | for i in range(batch_size): 42 | res[i] = [array_to_str(gen_result[i])] 43 | for i in range(batch_size): 44 | res[batch_size + i] = [array_to_str(greedy_res[i])] 45 | 46 | gts = OrderedDict() 47 | for i in range(len(data_gts)): 48 | gts[i] = [array_to_str(data_gts[i][j]) for j in range(len(data_gts[i]))] 49 | 50 | res_ = [{'image_id':i, 'caption': res[i]} for i in range(2 * batch_size)] 51 | res__ = {i: res[i] for i in range(2 * batch_size)} 52 | gts = {i: gts[i % batch_size // seq_per_img] for i in range(2 * batch_size)} 53 | if opt.cider_reward_weight > 0: 54 | _, cider_scores = CiderD_scorer.compute_score(gts, res_) 55 | print('Cider scores:', _) 56 | else: 57 | cider_scores = 0 58 | if opt.bleu_reward_weight > 0: 59 | _, bleu_scores = Bleu_scorer.compute_score(gts, res__) 60 | bleu_scores = np.array(bleu_scores[3]) 61 | print('Bleu scores:', _[3]) 62 | else: 63 | bleu_scores = 0 64 | scores = opt.cider_reward_weight * cider_scores + opt.bleu_reward_weight * bleu_scores 65 | 66 | scores = scores[:batch_size] - scores[batch_size:] 67 | 68 | rewards = np.repeat(scores[:, np.newaxis], gen_result.shape[1], 1) 69 | 70 | return rewards 71 | -------------------------------------------------------------------------------- /models/transformer/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import copy 4 | from models.containers import ModuleList 5 | from ..captioning_model import CaptioningModel 6 | 7 | 8 | class Transformer(CaptioningModel): 9 | def __init__(self, bos_idx, encoder, decoder): 10 | super(Transformer, self).__init__() 11 | self.bos_idx = bos_idx 12 | self.encoder = encoder 13 | self.decoder = decoder 14 | self.register_state('enc_output', None) 15 | self.register_state('mask_enc', None) 16 | self.init_weights() 17 | 18 | @property 19 | def d_model(self): 20 | return self.decoder.d_model 21 | 22 | def init_weights(self): 23 | for p in self.parameters(): 24 | if p.dim() > 1: 25 | nn.init.xavier_uniform_(p) 26 | 27 | def forward(self, images, images_gl, detections_mask, seq, isencoder=None, *args): 28 | 29 | enc_output, mask_enc = self.encoder(images, images_gl,detections_mask,isencoder=isencoder) 30 | dec_output = self.decoder(seq, enc_output, mask_enc) 31 | return dec_output 32 | 33 | def init_state(self, b_s, device): 34 | return [torch.zeros((b_s, 0), dtype=torch.long, device=device), 35 | None, None] 36 | 37 | def step(self, t, prev_output, visual, visual_gl,detections_mask, seq, mode='teacher_forcing', **kwargs): 38 | it = None 39 | if mode == 'teacher_forcing': 40 | raise NotImplementedError 41 | elif mode == 'feedback': 42 | if t == 0: 43 | self.enc_output, self.mask_enc = self.encoder(visual,visual_gl,detections_mask,isencoder=True) 44 | if isinstance(visual, torch.Tensor): 45 | it = visual.data.new_full((visual.shape[0], 1), self.bos_idx).long() 46 | else: 47 | it = visual[0].data.new_full((visual[0].shape[0], 1), self.bos_idx).long() 48 | else: 49 | it = prev_output 50 | 51 | return self.decoder(it, self.enc_output, self.mask_enc) 52 | 53 | 54 | class TransformerEnsemble(CaptioningModel): 55 | def __init__(self, model: Transformer, weight_files): 56 | super(TransformerEnsemble, self).__init__() 57 | self.n = len(weight_files) 58 | self.models = ModuleList([copy.deepcopy(model) for _ in range(self.n)]) 59 | for i in range(self.n): 60 | state_dict_i = torch.load(weight_files[i])['state_dict'] 61 | self.models[i].load_state_dict(state_dict_i) 62 | 63 | def step(self, t, prev_output, visual, seq, mode='teacher_forcing', **kwargs): 64 | out_ensemble = [] 65 | for i in range(self.n): 66 | out_i = self.models[i].step(t, prev_output, visual, seq, mode, **kwargs) 67 | out_ensemble.append(out_i.unsqueeze(0)) 68 | 69 | return torch.mean(torch.cat(out_ensemble, 0), dim=0) 70 | -------------------------------------------------------------------------------- /feature_pro/pre_region_feature.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import json 7 | import argparse 8 | import numpy as np 9 | import torch 10 | import skimage.io 11 | import skimage.transform 12 | from torchvision import transforms as trn 13 | from misc.resnet_utils import myResnet 14 | import misc.resnet as resnet 15 | 16 | preprocess = trn.Compose([ 17 | # trn.ToTensor(), 18 | trn.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 19 | ]) 20 | 21 | 22 | def main(params): 23 | net = getattr(resnet, params['model'])() 24 | net.load_state_dict(torch.load(os.path.join(params['model_root'], params['model'] + '.pth'))) 25 | my_resnet = myResnet(net) 26 | my_resnet.cuda() 27 | my_resnet.eval() 28 | 29 | imgs = json.load(open(params['input_json'], 'r')) 30 | imgs = imgs['images'] 31 | N = len(imgs) 32 | 33 | for i, img in enumerate(imgs): 34 | I = skimage.io.imread(os.path.join(params['images_root'], img['filename'])) # sydney 500 500 3 35 | I = skimage.transform.resize(I, (224, 224)) 36 | 37 | if len(I.shape) == 2: 38 | I = I[:, :, np.newaxis] 39 | I = np.concatenate((I, I, I), axis=2) 40 | 41 | I = I.astype('float32') / 255.0 42 | I = torch.from_numpy(I.transpose([2, 0, 1])).cuda() 43 | I = preprocess(I) 44 | 45 | with torch.no_grad(): 46 | conv5= my_resnet(I,params['att_size']) 47 | 48 | # np.save(os.path.join(params['output_dir'], str(img['filename'])), conv5.data.cpu().float().numpy()) 49 | 50 | if i % 1000 == 0: 51 | print('processing %d/%d (%.2f%% done)' % (i, N, i * 100.0 / N)) 52 | print('wrote ', params['output_dir']) 53 | 54 | 55 | if __name__ == "__main__": 56 | parser = argparse.ArgumentParser() 57 | parser.add_argument('--input_json', default='/media/dmd/ours/mlw/rs/Sydney_Captions/dataset.json') 58 | # parser.add_argument('--output_dir', default='/media/dmd/ours/mlw/rs/multi_scale_T/sydney_res152_con5_500') 59 | parser.add_argument('--output_dir', default='/media/dmd/ours/mlw/rs/multi_scale_T/sydney_name') 60 | parser.add_argument('--images_root', default='/media/dmd/ours/mlw/rs/Sydney_Captions/imgs', 61 | help='root location in which images are stored, to be prepended to file_path in input json') 62 | 63 | parser.add_argument('--att_size', default=14, type=int, help='14x14 or 7x7') 64 | parser.add_argument('--model', default='resnet152', type=str, help='resnet50,resnet101, resnet152') 65 | parser.add_argument('--model_root', default='/media/dmd/ours/mlw/pre_model', type=str, 66 | help='model root') 67 | 68 | args = parser.parse_args() 69 | params = vars(args) # convert to ordinary dict 70 | print('parsed input parameters:') 71 | print(json.dumps(params, indent=2)) 72 | main(params) 73 | -------------------------------------------------------------------------------- /models/containers.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from torch import nn 3 | from utils.typing import * 4 | 5 | 6 | class Module(nn.Module): 7 | def __init__(self): 8 | super(Module, self).__init__() 9 | self._is_stateful = False 10 | self._state_names = [] 11 | self._state_defaults = dict() 12 | 13 | def register_state(self, name: str, default: TensorOrNone): 14 | self._state_names.append(name) 15 | if default is None: 16 | self._state_defaults[name] = None 17 | else: 18 | self._state_defaults[name] = default.clone().detach() 19 | self.register_buffer(name, default) 20 | 21 | def states(self): 22 | for name in self._state_names: 23 | yield self._buffers[name] 24 | for m in self.children(): 25 | if isinstance(m, Module): 26 | yield from m.states() 27 | 28 | def apply_to_states(self, fn): 29 | for name in self._state_names: 30 | self._buffers[name] = fn(self._buffers[name]) 31 | for m in self.children(): 32 | if isinstance(m, Module): 33 | m.apply_to_states(fn) 34 | 35 | def _init_states(self, batch_size: int): 36 | for name in self._state_names: 37 | if self._state_defaults[name] is None: 38 | self._buffers[name] = None 39 | else: 40 | self._buffers[name] = self._state_defaults[name].clone().detach().to(self._buffers[name].device) 41 | self._buffers[name] = self._buffers[name].unsqueeze(0) 42 | self._buffers[name] = self._buffers[name].expand([batch_size, ] + list(self._buffers[name].shape[1:])) 43 | self._buffers[name] = self._buffers[name].contiguous() 44 | 45 | def _reset_states(self): 46 | for name in self._state_names: 47 | if self._state_defaults[name] is None: 48 | self._buffers[name] = None 49 | else: 50 | self._buffers[name] = self._state_defaults[name].clone().detach().to(self._buffers[name].device) 51 | 52 | def enable_statefulness(self, batch_size: int): 53 | for m in self.children(): 54 | if isinstance(m, Module): 55 | m.enable_statefulness(batch_size) 56 | self._init_states(batch_size) 57 | self._is_stateful = True 58 | 59 | def disable_statefulness(self): 60 | for m in self.children(): 61 | if isinstance(m, Module): 62 | m.disable_statefulness() 63 | self._reset_states() 64 | self._is_stateful = False 65 | 66 | @contextmanager 67 | def statefulness(self, batch_size: int): 68 | self.enable_statefulness(batch_size) 69 | try: 70 | yield 71 | finally: 72 | self.disable_statefulness() 73 | 74 | 75 | class ModuleList(nn.ModuleList, Module): 76 | pass 77 | 78 | 79 | class ModuleDict(nn.ModuleDict, Module): 80 | pass 81 | -------------------------------------------------------------------------------- /models/captioning_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import distributions 3 | import utils 4 | from models.containers import Module 5 | from models.beam_search import * 6 | 7 | 8 | class CaptioningModel(Module): 9 | def __init__(self): 10 | super(CaptioningModel, self).__init__() 11 | 12 | def init_weights(self): 13 | raise NotImplementedError 14 | 15 | def step(self, t, prev_output, visual, seq, mode='teacher_forcing', **kwargs): 16 | raise NotImplementedError 17 | 18 | def forward(self, images, seq, *args): 19 | device = images.device 20 | b_s = images.size(0) 21 | seq_len = seq.size(1) 22 | state = self.init_state(b_s, device) 23 | out = None 24 | 25 | outputs = [] 26 | for t in range(seq_len): 27 | out, state = self.step(t, state, out, images, seq, *args, mode='teacher_forcing') 28 | outputs.append(out) 29 | 30 | outputs = torch.cat([o.unsqueeze(1) for o in outputs], 1) 31 | return outputs 32 | 33 | def test(self, visual: utils.TensorOrSequence, max_len: int, eos_idx: int, **kwargs) -> utils.Tuple[torch.Tensor, torch.Tensor]: 34 | b_s = utils.get_batch_size(visual) 35 | device = utils.get_device(visual) 36 | outputs = [] 37 | log_probs = [] 38 | 39 | mask = torch.ones((b_s,), device=device) 40 | with self.statefulness(b_s): 41 | out = None 42 | for t in range(max_len): 43 | log_probs_t = self.step(t, out, visual, None, mode='feedback', **kwargs) 44 | out = torch.max(log_probs_t, -1)[1] 45 | mask = mask * (out.squeeze(-1) != eos_idx).float() 46 | log_probs.append(log_probs_t * mask.unsqueeze(-1).unsqueeze(-1)) 47 | outputs.append(out) 48 | 49 | return torch.cat(outputs, 1), torch.cat(log_probs, 1) 50 | 51 | def sample_rl(self, visual: utils.TensorOrSequence, max_len: int, **kwargs) -> utils.Tuple[torch.Tensor, torch.Tensor]: 52 | b_s = utils.get_batch_size(visual) 53 | outputs = [] 54 | log_probs = [] 55 | 56 | with self.statefulness(b_s): 57 | out = None 58 | for t in range(max_len): 59 | out = self.step(t, out, visual, None, mode='feedback', **kwargs) 60 | distr = distributions.Categorical(logits=out[:, 0]) 61 | out = distr.sample().unsqueeze(1) 62 | outputs.append(out) 63 | log_probs.append(distr.log_prob(out).unsqueeze(1)) 64 | 65 | return torch.cat(outputs, 1), torch.cat(log_probs, 1) 66 | 67 | def beam_search(self, visual: utils.TensorOrSequence, visual_gl: utils.TensorOrSequence,detections_mask: utils.TensorOrSequence, 68 | max_len: int, eos_idx: int, beam_size: int, out_size=1,return_probs=False, **kwargs): 69 | bs = BeamSearch(self, max_len, eos_idx, beam_size) 70 | return bs.apply(visual, visual_gl,detections_mask, out_size, return_probs, **kwargs) 71 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: m2 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _openmp_mutex=5.1=1_gnu 7 | - blas=1.0=mkl 8 | - ca-certificates=2022.07.19=h06a4308_0 9 | - certifi=2021.5.30=py36h06a4308_0 10 | - intel-openmp=2022.0.1=h06a4308_3633 11 | - joblib=1.0.1=pyhd3eb1b0_0 12 | - ld_impl_linux-64=2.38=h1181459_1 13 | - libffi=3.3=he6710b0_2 14 | - libgcc-ng=11.2.0=h1234567_1 15 | - libgfortran-ng=7.5.0=ha8ba4b0_17 16 | - libgfortran4=7.5.0=ha8ba4b0_17 17 | - libgomp=11.2.0=h1234567_1 18 | - libstdcxx-ng=11.2.0=h1234567_1 19 | - mkl=2020.2=256 20 | - mkl-service=2.3.0=py36he8ac12f_0 21 | - mkl_fft=1.3.0=py36h54f3939_0 22 | - mkl_random=1.1.1=py36h0573a6f_0 23 | - ncurses=6.3=h7f8727e_2 24 | - numpy-base=1.19.2=py36hfa32c7d_0 25 | - openssl=1.1.1q=h7f8727e_0 26 | - pip=21.2.2=py36h06a4308_0 27 | - python=3.6.13=h12debd9_1 28 | - readline=8.1.2=h7f8727e_1 29 | - scikit-learn=0.24.2=py36ha9443f7_0 30 | - setuptools=58.0.4=py36h06a4308_0 31 | - six=1.16.0=pyhd3eb1b0_1 32 | - sqlite=3.38.3=hc218d9a_0 33 | - threadpoolctl=2.2.0=pyh0d69192_0 34 | - tk=8.6.12=h1ccaba5_0 35 | - wheel=0.37.1=pyhd3eb1b0_0 36 | - xz=5.2.5=h7f8727e_1 37 | - zlib=1.2.12=h7f8727e_2 38 | - pip: 39 | - absl-py==1.1.0 40 | - blis==0.2.4 41 | - cached-property==1.5.2 42 | - cachetools==4.2.4 43 | - charset-normalizer==2.0.12 44 | - cycler==0.11.0 45 | - cymem==2.0.6 46 | - cython==0.29.30 47 | - dataclasses==0.8 48 | - decorator==4.4.2 49 | - en-core-web-md==2.1.0 50 | - en-core-web-sm==2.1.0 51 | - google-auth==2.7.0 52 | - google-auth-oauthlib==0.4.6 53 | - grpcio==1.46.3 54 | - h5py==3.1.0 55 | - idna==3.3 56 | - imagecodecs==2020.5.30 57 | - imageio==2.15.0 58 | - importlib-metadata==4.8.3 59 | - importlib-resources==5.4.0 60 | - jsonschema==2.6.0 61 | - kiwisolver==1.3.1 62 | - markdown==3.3.7 63 | - matplotlib==3.3.4 64 | - murmurhash==1.0.7 65 | - networkx==2.5.1 66 | - numpy==1.19.5 67 | - oauthlib==3.2.0 68 | - opencv-contrib-python==4.6.0.66 69 | - opencv-python==4.5.1.48 70 | - param==1.13.0 71 | - pillow==8.4.0 72 | - plac==0.9.6 73 | - preshed==2.0.1 74 | - protobuf==3.19.4 75 | - pyasn1==0.4.8 76 | - pyasn1-modules==0.2.8 77 | - pycocotools==2.0.1 78 | - pyct==0.4.8 79 | - pyheatmap==0.1.12 80 | - pyparsing==3.0.9 81 | - python-dateutil==2.8.2 82 | - pywavelets==1.1.1 83 | - requests==2.27.1 84 | - requests-oauthlib==1.3.1 85 | - rsa==4.8 86 | - scikit-image==0.17.2 87 | - scipy==1.5.4 88 | - spacy==2.1.0 89 | - srsly==1.0.5 90 | - tensorboard==2.9.0 91 | - tensorboard-data-server==0.6.1 92 | - tensorboard-plugin-wit==1.8.1 93 | - thinc==7.0.8 94 | - thop==0.1.1-2209072238 95 | - tifffile==2020.9.3 96 | - torch==1.10.0+cu113 97 | - torchvision==0.11.0+cu113 98 | - tqdm==4.64.0 99 | - typing-extensions==4.1.1 100 | - urllib3==1.26.9 101 | - wasabi==0.9.1 102 | - werkzeug==2.0.3 103 | - zipp==3.6.0 104 | -------------------------------------------------------------------------------- /data/utils.py: -------------------------------------------------------------------------------- 1 | import contextlib, sys 2 | 3 | 4 | class DummyFile(object): 5 | def write(self, x): pass 6 | 7 | 8 | @contextlib.contextmanager 9 | def nostdout(): 10 | save_stdout = sys.stdout 11 | sys.stdout = DummyFile() 12 | yield 13 | sys.stdout = save_stdout 14 | 15 | 16 | def reporthook(t): 17 | """https://github.com/tqdm/tqdm""" 18 | last_b = [0] 19 | 20 | def inner(b=1, bsize=1, tsize=None): 21 | """ 22 | b: int, optionala 23 | Number of blocks just transferred [default: 1]. 24 | bsize: int, optional 25 | Size of each block (in tqdm units) [default: 1]. 26 | tsize: int, optional 27 | Total size (in tqdm units). If [default: None] remains unchanged. 28 | """ 29 | if tsize is not None: 30 | t.total = tsize 31 | t.update((b - last_b[0]) * bsize) 32 | last_b[0] = b 33 | 34 | return inner 35 | 36 | 37 | def get_tokenizer(tokenizer): 38 | if callable(tokenizer): 39 | return tokenizer 40 | if tokenizer == "spacy": 41 | try: 42 | import spacy 43 | # spacy_en = spacy.load('en') 44 | spacy_en = spacy.load('en_core_web_sm') 45 | return lambda s: [tok.text for tok in spacy_en.tokenizer(s)] 46 | except ImportError: 47 | print("Please install SpaCy and the SpaCy English tokenizer. " 48 | "See the docs at https://spacy.io for more information.") 49 | raise 50 | except AttributeError: 51 | print("Please install SpaCy and the SpaCy English tokenizer. " 52 | "See the docs at https://spacy.io for more information.") 53 | raise 54 | elif tokenizer == "moses": 55 | try: 56 | from nltk.tokenize.moses import MosesTokenizer 57 | moses_tokenizer = MosesTokenizer() 58 | return moses_tokenizer.tokenize 59 | except ImportError: 60 | print("Please install NLTK. " 61 | "See the docs at http://nltk.org for more information.") 62 | raise 63 | except LookupError: 64 | print("Please install the necessary NLTK corpora. " 65 | "See the docs at http://nltk.org for more information.") 66 | raise 67 | elif tokenizer == 'revtok': 68 | try: 69 | import revtok 70 | return revtok.tokenize 71 | except ImportError: 72 | print("Please install revtok.") 73 | raise 74 | elif tokenizer == 'subword': 75 | try: 76 | import revtok 77 | return lambda x: revtok.tokenize(x, decap=True) 78 | except ImportError: 79 | print("Please install revtok.") 80 | raise 81 | raise ValueError("Requested tokenizer {}, valid choices are a " 82 | "callable that takes a single string as input, " 83 | "\"revtok\" for the revtok reversible tokenizer, " 84 | "\"subword\" for the revtok caps-aware tokenizer, " 85 | "\"spacy\" for the SpaCy English tokenizer, or " 86 | "\"moses\" for the NLTK port of the Moses tokenization " 87 | "script.".format(tokenizer)) 88 | -------------------------------------------------------------------------------- /feature_pro/split_group.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | import random 6 | random.seed(1234) 7 | torch.manual_seed(1234) 8 | np.random.seed(1234) 9 | 10 | def readInfo(filePath): 11 | name = os.listdir(filePath) 12 | return name 13 | 14 | def data_normal(data): 15 | d_min=data.min() 16 | if d_min<0: 17 | data+=torch.abs(d_min) 18 | d_min=data.min() 19 | 20 | d_max=data.max() 21 | dst=d_max-d_min 22 | nor_data=(data-d_min).true_divide(dst) 23 | return nor_data 24 | 25 | 26 | features_path='/media/dmd/ours/mlw/rs/multi_scale_T/sydney_res152_con5_224' 27 | out_path='/media/dmd/ours/mlw/rs/multi_scale_T/sydney_group_mask_8_6' 28 | 29 | num=6 30 | 31 | 32 | 33 | fileList = readInfo(features_path) 34 | 35 | for i, img in enumerate(fileList): 36 | res_feature=np.load(features_path + '/%s' % img) 37 | r_feature = res_feature.reshape(res_feature.shape[0], res_feature.shape[1] * res_feature.shape[2]) 38 | r_feature=torch.Tensor(r_feature) 39 | 40 | res_feature=torch.Tensor(res_feature) 41 | avg_feature = F.adaptive_avg_pool2d(res_feature, [1, 1]).squeeze() 42 | avg_feature=avg_feature.unsqueeze(0) 43 | p=torch.matmul(avg_feature, r_feature) 44 | 45 | pp=data_normal(p) 46 | 47 | # divide group 8 48 | group_mask = torch.ones((8, 49, 49)).cuda() # res 49 | 50 | index_1 = torch.nonzero(pp <= 1 / 8, as_tuple=False)[:, 1] 51 | index_2 = torch.nonzero((pp <= 2 / 8) & (pp > 1 / 8), as_tuple=False)[:, 1] 52 | index_3 = torch.nonzero((pp <= 3 / 8) & (pp > 2 / 8), as_tuple=False)[:, 1] 53 | index_4 = torch.nonzero((pp <= 4 / 8) & (pp > 3 / 8), as_tuple=False)[:, 1] 54 | index_5 = torch.nonzero((pp <= 5 / 8) & (pp > 4 / 8), as_tuple=False)[:, 1] 55 | index_6 = torch.nonzero((pp <= 6 / 8) & (pp > 5 / 8), as_tuple=False)[:, 1] 56 | index_7 = torch.nonzero((pp <= 7 / 8) & (pp > 3 / 8), as_tuple=False)[:, 1] 57 | index_8 = torch.nonzero((pp <= 8 / 8) & (pp > 7 / 8), as_tuple=False)[:, 1] 58 | 59 | 60 | 61 | for k in range(8): 62 | list = random.sample(range(1, 9), num) 63 | for j in range(2): 64 | if list[j] == 1: 65 | length = index_1.__len__() 66 | for i in range(length): 67 | group_mask[k][:, index_1[i]] = 0 68 | elif list[j] == 2: 69 | length = index_2.__len__() 70 | for i in range(length): 71 | group_mask[k][:, index_2[i]] = 0 72 | elif list[j] == 3: 73 | length = index_3.__len__() 74 | for i in range(length): 75 | group_mask[k][:, index_3[i]] = 0 76 | elif list[j] == 4: 77 | length = index_4.__len__() 78 | for i in range(length): 79 | group_mask[k][:, index_4[i]] = 0 80 | elif list[j] == 5: 81 | length = index_5.__len__() 82 | for i in range(length): 83 | group_mask[k][:, index_5[i]] = 0 84 | elif list[j] == 6: 85 | length = index_6.__len__() 86 | for i in range(length): 87 | group_mask[k][:, index_6[i]] = 0 88 | elif list[j] == 7: 89 | length = index_7.__len__() 90 | for i in range(length): 91 | group_mask[k][:, index_7[i]] = 0 92 | elif list[j] == 8: 93 | length = index_8.__len__() 94 | for i in range(length): 95 | group_mask[k][:, index_8[i]] = 0 96 | 97 | np.save(os.path.join(out_path, img), group_mask.data.cpu().float().numpy()) # size: 8,49,49 98 | 99 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import random 2 | from data import ImageDetectionsField, TextField, RawField 3 | from data import DataLoader, Sydney, UCM, RSICD 4 | import evaluation 5 | from models.transformer import Transformer, MemoryAugmentedEncoder, MeshedDecoder, GlobalGroupingAttention_with_DC 6 | import torch 7 | from tqdm import tqdm 8 | import argparse, os, pickle 9 | import numpy as np 10 | import warnings 11 | warnings.filterwarnings("ignore") 12 | 13 | 14 | 15 | def seed_torch(seed=1): 16 | random.seed(seed) 17 | os.environ['PYTHONHASHSEED'] = str(seed) 18 | np.random.seed(seed) 19 | torch.manual_seed(seed) 20 | torch.cuda.manual_seed(seed) 21 | torch.backends.cudnn.deterministic = True 22 | torch.backends.cudnn.benchmark = False 23 | seed_torch() 24 | 25 | 26 | def predict_captions(model, dataloader, text_field): 27 | import itertools 28 | model.eval() 29 | gen = {} 30 | gts = {} 31 | with tqdm(desc='Evaluation', unit='it', total=len(dataloader)) as pbar: 32 | for it, (images, caps_gt) in enumerate(dataloader): 33 | detections = images[0].to(device) 34 | detections_gl = images[1].to(device) 35 | detections_mask = images[2].to(device) 36 | 37 | with torch.no_grad(): 38 | out, _ = model.beam_search(detections, detections_gl, detections_mask, 20, 39 | text_field.vocab.stoi[''], 5, out_size=1) 40 | caps_gen = text_field.decode(out, join_words=False) 41 | for i, (gts_i, gen_i) in enumerate(zip(caps_gt, caps_gen)): 42 | gen_i = ' '.join([k for k, g in itertools.groupby(gen_i)]) 43 | gen['%d_%d' % (it, i)] = [gen_i, ] 44 | gts['%d_%d' % (it, i)] = gts_i 45 | pbar.update() 46 | 47 | gts = evaluation.PTBTokenizer.tokenize(gts) 48 | gen = evaluation.PTBTokenizer.tokenize(gen) 49 | scores, _ = evaluation.compute_scores(gts, gen) 50 | return scores 51 | 52 | 53 | 54 | 55 | if __name__ == '__main__': 56 | device = torch.device('cuda') 57 | 58 | parser = argparse.ArgumentParser(description='MG-Transformer') 59 | parser.add_argument('--exp_name', type=str, default='Sydney') # Sydney,UCM,RSICD 60 | parser.add_argument('--batch_size', type=int, default=50) 61 | parser.add_argument('--workers', type=int, default=0) 62 | 63 | # sydney 64 | parser.add_argument('--annotation_folder', type=str, 65 | default='/media/dmd/ours/mlw/rs/Sydney_Captions') 66 | parser.add_argument('--res_features_path', type=str, 67 | default='/media/dmd/ours/mlw/rs/multi_scale_T/sydney_res152_con5_224') 68 | 69 | parser.add_argument('--clip_features_path', type=str, 70 | default='/media/dmd/ours/mlw/rs/clip_feature/sydney_224') 71 | 72 | parser.add_argument('--mask_features_path', type=str, 73 | default='/media/dmd/ours/mlw/rs/multi_scale_T/sydney_group_mask_8_6') 74 | 75 | 76 | args = parser.parse_args() 77 | 78 | print('Transformer Evaluation') 79 | 80 | # Pipeline for image regions 81 | image_field = ImageDetectionsField(clip_features_path=args.clip_features_path, 82 | res_features_path=args.res_features_path, 83 | mask_features_path=args.mask_features_path) 84 | # Pipeline for text 85 | text_field = TextField(init_token='', eos_token='', lower=True, tokenize='spacy', 86 | remove_punctuation=True, nopoints=False) 87 | 88 | 89 | # Create the dataset Sydney,UCM,RSICD 90 | if args.exp_name == 'Sydney': 91 | dataset = Sydney(image_field, text_field, 'Sydney/images/', args.annotation_folder, args.annotation_folder) 92 | elif args.exp_name == 'UCM': 93 | dataset = UCM(image_field, text_field, 'UCM/images/', args.annotation_folder, args.annotation_folder) 94 | elif args.exp_name == 'RSICD': 95 | dataset = RSICD(image_field, text_field, 'RSICD/images/', args.annotation_folder, args.annotation_folder) 96 | 97 | _, _, test_dataset = dataset.splits 98 | 99 | text_field.vocab = pickle.load(open('vocab_%s.pkl' % args.exp_name, 'rb')) 100 | 101 | 102 | 103 | # Model and dataloaders 104 | encoder = MemoryAugmentedEncoder(3, 0, attention_module=GlobalGroupingAttention_with_DC) 105 | decoder = MeshedDecoder(len(text_field.vocab), 127, 3, text_field.vocab.stoi['']) 106 | model = Transformer(text_field.vocab.stoi[''], encoder, decoder).to(device) 107 | 108 | 109 | data = torch.load('./saved_models/Sydney_best.pth') 110 | # data = torch.load('./saved_models/Sydney_best_test.pth') 111 | 112 | model.load_state_dict(data['state_dict']) 113 | dict_dataset_test = test_dataset.image_dictionary({'image': image_field, 'text': RawField()}) 114 | dict_dataloader_test = DataLoader(dict_dataset_test, batch_size=args.batch_size, num_workers=args.workers) 115 | 116 | scores = predict_captions(model, dict_dataloader_test, text_field) 117 | 118 | 119 | -------------------------------------------------------------------------------- /models/transformer/encoders.py: -------------------------------------------------------------------------------- 1 | from torch.nn import functional as F 2 | from models.transformer.utils import PositionWiseFeedForward 3 | import torch 4 | from torch import nn 5 | from models.transformer.attention import MultiHeadAttention 6 | from models.transformer.attention import ScaledDotProductAttention 7 | 8 | 9 | from models.transformer.utils import * 10 | 11 | 12 | 13 | class EncoderLayer(nn.Module): 14 | def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, identity_map_reordering=False, 15 | attention_module=None, attention_module_kwargs=None, dilation=None): 16 | super(EncoderLayer, self).__init__() 17 | self.identity_map_reordering = identity_map_reordering 18 | self.mhatt = MultiHeadAttention(d_model, d_k, d_v, h, dropout, identity_map_reordering=identity_map_reordering, 19 | attention_module=attention_module, 20 | attention_module_kwargs=attention_module_kwargs, dilation=dilation) 21 | self.pwff = PositionWiseFeedForward(d_model, d_ff, dropout, identity_map_reordering=identity_map_reordering) 22 | 23 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None, group_mask=None, 24 | input_gl=None, memory=None, isencoder=None): 25 | att = self.mhatt(queries, keys, values, attention_mask, attention_weights, group_mask=group_mask, 26 | input_gl=input_gl, memory=memory, isencoder=isencoder) 27 | ff = self.pwff(att) 28 | return ff 29 | 30 | 31 | class MultiLevelEncoder(nn.Module): 32 | def __init__(self, N, padding_idx, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, 33 | identity_map_reordering=False, attention_module=None, attention_module_kwargs=None): 34 | super(MultiLevelEncoder, self).__init__() 35 | self.d_model = d_model 36 | self.dropout = dropout 37 | 38 | # aoa 39 | self.aoa_layer = nn.Sequential(nn.Linear(2 * self.d_model, 2 * self.d_model), nn.GLU()) 40 | self.dropout_aoa = nn.Dropout(p=self.dropout) 41 | 42 | 43 | self.layers = nn.Sequential(EncoderLayer(d_model, d_k, d_v, h, d_ff, dropout, 44 | identity_map_reordering=identity_map_reordering, 45 | attention_module=attention_module, 46 | attention_module_kwargs=attention_module_kwargs, dilation=2), 47 | EncoderLayer(d_model, d_k, d_v, h, d_ff, dropout, 48 | identity_map_reordering=identity_map_reordering, 49 | attention_module=attention_module, 50 | attention_module_kwargs=attention_module_kwargs, dilation=4), 51 | EncoderLayer(d_model, d_k, d_v, h, d_ff, dropout, 52 | identity_map_reordering=identity_map_reordering, 53 | attention_module=attention_module, 54 | attention_module_kwargs=attention_module_kwargs, dilation=8) 55 | ) 56 | 57 | self.mhatt = MultiHeadAttention(d_model, d_k, d_v, h, dropout, identity_map_reordering=identity_map_reordering, 58 | attention_module=ScaledDotProductAttention, 59 | attention_module_kwargs=attention_module_kwargs) 60 | self.padding_idx = padding_idx 61 | def forward(self, input, input_gl=None, isencoder=None, detections_mask=None, memory=None, attention_weights=None): 62 | # fine module 63 | attention_mask = (torch.sum(input, -1) == self.padding_idx).unsqueeze(1).unsqueeze(1) 64 | input = self.mhatt(input, input, input, attention_mask) 65 | out = self.aoa_layer(self.dropout_aoa(torch.cat([input, input_gl], -1))) 66 | group_mask = detections_mask 67 | outs = [] 68 | for l in self.layers: 69 | out = l(out, out, out, attention_mask, attention_weights, group_mask=group_mask, input_gl=input_gl, 70 | memory=memory,isencoder=isencoder) 71 | outs.append(out.unsqueeze(1)) 72 | 73 | outs = torch.cat(outs, 1) 74 | return outs, attention_mask 75 | 76 | 77 | class MemoryAugmentedEncoder(MultiLevelEncoder): 78 | def __init__(self, N, padding_idx, d_in=2048, d_clip=768, **kwargs): 79 | super(MemoryAugmentedEncoder, self).__init__(N, padding_idx, **kwargs) 80 | self.fc = nn.Linear(d_in, self.d_model) 81 | self.fc_clip = nn.Linear(d_clip, self.d_model) 82 | 83 | self.dropout_1 = nn.Dropout(p=self.dropout) 84 | self.layer_norm_1 = nn.LayerNorm(self.d_model) 85 | 86 | self.dropout_2 = nn.Dropout(p=self.dropout) 87 | self.layer_norm_2 = nn.LayerNorm(self.d_model) 88 | 89 | 90 | def forward(self, input, input_gl=None, detections_mask=None, isencoder=None, attention_weights=None): 91 | input_gl = F.relu(self.fc_clip(input_gl)) 92 | input_gl = self.dropout_1(input_gl) 93 | input_gl = self.layer_norm_1(input_gl) 94 | 95 | out = F.relu(self.fc(input)) 96 | out = self.dropout_2(out) 97 | out = self.layer_norm_2(out) 98 | return super(MemoryAugmentedEncoder, self).forward(out, input_gl=input_gl, detections_mask=detections_mask, 99 | isencoder=isencoder, attention_weights=attention_weights) 100 | -------------------------------------------------------------------------------- /models/transformer/decoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import numpy as np 5 | 6 | from models.transformer.attention import MultiHeadAttention 7 | from models.transformer.utils import sinusoid_encoding_table, PositionWiseFeedForward 8 | from models.containers import Module, ModuleList 9 | 10 | 11 | class MeshedDecoderLayer(Module): 12 | def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, self_att_module=None, 13 | enc_att_module=None, self_att_module_kwargs=None, enc_att_module_kwargs=None): 14 | super(MeshedDecoderLayer, self).__init__() 15 | self.self_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=True, 16 | attention_module=self_att_module, 17 | attention_module_kwargs=self_att_module_kwargs) 18 | self.enc_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=False, 19 | attention_module=enc_att_module, 20 | attention_module_kwargs=enc_att_module_kwargs) 21 | self.pwff = PositionWiseFeedForward(d_model, d_ff, dropout) 22 | 23 | self.fc_alpha1 = nn.Linear(d_model + d_model, d_model) 24 | self.fc_alpha2 = nn.Linear(d_model + d_model, d_model) 25 | self.fc_alpha3 = nn.Linear(d_model + d_model, d_model) 26 | 27 | self.init_weights() 28 | 29 | def init_weights(self): 30 | nn.init.xavier_uniform_(self.fc_alpha1.weight) 31 | nn.init.xavier_uniform_(self.fc_alpha2.weight) 32 | nn.init.xavier_uniform_(self.fc_alpha3.weight) 33 | nn.init.constant_(self.fc_alpha1.bias, 0) 34 | nn.init.constant_(self.fc_alpha2.bias, 0) 35 | nn.init.constant_(self.fc_alpha3.bias, 0) 36 | 37 | 38 | def forward(self, input, enc_output, mask_pad, mask_self_att, mask_enc_att): # decoder(seq, enc_output, mask_enc) 39 | self_att = self.self_att(input, input, input, mask_self_att, input_gl=None, 40 | isencoder=False) 41 | self_att = self_att * mask_pad 42 | enc_att1 = self.enc_att(self_att, enc_output[:, 0], enc_output[:, 0], mask_enc_att) * mask_pad # 10 15 512 43 | enc_att2 = self.enc_att(self_att, enc_output[:, 1], enc_output[:, 1], mask_enc_att) * mask_pad # 10 15 512 44 | enc_att3 = self.enc_att(self_att, enc_output[:, 2], enc_output[:, 2], mask_enc_att) * mask_pad # 10 15 512 45 | alpha1 = torch.sigmoid(self.fc_alpha1(torch.cat([self_att, enc_att1], -1))) # 10 15 512 46 | alpha2 = torch.sigmoid(self.fc_alpha2(torch.cat([self_att, enc_att2], -1))) # 10 15 512 47 | alpha3 = torch.sigmoid(self.fc_alpha3(torch.cat([self_att, enc_att3], -1))) # 10 15 512 48 | enc_att = (enc_att1 * alpha1 + enc_att2 * alpha2 + enc_att3 * alpha3) / np.sqrt(3) 49 | enc_att = enc_att * mask_pad # 10 15 512 50 | 51 | ff = self.pwff(enc_att) 52 | ff = ff * mask_pad 53 | return ff 54 | 55 | 56 | class MeshedDecoder(Module): 57 | def __init__(self, vocab_size, max_len, N_dec, padding_idx, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, 58 | self_att_module=None, enc_att_module=None, self_att_module_kwargs=None, enc_att_module_kwargs=None): 59 | super(MeshedDecoder, self).__init__() 60 | self.d_model = d_model 61 | self.word_emb = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx) 62 | self.pos_emb = nn.Embedding.from_pretrained(sinusoid_encoding_table(max_len+1, d_model, 0), 63 | freeze=True) 64 | self.layers = ModuleList( 65 | [MeshedDecoderLayer(d_model, d_k, d_v, h, d_ff, dropout, self_att_module=self_att_module, 66 | enc_att_module=enc_att_module, self_att_module_kwargs=self_att_module_kwargs, 67 | enc_att_module_kwargs=enc_att_module_kwargs) for _ in range(N_dec)]) 68 | self.fc = nn.Linear(d_model, vocab_size, bias=False) 69 | self.max_len = max_len 70 | self.padding_idx = padding_idx 71 | self.N = N_dec 72 | 73 | self.register_state('running_mask_self_attention', torch.zeros((1, 1, 0)).byte()) 74 | self.register_state('running_seq', torch.zeros((1,)).long()) 75 | 76 | def forward(self, input, encoder_output, mask_encoder): 77 | b_s, seq_len = input.shape[:2] # 10 15 78 | mask_queries = (input != self.padding_idx).unsqueeze(-1).float() 79 | mask_self_attention = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.uint8, device=input.device),diagonal=1) 80 | mask_self_attention = mask_self_attention.unsqueeze(0).unsqueeze(0) 81 | mask_self_attention = mask_self_attention + (input == self.padding_idx).unsqueeze(1).unsqueeze(1).byte() 82 | mask_self_attention = mask_self_attention.gt(0) 83 | 84 | if self._is_stateful: # false 85 | device = self.running_mask_self_attention.device 86 | mask_self_attention = torch.tensor(mask_self_attention, dtype=torch.uint8).to(device) 87 | self.running_mask_self_attention = torch.cat([self.running_mask_self_attention, mask_self_attention],-1) 88 | mask_self_attention = self.running_mask_self_attention 89 | 90 | seq = torch.arange(1, seq_len + 1).view(1, -1).expand(b_s, -1).to(input.device) 91 | seq = seq.masked_fill(mask_queries.squeeze(-1) == 0, 0) 92 | 93 | if self._is_stateful: 94 | self.running_seq.add_(1) 95 | seq = self.running_seq 96 | 97 | out = self.word_emb(input) + self.pos_emb(seq) 98 | 99 | for i, l in enumerate(self.layers): 100 | out = l(out, encoder_output, mask_queries, mask_self_attention, mask_encoder) 101 | 102 | 103 | out = self.fc(out) 104 | return F.log_softmax(out, dim=-1) 105 | -------------------------------------------------------------------------------- /models/beam_search/beam_search.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import utils 3 | 4 | 5 | class BeamSearch(object): 6 | def __init__(self, model, max_len: int, eos_idx: int, beam_size: int): 7 | self.model = model 8 | self.max_len = max_len 9 | self.eos_idx = eos_idx 10 | self.beam_size = beam_size 11 | self.b_s = None 12 | self.device = None 13 | self.seq_mask = None 14 | self.seq_logprob = None 15 | self.outputs = None 16 | self.log_probs = None 17 | self.selected_words = None 18 | self.all_log_probs = None 19 | 20 | def _expand_state(self, selected_beam, cur_beam_size): 21 | def fn(s): 22 | shape = [int(sh) for sh in s.shape] 23 | beam = selected_beam 24 | for _ in shape[1:]: 25 | beam = beam.unsqueeze(-1) 26 | # import pdb 27 | # pdb.set_trace() 28 | s = torch.gather(s.view(*([self.b_s, cur_beam_size] + shape[1:])), 1, 29 | beam.expand(*([self.b_s, self.beam_size] + shape[1:]))) 30 | s = s.view(*([-1, ] + shape[1:])) 31 | return s 32 | 33 | return fn 34 | 35 | def _expand_visual(self, visual: utils.TensorOrSequence, cur_beam_size: int, selected_beam: torch.Tensor): 36 | if isinstance(visual, torch.Tensor): 37 | visual_shape = visual.shape 38 | visual_exp_shape = (self.b_s, cur_beam_size) + visual_shape[1:] 39 | visual_red_shape = (self.b_s * self.beam_size,) + visual_shape[1:] 40 | selected_beam_red_size = (self.b_s, self.beam_size) + tuple(1 for _ in range(len(visual_exp_shape) - 2)) 41 | selected_beam_exp_size = (self.b_s, self.beam_size) + visual_exp_shape[2:] 42 | visual_exp = visual.view(visual_exp_shape) 43 | selected_beam_exp = selected_beam.view(selected_beam_red_size).expand(selected_beam_exp_size) 44 | visual = torch.gather(visual_exp, 1, selected_beam_exp).view(visual_red_shape) 45 | else: 46 | new_visual = [] 47 | for im in visual: 48 | visual_shape = im.shape 49 | visual_exp_shape = (self.b_s, cur_beam_size) + visual_shape[1:] 50 | visual_red_shape = (self.b_s * self.beam_size,) + visual_shape[1:] 51 | selected_beam_red_size = (self.b_s, self.beam_size) + tuple(1 for _ in range(len(visual_exp_shape) - 2)) 52 | selected_beam_exp_size = (self.b_s, self.beam_size) + visual_exp_shape[2:] 53 | visual_exp = im.view(visual_exp_shape) 54 | selected_beam_exp = selected_beam.view(selected_beam_red_size).expand(selected_beam_exp_size) 55 | new_im = torch.gather(visual_exp, 1, selected_beam_exp).view(visual_red_shape) 56 | new_visual.append(new_im) 57 | visual = tuple(new_visual) 58 | return visual 59 | 60 | def apply(self, visual: utils.TensorOrSequence, visual_gl: utils.TensorOrSequence,detections_mask:utils.TensorOrSequence, out_size=1, return_probs=False, **kwargs): 61 | self.b_s = utils.get_batch_size(visual) 62 | self.device = utils.get_device(visual) 63 | self.seq_mask = torch.ones((self.b_s, self.beam_size, 1), device=self.device) 64 | self.seq_logprob = torch.zeros((self.b_s, 1, 1), device=self.device) 65 | self.log_probs = [] 66 | self.selected_words = None 67 | if return_probs: 68 | self.all_log_probs = [] 69 | 70 | outputs = [] 71 | with self.model.statefulness(self.b_s): 72 | for t in range(self.max_len): 73 | visual, outputs = self.iter(t, visual,visual_gl, detections_mask,outputs, return_probs, **kwargs) 74 | 75 | 76 | # Sort result 77 | seq_logprob, sort_idxs = torch.sort(self.seq_logprob, 1, descending=True) 78 | outputs = torch.cat(outputs, -1) 79 | outputs = torch.gather(outputs, 1, sort_idxs.expand(self.b_s, self.beam_size, self.max_len)) 80 | log_probs = torch.cat(self.log_probs, -1) 81 | log_probs = torch.gather(log_probs, 1, sort_idxs.expand(self.b_s, self.beam_size, self.max_len)) 82 | if return_probs: 83 | all_log_probs = torch.cat(self.all_log_probs, 2) 84 | all_log_probs = torch.gather(all_log_probs, 1, sort_idxs.unsqueeze(-1).expand(self.b_s, self.beam_size, 85 | self.max_len, 86 | all_log_probs.shape[-1])) 87 | 88 | outputs = outputs.contiguous()[:, :out_size] 89 | log_probs = log_probs.contiguous()[:, :out_size] 90 | if out_size == 1: 91 | outputs = outputs.squeeze(1) 92 | log_probs = log_probs.squeeze(1) 93 | 94 | if return_probs: 95 | return outputs, log_probs, all_log_probs 96 | else: 97 | return outputs, log_probs 98 | 99 | def select(self, t, candidate_logprob, **kwargs): 100 | selected_logprob, selected_idx = torch.sort(candidate_logprob.view(self.b_s, -1), -1, descending=True) 101 | selected_logprob, selected_idx = selected_logprob[:, :self.beam_size], selected_idx[:, :self.beam_size] 102 | return selected_idx, selected_logprob 103 | 104 | def iter(self, t: int, visual: utils.TensorOrSequence,visual_gl:utils.TensorOrSequence,detections_mask:utils.TensorOrSequence,outputs, return_probs, **kwargs): 105 | cur_beam_size = 1 if t == 0 else self.beam_size 106 | # word_logprob = self.model.step(t, self.selected_words, visual, visual_gl, None, mode='feedback', **kwargs) 107 | 108 | word_logprob = self.model.step(t, self.selected_words, visual, visual_gl, detections_mask,None, mode='feedback', **kwargs) 109 | word_logprob = word_logprob.view(self.b_s, cur_beam_size, -1) 110 | candidate_logprob = self.seq_logprob + word_logprob 111 | 112 | # Mask sequence if it reaches EOS 113 | if t > 0: 114 | mask = (self.selected_words.view(self.b_s, cur_beam_size) != self.eos_idx).float().unsqueeze(-1) 115 | self.seq_mask = self.seq_mask * mask 116 | word_logprob = word_logprob * self.seq_mask.expand_as(word_logprob) 117 | old_seq_logprob = self.seq_logprob.expand_as(candidate_logprob).contiguous() 118 | old_seq_logprob[:, :, 1:] = -999 119 | candidate_logprob = self.seq_mask * candidate_logprob + old_seq_logprob * (1 - self.seq_mask) 120 | 121 | selected_idx, selected_logprob = self.select(t, candidate_logprob, **kwargs) 122 | selected_beam = selected_idx // candidate_logprob.shape[-1] 123 | selected_words = selected_idx - selected_beam * candidate_logprob.shape[-1] 124 | 125 | self.model.apply_to_states(self._expand_state(selected_beam, cur_beam_size)) 126 | visual = self._expand_visual(visual, cur_beam_size, selected_beam) 127 | 128 | self.seq_logprob = selected_logprob.unsqueeze(-1) 129 | self.seq_mask = torch.gather(self.seq_mask, 1, selected_beam.unsqueeze(-1)) 130 | outputs = list(torch.gather(o, 1, selected_beam.unsqueeze(-1)) for o in outputs) 131 | outputs.append(selected_words.unsqueeze(-1)) 132 | 133 | if return_probs: 134 | if t == 0: 135 | self.all_log_probs.append(word_logprob.expand((self.b_s, self.beam_size, -1)).unsqueeze(2)) 136 | else: 137 | self.all_log_probs.append(word_logprob.unsqueeze(2)) 138 | 139 | this_word_logprob = torch.gather(word_logprob, 1, 140 | selected_beam.unsqueeze(-1).expand(self.b_s, self.beam_size, 141 | word_logprob.shape[-1])) 142 | this_word_logprob = torch.gather(this_word_logprob, 2, selected_words.unsqueeze(-1)) 143 | self.log_probs = list( 144 | torch.gather(o, 1, selected_beam.unsqueeze(-1).expand(self.b_s, self.beam_size, 1)) for o in self.log_probs) 145 | self.log_probs.append(this_word_logprob) 146 | self.selected_words = selected_words.view(-1, 1) 147 | 148 | return visual, outputs 149 | -------------------------------------------------------------------------------- /models/transformer/attention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from models.containers import Module 5 | from torch.nn import functional as F 6 | from models.transformer.utils import * 7 | 8 | 9 | class SELayer(nn.Module): 10 | def __init__(self, channel, reduction=16): 11 | super(SELayer, self).__init__() 12 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 13 | self.fc = nn.Sequential( 14 | nn.Linear(channel, channel // reduction, bias=False), 15 | nn.ReLU(inplace=True), 16 | nn.Linear(channel // reduction, channel, bias=False), 17 | nn.Sigmoid() 18 | ) 19 | 20 | def forward(self, x): 21 | b, c, _, _ = x.size() 22 | y = self.avg_pool(x).view(b, c) 23 | y = self.fc(y).view(b, c, 1, 1) 24 | return x * y.expand_as(x) 25 | 26 | 27 | class ScaledDotProductAttention(nn.Module): 28 | """ 29 | Scaled dot-product attention 30 | """ 31 | 32 | def __init__(self, d_model, d_k, d_v, h, dilation=None): 33 | """ 34 | :param d_model: Output dimensionality of the model 35 | :param d_k: Dimensionality of queries and keys 36 | :param d_v: Dimensionality of values 37 | :param h: Number of heads 38 | """ 39 | super(ScaledDotProductAttention, self).__init__() 40 | self.fc_q = nn.Linear(d_model, h * d_k) 41 | self.fc_k = nn.Linear(d_model, h * d_k) 42 | self.fc_v = nn.Linear(d_model, h * d_v) 43 | self.fc_o = nn.Linear(h * d_v, d_model) 44 | 45 | self.d_model = d_model 46 | self.d_k = d_k 47 | self.d_v = d_v 48 | self.h = h 49 | 50 | self.init_weights() 51 | 52 | def init_weights(self): 53 | nn.init.xavier_uniform_(self.fc_q.weight) 54 | nn.init.xavier_uniform_(self.fc_k.weight) 55 | nn.init.xavier_uniform_(self.fc_v.weight) 56 | nn.init.xavier_uniform_(self.fc_o.weight) 57 | nn.init.constant_(self.fc_q.bias, 0) 58 | nn.init.constant_(self.fc_k.bias, 0) 59 | nn.init.constant_(self.fc_v.bias, 0) 60 | nn.init.constant_(self.fc_o.bias, 0) 61 | 62 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None, group_mask=None, 63 | input_gl=None, memory=None, 64 | isencoder=None, dilation=None): 65 | """ 66 | Computes 67 | :param queries: Queries (b_s, nq, d_model) 68 | :param keys: Keys (b_s, nk, d_model) 69 | :param values: Values (b_s, nk, d_model) 70 | :param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking. 71 | :param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk). 72 | :return: 73 | """ 74 | # att[0][0].argmax() 75 | b_s, nq = queries.shape[:2] 76 | nk = keys.shape[1] 77 | q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k) 10 8 50 64 78 | k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk) 10 8 64 50 79 | v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v) 10 8 50 64 80 | 81 | att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk) 10 8 50 50 82 | 83 | if attention_weights is not None: 84 | att = att * attention_weights 85 | if attention_mask is not None: 86 | att = att.masked_fill(attention_mask.bool(), -np.inf) 87 | 88 | att = torch.softmax(att, -1) 89 | out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq,self.h * self.d_v) 90 | out = self.fc_o(out) # (b_s, nq, d_model) 10 50 512 91 | return out 92 | 93 | 94 | 95 | # ScaledDotProductAttention_dilated 96 | class GlobalGroupingAttention_with_DC(nn.Module): 97 | """ 98 | Scaled dot-product attention 99 | """ 100 | 101 | def __init__(self, d_model, d_k, d_v, h, dilation=None): 102 | """ 103 | :param d_model: Output dimensionality of the model 104 | :param d_k: Dimensionality of queries and keys 105 | :param d_v: Dimensionality of values 106 | :param h: Number of heads 107 | """ 108 | super(GlobalGroupingAttention_with_DC, self).__init__() 109 | self.fc_c = nn.Linear(49, 1) 110 | 111 | self.fc_q = nn.Linear(d_model, h * d_k) 112 | self.fc_k = nn.Linear(d_model, h * d_k) 113 | self.fc_v = nn.Linear(d_model, h * d_v) 114 | self.fc_o = nn.Linear(h * d_v, d_model) 115 | 116 | self.d_model = d_model 117 | self.d_k = d_k 118 | self.d_v = d_v 119 | self.h = h 120 | 121 | self.net = nn.Sequential( 122 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), stride=1, padding=dilation, 123 | dilation=dilation), 124 | nn.BatchNorm2d(512, affine=False), nn.ReLU(inplace=True)) 125 | 126 | # self.SE = SELayer(channel=512) 127 | self.init_weights() 128 | 129 | def init_weights(self): 130 | nn.init.xavier_uniform_(self.fc_c.weight) 131 | nn.init.constant_(self.fc_c.bias, 0) 132 | 133 | nn.init.xavier_uniform_(self.fc_q.weight) 134 | nn.init.xavier_uniform_(self.fc_k.weight) 135 | nn.init.xavier_uniform_(self.fc_v.weight) 136 | nn.init.xavier_uniform_(self.fc_o.weight) 137 | nn.init.constant_(self.fc_q.bias, 0) 138 | nn.init.constant_(self.fc_k.bias, 0) 139 | nn.init.constant_(self.fc_v.bias, 0) 140 | nn.init.constant_(self.fc_o.bias, 0) 141 | 142 | 143 | 144 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None, group_mask=None, 145 | input_gl=None, memory=None, isencoder=None, dilation=None): 146 | """ 147 | Computes 148 | :param queries: Queries (b_s, nq, d_model) 149 | :param keys: Keys (b_s, nk, d_model) 150 | :param values: Values (b_s, nk, d_model) 151 | :param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking. 152 | :param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk). 153 | :return: 154 | """ 155 | ######################################################################### 156 | # dilation 157 | x = queries.permute(0, 2, 1) 158 | x = x.reshape(x.shape[0], x.shape[1], 7, 7) # res 159 | # x = x.reshape(x.shape[0], x.shape[1], 14, 14) # vgg 160 | x = self.net(x) 161 | x_s = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(0, 2, 1) 162 | 163 | # channel 164 | # x_c=x_s.permute(0, 2, 1) 165 | # c=torch.softmax(self.fc_c(x_c),1).permute(0, 2, 1) 166 | # c = c.repeat(1, 49, 1) 167 | 168 | # spatial 169 | # queries = queries*c 170 | # keys = keys*c 171 | # values = values*c 172 | 173 | queries = x_s + queries 174 | keys = x_s + keys 175 | values = x_s + values 176 | 177 | b_s, nq = queries.shape[:2] 178 | nk = keys.shape[1] 179 | q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k) 180 | k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk) 181 | v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v) 182 | 183 | att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk) 184 | if attention_weights is not None: 185 | att = att * attention_weights 186 | if attention_mask is not None: 187 | att = att.masked_fill(attention_mask.bool(), -np.inf) 188 | if group_mask is not None: 189 | # (1) 190 | # att = att.masked_fill(group_mask.bool(), -np.inf) 191 | # (2) 192 | # att = att.masked_fill(group_mask.bool(), torch.tensor(-1e9)) 193 | # (3) 194 | group_mask_mat=group_mask.masked_fill(group_mask.bool(), torch.tensor(-1e9)) 195 | att=att+group_mask_mat 196 | 197 | att = torch.softmax(att, -1) 198 | out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v) 199 | out = self.fc_o(out) # (b_s, nq, d_model) 200 | return out 201 | 202 | 203 | 204 | class MultiHeadAttention(Module): 205 | """ 206 | Multi-head attention layer with Dropout and Layer Normalization. 207 | """ 208 | 209 | def __init__(self, d_model, d_k, d_v, h, dropout=.1, identity_map_reordering=False, can_be_stateful=False, 210 | attention_module=None, attention_module_kwargs=None, isenc=None, dilation=None): 211 | super(MultiHeadAttention, self).__init__() 212 | self.identity_map_reordering = identity_map_reordering 213 | if attention_module is not None: 214 | if attention_module_kwargs is not None: 215 | self.attention = attention_module(d_model=d_model, d_k=d_k, d_v=d_v, h=h, **attention_module_kwargs, 216 | dilation=dilation) 217 | else: 218 | self.attention = attention_module(d_model=d_model, d_k=d_k, d_v=d_v, h=h, dilation=dilation) 219 | else: 220 | self.attention = ScaledDotProductAttention(d_model=d_model, d_k=d_k, d_v=d_v, h=h) 221 | 222 | 223 | self.dropout = nn.Dropout(p=dropout) 224 | self.layer_norm = nn.LayerNorm(d_model) 225 | 226 | self.can_be_stateful = can_be_stateful 227 | if self.can_be_stateful: 228 | self.register_state('running_keys', torch.zeros((0, d_model))) 229 | self.register_state('running_values', torch.zeros((0, d_model))) 230 | 231 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None, group_mask=None, 232 | input_gl=None, memory=None,isencoder=None): 233 | if self.can_be_stateful and self._is_stateful: 234 | self.running_keys = torch.cat([self.running_keys, keys], 1) 235 | keys = self.running_keys 236 | 237 | self.running_values = torch.cat([self.running_values, values], 1) 238 | values = self.running_values 239 | 240 | if self.identity_map_reordering: 241 | q_norm = self.layer_norm(queries) 242 | k_norm = self.layer_norm(keys) 243 | v_norm = self.layer_norm(values) 244 | out = self.attention(q_norm, k_norm, v_norm, attention_mask, attention_weights, input_gl=input_gl) 245 | out = queries + self.dropout(torch.relu(out)) 246 | else: 247 | if isencoder == True: 248 | out = self.attention(queries, keys, values, attention_mask, attention_weights, group_mask=group_mask, 249 | input_gl=input_gl, memory=memory,isencoder=isencoder) 250 | else: 251 | out = self.attention(queries, keys, values, attention_mask, attention_weights, 252 | input_gl=None, memory=memory, isencoder=isencoder) 253 | out = self.dropout(out) 254 | out = self.layer_norm(queries + out) 255 | return out 256 | -------------------------------------------------------------------------------- /misc/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import collections 6 | import torch 7 | import torch.nn as nn 8 | import numpy as np 9 | import torch.optim as optim 10 | import os 11 | 12 | import six 13 | from six.moves import cPickle 14 | 15 | bad_endings = ['with', 'in', 'on', 'of', 'a', 'at', 'to', 'for', 'an', 'this', 'his', 'her', 'that'] 16 | bad_endings += ['the'] 17 | 18 | 19 | def pickle_load(f): 20 | """ Load a pickle. 21 | Parameters 22 | ---------- 23 | f: file-like object 24 | """ 25 | if six.PY3: 26 | return cPickle.load(f, encoding='latin-1') 27 | else: 28 | return cPickle.load(f) 29 | 30 | 31 | def pickle_dump(obj, f): 32 | """ Dump a pickle. 33 | Parameters 34 | ---------- 35 | obj: pickled object 36 | f: file-like object 37 | """ 38 | if six.PY3: 39 | return cPickle.dump(obj, f, protocol=2) 40 | else: 41 | return cPickle.dump(obj, f) 42 | 43 | 44 | def if_use_feat(caption_model): 45 | # Decide if load attention feature according to caption model 46 | if caption_model in ['show_tell', 'all_img', 'fc', 'newfc']: 47 | use_att, use_fc = False, True 48 | elif caption_model == 'language_model': 49 | use_att, use_fc = False, False 50 | elif caption_model in ['topdown', 'aoa']: 51 | use_fc, use_att = True, True 52 | else: 53 | use_att, use_fc = True, False 54 | return use_fc, use_att 55 | 56 | 57 | # Input: seq, N*D numpy array, with element 0 .. vocab_size. 0 is END token. 58 | def decode_sequence(ix_to_word, seq): 59 | N, D = seq.size() 60 | out = [] 61 | for i in range(N): 62 | txt = '' 63 | for j in range(D): 64 | ix = seq[i, j] 65 | if ix > 0: 66 | if j >= 1: 67 | txt = txt + ' ' 68 | txt = txt + ix_to_word[str(ix.item())] 69 | else: 70 | break 71 | if int(os.getenv('REMOVE_BAD_ENDINGS', '0')): 72 | flag = 0 73 | words = txt.split(' ') 74 | for j in range(len(words)): 75 | if words[-j - 1] not in bad_endings: 76 | flag = -j 77 | break 78 | txt = ' '.join(words[0:len(words) + flag]) 79 | out.append(txt.replace('@@ ', '')) 80 | return out 81 | 82 | 83 | def to_contiguous(tensor): 84 | if tensor.is_contiguous(): 85 | return tensor 86 | else: 87 | return tensor.contiguous() 88 | 89 | 90 | class RewardCriterion(nn.Module): 91 | def __init__(self): 92 | super(RewardCriterion, self).__init__() 93 | 94 | def forward(self, input, seq, reward): 95 | input = to_contiguous(input).view(-1) 96 | reward = to_contiguous(reward).view(-1) 97 | mask = (seq > 0).float() 98 | mask = to_contiguous(torch.cat([mask.new(mask.size(0), 1).fill_(1), mask[:, :-1]], 1)).view(-1) 99 | output = - input * reward * mask 100 | output = torch.sum(output) / torch.sum(mask) 101 | 102 | return output 103 | 104 | 105 | class LanguageModelCriterion(nn.Module): 106 | def __init__(self): 107 | super(LanguageModelCriterion, self).__init__() 108 | 109 | def forward(self, input, target, mask): 110 | # truncate to the same size 111 | target = target[:, :input.size(1)] 112 | mask = mask[:, :input.size(1)] 113 | 114 | output = -input.gather(2, target.unsqueeze(2)).squeeze(2) * mask 115 | output = torch.sum(output) / torch.sum(mask) 116 | 117 | return output 118 | 119 | 120 | class LabelSmoothing(nn.Module): 121 | "Implement label smoothing." 122 | 123 | def __init__(self, size=0, padding_idx=0, smoothing=0.0): 124 | super(LabelSmoothing, self).__init__() 125 | """ nn.KLDivLoss计算KL散度的损失函数,要将模型输出的原始预测值要先进行softmax,然后进行log运算 126 | (torch.nn.functional.log_softmax可以直接实现),得到结果作为input输入到KLDivLoss中。target是二维的,形状与input一样 127 | 128 | torch.nn.CrossEntropyLoss()交叉熵计算,输入的预测值不需要进行softmax,不需要进行log运算!直接用原始的预测输出,标签用整数序列。 129 | """ 130 | self.criterion = nn.KLDivLoss(size_average=False, reduce=False) 131 | # self.padding_idx = padding_idx 132 | self.confidence = 1.0 - smoothing # 0.8 133 | self.smoothing = smoothing 134 | # self.size = size 135 | self.true_dist = None 136 | 137 | def forward(self, input, target, mask): 138 | # truncate to the same size 139 | target = target[:, :input.size(1)] 140 | mask = mask[:, :input.size(1)] 141 | 142 | input = to_contiguous(input).view(-1, input.size(-1)) 143 | target = to_contiguous(target).view(-1) 144 | mask = to_contiguous(mask).view(-1) 145 | 146 | # assert x.size(1) == self.size 147 | self.size = input.size(1) 148 | # true_dist = x.data.clone() 149 | true_dist = input.data.clone() 150 | # true_dist.fill_(self.smoothing / (self.size - 2)) 151 | true_dist.fill_(self.smoothing / (self.size - 1)) 152 | true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) 153 | # true_dist[:, self.padding_idx] = 0 154 | # mask = torch.nonzero(target.data == self.padding_idx) 155 | # self.true_dist = true_dist 156 | return (self.criterion(input, true_dist).sum(1) * mask).sum() / mask.sum() 157 | 158 | 159 | def set_lr(optimizer, lr): 160 | for group in optimizer.param_groups: 161 | group['lr'] = lr 162 | 163 | 164 | def get_lr(optimizer): 165 | for group in optimizer.param_groups: 166 | return group['lr'] 167 | 168 | 169 | def clip_gradient(optimizer, grad_clip): # 梯度截断Clip, 将梯度约束在某一个区间之内. 避免梯度爆炸 170 | for group in optimizer.param_groups: 171 | for param in group['params']: 172 | param.grad.data.clamp_(-grad_clip, grad_clip) 173 | 174 | 175 | def build_optimizer(params, opt): 176 | if opt.optim == 'rmsprop': 177 | return optim.RMSprop(params, opt.learning_rate, opt.optim_alpha, opt.optim_epsilon, 178 | weight_decay=opt.weight_decay) 179 | elif opt.optim == 'adagrad': 180 | return optim.Adagrad(params, opt.learning_rate, weight_decay=opt.weight_decay) 181 | elif opt.optim == 'sgd': 182 | return optim.SGD(params, opt.learning_rate, weight_decay=opt.weight_decay) 183 | elif opt.optim == 'sgdm': 184 | return optim.SGD(params, opt.learning_rate, opt.optim_alpha, weight_decay=opt.weight_decay) 185 | elif opt.optim == 'sgdmom': 186 | return optim.SGD(params, opt.learning_rate, opt.optim_alpha, weight_decay=opt.weight_decay, nesterov=True) 187 | elif opt.optim == 'adam': 188 | return optim.Adam(params, opt.learning_rate, (opt.optim_alpha, opt.optim_beta), opt.optim_epsilon, 189 | weight_decay=opt.weight_decay) 190 | else: 191 | raise Exception("bad option opt.optim: {}".format(opt.optim)) 192 | 193 | 194 | def penalty_builder(penalty_config): 195 | if penalty_config == '': 196 | return lambda x, y: y 197 | pen_type, alpha = penalty_config.split('_') 198 | alpha = float(alpha) 199 | if pen_type == 'wu': 200 | return lambda x, y: length_wu(x, y, alpha) 201 | if pen_type == 'avg': 202 | return lambda x, y: length_average(x, y, alpha) 203 | 204 | 205 | def length_wu(length, logprobs, alpha=0.): 206 | """ 207 | NMT length re-ranking score from 208 | "Google's Neural Machine Translation System" :cite:`wu2016google`. 209 | """ 210 | 211 | modifier = (((5 + length) ** alpha) / 212 | ((5 + 1) ** alpha)) 213 | return (logprobs / modifier) 214 | 215 | 216 | def length_average(length, logprobs, alpha=0.): 217 | """ 218 | Returns the average probability of tokens in a sequence. 219 | """ 220 | return logprobs / length 221 | 222 | 223 | class NoamOpt(object): 224 | "Optim wrapper that implements rate." 225 | 226 | def __init__(self, model_size, factor, warmup, optimizer): 227 | self.optimizer = optimizer 228 | self._step = 0 229 | self.warmup = warmup 230 | self.factor = factor 231 | self.model_size = model_size 232 | self._rate = 0 233 | 234 | def step(self): 235 | "Update parameters and rate" 236 | self._step += 1 237 | rate = self.rate() 238 | for p in self.optimizer.param_groups: 239 | p['lr'] = rate 240 | self._rate = rate 241 | self.optimizer.step() 242 | 243 | def rate(self, step=None): 244 | "Implement `lrate` above" 245 | if step is None: 246 | step = self._step 247 | return self.factor * \ 248 | (self.model_size ** (-0.5) * 249 | min(step ** (-0.5), step * self.warmup ** (-1.5))) 250 | 251 | def __getattr__(self, name): 252 | return getattr(self.optimizer, name) 253 | 254 | 255 | class ReduceLROnPlateau(object): 256 | "Optim wrapper that implements rate." 257 | 258 | def __init__(self, optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001, 259 | threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08): 260 | self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode, factor, patience, verbose, threshold, 261 | threshold_mode, cooldown, min_lr, eps) 262 | self.optimizer = optimizer 263 | self.current_lr = get_lr(optimizer) 264 | 265 | def step(self): 266 | "Update parameters and rate" 267 | self.optimizer.step() 268 | 269 | def scheduler_step(self, val): 270 | self.scheduler.step(val) 271 | self.current_lr = get_lr(self.optimizer) 272 | 273 | def state_dict(self): 274 | return {'current_lr': self.current_lr, 275 | 'scheduler_state_dict': self.scheduler.state_dict(), 276 | 'optimizer_state_dict': self.optimizer.state_dict()} 277 | 278 | def load_state_dict(self, state_dict): 279 | if 'current_lr' not in state_dict: 280 | # it's normal optimizer 281 | self.optimizer.load_state_dict(state_dict) 282 | set_lr(self.optimizer, self.current_lr) # use the lr fromt the option 283 | else: 284 | # it's a schduler 285 | self.current_lr = state_dict['current_lr'] 286 | self.scheduler.load_state_dict(state_dict['scheduler_state_dict']) 287 | self.optimizer.load_state_dict(state_dict['optimizer_state_dict']) 288 | # current_lr is actually useless in this case 289 | 290 | def rate(self, step=None): 291 | "Implement `lrate` above" 292 | if step is None: 293 | step = self._step 294 | return self.factor * \ 295 | (self.model_size ** (-0.5) * 296 | min(step ** (-0.5), step * self.warmup ** (-1.5))) 297 | 298 | def __getattr__(self, name): 299 | return getattr(self.optimizer, name) 300 | 301 | 302 | def get_std_opt(model, factor=1, warmup=2000): 303 | # return NoamOpt(model.tgt_embed[0].d_model, 2, 4000, 304 | # torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) 305 | return NoamOpt(model.model.tgt_embed[0].d_model, factor, warmup, 306 | torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) 307 | -------------------------------------------------------------------------------- /data/field.py: -------------------------------------------------------------------------------- 1 | # coding: utf8 2 | from collections import Counter, OrderedDict 3 | from torch.utils.data.dataloader import default_collate 4 | from itertools import chain 5 | import six 6 | import torch 7 | import numpy as np 8 | import h5py 9 | import os 10 | import warnings 11 | import shutil 12 | 13 | from .dataset import Dataset 14 | from .vocab import Vocab 15 | from .utils import get_tokenizer 16 | 17 | class RawField(object): 18 | """ Defines a general datatype. 19 | 20 | Every dataset consists of one or more types of data. For instance, 21 | a machine translation dataset contains paired examples of text, while 22 | an image captioning dataset contains images and texts. 23 | Each of these types of data is represented by a RawField object. 24 | An RawField object does not assume any property of the data type and 25 | it holds parameters relating to how a datatype should be processed. 26 | 27 | Attributes: 28 | preprocessing: The Pipeline that will be applied to examples 29 | using this field before creating an example. 30 | Default: None. 31 | postprocessing: A Pipeline that will be applied to a list of examples 32 | using this field before assigning to a batch. 33 | Function signature: (batch(list)) -> object 34 | Default: None. 35 | """ 36 | 37 | def __init__(self, preprocessing=None, postprocessing=None): 38 | self.preprocessing = preprocessing 39 | self.postprocessing = postprocessing 40 | 41 | def preprocess(self, x): 42 | """ Preprocess an example if the `preprocessing` Pipeline is provided. """ 43 | if self.preprocessing is not None: 44 | return self.preprocessing(x) 45 | else: 46 | return x 47 | 48 | def process(self, batch, *args, **kwargs): 49 | """ Process a list of examples to create a batch. 50 | 51 | Postprocess the batch with user-provided Pipeline. 52 | 53 | Args: 54 | batch (list(object)): A list of object from a batch of examples. 55 | Returns: 56 | object: Processed object given the input and custom 57 | postprocessing Pipeline. 58 | """ 59 | if self.postprocessing is not None: 60 | batch = self.postprocessing(batch) 61 | return default_collate(batch) 62 | 63 | 64 | class Merge(RawField): 65 | def __init__(self, *fields): 66 | super(Merge, self).__init__() 67 | self.fields = fields 68 | 69 | def preprocess(self, x): 70 | return tuple(f.preprocess(x) for f in self.fields) 71 | 72 | def process(self, batch, *args, **kwargs): 73 | if len(self.fields) == 1: 74 | batch = [batch, ] 75 | else: 76 | batch = list(zip(*batch)) 77 | 78 | out = list(f.process(b, *args, **kwargs) for f, b in zip(self.fields, batch)) 79 | return out 80 | 81 | 82 | class ImageDetectionsField(RawField): 83 | def __init__(self, preprocessing=None, postprocessing=None, clip_features_path=None, 84 | res_features_path=None,mask_features_path=None): 85 | self.clip_features_path = clip_features_path 86 | self.res_features_path = res_features_path 87 | self.mask_features_path = mask_features_path 88 | super(ImageDetectionsField, self).__init__(preprocessing, postprocessing) 89 | 90 | 91 | def preprocess(self, x, avoid_precomp=False): 92 | image_name = x.split('/')[-1] 93 | 94 | # resnet feature 95 | precomp_data = np.load(self.res_features_path + '/%s.npy' % image_name) # res 2048, 7, 7 96 | precomp_data = precomp_data.reshape(precomp_data.shape[0], precomp_data.shape[1] * precomp_data.shape[2]) 97 | precomp_data=precomp_data.transpose(1, 0) 98 | 99 | # clip feature 100 | precomp_data_clip = np.load(self.clip_features_path + '/%s.npy' % image_name) 101 | precomp_data_clip = precomp_data_clip[0] 102 | precomp_data_clip = np.expand_dims(precomp_data_clip, axis=0) 103 | precomp_data_clip = precomp_data_clip.repeat([49], axis=0) 104 | 105 | # group mask matrix 106 | precomp_data_mask = np.load(self.mask_features_path + '/%s.npy' % image_name) 107 | 108 | 109 | return precomp_data.astype(np.float32), precomp_data_clip.astype(np.float32), precomp_data_mask.astype(np.float32) 110 | 111 | 112 | 113 | 114 | class TextField(RawField): 115 | vocab_cls = Vocab 116 | # Dictionary mapping PyTorch tensor dtypes to the appropriate Python 117 | # numeric type. 118 | dtypes = { 119 | torch.float32: float, 120 | torch.float: float, 121 | torch.float64: float, 122 | torch.double: float, 123 | torch.float16: float, 124 | torch.half: float, 125 | 126 | torch.uint8: int, 127 | torch.int8: int, 128 | torch.int16: int, 129 | torch.short: int, 130 | torch.int32: int, 131 | torch.int: int, 132 | torch.int64: int, 133 | torch.long: int, 134 | } 135 | punctuations = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \ 136 | ".", "?", "!", ",", ":", "-", "--", "...", ";"] 137 | 138 | def __init__(self, use_vocab=True, init_token=None, eos_token=None, fix_length=None, dtype=torch.long, 139 | preprocessing=None, postprocessing=None, lower=False, tokenize=(lambda s: s.split()), 140 | remove_punctuation=False, include_lengths=False, batch_first=True, pad_token="", 141 | unk_token="", pad_first=False, truncate_first=False, vectors=None, nopoints=True): 142 | self.use_vocab = use_vocab 143 | self.init_token = init_token 144 | self.eos_token = eos_token 145 | self.fix_length = fix_length 146 | self.dtype = dtype 147 | self.lower = lower 148 | self.tokenize = get_tokenizer(tokenize) 149 | self.remove_punctuation = remove_punctuation 150 | self.include_lengths = include_lengths 151 | self.batch_first = batch_first 152 | self.pad_token = pad_token 153 | self.unk_token = unk_token 154 | self.pad_first = pad_first 155 | self.truncate_first = truncate_first 156 | self.vocab = None 157 | self.vectors = vectors 158 | if nopoints: 159 | self.punctuations.append("..") 160 | 161 | super(TextField, self).__init__(preprocessing, postprocessing) 162 | 163 | def preprocess(self, x): 164 | 165 | if isinstance(x, list): 166 | x = str(x) 167 | 168 | if six.PY2 and isinstance(x, six.string_types) and not isinstance(x, six.text_type): 169 | x = six.text_type(x, encoding='utf-8') 170 | if self.lower: 171 | x = six.text_type.lower(x) 172 | x = self.tokenize(x.rstrip('\n')) 173 | if self.remove_punctuation: 174 | x = [w for w in x if w not in self.punctuations] 175 | if self.preprocessing is not None: 176 | return self.preprocessing(x) 177 | else: 178 | return x 179 | 180 | def process(self, batch, device=None): 181 | padded = self.pad(batch) 182 | tensor = self.numericalize(padded, device=device) 183 | return tensor 184 | 185 | def build_vocab(self, *args, **kwargs): 186 | counter = Counter() 187 | sources = [] 188 | for arg in args: 189 | if isinstance(arg, Dataset): 190 | # sources += [getattr(arg, name) for name, field in arg.fields.items() if field is self] 191 | for name, field in arg.fields.items(): 192 | if field is self: 193 | sources += [getattr(arg, name)] 194 | else: 195 | sources.append(arg) 196 | 197 | for data in sources: 198 | for x in data: 199 | x = self.preprocess(x) # 返回list 200 | try: 201 | counter.update(x) 202 | except TypeError: 203 | counter.update(chain.from_iterable(x)) 204 | 205 | specials = list(OrderedDict.fromkeys([ 206 | tok for tok in [self.unk_token, self.pad_token, self.init_token, 207 | self.eos_token] 208 | if tok is not None])) 209 | self.vocab = self.vocab_cls(counter, specials=specials, **kwargs) 210 | 211 | def pad(self, minibatch): 212 | """Pad a batch of examples using this field. 213 | Pads to self.fix_length if provided, otherwise pads to the length of 214 | the longest example in the batch. Prepends self.init_token and appends 215 | self.eos_token if those attributes are not None. Returns a tuple of the 216 | padded list and a list containing lengths of each example if 217 | `self.include_lengths` is `True`, else just 218 | returns the padded list. 219 | """ 220 | minibatch = list(minibatch) 221 | if self.fix_length is None: 222 | max_len = max(len(x) for x in minibatch) 223 | else: 224 | max_len = self.fix_length + ( 225 | self.init_token, self.eos_token).count(None) - 2 226 | padded, lengths = [], [] 227 | for x in minibatch: 228 | if self.pad_first: 229 | padded.append( 230 | [self.pad_token] * max(0, max_len - len(x)) + 231 | ([] if self.init_token is None else [self.init_token]) + 232 | list(x[-max_len:] if self.truncate_first else x[:max_len]) + 233 | ([] if self.eos_token is None else [self.eos_token])) 234 | else: 235 | padded.append( 236 | ([] if self.init_token is None else [self.init_token]) + 237 | list(x[-max_len:] if self.truncate_first else x[:max_len]) + 238 | ([] if self.eos_token is None else [self.eos_token]) + 239 | [self.pad_token] * max(0, max_len - len(x))) 240 | lengths.append(len(padded[-1]) - max(0, max_len - len(x))) 241 | if self.include_lengths: 242 | return padded, lengths 243 | return padded # 返回处理好的 batch text 10, 13 244 | 245 | def numericalize(self, arr, device=None): 246 | """Turn a batch of examples that use this field into a list of Variables. 247 | If the field has include_lengths=True, a tensor of lengths will be 248 | included in the return value. 249 | Arguments: 250 | arr (List[List[str]], or tuple of (List[List[str]], List[int])): 251 | List of tokenized and padded examples, or tuple of List of 252 | tokenized and padded examples and List of lengths of each 253 | example if self.include_lengths is True. 254 | device (str or torch.device): A string or instance of `torch.device` 255 | specifying which device the Variables are going to be created on. 256 | If left as default, the tensors will be created on cpu. Default: None. 257 | """ 258 | if self.include_lengths and not isinstance(arr, tuple): 259 | raise ValueError("Field has include_lengths set to True, but " 260 | "input data is not a tuple of " 261 | "(data batch, batch lengths).") 262 | if isinstance(arr, tuple): 263 | arr, lengths = arr 264 | lengths = torch.tensor(lengths, dtype=self.dtype, device=device) 265 | 266 | if self.use_vocab: 267 | arr = [[self.vocab.stoi[x] for x in ex] for ex in arr] 268 | 269 | if self.postprocessing is not None: 270 | arr = self.postprocessing(arr, self.vocab) 271 | 272 | var = torch.tensor(arr, dtype=self.dtype, device=device) 273 | else: 274 | if self.vectors: 275 | arr = [[self.vectors[x] for x in ex] for ex in arr] 276 | if self.dtype not in self.dtypes: 277 | raise ValueError( 278 | "Specified Field dtype {} can not be used with " 279 | "use_vocab=False because we do not know how to numericalize it. " 280 | "Please raise an issue at " 281 | "https://github.com/pytorch/text/issues".format(self.dtype)) 282 | numericalization_func = self.dtypes[self.dtype] 283 | # It doesn't make sense to explictly coerce to a numeric type if 284 | # the data is sequential, since it's unclear how to coerce padding tokens 285 | # to a numeric type. 286 | arr = [numericalization_func(x) if isinstance(x, six.string_types) 287 | else x for x in arr] 288 | 289 | if self.postprocessing is not None: 290 | arr = self.postprocessing(arr, None) 291 | 292 | var = torch.cat([torch.cat([a.unsqueeze(0) for a in ar]).unsqueeze(0) for ar in arr]) 293 | 294 | # var = torch.tensor(arr, dtype=self.dtype, device=device) 295 | if not self.batch_first: 296 | var.t_() # 10 13的张量 297 | var = var.contiguous() 298 | 299 | if self.include_lengths: 300 | return var, lengths 301 | return var 302 | 303 | def decode(self, word_idxs, join_words=True): 304 | if isinstance(word_idxs, list) and len(word_idxs) == 0: 305 | return self.decode([word_idxs, ], join_words)[0] 306 | if isinstance(word_idxs, list) and isinstance(word_idxs[0], int): 307 | return self.decode([word_idxs, ], join_words)[0] 308 | elif isinstance(word_idxs, np.ndarray) and word_idxs.ndim == 1: 309 | return self.decode(word_idxs.reshape((1, -1)), join_words)[0] 310 | elif isinstance(word_idxs, torch.Tensor) and word_idxs.ndimension() == 1: 311 | return self.decode(word_idxs.unsqueeze(0), join_words)[0] 312 | 313 | captions = [] 314 | for wis in word_idxs: 315 | caption = [] 316 | for wi in wis: 317 | word = self.vocab.itos[int(wi)] 318 | if word == self.eos_token: 319 | break 320 | caption.append(word) 321 | if join_words: 322 | caption = ' '.join(caption) 323 | captions.append(caption) 324 | 325 | return captions 326 | -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import itertools 4 | import collections 5 | import torch 6 | from .example import Example 7 | from .utils import nostdout 8 | import json 9 | 10 | 11 | class Dataset(object): 12 | def __init__(self, examples, fields): 13 | self.examples = examples 14 | self.fields = dict(fields) 15 | 16 | def collate_fn(self): 17 | def collate(batch): 18 | if len(self.fields) == 1: 19 | batch = [batch, ] 20 | else: 21 | batch = list(zip(*batch)) 22 | 23 | tensors = [] 24 | for field, data in zip(self.fields.values(), batch): 25 | tensor = field.process(data) 26 | if isinstance(tensor, collections.Sequence) and any(isinstance(t, torch.Tensor) for t in tensor): 27 | tensors.extend(tensor) 28 | else: 29 | tensors.append(tensor) 30 | 31 | if len(tensors) > 1: 32 | return tensors 33 | else: 34 | return tensors[0] 35 | 36 | return collate 37 | 38 | def __getitem__(self, i): 39 | example = self.examples[i] 40 | data = [] 41 | for field_name, field in self.fields.items(): 42 | tem = getattr(example, field_name) 43 | data.append(field.preprocess(tem)) 44 | 45 | if len(data) == 1: 46 | data = data[0] 47 | return data 48 | 49 | def __len__(self): 50 | return len(self.examples) 51 | 52 | def __getattr__(self, attr): 53 | if attr in self.fields: 54 | for x in self.examples: 55 | yield getattr(x, attr) 56 | 57 | 58 | class ValueDataset(Dataset): 59 | def __init__(self, examples, fields, dictionary): 60 | self.dictionary = dictionary 61 | super(ValueDataset, self).__init__(examples, fields) 62 | 63 | def collate_fn(self): 64 | def collate(batch): 65 | value_batch_flattened = list(itertools.chain(*batch)) 66 | value_tensors_flattened = super(ValueDataset, self).collate_fn()(value_batch_flattened) 67 | 68 | lengths = [0, ] + list(itertools.accumulate([len(x) for x in batch])) 69 | if isinstance(value_tensors_flattened, collections.Sequence) \ 70 | and any(isinstance(t, torch.Tensor) for t in value_tensors_flattened): 71 | value_tensors = [[vt[s:e] for (s, e) in zip(lengths[:-1], lengths[1:])] for vt in 72 | value_tensors_flattened] 73 | else: 74 | value_tensors = [value_tensors_flattened[s:e] for (s, e) in zip(lengths[:-1], lengths[1:])] 75 | 76 | return value_tensors 77 | 78 | return collate 79 | 80 | def __getitem__(self, i): 81 | if i not in self.dictionary: 82 | raise IndexError 83 | 84 | values_data = [] 85 | for idx in self.dictionary[i]: 86 | value_data = super(ValueDataset, self).__getitem__(idx) 87 | values_data.append(value_data) 88 | return values_data 89 | 90 | def __len__(self): 91 | return len(self.dictionary) 92 | 93 | 94 | class DictionaryDataset(Dataset): 95 | def __init__(self, examples, fields, key_fields): 96 | if not isinstance(key_fields, (tuple, list)): 97 | key_fields = (key_fields,) 98 | for field in key_fields: 99 | assert (field in fields) 100 | 101 | dictionary = collections.defaultdict(list) 102 | key_fields = {k: fields[k] for k in key_fields} 103 | value_fields = {k: fields[k] for k in fields.keys() if k not in key_fields} 104 | key_examples = [] 105 | key_dict = dict() 106 | value_examples = [] 107 | 108 | for i, e in enumerate(examples): 109 | key_example = Example.fromdict({k: getattr(e, k) for k in key_fields}) 110 | value_example = Example.fromdict({v: getattr(e, v) for v in value_fields}) 111 | if key_example not in key_dict: 112 | key_dict[key_example] = len(key_examples) 113 | key_examples.append(key_example) 114 | 115 | value_examples.append(value_example) 116 | dictionary[key_dict[key_example]].append(i) 117 | 118 | self.key_dataset = Dataset(key_examples, key_fields) 119 | self.value_dataset = ValueDataset(value_examples, value_fields, dictionary) 120 | super(DictionaryDataset, self).__init__(examples, fields) 121 | 122 | def collate_fn(self): 123 | def collate(batch): 124 | key_batch, value_batch = list(zip(*batch)) 125 | key_tensors = self.key_dataset.collate_fn()(key_batch) 126 | value_tensors = self.value_dataset.collate_fn()(value_batch) 127 | return key_tensors, value_tensors 128 | 129 | return collate 130 | 131 | def __getitem__(self, i): 132 | return self.key_dataset[i], self.value_dataset[i] 133 | 134 | def __len__(self): 135 | return len(self.key_dataset) 136 | 137 | 138 | def unique(sequence): 139 | seen = set() 140 | if isinstance(sequence[0], list): 141 | return [x for x in sequence if not (tuple(x) in seen or seen.add(tuple(x)))] 142 | else: 143 | return [x for x in sequence if not (x in seen or seen.add(x))] 144 | 145 | 146 | class PairedDataset(Dataset): 147 | def __init__(self, examples, fields): 148 | assert ('image' in fields) 149 | assert ('text' in fields) 150 | super(PairedDataset, self).__init__(examples, fields) 151 | self.image_field = self.fields['image'] 152 | self.text_field = self.fields['text'] 153 | 154 | def image_set(self): 155 | img_list = [e.image for e in self.examples] 156 | image_set = unique(img_list) 157 | examples = [Example.fromdict({'image': i}) for i in image_set] 158 | dataset = Dataset(examples, {'image': self.image_field}) 159 | return dataset 160 | 161 | def text_set(self): 162 | text_list = [e.text for e in self.examples] 163 | text_list = unique(text_list) 164 | examples = [Example.fromdict({'text': t}) for t in text_list] 165 | dataset = Dataset(examples, {'text': self.text_field}) 166 | return dataset 167 | 168 | def image_dictionary(self, fields=None): 169 | if not fields: 170 | fields = self.fields 171 | dataset = DictionaryDataset(self.examples, fields, key_fields='image') 172 | return dataset 173 | 174 | def text_dictionary(self, fields=None): 175 | if not fields: 176 | fields = self.fields 177 | dataset = DictionaryDataset(self.examples, fields, key_fields='text') 178 | return dataset 179 | 180 | @property 181 | def splits(self): 182 | raise NotImplementedError 183 | 184 | 185 | class Sydney(PairedDataset): 186 | def __init__(self, image_field, text_field, img_root, ann_root, id_root=None, use_restval=True, 187 | cut_validation=False): 188 | roots = {} 189 | img_root = os.path.join(img_root, '') 190 | roots['train'] = { 191 | 'img': os.path.join(img_root, ''), 192 | 'cap': os.path.join(ann_root, 'dataset.json') 193 | } 194 | roots['val'] = { 195 | 'img': os.path.join(img_root, ''), 196 | 'cap': os.path.join(ann_root, 'dataset.json') 197 | } 198 | roots['test'] = { 199 | 'img': os.path.join(img_root, ''), 200 | 'cap': os.path.join(ann_root, 'dataset.json') 201 | } 202 | ids = {'train': None, 'val': None, 'test': None} 203 | 204 | with nostdout(): 205 | self.train_examples, self.val_examples, self.test_examples = self.get_samples(img_root, ann_root, 206 | ids) 207 | examples = self.train_examples + self.val_examples + self.test_examples 208 | super(Sydney, self).__init__(examples, {'image': image_field, 'text': text_field}) 209 | 210 | @property 211 | def splits(self): 212 | train_split = PairedDataset(self.train_examples, self.fields) 213 | val_split = PairedDataset(self.val_examples, self.fields) 214 | test_split = PairedDataset(self.test_examples, self.fields) 215 | return train_split, val_split, test_split 216 | 217 | @classmethod 218 | def get_samples(cls, img_root, ann_root, ids_dataset=None): 219 | train_samples = [] 220 | val_samples = [] 221 | test_samples = [] 222 | 223 | dataset = json.load(open(os.path.join(ann_root, 'dataset.json'), 'r'))['images'] 224 | index = list(range(613)) 225 | np.random.shuffle(index) 226 | # 497 58 58 227 | a1 = index[:497] 228 | a2 = index[497:555] 229 | a3 = index[555:613] 230 | for i, d in enumerate(dataset): 231 | if i in a1: 232 | for x in range(len(d['sentences'])): 233 | captions = d['sentences'][x]['raw'] 234 | filename = d['filename'] 235 | # imgid = d['imgid'] 236 | example = Example.fromdict({'image': os.path.join(img_root, str(filename)), 'text': captions}) 237 | train_samples.append(example) 238 | elif i in a2: 239 | for x in range(len(d['sentences'])): 240 | captions = d['sentences'][x]['raw'] 241 | filename = d['filename'] 242 | example = Example.fromdict({'image': os.path.join(img_root, str(filename)), 'text': captions}) 243 | val_samples.append(example) 244 | elif i in a3: 245 | for x in range(len(d['sentences'])): 246 | captions = d['sentences'][x]['raw'] 247 | filename = d['filename'] 248 | example = Example.fromdict({'image': os.path.join(img_root, str(filename)), 'text': captions}) 249 | test_samples.append(example) 250 | 251 | return train_samples, val_samples, test_samples 252 | 253 | 254 | class UCM(PairedDataset): 255 | def __init__(self, image_field, text_field, img_root, ann_root, id_root=None, use_restval=True, 256 | cut_validation=False): 257 | 258 | roots = {} 259 | img_root = os.path.join(img_root, '') 260 | roots['train'] = { 261 | 'img': os.path.join(img_root, ''), 262 | 'cap': os.path.join(ann_root, 'dataset.json') 263 | } 264 | roots['val'] = { 265 | 'img': os.path.join(img_root, ''), 266 | 'cap': os.path.join(ann_root, 'dataset.json') 267 | } 268 | roots['test'] = { 269 | 'img': os.path.join(img_root, ''), 270 | 'cap': os.path.join(ann_root, 'dataset.json') 271 | } 272 | ids = {'train': None, 'val': None, 'test': None} 273 | 274 | with nostdout(): 275 | self.train_examples, self.val_examples, self.test_examples = self.get_samples(img_root, ann_root, ids) 276 | examples = self.train_examples + self.val_examples + self.test_examples # 三个数据集合在一起 277 | super(UCM, self).__init__(examples, {'image': image_field, 'text': text_field}) # 调用PairedDataset 278 | 279 | @property 280 | def splits(self): 281 | train_split = PairedDataset(self.train_examples, self.fields) 282 | val_split = PairedDataset(self.val_examples, self.fields) 283 | test_split = PairedDataset(self.test_examples, self.fields) 284 | return train_split, val_split, test_split 285 | 286 | @classmethod 287 | def get_samples(cls, img_root, ann_root, ids_dataset=None): 288 | train_samples = [] 289 | val_samples = [] 290 | test_samples = [] 291 | 292 | dataset = json.load(open(os.path.join(ann_root, 'dataset.json'), 'r'))['images'] 293 | index = list(range(2100)) 294 | np.random.shuffle(index) 295 | 296 | # UCM 1680 210 210 297 | a1 = index[:1680] 298 | a2 = index[1680:1890] 299 | a3 = index[1890:2100] 300 | 301 | for i, d in enumerate(dataset): 302 | if i in a1: 303 | for x in range(len(d['sentences'])): 304 | captions = d['sentences'][x]['raw'] 305 | filename = d['filename'] 306 | # imgid = d['imgid'] 307 | example = Example.fromdict({'image': os.path.join(img_root, str(filename)), 'text': captions}) 308 | train_samples.append(example) 309 | elif i in a2: 310 | for x in range(len(d['sentences'])): 311 | captions = d['sentences'][x]['raw'] 312 | filename = d['filename'] 313 | example = Example.fromdict({'image': os.path.join(img_root, str(filename)), 'text': captions}) 314 | val_samples.append(example) 315 | elif i in a3: 316 | for x in range(len(d['sentences'])): 317 | captions = d['sentences'][x]['raw'] 318 | filename = d['filename'] 319 | example = Example.fromdict({'image': os.path.join(img_root, str(filename)), 'text': captions}) 320 | test_samples.append(example) 321 | 322 | return train_samples, val_samples, test_samples 323 | 324 | 325 | class RSICD(PairedDataset): 326 | def __init__(self, image_field, text_field, img_root, ann_root, id_root=None, use_restval=True, 327 | cut_validation=False): 328 | 329 | roots = {} 330 | img_root = os.path.join(img_root, '') 331 | roots['train'] = { 332 | 'img': os.path.join(img_root, ''), 333 | 'cap': os.path.join(ann_root, 'dataset.json') 334 | } 335 | roots['val'] = { 336 | 'img': os.path.join(img_root, ''), 337 | 'cap': os.path.join(ann_root, 'dataset.json') 338 | } 339 | roots['test'] = { 340 | 'img': os.path.join(img_root, ''), 341 | 'cap': os.path.join(ann_root, 'dataset.json') 342 | } 343 | ids = {'train': None, 'val': None, 'test': None} 344 | 345 | with nostdout(): 346 | self.train_examples, self.val_examples, self.test_examples = self.get_samples(img_root, ann_root, ids) 347 | examples = self.train_examples + self.val_examples + self.test_examples 348 | super(RSICD, self).__init__(examples, {'image': image_field, 'text': text_field}) 349 | 350 | @property 351 | def splits(self): 352 | train_split = PairedDataset(self.train_examples, self.fields) 353 | val_split = PairedDataset(self.val_examples, self.fields) 354 | test_split = PairedDataset(self.test_examples, self.fields) 355 | return train_split, val_split, test_split 356 | 357 | @classmethod 358 | def get_samples(cls, img_root, ann_root, ids_dataset=None): 359 | train_samples = [] 360 | val_samples = [] 361 | test_samples = [] 362 | 363 | dataset = json.load(open(os.path.join(ann_root, 'dataset_rsicd.json'), 'r'))['images'] 364 | index = list(range(10921)) 365 | np.random.shuffle(index) 366 | 367 | # 10921 8734 1094 1093 368 | a1 = index[:8734] 369 | a2 = index[8734:9828] 370 | a3 = index[9828:10921] 371 | for i, d in enumerate(dataset): 372 | if i in a1: 373 | for x in range(len(d['sentences'])): 374 | captions = d['sentences'][x]['raw'] 375 | filename = d['filename'] 376 | example = Example.fromdict({'image': os.path.join(img_root, str(filename)), 'text': captions}) 377 | train_samples.append(example) 378 | elif i in a2: 379 | for x in range(len(d['sentences'])): 380 | captions = d['sentences'][x]['raw'] 381 | filename = d['filename'] 382 | example = Example.fromdict({'image': os.path.join(img_root, str(filename)), 'text': captions}) 383 | val_samples.append(example) 384 | elif i in a3: 385 | for x in range(len(d['sentences'])): 386 | captions = d['sentences'][x]['raw'] 387 | filename = d['filename'] 388 | example = Example.fromdict({'image': os.path.join(img_root, str(filename)), 'text': captions}) 389 | test_samples.append(example) 390 | 391 | return train_samples, val_samples, test_samples 392 | 393 | -------------------------------------------------------------------------------- /data/vocab.py: -------------------------------------------------------------------------------- 1 | from __future__ import unicode_literals 2 | import array 3 | from collections import defaultdict 4 | from functools import partial 5 | import io 6 | import logging 7 | import os 8 | import zipfile 9 | 10 | import six 11 | from six.moves.urllib.request import urlretrieve 12 | import torch 13 | from tqdm import tqdm 14 | import tarfile 15 | 16 | from .utils import reporthook 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | class Vocab(object): 22 | """Defines a vocabulary object that will be used to numericalize a field. 23 | 24 | Attributes: 25 | freqs: A collections.Counter object holding the frequencies of tokens 26 | in the data used to build the Vocab. 27 | stoi: A collections.defaultdict instance mapping token strings to 28 | numerical identifiers. 29 | itos: A list of token strings indexed by their numerical identifiers. 30 | """ 31 | def __init__(self, counter, max_size=None, min_freq=1, specials=[''], 32 | vectors=None, unk_init=None, vectors_cache=None): 33 | """Create a Vocab object from a collections.Counter. 34 | 35 | Arguments: 36 | counter: collections.Counter object holding the frequencies of 37 | each value found in the data. 38 | max_size: The maximum size of the vocabulary, or None for no 39 | maximum. Default: None. 40 | min_freq: The minimum frequency needed to include a token in the 41 | vocabulary. Values less than 1 will be set to 1. Default: 1. 42 | specials: The list of special tokens (e.g., padding or eos) that 43 | will be prepended to the vocabulary in addition to an 44 | token. Default: [''] 45 | vectors: One of either the available pretrained vectors 46 | or custom pretrained vectors (see Vocab.load_vectors); 47 | or a list of aforementioned vectors 48 | unk_init (callback): by default, initialize out-of-vocabulary word vectors 49 | to zero vectors; can be any function that takes in a Tensor and 50 | returns a Tensor of the same size. Default: torch.Tensor.zero_ 51 | vectors_cache: directory for cached vectors. Default: '.vector_cache' 52 | """ 53 | self.freqs = counter 54 | counter = counter.copy() 55 | min_freq = max(min_freq, 1) 56 | 57 | self.itos = list(specials) 58 | # frequencies of special tokens are not counted when building vocabulary 59 | # in frequency order 60 | for tok in specials: 61 | del counter[tok] 62 | 63 | max_size = None if max_size is None else max_size + len(self.itos) 64 | 65 | # sort by frequency, then alphabetically 66 | words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0]) 67 | words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True) 68 | 69 | for word, freq in words_and_frequencies: 70 | if freq < min_freq or len(self.itos) == max_size: 71 | break 72 | self.itos.append(word) 73 | 74 | self.stoi = defaultdict(_default_unk_index) 75 | # stoi is simply a reverse dict for itos 76 | self.stoi.update({tok: i for i, tok in enumerate(self.itos)}) 77 | 78 | self.vectors = None 79 | if vectors is not None: 80 | self.load_vectors(vectors, unk_init=unk_init, cache=vectors_cache) 81 | else: 82 | assert unk_init is None and vectors_cache is None 83 | 84 | def __eq__(self, other): 85 | if self.freqs != other.freqs: 86 | return False 87 | if self.stoi != other.stoi: 88 | return False 89 | if self.itos != other.itos: 90 | return False 91 | if self.vectors != other.vectors: 92 | return False 93 | return True 94 | 95 | def __len__(self): 96 | return len(self.itos) 97 | 98 | def extend(self, v, sort=False): 99 | words = sorted(v.itos) if sort else v.itos 100 | for w in words: 101 | if w not in self.stoi: 102 | self.itos.append(w) 103 | self.stoi[w] = len(self.itos) - 1 104 | 105 | def load_vectors(self, vectors, **kwargs): 106 | """ 107 | Arguments: 108 | vectors: one of or a list containing instantiations of the 109 | GloVe, CharNGram, or Vectors classes. Alternatively, one 110 | of or a list of available pretrained vectors: 111 | charngram.100d 112 | fasttext.en.300d 113 | fasttext.simple.300d 114 | glove.42B.300d 115 | glove.840B.300d 116 | glove.twitter.27B.25d 117 | glove.twitter.27B.50d 118 | glove.twitter.27B.100d 119 | glove.twitter.27B.200d 120 | glove.6B.50d 121 | glove.6B.100d 122 | glove.6B.200d 123 | glove.6B.300d 124 | Remaining keyword arguments: Passed to the constructor of Vectors classes. 125 | """ 126 | if not isinstance(vectors, list): 127 | vectors = [vectors] 128 | for idx, vector in enumerate(vectors): 129 | if six.PY2 and isinstance(vector, str): 130 | vector = six.text_type(vector) 131 | if isinstance(vector, six.string_types): 132 | # Convert the string pretrained vector identifier 133 | # to a Vectors object 134 | if vector not in pretrained_aliases: 135 | raise ValueError( 136 | "Got string input vector {}, but allowed pretrained " 137 | "vectors are {}".format( 138 | vector, list(pretrained_aliases.keys()))) 139 | vectors[idx] = pretrained_aliases[vector](**kwargs) 140 | elif not isinstance(vector, Vectors): 141 | raise ValueError( 142 | "Got input vectors of type {}, expected str or " 143 | "Vectors object".format(type(vector))) 144 | 145 | tot_dim = sum(v.dim for v in vectors) 146 | self.vectors = torch.Tensor(len(self), tot_dim) 147 | for i, token in enumerate(self.itos): 148 | start_dim = 0 149 | for v in vectors: 150 | end_dim = start_dim + v.dim 151 | self.vectors[i][start_dim:end_dim] = v[token.strip()] 152 | start_dim = end_dim 153 | assert(start_dim == tot_dim) 154 | 155 | def set_vectors(self, stoi, vectors, dim, unk_init=torch.Tensor.zero_): 156 | """ 157 | Set the vectors for the Vocab instance from a collection of Tensors. 158 | 159 | Arguments: 160 | stoi: A dictionary of string to the index of the associated vector 161 | in the `vectors` input argument. 162 | vectors: An indexed iterable (or other structure supporting __getitem__) that 163 | given an input index, returns a FloatTensor representing the vector 164 | for the token associated with the index. For example, 165 | vector[stoi["string"]] should return the vector for "string". 166 | dim: The dimensionality of the vectors. 167 | unk_init (callback): by default, initialize out-of-vocabulary word vectors 168 | to zero vectors; can be any function that takes in a Tensor and 169 | returns a Tensor of the same size. Default: torch.Tensor.zero_ 170 | """ 171 | self.vectors = torch.Tensor(len(self), dim) 172 | for i, token in enumerate(self.itos): 173 | wv_index = stoi.get(token, None) 174 | if wv_index is not None: 175 | self.vectors[i] = vectors[wv_index] 176 | else: 177 | self.vectors[i] = unk_init(self.vectors[i]) 178 | 179 | 180 | class Vectors(object): 181 | 182 | def __init__(self, name, cache=None, 183 | url=None, unk_init=None): 184 | """ 185 | Arguments: 186 | name: name of the file that contains the vectors 187 | cache: directory for cached vectors 188 | url: url for download if vectors not found in cache 189 | unk_init (callback): by default, initalize out-of-vocabulary word vectors 190 | to zero vectors; can be any function that takes in a Tensor and 191 | returns a Tensor of the same size 192 | """ 193 | cache = '.vector_cache' if cache is None else cache 194 | self.unk_init = torch.Tensor.zero_ if unk_init is None else unk_init 195 | self.cache(name, cache, url=url) 196 | 197 | def __getitem__(self, token): 198 | if token in self.stoi: 199 | return self.vectors[self.stoi[token]] 200 | else: 201 | return self.unk_init(torch.Tensor(self.dim)) # self.unk_init(torch.Tensor(1, self.dim)) 202 | 203 | def cache(self, name, cache, url=None): 204 | if os.path.isfile(name): 205 | path = name 206 | path_pt = os.path.join(cache, os.path.basename(name)) + '.pt' 207 | else: 208 | path = os.path.join(cache, name) 209 | path_pt = path + '.pt' 210 | 211 | if not os.path.isfile(path_pt): 212 | if not os.path.isfile(path) and url: 213 | logger.info('Downloading vectors from {}'.format(url)) 214 | if not os.path.exists(cache): 215 | os.makedirs(cache) 216 | dest = os.path.join(cache, os.path.basename(url)) 217 | if not os.path.isfile(dest): 218 | with tqdm(unit='B', unit_scale=True, miniters=1, desc=dest) as t: 219 | try: 220 | urlretrieve(url, dest, reporthook=reporthook(t)) 221 | except KeyboardInterrupt as e: # remove the partial zip file 222 | os.remove(dest) 223 | raise e 224 | logger.info('Extracting vectors into {}'.format(cache)) 225 | ext = os.path.splitext(dest)[1][1:] 226 | if ext == 'zip': 227 | with zipfile.ZipFile(dest, "r") as zf: 228 | zf.extractall(cache) 229 | elif ext == 'gz': 230 | with tarfile.open(dest, 'r:gz') as tar: 231 | tar.extractall(path=cache) 232 | if not os.path.isfile(path): 233 | raise RuntimeError('no vectors found at {}'.format(path)) 234 | 235 | # str call is necessary for Python 2/3 compatibility, since 236 | # argument must be Python 2 str (Python 3 bytes) or 237 | # Python 3 str (Python 2 unicode) 238 | itos, vectors, dim = [], array.array(str('d')), None 239 | 240 | # Try to read the whole file with utf-8 encoding. 241 | binary_lines = False 242 | try: 243 | with io.open(path, encoding="utf8") as f: 244 | lines = [line for line in f] 245 | # If there are malformed lines, read in binary mode 246 | # and manually decode each word from utf-8 247 | except: 248 | logger.warning("Could not read {} as UTF8 file, " 249 | "reading file as bytes and skipping " 250 | "words with malformed UTF8.".format(path)) 251 | with open(path, 'rb') as f: 252 | lines = [line for line in f] 253 | binary_lines = True 254 | 255 | logger.info("Loading vectors from {}".format(path)) 256 | for line in tqdm(lines, total=len(lines)): 257 | # Explicitly splitting on " " is important, so we don't 258 | # get rid of Unicode non-breaking spaces in the vectors. 259 | entries = line.rstrip().split(b" " if binary_lines else " ") 260 | 261 | word, entries = entries[0], entries[1:] 262 | if dim is None and len(entries) > 1: 263 | dim = len(entries) 264 | elif len(entries) == 1: 265 | logger.warning("Skipping token {} with 1-dimensional " 266 | "vector {}; likely a header".format(word, entries)) 267 | continue 268 | elif dim != len(entries): 269 | raise RuntimeError( 270 | "Vector for token {} has {} dimensions, but previously " 271 | "read vectors have {} dimensions. All vectors must have " 272 | "the same number of dimensions.".format(word, len(entries), dim)) 273 | 274 | if binary_lines: 275 | try: 276 | if isinstance(word, six.binary_type): 277 | word = word.decode('utf-8') 278 | except: 279 | logger.info("Skipping non-UTF8 token {}".format(repr(word))) 280 | continue 281 | vectors.extend(float(x) for x in entries) 282 | itos.append(word) 283 | 284 | self.itos = itos 285 | self.stoi = {word: i for i, word in enumerate(itos)} 286 | self.vectors = torch.Tensor(vectors).view(-1, dim) 287 | self.dim = dim 288 | logger.info('Saving vectors to {}'.format(path_pt)) 289 | if not os.path.exists(cache): 290 | os.makedirs(cache) 291 | torch.save((self.itos, self.stoi, self.vectors, self.dim), path_pt) 292 | else: 293 | logger.info('Loading vectors from {}'.format(path_pt)) 294 | self.itos, self.stoi, self.vectors, self.dim = torch.load(path_pt) 295 | 296 | 297 | class GloVe(Vectors): 298 | url = { 299 | '42B': 'http://nlp.stanford.edu/data/glove.42B.300d.zip', 300 | '840B': 'http://nlp.stanford.edu/data/glove.840B.300d.zip', 301 | 'twitter.27B': 'http://nlp.stanford.edu/data/glove.twitter.27B.zip', 302 | '6B': 'http://nlp.stanford.edu/data/glove.6B.zip', 303 | } 304 | 305 | def __init__(self, name='840B', dim=300, **kwargs): 306 | url = self.url[name] 307 | name = 'glove.{}.{}d.txt'.format(name, str(dim)) 308 | super(GloVe, self).__init__(name, url=url, **kwargs) 309 | 310 | 311 | class FastText(Vectors): 312 | 313 | url_base = 'https://s3-us-west-1.amazonaws.com/fasttext-vectors/wiki.{}.vec' 314 | 315 | def __init__(self, language="en", **kwargs): 316 | url = self.url_base.format(language) 317 | name = os.path.basename(url) 318 | super(FastText, self).__init__(name, url=url, **kwargs) 319 | 320 | 321 | class CharNGram(Vectors): 322 | 323 | name = 'charNgram.txt' 324 | url = ('http://www.logos.t.u-tokyo.ac.jp/~hassy/publications/arxiv2016jmt/' 325 | 'jmt_pre-trained_embeddings.tar.gz') 326 | 327 | def __init__(self, **kwargs): 328 | super(CharNGram, self).__init__(self.name, url=self.url, **kwargs) 329 | 330 | def __getitem__(self, token): 331 | vector = torch.Tensor(1, self.dim).zero_() 332 | if token == "": 333 | return self.unk_init(vector) 334 | # These literals need to be coerced to unicode for Python 2 compatibility 335 | # when we try to join them with read ngrams from the files. 336 | chars = ['#BEGIN#'] + list(token) + ['#END#'] 337 | num_vectors = 0 338 | for n in [2, 3, 4]: 339 | end = len(chars) - n + 1 340 | grams = [chars[i:(i + n)] for i in range(end)] 341 | for gram in grams: 342 | gram_key = '{}gram-{}'.format(n, ''.join(gram)) 343 | if gram_key in self.stoi: 344 | vector += self.vectors[self.stoi[gram_key]] 345 | num_vectors += 1 346 | if num_vectors > 0: 347 | vector /= num_vectors 348 | else: 349 | vector = self.unk_init(vector) 350 | return vector 351 | 352 | 353 | def _default_unk_index(): 354 | return 0 355 | 356 | 357 | pretrained_aliases = { 358 | "charngram.100d": partial(CharNGram), 359 | "fasttext.en.300d": partial(FastText, language="en"), 360 | "fasttext.simple.300d": partial(FastText, language="simple"), 361 | "glove.42B.300d": partial(GloVe, name="42B", dim="300"), 362 | "glove.840B.300d": partial(GloVe, name="840B", dim="300"), 363 | "glove.twitter.27B.25d": partial(GloVe, name="twitter.27B", dim="25"), 364 | "glove.twitter.27B.50d": partial(GloVe, name="twitter.27B", dim="50"), 365 | "glove.twitter.27B.100d": partial(GloVe, name="twitter.27B", dim="100"), 366 | "glove.twitter.27B.200d": partial(GloVe, name="twitter.27B", dim="200"), 367 | "glove.6B.50d": partial(GloVe, name="6B", dim="50"), 368 | "glove.6B.100d": partial(GloVe, name="6B", dim="100"), 369 | "glove.6B.200d": partial(GloVe, name="6B", dim="200"), 370 | "glove.6B.300d": partial(GloVe, name="6B", dim="300") 371 | } 372 | """Mapping from string name to factory function""" 373 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from torch.backends import cudnn 4 | 5 | from data import ImageDetectionsField, TextField, RawField 6 | from data import DataLoader, Sydney, UCM, RSICD 7 | import evaluation 8 | from evaluation import PTBTokenizer, Cider 9 | from models.transformer import Transformer, MemoryAugmentedEncoder, MeshedDecoder, GlobalGroupingAttention_with_DC 10 | import torch 11 | from torch.optim import Adam 12 | from torch.optim.lr_scheduler import LambdaLR 13 | from torch.nn import NLLLoss 14 | from tqdm import tqdm 15 | from torch.utils.tensorboard import SummaryWriter 16 | import argparse, os, pickle 17 | import numpy as np 18 | import itertools 19 | import multiprocessing 20 | from shutil import copyfile 21 | 22 | import warnings 23 | 24 | warnings.filterwarnings("ignore") 25 | 26 | 27 | 28 | def seed_torch(seed=1): 29 | random.seed(seed) 30 | os.environ['PYTHONHASHSEED'] = str(seed) # 为了禁止hash随机化,使得实验可复现 31 | np.random.seed(seed) 32 | torch.manual_seed(seed) 33 | torch.cuda.manual_seed(seed) 34 | torch.backends.cudnn.deterministic = True 35 | torch.backends.cudnn.benchmark = False 36 | seed_torch() 37 | 38 | 39 | def evaluate_loss(model, dataloader, loss_fn, text_field): 40 | # Validation loss 41 | model.eval() 42 | running_loss = .0 43 | with tqdm(desc='Epoch %d - validation' % e, unit='it', total=len(dataloader)) as pbar: 44 | with torch.no_grad(): 45 | for it, (detections, detections_gl, detections_mask, captions) in enumerate(dataloader): 46 | detections, detections_gl, detections_mask, captions = detections.to(device), detections_gl.to(device), \ 47 | detections_mask.to(device), captions.to(device) 48 | out = model(detections, detections_gl, detections_mask, captions, isencoder=True) 49 | captions = captions[:, 1:].contiguous() 50 | out = out[:, :-1].contiguous() 51 | loss = loss_fn(out.view(-1, len(text_field.vocab)), captions.view(-1)) 52 | this_loss = loss.item() 53 | running_loss += this_loss 54 | pbar.set_postfix(loss=running_loss / (it + 1)) 55 | pbar.update() 56 | val_loss = running_loss / len(dataloader) 57 | return val_loss 58 | 59 | 60 | def evaluate_metrics(model, dataloader, text_field): 61 | import itertools 62 | model.eval() 63 | gen = {} 64 | gts = {} 65 | with tqdm(desc='Epoch %d - evaluation' % e, unit='it', total=len(dataloader)) as pbar: 66 | for it, (images, caps_gt) in enumerate(dataloader): 67 | detections = images[0].to(device) 68 | detections_gl = images[1].to(device) 69 | detections_mask = images[2].to(device) 70 | 71 | with torch.no_grad(): 72 | out, _ = model.beam_search(detections, detections_gl, detections_mask, 20, 73 | text_field.vocab.stoi[''], 5, out_size=1) 74 | caps_gen = text_field.decode(out, join_words=False) 75 | for i, (gts_i, gen_i) in enumerate(zip(caps_gt, caps_gen)): 76 | gen_i = ' '.join([k for k, g in itertools.groupby(gen_i)]) 77 | gen['%d_%d' % (it, i)] = [gen_i, ] 78 | gts['%d_%d' % (it, i)] = gts_i 79 | pbar.update() 80 | 81 | gts = evaluation.PTBTokenizer.tokenize(gts) 82 | gen = evaluation.PTBTokenizer.tokenize(gen) 83 | scores, _ = evaluation.compute_scores(gts, gen) 84 | return scores 85 | 86 | 87 | def train_xe(model, dataloader, optim, text_field): 88 | # Training with cross-entropy 89 | model.train() 90 | scheduler.step() 91 | running_loss = .0 92 | with tqdm(desc='Epoch %d - train' % e, unit='it', total=len(dataloader)) as pbar: 93 | for it, (detections, detections_gl, detections_mask, captions) in enumerate( 94 | dataloader): 95 | 96 | detections, detections_gl, detections_mask, captions = detections.to(device), detections_gl.to(device), \ 97 | detections_mask.to(device), captions.to(device) 98 | out = model(detections, detections_gl, detections_mask, captions, isencoder=True) 99 | optim.zero_grad() 100 | captions_gt = captions[:, 1:].contiguous() 101 | out = out[:, :-1].contiguous() 102 | loss = loss_fn(out.view(-1, len(text_field.vocab)), captions_gt.view(-1)) 103 | loss.backward() 104 | 105 | optim.step() 106 | this_loss = loss.item() 107 | running_loss += this_loss 108 | 109 | pbar.set_postfix(loss=running_loss / (it + 1)) 110 | pbar.update() 111 | scheduler.step() 112 | 113 | loss = running_loss / len(dataloader) 114 | return loss 115 | 116 | 117 | def train_scst(model, dataloader, optim, cider, text_field): 118 | # Training with self-critical 119 | tokenizer_pool = multiprocessing.Pool() 120 | running_reward = .0 121 | running_reward_baseline = .0 122 | model.train() 123 | running_loss = .0 124 | seq_len = 20 125 | beam_size = 5 126 | 127 | with tqdm(desc='Epoch %d - train' % e, unit='it', total=len(dataloader)) as pbar: 128 | for it, (images, caps_gt) in enumerate(dataloader): 129 | detections = images[0].to(device) 130 | detections_gl = images[1].to(device) 131 | detections_mask = images[2].to(device) 132 | 133 | outs, log_probs = model.beam_search(detections, detections_gl, detections_mask, 20, 134 | text_field.vocab.stoi[''], beam_size, out_size=beam_size) 135 | 136 | optim.zero_grad() 137 | 138 | # Rewards 139 | caps_gen = text_field.decode(outs.view(-1, seq_len)) 140 | caps_gt = list(itertools.chain(*([c, ] * beam_size for c in caps_gt))) 141 | caps_gen, caps_gt = tokenizer_pool.map(evaluation.PTBTokenizer.tokenize, [caps_gen, caps_gt]) 142 | reward = cider.compute_score(caps_gt, caps_gen)[1].astype(np.float32) 143 | reward = torch.from_numpy(reward).to(device).view(detections.shape[0], beam_size) 144 | reward_baseline = torch.mean(reward, -1, keepdim=True) 145 | loss = -torch.mean(log_probs, -1) * (reward - reward_baseline) 146 | 147 | loss = loss.mean() 148 | loss.backward() 149 | optim.step() 150 | 151 | running_loss += loss.item() 152 | running_reward += reward.mean().item() 153 | running_reward_baseline += reward_baseline.mean().item() 154 | pbar.set_postfix(loss=running_loss / (it + 1), reward=running_reward / (it + 1), 155 | reward_baseline=running_reward_baseline / (it + 1)) 156 | pbar.update() 157 | 158 | loss = running_loss / len(dataloader) 159 | reward = running_reward / len(dataloader) 160 | reward_baseline = running_reward_baseline / len(dataloader) 161 | return loss, reward, reward_baseline 162 | 163 | 164 | 165 | if __name__ == '__main__': 166 | device = torch.device('cuda') 167 | parser = argparse.ArgumentParser(description='MG-Transformer') 168 | parser.add_argument('--exp_name', type=str, default='Sydney') # Sydney,UCM,RSICD 169 | parser.add_argument('--batch_size', type=int, default=50) 170 | parser.add_argument('--workers', type=int, default=0) 171 | parser.add_argument('--head', type=int, default=8) 172 | parser.add_argument('--warmup', type=int, default=10000) # 1000 173 | parser.add_argument('--warm_up_epochs', type=int, default=10) # 10 174 | parser.add_argument('--epochs', type=int, default=100) 175 | parser.add_argument('--resume_last', action='store_true') 176 | parser.add_argument('--resume_best', action='store_true') 177 | ################################################################################################################ 178 | parser.add_argument('--annotation_folder', type=str, 179 | default='/media/dmd/ours/mlw/rs/Sydney_Captions') 180 | parser.add_argument('--res_features_path', type=str, 181 | default='/media/dmd/ours/mlw/rs/multi_scale_T/sydney_res152_con5_224') 182 | 183 | parser.add_argument('--clip_features_path', type=str, 184 | default='/media/dmd/ours/mlw/rs/clip_feature/sydney_224') 185 | 186 | parser.add_argument('--mask_features_path', type=str, 187 | default='/media/dmd/ours/mlw/rs/multi_scale_T/sydney_group_mask_8_6') 188 | 189 | 190 | parser.add_argument('--logs_folder', type=str, default='tensorboard_logs') 191 | args = parser.parse_args() 192 | print(args) 193 | 194 | print('Transformer Training') 195 | # 日志 196 | writer = SummaryWriter(log_dir=os.path.join(args.logs_folder, args.exp_name)) 197 | 198 | # Pipeline for image regions 199 | image_field = ImageDetectionsField(clip_features_path=args.clip_features_path, 200 | res_features_path=args.res_features_path, 201 | mask_features_path=args.mask_features_path) 202 | 203 | # Pipeline for text 204 | text_field = TextField(init_token='', eos_token='', lower=True, tokenize='spacy', 205 | remove_punctuation=True, nopoints=False) 206 | 207 | # Create the dataset 208 | if args.exp_name == 'Sydney': 209 | dataset = Sydney(image_field, text_field, 'Sydney/images/', args.annotation_folder, args.annotation_folder) 210 | elif args.exp_name == 'UCM': 211 | dataset = UCM(image_field, text_field, 'UCM/images/', args.annotation_folder, args.annotation_folder) 212 | elif args.exp_name == 'RSICD': 213 | dataset = RSICD(image_field, text_field, 'RSICD/images/', args.annotation_folder, args.annotation_folder) 214 | 215 | 216 | train_dataset, val_dataset, test_dataset = dataset.splits 217 | 218 | 219 | if not os.path.isfile('vocab_%s.pkl' % args.exp_name): 220 | print("Building vocabulary") 221 | text_field.build_vocab(train_dataset, val_dataset, min_freq=5) 222 | pickle.dump(text_field.vocab, open('vocab_%s.pkl' % args.exp_name, 'wb')) 223 | else: 224 | text_field.vocab = pickle.load(open('vocab_%s.pkl' % args.exp_name, 'rb')) 225 | 226 | 227 | encoder = MemoryAugmentedEncoder(3, 0, attention_module=GlobalGroupingAttention_with_DC) 228 | decoder = MeshedDecoder(len(text_field.vocab), 127, 3, text_field.vocab.stoi['']) 229 | model = Transformer(text_field.vocab.stoi[''], encoder, decoder).to(device) 230 | 231 | dict_dataset_train = train_dataset.image_dictionary({'image': image_field, 'text': RawField()}) 232 | ref_caps_train = list(train_dataset.text) 233 | cider_train = Cider(PTBTokenizer.tokenize(ref_caps_train)) 234 | dict_dataset_val = val_dataset.image_dictionary({'image': image_field, 'text': RawField()}) 235 | dict_dataset_test = test_dataset.image_dictionary({'image': image_field, 'text': RawField()}) 236 | 237 | 238 | def lambda_lr(s): 239 | warm_up = args.warmup 240 | s += 1 241 | return (model.d_model ** -.5) * min(s ** -.5, s * warm_up ** -1.5) 242 | # return (model.d_model ** -.5) * min(s ** -.5, s * warm_up ** -2) 243 | 244 | 245 | 246 | # Initial conditions 247 | optim = Adam(model.parameters(), lr=1, betas=(0.9, 0.98)) 248 | scheduler = LambdaLR(optim, lambda_lr) 249 | loss_fn = NLLLoss(ignore_index=text_field.vocab.stoi['']) 250 | use_rl = False 251 | best_cider = .0 252 | best_test_cider = .0 253 | patience = 0 254 | start_epoch = 0 255 | 256 | if args.resume_last or args.resume_best: 257 | if args.resume_last: 258 | fname = 'saved_models/%s_last.pth' % args.exp_name 259 | else: 260 | fname = 'saved_models/%s_best.pth' % args.exp_name 261 | 262 | if os.path.exists(fname): 263 | data = torch.load(fname) 264 | torch.set_rng_state(data['torch_rng_state']) 265 | torch.cuda.set_rng_state(data['cuda_rng_state']) 266 | np.random.set_state(data['numpy_rng_state']) 267 | random.setstate(data['random_rng_state']) 268 | model.load_state_dict(data['state_dict'], strict=False) 269 | optim.load_state_dict(data['optimizer']) 270 | scheduler.load_state_dict(data['scheduler']) 271 | start_epoch = data['epoch'] + 1 272 | best_cider = data['best_cider'] 273 | best_test_cider = data['best_test_cider'] 274 | patience = data['patience'] 275 | use_rl = data['use_rl'] 276 | # print('Resuming from epoch %d, validation loss %f, best cider %f, and best_test_cider %f' % ( 277 | # data['epoch'], data['val_loss'], data['best_cider'], data['best_test_cider'])) 278 | 279 | print("Training starts") 280 | 281 | for e in range(start_epoch, start_epoch + 100): 282 | dataloader_train = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, 283 | drop_last=True) 284 | dataloader_val = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) 285 | dict_dataloader_train = DataLoader(dict_dataset_train, batch_size=args.batch_size // 5, shuffle=True, 286 | num_workers=args.workers) 287 | dict_dataloader_val = DataLoader(dict_dataset_val, batch_size=args.batch_size // 5) 288 | dict_dataloader_test = DataLoader(dict_dataset_test, batch_size=args.batch_size // 5) 289 | if not use_rl: 290 | train_loss = train_xe(model, dataloader_train, optim, text_field) 291 | writer.add_scalar('data/train_loss', train_loss, e) 292 | else: 293 | train_loss, reward, reward_baseline = train_scst(model, dict_dataloader_train, optim, cider_train, 294 | text_field) 295 | writer.add_scalar('data/train_loss', train_loss, e) 296 | writer.add_scalar('data/reward', reward, e) 297 | writer.add_scalar('data/reward_baseline', reward_baseline, e) 298 | 299 | 300 | # Validation loss 301 | val_loss = evaluate_loss(model, dataloader_val, loss_fn, text_field) 302 | writer.add_scalar('data/val_loss', val_loss, e) 303 | 304 | # Validation scores 305 | scores = evaluate_metrics(model, dict_dataloader_val, text_field) 306 | # print("Validation scores", scores) 307 | val_cider = scores['CIDEr'] 308 | writer.add_scalar('data/val_cider', val_cider, e) 309 | writer.add_scalar('data/val_bleu1', scores['BLEU'][0], e) 310 | writer.add_scalar('data/val_bleu2', scores['BLEU'][1], e) 311 | writer.add_scalar('data/val_bleu3', scores['BLEU'][2], e) 312 | writer.add_scalar('data/val_bleu4', scores['BLEU'][3], e) 313 | writer.add_scalar('data/val_meteor', scores['METEOR'], e) 314 | writer.add_scalar('data/val_rouge', scores['ROUGE'], e) 315 | writer.add_scalar('data/val_spice', scores['SPICE'], e) 316 | writer.add_scalar('data/val_S*', 317 | (scores['BLEU'][3] + scores['METEOR'] + scores['ROUGE'] + scores['CIDEr']) / 4, e) 318 | writer.add_scalar('data/val_Sm', ( 319 | scores['BLEU'][3] + scores['METEOR'] + scores['ROUGE'] + scores['CIDEr'] + scores['SPICE']) / 5, e) 320 | 321 | # Test scores 322 | scores = evaluate_metrics(model, dict_dataloader_test, text_field) 323 | print("Test scores", scores) 324 | test_cider = scores['CIDEr'] 325 | writer.add_scalar('data/test_cider', scores['CIDEr'], e) 326 | writer.add_scalar('data/test_bleu1', scores['BLEU'][0], e) 327 | writer.add_scalar('data/test_bleu2', scores['BLEU'][1], e) 328 | writer.add_scalar('data/test_bleu3', scores['BLEU'][2], e) 329 | writer.add_scalar('data/test_bleu4', scores['BLEU'][3], e) 330 | writer.add_scalar('data/test_meteor', scores['METEOR'], e) 331 | writer.add_scalar('data/test_rouge', scores['ROUGE'], e) 332 | writer.add_scalar('data/test_spice', scores['SPICE'], e) 333 | writer.add_scalar('data/test_S*', 334 | (scores['BLEU'][3] + scores['METEOR'] + scores['ROUGE'] + scores['CIDEr']) / 4, e) 335 | writer.add_scalar('data/test_Sm', ( 336 | scores['BLEU'][3] + scores['METEOR'] + scores['ROUGE'] + scores['CIDEr'] + scores['SPICE']) / 5, e) 337 | 338 | 339 | 340 | # Prepare for next epoch 341 | best = False 342 | if val_cider >= best_cider: 343 | best_cider = val_cider 344 | patience = 0 345 | best = True 346 | else: 347 | patience += 1 348 | 349 | best_test = False 350 | if test_cider >= best_test_cider: 351 | best_test_cider = test_cider 352 | best_test = True 353 | 354 | switch_to_rl = False 355 | exit_train = False 356 | if patience == 5: 357 | if not use_rl: 358 | use_rl = True 359 | switch_to_rl = True 360 | patience = 0 361 | optim = Adam(model.parameters(), lr=5e-6) 362 | # print("Switching to RL") 363 | else: 364 | print('patience reached.') 365 | exit_train = True 366 | 367 | 368 | if switch_to_rl and not best: 369 | data = torch.load('saved_models/%s_best.pth' % args.exp_name) 370 | torch.set_rng_state(data['torch_rng_state']) 371 | torch.cuda.set_rng_state(data['cuda_rng_state']) 372 | np.random.set_state(data['numpy_rng_state']) 373 | random.setstate(data['random_rng_state']) 374 | model.load_state_dict(data['state_dict']) 375 | # print('Resuming from epoch %d, validation loss %f, best_cider %f, and best test_cider %f' % ( 376 | # data['epoch'], data['val_loss'], data['best_cider'], data['best_test_cider'])) 377 | 378 | 379 | torch.save({ 380 | 'torch_rng_state': torch.get_rng_state(), 381 | 'cuda_rng_state': torch.cuda.get_rng_state(), 382 | 'numpy_rng_state': np.random.get_state(), 383 | 'random_rng_state': random.getstate(), 384 | 'epoch': e, 385 | 'val_loss': val_loss, 386 | 'val_cider': val_cider, 387 | 'state_dict': model.state_dict(), 388 | 'optimizer': optim.state_dict(), 389 | 'scheduler': scheduler.state_dict(), 390 | 'patience': patience, 391 | 'best_cider': best_cider, 392 | 'best_test_cider': best_test_cider, 393 | 'use_rl': use_rl, 394 | }, 'saved_models/%s_last.pth' % args.exp_name) 395 | 396 | if best: 397 | copyfile('saved_models/%s_last.pth' % args.exp_name, 'saved_models/%s_best.pth' % args.exp_name) 398 | 399 | 400 | if best_test: 401 | copyfile('saved_models/%s_last.pth' % args.exp_name, 'saved_models/%s_best_test.pth' % args.exp_name) 402 | 403 | if exit_train: 404 | writer.close() 405 | break 406 | 407 | --------------------------------------------------------------------------------