├── README.md ├── quan_conv.py ├── vgg_small.py └── resnet18.py /README.md: -------------------------------------------------------------------------------- 1 | # AutoBNN 2 | 3 | This is the implementation of [Searching for Accurate Binary Neural Architectures](http://openaccess.thecvf.com/content_ICCVW_2019/papers/NeurArch/Shen_Searching_for_Accurate_Binary_Neural_Architectures_ICCVW_2019_paper.pdf) 4 | 5 | ## Network 6 | 7 | The implementation of VGG_Small and ResNet18 is in vgg_small.py and resnet18.py 8 | 9 | | Network | Expansion Ratio | 10 | | ---- | ---- | 11 | | VGG-Auto-A | [0.5, 2, 1, 1, 1, 0.5] 12 | | VGG-Auto-B | [2, 2, 4, 2, 4, 0.5] | 13 | | Res18-Auto-A | [2, 4, 2, 2, 2, 4, 4, 4, 4, 4, 4, 4] | 14 | | Res18-Auto-B | [3, 4, 4, 3, 1, 4, 4, 4, 3, 3, 3, 3] | 15 | 16 | 17 | ## Citation 18 | 19 | @inproceedings{shen2019searching, 20 | title={Searching for accurate binary neural architectures}, 21 | author={Shen, Mingzhu and Han, Kai and Xu, Chunjing and Wang, Yunhe}, 22 | booktitle={ICCV Neural Architecture Workshop}, 23 | pages={0--0}, 24 | year={2019} 25 | } 26 | -------------------------------------------------------------------------------- /quan_conv.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Function 6 | import torch.nn.functional as F 7 | 8 | 9 | class ScaleSigner(Function): 10 | """take a real value x, output sign(x)*E(|x|)""" 11 | @staticmethod 12 | def forward(ctx, input): 13 | return torch.sign(input) * torch.mean(torch.abs(input)) 14 | @staticmethod 15 | def backward(ctx, grad_output): 16 | return grad_output 17 | 18 | def scale_sign(input): 19 | return ScaleSigner.apply(input) 20 | 21 | class Quantizer(Function): 22 | @staticmethod 23 | def forward(ctx, input, nbit): 24 | scale = 2 ** nbit -1 25 | return torch.round(input * scale) / scale 26 | 27 | @staticmethod 28 | def backward(ctx, grad_output): 29 | return grad_output, None 30 | 31 | def quantize(input, nbit): 32 | return Quantizer.apply(input, nbit) 33 | 34 | def dorefa_w(w, nbit_w): 35 | if nbit_w == 1: 36 | w = scale_sign(w) 37 | else: 38 | w = torch.tanh(w) 39 | w = w / (2 * torch.max(torc.abs(w))) + 0.5 40 | w = 2 * quantize(w, nbit_w) - 1 41 | 42 | return w 43 | 44 | def dorefa_a(input, nbit_a): 45 | return quantize(torch.clamp(input, 0, 1), nbit_a) 46 | 47 | 48 | class QuanConv(nn.Conv2d): 49 | def __init__(self, in_channels, out_channels, kernel_size, quan_name_w='dorefa', quan_name_a='dorefa', nbit_w=1, nbit_a=1, stride=1, padding=0, dilation=1, groups=1, bias=True): 50 | super(QuanConv, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) 51 | self.nbit_w=nbit_w 52 | self.nbit_a=nbit_a 53 | name_w_dict={'dorefa':dorefa_w} 54 | name_a_dict={'dorefa':dorefa_a} 55 | self.quan_w = name_w_dict[quan_name_w] 56 | self.quan_a = name_a_dict[quan_name_a] 57 | 58 | def forward(self, input): 59 | if self.nbit_w<32: 60 | w = self.quan_w(self.weight, self.nbit_w) 61 | else: 62 | w = self.weight 63 | 64 | if self.nbit_a<32: 65 | x = self.quan_a(input, self.nbit_a) 66 | else: 67 | x = F.relu(input) 68 | 69 | output = F.conv2d(x, w, None, self.stride, self.padding, self.dilation, self.groups) 70 | return output 71 | -------------------------------------------------------------------------------- /vgg_small.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torchvision.transforms as transforms 3 | from quan_conv import QuanConv as Conv 4 | 5 | 6 | class VGG_Cifar10(nn.Module): 7 | 8 | def __init__(self, self, ratio_code, num_classes=10): 9 | super(VGG_Cifar10, self).__init__() 10 | in_channels = [3, 128, 128, 256, 256, 512] 11 | out_channels = [128, 128, 256, 256, 512, 512] 12 | for i in range(6): 13 | if i != 5: 14 | in_channels[i+1] = int(in_channels[i+1]*ratio_code[i]) 15 | out_channels[i] = int(out_channels[i]*ratio_code[i]) 16 | self.in_planes = int(512*4*4*ratio_code[5]) 17 | self.features = nn.Sequential( 18 | nn.Conv2d(in_channels[0], out_channels[0], kernel_size=3, stride=1, padding=1, 19 | bias=False), 20 | nn.BatchNorm2d(out_channels[0]), 21 | 22 | Conv(in_channels[1], out_channels[1], kernel_size=3, padding=1, bias=False), 23 | nn.MaxPool2d(kernel_size=2, stride=2), 24 | nn.BatchNorm2d(out_channels[1]), 25 | 26 | Conv(in_channels[2], out_channels[2], kernel_size=3, padding=1, bias=False), 27 | nn.BatchNorm2d(out_channels[2]), 28 | 29 | Conv(in_channels[3], out_channels[3], kernel_size=3, padding=1, bias=False), 30 | nn.MaxPool2d(kernel_size=2, stride=2), 31 | nn.BatchNorm2d(out_channels[3]), 32 | 33 | Conv(in_channels[4], out_channels[4], kernel_size=3, padding=1, bias=False), 34 | nn.BatchNorm2d(out_channels[4]), 35 | 36 | Conv(in_channels[5], out_channels[5], kernel_size=3, padding=1, bias=False), 37 | nn.MaxPool2d(kernel_size=2, stride=2), 38 | nn.BatchNorm2d(out_channels[5]), 39 | ) 40 | self.classifier = nn.Sequential( 41 | nn.Linear(self.in_planes, 10, bias=True), 42 | ) 43 | 44 | 45 | def forward(self, x): 46 | x = self.features(x) 47 | x = x.view(-1, self.in_planes) 48 | x = self.classifier(x) 49 | return x 50 | 51 | 52 | def vgg_small(ratio_code, num_classes=10, **kwargs): 53 | return VGG_Cifar10(ratio_code, num_classes) 54 | -------------------------------------------------------------------------------- /resnet18.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from quan_conv import QuanConv as Conv 4 | 5 | 6 | __all__ = ['ResNet', 'resnet18'] 7 | 8 | 9 | 10 | def conv3x3(in_planes, out_planes, stride=1): 11 | """3x3 convolution with padding""" 12 | return Conv(in_planes, out_planes, 3, stride=stride, padding=1, bias=False) 13 | 14 | 15 | def conv1x1(in_planes, out_planes, stride=1): 16 | """1x1 convolution""" 17 | return Conv(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 18 | 19 | 20 | class BasicBlock(nn.Module): 21 | expansion = 1 22 | 23 | def __init__(self, inplanes, midplanes, planes, stride=1, downsample=None): 24 | super(BasicBlock, self).__init__() 25 | self.conv1 = conv3x3(inplanes, midplanes, stride) 26 | self.bn1 = nn.BatchNorm2d(midplanes) 27 | self.conv2 = conv3x3(midplanes, planes) 28 | self.bn2 = nn.BatchNorm2d(planes) 29 | self.downsample = downsample 30 | self.stride = stride 31 | 32 | def forward(self, x): 33 | identity = x 34 | 35 | out = self.conv1(x) 36 | out = self.bn1(out) 37 | 38 | out = self.conv2(out) 39 | out = self.bn2(out) 40 | 41 | if self.downsample is not None: 42 | identity = self.downsample(x) 43 | 44 | out += identity 45 | 46 | return out 47 | 48 | 49 | class ResNet(nn.Module): 50 | 51 | def __init__(self, block, ratio_code, layers, num_classes=1000): 52 | super(ResNet, self).__init__() 53 | 54 | self.inplanes = int(64*ratio_code[0]) 55 | 56 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 57 | bias=False) 58 | self.bn1 = nn.BatchNorm2d(self.inplanes) 59 | self.relu = nn.ReLU(inplace=True) 60 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 61 | self.layer1 = self._make_layer(block, 64, int(64*ratio_code[0]), layers[0], ratio_code, 4) 62 | self.layer2 = self._make_layer(block, 128, int(128*ratio_code[1]), layers[1], ratio_code, 6, stride=2) 63 | self.layer3 = self._make_layer(block, 256, int(256*ratio_code[2]), layers[2], ratio_code, 8, stride=2) 64 | self.layer4 = self._make_layer(block, 512, int(512*ratio_code[3]), layers[3], ratio_code, 10, stride=2) 65 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 66 | self.fc = nn.Linear(int(512*ratio_code[3]) * block.expansion, num_classes) 67 | 68 | for m in self.modules(): 69 | if isinstance(m, nn.Conv2d): 70 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 71 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 72 | nn.init.constant_(m.weight, 1) 73 | nn.init.constant_(m.bias, 0) 74 | 75 | def _make_layer(self, block, midplanes, planes, blocks, ratio_code, start, stride=1): 76 | downsample = None 77 | if stride != 1 or self.inplanes != planes: 78 | downsample = nn.Sequential( 79 | conv1x1(self.inplanes, planes, stride=stride, bias=False), 80 | nn.BatchNorm2d(planes * block.expansion), 81 | ) 82 | 83 | j = start 84 | mid_planes = int(midplanes * ratio_code[j]) 85 | layers = [] 86 | layers.append(block(self.inplanes, mid_planes, planes, stride, downsample)) 87 | self.inplanes = planes 88 | for _ in range(1, blocks): 89 | mid_planes = int(midplanes * ratio_code[j+1]) 90 | layers.append(block(self.inplanes, mid_planes, planes)) 91 | 92 | return nn.Sequential(*layers) 93 | 94 | def forward(self, x): 95 | x = self.conv1(x) 96 | x = self.bn1(x) 97 | x = self.maxpool(x) 98 | 99 | x = self.layer1(x) 100 | x = self.layer2(x) 101 | x = self.layer3(x) 102 | x = self.layer4(x) 103 | x = self.relu(x) 104 | 105 | x = self.avgpool(x) 106 | x = torch.flatten(x, 1) 107 | x = self.fc(x) 108 | 109 | return x 110 | 111 | 112 | def resnet18(ratio_code, **kwargs): 113 | model = ResNet(BasicBlock, ratio_code, [2, 2, 2, 2]) 114 | return model 115 | 116 | 117 | --------------------------------------------------------------------------------