├── LICENSE ├── README.md ├── evaluate_packed.py ├── main.py ├── requirements.txt └── wrn_mcdonnell.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Sergey Zagoruyko 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 1-bit Wide ResNet 2 | =========== 3 | 4 | PyTorch implementation of training 1-bit Wide ResNets from this paper: 5 | 6 | *Training wide residual networks for deployment using a single bit for each weight* by **Mark D. McDonnell** at ICLR 2018 7 | 8 | 9 | 10 | 11 | 12 | The idea is very simple but surprisingly effective for training ResNets with binary weights. Here is the proposed weight parameterization as PyTorch autograd function: 13 | 14 | ```python 15 | class ForwardSign(torch.autograd.Function): 16 | @staticmethod 17 | def forward(ctx, w): 18 | return math.sqrt(2. / (w.shape[1] * w.shape[2] * w.shape[3])) * w.sign() 19 | 20 | @staticmethod 21 | def backward(ctx, g): 22 | return g 23 | ``` 24 | 25 | On forward, we take sign of the weights and scale it by He-init constant. On backward, we propagate gradient without changes. WRN-20-10 trained with such parameterization is only slightly off from it's full precision variant, here is what I got myself with this code on CIFAR-100: 26 | 27 | | network | accuracy (5 runs mean +- std) | checkpoint (Mb) | 28 | |:---|:---:|:---:| 29 | | WRN-20-10 | 80.5 +- 0.24 | 205 Mb | 30 | | WRN-20-10-1bit | 80.0 +- 0.26 | 3.5 Mb | 31 | 32 | ## Details 33 | 34 | Here are the differences with WRN code : 35 | 36 | * BatchNorm has no affine weight and bias parameters 37 | * First layer has 16 * width channels 38 | * Last fc layer is removed in favor of 1x1 conv + F.avg_pool2d 39 | * Downsample is done by F.avg_pool2d + torch.cat instead of strided conv 40 | * SGD with cosine annealing and warm restarts 41 | 42 | I used PyTorch 0.4.1 and Python 3.6 to run the code. 43 | 44 | Reproduce WRN-20-10 with 1-bit training on CIFAR-100: 45 | 46 | ```bash 47 | python main.py --binarize --save ./logs/WRN-20-10-1bit_$RANDOM --width 10 --dataset CIFAR100 48 | ``` 49 | 50 | Convergence plot (train error in dash): 51 | 52 | download 53 | 54 | I've also put 3.5 Mb checkpoint with binary weights packed with `np.packbits`, and a very short script to evaluate it: 55 | 56 | ```bash 57 | python evaluate_packed.py --checkpoint wrn20-10-1bit-packed.pth.tar --width 10 --dataset CIFAR100 58 | ``` 59 | 60 | S3 url to checkpoint: 61 | -------------------------------------------------------------------------------- /evaluate_packed.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | from tqdm import tqdm 4 | import torch 5 | from torch.utils.data import DataLoader 6 | from torchnet.meter import ClassErrorMeter 7 | from wrn_mcdonnell import WRN_McDonnell 8 | from main import create_dataset 9 | 10 | 11 | def parse_args(): 12 | parser = argparse.ArgumentParser(description='Binary Wide Residual Networks') 13 | # Model options 14 | parser.add_argument('--depth', default=20, type=int) 15 | parser.add_argument('--width', default=1, type=float) 16 | parser.add_argument('--dataset', default='CIFAR10', type=str) 17 | parser.add_argument('--dataroot', default='.', type=str) 18 | parser.add_argument('--checkpoint', required=True, type=str) 19 | return parser.parse_args() 20 | 21 | 22 | def main(): 23 | args = parse_args() 24 | num_classes = 10 if args.dataset == 'CIFAR10' else 100 25 | 26 | have_cuda = torch.cuda.is_available() 27 | def cast(x): 28 | return x.cuda() if have_cuda else x 29 | 30 | checkpoint = torch.load(args.checkpoint) 31 | 32 | weights_unpacked = {} 33 | for k, w in checkpoint.items(): 34 | if w.dtype == torch.uint8: 35 | # weights are packed with np.packbits function 36 | scale = np.sqrt(2 / (w.shape[1] * w.shape[2] * w.shape[3] * 8)) 37 | signed = np.unpackbits(w, axis=1).astype(np.int) * 2 - 1 38 | weights_unpacked[k[7:]] = torch.from_numpy(signed).float() * scale 39 | else: 40 | weights_unpacked[k[7:]] = w 41 | 42 | model = WRN_McDonnell(args.depth, args.width, num_classes) 43 | model.load_state_dict(weights_unpacked) 44 | model = cast(model) 45 | model.eval() 46 | 47 | class_acc = ClassErrorMeter(accuracy=True) 48 | 49 | for inputs, targets in tqdm(DataLoader(create_dataset(args, train=False), 256)): 50 | with torch.no_grad(): 51 | class_acc.add(model(cast(inputs)).cpu(), targets) 52 | 53 | print(class_acc.value()) 54 | 55 | 56 | if __name__ == '__main__': 57 | main() 58 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """Implementation of 1-bit Wide Residual Networks 2 | 3 | Implements ICLR 2018 paper: 4 | "Training wide residual networks for deployment using a single bit for each weight" 5 | by Mark D. McDonnell 6 | 7 | 2018 Sergey Zagoruyko 8 | """ 9 | from pathlib import Path 10 | import argparse 11 | import json 12 | import numpy as np 13 | from tqdm import tqdm 14 | import torch 15 | from torch.optim import SGD 16 | from torch.optim.lr_scheduler import CosineAnnealingLR 17 | import torch.utils.data 18 | from torch.utils.data import DataLoader 19 | from torch.nn.functional import cross_entropy 20 | from torch.nn import DataParallel 21 | from torch.backends import cudnn 22 | import torchvision.transforms as T 23 | import torchvision.datasets as datasets 24 | import torchnet as tnt 25 | from wrn_mcdonnell import WRN_McDonnell 26 | 27 | cudnn.benchmark = True 28 | 29 | def parse_args(): 30 | parser = argparse.ArgumentParser(description='Binary Wide Residual Networks') 31 | # Model options 32 | parser.add_argument('--depth', default=20, type=int) 33 | parser.add_argument('--width', default=1, type=float) 34 | parser.add_argument('--dataset', default='CIFAR10', type=str) 35 | parser.add_argument('--dataroot', default='.', type=str) 36 | parser.add_argument('--nthread', default=4, type=int) 37 | parser.add_argument('--seed', default=1, type=int) 38 | parser.add_argument('--binarize', action='store_true') 39 | 40 | # Training options 41 | parser.add_argument('--batch_size', default=128, type=int) 42 | parser.add_argument('--lr', default=0.1, type=float) 43 | parser.add_argument('--lr-min', default=0.0001, type=float) 44 | parser.add_argument('--epochs', default=256, type=int, metavar='N', 45 | help='number of total epochs to run') 46 | parser.add_argument('--weight_decay', default=0.0005, type=float) 47 | parser.add_argument('--restarts', default='[2,4,8,16,32,64,128]', type=json.loads, 48 | help='json list with epochs to drop lr on') 49 | parser.add_argument('--resume', default='', type=str) 50 | parser.add_argument('--save', default='', type=str, 51 | help='save parameters and logs in this folder') 52 | return parser.parse_args() 53 | 54 | 55 | def create_dataset(args, train): 56 | transform = T.Compose([ 57 | T.ToTensor(), 58 | T.Normalize(np.array([125.3, 123.0, 113.9]) / 255.0, 59 | np.array([63.0, 62.1, 66.7]) / 255.0), 60 | ]) 61 | if train: 62 | transform = T.Compose([ 63 | T.Pad(4, padding_mode='reflect'), 64 | T.RandomHorizontalFlip(), 65 | T.RandomCrop(32), 66 | transform 67 | ]) 68 | return getattr(datasets, args.dataset)(args.dataroot, train=train, download=True, transform=transform) 69 | 70 | 71 | def main(): 72 | args = parse_args() 73 | print('parsed options:', vars(args)) 74 | 75 | have_cuda = torch.cuda.is_available() 76 | def cast(x): 77 | return x.cuda() if have_cuda else x 78 | 79 | torch.manual_seed(args.seed) 80 | 81 | num_classes = 10 if args.dataset == 'CIFAR10' else 100 82 | 83 | def create_iterator(mode): 84 | return DataLoader(create_dataset(args, mode), args.batch_size, shuffle=mode, 85 | num_workers=args.nthread, pin_memory=torch.cuda.is_available()) 86 | 87 | train_loader = create_iterator(True) 88 | test_loader = create_iterator(False) 89 | 90 | model = WRN_McDonnell(args.depth, args.width, num_classes, args.binarize) 91 | model = cast(DataParallel(model)) 92 | 93 | n_parameters = sum(p.numel() for p in model.parameters()) 94 | 95 | optimizer = SGD(model.parameters(), args.lr, momentum=0.9, weight_decay=args.weight_decay) 96 | scheduler = CosineAnnealingLR(optimizer, T_max=2, eta_min=args.lr_min) 97 | 98 | start_epoch = 0 99 | if args.resume: 100 | checkpoint = torch.load(args.resume) 101 | model.load_state_dict(checkpoint['state_dict']) 102 | optimizer.load_state_dict(checkpoint['optimizer']) 103 | scheduler.load_state_dict(checkpoint['scheduler']) 104 | start_epoch = checkpoint['epoch'] 105 | 106 | if not Path(args.save).exists(): 107 | Path(args.save).mkdir() 108 | 109 | def log(log_data): 110 | torch.save({'state_dict': model.state_dict(), 111 | 'optimizer': optimizer.state_dict(), 112 | 'scheduler': scheduler.state_dict(), 113 | 'epoch': log_data['epoch'], 114 | }, Path(args.save) / 'checkpoint.pth.tar') 115 | z = {**vars(args), **log_data} 116 | with open(Path(args.save) / 'log.txt', 'a') as f: 117 | f.write(json.dumps(z) + '\n') 118 | print(z) 119 | 120 | def train(): 121 | model.train() 122 | meter_loss = tnt.meter.AverageValueMeter() 123 | classacc = tnt.meter.ClassErrorMeter(accuracy=True) 124 | train_iterator = tqdm(train_loader, dynamic_ncols=True) 125 | for x, y in train_iterator: 126 | optimizer.zero_grad() 127 | outputs = model(cast(x)) 128 | loss = cross_entropy(outputs, cast(y)) 129 | loss.backward() 130 | optimizer.step() 131 | meter_loss.add(loss.item()) 132 | train_iterator.set_postfix(loss=loss.item()) 133 | classacc.add(outputs.data.cpu(), y.cpu()) 134 | return meter_loss.mean, classacc.value()[0] 135 | 136 | def test(): 137 | model.eval() 138 | meter_loss = tnt.meter.AverageValueMeter() 139 | classacc = tnt.meter.ClassErrorMeter(accuracy=True) 140 | test_iterator = tqdm(test_loader, dynamic_ncols=True) 141 | for x, y in test_iterator: 142 | optimizer.zero_grad() 143 | outputs = model(cast(x)) 144 | loss = cross_entropy(outputs, cast(y)) 145 | loss.backward() 146 | meter_loss.add(loss.item()) 147 | classacc.add(outputs.data.cpu(), y.cpu()) 148 | return meter_loss.mean, classacc.value()[0] 149 | 150 | for epoch in range(start_epoch, args.epochs): 151 | scheduler.step() 152 | if epoch in args.restarts: 153 | scheduler = CosineAnnealingLR(optimizer, T_max=epoch, eta_min=args.lr_min) 154 | train_loss, train_acc = train() 155 | test_loss, test_acc = test() 156 | log_data = { 157 | "train_loss": train_loss, 158 | "train_acc": train_acc, 159 | "test_loss": test_loss, 160 | "test_acc": test_acc, 161 | "epoch": epoch, 162 | "num_classes": num_classes, 163 | "n_parameters": n_parameters, 164 | "lr": scheduler.get_lr(), 165 | } 166 | log(log_data) 167 | print('==> id: %s (%d/%d), test_acc: \33[91m%.2f\033[0m' % 168 | (args.save, epoch, args.epochs, test_acc)) 169 | 170 | 171 | if __name__ == '__main__': 172 | main() 173 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch >= 0.4.1 2 | torchnet 3 | tqdm 4 | torchvision -------------------------------------------------------------------------------- /wrn_mcdonnell.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import math 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | 7 | 8 | def init_weight(*args): 9 | return nn.Parameter(nn.init.kaiming_normal_(torch.zeros(*args), mode='fan_out', nonlinearity='relu')) 10 | 11 | 12 | class ForwardSign(torch.autograd.Function): 13 | """Fake sign op for 1-bit weights. 14 | 15 | See eq. (1) in https://arxiv.org/abs/1802.08530 16 | 17 | Does He-init like forward, and nothing on backward. 18 | """ 19 | 20 | @staticmethod 21 | def forward(ctx, x): 22 | return math.sqrt(2. / (x.shape[1] * x.shape[2] * x.shape[3])) * x.sign() 23 | 24 | @staticmethod 25 | def backward(ctx, g): 26 | return g 27 | 28 | 29 | class ModuleBinarizable(nn.Module): 30 | 31 | def __init__(self, binarize=False): 32 | super().__init__() 33 | self.binarize = binarize 34 | 35 | def _get_weight(self, name): 36 | w = getattr(self, name) 37 | return ForwardSign.apply(w) if self.binarize else w 38 | 39 | def forward(self): 40 | pass 41 | 42 | 43 | class Block(ModuleBinarizable): 44 | """Pre-activated ResNet block. 45 | """ 46 | 47 | def __init__(self, width, binarize=False): 48 | super().__init__(binarize) 49 | self.bn0 = nn.BatchNorm2d(width, affine=False) 50 | self.register_parameter('conv0', init_weight(width, width, 3, 3)) 51 | self.bn1 = nn.BatchNorm2d(width, affine=False) 52 | self.register_parameter('conv1', init_weight(width, width, 3, 3)) 53 | 54 | def forward(self, x): 55 | h = F.conv2d(F.relu(self.bn0(x)), self._get_weight('conv0'), padding=1) 56 | h = F.conv2d(F.relu(self.bn1(h)), self._get_weight('conv1'), padding=1) 57 | return x + h 58 | 59 | 60 | class DownsampleBlock(ModuleBinarizable): 61 | """Downsample block. 62 | 63 | Does F.avg_pool2d + torch.cat instead of strided conv. 64 | """ 65 | 66 | def __init__(self, width, binarize=False): 67 | super().__init__(binarize) 68 | self.bn0 = nn.BatchNorm2d(width // 2, affine=False) 69 | self.register_parameter('conv0', init_weight(width, width // 2, 3, 3)) 70 | self.bn1 = nn.BatchNorm2d(width, affine=False) 71 | self.register_parameter('conv1', init_weight(width, width, 3, 3)) 72 | 73 | def forward(self, x): 74 | h = F.conv2d(F.relu(self.bn0(x)), self._get_weight('conv0'), padding=1, stride=2) 75 | h = F.conv2d(F.relu(self.bn1(h)), self._get_weight('conv1'), padding=1) 76 | x_d = F.avg_pool2d(x, kernel_size=3, padding=1, stride=2) 77 | x_d = torch.cat([x_d, torch.zeros_like(x_d)], dim=1) 78 | return x_d + h 79 | 80 | 81 | class WRN_McDonnell(ModuleBinarizable): 82 | """Implementation of modified Wide Residual Network. 83 | 84 | Differences with pre-activated ResNet and Wide ResNet: 85 | * BatchNorm has no affine weight and bias parameters 86 | * First layer has 16 * width channels 87 | * Last fc layer is removed in favor of 1x1 conv + F.avg_pool2d 88 | * Downsample is done by F.avg_pool2d + torch.cat instead of strided conv 89 | 90 | First and last convolutional layers are kept in float32. 91 | """ 92 | 93 | def __init__(self, depth, width, num_classes, binarize=False): 94 | super().__init__() 95 | self.binarize = binarize 96 | widths = [int(v * width) for v in (16, 32, 64)] 97 | n = (depth - 2) // 6 98 | 99 | self.register_parameter('conv0', init_weight(widths[0], 3, 3, 3)) 100 | 101 | self.group0 = self._make_block(widths[0], n) 102 | self.group1 = self._make_block(widths[1], n, downsample=True) 103 | self.group2 = self._make_block(widths[2], n, downsample=True) 104 | 105 | self.bn = nn.BatchNorm2d(widths[2], affine=False) 106 | self.register_parameter('conv_last', init_weight(num_classes, widths[2], 1, 1)) 107 | self.bn_last = nn.BatchNorm2d(num_classes) 108 | 109 | def _make_block(self, width, n, downsample=False): 110 | def select_block(j): 111 | if downsample and j == 0: 112 | return DownsampleBlock(width, self.binarize) 113 | return Block(width, self.binarize) 114 | return nn.Sequential(OrderedDict(('block%d' % i, select_block(i)) 115 | for i in range(n))) 116 | 117 | def forward(self, x): 118 | h = F.conv2d(x, self.conv0, padding=1) 119 | h = self.group0(h) 120 | h = self.group1(h) 121 | h = self.group2(h) 122 | h = F.relu(self.bn(h)) 123 | h = F.conv2d(h, self.conv_last) 124 | h = self.bn_last(h) 125 | return F.avg_pool2d(h, kernel_size=h.shape[-2:]).view(h.shape[0], -1) 126 | --------------------------------------------------------------------------------