├── README.md ├── figures ├── SKConv.png ├── SKNet50.png └── empty └── sknet.py /README.md: -------------------------------------------------------------------------------- 1 | # SKNet-Pytorch 2 | Nearly Perfect & Easily Understandable PyTorch Implementation of [SKNet(Selective Kernel Networks)](https://arxiv.org/abs/1903.06586) 3 | 4 | I reimplemented SKNET using PyTorch. Although there are several PyTorch implementations of SKNET, they are different from implementation described in the original paper, and I was hard to understand their implementations. So I did. 5 | 6 | ## Seletive Kernel Convolution 7 | 8 | 9 | 10 | Refer to [this part](https://github.com/developer0hye/SKNet-PyTorch/blob/4e299e61a9acba35704112078746348150bf4dd4/sknet.py#L7-L59) for the implementation of Selective Kernel Convolution 11 | 12 | # Reference 13 | - Paper: [Selective Kernel Networks](https://arxiv.org/abs/1903.06586) 14 | - Paper: [Aggregated Residual Transformations for Deep Neural Networks](https://arxiv.org/abs/1611.05431) 15 | - Repository: [pppLang/SKNet](https://github.com/pppLang/SKNet) 16 | - Repository: [ResearchingDexter/SKNet_pytorch](https://github.com/ResearchingDexter/SKNet_pytorch) 17 | - Repository: [bearpaw/pytorch-classification](https://github.com/bearpaw/pytorch-classification/blob/master/models/imagenet/resnext.py) 18 | 19 | # To Do Lists 20 | - Experiment on CIFAR100 with Resnet-18 21 | -------------------------------------------------------------------------------- /figures/SKConv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/developer0hye/SKNet-PyTorch/f78ea0c8e9394a77e58b33b228a4209186a6ad9e/figures/SKConv.png -------------------------------------------------------------------------------- /figures/SKNet50.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/developer0hye/SKNet-PyTorch/f78ea0c8e9394a77e58b33b228a4209186a6ad9e/figures/SKNet50.png -------------------------------------------------------------------------------- /figures/empty: -------------------------------------------------------------------------------- 1 | empty 2 | -------------------------------------------------------------------------------- /sknet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | #from thop import profile 5 | #from thop import clever_format 6 | 7 | class SKConv(nn.Module): 8 | def __init__(self, features, M=2, G=32, r=16, stride=1 ,L=32): 9 | """ Constructor 10 | Args: 11 | features: input channel dimensionality. 12 | M: the number of branchs. 13 | G: num of convolution groups. 14 | r: the ratio for compute d, the length of z. 15 | stride: stride, default 1. 16 | L: the minimum dim of the vector z in paper, default 32. 17 | """ 18 | super(SKConv, self).__init__() 19 | d = max(int(features/r), L) 20 | self.M = M 21 | self.features = features 22 | self.convs = nn.ModuleList([]) 23 | for i in range(M): 24 | self.convs.append(nn.Sequential( 25 | nn.Conv2d(features, features, kernel_size=3, stride=stride, padding=1+i, dilation=1+i, groups=G, bias=False), 26 | nn.BatchNorm2d(features), 27 | nn.ReLU(inplace=True) 28 | )) 29 | self.gap = nn.AdaptiveAvgPool2d((1,1)) 30 | self.fc = nn.Sequential(nn.Conv2d(features, d, kernel_size=1, stride=1, bias=False), 31 | nn.BatchNorm2d(d), 32 | nn.ReLU(inplace=True)) 33 | self.fcs = nn.ModuleList([]) 34 | for i in range(M): 35 | self.fcs.append( 36 | nn.Conv2d(d, features, kernel_size=1, stride=1) 37 | ) 38 | self.softmax = nn.Softmax(dim=1) 39 | 40 | def forward(self, x): 41 | 42 | batch_size = x.shape[0] 43 | 44 | feats = [conv(x) for conv in self.convs] 45 | feats = torch.cat(feats, dim=1) 46 | feats = feats.view(batch_size, self.M, self.features, feats.shape[2], feats.shape[3]) 47 | 48 | feats_U = torch.sum(feats, dim=1) 49 | feats_S = self.gap(feats_U) 50 | feats_Z = self.fc(feats_S) 51 | 52 | attention_vectors = [fc(feats_Z) for fc in self.fcs] 53 | attention_vectors = torch.cat(attention_vectors, dim=1) 54 | attention_vectors = attention_vectors.view(batch_size, self.M, self.features, 1, 1) 55 | attention_vectors = self.softmax(attention_vectors) 56 | 57 | feats_V = torch.sum(feats*attention_vectors, dim=1) 58 | 59 | return feats_V 60 | 61 | 62 | class SKUnit(nn.Module): 63 | def __init__(self, in_features, mid_features, out_features, M=2, G=32, r=16, stride=1, L=32): 64 | """ Constructor 65 | Args: 66 | in_features: input channel dimensionality. 67 | out_features: output channel dimensionality. 68 | M: the number of branchs. 69 | G: num of convolution groups. 70 | r: the ratio for compute d, the length of z. 71 | mid_features: the channle dim of the middle conv with stride not 1, default out_features/2. 72 | stride: stride. 73 | L: the minimum dim of the vector z in paper. 74 | """ 75 | super(SKUnit, self).__init__() 76 | 77 | self.conv1 = nn.Sequential( 78 | nn.Conv2d(in_features, mid_features, 1, stride=1, bias=False), 79 | nn.BatchNorm2d(mid_features), 80 | nn.ReLU(inplace=True) 81 | ) 82 | 83 | self.conv2_sk = SKConv(mid_features, M=M, G=G, r=r, stride=stride, L=L) 84 | 85 | self.conv3 = nn.Sequential( 86 | nn.Conv2d(mid_features, out_features, 1, stride=1, bias=False), 87 | nn.BatchNorm2d(out_features) 88 | ) 89 | 90 | 91 | if in_features == out_features: # when dim not change, input_features could be added diectly to out 92 | self.shortcut = nn.Sequential() 93 | else: # when dim not change, input_features should also change dim to be added to out 94 | self.shortcut = nn.Sequential( 95 | nn.Conv2d(in_features, out_features, 1, stride=stride, bias=False), 96 | nn.BatchNorm2d(out_features) 97 | ) 98 | 99 | self.relu = nn.ReLU(inplace=True) 100 | 101 | def forward(self, x): 102 | residual = x 103 | 104 | out = self.conv1(x) 105 | out = self.conv2_sk(out) 106 | out = self.conv3(out) 107 | 108 | return self.relu(out + self.shortcut(residual)) 109 | 110 | class SKNet(nn.Module): 111 | def __init__(self, class_num, nums_block_list = [3, 4, 6, 3], strides_list = [1, 2, 2, 2]): 112 | super(SKNet, self).__init__() 113 | self.basic_conv = nn.Sequential( 114 | nn.Conv2d(3, 64, 7, 2, 3, bias=False), 115 | nn.BatchNorm2d(64), 116 | nn.ReLU(inplace=True), 117 | ) 118 | 119 | self.maxpool = nn.MaxPool2d(3,2,1) 120 | 121 | self.stage_1 = self._make_layer(64, 128, 256, nums_block=nums_block_list[0], stride=strides_list[0]) 122 | self.stage_2 = self._make_layer(256, 256, 512, nums_block=nums_block_list[1], stride=strides_list[1]) 123 | self.stage_3 = self._make_layer(512, 512, 1024, nums_block=nums_block_list[2], stride=strides_list[2]) 124 | self.stage_4 = self._make_layer(1024, 1024, 2048, nums_block=nums_block_list[3], stride=strides_list[3]) 125 | 126 | self.gap = nn.AdaptiveAvgPool2d((1, 1)) 127 | self.classifier = nn.Linear(2048, class_num) 128 | 129 | def _make_layer(self, in_feats, mid_feats, out_feats, nums_block, stride=1): 130 | layers=[SKUnit(in_feats, mid_feats, out_feats, stride=stride)] 131 | for _ in range(1,nums_block): 132 | layers.append(SKUnit(out_feats, mid_feats, out_feats)) 133 | return nn.Sequential(*layers) 134 | 135 | def forward(self, x): 136 | fea = self.basic_conv(x) 137 | fea = self.maxpool(fea) 138 | fea = self.stage_1(fea) 139 | fea = self.stage_2(fea) 140 | fea = self.stage_3(fea) 141 | fea = self.stage_4(fea) 142 | fea = self.gap(fea) 143 | fea = torch.squeeze(fea) 144 | fea = self.classifier(fea) 145 | return fea 146 | 147 | def SKNet26(nums_class=1000): 148 | return SKNet(nums_class, [2, 2, 2, 2]) 149 | def SKNet50(nums_class=1000): 150 | return SKNet(nums_class, [3, 4, 6, 3]) 151 | def SKNet101(nums_class=1000): 152 | return SKNet(nums_class, [3, 4, 23, 3]) 153 | 154 | if __name__=='__main__': 155 | x = torch.rand(8, 3, 224, 224) 156 | model = SKNet26() 157 | out = model(x) 158 | 159 | #flops, params = profile(model, (x, )) 160 | #flops, params = clever_format([flops, params], "%.5f") 161 | 162 | #print(flops, params) 163 | #print('out shape : {}'.format(out.shape)) 164 | 165 | --------------------------------------------------------------------------------