├── README.md ├── SE-3D.py ├── SCSE-3D.py ├── self-attention-3D.py ├── dual-attention-3D.py └── non-local-3D.py /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch_3D_Attention_Modules 2 | Pytorch implement of 3D attention modules 3 | -------------------------------------------------------------------------------- /SE-3D.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class SELayer(nn.Module): 7 | def __init__(self, channel=32, reduction=8): 8 | super(SELayer, self).__init__() 9 | self.cse_avg_pool = nn.AdaptiveAvgPool3d(1) 10 | self.cse_fc = nn.Sequential( 11 | nn.Linear(channel, channel // reduction, bias=False), 12 | nn.ReLU(inplace=True), 13 | nn.Linear(channel // reduction, channel, bias=False), 14 | nn.Sigmoid() 15 | ) 16 | 17 | def forward(self, x): 18 | b, c, z, w, h = x.size() 19 | cse_y = self.cse_avg_pool(x).view(b, c) 20 | cse_y = self.cse_fc(cse_y).view(b, c, 1, 1, 1) 21 | 22 | return x * cse_y.expand_as(x) 23 | -------------------------------------------------------------------------------- /SCSE-3D.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class SCSELayer(nn.Module): 7 | def __init__(self, channel=32, reduction=8): 8 | super(SCSELayer, self).__init__() 9 | self.cse_avg_pool = nn.AdaptiveAvgPool3d(1) 10 | self.cse_fc = nn.Sequential( 11 | nn.Linear(channel, channel // reduction, bias=False), 12 | nn.ReLU(inplace=True), 13 | nn.Linear(channel // reduction, channel, bias=False), 14 | nn.Sigmoid() 15 | ) 16 | self.sse_conv = nn.Conv3d(channel, 1, 1, padding=0) 17 | 18 | def forward(self, x): 19 | b, c, z, w, h = x.size() 20 | cse_y = self.cse_avg_pool(x).view(b, c) 21 | cse_y = self.cse_fc(cse_y).view(b, c, 1, 1, 1) 22 | sse_y = self.sse_conv(x) 23 | 24 | return x * cse_y.expand_as(x) + x * sse_y.expand_as(x) 25 | -------------------------------------------------------------------------------- /self-attention-3D.py: -------------------------------------------------------------------------------- 1 | '''forked from https://github.com/openseg-group/OCNet.pytorch''' 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | from torch.autograd import Variable 7 | 8 | 9 | class SelfAttentionBlock(nn.Module): 10 | ''' 11 | The basic implementation for self-attention block/non-local block 12 | Input: 13 | N X C X z*y*x 14 | Parameters: 15 | in_channels : the dimension of the input feature map 16 | key_channels : the dimension after the key/query transform 17 | value_channels : the dimension after the value transform 18 | scale : choose the scale to downsample the input feature maps (save memory cost) 19 | Return: 20 | N X C X z*y*x 21 | position-aware context features.(w/o concate or add with the input) 22 | ''' 23 | 24 | def __init__(self, in_channels, key_channels=None, value_channels=None, out_channels=None, scale=1): 25 | super(SelfAttentionBlock, self).__init__() 26 | self.scale = scale 27 | self.in_channels = in_channels 28 | self.out_channels = out_channels 29 | self.key_channels = key_channels 30 | self.value_channels = value_channels 31 | if out_channels is None: 32 | self.out_channels = in_channels 33 | if key_channels is None: 34 | self.key_channels = in_channels // 2 35 | if value_channels is None: 36 | self.value_channels = in_channels // 2 37 | self.pool = nn.MaxPool3d(kernel_size=(scale, scale)) 38 | self.f_key = nn.Sequential( 39 | nn.Conv3d(in_channels=self.in_channels, out_channels=self.key_channels, 40 | kernel_size=1, stride=1, padding=0), 41 | nn.BatchNorm3d(self.key_channels), 42 | ) 43 | self.f_query = self.f_key 44 | self.f_value = nn.Conv3d(in_channels=self.in_channels, out_channels=self.value_channels, 45 | kernel_size=1, stride=1, padding=0) 46 | self.W = nn.Conv3d(in_channels=self.value_channels, out_channels=self.out_channels, 47 | kernel_size=1, stride=1, padding=0) 48 | 49 | def forward(self, x): 50 | batch_size, c, d, h, w = x.size() 51 | if self.scale > 1: 52 | x = self.pool(x) 53 | 54 | value = self.f_value(x).view(batch_size, self.value_channels, -1) 55 | value = value.permute(0, 2, 1) 56 | query = self.f_query(x).view(batch_size, self.key_channels, -1) 57 | query = query.permute(0, 2, 1) 58 | key = self.f_key(x).view(batch_size, self.key_channels, -1) 59 | 60 | sim_map = torch.matmul(query, key) 61 | sim_map = (self.key_channels ** -.5) * sim_map 62 | sim_map = F.softmax(sim_map, dim=-1) 63 | 64 | context = torch.matmul(sim_map, value) 65 | context = context.permute(0, 2, 1).contiguous() 66 | context = context.view(batch_size, self.value_channels, *x.size()[2:]) 67 | context = self.W(context) 68 | return context 69 | 70 | 71 | if __name__ == '__main__': 72 | import torch 73 | 74 | img = torch.randn(2, 32, 8, 20, 20) 75 | net = SelfAttentionBlock(in_channels=32) 76 | out = net(img) 77 | print(out.size()) 78 | 79 | -------------------------------------------------------------------------------- /dual-attention-3D.py: -------------------------------------------------------------------------------- 1 | '''fored from https://github.com/junfu1115/DANet''' 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | from torch.autograd import Variable 7 | 8 | 9 | class PAM_Module(nn.Module): 10 | """ Position attention module""" 11 | 12 | # Ref from SAGAN 13 | def __init__(self, in_dim): 14 | super(PAM_Module, self).__init__() 15 | self.chanel_in = in_dim 16 | 17 | self.query_conv = nn.Conv3d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) 18 | self.key_conv = nn.Conv3d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) 19 | self.value_conv = nn.Conv3d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) 20 | self.gamma = nn.Parameter(torch.zeros(1)) 21 | 22 | self.softmax = nn.Softmax(dim=-1) 23 | 24 | def forward(self, x): 25 | """ 26 | inputs : 27 | x : input feature maps( B X C X z*y*x) 28 | returns : 29 | out : attention value + input feature 30 | attention: B X (z*y*x) X (z*y*x) 31 | """ 32 | m_batchsize, C, depth, height, width = x.size() 33 | proj_query = self.query_conv(x).view(m_batchsize, -1, depth * width * height).permute(0, 2, 1) 34 | proj_key = self.key_conv(x).view(m_batchsize, -1, depth * width * height) 35 | energy = torch.bmm(proj_query, proj_key) 36 | attention = self.softmax(energy) 37 | proj_value = self.value_conv(x).view(m_batchsize, -1, depth * width * height) 38 | 39 | out = torch.bmm(proj_value, attention.permute(0, 2, 1)) 40 | out = out.view(m_batchsize, C, depth, height, width) 41 | 42 | out = self.gamma * out + x 43 | return out 44 | 45 | 46 | class CAM_Module(nn.Module): 47 | """ Channel attention module""" 48 | 49 | def __init__(self, in_dim): 50 | super(CAM_Module, self).__init__() 51 | self.chanel_in = in_dim 52 | 53 | self.gamma = nn.Parameter(torch.zeros(1)) 54 | self.softmax = nn.Softmax(dim=-1) 55 | 56 | def forward(self, x): 57 | """ 58 | inputs : 59 | x : input feature maps( B X C X z*y*x) 60 | returns : 61 | out : attention value + input feature 62 | attention: B X C X C 63 | """ 64 | m_batchsize, C, depth, height, width = x.size() 65 | proj_query = x.view(m_batchsize, C, -1) 66 | proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1) 67 | energy = torch.bmm(proj_query, proj_key) 68 | energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy) - energy 69 | attention = self.softmax(energy_new) 70 | proj_value = x.view(m_batchsize, C, -1) 71 | 72 | out = torch.bmm(attention, proj_value) 73 | out = out.view(m_batchsize, C, depth, height, width) 74 | 75 | out = self.gamma * out + x 76 | return out 77 | 78 | 79 | class DANetHead(nn.Module): 80 | def __init__(self, in_channels, out_channels, norm_layer): 81 | super(DANetHead, self).__init__() 82 | inter_channels = in_channels // 4 83 | self.conv5a = nn.Sequential(nn.Conv3d(in_channels, inter_channels, 3, padding=1, bias=False), 84 | norm_layer(inter_channels), 85 | nn.ReLU()) 86 | 87 | self.conv5c = nn.Sequential(nn.Conv3d(in_channels, inter_channels, 3, padding=1, bias=False), 88 | norm_layer(inter_channels), 89 | nn.ReLU()) 90 | 91 | self.sa = PAM_Module(inter_channels) 92 | self.sc = CAM_Module(inter_channels) 93 | self.conv51 = nn.Sequential(nn.Conv3d(inter_channels, inter_channels, 3, padding=1, bias=False), 94 | norm_layer(inter_channels), 95 | nn.ReLU()) 96 | self.conv52 = nn.Sequential(nn.Conv3d(inter_channels, inter_channels, 3, padding=1, bias=False), 97 | norm_layer(inter_channels), 98 | nn.ReLU()) 99 | 100 | # self.conv6 = nn.Sequential(nn.Dropout3d(0.1, False), nn.Conv3d(inter_channels, out_channels, 1)) 101 | # self.conv7 = nn.Sequential(nn.Dropout3d(0.1, False), nn.Conv3d(inter_channels, out_channels, 1)) 102 | 103 | self.conv8 = nn.Sequential(nn.Dropout3d(0.1, False), nn.Conv3d(inter_channels, out_channels, 1)) 104 | 105 | def forward(self, x): 106 | feat1 = self.conv5a(x) 107 | sa_feat = self.sa(feat1) 108 | sa_conv = self.conv51(sa_feat) 109 | # sa_output = self.conv6(sa_conv) 110 | 111 | feat2 = self.conv5c(x) 112 | sc_feat = self.sc(feat2) 113 | sc_conv = self.conv52(sc_feat) 114 | # sc_output = self.conv7(sc_conv) 115 | 116 | feat_sum = sa_conv + sc_conv 117 | 118 | sasc_output = self.conv8(feat_sum) 119 | return sasc_output 120 | 121 | 122 | if __name__ == '__main__': 123 | import torch 124 | from Model.nnUnet import MyGroupNorm 125 | 126 | img = torch.randn(2, 32, 8, 20, 20) 127 | net = DANetHead(in_channels=32, out_channels=32, norm_layer=MyGroupNorm) 128 | out = net(img) 129 | print(out.size()) 130 | 131 | 132 | -------------------------------------------------------------------------------- /non-local-3D.py: -------------------------------------------------------------------------------- 1 | '''forked from https://github.com/tea1528/Non-Local-NN-Pytorch/blob/master/3D_experiment/models/non_local.py''' 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | 8 | class NLBlockND(nn.Module): 9 | def __init__(self, in_channels, inter_channels=None, mode='embedded', 10 | dimension=3, bn_layer=True): 11 | """Implementation of Non-Local Block with 4 different pairwise functions 12 | args: 13 | in_channels: original channel size (1024 in the paper) 14 | inter_channels: channel size inside the block if not specifed reduced to half (512 in the paper) 15 | mode: supports Gaussian, Embedded Gaussian, Dot Product, and Concatenation 16 | dimension: can be 1 (temporal), 2 (spatial), 3 (spatiotemporal) 17 | bn_layer: whether to add batch norm 18 | """ 19 | super(NLBlockND, self).__init__() 20 | 21 | assert dimension in [1, 2, 3] 22 | 23 | if mode not in ['gaussian', 'embedded', 'dot', 'concatenate']: 24 | raise ValueError('`mode` must be one of `gaussian`, `embedded`, `dot` or `concatenate`') 25 | 26 | self.mode = mode 27 | self.dimension = dimension 28 | 29 | self.in_channels = in_channels 30 | self.inter_channels = inter_channels 31 | 32 | # the channel size is reduced to half inside the block 33 | if self.inter_channels is None: 34 | self.inter_channels = in_channels // 2 35 | if self.inter_channels == 0: 36 | self.inter_channels = 1 37 | 38 | # assign appropriate convolutional, max pool, and batch norm layers for different dimensions 39 | if dimension == 3: 40 | conv_nd = nn.Conv3d 41 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) 42 | bn = nn.BatchNorm3d 43 | elif dimension == 2: 44 | conv_nd = nn.Conv2d 45 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) 46 | bn = nn.BatchNorm2d 47 | else: 48 | conv_nd = nn.Conv1d 49 | max_pool_layer = nn.MaxPool1d(kernel_size=(2)) 50 | bn = nn.BatchNorm1d 51 | 52 | # function g in the paper which goes through conv. with kernel size 1 53 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1) 54 | 55 | # add BatchNorm layer after the last conv layer 56 | if bn_layer: 57 | self.W_z = nn.Sequential( 58 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1), 59 | bn(self.in_channels) 60 | ) 61 | nn.init.constant_(self.W_z[1].weight, 0) 62 | nn.init.constant_(self.W_z[1].bias, 0) 63 | else: 64 | self.W_z = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1) 65 | nn.init.constant_(self.W_z.weight, 0) 66 | nn.init.constant_(self.W_z.bias, 0) 67 | 68 | # define theta and phi for all operations except gaussian 69 | if self.mode == "embedded" or self.mode == "dot" or self.mode == "concatenate": 70 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1) 71 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1) 72 | 73 | if self.mode == "concatenate": 74 | self.W_f = nn.Sequential( 75 | nn.Conv2d(in_channels=self.inter_channels * 2, out_channels=1, kernel_size=1), 76 | nn.ReLU() 77 | ) 78 | 79 | def forward(self, x): 80 | """ 81 | args 82 | x: (N, C, T, H, W) for dimension=3; (N, C, H, W) for dimension 2; (N, C, T) for dimension 1 83 | """ 84 | 85 | batch_size = x.size(0) 86 | 87 | # (N, C, THW) 88 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 89 | g_x = g_x.permute(0, 2, 1) 90 | 91 | if self.mode == "gaussian": 92 | theta_x = x.view(batch_size, self.in_channels, -1) 93 | phi_x = x.view(batch_size, self.in_channels, -1) 94 | theta_x = theta_x.permute(0, 2, 1) 95 | f = torch.matmul(theta_x, phi_x) 96 | 97 | elif self.mode == "embedded" or self.mode == "dot": 98 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 99 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 100 | theta_x = theta_x.permute(0, 2, 1) 101 | f = torch.matmul(theta_x, phi_x) 102 | 103 | elif self.mode == "concatenate": 104 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1, 1) 105 | phi_x = self.phi(x).view(batch_size, self.inter_channels, 1, -1) 106 | 107 | h = theta_x.size(2) 108 | w = phi_x.size(3) 109 | theta_x = theta_x.repeat(1, 1, 1, w) 110 | phi_x = phi_x.repeat(1, 1, h, 1) 111 | 112 | concat = torch.cat([theta_x, phi_x], dim=1) 113 | f = self.W_f(concat) 114 | f = f.view(f.size(0), f.size(2), f.size(3)) 115 | 116 | if self.mode == "gaussian" or self.mode == "embedded": 117 | f_div_C = F.softmax(f, dim=-1) 118 | elif self.mode == "dot" or self.mode == "concatenate": 119 | N = f.size(-1) # number of position in x 120 | f_div_C = f / N 121 | 122 | y = torch.matmul(f_div_C, g_x) 123 | 124 | # contiguous here just allocates contiguous chunk of memory 125 | y = y.permute(0, 2, 1).contiguous() 126 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 127 | 128 | W_y = self.W_z(y) 129 | # residual connection 130 | z = W_y + x 131 | 132 | return z 133 | 134 | 135 | if __name__ == '__main__': 136 | import torch 137 | 138 | for bn_layer in [True, False]: 139 | img = torch.zeros(2, 3, 20) 140 | net = NLBlockND(in_channels=3, mode='concatenate', dimension=1, bn_layer=bn_layer) 141 | out = net(img) 142 | print(out.size()) 143 | 144 | img = torch.zeros(2, 3, 20, 20) 145 | net = NLBlockND(in_channels=3, mode='concatenate', dimension=2, bn_layer=bn_layer) 146 | out = net(img) 147 | print(out.size()) 148 | 149 | img = torch.randn(2, 3, 8, 20, 20) 150 | net = NLBlockND(in_channels=3, mode='concatenate', dimension=3, bn_layer=bn_layer) 151 | out = net(img) 152 | print(out.size()) 153 | --------------------------------------------------------------------------------