├── README.md ├── dorefa.py └── qnn_mnist.py /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # DoReFa-Net in PyTorch 4 | 5 | Link to the paper: https://arxiv.org/abs/1606.06160. 6 | 7 | Partial implementation supporting 1-bit weights and k-bit activations only. 8 | 9 | Contains a single example, on MNIST. 10 | 11 | 12 | 13 | ### Use with QNN-MO-PYNQ 14 | 15 | The scripts trains a LeNet run ``python qnn_mnist.py --ab k`` (where k is the activation bits). For use with the Finnthesizer in my [QNN-MO-PYNQ Fork](https://github.com/mohdumar644/QNN-MO-PYNQ), run the ``python qnn_mnist.py --export`` to create a compatible NPZ archive. 16 | 17 | 18 | -------------------------------------------------------------------------------- /dorefa.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pdb 3 | import torch.nn as nn 4 | import math 5 | from torch.autograd import Variable 6 | from torch.autograd import Function 7 | 8 | import numpy as np 9 | 10 | 11 | def Binarize(tensor): 12 | E = tensor.abs().mean() 13 | return tensor.sign() * E 14 | 15 | class Clamper(nn.Module): 16 | def __init__(self,minval,maxval): 17 | super(Clamper, self).__init__() 18 | self.minval = minval 19 | self.maxval = maxval 20 | 21 | def forward(self, x): 22 | return x.clamp_(self.minval, self.maxval) 23 | 24 | class Quantizer(nn.Module): 25 | def __init__(self, k): 26 | super(Quantizer, self).__init__() 27 | self.numbits = k 28 | 29 | def forward(self, input): 30 | return Quantize.apply(input, self.numbits) 31 | 32 | class Quantize(torch.autograd.Function): 33 | """ 34 | We can implement our own custom autograd Functions by subclassing 35 | torch.autograd.Function and implementing the forward and backward passes 36 | which operate on Tensors. 37 | """ 38 | 39 | @staticmethod 40 | def forward(ctx, input, k): 41 | """ 42 | In the forward pass we receive a Tensor containing the input and return 43 | a Tensor containing the output. ctx is a context object that can be used 44 | to stash information for backward computation. You can cache arbitrary 45 | objects for use in the backward pass using the ctx.save_for_backward method. 46 | """ 47 | ctx.save_for_backward(input) 48 | n = float(2 ** k - 1) 49 | input = input*n 50 | return input.round() / n 51 | 52 | @staticmethod 53 | def backward(ctx, grad_output): 54 | """ 55 | In the backward pass we receive a Tensor containing the gradient of the loss 56 | with respect to the output, and we need to compute the gradient of the loss 57 | with respect to the input. 58 | """ 59 | #input, = ctx.saved_tensors 60 | grad_input = grad_output.clone() 61 | #grad_input[input < 0] = 0 62 | return grad_input, None 63 | 64 | 65 | 66 | 67 | class BinarizeLinear(nn.Linear): 68 | 69 | def __init__(self, *kargs, **kwargs): 70 | super(BinarizeLinear, self).__init__(*kargs, **kwargs) 71 | 72 | def forward(self, input): 73 | 74 | if not hasattr(self.weight,'org'): 75 | self.weight.org=self.weight.data.clone() 76 | self.weight.data=Binarize(self.weight.org) 77 | 78 | out = nn.functional.linear(input, self.weight) 79 | 80 | 81 | if not self.bias is None: 82 | self.bias.org=self.bias.data.clone() 83 | out += self.bias.view(1, -1).expand_as(out) 84 | 85 | return out 86 | 87 | class BinarizeConv2d(nn.Conv2d): 88 | 89 | def __init__(self, *kargs, **kwargs): 90 | super(BinarizeConv2d, self).__init__(*kargs, **kwargs) 91 | 92 | 93 | def forward(self, input): 94 | 95 | if not hasattr(self.weight,'org'): 96 | self.weight.org=self.weight.data.clone() 97 | self.weight.data=Binarize(self.weight.org) 98 | 99 | out = nn.functional.conv2d(input, self.weight, None, self.stride, 100 | self.padding, self.dilation, self.groups) 101 | 102 | if not self.bias is None: 103 | self.bias.org=self.bias.data.clone() 104 | out += self.bias.view(1, -1, 1, 1).expand_as(out) 105 | 106 | return out 107 | -------------------------------------------------------------------------------- /qnn_mnist.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import print_function 3 | import argparse 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | from torchvision import datasets, transforms 9 | from torch.autograd import Variable 10 | from dorefa import * 11 | # from tensorboardX import SummaryWriter 12 | 13 | # Training settings 14 | parser = argparse.ArgumentParser(description='PyTorch QNN-MO-PYNQ MNIST Example') 15 | parser.add_argument('--batch-size', type=int, default=128, metavar='N', 16 | help='input batch size for training (default: 128)') 17 | parser.add_argument('--test-batch-size', type=int, default=128, metavar='N', 18 | help='input batch size for testing (default: 128)') 19 | parser.add_argument('--epochs', type=int, default=1000, metavar='N', 20 | help='number of epochs to train (default: 10000)') 21 | parser.add_argument('--lr', type=float, default=0.01, metavar='LR', 22 | help='learning rate (default: 0.01)') 23 | parser.add_argument('--momentum', type=float, default=0.5, metavar='M', 24 | help='SGD momentum (default: 0.5)') 25 | parser.add_argument('--no-cuda', action='store_true', default=False, 26 | help='disables CUDA training') 27 | parser.add_argument('--seed', type=int, default=1, metavar='S', 28 | help='random seed (default: 1)') 29 | parser.add_argument('--gpus', default=1, 30 | help='gpus used for training - e.g 0,1,3') 31 | parser.add_argument('--log-interval', type=int, default=10, metavar='N', 32 | help='how many batches to wait before logging training status') 33 | 34 | parser.add_argument('--resume', default=False, action='store_true', help='Perform only evaluation on val dataset.') 35 | parser.add_argument('--ab', type=int, default=2, metavar='N', help='number of bits for activations (default: 2)') 36 | parser.add_argument('--eval', default=False, action='store_true', help='perform evaluation of trained model') 37 | parser.add_argument('--export', default=False, action='store_true', help='perform weights export as npz of trained model') 38 | args = parser.parse_args() 39 | args.cuda = not args.no_cuda and torch.cuda.is_available() 40 | prev_acc = 0 41 | save_path = 'results/mnist-w1a{}.pt'.format(args.ab) 42 | class Net(nn.Module): 43 | def __init__(self): 44 | super(Net, self).__init__() 45 | 46 | self.features = nn.Sequential( 47 | BinarizeConv2d(1, 64, kernel_size=3, stride=1, padding=1,bias=True), 48 | nn.MaxPool2d(2, stride=2, padding=0), 49 | Clamper(0, 1), 50 | Quantizer(args.ab), 51 | 52 | BinarizeConv2d(64, 64, kernel_size=3, stride=1, padding=1,bias=False), 53 | nn.BatchNorm2d(64,momentum=0.9,eps=1e-4), 54 | Clamper(0, 1), 55 | Quantizer(args.ab), 56 | nn.MaxPool2d(2, stride=2, padding=0), 57 | 58 | 59 | BinarizeConv2d(64, 64, kernel_size=3, stride=1, padding=0,bias=False), 60 | nn.BatchNorm2d(64,momentum=0.9,eps=1e-4), 61 | Clamper(0, 1), 62 | Quantizer(args.ab)) 63 | 64 | self.classifier = nn.Sequential( 65 | BinarizeLinear(64*5*5, 512, bias=True), 66 | nn.Linear(512, 10), 67 | nn.LogSoftmax()) 68 | 69 | def forward(self, x): 70 | x = x.view(-1, 1,28,28) 71 | 72 | x = self.features(x) 73 | 74 | x = x.permute((0,2,3,1)) 75 | x = x.contiguous() 76 | x = x.view(-1, 64*5*5) 77 | 78 | x = self.classifier(x) 79 | 80 | return x 81 | 82 | def export(self): 83 | import numpy as np 84 | dic = {} 85 | i = 0 86 | j = 0 87 | # process conv and BN layers 88 | for k in range(len(self.features)): 89 | if hasattr(self.features[k], 'weight') and not hasattr(self.features[k], 'running_mean'): 90 | dic['conv'+str(i)+'/W:0'] = np.transpose(self.features[k].weight.detach().numpy(),(2,3,1,0)) 91 | if self.features[k].bias is not None: 92 | dic['conv'+str(i)+'/b:0'] = np.transpose(self.features[k].bias.detach().numpy()) 93 | i = i + 1 94 | elif hasattr(self.features[k], 'running_mean'): 95 | dic['bn'+str(j)+'/beta:0'] = self.features[k].bias.detach().numpy() 96 | dic['bn'+str(j)+'/gamma:0'] = self.features[k].weight.detach().numpy() 97 | dic['bn'+str(j)+'/mean/EMA:0'] = self.features[k].running_mean.detach().numpy() 98 | dic['bn'+str(j)+'/variance/EMA:0'] = self.features[k].running_var.detach().numpy() 99 | j = j + 1 100 | i = 0 101 | j = 0 102 | # process linear and BN layers 103 | for k in range(len(self.classifier)): 104 | if hasattr(self.classifier[k], 'weight') and not hasattr(self.classifier[k], 'running_mean'): 105 | dic['fc'+str(i)+'/W:0'] = np.transpose(self.classifier[k].weight.detach().numpy()) 106 | if self.classifier[k].bias is not None: 107 | dic['fc'+str(i)+'/b:0'] = self.classifier[k].bias.detach().numpy() 108 | i = i + 1 109 | elif hasattr(self.classifier[k], 'running_mean'): 110 | dic['bn'+str(j)+'/beta:0'] = self.classifier[k].bias.detach().numpy() 111 | dic['bn'+str(j)+'/gamma:0'] = self.classifier[k].weight.detach().numpy() 112 | dic['bn'+str(j)+'/mean/EMA:0'] = self.classifier[k].running_mean.detach().numpy() 113 | dic['bn'+str(j)+'/variance/EMA:0'] = self.classifier[k].running_var.detach().numpy() 114 | j = j + 1 115 | 116 | save_file = 'results/mnist-w1a{}.npz'.format(args.ab) 117 | np.savez(save_file, **dic) 118 | print("Model exported at: ", save_file) 119 | 120 | def train(epoch): 121 | model.train() 122 | for batch_idx, (data, target) in enumerate(train_loader): 123 | if args.cuda: 124 | data, target = data.cuda(), target.cuda() 125 | data, target = Variable(data), Variable(target) 126 | optimizer.zero_grad() 127 | output = model(data) 128 | loss = criterion(output, target) 129 | 130 | if epoch%40==0: 131 | optimizer.param_groups[0]['lr']=optimizer.param_groups[0]['lr']*0.1 132 | 133 | optimizer.zero_grad() 134 | loss.backward() 135 | for p in list(model.parameters()): 136 | if hasattr(p,'org'): 137 | p.data.copy_(p.org) 138 | optimizer.step() 139 | for p in list(model.parameters()): 140 | if hasattr(p,'org'): 141 | p.org.copy_(p.data) 142 | 143 | if batch_idx % args.log_interval == 0: 144 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 145 | epoch, batch_idx * len(data), len(train_loader.dataset), 146 | 100. * batch_idx / len(train_loader), loss.data)) 147 | 148 | def test(save_model=False): 149 | model.eval() 150 | test_loss = 0 151 | correct = 0 152 | global prev_acc 153 | with torch.no_grad(): 154 | for data, target in test_loader: 155 | if args.cuda: 156 | data, target = data.cuda(), target.cuda() 157 | output = model(data) 158 | test_loss += criterion(output, target).data # sum up batch loss 159 | pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability 160 | correct += pred.eq(target.data.view_as(pred)).cpu().sum() 161 | 162 | test_loss /= len(test_loader.dataset) 163 | new_acc = 100. * correct.float() / len(test_loader.dataset) 164 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format( 165 | test_loss, correct, len(test_loader.dataset), new_acc)) 166 | if new_acc > prev_acc: 167 | # save model 168 | if save_model: 169 | torch.save(model, save_path) 170 | print("Model saved at: ", save_path, "\n") 171 | prev_acc = new_acc 172 | 173 | if __name__ == '__main__': 174 | torch.manual_seed(args.seed) 175 | if args.cuda: 176 | torch.cuda.manual_seed(args.seed) 177 | 178 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} 179 | train_loader = torch.utils.data.DataLoader( 180 | datasets.MNIST('data', train=True, download=True, 181 | transform=transforms.Compose([ 182 | transforms.ToTensor() 183 | ])), 184 | batch_size=args.batch_size, shuffle=True, **kwargs) 185 | test_loader = torch.utils.data.DataLoader( 186 | datasets.MNIST('data', train=False, transform=transforms.Compose([ 187 | transforms.ToTensor() 188 | ])), 189 | batch_size=args.test_batch_size, shuffle=True, **kwargs) 190 | 191 | model = Net() 192 | if args.cuda: 193 | torch.cuda.set_device(0) 194 | print(torch.cuda.get_device_name(0)) 195 | model.cuda() 196 | dummy_input = Variable(torch.rand(1, 1, 28, 28)).cuda() 197 | else: 198 | dummy_input = Variable(torch.rand(1, 1, 28, 28)) 199 | 200 | # with SummaryWriter(comment='Net1') as w: 201 | # w.add_graph(model, (dummy_input, ), verbose=True) 202 | 203 | criterion = nn.CrossEntropyLoss() 204 | # test model 205 | if args.eval: 206 | model = torch.load(save_path) 207 | test() 208 | # export npz 209 | elif args.export: 210 | model = torch.load(save_path, map_location = 'cpu') 211 | model.export() 212 | # train model 213 | else: 214 | if args.resume: 215 | model = torch.load(save_path) 216 | test() 217 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 218 | for epoch in range(1, args.epochs + 1): 219 | train(epoch) 220 | test(save_model=True) 221 | 222 | --------------------------------------------------------------------------------