├── README.md ├── arc ├── __init__.py ├── adaptive_rotated_conv.py ├── routing_function.py └── weight_init.py ├── args.py ├── bert ├── activations.py ├── configuration_bert.py ├── configuration_utils.py ├── file_utils.py ├── generation_utils.py ├── modeling_bert.py ├── modeling_utils.py ├── tokenization_bert.py ├── tokenization_utils.py └── tokenization_utils_base.py ├── data └── dataset_refer_bert.py ├── lib ├── _utils.py ├── backbone.py ├── cross_scale_interaction.py ├── mask_predictor.py ├── mmcv_custom │ ├── __init__.py │ └── checkpoint.py ├── sa │ ├── functional.py │ ├── functions │ │ ├── __init__.py │ │ ├── aggregation_refpad.py │ │ ├── aggregation_zeropad.py │ │ ├── subtraction2_refpad.py │ │ ├── subtraction2_zeropad.py │ │ ├── subtraction_refpad.py │ │ ├── subtraction_zeropad.py │ │ └── utils.py │ └── modules │ │ ├── __init__.py │ │ ├── aggregation.py │ │ ├── subtraction.py │ │ └── subtraction2.py ├── segmentation.py ├── transformer.py └── various_receptive.py ├── loss └── loss.py ├── pipeline.jpg ├── refer └── refer.py ├── requirements.txt ├── test.py ├── train.py ├── transforms.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # RMSIN 2 | This repository is the offical implementation for ["Rotated Multi-Scale Interaction Network for Referring Remote Sensing Image Segmentation."](https://arxiv.org/abs/2312.12470) 3 | ![Pipeline Image](pipeline.jpg) 4 | 5 | ## Setting Up 6 | ### Preliminaries 7 | The code has been verified to work with PyTorch v1.7.1 and Python 3.7. 8 | 1. Clone this repository. 9 | 2. Change directory to root of this repository. 10 | ### Package Dependencies 11 | 1. Create a new Conda environment with Python 3.7 then activate it: 12 | ```shell 13 | conda create -n RMSIN python==3.7 14 | conda activate RMSIN 15 | ``` 16 | 17 | 2. Install PyTorch v1.7.1 with a CUDA version that works on your cluster/machine (CUDA 10.2 is used in this example): 18 | ```shell 19 | conda install pytorch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 cudatoolkit=10.2 -c pytorch 20 | ``` 21 | 22 | 3. Install the packages in `requirements.txt` via `pip`: 23 | ```shell 24 | pip install -r requirements.txt 25 | ``` 26 | ### The Initialization Weights for Training 27 | 1. Create the `./pretrained_weights` directory where we will be storing the weights. 28 | ```shell 29 | mkdir ./pretrained_weights 30 | ``` 31 | 2. Download [pre-trained classification weights of 32 | the Swin Transformer](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth), 33 | and put the `pth` file in `./pretrained_weights`. 34 | These weights are needed for training to initialize the model. 35 | 36 | ## Datasets 37 | We perform all experiments on our proposed dataset RRSIS-D. RRSIS-D is a new Referring Remote Sensing Image Segmentation benchmark which containes 17,402 image-caption-mask triplets. It can be downloaded from [Google Drive](https://drive.google.com/drive/folders/1Xqi3Am2Vgm4a5tHqiV9tfaqKNovcuK3A?usp=sharing) or [Baidu Netdisk](https://pan.baidu.com/s/1yZatV2w_bSXIP9QBv2lCrA?pwd=sjoe) (access code: sjoe). 38 | ### Usage 39 | 1. Download our dataset. 40 | 2. Copy all the downloaded files to `./refer/data/`. The dataset folder should be like this: 41 | ``` 42 | $DATA_PATH 43 | ├── rrsisd 44 | │ ├── refs(unc).p 45 | │ ├── instances.json 46 | └── images 47 | └── rrsisd 48 | ├── JPEGImages 49 | ├── ann_split 50 | 51 | ``` 52 | 53 | ## Training 54 | We use DistributedDataParallel from PyTorch for training. To run on 4 GPUs (with IDs 0, 1, 2, and 3) on a single node: 55 | ```shell 56 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node 4 --master_port 12345 train.py --dataset rrsisd --model_id RMSIN --epochs 40 --img_size 480 2>&1 | tee ./output 57 | ``` 58 | 59 | ## Testing 60 | ```shell 61 | python test.py --swin_type base --dataset rrsisd --resume ./your_checkpoints_path --split val --workers 4 --window12 --img_size 480 62 | ``` 63 | 64 | ## Acknowledgements 65 | Code in this repository is built on [LAVT](https://github.com/yz93/LAVT-RIS). We'd like to thank the authors for open sourcing their project. 66 | -------------------------------------------------------------------------------- /arc/__init__.py: -------------------------------------------------------------------------------- 1 | from .adaptive_rotated_conv import AdaptiveRotatedConv2d 2 | from .routing_function import RountingFunction 3 | 4 | __all__ = [ 5 | 'AdaptiveRotatedConv2d', 'RountingFunction', 6 | ] 7 | -------------------------------------------------------------------------------- /arc/adaptive_rotated_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | 6 | __all__ = ['AdaptiveRotatedConv2d'] 7 | 8 | 9 | def _get_rotation_matrix(thetas): 10 | bs, g = thetas.shape 11 | device = thetas.device 12 | thetas = thetas.reshape(-1) # [bs, n] --> [bs x n] 13 | 14 | x = torch.cos(thetas) 15 | y = torch.sin(thetas) 16 | x = x.unsqueeze(0).unsqueeze(0) # shape = [1, 1, bs * g] 17 | y = y.unsqueeze(0).unsqueeze(0) 18 | a = x - y 19 | b = x * y 20 | c = x + y 21 | 22 | rot_mat_positive = torch.cat(( 23 | torch.cat((a, 1-a, torch.zeros(1, 7, bs*g, device=device)), dim=1), 24 | torch.cat((torch.zeros(1, 1, bs*g, device=device), x-b, b, torch.zeros(1, 1, bs*g, device=device), 1-c+b, y-b, torch.zeros(1, 3, bs*g, device=device)), dim=1), 25 | torch.cat((torch.zeros(1, 2, bs*g, device=device), a, torch.zeros(1, 2, bs*g, device=device), 1-a, torch.zeros(1, 3, bs*g, device=device)), dim=1), 26 | torch.cat((b, y-b, torch.zeros(1,1 , bs*g, device=device), x-b, 1-c+b, torch.zeros(1, 4, bs*g, device=device)), dim=1), 27 | torch.cat((torch.zeros(1, 4, bs*g, device=device), torch.ones(1, 1, bs*g, device=device), torch.zeros(1, 4, bs*g, device=device)), dim=1), 28 | torch.cat((torch.zeros(1, 4, bs*g, device=device), 1-c+b, x-b, torch.zeros(1, 1, bs*g, device=device), y-b, b), dim=1), 29 | torch.cat((torch.zeros(1, 3, bs*g, device=device), 1-a, torch.zeros(1, 2, bs*g, device=device), a, torch.zeros(1, 2, bs*g, device=device)), dim=1), 30 | torch.cat((torch.zeros(1, 3, bs*g, device=device), y-b, 1-c+b, torch.zeros(1, 1, bs*g, device=device), b, x-b, torch.zeros(1, 1, bs*g, device=device)), dim=1), 31 | torch.cat((torch.zeros(1, 7, bs*g, device=device), 1-a, a), dim=1) 32 | ), dim=0) # shape = [k^2, k^2, bs*g] 33 | 34 | rot_mat_negative = torch.cat(( 35 | torch.cat((c, torch.zeros(1, 2, bs*g, device=device), 1-c, torch.zeros(1, 5, bs*g, device=device)), dim=1), 36 | torch.cat((-b, x+b, torch.zeros(1, 1, bs*g, device=device), b-y, 1-a-b, torch.zeros(1, 4, bs*g, device=device)), dim=1), 37 | torch.cat((torch.zeros(1, 1, bs*g, device=device), 1-c, c, torch.zeros(1, 6, bs*g, device=device)), dim=1), 38 | torch.cat((torch.zeros(1, 3, bs*g, device=device), x+b, 1-a-b, torch.zeros(1, 1, bs*g, device=device), -b, b-y, torch.zeros(1, 1, bs*g, device=device)), dim=1), 39 | torch.cat((torch.zeros(1, 4, bs*g, device=device), torch.ones(1, 1, bs*g, device=device), torch.zeros(1, 4, bs*g, device=device)), dim=1), 40 | torch.cat((torch.zeros(1, 1, bs*g, device=device), b-y, -b, torch.zeros(1, 1, bs*g, device=device), 1-a-b, x+b, torch.zeros(1, 3, bs*g, device=device)), dim=1), 41 | torch.cat((torch.zeros(1, 6, bs*g, device=device), c, 1-c, torch.zeros(1, 1, bs*g, device=device)), dim=1), 42 | torch.cat((torch.zeros(1, 4, bs*g, device=device), 1-a-b, b-y, torch.zeros(1, 1, bs*g, device=device), x+b, -b), dim=1), 43 | torch.cat((torch.zeros(1, 5, bs*g, device=device), 1-c, torch.zeros(1, 2, bs*g, device=device), c), dim=1) 44 | ), dim=0) # shape = [k^2, k^2, bs*g] 45 | 46 | mask = (thetas >= 0).unsqueeze(0).unsqueeze(0) 47 | mask = mask.float() # shape = [1, 1, bs*g] 48 | rot_mat = mask * rot_mat_positive + (1 - mask) * rot_mat_negative # shape = [k*k, k*k, bs*g] 49 | rot_mat = rot_mat.permute(2, 0, 1) # shape = [bs*g, k*k, k*k] 50 | rot_mat = rot_mat.reshape(bs, g, rot_mat.shape[1], rot_mat.shape[2]) # shape = [bs, g, k*k, k*k] 51 | return rot_mat 52 | 53 | 54 | def batch_rotate_multiweight(weights, lambdas, thetas): 55 | """ 56 | Let 57 | batch_size = b 58 | kernel_number = n 59 | kernel_size = 3 60 | Args: 61 | weights: tensor, shape = [kernel_number, Cout, Cin, k, k] 62 | thetas: tensor of thetas, shape = [batch_size, kernel_number] 63 | Return: 64 | weights_out: tensor, shape = [batch_size x Cout, Cin // groups, k, k] 65 | """ 66 | assert(thetas.shape == lambdas.shape) 67 | assert(lambdas.shape[1] == weights.shape[0]) 68 | 69 | b = thetas.shape[0] 70 | n = thetas.shape[1] 71 | k = weights.shape[-1] 72 | _, Cout, Cin, _, _ = weights.shape 73 | 74 | if k == 3 : 75 | # Stage 1: 76 | # input: thetas: [b, n] 77 | # lambdas: [b, n] 78 | # output: rotation_matrix: [b, n, 9, 9] (with gate) --> [b*9, n*9] 79 | 80 | # Sub_Stage 1.1: 81 | # input: [b, n] kernel 82 | # output: [b, n, 9, 9] rotation matrix 83 | rotation_matrix = _get_rotation_matrix(thetas) 84 | 85 | # Sub_Stage 1.2: 86 | # input: [b, n, 9, 9] rotation matrix 87 | # [b, n] lambdas 88 | # --> [b, n, 1, 1] lambdas 89 | # --> [b, n, 1, 1] lambdas dot [b, n, 9, 9] rotation matrix 90 | # --> [b, n, 9, 9] rotation matrix with gate (done) 91 | # output: [b, n, 9, 9] rotation matrix with gate 92 | lambdas = lambdas.unsqueeze(2).unsqueeze(3) 93 | rotation_matrix = torch.mul(rotation_matrix, lambdas) 94 | 95 | # Sub_Stage 1.3: Reshape 96 | # input: [b, n, 9, 9] rotation matrix with gate 97 | # output: [b*9, n*9] rotation matrix with gate 98 | rotation_matrix = rotation_matrix.permute(0, 2, 1, 3) 99 | rotation_matrix = rotation_matrix.reshape(b*k*k, n*k*k) 100 | 101 | # Stage 2: Reshape 102 | # input: weights: [n, Cout, Cin, 3, 3] 103 | # --> [n, 3, 3, Cout, Cin] 104 | # --> [n*9, Cout*Cin] done 105 | # output: weights: [n*9, Cout*Cin] 106 | weights = weights.permute(0, 3, 4, 1, 2) 107 | weights = weights.contiguous().view(n*k*k, Cout*Cin) 108 | 109 | 110 | # Stage 3: torch.mm 111 | # [b*9, n*9] x [n*9, Cout*Cin] 112 | # --> [b*9, Cout*Cin] 113 | weights = torch.mm(rotation_matrix, weights) 114 | 115 | # Stage 4: Reshape Back 116 | # input: [b*9, Cout*Cin] 117 | # --> [b, 3, 3, Cout, Cin] 118 | # --> [b, Cout, Cin, 3, 3] 119 | # --> [b * Cout, Cin, 3, 3] done 120 | # output: [b * Cout, Cin, 3, 3] 121 | weights = weights.contiguous().view(b, k, k, Cout, Cin) 122 | weights = weights.permute(0, 3, 4, 1, 2) 123 | weights = weights.reshape(b * Cout, Cin, k, k) 124 | else: 125 | thetas = thetas.reshape(-1) # [bs, n] --> [bs x n] 126 | 127 | x = torch.cos(thetas) 128 | y = torch.sin(thetas) 129 | rotate_matrix = torch.tensor([[x, -y, 0], [y, x, 0]]) 130 | rotate_matrix = rotate_matrix.unsqueeze(0).repeat(n, 1, 1) 131 | 132 | weights = weights.contiguous().view(n, Cout*Cin, k, k) 133 | 134 | grid = F.affine_grid(rotate_matrix, weights.shape) 135 | weights = F.grid_sample(weights, grid, mode='biliner') 136 | 137 | return weights 138 | 139 | 140 | class AdaptiveRotatedConv2d(nn.Module): 141 | 142 | def __init__(self, in_channels, out_channels, kernel_size, 143 | stride=1, padding=1, dilation=1, groups=1, bias=False, 144 | kernel_number=1, rounting_func=None, rotate_func=batch_rotate_multiweight): 145 | super().__init__() 146 | self.kernel_number = kernel_number 147 | self.in_channels = in_channels 148 | self.out_channels = out_channels 149 | self.kernel_size = kernel_size 150 | self.stride = stride 151 | self.padding = padding 152 | self.dilation = dilation 153 | self.groups = groups 154 | self.bias = bias 155 | 156 | self.rounting_func = rounting_func 157 | self.rotate_func = rotate_func 158 | 159 | self.weight = nn.Parameter( 160 | torch.Tensor( 161 | kernel_number, 162 | out_channels, 163 | in_channels // groups, 164 | kernel_size, 165 | kernel_size, 166 | ) 167 | ) 168 | nn.init.kaiming_normal_(self.weight, mode='fan_out', nonlinearity='relu') 169 | 170 | def forward(self, x): 171 | # get alphas, angles 172 | # # [bs, Cin, h, w] --> [bs, n_theta], [bs, n_theta] 173 | alphas, angles = self.rounting_func(x) 174 | 175 | # rotate weight 176 | # # [Cout, Cin, k, k] --> [bs * Cout, Cin, k, k] 177 | # print(self.weight.shape) 178 | rotated_weight = self.rotate_func(self.weight, alphas, angles) 179 | 180 | # reshape images 181 | bs, Cin, h, w = x.shape 182 | x = x.reshape(1, bs * Cin, h, w) # [1, bs * Cin, h, w] 183 | 184 | # adaptive conv over images using group conv 185 | out = F.conv2d(input=x, weight=rotated_weight, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=(self.groups * bs)) 186 | 187 | # reshape back 188 | out = out.reshape(bs, self.out_channels, *out.shape[2:]) 189 | return out 190 | 191 | def extra_repr(self): 192 | s = ('{in_channels}, {out_channels}, kernel_number={kernel_number}' 193 | ', kernel_size={kernel_size}, stride={stride}, bias={bias}') 194 | 195 | if self.padding != (0,) * len([self.padding]): 196 | s += ', padding={padding}' 197 | if self.dilation != (1,) * len([self.dilation]): 198 | s += ', dilation={dilation}' 199 | if self.groups != 1: 200 | s += ', groups={groups}' 201 | return s.format(**self.__dict__) 202 | -------------------------------------------------------------------------------- /arc/routing_function.py: -------------------------------------------------------------------------------- 1 | import math 2 | import einops 3 | import torch 4 | import torch.nn as nn 5 | from .weight_init import trunc_normal_ 6 | 7 | 8 | class LayerNormProxy(nn.Module): 9 | # copy from https://github.com/LeapLabTHU/DAT/blob/main/models/dat_blocks.py 10 | def __init__(self, dim): 11 | super().__init__() 12 | self.norm = nn.LayerNorm(dim) 13 | 14 | def forward(self, x): 15 | x = einops.rearrange(x, 'b c h w -> b h w c') 16 | x = self.norm(x) 17 | return einops.rearrange(x, 'b h w c -> b c h w') 18 | 19 | 20 | class RountingFunction(nn.Module): 21 | 22 | def __init__(self, in_channels, kernel_number, dropout_rate=0.2, proportion=40.0): 23 | super().__init__() 24 | self.kernel_number = kernel_number 25 | self.dwc = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, 26 | groups=in_channels, bias=False) 27 | self.norm = LayerNormProxy(in_channels) 28 | self.relu = nn.ReLU(inplace=True) 29 | 30 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 31 | 32 | self.dropout1 = nn.Dropout(dropout_rate) 33 | self.fc_alpha = nn.Linear(in_channels, kernel_number, bias=True) 34 | 35 | self.dropout2= nn.Dropout(dropout_rate) 36 | self.fc_theta = nn.Linear(in_channels, kernel_number, bias=False) 37 | 38 | self.act_func = nn.Softsign() 39 | self.proportion = proportion / 180.0 * math.pi 40 | 41 | # init weights 42 | trunc_normal_(self.dwc.weight, std=.02) 43 | trunc_normal_(self.fc_alpha.weight, std=.02) 44 | trunc_normal_(self.fc_theta.weight, std=.02) 45 | 46 | def forward(self, x): 47 | 48 | x = self.dwc(x) 49 | x = self.norm(x) 50 | x = self.relu(x) 51 | 52 | x = self.avg_pool(x).squeeze(dim=-1).squeeze(dim=-1) # avg_x.shape = [batch_size, Cin] 53 | 54 | alphas = self.dropout1(x) 55 | alphas = self.fc_alpha(alphas) 56 | alphas = torch.sigmoid(alphas) 57 | 58 | angles = self.dropout2(x) 59 | angles = self.fc_theta(angles) 60 | angles = self.act_func(angles) 61 | angles = angles * self.proportion 62 | 63 | return alphas, angles 64 | 65 | def extra_repr(self): 66 | s = (f'kernel_number={self.kernel_number}') 67 | return s.format(**self.__dict__) 68 | -------------------------------------------------------------------------------- /arc/weight_init.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # get from https://github.com/rwightman/pytorch-image-models/blob/main/timm/models/layers/weight_init.py 3 | # -------------------------------------------------------- 4 | import torch 5 | import math 6 | import warnings 7 | 8 | 9 | def _trunc_normal_(tensor, mean, std, a, b): 10 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 11 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 12 | def norm_cdf(x): 13 | # Computes standard normal cumulative distribution function 14 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 15 | 16 | if (mean < a - 2 * std) or (mean > b + 2 * std): 17 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 18 | "The distribution of values may be incorrect.", 19 | stacklevel=2) 20 | 21 | # Values are generated by using a truncated uniform distribution and 22 | # then using the inverse CDF for the normal distribution. 23 | # Get upper and lower cdf values 24 | l = norm_cdf((a - mean) / std) 25 | u = norm_cdf((b - mean) / std) 26 | 27 | # Uniformly fill tensor with values from [l, u], then translate to 28 | # [2l-1, 2u-1]. 29 | tensor.uniform_(2 * l - 1, 2 * u - 1) 30 | 31 | # Use inverse cdf transform for normal distribution to get truncated 32 | # standard normal 33 | tensor.erfinv_() 34 | 35 | # Transform to proper mean, std 36 | tensor.mul_(std * math.sqrt(2.)) 37 | tensor.add_(mean) 38 | 39 | # Clamp to ensure it's in the proper range 40 | tensor.clamp_(min=a, max=b) 41 | return tensor 42 | 43 | 44 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 45 | # type: (Tensor, float, float, float, float) -> Tensor 46 | r"""Fills the input Tensor with values drawn from a truncated 47 | normal distribution. The values are effectively drawn from the 48 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 49 | with values outside :math:`[a, b]` redrawn until they are within 50 | the bounds. The method used for generating the random values works 51 | best when :math:`a \leq \text{mean} \leq b`. 52 | NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are 53 | applied while sampling the normal with mean/std applied, therefore a, b args 54 | should be adjusted to match the range of mean, std args. 55 | Args: 56 | tensor: an n-dimensional `torch.Tensor` 57 | mean: the mean of the normal distribution 58 | std: the standard deviation of the normal distribution 59 | a: the minimum cutoff value 60 | b: the maximum cutoff value 61 | Examples: 62 | >>> w = torch.empty(3, 5) 63 | >>> nn.init.trunc_normal_(w) 64 | """ 65 | with torch.no_grad(): 66 | return _trunc_normal_(tensor, mean, std, a, b) 67 | -------------------------------------------------------------------------------- /args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def get_parser(): 5 | parser = argparse.ArgumentParser(description='RMSIN training and testing') 6 | parser.add_argument('--amsgrad', action='store_true', 7 | help='if true, set amsgrad to True in an Adam or AdamW optimizer.') 8 | parser.add_argument('-b', '--batch-size', default=2, type=int) 9 | parser.add_argument('--bert_tokenizer', default='bert-base-uncased', help='BERT tokenizer') 10 | parser.add_argument('--ck_bert', default='bert-base-uncased', help='pre-trained BERT weights') 11 | parser.add_argument('--dataset', default='rrsisd', help='refcoco, refcoco+, or refcocog') 12 | parser.add_argument('--ddp_trained_weights', action='store_true', 13 | help='Only needs specified when testing,' 14 | 'whether the weights to be loaded are from a DDP-trained model') 15 | parser.add_argument('--device', default='cuda:0', help='device') # only used when testing on a single machine 16 | parser.add_argument('--epochs', default=40, type=int, metavar='N', help='number of total epochs to run') 17 | parser.add_argument('--fusion_drop', default=0.0, type=float, help='dropout rate for PWAMs') 18 | parser.add_argument('--img_size', default=480, type=int, help='input image size') 19 | parser.add_argument("--local_rank", type=int,default=0,help='local rank for DistributedDataParallel') 20 | parser.add_argument('--lr', default=0.00003, type=float, help='the initial learning rate') 21 | parser.add_argument('--mha', default='', help='If specified, should be in the format of a-b-c-d, e.g., 4-4-4-4,' 22 | 'where a, b, c, and d refer to the numbers of heads in stage-1,' 23 | 'stage-2, stage-3, and stage-4 PWAMs') 24 | parser.add_argument('--model', default='lavt_one', help='model: lavt, lavt_one') 25 | parser.add_argument('--model_id', default='RMSIN', help='name to identify the model') 26 | parser.add_argument('--output-dir', default='./checkpoints/', help='path where to save checkpoint weights') 27 | parser.add_argument('--pin_mem', action='store_true', 28 | help='If true, pin memory when using the data loader.') 29 | parser.add_argument('--pretrained_swin_weights', default='./pretrained_weights/swin_base_patch4_window12_384_22k.pth', 30 | help='path to pre-trained Swin backbone weights') 31 | parser.add_argument('--print-freq', default=10, type=int, help='print frequency') 32 | parser.add_argument('--refer_data_root', default='./refer/data/', help='REFER dataset root directory') 33 | parser.add_argument('--resume', default='', help='resume from checkpoint') 34 | parser.add_argument('--split', default='test', help='only used when testing') 35 | parser.add_argument('--splitBy', default='unc', help='change to umd or google when the datasset is G-Ref (RefCOCOg)') 36 | parser.add_argument('--swin_type', default='base', 37 | help='tiny, small, base, or large variants of the Swin Transformer') 38 | parser.add_argument('--wd', '--weight-decay', default=1e-2, type=float, metavar='W', help='weight decay', 39 | dest='weight_decay') 40 | parser.add_argument('--window12', action='store_true', 41 | help='only needs specified when testing,' 42 | 'when training, window size is inferred from pre-trained weights file name' 43 | '(containing \'window12\'). Initialize Swin with window size 12 instead of the default 7.') 44 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', help='number of data loading workers') 45 | return parser 46 | 47 | 48 | if __name__ == "__main__": 49 | parser = get_parser() 50 | args_dict = parser.parse_args() 51 | -------------------------------------------------------------------------------- /bert/activations.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | def swish(x): 12 | return x * torch.sigmoid(x) 13 | 14 | 15 | def _gelu_python(x): 16 | """ Original Implementation of the gelu activation function in Google Bert repo when initially created. 17 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 18 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 19 | This is now written in C in torch.nn.functional 20 | Also see https://arxiv.org/abs/1606.08415 21 | """ 22 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 23 | 24 | 25 | def gelu_new(x): 26 | """ Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT). 27 | Also see https://arxiv.org/abs/1606.08415 28 | """ 29 | return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) 30 | 31 | 32 | if torch.__version__ < "1.4.0": 33 | gelu = _gelu_python 34 | else: 35 | gelu = F.gelu 36 | 37 | 38 | def gelu_fast(x): 39 | return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x))) 40 | 41 | 42 | ACT2FN = { 43 | "relu": F.relu, 44 | "swish": swish, 45 | "gelu": gelu, 46 | "tanh": torch.tanh, 47 | "gelu_new": gelu_new, 48 | "gelu_fast": gelu_fast, 49 | } 50 | 51 | 52 | def get_activation(activation_string): 53 | if activation_string in ACT2FN: 54 | return ACT2FN[activation_string] 55 | else: 56 | raise KeyError("function {} not found in ACT2FN mapping {}".format(activation_string, list(ACT2FN.keys()))) 57 | -------------------------------------------------------------------------------- /bert/configuration_bert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ BERT model configuration """ 17 | 18 | 19 | import logging 20 | 21 | from .configuration_utils import PretrainedConfig 22 | 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { 27 | "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json", 28 | "bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json", 29 | "bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json", 30 | "bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json", 31 | "bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json", 32 | "bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json", 33 | "bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json", 34 | "bert-base-german-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json", 35 | "bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json", 36 | "bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json", 37 | "bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json", 38 | "bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json", 39 | "bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json", 40 | "bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-config.json", 41 | "bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-config.json", 42 | "cl-tohoku/bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese/config.json", 43 | "cl-tohoku/bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking/config.json", 44 | "cl-tohoku/bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char/config.json", 45 | "cl-tohoku/bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking/config.json", 46 | "TurkuNLP/bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/config.json", 47 | "TurkuNLP/bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/config.json", 48 | "wietsedv/bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/config.json", 49 | # See all BERT models at https://huggingface.co/models?filter=bert 50 | } 51 | 52 | 53 | class BertConfig(PretrainedConfig): 54 | r""" 55 | This is the configuration class to store the configuration of a :class:`~transformers.BertModel`. 56 | It is used to instantiate an BERT model according to the specified arguments, defining the model 57 | architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of 58 | the BERT `bert-base-uncased `__ architecture. 59 | 60 | Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used 61 | to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig` 62 | for more information. 63 | 64 | 65 | Args: 66 | vocab_size (:obj:`int`, optional, defaults to 30522): 67 | Vocabulary size of the BERT model. Defines the different tokens that 68 | can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.BertModel`. 69 | hidden_size (:obj:`int`, optional, defaults to 768): 70 | Dimensionality of the encoder layers and the pooler layer. 71 | num_hidden_layers (:obj:`int`, optional, defaults to 12): 72 | Number of hidden layers in the Transformer encoder. 73 | num_attention_heads (:obj:`int`, optional, defaults to 12): 74 | Number of attention heads for each attention layer in the Transformer encoder. 75 | intermediate_size (:obj:`int`, optional, defaults to 3072): 76 | Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. 77 | hidden_act (:obj:`str` or :obj:`function`, optional, defaults to "gelu"): 78 | The non-linear activation function (function or string) in the encoder and pooler. 79 | If string, "gelu", "relu", "swish" and "gelu_new" are supported. 80 | hidden_dropout_prob (:obj:`float`, optional, defaults to 0.1): 81 | The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler. 82 | attention_probs_dropout_prob (:obj:`float`, optional, defaults to 0.1): 83 | The dropout ratio for the attention probabilities. 84 | max_position_embeddings (:obj:`int`, optional, defaults to 512): 85 | The maximum sequence length that this model might ever be used with. 86 | Typically set this to something large just in case (e.g., 512 or 1024 or 2048). 87 | type_vocab_size (:obj:`int`, optional, defaults to 2): 88 | The vocabulary size of the `token_type_ids` passed into :class:`~transformers.BertModel`. 89 | initializer_range (:obj:`float`, optional, defaults to 0.02): 90 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 91 | layer_norm_eps (:obj:`float`, optional, defaults to 1e-12): 92 | The epsilon used by the layer normalization layers. 93 | gradient_checkpointing (:obj:`bool`, optional, defaults to False): 94 | If True, use gradient checkpointing to save memory at the expense of slower backward pass. 95 | 96 | Example:: 97 | 98 | >>> from transformers import BertModel, BertConfig 99 | 100 | >>> # Initializing a BERT bert-base-uncased style configuration 101 | >>> configuration = BertConfig() 102 | 103 | >>> # Initializing a model from the bert-base-uncased style configuration 104 | >>> model = BertModel(configuration) 105 | 106 | >>> # Accessing the model configuration 107 | >>> configuration = model.config 108 | """ 109 | model_type = "bert" 110 | 111 | def __init__( 112 | self, 113 | vocab_size=30522, 114 | hidden_size=768, 115 | num_hidden_layers=12, 116 | num_attention_heads=12, 117 | intermediate_size=3072, 118 | hidden_act="gelu", 119 | hidden_dropout_prob=0.1, 120 | attention_probs_dropout_prob=0.1, 121 | max_position_embeddings=512, 122 | type_vocab_size=2, 123 | initializer_range=0.02, 124 | layer_norm_eps=1e-12, 125 | pad_token_id=0, 126 | gradient_checkpointing=False, 127 | **kwargs 128 | ): 129 | super().__init__(pad_token_id=pad_token_id, **kwargs) 130 | 131 | self.vocab_size = vocab_size 132 | self.hidden_size = hidden_size 133 | self.num_hidden_layers = num_hidden_layers 134 | self.num_attention_heads = num_attention_heads 135 | self.hidden_act = hidden_act 136 | self.intermediate_size = intermediate_size 137 | self.hidden_dropout_prob = hidden_dropout_prob 138 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 139 | self.max_position_embeddings = max_position_embeddings 140 | self.type_vocab_size = type_vocab_size 141 | self.initializer_range = initializer_range 142 | self.layer_norm_eps = layer_norm_eps 143 | self.gradient_checkpointing = gradient_checkpointing 144 | -------------------------------------------------------------------------------- /bert/configuration_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Configuration base class and utilities.""" 17 | 18 | 19 | import copy 20 | import json 21 | import logging 22 | import os 23 | from typing import Dict, Tuple 24 | 25 | from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url 26 | 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | 31 | class PretrainedConfig(object): 32 | r""" Base class for all configuration classes. 33 | Handles a few parameters common to all models' configurations as well as methods for loading/downloading/saving configurations. 34 | 35 | Note: 36 | A configuration file can be loaded and saved to disk. Loading the configuration file and using this file to initialize a model does **not** load the model weights. 37 | It only affects the model's configuration. 38 | 39 | Class attributes (overridden by derived classes): 40 | - ``model_type``: a string that identifies the model type, that we serialize into the JSON file, and that we use to recreate the correct object in :class:`~transformers.AutoConfig`. 41 | 42 | Args: 43 | finetuning_task (:obj:`string` or :obj:`None`, `optional`, defaults to :obj:`None`): 44 | Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow or PyTorch) checkpoint. 45 | num_labels (:obj:`int`, `optional`, defaults to `2`): 46 | Number of classes to use when the model is a classification model (sequences/tokens) 47 | output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`False`): 48 | Should the model returns all hidden-states. 49 | output_attentions (:obj:`bool`, `optional`, defaults to :obj:`False`): 50 | Should the model returns all attentions. 51 | torchscript (:obj:`bool`, `optional`, defaults to :obj:`False`): 52 | Is the model used with Torchscript (for PyTorch models). 53 | """ 54 | model_type: str = "" 55 | 56 | def __init__(self, **kwargs): 57 | # Attributes with defaults 58 | self.output_hidden_states = kwargs.pop("output_hidden_states", False) 59 | self.output_attentions = kwargs.pop("output_attentions", False) 60 | self.use_cache = kwargs.pop("use_cache", True) # Not used by all models 61 | self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models 62 | self.use_bfloat16 = kwargs.pop("use_bfloat16", False) 63 | self.pruned_heads = kwargs.pop("pruned_heads", {}) 64 | 65 | # Is decoder is used in encoder-decoder models to differentiate encoder from decoder 66 | self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False) 67 | self.is_decoder = kwargs.pop("is_decoder", False) 68 | 69 | # Parameters for sequence generation 70 | self.max_length = kwargs.pop("max_length", 20) 71 | self.min_length = kwargs.pop("min_length", 0) 72 | self.do_sample = kwargs.pop("do_sample", False) 73 | self.early_stopping = kwargs.pop("early_stopping", False) 74 | self.num_beams = kwargs.pop("num_beams", 1) 75 | self.temperature = kwargs.pop("temperature", 1.0) 76 | self.top_k = kwargs.pop("top_k", 50) 77 | self.top_p = kwargs.pop("top_p", 1.0) 78 | self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0) 79 | self.length_penalty = kwargs.pop("length_penalty", 1.0) 80 | self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0) 81 | self.bad_words_ids = kwargs.pop("bad_words_ids", None) 82 | self.num_return_sequences = kwargs.pop("num_return_sequences", 1) 83 | 84 | # Fine-tuning task arguments 85 | self.architectures = kwargs.pop("architectures", None) 86 | self.finetuning_task = kwargs.pop("finetuning_task", None) 87 | self.id2label = kwargs.pop("id2label", None) 88 | self.label2id = kwargs.pop("label2id", None) 89 | if self.id2label is not None: 90 | kwargs.pop("num_labels", None) 91 | self.id2label = dict((int(key), value) for key, value in self.id2label.items()) 92 | # Keys are always strings in JSON so convert ids to int here. 93 | else: 94 | self.num_labels = kwargs.pop("num_labels", 2) 95 | 96 | # Tokenizer arguments TODO: eventually tokenizer and models should share the same config 97 | self.prefix = kwargs.pop("prefix", None) 98 | self.bos_token_id = kwargs.pop("bos_token_id", None) 99 | self.pad_token_id = kwargs.pop("pad_token_id", None) 100 | self.eos_token_id = kwargs.pop("eos_token_id", None) 101 | self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None) 102 | 103 | # task specific arguments 104 | self.task_specific_params = kwargs.pop("task_specific_params", None) 105 | 106 | # TPU arguments 107 | self.xla_device = kwargs.pop("xla_device", None) 108 | 109 | # Additional attributes without default values 110 | for key, value in kwargs.items(): 111 | try: 112 | setattr(self, key, value) 113 | except AttributeError as err: 114 | logger.error("Can't set {} with value {} for {}".format(key, value, self)) 115 | raise err 116 | 117 | @property 118 | def num_labels(self): 119 | return len(self.id2label) 120 | 121 | @num_labels.setter 122 | def num_labels(self, num_labels): 123 | self.id2label = {i: "LABEL_{}".format(i) for i in range(num_labels)} 124 | self.label2id = dict(zip(self.id2label.values(), self.id2label.keys())) 125 | 126 | def save_pretrained(self, save_directory): 127 | """ 128 | Save a configuration object to the directory `save_directory`, so that it 129 | can be re-loaded using the :func:`~transformers.PretrainedConfig.from_pretrained` class method. 130 | 131 | Args: 132 | save_directory (:obj:`string`): 133 | Directory where the configuration JSON file will be saved. 134 | """ 135 | if os.path.isfile(save_directory): 136 | raise AssertionError("Provided path ({}) should be a directory, not a file".format(save_directory)) 137 | os.makedirs(save_directory, exist_ok=True) 138 | # If we save using the predefined names, we can load using `from_pretrained` 139 | output_config_file = os.path.join(save_directory, CONFIG_NAME) 140 | 141 | self.to_json_file(output_config_file, use_diff=True) 142 | logger.info("Configuration saved in {}".format(output_config_file)) 143 | 144 | @classmethod 145 | def from_pretrained(cls, pretrained_model_name_or_path, **kwargs) -> "PretrainedConfig": 146 | r""" 147 | 148 | Instantiate a :class:`~transformers.PretrainedConfig` (or a derived class) from a pre-trained model configuration. 149 | 150 | Args: 151 | pretrained_model_name_or_path (:obj:`string`): 152 | either: 153 | - a string with the `shortcut name` of a pre-trained model configuration to load from cache or 154 | download, e.g.: ``bert-base-uncased``. 155 | - a string with the `identifier name` of a pre-trained model configuration that was user-uploaded to 156 | our S3, e.g.: ``dbmdz/bert-base-german-cased``. 157 | - a path to a `directory` containing a configuration file saved using the 158 | :func:`~transformers.PretrainedConfig.save_pretrained` method, e.g.: ``./my_model_directory/``. 159 | - a path or url to a saved configuration JSON `file`, e.g.: 160 | ``./my_model_directory/configuration.json``. 161 | cache_dir (:obj:`string`, `optional`): 162 | Path to a directory in which a downloaded pre-trained model 163 | configuration should be cached if the standard cache should not be used. 164 | kwargs (:obj:`Dict[str, any]`, `optional`): 165 | The values in kwargs of any keys which are configuration attributes will be used to override the loaded 166 | values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is 167 | controlled by the `return_unused_kwargs` keyword parameter. 168 | force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): 169 | Force to (re-)download the model weights and configuration files and override the cached versions if they exist. 170 | resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): 171 | Do not delete incompletely recieved file. Attempt to resume the download if such a file exists. 172 | proxies (:obj:`Dict`, `optional`): 173 | A dictionary of proxy servers to use by protocol or endpoint, e.g.: 174 | :obj:`{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.` 175 | The proxies are used on each request. 176 | return_unused_kwargs: (`optional`) bool: 177 | If False, then this function returns just the final configuration object. 178 | If True, then this functions returns a :obj:`Tuple(config, unused_kwargs)` where `unused_kwargs` is a 179 | dictionary consisting of the key/value pairs whose keys are not configuration attributes: ie the part 180 | of kwargs which has not been used to update `config` and is otherwise ignored. 181 | 182 | Returns: 183 | :class:`PretrainedConfig`: An instance of a configuration object 184 | 185 | Examples:: 186 | 187 | # We can't instantiate directly the base class `PretrainedConfig` so let's show the examples on a 188 | # derived class: BertConfig 189 | config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache. 190 | config = BertConfig.from_pretrained('./test/saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')` 191 | config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json') 192 | config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False) 193 | assert config.output_attention == True 194 | config, unused_kwargs = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, 195 | foo=False, return_unused_kwargs=True) 196 | assert config.output_attention == True 197 | assert unused_kwargs == {'foo': False} 198 | 199 | """ 200 | config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) 201 | return cls.from_dict(config_dict, **kwargs) 202 | 203 | @classmethod 204 | def get_config_dict(cls, pretrained_model_name_or_path: str, **kwargs) -> Tuple[Dict, Dict]: 205 | """ 206 | From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used 207 | for instantiating a Config using `from_dict`. 208 | 209 | Parameters: 210 | pretrained_model_name_or_path (:obj:`string`): 211 | The identifier of the pre-trained checkpoint from which we want the dictionary of parameters. 212 | 213 | Returns: 214 | :obj:`Tuple[Dict, Dict]`: The dictionary that will be used to instantiate the configuration object. 215 | 216 | """ 217 | cache_dir = kwargs.pop("cache_dir", None) 218 | force_download = kwargs.pop("force_download", False) 219 | resume_download = kwargs.pop("resume_download", False) 220 | proxies = kwargs.pop("proxies", None) 221 | local_files_only = kwargs.pop("local_files_only", False) 222 | 223 | if os.path.isdir(pretrained_model_name_or_path): 224 | config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) 225 | elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): 226 | config_file = pretrained_model_name_or_path 227 | else: 228 | config_file = hf_bucket_url(pretrained_model_name_or_path, filename=CONFIG_NAME, use_cdn=False) 229 | 230 | try: 231 | # Load from URL or cache if already cached 232 | resolved_config_file = cached_path( 233 | config_file, 234 | cache_dir=cache_dir, 235 | force_download=force_download, 236 | proxies=proxies, 237 | resume_download=resume_download, 238 | local_files_only=local_files_only, 239 | ) 240 | # Load config dict 241 | if resolved_config_file is None: 242 | raise EnvironmentError 243 | config_dict = cls._dict_from_json_file(resolved_config_file) 244 | 245 | except EnvironmentError: 246 | msg = ( 247 | f"Can't load config for '{pretrained_model_name_or_path}'. Make sure that:\n\n" 248 | f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n" 249 | f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a {CONFIG_NAME} file\n\n" 250 | ) 251 | raise EnvironmentError(msg) 252 | 253 | except json.JSONDecodeError: 254 | msg = ( 255 | "Couldn't reach server at '{}' to download configuration file or " 256 | "configuration file is not a valid JSON file. " 257 | "Please check network or file content here: {}.".format(config_file, resolved_config_file) 258 | ) 259 | raise EnvironmentError(msg) 260 | 261 | if resolved_config_file == config_file: 262 | logger.info("loading configuration file {}".format(config_file)) 263 | else: 264 | logger.info("loading configuration file {} from cache at {}".format(config_file, resolved_config_file)) 265 | 266 | return config_dict, kwargs 267 | 268 | @classmethod 269 | def from_dict(cls, config_dict: Dict, **kwargs) -> "PretrainedConfig": 270 | """ 271 | Constructs a `Config` from a Python dictionary of parameters. 272 | 273 | Args: 274 | config_dict (:obj:`Dict[str, any]`): 275 | Dictionary that will be used to instantiate the configuration object. Such a dictionary can be retrieved 276 | from a pre-trained checkpoint by leveraging the :func:`~transformers.PretrainedConfig.get_config_dict` 277 | method. 278 | kwargs (:obj:`Dict[str, any]`): 279 | Additional parameters from which to initialize the configuration object. 280 | 281 | Returns: 282 | :class:`PretrainedConfig`: An instance of a configuration object 283 | """ 284 | return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) 285 | 286 | config = cls(**config_dict) 287 | 288 | if hasattr(config, "pruned_heads"): 289 | config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items()) 290 | 291 | # Update config with kwargs if needed 292 | to_remove = [] 293 | for key, value in kwargs.items(): 294 | if hasattr(config, key): 295 | setattr(config, key, value) 296 | to_remove.append(key) 297 | for key in to_remove: 298 | kwargs.pop(key, None) 299 | 300 | logger.info("Model config %s", str(config)) 301 | if return_unused_kwargs: 302 | return config, kwargs 303 | else: 304 | return config 305 | 306 | @classmethod 307 | def from_json_file(cls, json_file: str) -> "PretrainedConfig": 308 | """ 309 | Constructs a `Config` from the path to a json file of parameters. 310 | 311 | Args: 312 | json_file (:obj:`string`): 313 | Path to the JSON file containing the parameters. 314 | 315 | Returns: 316 | :class:`PretrainedConfig`: An instance of a configuration object 317 | 318 | """ 319 | config_dict = cls._dict_from_json_file(json_file) 320 | return cls(**config_dict) 321 | 322 | @classmethod 323 | def _dict_from_json_file(cls, json_file: str): 324 | with open(json_file, "r", encoding="utf-8") as reader: 325 | text = reader.read() 326 | return json.loads(text) 327 | 328 | def __eq__(self, other): 329 | return self.__dict__ == other.__dict__ 330 | 331 | def __repr__(self): 332 | return "{} {}".format(self.__class__.__name__, self.to_json_string()) 333 | 334 | def to_diff_dict(self): 335 | """ 336 | Removes all attributes from config which correspond to the default 337 | config attributes for better readability and serializes to a Python 338 | dictionary. 339 | 340 | Returns: 341 | :obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, 342 | """ 343 | config_dict = self.to_dict() 344 | 345 | # get the default config dict 346 | default_config_dict = PretrainedConfig().to_dict() 347 | 348 | serializable_config_dict = {} 349 | 350 | # only serialize values that differ from the default config 351 | for key, value in config_dict.items(): 352 | if key not in default_config_dict or value != default_config_dict[key]: 353 | serializable_config_dict[key] = value 354 | 355 | return serializable_config_dict 356 | 357 | def to_dict(self): 358 | """ 359 | Serializes this instance to a Python dictionary. 360 | 361 | Returns: 362 | :obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, 363 | """ 364 | output = copy.deepcopy(self.__dict__) 365 | if hasattr(self.__class__, "model_type"): 366 | output["model_type"] = self.__class__.model_type 367 | return output 368 | 369 | def to_json_string(self, use_diff=True): 370 | """ 371 | Serializes this instance to a JSON string. 372 | 373 | Args: 374 | use_diff (:obj:`bool`): 375 | If set to True, only the difference between the config instance and the default PretrainedConfig() is serialized to JSON string. 376 | 377 | Returns: 378 | :obj:`string`: String containing all the attributes that make up this configuration instance in JSON format. 379 | """ 380 | if use_diff is True: 381 | config_dict = self.to_diff_dict() 382 | else: 383 | config_dict = self.to_dict() 384 | return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" 385 | 386 | def to_json_file(self, json_file_path, use_diff=True): 387 | """ 388 | Save this instance to a json file. 389 | 390 | Args: 391 | json_file_path (:obj:`string`): 392 | Path to the JSON file in which this configuration instance's parameters will be saved. 393 | use_diff (:obj:`bool`): 394 | If set to True, only the difference between the config instance and the default PretrainedConfig() is serialized to JSON file. 395 | """ 396 | with open(json_file_path, "w", encoding="utf-8") as writer: 397 | writer.write(self.to_json_string(use_diff=use_diff)) 398 | 399 | def update(self, config_dict: Dict): 400 | """ 401 | Updates attributes of this class 402 | with attributes from `config_dict`. 403 | 404 | Args: 405 | :obj:`Dict[str, any]`: Dictionary of attributes that shall be updated for this class. 406 | """ 407 | for key, value in config_dict.items(): 408 | setattr(self, key, value) 409 | -------------------------------------------------------------------------------- /data/dataset_refer_bert.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch.utils.data as data 4 | import torch 5 | from torchvision import transforms 6 | from torch.autograd import Variable 7 | import numpy as np 8 | from PIL import Image 9 | import torchvision.transforms.functional as TF 10 | import random 11 | 12 | from bert.tokenization_bert import BertTokenizer 13 | 14 | import h5py 15 | from refer.refer import REFER 16 | 17 | from args import get_parser 18 | 19 | # Dataset configuration initialization 20 | parser = get_parser() 21 | args = parser.parse_args() 22 | 23 | 24 | def add_random_boxes(img, min_num=20, max_num=60, size=32): 25 | h,w = size, size 26 | img = np.asarray(img).copy() 27 | img_size = img.shape[1] 28 | boxes = [] 29 | num = random.randint(min_num, max_num) 30 | for k in range(num): 31 | y, x = random.randint(0, img_size-w), random.randint(0, img_size-h) 32 | img[y:y+h, x: x+w] = 0 33 | boxes. append((x,y,h,w) ) 34 | img = Image.fromarray(img.astype('uint8'), 'RGB') 35 | return img 36 | 37 | 38 | class ReferDataset(data.Dataset): 39 | 40 | def __init__(self, 41 | args, 42 | image_transforms=None, 43 | target_transforms=None, 44 | split='train', 45 | eval_mode=False): 46 | 47 | self.classes = [] 48 | self.image_transforms = image_transforms 49 | self.target_transform = target_transforms 50 | self.split = split 51 | self.refer = REFER(args.refer_data_root, args.dataset, args.splitBy) 52 | 53 | self.max_tokens = 20 54 | 55 | ref_ids = self.refer.getRefIds(split=self.split) 56 | img_ids = self.refer.getImgIds(ref_ids) 57 | 58 | num_images_to_mask = int(len(ref_ids) * 0.2) 59 | self.images_to_mask = random.sample(ref_ids, num_images_to_mask) 60 | 61 | all_imgs = self.refer.Imgs 62 | self.imgs = list(all_imgs[i] for i in img_ids) 63 | self.ref_ids = ref_ids 64 | 65 | self.input_ids = [] 66 | self.attention_masks = [] 67 | self.tokenizer = BertTokenizer.from_pretrained(args.bert_tokenizer) 68 | 69 | self.eval_mode = eval_mode 70 | # if we are testing on a dataset, test all sentences of an object; 71 | # o/w, we are validating during training, randomly sample one sentence for efficiency 72 | for r in ref_ids: 73 | ref = self.refer.Refs[r] 74 | 75 | sentences_for_ref = [] 76 | attentions_for_ref = [] 77 | 78 | for i, (el, sent_id) in enumerate(zip(ref['sentences'], ref['sent_ids'])): 79 | sentence_raw = el['raw'] 80 | attention_mask = [0] * self.max_tokens 81 | padded_input_ids = [0] * self.max_tokens 82 | 83 | input_ids = self.tokenizer.encode(text=sentence_raw, add_special_tokens=True) 84 | 85 | # truncation of tokens 86 | input_ids = input_ids[:self.max_tokens] 87 | 88 | padded_input_ids[:len(input_ids)] = input_ids 89 | attention_mask[:len(input_ids)] = [1]*len(input_ids) 90 | 91 | sentences_for_ref.append(torch.tensor(padded_input_ids).unsqueeze(0)) 92 | attentions_for_ref.append(torch.tensor(attention_mask).unsqueeze(0)) 93 | 94 | self.input_ids.append(sentences_for_ref) 95 | self.attention_masks.append(attentions_for_ref) 96 | 97 | def get_classes(self): 98 | return self.classes 99 | 100 | def __len__(self): 101 | return len(self.ref_ids) 102 | 103 | def __getitem__(self, index): 104 | this_ref_id = self.ref_ids[index] 105 | this_img_id = self.refer.getImgIds(this_ref_id) 106 | this_img = self.refer.Imgs[this_img_id[0]] 107 | 108 | img = Image.open(os.path.join(self.refer.IMAGE_DIR, this_img['file_name'])) 109 | if self.split == 'train' and this_ref_id in self.images_to_mask: 110 | img = add_random_boxes(img) 111 | 112 | ref = self.refer.loadRefs(this_ref_id) 113 | 114 | ref_mask = np.array(self.refer.getMask(ref[0])['mask']) 115 | annot = np.zeros(ref_mask.shape) 116 | annot[ref_mask == 1] = 1 117 | 118 | annot = Image.fromarray(annot.astype(np.uint8), mode="P") 119 | 120 | if self.image_transforms is not None: 121 | # resize, from PIL to tensor, and mean and std normalization 122 | img, target = self.image_transforms(img, annot) 123 | 124 | if self.eval_mode: 125 | embedding = [] 126 | att = [] 127 | for s in range(len(self.input_ids[index])): 128 | e = self.input_ids[index][s] 129 | a = self.attention_masks[index][s] 130 | embedding.append(e.unsqueeze(-1)) 131 | att.append(a.unsqueeze(-1)) 132 | 133 | tensor_embeddings = torch.cat(embedding, dim=-1) 134 | attention_mask = torch.cat(att, dim=-1) 135 | else: 136 | choice_sent = np.random.choice(len(self.input_ids[index])) 137 | tensor_embeddings = self.input_ids[index][choice_sent] 138 | attention_mask = self.attention_masks[index][choice_sent] 139 | 140 | return img, target, tensor_embeddings, attention_mask 141 | -------------------------------------------------------------------------------- /lib/_utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import sys 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from bert.modeling_bert import BertModel 7 | 8 | 9 | def load_weights(model, load_path): 10 | dict_trained = torch.load(load_path)['model'] 11 | dict_new = model.state_dict().copy() 12 | for key in dict_new.keys(): 13 | if key in dict_trained.keys(): 14 | dict_new[key] = dict_trained[key] 15 | model.load_state_dict(dict_new) 16 | del dict_new 17 | del dict_trained 18 | torch.cuda.empty_cache() 19 | print('load weights from {}'.format(load_path)) 20 | return model 21 | 22 | 23 | class _LAVTSimpleDecode(nn.Module): 24 | def __init__(self, backbone, classifier): 25 | super(_LAVTSimpleDecode, self).__init__() 26 | self.backbone = backbone 27 | self.classifier = classifier 28 | 29 | def forward(self, x, l_feats, l_mask): 30 | input_shape = x.shape[-2:] 31 | features = self.backbone(x, l_feats, l_mask) 32 | x_c1, x_c2, x_c3, x_c4 = features 33 | 34 | x = self.classifier(x_c4, x_c3, x_c2, x_c1) 35 | x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=True) 36 | 37 | return x 38 | 39 | 40 | class LAVT(_LAVTSimpleDecode): 41 | pass 42 | 43 | 44 | ############################################### 45 | # LAVT One: put BERT inside the overall model # 46 | ############################################### 47 | class _LAVTOneSimpleDecode(nn.Module): 48 | def __init__(self, backbone, classifier, args): 49 | super(_LAVTOneSimpleDecode, self).__init__() 50 | self.backbone = backbone 51 | self.classifier = classifier 52 | self.text_encoder = BertModel.from_pretrained(args.ck_bert) 53 | self.text_encoder.pooler = None 54 | 55 | def forward(self, x, text, l_mask): 56 | input_shape = x.shape[-2:] 57 | ### language inference ### 58 | l_feats = self.text_encoder(text, attention_mask=l_mask)[0] # (6, 10, 768) 59 | l_feats = l_feats.permute(0, 2, 1) # (B, 768, N_l) 60 | l_mask = l_mask.unsqueeze(dim=-1) # (batch, N_l, 1) 61 | ########################## 62 | features = self.backbone(x, l_feats, l_mask) 63 | x_c1, x_c2, x_c3, x_c4 = features # e.g. x_c1:[B, 128, 120, 120], x_c2:[B, 256, 60, 60], x_c3:[B, 512, 30, 30], x_c4:[B, 1024, 15, 15] 64 | x = self.classifier(x_c4, x_c3, x_c2, x_c1) 65 | x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=True) 66 | return x 67 | 68 | 69 | class LAVTOne(_LAVTOneSimpleDecode): #change 70 | pass 71 | -------------------------------------------------------------------------------- /lib/cross_scale_interaction.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class h_sigmoid(nn.Module): 7 | def __init__(self, inplace=True): 8 | super(h_sigmoid, self).__init__() 9 | self.relu = nn.ReLU6(inplace=inplace) 10 | 11 | def forward(self, x): 12 | return self.relu(x + 3) / 6 13 | 14 | 15 | class Linear_BN(torch.nn.Sequential): 16 | def __init__(self, a, b, bn_weight_init=1): 17 | super().__init__() 18 | self.add_module('c', torch.nn.Linear(a, b, bias=False)) 19 | bn = torch.nn.BatchNorm1d(b) 20 | torch.nn.init.constant_(bn.weight, bn_weight_init) 21 | torch.nn.init.constant_(bn.bias, 0) 22 | self.add_module('bn', bn) 23 | 24 | @torch.no_grad() 25 | def fuse(self): 26 | l, bn = self._modules.values() 27 | w = bn.weight / (bn.running_var + bn.eps)**0.5 28 | w = l.weight * w[:, None] 29 | b = bn.bias - bn.running_mean * bn.weight / \ 30 | (bn.running_var + bn.eps)**0.5 31 | m = torch.nn.Linear(w.size(1), w.size(0)) 32 | m.weight.data.copy_(w) 33 | m.bias.data.copy_(b) 34 | return m 35 | 36 | def forward(self, x): 37 | l, bn = self._modules.values() 38 | x = l(x) 39 | return bn(x.flatten(0, 1)).reshape_as(x) 40 | 41 | 42 | class Residual(torch.nn.Module): 43 | def __init__(self, m): 44 | super().__init__() 45 | self.m = m 46 | 47 | def forward(self, x): 48 | return x + self.m(x) 49 | 50 | 51 | class ScaleAwareGate(nn.Module): 52 | def __init__(self, inp, oup): 53 | super(ScaleAwareGate, self).__init__() 54 | 55 | self.local_embedding = nn.Conv2d(inp, oup, kernel_size=1) 56 | self.bn1 = nn.BatchNorm2d(oup) 57 | 58 | self.global_embedding = nn.Conv2d(inp, oup, kernel_size=1) 59 | self.bn2 = nn.BatchNorm2d(oup) 60 | 61 | self.global_act = nn.Conv2d(inp, oup, kernel_size=1) 62 | self.bn3 = nn.BatchNorm2d(oup) 63 | self.act = h_sigmoid() 64 | 65 | def forward(self, x_l, x_g): 66 | B, C, H, W = x_l.shape 67 | local_feat = self.local_embedding(x_l) 68 | local_feat = self.bn1(local_feat) 69 | 70 | global_feat = self.global_embedding(x_g) 71 | global_feat = self.bn2(global_feat) 72 | global_feat = F.interpolate(global_feat, size=(H, W), mode='bilinear', align_corners=False) 73 | 74 | global_act = self.global_act(x_g) 75 | global_act = self.bn3(global_act) 76 | sig_act = F.interpolate(self.act(global_act), size=(H, W), mode='bilinear', align_corners=False) 77 | 78 | out = local_feat * sig_act + global_feat 79 | return out 80 | 81 | 82 | class Attention(torch.nn.Module): 83 | def __init__(self, dim, img_shape, att_shape, key_dim=32, num_heads=8, attn_ratio=2, activation=torch.nn.Hardswish): 84 | super().__init__() 85 | self.num_heads = num_heads 86 | self.scale = key_dim ** -0.5 87 | self.key_dim = key_dim 88 | self.img_shape = img_shape 89 | self.nh_kd = nh_kd = key_dim * num_heads 90 | self.d = int(attn_ratio * key_dim) 91 | self.dh = int(attn_ratio * key_dim) * num_heads 92 | self.attn_ratio = attn_ratio 93 | h = self.dh + nh_kd * 2 94 | self.qkv = Linear_BN(dim, h) 95 | 96 | self.parallel_conv = nn.Sequential( 97 | nn.Hardswish(inplace=False), 98 | nn.Conv2d(self.dh, self.dh, kernel_size=3, padding=1, groups=self.dh), 99 | ) 100 | self.to_out = nn.Linear(self.dh, dim) 101 | self.proj = nn.Linear(att_shape, img_shape) 102 | 103 | def forward(self, x): # x (B,N,C) 104 | B, N, C = x.shape 105 | qkv = self.qkv(x) 106 | q, k, v = qkv.view(B, N, self.num_heads, - 107 | 1).split([self.key_dim, self.key_dim, self.d], dim=3) 108 | q = q.permute(0, 2, 1, 3) 109 | k = k.permute(0, 2, 1, 3) 110 | v = v.permute(0, 2, 1, 3) 111 | 112 | v0 = v[:, :, :self.img_shape, :] 113 | 114 | v0 = v0.reshape(B, self.dh, int(self.img_shape ** 0.5), -1) 115 | v_conv = self.parallel_conv(v0).flatten(2) 116 | 117 | attn = ( 118 | (q @ k.transpose(-2, -1)) * self.scale 119 | ) 120 | attn = attn.softmax(dim=-1) 121 | x = (attn @ v).transpose(1, 2).reshape(B, -1, N) 122 | x = self.proj(x) + v_conv 123 | x = self.to_out(x.permute(0, 2, 1)) # + v_conv 124 | return x 125 | 126 | 127 | class CrossScaleAttention(nn.Module): 128 | def __init__(self, dim, img_shape=225, att_shape=314): 129 | super().__init__() 130 | self.bn1 = nn.BatchNorm2d(dim) 131 | 132 | self.DWConv1 = nn.Sequential( 133 | nn.Conv2d(dim, dim, kernel_size=3, stride=2, padding=1, groups=dim), 134 | nn.BatchNorm2d(dim), 135 | ) 136 | self.DWConv2 = nn.Sequential( 137 | nn.Conv2d(dim, dim, kernel_size=5, stride=3, padding=2, groups=dim), 138 | nn.BatchNorm2d(dim), 139 | ) 140 | self.attention = Attention(dim, img_shape, att_shape) 141 | self.bn4 = nn.BatchNorm2d(dim) 142 | self.activate = nn.Hardswish() 143 | self.conv = nn.Conv2d(dim, dim, 1) 144 | 145 | 146 | def forward(self, x): 147 | x0 = self.bn1(x) 148 | x1 = self.DWConv1(x0) 149 | x2 = self.DWConv2(x0) 150 | # [B, C, H, W] -> [B, C, H*W] 151 | x0, x1, x2 = x0.view(x0.shape[0], x0.shape[1], -1), x1.view(x1.shape[0], x1.shape[1], -1), x2.view(x2.shape[0], x2.shape[1], -1) 152 | attn = torch.cat((x0, x1, x2), dim=2).permute(0, 2, 1) 153 | attn = self.attention(attn) 154 | attn = attn.permute(0, 2, 1).contiguous().view(x0.shape[0], x0.shape[1], 15, 15) 155 | x = self.conv(self.activate(self.bn4(attn))) 156 | return x 157 | 158 | 159 | class FeedForward(nn.Module): 160 | def __init__(self, dim, hidden_dim): 161 | super().__init__() 162 | self.bn1 = nn.BatchNorm2d(dim) 163 | self.conv1 = nn.Conv2d(dim, hidden_dim, 1) 164 | self.bn2 = nn.BatchNorm2d(hidden_dim) 165 | self.conv2 = nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1, groups=dim) 166 | self.relu = nn.ReLU6() 167 | self.conv3 = nn.Conv2d(hidden_dim, dim, 1) 168 | 169 | def forward(self, x): 170 | out = self.conv3(self.relu(self.conv2(self.bn2(self.conv1(self.bn1(x)))))) 171 | return out 172 | 173 | 174 | class IntraFeedForward(nn.Module): 175 | def __init__(self, channels, mlp_ratio=2): 176 | super().__init__() 177 | self.channels = [channels[i]//4 for i in range(len(channels))] 178 | 179 | self.ff1 = Residual(FeedForward(self.channels[0], mlp_ratio*self.channels[0])) 180 | self.ff2 = Residual(FeedForward(self.channels[1], mlp_ratio*self.channels[1])) 181 | self.ff3 = Residual(FeedForward(self.channels[2], mlp_ratio*self.channels[2])) 182 | self.ff4 = Residual(FeedForward(self.channels[3], mlp_ratio*self.channels[3])) 183 | 184 | def forward(self, x): 185 | x1, x2, x3, x4 = x.split(self.channels, dim=1) 186 | x1 = self.ff1(x1) 187 | x2 = self.ff2(x2) 188 | x3 = self.ff3(x3) 189 | x4 = self.ff4(x4) 190 | return torch.cat([x1, x2, x3, x4], dim=1) 191 | 192 | 193 | class CIMBlock(nn.Module): 194 | def __init__(self, dim, channels, mlp_ratio=2): 195 | super().__init__() 196 | self.csa1 = Residual(CrossScaleAttention(dim)) 197 | self.intra_ff = Residual(IntraFeedForward(channels, mlp_ratio)) 198 | self.csa2 = Residual(CrossScaleAttention(dim)) 199 | self.ff = Residual(FeedForward(dim, dim*mlp_ratio)) 200 | 201 | def forward(self, x): 202 | x = self.csa1(x) 203 | x = self.intra_ff(x) 204 | x = self.csa2(x) 205 | x = self.ff(x) 206 | return x 207 | 208 | 209 | class PyramidPoolAgg(nn.Module): 210 | def __init__(self, stride): 211 | super().__init__() 212 | self.stride = stride 213 | 214 | def forward(self, inputs): 215 | B, C, H, W = inputs[-1].shape 216 | H = (H - 1) // self.stride + 1 217 | W = (W - 1) // self.stride + 1 218 | return torch.cat([nn.functional.adaptive_avg_pool2d(inp, (H, W)) for inp in inputs], dim=1) 219 | 220 | 221 | class CIM(nn.Module): 222 | def __init__(self, dim, num_layers=1, channels=[128, 256, 512, 1024], downsample=1): 223 | super().__init__() 224 | self.hidden_dim = dim // 4 225 | self.channels = channels 226 | self.stride = downsample 227 | 228 | self.down_channel = nn.Conv2d(dim, self.hidden_dim, 1) 229 | self.up_channel = nn.Conv2d(self.hidden_dim, dim, 1) 230 | 231 | # downsample to h/32, w/32 232 | self.pool = PyramidPoolAgg(stride=self.stride) 233 | self.block = nn.ModuleList([ 234 | CIMBlock(self.hidden_dim, channels) 235 | for _ in range(num_layers) 236 | ]) 237 | self.bn = nn.BatchNorm2d(self.hidden_dim) 238 | self.fusion = nn.ModuleList([ 239 | ScaleAwareGate(channels[i], channels[i]) 240 | for i in range(len(channels)) 241 | ]) 242 | 243 | def forward(self, input): # [B, C, H, W] 244 | out = self.pool(input) 245 | out = self.down_channel(out) 246 | for layer in self.block: 247 | out = layer(out) 248 | out = self.bn(out) 249 | out = self.up_channel(out) 250 | xx = out.split(self.channels, dim=1) 251 | results = [] 252 | for i in range(len(self.channels)): 253 | CIM_before = input[i] 254 | CIM_after = xx[i] 255 | out_ = self.fusion[i](CIM_before, CIM_after) 256 | results.append(out_) 257 | return results 258 | 259 | 260 | 261 | if __name__ == '__main__': 262 | model = CIM(1920) 263 | x1 = torch.randn(2, 128, 120, 120) 264 | x2 = torch.randn(2, 256, 60, 60) 265 | x3 = torch.randn(2, 512, 30, 30) 266 | x4 = torch.randn(2, 1024, 15, 15) 267 | x = tuple([x1, x2, x3, x4]) 268 | y = model(x) 269 | 270 | 271 | -------------------------------------------------------------------------------- /lib/mask_predictor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from collections import OrderedDict 5 | from arc import AdaptiveRotatedConv2d, RountingFunction 6 | 7 | 8 | class SimpleDecoding(nn.Module): 9 | def __init__(self, c4_dims, factor=2): 10 | super(SimpleDecoding, self).__init__() 11 | 12 | hidden_size = c4_dims//factor 13 | c4_size = c4_dims 14 | c3_size = c4_dims//(factor**1) 15 | c2_size = c4_dims//(factor**2) 16 | c1_size = c4_dims//(factor**3) 17 | 18 | self.conv1_4 = nn.Conv2d(c4_size+c3_size, hidden_size, 3, padding=1, bias=False) 19 | routing_function1 = RountingFunction(in_channels=hidden_size, kernel_number=1) 20 | self.conv2_4 = AdaptiveRotatedConv2d(in_channels=hidden_size, out_channels=hidden_size, 21 | kernel_size=3, padding=1, rounting_func=routing_function1, bias=False, kernel_number=1) 22 | 23 | self.bn1_4 = nn.BatchNorm2d(hidden_size) 24 | self.relu1_4 = nn.ReLU() 25 | self.bn2_4 = nn.BatchNorm2d(hidden_size) 26 | self.relu2_4 = nn.ReLU() 27 | 28 | self.conv1_3 = nn.Conv2d(hidden_size + c2_size, hidden_size, 3, padding=1, bias=False) 29 | routing_function2 = RountingFunction(in_channels=hidden_size, kernel_number=1) 30 | self.conv2_3 = AdaptiveRotatedConv2d(in_channels=hidden_size, out_channels=hidden_size, 31 | kernel_size=3, padding=1, rounting_func=routing_function2, bias=False, kernel_number=1) 32 | self.bn1_3 = nn.BatchNorm2d(hidden_size) 33 | self.relu1_3 = nn.ReLU() 34 | self.bn2_3 = nn.BatchNorm2d(hidden_size) 35 | self.relu2_3 = nn.ReLU() 36 | 37 | self.conv1_2 = nn.Conv2d(hidden_size + c1_size, hidden_size, 3, padding=1, bias=False) 38 | routing_function3 = RountingFunction(in_channels=hidden_size, kernel_number=1) 39 | self.conv2_2 = AdaptiveRotatedConv2d(in_channels=hidden_size, out_channels=hidden_size, 40 | kernel_size=3, padding=1, rounting_func=routing_function3, bias=False, kernel_number=1) 41 | self.bn1_2 = nn.BatchNorm2d(hidden_size) 42 | self.relu1_2 = nn.ReLU() 43 | self.bn2_2 = nn.BatchNorm2d(hidden_size) 44 | self.relu2_2 = nn.ReLU() 45 | 46 | self.conv1_1 = nn.Conv2d(hidden_size, 2, 1) 47 | 48 | def forward(self, x_c4, x_c3, x_c2, x_c1): 49 | # fuse Y4 and Y3 50 | if x_c4.size(-2) < x_c3.size(-2) or x_c4.size(-1) < x_c3.size(-1): 51 | x_c4 = F.interpolate(input=x_c4, scale_factor=2, mode='bilinear', align_corners=True) 52 | x = torch.cat([x_c4, x_c3], dim=1) 53 | x = self.conv1_4(x) 54 | x = self.bn1_4(x) 55 | x = self.relu1_4(x) 56 | x = self.conv2_4(x) 57 | x = self.bn2_4(x) 58 | x = self.relu2_4(x) 59 | 60 | # fuse top-down features and Y2 features 61 | if x.size(-2) < x_c2.size(-2) or x.size(-1) < x_c2.size(-1): 62 | x = F.interpolate(input=x, scale_factor=2, mode='bilinear', align_corners=True) 63 | x = torch.cat([x, x_c2], dim=1) 64 | x = self.conv1_3(x) 65 | x = self.bn1_3(x) 66 | x = self.relu1_3(x) 67 | x = self.conv2_3(x) 68 | x = self.bn2_3(x) 69 | x = self.relu2_3(x) 70 | 71 | # fuse top-down features and Y1 features 72 | if x.size(-2) < x_c1.size(-2) or x.size(-1) < x_c1.size(-1): 73 | x = F.interpolate(input=x, scale_factor=2, mode='bilinear', align_corners=True) 74 | x = torch.cat([x, x_c1], dim=1) 75 | x = self.conv1_2(x) 76 | x = self.bn1_2(x) 77 | x = self.relu1_2(x) 78 | x = self.conv2_2(x) 79 | x = self.bn2_2(x) 80 | x = self.relu2_2(x) 81 | 82 | return self.conv1_1(x) 83 | -------------------------------------------------------------------------------- /lib/mmcv_custom/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .checkpoint import load_checkpoint 4 | 5 | __all__ = ['load_checkpoint'] 6 | -------------------------------------------------------------------------------- /lib/sa/functional.py: -------------------------------------------------------------------------------- 1 | from . import functions 2 | 3 | 4 | def aggregation(input, weight, kernel_size=3, stride=1, padding=0, dilation=1, pad_mode=1): 5 | assert input.shape[0] == weight.shape[0] and (input.shape[1] % weight.shape[1] == 0) and pad_mode in [0, 1] 6 | if input.is_cuda: 7 | if pad_mode == 0: 8 | out = functions.aggregation_zeropad(input, weight, kernel_size, stride, padding, dilation) 9 | elif pad_mode == 1: 10 | out = functions.aggregation_refpad(input, weight, kernel_size, stride, padding, dilation) 11 | else: 12 | raise NotImplementedError 13 | return out 14 | 15 | 16 | def subtraction(input, kernel_size=3, stride=1, padding=0, dilation=1, pad_mode=1): 17 | assert input.dim() == 4 and pad_mode in [0, 1] 18 | if input.is_cuda: 19 | if pad_mode == 0: 20 | out = functions.subtraction_zeropad(input, kernel_size, stride, padding, dilation) 21 | elif pad_mode == 1: 22 | out = functions.subtraction_refpad(input, kernel_size, stride, padding, dilation) 23 | else: 24 | raise NotImplementedError 25 | return out 26 | 27 | 28 | def subtraction2(input1, input2, kernel_size=3, stride=1, padding=0, dilation=1, pad_mode=1): 29 | assert input1.dim() == 4 and input2.dim() == 4 and pad_mode in [0, 1] 30 | if input1.is_cuda: 31 | if pad_mode == 0: 32 | out = functions.subtraction2_zeropad(input1, input2, kernel_size, stride, padding, dilation) 33 | elif pad_mode == 1: 34 | out = functions.subtraction2_refpad(input1, input2, kernel_size, stride, padding, dilation) 35 | else: 36 | raise NotImplementedError 37 | return out 38 | -------------------------------------------------------------------------------- /lib/sa/functions/__init__.py: -------------------------------------------------------------------------------- 1 | from .aggregation_zeropad import * 2 | from .aggregation_refpad import * 3 | from .subtraction_zeropad import * 4 | from .subtraction_refpad import * 5 | from .subtraction2_zeropad import * 6 | from .subtraction2_refpad import * 7 | from .utils import * 8 | -------------------------------------------------------------------------------- /lib/sa/functions/aggregation_refpad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from torch.nn.modules.utils import _pair 4 | 5 | from lib.sa.functions.utils import Dtype, Stream, load_kernel 6 | 7 | 8 | CUDA_NUM_THREADS = 1024 9 | 10 | kernel_loop = ''' 11 | #define CUDA_KERNEL_LOOP(i, n) \ 12 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ 13 | i < (n); \ 14 | i += blockDim.x * gridDim.x) 15 | ''' 16 | 17 | 18 | def GET_BLOCKS(N): 19 | return (N + CUDA_NUM_THREADS - 1) // CUDA_NUM_THREADS 20 | 21 | 22 | _aggregation_refpad_forward_kernel = kernel_loop + ''' 23 | extern "C" 24 | __global__ void aggregation_refpad_forward_kernel( 25 | const ${Dtype}* bottom_data, const ${Dtype}* weight_data, ${Dtype}* top_data) { 26 | CUDA_KERNEL_LOOP(index, ${nthreads}) { 27 | const int n = index / ${input_channels} / ${top_height} / ${top_width}; 28 | const int c = (index / ${top_height} / ${top_width}) % ${input_channels}; 29 | const int h = (index / ${top_width}) % ${top_height}; 30 | const int w = index % ${top_width}; 31 | ${Dtype} value = 0; 32 | for (int kh = 0; kh < ${kernel_h}; ++kh) { 33 | for (int kw = 0; kw < ${kernel_w}; ++kw) { 34 | int h_in = -${pad_h} + h * ${stride_h} + kh * ${dilation_h}; 35 | int w_in = -${pad_w} + w * ${stride_w} + kw * ${dilation_w}; 36 | const int offset_weight = ((n * ${weight_channels} + c % ${weight_channels}) * ${kernel_h} * ${kernel_w} + (kh * ${kernel_w} + kw)) * ${top_height} * ${top_width} + h * ${top_width} + w; 37 | int offset_bottom; 38 | if ((h_in >= 0) && (h_in < ${bottom_height}) && (w_in >= 0) && (w_in < ${bottom_width})) { 39 | offset_bottom = ((n * ${input_channels} + c) * ${bottom_height} + h_in) * ${bottom_width} + w_in; 40 | } 41 | else { 42 | if (h_in < 0) h_in = -h_in; 43 | if (h_in >= ${bottom_height}) h_in = 2 * (${bottom_height} - 1) - h_in; 44 | if (w_in < 0) w_in = -w_in; 45 | if (w_in >= ${bottom_width}) w_in = 2 * (${bottom_width} - 1) - w_in; 46 | offset_bottom = ((n * ${input_channels} + c) * ${bottom_height} + h_in) * ${bottom_width} + w_in; 47 | } 48 | value += weight_data[offset_weight] * bottom_data[offset_bottom]; 49 | } 50 | } 51 | top_data[index] = value; 52 | } 53 | } 54 | ''' 55 | 56 | 57 | _aggregation_refpad_input_backward_kernel = kernel_loop + ''' 58 | extern "C" 59 | __global__ void aggregation_refpad_input_backward_kernel( 60 | const ${Dtype}* const top_diff, const ${Dtype}* const weight_data, ${Dtype}* bottom_diff) { 61 | CUDA_KERNEL_LOOP(index, ${nthreads}) { 62 | const int n = index / ${input_channels} / (${bottom_height} + 2 * ${pad_h}) / (${bottom_width} + 2 * ${pad_w}); 63 | const int c = (index / (${bottom_height} + 2 * ${pad_h}) / (${bottom_width} + 2 * ${pad_w})) % ${input_channels}; 64 | const int h = (index / (${bottom_width} + 2 * ${pad_w})) % (${bottom_height} + 2 * ${pad_h}); 65 | const int w = index % (${bottom_width} + 2 * ${pad_w}); 66 | ${Dtype} value = 0; 67 | for (int kh = 0; kh < ${kernel_h}; ++kh) { 68 | for (int kw = 0; kw < ${kernel_w}; ++kw) { 69 | const int h_out_s = h - kh * ${dilation_h}; 70 | const int w_out_s = w - kw * ${dilation_w}; 71 | if ((h_out_s % ${stride_h} == 0) && (w_out_s % ${stride_w} == 0)) { 72 | const int h_out = h_out_s / ${stride_h}; 73 | const int w_out = w_out_s / ${stride_w}; 74 | if ((h_out >= 0) && (h_out < ${top_height}) && (w_out >= 0) && (w_out < ${top_width})) { 75 | const int offset_top = ((n * ${input_channels} + c) * ${top_height} + h_out) * ${top_width} + w_out; 76 | const int offset_weight = ((n * ${weight_channels} + c % ${weight_channels}) * ${kernel_h} * ${kernel_w} + (kh * ${kernel_w} + kw)) * ${top_height} * ${top_width} + h_out * ${top_width} + w_out; 77 | value += weight_data[offset_weight] * top_diff[offset_top]; 78 | } 79 | } 80 | } 81 | } 82 | bottom_diff[index] = value; 83 | } 84 | } 85 | ''' 86 | 87 | 88 | _aggregation_refpad_weight_backward_kernel = kernel_loop + ''' 89 | extern "C" 90 | __global__ void aggregation_refpad_weight_backward_kernel( 91 | const ${Dtype}* const top_diff, const ${Dtype}* const bottom_data, ${Dtype}* weight_diff) { 92 | CUDA_KERNEL_LOOP(index, ${nthreads}) { 93 | const int n = index / ${weight_channels} / ${top_height} / ${top_width}; 94 | const int c = (index / ${top_height} / ${top_width}) % ${weight_channels}; 95 | const int h = (index / ${top_width}) % ${top_height}; 96 | const int w = index % ${top_width}; 97 | for (int kh = 0; kh < ${kernel_h}; ++kh) { 98 | for (int kw = 0; kw < ${kernel_w}; ++kw) { 99 | int h_in = -${pad_h} + h * ${stride_h} + kh * ${dilation_h}; 100 | int w_in = -${pad_w} + w * ${stride_w} + kw * ${dilation_w}; 101 | const int offset_weight = ((n * ${weight_channels} + c) * ${kernel_h} * ${kernel_w} + (kh * ${kernel_w} + kw)) * ${top_height} * ${top_width} + h * ${top_width} + w; 102 | ${Dtype} value = 0; 103 | for (int cc = c; cc < ${input_channels}; cc += ${weight_channels}) { 104 | const int offset_top = ((n * ${input_channels} + cc) * ${top_height} + h) * ${top_width} + w; 105 | int offset_bottom; 106 | if ((h_in >= 0) && (h_in < ${bottom_height}) && (w_in >= 0) && (w_in < ${bottom_width})) { 107 | offset_bottom = ((n * ${input_channels} + cc) * ${bottom_height} + h_in) * ${bottom_width} + w_in; 108 | } 109 | else { 110 | if (h_in < 0) h_in = -h_in; 111 | if (h_in >= ${bottom_height}) h_in = 2 * (${bottom_height} - 1) - h_in; 112 | if (w_in < 0) w_in = -w_in; 113 | if (w_in >= ${bottom_width}) w_in = 2 * (${bottom_width} - 1) - w_in; 114 | offset_bottom = ((n * ${input_channels} + cc) * ${bottom_height} + h_in) * ${bottom_width} + w_in; 115 | } 116 | value += bottom_data[offset_bottom] * top_diff[offset_top]; 117 | } 118 | weight_diff[offset_weight] = value; 119 | } 120 | } 121 | } 122 | } 123 | ''' 124 | 125 | 126 | class AggregationRefpad(Function): 127 | @staticmethod 128 | def forward(ctx, input, weight, kernel_size, stride, padding, dilation): 129 | kernel_size, stride, padding, dilation = _pair(kernel_size), _pair(stride), _pair(padding), _pair(dilation) 130 | ctx.kernel_size, ctx.stride, ctx.padding, ctx.dilation = kernel_size, stride, padding, dilation 131 | assert input.dim() == 4 and input.is_cuda and weight.is_cuda 132 | batch_size, input_channels, input_height, input_width = input.size() 133 | _, weight_channels, weight_height, weight_width = weight.size() 134 | output_height = int((input_height + 2 * padding[0] - (dilation[0] * (kernel_size[0] - 1) + 1)) / stride[0] + 1) 135 | output_width = int((input_width + 2 * padding[1] - (dilation[1] * (kernel_size[1] - 1) + 1)) / stride[1] + 1) 136 | assert output_height * output_width == weight_width 137 | output = input.new(batch_size, input_channels, output_height, output_width) 138 | n = output.numel() 139 | with torch.cuda.device_of(input): 140 | f = load_kernel('aggregation_refpad_forward_kernel', _aggregation_refpad_forward_kernel, Dtype=Dtype(input), nthreads=n, 141 | num=batch_size, input_channels=input_channels, weight_channels=weight_channels, 142 | bottom_height=input_height, bottom_width=input_width, 143 | top_height=output_height, top_width=output_width, 144 | kernel_h=kernel_size[0], kernel_w=kernel_size[1], 145 | stride_h=stride[0], stride_w=stride[1], 146 | dilation_h=dilation[0], dilation_w=dilation[1], 147 | pad_h=padding[0], pad_w=padding[1]) 148 | f(block=(CUDA_NUM_THREADS, 1, 1), 149 | grid=(GET_BLOCKS(n), 1, 1), 150 | args=[input.data_ptr(), weight.data_ptr(), output.data_ptr()], 151 | stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) 152 | ctx.save_for_backward(input, weight) 153 | return output 154 | 155 | @staticmethod 156 | def backward(ctx, grad_output): 157 | kernel_size, stride, padding, dilation = ctx.kernel_size, ctx.stride, ctx.padding, ctx.dilation 158 | input, weight = ctx.saved_tensors 159 | assert grad_output.is_cuda 160 | if not grad_output.is_contiguous(): 161 | grad_output = grad_output.contiguous() 162 | batch_size, input_channels, input_height, input_width = input.size() 163 | _, weight_channels, weight_height, weight_width = weight.size() 164 | output_height, output_width = grad_output.size()[2:] 165 | grad_input, grad_weight = None, None 166 | opt = dict(Dtype=Dtype(grad_output), 167 | num=batch_size, input_channels=input_channels, weight_channels=weight_channels, 168 | bottom_height=input_height, bottom_width=input_width, 169 | top_height=output_height, top_width=output_width, 170 | kernel_h=kernel_size[0], kernel_w=kernel_size[1], 171 | stride_h=stride[0], stride_w=stride[1], 172 | dilation_h=dilation[0], dilation_w=dilation[1], 173 | pad_h=padding[0], pad_w=padding[1]) 174 | with torch.cuda.device_of(input): 175 | if ctx.needs_input_grad[0]: 176 | grad_input = input.new(batch_size, input_channels, input_height + 2 * padding[0], input_width + 2 * padding[1]) 177 | n = grad_input.numel() 178 | opt['nthreads'] = n 179 | f = load_kernel('aggregation_refpad_input_backward_kernel', _aggregation_refpad_input_backward_kernel, **opt) 180 | f(block=(CUDA_NUM_THREADS, 1, 1), 181 | grid=(GET_BLOCKS(n), 1, 1), 182 | args=[grad_output.data_ptr(), weight.data_ptr(), grad_input.data_ptr()], 183 | stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) 184 | grad_input[:, :, padding[0] + 1:2 * padding[0] + 1, :] += torch.flip(grad_input[:, :, :padding[0], :], dims=[2]) 185 | grad_input[:, :, input_height - 1:input_height + padding[0] - 1, :] += torch.flip(grad_input[:, :, input_height + padding[0]:, :], dims=[2]) 186 | grad_input[:, :, :, padding[1] + 1:2 * padding[1] + 1] += torch.flip(grad_input[:, :, :, :padding[1]], dims=[3]) 187 | grad_input[:, :, :, input_width - 1:input_width + padding[1] - 1] += torch.flip(grad_input[:, :, :, input_width + padding[1]:], dims=[3]) 188 | grad_input = grad_input[:, :, padding[0]:padding[0]+input_height, padding[1]:padding[1]+input_width] 189 | 190 | if ctx.needs_input_grad[1]: 191 | grad_weight = weight.new(weight.size()) 192 | n = grad_weight.numel() // weight.shape[2] 193 | opt['nthreads'] = n 194 | f = load_kernel('aggregation_refpad_weight_backward_kernel', _aggregation_refpad_weight_backward_kernel, **opt) 195 | f(block=(CUDA_NUM_THREADS, 1, 1), 196 | grid=(GET_BLOCKS(n), 1, 1), 197 | args=[grad_output.data_ptr(), input.data_ptr(), grad_weight.data_ptr()], 198 | stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) 199 | return grad_input, grad_weight, None, None, None, None 200 | 201 | 202 | def aggregation_refpad(input, weight, kernel_size=3, stride=1, padding=0, dilation=1): 203 | assert input.shape[0] == weight.shape[0] and (input.shape[1] % weight.shape[1] == 0) 204 | if input.is_cuda: 205 | out = AggregationRefpad.apply(input, weight, kernel_size, stride, padding, dilation) 206 | else: 207 | raise NotImplementedError 208 | return out 209 | 210 | 211 | def test_aggregation_refpad(): 212 | import os 213 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 214 | kernel_size, stride, dilation = 5, 4, 2 215 | padding = (dilation * (kernel_size - 1) + 1) // 2 216 | n, c_x, c_w, in_height, in_width = 2, 8, 4, 5, 5 217 | out_height = int((in_height + 2 * padding - (dilation * (kernel_size - 1) + 1)) / stride + 1) 218 | out_width = int((in_width + 2 * padding - (dilation * (kernel_size - 1) + 1)) / stride + 1) 219 | x = torch.randn(n, c_x, in_height, in_width, requires_grad=True).double().cuda() 220 | w = torch.randn(n, c_w, pow(kernel_size, 2), out_height * out_width, requires_grad=True).double().cuda() 221 | 222 | y1 = aggregation_refpad(x, w, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation) 223 | unfold_j = torch.nn.Unfold(kernel_size=kernel_size, dilation=dilation, padding=0, stride=stride) 224 | pad = torch.nn.ReflectionPad2d(padding) 225 | x2 = unfold_j(pad(x)).view(n, c_x // c_w, c_w, pow(kernel_size, 2), out_height * out_width) 226 | y2 = (w.unsqueeze(1) * x2).sum(-2).view(n, c_x, out_height, out_width) 227 | assert (y1 - y2).abs().max() < 1e-9 228 | 229 | gx1 = torch.autograd.grad(y1.mean(), x, retain_graph=True)[0] 230 | gx2 = torch.autograd.grad(y2.mean(), x, retain_graph=True)[0] 231 | assert (gx1 - gx2).abs().max() < 1e-9 232 | 233 | gw1 = torch.autograd.grad(y1.mean(), w, retain_graph=True)[0] 234 | gw2 = torch.autograd.grad(y2.mean(), w, retain_graph=True)[0] 235 | assert (gw1 - gw2).abs().max() < 1e-9 236 | 237 | from functools import partial 238 | assert torch.autograd.gradcheck(partial(aggregation_refpad, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation), (x, w)) 239 | print('test case passed') 240 | 241 | 242 | if __name__ == '__main__': 243 | test_aggregation_refpad() 244 | -------------------------------------------------------------------------------- /lib/sa/functions/aggregation_zeropad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from torch.nn.modules.utils import _pair 4 | 5 | from lib.sa.functions.utils import Dtype, Stream, load_kernel 6 | 7 | 8 | CUDA_NUM_THREADS = 1024 9 | 10 | kernel_loop = ''' 11 | #define CUDA_KERNEL_LOOP(i, n) \ 12 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ 13 | i < (n); \ 14 | i += blockDim.x * gridDim.x) 15 | ''' 16 | 17 | 18 | def GET_BLOCKS(N): 19 | return (N + CUDA_NUM_THREADS - 1) // CUDA_NUM_THREADS 20 | 21 | 22 | _aggregation_zeropad_forward_kernel = kernel_loop + ''' 23 | extern "C" 24 | __global__ void aggregation_zeropad_forward_kernel( 25 | const ${Dtype}* bottom_data, const ${Dtype}* weight_data, ${Dtype}* top_data) { 26 | CUDA_KERNEL_LOOP(index, ${nthreads}) { 27 | const int n = index / ${input_channels} / ${top_height} / ${top_width}; 28 | const int c = (index / ${top_height} / ${top_width}) % ${input_channels}; 29 | const int h = (index / ${top_width}) % ${top_height}; 30 | const int w = index % ${top_width}; 31 | ${Dtype} value = 0; 32 | for (int kh = 0; kh < ${kernel_h}; ++kh) { 33 | for (int kw = 0; kw < ${kernel_w}; ++kw) { 34 | const int h_in = -${pad_h} + h * ${stride_h} + kh * ${dilation_h}; 35 | const int w_in = -${pad_w} + w * ${stride_w} + kw * ${dilation_w}; 36 | if ((h_in >= 0) && (h_in < ${bottom_height}) && (w_in >= 0) && (w_in < ${bottom_width})) { 37 | const int offset_bottom = ((n * ${input_channels} + c) * ${bottom_height} + h_in) * ${bottom_width} + w_in; 38 | const int offset_weight = ((n * ${weight_channels} + c % ${weight_channels}) * ${kernel_h} * ${kernel_w} + (kh * ${kernel_w} + kw)) * ${top_height} * ${top_width} + h * ${top_width} + w; 39 | value += weight_data[offset_weight] * bottom_data[offset_bottom]; 40 | } 41 | } 42 | } 43 | top_data[index] = value; 44 | } 45 | } 46 | ''' 47 | 48 | 49 | _aggregation_zeropad_input_backward_kernel = kernel_loop + ''' 50 | extern "C" 51 | __global__ void aggregation_zeropad_input_backward_kernel( 52 | const ${Dtype}* const top_diff, const ${Dtype}* const weight_data, ${Dtype}* bottom_diff) { 53 | CUDA_KERNEL_LOOP(index, ${nthreads}) { 54 | const int n = index / ${input_channels} / ${bottom_height} / ${bottom_width}; 55 | const int c = (index / ${bottom_height} / ${bottom_width}) % ${input_channels}; 56 | const int h = (index / ${bottom_width}) % ${bottom_height}; 57 | const int w = index % ${bottom_width}; 58 | ${Dtype} value = 0; 59 | for (int kh = 0; kh < ${kernel_h}; ++kh) { 60 | for (int kw = 0; kw < ${kernel_w}; ++kw) { 61 | const int h_out_s = h + ${pad_h} - kh * ${dilation_h}; 62 | const int w_out_s = w + ${pad_w} - kw * ${dilation_w}; 63 | if (((h_out_s % ${stride_h}) == 0) && ((w_out_s % ${stride_w}) == 0)) { 64 | const int h_out = h_out_s / ${stride_h}; 65 | const int w_out = w_out_s / ${stride_w}; 66 | if ((h_out >= 0) && (h_out < ${top_height}) && (w_out >= 0) && (w_out < ${top_width})) { 67 | const int offset_top = ((n * ${input_channels} + c) * ${top_height} + h_out) * ${top_width} + w_out; 68 | const int offset_weight = ((n * ${weight_channels} + c % ${weight_channels}) * ${kernel_h} * ${kernel_w} + (kh * ${kernel_w} + kw)) * ${top_height} * ${top_width} + h_out * ${top_width} + w_out; 69 | value += weight_data[offset_weight] * top_diff[offset_top]; 70 | } 71 | } 72 | } 73 | } 74 | bottom_diff[index] = value; 75 | } 76 | } 77 | ''' 78 | 79 | 80 | _aggregation_zeropad_weight_backward_kernel = kernel_loop + ''' 81 | extern "C" 82 | __global__ void aggregation_zeropad_weight_backward_kernel( 83 | const ${Dtype}* const top_diff, const ${Dtype}* const bottom_data, ${Dtype}* weight_diff) { 84 | CUDA_KERNEL_LOOP(index, ${nthreads}) { 85 | const int n = index / ${weight_channels} / ${top_height} / ${top_width}; 86 | const int c = (index / ${top_height} / ${top_width}) % ${weight_channels}; 87 | const int h = (index / ${top_width}) % ${top_height}; 88 | const int w = index % ${top_width}; 89 | for (int kh = 0; kh < ${kernel_h}; ++kh) { 90 | for (int kw = 0; kw < ${kernel_w}; ++kw) { 91 | const int h_in = -${pad_h} + h * ${stride_h} + kh * ${dilation_h}; 92 | const int w_in = -${pad_w} + w * ${stride_w} + kw * ${dilation_w}; 93 | const int offset_weight = ((n * ${weight_channels} + c) * ${kernel_h} * ${kernel_w} + (kh * ${kernel_w} + kw)) * ${top_height} * ${top_width} + h * ${top_width} + w; 94 | ${Dtype} value = 0; 95 | if ((h_in >= 0) && (h_in < ${bottom_height}) && (w_in >= 0) && (w_in < ${bottom_width})) { 96 | for (int cc = c; cc < ${input_channels}; cc += ${weight_channels}) { 97 | const int offset_bottom = ((n * ${input_channels} + cc) * ${bottom_height} + h_in) * ${bottom_width} + w_in; 98 | const int offset_top = ((n * ${input_channels} + cc) * ${top_height} + h) * ${top_width} + w; 99 | value += bottom_data[offset_bottom] * top_diff[offset_top]; 100 | } 101 | } 102 | weight_diff[offset_weight] = value; 103 | } 104 | } 105 | } 106 | } 107 | ''' 108 | 109 | 110 | class AggregationZeropad(Function): 111 | @staticmethod 112 | def forward(ctx, input, weight, kernel_size, stride, padding, dilation): 113 | kernel_size, stride, padding, dilation = _pair(kernel_size), _pair(stride), _pair(padding), _pair(dilation) 114 | ctx.kernel_size, ctx.stride, ctx.padding, ctx.dilation = kernel_size, stride, padding, dilation 115 | assert input.dim() == 4 and input.is_cuda and weight.is_cuda 116 | batch_size, input_channels, input_height, input_width = input.size() 117 | _, weight_channels, weight_height, weight_width = weight.size() 118 | output_height = int((input_height + 2 * padding[0] - (dilation[0] * (kernel_size[0] - 1) + 1)) / stride[0] + 1) 119 | output_width = int((input_width + 2 * padding[1] - (dilation[1] * (kernel_size[1] - 1) + 1)) / stride[1] + 1) 120 | assert output_height * output_width == weight_width 121 | output = input.new(batch_size, input_channels, output_height, output_width) 122 | n = output.numel() 123 | with torch.cuda.device_of(input): 124 | f = load_kernel('aggregation_zeropad_forward_kernel', _aggregation_zeropad_forward_kernel, Dtype=Dtype(input), nthreads=n, 125 | num=batch_size, input_channels=input_channels, weight_channels=weight_channels, 126 | bottom_height=input_height, bottom_width=input_width, 127 | top_height=output_height, top_width=output_width, 128 | kernel_h=kernel_size[0], kernel_w=kernel_size[1], 129 | stride_h=stride[0], stride_w=stride[1], 130 | dilation_h=dilation[0], dilation_w=dilation[1], 131 | pad_h=padding[0], pad_w=padding[1]) 132 | f(block=(CUDA_NUM_THREADS, 1, 1), 133 | grid=(GET_BLOCKS(n), 1, 1), 134 | args=[input.data_ptr(), weight.data_ptr(), output.data_ptr()], 135 | stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) 136 | ctx.save_for_backward(input, weight) 137 | return output 138 | 139 | @staticmethod 140 | def backward(ctx, grad_output): 141 | kernel_size, stride, padding, dilation = ctx.kernel_size, ctx.stride, ctx.padding, ctx.dilation 142 | input, weight = ctx.saved_tensors 143 | assert grad_output.is_cuda 144 | if not grad_output.is_contiguous(): 145 | grad_output = grad_output.contiguous() 146 | batch_size, input_channels, input_height, input_width = input.size() 147 | _, weight_channels, weight_height, weight_width = weight.size() 148 | output_height, output_width = grad_output.size()[2:] 149 | grad_input, grad_weight = None, None 150 | opt = dict(Dtype=Dtype(grad_output), 151 | num=batch_size, input_channels=input_channels, weight_channels=weight_channels, 152 | bottom_height=input_height, bottom_width=input_width, 153 | top_height=output_height, top_width=output_width, 154 | kernel_h=kernel_size[0], kernel_w=kernel_size[1], 155 | stride_h=stride[0], stride_w=stride[1], 156 | dilation_h=dilation[0], dilation_w=dilation[1], 157 | pad_h=padding[0], pad_w=padding[1]) 158 | with torch.cuda.device_of(input): 159 | if ctx.needs_input_grad[0]: 160 | grad_input = input.new(input.size()) 161 | n = grad_input.numel() 162 | opt['nthreads'] = n 163 | f = load_kernel('aggregation_zeropad_input_backward_kernel', _aggregation_zeropad_input_backward_kernel, **opt) 164 | f(block=(CUDA_NUM_THREADS, 1, 1), 165 | grid=(GET_BLOCKS(n), 1, 1), 166 | args=[grad_output.data_ptr(), weight.data_ptr(), grad_input.data_ptr()], 167 | stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) 168 | if ctx.needs_input_grad[1]: 169 | grad_weight = weight.new(weight.size()) 170 | n = grad_weight.numel() // weight.shape[2] 171 | opt['nthreads'] = n 172 | f = load_kernel('aggregation_zeropad_weight_backward_kernel', _aggregation_zeropad_weight_backward_kernel, **opt) 173 | f(block=(CUDA_NUM_THREADS, 1, 1), 174 | grid=(GET_BLOCKS(n), 1, 1), 175 | args=[grad_output.data_ptr(), input.data_ptr(), grad_weight.data_ptr()], 176 | stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) 177 | return grad_input, grad_weight, None, None, None, None 178 | 179 | 180 | def aggregation_zeropad(input, weight, kernel_size=3, stride=1, padding=0, dilation=1): 181 | assert input.shape[0] == weight.shape[0] and (input.shape[1] % weight.shape[1] == 0) 182 | if input.is_cuda: 183 | out = AggregationZeropad.apply(input, weight, kernel_size, stride, padding, dilation) 184 | else: 185 | raise NotImplementedError 186 | return out 187 | 188 | 189 | def test_aggregation_zeropad(): 190 | import os 191 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 192 | kernel_size, stride, dilation = 5, 4, 2 193 | padding = (dilation * (kernel_size - 1) + 1) // 2 194 | n, c_x, c_w, in_height, in_width = 2, 8, 4, 9, 9 195 | out_height = int((in_height + 2 * padding - (dilation * (kernel_size - 1) + 1)) / stride + 1) 196 | out_width = int((in_width + 2 * padding - (dilation * (kernel_size - 1) + 1)) / stride + 1) 197 | x = torch.randn(n, c_x, in_height, in_width, requires_grad=True).double().cuda() 198 | w = torch.randn(n, c_w, pow(kernel_size, 2), out_height * out_width, requires_grad=True).double().cuda() 199 | 200 | y1 = aggregation_zeropad(x, w, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation) 201 | unfold_j = torch.nn.Unfold(kernel_size=kernel_size, dilation=dilation, padding=padding, stride=stride) 202 | x2 = unfold_j(x).view(n, c_x // c_w, c_w, pow(kernel_size, 2), out_height * out_width) 203 | y2 = (w.unsqueeze(1) * x2).sum(-2).view(n, c_x, out_height, out_width) 204 | assert (y1 - y2).abs().max() < 1e-9 205 | 206 | gx1 = torch.autograd.grad(y1.mean(), x, retain_graph=True)[0] 207 | gx2 = torch.autograd.grad(y2.mean(), x, retain_graph=True)[0] 208 | assert (gx1 - gx2).abs().max() < 1e-9 209 | 210 | gw1 = torch.autograd.grad(y1.mean(), w, retain_graph=True)[0] 211 | gw2 = torch.autograd.grad(y2.mean(), w, retain_graph=True)[0] 212 | assert (gw1 - gw2).abs().max() < 1e-9 213 | 214 | from functools import partial 215 | assert torch.autograd.gradcheck(partial(aggregation_zeropad, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation), (x, w)) 216 | print('test case passed') 217 | 218 | 219 | if __name__ == '__main__': 220 | test_aggregation_zeropad() 221 | -------------------------------------------------------------------------------- /lib/sa/functions/subtraction2_refpad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from torch.nn.modules.utils import _pair 4 | 5 | from lib.sa.functions.utils import Dtype, Stream, load_kernel 6 | 7 | 8 | CUDA_NUM_THREADS = 1024 9 | 10 | kernel_loop = ''' 11 | #define CUDA_KERNEL_LOOP(i, n) \ 12 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ 13 | i < (n); \ 14 | i += blockDim.x * gridDim.x) 15 | ''' 16 | 17 | 18 | def GET_BLOCKS(N): 19 | return (N + CUDA_NUM_THREADS - 1) // CUDA_NUM_THREADS 20 | 21 | 22 | _subtraction2_refpad_forward_kernel = kernel_loop + ''' 23 | extern "C" 24 | __global__ void subtraction2_refpad_forward_kernel( 25 | const ${Dtype}* bottom1_data, const ${Dtype}* bottom2_data, ${Dtype}* top_data) { 26 | CUDA_KERNEL_LOOP(index, ${nthreads}) { 27 | const int n = index / ${input_channels} / ${top_height} / ${top_width}; 28 | const int c = (index / ${top_height} / ${top_width}) % ${input_channels}; 29 | const int h = (index / ${top_width}) % ${top_height}; 30 | const int w = index % ${top_width}; 31 | const int h_in_center = -${pad_h} + h * ${stride_h} + (${kernel_h} - 1) / 2 * ${dilation_h}; 32 | const int w_in_center = -${pad_w} + w * ${stride_w} + (${kernel_w} - 1) / 2 * ${dilation_w}; 33 | const int offset_center = ((n * ${input_channels} + c) * ${bottom_height} + h_in_center) * ${bottom_width} + w_in_center; 34 | for (int kh = 0; kh < ${kernel_h}; ++kh) { 35 | for (int kw = 0; kw < ${kernel_w}; ++kw) { 36 | int h_in = -${pad_h} + h * ${stride_h} + kh * ${dilation_h}; 37 | int w_in = -${pad_w} + w * ${stride_w} + kw * ${dilation_w}; 38 | const int offset_top = ((n * ${input_channels} + c) * ${kernel_h} * ${kernel_w} + (kh * ${kernel_w} + kw)) * ${top_height} * ${top_width} + h * ${top_width} + w; 39 | int offset_bottom; 40 | if ((h_in >= 0) && (h_in < ${bottom_height}) && (w_in >= 0) && (w_in < ${bottom_width})) { 41 | offset_bottom = ((n * ${input_channels} + c) * ${bottom_height} + h_in) * ${bottom_width} + w_in; 42 | } 43 | else { 44 | if (h_in < 0) h_in = -h_in; 45 | if (h_in >= ${bottom_height}) h_in = 2 * (${bottom_height} - 1) - h_in; 46 | if (w_in < 0) w_in = -w_in; 47 | if (w_in >= ${bottom_width}) w_in = 2 * (${bottom_width} - 1) - w_in; 48 | offset_bottom = ((n * ${input_channels} + c) * ${bottom_height} + h_in) * ${bottom_width} + w_in; 49 | } 50 | top_data[offset_top] = bottom1_data[offset_center] - bottom2_data[offset_bottom]; 51 | } 52 | } 53 | } 54 | } 55 | ''' 56 | 57 | 58 | _subtraction2_refpad_input1_backward_kernel = kernel_loop + ''' 59 | extern "C" 60 | __global__ void subtraction2_refpad_input1_backward_kernel( 61 | const ${Dtype}* const top_diff, ${Dtype}* bottom_diff) { 62 | CUDA_KERNEL_LOOP(index, ${nthreads}) { 63 | const int n = index / ${input_channels} / ${bottom_height} / ${bottom_width}; 64 | const int c = (index / ${bottom_height} / ${bottom_width}) % ${input_channels}; 65 | const int h = (index / ${bottom_width}) % ${bottom_height}; 66 | const int w = index % ${bottom_width}; 67 | ${Dtype} value = 0; 68 | if (((h % ${stride_h}) == 0) && ((w % ${stride_w}) == 0)) { 69 | const int h_out = h / ${stride_h}; 70 | const int w_out = w / ${stride_w}; 71 | for (int kh = 0; kh < ${kernel_h}; ++kh) { 72 | for (int kw = 0; kw < ${kernel_w}; ++kw) { 73 | const int offset_top = ((n * ${input_channels} + c) * ${kernel_h} * ${kernel_w} + (kh * ${kernel_w} + kw)) * ${top_height} * ${top_width} + h_out * ${top_width} + w_out; 74 | value += top_diff[offset_top]; 75 | } 76 | } 77 | } 78 | bottom_diff[index] = value; 79 | } 80 | } 81 | ''' 82 | 83 | 84 | _subtraction2_refpad_input2_backward_kernel = kernel_loop + ''' 85 | extern "C" 86 | __global__ void subtraction2_refpad_input2_backward_kernel( 87 | const ${Dtype}* const top_diff, ${Dtype}* bottom_diff) { 88 | CUDA_KERNEL_LOOP(index, ${nthreads}) { 89 | const int n = index / ${input_channels} / (${bottom_height} + 2 * ${pad_h}) / (${bottom_width} + 2 * ${pad_w}); 90 | const int c = (index / (${bottom_height} + 2 * ${pad_h}) / (${bottom_width} + 2 * ${pad_w})) % ${input_channels}; 91 | const int h = (index / (${bottom_width} + 2 * ${pad_w})) % (${bottom_height} + 2 * ${pad_h}); 92 | const int w = index % (${bottom_width} + 2 * ${pad_w}); 93 | ${Dtype} value = 0; 94 | for (int kh = 0; kh < ${kernel_h}; ++kh) { 95 | for (int kw = 0; kw < ${kernel_w}; ++kw) { 96 | const int h_out_s = h - kh * ${dilation_h}; 97 | const int w_out_s = w - kw * ${dilation_w}; 98 | if (((h_out_s % ${stride_h}) == 0) && ((w_out_s % ${stride_w}) == 0)) { 99 | const int h_out = h_out_s / ${stride_h}; 100 | const int w_out = w_out_s / ${stride_w}; 101 | if ((h_out >= 0) && (h_out < ${top_height}) && (w_out >= 0) && (w_out < ${top_width})) { 102 | const int offset_top = ((n * ${input_channels} + c) * ${kernel_h} * ${kernel_w} + (kh * ${kernel_w} + kw)) * ${top_height} * ${top_width} + h_out * ${top_width} + w_out; 103 | value += -top_diff[offset_top]; 104 | } 105 | } 106 | } 107 | } 108 | bottom_diff[index] = value; 109 | } 110 | } 111 | ''' 112 | 113 | 114 | class Subtraction2Refpad(Function): 115 | @staticmethod 116 | def forward(ctx, input1, input2, kernel_size, stride, padding, dilation): 117 | kernel_size, stride, padding, dilation = _pair(kernel_size), _pair(stride), _pair(padding), _pair(dilation) 118 | ctx.kernel_size, ctx.stride, ctx.padding, ctx.dilation = kernel_size, stride, padding, dilation 119 | assert input1.dim() == 4 and input1.is_cuda 120 | batch_size, input_channels, input_height, input_width = input1.size() 121 | output_height = int((input_height + 2 * padding[0] - (dilation[0] * (kernel_size[0] - 1) + 1)) / stride[0] + 1) 122 | output_width = int((input_width + 2 * padding[1] - (dilation[1] * (kernel_size[1] - 1) + 1)) / stride[1] + 1) 123 | output = input1.new(batch_size, input_channels, kernel_size[0] * kernel_size[1], output_height * output_width) 124 | n = output.numel() // output.shape[2] 125 | with torch.cuda.device_of(input1): 126 | f = load_kernel('subtraction2_refpad_forward_kernel', _subtraction2_refpad_forward_kernel, Dtype=Dtype(input1), nthreads=n, 127 | num=batch_size, input_channels=input_channels, 128 | bottom_height=input_height, bottom_width=input_width, 129 | top_height=output_height, top_width=output_width, 130 | kernel_h=kernel_size[0], kernel_w=kernel_size[1], 131 | stride_h=stride[0], stride_w=stride[1], 132 | dilation_h=dilation[0], dilation_w=dilation[1], 133 | pad_h=padding[0], pad_w=padding[1]) 134 | f(block=(CUDA_NUM_THREADS, 1, 1), 135 | grid=(GET_BLOCKS(n), 1, 1), 136 | args=[input1.data_ptr(), input2.data_ptr(), output.data_ptr()], 137 | stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) 138 | ctx.save_for_backward(input1, input2) 139 | return output 140 | 141 | @staticmethod 142 | def backward(ctx, grad_output): 143 | kernel_size, stride, padding, dilation = ctx.kernel_size, ctx.stride, ctx.padding, ctx.dilation 144 | input1, input2 = ctx.saved_tensors 145 | assert grad_output.is_cuda 146 | if not grad_output.is_contiguous(): 147 | grad_output = grad_output.contiguous() 148 | batch_size, input_channels, input_height, input_width = input1.size() 149 | output_height = int((input_height + 2 * padding[0] - (dilation[0] * (kernel_size[0] - 1) + 1)) / stride[0] + 1) 150 | output_width = int((input_width + 2 * padding[1] - (dilation[1] * (kernel_size[1] - 1) + 1)) / stride[1] + 1) 151 | grad_input1, grad_input2 = None, None 152 | opt = dict(Dtype=Dtype(grad_output), 153 | num=batch_size, input_channels=input_channels, 154 | bottom_height=input_height, bottom_width=input_width, 155 | top_height=output_height, top_width=output_width, 156 | kernel_h=kernel_size[0], kernel_w=kernel_size[1], 157 | stride_h=stride[0], stride_w=stride[1], 158 | dilation_h=dilation[0], dilation_w=dilation[1], 159 | pad_h=padding[0], pad_w=padding[1]) 160 | with torch.cuda.device_of(input1): 161 | if ctx.needs_input_grad[0]: 162 | grad_input1 = input1.new(input1.size()) 163 | n = grad_input1.numel() 164 | opt['nthreads'] = n 165 | f = load_kernel('subtraction2_refpad_input1_backward_kernel', _subtraction2_refpad_input1_backward_kernel, **opt) 166 | f(block=(CUDA_NUM_THREADS, 1, 1), 167 | grid=(GET_BLOCKS(n), 1, 1), 168 | args=[grad_output.data_ptr(), grad_input1.data_ptr()], 169 | stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) 170 | with torch.cuda.device_of(input2): 171 | if ctx.needs_input_grad[1]: 172 | grad_input2 = input2.new(batch_size, input_channels, input_height + 2 * padding[0], input_width + 2 * padding[1]) 173 | n = grad_input2.numel() 174 | opt['nthreads'] = n 175 | f = load_kernel('subtraction2_refpad_input2_backward_kernel', _subtraction2_refpad_input2_backward_kernel, **opt) 176 | f(block=(CUDA_NUM_THREADS, 1, 1), 177 | grid=(GET_BLOCKS(n), 1, 1), 178 | args=[grad_output.data_ptr(), grad_input2.data_ptr()], 179 | stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) 180 | grad_input2[:, :, padding[0] + 1:2 * padding[0] + 1, :] += torch.flip(grad_input2[:, :, :padding[0], :], dims=[2]) 181 | grad_input2[:, :, input_height - 1:input_height + padding[0] - 1, :] += torch.flip(grad_input2[:, :, input_height + padding[0]:, :], dims=[2]) 182 | grad_input2[:, :, :, padding[1] + 1:2 * padding[1] + 1] += torch.flip(grad_input2[:, :, :, :padding[1]], dims=[3]) 183 | grad_input2[:, :, :, input_width - 1:input_width + padding[1] - 1] += torch.flip(grad_input2[:, :, :, input_width + padding[1]:], dims=[3]) 184 | grad_input2 = grad_input2[:, :, padding[0]:padding[0] + input_height, padding[1]:padding[1] + input_width] 185 | return grad_input1, grad_input2, None, None, None, None 186 | 187 | 188 | def subtraction2_refpad(input1, input2, kernel_size=3, stride=1, padding=0, dilation=1): 189 | assert input1.dim() == 4 190 | if input1.is_cuda: 191 | out = Subtraction2Refpad.apply(input1, input2, kernel_size, stride, padding, dilation) 192 | else: 193 | raise NotImplementedError 194 | return out 195 | 196 | 197 | def test_subtraction2_refpad(): 198 | import os 199 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 200 | kernel_size, stride, dilation = 5, 4, 2 # 3, 1, 1 201 | padding = (dilation * (kernel_size - 1) + 1) // 2 202 | n, c, in_height, in_width = 2, 8, 9, 9 203 | out_height = int((in_height + 2 * padding - (dilation * (kernel_size - 1) + 1)) / stride + 1) 204 | out_width = int((in_width + 2 * padding - (dilation * (kernel_size - 1) + 1)) / stride + 1) 205 | x1 = torch.randn(n, c, in_height, in_width, requires_grad=True).double().cuda() 206 | x2 = torch.randn(n, c, in_height, in_width, requires_grad=True).double().cuda() 207 | 208 | y1 = subtraction2_refpad(x1, x2, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation) 209 | unfold_i = torch.nn.Unfold(kernel_size=1, dilation=dilation, padding=0, stride=stride) 210 | unfold_j = torch.nn.Unfold(kernel_size=kernel_size, dilation=dilation, padding=0, stride=stride) 211 | pad = torch.nn.ReflectionPad2d(padding) 212 | y2 = unfold_i(x1).view(n, c, 1, out_height * out_width) - unfold_j(pad(x2)).view(n, c, pow(kernel_size, 2), out_height * out_width) 213 | assert (y1 - y2).abs().max() < 1e-9 214 | 215 | gx11 = torch.autograd.grad(y1.mean(), x1, retain_graph=True)[0] 216 | gx12 = torch.autograd.grad(y1.mean(), x2, retain_graph=True)[0] 217 | gx21 = torch.autograd.grad(y2.mean(), x1, retain_graph=True)[0] 218 | gx22 = torch.autograd.grad(y2.mean(), x2, retain_graph=True)[0] 219 | assert (gx11 - gx21).abs().max() < 1e-9 220 | assert (gx12 - gx22).abs().max() < 1e-9 221 | 222 | from functools import partial 223 | assert torch.autograd.gradcheck(partial(subtraction2_refpad, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation), (x1, x2)) 224 | print('test case passed') 225 | 226 | if __name__ == '__main__': 227 | test_subtraction2_refpad() 228 | -------------------------------------------------------------------------------- /lib/sa/functions/subtraction2_zeropad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from torch.nn.modules.utils import _pair 4 | 5 | from lib.sa.functions.utils import Dtype, Stream, load_kernel 6 | 7 | 8 | CUDA_NUM_THREADS = 1024 9 | 10 | kernel_loop = ''' 11 | #define CUDA_KERNEL_LOOP(i, n) \ 12 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ 13 | i < (n); \ 14 | i += blockDim.x * gridDim.x) 15 | ''' 16 | 17 | 18 | def GET_BLOCKS(N): 19 | return (N + CUDA_NUM_THREADS - 1) // CUDA_NUM_THREADS 20 | 21 | 22 | _subtraction2_zeropad_forward_kernel = kernel_loop + ''' 23 | extern "C" 24 | __global__ void subtraction2_zeropad_forward_kernel( 25 | const ${Dtype}* bottom1_data, const ${Dtype}* bottom2_data, ${Dtype}* top_data) { 26 | CUDA_KERNEL_LOOP(index, ${nthreads}) { 27 | const int n = index / ${input_channels} / ${top_height} / ${top_width}; 28 | const int c = (index / ${top_height} / ${top_width}) % ${input_channels}; 29 | const int h = (index / ${top_width}) % ${top_height}; 30 | const int w = index % ${top_width}; 31 | const int h_in_center = -${pad_h} + h * ${stride_h} + (${kernel_h} - 1) / 2 * ${dilation_h}; 32 | const int w_in_center = -${pad_w} + w * ${stride_w} + (${kernel_w} - 1) / 2 * ${dilation_w}; 33 | const int offset_center = ((n * ${input_channels} + c) * ${bottom_height} + h_in_center) * ${bottom_width} + w_in_center; 34 | for (int kh = 0; kh < ${kernel_h}; ++kh) { 35 | for (int kw = 0; kw < ${kernel_w}; ++kw) { 36 | const int h_in = -${pad_h} + h * ${stride_h} + kh * ${dilation_h}; 37 | const int w_in = -${pad_w} + w * ${stride_w} + kw * ${dilation_w}; 38 | const int offset_top = ((n * ${input_channels} + c) * ${kernel_h} * ${kernel_w} + (kh * ${kernel_w} + kw)) * ${top_height} * ${top_width} + h * ${top_width} + w; 39 | if ((h_in >= 0) && (h_in < ${bottom_height}) && (w_in >= 0) && (w_in < ${bottom_width})) { 40 | const int offset_bottom = ((n * ${input_channels} + c) * ${bottom_height} + h_in) * ${bottom_width} + w_in; 41 | top_data[offset_top] = bottom1_data[offset_center] - bottom2_data[offset_bottom]; 42 | } 43 | else 44 | top_data[offset_top] = bottom1_data[offset_center]; 45 | } 46 | } 47 | } 48 | } 49 | ''' 50 | 51 | 52 | _subtraction2_zeropad_input1_backward_kernel = kernel_loop + ''' 53 | extern "C" 54 | __global__ void subtraction2_zeropad_input1_backward_kernel( 55 | const ${Dtype}* const top_diff, ${Dtype}* bottom_diff) { 56 | CUDA_KERNEL_LOOP(index, ${nthreads}) { 57 | const int n = index / ${input_channels} / ${bottom_height} / ${bottom_width}; 58 | const int c = (index / ${bottom_height} / ${bottom_width}) % ${input_channels}; 59 | const int h = (index / ${bottom_width}) % ${bottom_height}; 60 | const int w = index % ${bottom_width}; 61 | ${Dtype} value = 0; 62 | if (((h % ${stride_h}) == 0) && ((w % ${stride_w}) == 0)) { 63 | const int h_out = h / ${stride_h}; 64 | const int w_out = w / ${stride_w}; 65 | for (int kh = 0; kh < ${kernel_h}; ++kh) { 66 | for (int kw = 0; kw < ${kernel_w}; ++kw) { 67 | const int offset_top = ((n * ${input_channels} + c) * ${kernel_h} * ${kernel_w} + (kh * ${kernel_w} + kw)) * ${top_height} * ${top_width} + h_out * ${top_width} + w_out; 68 | value += top_diff[offset_top]; 69 | } 70 | } 71 | } 72 | bottom_diff[index] = value; 73 | } 74 | } 75 | ''' 76 | 77 | 78 | _subtraction2_zeropad_input2_backward_kernel = kernel_loop + ''' 79 | extern "C" 80 | __global__ void subtraction2_zeropad_input2_backward_kernel( 81 | const ${Dtype}* const top_diff, ${Dtype}* bottom_diff) { 82 | CUDA_KERNEL_LOOP(index, ${nthreads}) { 83 | const int n = index / ${input_channels} / ${bottom_height} / ${bottom_width}; 84 | const int c = (index / ${bottom_height} / ${bottom_width}) % ${input_channels}; 85 | const int h = (index / ${bottom_width}) % ${bottom_height}; 86 | const int w = index % ${bottom_width}; 87 | ${Dtype} value = 0; 88 | for (int kh = 0; kh < ${kernel_h}; ++kh) { 89 | for (int kw = 0; kw < ${kernel_w}; ++kw) { 90 | const int h_out_s = h + ${pad_h} - kh * ${dilation_h}; 91 | const int w_out_s = w + ${pad_w} - kw * ${dilation_w}; 92 | if (((h_out_s % ${stride_h}) == 0) && ((w_out_s % ${stride_w}) == 0)) { 93 | const int h_out = h_out_s / ${stride_h}; 94 | const int w_out = w_out_s / ${stride_w}; 95 | if ((h_out >= 0) && (h_out < ${top_height}) && (w_out >= 0) && (w_out < ${top_width})) { 96 | const int offset_top = ((n * ${input_channels} + c) * ${kernel_h} * ${kernel_w} + (kh * ${kernel_w} + kw)) * ${top_height} * ${top_width} + h_out * ${top_width} + w_out; 97 | value += -top_diff[offset_top]; 98 | } 99 | } 100 | } 101 | } 102 | bottom_diff[index] = value; 103 | } 104 | } 105 | ''' 106 | 107 | 108 | class Subtraction2Zeropad(Function): 109 | @staticmethod 110 | def forward(ctx, input1, input2, kernel_size, stride, padding, dilation): 111 | kernel_size, stride, padding, dilation = _pair(kernel_size), _pair(stride), _pair(padding), _pair(dilation) 112 | ctx.kernel_size, ctx.stride, ctx.padding, ctx.dilation = kernel_size, stride, padding, dilation 113 | assert input1.dim() == 4 and input1.is_cuda 114 | batch_size, input_channels, input_height, input_width = input1.size() 115 | output_height = int((input_height + 2 * padding[0] - (dilation[0] * (kernel_size[0] - 1) + 1)) / stride[0] + 1) 116 | output_width = int((input_width + 2 * padding[1] - (dilation[1] * (kernel_size[1] - 1) + 1)) / stride[1] + 1) 117 | output = input1.new(batch_size, input_channels, kernel_size[0] * kernel_size[1], output_height * output_width) 118 | n = output.numel() // output.shape[2] 119 | with torch.cuda.device_of(input1): 120 | f = load_kernel('subtraction2_zeropad_forward_kernel', _subtraction2_zeropad_forward_kernel, Dtype=Dtype(input1), nthreads=n, 121 | num=batch_size, input_channels=input_channels, 122 | bottom_height=input_height, bottom_width=input_width, 123 | top_height=output_height, top_width=output_width, 124 | kernel_h=kernel_size[0], kernel_w=kernel_size[1], 125 | stride_h=stride[0], stride_w=stride[1], 126 | dilation_h=dilation[0], dilation_w=dilation[1], 127 | pad_h=padding[0], pad_w=padding[1]) 128 | f(block=(CUDA_NUM_THREADS, 1, 1), 129 | grid=(GET_BLOCKS(n), 1, 1), 130 | args=[input1.data_ptr(), input2.data_ptr(), output.data_ptr()], 131 | stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) 132 | ctx.save_for_backward(input1, input2) 133 | return output 134 | 135 | @staticmethod 136 | def backward(ctx, grad_output): 137 | kernel_size, stride, padding, dilation = ctx.kernel_size, ctx.stride, ctx.padding, ctx.dilation 138 | input1, input2 = ctx.saved_tensors 139 | assert grad_output.is_cuda 140 | if not grad_output.is_contiguous(): 141 | grad_output = grad_output.contiguous() 142 | batch_size, input_channels, input_height, input_width = input1.size() 143 | output_height = int((input_height + 2 * padding[0] - (dilation[0] * (kernel_size[0] - 1) + 1)) / stride[0] + 1) 144 | output_width = int((input_width + 2 * padding[1] - (dilation[1] * (kernel_size[1] - 1) + 1)) / stride[1] + 1) 145 | grad_input1, grad_input2 = None, None 146 | opt = dict(Dtype=Dtype(grad_output), 147 | num=batch_size, input_channels=input_channels, 148 | bottom_height=input_height, bottom_width=input_width, 149 | top_height=output_height, top_width=output_width, 150 | kernel_h=kernel_size[0], kernel_w=kernel_size[1], 151 | stride_h=stride[0], stride_w=stride[1], 152 | dilation_h=dilation[0], dilation_w=dilation[1], 153 | pad_h=padding[0], pad_w=padding[1]) 154 | with torch.cuda.device_of(input1): 155 | if ctx.needs_input_grad[0]: 156 | grad_input1 = input1.new(input1.size()) 157 | n = grad_input1.numel() 158 | opt['nthreads'] = n 159 | f = load_kernel('subtraction2_zeropad_input1_backward_kernel', _subtraction2_zeropad_input1_backward_kernel, **opt) 160 | f(block=(CUDA_NUM_THREADS, 1, 1), 161 | grid=(GET_BLOCKS(n), 1, 1), 162 | args=[grad_output.data_ptr(), grad_input1.data_ptr()], 163 | stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) 164 | with torch.cuda.device_of(input2): 165 | if ctx.needs_input_grad[1]: 166 | grad_input2 = input2.new(input2.size()) 167 | n = grad_input2.numel() 168 | opt['nthreads'] = n 169 | f = load_kernel('subtraction2_zeropad_input2_backward_kernel', _subtraction2_zeropad_input2_backward_kernel, **opt) 170 | f(block=(CUDA_NUM_THREADS, 1, 1), 171 | grid=(GET_BLOCKS(n), 1, 1), 172 | args=[grad_output.data_ptr(), grad_input2.data_ptr()], 173 | stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) 174 | return grad_input1, grad_input2, None, None, None, None 175 | 176 | 177 | def subtraction2_zeropad(input1, input2, kernel_size=3, stride=1, padding=0, dilation=1): 178 | assert input1.dim() == 4 179 | if input1.is_cuda: 180 | out = Subtraction2Zeropad.apply(input1, input2, kernel_size, stride, padding, dilation) 181 | else: 182 | raise NotImplementedError 183 | return out 184 | 185 | 186 | def test_subtraction2_zeropad(): 187 | import os 188 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 189 | kernel_size, stride, dilation = 5, 4, 2 190 | padding = (dilation * (kernel_size - 1) + 1) // 2 191 | n, c, in_height, in_width = 2, 8, 9, 9 192 | out_height = int((in_height + 2 * padding - (dilation * (kernel_size - 1) + 1)) / stride + 1) 193 | out_width = int((in_width + 2 * padding - (dilation * (kernel_size - 1) + 1)) / stride + 1) 194 | x1 = torch.randn(n, c, in_height, in_width, requires_grad=True).double().cuda() 195 | x2 = torch.randn(n, c, in_height, in_width, requires_grad=True).double().cuda() 196 | 197 | y1 = subtraction2_zeropad(x1, x2, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation) 198 | unfold_i = torch.nn.Unfold(kernel_size=1, dilation=dilation, padding=0, stride=stride) 199 | unfold_j = torch.nn.Unfold(kernel_size=kernel_size, dilation=dilation, padding=padding, stride=stride) 200 | y2 = unfold_i(x1).view(n, c, 1, out_height * out_width) - unfold_j(x2).view(n, c, pow(kernel_size, 2), out_height * out_width) 201 | # y2 = unfold_i(x[:, :, kernel_size//2:-(kernel_size//2), kernel_size//2:-(kernel_size//2)]).view(n, c, 1, out_height * out_width) - unfold_j(x).view(n, c, pow(kernel_size, 2), out_height * out_width) 202 | assert (y1 - y2).abs().max() < 1e-9 203 | 204 | gx11 = torch.autograd.grad(y1.mean(), x1, retain_graph=True)[0] 205 | gx12 = torch.autograd.grad(y1.mean(), x2, retain_graph=True)[0] 206 | gx21 = torch.autograd.grad(y2.mean(), x1, retain_graph=True)[0] 207 | gx22 = torch.autograd.grad(y2.mean(), x2, retain_graph=True)[0] 208 | assert (gx11 - gx21).abs().max() < 1e-9 209 | assert (gx12 - gx22).abs().max() < 1e-9 210 | 211 | from functools import partial 212 | assert torch.autograd.gradcheck(partial(subtraction2_zeropad, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation), (x1, x2)) 213 | print('test case passed') 214 | 215 | 216 | if __name__ == '__main__': 217 | test_subtraction2_zeropad() 218 | -------------------------------------------------------------------------------- /lib/sa/functions/subtraction_refpad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from torch.nn.modules.utils import _pair 4 | 5 | from lib.sa.functions.utils import Dtype, Stream, load_kernel 6 | 7 | 8 | CUDA_NUM_THREADS = 1024 9 | 10 | kernel_loop = ''' 11 | #define CUDA_KERNEL_LOOP(i, n) \ 12 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ 13 | i < (n); \ 14 | i += blockDim.x * gridDim.x) 15 | ''' 16 | 17 | 18 | def GET_BLOCKS(N): 19 | return (N + CUDA_NUM_THREADS - 1) // CUDA_NUM_THREADS 20 | 21 | 22 | _subtraction_refpad_forward_kernel = kernel_loop + ''' 23 | extern "C" 24 | __global__ void subtraction_refpad_forward_kernel( 25 | const ${Dtype}* bottom_data, ${Dtype}* top_data) { 26 | CUDA_KERNEL_LOOP(index, ${nthreads}) { 27 | const int n = index / ${input_channels} / ${top_height} / ${top_width}; 28 | const int c = (index / ${top_height} / ${top_width}) % ${input_channels}; 29 | const int h = (index / ${top_width}) % ${top_height}; 30 | const int w = index % ${top_width}; 31 | const int h_in_center = -${pad_h} + h * ${stride_h} + (${kernel_h} - 1) / 2 * ${dilation_h}; 32 | const int w_in_center = -${pad_w} + w * ${stride_w} + (${kernel_w} - 1) / 2 * ${dilation_w}; 33 | const int offset_center = ((n * ${input_channels} + c) * ${bottom_height} + h_in_center) * ${bottom_width} + w_in_center; 34 | for (int kh = 0; kh < ${kernel_h}; ++kh) { 35 | for (int kw = 0; kw < ${kernel_w}; ++kw) { 36 | int h_in = -${pad_h} + h * ${stride_h} + kh * ${dilation_h}; 37 | int w_in = -${pad_w} + w * ${stride_w} + kw * ${dilation_w}; 38 | const int offset_top = ((n * ${input_channels} + c) * ${kernel_h} * ${kernel_w} + (kh * ${kernel_w} + kw)) * ${top_height} * ${top_width} + h * ${top_width} + w; 39 | int offset_bottom; 40 | if ((h_in >= 0) && (h_in < ${bottom_height}) && (w_in >= 0) && (w_in < ${bottom_width})) { 41 | offset_bottom = ((n * ${input_channels} + c) * ${bottom_height} + h_in) * ${bottom_width} + w_in; 42 | } 43 | else { 44 | if (h_in < 0) h_in = -h_in; 45 | if (h_in >= ${bottom_height}) h_in = 2 * (${bottom_height} - 1) - h_in; 46 | if (w_in < 0) w_in = -w_in; 47 | if (w_in >= ${bottom_width}) w_in = 2 * (${bottom_width} - 1) - w_in; 48 | offset_bottom = ((n * ${input_channels} + c) * ${bottom_height} + h_in) * ${bottom_width} + w_in; 49 | } 50 | top_data[offset_top] = bottom_data[offset_center] - bottom_data[offset_bottom]; 51 | } 52 | } 53 | } 54 | } 55 | ''' 56 | 57 | 58 | _subtraction_refpad_input_backward_kernel = kernel_loop + ''' 59 | extern "C" 60 | __global__ void subtraction_refpad_input_backward_kernel( 61 | const ${Dtype}* const top_diff, ${Dtype}* bottom_diff) { 62 | CUDA_KERNEL_LOOP(index, ${nthreads}) { 63 | const int n = index / ${input_channels} / (${bottom_height} + 2 * ${pad_h}) / (${bottom_width} + 2 * ${pad_w}); 64 | const int c = (index / (${bottom_height} + 2 * ${pad_h}) / (${bottom_width} + 2 * ${pad_w})) % ${input_channels}; 65 | const int h = (index / (${bottom_width} + 2 * ${pad_w})) % (${bottom_height} + 2 * ${pad_h}); 66 | const int w = index % (${bottom_width} + 2 * ${pad_w}); 67 | ${Dtype} value = 0; 68 | for (int kh = 0; kh < ${kernel_h}; ++kh) { 69 | for (int kw = 0; kw < ${kernel_w}; ++kw) { 70 | const int h_out_s = h - kh * ${dilation_h}; 71 | const int w_out_s = w - kw * ${dilation_w}; 72 | if (((h_out_s % ${stride_h}) == 0) && ((w_out_s % ${stride_w}) == 0)) { 73 | const int h_out = h_out_s / ${stride_h}; 74 | const int w_out = w_out_s / ${stride_w}; 75 | if ((h_out >= 0) && (h_out < ${top_height}) && (w_out >= 0) && (w_out < ${top_width})) { 76 | const int offset_top = ((n * ${input_channels} + c) * ${kernel_h} * ${kernel_w} + (kh * ${kernel_w} + kw)) * ${top_height} * ${top_width} + h_out * ${top_width} + w_out; 77 | value += -top_diff[offset_top]; 78 | } 79 | } 80 | } 81 | } 82 | const int hh = h - ${pad_h}; 83 | const int ww = w - ${pad_w}; 84 | if ((hh >= 0) && (hh < ${bottom_height}) && (ww >= 0) && (ww < ${bottom_width})) { 85 | if (((hh % ${stride_h}) == 0) && ((ww % ${stride_w}) == 0)) { 86 | const int h_out = hh / ${stride_h}; 87 | const int w_out = ww / ${stride_w}; 88 | for (int kh = 0; kh < ${kernel_h}; ++kh) { 89 | for (int kw = 0; kw < ${kernel_w}; ++kw) { 90 | const int offset_top = ((n * ${input_channels} + c) * ${kernel_h} * ${kernel_w} + (kh * ${kernel_w} + kw)) * ${top_height} * ${top_width} + h_out * ${top_width} + w_out; 91 | value += top_diff[offset_top]; 92 | } 93 | } 94 | } 95 | } 96 | bottom_diff[index] = value; 97 | } 98 | } 99 | ''' 100 | 101 | 102 | class SubtractionRefpad(Function): 103 | @staticmethod 104 | def forward(ctx, input, kernel_size, stride, padding, dilation): 105 | kernel_size, stride, padding, dilation = _pair(kernel_size), _pair(stride), _pair(padding), _pair(dilation) 106 | ctx.kernel_size, ctx.stride, ctx.padding, ctx.dilation = kernel_size, stride, padding, dilation 107 | assert input.dim() == 4 and input.is_cuda 108 | batch_size, input_channels, input_height, input_width = input.size() 109 | output_height = int((input_height + 2 * padding[0] - (dilation[0] * (kernel_size[0] - 1) + 1)) / stride[0] + 1) 110 | output_width = int((input_width + 2 * padding[1] - (dilation[1] * (kernel_size[1] - 1) + 1)) / stride[1] + 1) 111 | output = input.new(batch_size, input_channels, kernel_size[0] * kernel_size[1], output_height * output_width) 112 | n = output.numel() // output.shape[2] 113 | with torch.cuda.device_of(input): 114 | f = load_kernel('subtraction_refpad_forward_kernel', _subtraction_refpad_forward_kernel, Dtype=Dtype(input), nthreads=n, 115 | num=batch_size, input_channels=input_channels, 116 | bottom_height=input_height, bottom_width=input_width, 117 | top_height=output_height, top_width=output_width, 118 | kernel_h=kernel_size[0], kernel_w=kernel_size[1], 119 | stride_h=stride[0], stride_w=stride[1], 120 | dilation_h=dilation[0], dilation_w=dilation[1], 121 | pad_h=padding[0], pad_w=padding[1]) 122 | f(block=(CUDA_NUM_THREADS, 1, 1), 123 | grid=(GET_BLOCKS(n), 1, 1), 124 | args=[input.data_ptr(), output.data_ptr()], 125 | stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) 126 | ctx.save_for_backward(input) 127 | return output 128 | 129 | @staticmethod 130 | def backward(ctx, grad_output): 131 | kernel_size, stride, padding, dilation = ctx.kernel_size, ctx.stride, ctx.padding, ctx.dilation 132 | input, = ctx.saved_tensors 133 | assert grad_output.is_cuda 134 | if not grad_output.is_contiguous(): 135 | grad_output = grad_output.contiguous() 136 | batch_size, input_channels, input_height, input_width = input.size() 137 | output_height = int((input_height + 2 * padding[0] - (dilation[0] * (kernel_size[0] - 1) + 1)) / stride[0] + 1) 138 | output_width = int((input_width + 2 * padding[1] - (dilation[1] * (kernel_size[1] - 1) + 1)) / stride[1] + 1) 139 | grad_input = None 140 | opt = dict(Dtype=Dtype(grad_output), 141 | num=batch_size, input_channels=input_channels, 142 | bottom_height=input_height, bottom_width=input_width, 143 | top_height=output_height, top_width=output_width, 144 | kernel_h=kernel_size[0], kernel_w=kernel_size[1], 145 | stride_h=stride[0], stride_w=stride[1], 146 | dilation_h=dilation[0], dilation_w=dilation[1], 147 | pad_h=padding[0], pad_w=padding[1]) 148 | with torch.cuda.device_of(input): 149 | if ctx.needs_input_grad[0]: 150 | grad_input = input.new(batch_size, input_channels, input_height + 2 * padding[0], input_width + 2 * padding[1]) 151 | n = grad_input.numel() 152 | opt['nthreads'] = n 153 | f = load_kernel('subtraction_refpad_input_backward_kernel', _subtraction_refpad_input_backward_kernel, **opt) 154 | f(block=(CUDA_NUM_THREADS, 1, 1), 155 | grid=(GET_BLOCKS(n), 1, 1), 156 | args=[grad_output.data_ptr(), grad_input.data_ptr()], 157 | stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) 158 | grad_input[:, :, padding[0] + 1:2 * padding[0] + 1, :] += torch.flip(grad_input[:, :, :padding[0], :], dims=[2]) 159 | grad_input[:, :, input_height - 1:input_height + padding[0] - 1, :] += torch.flip(grad_input[:, :, input_height + padding[0]:, :], dims=[2]) 160 | grad_input[:, :, :, padding[1] + 1:2 * padding[1] + 1] += torch.flip(grad_input[:, :, :, :padding[1]], dims=[3]) 161 | grad_input[:, :, :, input_width - 1:input_width + padding[1] - 1] += torch.flip(grad_input[:, :, :, input_width + padding[1]:], dims=[3]) 162 | grad_input = grad_input[:, :, padding[0]:padding[0] + input_height, padding[1]:padding[1] + input_width] 163 | return grad_input, None, None, None, None 164 | 165 | 166 | def subtraction_refpad(input, kernel_size=3, stride=1, padding=0, dilation=1): 167 | assert input.dim() == 4 168 | if input.is_cuda: 169 | out = SubtractionRefpad.apply(input, kernel_size, stride, padding, dilation) 170 | else: 171 | raise NotImplementedError 172 | return out 173 | 174 | 175 | def test_subtraction_refpad(): 176 | import os 177 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 178 | kernel_size, stride, dilation = 5, 4, 2 179 | padding = (dilation * (kernel_size - 1) + 1) // 2 180 | n, c, in_height, in_width = 2, 8, 5, 5 181 | out_height = int((in_height + 2 * padding - (dilation * (kernel_size - 1) + 1)) / stride + 1) 182 | out_width = int((in_width + 2 * padding - (dilation * (kernel_size - 1) + 1)) / stride + 1) 183 | x = torch.randn(n, c, in_height, in_width, requires_grad=True).double().cuda() 184 | 185 | y1 = subtraction_refpad(x, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation) 186 | unfold_i = torch.nn.Unfold(kernel_size=1, dilation=dilation, padding=0, stride=stride) 187 | unfold_j = torch.nn.Unfold(kernel_size=kernel_size, dilation=dilation, padding=0, stride=stride) 188 | pad = torch.nn.ReflectionPad2d(padding) 189 | y2 = unfold_i(x).view(n, c, 1, out_height * out_width) - unfold_j(pad(x)).view(n, c, pow(kernel_size, 2), out_height * out_width) 190 | assert (y1 - y2).abs().max() < 1e-9 191 | 192 | gx1 = torch.autograd.grad(y1.mean(), x, retain_graph=True)[0] 193 | gx2 = torch.autograd.grad(y2.mean(), x, retain_graph=True)[0] 194 | assert (gx1 - gx2).abs().max() < 1e-9 195 | 196 | from functools import partial 197 | assert torch.autograd.gradcheck(partial(subtraction_refpad, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation), x) 198 | print('test case passed') 199 | 200 | 201 | if __name__ == '__main__': 202 | test_subtraction_refpad() 203 | -------------------------------------------------------------------------------- /lib/sa/functions/subtraction_zeropad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from torch.nn.modules.utils import _pair 4 | 5 | from lib.sa.functions.utils import Dtype, Stream, load_kernel 6 | 7 | 8 | CUDA_NUM_THREADS = 1024 9 | 10 | kernel_loop = ''' 11 | #define CUDA_KERNEL_LOOP(i, n) \ 12 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ 13 | i < (n); \ 14 | i += blockDim.x * gridDim.x) 15 | ''' 16 | 17 | 18 | def GET_BLOCKS(N): 19 | return (N + CUDA_NUM_THREADS - 1) // CUDA_NUM_THREADS 20 | 21 | 22 | _subtraction_zeropad_forward_kernel = kernel_loop + ''' 23 | extern "C" 24 | __global__ void subtraction_zeropad_forward_kernel( 25 | const ${Dtype}* bottom_data, ${Dtype}* top_data) { 26 | CUDA_KERNEL_LOOP(index, ${nthreads}) { 27 | const int n = index / ${input_channels} / ${top_height} / ${top_width}; 28 | const int c = (index / ${top_height} / ${top_width}) % ${input_channels}; 29 | const int h = (index / ${top_width}) % ${top_height}; 30 | const int w = index % ${top_width}; 31 | const int h_in_center = -${pad_h} + h * ${stride_h} + (${kernel_h} - 1) / 2 * ${dilation_h}; 32 | const int w_in_center = -${pad_w} + w * ${stride_w} + (${kernel_w} - 1) / 2 * ${dilation_w}; 33 | const int offset_center = ((n * ${input_channels} + c) * ${bottom_height} + h_in_center) * ${bottom_width} + w_in_center; 34 | for (int kh = 0; kh < ${kernel_h}; ++kh) { 35 | for (int kw = 0; kw < ${kernel_w}; ++kw) { 36 | const int h_in = -${pad_h} + h * ${stride_h} + kh * ${dilation_h}; 37 | const int w_in = -${pad_w} + w * ${stride_w} + kw * ${dilation_w}; 38 | const int offset_top = ((n * ${input_channels} + c) * ${kernel_h} * ${kernel_w} + (kh * ${kernel_w} + kw)) * ${top_height} * ${top_width} + h * ${top_width} + w; 39 | if ((h_in >= 0) && (h_in < ${bottom_height}) && (w_in >= 0) && (w_in < ${bottom_width})) { 40 | const int offset_bottom = ((n * ${input_channels} + c) * ${bottom_height} + h_in) * ${bottom_width} + w_in; 41 | top_data[offset_top] = bottom_data[offset_center] - bottom_data[offset_bottom]; 42 | } 43 | else 44 | top_data[offset_top] = bottom_data[offset_center]; 45 | } 46 | } 47 | } 48 | } 49 | ''' 50 | 51 | 52 | _subtraction_zeropad_input_backward_kernel = kernel_loop + ''' 53 | extern "C" 54 | __global__ void subtraction_zeropad_input_backward_kernel( 55 | const ${Dtype}* const top_diff, ${Dtype}* bottom_diff) { 56 | CUDA_KERNEL_LOOP(index, ${nthreads}) { 57 | const int n = index / ${input_channels} / ${bottom_height} / ${bottom_width}; 58 | const int c = (index / ${bottom_height} / ${bottom_width}) % ${input_channels}; 59 | const int h = (index / ${bottom_width}) % ${bottom_height}; 60 | const int w = index % ${bottom_width}; 61 | ${Dtype} value = 0; 62 | for (int kh = 0; kh < ${kernel_h}; ++kh) { 63 | for (int kw = 0; kw < ${kernel_w}; ++kw) { 64 | const int h_out_s = h + ${pad_h} - kh * ${dilation_h}; 65 | const int w_out_s = w + ${pad_w} - kw * ${dilation_w}; 66 | if (((h_out_s % ${stride_h}) == 0) && ((w_out_s % ${stride_w}) == 0)) { 67 | const int h_out = h_out_s / ${stride_h}; 68 | const int w_out = w_out_s / ${stride_w}; 69 | if ((h_out >= 0) && (h_out < ${top_height}) && (w_out >= 0) && (w_out < ${top_width})) { 70 | const int offset_top = ((n * ${input_channels} + c) * ${kernel_h} * ${kernel_w} + (kh * ${kernel_w} + kw)) * ${top_height} * ${top_width} + h_out * ${top_width} + w_out; 71 | value += -top_diff[offset_top]; 72 | } 73 | } 74 | } 75 | } 76 | if (((h % ${stride_h}) == 0) && ((w % ${stride_w}) == 0)) { 77 | const int h_out = h / ${stride_h}; 78 | const int w_out = w / ${stride_w}; 79 | for (int kh = 0; kh < ${kernel_h}; ++kh) { 80 | for (int kw = 0; kw < ${kernel_w}; ++kw) { 81 | const int offset_top = ((n * ${input_channels} + c) * ${kernel_h} * ${kernel_w} + (kh * ${kernel_w} + kw)) * ${top_height} * ${top_width} + h_out * ${top_width} + w_out; 82 | value += top_diff[offset_top]; 83 | } 84 | } 85 | } 86 | bottom_diff[index] = value; 87 | } 88 | } 89 | ''' 90 | 91 | 92 | class SubtractionZeropad(Function): 93 | @staticmethod 94 | def forward(ctx, input, kernel_size, stride, padding, dilation): 95 | kernel_size, stride, padding, dilation = _pair(kernel_size), _pair(stride), _pair(padding), _pair(dilation) 96 | ctx.kernel_size, ctx.stride, ctx.padding, ctx.dilation = kernel_size, stride, padding, dilation 97 | assert input.dim() == 4 and input.is_cuda 98 | batch_size, input_channels, input_height, input_width = input.size() 99 | output_height = int((input_height + 2 * padding[0] - (dilation[0] * (kernel_size[0] - 1) + 1)) / stride[0] + 1) 100 | output_width = int((input_width + 2 * padding[1] - (dilation[1] * (kernel_size[1] - 1) + 1)) / stride[1] + 1) 101 | output = input.new(batch_size, input_channels, kernel_size[0] * kernel_size[1], output_height * output_width) 102 | n = output.numel() // output.shape[2] 103 | with torch.cuda.device_of(input): 104 | f = load_kernel('subtraction_zeropad_forward_kernel', _subtraction_zeropad_forward_kernel, Dtype=Dtype(input), nthreads=n, 105 | num=batch_size, input_channels=input_channels, 106 | bottom_height=input_height, bottom_width=input_width, 107 | top_height=output_height, top_width=output_width, 108 | kernel_h=kernel_size[0], kernel_w=kernel_size[1], 109 | stride_h=stride[0], stride_w=stride[1], 110 | dilation_h=dilation[0], dilation_w=dilation[1], 111 | pad_h=padding[0], pad_w=padding[1]) 112 | f(block=(CUDA_NUM_THREADS, 1, 1), 113 | grid=(GET_BLOCKS(n), 1, 1), 114 | args=[input.data_ptr(), output.data_ptr()], 115 | stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) 116 | ctx.save_for_backward(input) 117 | return output 118 | 119 | @staticmethod 120 | def backward(ctx, grad_output): 121 | kernel_size, stride, padding, dilation = ctx.kernel_size, ctx.stride, ctx.padding, ctx.dilation 122 | input, = ctx.saved_tensors 123 | assert grad_output.is_cuda 124 | if not grad_output.is_contiguous(): 125 | grad_output = grad_output.contiguous() 126 | batch_size, input_channels, input_height, input_width = input.size() 127 | output_height = int((input_height + 2 * padding[0] - (dilation[0] * (kernel_size[0] - 1) + 1)) / stride[0] + 1) 128 | output_width = int((input_width + 2 * padding[1] - (dilation[1] * (kernel_size[1] - 1) + 1)) / stride[1] + 1) 129 | grad_input = None 130 | opt = dict(Dtype=Dtype(grad_output), 131 | num=batch_size, input_channels=input_channels, 132 | bottom_height=input_height, bottom_width=input_width, 133 | top_height=output_height, top_width=output_width, 134 | kernel_h=kernel_size[0], kernel_w=kernel_size[1], 135 | stride_h=stride[0], stride_w=stride[1], 136 | dilation_h=dilation[0], dilation_w=dilation[1], 137 | pad_h=padding[0], pad_w=padding[1]) 138 | with torch.cuda.device_of(input): 139 | if ctx.needs_input_grad[0]: 140 | grad_input = input.new(input.size()) 141 | n = grad_input.numel() 142 | opt['nthreads'] = n 143 | f = load_kernel('subtraction_zeropad_input_backward_kernel', _subtraction_zeropad_input_backward_kernel, **opt) 144 | f(block=(CUDA_NUM_THREADS, 1, 1), 145 | grid=(GET_BLOCKS(n), 1, 1), 146 | args=[grad_output.data_ptr(), grad_input.data_ptr()], 147 | stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) 148 | return grad_input, None, None, None, None 149 | 150 | 151 | def subtraction_zeropad(input, kernel_size=3, stride=1, padding=0, dilation=1): 152 | assert input.dim() == 4 153 | if input.is_cuda: 154 | out = SubtractionZeropad.apply(input, kernel_size, stride, padding, dilation) 155 | else: 156 | raise NotImplementedError 157 | return out 158 | 159 | 160 | def test_subtraction_zeropad(): 161 | import os 162 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 163 | kernel_size, stride, dilation = 5, 4, 2 164 | padding = (dilation * (kernel_size - 1) + 1) // 2 165 | n, c, in_height, in_width = 2, 8, 9, 9 166 | out_height = int((in_height + 2 * padding - (dilation * (kernel_size - 1) + 1)) / stride + 1) 167 | out_width = int((in_width + 2 * padding - (dilation * (kernel_size - 1) + 1)) / stride + 1) 168 | x = torch.randn(n, c, in_height, in_width, requires_grad=True).double().cuda() 169 | 170 | y1 = subtraction_zeropad(x, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation) 171 | unfold_i = torch.nn.Unfold(kernel_size=1, dilation=dilation, padding=0, stride=stride) 172 | unfold_j = torch.nn.Unfold(kernel_size=kernel_size, dilation=dilation, padding=padding, stride=stride) 173 | y2 = unfold_i(x).view(n, c, 1, out_height * out_width) - unfold_j(x).view(n, c, pow(kernel_size, 2), out_height * out_width) 174 | # y2 = unfold_i(x[:, :, kernel_size//2:-(kernel_size//2), kernel_size//2:-(kernel_size//2)]).view(n, c, 1, out_height * out_width) - unfold_j(x).view(n, c, pow(kernel_size, 2), out_height * out_width) 175 | assert (y1 - y2).abs().max() < 1e-9 176 | 177 | gx1 = torch.autograd.grad(y1.mean(), x, retain_graph=True)[0] 178 | gx2 = torch.autograd.grad(y2.mean(), x, retain_graph=True)[0] 179 | assert (gx1 - gx2).abs().max() < 1e-9 180 | 181 | from functools import partial 182 | assert torch.autograd.gradcheck(partial(subtraction_zeropad, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation), x) 183 | print('test case passed') 184 | 185 | 186 | if __name__ == '__main__': 187 | test_subtraction_zeropad() 188 | -------------------------------------------------------------------------------- /lib/sa/functions/utils.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from string import Template 3 | import cupy 4 | import torch 5 | 6 | 7 | Stream = namedtuple('Stream', ['ptr']) 8 | 9 | 10 | def Dtype(t): 11 | if isinstance(t, torch.cuda.FloatTensor): 12 | return 'float' 13 | elif isinstance(t, torch.cuda.DoubleTensor): 14 | return 'double' 15 | 16 | 17 | @cupy.memoize(for_each_device=True) 18 | def load_kernel(kernel_name, code, **kwargs): 19 | code = Template(code).substitute(**kwargs) 20 | kernel_code = cupy.cuda.compile_with_cache(code) 21 | return kernel_code.get_function(kernel_name) 22 | -------------------------------------------------------------------------------- /lib/sa/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .aggregation import * 2 | from .subtraction import * 3 | from .subtraction2 import * 4 | -------------------------------------------------------------------------------- /lib/sa/modules/aggregation.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn.modules.utils import _pair 3 | 4 | from .. import functional as F 5 | 6 | 7 | class Aggregation(nn.Module): 8 | 9 | def __init__(self, kernel_size, stride, padding, dilation, pad_mode): 10 | super(Aggregation, self).__init__() 11 | self.kernel_size = _pair(kernel_size) 12 | self.stride = _pair(stride) 13 | self.padding = _pair(padding) 14 | self.dilation = _pair(dilation) 15 | self.pad_mode = pad_mode 16 | 17 | def forward(self, input, weight): 18 | return F.aggregation(input, weight, self.kernel_size, self.stride, self.padding, self.dilation, self.pad_mode) 19 | -------------------------------------------------------------------------------- /lib/sa/modules/subtraction.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn.modules.utils import _pair 3 | 4 | from .. import functional as F 5 | 6 | 7 | class Subtraction(nn.Module): 8 | 9 | def __init__(self, kernel_size, stride, padding, dilation, pad_mode): 10 | super(Subtraction, self).__init__() 11 | self.kernel_size = _pair(kernel_size) 12 | self.stride = _pair(stride) 13 | self.padding = _pair(padding) 14 | self.dilation = _pair(dilation) 15 | self.pad_mode = pad_mode 16 | 17 | def forward(self, input): 18 | return F.subtraction(input, self.kernel_size, self.stride, self.padding, self.dilation, self.pad_mode) 19 | -------------------------------------------------------------------------------- /lib/sa/modules/subtraction2.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn.modules.utils import _pair 3 | 4 | from .. import functional as F 5 | 6 | 7 | class Subtraction2(nn.Module): 8 | 9 | def __init__(self, kernel_size, stride, padding, dilation, pad_mode): 10 | super(Subtraction2, self).__init__() 11 | self.kernel_size = _pair(kernel_size) 12 | self.stride = _pair(stride) 13 | self.padding = _pair(padding) 14 | self.dilation = _pair(dilation) 15 | self.pad_mode = pad_mode 16 | 17 | def forward(self, input1, input2): 18 | return F.subtraction2(input1, input2, self.kernel_size, self.stride, self.padding, self.dilation, self.pad_mode) 19 | -------------------------------------------------------------------------------- /lib/segmentation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .mask_predictor import SimpleDecoding 4 | from .backbone import MultiModalSwinTransformer 5 | from ._utils import LAVT, LAVTOne 6 | 7 | 8 | __all__ = ['lavt', 'lavt_one'] 9 | 10 | 11 | # LAVT 12 | def _segm_lavt(pretrained, args): 13 | # initialize the SwinTransformer backbone with the specified version 14 | if args.swin_type == 'tiny': 15 | embed_dim = 96 16 | depths = [2, 2, 6, 2] 17 | num_heads = [3, 6, 12, 24] 18 | elif args.swin_type == 'small': 19 | embed_dim = 96 20 | depths = [2, 2, 18, 2] 21 | num_heads = [3, 6, 12, 24] 22 | elif args.swin_type == 'base': 23 | embed_dim = 128 24 | depths = [2, 2, 18, 2] 25 | num_heads = [4, 8, 16, 32] 26 | elif args.swin_type == 'large': 27 | embed_dim = 192 28 | depths = [2, 2, 18, 2] 29 | num_heads = [6, 12, 24, 48] 30 | else: 31 | assert False 32 | # args.window12 added for test.py because state_dict is loaded after model initialization 33 | if 'window12' in pretrained or args.window12: 34 | print('Window size 12!') 35 | window_size = 12 36 | else: 37 | window_size = 7 38 | 39 | if args.mha: 40 | mha = args.mha.split('-') # if non-empty, then ['a', 'b', 'c', 'd'] 41 | mha = [int(a) for a in mha] 42 | else: 43 | mha = [1, 1, 1, 1] 44 | 45 | out_indices = (0, 1, 2, 3) 46 | backbone = MultiModalSwinTransformer(embed_dim=embed_dim, depths=depths, num_heads=num_heads, 47 | window_size=window_size, 48 | ape=False, drop_path_rate=0.3, patch_norm=True, 49 | out_indices=out_indices, 50 | use_checkpoint=False, num_heads_fusion=mha, 51 | fusion_drop=args.fusion_drop 52 | ) 53 | if pretrained: 54 | print('Initializing Multi-modal Swin Transformer weights from ' + pretrained) 55 | backbone.init_weights(pretrained=pretrained) 56 | else: 57 | print('Randomly initialize Multi-modal Swin Transformer weights.') 58 | backbone.init_weights() 59 | 60 | model_map = [SimpleDecoding, LAVT] 61 | 62 | classifier = model_map[0](8*embed_dim) 63 | base_model = model_map[1] 64 | 65 | model = base_model(backbone, classifier) 66 | return model 67 | 68 | 69 | def _load_model_lavt(pretrained, args): 70 | model = _segm_lavt(pretrained, args) 71 | return model 72 | 73 | 74 | def lavt(pretrained='', args=None): 75 | return _load_model_lavt(pretrained, args) 76 | 77 | 78 | ############################################### 79 | # LAVT One: put BERT inside the overall model # 80 | ############################################### 81 | def _segm_lavt_one(pretrained, args): 82 | # initialize the SwinTransformer backbone with the specified version 83 | if args.swin_type == 'tiny': 84 | embed_dim = 96 85 | depths = [2, 2, 6, 2] 86 | num_heads = [3, 6, 12, 24] 87 | elif args.swin_type == 'small': 88 | embed_dim = 96 89 | depths = [2, 2, 18, 2] 90 | num_heads = [3, 6, 12, 24] 91 | elif args.swin_type == 'base': 92 | embed_dim = 128 93 | depths = [2, 2, 18, 2] 94 | num_heads = [4, 8, 16, 32] 95 | elif args.swin_type == 'large': 96 | embed_dim = 192 97 | depths = [2, 2, 18, 2] 98 | num_heads = [6, 12, 24, 48] 99 | else: 100 | assert False 101 | # args.window12 added for test.py because state_dict is loaded after model initialization 102 | if 'window12' in pretrained or args.window12: 103 | print('Window size 12!') 104 | window_size = 12 105 | else: 106 | window_size = 7 107 | 108 | if args.mha: 109 | mha = args.mha.split('-') # if non-empty, then ['a', 'b', 'c', 'd'] 110 | mha = [int(a) for a in mha] 111 | else: 112 | mha = [1, 1, 1, 1] 113 | 114 | out_indices = (0, 1, 2, 3) 115 | backbone = MultiModalSwinTransformer(embed_dim=embed_dim, depths=depths, num_heads=num_heads, 116 | window_size=window_size, 117 | ape=False, drop_path_rate=0.3, patch_norm=True, 118 | out_indices=out_indices, 119 | use_checkpoint=False, num_heads_fusion=mha, 120 | fusion_drop=args.fusion_drop, 121 | # frozen_stages=args.frozen_stages, 122 | # only_fusion=args.only_fusion, 123 | ) 124 | if pretrained: 125 | print('Initializing Multi-modal Swin Transformer weights from ' + pretrained) 126 | backbone.init_weights(pretrained=pretrained) 127 | else: 128 | print('Randomly initialize Multi-modal Swin Transformer weights.') 129 | backbone.init_weights() 130 | 131 | model_map = [SimpleDecoding, LAVTOne] 132 | classifier = model_map[0](8*embed_dim) 133 | base_model = model_map[1] 134 | 135 | model = base_model(backbone, classifier, args) 136 | return model 137 | 138 | 139 | def _load_model_lavt_one(pretrained, args): 140 | model = _segm_lavt_one(pretrained, args) 141 | return model 142 | 143 | 144 | def lavt_one(pretrained='', args=None): 145 | return _load_model_lavt_one(pretrained, args) 146 | -------------------------------------------------------------------------------- /lib/transformer.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Optional, List 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, Tensor 6 | 7 | 8 | class Transformer_vis(nn.Module): 9 | 10 | def __init__(self, d_model=256, nhead=8, num_encoder_layers=6,dim_feedforward=2048, 11 | dropout=0.1, activation="relu", normalize_before=False): 12 | super().__init__() 13 | 14 | encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, 15 | dropout, activation, normalize_before) 16 | encoder_norm = nn.LayerNorm(d_model) if normalize_before else None 17 | self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) 18 | 19 | self._reset_parameters() 20 | 21 | self.d_model = d_model 22 | self.nhead = nhead 23 | 24 | def _reset_parameters(self): 25 | for p in self.parameters(): 26 | if p.dim() > 1: 27 | nn.init.xavier_uniform_(p) 28 | 29 | def forward(self, src, mask, pos_embed): 30 | # flatten NxCxHxW to HWxNxC 31 | bs, c, h, w = src.shape 32 | src = src.flatten(2).permute(2, 0, 1) 33 | pos_embed = pos_embed.flatten(2).permute(2, 0, 1) 34 | mask = mask.flatten(1) 35 | memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) 36 | return memory.permute(1, 2, 0).view(bs, c, h, w) 37 | 38 | class Transformer_Decoder(nn.Module): 39 | def __init__(self, d_model=512, nhead=8, 40 | num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, 41 | activation="relu", normalize_before=False, 42 | return_intermediate_dec=False): 43 | super().__init__() 44 | 45 | decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, 46 | dropout, activation, normalize_before) 47 | decoder_norm = nn.LayerNorm(d_model) 48 | self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, 49 | return_intermediate=return_intermediate_dec) 50 | 51 | self._reset_parameters() 52 | 53 | self.d_model = d_model 54 | self.nhead = nhead 55 | 56 | def _reset_parameters(self): 57 | for p in self.parameters(): 58 | if p.dim() > 1: 59 | nn.init.xavier_uniform_(p) 60 | 61 | def forward(self, tgt, memory, mask,pos_embed, query_embed): 62 | hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,pos=pos_embed, query_pos=query_embed) 63 | return hs 64 | 65 | class Transformer(nn.Module): 66 | 67 | def __init__(self, d_model=256, nhead=8, num_encoder_layers=6, 68 | num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, 69 | activation="relu", normalize_before=False, 70 | return_intermediate_dec=False): 71 | super().__init__() 72 | 73 | encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, 74 | dropout, activation, normalize_before) 75 | encoder_norm = nn.LayerNorm(d_model) if normalize_before else None 76 | self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) 77 | 78 | self._reset_parameters() 79 | 80 | self.d_model = d_model 81 | self.nhead = nhead 82 | 83 | def _reset_parameters(self): 84 | for p in self.parameters(): 85 | if p.dim() > 1: 86 | nn.init.xavier_uniform_(p) 87 | 88 | def forward(self, src, mask, pos_embed): 89 | # flatten NxCxHxW to HWxNxC 90 | # permute NxCxW to WxNxC 91 | src = src.permute(2, 0, 1) 92 | pos_embed = pos_embed.permute(1, 0, 2) 93 | memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) 94 | return memory 95 | 96 | 97 | class TransformerEncoder(nn.Module): 98 | def __init__(self, encoder_layer, num_layers, norm=None): 99 | super().__init__() 100 | self.layers = _get_clones(encoder_layer, num_layers) 101 | self.num_layers = num_layers 102 | self.norm = norm 103 | 104 | def forward(self, src, 105 | mask: Optional[Tensor] = None, # 没有用mask 106 | src_key_padding_mask: Optional[Tensor] = None, 107 | pos: Optional[Tensor] = None): 108 | output = src 109 | 110 | for layer in self.layers: 111 | output = layer(output, src_mask=mask, 112 | src_key_padding_mask=src_key_padding_mask, pos=pos) 113 | 114 | if self.norm is not None: 115 | output = self.norm(output) 116 | 117 | return output 118 | 119 | 120 | class TransformerDecoder(nn.Module): 121 | 122 | def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): 123 | super().__init__() 124 | self.layers = _get_clones(decoder_layer, num_layers) 125 | self.num_layers = num_layers 126 | self.norm = norm 127 | self.return_intermediate = return_intermediate 128 | 129 | def forward(self, tgt, memory, 130 | tgt_mask: Optional[Tensor] = None, 131 | memory_mask: Optional[Tensor] = None, 132 | tgt_key_padding_mask: Optional[Tensor] = None, 133 | memory_key_padding_mask: Optional[Tensor] = None, 134 | pos: Optional[Tensor] = None, 135 | query_pos: Optional[Tensor] = None): 136 | 137 | output = tgt 138 | 139 | intermediate = [] 140 | 141 | for layer in self.layers: 142 | output = layer(output, memory, tgt_mask=tgt_mask, 143 | memory_mask=memory_mask, 144 | tgt_key_padding_mask=tgt_key_padding_mask, 145 | memory_key_padding_mask=memory_key_padding_mask, 146 | pos=pos, query_pos=query_pos) 147 | if self.return_intermediate: 148 | intermediate.append(self.norm(output)) 149 | 150 | if self.norm is not None: 151 | output = self.norm(output) 152 | if self.return_intermediate: 153 | intermediate.pop() 154 | intermediate.append(output) 155 | 156 | if self.return_intermediate: 157 | return torch.stack(intermediate) 158 | 159 | return output.unsqueeze(0) 160 | 161 | 162 | class TransformerEncoderLayer(nn.Module): 163 | 164 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 165 | activation="relu", normalize_before=False): 166 | super().__init__() 167 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 168 | # Implementation of Feedforward model 169 | self.linear1 = nn.Linear(d_model, dim_feedforward) 170 | self.dropout = nn.Dropout(dropout) 171 | self.linear2 = nn.Linear(dim_feedforward, d_model) 172 | 173 | self.norm1 = nn.LayerNorm(d_model) 174 | self.norm2 = nn.LayerNorm(d_model) 175 | self.dropout1 = nn.Dropout(dropout) 176 | self.dropout2 = nn.Dropout(dropout) 177 | 178 | self.activation = _get_activation_fn(activation) 179 | self.normalize_before = normalize_before 180 | 181 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 182 | return tensor if pos is None else tensor + pos 183 | 184 | def forward_post(self, 185 | src, 186 | src_mask: Optional[Tensor] = None, 187 | src_key_padding_mask: Optional[Tensor] = None, 188 | pos: Optional[Tensor] = None): 189 | q = k = self.with_pos_embed(src, pos) 190 | src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, 191 | key_padding_mask=src_key_padding_mask)[0] 192 | src = src + self.dropout1(src2) 193 | src = self.norm1(src) 194 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 195 | src = src + self.dropout2(src2) 196 | src = self.norm2(src) 197 | return src 198 | 199 | def forward_pre(self, src, 200 | src_mask: Optional[Tensor] = None, 201 | src_key_padding_mask: Optional[Tensor] = None, 202 | pos: Optional[Tensor] = None): 203 | src2 = self.norm1(src) 204 | q = k = self.with_pos_embed(src2, pos) 205 | src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, 206 | key_padding_mask=src_key_padding_mask)[0] 207 | src = src + self.dropout1(src2) 208 | src2 = self.norm2(src) 209 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) 210 | src = src + self.dropout2(src2) 211 | return src 212 | 213 | def forward(self, src, 214 | src_mask: Optional[Tensor] = None, 215 | src_key_padding_mask: Optional[Tensor] = None, 216 | pos: Optional[Tensor] = None): 217 | if self.normalize_before: 218 | return self.forward_pre(src, src_mask, src_key_padding_mask, pos) 219 | return self.forward_post(src, src_mask, src_key_padding_mask, pos) 220 | 221 | 222 | class TransformerDecoderLayer(nn.Module): 223 | 224 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 225 | activation="relu", normalize_before=False): 226 | super().__init__() 227 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 228 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 229 | # Implementation of Feedforward model 230 | self.linear1 = nn.Linear(d_model, dim_feedforward) 231 | self.dropout = nn.Dropout(dropout) 232 | self.linear2 = nn.Linear(dim_feedforward, d_model) 233 | 234 | self.norm1 = nn.LayerNorm(d_model) 235 | self.norm2 = nn.LayerNorm(d_model) 236 | self.norm3 = nn.LayerNorm(d_model) 237 | self.dropout1 = nn.Dropout(dropout) 238 | self.dropout2 = nn.Dropout(dropout) 239 | self.dropout3 = nn.Dropout(dropout) 240 | 241 | self.activation = _get_activation_fn(activation) 242 | self.normalize_before = normalize_before 243 | 244 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 245 | return tensor if pos is None else tensor + pos# tensor 16 pos 4096 246 | 247 | def forward_post(self, tgt, memory, 248 | tgt_mask: Optional[Tensor] = None, 249 | memory_mask: Optional[Tensor] = None, 250 | tgt_key_padding_mask: Optional[Tensor] = None, 251 | memory_key_padding_mask: Optional[Tensor] = None, 252 | pos: Optional[Tensor] = None, 253 | query_pos: Optional[Tensor] = None): 254 | q = k = self.with_pos_embed(tgt, query_pos) 255 | tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, 256 | key_padding_mask=tgt_key_padding_mask)[0] 257 | tgt = tgt + self.dropout1(tgt2) 258 | tgt = self.norm1(tgt) 259 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), 260 | key=self.with_pos_embed(memory, pos), 261 | value=memory, attn_mask=memory_mask, 262 | key_padding_mask=memory_key_padding_mask)[0] 263 | tgt = tgt + self.dropout2(tgt2) 264 | tgt = self.norm2(tgt) 265 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 266 | tgt = tgt + self.dropout3(tgt2) 267 | tgt = self.norm3(tgt) 268 | return tgt 269 | 270 | def forward_pre(self, tgt, memory, 271 | tgt_mask: Optional[Tensor] = None, 272 | memory_mask: Optional[Tensor] = None, 273 | tgt_key_padding_mask: Optional[Tensor] = None, 274 | memory_key_padding_mask: Optional[Tensor] = None, 275 | pos: Optional[Tensor] = None, 276 | query_pos: Optional[Tensor] = None): 277 | tgt2 = self.norm1(tgt) 278 | q = k = self.with_pos_embed(tgt2, query_pos) 279 | tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, 280 | key_padding_mask=tgt_key_padding_mask)[0] 281 | tgt = tgt + self.dropout1(tgt2) 282 | tgt2 = self.norm2(tgt) 283 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), 284 | key=self.with_pos_embed(memory, pos), 285 | value=memory, attn_mask=memory_mask, 286 | key_padding_mask=memory_key_padding_mask)[0] 287 | tgt = tgt + self.dropout2(tgt2) 288 | tgt2 = self.norm3(tgt) 289 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 290 | tgt = tgt + self.dropout3(tgt2) 291 | return tgt 292 | 293 | def forward(self, tgt, memory, 294 | tgt_mask: Optional[Tensor] = None, 295 | memory_mask: Optional[Tensor] = None, 296 | tgt_key_padding_mask: Optional[Tensor] = None, 297 | memory_key_padding_mask: Optional[Tensor] = None, 298 | pos: Optional[Tensor] = None, 299 | query_pos: Optional[Tensor] = None): 300 | if self.normalize_before: 301 | return self.forward_pre(tgt, memory, tgt_mask, memory_mask, 302 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 303 | return self.forward_post(tgt, memory, tgt_mask, memory_mask, 304 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 305 | 306 | def _get_clones(module, N): 307 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 308 | 309 | 310 | 311 | def _get_activation_fn(activation): 312 | """Return an activation function given a string""" 313 | if activation == "relu": 314 | return F.relu 315 | if activation == "gelu": 316 | return F.gelu 317 | if activation == "glu": 318 | return F.glu 319 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.") 320 | -------------------------------------------------------------------------------- /lib/various_receptive.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import mean, nn 3 | from collections import OrderedDict 4 | from torch.nn import functional as F 5 | import numpy as np 6 | from numpy import random 7 | import os 8 | 9 | 10 | 11 | def transI_fusebn(kernel, bn): 12 | gamma = bn.weight 13 | std = (bn.running_var + bn.eps).sqrt() 14 | return kernel * ((gamma / std).reshape(-1, 1, 1, 1)), bn.bias - bn.running_mean * gamma / std 15 | 16 | 17 | def transIV_depthconcat(kernels, biases): 18 | return torch.cat(kernels, dim=0), torch.cat(biases) 19 | 20 | 21 | def transIII_1x1_kxk(k1, b1, k2, b2, groups): 22 | if groups == 1: 23 | k = F.conv2d(k2, k1.permute(1, 0, 2, 3)) # 24 | b_hat = (k2 * b1.reshape(1, -1, 1, 1)).sum((1, 2, 3)) 25 | else: 26 | k_slices = [] 27 | b_slices = [] 28 | k1_T = k1.permute(1, 0, 2, 3) 29 | k1_group_width = k1.size(0) // groups 30 | k2_group_width = k2.size(0) // groups 31 | for g in range(groups): 32 | k1_T_slice = k1_T[:, g*k1_group_width:(g+1)*k1_group_width, :, :] 33 | k2_slice = k2[g*k2_group_width:(g+1)*k2_group_width, :, :, :] 34 | k_slices.append(F.conv2d(k2_slice, k1_T_slice)) 35 | b_slices.append((k2_slice * b1[g*k1_group_width:(g+1)*k1_group_width].reshape(1, -1, 1, 1)).sum((1, 2, 3))) 36 | k, b_hat = transIV_depthconcat(k_slices, b_slices) 37 | return k, b_hat + b2 38 | 39 | 40 | def _conv_bn(input_channel,output_channel,kernel_size=3,padding=1,stride=1,groups=1): 41 | res=nn.Sequential() 42 | res.add_module('conv',nn.Conv2d(in_channels=input_channel,out_channels=output_channel,kernel_size=kernel_size,padding=padding,padding_mode='zeros',stride=stride,groups=groups,bias=False)) 43 | res.add_module('bn',nn.BatchNorm2d(output_channel)) 44 | return res 45 | 46 | 47 | def _conv_bn2(input_channel,output_channel,kernel_size=3,padding=1,stride=1,groups=1): 48 | res=nn.Sequential() 49 | res.add_module('conv1',nn.Conv2d(in_channels=input_channel,out_channels=output_channel,kernel_size=1,padding=0,padding_mode='zeros',stride=stride,groups=groups,bias=False)) 50 | res.add_module('bn1',nn.BatchNorm2d(output_channel)) 51 | res.add_module('conv2',nn.Conv2d(in_channels=input_channel,out_channels=output_channel,kernel_size=kernel_size,padding=padding,padding_mode='zeros',stride=stride,groups=groups,bias=False)) 52 | res.add_module('bn2',nn.BatchNorm2d(output_channel)) 53 | return res 54 | 55 | 56 | class RepBlock(nn.Module): 57 | def __init__(self,input_channel,output_channel,kernel_size=3,groups=1,stride=1): 58 | super().__init__() 59 | self.input_channel=input_channel 60 | self.output_channel=output_channel 61 | self.kernel_size=kernel_size 62 | self.padding=kernel_size//2 63 | self.groups=groups 64 | self.activation=nn.ReLU() 65 | self.sigmoid=nn.Sigmoid() 66 | 67 | #make sure kernel_size=3 padding=1 68 | assert self.kernel_size==3 69 | assert self.padding==1 70 | 71 | self.brb_3x3=_conv_bn2(input_channel,output_channel,kernel_size=self.kernel_size,padding=self.padding,groups=groups) 72 | self.brb_1x1=_conv_bn(input_channel,output_channel,kernel_size=1,padding=0,groups=groups) 73 | self.brb_identity=nn.BatchNorm2d(self.input_channel) if self.input_channel == self.output_channel else None 74 | 75 | self.brb_3x3_2=_conv_bn2(input_channel,output_channel,kernel_size=self.kernel_size,padding=self.padding,groups=groups) 76 | self.brb_1x1_2=_conv_bn(input_channel,output_channel,kernel_size=1,padding=0,groups=groups) 77 | self.brb_identity_2=nn.BatchNorm2d(self.input_channel) if self.input_channel == self.output_channel else None 78 | 79 | def forward(self, inputs): 80 | if(self.brb_identity==None): 81 | identity_out=0 82 | else: 83 | identity_out=self.brb_identity(inputs) 84 | out1=self.activation(self.brb_1x1(inputs)+self.brb_3x3(inputs)+identity_out) 85 | 86 | 87 | if(self.brb_identity_2==None): 88 | identity_out_2=0 89 | else: 90 | identity_out_2=self.brb_identity_2(out1) 91 | out2=self.brb_1x1_2(out1)+self.brb_3x3_2(out1)+identity_out_2 92 | 93 | # print('relu') 94 | 95 | return self.sigmoid(out2) 96 | 97 | 98 | class VariousReceptive(nn.Module): 99 | def __init__(self,dim): 100 | super().__init__() 101 | self.repblock = RepBlock(1, 1) 102 | 103 | def forward(self, x): 104 | bs, n, dim = x.shape 105 | h, w = int(np.sqrt(n)), int(np.sqrt(n)) 106 | 107 | input = x.view(bs, h, w, dim).permute(0, 3, 1, 2) # bs,dim,h,w 108 | mean_input = torch.mean(input,dim=1,keepdim=True) # bs,1,h,w 109 | weight = self.repblock(mean_input) # bs,1,h,w 110 | out = input * weight 111 | out = out.reshape(bs, dim, -1).permute(0, 2, 1) # bs,n,dim 112 | return out 113 | 114 | 115 | ###test 116 | if __name__ == '__main__': 117 | input=torch.randn(50,1,49,49) 118 | repblock=RepBlock(1,1) 119 | repblock.eval() 120 | out=repblock(input) 121 | 122 | -------------------------------------------------------------------------------- /loss/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | class DiceLoss: 6 | "Dice loss for segmentation" 7 | 8 | def __init__(self, 9 | axis: int = 1, # Class axis 10 | smooth: float = 1e-6, # Helps with numerical stabilities in the IoU division 11 | reduction: str = "sum", # PyTorch reduction to apply to the output 12 | square_in_union: bool = False # Squares predictions to increase slope of gradients 13 | ): 14 | self.axis = axis 15 | self.smooth = smooth 16 | self.reduction = reduction 17 | self.square_in_union = square_in_union 18 | 19 | def __call__(self, pred, targ): 20 | "One-hot encodes targ, then runs IoU calculation then takes 1-dice value" 21 | targ = self._one_hot(targ, pred.shape[self.axis]) 22 | assert pred.shape == targ.shape, 'input and target dimensions differ, DiceLoss expects non one-hot targs' 23 | pred = self.activation(pred) 24 | sum_dims = list(range(2, len(pred.shape))) 25 | inter = torch.sum(pred * targ, dim=sum_dims) 26 | union = (torch.sum(pred ** 2 + targ, dim=sum_dims) if self.square_in_union 27 | else torch.sum(pred + targ, dim=sum_dims)) 28 | dice_score = (2. * inter + self.smooth) / (union + self.smooth) 29 | loss = 1 - dice_score 30 | if self.reduction == 'mean': 31 | loss = loss.mean() 32 | elif self.reduction == 'sum': 33 | loss = loss.sum() 34 | return loss 35 | 36 | @staticmethod 37 | def _one_hot( 38 | x, # Non one-hot encoded targs 39 | classes: int, # The number of classes 40 | axis: int = 1 # The axis to stack for encoding (class dimension) 41 | ): 42 | "Creates one binary mask per class" 43 | return torch.stack([torch.where(x == c, 1, 0) for c in range(classes)], axis=axis) 44 | 45 | def activation(self, x): 46 | "Activation function applied to model output" 47 | return F.softmax(x, dim=self.axis) 48 | 49 | def decodes(self, x): 50 | "Converts model output to target format" 51 | return x.argmax(dim=self.axis) 52 | 53 | 54 | class Loss(): 55 | def __init__(self, weight=0.1): 56 | self.dice_loss = DiceLoss() 57 | self.ce_loss = torch.nn.CrossEntropyLoss(weight=torch.FloatTensor([0.9, 1.1]).cuda()) 58 | self.weight = weight 59 | 60 | def __call__(self, pred, targ): 61 | dice_loss = self.dice_loss(pred, targ) 62 | ce_loss = self.ce_loss(pred, targ) 63 | return (1 - self.weight) * ce_loss + self.weight * dice_loss -------------------------------------------------------------------------------- /pipeline.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lsan2401/RMSIN/b4a9b27c40cdc875d15de5ae2d36e7c10012371e/pipeline.jpg -------------------------------------------------------------------------------- /refer/refer.py: -------------------------------------------------------------------------------- 1 | """ 2 | This interface provides access to four datasets: 3 | 1) refclef 4 | 2) refcoco 5 | 3) refcoco+ 6 | 4) refcocog 7 | split by unc and google 8 | 9 | The following API functions are defined: 10 | REFER - REFER api class 11 | getRefIds - get ref ids that satisfy given filter conditions. 12 | getAnnIds - get ann ids that satisfy given filter conditions. 13 | getImgIds - get image ids that satisfy given filter conditions. 14 | getCatIds - get category ids that satisfy given filter conditions. 15 | loadRefs - load refs with the specified ref ids. 16 | loadAnns - load anns with the specified ann ids. 17 | loadImgs - load images with the specified image ids. 18 | loadCats - load category names with the specified category ids. 19 | getRefBox - get ref's bounding box [x, y, w, h] given the ref_id 20 | showRef - show image, segmentation or box of the referred object with the ref 21 | getMask - get mask and area of the referred object given ref 22 | showMask - show mask of the referred object given ref 23 | """ 24 | 25 | import sys 26 | import os.path as osp 27 | import json 28 | import pickle as pickle 29 | import time 30 | import itertools 31 | import skimage.io as io 32 | import matplotlib.pyplot as plt 33 | from matplotlib.collections import PatchCollection 34 | from matplotlib.patches import Polygon, Rectangle 35 | from pprint import pprint 36 | import numpy as np 37 | from pycocotools import mask 38 | 39 | 40 | class REFER: 41 | 42 | def __init__(self, data_root, dataset='refcoco', splitBy='unc'): 43 | # provide data_root folder which contains refclef, refcoco, refcoco+ and refcocog 44 | # also provide dataset name and splitBy information 45 | # e.g., dataset = 'refcoco', splitBy = 'unc' 46 | print('loading dataset %s into memory...' % dataset) 47 | if dataset == 'refcocog': 48 | print('Split by {}!'.format(splitBy)) 49 | self.ROOT_DIR = osp.abspath(osp.dirname(__file__)) 50 | self.DATA_DIR = osp.join(data_root, dataset) 51 | if dataset in ['refcoco', 'refcoco+', 'refcocog']: 52 | self.IMAGE_DIR = osp.join(data_root, 'images/mscoco/images/train2014') 53 | elif dataset == 'refclef': 54 | self.IMAGE_DIR = osp.join(data_root, 'images/saiapr_tc-12') 55 | elif dataset == 'rrsisd': 56 | self.IMAGE_DIR = osp.join(data_root, 'images/rrsisd/JPEGImages') 57 | else: 58 | print('No refer dataset is called [%s]' % dataset) 59 | sys.exit() 60 | 61 | # load refs from data/dataset/refs(dataset).json 62 | tic = time.time() 63 | ref_file = osp.join(self.DATA_DIR, 'refs(' + splitBy + ').p') 64 | self.data = {} 65 | self.data['dataset'] = dataset 66 | f = open(ref_file, 'r') 67 | self.data['refs'] = pickle.load(open(ref_file, 'rb')) 68 | 69 | # load annotations from data/dataset/instances.json 70 | instances_file = osp.join(self.DATA_DIR, 'instances.json') 71 | instances = json.load(open(instances_file, 'r')) 72 | self.data['images'] = instances['images'] 73 | self.data['annotations'] = instances['annotations'] 74 | self.data['categories'] = instances['categories'] 75 | 76 | # create index 77 | self.createIndex() 78 | print('DONE (t=%.2fs)' % (time.time() - tic)) 79 | 80 | def createIndex(self): 81 | # create sets of mapping 82 | # 1) Refs: {ref_id: ref} 83 | # 2) Anns: {ann_id: ann} 84 | # 3) Imgs: {image_id: image} 85 | # 4) Cats: {category_id: category_name} 86 | # 5) Sents: {sent_id: sent} 87 | # 6) imgToRefs: {image_id: refs} 88 | # 7) imgToAnns: {image_id: anns} 89 | # 8) refToAnn: {ref_id: ann} 90 | # 9) annToRef: {ann_id: ref} 91 | # 10) catToRefs: {category_id: refs} 92 | # 11) sentToRef: {sent_id: ref} 93 | # 12) sentToTokens: {sent_id: tokens} 94 | print('creating index...') 95 | # fetch info from instances 96 | Anns, Imgs, Cats, imgToAnns = {}, {}, {}, {} 97 | for ann in self.data['annotations']: 98 | Anns[ann['id']] = ann 99 | imgToAnns[ann['image_id']] = imgToAnns.get(ann['image_id'], []) + [ann] 100 | for img in self.data['images']: 101 | Imgs[img['id']] = img 102 | for cat in self.data['categories']: 103 | Cats[cat['id']] = cat['name'] 104 | 105 | # fetch info from refs 106 | Refs, imgToRefs, refToAnn, annToRef, catToRefs = {}, {}, {}, {}, {} 107 | Sents, sentToRef, sentToTokens = {}, {}, {} 108 | for ref in self.data['refs']: 109 | # ids 110 | ref_id = ref['ref_id'] 111 | ann_id = ref['ann_id'] 112 | category_id = ref['category_id'] 113 | image_id = ref['image_id'] 114 | 115 | # add mapping related to ref 116 | Refs[ref_id] = ref 117 | imgToRefs[image_id] = imgToRefs.get(image_id, []) + [ref] 118 | catToRefs[category_id] = catToRefs.get(category_id, []) + [ref] 119 | refToAnn[ref_id] = Anns[ann_id] 120 | annToRef[ann_id] = ref 121 | 122 | # add mapping of sent 123 | for sent in ref['sentences']: 124 | Sents[sent['sent_id']] = sent 125 | sentToRef[sent['sent_id']] = ref 126 | sentToTokens[sent['sent_id']] = sent['tokens'] 127 | 128 | # create class members 129 | self.Refs = Refs 130 | self.Anns = Anns 131 | self.Imgs = Imgs 132 | self.Cats = Cats 133 | self.Sents = Sents 134 | self.imgToRefs = imgToRefs 135 | self.imgToAnns = imgToAnns 136 | self.refToAnn = refToAnn 137 | self.annToRef = annToRef 138 | self.catToRefs = catToRefs 139 | self.sentToRef = sentToRef 140 | self.sentToTokens = sentToTokens 141 | print('index created.') 142 | 143 | def getRefIds(self, image_ids=[], cat_ids=[], ref_ids=[], split=''): 144 | image_ids = image_ids if type(image_ids) == list else [image_ids] 145 | cat_ids = cat_ids if type(cat_ids) == list else [cat_ids] 146 | ref_ids = ref_ids if type(ref_ids) == list else [ref_ids] 147 | 148 | if len(image_ids) == len(cat_ids) == len(ref_ids) == len(split) == 0: 149 | refs = self.data['refs'] 150 | else: 151 | if not len(image_ids) == 0: 152 | refs = [self.imgToRefs[image_id] for image_id in image_ids] 153 | else: 154 | refs = self.data['refs'] 155 | if not len(cat_ids) == 0: 156 | refs = [ref for ref in refs if ref['category_id'] in cat_ids] 157 | if not len(ref_ids) == 0: 158 | refs = [ref for ref in refs if ref['ref_id'] in ref_ids] 159 | if not len(split) == 0: 160 | if split in ['testA', 'testB', 'testC']: 161 | refs = [ref for ref in refs if split[-1] in ref['split']] # we also consider testAB, testBC, ... 162 | elif split in ['testAB', 'testBC', 'testAC']: 163 | refs = [ref for ref in refs if ref['split'] == split] # rarely used I guess... 164 | elif split == 'test': 165 | refs = [ref for ref in refs if 'test' in ref['split']] 166 | elif split == 'train' or split == 'val': 167 | refs = [ref for ref in refs if ref['split'] == split] 168 | else: 169 | print('No such split [%s]' % split) 170 | sys.exit() 171 | ref_ids = [ref['ref_id'] for ref in refs] 172 | return ref_ids 173 | 174 | def getAnnIds(self, image_ids=[], cat_ids=[], ref_ids=[]): 175 | image_ids = image_ids if type(image_ids) == list else [image_ids] 176 | cat_ids = cat_ids if type(cat_ids) == list else [cat_ids] 177 | ref_ids = ref_ids if type(ref_ids) == list else [ref_ids] 178 | 179 | if len(image_ids) == len(cat_ids) == len(ref_ids) == 0: 180 | ann_ids = [ann['id'] for ann in self.data['annotations']] 181 | else: 182 | if not len(image_ids) == 0: 183 | lists = [self.imgToAnns[image_id] for image_id in image_ids if 184 | image_id in self.imgToAnns] # list of [anns] 185 | anns = list(itertools.chain.from_iterable(lists)) 186 | else: 187 | anns = self.data['annotations'] 188 | if not len(cat_ids) == 0: 189 | anns = [ann for ann in anns if ann['category_id'] in cat_ids] 190 | ann_ids = [ann['id'] for ann in anns] 191 | if not len(ref_ids) == 0: 192 | ids = set(ann_ids).intersection(set([self.Refs[ref_id]['ann_id'] for ref_id in ref_ids])) 193 | return ann_ids 194 | 195 | def getImgIds(self, ref_ids=[]): 196 | ref_ids = ref_ids if type(ref_ids) == list else [ref_ids] 197 | 198 | if not len(ref_ids) == 0: 199 | image_ids = list(set([self.Refs[ref_id]['image_id'] for ref_id in ref_ids])) 200 | else: 201 | image_ids = self.Imgs.keys() 202 | return image_ids 203 | 204 | def getCatIds(self): 205 | return self.Cats.keys() 206 | 207 | def loadRefs(self, ref_ids=[]): 208 | if type(ref_ids) == list: 209 | return [self.Refs[ref_id] for ref_id in ref_ids] 210 | elif type(ref_ids) == int: 211 | return [self.Refs[ref_ids]] 212 | 213 | def loadAnns(self, ann_ids=[]): 214 | if type(ann_ids) == list: 215 | return [self.Anns[ann_id] for ann_id in ann_ids] 216 | elif type(ann_ids) == int or type(ann_ids) == unicode: 217 | return [self.Anns[ann_ids]] 218 | 219 | def loadImgs(self, image_ids=[]): 220 | if type(image_ids) == list: 221 | return [self.Imgs[image_id] for image_id in image_ids] 222 | elif type(image_ids) == int: 223 | return [self.Imgs[image_ids]] 224 | 225 | def loadCats(self, cat_ids=[]): 226 | if type(cat_ids) == list: 227 | return [self.Cats[cat_id] for cat_id in cat_ids] 228 | elif type(cat_ids) == int: 229 | return [self.Cats[cat_ids]] 230 | 231 | def getRefBox(self, ref_id): 232 | ref = self.Refs[ref_id] 233 | ann = self.refToAnn[ref_id] 234 | return ann['bbox'] # [x, y, w, h] 235 | 236 | def showRef(self, ref, seg_box='seg'): 237 | ax = plt.gca() 238 | # show image 239 | image = self.Imgs[ref['image_id']] 240 | I = io.imread(osp.join(self.IMAGE_DIR, image['file_name'])) 241 | ax.imshow(I) 242 | # show refer expression 243 | for sid, sent in enumerate(ref['sentences']): 244 | print('%s. %s' % (sid + 1, sent['sent'])) 245 | # show segmentations 246 | if seg_box == 'seg': 247 | ann_id = ref['ann_id'] 248 | ann = self.Anns[ann_id] 249 | polygons = [] 250 | color = [] 251 | c = 'none' 252 | if type(ann['segmentation'][0]) == list: 253 | # polygon used for refcoco* 254 | for seg in ann['segmentation']: 255 | poly = np.array(seg).reshape((len(seg) / 2, 2)) 256 | polygons.append(Polygon(poly, True, alpha=0.4)) 257 | color.append(c) 258 | p = PatchCollection(polygons, facecolors=color, edgecolors=(1, 1, 0, 0), linewidths=3, alpha=1) 259 | ax.add_collection(p) # thick yellow polygon 260 | p = PatchCollection(polygons, facecolors=color, edgecolors=(1, 0, 0, 0), linewidths=1, alpha=1) 261 | ax.add_collection(p) # thin red polygon 262 | else: 263 | # mask used for refclef 264 | rle = ann['segmentation'] 265 | m = mask.decode(rle) 266 | img = np.ones((m.shape[0], m.shape[1], 3)) 267 | color_mask = np.array([2.0, 166.0, 101.0]) / 255 268 | for i in range(3): 269 | img[:, :, i] = color_mask[i] 270 | ax.imshow(np.dstack((img, m * 0.5))) 271 | # show bounding-box 272 | elif seg_box == 'box': 273 | ann_id = ref['ann_id'] 274 | ann = self.Anns[ann_id] 275 | bbox = self.getRefBox(ref['ref_id']) 276 | box_plot = Rectangle((bbox[0], bbox[1]), bbox[2], bbox[3], fill=False, edgecolor='green', linewidth=3) 277 | ax.add_patch(box_plot) 278 | 279 | def getMask(self, ref): 280 | # return mask, area and mask-center 281 | ann = self.refToAnn[ref['ref_id']] 282 | image = self.Imgs[ref['image_id']] 283 | if type(ann['segmentation'][0]) == list: # polygon 284 | rle = mask.frPyObjects(ann['segmentation'], image['height'], image['width']) 285 | else: 286 | rle = ann['segmentation'] 287 | 288 | m = mask.decode(rle) 289 | m = np.sum(m, axis=2) # sometimes there are multiple binary map (corresponding to multiple segs) 290 | m = m.astype(np.uint8) # convert to np.uint8 291 | # compute area 292 | area = sum(mask.area(rle)) # should be close to ann['area'] 293 | 294 | return {'mask': m, 'area': area} 295 | 296 | 297 | def showMask(self, ref): 298 | M = self.getMask(ref) 299 | msk = M['mask'] 300 | ax = plt.gca() 301 | ax.imshow(msk) 302 | 303 | 304 | if __name__ == '__main__': 305 | refer = REFER(dataset='refcocog', splitBy='google') 306 | ref_ids = refer.getRefIds() 307 | 308 | ref_ids = refer.getRefIds(split='train') 309 | print('There are %s training referred objects.' % len(ref_ids)) 310 | 311 | for ref_id in ref_ids: 312 | ref = refer.loadRefs(ref_id)[0] 313 | if len(ref['sentences']) < 2: 314 | continue 315 | print('The label is %s.' % refer.Cats[ref['category_id']]) 316 | plt.figure() 317 | refer.showRef(ref, seg_box='box') 318 | plt.show() 319 | 320 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | requests 2 | filelock 3 | tqdm 4 | timm 5 | mmcv-full==1.3.12 6 | mmsegmentation==0.17.0 7 | ftfy 8 | regex 9 | scipy 10 | scikit-image 11 | pycocotools==2.0.2 12 | opencv-python==4.5.3.56 13 | tokenizers==0.8.1rc1 14 | h5py -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | import utils 4 | import numpy as np 5 | import transforms as T 6 | import random 7 | from bert.modeling_bert import BertModel 8 | from lib import segmentation 9 | 10 | 11 | 12 | 13 | def get_dataset(image_set, transform, args): 14 | from data.dataset_refer_bert import ReferDataset 15 | ds = ReferDataset(args, 16 | split=image_set, 17 | image_transforms=transform, 18 | target_transforms=None, 19 | eval_mode=True 20 | ) 21 | num_classes = 2 22 | return ds, num_classes 23 | 24 | 25 | def evaluate(model, data_loader, bert_model, device): 26 | model.eval() 27 | metric_logger = utils.MetricLogger(delimiter=" ") 28 | 29 | # evaluation variables 30 | cum_I, cum_U = 0, 0 31 | eval_seg_iou_list = [.5, .6, .7, .8, .9] 32 | seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32) 33 | seg_total = 0 34 | mean_IoU = [] 35 | header = 'Test:' 36 | 37 | with torch.no_grad(): 38 | 39 | for data in metric_logger.log_every(data_loader, 100, header): 40 | image, target, sentences, attentions = data 41 | image, target, sentences, attentions = image.to(device), target.to(device), \ 42 | sentences.to(device), attentions.to(device) 43 | sentences = sentences.squeeze(1) 44 | attentions = attentions.squeeze(1) 45 | target = target.cpu().data.numpy() 46 | for j in range(sentences.size(-1)): 47 | if bert_model is not None: 48 | last_hidden_states = bert_model(sentences[:, :, j], attention_mask=attentions[:, :, j])[0] 49 | embedding = last_hidden_states.permute(0, 2, 1) 50 | output = model(image, embedding, l_mask=attentions[:, :, j].unsqueeze(-1)) 51 | else: 52 | output = model(image, sentences[:, :, j], l_mask=attentions[:, :, j]) 53 | 54 | output = output.cpu() 55 | 56 | output_mask = output.argmax(1).data.numpy() 57 | 58 | I, U = computeIoU(output_mask, target) 59 | if U == 0: 60 | this_iou = 0.0 61 | else: 62 | this_iou = I*1.0/U 63 | mean_IoU.append(this_iou) 64 | cum_I += I 65 | cum_U += U 66 | for n_eval_iou in range(len(eval_seg_iou_list)): 67 | eval_seg_iou = eval_seg_iou_list[n_eval_iou] 68 | seg_correct[n_eval_iou] += (this_iou >= eval_seg_iou) 69 | 70 | seg_total += 1 71 | 72 | 73 | del image, target, sentences, attentions, output,output_mask 74 | if bert_model is not None: 75 | del last_hidden_states, embedding 76 | 77 | mean_IoU = np.array(mean_IoU) 78 | mIoU = np.mean(mean_IoU) 79 | print('Final results:') 80 | print('Mean IoU is %.2f\n' % (mIoU*100.)) 81 | results_str = '' 82 | for n_eval_iou in range(len(eval_seg_iou_list)): 83 | results_str += ' precision@%s = %.2f\n' % \ 84 | (str(eval_seg_iou_list[n_eval_iou]), seg_correct[n_eval_iou] * 100. / seg_total) 85 | results_str += ' overall IoU = %.2f\n' % (cum_I * 100. / cum_U) 86 | print(results_str) 87 | 88 | 89 | 90 | 91 | def get_transform(args): 92 | transforms = [T.Resize(args.img_size, args.img_size), 93 | T.ToTensor(), 94 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 95 | ] 96 | 97 | return T.Compose(transforms) 98 | 99 | 100 | def computeIoU(pred_seg, gd_seg): 101 | I = np.sum(np.logical_and(pred_seg, gd_seg)) 102 | U = np.sum(np.logical_or(pred_seg, gd_seg)) 103 | 104 | return I, U 105 | 106 | 107 | def main(args): 108 | device = torch.device(args.device) 109 | dataset_test, _ = get_dataset(args.split, get_transform(args=args), args) 110 | 111 | test_sampler = torch.utils.data.SequentialSampler(dataset_test) 112 | data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=1, 113 | sampler=test_sampler, num_workers=args.workers) 114 | print(args.model) 115 | single_model = segmentation.__dict__[args.model](pretrained='',args=args) 116 | checkpoint = torch.load(args.resume, map_location='cpu') 117 | single_model.load_state_dict(checkpoint['model'], strict=False) 118 | model = single_model.to(device) 119 | 120 | if args.model != 'lavt_one': 121 | model_class = BertModel 122 | single_bert_model = model_class.from_pretrained(args.ck_bert) 123 | # work-around for a transformers bug; need to update to a newer version of transformers to remove these two lines 124 | if args.ddp_trained_weights: 125 | single_bert_model.pooler = None 126 | single_bert_model.load_state_dict(checkpoint['bert_model']) 127 | bert_model = single_bert_model.to(device) 128 | else: 129 | bert_model = None 130 | 131 | evaluate(model, data_loader_test, bert_model, device=device) 132 | 133 | 134 | if __name__ == "__main__": 135 | from args import get_parser 136 | parser = get_parser() 137 | args = parser.parse_args() 138 | print('Image size: {}'.format(str(args.img_size))) 139 | main(args) 140 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import time 4 | import torch 5 | import torch.utils.data 6 | import wandb 7 | import cv2 8 | import random 9 | import transforms as T 10 | import utils 11 | import numpy as np 12 | import gc 13 | import operator 14 | from functools import reduce 15 | from bert.modeling_bert import BertModel 16 | from lib import segmentation 17 | from loss.loss import Loss 18 | 19 | 20 | 21 | def seed_everything(seed=2401): 22 | random.seed(seed) 23 | os.environ['PYTHONHASHSEED'] = str(seed) 24 | np.random.seed(seed) 25 | torch.manual_seed(seed) 26 | torch.cuda.manual_seed(seed) 27 | torch.cuda.manual_seed_all(seed) 28 | torch.backends.cudnn.benchmark = False 29 | torch.backends.cudnn.deterministic = True 30 | 31 | 32 | def get_dataset(image_set, transform, args): 33 | from data.dataset_refer_bert import ReferDataset 34 | ds = ReferDataset(args, 35 | split=image_set, 36 | image_transforms=transform, 37 | target_transforms=None 38 | ) 39 | num_classes = 2 40 | 41 | return ds, num_classes 42 | 43 | 44 | def IoU(pred, gt): 45 | pred = pred.argmax(1) 46 | 47 | intersection = torch.sum(torch.mul(pred, gt)) 48 | union = torch.sum(torch.add(pred, gt)) - intersection 49 | 50 | if intersection == 0 or union == 0: 51 | iou = 0 52 | else: 53 | iou = float(intersection) / float(union) 54 | return iou, intersection, union 55 | 56 | 57 | def get_transform(args): 58 | transforms = [ 59 | T.Resize(args.img_size, args.img_size), 60 | T.ToTensor(), 61 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 62 | ] 63 | return T.Compose(transforms) 64 | 65 | 66 | def criterion(input, target, weight=0.1): 67 | return Loss(weight=weight)(input, target) 68 | 69 | 70 | def evaluate(model, data_loader, bert_model, epoch): 71 | model.eval() 72 | metric_logger = utils.MetricLogger(delimiter=" ") 73 | header = "Test: " 74 | total_its = 0 75 | acc_ious = 0 76 | 77 | # evaluation variables 78 | cum_I, cum_U = 0, 0 79 | eval_seg_iou_list = [.5, .6, .7, .8, .9] 80 | seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32) 81 | seg_total = 0 82 | mean_IoU = [] 83 | total_loss = 0 84 | 85 | with torch.no_grad(): 86 | for data in metric_logger.log_every(data_loader, 100, header): 87 | total_its += 1 88 | image, target, sentences, attentions = data 89 | pixels = cv2.countNonZero(target.data.numpy()[0]) / 230400. 90 | image, target, sentences, attentions = image.cuda(non_blocking=True),\ 91 | target.cuda(non_blocking=True),\ 92 | sentences.cuda(non_blocking=True),\ 93 | attentions.cuda(non_blocking=True) 94 | 95 | sentences = sentences.squeeze(1) 96 | attentions = attentions.squeeze(1) 97 | 98 | 99 | if bert_model is not None: 100 | last_hidden_states = bert_model(sentences, attention_mask=attentions)[0] 101 | embedding = last_hidden_states.permute(0, 2, 1) # (B, 768, N_l) to make Conv1d happy 102 | attentions = attentions.unsqueeze(dim=-1) # (B, N_l, 1) 103 | output = model(image, embedding, l_mask=attentions) 104 | else: 105 | output = model(image, sentences, l_mask=attentions) 106 | 107 | iou, I, U = IoU(output, target) 108 | loss = criterion(output, target) 109 | total_loss += loss.item() 110 | acc_ious += iou 111 | mean_IoU.append(iou) 112 | cum_I += I 113 | cum_U += U 114 | for n_eval_iou in range(len(eval_seg_iou_list)): 115 | eval_seg_iou = eval_seg_iou_list[n_eval_iou] 116 | seg_correct[n_eval_iou] += (iou >= eval_seg_iou) 117 | seg_total += 1 118 | iou = acc_ious / total_its 119 | 120 | mean_IoU = np.array(mean_IoU) 121 | mIoU = np.mean(mean_IoU) 122 | print('Final results:') 123 | print('Mean IoU is %.2f\n' % (mIoU * 100.)) 124 | results_str = '' 125 | for n_eval_iou in range(len(eval_seg_iou_list)): 126 | results_str += ' precision@%s = %.2f\n' % \ 127 | (str(eval_seg_iou_list[n_eval_iou]), seg_correct[n_eval_iou] * 100. / seg_total) 128 | results_str += ' overall IoU = %.2f\n' % (cum_I * 100. / cum_U) 129 | print(results_str) 130 | 131 | if args.local_rank == 0: 132 | wandb.log({ 133 | "val mIoU": mIoU, 134 | "val oiou": cum_I * 100. / cum_U, 135 | "val Loss": total_loss / total_its}) 136 | 137 | return 100 * iou, 100 * cum_I / cum_U 138 | 139 | 140 | def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, epoch, print_freq, 141 | iterations, bert_model): 142 | model.train() 143 | metric_logger = utils.MetricLogger(delimiter=" ") 144 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}')) 145 | header = 'Epoch: [{}]'.format(epoch) 146 | train_loss = 0 147 | total_its = 0 148 | 149 | # for data in data_loader: 150 | for i, data in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 151 | total_its += 1 152 | image, target, sentences, attentions = data 153 | image, target, sentences, attentions = image.cuda(non_blocking=True),\ 154 | target.cuda(non_blocking=True),\ 155 | sentences.cuda(non_blocking=True),\ 156 | attentions.cuda(non_blocking=True) 157 | 158 | sentences = sentences.squeeze(1) 159 | attentions = attentions.squeeze(1) 160 | 161 | if bert_model is not None: 162 | 163 | last_hidden_states = bert_model(sentences, attention_mask=attentions)[0] # (6, 10, 768) 164 | embedding = last_hidden_states.permute(0, 2, 1) # (B, 768, N_l) to make Conv1d happy 165 | attentions = attentions.unsqueeze(dim=-1) # (batch, N_l, 1) 166 | output = model(image, embedding, attentions)#, sentences_hidden_state)# [4,2,120,120] 167 | else: 168 | output = model(image, sentences, attentions)#, sentences_hidden_state) 169 | optimizer.zero_grad() 170 | loss = criterion(output, target) 171 | 172 | loss.backward() 173 | optimizer.step() 174 | lr_scheduler.step() 175 | 176 | 177 | torch.cuda.synchronize() 178 | train_loss += loss.item() 179 | iterations += 1 180 | metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"]) 181 | 182 | del image, target, sentences, attentions, loss, output, data 183 | if bert_model is not None: 184 | del last_hidden_states, embedding 185 | 186 | gc.collect() 187 | torch.cuda.empty_cache() 188 | torch.cuda.synchronize() 189 | if args.local_rank == 0: 190 | wandb.log({ 191 | "Train Loss": train_loss / total_its,}) 192 | 193 | 194 | def main(args): 195 | dataset, num_classes = get_dataset("train", 196 | get_transform(args=args), 197 | args=args) 198 | 199 | dataset_test, _ = get_dataset("val", 200 | get_transform(args=args), 201 | args=args) 202 | 203 | # batch sampler 204 | print(f"local rank {args.local_rank} / global rank {utils.get_rank()} successfully built train dataset.") 205 | num_tasks = utils.get_world_size() 206 | global_rank = utils.get_rank() 207 | train_sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, 208 | shuffle=True) 209 | test_sampler = torch.utils.data.SequentialSampler(dataset_test) 210 | 211 | # data loader 212 | data_loader = torch.utils.data.DataLoader( 213 | dataset, batch_size=args.batch_size, 214 | sampler=train_sampler, num_workers=args.workers, pin_memory=args.pin_mem, drop_last=True) 215 | 216 | data_loader_test = torch.utils.data.DataLoader( 217 | dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers) 218 | 219 | # model initialization 220 | print(args.model) 221 | model = segmentation.__dict__[args.model](pretrained=args.pretrained_swin_weights, 222 | args=args) 223 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 224 | model.cuda() 225 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True) 226 | single_model = model.module #ddp 227 | 228 | # print(model) 229 | if args.model != 'lavt_one': 230 | model_class = BertModel 231 | bert_model = model_class.from_pretrained(args.ck_bert) 232 | bert_model.pooler = None # a work-around for a bug in Transformers = 3.0.2 that appears for DistributedDataParallel 233 | bert_model.cuda() 234 | bert_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(bert_model) 235 | bert_model = torch.nn.parallel.DistributedDataParallel(bert_model, device_ids=[args.local_rank]) 236 | single_bert_model = bert_model 237 | else: 238 | bert_model = None 239 | single_bert_model = None 240 | 241 | # resume training 242 | if args.resume: 243 | checkpoint = torch.load(args.resume, map_location='cpu') 244 | single_model.load_state_dict(checkpoint['model'], strict=False) 245 | if args.model != 'lavt_one': 246 | single_bert_model.load_state_dict(checkpoint['bert_model']) 247 | 248 | # parameters to optimize 249 | backbone_no_decay = list() 250 | backbone_decay = list() 251 | for name, m in single_model.backbone.named_parameters(): 252 | if 'norm' in name or 'absolute_pos_embed' in name or 'relative_position_bias_table' in name: 253 | backbone_no_decay.append(m) 254 | else: 255 | backbone_decay.append(m) 256 | 257 | if args.model != 'lavt_one': 258 | params_to_optimize = [ 259 | {'params': backbone_no_decay, 'weight_decay': 0.0}, 260 | {'params': backbone_decay}, 261 | {"params": [p for p in single_model.classifier.parameters() if p.requires_grad]}, 262 | # the following are the parameters of bert 263 | {"params": reduce(operator.concat, 264 | [[p for p in single_bert_model.module.encoder.layer[i].parameters() 265 | if p.requires_grad] for i in range(10)])}, 266 | ] 267 | else: 268 | params_to_optimize = [ 269 | {'params': backbone_no_decay, 'weight_decay': 0.0}, 270 | {'params': backbone_decay}, 271 | {"params": [p for p in single_model.classifier.parameters() if p.requires_grad]}, 272 | # the following are the parameters of bert 273 | {"params": reduce(operator.concat, 274 | [[p for p in single_model.text_encoder.encoder.layer[i].parameters() 275 | if p.requires_grad] for i in range(10)])}, 276 | ] 277 | 278 | # optimizer 279 | optimizer = torch.optim.AdamW(params_to_optimize, 280 | lr=args.lr, 281 | weight_decay=args.weight_decay, 282 | amsgrad=args.amsgrad 283 | ) 284 | 285 | # learning rate scheduler 286 | lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, 287 | lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9) 288 | 289 | # housekeeping 290 | start_time = time.time() 291 | iterations = 0 292 | best_oIoU = -0.1 293 | 294 | # resume training (optimizer, lr scheduler, and the epoch) 295 | if args.resume: 296 | optimizer.load_state_dict(checkpoint['optimizer']) 297 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 298 | resume_epoch = checkpoint['epoch'] 299 | 300 | else: 301 | resume_epoch = -999 302 | 303 | # training loops 304 | if args.local_rank == 0: 305 | wandb.watch(model, log="all") 306 | 307 | 308 | for epoch in range(max(0, resume_epoch+1), args.epochs): 309 | data_loader.sampler.set_epoch(epoch) 310 | train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, epoch, args.print_freq, 311 | iterations, bert_model) 312 | iou, overallIoU = evaluate(model, data_loader_test, bert_model, epoch) 313 | print('Average object IoU {}'.format(iou)) 314 | print('Overall IoU {}'.format(overallIoU)) 315 | best = (best_oIoU < overallIoU) 316 | if single_bert_model is not None: 317 | dict_to_save = {'model': single_model.state_dict(), 'bert_model': single_bert_model.state_dict(), 318 | 'optimizer': optimizer.state_dict(), 'epoch': epoch, 'args': args, 319 | 'lr_scheduler': lr_scheduler.state_dict()} 320 | else: 321 | dict_to_save = {'model': single_model.state_dict(), 322 | 'optimizer': optimizer.state_dict(), 'epoch': epoch, 'args': args, 323 | 'lr_scheduler': lr_scheduler.state_dict()} 324 | 325 | if best: 326 | print('Better epoch: {}\n'.format(epoch)) 327 | utils.save_on_master(dict_to_save, os.path.join(args.output_dir, 328 | 'model_best_{}.pth'.format(args.model_id))) 329 | best_oIoU = overallIoU 330 | utils.save_on_master(dict_to_save, os.path.join(args.output_dir, 331 | 'model_last_{}.pth'.format(args.model_id))) 332 | if args.local_rank == 0: 333 | wandb.save('model.h5') 334 | 335 | # summarize 336 | total_time = time.time() - start_time 337 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 338 | print('Training time {}'.format(total_time_str)) 339 | 340 | 341 | if __name__ == "__main__": 342 | from args import get_parser 343 | seed_everything() 344 | parser = get_parser() 345 | args = parser.parse_args() 346 | if args.local_rank == 0: 347 | wandb.init(project="rmsin_2080") 348 | # set up distributed learning 349 | utils.init_distributed_mode(args) 350 | print('Image size: {}'.format(str(args.img_size))) 351 | main(args) 352 | -------------------------------------------------------------------------------- /transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import random 4 | 5 | import torch 6 | from torchvision import transforms as T 7 | from torchvision.transforms import functional as F 8 | 9 | 10 | def pad_if_smaller(img, size, fill=0): 11 | min_size = min(img.size) 12 | if min_size < size: 13 | ow, oh = img.size 14 | padh = size - oh if oh < size else 0 15 | padw = size - ow if ow < size else 0 16 | img = F.pad(img, (0, 0, padw, padh), fill=fill) 17 | return img 18 | 19 | 20 | class Compose(object): 21 | def __init__(self, transforms): 22 | self.transforms = transforms 23 | 24 | def __call__(self, image, target): 25 | for t in self.transforms: 26 | image, target = t(image, target) 27 | return image, target 28 | 29 | 30 | class Resize(object): 31 | def __init__(self, h, w): 32 | self.h = h 33 | self.w = w 34 | 35 | def __call__(self, image, target): 36 | image = F.resize(image, (self.h, self.w)) 37 | # If size is a sequence like (h, w), the output size will be matched to this. 38 | # If size is an int, the smaller edge of the image will be matched to this number maintaining the aspect ratio 39 | target = F.resize(target, (self.h, self.w), interpolation=Image.NEAREST) 40 | return image, target 41 | 42 | 43 | class RandomResize(object): 44 | def __init__(self, min_size, max_size=None): 45 | self.min_size = min_size 46 | if max_size is None: 47 | max_size = min_size 48 | self.max_size = max_size 49 | 50 | def __call__(self, image, target): 51 | size = random.randint(self.min_size, self.max_size) # Return a random integer N such that a <= N <= b. Alias for randrange(a, b+1) 52 | image = F.resize(image, size) 53 | # If size is a sequence like (h, w), the output size will be matched to this. 54 | # If size is an int, the smaller edge of the image will be matched to this number maintaining the aspect ratio 55 | target = F.resize(target, size, interpolation=Image.NEAREST) 56 | return image, target 57 | 58 | 59 | class RandomHorizontalFlip(object): 60 | def __init__(self, flip_prob): 61 | self.flip_prob = flip_prob 62 | 63 | def __call__(self, image, target): 64 | if random.random() < self.flip_prob: 65 | image = F.hflip(image) 66 | target = F.hflip(target) 67 | return image, target 68 | 69 | 70 | class RandomCrop(object): 71 | def __init__(self, size): 72 | self.size = size 73 | 74 | def __call__(self, image, target): 75 | image = pad_if_smaller(image, self.size) 76 | target = pad_if_smaller(target, self.size, fill=255) 77 | crop_params = T.RandomCrop.get_params(image, (self.size, self.size)) 78 | image = F.crop(image, *crop_params) 79 | target = F.crop(target, *crop_params) 80 | return image, target 81 | 82 | 83 | class CenterCrop(object): 84 | def __init__(self, size): 85 | self.size = size 86 | 87 | def __call__(self, image, target): 88 | image = F.center_crop(image, self.size) 89 | target = F.center_crop(target, self.size) 90 | return image, target 91 | 92 | 93 | class ToTensor(object): 94 | def __call__(self, image, target): 95 | image = F.to_tensor(image) 96 | target = torch.as_tensor(np.asarray(target).copy(), dtype=torch.int64) 97 | return image, target 98 | 99 | 100 | class RandomAffine(object): 101 | def __init__(self, angle, translate, scale, shear, resample=0, fillcolor=None): 102 | self.angle = angle 103 | self.translate = translate 104 | self.scale = scale 105 | self.shear = shear 106 | self.resample = resample 107 | self.fillcolor = fillcolor 108 | 109 | def __call__(self, image, target): 110 | affine_params = T.RandomAffine.get_params(self.angle, self.translate, self.scale, self.shear, image.size) 111 | image = F.affine(image, *affine_params) 112 | target = F.affine(target, *affine_params) 113 | return image, target 114 | 115 | 116 | class Normalize(object): 117 | def __init__(self, mean, std): 118 | self.mean = mean 119 | self.std = std 120 | 121 | def __call__(self, image, target): 122 | image = F.normalize(image, mean=self.mean, std=self.std) 123 | return image, target 124 | 125 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from collections import defaultdict, deque 3 | import datetime 4 | import math 5 | import time 6 | import torch 7 | import torch.distributed as dist 8 | import torch.backends.cudnn as cudnn 9 | 10 | import errno 11 | import os 12 | 13 | import sys 14 | 15 | 16 | class SmoothedValue(object): 17 | """Track a series of values and provide access to smoothed values over a 18 | window or the global series average. 19 | """ 20 | 21 | def __init__(self, window_size=20, fmt=None): 22 | if fmt is None: 23 | fmt = "{median:.4f} ({global_avg:.4f})" 24 | self.deque = deque(maxlen=window_size) 25 | self.total = 0.0 26 | self.count = 0 27 | self.fmt = fmt 28 | 29 | def update(self, value, n=1): 30 | self.deque.append(value) 31 | self.count += n 32 | self.total += value * n 33 | 34 | def synchronize_between_processes(self): 35 | """ 36 | Warning: does not synchronize the deque! 37 | """ 38 | if not is_dist_avail_and_initialized(): 39 | return 40 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 41 | dist.barrier() 42 | dist.all_reduce(t) 43 | t = t.tolist() 44 | self.count = int(t[0]) 45 | self.total = t[1] 46 | 47 | @property 48 | def median(self): 49 | d = torch.tensor(list(self.deque)) 50 | return d.median().item() 51 | 52 | @property 53 | def avg(self): 54 | d = torch.tensor(list(self.deque), dtype=torch.float32) 55 | return d.mean().item() 56 | 57 | @property 58 | def global_avg(self): 59 | return self.total / self.count 60 | 61 | @property 62 | def max(self): 63 | return max(self.deque) 64 | 65 | @property 66 | def value(self): 67 | return self.deque[-1] 68 | 69 | def __str__(self): 70 | return self.fmt.format( 71 | median=self.median, 72 | avg=self.avg, 73 | global_avg=self.global_avg, 74 | max=self.max, 75 | value=self.value) 76 | 77 | 78 | class MetricLogger(object): 79 | def __init__(self, delimiter="\t"): 80 | self.meters = defaultdict(SmoothedValue) 81 | self.delimiter = delimiter 82 | 83 | def update(self, **kwargs): 84 | for k, v in kwargs.items(): 85 | if isinstance(v, torch.Tensor): 86 | v = v.item() 87 | assert isinstance(v, (float, int)) 88 | self.meters[k].update(v) 89 | 90 | 91 | 92 | def __getattr__(self, attr): 93 | if attr in self.meters: 94 | return self.meters[attr] 95 | if attr in self.__dict__: 96 | return self.__dict__[attr] 97 | raise AttributeError("'{}' object has no attribute '{}'".format( 98 | type(self).__name__, attr)) 99 | 100 | def __str__(self): 101 | loss_str = [] 102 | for name, meter in self.meters.items(): 103 | loss_str.append( 104 | "{}: {}".format(name, str(meter)) 105 | ) 106 | return self.delimiter.join(loss_str) 107 | 108 | def synchronize_between_processes(self): 109 | for meter in self.meters.values(): 110 | meter.synchronize_between_processes() 111 | 112 | def add_meter(self, name, meter): 113 | self.meters[name] = meter 114 | 115 | def log_every(self, iterable, print_freq, header=None): 116 | print(iterable) 117 | i = 0 118 | if not header: 119 | header = '' 120 | start_time = time.time() 121 | end = time.time() 122 | iter_time = SmoothedValue(fmt='{avg:.4f}') 123 | data_time = SmoothedValue(fmt='{avg:.4f}') 124 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 125 | log_msg = self.delimiter.join([ 126 | header, 127 | '[{0' + space_fmt + '}/{1}]', 128 | 'eta: {eta}', 129 | '{meters}', 130 | 'time: {time}', 131 | 'data: {data}', 132 | 'max mem: {memory:.0f}' 133 | ]) 134 | MB = 1024.0 * 1024.0 135 | for obj in iterable: 136 | data_time.update(time.time() - end) 137 | yield obj 138 | iter_time.update(time.time() - end) 139 | if i % print_freq == 0: 140 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 141 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 142 | print(log_msg.format( 143 | i, len(iterable), eta=eta_string, 144 | meters=str(self), 145 | time=str(iter_time), data=str(data_time), 146 | memory=torch.cuda.max_memory_allocated() / MB)) 147 | sys.stdout.flush() 148 | 149 | i += 1 150 | end = time.time() 151 | total_time = time.time() - start_time 152 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 153 | print('{} Total time: {}'.format(header, total_time_str)) 154 | 155 | 156 | def mkdir(path): 157 | try: 158 | os.makedirs(path) 159 | except OSError as e: 160 | if e.errno != errno.EEXIST: 161 | raise 162 | 163 | 164 | def setup_for_distributed(is_master): 165 | """ 166 | This function disables printing when not in master process 167 | """ 168 | import builtins as __builtin__ 169 | builtin_print = __builtin__.print 170 | 171 | def print(*args, **kwargs): 172 | force = kwargs.pop('force', False) 173 | if is_master or force: 174 | builtin_print(*args, **kwargs) 175 | 176 | __builtin__.print = print 177 | 178 | 179 | def is_dist_avail_and_initialized(): 180 | if not dist.is_available(): 181 | return False 182 | if not dist.is_initialized(): 183 | return False 184 | return True 185 | 186 | 187 | def get_world_size(): 188 | if not is_dist_avail_and_initialized(): 189 | return 1 190 | return dist.get_world_size() 191 | 192 | 193 | def get_rank(): 194 | if not is_dist_avail_and_initialized(): 195 | return 0 196 | return dist.get_rank() 197 | 198 | 199 | def is_main_process(): 200 | return get_rank() == 0 201 | 202 | 203 | def save_on_master(*args, **kwargs): 204 | if is_main_process(): 205 | torch.save(*args, **kwargs) 206 | 207 | 208 | def init_distributed_mode(args): 209 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 210 | rank = int(os.environ["RANK"]) 211 | world_size = int(os.environ['WORLD_SIZE']) 212 | print(f"RANK and WORLD_SIZE in environment: {rank}/{world_size}") 213 | else: 214 | rank = -1 215 | world_size = -1 216 | 217 | torch.cuda.set_device(args.local_rank) 218 | torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank) 219 | torch.distributed.barrier() 220 | setup_for_distributed(is_main_process()) 221 | 222 | if args.output_dir: 223 | mkdir(args.output_dir) 224 | if args.model_id: 225 | mkdir(os.path.join('./models/', args.model_id)) 226 | --------------------------------------------------------------------------------