├── lib ├── sa │ ├── modules │ │ ├── __init__.py │ │ ├── subtraction.py │ │ ├── aggregation.py │ │ └── subtraction2.py │ ├── functions │ │ ├── __init__.py │ │ ├── utils.py │ │ ├── subtraction_zeropad.py │ │ ├── subtraction_refpad.py │ │ ├── subtraction2_zeropad.py │ │ ├── aggregation_zeropad.py │ │ ├── subtraction2_refpad.py │ │ └── aggregation_refpad.py │ └── functional.py ├── mmcv_custom │ ├── __init__.py │ └── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ └── checkpoint.cpython-37.pyc ├── _utils.py ├── mask_predictor.py ├── segmentation.py ├── text_aware_multiscale_enhancement.py └── transformer.py ├── requirements.txt ├── bert ├── activations.py ├── configuration_bert.py └── configuration_utils.py ├── loss └── loss.py ├── README.md ├── args.py ├── transforms.py ├── utils.py ├── test.py ├── data ├── refsegrs_refer_bert.py └── rrsisd_refer_bert.py ├── refer └── refer.py └── train.py /lib/sa/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .aggregation import * 2 | from .subtraction import * 3 | from .subtraction2 import * 4 | -------------------------------------------------------------------------------- /lib/mmcv_custom/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .checkpoint import load_checkpoint 4 | 5 | __all__ = ['load_checkpoint'] 6 | -------------------------------------------------------------------------------- /lib/mmcv_custom/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shaosifan/FIANet/HEAD/lib/mmcv_custom/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /lib/mmcv_custom/__pycache__/checkpoint.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shaosifan/FIANet/HEAD/lib/mmcv_custom/__pycache__/checkpoint.cpython-37.pyc -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | requests 2 | filelock 3 | tqdm 4 | timm 5 | mmcv-full==1.4.0 6 | mmsegmentation==0.17.0 7 | ftfy 8 | regex 9 | scipy 10 | scikit-image 11 | pycocotools==2.0.7 12 | opencv-python==4.5.3.56 13 | tokenizers==0.8.1rc1 14 | h5py -------------------------------------------------------------------------------- /lib/sa/functions/__init__.py: -------------------------------------------------------------------------------- 1 | from .aggregation_zeropad import * 2 | from .aggregation_refpad import * 3 | from .subtraction_zeropad import * 4 | from .subtraction_refpad import * 5 | from .subtraction2_zeropad import * 6 | from .subtraction2_refpad import * 7 | from .utils import * 8 | -------------------------------------------------------------------------------- /lib/sa/functions/utils.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from string import Template 3 | import cupy 4 | import torch 5 | 6 | 7 | Stream = namedtuple('Stream', ['ptr']) 8 | 9 | 10 | def Dtype(t): 11 | if isinstance(t, torch.cuda.FloatTensor): 12 | return 'float' 13 | elif isinstance(t, torch.cuda.DoubleTensor): 14 | return 'double' 15 | 16 | 17 | @cupy.memoize(for_each_device=True) 18 | def load_kernel(kernel_name, code, **kwargs): 19 | code = Template(code).substitute(**kwargs) 20 | kernel_code = cupy.cuda.compile_with_cache(code) 21 | return kernel_code.get_function(kernel_name) 22 | -------------------------------------------------------------------------------- /lib/sa/modules/subtraction.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn.modules.utils import _pair 3 | 4 | from .. import functional as F 5 | 6 | 7 | class Subtraction(nn.Module): 8 | 9 | def __init__(self, kernel_size, stride, padding, dilation, pad_mode): 10 | super(Subtraction, self).__init__() 11 | self.kernel_size = _pair(kernel_size) 12 | self.stride = _pair(stride) 13 | self.padding = _pair(padding) 14 | self.dilation = _pair(dilation) 15 | self.pad_mode = pad_mode 16 | 17 | def forward(self, input): 18 | return F.subtraction(input, self.kernel_size, self.stride, self.padding, self.dilation, self.pad_mode) 19 | -------------------------------------------------------------------------------- /lib/sa/modules/aggregation.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn.modules.utils import _pair 3 | 4 | from .. import functional as F 5 | 6 | 7 | class Aggregation(nn.Module): 8 | 9 | def __init__(self, kernel_size, stride, padding, dilation, pad_mode): 10 | super(Aggregation, self).__init__() 11 | self.kernel_size = _pair(kernel_size) 12 | self.stride = _pair(stride) 13 | self.padding = _pair(padding) 14 | self.dilation = _pair(dilation) 15 | self.pad_mode = pad_mode 16 | 17 | def forward(self, input, weight): 18 | return F.aggregation(input, weight, self.kernel_size, self.stride, self.padding, self.dilation, self.pad_mode) 19 | -------------------------------------------------------------------------------- /lib/sa/modules/subtraction2.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn.modules.utils import _pair 3 | 4 | from .. import functional as F 5 | 6 | 7 | class Subtraction2(nn.Module): 8 | 9 | def __init__(self, kernel_size, stride, padding, dilation, pad_mode): 10 | super(Subtraction2, self).__init__() 11 | self.kernel_size = _pair(kernel_size) 12 | self.stride = _pair(stride) 13 | self.padding = _pair(padding) 14 | self.dilation = _pair(dilation) 15 | self.pad_mode = pad_mode 16 | 17 | def forward(self, input1, input2): 18 | return F.subtraction2(input1, input2, self.kernel_size, self.stride, self.padding, self.dilation, self.pad_mode) 19 | -------------------------------------------------------------------------------- /lib/sa/functional.py: -------------------------------------------------------------------------------- 1 | from . import functions 2 | 3 | 4 | def aggregation(input, weight, kernel_size=3, stride=1, padding=0, dilation=1, pad_mode=1): 5 | assert input.shape[0] == weight.shape[0] and (input.shape[1] % weight.shape[1] == 0) and pad_mode in [0, 1] 6 | if input.is_cuda: 7 | if pad_mode == 0: 8 | out = functions.aggregation_zeropad(input, weight, kernel_size, stride, padding, dilation) 9 | elif pad_mode == 1: 10 | out = functions.aggregation_refpad(input, weight, kernel_size, stride, padding, dilation) 11 | else: 12 | raise NotImplementedError 13 | return out 14 | 15 | 16 | def subtraction(input, kernel_size=3, stride=1, padding=0, dilation=1, pad_mode=1): 17 | assert input.dim() == 4 and pad_mode in [0, 1] 18 | if input.is_cuda: 19 | if pad_mode == 0: 20 | out = functions.subtraction_zeropad(input, kernel_size, stride, padding, dilation) 21 | elif pad_mode == 1: 22 | out = functions.subtraction_refpad(input, kernel_size, stride, padding, dilation) 23 | else: 24 | raise NotImplementedError 25 | return out 26 | 27 | 28 | def subtraction2(input1, input2, kernel_size=3, stride=1, padding=0, dilation=1, pad_mode=1): 29 | assert input1.dim() == 4 and input2.dim() == 4 and pad_mode in [0, 1] 30 | if input1.is_cuda: 31 | if pad_mode == 0: 32 | out = functions.subtraction2_zeropad(input1, input2, kernel_size, stride, padding, dilation) 33 | elif pad_mode == 1: 34 | out = functions.subtraction2_refpad(input1, input2, kernel_size, stride, padding, dilation) 35 | else: 36 | raise NotImplementedError 37 | return out 38 | -------------------------------------------------------------------------------- /bert/activations.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | def swish(x): 12 | return x * torch.sigmoid(x) 13 | 14 | 15 | def _gelu_python(x): 16 | """ Original Implementation of the gelu activation function in Google Bert repo when initially created. 17 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 18 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 19 | This is now written in C in torch.nn.functional 20 | Also see https://arxiv.org/abs/1606.08415 21 | """ 22 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 23 | 24 | 25 | def gelu_new(x): 26 | """ Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT). 27 | Also see https://arxiv.org/abs/1606.08415 28 | """ 29 | return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) 30 | 31 | 32 | if torch.__version__ < "1.4.0": 33 | gelu = _gelu_python 34 | else: 35 | gelu = F.gelu 36 | 37 | 38 | def gelu_fast(x): 39 | return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x))) 40 | 41 | 42 | ACT2FN = { 43 | "relu": F.relu, 44 | "swish": swish, 45 | "gelu": gelu, 46 | "tanh": torch.tanh, 47 | "gelu_new": gelu_new, 48 | "gelu_fast": gelu_fast, 49 | } 50 | 51 | 52 | def get_activation(activation_string): 53 | if activation_string in ACT2FN: 54 | return ACT2FN[activation_string] 55 | else: 56 | raise KeyError("function {} not found in ACT2FN mapping {}".format(activation_string, list(ACT2FN.keys()))) 57 | -------------------------------------------------------------------------------- /loss/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | class DiceLoss: 6 | "Dice loss for segmentation" 7 | 8 | def __init__(self, 9 | axis: int = 1, # Class axis 10 | smooth: float = 1e-6, # Helps with numerical stabilities in the IoU division 11 | reduction: str = "sum", # PyTorch reduction to apply to the output 12 | square_in_union: bool = False # Squares predictions to increase slope of gradients 13 | ): 14 | self.axis = axis 15 | self.smooth = smooth 16 | self.reduction = reduction 17 | self.square_in_union = square_in_union 18 | 19 | def __call__(self, pred, targ): 20 | "One-hot encodes targ, then runs IoU calculation then takes 1-dice value" 21 | targ = self._one_hot(targ, pred.shape[self.axis]) 22 | assert pred.shape == targ.shape, 'input and target dimensions differ, DiceLoss expects non one-hot targs' 23 | pred = self.activation(pred) 24 | sum_dims = list(range(2, len(pred.shape))) 25 | inter = torch.sum(pred * targ, dim=sum_dims) 26 | union = (torch.sum(pred ** 2 + targ, dim=sum_dims) if self.square_in_union 27 | else torch.sum(pred + targ, dim=sum_dims)) 28 | dice_score = (2. * inter + self.smooth) / (union + self.smooth) 29 | loss = 1 - dice_score 30 | if self.reduction == 'mean': 31 | loss = loss.mean() 32 | elif self.reduction == 'sum': 33 | loss = loss.sum() 34 | return loss 35 | 36 | @staticmethod 37 | def _one_hot( 38 | x, # Non one-hot encoded targs 39 | classes: int, # The number of classes 40 | axis: int = 1 # The axis to stack for encoding (class dimension) 41 | ): 42 | "Creates one binary mask per class" 43 | return torch.stack([torch.where(x == c, 1, 0) for c in range(classes)], axis=axis) 44 | 45 | def activation(self, x): 46 | "Activation function applied to model output" 47 | return F.softmax(x, dim=self.axis) 48 | 49 | def decodes(self, x): 50 | "Converts model output to target format" 51 | return x.argmax(dim=self.axis) 52 | 53 | 54 | class Loss(): 55 | def __init__(self, weight=0.1): 56 | self.dice_loss = DiceLoss() 57 | self.ce_loss = torch.nn.CrossEntropyLoss(weight=torch.FloatTensor([0.9, 1.1]).cuda()) 58 | self.weight = weight 59 | 60 | def __call__(self, pred, targ): 61 | dice_loss = self.dice_loss(pred, targ) 62 | ce_loss = self.ce_loss(pred, targ) 63 | return (1 - self.weight) * ce_loss + self.weight * dice_loss -------------------------------------------------------------------------------- /lib/_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from bert.modeling_bert import BertModel 5 | 6 | 7 | def load_weights(model, load_path): 8 | dict_trained = torch.load(load_path)['model'] 9 | dict_new = model.state_dict().copy() 10 | for key in dict_new.keys(): 11 | if key in dict_trained.keys(): 12 | dict_new[key] = dict_trained[key] 13 | model.load_state_dict(dict_new) 14 | del dict_new 15 | del dict_trained 16 | torch.cuda.empty_cache() 17 | print('load weights from {}'.format(load_path)) 18 | return model 19 | 20 | 21 | class _LAVTSimpleDecode(nn.Module): 22 | def __init__(self, backbone, classifier): 23 | super(_LAVTSimpleDecode, self).__init__() 24 | self.backbone = backbone 25 | self.classifier = classifier 26 | 27 | def forward(self, x, l_feats, l_mask): 28 | input_shape = x.shape[-2:] 29 | features = self.backbone(x, l_feats, l_mask) 30 | x_c1, x_c2, x_c3, x_c4 = features 31 | 32 | x = self.classifier(x_c4, x_c3, x_c2, x_c1) 33 | x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=True) 34 | 35 | return x 36 | 37 | 38 | class LAVT(_LAVTSimpleDecode): 39 | pass 40 | 41 | 42 | ############################################### 43 | # LAVT One: put BERT inside the overall model # 44 | ############################################### 45 | class _LAVTOneSimpleDecode(nn.Module): 46 | def __init__(self, backbone, classifier, args): 47 | super(_LAVTOneSimpleDecode, self).__init__() 48 | self.backbone = backbone 49 | self.classifier = classifier 50 | self.text_encoder = BertModel.from_pretrained(args.ck_bert) 51 | self.text_encoder.pooler = None 52 | 53 | def forward(self, x, text, l_mask, t_mask, p_mask): 54 | input_shape = x.shape[-2:] 55 | ### language inference ### 56 | l_feats = self.text_encoder(text, attention_mask=l_mask)[0] 57 | l_feats = l_feats.permute(0, 2, 1) # (B, 768, N_l) 58 | l_mask = l_mask.unsqueeze(dim=-1) # (batch, N_l, 1) 59 | 60 | t_feats = self.text_encoder(text, attention_mask=t_mask)[0] 61 | t_feats = t_feats.permute(0, 2, 1) # (B, 768, N_l) 62 | t_mask = t_mask.unsqueeze(dim=-1) # (batch, N_l, 1) 63 | 64 | p_feats = self.text_encoder(text, attention_mask=p_mask)[0] 65 | p_feats = p_feats.permute(0, 2, 1) # (B, 768, N_l) 66 | p_mask = p_mask.unsqueeze(dim=-1) # (batch, N_l, 1) 67 | 68 | ########################## 69 | features = self.backbone(x, l_feats, l_mask, t_feats, t_mask, p_feats, p_mask) 70 | x_c1, x_c2, x_c3, x_c4 = features # e.g. x_c1:[B, 128, 120, 120], x_c2:[B, 256, 60, 60], x_c3:[B, 512, 30, 30], x_c4:[B, 1024, 15, 15] 71 | x = self.classifier(x_c4, x_c3, x_c2, x_c1) 72 | x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=True) 73 | return x 74 | 75 | 76 | class LAVTOne(_LAVTOneSimpleDecode): #change 77 | pass 78 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FIANet 2 | This repository is the offical implementation for "Exploring Fine-Grained Image-Text Alignment for Referring Remote Sensing Image Segmentation."[[IEEE TGRS](https://ieeexplore.ieee.org/document/10816052)] [[arXiv](https://arxiv.org/abs/2409.13637)] 3 | 4 | ## Setting Up 5 | ### Preliminaries 6 | The code has been verified to work with PyTorch v1.12.1 and Python 3.7. 7 | 1. Clone this repository. 8 | 2. Change directory to root of this repository. 9 | ### Package Dependencies 10 | 1. Create a new Conda environment with Python 3.7 then activate it: 11 | ```shell 12 | conda create -n FIANet python==3.7 13 | conda activate FIANet 14 | ``` 15 | 16 | 2. Install PyTorch v1.12.1 with a CUDA version that works on your cluster/machine (CUDA 10.2 is used in this example): 17 | ```shell 18 | conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=10.2 -c pytorch 19 | ``` 20 | 21 | 3. Install the packages in `requirements.txt` via `pip`: 22 | ```shell 23 | pip install -r requirements.txt 24 | ``` 25 | ### The Initialization Weights for Training 26 | 1. Create the `./pretrained_weights` directory where we will be storing the weights. 27 | ```shell 28 | mkdir ./pretrained_weights 29 | ``` 30 | 2. Download [pre-trained classification weights of 31 | the Swin Transformer](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth), 32 | and put the `pth` file in `./pretrained_weights`. 33 | These weights are needed for training to initialize the visual encoder. 34 | 3. Download [BERT weights from HuggingFace’s Transformer library](https://huggingface.co/google-bert/bert-base-uncased), 35 | and put it in the root directory. 36 | 37 | ## Datasets 38 | We perform the experiments on two dataset including [RefSegRS](https://github.com/zhu-xlab/rrsis) and [RRSIS-D](https://github.com/Lsan2401/RMSIN). 39 | 40 | ## Training 41 | We use one GPU to train our model. 42 | For training on RefSegRS dataset: 43 | ```shell 44 | python train.py --dataset refsegrs --model_id FIANet --epochs 60 --lr 5e-5 --num_tmem 1 45 | ``` 46 | 47 | For training on RRSIS-D dataset: 48 | ```shell 49 | python train.py --dataset rrsisd --model_id FIANet --epochs 40 --lr 3e-5 --num_tmem 3 50 | ``` 51 | The pretrained models can be downloaded from [[BaiduNetDisk](https://pan.baidu.com/s/1WgvKFn9nXiny1pzcvVJjwQ?pwd=65g4)](extract code: 65g4). 52 | 53 | ## Testing 54 | For RefSegRS dataset: 55 | ```shell 56 | python test.py --swin_type base --dataset refsegrs --resume ./your_checkpoints_path --split test --window12 --img_size 480 --num_tmem 1 57 | ``` 58 | For RRSIS-D dataset: 59 | ```shell 60 | python test.py --swin_type base --dataset rrsisd --resume ./your_checkpoints_path --split test --window12 --img_size 480 --num_tmem 3 61 | ``` 62 | 63 | ## Citation 64 | If you find this code useful for your research, please cite our paper: 65 | `````` 66 | @ARTICLE{10816052, 67 | author={Lei, Sen and Xiao, Xinyu and Zhang, Tianlin and Li, Heng-Chao and Shi, Zhenwei and Zhu, Qing}, 68 | journal={IEEE Transactions on Geoscience and Remote Sensing}, 69 | title={Exploring Fine-Grained Image-Text Alignment for Referring Remote Sensing Image Segmentation}, 70 | year={2025}, 71 | volume={63}, 72 | number={}, 73 | pages={1-11}, 74 | doi={10.1109/TGRS.2024.3522293}} 75 | `````` 76 | 77 | ## Acknowledgements 78 | Code in this repository is built on [RMSIN](https://github.com/Lsan2401/RMSIN) and [LAVT](https://github.com/yz93/LAVT-RIS). We'd like to thank the authors for open sourcing their project. 79 | -------------------------------------------------------------------------------- /lib/mask_predictor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from collections import OrderedDict 5 | 6 | 7 | class SimpleDecoding(nn.Module): 8 | def __init__(self, c4_dims, factor=2): 9 | super(SimpleDecoding, self).__init__() 10 | 11 | hidden_size = c4_dims//factor 12 | c4_size = c4_dims 13 | c3_size = c4_dims//(factor**1) 14 | c2_size = c4_dims//(factor**2) 15 | c1_size = c4_dims//(factor**3) 16 | 17 | self.conv1_4 = nn.Conv2d(c4_size+c3_size, hidden_size, 3, padding=1, bias=False) 18 | self.bn1_4 = nn.BatchNorm2d(hidden_size) 19 | self.relu1_4 = nn.ReLU() 20 | self.conv2_4 = nn.Conv2d(hidden_size, hidden_size, 3, padding=1, bias=False) 21 | self.bn2_4 = nn.BatchNorm2d(hidden_size) 22 | self.relu2_4 = nn.ReLU() 23 | 24 | self.conv1_3 = nn.Conv2d(hidden_size + c2_size, hidden_size, 3, padding=1, bias=False) 25 | self.bn1_3 = nn.BatchNorm2d(hidden_size) 26 | self.relu1_3 = nn.ReLU() 27 | self.conv2_3 = nn.Conv2d(hidden_size, hidden_size, 3, padding=1, bias=False) 28 | self.bn2_3 = nn.BatchNorm2d(hidden_size) 29 | self.relu2_3 = nn.ReLU() 30 | 31 | self.conv1_2 = nn.Conv2d(hidden_size + c1_size, hidden_size, 3, padding=1, bias=False) 32 | self.bn1_2 = nn.BatchNorm2d(hidden_size) 33 | self.relu1_2 = nn.ReLU() 34 | self.conv2_2 = nn.Conv2d(hidden_size, hidden_size, 3, padding=1, bias=False) 35 | self.bn2_2 = nn.BatchNorm2d(hidden_size) 36 | self.relu2_2 = nn.ReLU() 37 | 38 | self.conv1_1 = nn.Conv2d(hidden_size, 2, 1) 39 | 40 | def forward(self, x_c4, x_c3, x_c2, x_c1): 41 | 42 | # import matplotlib.pyplot as plt 43 | # import numpy as np 44 | # input = x_c4 45 | # ttt = torch.mean(input, dim=1) 46 | # x_c1_show = ttt[0, :, :].cpu().numpy() 47 | # xc1_min = np.min(x_c1_show) 48 | # xc2_max = np.max(x_c1_show) 49 | # x_c1_show = (x_c1_show - xc1_min) / (xc2_max - xc1_min) 50 | # plt.imshow(x_c1_show, cmap='viridis') # viridis 51 | # plt.show() 52 | 53 | # fuse Y4 and Y3 54 | if x_c4.size(-2) < x_c3.size(-2) or x_c4.size(-1) < x_c3.size(-1): 55 | x_c4 = F.interpolate(input=x_c4, size=(x_c3.size(-2), x_c3.size(-1)), mode='bilinear', align_corners=True) 56 | x = torch.cat([x_c4, x_c3], dim=1) 57 | x = self.conv1_4(x) 58 | x = self.bn1_4(x) 59 | x = self.relu1_4(x) 60 | x = self.conv2_4(x) 61 | x = self.bn2_4(x) 62 | x = self.relu2_4(x) 63 | # fuse top-down features and Y2 features 64 | if x.size(-2) < x_c2.size(-2) or x.size(-1) < x_c2.size(-1): 65 | x = F.interpolate(input=x, size=(x_c2.size(-2), x_c2.size(-1)), mode='bilinear', align_corners=True) 66 | x = torch.cat([x, x_c2], dim=1) 67 | x = self.conv1_3(x) 68 | x = self.bn1_3(x) 69 | x = self.relu1_3(x) 70 | x = self.conv2_3(x) 71 | x = self.bn2_3(x) 72 | x = self.relu2_3(x) 73 | # fuse top-down features and Y1 features 74 | if x.size(-2) < x_c1.size(-2) or x.size(-1) < x_c1.size(-1): 75 | x = F.interpolate(input=x, size=(x_c1.size(-2), x_c1.size(-1)), mode='bilinear', align_corners=True) 76 | x = torch.cat([x, x_c1], dim=1) 77 | x = self.conv1_2(x) 78 | x = self.bn1_2(x) 79 | x = self.relu1_2(x) 80 | x = self.conv2_2(x) 81 | x = self.bn2_2(x) 82 | x = self.relu2_2(x) 83 | 84 | return self.conv1_1(x) 85 | -------------------------------------------------------------------------------- /args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def get_parser(): 4 | parser = argparse.ArgumentParser(description='FIANet training and testing') 5 | parser.add_argument('--amsgrad', action='store_true', 6 | help='if true, set amsgrad to True in an Adam or AdamW optimizer.') 7 | parser.add_argument('-b', '--batch-size', default=8, type=int) 8 | parser.add_argument('--bert_tokenizer', default='./bert-base-uncased', help='BERT tokenizer') 9 | parser.add_argument('--ck_bert', default='bert-base-uncased', help='pre-trained BERT weights') 10 | parser.add_argument('--dataset', default='rrsisd', help='refcoco, refcoco+, or refcocog') 11 | parser.add_argument('--ddp_trained_weights', action='store_true', 12 | help='Only needs specified when testing,' 13 | 'whether the weights to be loaded are from a DDP-trained model') 14 | parser.add_argument('--device', default='cuda:0', help='device') # only used when testing on a single machine 15 | parser.add_argument('--epochs', default=60, type=int, metavar='N', help='number of total epochs to run') 16 | parser.add_argument('--fusion_drop', default=0.0, type=float, help='dropout rate for PWAMs') 17 | parser.add_argument('--img_size', default=480, type=int, help='input image size') 18 | parser.add_argument("--local_rank", type=int,default=0,help='local rank for DistributedDataParallel') 19 | parser.add_argument('--lr', default=5e-5, type=float, help='the initial learning rate') # 5e-5 for RefSegRS, 3e-5 for RRSIS-D 20 | parser.add_argument('--mha', default='', help='If specified, should be in the format of a-b-c-d, e.g., 4-4-4-4,' 21 | 'where a, b, c, and d refer to the numbers of heads in stage-1,' 22 | 'stage-2, stage-3, and stage-4 PWAMs') 23 | parser.add_argument('--model', default='lavt_one', help='model: lavt, lavt_one') 24 | parser.add_argument('--model_id', default='FIANet', help='name to identify the model') 25 | parser.add_argument('--output-dir', default='./checkpoints/', help='path where to save checkpoint weights') 26 | parser.add_argument('--pin_mem', action='store_true', 27 | help='If true, pin memory when using the data loader.') 28 | parser.add_argument('--pretrained_swin_weights', default='./pretrained_weights/swin_base_patch4_window12_384_22k.pth', 29 | help='path to pre-trained Swin backbone weights') 30 | parser.add_argument('--print-freq', default=10, type=int, help='print frequency') 31 | parser.add_argument('--refer_data_root', default='C:/Dataset/refer_seg/RefSegRS/', help='REFER dataset root directory') 32 | parser.add_argument('--resume', default='', help='resume from checkpoint') 33 | parser.add_argument('--split', default='test', help='only used when testing') 34 | parser.add_argument('--splitBy', default='unc', help='change to umd or google when the datasset is G-Ref (RefCOCOg)') 35 | parser.add_argument('--swin_type', default='base', 36 | help='tiny, small, base, or large variants of the Swin Transformer') 37 | parser.add_argument('--wd', '--weight-decay', default=1e-2, type=float, metavar='W', help='weight decay', 38 | dest='weight_decay') 39 | parser.add_argument('--window12', action='store_true', 40 | help='only needs specified when testing,' 41 | 'when training, window size is inferred from pre-trained weights file name' 42 | '(containing \'window12\'). Initialize Swin with window size 12 instead of the default 7.') 43 | parser.add_argument('-j', '--workers', default=0, type=int, metavar='N', help='number of data loading workers') 44 | parser.add_argument('--num_tmem', default=1, type=int, help='number of tmem layers') # 1 for RefSegRS, 3 for RRSIS-D 45 | return parser 46 | 47 | 48 | if __name__ == "__main__": 49 | parser = get_parser() 50 | args_dict = parser.parse_args() 51 | -------------------------------------------------------------------------------- /transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import random 4 | 5 | import torch 6 | from torchvision import transforms as T 7 | from torchvision.transforms import functional as F 8 | 9 | 10 | def pad_if_smaller(img, size, fill=0): 11 | min_size = min(img.size) 12 | if min_size < size: 13 | ow, oh = img.size 14 | padh = size - oh if oh < size else 0 15 | padw = size - ow if ow < size else 0 16 | img = F.pad(img, (0, 0, padw, padh), fill=fill) 17 | return img 18 | 19 | 20 | class Compose(object): 21 | def __init__(self, transforms): 22 | self.transforms = transforms 23 | 24 | def __call__(self, image, target): 25 | for t in self.transforms: 26 | image, target = t(image, target) 27 | return image, target 28 | 29 | 30 | class Resize(object): 31 | def __init__(self, h, w): 32 | self.h = h 33 | self.w = w 34 | 35 | def __call__(self, image, target): 36 | image = F.resize(image, (self.h, self.w)) 37 | # If size is a sequence like (h, w), the output size will be matched to this. 38 | # If size is an int, the smaller edge of the image will be matched to this number maintaining the aspect ratio 39 | target = F.resize(target, (self.h, self.w), interpolation=Image.NEAREST) 40 | return image, target 41 | 42 | 43 | class RandomResize(object): 44 | def __init__(self, min_size, max_size=None): 45 | self.min_size = min_size 46 | if max_size is None: 47 | max_size = min_size 48 | self.max_size = max_size 49 | 50 | def __call__(self, image, target): 51 | size = random.randint(self.min_size, self.max_size) # Return a random integer N such that a <= N <= b. Alias for randrange(a, b+1) 52 | image = F.resize(image, size) 53 | # If size is a sequence like (h, w), the output size will be matched to this. 54 | # If size is an int, the smaller edge of the image will be matched to this number maintaining the aspect ratio 55 | target = F.resize(target, size, interpolation=Image.NEAREST) 56 | return image, target 57 | 58 | 59 | class RandomHorizontalFlip(object): 60 | def __init__(self, flip_prob): 61 | self.flip_prob = flip_prob 62 | 63 | def __call__(self, image, target): 64 | if random.random() < self.flip_prob: 65 | image = F.hflip(image) 66 | target = F.hflip(target) 67 | return image, target 68 | 69 | 70 | class RandomCrop(object): 71 | def __init__(self, size): 72 | self.size = size 73 | 74 | def __call__(self, image, target): 75 | image = pad_if_smaller(image, self.size) 76 | target = pad_if_smaller(target, self.size, fill=255) 77 | crop_params = T.RandomCrop.get_params(image, (self.size, self.size)) 78 | image = F.crop(image, *crop_params) 79 | target = F.crop(target, *crop_params) 80 | return image, target 81 | 82 | 83 | class CenterCrop(object): 84 | def __init__(self, size): 85 | self.size = size 86 | 87 | def __call__(self, image, target): 88 | image = F.center_crop(image, self.size) 89 | target = F.center_crop(target, self.size) 90 | return image, target 91 | 92 | 93 | class ToTensor(object): 94 | def __call__(self, image, target): 95 | image = F.to_tensor(image) 96 | target = torch.as_tensor(np.asarray(target).copy(), dtype=torch.int64) 97 | return image, target 98 | 99 | 100 | class RandomAffine(object): 101 | def __init__(self, angle, translate, scale, shear, resample=0, fillcolor=None): 102 | self.angle = angle 103 | self.translate = translate 104 | self.scale = scale 105 | self.shear = shear 106 | self.resample = resample 107 | self.fillcolor = fillcolor 108 | 109 | def __call__(self, image, target): 110 | affine_params = T.RandomAffine.get_params(self.angle, self.translate, self.scale, self.shear, image.size) 111 | image = F.affine(image, *affine_params) 112 | target = F.affine(target, *affine_params) 113 | return image, target 114 | 115 | 116 | class Normalize(object): 117 | def __init__(self, mean, std): 118 | self.mean = mean 119 | self.std = std 120 | 121 | def __call__(self, image, target): 122 | image = F.normalize(image, mean=self.mean, std=self.std) 123 | return image, target 124 | 125 | -------------------------------------------------------------------------------- /lib/segmentation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .mask_predictor import SimpleDecoding 4 | from .backbone import MultiModalSwinTransformer 5 | from ._utils import LAVT, LAVTOne 6 | 7 | 8 | __all__ = ['lavt', 'lavt_one'] 9 | 10 | 11 | # LAVT 12 | def _segm_lavt(pretrained, args): 13 | # initialize the SwinTransformer backbone with the specified version 14 | if args.swin_type == 'tiny': 15 | embed_dim = 96 16 | depths = [2, 2, 6, 2] 17 | num_heads = [3, 6, 12, 24] 18 | elif args.swin_type == 'small': 19 | embed_dim = 96 20 | depths = [2, 2, 18, 2] 21 | num_heads = [3, 6, 12, 24] 22 | elif args.swin_type == 'base': 23 | embed_dim = 128 24 | depths = [2, 2, 18, 2] 25 | num_heads = [4, 8, 16, 32] 26 | elif args.swin_type == 'large': 27 | embed_dim = 192 28 | depths = [2, 2, 18, 2] 29 | num_heads = [6, 12, 24, 48] 30 | else: 31 | assert False 32 | # args.window12 added for test.py because state_dict is loaded after model initialization 33 | if 'window12' in pretrained or args.window12: 34 | print('Window size 12!') 35 | window_size = 12 36 | else: 37 | window_size = 7 38 | 39 | if args.mha: 40 | mha = args.mha.split('-') # if non-empty, then ['a', 'b', 'c', 'd'] 41 | mha = [int(a) for a in mha] 42 | else: 43 | mha = [1, 1, 1, 1] 44 | 45 | out_indices = (0, 1, 2, 3) 46 | backbone = MultiModalSwinTransformer(embed_dim=embed_dim, depths=depths, num_heads=num_heads, 47 | window_size=window_size, 48 | num_tmem=args.num_tmem, 49 | ape=False, drop_path_rate=0.3, patch_norm=True, 50 | out_indices=out_indices, 51 | use_checkpoint=False, num_heads_fusion=mha, 52 | fusion_drop=args.fusion_drop 53 | ) 54 | if pretrained: 55 | print('Initializing Multi-modal Swin Transformer weights from ' + pretrained) 56 | backbone.init_weights(pretrained=pretrained) 57 | else: 58 | print('Randomly initialize Multi-modal Swin Transformer weights.') 59 | backbone.init_weights() 60 | 61 | model_map = [SimpleDecoding, LAVT] 62 | 63 | classifier = model_map[0](8*embed_dim) 64 | base_model = model_map[1] 65 | 66 | model = base_model(backbone, classifier) 67 | return model 68 | 69 | 70 | def _load_model_lavt(pretrained, args): 71 | model = _segm_lavt(pretrained, args) 72 | return model 73 | 74 | 75 | def lavt(pretrained='', args=None): 76 | return _load_model_lavt(pretrained, args) 77 | 78 | 79 | ############################################### 80 | # LAVT One: put BERT inside the overall model # 81 | ############################################### 82 | def _segm_lavt_one(pretrained, args): 83 | # initialize the SwinTransformer backbone with the specified version 84 | if args.swin_type == 'tiny': 85 | embed_dim = 96 86 | depths = [2, 2, 6, 2] 87 | num_heads = [3, 6, 12, 24] 88 | elif args.swin_type == 'small': 89 | embed_dim = 96 90 | depths = [2, 2, 18, 2] 91 | num_heads = [3, 6, 12, 24] 92 | elif args.swin_type == 'base': 93 | embed_dim = 128 94 | depths = [2, 2, 18, 2] 95 | num_heads = [4, 8, 16, 32] 96 | elif args.swin_type == 'large': 97 | embed_dim = 192 98 | depths = [2, 2, 18, 2] 99 | num_heads = [6, 12, 24, 48] 100 | else: 101 | assert False 102 | # args.window12 added for test.py because state_dict is loaded after model initialization 103 | if 'window12' in pretrained or args.window12: 104 | print('Window size 12!') 105 | window_size = 12 106 | else: 107 | window_size = 7 108 | 109 | if args.mha: 110 | mha = args.mha.split('-') # if non-empty, then ['a', 'b', 'c', 'd'] 111 | mha = [int(a) for a in mha] 112 | else: 113 | mha = [1, 1, 1, 1] 114 | 115 | out_indices = (0, 1, 2, 3) 116 | backbone = MultiModalSwinTransformer(embed_dim=embed_dim, depths=depths, num_heads=num_heads, 117 | window_size=window_size, 118 | num_tmem=args.num_tmem, 119 | ape=False, drop_path_rate=0.3, patch_norm=True, 120 | out_indices=out_indices, 121 | use_checkpoint=False, num_heads_fusion=mha, 122 | fusion_drop=args.fusion_drop, 123 | ) 124 | if pretrained: 125 | print('Initializing Multi-modal Swin Transformer weights from ' + pretrained) 126 | backbone.init_weights(pretrained=pretrained) 127 | else: 128 | print('Randomly initialize Multi-modal Swin Transformer weights.') 129 | backbone.init_weights() 130 | 131 | model_map = [SimpleDecoding, LAVTOne] 132 | classifier = model_map[0](8*embed_dim) 133 | base_model = model_map[1] 134 | 135 | model = base_model(backbone, classifier, args) 136 | return model 137 | 138 | 139 | def _load_model_lavt_one(pretrained, args): 140 | model = _segm_lavt_one(pretrained, args) 141 | return model 142 | 143 | 144 | def lavt_one(pretrained='', args=None): 145 | return _load_model_lavt_one(pretrained, args) 146 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from collections import defaultdict, deque 3 | import datetime 4 | import math 5 | import time 6 | import torch 7 | import torch.distributed as dist 8 | import torch.backends.cudnn as cudnn 9 | 10 | import errno 11 | import os 12 | 13 | import sys 14 | 15 | 16 | class SmoothedValue(object): 17 | """Track a series of values and provide access to smoothed values over a 18 | window or the global series average. 19 | """ 20 | 21 | def __init__(self, window_size=20, fmt=None): 22 | if fmt is None: 23 | fmt = "{median:.4f} ({global_avg:.4f})" 24 | self.deque = deque(maxlen=window_size) 25 | self.total = 0.0 26 | self.count = 0 27 | self.fmt = fmt 28 | 29 | def update(self, value, n=1): 30 | self.deque.append(value) 31 | self.count += n 32 | self.total += value * n 33 | 34 | def synchronize_between_processes(self): 35 | """ 36 | Warning: does not synchronize the deque! 37 | """ 38 | if not is_dist_avail_and_initialized(): 39 | return 40 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 41 | dist.barrier() 42 | dist.all_reduce(t) 43 | t = t.tolist() 44 | self.count = int(t[0]) 45 | self.total = t[1] 46 | 47 | @property 48 | def median(self): 49 | d = torch.tensor(list(self.deque)) 50 | return d.median().item() 51 | 52 | @property 53 | def avg(self): 54 | d = torch.tensor(list(self.deque), dtype=torch.float32) 55 | return d.mean().item() 56 | 57 | @property 58 | def global_avg(self): 59 | return self.total / self.count 60 | 61 | @property 62 | def max(self): 63 | return max(self.deque) 64 | 65 | @property 66 | def value(self): 67 | return self.deque[-1] 68 | 69 | def __str__(self): 70 | return self.fmt.format( 71 | median=self.median, 72 | avg=self.avg, 73 | global_avg=self.global_avg, 74 | max=self.max, 75 | value=self.value) 76 | 77 | 78 | class MetricLogger(object): 79 | def __init__(self, delimiter="\t"): 80 | self.meters = defaultdict(SmoothedValue) 81 | self.delimiter = delimiter 82 | 83 | def update(self, **kwargs): 84 | for k, v in kwargs.items(): 85 | if isinstance(v, torch.Tensor): 86 | v = v.item() 87 | assert isinstance(v, (float, int)) 88 | self.meters[k].update(v) 89 | 90 | 91 | 92 | def __getattr__(self, attr): 93 | if attr in self.meters: 94 | return self.meters[attr] 95 | if attr in self.__dict__: 96 | return self.__dict__[attr] 97 | raise AttributeError("'{}' object has no attribute '{}'".format( 98 | type(self).__name__, attr)) 99 | 100 | def __str__(self): 101 | loss_str = [] 102 | for name, meter in self.meters.items(): 103 | loss_str.append( 104 | "{}: {}".format(name, str(meter)) 105 | ) 106 | return self.delimiter.join(loss_str) 107 | 108 | def synchronize_between_processes(self): 109 | for meter in self.meters.values(): 110 | meter.synchronize_between_processes() 111 | 112 | def add_meter(self, name, meter): 113 | self.meters[name] = meter 114 | 115 | def log_every(self, iterable, print_freq, header=None): 116 | print(iterable) 117 | i = 0 118 | if not header: 119 | header = '' 120 | start_time = time.time() 121 | end = time.time() 122 | iter_time = SmoothedValue(fmt='{avg:.4f}') 123 | data_time = SmoothedValue(fmt='{avg:.4f}') 124 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 125 | log_msg = self.delimiter.join([ 126 | header, 127 | '[{0' + space_fmt + '}/{1}]', 128 | 'eta: {eta}', 129 | '{meters}', 130 | 'time: {time}', 131 | 'data: {data}', 132 | 'max mem: {memory:.0f}' 133 | ]) 134 | MB = 1024.0 * 1024.0 135 | for obj in iterable: 136 | data_time.update(time.time() - end) 137 | yield obj 138 | iter_time.update(time.time() - end) 139 | if i % print_freq == 0: 140 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 141 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 142 | print(log_msg.format( 143 | i, len(iterable), eta=eta_string, 144 | meters=str(self), 145 | time=str(iter_time), data=str(data_time), 146 | memory=torch.cuda.max_memory_allocated() / MB)) 147 | sys.stdout.flush() 148 | 149 | i += 1 150 | end = time.time() 151 | total_time = time.time() - start_time 152 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 153 | print('{} Total time: {}'.format(header, total_time_str)) 154 | 155 | 156 | def mkdir(path): 157 | try: 158 | os.makedirs(path) 159 | except OSError as e: 160 | if e.errno != errno.EEXIST: 161 | raise 162 | 163 | 164 | def setup_for_distributed(is_master): 165 | """ 166 | This function disables printing when not in master process 167 | """ 168 | import builtins as __builtin__ 169 | builtin_print = __builtin__.print 170 | 171 | def print(*args, **kwargs): 172 | force = kwargs.pop('force', False) 173 | if is_master or force: 174 | builtin_print(*args, **kwargs) 175 | 176 | __builtin__.print = print 177 | 178 | 179 | def is_dist_avail_and_initialized(): 180 | if not dist.is_available(): 181 | return False 182 | if not dist.is_initialized(): 183 | return False 184 | return True 185 | 186 | 187 | def get_world_size(): 188 | if not is_dist_avail_and_initialized(): 189 | return 1 190 | return dist.get_world_size() 191 | 192 | 193 | def get_rank(): 194 | if not is_dist_avail_and_initialized(): 195 | return 0 196 | return dist.get_rank() 197 | 198 | 199 | def is_main_process(): 200 | return get_rank() == 0 201 | 202 | 203 | def save_on_master(*args, **kwargs): 204 | if is_main_process(): 205 | torch.save(*args, **kwargs) 206 | 207 | 208 | def init_distributed_mode(args): 209 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 210 | rank = int(os.environ["RANK"]) 211 | world_size = int(os.environ['WORLD_SIZE']) 212 | print(f"RANK and WORLD_SIZE in environment: {rank}/{world_size}") 213 | else: 214 | rank = -1 215 | world_size = -1 216 | 217 | # torch.cuda.set_device(args.local_rank) 218 | # torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank) 219 | # torch.distributed.barrier() 220 | # setup_for_distributed(is_main_process()) 221 | 222 | if args.output_dir: 223 | mkdir(args.output_dir) 224 | if args.model_id: 225 | mkdir(os.path.join('./models/', args.model_id)) 226 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import torch.utils.data 4 | import utils 5 | import numpy as np 6 | import transforms as T 7 | from torchvision.transforms import functional as F 8 | import random 9 | from bert.modeling_bert import BertModel 10 | from lib import segmentation 11 | import os 12 | os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" 13 | 14 | import time 15 | import datetime 16 | 17 | def get_dataset(image_set, transform, args): 18 | 19 | if args.dataset == "rrsisd": 20 | from data.rrsisd_refer_bert import ReferDataset 21 | else: 22 | from data.refsegrs_refer_bert import ReferDataset 23 | ds = ReferDataset(args, 24 | split=image_set, 25 | image_transforms=transform, 26 | target_transforms=None, 27 | eval_mode=True 28 | ) 29 | num_classes = 2 30 | return ds, num_classes 31 | 32 | 33 | def evaluate(model, data_loader, bert_model, device): 34 | model.eval() 35 | metric_logger = utils.MetricLogger(delimiter=" ") 36 | 37 | # evaluation variables 38 | cum_I, cum_U = 0, 0 39 | eval_seg_iou_list = [.5, .6, .7, .8, .9] 40 | seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32) 41 | seg_total = 0 42 | mean_IoU = [] 43 | header = 'Test:' 44 | save_dir = "experiments/test_vis" 45 | 46 | start_time = time.time() 47 | 48 | with torch.no_grad(): 49 | 50 | for data in metric_logger.log_every(data_loader, 100, header): 51 | 52 | image, target, sentences, attentions, target_masks, position_masks, save_prefix = data 53 | image = image.to(device) 54 | target = target.to(device) 55 | sentences = sentences.to(device) 56 | attentions = attentions.to(device) 57 | target = target.to(device) 58 | 59 | target_masks = target_masks.to(device) 60 | position_masks = position_masks.to(device) 61 | 62 | sentences = sentences.squeeze(1) 63 | attentions = attentions.squeeze(1) 64 | target_masks = target_masks.squeeze(1) 65 | position_masks = position_masks.squeeze(1) 66 | 67 | target = target.cpu().data.numpy() 68 | for j in range(sentences.size(0)): 69 | if bert_model is not None: 70 | last_hidden_states = bert_model(sentences[:, :, j], attention_mask=attentions[:, :, j])[0] 71 | embedding = last_hidden_states.permute(0, 2, 1) 72 | output = model(image, embedding, l_mask=attentions[:, :, j].unsqueeze(-1)) 73 | else: 74 | # output = model(image, sentences[:, :, j], l_mask=attentions[:, :, j]) 75 | 76 | # output = model(image, sentences[:, :, j], attentions[:, :, j], target_masks[:, :, j], position_masks[:, :, j]) 77 | output = model(image, sentences, attentions, target_masks, position_masks) 78 | 79 | output = output.cpu() 80 | 81 | output_mask = output.argmax(1).data.numpy() 82 | 83 | I, U = computeIoU(output_mask, target) 84 | 85 | # save pred results 86 | # save_path = os.path.join(save_dir, str(seg_total+1)) 87 | # save_pred_targ_results(output_mask, target, image, save_path) 88 | 89 | if U == 0: 90 | this_iou = 0.0 91 | else: 92 | this_iou = I*1.0/U 93 | mean_IoU.append(this_iou) 94 | cum_I += I 95 | cum_U += U 96 | for n_eval_iou in range(len(eval_seg_iou_list)): 97 | eval_seg_iou = eval_seg_iou_list[n_eval_iou] 98 | seg_correct[n_eval_iou] += (this_iou >= eval_seg_iou) 99 | 100 | # mask = output_mask * 255 101 | # cv2.imwrite(os.path.join(save_dir, save_prefix[0] + "_" + str(this_iou) + "_pred.png"), mask[0, :, :]) 102 | 103 | seg_total += 1 104 | 105 | del image, target, sentences, attentions, output,output_mask 106 | if bert_model is not None: 107 | del last_hidden_states, embedding 108 | 109 | mean_IoU = np.array(mean_IoU) 110 | mIoU = np.mean(mean_IoU) 111 | print('Final results:') 112 | print('Mean IoU is %.2f\n' % (mIoU*100.)) 113 | results_str = '' 114 | for n_eval_iou in range(len(eval_seg_iou_list)): 115 | results_str += ' precision@%s = %.2f\n' % \ 116 | (str(eval_seg_iou_list[n_eval_iou]), seg_correct[n_eval_iou] * 100. / seg_total) 117 | results_str += ' overall IoU = %.2f\n' % (cum_I * 100. / cum_U) 118 | print(results_str) 119 | 120 | # summarize 121 | total_time = time.time() - start_time 122 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 123 | print('Total test time {}'.format(total_time_str)) 124 | print('Test time for one image %.2f ' % (total_time / seg_total)) 125 | 126 | 127 | def get_transform(args): 128 | transforms = [T.Resize(args.img_size, args.img_size), 129 | T.ToTensor(), 130 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 131 | ] 132 | 133 | return T.Compose(transforms) 134 | 135 | 136 | def computeIoU(pred_seg, gd_seg): 137 | I = np.sum(np.logical_and(pred_seg, gd_seg)) 138 | U = np.sum(np.logical_or(pred_seg, gd_seg)) 139 | 140 | return I, U 141 | 142 | def save_pred_targ_results(output_mask, target, image, save_path): 143 | "leisen: show the predict and target outcomes" 144 | pred = output_mask[0, :, :] 145 | targ = target[0, :, :] 146 | # pred = np.uint8(pred * 255) 147 | # targ = np.uint8(targ * 255) 148 | # cv2.imwrite("pred.png", pred) 149 | # cv2.imwrite("targ.png", targ) 150 | 151 | mean = [0.485, 0.456, 0.406] 152 | std = [0.229, 0.224, 0.225] 153 | inv_mean = [-m / s for m, s in zip(mean, std)] 154 | inv_std = [1 / s for s in std] 155 | im = F.normalize(image, mean=inv_mean, std=inv_std) 156 | im = im[0, :, :, :].cpu().detach().numpy() 157 | im = im.transpose([1, 2, 0]) 158 | im = np.uint8(im * 255) 159 | 160 | pred_red_mask = np.zeros((pred.shape[0], pred.shape[1], 3), dtype=np.uint8) 161 | pred_red_mask[:, :] = (0, 0, 255) 162 | pred_mask = np.repeat(pred[:, :, np.newaxis], 3, -1) 163 | pred_mask = pred_red_mask * pred_mask 164 | pred_mask = np.uint8(pred_mask) 165 | pred_img_write = cv2.addWeighted(im, 0.5, pred_mask, 0.5, 0) 166 | cv2.imwrite((save_path + "_pred.png"), pred_img_write) 167 | 168 | targ_red_mask = np.zeros((targ.shape[0], targ.shape[1], 3), dtype=np.uint8) 169 | targ_red_mask[:, :] = (0, 0, 255) 170 | targ_mask = np.repeat(targ[:, :, np.newaxis], 3, -1) 171 | targ_mask = targ_red_mask * targ_mask 172 | targ_mask = np.uint8(targ_mask) 173 | targ_img_write = cv2.addWeighted(im, 0.5, targ_mask, 0.5, 0) 174 | cv2.imwrite((save_path + "_targ.png"), targ_img_write) 175 | 176 | 177 | def main(args): 178 | device = torch.device(args.device) 179 | dataset_test, _ = get_dataset(args.split, get_transform(args=args), args) 180 | 181 | test_sampler = torch.utils.data.SequentialSampler(dataset_test) 182 | data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=1, 183 | sampler=test_sampler, num_workers=args.workers) 184 | print(args.model) 185 | single_model = segmentation.__dict__[args.model](pretrained='',args=args) 186 | checkpoint = torch.load(args.resume, map_location='cpu') 187 | single_model.load_state_dict(checkpoint['model'], strict=False) 188 | model = single_model.to(device) 189 | 190 | if args.model != 'lavt_one': 191 | model_class = BertModel 192 | single_bert_model = model_class.from_pretrained(args.ck_bert) 193 | # work-around for a transformers bug; need to update to a newer version of transformers to remove these two lines 194 | if args.ddp_trained_weights: 195 | single_bert_model.pooler = None 196 | single_bert_model.load_state_dict(checkpoint['bert_model']) 197 | bert_model = single_bert_model.to(device) 198 | else: 199 | bert_model = None 200 | 201 | evaluate(model, data_loader_test, bert_model, device=device) 202 | 203 | 204 | if __name__ == "__main__": 205 | from args import get_parser 206 | parser = get_parser() 207 | args = parser.parse_args() 208 | print('Image size: {}'.format(str(args.img_size))) 209 | main(args) 210 | -------------------------------------------------------------------------------- /data/refsegrs_refer_bert.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.utils.data as data 3 | import torch 4 | import numpy as np 5 | from PIL import Image 6 | from bert.tokenization_bert import BertTokenizer 7 | from args import get_parser 8 | import cv2 9 | 10 | import re 11 | import nltk 12 | from nltk.tokenize import word_tokenize 13 | 14 | # Dataset configuration initialization 15 | parser = get_parser() 16 | args = parser.parse_args() 17 | 18 | data_root = args.refer_data_root 19 | 20 | def build_rsris_batches(setname): 21 | im_dir1 = f'{data_root}/images/' 22 | seg_label_dir = f'{data_root}/masks/' 23 | if setname == 'train': 24 | setfile = 'output_phrase_train.txt' 25 | if setname == 'val': 26 | setfile = 'output_phrase_val.txt' 27 | if setname == 'test': 28 | setfile = 'output_phrase_test.txt' 29 | 30 | n_batch = 0 31 | train_ids = [] 32 | tf = f'{data_root}/'+setfile 33 | nn = 0 34 | imgnames = set() 35 | imname = 'start' 36 | all_imgs1 = [] 37 | all_labels = [] 38 | all_sentences = [] 39 | 40 | test_sentence = [] 41 | 42 | with open(tf,'r') as rf: 43 | rlines = rf.readlines() 44 | for idx,line in enumerate(rlines): 45 | lsplit = line.split(' ') 46 | if True: 47 | im_name1 = im_dir1 + lsplit[0] + '.tif' 48 | seg = seg_label_dir + lsplit[0] + '.tif' 49 | del(lsplit[0]) 50 | if False and setname != 'train': 51 | del(lsplit[-1]) 52 | sentence = ' '.join(lsplit) 53 | sent = sentence 54 | 55 | im_1 = im_name1 56 | label_mask = seg 57 | all_imgs1.append(im_name1) 58 | all_labels.append(label_mask) 59 | all_sentences.append(sent) 60 | 61 | print("Dataset Loaded.") 62 | return all_imgs1, all_labels, all_sentences 63 | 64 | class ReferDataset(data.Dataset): 65 | 66 | def __init__(self, 67 | args, 68 | image_transforms=None, 69 | target_transforms=None, 70 | split='train', 71 | eval_mode=False): 72 | 73 | self.classes = [] 74 | self.image_transforms = image_transforms 75 | self.target_transform = target_transforms 76 | self.split = split 77 | self.max_tokens = 20 78 | 79 | all_imgs1, all_labels, all_sentences = build_rsris_batches(self.split) 80 | self.sentences = all_sentences 81 | self.imgs1 = all_imgs1 82 | self.labels = all_labels 83 | 84 | self.input_ids = [] 85 | self.attention_masks = [] 86 | self.tokenizer = BertTokenizer.from_pretrained(args.bert_tokenizer) 87 | 88 | self.target_masks = [] 89 | self.position_masks = [] 90 | 91 | self.sentences_raw = [] 92 | self.pp_phrase = [] 93 | 94 | # debug 95 | self.max_len = 0 96 | 97 | # for RefSegRS dataset 98 | self.target_cls = {"road", "vehicle", "car", "van", "building", "truck", "trailer", "bus", 99 | "road marking", "bikeway", "sidewalk", "tree", "low vegetation", 100 | "impervious surface"} 101 | 102 | self.eval_mode = eval_mode 103 | # if we are testing on a dataset, test all sentences of an object; 104 | # o/w, we are validating during training, randomly sample one sentence for efficiency 105 | for r in range(len(self.imgs1)): 106 | img_sentences = [self.sentences[r]] 107 | sentences_for_ref = [] 108 | attentions_for_ref = [] 109 | 110 | target_for_ref = [] 111 | position_for_ref = [] 112 | 113 | for i, el in enumerate(img_sentences): 114 | sentence_raw = el 115 | attention_mask = [0] * self.max_tokens 116 | padded_input_ids = [0] * self.max_tokens 117 | 118 | target_masks = [0] * self.max_tokens 119 | position_masks = [0] * self.max_tokens 120 | 121 | input_ids = self.tokenizer.encode(text=sentence_raw, add_special_tokens=True) 122 | 123 | # truncation of tokens 124 | input_ids = input_ids[:self.max_tokens] 125 | 126 | padded_input_ids[:len(input_ids)] = input_ids 127 | attention_mask[:len(input_ids)] = [1]*len(input_ids) 128 | 129 | sentences_for_ref.append(torch.tensor(padded_input_ids).unsqueeze(0)) 130 | attentions_for_ref.append(torch.tensor(attention_mask).unsqueeze(0)) 131 | 132 | # extract the ground object 133 | # print(sentence_raw) 134 | self.sentences_raw.append(sentence_raw) 135 | tokenized_sentence = word_tokenize(sentence_raw) 136 | 137 | for cls in self.target_cls: 138 | if re.findall(cls, sentence_raw): 139 | 140 | tokenized_cls = word_tokenize(cls) 141 | nums_cls = len(tokenized_cls) 142 | index = 0 143 | for i, token in enumerate(tokenized_sentence): 144 | if re.findall(tokenized_cls[0], token): 145 | index = i 146 | break 147 | target_masks[index + 1: index + nums_cls + 1] = [1] * nums_cls 148 | 149 | target_for_ref.append(torch.tensor(target_masks).unsqueeze(0)) 150 | 151 | # extract the spatial position 152 | grammar = r""" 153 | PP: {
??} 154 | {
??} 155 | {
?} 156 | """ 157 | chunkr = nltk.RegexpParser(grammar) 158 | # grammar parsing 159 | tree = chunkr.parse(nltk.pos_tag(tokenized_sentence)) 160 | pp_phrases = [] 161 | for subtree in tree.subtrees(): 162 | if subtree.label() == 'PP': 163 | pp_phrases.append(' '.join(word for word, pos in subtree.leaves())) 164 | 165 | new_pp_phrase = [] 166 | for phrase in pp_phrases: 167 | if not re.findall("of", phrase): 168 | new_pp_phrase.append(phrase) 169 | 170 | if len(new_pp_phrase) > 0: 171 | tokenized_sentence = word_tokenize(sentence_raw) 172 | for pp in new_pp_phrase: 173 | tokenized_pos = word_tokenize(pp) 174 | nums_pos = len(tokenized_pos) 175 | index = 0 176 | for i, token in enumerate(tokenized_sentence): 177 | if tokenized_pos[0] == token: 178 | index = i 179 | break 180 | position_masks[index + 1: index + nums_pos +1] = [1] * nums_pos 181 | 182 | self.pp_phrase.append(new_pp_phrase) 183 | position_for_ref.append(torch.tensor(position_masks).unsqueeze(0)) 184 | # if there are no pp, the position_for_ref equals attentions_for_ref 185 | if torch.sum(position_for_ref[0]) == 0: 186 | position_for_ref = attentions_for_ref 187 | 188 | self.input_ids.append(sentences_for_ref) 189 | self.attention_masks.append(attentions_for_ref) 190 | self.target_masks.append(target_for_ref) 191 | self.position_masks.append(position_for_ref) 192 | 193 | def get_classes(self): 194 | return self.classes 195 | 196 | def __len__(self): 197 | return len(self.imgs1) 198 | 199 | def __getitem__(self, index): 200 | this_img1 = self.imgs1[index] 201 | 202 | img1 = Image.open(this_img1).convert("RGB") 203 | label_mask = cv2.imread(self.labels[index],2) 204 | 205 | ref_mask = np.array(label_mask) > 50 206 | annot = np.zeros(ref_mask.shape) 207 | annot[ref_mask == 1] = 1 208 | 209 | annot = Image.fromarray(annot.astype(np.uint8), mode="P") 210 | save_prefix = str(index) + "_" + self.sentences_raw[index][:-1] 211 | if self.image_transforms is not None: 212 | # resize, from PIL to tensor, and mean and std normalization 213 | img1, target = self.image_transforms(img1, annot) 214 | 215 | choice_sent = np.random.choice(len(self.input_ids[index])) 216 | tensor_embeddings = self.input_ids[index][choice_sent] 217 | attention_mask = self.attention_masks[index][choice_sent] 218 | target_mask = self.target_masks[index][choice_sent] 219 | position_mask = self.position_masks[index][choice_sent] 220 | 221 | return img1, target, tensor_embeddings, attention_mask, target_mask, position_mask, save_prefix 222 | -------------------------------------------------------------------------------- /bert/configuration_bert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ BERT model configuration """ 17 | 18 | 19 | import logging 20 | 21 | from .configuration_utils import PretrainedConfig 22 | 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { 27 | "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json", 28 | "bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json", 29 | "bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json", 30 | "bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json", 31 | "bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json", 32 | "bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json", 33 | "bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json", 34 | "bert-base-german-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json", 35 | "bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json", 36 | "bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json", 37 | "bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json", 38 | "bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json", 39 | "bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json", 40 | "bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-config.json", 41 | "bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-config.json", 42 | "cl-tohoku/bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese/config.json", 43 | "cl-tohoku/bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking/config.json", 44 | "cl-tohoku/bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char/config.json", 45 | "cl-tohoku/bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking/config.json", 46 | "TurkuNLP/bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/config.json", 47 | "TurkuNLP/bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/config.json", 48 | "wietsedv/bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/config.json", 49 | # See all BERT models at https://huggingface.co/models?filter=bert 50 | } 51 | 52 | 53 | class BertConfig(PretrainedConfig): 54 | r""" 55 | This is the configuration class to store the configuration of a :class:`~transformers.BertModel`. 56 | It is used to instantiate an BERT model according to the specified arguments, defining the model 57 | architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of 58 | the BERT `bert-base-uncased `__ architecture. 59 | 60 | Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used 61 | to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig` 62 | for more information. 63 | 64 | 65 | Args: 66 | vocab_size (:obj:`int`, optional, defaults to 30522): 67 | Vocabulary size of the BERT model. Defines the different tokens that 68 | can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.BertModel`. 69 | hidden_size (:obj:`int`, optional, defaults to 768): 70 | Dimensionality of the encoder layers and the pooler layer. 71 | num_hidden_layers (:obj:`int`, optional, defaults to 12): 72 | Number of hidden layers in the Transformer encoder. 73 | num_attention_heads (:obj:`int`, optional, defaults to 12): 74 | Number of attention heads for each attention layer in the Transformer encoder. 75 | intermediate_size (:obj:`int`, optional, defaults to 3072): 76 | Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. 77 | hidden_act (:obj:`str` or :obj:`function`, optional, defaults to "gelu"): 78 | The non-linear activation function (function or string) in the encoder and pooler. 79 | If string, "gelu", "relu", "swish" and "gelu_new" are supported. 80 | hidden_dropout_prob (:obj:`float`, optional, defaults to 0.1): 81 | The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler. 82 | attention_probs_dropout_prob (:obj:`float`, optional, defaults to 0.1): 83 | The dropout ratio for the attention probabilities. 84 | max_position_embeddings (:obj:`int`, optional, defaults to 512): 85 | The maximum sequence length that this model might ever be used with. 86 | Typically set this to something large just in case (e.g., 512 or 1024 or 2048). 87 | type_vocab_size (:obj:`int`, optional, defaults to 2): 88 | The vocabulary size of the `token_type_ids` passed into :class:`~transformers.BertModel`. 89 | initializer_range (:obj:`float`, optional, defaults to 0.02): 90 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 91 | layer_norm_eps (:obj:`float`, optional, defaults to 1e-12): 92 | The epsilon used by the layer normalization layers. 93 | gradient_checkpointing (:obj:`bool`, optional, defaults to False): 94 | If True, use gradient checkpointing to save memory at the expense of slower backward pass. 95 | 96 | Example:: 97 | 98 | >>> from transformers import BertModel, BertConfig 99 | 100 | >>> # Initializing a BERT bert-base-uncased style configuration 101 | >>> configuration = BertConfig() 102 | 103 | >>> # Initializing a model from the bert-base-uncased style configuration 104 | >>> model = BertModel(configuration) 105 | 106 | >>> # Accessing the model configuration 107 | >>> configuration = model.config 108 | """ 109 | model_type = "bert" 110 | 111 | def __init__( 112 | self, 113 | vocab_size=30522, 114 | hidden_size=768, 115 | num_hidden_layers=12, 116 | num_attention_heads=12, 117 | intermediate_size=3072, 118 | hidden_act="gelu", 119 | hidden_dropout_prob=0.1, 120 | attention_probs_dropout_prob=0.1, 121 | max_position_embeddings=512, 122 | type_vocab_size=2, 123 | initializer_range=0.02, 124 | layer_norm_eps=1e-12, 125 | pad_token_id=0, 126 | gradient_checkpointing=False, 127 | **kwargs 128 | ): 129 | super().__init__(pad_token_id=pad_token_id, **kwargs) 130 | 131 | self.vocab_size = vocab_size 132 | self.hidden_size = hidden_size 133 | self.num_hidden_layers = num_hidden_layers 134 | self.num_attention_heads = num_attention_heads 135 | self.hidden_act = hidden_act 136 | self.intermediate_size = intermediate_size 137 | self.hidden_dropout_prob = hidden_dropout_prob 138 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 139 | self.max_position_embeddings = max_position_embeddings 140 | self.type_vocab_size = type_vocab_size 141 | self.initializer_range = initializer_range 142 | self.layer_norm_eps = layer_norm_eps 143 | self.gradient_checkpointing = gradient_checkpointing 144 | -------------------------------------------------------------------------------- /data/rrsisd_refer_bert.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.utils.data as data 3 | import torch 4 | import numpy as np 5 | from PIL import Image 6 | import random 7 | from bert.tokenization_bert import BertTokenizer 8 | from refer.refer import REFER 9 | 10 | from args import get_parser 11 | 12 | import re 13 | import nltk 14 | from nltk.tokenize import word_tokenize 15 | 16 | # Dataset configuration initialization 17 | parser = get_parser() 18 | args = parser.parse_args() 19 | 20 | 21 | def add_random_boxes(img, min_num=20, max_num=60, size=32): 22 | h,w = size, size 23 | img = np.asarray(img).copy() 24 | img_size = img.shape[1] 25 | boxes = [] 26 | num = random.randint(min_num, max_num) 27 | for k in range(num): 28 | y, x = random.randint(0, img_size-w), random.randint(0, img_size-h) 29 | img[y:y+h, x: x+w] = 0 30 | boxes. append((x,y,h,w) ) 31 | img = Image.fromarray(img.astype('uint8'), 'RGB') 32 | return img 33 | 34 | 35 | class ReferDataset(data.Dataset): 36 | 37 | def __init__(self, 38 | args, 39 | image_transforms=None, 40 | target_transforms=None, 41 | split='train', 42 | eval_mode=False): 43 | 44 | self.classes = [] 45 | self.image_transforms = image_transforms 46 | self.target_transform = target_transforms 47 | self.split = split 48 | self.refer = REFER(args.refer_data_root, args.dataset, args.splitBy) 49 | 50 | self.max_tokens = 22 51 | 52 | ref_ids = self.refer.getRefIds(split=self.split) 53 | img_ids = self.refer.getImgIds(ref_ids) 54 | 55 | num_images_to_mask = int(len(ref_ids) * 0.2) 56 | self.images_to_mask = random.sample(ref_ids, num_images_to_mask) 57 | 58 | all_imgs = self.refer.Imgs 59 | self.imgs = list(all_imgs[i] for i in img_ids) 60 | self.ref_ids = ref_ids 61 | 62 | self.input_ids = [] 63 | self.attention_masks = [] 64 | self.tokenizer = BertTokenizer.from_pretrained(args.bert_tokenizer) 65 | 66 | # for ground target and spatial position 67 | self.target_masks = [] 68 | self.position_masks = [] 69 | 70 | self.sentense_raw = [] 71 | self.pp_phrase = [] 72 | 73 | # debug 74 | self.max_len = 0 75 | 76 | # for RRSIS-D dataset 77 | self.target_cls = {"airplane", "airport", "golf field", "expressway service area", "baseball field","stadium", 78 | "ground track field", "storage tank", "basketball court", "chimney", "tennis court", "overpass", 79 | "train station", "ship", "expressway toll station", "dam", "harbor", "bridge", "vehicle", 80 | "windmill"} 81 | 82 | self.eval_mode = eval_mode 83 | # if we are testing on a dataset, test all sentences of an object; 84 | # o/w, we are validating during training, randomly sample one sentence for efficiency 85 | for r in ref_ids: 86 | ref = self.refer.Refs[r] 87 | 88 | sentences_for_ref = [] 89 | attentions_for_ref = [] 90 | 91 | target_for_ref = [] 92 | position_for_ref = [] 93 | 94 | for i, (el, sent_id) in enumerate(zip(ref['sentences'], ref['sent_ids'])): 95 | sentence_raw = el['raw'] 96 | attention_mask = [0] * self.max_tokens 97 | padded_input_ids = [0] * self.max_tokens 98 | 99 | target_masks = [0] * self.max_tokens 100 | position_masks = [0] * self.max_tokens 101 | 102 | input_ids = self.tokenizer.encode(text=sentence_raw, add_special_tokens=True) 103 | 104 | # truncation of tokens 105 | input_ids = input_ids[:self.max_tokens] 106 | 107 | padded_input_ids[:len(input_ids)] = input_ids 108 | attention_mask[:len(input_ids)] = [1]*len(input_ids) 109 | 110 | sentences_for_ref.append(torch.tensor(padded_input_ids).unsqueeze(0)) 111 | attentions_for_ref.append(torch.tensor(attention_mask).unsqueeze(0)) 112 | 113 | # extract the ground object 114 | # print(sentence_raw) 115 | self.sentense_raw.append(sentence_raw) 116 | tokenized_sentence = word_tokenize(sentence_raw) 117 | 118 | for cls in self.target_cls: 119 | if re.findall(cls, sentence_raw): 120 | 121 | tokenized_cls = word_tokenize(cls) 122 | nums_cls = len(tokenized_cls) 123 | index = 0 124 | for i, token in enumerate(tokenized_sentence): 125 | if re.findall(tokenized_cls[0], token): 126 | index = i 127 | break 128 | target_masks[index + 1: index + nums_cls +1] = [1] * nums_cls 129 | # print(target_masks) 130 | 131 | target_for_ref.append(torch.tensor(target_masks).unsqueeze(0)) 132 | 133 | # extract the spatial position 134 | grammar = r""" 135 | PP: {
??} 136 | {
??} 137 | {
?} 138 | """ 139 | chunkr = nltk.RegexpParser(grammar) 140 | # grammar parsing 141 | tree = chunkr.parse(nltk.pos_tag(tokenized_sentence)) 142 | pp_phrases = [] 143 | for subtree in tree.subtrees(): 144 | if subtree.label() == 'PP': 145 | pp_phrases.append(' '.join(word for word, pos in subtree.leaves())) 146 | 147 | new_pp_phrase = [] 148 | for phrase in pp_phrases: 149 | if not re.findall("of", phrase): 150 | new_pp_phrase.append(phrase) 151 | 152 | if len(new_pp_phrase) > 0: 153 | tokenized_sentence = word_tokenize(sentence_raw) 154 | for pp in new_pp_phrase: 155 | tokenized_pos = word_tokenize(pp) 156 | nums_pos = len(tokenized_pos) 157 | index = 0 158 | for i, token in enumerate(tokenized_sentence): 159 | if tokenized_pos[0] == token: 160 | index = i 161 | break 162 | position_masks[index + 1: index + nums_pos +1] = [1] * nums_pos 163 | 164 | self.pp_phrase.append(new_pp_phrase) 165 | position_for_ref.append(torch.tensor(position_masks).unsqueeze(0)) 166 | # if there are no pp, the position_for_ref equals attentions_for_ref 167 | if torch.sum(position_for_ref[0]) == 0: 168 | position_for_ref = attentions_for_ref 169 | 170 | self.input_ids.append(sentences_for_ref) 171 | self.attention_masks.append(attentions_for_ref) 172 | self.target_masks.append(target_for_ref) 173 | self.position_masks.append(position_for_ref) 174 | 175 | 176 | def get_classes(self): 177 | return self.classes 178 | 179 | def __len__(self): 180 | return len(self.ref_ids) 181 | 182 | def __getitem__(self, index): 183 | this_ref_id = self.ref_ids[index] 184 | this_img_id = self.refer.getImgIds(this_ref_id) 185 | this_img = self.refer.Imgs[this_img_id[0]] 186 | 187 | img = Image.open(os.path.join(self.refer.IMAGE_DIR, this_img['file_name'])) 188 | if self.split == 'train' and this_ref_id in self.images_to_mask: 189 | img = add_random_boxes(img) 190 | 191 | ref = self.refer.loadRefs(this_ref_id) 192 | 193 | ref_mask = np.array(self.refer.getMask(ref[0])['mask']) 194 | annot = np.zeros(ref_mask.shape) 195 | annot[ref_mask == 1] = 1 196 | 197 | annot = Image.fromarray(annot.astype(np.uint8), mode="P") 198 | 199 | sentence = ref[0]['sentences'][0]['raw'] 200 | save_prefix = str(ref[0]['image_id']) + "_" + sentence 201 | 202 | if self.image_transforms is not None: 203 | # resize, from PIL to tensor, and mean and std normalization 204 | # Leisen Debug: write the input images and labels 205 | SHOW_INPUT = False 206 | if SHOW_INPUT: 207 | import cv2 208 | save_dir = "experiments/input_vis" 209 | 210 | # write in the type of image and label 211 | img.save(os.path.join(save_dir, save_prefix + "_image.png")) 212 | mask = ref_mask * 255 213 | cv2.imwrite(os.path.join(save_dir, save_prefix + "_label.png"), mask) 214 | 215 | img, target = self.image_transforms(img, annot) 216 | 217 | choice_sent = np.random.choice(len(self.input_ids[index])) 218 | tensor_embeddings = self.input_ids[index][choice_sent] 219 | attention_mask = self.attention_masks[index][choice_sent] 220 | target_mask = self.target_masks[index][choice_sent] 221 | position_mask = self.position_masks[index][choice_sent] 222 | 223 | # bebug 224 | # print(img.size(), target.size(), tensor_embeddings.size(), attention_mask.size(), 225 | # target_masks.size(), position_masks.size()) 226 | # print(self.sentense_raw[index]) 227 | # print(self.pp_phrase[index]) 228 | # print(self.position_masks[index]) 229 | return img, target, tensor_embeddings, attention_mask, target_mask, position_mask, save_prefix 230 | 231 | 232 | -------------------------------------------------------------------------------- /lib/sa/functions/subtraction_zeropad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from torch.nn.modules.utils import _pair 4 | 5 | from lib.sa.functions.utils import Dtype, Stream, load_kernel 6 | 7 | 8 | CUDA_NUM_THREADS = 1024 9 | 10 | kernel_loop = ''' 11 | #define CUDA_KERNEL_LOOP(i, n) \ 12 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ 13 | i < (n); \ 14 | i += blockDim.x * gridDim.x) 15 | ''' 16 | 17 | 18 | def GET_BLOCKS(N): 19 | return (N + CUDA_NUM_THREADS - 1) // CUDA_NUM_THREADS 20 | 21 | 22 | _subtraction_zeropad_forward_kernel = kernel_loop + ''' 23 | extern "C" 24 | __global__ void subtraction_zeropad_forward_kernel( 25 | const ${Dtype}* bottom_data, ${Dtype}* top_data) { 26 | CUDA_KERNEL_LOOP(index, ${nthreads}) { 27 | const int n = index / ${input_channels} / ${top_height} / ${top_width}; 28 | const int c = (index / ${top_height} / ${top_width}) % ${input_channels}; 29 | const int h = (index / ${top_width}) % ${top_height}; 30 | const int w = index % ${top_width}; 31 | const int h_in_center = -${pad_h} + h * ${stride_h} + (${kernel_h} - 1) / 2 * ${dilation_h}; 32 | const int w_in_center = -${pad_w} + w * ${stride_w} + (${kernel_w} - 1) / 2 * ${dilation_w}; 33 | const int offset_center = ((n * ${input_channels} + c) * ${bottom_height} + h_in_center) * ${bottom_width} + w_in_center; 34 | for (int kh = 0; kh < ${kernel_h}; ++kh) { 35 | for (int kw = 0; kw < ${kernel_w}; ++kw) { 36 | const int h_in = -${pad_h} + h * ${stride_h} + kh * ${dilation_h}; 37 | const int w_in = -${pad_w} + w * ${stride_w} + kw * ${dilation_w}; 38 | const int offset_top = ((n * ${input_channels} + c) * ${kernel_h} * ${kernel_w} + (kh * ${kernel_w} + kw)) * ${top_height} * ${top_width} + h * ${top_width} + w; 39 | if ((h_in >= 0) && (h_in < ${bottom_height}) && (w_in >= 0) && (w_in < ${bottom_width})) { 40 | const int offset_bottom = ((n * ${input_channels} + c) * ${bottom_height} + h_in) * ${bottom_width} + w_in; 41 | top_data[offset_top] = bottom_data[offset_center] - bottom_data[offset_bottom]; 42 | } 43 | else 44 | top_data[offset_top] = bottom_data[offset_center]; 45 | } 46 | } 47 | } 48 | } 49 | ''' 50 | 51 | 52 | _subtraction_zeropad_input_backward_kernel = kernel_loop + ''' 53 | extern "C" 54 | __global__ void subtraction_zeropad_input_backward_kernel( 55 | const ${Dtype}* const top_diff, ${Dtype}* bottom_diff) { 56 | CUDA_KERNEL_LOOP(index, ${nthreads}) { 57 | const int n = index / ${input_channels} / ${bottom_height} / ${bottom_width}; 58 | const int c = (index / ${bottom_height} / ${bottom_width}) % ${input_channels}; 59 | const int h = (index / ${bottom_width}) % ${bottom_height}; 60 | const int w = index % ${bottom_width}; 61 | ${Dtype} value = 0; 62 | for (int kh = 0; kh < ${kernel_h}; ++kh) { 63 | for (int kw = 0; kw < ${kernel_w}; ++kw) { 64 | const int h_out_s = h + ${pad_h} - kh * ${dilation_h}; 65 | const int w_out_s = w + ${pad_w} - kw * ${dilation_w}; 66 | if (((h_out_s % ${stride_h}) == 0) && ((w_out_s % ${stride_w}) == 0)) { 67 | const int h_out = h_out_s / ${stride_h}; 68 | const int w_out = w_out_s / ${stride_w}; 69 | if ((h_out >= 0) && (h_out < ${top_height}) && (w_out >= 0) && (w_out < ${top_width})) { 70 | const int offset_top = ((n * ${input_channels} + c) * ${kernel_h} * ${kernel_w} + (kh * ${kernel_w} + kw)) * ${top_height} * ${top_width} + h_out * ${top_width} + w_out; 71 | value += -top_diff[offset_top]; 72 | } 73 | } 74 | } 75 | } 76 | if (((h % ${stride_h}) == 0) && ((w % ${stride_w}) == 0)) { 77 | const int h_out = h / ${stride_h}; 78 | const int w_out = w / ${stride_w}; 79 | for (int kh = 0; kh < ${kernel_h}; ++kh) { 80 | for (int kw = 0; kw < ${kernel_w}; ++kw) { 81 | const int offset_top = ((n * ${input_channels} + c) * ${kernel_h} * ${kernel_w} + (kh * ${kernel_w} + kw)) * ${top_height} * ${top_width} + h_out * ${top_width} + w_out; 82 | value += top_diff[offset_top]; 83 | } 84 | } 85 | } 86 | bottom_diff[index] = value; 87 | } 88 | } 89 | ''' 90 | 91 | 92 | class SubtractionZeropad(Function): 93 | @staticmethod 94 | def forward(ctx, input, kernel_size, stride, padding, dilation): 95 | kernel_size, stride, padding, dilation = _pair(kernel_size), _pair(stride), _pair(padding), _pair(dilation) 96 | ctx.kernel_size, ctx.stride, ctx.padding, ctx.dilation = kernel_size, stride, padding, dilation 97 | assert input.dim() == 4 and input.is_cuda 98 | batch_size, input_channels, input_height, input_width = input.size() 99 | output_height = int((input_height + 2 * padding[0] - (dilation[0] * (kernel_size[0] - 1) + 1)) / stride[0] + 1) 100 | output_width = int((input_width + 2 * padding[1] - (dilation[1] * (kernel_size[1] - 1) + 1)) / stride[1] + 1) 101 | output = input.new(batch_size, input_channels, kernel_size[0] * kernel_size[1], output_height * output_width) 102 | n = output.numel() // output.shape[2] 103 | with torch.cuda.device_of(input): 104 | f = load_kernel('subtraction_zeropad_forward_kernel', _subtraction_zeropad_forward_kernel, Dtype=Dtype(input), nthreads=n, 105 | num=batch_size, input_channels=input_channels, 106 | bottom_height=input_height, bottom_width=input_width, 107 | top_height=output_height, top_width=output_width, 108 | kernel_h=kernel_size[0], kernel_w=kernel_size[1], 109 | stride_h=stride[0], stride_w=stride[1], 110 | dilation_h=dilation[0], dilation_w=dilation[1], 111 | pad_h=padding[0], pad_w=padding[1]) 112 | f(block=(CUDA_NUM_THREADS, 1, 1), 113 | grid=(GET_BLOCKS(n), 1, 1), 114 | args=[input.data_ptr(), output.data_ptr()], 115 | stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) 116 | ctx.save_for_backward(input) 117 | return output 118 | 119 | @staticmethod 120 | def backward(ctx, grad_output): 121 | kernel_size, stride, padding, dilation = ctx.kernel_size, ctx.stride, ctx.padding, ctx.dilation 122 | input, = ctx.saved_tensors 123 | assert grad_output.is_cuda 124 | if not grad_output.is_contiguous(): 125 | grad_output = grad_output.contiguous() 126 | batch_size, input_channels, input_height, input_width = input.size() 127 | output_height = int((input_height + 2 * padding[0] - (dilation[0] * (kernel_size[0] - 1) + 1)) / stride[0] + 1) 128 | output_width = int((input_width + 2 * padding[1] - (dilation[1] * (kernel_size[1] - 1) + 1)) / stride[1] + 1) 129 | grad_input = None 130 | opt = dict(Dtype=Dtype(grad_output), 131 | num=batch_size, input_channels=input_channels, 132 | bottom_height=input_height, bottom_width=input_width, 133 | top_height=output_height, top_width=output_width, 134 | kernel_h=kernel_size[0], kernel_w=kernel_size[1], 135 | stride_h=stride[0], stride_w=stride[1], 136 | dilation_h=dilation[0], dilation_w=dilation[1], 137 | pad_h=padding[0], pad_w=padding[1]) 138 | with torch.cuda.device_of(input): 139 | if ctx.needs_input_grad[0]: 140 | grad_input = input.new(input.size()) 141 | n = grad_input.numel() 142 | opt['nthreads'] = n 143 | f = load_kernel('subtraction_zeropad_input_backward_kernel', _subtraction_zeropad_input_backward_kernel, **opt) 144 | f(block=(CUDA_NUM_THREADS, 1, 1), 145 | grid=(GET_BLOCKS(n), 1, 1), 146 | args=[grad_output.data_ptr(), grad_input.data_ptr()], 147 | stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) 148 | return grad_input, None, None, None, None 149 | 150 | 151 | def subtraction_zeropad(input, kernel_size=3, stride=1, padding=0, dilation=1): 152 | assert input.dim() == 4 153 | if input.is_cuda: 154 | out = SubtractionZeropad.apply(input, kernel_size, stride, padding, dilation) 155 | else: 156 | raise NotImplementedError 157 | return out 158 | 159 | 160 | def test_subtraction_zeropad(): 161 | import os 162 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 163 | kernel_size, stride, dilation = 5, 4, 2 164 | padding = (dilation * (kernel_size - 1) + 1) // 2 165 | n, c, in_height, in_width = 2, 8, 9, 9 166 | out_height = int((in_height + 2 * padding - (dilation * (kernel_size - 1) + 1)) / stride + 1) 167 | out_width = int((in_width + 2 * padding - (dilation * (kernel_size - 1) + 1)) / stride + 1) 168 | x = torch.randn(n, c, in_height, in_width, requires_grad=True).double().cuda() 169 | 170 | y1 = subtraction_zeropad(x, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation) 171 | unfold_i = torch.nn.Unfold(kernel_size=1, dilation=dilation, padding=0, stride=stride) 172 | unfold_j = torch.nn.Unfold(kernel_size=kernel_size, dilation=dilation, padding=padding, stride=stride) 173 | y2 = unfold_i(x).view(n, c, 1, out_height * out_width) - unfold_j(x).view(n, c, pow(kernel_size, 2), out_height * out_width) 174 | # y2 = unfold_i(x[:, :, kernel_size//2:-(kernel_size//2), kernel_size//2:-(kernel_size//2)]).view(n, c, 1, out_height * out_width) - unfold_j(x).view(n, c, pow(kernel_size, 2), out_height * out_width) 175 | assert (y1 - y2).abs().max() < 1e-9 176 | 177 | gx1 = torch.autograd.grad(y1.mean(), x, retain_graph=True)[0] 178 | gx2 = torch.autograd.grad(y2.mean(), x, retain_graph=True)[0] 179 | assert (gx1 - gx2).abs().max() < 1e-9 180 | 181 | from functools import partial 182 | assert torch.autograd.gradcheck(partial(subtraction_zeropad, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation), x) 183 | print('test case passed') 184 | 185 | 186 | if __name__ == '__main__': 187 | test_subtraction_zeropad() 188 | -------------------------------------------------------------------------------- /lib/text_aware_multiscale_enhancement.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | # from backbone import SpatialImageLanguageAttention 5 | 6 | class SpatialImageLanguageAttention(nn.Module): 7 | def __init__(self, v_in_channels, l_in_channels, key_channels, value_channels, out_channels=None, num_heads=1): 8 | super(SpatialImageLanguageAttention, self).__init__() 9 | # x shape: (B, H*W, v_in_channels) 10 | # l input shape: (B, l_in_channels, N_l) 11 | # l_mask shape: (B, N_l, 1) 12 | self.v_in_channels = v_in_channels 13 | self.l_in_channels = l_in_channels 14 | self.out_channels = out_channels 15 | self.key_channels = key_channels 16 | self.value_channels = value_channels 17 | self.num_heads = num_heads 18 | if out_channels is None: 19 | self.out_channels = self.value_channels 20 | 21 | # Keys: language features: (B, l_in_channels, #words) 22 | # avoid any form of spatial normalization because a sentence contains many padding 0s 23 | self.f_key = nn.Sequential( 24 | nn.Conv1d(self.l_in_channels, self.key_channels, kernel_size=1, stride=1), 25 | ) 26 | 27 | # Queries: visual features: (B, H*W, v_in_channels) 28 | self.f_query = nn.Sequential( 29 | nn.Conv1d(self.v_in_channels, self.key_channels, kernel_size=1, stride=1), 30 | nn.InstanceNorm1d(self.key_channels), 31 | ) 32 | 33 | # Values: language features: (B, l_in_channels, #words) 34 | self.f_value = nn.Sequential( 35 | nn.Conv1d(self.l_in_channels, self.value_channels, kernel_size=1, stride=1), 36 | ) 37 | 38 | # Out projection 39 | self.W = nn.Sequential( 40 | nn.Conv1d(self.value_channels, self.out_channels, kernel_size=1, stride=1), 41 | nn.InstanceNorm1d(self.out_channels), 42 | ) 43 | 44 | def forward(self, x, l, l_mask): 45 | # x shape: (B, H*W, v_in_channels) 46 | # l input shape: (B, l_in_channels, N_l) 47 | # l_mask shape: (B, N_l, 1) 48 | B, HW = x.size(0), x.size(1) 49 | x = x.permute(0, 2, 1) # (B, key_channels, H*W) 50 | l_mask = l_mask.permute(0, 2, 1) # (B, N_l, 1) -> (B, 1, N_l) 51 | 52 | query = self.f_query(x) # (B, key_channels, H*W) if Conv1D 53 | query = query.permute(0, 2, 1) # (B, H*W, key_channels) 54 | key = self.f_key(l) # (B, key_channels, N_l) 55 | value = self.f_value(l) # (B, self.value_channels, N_l) 56 | key = key * l_mask # (B, key_channels, N_l) 57 | value = value * l_mask # (B, self.value_channels, N_l) 58 | n_l = value.size(-1) 59 | query = query.reshape(B, HW, self.num_heads, self.key_channels//self.num_heads).permute(0, 2, 1, 3) 60 | # (b, num_heads, H*W, self.key_channels//self.num_heads) 61 | key = key.reshape(B, self.num_heads, self.key_channels//self.num_heads, n_l) 62 | # (b, num_heads, self.key_channels//self.num_heads, n_l) 63 | value = value.reshape(B, self.num_heads, self.value_channels//self.num_heads, n_l) 64 | # # (b, num_heads, self.value_channels//self.num_heads, n_l) 65 | l_mask = l_mask.unsqueeze(1) # (b, 1, 1, n_l) 66 | 67 | sim_map = torch.matmul(query, key) # (B, self.num_heads, H*W, N_l) 68 | sim_map = (self.key_channels ** -.5) * sim_map # scaled dot product 69 | 70 | sim_map = sim_map + (1e4*l_mask - 1e4) # assign a very small number to padding positions 71 | sim_map = F.softmax(sim_map, dim=-1) # (B, num_heads, h*w, N_l) 72 | out = torch.matmul(sim_map, value.permute(0, 1, 3, 2)) # (B, num_heads, H*W, self.value_channels//num_heads) 73 | out = out.permute(0, 2, 1, 3).contiguous().reshape(B, HW, self.value_channels) # (B, H*W, value_channels) 74 | out = out.permute(0, 2, 1) # (B, value_channels, HW) 75 | out = self.W(out) # (B, value_channels, HW) 76 | out = out.permute(0, 2, 1) # (B, HW, value_channels) 77 | 78 | return out 79 | 80 | 81 | 82 | class h_sigmoid(nn.Module): 83 | def __init__(self, inplace=True): 84 | super(h_sigmoid, self).__init__() 85 | self.relu = nn.ReLU6(inplace=inplace) 86 | 87 | def forward(self, x): 88 | return self.relu(x + 3) / 6 89 | 90 | 91 | class Linear_BN(torch.nn.Sequential): 92 | def __init__(self, a, b, bn_weight_init=1): 93 | super().__init__() 94 | self.add_module('c', torch.nn.Linear(a, b, bias=False)) 95 | bn = torch.nn.BatchNorm1d(b) 96 | torch.nn.init.constant_(bn.weight, bn_weight_init) 97 | torch.nn.init.constant_(bn.bias, 0) 98 | self.add_module('bn', bn) 99 | 100 | @torch.no_grad() 101 | def fuse(self): 102 | l, bn = self._modules.values() 103 | w = bn.weight / (bn.running_var + bn.eps)**0.5 104 | w = l.weight * w[:, None] 105 | b = bn.bias - bn.running_mean * bn.weight / \ 106 | (bn.running_var + bn.eps)**0.5 107 | m = torch.nn.Linear(w.size(1), w.size(0)) 108 | m.weight.data.copy_(w) 109 | m.bias.data.copy_(b) 110 | return m 111 | 112 | def forward(self, x): 113 | l, bn = self._modules.values() 114 | x = l(x) 115 | return bn(x.flatten(0, 1)).reshape_as(x) 116 | 117 | 118 | class Residual(torch.nn.Module): 119 | def __init__(self, m): 120 | super().__init__() 121 | self.m = m 122 | 123 | def forward(self, x): 124 | return x + self.m(x) 125 | 126 | 127 | class ScaleAwareGate(nn.Module): 128 | def __init__(self, inp, oup): 129 | super(ScaleAwareGate, self).__init__() 130 | 131 | self.local_embedding = nn.Conv2d(inp, oup, kernel_size=1) 132 | self.bn1 = nn.BatchNorm2d(oup) 133 | 134 | self.global_embedding = nn.Conv2d(inp, oup, kernel_size=1) 135 | self.bn2 = nn.BatchNorm2d(oup) 136 | 137 | self.global_act = nn.Conv2d(inp, oup, kernel_size=1) 138 | self.bn3 = nn.BatchNorm2d(oup) 139 | self.act = h_sigmoid() 140 | 141 | def forward(self, x_l, x_g): 142 | B, C, H, W = x_l.shape 143 | local_feat = self.local_embedding(x_l) 144 | local_feat = self.bn1(local_feat) 145 | 146 | global_feat = self.global_embedding(x_g) 147 | global_feat = self.bn2(global_feat) 148 | global_feat = F.interpolate(global_feat, size=(H, W), mode='bilinear', align_corners=False) 149 | 150 | global_act = self.global_act(x_g) 151 | global_act = self.bn3(global_act) 152 | sig_act = F.interpolate(self.act(global_act), size=(H, W), mode='bilinear', align_corners=False) 153 | 154 | out = local_feat * sig_act + global_feat 155 | return out 156 | 157 | 158 | class FeedForward(nn.Module): 159 | def __init__(self, dim, hidden_dim, dropout = 0.): 160 | super().__init__() 161 | self.net = nn.Sequential( 162 | nn.Linear(dim, hidden_dim), 163 | nn.GELU(), 164 | nn.Dropout(dropout), 165 | nn.Linear(hidden_dim, dim), 166 | nn.Dropout(dropout) 167 | ) 168 | 169 | def forward(self, x): 170 | return self.net(x) 171 | 172 | 173 | class TMEMBlock(nn.Module): 174 | def __init__(self, dim, channels, mlp_ratio=2): 175 | super().__init__() 176 | # self.csa1 = Residual(CrossScaleAttention(dim)) 177 | # self.intra_ff = Residual(IntraFeedForward(channels, mlp_ratio)) 178 | # self.csa2 = Residual(CrossScaleAttention(dim)) 179 | # self.ff = Residual(FeedForward(dim, dim*mlp_ratio)) 180 | 181 | self.norm1 = nn.LayerNorm(dim) 182 | self.tma = SpatialImageLanguageAttention(dim, 768, dim, dim) 183 | 184 | self.norm2 = nn.LayerNorm(dim) 185 | self.ff = FeedForward(dim, 2*dim) 186 | 187 | def forward(self, x, l, l_mask): 188 | B, C, H, W = x.shape 189 | x = x.flatten(2).transpose(1, 2) 190 | out = self.norm1(x) 191 | out = x + self.tma(out, l, l_mask) 192 | 193 | # out = self.norm2(out) 194 | 195 | out = out + self.ff(self.norm2(out)) 196 | out = out.reshape(B, H, W, C).permute(0, 3, 1, 2) 197 | 198 | return out 199 | 200 | 201 | 202 | class PyramidPoolAgg(nn.Module): 203 | def __init__(self, stride): 204 | super().__init__() 205 | self.stride = stride 206 | 207 | def forward(self, inputs): 208 | B, C, H, W = inputs[-1].shape 209 | H = (H - 1) // self.stride + 1 210 | W = (W - 1) // self.stride + 1 211 | return torch.cat([nn.functional.adaptive_avg_pool2d(inp, (H, W)) for inp in inputs], dim=1) 212 | 213 | 214 | class TMEM(nn.Module): 215 | def __init__(self, dim, num_blocks=1, channels=[128, 256, 512, 1024], downsample=1): 216 | super().__init__() 217 | self.hidden_dim = dim // 4 218 | self.channels = channels 219 | self.stride = downsample 220 | 221 | self.down_channel = nn.Conv2d(dim, self.hidden_dim, 1) 222 | self.up_channel = nn.Conv2d(self.hidden_dim, dim, 1) 223 | 224 | # downsample to h/32, w/32 225 | self.pool = PyramidPoolAgg(stride=self.stride) 226 | self.block = nn.ModuleList([ 227 | TMEMBlock(self.hidden_dim, channels) 228 | for _ in range(num_blocks) 229 | ]) 230 | self.bn = nn.BatchNorm2d(self.hidden_dim) 231 | self.fusion = nn.ModuleList([ 232 | ScaleAwareGate(channels[i], channels[i]) 233 | for i in range(len(channels)) 234 | ]) 235 | 236 | def forward(self, input, l, l_mask): # [B, C, H, W] 237 | out = self.pool(input) 238 | out = self.down_channel(out) 239 | for layer in self.block: 240 | out = layer(out, l, l_mask) 241 | out = self.bn(out) 242 | out = self.up_channel(out) 243 | xx = out.split(self.channels, dim=1) 244 | results = [] 245 | for i in range(len(self.channels)): 246 | TMEM_before = input[i] 247 | TMEM_after = xx[i] 248 | out_ = self.fusion[i](TMEM_before, TMEM_after) 249 | results.append(out_) 250 | return results 251 | 252 | 253 | if __name__ == '__main__': 254 | model = TMEM(1920) 255 | # model = CIM(1920) 256 | l = torch.randn(2, 768, 20) 257 | l_mask = torch.ones(2, 20, 1) 258 | x1 = torch.randn(2, 128, 120, 120) 259 | x2 = torch.randn(2, 256, 60, 60) 260 | x3 = torch.randn(2, 512, 30, 30) 261 | x4 = torch.randn(2, 1024, 15, 15) 262 | x = tuple([x1, x2, x3, x4]) 263 | y = model(x, l, l_mask) 264 | # y = model(x) 265 | print(y.shape) -------------------------------------------------------------------------------- /lib/sa/functions/subtraction_refpad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from torch.nn.modules.utils import _pair 4 | 5 | from lib.sa.functions.utils import Dtype, Stream, load_kernel 6 | 7 | 8 | CUDA_NUM_THREADS = 1024 9 | 10 | kernel_loop = ''' 11 | #define CUDA_KERNEL_LOOP(i, n) \ 12 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ 13 | i < (n); \ 14 | i += blockDim.x * gridDim.x) 15 | ''' 16 | 17 | 18 | def GET_BLOCKS(N): 19 | return (N + CUDA_NUM_THREADS - 1) // CUDA_NUM_THREADS 20 | 21 | 22 | _subtraction_refpad_forward_kernel = kernel_loop + ''' 23 | extern "C" 24 | __global__ void subtraction_refpad_forward_kernel( 25 | const ${Dtype}* bottom_data, ${Dtype}* top_data) { 26 | CUDA_KERNEL_LOOP(index, ${nthreads}) { 27 | const int n = index / ${input_channels} / ${top_height} / ${top_width}; 28 | const int c = (index / ${top_height} / ${top_width}) % ${input_channels}; 29 | const int h = (index / ${top_width}) % ${top_height}; 30 | const int w = index % ${top_width}; 31 | const int h_in_center = -${pad_h} + h * ${stride_h} + (${kernel_h} - 1) / 2 * ${dilation_h}; 32 | const int w_in_center = -${pad_w} + w * ${stride_w} + (${kernel_w} - 1) / 2 * ${dilation_w}; 33 | const int offset_center = ((n * ${input_channels} + c) * ${bottom_height} + h_in_center) * ${bottom_width} + w_in_center; 34 | for (int kh = 0; kh < ${kernel_h}; ++kh) { 35 | for (int kw = 0; kw < ${kernel_w}; ++kw) { 36 | int h_in = -${pad_h} + h * ${stride_h} + kh * ${dilation_h}; 37 | int w_in = -${pad_w} + w * ${stride_w} + kw * ${dilation_w}; 38 | const int offset_top = ((n * ${input_channels} + c) * ${kernel_h} * ${kernel_w} + (kh * ${kernel_w} + kw)) * ${top_height} * ${top_width} + h * ${top_width} + w; 39 | int offset_bottom; 40 | if ((h_in >= 0) && (h_in < ${bottom_height}) && (w_in >= 0) && (w_in < ${bottom_width})) { 41 | offset_bottom = ((n * ${input_channels} + c) * ${bottom_height} + h_in) * ${bottom_width} + w_in; 42 | } 43 | else { 44 | if (h_in < 0) h_in = -h_in; 45 | if (h_in >= ${bottom_height}) h_in = 2 * (${bottom_height} - 1) - h_in; 46 | if (w_in < 0) w_in = -w_in; 47 | if (w_in >= ${bottom_width}) w_in = 2 * (${bottom_width} - 1) - w_in; 48 | offset_bottom = ((n * ${input_channels} + c) * ${bottom_height} + h_in) * ${bottom_width} + w_in; 49 | } 50 | top_data[offset_top] = bottom_data[offset_center] - bottom_data[offset_bottom]; 51 | } 52 | } 53 | } 54 | } 55 | ''' 56 | 57 | 58 | _subtraction_refpad_input_backward_kernel = kernel_loop + ''' 59 | extern "C" 60 | __global__ void subtraction_refpad_input_backward_kernel( 61 | const ${Dtype}* const top_diff, ${Dtype}* bottom_diff) { 62 | CUDA_KERNEL_LOOP(index, ${nthreads}) { 63 | const int n = index / ${input_channels} / (${bottom_height} + 2 * ${pad_h}) / (${bottom_width} + 2 * ${pad_w}); 64 | const int c = (index / (${bottom_height} + 2 * ${pad_h}) / (${bottom_width} + 2 * ${pad_w})) % ${input_channels}; 65 | const int h = (index / (${bottom_width} + 2 * ${pad_w})) % (${bottom_height} + 2 * ${pad_h}); 66 | const int w = index % (${bottom_width} + 2 * ${pad_w}); 67 | ${Dtype} value = 0; 68 | for (int kh = 0; kh < ${kernel_h}; ++kh) { 69 | for (int kw = 0; kw < ${kernel_w}; ++kw) { 70 | const int h_out_s = h - kh * ${dilation_h}; 71 | const int w_out_s = w - kw * ${dilation_w}; 72 | if (((h_out_s % ${stride_h}) == 0) && ((w_out_s % ${stride_w}) == 0)) { 73 | const int h_out = h_out_s / ${stride_h}; 74 | const int w_out = w_out_s / ${stride_w}; 75 | if ((h_out >= 0) && (h_out < ${top_height}) && (w_out >= 0) && (w_out < ${top_width})) { 76 | const int offset_top = ((n * ${input_channels} + c) * ${kernel_h} * ${kernel_w} + (kh * ${kernel_w} + kw)) * ${top_height} * ${top_width} + h_out * ${top_width} + w_out; 77 | value += -top_diff[offset_top]; 78 | } 79 | } 80 | } 81 | } 82 | const int hh = h - ${pad_h}; 83 | const int ww = w - ${pad_w}; 84 | if ((hh >= 0) && (hh < ${bottom_height}) && (ww >= 0) && (ww < ${bottom_width})) { 85 | if (((hh % ${stride_h}) == 0) && ((ww % ${stride_w}) == 0)) { 86 | const int h_out = hh / ${stride_h}; 87 | const int w_out = ww / ${stride_w}; 88 | for (int kh = 0; kh < ${kernel_h}; ++kh) { 89 | for (int kw = 0; kw < ${kernel_w}; ++kw) { 90 | const int offset_top = ((n * ${input_channels} + c) * ${kernel_h} * ${kernel_w} + (kh * ${kernel_w} + kw)) * ${top_height} * ${top_width} + h_out * ${top_width} + w_out; 91 | value += top_diff[offset_top]; 92 | } 93 | } 94 | } 95 | } 96 | bottom_diff[index] = value; 97 | } 98 | } 99 | ''' 100 | 101 | 102 | class SubtractionRefpad(Function): 103 | @staticmethod 104 | def forward(ctx, input, kernel_size, stride, padding, dilation): 105 | kernel_size, stride, padding, dilation = _pair(kernel_size), _pair(stride), _pair(padding), _pair(dilation) 106 | ctx.kernel_size, ctx.stride, ctx.padding, ctx.dilation = kernel_size, stride, padding, dilation 107 | assert input.dim() == 4 and input.is_cuda 108 | batch_size, input_channels, input_height, input_width = input.size() 109 | output_height = int((input_height + 2 * padding[0] - (dilation[0] * (kernel_size[0] - 1) + 1)) / stride[0] + 1) 110 | output_width = int((input_width + 2 * padding[1] - (dilation[1] * (kernel_size[1] - 1) + 1)) / stride[1] + 1) 111 | output = input.new(batch_size, input_channels, kernel_size[0] * kernel_size[1], output_height * output_width) 112 | n = output.numel() // output.shape[2] 113 | with torch.cuda.device_of(input): 114 | f = load_kernel('subtraction_refpad_forward_kernel', _subtraction_refpad_forward_kernel, Dtype=Dtype(input), nthreads=n, 115 | num=batch_size, input_channels=input_channels, 116 | bottom_height=input_height, bottom_width=input_width, 117 | top_height=output_height, top_width=output_width, 118 | kernel_h=kernel_size[0], kernel_w=kernel_size[1], 119 | stride_h=stride[0], stride_w=stride[1], 120 | dilation_h=dilation[0], dilation_w=dilation[1], 121 | pad_h=padding[0], pad_w=padding[1]) 122 | f(block=(CUDA_NUM_THREADS, 1, 1), 123 | grid=(GET_BLOCKS(n), 1, 1), 124 | args=[input.data_ptr(), output.data_ptr()], 125 | stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) 126 | ctx.save_for_backward(input) 127 | return output 128 | 129 | @staticmethod 130 | def backward(ctx, grad_output): 131 | kernel_size, stride, padding, dilation = ctx.kernel_size, ctx.stride, ctx.padding, ctx.dilation 132 | input, = ctx.saved_tensors 133 | assert grad_output.is_cuda 134 | if not grad_output.is_contiguous(): 135 | grad_output = grad_output.contiguous() 136 | batch_size, input_channels, input_height, input_width = input.size() 137 | output_height = int((input_height + 2 * padding[0] - (dilation[0] * (kernel_size[0] - 1) + 1)) / stride[0] + 1) 138 | output_width = int((input_width + 2 * padding[1] - (dilation[1] * (kernel_size[1] - 1) + 1)) / stride[1] + 1) 139 | grad_input = None 140 | opt = dict(Dtype=Dtype(grad_output), 141 | num=batch_size, input_channels=input_channels, 142 | bottom_height=input_height, bottom_width=input_width, 143 | top_height=output_height, top_width=output_width, 144 | kernel_h=kernel_size[0], kernel_w=kernel_size[1], 145 | stride_h=stride[0], stride_w=stride[1], 146 | dilation_h=dilation[0], dilation_w=dilation[1], 147 | pad_h=padding[0], pad_w=padding[1]) 148 | with torch.cuda.device_of(input): 149 | if ctx.needs_input_grad[0]: 150 | grad_input = input.new(batch_size, input_channels, input_height + 2 * padding[0], input_width + 2 * padding[1]) 151 | n = grad_input.numel() 152 | opt['nthreads'] = n 153 | f = load_kernel('subtraction_refpad_input_backward_kernel', _subtraction_refpad_input_backward_kernel, **opt) 154 | f(block=(CUDA_NUM_THREADS, 1, 1), 155 | grid=(GET_BLOCKS(n), 1, 1), 156 | args=[grad_output.data_ptr(), grad_input.data_ptr()], 157 | stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) 158 | grad_input[:, :, padding[0] + 1:2 * padding[0] + 1, :] += torch.flip(grad_input[:, :, :padding[0], :], dims=[2]) 159 | grad_input[:, :, input_height - 1:input_height + padding[0] - 1, :] += torch.flip(grad_input[:, :, input_height + padding[0]:, :], dims=[2]) 160 | grad_input[:, :, :, padding[1] + 1:2 * padding[1] + 1] += torch.flip(grad_input[:, :, :, :padding[1]], dims=[3]) 161 | grad_input[:, :, :, input_width - 1:input_width + padding[1] - 1] += torch.flip(grad_input[:, :, :, input_width + padding[1]:], dims=[3]) 162 | grad_input = grad_input[:, :, padding[0]:padding[0] + input_height, padding[1]:padding[1] + input_width] 163 | return grad_input, None, None, None, None 164 | 165 | 166 | def subtraction_refpad(input, kernel_size=3, stride=1, padding=0, dilation=1): 167 | assert input.dim() == 4 168 | if input.is_cuda: 169 | out = SubtractionRefpad.apply(input, kernel_size, stride, padding, dilation) 170 | else: 171 | raise NotImplementedError 172 | return out 173 | 174 | 175 | def test_subtraction_refpad(): 176 | import os 177 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 178 | kernel_size, stride, dilation = 5, 4, 2 179 | padding = (dilation * (kernel_size - 1) + 1) // 2 180 | n, c, in_height, in_width = 2, 8, 5, 5 181 | out_height = int((in_height + 2 * padding - (dilation * (kernel_size - 1) + 1)) / stride + 1) 182 | out_width = int((in_width + 2 * padding - (dilation * (kernel_size - 1) + 1)) / stride + 1) 183 | x = torch.randn(n, c, in_height, in_width, requires_grad=True).double().cuda() 184 | 185 | y1 = subtraction_refpad(x, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation) 186 | unfold_i = torch.nn.Unfold(kernel_size=1, dilation=dilation, padding=0, stride=stride) 187 | unfold_j = torch.nn.Unfold(kernel_size=kernel_size, dilation=dilation, padding=0, stride=stride) 188 | pad = torch.nn.ReflectionPad2d(padding) 189 | y2 = unfold_i(x).view(n, c, 1, out_height * out_width) - unfold_j(pad(x)).view(n, c, pow(kernel_size, 2), out_height * out_width) 190 | assert (y1 - y2).abs().max() < 1e-9 191 | 192 | gx1 = torch.autograd.grad(y1.mean(), x, retain_graph=True)[0] 193 | gx2 = torch.autograd.grad(y2.mean(), x, retain_graph=True)[0] 194 | assert (gx1 - gx2).abs().max() < 1e-9 195 | 196 | from functools import partial 197 | assert torch.autograd.gradcheck(partial(subtraction_refpad, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation), x) 198 | print('test case passed') 199 | 200 | 201 | if __name__ == '__main__': 202 | test_subtraction_refpad() 203 | -------------------------------------------------------------------------------- /lib/sa/functions/subtraction2_zeropad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from torch.nn.modules.utils import _pair 4 | 5 | from lib.sa.functions.utils import Dtype, Stream, load_kernel 6 | 7 | 8 | CUDA_NUM_THREADS = 1024 9 | 10 | kernel_loop = ''' 11 | #define CUDA_KERNEL_LOOP(i, n) \ 12 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ 13 | i < (n); \ 14 | i += blockDim.x * gridDim.x) 15 | ''' 16 | 17 | 18 | def GET_BLOCKS(N): 19 | return (N + CUDA_NUM_THREADS - 1) // CUDA_NUM_THREADS 20 | 21 | 22 | _subtraction2_zeropad_forward_kernel = kernel_loop + ''' 23 | extern "C" 24 | __global__ void subtraction2_zeropad_forward_kernel( 25 | const ${Dtype}* bottom1_data, const ${Dtype}* bottom2_data, ${Dtype}* top_data) { 26 | CUDA_KERNEL_LOOP(index, ${nthreads}) { 27 | const int n = index / ${input_channels} / ${top_height} / ${top_width}; 28 | const int c = (index / ${top_height} / ${top_width}) % ${input_channels}; 29 | const int h = (index / ${top_width}) % ${top_height}; 30 | const int w = index % ${top_width}; 31 | const int h_in_center = -${pad_h} + h * ${stride_h} + (${kernel_h} - 1) / 2 * ${dilation_h}; 32 | const int w_in_center = -${pad_w} + w * ${stride_w} + (${kernel_w} - 1) / 2 * ${dilation_w}; 33 | const int offset_center = ((n * ${input_channels} + c) * ${bottom_height} + h_in_center) * ${bottom_width} + w_in_center; 34 | for (int kh = 0; kh < ${kernel_h}; ++kh) { 35 | for (int kw = 0; kw < ${kernel_w}; ++kw) { 36 | const int h_in = -${pad_h} + h * ${stride_h} + kh * ${dilation_h}; 37 | const int w_in = -${pad_w} + w * ${stride_w} + kw * ${dilation_w}; 38 | const int offset_top = ((n * ${input_channels} + c) * ${kernel_h} * ${kernel_w} + (kh * ${kernel_w} + kw)) * ${top_height} * ${top_width} + h * ${top_width} + w; 39 | if ((h_in >= 0) && (h_in < ${bottom_height}) && (w_in >= 0) && (w_in < ${bottom_width})) { 40 | const int offset_bottom = ((n * ${input_channels} + c) * ${bottom_height} + h_in) * ${bottom_width} + w_in; 41 | top_data[offset_top] = bottom1_data[offset_center] - bottom2_data[offset_bottom]; 42 | } 43 | else 44 | top_data[offset_top] = bottom1_data[offset_center]; 45 | } 46 | } 47 | } 48 | } 49 | ''' 50 | 51 | 52 | _subtraction2_zeropad_input1_backward_kernel = kernel_loop + ''' 53 | extern "C" 54 | __global__ void subtraction2_zeropad_input1_backward_kernel( 55 | const ${Dtype}* const top_diff, ${Dtype}* bottom_diff) { 56 | CUDA_KERNEL_LOOP(index, ${nthreads}) { 57 | const int n = index / ${input_channels} / ${bottom_height} / ${bottom_width}; 58 | const int c = (index / ${bottom_height} / ${bottom_width}) % ${input_channels}; 59 | const int h = (index / ${bottom_width}) % ${bottom_height}; 60 | const int w = index % ${bottom_width}; 61 | ${Dtype} value = 0; 62 | if (((h % ${stride_h}) == 0) && ((w % ${stride_w}) == 0)) { 63 | const int h_out = h / ${stride_h}; 64 | const int w_out = w / ${stride_w}; 65 | for (int kh = 0; kh < ${kernel_h}; ++kh) { 66 | for (int kw = 0; kw < ${kernel_w}; ++kw) { 67 | const int offset_top = ((n * ${input_channels} + c) * ${kernel_h} * ${kernel_w} + (kh * ${kernel_w} + kw)) * ${top_height} * ${top_width} + h_out * ${top_width} + w_out; 68 | value += top_diff[offset_top]; 69 | } 70 | } 71 | } 72 | bottom_diff[index] = value; 73 | } 74 | } 75 | ''' 76 | 77 | 78 | _subtraction2_zeropad_input2_backward_kernel = kernel_loop + ''' 79 | extern "C" 80 | __global__ void subtraction2_zeropad_input2_backward_kernel( 81 | const ${Dtype}* const top_diff, ${Dtype}* bottom_diff) { 82 | CUDA_KERNEL_LOOP(index, ${nthreads}) { 83 | const int n = index / ${input_channels} / ${bottom_height} / ${bottom_width}; 84 | const int c = (index / ${bottom_height} / ${bottom_width}) % ${input_channels}; 85 | const int h = (index / ${bottom_width}) % ${bottom_height}; 86 | const int w = index % ${bottom_width}; 87 | ${Dtype} value = 0; 88 | for (int kh = 0; kh < ${kernel_h}; ++kh) { 89 | for (int kw = 0; kw < ${kernel_w}; ++kw) { 90 | const int h_out_s = h + ${pad_h} - kh * ${dilation_h}; 91 | const int w_out_s = w + ${pad_w} - kw * ${dilation_w}; 92 | if (((h_out_s % ${stride_h}) == 0) && ((w_out_s % ${stride_w}) == 0)) { 93 | const int h_out = h_out_s / ${stride_h}; 94 | const int w_out = w_out_s / ${stride_w}; 95 | if ((h_out >= 0) && (h_out < ${top_height}) && (w_out >= 0) && (w_out < ${top_width})) { 96 | const int offset_top = ((n * ${input_channels} + c) * ${kernel_h} * ${kernel_w} + (kh * ${kernel_w} + kw)) * ${top_height} * ${top_width} + h_out * ${top_width} + w_out; 97 | value += -top_diff[offset_top]; 98 | } 99 | } 100 | } 101 | } 102 | bottom_diff[index] = value; 103 | } 104 | } 105 | ''' 106 | 107 | 108 | class Subtraction2Zeropad(Function): 109 | @staticmethod 110 | def forward(ctx, input1, input2, kernel_size, stride, padding, dilation): 111 | kernel_size, stride, padding, dilation = _pair(kernel_size), _pair(stride), _pair(padding), _pair(dilation) 112 | ctx.kernel_size, ctx.stride, ctx.padding, ctx.dilation = kernel_size, stride, padding, dilation 113 | assert input1.dim() == 4 and input1.is_cuda 114 | batch_size, input_channels, input_height, input_width = input1.size() 115 | output_height = int((input_height + 2 * padding[0] - (dilation[0] * (kernel_size[0] - 1) + 1)) / stride[0] + 1) 116 | output_width = int((input_width + 2 * padding[1] - (dilation[1] * (kernel_size[1] - 1) + 1)) / stride[1] + 1) 117 | output = input1.new(batch_size, input_channels, kernel_size[0] * kernel_size[1], output_height * output_width) 118 | n = output.numel() // output.shape[2] 119 | with torch.cuda.device_of(input1): 120 | f = load_kernel('subtraction2_zeropad_forward_kernel', _subtraction2_zeropad_forward_kernel, Dtype=Dtype(input1), nthreads=n, 121 | num=batch_size, input_channels=input_channels, 122 | bottom_height=input_height, bottom_width=input_width, 123 | top_height=output_height, top_width=output_width, 124 | kernel_h=kernel_size[0], kernel_w=kernel_size[1], 125 | stride_h=stride[0], stride_w=stride[1], 126 | dilation_h=dilation[0], dilation_w=dilation[1], 127 | pad_h=padding[0], pad_w=padding[1]) 128 | f(block=(CUDA_NUM_THREADS, 1, 1), 129 | grid=(GET_BLOCKS(n), 1, 1), 130 | args=[input1.data_ptr(), input2.data_ptr(), output.data_ptr()], 131 | stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) 132 | ctx.save_for_backward(input1, input2) 133 | return output 134 | 135 | @staticmethod 136 | def backward(ctx, grad_output): 137 | kernel_size, stride, padding, dilation = ctx.kernel_size, ctx.stride, ctx.padding, ctx.dilation 138 | input1, input2 = ctx.saved_tensors 139 | assert grad_output.is_cuda 140 | if not grad_output.is_contiguous(): 141 | grad_output = grad_output.contiguous() 142 | batch_size, input_channels, input_height, input_width = input1.size() 143 | output_height = int((input_height + 2 * padding[0] - (dilation[0] * (kernel_size[0] - 1) + 1)) / stride[0] + 1) 144 | output_width = int((input_width + 2 * padding[1] - (dilation[1] * (kernel_size[1] - 1) + 1)) / stride[1] + 1) 145 | grad_input1, grad_input2 = None, None 146 | opt = dict(Dtype=Dtype(grad_output), 147 | num=batch_size, input_channels=input_channels, 148 | bottom_height=input_height, bottom_width=input_width, 149 | top_height=output_height, top_width=output_width, 150 | kernel_h=kernel_size[0], kernel_w=kernel_size[1], 151 | stride_h=stride[0], stride_w=stride[1], 152 | dilation_h=dilation[0], dilation_w=dilation[1], 153 | pad_h=padding[0], pad_w=padding[1]) 154 | with torch.cuda.device_of(input1): 155 | if ctx.needs_input_grad[0]: 156 | grad_input1 = input1.new(input1.size()) 157 | n = grad_input1.numel() 158 | opt['nthreads'] = n 159 | f = load_kernel('subtraction2_zeropad_input1_backward_kernel', _subtraction2_zeropad_input1_backward_kernel, **opt) 160 | f(block=(CUDA_NUM_THREADS, 1, 1), 161 | grid=(GET_BLOCKS(n), 1, 1), 162 | args=[grad_output.data_ptr(), grad_input1.data_ptr()], 163 | stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) 164 | with torch.cuda.device_of(input2): 165 | if ctx.needs_input_grad[1]: 166 | grad_input2 = input2.new(input2.size()) 167 | n = grad_input2.numel() 168 | opt['nthreads'] = n 169 | f = load_kernel('subtraction2_zeropad_input2_backward_kernel', _subtraction2_zeropad_input2_backward_kernel, **opt) 170 | f(block=(CUDA_NUM_THREADS, 1, 1), 171 | grid=(GET_BLOCKS(n), 1, 1), 172 | args=[grad_output.data_ptr(), grad_input2.data_ptr()], 173 | stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) 174 | return grad_input1, grad_input2, None, None, None, None 175 | 176 | 177 | def subtraction2_zeropad(input1, input2, kernel_size=3, stride=1, padding=0, dilation=1): 178 | assert input1.dim() == 4 179 | if input1.is_cuda: 180 | out = Subtraction2Zeropad.apply(input1, input2, kernel_size, stride, padding, dilation) 181 | else: 182 | raise NotImplementedError 183 | return out 184 | 185 | 186 | def test_subtraction2_zeropad(): 187 | import os 188 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 189 | kernel_size, stride, dilation = 5, 4, 2 190 | padding = (dilation * (kernel_size - 1) + 1) // 2 191 | n, c, in_height, in_width = 2, 8, 9, 9 192 | out_height = int((in_height + 2 * padding - (dilation * (kernel_size - 1) + 1)) / stride + 1) 193 | out_width = int((in_width + 2 * padding - (dilation * (kernel_size - 1) + 1)) / stride + 1) 194 | x1 = torch.randn(n, c, in_height, in_width, requires_grad=True).double().cuda() 195 | x2 = torch.randn(n, c, in_height, in_width, requires_grad=True).double().cuda() 196 | 197 | y1 = subtraction2_zeropad(x1, x2, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation) 198 | unfold_i = torch.nn.Unfold(kernel_size=1, dilation=dilation, padding=0, stride=stride) 199 | unfold_j = torch.nn.Unfold(kernel_size=kernel_size, dilation=dilation, padding=padding, stride=stride) 200 | y2 = unfold_i(x1).view(n, c, 1, out_height * out_width) - unfold_j(x2).view(n, c, pow(kernel_size, 2), out_height * out_width) 201 | # y2 = unfold_i(x[:, :, kernel_size//2:-(kernel_size//2), kernel_size//2:-(kernel_size//2)]).view(n, c, 1, out_height * out_width) - unfold_j(x).view(n, c, pow(kernel_size, 2), out_height * out_width) 202 | assert (y1 - y2).abs().max() < 1e-9 203 | 204 | gx11 = torch.autograd.grad(y1.mean(), x1, retain_graph=True)[0] 205 | gx12 = torch.autograd.grad(y1.mean(), x2, retain_graph=True)[0] 206 | gx21 = torch.autograd.grad(y2.mean(), x1, retain_graph=True)[0] 207 | gx22 = torch.autograd.grad(y2.mean(), x2, retain_graph=True)[0] 208 | assert (gx11 - gx21).abs().max() < 1e-9 209 | assert (gx12 - gx22).abs().max() < 1e-9 210 | 211 | from functools import partial 212 | assert torch.autograd.gradcheck(partial(subtraction2_zeropad, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation), (x1, x2)) 213 | print('test case passed') 214 | 215 | 216 | if __name__ == '__main__': 217 | test_subtraction2_zeropad() 218 | -------------------------------------------------------------------------------- /lib/sa/functions/aggregation_zeropad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from torch.nn.modules.utils import _pair 4 | 5 | from lib.sa.functions.utils import Dtype, Stream, load_kernel 6 | 7 | 8 | CUDA_NUM_THREADS = 1024 9 | 10 | kernel_loop = ''' 11 | #define CUDA_KERNEL_LOOP(i, n) \ 12 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ 13 | i < (n); \ 14 | i += blockDim.x * gridDim.x) 15 | ''' 16 | 17 | 18 | def GET_BLOCKS(N): 19 | return (N + CUDA_NUM_THREADS - 1) // CUDA_NUM_THREADS 20 | 21 | 22 | _aggregation_zeropad_forward_kernel = kernel_loop + ''' 23 | extern "C" 24 | __global__ void aggregation_zeropad_forward_kernel( 25 | const ${Dtype}* bottom_data, const ${Dtype}* weight_data, ${Dtype}* top_data) { 26 | CUDA_KERNEL_LOOP(index, ${nthreads}) { 27 | const int n = index / ${input_channels} / ${top_height} / ${top_width}; 28 | const int c = (index / ${top_height} / ${top_width}) % ${input_channels}; 29 | const int h = (index / ${top_width}) % ${top_height}; 30 | const int w = index % ${top_width}; 31 | ${Dtype} value = 0; 32 | for (int kh = 0; kh < ${kernel_h}; ++kh) { 33 | for (int kw = 0; kw < ${kernel_w}; ++kw) { 34 | const int h_in = -${pad_h} + h * ${stride_h} + kh * ${dilation_h}; 35 | const int w_in = -${pad_w} + w * ${stride_w} + kw * ${dilation_w}; 36 | if ((h_in >= 0) && (h_in < ${bottom_height}) && (w_in >= 0) && (w_in < ${bottom_width})) { 37 | const int offset_bottom = ((n * ${input_channels} + c) * ${bottom_height} + h_in) * ${bottom_width} + w_in; 38 | const int offset_weight = ((n * ${weight_channels} + c % ${weight_channels}) * ${kernel_h} * ${kernel_w} + (kh * ${kernel_w} + kw)) * ${top_height} * ${top_width} + h * ${top_width} + w; 39 | value += weight_data[offset_weight] * bottom_data[offset_bottom]; 40 | } 41 | } 42 | } 43 | top_data[index] = value; 44 | } 45 | } 46 | ''' 47 | 48 | 49 | _aggregation_zeropad_input_backward_kernel = kernel_loop + ''' 50 | extern "C" 51 | __global__ void aggregation_zeropad_input_backward_kernel( 52 | const ${Dtype}* const top_diff, const ${Dtype}* const weight_data, ${Dtype}* bottom_diff) { 53 | CUDA_KERNEL_LOOP(index, ${nthreads}) { 54 | const int n = index / ${input_channels} / ${bottom_height} / ${bottom_width}; 55 | const int c = (index / ${bottom_height} / ${bottom_width}) % ${input_channels}; 56 | const int h = (index / ${bottom_width}) % ${bottom_height}; 57 | const int w = index % ${bottom_width}; 58 | ${Dtype} value = 0; 59 | for (int kh = 0; kh < ${kernel_h}; ++kh) { 60 | for (int kw = 0; kw < ${kernel_w}; ++kw) { 61 | const int h_out_s = h + ${pad_h} - kh * ${dilation_h}; 62 | const int w_out_s = w + ${pad_w} - kw * ${dilation_w}; 63 | if (((h_out_s % ${stride_h}) == 0) && ((w_out_s % ${stride_w}) == 0)) { 64 | const int h_out = h_out_s / ${stride_h}; 65 | const int w_out = w_out_s / ${stride_w}; 66 | if ((h_out >= 0) && (h_out < ${top_height}) && (w_out >= 0) && (w_out < ${top_width})) { 67 | const int offset_top = ((n * ${input_channels} + c) * ${top_height} + h_out) * ${top_width} + w_out; 68 | const int offset_weight = ((n * ${weight_channels} + c % ${weight_channels}) * ${kernel_h} * ${kernel_w} + (kh * ${kernel_w} + kw)) * ${top_height} * ${top_width} + h_out * ${top_width} + w_out; 69 | value += weight_data[offset_weight] * top_diff[offset_top]; 70 | } 71 | } 72 | } 73 | } 74 | bottom_diff[index] = value; 75 | } 76 | } 77 | ''' 78 | 79 | 80 | _aggregation_zeropad_weight_backward_kernel = kernel_loop + ''' 81 | extern "C" 82 | __global__ void aggregation_zeropad_weight_backward_kernel( 83 | const ${Dtype}* const top_diff, const ${Dtype}* const bottom_data, ${Dtype}* weight_diff) { 84 | CUDA_KERNEL_LOOP(index, ${nthreads}) { 85 | const int n = index / ${weight_channels} / ${top_height} / ${top_width}; 86 | const int c = (index / ${top_height} / ${top_width}) % ${weight_channels}; 87 | const int h = (index / ${top_width}) % ${top_height}; 88 | const int w = index % ${top_width}; 89 | for (int kh = 0; kh < ${kernel_h}; ++kh) { 90 | for (int kw = 0; kw < ${kernel_w}; ++kw) { 91 | const int h_in = -${pad_h} + h * ${stride_h} + kh * ${dilation_h}; 92 | const int w_in = -${pad_w} + w * ${stride_w} + kw * ${dilation_w}; 93 | const int offset_weight = ((n * ${weight_channels} + c) * ${kernel_h} * ${kernel_w} + (kh * ${kernel_w} + kw)) * ${top_height} * ${top_width} + h * ${top_width} + w; 94 | ${Dtype} value = 0; 95 | if ((h_in >= 0) && (h_in < ${bottom_height}) && (w_in >= 0) && (w_in < ${bottom_width})) { 96 | for (int cc = c; cc < ${input_channels}; cc += ${weight_channels}) { 97 | const int offset_bottom = ((n * ${input_channels} + cc) * ${bottom_height} + h_in) * ${bottom_width} + w_in; 98 | const int offset_top = ((n * ${input_channels} + cc) * ${top_height} + h) * ${top_width} + w; 99 | value += bottom_data[offset_bottom] * top_diff[offset_top]; 100 | } 101 | } 102 | weight_diff[offset_weight] = value; 103 | } 104 | } 105 | } 106 | } 107 | ''' 108 | 109 | 110 | class AggregationZeropad(Function): 111 | @staticmethod 112 | def forward(ctx, input, weight, kernel_size, stride, padding, dilation): 113 | kernel_size, stride, padding, dilation = _pair(kernel_size), _pair(stride), _pair(padding), _pair(dilation) 114 | ctx.kernel_size, ctx.stride, ctx.padding, ctx.dilation = kernel_size, stride, padding, dilation 115 | assert input.dim() == 4 and input.is_cuda and weight.is_cuda 116 | batch_size, input_channels, input_height, input_width = input.size() 117 | _, weight_channels, weight_height, weight_width = weight.size() 118 | output_height = int((input_height + 2 * padding[0] - (dilation[0] * (kernel_size[0] - 1) + 1)) / stride[0] + 1) 119 | output_width = int((input_width + 2 * padding[1] - (dilation[1] * (kernel_size[1] - 1) + 1)) / stride[1] + 1) 120 | assert output_height * output_width == weight_width 121 | output = input.new(batch_size, input_channels, output_height, output_width) 122 | n = output.numel() 123 | with torch.cuda.device_of(input): 124 | f = load_kernel('aggregation_zeropad_forward_kernel', _aggregation_zeropad_forward_kernel, Dtype=Dtype(input), nthreads=n, 125 | num=batch_size, input_channels=input_channels, weight_channels=weight_channels, 126 | bottom_height=input_height, bottom_width=input_width, 127 | top_height=output_height, top_width=output_width, 128 | kernel_h=kernel_size[0], kernel_w=kernel_size[1], 129 | stride_h=stride[0], stride_w=stride[1], 130 | dilation_h=dilation[0], dilation_w=dilation[1], 131 | pad_h=padding[0], pad_w=padding[1]) 132 | f(block=(CUDA_NUM_THREADS, 1, 1), 133 | grid=(GET_BLOCKS(n), 1, 1), 134 | args=[input.data_ptr(), weight.data_ptr(), output.data_ptr()], 135 | stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) 136 | ctx.save_for_backward(input, weight) 137 | return output 138 | 139 | @staticmethod 140 | def backward(ctx, grad_output): 141 | kernel_size, stride, padding, dilation = ctx.kernel_size, ctx.stride, ctx.padding, ctx.dilation 142 | input, weight = ctx.saved_tensors 143 | assert grad_output.is_cuda 144 | if not grad_output.is_contiguous(): 145 | grad_output = grad_output.contiguous() 146 | batch_size, input_channels, input_height, input_width = input.size() 147 | _, weight_channels, weight_height, weight_width = weight.size() 148 | output_height, output_width = grad_output.size()[2:] 149 | grad_input, grad_weight = None, None 150 | opt = dict(Dtype=Dtype(grad_output), 151 | num=batch_size, input_channels=input_channels, weight_channels=weight_channels, 152 | bottom_height=input_height, bottom_width=input_width, 153 | top_height=output_height, top_width=output_width, 154 | kernel_h=kernel_size[0], kernel_w=kernel_size[1], 155 | stride_h=stride[0], stride_w=stride[1], 156 | dilation_h=dilation[0], dilation_w=dilation[1], 157 | pad_h=padding[0], pad_w=padding[1]) 158 | with torch.cuda.device_of(input): 159 | if ctx.needs_input_grad[0]: 160 | grad_input = input.new(input.size()) 161 | n = grad_input.numel() 162 | opt['nthreads'] = n 163 | f = load_kernel('aggregation_zeropad_input_backward_kernel', _aggregation_zeropad_input_backward_kernel, **opt) 164 | f(block=(CUDA_NUM_THREADS, 1, 1), 165 | grid=(GET_BLOCKS(n), 1, 1), 166 | args=[grad_output.data_ptr(), weight.data_ptr(), grad_input.data_ptr()], 167 | stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) 168 | if ctx.needs_input_grad[1]: 169 | grad_weight = weight.new(weight.size()) 170 | n = grad_weight.numel() // weight.shape[2] 171 | opt['nthreads'] = n 172 | f = load_kernel('aggregation_zeropad_weight_backward_kernel', _aggregation_zeropad_weight_backward_kernel, **opt) 173 | f(block=(CUDA_NUM_THREADS, 1, 1), 174 | grid=(GET_BLOCKS(n), 1, 1), 175 | args=[grad_output.data_ptr(), input.data_ptr(), grad_weight.data_ptr()], 176 | stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) 177 | return grad_input, grad_weight, None, None, None, None 178 | 179 | 180 | def aggregation_zeropad(input, weight, kernel_size=3, stride=1, padding=0, dilation=1): 181 | assert input.shape[0] == weight.shape[0] and (input.shape[1] % weight.shape[1] == 0) 182 | if input.is_cuda: 183 | out = AggregationZeropad.apply(input, weight, kernel_size, stride, padding, dilation) 184 | else: 185 | raise NotImplementedError 186 | return out 187 | 188 | 189 | def test_aggregation_zeropad(): 190 | import os 191 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 192 | kernel_size, stride, dilation = 5, 4, 2 193 | padding = (dilation * (kernel_size - 1) + 1) // 2 194 | n, c_x, c_w, in_height, in_width = 2, 8, 4, 9, 9 195 | out_height = int((in_height + 2 * padding - (dilation * (kernel_size - 1) + 1)) / stride + 1) 196 | out_width = int((in_width + 2 * padding - (dilation * (kernel_size - 1) + 1)) / stride + 1) 197 | x = torch.randn(n, c_x, in_height, in_width, requires_grad=True).double().cuda() 198 | w = torch.randn(n, c_w, pow(kernel_size, 2), out_height * out_width, requires_grad=True).double().cuda() 199 | 200 | y1 = aggregation_zeropad(x, w, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation) 201 | unfold_j = torch.nn.Unfold(kernel_size=kernel_size, dilation=dilation, padding=padding, stride=stride) 202 | x2 = unfold_j(x).view(n, c_x // c_w, c_w, pow(kernel_size, 2), out_height * out_width) 203 | y2 = (w.unsqueeze(1) * x2).sum(-2).view(n, c_x, out_height, out_width) 204 | assert (y1 - y2).abs().max() < 1e-9 205 | 206 | gx1 = torch.autograd.grad(y1.mean(), x, retain_graph=True)[0] 207 | gx2 = torch.autograd.grad(y2.mean(), x, retain_graph=True)[0] 208 | assert (gx1 - gx2).abs().max() < 1e-9 209 | 210 | gw1 = torch.autograd.grad(y1.mean(), w, retain_graph=True)[0] 211 | gw2 = torch.autograd.grad(y2.mean(), w, retain_graph=True)[0] 212 | assert (gw1 - gw2).abs().max() < 1e-9 213 | 214 | from functools import partial 215 | assert torch.autograd.gradcheck(partial(aggregation_zeropad, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation), (x, w)) 216 | print('test case passed') 217 | 218 | 219 | if __name__ == '__main__': 220 | test_aggregation_zeropad() 221 | -------------------------------------------------------------------------------- /lib/sa/functions/subtraction2_refpad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from torch.nn.modules.utils import _pair 4 | 5 | from lib.sa.functions.utils import Dtype, Stream, load_kernel 6 | 7 | 8 | CUDA_NUM_THREADS = 1024 9 | 10 | kernel_loop = ''' 11 | #define CUDA_KERNEL_LOOP(i, n) \ 12 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ 13 | i < (n); \ 14 | i += blockDim.x * gridDim.x) 15 | ''' 16 | 17 | 18 | def GET_BLOCKS(N): 19 | return (N + CUDA_NUM_THREADS - 1) // CUDA_NUM_THREADS 20 | 21 | 22 | _subtraction2_refpad_forward_kernel = kernel_loop + ''' 23 | extern "C" 24 | __global__ void subtraction2_refpad_forward_kernel( 25 | const ${Dtype}* bottom1_data, const ${Dtype}* bottom2_data, ${Dtype}* top_data) { 26 | CUDA_KERNEL_LOOP(index, ${nthreads}) { 27 | const int n = index / ${input_channels} / ${top_height} / ${top_width}; 28 | const int c = (index / ${top_height} / ${top_width}) % ${input_channels}; 29 | const int h = (index / ${top_width}) % ${top_height}; 30 | const int w = index % ${top_width}; 31 | const int h_in_center = -${pad_h} + h * ${stride_h} + (${kernel_h} - 1) / 2 * ${dilation_h}; 32 | const int w_in_center = -${pad_w} + w * ${stride_w} + (${kernel_w} - 1) / 2 * ${dilation_w}; 33 | const int offset_center = ((n * ${input_channels} + c) * ${bottom_height} + h_in_center) * ${bottom_width} + w_in_center; 34 | for (int kh = 0; kh < ${kernel_h}; ++kh) { 35 | for (int kw = 0; kw < ${kernel_w}; ++kw) { 36 | int h_in = -${pad_h} + h * ${stride_h} + kh * ${dilation_h}; 37 | int w_in = -${pad_w} + w * ${stride_w} + kw * ${dilation_w}; 38 | const int offset_top = ((n * ${input_channels} + c) * ${kernel_h} * ${kernel_w} + (kh * ${kernel_w} + kw)) * ${top_height} * ${top_width} + h * ${top_width} + w; 39 | int offset_bottom; 40 | if ((h_in >= 0) && (h_in < ${bottom_height}) && (w_in >= 0) && (w_in < ${bottom_width})) { 41 | offset_bottom = ((n * ${input_channels} + c) * ${bottom_height} + h_in) * ${bottom_width} + w_in; 42 | } 43 | else { 44 | if (h_in < 0) h_in = -h_in; 45 | if (h_in >= ${bottom_height}) h_in = 2 * (${bottom_height} - 1) - h_in; 46 | if (w_in < 0) w_in = -w_in; 47 | if (w_in >= ${bottom_width}) w_in = 2 * (${bottom_width} - 1) - w_in; 48 | offset_bottom = ((n * ${input_channels} + c) * ${bottom_height} + h_in) * ${bottom_width} + w_in; 49 | } 50 | top_data[offset_top] = bottom1_data[offset_center] - bottom2_data[offset_bottom]; 51 | } 52 | } 53 | } 54 | } 55 | ''' 56 | 57 | 58 | _subtraction2_refpad_input1_backward_kernel = kernel_loop + ''' 59 | extern "C" 60 | __global__ void subtraction2_refpad_input1_backward_kernel( 61 | const ${Dtype}* const top_diff, ${Dtype}* bottom_diff) { 62 | CUDA_KERNEL_LOOP(index, ${nthreads}) { 63 | const int n = index / ${input_channels} / ${bottom_height} / ${bottom_width}; 64 | const int c = (index / ${bottom_height} / ${bottom_width}) % ${input_channels}; 65 | const int h = (index / ${bottom_width}) % ${bottom_height}; 66 | const int w = index % ${bottom_width}; 67 | ${Dtype} value = 0; 68 | if (((h % ${stride_h}) == 0) && ((w % ${stride_w}) == 0)) { 69 | const int h_out = h / ${stride_h}; 70 | const int w_out = w / ${stride_w}; 71 | for (int kh = 0; kh < ${kernel_h}; ++kh) { 72 | for (int kw = 0; kw < ${kernel_w}; ++kw) { 73 | const int offset_top = ((n * ${input_channels} + c) * ${kernel_h} * ${kernel_w} + (kh * ${kernel_w} + kw)) * ${top_height} * ${top_width} + h_out * ${top_width} + w_out; 74 | value += top_diff[offset_top]; 75 | } 76 | } 77 | } 78 | bottom_diff[index] = value; 79 | } 80 | } 81 | ''' 82 | 83 | 84 | _subtraction2_refpad_input2_backward_kernel = kernel_loop + ''' 85 | extern "C" 86 | __global__ void subtraction2_refpad_input2_backward_kernel( 87 | const ${Dtype}* const top_diff, ${Dtype}* bottom_diff) { 88 | CUDA_KERNEL_LOOP(index, ${nthreads}) { 89 | const int n = index / ${input_channels} / (${bottom_height} + 2 * ${pad_h}) / (${bottom_width} + 2 * ${pad_w}); 90 | const int c = (index / (${bottom_height} + 2 * ${pad_h}) / (${bottom_width} + 2 * ${pad_w})) % ${input_channels}; 91 | const int h = (index / (${bottom_width} + 2 * ${pad_w})) % (${bottom_height} + 2 * ${pad_h}); 92 | const int w = index % (${bottom_width} + 2 * ${pad_w}); 93 | ${Dtype} value = 0; 94 | for (int kh = 0; kh < ${kernel_h}; ++kh) { 95 | for (int kw = 0; kw < ${kernel_w}; ++kw) { 96 | const int h_out_s = h - kh * ${dilation_h}; 97 | const int w_out_s = w - kw * ${dilation_w}; 98 | if (((h_out_s % ${stride_h}) == 0) && ((w_out_s % ${stride_w}) == 0)) { 99 | const int h_out = h_out_s / ${stride_h}; 100 | const int w_out = w_out_s / ${stride_w}; 101 | if ((h_out >= 0) && (h_out < ${top_height}) && (w_out >= 0) && (w_out < ${top_width})) { 102 | const int offset_top = ((n * ${input_channels} + c) * ${kernel_h} * ${kernel_w} + (kh * ${kernel_w} + kw)) * ${top_height} * ${top_width} + h_out * ${top_width} + w_out; 103 | value += -top_diff[offset_top]; 104 | } 105 | } 106 | } 107 | } 108 | bottom_diff[index] = value; 109 | } 110 | } 111 | ''' 112 | 113 | 114 | class Subtraction2Refpad(Function): 115 | @staticmethod 116 | def forward(ctx, input1, input2, kernel_size, stride, padding, dilation): 117 | kernel_size, stride, padding, dilation = _pair(kernel_size), _pair(stride), _pair(padding), _pair(dilation) 118 | ctx.kernel_size, ctx.stride, ctx.padding, ctx.dilation = kernel_size, stride, padding, dilation 119 | assert input1.dim() == 4 and input1.is_cuda 120 | batch_size, input_channels, input_height, input_width = input1.size() 121 | output_height = int((input_height + 2 * padding[0] - (dilation[0] * (kernel_size[0] - 1) + 1)) / stride[0] + 1) 122 | output_width = int((input_width + 2 * padding[1] - (dilation[1] * (kernel_size[1] - 1) + 1)) / stride[1] + 1) 123 | output = input1.new(batch_size, input_channels, kernel_size[0] * kernel_size[1], output_height * output_width) 124 | n = output.numel() // output.shape[2] 125 | with torch.cuda.device_of(input1): 126 | f = load_kernel('subtraction2_refpad_forward_kernel', _subtraction2_refpad_forward_kernel, Dtype=Dtype(input1), nthreads=n, 127 | num=batch_size, input_channels=input_channels, 128 | bottom_height=input_height, bottom_width=input_width, 129 | top_height=output_height, top_width=output_width, 130 | kernel_h=kernel_size[0], kernel_w=kernel_size[1], 131 | stride_h=stride[0], stride_w=stride[1], 132 | dilation_h=dilation[0], dilation_w=dilation[1], 133 | pad_h=padding[0], pad_w=padding[1]) 134 | f(block=(CUDA_NUM_THREADS, 1, 1), 135 | grid=(GET_BLOCKS(n), 1, 1), 136 | args=[input1.data_ptr(), input2.data_ptr(), output.data_ptr()], 137 | stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) 138 | ctx.save_for_backward(input1, input2) 139 | return output 140 | 141 | @staticmethod 142 | def backward(ctx, grad_output): 143 | kernel_size, stride, padding, dilation = ctx.kernel_size, ctx.stride, ctx.padding, ctx.dilation 144 | input1, input2 = ctx.saved_tensors 145 | assert grad_output.is_cuda 146 | if not grad_output.is_contiguous(): 147 | grad_output = grad_output.contiguous() 148 | batch_size, input_channels, input_height, input_width = input1.size() 149 | output_height = int((input_height + 2 * padding[0] - (dilation[0] * (kernel_size[0] - 1) + 1)) / stride[0] + 1) 150 | output_width = int((input_width + 2 * padding[1] - (dilation[1] * (kernel_size[1] - 1) + 1)) / stride[1] + 1) 151 | grad_input1, grad_input2 = None, None 152 | opt = dict(Dtype=Dtype(grad_output), 153 | num=batch_size, input_channels=input_channels, 154 | bottom_height=input_height, bottom_width=input_width, 155 | top_height=output_height, top_width=output_width, 156 | kernel_h=kernel_size[0], kernel_w=kernel_size[1], 157 | stride_h=stride[0], stride_w=stride[1], 158 | dilation_h=dilation[0], dilation_w=dilation[1], 159 | pad_h=padding[0], pad_w=padding[1]) 160 | with torch.cuda.device_of(input1): 161 | if ctx.needs_input_grad[0]: 162 | grad_input1 = input1.new(input1.size()) 163 | n = grad_input1.numel() 164 | opt['nthreads'] = n 165 | f = load_kernel('subtraction2_refpad_input1_backward_kernel', _subtraction2_refpad_input1_backward_kernel, **opt) 166 | f(block=(CUDA_NUM_THREADS, 1, 1), 167 | grid=(GET_BLOCKS(n), 1, 1), 168 | args=[grad_output.data_ptr(), grad_input1.data_ptr()], 169 | stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) 170 | with torch.cuda.device_of(input2): 171 | if ctx.needs_input_grad[1]: 172 | grad_input2 = input2.new(batch_size, input_channels, input_height + 2 * padding[0], input_width + 2 * padding[1]) 173 | n = grad_input2.numel() 174 | opt['nthreads'] = n 175 | f = load_kernel('subtraction2_refpad_input2_backward_kernel', _subtraction2_refpad_input2_backward_kernel, **opt) 176 | f(block=(CUDA_NUM_THREADS, 1, 1), 177 | grid=(GET_BLOCKS(n), 1, 1), 178 | args=[grad_output.data_ptr(), grad_input2.data_ptr()], 179 | stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) 180 | grad_input2[:, :, padding[0] + 1:2 * padding[0] + 1, :] += torch.flip(grad_input2[:, :, :padding[0], :], dims=[2]) 181 | grad_input2[:, :, input_height - 1:input_height + padding[0] - 1, :] += torch.flip(grad_input2[:, :, input_height + padding[0]:, :], dims=[2]) 182 | grad_input2[:, :, :, padding[1] + 1:2 * padding[1] + 1] += torch.flip(grad_input2[:, :, :, :padding[1]], dims=[3]) 183 | grad_input2[:, :, :, input_width - 1:input_width + padding[1] - 1] += torch.flip(grad_input2[:, :, :, input_width + padding[1]:], dims=[3]) 184 | grad_input2 = grad_input2[:, :, padding[0]:padding[0] + input_height, padding[1]:padding[1] + input_width] 185 | return grad_input1, grad_input2, None, None, None, None 186 | 187 | 188 | def subtraction2_refpad(input1, input2, kernel_size=3, stride=1, padding=0, dilation=1): 189 | assert input1.dim() == 4 190 | if input1.is_cuda: 191 | out = Subtraction2Refpad.apply(input1, input2, kernel_size, stride, padding, dilation) 192 | else: 193 | raise NotImplementedError 194 | return out 195 | 196 | 197 | def test_subtraction2_refpad(): 198 | import os 199 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 200 | kernel_size, stride, dilation = 5, 4, 2 # 3, 1, 1 201 | padding = (dilation * (kernel_size - 1) + 1) // 2 202 | n, c, in_height, in_width = 2, 8, 9, 9 203 | out_height = int((in_height + 2 * padding - (dilation * (kernel_size - 1) + 1)) / stride + 1) 204 | out_width = int((in_width + 2 * padding - (dilation * (kernel_size - 1) + 1)) / stride + 1) 205 | x1 = torch.randn(n, c, in_height, in_width, requires_grad=True).double().cuda() 206 | x2 = torch.randn(n, c, in_height, in_width, requires_grad=True).double().cuda() 207 | 208 | y1 = subtraction2_refpad(x1, x2, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation) 209 | unfold_i = torch.nn.Unfold(kernel_size=1, dilation=dilation, padding=0, stride=stride) 210 | unfold_j = torch.nn.Unfold(kernel_size=kernel_size, dilation=dilation, padding=0, stride=stride) 211 | pad = torch.nn.ReflectionPad2d(padding) 212 | y2 = unfold_i(x1).view(n, c, 1, out_height * out_width) - unfold_j(pad(x2)).view(n, c, pow(kernel_size, 2), out_height * out_width) 213 | assert (y1 - y2).abs().max() < 1e-9 214 | 215 | gx11 = torch.autograd.grad(y1.mean(), x1, retain_graph=True)[0] 216 | gx12 = torch.autograd.grad(y1.mean(), x2, retain_graph=True)[0] 217 | gx21 = torch.autograd.grad(y2.mean(), x1, retain_graph=True)[0] 218 | gx22 = torch.autograd.grad(y2.mean(), x2, retain_graph=True)[0] 219 | assert (gx11 - gx21).abs().max() < 1e-9 220 | assert (gx12 - gx22).abs().max() < 1e-9 221 | 222 | from functools import partial 223 | assert torch.autograd.gradcheck(partial(subtraction2_refpad, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation), (x1, x2)) 224 | print('test case passed') 225 | 226 | if __name__ == '__main__': 227 | test_subtraction2_refpad() 228 | -------------------------------------------------------------------------------- /lib/sa/functions/aggregation_refpad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from torch.nn.modules.utils import _pair 4 | 5 | from lib.sa.functions.utils import Dtype, Stream, load_kernel 6 | 7 | 8 | CUDA_NUM_THREADS = 1024 9 | 10 | kernel_loop = ''' 11 | #define CUDA_KERNEL_LOOP(i, n) \ 12 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ 13 | i < (n); \ 14 | i += blockDim.x * gridDim.x) 15 | ''' 16 | 17 | 18 | def GET_BLOCKS(N): 19 | return (N + CUDA_NUM_THREADS - 1) // CUDA_NUM_THREADS 20 | 21 | 22 | _aggregation_refpad_forward_kernel = kernel_loop + ''' 23 | extern "C" 24 | __global__ void aggregation_refpad_forward_kernel( 25 | const ${Dtype}* bottom_data, const ${Dtype}* weight_data, ${Dtype}* top_data) { 26 | CUDA_KERNEL_LOOP(index, ${nthreads}) { 27 | const int n = index / ${input_channels} / ${top_height} / ${top_width}; 28 | const int c = (index / ${top_height} / ${top_width}) % ${input_channels}; 29 | const int h = (index / ${top_width}) % ${top_height}; 30 | const int w = index % ${top_width}; 31 | ${Dtype} value = 0; 32 | for (int kh = 0; kh < ${kernel_h}; ++kh) { 33 | for (int kw = 0; kw < ${kernel_w}; ++kw) { 34 | int h_in = -${pad_h} + h * ${stride_h} + kh * ${dilation_h}; 35 | int w_in = -${pad_w} + w * ${stride_w} + kw * ${dilation_w}; 36 | const int offset_weight = ((n * ${weight_channels} + c % ${weight_channels}) * ${kernel_h} * ${kernel_w} + (kh * ${kernel_w} + kw)) * ${top_height} * ${top_width} + h * ${top_width} + w; 37 | int offset_bottom; 38 | if ((h_in >= 0) && (h_in < ${bottom_height}) && (w_in >= 0) && (w_in < ${bottom_width})) { 39 | offset_bottom = ((n * ${input_channels} + c) * ${bottom_height} + h_in) * ${bottom_width} + w_in; 40 | } 41 | else { 42 | if (h_in < 0) h_in = -h_in; 43 | if (h_in >= ${bottom_height}) h_in = 2 * (${bottom_height} - 1) - h_in; 44 | if (w_in < 0) w_in = -w_in; 45 | if (w_in >= ${bottom_width}) w_in = 2 * (${bottom_width} - 1) - w_in; 46 | offset_bottom = ((n * ${input_channels} + c) * ${bottom_height} + h_in) * ${bottom_width} + w_in; 47 | } 48 | value += weight_data[offset_weight] * bottom_data[offset_bottom]; 49 | } 50 | } 51 | top_data[index] = value; 52 | } 53 | } 54 | ''' 55 | 56 | 57 | _aggregation_refpad_input_backward_kernel = kernel_loop + ''' 58 | extern "C" 59 | __global__ void aggregation_refpad_input_backward_kernel( 60 | const ${Dtype}* const top_diff, const ${Dtype}* const weight_data, ${Dtype}* bottom_diff) { 61 | CUDA_KERNEL_LOOP(index, ${nthreads}) { 62 | const int n = index / ${input_channels} / (${bottom_height} + 2 * ${pad_h}) / (${bottom_width} + 2 * ${pad_w}); 63 | const int c = (index / (${bottom_height} + 2 * ${pad_h}) / (${bottom_width} + 2 * ${pad_w})) % ${input_channels}; 64 | const int h = (index / (${bottom_width} + 2 * ${pad_w})) % (${bottom_height} + 2 * ${pad_h}); 65 | const int w = index % (${bottom_width} + 2 * ${pad_w}); 66 | ${Dtype} value = 0; 67 | for (int kh = 0; kh < ${kernel_h}; ++kh) { 68 | for (int kw = 0; kw < ${kernel_w}; ++kw) { 69 | const int h_out_s = h - kh * ${dilation_h}; 70 | const int w_out_s = w - kw * ${dilation_w}; 71 | if ((h_out_s % ${stride_h} == 0) && (w_out_s % ${stride_w} == 0)) { 72 | const int h_out = h_out_s / ${stride_h}; 73 | const int w_out = w_out_s / ${stride_w}; 74 | if ((h_out >= 0) && (h_out < ${top_height}) && (w_out >= 0) && (w_out < ${top_width})) { 75 | const int offset_top = ((n * ${input_channels} + c) * ${top_height} + h_out) * ${top_width} + w_out; 76 | const int offset_weight = ((n * ${weight_channels} + c % ${weight_channels}) * ${kernel_h} * ${kernel_w} + (kh * ${kernel_w} + kw)) * ${top_height} * ${top_width} + h_out * ${top_width} + w_out; 77 | value += weight_data[offset_weight] * top_diff[offset_top]; 78 | } 79 | } 80 | } 81 | } 82 | bottom_diff[index] = value; 83 | } 84 | } 85 | ''' 86 | 87 | 88 | _aggregation_refpad_weight_backward_kernel = kernel_loop + ''' 89 | extern "C" 90 | __global__ void aggregation_refpad_weight_backward_kernel( 91 | const ${Dtype}* const top_diff, const ${Dtype}* const bottom_data, ${Dtype}* weight_diff) { 92 | CUDA_KERNEL_LOOP(index, ${nthreads}) { 93 | const int n = index / ${weight_channels} / ${top_height} / ${top_width}; 94 | const int c = (index / ${top_height} / ${top_width}) % ${weight_channels}; 95 | const int h = (index / ${top_width}) % ${top_height}; 96 | const int w = index % ${top_width}; 97 | for (int kh = 0; kh < ${kernel_h}; ++kh) { 98 | for (int kw = 0; kw < ${kernel_w}; ++kw) { 99 | int h_in = -${pad_h} + h * ${stride_h} + kh * ${dilation_h}; 100 | int w_in = -${pad_w} + w * ${stride_w} + kw * ${dilation_w}; 101 | const int offset_weight = ((n * ${weight_channels} + c) * ${kernel_h} * ${kernel_w} + (kh * ${kernel_w} + kw)) * ${top_height} * ${top_width} + h * ${top_width} + w; 102 | ${Dtype} value = 0; 103 | for (int cc = c; cc < ${input_channels}; cc += ${weight_channels}) { 104 | const int offset_top = ((n * ${input_channels} + cc) * ${top_height} + h) * ${top_width} + w; 105 | int offset_bottom; 106 | if ((h_in >= 0) && (h_in < ${bottom_height}) && (w_in >= 0) && (w_in < ${bottom_width})) { 107 | offset_bottom = ((n * ${input_channels} + cc) * ${bottom_height} + h_in) * ${bottom_width} + w_in; 108 | } 109 | else { 110 | if (h_in < 0) h_in = -h_in; 111 | if (h_in >= ${bottom_height}) h_in = 2 * (${bottom_height} - 1) - h_in; 112 | if (w_in < 0) w_in = -w_in; 113 | if (w_in >= ${bottom_width}) w_in = 2 * (${bottom_width} - 1) - w_in; 114 | offset_bottom = ((n * ${input_channels} + cc) * ${bottom_height} + h_in) * ${bottom_width} + w_in; 115 | } 116 | value += bottom_data[offset_bottom] * top_diff[offset_top]; 117 | } 118 | weight_diff[offset_weight] = value; 119 | } 120 | } 121 | } 122 | } 123 | ''' 124 | 125 | 126 | class AggregationRefpad(Function): 127 | @staticmethod 128 | def forward(ctx, input, weight, kernel_size, stride, padding, dilation): 129 | kernel_size, stride, padding, dilation = _pair(kernel_size), _pair(stride), _pair(padding), _pair(dilation) 130 | ctx.kernel_size, ctx.stride, ctx.padding, ctx.dilation = kernel_size, stride, padding, dilation 131 | assert input.dim() == 4 and input.is_cuda and weight.is_cuda 132 | batch_size, input_channels, input_height, input_width = input.size() 133 | _, weight_channels, weight_height, weight_width = weight.size() 134 | output_height = int((input_height + 2 * padding[0] - (dilation[0] * (kernel_size[0] - 1) + 1)) / stride[0] + 1) 135 | output_width = int((input_width + 2 * padding[1] - (dilation[1] * (kernel_size[1] - 1) + 1)) / stride[1] + 1) 136 | assert output_height * output_width == weight_width 137 | output = input.new(batch_size, input_channels, output_height, output_width) 138 | n = output.numel() 139 | with torch.cuda.device_of(input): 140 | f = load_kernel('aggregation_refpad_forward_kernel', _aggregation_refpad_forward_kernel, Dtype=Dtype(input), nthreads=n, 141 | num=batch_size, input_channels=input_channels, weight_channels=weight_channels, 142 | bottom_height=input_height, bottom_width=input_width, 143 | top_height=output_height, top_width=output_width, 144 | kernel_h=kernel_size[0], kernel_w=kernel_size[1], 145 | stride_h=stride[0], stride_w=stride[1], 146 | dilation_h=dilation[0], dilation_w=dilation[1], 147 | pad_h=padding[0], pad_w=padding[1]) 148 | f(block=(CUDA_NUM_THREADS, 1, 1), 149 | grid=(GET_BLOCKS(n), 1, 1), 150 | args=[input.data_ptr(), weight.data_ptr(), output.data_ptr()], 151 | stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) 152 | ctx.save_for_backward(input, weight) 153 | return output 154 | 155 | @staticmethod 156 | def backward(ctx, grad_output): 157 | kernel_size, stride, padding, dilation = ctx.kernel_size, ctx.stride, ctx.padding, ctx.dilation 158 | input, weight = ctx.saved_tensors 159 | assert grad_output.is_cuda 160 | if not grad_output.is_contiguous(): 161 | grad_output = grad_output.contiguous() 162 | batch_size, input_channels, input_height, input_width = input.size() 163 | _, weight_channels, weight_height, weight_width = weight.size() 164 | output_height, output_width = grad_output.size()[2:] 165 | grad_input, grad_weight = None, None 166 | opt = dict(Dtype=Dtype(grad_output), 167 | num=batch_size, input_channels=input_channels, weight_channels=weight_channels, 168 | bottom_height=input_height, bottom_width=input_width, 169 | top_height=output_height, top_width=output_width, 170 | kernel_h=kernel_size[0], kernel_w=kernel_size[1], 171 | stride_h=stride[0], stride_w=stride[1], 172 | dilation_h=dilation[0], dilation_w=dilation[1], 173 | pad_h=padding[0], pad_w=padding[1]) 174 | with torch.cuda.device_of(input): 175 | if ctx.needs_input_grad[0]: 176 | grad_input = input.new(batch_size, input_channels, input_height + 2 * padding[0], input_width + 2 * padding[1]) 177 | n = grad_input.numel() 178 | opt['nthreads'] = n 179 | f = load_kernel('aggregation_refpad_input_backward_kernel', _aggregation_refpad_input_backward_kernel, **opt) 180 | f(block=(CUDA_NUM_THREADS, 1, 1), 181 | grid=(GET_BLOCKS(n), 1, 1), 182 | args=[grad_output.data_ptr(), weight.data_ptr(), grad_input.data_ptr()], 183 | stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) 184 | grad_input[:, :, padding[0] + 1:2 * padding[0] + 1, :] += torch.flip(grad_input[:, :, :padding[0], :], dims=[2]) 185 | grad_input[:, :, input_height - 1:input_height + padding[0] - 1, :] += torch.flip(grad_input[:, :, input_height + padding[0]:, :], dims=[2]) 186 | grad_input[:, :, :, padding[1] + 1:2 * padding[1] + 1] += torch.flip(grad_input[:, :, :, :padding[1]], dims=[3]) 187 | grad_input[:, :, :, input_width - 1:input_width + padding[1] - 1] += torch.flip(grad_input[:, :, :, input_width + padding[1]:], dims=[3]) 188 | grad_input = grad_input[:, :, padding[0]:padding[0]+input_height, padding[1]:padding[1]+input_width] 189 | 190 | if ctx.needs_input_grad[1]: 191 | grad_weight = weight.new(weight.size()) 192 | n = grad_weight.numel() // weight.shape[2] 193 | opt['nthreads'] = n 194 | f = load_kernel('aggregation_refpad_weight_backward_kernel', _aggregation_refpad_weight_backward_kernel, **opt) 195 | f(block=(CUDA_NUM_THREADS, 1, 1), 196 | grid=(GET_BLOCKS(n), 1, 1), 197 | args=[grad_output.data_ptr(), input.data_ptr(), grad_weight.data_ptr()], 198 | stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) 199 | return grad_input, grad_weight, None, None, None, None 200 | 201 | 202 | def aggregation_refpad(input, weight, kernel_size=3, stride=1, padding=0, dilation=1): 203 | assert input.shape[0] == weight.shape[0] and (input.shape[1] % weight.shape[1] == 0) 204 | if input.is_cuda: 205 | out = AggregationRefpad.apply(input, weight, kernel_size, stride, padding, dilation) 206 | else: 207 | raise NotImplementedError 208 | return out 209 | 210 | 211 | def test_aggregation_refpad(): 212 | import os 213 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 214 | kernel_size, stride, dilation = 5, 4, 2 215 | padding = (dilation * (kernel_size - 1) + 1) // 2 216 | n, c_x, c_w, in_height, in_width = 2, 8, 4, 5, 5 217 | out_height = int((in_height + 2 * padding - (dilation * (kernel_size - 1) + 1)) / stride + 1) 218 | out_width = int((in_width + 2 * padding - (dilation * (kernel_size - 1) + 1)) / stride + 1) 219 | x = torch.randn(n, c_x, in_height, in_width, requires_grad=True).double().cuda() 220 | w = torch.randn(n, c_w, pow(kernel_size, 2), out_height * out_width, requires_grad=True).double().cuda() 221 | 222 | y1 = aggregation_refpad(x, w, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation) 223 | unfold_j = torch.nn.Unfold(kernel_size=kernel_size, dilation=dilation, padding=0, stride=stride) 224 | pad = torch.nn.ReflectionPad2d(padding) 225 | x2 = unfold_j(pad(x)).view(n, c_x // c_w, c_w, pow(kernel_size, 2), out_height * out_width) 226 | y2 = (w.unsqueeze(1) * x2).sum(-2).view(n, c_x, out_height, out_width) 227 | assert (y1 - y2).abs().max() < 1e-9 228 | 229 | gx1 = torch.autograd.grad(y1.mean(), x, retain_graph=True)[0] 230 | gx2 = torch.autograd.grad(y2.mean(), x, retain_graph=True)[0] 231 | assert (gx1 - gx2).abs().max() < 1e-9 232 | 233 | gw1 = torch.autograd.grad(y1.mean(), w, retain_graph=True)[0] 234 | gw2 = torch.autograd.grad(y2.mean(), w, retain_graph=True)[0] 235 | assert (gw1 - gw2).abs().max() < 1e-9 236 | 237 | from functools import partial 238 | assert torch.autograd.gradcheck(partial(aggregation_refpad, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation), (x, w)) 239 | print('test case passed') 240 | 241 | 242 | if __name__ == '__main__': 243 | test_aggregation_refpad() 244 | -------------------------------------------------------------------------------- /refer/refer.py: -------------------------------------------------------------------------------- 1 | """ 2 | This interface provides access to four datasets: 3 | 1) refclef 4 | 2) refcoco 5 | 3) refcoco+ 6 | 4) refcocog 7 | split by unc and google 8 | 9 | The following API functions are defined: 10 | REFER - REFER api class 11 | getRefIds - get ref ids that satisfy given filter conditions. 12 | getAnnIds - get ann ids that satisfy given filter conditions. 13 | getImgIds - get image ids that satisfy given filter conditions. 14 | getCatIds - get category ids that satisfy given filter conditions. 15 | loadRefs - load refs with the specified ref ids. 16 | loadAnns - load anns with the specified ann ids. 17 | loadImgs - load images with the specified image ids. 18 | loadCats - load category names with the specified category ids. 19 | getRefBox - get ref's bounding box [x, y, w, h] given the ref_id 20 | showRef - show image, segmentation or box of the referred object with the ref 21 | getMask - get mask and area of the referred object given ref 22 | showMask - show mask of the referred object given ref 23 | """ 24 | 25 | import sys 26 | import os.path as osp 27 | import json 28 | import pickle as pickle 29 | import time 30 | import itertools 31 | import skimage.io as io 32 | import matplotlib.pyplot as plt 33 | from matplotlib.collections import PatchCollection 34 | from matplotlib.patches import Polygon, Rectangle 35 | from pprint import pprint 36 | import numpy as np 37 | from pycocotools import mask 38 | 39 | 40 | class REFER: 41 | 42 | def __init__(self, data_root, dataset='refcoco', splitBy='unc'): 43 | # provide data_root folder which contains refclef, refcoco, refcoco+ and refcocog 44 | # also provide dataset name and splitBy information 45 | # e.g., dataset = 'refcoco', splitBy = 'unc' 46 | print('loading dataset %s into memory...' % dataset) 47 | if dataset == 'refcocog': 48 | print('Split by {}!'.format(splitBy)) 49 | self.ROOT_DIR = osp.abspath(osp.dirname(__file__)) 50 | self.DATA_DIR = osp.join(data_root, dataset) 51 | if dataset in ['refcoco', 'refcoco+', 'refcocog']: 52 | self.IMAGE_DIR = osp.join(data_root, 'images/mscoco/images/train2014') 53 | elif dataset == 'refclef': 54 | self.IMAGE_DIR = osp.join(data_root, 'images/saiapr_tc-12') 55 | elif dataset == 'rrsisd': 56 | self.IMAGE_DIR = osp.join(data_root, 'images/rrsisd/JPEGImages') 57 | else: 58 | print('No refer dataset is called [%s]' % dataset) 59 | sys.exit() 60 | 61 | # load refs from data/dataset/refs(dataset).json 62 | tic = time.time() 63 | ref_file = osp.join(self.DATA_DIR, 'refs(' + splitBy + ').p') 64 | self.data = {} 65 | self.data['dataset'] = dataset 66 | f = open(ref_file, 'r') 67 | self.data['refs'] = pickle.load(open(ref_file, 'rb')) 68 | 69 | # load annotations from data/dataset/instances.json 70 | instances_file = osp.join(self.DATA_DIR, 'instances.json') 71 | instances = json.load(open(instances_file, 'r')) 72 | self.data['images'] = instances['images'] 73 | self.data['annotations'] = instances['annotations'] 74 | self.data['categories'] = instances['categories'] 75 | 76 | # create index 77 | self.createIndex() 78 | print('DONE (t=%.2fs)' % (time.time() - tic)) 79 | 80 | def createIndex(self): 81 | # create sets of mapping 82 | # 1) Refs: {ref_id: ref} 83 | # 2) Anns: {ann_id: ann} 84 | # 3) Imgs: {image_id: image} 85 | # 4) Cats: {category_id: category_name} 86 | # 5) Sents: {sent_id: sent} 87 | # 6) imgToRefs: {image_id: refs} 88 | # 7) imgToAnns: {image_id: anns} 89 | # 8) refToAnn: {ref_id: ann} 90 | # 9) annToRef: {ann_id: ref} 91 | # 10) catToRefs: {category_id: refs} 92 | # 11) sentToRef: {sent_id: ref} 93 | # 12) sentToTokens: {sent_id: tokens} 94 | print('creating index...') 95 | # fetch info from instances 96 | Anns, Imgs, Cats, imgToAnns = {}, {}, {}, {} 97 | for ann in self.data['annotations']: 98 | Anns[ann['id']] = ann 99 | imgToAnns[ann['image_id']] = imgToAnns.get(ann['image_id'], []) + [ann] 100 | for img in self.data['images']: 101 | Imgs[img['id']] = img 102 | for cat in self.data['categories']: 103 | Cats[cat['id']] = cat['name'] 104 | 105 | # fetch info from refs 106 | Refs, imgToRefs, refToAnn, annToRef, catToRefs = {}, {}, {}, {}, {} 107 | Sents, sentToRef, sentToTokens = {}, {}, {} 108 | for ref in self.data['refs']: 109 | # ids 110 | ref_id = ref['ref_id'] 111 | ann_id = ref['ann_id'] 112 | category_id = ref['category_id'] 113 | image_id = ref['image_id'] 114 | 115 | # add mapping related to ref 116 | Refs[ref_id] = ref 117 | imgToRefs[image_id] = imgToRefs.get(image_id, []) + [ref] 118 | catToRefs[category_id] = catToRefs.get(category_id, []) + [ref] 119 | refToAnn[ref_id] = Anns[ann_id] 120 | annToRef[ann_id] = ref 121 | 122 | # add mapping of sent 123 | for sent in ref['sentences']: 124 | Sents[sent['sent_id']] = sent 125 | sentToRef[sent['sent_id']] = ref 126 | sentToTokens[sent['sent_id']] = sent['tokens'] 127 | 128 | # create class members 129 | self.Refs = Refs 130 | self.Anns = Anns 131 | self.Imgs = Imgs 132 | self.Cats = Cats 133 | self.Sents = Sents 134 | self.imgToRefs = imgToRefs 135 | self.imgToAnns = imgToAnns 136 | self.refToAnn = refToAnn 137 | self.annToRef = annToRef 138 | self.catToRefs = catToRefs 139 | self.sentToRef = sentToRef 140 | self.sentToTokens = sentToTokens 141 | print('index created.') 142 | 143 | def getRefIds(self, image_ids=[], cat_ids=[], ref_ids=[], split=''): 144 | image_ids = image_ids if type(image_ids) == list else [image_ids] 145 | cat_ids = cat_ids if type(cat_ids) == list else [cat_ids] 146 | ref_ids = ref_ids if type(ref_ids) == list else [ref_ids] 147 | 148 | if len(image_ids) == len(cat_ids) == len(ref_ids) == len(split) == 0: 149 | refs = self.data['refs'] 150 | else: 151 | if not len(image_ids) == 0: 152 | refs = [self.imgToRefs[image_id] for image_id in image_ids] 153 | else: 154 | refs = self.data['refs'] 155 | if not len(cat_ids) == 0: 156 | refs = [ref for ref in refs if ref['category_id'] in cat_ids] 157 | if not len(ref_ids) == 0: 158 | refs = [ref for ref in refs if ref['ref_id'] in ref_ids] 159 | if not len(split) == 0: 160 | if split in ['testA', 'testB', 'testC']: 161 | refs = [ref for ref in refs if split[-1] in ref['split']] # we also consider testAB, testBC, ... 162 | elif split in ['testAB', 'testBC', 'testAC']: 163 | refs = [ref for ref in refs if ref['split'] == split] # rarely used I guess... 164 | elif split == 'test': 165 | refs = [ref for ref in refs if 'test' in ref['split']] 166 | elif split == 'train' or split == 'val': 167 | refs = [ref for ref in refs if ref['split'] == split] 168 | else: 169 | print('No such split [%s]' % split) 170 | sys.exit() 171 | ref_ids = [ref['ref_id'] for ref in refs] 172 | return ref_ids 173 | 174 | def getAnnIds(self, image_ids=[], cat_ids=[], ref_ids=[]): 175 | image_ids = image_ids if type(image_ids) == list else [image_ids] 176 | cat_ids = cat_ids if type(cat_ids) == list else [cat_ids] 177 | ref_ids = ref_ids if type(ref_ids) == list else [ref_ids] 178 | 179 | if len(image_ids) == len(cat_ids) == len(ref_ids) == 0: 180 | ann_ids = [ann['id'] for ann in self.data['annotations']] 181 | else: 182 | if not len(image_ids) == 0: 183 | lists = [self.imgToAnns[image_id] for image_id in image_ids if 184 | image_id in self.imgToAnns] # list of [anns] 185 | anns = list(itertools.chain.from_iterable(lists)) 186 | else: 187 | anns = self.data['annotations'] 188 | if not len(cat_ids) == 0: 189 | anns = [ann for ann in anns if ann['category_id'] in cat_ids] 190 | ann_ids = [ann['id'] for ann in anns] 191 | if not len(ref_ids) == 0: 192 | ids = set(ann_ids).intersection(set([self.Refs[ref_id]['ann_id'] for ref_id in ref_ids])) 193 | return ann_ids 194 | 195 | def getImgIds(self, ref_ids=[]): 196 | ref_ids = ref_ids if type(ref_ids) == list else [ref_ids] 197 | 198 | if not len(ref_ids) == 0: 199 | image_ids = list(set([self.Refs[ref_id]['image_id'] for ref_id in ref_ids])) 200 | else: 201 | image_ids = self.Imgs.keys() 202 | return image_ids 203 | 204 | def getCatIds(self): 205 | return self.Cats.keys() 206 | 207 | def loadRefs(self, ref_ids=[]): 208 | if type(ref_ids) == list: 209 | return [self.Refs[ref_id] for ref_id in ref_ids] 210 | elif type(ref_ids) == int: 211 | return [self.Refs[ref_ids]] 212 | 213 | def loadAnns(self, ann_ids=[]): 214 | if type(ann_ids) == list: 215 | return [self.Anns[ann_id] for ann_id in ann_ids] 216 | elif type(ann_ids) == int or type(ann_ids) == unicode: 217 | return [self.Anns[ann_ids]] 218 | 219 | def loadImgs(self, image_ids=[]): 220 | if type(image_ids) == list: 221 | return [self.Imgs[image_id] for image_id in image_ids] 222 | elif type(image_ids) == int: 223 | return [self.Imgs[image_ids]] 224 | 225 | def loadCats(self, cat_ids=[]): 226 | if type(cat_ids) == list: 227 | return [self.Cats[cat_id] for cat_id in cat_ids] 228 | elif type(cat_ids) == int: 229 | return [self.Cats[cat_ids]] 230 | 231 | def getRefBox(self, ref_id): 232 | ref = self.Refs[ref_id] 233 | ann = self.refToAnn[ref_id] 234 | return ann['bbox'] # [x, y, w, h] 235 | 236 | def showRef(self, ref, seg_box='seg'): 237 | ax = plt.gca() 238 | # show image 239 | image = self.Imgs[ref['image_id']] 240 | I = io.imread(osp.join(self.IMAGE_DIR, image['file_name'])) 241 | ax.imshow(I) 242 | # show refer expression 243 | for sid, sent in enumerate(ref['sentences']): 244 | print('%s. %s' % (sid + 1, sent['sent'])) 245 | # show segmentations 246 | if seg_box == 'seg': 247 | ann_id = ref['ann_id'] 248 | ann = self.Anns[ann_id] 249 | polygons = [] 250 | color = [] 251 | c = 'none' 252 | if type(ann['segmentation'][0]) == list: 253 | # polygon used for refcoco* 254 | for seg in ann['segmentation']: 255 | poly = np.array(seg).reshape((len(seg) / 2, 2)) 256 | polygons.append(Polygon(poly, True, alpha=0.4)) 257 | color.append(c) 258 | p = PatchCollection(polygons, facecolors=color, edgecolors=(1, 1, 0, 0), linewidths=3, alpha=1) 259 | ax.add_collection(p) # thick yellow polygon 260 | p = PatchCollection(polygons, facecolors=color, edgecolors=(1, 0, 0, 0), linewidths=1, alpha=1) 261 | ax.add_collection(p) # thin red polygon 262 | else: 263 | # mask used for refclef 264 | rle = ann['segmentation'] 265 | m = mask.decode(rle) 266 | img = np.ones((m.shape[0], m.shape[1], 3)) 267 | color_mask = np.array([2.0, 166.0, 101.0]) / 255 268 | for i in range(3): 269 | img[:, :, i] = color_mask[i] 270 | ax.imshow(np.dstack((img, m * 0.5))) 271 | # show bounding-box 272 | elif seg_box == 'box': 273 | ann_id = ref['ann_id'] 274 | ann = self.Anns[ann_id] 275 | bbox = self.getRefBox(ref['ref_id']) 276 | box_plot = Rectangle((bbox[0], bbox[1]), bbox[2], bbox[3], fill=False, edgecolor='green', linewidth=3) 277 | ax.add_patch(box_plot) 278 | 279 | def getMask(self, ref): 280 | # return mask, area and mask-center 281 | ann = self.refToAnn[ref['ref_id']] 282 | image = self.Imgs[ref['image_id']] 283 | if type(ann['segmentation'][0]) == list: # polygon 284 | rle = mask.frPyObjects(ann['segmentation'], image['height'], image['width']) 285 | else: 286 | rle = ann['segmentation'] 287 | 288 | m = mask.decode(rle) 289 | m = np.sum(m, axis=2) # sometimes there are multiple binary map (corresponding to multiple segs) 290 | m = m.astype(np.uint8) # convert to np.uint8 291 | # compute area 292 | area = sum(mask.area(rle)) # should be close to ann['area'] 293 | 294 | return {'mask': m, 'area': area} 295 | 296 | 297 | def showMask(self, ref): 298 | M = self.getMask(ref) 299 | msk = M['mask'] 300 | ax = plt.gca() 301 | ax.imshow(msk) 302 | 303 | 304 | if __name__ == '__main__': 305 | # refer = REFER(dataset='refcocog', splitBy='google') 306 | refer = REFER(data_root='data', dataset='rrsid', splitBy='google') 307 | ref_ids = refer.getRefIds() 308 | 309 | ref_ids = refer.getRefIds(split='train') 310 | print('There are %s training referred objects.' % len(ref_ids)) 311 | 312 | for ref_id in ref_ids: 313 | ref = refer.loadRefs(ref_id)[0] 314 | if len(ref['sentences']) < 2: 315 | continue 316 | print('The label is %s.' % refer.Cats[ref['category_id']]) 317 | plt.figure() 318 | refer.showRef(ref, seg_box='box') 319 | plt.show() 320 | 321 | -------------------------------------------------------------------------------- /lib/transformer.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Optional, List 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, Tensor 6 | 7 | 8 | class Transformer_vis(nn.Module): 9 | 10 | def __init__(self, d_model=256, nhead=8, num_encoder_layers=6,dim_feedforward=2048, 11 | dropout=0.1, activation="relu", normalize_before=False): 12 | super().__init__() 13 | 14 | encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, 15 | dropout, activation, normalize_before) 16 | encoder_norm = nn.LayerNorm(d_model) if normalize_before else None 17 | self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) 18 | 19 | self._reset_parameters() 20 | 21 | self.d_model = d_model 22 | self.nhead = nhead 23 | 24 | def _reset_parameters(self): 25 | for p in self.parameters(): 26 | if p.dim() > 1: 27 | nn.init.xavier_uniform_(p) 28 | 29 | def forward(self, src, mask, pos_embed): 30 | # flatten NxCxHxW to HWxNxC 31 | bs, c, h, w = src.shape 32 | src = src.flatten(2).permute(2, 0, 1) 33 | pos_embed = pos_embed.flatten(2).permute(2, 0, 1) 34 | mask = mask.flatten(1) 35 | memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) 36 | return memory.permute(1, 2, 0).view(bs, c, h, w) 37 | 38 | class Transformer_Decoder(nn.Module): 39 | def __init__(self, d_model=512, nhead=8, 40 | num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, 41 | activation="relu", normalize_before=False, 42 | return_intermediate_dec=False): 43 | super().__init__() 44 | 45 | decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, 46 | dropout, activation, normalize_before) 47 | decoder_norm = nn.LayerNorm(d_model) 48 | self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, 49 | return_intermediate=return_intermediate_dec) 50 | 51 | self._reset_parameters() 52 | 53 | self.d_model = d_model 54 | self.nhead = nhead 55 | 56 | def _reset_parameters(self): 57 | for p in self.parameters(): 58 | if p.dim() > 1: 59 | nn.init.xavier_uniform_(p) 60 | 61 | def forward(self, tgt, memory, mask,pos_embed, query_embed): 62 | hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,pos=pos_embed, query_pos=query_embed) 63 | return hs 64 | 65 | class Transformer(nn.Module): 66 | 67 | def __init__(self, d_model=256, nhead=8, num_encoder_layers=6, 68 | num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, 69 | activation="relu", normalize_before=False, 70 | return_intermediate_dec=False): 71 | super().__init__() 72 | 73 | encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, 74 | dropout, activation, normalize_before) 75 | encoder_norm = nn.LayerNorm(d_model) if normalize_before else None 76 | self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) 77 | 78 | self._reset_parameters() 79 | 80 | self.d_model = d_model 81 | self.nhead = nhead 82 | 83 | def _reset_parameters(self): 84 | for p in self.parameters(): 85 | if p.dim() > 1: 86 | nn.init.xavier_uniform_(p) 87 | 88 | def forward(self, src, mask, pos_embed): 89 | # flatten NxCxHxW to HWxNxC 90 | # permute NxCxW to WxNxC 91 | src = src.permute(2, 0, 1) 92 | pos_embed = pos_embed.permute(1, 0, 2) 93 | memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) 94 | return memory 95 | 96 | 97 | class TransformerEncoder(nn.Module): 98 | def __init__(self, encoder_layer, num_layers, norm=None): 99 | super().__init__() 100 | self.layers = _get_clones(encoder_layer, num_layers) 101 | self.num_layers = num_layers 102 | self.norm = norm 103 | 104 | def forward(self, src, 105 | mask: Optional[Tensor] = None, # 没有用mask 106 | src_key_padding_mask: Optional[Tensor] = None, 107 | pos: Optional[Tensor] = None): 108 | output = src 109 | 110 | for layer in self.layers: 111 | output = layer(output, src_mask=mask, 112 | src_key_padding_mask=src_key_padding_mask, pos=pos) 113 | 114 | if self.norm is not None: 115 | output = self.norm(output) 116 | 117 | return output 118 | 119 | 120 | class TransformerDecoder(nn.Module): 121 | 122 | def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): 123 | super().__init__() 124 | self.layers = _get_clones(decoder_layer, num_layers) 125 | self.num_layers = num_layers 126 | self.norm = norm 127 | self.return_intermediate = return_intermediate 128 | 129 | def forward(self, tgt, memory, 130 | tgt_mask: Optional[Tensor] = None, 131 | memory_mask: Optional[Tensor] = None, 132 | tgt_key_padding_mask: Optional[Tensor] = None, 133 | memory_key_padding_mask: Optional[Tensor] = None, 134 | pos: Optional[Tensor] = None, 135 | query_pos: Optional[Tensor] = None): 136 | 137 | output = tgt 138 | 139 | intermediate = [] 140 | 141 | for layer in self.layers: 142 | output = layer(output, memory, tgt_mask=tgt_mask, 143 | memory_mask=memory_mask, 144 | tgt_key_padding_mask=tgt_key_padding_mask, 145 | memory_key_padding_mask=memory_key_padding_mask, 146 | pos=pos, query_pos=query_pos) 147 | if self.return_intermediate: 148 | intermediate.append(self.norm(output)) 149 | 150 | if self.norm is not None: 151 | output = self.norm(output) 152 | if self.return_intermediate: 153 | intermediate.pop() 154 | intermediate.append(output) 155 | 156 | if self.return_intermediate: 157 | return torch.stack(intermediate) 158 | 159 | return output.unsqueeze(0) 160 | 161 | 162 | class TransformerEncoderLayer(nn.Module): 163 | 164 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 165 | activation="relu", normalize_before=False): 166 | super().__init__() 167 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 168 | # Implementation of Feedforward model 169 | self.linear1 = nn.Linear(d_model, dim_feedforward) 170 | self.dropout = nn.Dropout(dropout) 171 | self.linear2 = nn.Linear(dim_feedforward, d_model) 172 | 173 | self.norm1 = nn.LayerNorm(d_model) 174 | self.norm2 = nn.LayerNorm(d_model) 175 | self.dropout1 = nn.Dropout(dropout) 176 | self.dropout2 = nn.Dropout(dropout) 177 | 178 | self.activation = _get_activation_fn(activation) 179 | self.normalize_before = normalize_before 180 | 181 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 182 | return tensor if pos is None else tensor + pos 183 | 184 | def forward_post(self, 185 | src, 186 | src_mask: Optional[Tensor] = None, 187 | src_key_padding_mask: Optional[Tensor] = None, 188 | pos: Optional[Tensor] = None): 189 | q = k = self.with_pos_embed(src, pos) 190 | src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, 191 | key_padding_mask=src_key_padding_mask)[0] 192 | src = src + self.dropout1(src2) 193 | src = self.norm1(src) 194 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 195 | src = src + self.dropout2(src2) 196 | src = self.norm2(src) 197 | return src 198 | 199 | def forward_pre(self, src, 200 | src_mask: Optional[Tensor] = None, 201 | src_key_padding_mask: Optional[Tensor] = None, 202 | pos: Optional[Tensor] = None): 203 | src2 = self.norm1(src) 204 | q = k = self.with_pos_embed(src2, pos) 205 | src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, 206 | key_padding_mask=src_key_padding_mask)[0] 207 | src = src + self.dropout1(src2) 208 | src2 = self.norm2(src) 209 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) 210 | src = src + self.dropout2(src2) 211 | return src 212 | 213 | def forward(self, src, 214 | src_mask: Optional[Tensor] = None, 215 | src_key_padding_mask: Optional[Tensor] = None, 216 | pos: Optional[Tensor] = None): 217 | if self.normalize_before: 218 | return self.forward_pre(src, src_mask, src_key_padding_mask, pos) 219 | return self.forward_post(src, src_mask, src_key_padding_mask, pos) 220 | 221 | 222 | class TransformerDecoderLayer(nn.Module): 223 | 224 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 225 | activation="relu", normalize_before=False): 226 | super().__init__() 227 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 228 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 229 | # Implementation of Feedforward model 230 | self.linear1 = nn.Linear(d_model, dim_feedforward) 231 | self.dropout = nn.Dropout(dropout) 232 | self.linear2 = nn.Linear(dim_feedforward, d_model) 233 | 234 | self.norm1 = nn.LayerNorm(d_model) 235 | self.norm2 = nn.LayerNorm(d_model) 236 | self.norm3 = nn.LayerNorm(d_model) 237 | self.dropout1 = nn.Dropout(dropout) 238 | self.dropout2 = nn.Dropout(dropout) 239 | self.dropout3 = nn.Dropout(dropout) 240 | 241 | self.activation = _get_activation_fn(activation) 242 | self.normalize_before = normalize_before 243 | 244 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 245 | return tensor if pos is None else tensor + pos# tensor 16 pos 4096 246 | 247 | def forward_post(self, tgt, memory, 248 | tgt_mask: Optional[Tensor] = None, 249 | memory_mask: Optional[Tensor] = None, 250 | tgt_key_padding_mask: Optional[Tensor] = None, 251 | memory_key_padding_mask: Optional[Tensor] = None, 252 | pos: Optional[Tensor] = None, 253 | query_pos: Optional[Tensor] = None): 254 | q = k = self.with_pos_embed(tgt, query_pos) 255 | tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, 256 | key_padding_mask=tgt_key_padding_mask)[0] 257 | tgt = tgt + self.dropout1(tgt2) 258 | tgt = self.norm1(tgt) 259 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), 260 | key=self.with_pos_embed(memory, pos), 261 | value=memory, attn_mask=memory_mask, 262 | key_padding_mask=memory_key_padding_mask)[0] 263 | tgt = tgt + self.dropout2(tgt2) 264 | tgt = self.norm2(tgt) 265 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 266 | tgt = tgt + self.dropout3(tgt2) 267 | tgt = self.norm3(tgt) 268 | return tgt 269 | 270 | def forward_pre(self, tgt, memory, 271 | tgt_mask: Optional[Tensor] = None, 272 | memory_mask: Optional[Tensor] = None, 273 | tgt_key_padding_mask: Optional[Tensor] = None, 274 | memory_key_padding_mask: Optional[Tensor] = None, 275 | pos: Optional[Tensor] = None, 276 | query_pos: Optional[Tensor] = None): 277 | tgt2 = self.norm1(tgt) 278 | q = k = self.with_pos_embed(tgt2, query_pos) 279 | tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, 280 | key_padding_mask=tgt_key_padding_mask)[0] 281 | tgt = tgt + self.dropout1(tgt2) 282 | tgt2 = self.norm2(tgt) 283 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), 284 | key=self.with_pos_embed(memory, pos), 285 | value=memory, attn_mask=memory_mask, 286 | key_padding_mask=memory_key_padding_mask)[0] 287 | tgt = tgt + self.dropout2(tgt2) 288 | tgt2 = self.norm3(tgt) 289 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 290 | tgt = tgt + self.dropout3(tgt2) 291 | return tgt 292 | 293 | def forward(self, tgt, memory, 294 | tgt_mask: Optional[Tensor] = None, 295 | memory_mask: Optional[Tensor] = None, 296 | tgt_key_padding_mask: Optional[Tensor] = None, 297 | memory_key_padding_mask: Optional[Tensor] = None, 298 | pos: Optional[Tensor] = None, 299 | query_pos: Optional[Tensor] = None): 300 | if self.normalize_before: 301 | return self.forward_pre(tgt, memory, tgt_mask, memory_mask, 302 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 303 | return self.forward_post(tgt, memory, tgt_mask, memory_mask, 304 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 305 | 306 | def _get_clones(module, N): 307 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 308 | 309 | 310 | 311 | def _get_activation_fn(activation): 312 | """Return an activation function given a string""" 313 | if activation == "relu": 314 | return F.relu 315 | if activation == "gelu": 316 | return F.gelu 317 | if activation == "glu": 318 | return F.glu 319 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.") 320 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import time 4 | import torch 5 | import torch.utils.data 6 | import wandb 7 | import random 8 | import transforms as T 9 | import utils 10 | import numpy as np 11 | import gc 12 | import operator 13 | from functools import reduce 14 | from bert.modeling_bert import BertModel 15 | from lib import segmentation 16 | from loss.loss import Loss 17 | 18 | os.environ["WANDB_API_KEY"] = '1ae5903bce9def26f040e6a15cc95aba3a99cc91' 19 | os.environ["WANDB_MODE"] = "offline" 20 | 21 | import os 22 | os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" 23 | 24 | def seed_everything(seed=2401): 25 | random.seed(seed) 26 | os.environ['PYTHONHASHSEED'] = str(seed) 27 | np.random.seed(seed) 28 | torch.manual_seed(seed) 29 | torch.cuda.manual_seed(seed) 30 | torch.cuda.manual_seed_all(seed) 31 | torch.backends.cudnn.benchmark = False 32 | torch.backends.cudnn.deterministic = True 33 | 34 | 35 | def get_dataset(image_set, transform, args): 36 | if args.dataset == "rrsisd": 37 | from data.rrsisd_refer_bert import ReferDataset 38 | else: 39 | from data.refsegrs_refer_bert import ReferDataset 40 | ds = ReferDataset(args, 41 | split=image_set, 42 | image_transforms=transform, 43 | target_transforms=None 44 | ) 45 | num_classes = 2 46 | 47 | return ds, num_classes 48 | 49 | 50 | def IoU(pred, gt): 51 | pred = pred.argmax(1) 52 | 53 | intersection = torch.sum(torch.mul(pred, gt)) 54 | union = torch.sum(torch.add(pred, gt)) - intersection 55 | 56 | if intersection == 0 or union == 0: 57 | iou = 0 58 | else: 59 | iou = float(intersection) / float(union) 60 | return iou, intersection, union 61 | 62 | 63 | def get_transform(args): 64 | transforms = [ 65 | T.Resize(args.img_size, args.img_size), 66 | T.ToTensor(), 67 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 68 | ] 69 | return T.Compose(transforms) 70 | 71 | 72 | def criterion(input, target, weight=0.1): 73 | return Loss(weight=weight)(input, target) 74 | 75 | 76 | def evaluate(model, data_loader, bert_model, epoch): 77 | model.eval() 78 | metric_logger = utils.MetricLogger(delimiter=" ") 79 | header = "Test: " 80 | total_its = 0 81 | acc_ious = 0 82 | 83 | # evaluation variables 84 | cum_I, cum_U = 0, 0 85 | eval_seg_iou_list = [.5, .6, .7, .8, .9] 86 | seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32) 87 | seg_total = 0 88 | mean_IoU = [] 89 | total_loss = 0 90 | 91 | with torch.no_grad(): 92 | for data in metric_logger.log_every(data_loader, 100, header): 93 | total_its += 1 94 | image, target, sentences, attentions, target_masks, position_masks, _ = data 95 | image = image.cuda(non_blocking=True) 96 | target = target.cuda(non_blocking=True) 97 | sentences = sentences.cuda(non_blocking=True) 98 | attentions = attentions.cuda(non_blocking=True) 99 | target_masks = target_masks.cuda(non_blocking=True) 100 | position_masks = position_masks.cuda(non_blocking=True) 101 | 102 | sentences = sentences.squeeze(1) 103 | attentions = attentions.squeeze(1) 104 | target_masks = target_masks.squeeze(1) 105 | position_masks = position_masks.squeeze(1) 106 | 107 | if bert_model is not None: 108 | last_hidden_states = bert_model(sentences, attention_mask=attentions)[0] 109 | embedding = last_hidden_states.permute(0, 2, 1) # (B, 768, N_l) to make Conv1d happy 110 | attentions = attentions.unsqueeze(dim=-1) # (B, N_l, 1) 111 | output = model(image, embedding, l_mask=attentions) 112 | else: 113 | output = model(image, sentences, attentions, target_masks, position_masks) 114 | 115 | iou, I, U = IoU(output, target) 116 | loss = criterion(output, target) 117 | total_loss += loss.item() 118 | acc_ious += iou 119 | mean_IoU.append(iou) 120 | cum_I += I 121 | cum_U += U 122 | for n_eval_iou in range(len(eval_seg_iou_list)): 123 | eval_seg_iou = eval_seg_iou_list[n_eval_iou] 124 | seg_correct[n_eval_iou] += (iou >= eval_seg_iou) 125 | seg_total += 1 126 | iou = acc_ious / total_its 127 | 128 | mean_IoU = np.array(mean_IoU) 129 | mIoU = np.mean(mean_IoU) 130 | print('Final results:') 131 | print('Mean IoU is %.2f\n' % (mIoU * 100.)) 132 | results_str = '' 133 | for n_eval_iou in range(len(eval_seg_iou_list)): 134 | results_str += ' precision@%s = %.2f\n' % \ 135 | (str(eval_seg_iou_list[n_eval_iou]), seg_correct[n_eval_iou] * 100. / seg_total) 136 | results_str += ' overall IoU = %.2f\n' % (cum_I * 100. / cum_U) 137 | print(results_str) 138 | 139 | if args.local_rank == 0: 140 | wandb.log({ 141 | "val mIoU": mIoU, 142 | "val oiou": cum_I * 100. / cum_U, 143 | "val Loss": total_loss / total_its}) 144 | 145 | return 100 * iou, 100 * cum_I / cum_U 146 | 147 | 148 | def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, epoch, print_freq, 149 | iterations, bert_model): 150 | model.train() 151 | metric_logger = utils.MetricLogger(delimiter=" ") 152 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}')) 153 | header = 'Epoch: [{}]'.format(epoch) 154 | train_loss = 0 155 | total_its = 0 156 | 157 | # for data in data_loader: 158 | for i, data in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 159 | total_its += 1 160 | image, target, sentences, attentions, target_masks, position_masks, _ = data 161 | image = image.cuda(non_blocking=True) 162 | target = target.cuda(non_blocking=True) 163 | sentences = sentences.cuda(non_blocking=True) 164 | attentions = attentions.cuda(non_blocking=True) 165 | target_masks = target_masks.cuda(non_blocking=True) 166 | position_masks = position_masks.cuda(non_blocking=True) 167 | 168 | sentences = sentences.squeeze(1) 169 | attentions = attentions.squeeze(1) 170 | target_masks = target_masks.squeeze(1) 171 | position_masks = position_masks.squeeze(1) 172 | 173 | if bert_model is not None: 174 | last_hidden_states = bert_model(sentences, attention_mask=attentions)[0] # (6, 10, 768) 175 | embedding = last_hidden_states.permute(0, 2, 1) # (B, 768, N_l) to make Conv1d happy 176 | attentions = attentions.unsqueeze(dim=-1) # (batch, N_l, 1) 177 | output = model(image, embedding, attentions)#, sentences_hidden_state)# [4,2,120,120] 178 | else: 179 | output = model(image, sentences, attentions, target_masks, position_masks) 180 | # output = model(image, sentences, attentions) 181 | optimizer.zero_grad() 182 | loss = criterion(output, target) 183 | 184 | loss.backward() 185 | optimizer.step() 186 | lr_scheduler.step() 187 | 188 | torch.cuda.synchronize() 189 | train_loss += loss.item() 190 | iterations += 1 191 | metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"]) 192 | 193 | del image, target, sentences, attentions, loss, output, data 194 | if bert_model is not None: 195 | del last_hidden_states, embedding 196 | 197 | gc.collect() 198 | torch.cuda.empty_cache() 199 | torch.cuda.synchronize() 200 | if args.local_rank == 0: 201 | wandb.log({ 202 | "Train Loss": train_loss / total_its,}) 203 | 204 | 205 | def main(args): 206 | 207 | # make folders 208 | if not os.path.exists(args.output_dir): 209 | os.mkdir(args.output_dir) 210 | 211 | # set datasets 212 | print("\n[***] Set Datasets") 213 | dataset, num_classes = get_dataset("train", 214 | get_transform(args=args), 215 | args=args) 216 | 217 | dataset_test, _ = get_dataset("val", 218 | get_transform(args=args), 219 | args=args) 220 | 221 | # Debug = False 222 | # if Debug: 223 | # dataset = torch.utils.data.Subset(dataset, list(range(200))) 224 | # dataset_test = torch.utils.data.Subset(dataset_test, list(range(20))) 225 | 226 | # build batch sampler 227 | print(f"local rank {args.local_rank} / global rank {utils.get_rank()} successfully built train dataset.") 228 | train_sampler = torch.utils.data.RandomSampler(dataset) 229 | test_sampler = torch.utils.data.SequentialSampler(dataset_test) 230 | 231 | # build data loader 232 | data_loader = torch.utils.data.DataLoader( 233 | dataset, batch_size=args.batch_size, 234 | sampler=train_sampler, num_workers=args.workers, pin_memory=args.pin_mem, drop_last=True) 235 | data_loader_test = torch.utils.data.DataLoader( 236 | dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers) 237 | 238 | # model initialization 239 | print("\n[***] Build Model") 240 | model = segmentation.__dict__[args.model](pretrained=args.pretrained_swin_weights, 241 | args=args) 242 | model.cuda() 243 | # print(args.model) 244 | 245 | if args.model != 'lavt_one': 246 | # need to load bert outside the model 247 | model_class = BertModel 248 | bert_model = model_class.from_pretrained(args.ck_bert) 249 | bert_model.pooler = None # a work-around for a bug in Transformers = 3.0.2 that appears for DistributedDataParallel 250 | bert_model.cuda() 251 | else: 252 | bert_model = None 253 | single_bert_model = None 254 | 255 | # resume training 256 | if args.resume: 257 | checkpoint = torch.load(args.resume, map_location='cpu') 258 | model.load_state_dict(checkpoint['model'], strict=False) 259 | if args.model != 'lavt_one': 260 | bert_model.load_state_dict(checkpoint['bert_model']) 261 | 262 | # parameters to optimize 263 | backbone_no_decay = list() 264 | backbone_decay = list() 265 | for name, m in model.backbone.named_parameters(): 266 | if 'norm' in name or 'absolute_pos_embed' in name or 'relative_position_bias_table' in name: 267 | backbone_no_decay.append(m) 268 | else: 269 | backbone_decay.append(m) 270 | 271 | if args.model != 'lavt_one': 272 | params_to_optimize = [ 273 | {'params': backbone_no_decay, 'weight_decay': 0.0}, 274 | {'params': backbone_decay}, 275 | {"params": [p for p in model.classifier.parameters() if p.requires_grad]}, 276 | # the following are the parameters of bert 277 | {"params": reduce(operator.concat, 278 | [[p for p in bert_model.module.encoder.layer[i].parameters() 279 | if p.requires_grad] for i in range(10)])}, 280 | ] 281 | else: 282 | params_to_optimize = [ 283 | {'params': backbone_no_decay, 'weight_decay': 0.0}, 284 | {'params': backbone_decay}, 285 | {"params": [p for p in model.classifier.parameters() if p.requires_grad]}, 286 | # the following are the parameters of bert 287 | {"params": reduce(operator.concat, 288 | [[p for p in model.text_encoder.encoder.layer[i].parameters() 289 | if p.requires_grad] for i in range(10)])}, 290 | ] 291 | 292 | # optimizer 293 | optimizer = torch.optim.AdamW(params_to_optimize, 294 | lr=args.lr, 295 | weight_decay=args.weight_decay, 296 | amsgrad=args.amsgrad 297 | ) 298 | 299 | # learning rate scheduler 300 | lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, 301 | lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9) 302 | 303 | # housekeeping 304 | start_time = time.time() 305 | iterations = 0 306 | best_oIoU = -0.1 307 | 308 | # resume training (optimizer, lr scheduler, and the epoch) 309 | if args.resume: 310 | optimizer.load_state_dict(checkpoint['optimizer']) 311 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 312 | resume_epoch = checkpoint['epoch'] 313 | best_oIoU = checkpoint['best_oIoU'] 314 | print("Resume training from Epoch {}".format(resume_epoch+1)) 315 | print("Initail best IoU is {}".format(best_oIoU)) 316 | 317 | else: 318 | resume_epoch = -999 319 | 320 | # training loops 321 | if args.local_rank == 0: 322 | wandb.watch(model, log="all") 323 | 324 | for epoch in range(max(0, resume_epoch+1), args.epochs): 325 | # data_loader.sampler.set_epoch(epoch) 326 | train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, epoch, args.print_freq, 327 | iterations, bert_model) 328 | iou, overallIoU = evaluate(model, data_loader_test, bert_model, epoch) 329 | print('Average object IoU {}'.format(iou)) 330 | print('Overall IoU {}'.format(overallIoU)) 331 | best = (best_oIoU < overallIoU) 332 | if bert_model is not None: 333 | dict_to_save = {'model': model.state_dict(), 'bert_model': bert_model.state_dict(), 334 | 'optimizer': optimizer.state_dict(), 'epoch': epoch, 'args': args, 335 | 'lr_scheduler': lr_scheduler.state_dict(), 336 | 'best_oIoU': best_oIoU} 337 | else: 338 | dict_to_save = {'model': model.state_dict(), 339 | 'optimizer': optimizer.state_dict(), 'epoch': epoch, 'args': args, 340 | 'lr_scheduler': lr_scheduler.state_dict(), 341 | 'best_oIoU': best_oIoU} 342 | 343 | if best: 344 | best_oIoU = overallIoU 345 | print('Better epoch: {}\n'.format(epoch)) 346 | dict_to_save['best_oIoU'] = best_oIoU 347 | utils.save_on_master( 348 | dict_to_save, os.path.join(args.output_dir, 'model_best_{}.pth'.format(args.model_id))) 349 | 350 | if epoch % 10 == 0: 351 | utils.save_on_master(dict_to_save, os.path.join(args.output_dir, 352 | 'model_last_{}_{}.pth'.format(args.model_id, epoch))) 353 | utils.save_on_master(dict_to_save, os.path.join(args.output_dir, 354 | 'model_last_{}.pth'.format(args.model_id))) 355 | if args.local_rank == 0: 356 | wandb.save('model.h5') 357 | 358 | # summarize 359 | total_time = time.time() - start_time 360 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 361 | print('Training time {}'.format(total_time_str)) 362 | 363 | 364 | if __name__ == "__main__": 365 | from args import get_parser 366 | seed_everything() 367 | parser = get_parser() 368 | args = parser.parse_args() 369 | if args.local_rank == 0: 370 | wandb.init(project="fianet_2080") 371 | print('Image size: {}'.format(str(args.img_size))) 372 | main(args) 373 | -------------------------------------------------------------------------------- /bert/configuration_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Configuration base class and utilities.""" 17 | 18 | 19 | import copy 20 | import json 21 | import logging 22 | import os 23 | from typing import Dict, Tuple 24 | 25 | from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url 26 | 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | 31 | class PretrainedConfig(object): 32 | r""" Base class for all configuration classes. 33 | Handles a few parameters common to all models' configurations as well as methods for loading/downloading/saving configurations. 34 | 35 | Note: 36 | A configuration file can be loaded and saved to disk. Loading the configuration file and using this file to initialize a model does **not** load the model weights. 37 | It only affects the model's configuration. 38 | 39 | Class attributes (overridden by derived classes): 40 | - ``model_type``: a string that identifies the model type, that we serialize into the JSON file, and that we use to recreate the correct object in :class:`~transformers.AutoConfig`. 41 | 42 | Args: 43 | finetuning_task (:obj:`string` or :obj:`None`, `optional`, defaults to :obj:`None`): 44 | Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow or PyTorch) checkpoint. 45 | num_labels (:obj:`int`, `optional`, defaults to `2`): 46 | Number of classes to use when the model is a classification model (sequences/tokens) 47 | output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`False`): 48 | Should the model returns all hidden-states. 49 | output_attentions (:obj:`bool`, `optional`, defaults to :obj:`False`): 50 | Should the model returns all attentions. 51 | torchscript (:obj:`bool`, `optional`, defaults to :obj:`False`): 52 | Is the model used with Torchscript (for PyTorch models). 53 | """ 54 | model_type: str = "" 55 | 56 | def __init__(self, **kwargs): 57 | # Attributes with defaults 58 | self.output_hidden_states = kwargs.pop("output_hidden_states", False) 59 | self.output_attentions = kwargs.pop("output_attentions", False) 60 | self.use_cache = kwargs.pop("use_cache", True) # Not used by all models 61 | self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models 62 | self.use_bfloat16 = kwargs.pop("use_bfloat16", False) 63 | self.pruned_heads = kwargs.pop("pruned_heads", {}) 64 | 65 | # Is decoder is used in encoder-decoder models to differentiate encoder from decoder 66 | self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False) 67 | self.is_decoder = kwargs.pop("is_decoder", False) 68 | 69 | # Parameters for sequence generation 70 | self.max_length = kwargs.pop("max_length", 20) 71 | self.min_length = kwargs.pop("min_length", 0) 72 | self.do_sample = kwargs.pop("do_sample", False) 73 | self.early_stopping = kwargs.pop("early_stopping", False) 74 | self.num_beams = kwargs.pop("num_beams", 1) 75 | self.temperature = kwargs.pop("temperature", 1.0) 76 | self.top_k = kwargs.pop("top_k", 50) 77 | self.top_p = kwargs.pop("top_p", 1.0) 78 | self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0) 79 | self.length_penalty = kwargs.pop("length_penalty", 1.0) 80 | self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0) 81 | self.bad_words_ids = kwargs.pop("bad_words_ids", None) 82 | self.num_return_sequences = kwargs.pop("num_return_sequences", 1) 83 | 84 | # Fine-tuning task arguments 85 | self.architectures = kwargs.pop("architectures", None) 86 | self.finetuning_task = kwargs.pop("finetuning_task", None) 87 | self.id2label = kwargs.pop("id2label", None) 88 | self.label2id = kwargs.pop("label2id", None) 89 | if self.id2label is not None: 90 | kwargs.pop("num_labels", None) 91 | self.id2label = dict((int(key), value) for key, value in self.id2label.items()) 92 | # Keys are always strings in JSON so convert ids to int here. 93 | else: 94 | self.num_labels = kwargs.pop("num_labels", 2) 95 | 96 | # Tokenizer arguments TODO: eventually tokenizer and models should share the same config 97 | self.prefix = kwargs.pop("prefix", None) 98 | self.bos_token_id = kwargs.pop("bos_token_id", None) 99 | self.pad_token_id = kwargs.pop("pad_token_id", None) 100 | self.eos_token_id = kwargs.pop("eos_token_id", None) 101 | self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None) 102 | 103 | # task specific arguments 104 | self.task_specific_params = kwargs.pop("task_specific_params", None) 105 | 106 | # TPU arguments 107 | self.xla_device = kwargs.pop("xla_device", None) 108 | 109 | # Additional attributes without default values 110 | for key, value in kwargs.items(): 111 | try: 112 | setattr(self, key, value) 113 | except AttributeError as err: 114 | logger.error("Can't set {} with value {} for {}".format(key, value, self)) 115 | raise err 116 | 117 | @property 118 | def num_labels(self): 119 | return len(self.id2label) 120 | 121 | @num_labels.setter 122 | def num_labels(self, num_labels): 123 | self.id2label = {i: "LABEL_{}".format(i) for i in range(num_labels)} 124 | self.label2id = dict(zip(self.id2label.values(), self.id2label.keys())) 125 | 126 | def save_pretrained(self, save_directory): 127 | """ 128 | Save a configuration object to the directory `save_directory`, so that it 129 | can be re-loaded using the :func:`~transformers.PretrainedConfig.from_pretrained` class method. 130 | 131 | Args: 132 | save_directory (:obj:`string`): 133 | Directory where the configuration JSON file will be saved. 134 | """ 135 | if os.path.isfile(save_directory): 136 | raise AssertionError("Provided path ({}) should be a directory, not a file".format(save_directory)) 137 | os.makedirs(save_directory, exist_ok=True) 138 | # If we save using the predefined names, we can load using `from_pretrained` 139 | output_config_file = os.path.join(save_directory, CONFIG_NAME) 140 | 141 | self.to_json_file(output_config_file, use_diff=True) 142 | logger.info("Configuration saved in {}".format(output_config_file)) 143 | 144 | @classmethod 145 | def from_pretrained(cls, pretrained_model_name_or_path, **kwargs) -> "PretrainedConfig": 146 | r""" 147 | 148 | Instantiate a :class:`~transformers.PretrainedConfig` (or a derived class) from a pre-trained model configuration. 149 | 150 | Args: 151 | pretrained_model_name_or_path (:obj:`string`): 152 | either: 153 | - a string with the `shortcut name` of a pre-trained model configuration to load from cache or 154 | download, e.g.: ``bert-base-uncased``. 155 | - a string with the `identifier name` of a pre-trained model configuration that was user-uploaded to 156 | our S3, e.g.: ``dbmdz/bert-base-german-cased``. 157 | - a path to a `directory` containing a configuration file saved using the 158 | :func:`~transformers.PretrainedConfig.save_pretrained` method, e.g.: ``./my_model_directory/``. 159 | - a path or url to a saved configuration JSON `file`, e.g.: 160 | ``./my_model_directory/configuration.json``. 161 | cache_dir (:obj:`string`, `optional`): 162 | Path to a directory in which a downloaded pre-trained model 163 | configuration should be cached if the standard cache should not be used. 164 | kwargs (:obj:`Dict[str, any]`, `optional`): 165 | The values in kwargs of any keys which are configuration attributes will be used to override the loaded 166 | values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is 167 | controlled by the `return_unused_kwargs` keyword parameter. 168 | force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): 169 | Force to (re-)download the model weights and configuration files and override the cached versions if they exist. 170 | resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): 171 | Do not delete incompletely recieved file. Attempt to resume the download if such a file exists. 172 | proxies (:obj:`Dict`, `optional`): 173 | A dictionary of proxy servers to use by protocol or endpoint, e.g.: 174 | :obj:`{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.` 175 | The proxies are used on each request. 176 | return_unused_kwargs: (`optional`) bool: 177 | If False, then this function returns just the final configuration object. 178 | If True, then this functions returns a :obj:`Tuple(config, unused_kwargs)` where `unused_kwargs` is a 179 | dictionary consisting of the key/value pairs whose keys are not configuration attributes: ie the part 180 | of kwargs which has not been used to update `config` and is otherwise ignored. 181 | 182 | Returns: 183 | :class:`PretrainedConfig`: An instance of a configuration object 184 | 185 | Examples:: 186 | 187 | # We can't instantiate directly the base class `PretrainedConfig` so let's show the examples on a 188 | # derived class: BertConfig 189 | config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache. 190 | config = BertConfig.from_pretrained('./test/saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')` 191 | config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json') 192 | config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False) 193 | assert config.output_attention == True 194 | config, unused_kwargs = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, 195 | foo=False, return_unused_kwargs=True) 196 | assert config.output_attention == True 197 | assert unused_kwargs == {'foo': False} 198 | 199 | """ 200 | config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) 201 | return cls.from_dict(config_dict, **kwargs) 202 | 203 | @classmethod 204 | def get_config_dict(cls, pretrained_model_name_or_path: str, **kwargs) -> Tuple[Dict, Dict]: 205 | """ 206 | From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used 207 | for instantiating a Config using `from_dict`. 208 | 209 | Parameters: 210 | pretrained_model_name_or_path (:obj:`string`): 211 | The identifier of the pre-trained checkpoint from which we want the dictionary of parameters. 212 | 213 | Returns: 214 | :obj:`Tuple[Dict, Dict]`: The dictionary that will be used to instantiate the configuration object. 215 | 216 | """ 217 | cache_dir = kwargs.pop("cache_dir", None) 218 | force_download = kwargs.pop("force_download", False) 219 | resume_download = kwargs.pop("resume_download", False) 220 | proxies = kwargs.pop("proxies", None) 221 | local_files_only = kwargs.pop("local_files_only", False) 222 | 223 | if os.path.isdir(pretrained_model_name_or_path): 224 | config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) 225 | elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): 226 | config_file = pretrained_model_name_or_path 227 | else: 228 | config_file = hf_bucket_url(pretrained_model_name_or_path, filename=CONFIG_NAME, use_cdn=False) 229 | 230 | try: 231 | # Load from URL or cache if already cached 232 | resolved_config_file = cached_path( 233 | config_file, 234 | cache_dir=cache_dir, 235 | force_download=force_download, 236 | proxies=proxies, 237 | resume_download=resume_download, 238 | local_files_only=local_files_only, 239 | ) 240 | # Load config dict 241 | if resolved_config_file is None: 242 | raise EnvironmentError 243 | config_dict = cls._dict_from_json_file(resolved_config_file) 244 | 245 | except EnvironmentError: 246 | msg = ( 247 | f"Can't load config for '{pretrained_model_name_or_path}'. Make sure that:\n\n" 248 | f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n" 249 | f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a {CONFIG_NAME} file\n\n" 250 | ) 251 | raise EnvironmentError(msg) 252 | 253 | except json.JSONDecodeError: 254 | msg = ( 255 | "Couldn't reach server at '{}' to download configuration file or " 256 | "configuration file is not a valid JSON file. " 257 | "Please check network or file content here: {}.".format(config_file, resolved_config_file) 258 | ) 259 | raise EnvironmentError(msg) 260 | 261 | if resolved_config_file == config_file: 262 | logger.info("loading configuration file {}".format(config_file)) 263 | else: 264 | logger.info("loading configuration file {} from cache at {}".format(config_file, resolved_config_file)) 265 | 266 | return config_dict, kwargs 267 | 268 | @classmethod 269 | def from_dict(cls, config_dict: Dict, **kwargs) -> "PretrainedConfig": 270 | """ 271 | Constructs a `Config` from a Python dictionary of parameters. 272 | 273 | Args: 274 | config_dict (:obj:`Dict[str, any]`): 275 | Dictionary that will be used to instantiate the configuration object. Such a dictionary can be retrieved 276 | from a pre-trained checkpoint by leveraging the :func:`~transformers.PretrainedConfig.get_config_dict` 277 | method. 278 | kwargs (:obj:`Dict[str, any]`): 279 | Additional parameters from which to initialize the configuration object. 280 | 281 | Returns: 282 | :class:`PretrainedConfig`: An instance of a configuration object 283 | """ 284 | return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) 285 | 286 | config = cls(**config_dict) 287 | 288 | if hasattr(config, "pruned_heads"): 289 | config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items()) 290 | 291 | # Update config with kwargs if needed 292 | to_remove = [] 293 | for key, value in kwargs.items(): 294 | if hasattr(config, key): 295 | setattr(config, key, value) 296 | to_remove.append(key) 297 | for key in to_remove: 298 | kwargs.pop(key, None) 299 | 300 | logger.info("Model config %s", str(config)) 301 | if return_unused_kwargs: 302 | return config, kwargs 303 | else: 304 | return config 305 | 306 | @classmethod 307 | def from_json_file(cls, json_file: str) -> "PretrainedConfig": 308 | """ 309 | Constructs a `Config` from the path to a json file of parameters. 310 | 311 | Args: 312 | json_file (:obj:`string`): 313 | Path to the JSON file containing the parameters. 314 | 315 | Returns: 316 | :class:`PretrainedConfig`: An instance of a configuration object 317 | 318 | """ 319 | config_dict = cls._dict_from_json_file(json_file) 320 | return cls(**config_dict) 321 | 322 | @classmethod 323 | def _dict_from_json_file(cls, json_file: str): 324 | with open(json_file, "r", encoding="utf-8") as reader: 325 | text = reader.read() 326 | return json.loads(text) 327 | 328 | def __eq__(self, other): 329 | return self.__dict__ == other.__dict__ 330 | 331 | def __repr__(self): 332 | return "{} {}".format(self.__class__.__name__, self.to_json_string()) 333 | 334 | def to_diff_dict(self): 335 | """ 336 | Removes all attributes from config which correspond to the default 337 | config attributes for better readability and serializes to a Python 338 | dictionary. 339 | 340 | Returns: 341 | :obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, 342 | """ 343 | config_dict = self.to_dict() 344 | 345 | # get the default config dict 346 | default_config_dict = PretrainedConfig().to_dict() 347 | 348 | serializable_config_dict = {} 349 | 350 | # only serialize values that differ from the default config 351 | for key, value in config_dict.items(): 352 | if key not in default_config_dict or value != default_config_dict[key]: 353 | serializable_config_dict[key] = value 354 | 355 | return serializable_config_dict 356 | 357 | def to_dict(self): 358 | """ 359 | Serializes this instance to a Python dictionary. 360 | 361 | Returns: 362 | :obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, 363 | """ 364 | output = copy.deepcopy(self.__dict__) 365 | if hasattr(self.__class__, "model_type"): 366 | output["model_type"] = self.__class__.model_type 367 | return output 368 | 369 | def to_json_string(self, use_diff=True): 370 | """ 371 | Serializes this instance to a JSON string. 372 | 373 | Args: 374 | use_diff (:obj:`bool`): 375 | If set to True, only the difference between the config instance and the default PretrainedConfig() is serialized to JSON string. 376 | 377 | Returns: 378 | :obj:`string`: String containing all the attributes that make up this configuration instance in JSON format. 379 | """ 380 | if use_diff is True: 381 | config_dict = self.to_diff_dict() 382 | else: 383 | config_dict = self.to_dict() 384 | return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" 385 | 386 | def to_json_file(self, json_file_path, use_diff=True): 387 | """ 388 | Save this instance to a json file. 389 | 390 | Args: 391 | json_file_path (:obj:`string`): 392 | Path to the JSON file in which this configuration instance's parameters will be saved. 393 | use_diff (:obj:`bool`): 394 | If set to True, only the difference between the config instance and the default PretrainedConfig() is serialized to JSON file. 395 | """ 396 | with open(json_file_path, "w", encoding="utf-8") as writer: 397 | writer.write(self.to_json_string(use_diff=use_diff)) 398 | 399 | def update(self, config_dict: Dict): 400 | """ 401 | Updates attributes of this class 402 | with attributes from `config_dict`. 403 | 404 | Args: 405 | :obj:`Dict[str, any]`: Dictionary of attributes that shall be updated for this class. 406 | """ 407 | for key, value in config_dict.items(): 408 | setattr(self, key, value) 409 | --------------------------------------------------------------------------------