├── README.md └── DoubleAttentionLayer.py /README.md: -------------------------------------------------------------------------------- 1 | # DoubleAttentionNet 2 | PyTorch implementation of Double Attention Net, proposed in ***A2-Nets: Double Attention Networks***. 3 | 4 | 5 | 6 | # Pre-requirements 7 | 8 | - Python>=3.0 9 | - PyTorch=0.3 10 | 11 | # Usage 12 | 13 | ``` 14 | from DoubleAttentionLayer import DoubleAttentionLayer 15 | 16 | doubleA = DoubleAttentionLayer(in_channels, out_channels, c_n) 17 | ``` 18 | 19 | 20 | -------------------------------------------------------------------------------- /DoubleAttentionLayer.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | from torch.autograd import Variable 4 | 5 | 6 | class DoubleAttentionLayer(nn.Module): 7 | def __init__(self, in_channels, c_m, c_n,k =1 ): 8 | super(DoubleAttentionLayer, self).__init__() 9 | 10 | self.K = k 11 | self.c_m = c_m 12 | self.c_n = c_n 13 | self.softmax = nn.Softmax() 14 | self.in_channels = in_channels 15 | 16 | self.convA = nn.Conv2d(in_channels, c_m, 1) 17 | self.convB = nn.Conv2d(in_channels, c_n, 1) 18 | self.convV = nn.Conv2d(in_channels, c_n, 1) 19 | 20 | def forward(self, x): 21 | 22 | b, c, h, w = x.size() 23 | 24 | assert c == self.in_channels,'input channel not equal!' 25 | #assert b//self.K == self.in_channels, 'input channel not equal!' 26 | 27 | A = self.convA(x) 28 | B = self.convB(x) 29 | V = self.convV(x) 30 | 31 | batch = int(b/self.K) 32 | 33 | tmpA = A.view( batch, self.K, self.c_m, h*w ).permute(0,2,1,3).view( batch, self.c_m, self.K*h*w ) 34 | tmpB = B.view( batch, self.K, self.c_n, h*w ).permute(0,2,1,3).view( batch*self.c_n, self.K*h*w ) 35 | tmpV = V.view( batch, self.K, self.c_n, h*w ).permute(0,1,3,2).contiguous().view( int(b*h*w), self.c_n ) 36 | 37 | softmaxB = self.softmax(tmpB).view( batch, self.c_n, self.K*h*w ).permute( 0, 2, 1) #batch, self.K*h*w, self.c_n 38 | softmaxV = self.softmax(tmpV).view( batch, self.K*h*w, self.c_n ).permute( 0, 2, 1) #batch, self.c_n , self.K*h*w 39 | 40 | tmpG = tmpA.matmul( softmaxB ) #batch, self.c_m, self.c_n 41 | tmpZ = tmpG.matmul( softmaxV ) #batch, self.c_m, self.K*h*w 42 | tmpZ = tmpZ.view(batch, self.c_m, self.K,h*w).permute( 0, 2, 1,3).view( int(b), self.c_m, h, w ) 43 | 44 | return tmpZ 45 | 46 | 47 | if __name__ == "__main__": 48 | 49 | 50 | # tmp1 = torch.ones(2,2,3) 51 | # tmp1[1,:,:] = tmp1[1,:,:]*2 52 | # tmp2 = tmp1.permute(0,2,1) 53 | # print(tmp1) 54 | # print( tmp2) 55 | # print( tmp1.matmul(tmp2)) 56 | 57 | in_channels = 10 58 | c_m = 4 59 | c_n = 3 60 | 61 | doubleA = DoubleAttentionLayer(in_channels, c_m, c_n) 62 | 63 | x = torch.ones(2,in_channels,6,8) 64 | x = Variable(x) 65 | tmp = doubleA(x) 66 | 67 | print("result") 68 | --------------------------------------------------------------------------------