├── README.md ├── ULSAM_Poster_WACV2020.pdf ├── ULSAM_Slides_WACV2020.pdf └── ulsam.py /README.md: -------------------------------------------------------------------------------- 1 | ULSAM 2 | ===== 3 | 4 | ULSAM: Ultra-Lightweight Subspace Attention Module for Compact Convolutional Neural Networks 5 | 6 | [Full Paper](https://arxiv.org/abs/2006.15102) 7 | -------------------------------------------------------------------------------- /ULSAM_Poster_WACV2020.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nandan91/ULSAM/9a07d686b69ff41b6b640791719f6b6889eae84b/ULSAM_Poster_WACV2020.pdf -------------------------------------------------------------------------------- /ULSAM_Slides_WACV2020.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nandan91/ULSAM/9a07d686b69ff41b6b640791719f6b6889eae84b/ULSAM_Slides_WACV2020.pdf -------------------------------------------------------------------------------- /ulsam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | torch.set_default_tensor_type(torch.cuda.FloatTensor) 5 | 6 | 7 | class SubSpace(nn.Module): 8 | """ 9 | Subspace class. 10 | 11 | ... 12 | 13 | Attributes 14 | ---------- 15 | nin : int 16 | number of input feature volume. 17 | 18 | Methods 19 | ------- 20 | __init__(nin) 21 | initialize method. 22 | forward(x) 23 | forward pass. 24 | 25 | """ 26 | 27 | def __init__(self, nin: int) -> None: 28 | super(SubSpace, self).__init__() 29 | self.conv_dws = nn.Conv2d( 30 | nin, nin, kernel_size=1, stride=1, padding=0, groups=nin 31 | ) 32 | self.bn_dws = nn.BatchNorm2d(nin, momentum=0.9) 33 | self.relu_dws = nn.ReLU(inplace=False) 34 | 35 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1) 36 | 37 | self.conv_point = nn.Conv2d( 38 | nin, 1, kernel_size=1, stride=1, padding=0, groups=1 39 | ) 40 | self.bn_point = nn.BatchNorm2d(1, momentum=0.9) 41 | self.relu_point = nn.ReLU(inplace=False) 42 | 43 | self.softmax = nn.Softmax(dim=2) 44 | 45 | def forward(self, x: torch.Tensor) -> torch.Tensor: 46 | out = self.conv_dws(x) 47 | out = self.bn_dws(out) 48 | out = self.relu_dws(out) 49 | 50 | out = self.maxpool(out) 51 | 52 | out = self.conv_point(out) 53 | out = self.bn_point(out) 54 | out = self.relu_point(out) 55 | 56 | m, n, p, q = out.shape 57 | out = self.softmax(out.view(m, n, -1)) 58 | out = out.view(m, n, p, q) 59 | 60 | out = out.expand(x.shape[0], x.shape[1], x.shape[2], x.shape[3]) 61 | 62 | out = torch.mul(out, x) 63 | 64 | out = out + x 65 | 66 | return out 67 | 68 | 69 | class ULSAM(nn.Module): 70 | """ 71 | Grouped Attention Block having multiple (num_splits) Subspaces. 72 | 73 | ... 74 | 75 | Attributes 76 | ---------- 77 | nin : int 78 | number of input feature volume. 79 | 80 | nout : int 81 | number of output feature maps 82 | 83 | h : int 84 | height of a input feature map 85 | 86 | w : int 87 | width of a input feature map 88 | 89 | num_splits : int 90 | number of subspaces 91 | 92 | Methods 93 | ------- 94 | __init__(nin) 95 | initialize method. 96 | forward(x) 97 | forward pass. 98 | 99 | """ 100 | 101 | def __init__(self, nin: int, nout: int, h: int, w: int, num_splits: int) -> None: 102 | super(ULSAM, self).__init__() 103 | 104 | assert nin % num_splits == 0 105 | 106 | self.nin = nin 107 | self.nout = nout 108 | self.h = h 109 | self.w = w 110 | self.num_splits = num_splits 111 | 112 | self.subspaces = nn.ModuleList( 113 | [SubSpace(int(self.nin / self.num_splits)) for i in range(self.num_splits)] 114 | ) 115 | 116 | def forward(self, x: torch.Tensor) -> torch.Tensor: 117 | group_size = int(self.nin / self.num_splits) 118 | 119 | # split at batch dimension 120 | sub_feat = torch.chunk(x, self.num_splits, dim=1) 121 | 122 | out = [] 123 | for idx, l in enumerate(self.subspaces): 124 | out.append(self.subspaces[idx](sub_feat[idx])) 125 | 126 | out = torch.cat(out, dim=1) 127 | 128 | return out 129 | 130 | 131 | # for debug 132 | # print(ULSAM(64, 64, 112, 112, 4)) 133 | --------------------------------------------------------------------------------