├── .gitignore ├── Cifar10 ├── main.py └── models │ ├── Bin_VGG.py │ └── VGG.py ├── LICENSE ├── MNIST ├── main.py └── models │ ├── Bin_LeNet.py │ ├── LeNet.py │ └── __init__.py ├── README.md ├── csrc └── binop │ ├── Makefile │ ├── build.py │ ├── include │ ├── binop.h │ ├── binop_cuda.h │ ├── binop_cuda_kernel.h │ ├── libpopcnt.h │ └── matmul.h │ ├── make.sh │ └── src │ ├── binop.c │ ├── binop_cuda.c │ └── binop_cuda_kernel.cu └── util ├── __init__.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /Cifar10/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | from torchvision import datasets, transforms 7 | from torch.autograd import Variable 8 | import os 9 | import sys 10 | import itertools 11 | import numpy as np 12 | import matplotlib.pyplot as plt 13 | from tqdm import tqdm 14 | tqdm.monitor_interval = 0 15 | from sklearn.metrics import classification_report 16 | from sklearn.metrics import confusion_matrix 17 | import models as models 18 | sys.path.append('../') 19 | import binop 20 | from util import binop_train 21 | from util import bin_save_state 22 | 23 | class RunningMean: 24 | def __init__(self, value=0, count=0): 25 | self.total_value = value 26 | self.count = count 27 | 28 | def update(self, value, count=1): 29 | self.total_value += value 30 | self.count += count 31 | 32 | @property 33 | def value(self): 34 | if self.count: 35 | return self.total_value / self.count 36 | else: 37 | return float("inf") 38 | 39 | def __str__(self): 40 | return str(self.value) 41 | 42 | def plot_confusion_matrix(cm, classes, normalize=False, 43 | title='Confusion matrix', cmap=plt.cm.Blues): 44 | """ 45 | This function prints and plots the confusion matrix. 46 | Normalization can be applied by setting `normalize=True`. 47 | """ 48 | 49 | if normalize: 50 | cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] 51 | print("Normalized confusion matrix") 52 | else: 53 | print('Confusion matrix') 54 | 55 | print(cm) 56 | 57 | plt.imshow(cm, interpolation='nearest', cmap=cmap) 58 | plt.title(title) 59 | plt.colorbar() 60 | tick_marks = np.arange(len(classes)) 61 | plt.xticks(tick_marks, classes, rotation=45) 62 | plt.yticks(tick_marks, classes) 63 | 64 | fmt = '.2f' if normalize else 'd' 65 | thresh = cm.max() / 2. 66 | for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): 67 | plt.text(j, i, format(cm[i, j], fmt), 68 | horizontalalignment="center", 69 | color="white" if cm[i, j] > thresh else "black") 70 | 71 | plt.tight_layout() 72 | plt.ylabel('True label') 73 | plt.xlabel('Predicted label') 74 | 75 | def eval_test(y_pred, y_true): 76 | cnf_matrix = confusion_matrix(y_true, y_pred) 77 | np.set_printoptions(precision=6) 78 | 79 | # Plot non-normalized confusion matrix 80 | plt.figure() 81 | plot_confusion_matrix(cnf_matrix, classes=target_name, 82 | title='Confusion matrix') 83 | 84 | # Plot normalized confusion matrix 85 | plt.figure() 86 | plot_confusion_matrix(cnf_matrix, classes=target_name, normalize=True, 87 | title='Normalized confusion matrix') 88 | 89 | print(classification_report(y_true=y_true, y_pred=y_pred, target_names=target_name, digits=6)) 90 | plt.show() 91 | 92 | 93 | def save_state(model, mname=""): 94 | print('==> Saving model ...') 95 | torch.save(model.state_dict(), 'models/' + args.arch + mname +'.pth') 96 | 97 | def train_bin(epoch): 98 | running_loss = RunningMean() 99 | running_score = RunningMean() 100 | model_train.train() 101 | pbar = tqdm(train_loader, total=len(train_loader)) 102 | 103 | for data, target in pbar: 104 | batch_size = data.size(0) 105 | if args.cuda: 106 | data, target = data.cuda(), target.cuda() 107 | data, target = Variable(data), Variable(target) 108 | 109 | optimizer.zero_grad() 110 | binop_train.binarization() 111 | 112 | output = model_train(data) 113 | _, preds = torch.max(output.data, dim=1) 114 | loss = criterion(output, target) 115 | running_loss.update(loss.data[0], 1) 116 | running_score.update(torch.sum(preds != target.data), batch_size) 117 | loss.backward() 118 | 119 | # restore weights 120 | binop_train.restore() 121 | # update 122 | binop_train.updateBinaryGradWeight() 123 | 124 | optimizer.step() 125 | pbar.set_description('%.6f %.6f' % (running_loss.value, running_score.value)) 126 | print('[+] epoch %d: \nTraining: Average loss: %.6f, Average error: %.6f' % ( 127 | epoch, running_loss.value, running_score.value)) 128 | bin_save_state(args, model_train) 129 | 130 | def train(epoch): 131 | running_loss = RunningMean() 132 | running_score = RunningMean() 133 | model_ori.train() 134 | pbar = tqdm(train_loader, total=len(train_loader),) 135 | for data, target in pbar: 136 | batch_size = data.size(0) 137 | if args.cuda: 138 | data, target = data.cuda(), target.cuda() 139 | data, target = Variable(data), Variable(target) 140 | 141 | optimizer.zero_grad() 142 | output = model_ori(data) 143 | _, preds = torch.max(output.data, dim=1) 144 | loss = criterion(output, target) 145 | running_loss.update(loss.data[0], 1) 146 | running_score.update(torch.sum(preds != target.data), batch_size) 147 | loss.backward() 148 | 149 | optimizer.step() 150 | 151 | pbar.set_description('%.6f %.6f' % (running_loss.value, running_score.value)) 152 | print('[+] epoch %d: \nTraining: Average loss: %.6f, Average error: %.6f' % ( 153 | epoch, running_loss.value, running_score.value)) 154 | save_state(model_ori) 155 | 156 | # def test_train(model): 157 | # test_loss = 0 158 | # correct = 0 159 | # model.eval() 160 | # binop_train.binarization() 161 | # pbar = tqdm(test_loader, total=len(test_loader)) 162 | # for data, target in pbar: 163 | # if args.cuda: 164 | # data, target = data.cuda(), target.cuda() 165 | # data, target = Variable(data), Variable(target) 166 | # output = model(data) 167 | # test_loss += criterion(output, target).data[0] 168 | # pred = output.data.max(1, keepdim=False)[1] 169 | # correct += pred.eq(target.data).cpu().sum() 170 | # acc = 100. * correct / len(test_loader.dataset) 171 | # test_loss /= len(test_loader.dataset) 172 | # print('\nTrain Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'.format( 173 | # test_loss * args.batch_size, correct, len(test_loader.dataset), 174 | # 100. * correct / len(test_loader.dataset))) 175 | # print('Best Accuracy: {:.2f}%\n'.format(best_acc)) 176 | # binop_train.restore() 177 | 178 | def test(model, evaluate=False): 179 | global best_acc 180 | test_loss = 0 181 | correct = 0 182 | model.eval() 183 | if evaluate: 184 | model.load_state_dict(torch.load(args.pretrained)) 185 | else: 186 | model.load_state_dict(torch.load('models/' + args.arch + '.pth')) 187 | pbar = tqdm(test_loader, total=len(test_loader)) 188 | if evaluate: 189 | pred = torch.LongTensor() 190 | true = torch.LongTensor() 191 | if args.cuda: 192 | pred = pred.cuda() 193 | true = true.cuda() 194 | for data, target in pbar: 195 | if args.cuda: 196 | data, target = data.cuda(), target.cuda() 197 | 198 | data, target = Variable(data), Variable(target) 199 | output = model(data) 200 | test_loss += criterion(output, target).data[0] 201 | if evaluate: 202 | pred = torch.cat((pred, output.data.max(1, keepdim=False)[1])) 203 | true = torch.cat((true, target.data)) 204 | else: 205 | pred = output.data.max(1, keepdim=False)[1] 206 | correct += pred.eq(target.data).cpu().sum() 207 | if evaluate: 208 | correct = pred.eq(true).cpu().sum() 209 | acc = 100. * correct / len(test_loader.dataset) 210 | test_loss /= len(test_loader.dataset) 211 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'.format( 212 | test_loss * args.batch_size, correct, len(test_loader.dataset), 213 | 100. * correct / len(test_loader.dataset))) 214 | 215 | if not evaluate: 216 | if (acc > best_acc): 217 | best_acc = acc 218 | os.rename('models/' + args.arch + '.pth', 'models/' + args.arch + '.best.pth') 219 | if model_train is not None: 220 | save_state(model_train, "bak") 221 | else: 222 | os.remove('models/' + args.arch + '.pth') 223 | print('Best Accuracy: {:.2f}%\n'.format(best_acc)) 224 | else: 225 | eval_test(y_pred=pred.tolist(), y_true=true.tolist()) 226 | 227 | 228 | 229 | def adjust_learning_rate(optimizer, epoch): 230 | """Sets the learning rate to the initial LR decayed by 10 every 15 epochs""" 231 | update_list = [150, 200, 250] 232 | if epoch in update_list: 233 | for param_group in optimizer.param_groups: 234 | param_group['lr'] = param_group['lr'] * 0.1 235 | # lr = args.lr * (0.1 ** (epoch // 50)) 236 | # for param_group in optimizer.param_groups: 237 | # param_group['lr'] = lr 238 | print('Learning Rate: {}'.format(optimizer.param_groups[0]['lr'])) 239 | 240 | if __name__ == '__main__': 241 | # Training settings 242 | 243 | parser = argparse.ArgumentParser(description='PyTorch Cifar-10') 244 | 245 | parser.add_argument('--batch-size', type=int, default=128, metavar='N', 246 | help='input batch size for training (default: 128)') 247 | parser.add_argument('--test-batch-size', type=int, default=100, metavar='N', 248 | help='input batch size for testing (default: 128)') 249 | parser.add_argument('--epochs', type=int, default=300, metavar='N', 250 | help='number of epochs to train (default: 100)') 251 | parser.add_argument('--lr-epochs', type=int, default=100, metavar='N', 252 | help='number of epochs to decay the lr (default: 20)') 253 | parser.add_argument('--lr', type=float, default=0.1, metavar='LR', 254 | help='learning rate (default: 0.01)') 255 | parser.add_argument('--weight-decay', '--wd', default=1e-5, type=float, 256 | metavar='W', help='weight decay (default: 5e-4)') 257 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 258 | help='SGD momentum (default: 0.9)') 259 | parser.add_argument('--seed', type=int, default=1, metavar='S', 260 | help='random seed (default: 1)') 261 | parser.add_argument('--log-interval', type=int, default=100, metavar='N', 262 | help='how many batches to wait before logging training status') 263 | parser.add_argument('--arch', action='store', default='Bin_VGG16', 264 | help='the MNIST network structure: Bin_VGG19') 265 | parser.add_argument('--pretrained', action='store', default=None, 266 | help='pretrained model') 267 | parser.add_argument('--evaluate', action='store_true', default=False, 268 | help='whether to run evaluation') 269 | parser.add_argument('--no_cuda', action='store_true', default=False, 270 | help='disables CUDA training') 271 | args = parser.parse_args() 272 | args.cuda = not args.no_cuda and torch.cuda.is_available() 273 | target_name = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] 274 | print(args) 275 | cnt = 0 276 | lr_cnt = 0 277 | 278 | if args.cuda: 279 | torch.cuda.manual_seed(args.seed) 280 | else: 281 | torch.manual_seed(args.seed) 282 | 283 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 284 | std=[0.229, 0.224, 0.225]) 285 | # load data 286 | kwargs = {'num_workers': 1, 'pin_memory': True} 287 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 288 | std=[0.229, 0.224, 0.225]) 289 | train_loader = torch.utils.data.DataLoader( 290 | datasets.CIFAR10(root='./data', train=True, transform=transforms.Compose([ 291 | transforms.RandomHorizontalFlip(), 292 | transforms.RandomCrop(32, 4), 293 | transforms.ToTensor(), 294 | normalize, 295 | ]), download=True),batch_size=args.batch_size, shuffle=True, **kwargs) 296 | test_loader = torch.utils.data.DataLoader( 297 | datasets.CIFAR10(root='./data', train=False, transform=transforms.Compose([ 298 | transforms.ToTensor(), 299 | normalize, 300 | ])), 301 | batch_size=args.batch_size, shuffle=False, **kwargs) 302 | 303 | best_acc = 0.0 304 | model_ori = None 305 | model_train = None 306 | model_test = None 307 | 308 | # generate the model 309 | model_names = ['RESNET18','NIN', 'VGG11', 'VGG13', 'VGG16', 'VGG19'] 310 | args.arch = args.arch.upper() 311 | if '_' in args.arch: 312 | _, name = args.arch.split('_') 313 | else: 314 | name = args.arch 315 | 316 | if name in model_names: 317 | if 'BIN' in args.arch: 318 | if 'VGG' in name: 319 | model_train = models.Bin_VGG_train(name) 320 | model_test = models.Bin_VGG_test(name) 321 | elif 'NIN' in name: 322 | model_train = models.Bin_NIN_train() 323 | model_test = models.Bin_NIN_test() 324 | elif 'RESNET18' in name: 325 | pass 326 | if args.cuda: 327 | model_train.cuda() 328 | model_test.cuda() 329 | 330 | if args.pretrained: 331 | if args.evaluate: 332 | model_test.load_state_dict(torch.load(args.pretrained)) 333 | else: 334 | model_train.load_state_dict(torch.load(args.pretrained)) 335 | binop_train = binop_train(model_train) 336 | else: 337 | binop_train = binop_train(model_train) 338 | 339 | else: 340 | if 'VGG' in name: 341 | model_ori = models.VGG(name) 342 | elif 'NIN' in name: 343 | model_ori = models.NIN() 344 | elif "RESNET18" in name: 345 | pass 346 | if args.cuda: 347 | model_ori.cuda() 348 | 349 | if args.pretrained: 350 | model_ori.load_state_dict(torch.load(args.pretrained)) 351 | 352 | else: 353 | print('ERROR: specified arch is not suppported') 354 | exit() 355 | 356 | param_dict = dict(model_train.named_parameters()) if model_ori is None else dict(model_ori.named_parameters()) 357 | params = [] 358 | 359 | for key, value in param_dict.items(): 360 | if value.requires_grad: 361 | params += [{ 362 | 'params': [value], 363 | 'lr': args.lr, 364 | 'key': key 365 | }] 366 | 367 | optimizer = torch.optim.SGD(params, lr=args.lr, momentum=args.momentum) 368 | 369 | #optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay) 370 | #optimizer = optim.Adam(params, lr=args.lr) 371 | 372 | 373 | criterion = nn.CrossEntropyLoss() 374 | if args.cuda: 375 | criterion.cuda() 376 | 377 | 378 | if args.evaluate: 379 | if model_ori is None: 380 | test(model_test,evaluate=True) 381 | else: 382 | test(model_ori, evaluate=True) 383 | exit() 384 | 385 | for epoch in range(1, args.epochs + 1): 386 | adjust_learning_rate(optimizer, epoch) 387 | if model_ori is None: 388 | train_bin(epoch) 389 | test(model_test) 390 | 391 | else: 392 | train(epoch) 393 | test(model_ori) 394 | 395 | -------------------------------------------------------------------------------- /Cifar10/models/Bin_VGG.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | import torch.nn.functional as F 6 | from collections import OrderedDict 7 | import sys 8 | sys.path.append("..") 9 | from util import BinLinear 10 | from util import BinConv2d 11 | 12 | cfg = { 13 | 'VGG11': ['M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 14 | 'VGG13': [64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 15 | 'VGG16': [64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 16 | 'VGG19': [64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 17 | } 18 | 19 | class Bin_VGG_train(nn.Module): 20 | def __init__(self, vgg_name): 21 | super(Bin_VGG_train, self).__init__() 22 | self.features = self._make_layers(cfg[vgg_name]) 23 | self.classifier = nn.Linear(512, 10) 24 | for m in self.modules(): 25 | if isinstance(m, nn.Conv2d): 26 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 27 | m.weight.data.normal_(0, math.sqrt(2./n)) 28 | m.bias.data.zero_() 29 | 30 | def forward(self, x): 31 | out = self.features(x) 32 | out = out.view(out.size(0), -1) 33 | out = self.classifier(out) 34 | return out 35 | 36 | def _make_layers(self, cfg): 37 | layers = OrderedDict([ 38 | ('conv0', nn.Conv2d(3, 64, kernel_size=3, padding=1)), 39 | ('bn0', nn.BatchNorm2d(64)), 40 | ('relu0', nn.ReLU(inplace=True)) 41 | ]) 42 | in_channels = 64 43 | cnt = 1 44 | for x in cfg: 45 | if x == 'M': 46 | layers['pool'+str(cnt)] = nn.MaxPool2d(kernel_size=2, stride=2) 47 | cnt += 1 48 | else: 49 | layers['conv'+str(cnt)] = BinConv2d(in_channels=in_channels, out_channels=x, kernel_size=3, padding=1, istrain=True) 50 | cnt += 1 51 | layers['bn'+str(cnt)] = nn.BatchNorm2d(x) 52 | cnt += 1 53 | layers['relu'+str(cnt)] = nn.ReLU(inplace=True) 54 | cnt += 1 55 | in_channels = x 56 | layers['pool'+str(cnt)] = nn.AvgPool2d(kernel_size=1, stride=1) 57 | return nn.Sequential(layers) 58 | 59 | 60 | class Bin_VGG_test(nn.Module): 61 | def __init__(self, vgg_name): 62 | super(Bin_VGG_test, self).__init__() 63 | self.features = self._make_layers(cfg[vgg_name]) 64 | self.classifier = nn.Linear(512, 10) 65 | 66 | def forward(self, x): 67 | out = self.features(x) 68 | out = out.view(out.size(0), -1) 69 | out = self.classifier(out) 70 | return out 71 | 72 | def _make_layers(self, cfg): 73 | layers = OrderedDict([ 74 | ('conv0', nn.Conv2d(3, 64, kernel_size=3, padding=1)), 75 | ('bn0', nn.BatchNorm2d(64)), 76 | ('relu0', nn.ReLU(inplace=True)) 77 | ]) 78 | in_channels = 64 79 | cnt = 1 80 | for x in cfg: 81 | if x == 'M': 82 | layers['pool'+str(cnt)] = nn.MaxPool2d(kernel_size=2, stride=2) 83 | cnt += 1 84 | else: 85 | layers['conv'+str(cnt)] = BinConv2d(in_channels=in_channels, out_channels=x, kernel_size=3, padding=1, istrain=False) 86 | cnt += 1 87 | layers['bn'+str(cnt)] = nn.BatchNorm2d(x) 88 | cnt += 1 89 | layers['relu'+str(cnt)] = nn.ReLU(inplace=True) 90 | cnt += 1 91 | in_channels = x 92 | layers['pool'+str(cnt)] = nn.AvgPool2d(kernel_size=1, stride=1) 93 | return nn.Sequential(layers) 94 | 95 | 96 | 97 | 98 | 99 | class NIN_train(nn.Module): 100 | def __init__(self): 101 | super(NIN_train, self).__init__() 102 | self.conv1 = nn.Conv2d(3, 192, kernel_size=5, stride=1, padding=2) 103 | self.bn1 = nn.BatchNorm2d(192, eps=1e-4, momentum=0.1, affine=False) 104 | self.conv2 = BinConv2d(192, 160, kernel_size=1, stride=1, padding=0, istrain=True) 105 | self.conv3 = BinConv2d(160, 96, kernel_size=1, stride=1, padding=0, istrain=True) 106 | self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 107 | self.conv4 = BinConv2d(96, 192, kernel_size=5, stride=1, padding=2, istrain=True, drop=0.5) 108 | self.conv5 = BinConv2d(192, 192, kernel_size=1, stride=1, padding=0, istrain=True) 109 | self.conv6 = BinConv2d(192, 192, kernel_size=1, stride=1, padding=0, istrain=True) 110 | self.pool2 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) 111 | self.conv7 = BinConv2d(192, 192, kernel_size=3, stride=1, padding=1, istrain=True, drop=0.5) 112 | self.conv8 = BinConv2d(192, 192, kernel_size=1, stride=1, padding=0, istrain=True) 113 | self.bn2 = nn.BatchNorm2d(192, eps=1e-4, momentum=0.1, affine=False) 114 | self.conv9 = nn.Conv2d(192, 10, kernel_size=1, stride=1, padding=0) 115 | self.pool3 = nn.AvgPool2d(kernel_size=8, stride=1, padding=0) 116 | 117 | for m in self.modules(): 118 | if isinstance(m, nn.Conv2d): 119 | m.weight.data.normal_(0, 0.05) 120 | m.bias.data.zero_() 121 | def forward(self, x): 122 | x = self.conv1(x) 123 | x = self.bn1(x) 124 | x = F.relu(x,inplace=True) 125 | x = self.conv2(x) 126 | x = F.relu(x, inplace=True) 127 | x = self.conv3(x) 128 | x = F.relu(x, inplace=True) 129 | x = self.pool1(x) 130 | x = self.conv4(x) 131 | x = F.relu(x, inplace=True) 132 | x = self.conv5(x) 133 | x = F.relu(x, inplace=True) 134 | x = self.conv6(x) 135 | x = F.relu(x, inplace=True) 136 | x = self.pool2(x) 137 | x = self.conv7(x) 138 | x = F.relu(x, inplace=True) 139 | x = self.conv8(x) 140 | x = F.relu(x, inplace=True) 141 | x = self.bn2(x) 142 | x = self.conv9(x) 143 | x = F.relu(x, inplace=True) 144 | x = self.pool3(x) 145 | return x.view(x.size(0), 10) 146 | class NIN_test(nn.Module): 147 | def __init__(self): 148 | super(NIN_test, self).__init__() 149 | self.conv1 = nn.Conv2d(3, 192, kernel_size=5, stride=1, padding=2) 150 | self.bn1 = nn.BatchNorm2d(192, eps=1e-4, momentum=0.1, affine=False) 151 | self.conv2 = BinConv2d(192, 160, kernel_size=1, stride=1, padding=0, istrain=False) 152 | self.conv3 = BinConv2d(160, 96, kernel_size=1, stride=1, padding=0, istrain=False) 153 | self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 154 | self.conv4 = BinConv2d(96, 192, kernel_size=5, stride=1, padding=2, istrain=False, drop=0.5) 155 | self.conv5 = BinConv2d(192, 192, kernel_size=1, stride=1, padding=0, istrain=False) 156 | self.conv6 = BinConv2d(192, 192, kernel_size=1, stride=1, padding=0, istrain=False) 157 | self.pool2 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) 158 | self.conv7 = BinConv2d(192, 192, kernel_size=3, stride=1, padding=1, istrain=False, drop=0.5) 159 | self.conv8 = BinConv2d(192, 192, kernel_size=1, stride=1, padding=0, istrain=False) 160 | self.bn2 = nn.BatchNorm2d(192, eps=1e-4, momentum=0.1, affine=False) 161 | self.conv9 = nn.Conv2d(192, 10, kernel_size=1, stride=1, padding=0) 162 | self.pool3 = nn.AvgPool2d(kernel_size=8, stride=1, padding=0) 163 | def forward(self, x): 164 | x = self.conv1(x) 165 | x = self.bn1(x) 166 | x = F.relu(x, inplace=True) 167 | x = self.conv2(x) 168 | x = F.relu(x, inplace=True) 169 | x = self.conv3(x) 170 | x = F.relu(x, inplace=True) 171 | x = self.pool1(x) 172 | x = self.conv4(x) 173 | x = F.relu(x, inplace=True) 174 | x = self.conv5(x) 175 | x = F.relu(x, inplace=True) 176 | x = self.conv6(x) 177 | x = F.relu(x, inplace=True) 178 | x = self.pool2(x) 179 | x = self.conv7(x) 180 | x = F.relu(x, inplace=True) 181 | x = self.conv8(x) 182 | x = F.relu(x, inplace=True) 183 | x = self.bn2(x) 184 | x = self.conv9(x) 185 | x = F.relu(x, inplace=True) 186 | x = self.pool3(x) 187 | return x.view(x.size(0), 10) 188 | 189 | -------------------------------------------------------------------------------- /Cifar10/models/VGG.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | 5 | cfg = { 6 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 7 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 8 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 9 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 10 | } 11 | 12 | 13 | class VGG(nn.Module): 14 | def __init__(self, vgg_name): 15 | super(VGG, self).__init__() 16 | self.features = self._make_layers(cfg[vgg_name]) 17 | self.classifier = nn.Linear(512, 10) 18 | 19 | def forward(self, x): 20 | out = self.features(x) 21 | out = out.view(out.size(0), -1) 22 | out = self.classifier(out) 23 | return out 24 | 25 | def _make_layers(self, cfg): 26 | layers = [] 27 | in_channels = 3 28 | for x in cfg: 29 | if x == 'M': 30 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 31 | else: 32 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1), 33 | nn.BatchNorm2d(x), 34 | nn.ReLU(inplace=True)] 35 | in_channels = x 36 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)] 37 | return nn.Sequential(*layers) 38 | 39 | 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2018, 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /MNIST/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | from torchvision import datasets, transforms 7 | from torch.autograd import Variable 8 | import os 9 | import sys 10 | import itertools 11 | import numpy as np 12 | import matplotlib.pyplot as plt 13 | from tqdm import tqdm 14 | tqdm.monitor_interval = 0 15 | from sklearn.metrics import classification_report 16 | from sklearn.metrics import confusion_matrix 17 | 18 | import models as models 19 | sys.path.append('../') 20 | import binop 21 | from util import binop_train 22 | from util import bin_save_state 23 | 24 | class RunningMean: 25 | def __init__(self, value=0, count=0): 26 | self.total_value = value 27 | self.count = count 28 | 29 | def update(self, value, count=1): 30 | self.total_value += value 31 | self.count += count 32 | 33 | @property 34 | def value(self): 35 | if self.count: 36 | return self.total_value / self.count 37 | else: 38 | return float("inf") 39 | 40 | def __str__(self): 41 | return str(self.value) 42 | 43 | def plot_confusion_matrix(cm, classes, normalize=False, 44 | title='Confusion matrix', cmap=plt.cm.Blues): 45 | """ 46 | This function prints and plots the confusion matrix. 47 | Normalization can be applied by setting `normalize=True`. 48 | """ 49 | 50 | if normalize: 51 | cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] 52 | print("Normalized confusion matrix") 53 | else: 54 | print('Confusion matrix, without normalization') 55 | 56 | print(cm) 57 | 58 | plt.imshow(cm, interpolation='nearest', cmap=cmap) 59 | plt.title(title) 60 | plt.colorbar() 61 | tick_marks = np.arange(len(classes)) 62 | plt.xticks(tick_marks, classes, rotation=45) 63 | plt.yticks(tick_marks, classes) 64 | 65 | fmt = '.2f' if normalize else 'd' 66 | thresh = cm.max() / 2. 67 | for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): 68 | plt.text(j, i, format(cm[i, j], fmt), 69 | horizontalalignment="center", 70 | color="white" if cm[i, j] > thresh else "black") 71 | 72 | plt.tight_layout() 73 | plt.ylabel('True label') 74 | plt.xlabel('Predicted label') 75 | 76 | def eval_test(y_pred, y_true): 77 | cnf_matrix = confusion_matrix(y_true, y_pred) 78 | np.set_printoptions(precision=6) 79 | 80 | # Plot non-normalized confusion matrix 81 | plt.figure() 82 | plot_confusion_matrix(cnf_matrix, classes=target_name, 83 | title='Confusion matrix, without normalization') 84 | 85 | # Plot normalized confusion matrix 86 | plt.figure() 87 | plot_confusion_matrix(cnf_matrix, classes=target_name, normalize=True, 88 | title='Normalized confusion matrix') 89 | 90 | print(classification_report(y_true=y_true, y_pred=y_pred, target_names=target_name, digits=6)) 91 | plt.show() 92 | 93 | 94 | def save_state(model): 95 | print('==> Saving model ...') 96 | torch.save(model.state_dict(), 'models/' + args.arch + '.pth') 97 | 98 | def train_bin(epoch): 99 | running_loss = RunningMean() 100 | running_score = RunningMean() 101 | model_train.train() 102 | pbar = tqdm(train_loader, total=len(train_loader)) 103 | 104 | for data, target in pbar: 105 | batch_size = data.size(0) 106 | if args.cuda: 107 | data, target = data.cuda(), target.cuda() 108 | data, target = Variable(data), Variable(target) 109 | 110 | optimizer.zero_grad() 111 | binop_train.binarization() 112 | 113 | output = model_train(data) 114 | _, preds = torch.max(output.data, dim=1) 115 | loss = criterion(output, target) 116 | running_loss.update(loss.data[0], 1) 117 | running_score.update(torch.sum(preds != target.data), batch_size) 118 | loss.backward() 119 | 120 | # restore weights 121 | binop_train.restore() 122 | # update 123 | binop_train.updateBinaryGradWeight() 124 | 125 | optimizer.step() 126 | pbar.set_description('%.6f %.6f' % (running_loss.value, running_score.value)) 127 | print('[+] epoch %d: \nTraining: Average loss: %.6f, Average error: %.6f' % ( 128 | epoch, running_loss.value, running_score.value)) 129 | bin_save_state(args, model_train) 130 | 131 | def train(epoch): 132 | running_loss = RunningMean() 133 | running_score = RunningMean() 134 | model_ori.train() 135 | pbar = tqdm(train_loader, total=len(train_loader),) 136 | for data, target in pbar: 137 | batch_size = data.size(0) 138 | if args.cuda: 139 | data, target = data.cuda(), target.cuda() 140 | data, target = Variable(data), Variable(target) 141 | 142 | optimizer.zero_grad() 143 | output = model_ori(data) 144 | _, preds = torch.max(output.data, dim=1) 145 | loss = criterion(output, target) 146 | running_loss.update(loss.data[0], 1) 147 | running_score.update(torch.sum(preds != target.data), batch_size) 148 | loss.backward() 149 | 150 | optimizer.step() 151 | 152 | pbar.set_description('%.6f %.6f' % (running_loss.value, running_score.value)) 153 | print('[+] epoch %d: \nTraining: Average loss: %.6f, Average error: %.6f' % ( 154 | epoch, running_loss.value, running_score.value)) 155 | save_state(model_ori) 156 | 157 | def test(model, evaluate=False): 158 | global best_acc 159 | test_loss = 0 160 | correct = 0 161 | if evaluate: 162 | model.load_state_dict(torch.load(args.pretrained)) 163 | else: 164 | model.load_state_dict(torch.load('models/' + args.arch + '.pth')) 165 | model.eval() 166 | pbar = tqdm(test_loader, total=len(test_loader)) 167 | if evaluate: 168 | pred = torch.LongTensor() 169 | true = torch.LongTensor() 170 | if args.cuda: 171 | pred = pred.cuda() 172 | true = true.cuda() 173 | for data, target in pbar: 174 | if args.cuda: 175 | data, target = data.cuda(), target.cuda() 176 | 177 | data, target = Variable(data), Variable(target) 178 | output = model(data) 179 | test_loss += criterion(output, target).data[0] 180 | if evaluate: 181 | pred = torch.cat((pred, output.data.max(1, keepdim=False)[1])) 182 | true = torch.cat((true, target.data)) 183 | else: 184 | pred = output.data.max(1, keepdim=False)[1] 185 | correct += pred.eq(target.data).cpu().sum() 186 | if evaluate: 187 | correct = pred.eq(true).cpu().sum() 188 | acc = 100. * correct / len(test_loader.dataset) 189 | test_loss /= len(test_loader.dataset) 190 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'.format( 191 | test_loss * args.batch_size, correct, len(test_loader.dataset), 192 | 100. * correct / len(test_loader.dataset))) 193 | 194 | if not evaluate: 195 | if (acc > best_acc): 196 | best_acc = acc 197 | os.rename('models/' + args.arch + '.pth', 'models/' + args.arch + '.best.pth') 198 | else: 199 | os.remove('models/' + args.arch + '.pth') 200 | print('Best Accuracy: {:.2f}%\n'.format(best_acc)) 201 | else: 202 | eval_test(y_pred=pred.tolist(), y_true=true.tolist()) 203 | 204 | 205 | def adjust_learning_rate(optimizer, epoch): 206 | """Sets the learning rate to the initial LR decayed by 10 every 15 epochs""" 207 | lr = args.lr * (0.1 ** (epoch // args.lr_epochs)) 208 | print('Learning rate:', lr) 209 | for param_group in optimizer.param_groups: 210 | param_group['lr'] = lr 211 | return lr 212 | 213 | 214 | if __name__ == '__main__': 215 | # Training settings 216 | 217 | parser = argparse.ArgumentParser(description='PyTorch XNOR MNIST') 218 | 219 | parser.add_argument('--batch-size', type=int, default=128, metavar='N', 220 | help='input batch size for training (default: 128)') 221 | parser.add_argument('--test-batch-size', type=int, default=128, metavar='N', 222 | help='input batch size for testing (default: 128)') 223 | parser.add_argument('--epochs', type=int, default=60, metavar='N', 224 | help='number of epochs to train (default: 100)') 225 | parser.add_argument('--lr-epochs', type=int, default=15, metavar='N', 226 | help='number of epochs to decay the lr (default: 20)') 227 | parser.add_argument('--lr', type=float, default=0.01, metavar='LR', 228 | help='learning rate (default: 0.01)') 229 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 230 | help='SGD momentum (default: 0.9)') 231 | parser.add_argument('--seed', type=int, default=1, metavar='S', 232 | help='random seed (default: 1)') 233 | parser.add_argument('--log-interval', type=int, default=100, metavar='N', 234 | help='how many batches to wait before logging training status') 235 | parser.add_argument('--arch', action='store', default='Bin_LeNet', 236 | help='the MNIST network structure: Bin_LeNet') 237 | parser.add_argument('--pretrained', action='store', default=None, 238 | help='pretrained model') 239 | parser.add_argument('--evaluate', action='store_true', default=False, 240 | help='whether to run evaluation') 241 | parser.add_argument('--no_cuda', action='store_true', default=False, 242 | help='disables CUDA training') 243 | args = parser.parse_args() 244 | args.cuda = not args.no_cuda and torch.cuda.is_available() 245 | target_name = ['0','1','2','3','4','5','6','7','8','9'] 246 | print(args) 247 | 248 | if args.cuda: 249 | torch.cuda.manual_seed(args.seed) 250 | else: 251 | torch.manual_seed(args.seed) 252 | 253 | # load data 254 | kwargs = {'num_workers': 1, 'pin_memory': True} 255 | train_loader = torch.utils.data.DataLoader( 256 | datasets.MNIST('data', train=True, download=True, 257 | transform=transforms.Compose([ 258 | transforms.ToTensor(), 259 | transforms.Normalize((0.1307,), (0.3081,)) 260 | ])), 261 | batch_size=args.batch_size, shuffle=True, **kwargs) 262 | test_loader = torch.utils.data.DataLoader( 263 | datasets.MNIST('data', train=False, transform=transforms.Compose([ 264 | transforms.ToTensor(), 265 | transforms.Normalize((0.1307,), (0.3081,)) 266 | ])), 267 | batch_size=args.test_batch_size, shuffle=True, **kwargs) 268 | 269 | best_acc = 0.0 270 | model_ori = None 271 | model_train = None 272 | model_test = None 273 | 274 | # generate the model 275 | if args.arch == 'LeNet': 276 | model_ori = models.LeNet() 277 | if args.cuda: 278 | model_ori.cuda() 279 | if args.pretrained: 280 | model_ori.load_state_dict(torch.load(args.pretrained)) 281 | 282 | 283 | elif args.arch == 'Bin_LeNet': 284 | model_train = models.Bin_LeNet_train() 285 | model_test = models.Bin_LeNet_test() 286 | if args.cuda: 287 | model_train = model_train.cuda() 288 | model_test = model_test.cuda() 289 | 290 | if args.pretrained: 291 | model_test.load_state_dict(torch.load(args.pretrained)) 292 | else: 293 | binop_train = binop_train(model_train) 294 | 295 | else: 296 | print('ERROR: specified arch is not suppported') 297 | exit() 298 | 299 | param_dict = dict(model_train.named_parameters()) if model_ori is None else dict(model_ori.named_parameters()) 300 | params = [] 301 | 302 | for key, value in param_dict.items(): 303 | if value.requires_grad: 304 | params += [{ 305 | 'params': [value], 306 | 'lr': args.lr, 307 | 'key': key 308 | }] 309 | optimizer = optim.Adam(params, lr=args.lr) 310 | criterion = nn.CrossEntropyLoss() 311 | 312 | 313 | if args.evaluate: 314 | if model_ori is None: 315 | test(model_test, evaluate=True) 316 | else: 317 | test(model_ori, evaluate=True) 318 | exit() 319 | 320 | for epoch in range(1, args.epochs + 1): 321 | adjust_learning_rate(optimizer, epoch) 322 | if model_ori is None: 323 | train_bin(epoch) 324 | test(model_test) 325 | else: 326 | train(epoch) 327 | test(model_ori) 328 | -------------------------------------------------------------------------------- /MNIST/models/Bin_LeNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import sys 5 | sys.path.append("..") 6 | from util import BinLinear 7 | from util import BinConv2d 8 | 9 | 10 | 11 | class Bin_LeNet_train(nn.Module): 12 | def __init__(self): 13 | super(Bin_LeNet_train, self).__init__() 14 | 15 | self.conv1 = nn.Conv2d(in_channels=1, out_channels=20, kernel_size=5, bias=False) 16 | self.bn1 = nn.BatchNorm2d(num_features=20) 17 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) 18 | self.conv2 = BinConv2d(in_channels=20, out_channels=50, kernel_size=5, bias=False, istrain=True) 19 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) 20 | self.fc1 = BinLinear(in_features=50 * 4 * 4, out_features=500, bias=False, istrain=True) 21 | self.fc2 = nn.Linear(in_features=500, out_features=10, bias=True) 22 | 23 | self.bn2 = nn.BatchNorm2d(num_features=50) 24 | self.bn3 = nn.BatchNorm1d(num_features=500) 25 | 26 | def forward(self, x): 27 | x = self.conv1(x) 28 | x = self.bn1(x) 29 | x = F.relu(x) 30 | x = self.pool1(x) 31 | 32 | x = self.conv2(x) 33 | x = self.bn2(x) 34 | x = F.relu(x) 35 | x = self.pool2(x) 36 | 37 | x = x.view(-1, 4 * 4 * 50) 38 | 39 | x = self.fc1(x) 40 | x = self.bn3(x) 41 | x = F.relu(x) 42 | return self.fc2(x) 43 | 44 | 45 | class Bin_LeNet_test(nn.Module): 46 | def __init__(self): 47 | super(Bin_LeNet_test, self).__init__() 48 | 49 | self.conv1 = nn.Conv2d(in_channels=1, out_channels=20, kernel_size=5, bias=False) 50 | self.bn1 = nn.BatchNorm2d(num_features=20) 51 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) 52 | self.conv2 = BinConv2d(in_channels=20, out_channels=50, kernel_size=5, bias=False, istrain=False) 53 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) 54 | self.fc1 = BinLinear(in_features=50 * 4 * 4, out_features=500, bias=False, istrain=False) 55 | self.fc2 = nn.Linear(in_features=500, out_features=10, bias=True) 56 | 57 | self.bn2 = nn.BatchNorm2d(num_features=50) 58 | self.bn3 = nn.BatchNorm1d(num_features=500) 59 | 60 | def forward(self, x): 61 | x = self.conv1(x) 62 | x = self.bn1(x) 63 | x = F.relu(x) 64 | x = self.pool1(x) 65 | 66 | x = self.conv2(x) 67 | x = self.bn2(x) 68 | x = F.relu(x) 69 | x = self.pool2(x) 70 | 71 | x = x.view(-1, 4 * 4 * 50) 72 | 73 | x = self.fc1(x) 74 | x = self.bn3(x) 75 | x = F.relu(x) 76 | return self.fc2(x) -------------------------------------------------------------------------------- /MNIST/models/LeNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class LeNet(nn.Module): 6 | def __init__(self): 7 | super(LeNet, self).__init__() 8 | 9 | self.conv1 = nn.Conv2d(in_channels=1, out_channels=20, kernel_size=5, bias=False) 10 | self.bn1 = nn.BatchNorm2d(num_features=20) 11 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) 12 | self.conv2 = nn.Conv2d(in_channels=20, out_channels=50, kernel_size=5, bias=False) 13 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) 14 | self.fc1 = nn.Linear(in_features=50 * 4 * 4, out_features=500, bias=False) 15 | 16 | self.fc2 = nn.Linear(in_features=500, out_features=10, bias=True) 17 | 18 | self.bn2 = nn.BatchNorm2d(num_features=50) 19 | self.bn3 = nn.BatchNorm1d(num_features=500) 20 | 21 | def forward(self, x): 22 | x = self.conv1(x) 23 | x = self.bn1(x) 24 | x = F.relu(x) 25 | x = self.pool1(x) 26 | 27 | x = self.conv2(x) 28 | x = self.bn2(x) 29 | x = F.relu(x) 30 | x = self.pool2(x) 31 | 32 | x = x.view(-1, 4 * 4 * 50) 33 | 34 | x = self.fc1(x) 35 | x = self.bn3(x) 36 | x = F.relu(x) 37 | return self.fc2(x) -------------------------------------------------------------------------------- /MNIST/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .LeNet import LeNet 2 | from .Bin_LeNet import Bin_LeNet_test 3 | from .Bin_LeNet import Bin_LeNet_train -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch-XNOR-Net 2 | 3 | # Build 4 | ~~~shell 5 | cd /csrc/binop 6 | make 7 | ~~~ 8 | # MNIST 9 | 10 | ## Usage 11 | ### Train: 12 | ~~~shell 13 | cd /MNIST/ 14 | python3 main.py --arch Bin_LeNet 15 | python3 main.py --arch LeNet 16 | ~~~ 17 | ### Evaluate: 18 | ~~~shell 19 | cd /MNIST/ 20 | python3 main.py --arch Bin_LeNet --evaluate --pretrained ./models/Bin_LeNet.best.pth # --no_cuda (Use CPU) 21 | python3 main.py --arch LeNet --evaluate --pretrained ./models/LeNet.best.pth # --no_cuda (Use CPU) 22 | ~~~ 23 | ## Result 24 | | Network | Accuracy | Size | 25 | | ------- | -------- | ---- | 26 | | LeNet | 99.50% | 1.7 MB | 27 | | Bin_LeNet | 99.45% | 102 KB | 28 | 29 | 30 | # Cifar10 31 | 32 | ## Usage 33 | ### Train: 34 | ~~~shell 35 | cd /Cifar10/ 36 | python3 main.py --arch Bin_VGG16 #(11, 13, 16, 19) 37 | python3 main.py --arch VGG16 #(11, 13, 16, 19) 38 | ~~~ 39 | ### Evaluate: 40 | ~~~shell 41 | cd /Cifar10/ 42 | python3 main.py --arch Bin_VGG16 --evaluate --pretrained ./models/Bin_VGG16.best.pth # --no_cuda (Use CPU) 43 | python3 main.py --arch VGG16 --evaluate --pretrained ./models/VGG16.best.pth # --no_cuda (Use CPU) 44 | ~~~ 45 | ## Result 46 | | Network | Accuracy | Size   | 47 | | ------- | -------- | ---- | 48 | | VGG13 | 92.40% | 37.7 MB | 49 | | Bin_VGG13 | 88.74% | 1.3 MB | 50 | | VGG16 | 92.29% | 59.0 MB | 51 | | Bin_VGG16 | 87.78% | 2.0 MB | 52 | 53 | # Pre-trained models 54 | [Google Drive](https://drive.google.com/open?id=13KAF89w1-OnGTgHlhblnzBafpz-sTCVT) 55 | 56 | # Environment 57 | ## Software 58 | * Ubuntu 16.04 59 | * Python 3.5 60 | * Pytorch 0.3.1 61 | * CUDA 8.0 62 | * gcc 5.4 63 | 64 | ## Hardware 65 | 66 | * NVIDIA GTX 1080 67 | * Intel i5-6500 CPU @ 3.20GHz × 4 68 | 69 | 70 | # Reference 71 | * [Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1](https://arxiv.org/pdf/1602.02830.pdf) 72 | * [XNOR-Net: ImageNet Classification Using Binary Convolutional Neural Networks](https://arxiv.org/pdf/1603.05279.pdf) 73 | * https://github.com/jiecaoyu/XNOR-Net-PyTorch 74 | * [cpu-gemm](http://apfel.mathematik.uni-ulm.de/~lehn/sghpc/gemm/page02/index.html) 75 | * [cpu-conv2d](https://github.com/pytorch/pytorch/blob/f23feca681c5066c70f0fe1516fc2e269d615e93/aten/src/THNN/generic/SpatialConvolutionMM.c) 76 | * [gpu-gemm and gpu-conv2d](https://github.com/1adrianb/bnn.torch/blob/master/BinarySpatialConvolution.cu) 77 | * [popcount](https://github.com/kimwalisch/libpopcnt) 78 | -------------------------------------------------------------------------------- /csrc/binop/Makefile: -------------------------------------------------------------------------------- 1 | # Unix commands. 2 | PYTHON := python3 3 | NVCC_COMPILE := nvcc -c -o 4 | RM_RF := rm -rf 5 | 6 | # Library compilation rules. 7 | NVCC_FLAGS := -x cu -Xcompiler -fPIC -shared -arch=sm_61 \ 8 | -gencode=arch=compute_50,code=sm_50 \ 9 | -gencode=arch=compute_52,code=sm_52 \ 10 | -gencode=arch=compute_60,code=sm_60 \ 11 | -gencode=arch=compute_61,code=sm_61 \ 12 | -gencode=arch=compute_61,code=compute_61 13 | 14 | # File structure. 15 | BUILD_DIR := build 16 | INCLUDE_DIRS := include 17 | TORCH_FFI_BUILD := build.py 18 | MATHUTIL_KERNEL := $(BUILD_DIR)/binop_cuda_kernel.so 19 | TORCH_FFI_TARGET := $(BUILD_DIR)/binop/_binop.so 20 | 21 | INCLUDE_FLAGS := $(foreach d, $(INCLUDE_DIRS), -I$d) 22 | 23 | all: $(TORCH_FFI_TARGET) 24 | 25 | $(TORCH_FFI_TARGET): $(MATHUTIL_KERNEL) $(TORCH_FFI_BUILD) 26 | $(PYTHON) $(TORCH_FFI_BUILD) 27 | 28 | $(BUILD_DIR)/%.so: src/%.cu 29 | @ mkdir -p $(BUILD_DIR) 30 | # Separate cpp shared library that will be loaded to the extern C ffi 31 | $(NVCC_COMPILE) $@ $? $(NVCC_FLAGS) $(INCLUDE_FLAGS) 32 | 33 | clean: 34 | $(RM_RF) $(BUILD_DIR) $(MATHUTIL_KERNEL) 35 | -------------------------------------------------------------------------------- /csrc/binop/build.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import torch 3 | from os import path as osp 4 | from torch.utils.ffi import create_extension 5 | 6 | abs_path = osp.dirname(osp.realpath(__file__)) 7 | extra_objects = [] 8 | sources = ['src/binop.c'] 9 | headers = ['include/binop.h'] 10 | defines = [] 11 | with_cuda = False 12 | 13 | if torch.cuda.is_available(): 14 | extra_objects += [osp.join(abs_path, 'build/binop_cuda_kernel.so')] 15 | extra_objects += glob.glob('/usr/local/cuda/lib64/*.a') 16 | sources += ['src/binop_cuda.c'] 17 | headers += ['include/binop_cuda.h'] 18 | defines += [('WITH_CUDA', None)] 19 | with_cuda = True 20 | ffi = create_extension( 21 | 'binop', 22 | headers=headers, 23 | sources=sources, 24 | define_macros=defines, 25 | relative_to=__file__, 26 | with_cuda=with_cuda, 27 | extra_objects=extra_objects, 28 | include_dirs=[osp.join(abs_path, 'include')], 29 | extra_compile_args=["-std=c99", "-Ofast", "-fopenmp", "-mtune=native", "-march=x86-64"] 30 | ) 31 | 32 | if __name__ == '__main__': 33 | ffi.build() 34 | -------------------------------------------------------------------------------- /csrc/binop/include/binop.h: -------------------------------------------------------------------------------- 1 | void encode_rows_cpu(THFloatTensor* input, THIntTensor* output); 2 | void encode_cols_cpu(THFloatTensor* input, THIntTensor* output); 3 | void binary_gemm_cpu(THIntTensor* a, THIntTensor* b, THFloatTensor* c, int m, int nn, int k, int transb, int beta, int alpha, THFloatTensor* alphas); 4 | void THNN_Bin_SpatialConvolutionMM_updateOutput( 5 | THFloatTensor *input, 6 | THFloatTensor *output, 7 | THIntTensor *weight, 8 | THFloatTensor *bias, 9 | THFloatTensor *columns, 10 | THFloatTensor *alphas, 11 | int kH, int kW, 12 | int dH, int dW, 13 | int padH, int padW); -------------------------------------------------------------------------------- /csrc/binop/include/binop_cuda.h: -------------------------------------------------------------------------------- 1 | void binary_gemm(THCudaIntTensor* weight, THCudaIntTensor* columns_binary, THCudaTensor* output_n, int m, int nn, int k, int transb, int alpha, int beta, THCudaTensor *alphas); 2 | 3 | void im2col(THCudaTensor* data_im, int channels, int height, int width, int ksize_h, int ksize_w, int pad_h, int pad_w, int stride_h, int stride_w, int dilation_h, int dilation_w, THCudaTensor* data_col); 4 | 5 | void encode_rows(THCudaTensor* input, THCudaIntTensor* output); 6 | 7 | void encode_cols(THCudaTensor* input, THCudaIntTensor* output); 8 | 9 | void BinarySpatialConvolution_updateOutput( 10 | THCudaTensor *input, THCudaTensor *output, THCudaIntTensor *weight, THCudaTensor *columns, 11 | THCudaTensor *bias, THCudaTensor *alphas, int nInputPlane, 12 | int kH, int kW, int sH, int sW, int padH, int padW); 13 | -------------------------------------------------------------------------------- /csrc/binop/include/binop_cuda_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef _BINCONV_CUDA_KERNEL 2 | #define _BINCONV_CUDA_KERNEL 3 | 4 | #define BLOCK_SIZE 16 5 | #define BLOCK_DIM 16 6 | #define CUDA_NUM_THREADS 1024 7 | #define ENCODE_BITS 32 8 | 9 | #ifdef __cplusplus 10 | extern "C" { 11 | #endif 12 | 13 | void binary_gemm_cuda(uint32_t* A, uint32_t* B, float* C, int m, int nn, int k, int transb, int alpha, int beta, float *alphas, cudaStream_t stream); 14 | 15 | void im2col_cuda(int n, float* data_im, int height, int width, 16 | int ksize_h, int ksize_w, int pad_h, int pad_w, 17 | int stride_h, int stride_w, int dilation_h, int dilation_w, 18 | int height_col, int width_col, float* data_col, cudaStream_t stream); 19 | 20 | void encode_rows_cuda(float* input, uint32_t* output, int m, int n, int l, cudaStream_t stream); 21 | void encode_cols_cuda(float* input, uint32_t* output, int n, int k, cudaStream_t stream); 22 | 23 | #ifdef __cplusplus 24 | } 25 | #endif 26 | 27 | #endif 28 | -------------------------------------------------------------------------------- /csrc/binop/include/libpopcnt.h: -------------------------------------------------------------------------------- 1 | /* 2 | * libpopcnt.h - C/C++ library for counting the number of 1 bits (bit 3 | * population count) in an array as quickly as possible using 4 | * specialized CPU instructions i.e. POPCNT, AVX2, AVX512, NEON. 5 | * 6 | * Copyright (c) 2016 - 2018, Kim Walisch 7 | * Copyright (c) 2016 - 2018, Wojciech Muła 8 | * 9 | * All rights reserved. 10 | * 11 | * Redistribution and use in source and binary forms, with or without 12 | * modification, are permitted provided that the following conditions are met: 13 | * 14 | * 1. Redistributions of source code must retain the above copyright notice, this 15 | * list of conditions and the following disclaimer. 16 | * 2. Redistributions in binary form must reproduce the above copyright notice, 17 | * this list of conditions and the following disclaimer in the documentation 18 | * and/or other materials provided with the distribution. 19 | * 20 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 21 | * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 22 | * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 24 | * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 25 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 26 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 27 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 28 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 29 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | */ 31 | 32 | #ifndef LIBPOPCNT_H 33 | #define LIBPOPCNT_H 34 | 35 | #include 36 | 37 | #ifndef __has_builtin 38 | #define __has_builtin(x) 0 39 | #endif 40 | 41 | #ifndef __has_attribute 42 | #define __has_attribute(x) 0 43 | #endif 44 | 45 | #ifdef __GNUC__ 46 | #define GNUC_PREREQ(x, y) \ 47 | (__GNUC__ > x || (__GNUC__ == x && __GNUC_MINOR__ >= y)) 48 | #else 49 | #define GNUC_PREREQ(x, y) 0 50 | #endif 51 | 52 | #ifdef __clang__ 53 | #define CLANG_PREREQ(x, y) \ 54 | (__clang_major__ > x || (__clang_major__ == x && __clang_minor__ >= y)) 55 | #else 56 | #define CLANG_PREREQ(x, y) 0 57 | #endif 58 | 59 | #if (_MSC_VER < 1900) && \ 60 | !defined(__cplusplus) 61 | #define inline __inline 62 | #endif 63 | 64 | #if (defined(__i386__) || \ 65 | defined(__x86_64__) || \ 66 | defined(_M_IX86) || \ 67 | defined(_M_X64)) 68 | #define X86_OR_X64 69 | #endif 70 | 71 | #if defined(X86_OR_X64) && \ 72 | (defined(__cplusplus) || \ 73 | defined(_MSC_VER) || \ 74 | (GNUC_PREREQ(4, 2) || \ 75 | __has_builtin(__sync_val_compare_and_swap))) 76 | #define HAVE_CPUID 77 | #endif 78 | 79 | #if GNUC_PREREQ(4, 2) || \ 80 | __has_builtin(__builtin_popcount) 81 | #define HAVE_BUILTIN_POPCOUNT 82 | #endif 83 | 84 | #if GNUC_PREREQ(4, 2) || \ 85 | CLANG_PREREQ(3, 0) 86 | #define HAVE_ASM_POPCNT 87 | #endif 88 | 89 | #if defined(HAVE_CPUID) && \ 90 | (defined(HAVE_ASM_POPCNT) || \ 91 | defined(_MSC_VER)) 92 | #define HAVE_POPCNT 93 | #endif 94 | 95 | #if defined(HAVE_CPUID) && \ 96 | GNUC_PREREQ(4, 9) 97 | #define HAVE_AVX2 98 | #endif 99 | 100 | #if defined(HAVE_CPUID) && \ 101 | GNUC_PREREQ(5, 0) 102 | #define HAVE_AVX512 103 | #endif 104 | 105 | #if defined(HAVE_CPUID) && \ 106 | defined(_MSC_VER) && \ 107 | defined(__AVX2__) 108 | #define HAVE_AVX2 109 | #endif 110 | 111 | #if defined(HAVE_CPUID) && \ 112 | defined(_MSC_VER) && \ 113 | defined(__AVX512__) 114 | #define HAVE_AVX512 115 | #endif 116 | 117 | #if defined(HAVE_CPUID) && \ 118 | CLANG_PREREQ(3, 8) && \ 119 | __has_attribute(target) && \ 120 | (!defined(_MSC_VER) || defined(__AVX2__)) && \ 121 | (!defined(__apple_build_version__) || __apple_build_version__ >= 8000000) 122 | #define HAVE_AVX2 123 | #define HAVE_AVX512 124 | #endif 125 | 126 | /* 127 | * This uses fewer arithmetic operations than any other known 128 | * implementation on machines with fast multiplication. 129 | * It uses 12 arithmetic operations, one of which is a multiply. 130 | * http://en.wikipedia.org/wiki/Hamming_weight#Efficient_implementation 131 | */ 132 | static inline uint64_t popcount64(uint64_t x) 133 | { 134 | uint64_t m1 = 0x5555555555555555ll; 135 | uint64_t m2 = 0x3333333333333333ll; 136 | uint64_t m4 = 0x0F0F0F0F0F0F0F0Fll; 137 | uint64_t h01 = 0x0101010101010101ll; 138 | 139 | x -= (x >> 1) & m1; 140 | x = (x & m2) + ((x >> 2) & m2); 141 | x = (x + (x >> 4)) & m4; 142 | 143 | return (x * h01) >> 56; 144 | } 145 | 146 | #if defined(HAVE_ASM_POPCNT) && \ 147 | defined(__x86_64__) 148 | 149 | static inline uint64_t popcnt64(uint64_t x) 150 | { 151 | __asm__ ("popcnt %1, %0" : "=r" (x) : "0" (x)); 152 | return x; 153 | } 154 | 155 | #elif defined(HAVE_ASM_POPCNT) && \ 156 | defined(__i386__) 157 | 158 | static inline uint32_t popcnt32(uint32_t x) 159 | { 160 | __asm__ ("popcnt %1, %0" : "=r" (x) : "0" (x)); 161 | return x; 162 | } 163 | 164 | static inline uint64_t popcnt64(uint64_t x) 165 | { 166 | return popcnt32((uint32_t) x) + 167 | popcnt32((uint32_t)(x >> 32)); 168 | } 169 | 170 | #elif defined(_MSC_VER) && \ 171 | defined(_M_X64) 172 | 173 | #include 174 | 175 | static inline uint64_t popcnt64(uint64_t x) 176 | { 177 | return _mm_popcnt_u64(x); 178 | } 179 | 180 | #elif defined(_MSC_VER) && \ 181 | defined(_M_IX86) 182 | 183 | #include 184 | 185 | static inline uint64_t popcnt64(uint64_t x) 186 | { 187 | return _mm_popcnt_u32((uint32_t) x) + 188 | _mm_popcnt_u32((uint32_t)(x >> 32)); 189 | } 190 | 191 | /* non x86 CPUs */ 192 | #elif defined(HAVE_BUILTIN_POPCOUNT) 193 | 194 | static inline uint64_t popcnt64(uint64_t x) 195 | { 196 | return __builtin_popcountll(x); 197 | } 198 | 199 | /* no hardware POPCNT, 200 | * use pure integer algorithm */ 201 | #else 202 | 203 | static inline uint64_t popcnt64(uint64_t x) 204 | { 205 | return popcount64(x); 206 | } 207 | 208 | #endif 209 | 210 | static inline uint64_t popcnt64_unrolled(const uint64_t* data, uint64_t size) 211 | { 212 | uint64_t i = 0; 213 | uint64_t limit = size - size % 4; 214 | uint64_t cnt = 0; 215 | 216 | for (; i < limit; i += 4) 217 | { 218 | cnt += popcnt64(data[i+0]); 219 | cnt += popcnt64(data[i+1]); 220 | cnt += popcnt64(data[i+2]); 221 | cnt += popcnt64(data[i+3]); 222 | } 223 | 224 | for (; i < size; i++) 225 | cnt += popcnt64(data[i]); 226 | 227 | return cnt; 228 | } 229 | 230 | #if defined(HAVE_CPUID) 231 | 232 | #if defined(_MSC_VER) 233 | #include 234 | #include 235 | #endif 236 | 237 | /* %ecx bit flags */ 238 | #define bit_POPCNT (1 << 23) 239 | 240 | /* %ebx bit flags */ 241 | #define bit_AVX2 (1 << 5) 242 | #define bit_AVX512 (1 << 30) 243 | 244 | /* xgetbv bit flags */ 245 | #define XSTATE_SSE (1 << 1) 246 | #define XSTATE_YMM (1 << 2) 247 | #define XSTATE_ZMM (7 << 5) 248 | 249 | static inline void run_cpuid(int eax, int ecx, int* abcd) 250 | { 251 | #if defined(_MSC_VER) 252 | __cpuidex(abcd, eax, ecx); 253 | #else 254 | int ebx = 0; 255 | int edx = 0; 256 | 257 | #if defined(__i386__) && \ 258 | defined(__PIC__) 259 | /* in case of PIC under 32-bit EBX cannot be clobbered */ 260 | __asm__ ("movl %%ebx, %%edi;" 261 | "cpuid;" 262 | "xchgl %%ebx, %%edi;" 263 | : "=D" (ebx), 264 | "+a" (eax), 265 | "+c" (ecx), 266 | "=d" (edx)); 267 | #else 268 | __asm__ ("cpuid;" 269 | : "+b" (ebx), 270 | "+a" (eax), 271 | "+c" (ecx), 272 | "=d" (edx)); 273 | #endif 274 | 275 | abcd[0] = eax; 276 | abcd[1] = ebx; 277 | abcd[2] = ecx; 278 | abcd[3] = edx; 279 | #endif 280 | } 281 | 282 | #if defined(HAVE_AVX2) || \ 283 | defined(HAVE_AVX512) 284 | 285 | static inline int get_xcr0(void) 286 | { 287 | int xcr0; 288 | 289 | #if defined(_MSC_VER) 290 | xcr0 = (int) _xgetbv(0); 291 | #else 292 | __asm__ ("xgetbv" : "=a" (xcr0) : "c" (0) : "%edx" ); 293 | #endif 294 | 295 | return xcr0; 296 | } 297 | 298 | #endif 299 | 300 | static inline int get_cpuid(void) 301 | { 302 | int flags = 0; 303 | int abcd[4]; 304 | 305 | run_cpuid(1, 0, abcd); 306 | 307 | if ((abcd[2] & bit_POPCNT) == bit_POPCNT) 308 | flags |= bit_POPCNT; 309 | 310 | #if defined(HAVE_AVX2) || \ 311 | defined(HAVE_AVX512) 312 | 313 | int osxsave_mask = (1 << 27); 314 | 315 | /* ensure OS supports extended processor state management */ 316 | if ((abcd[2] & osxsave_mask) != osxsave_mask) 317 | return 0; 318 | 319 | int ymm_mask = XSTATE_SSE | XSTATE_YMM; 320 | int zmm_mask = XSTATE_SSE | XSTATE_YMM | XSTATE_ZMM; 321 | 322 | int xcr0 = get_xcr0(); 323 | 324 | if ((xcr0 & ymm_mask) == ymm_mask) 325 | { 326 | run_cpuid(7, 0, abcd); 327 | 328 | if ((abcd[1] & bit_AVX2) == bit_AVX2) 329 | flags |= bit_AVX2; 330 | 331 | if ((xcr0 & zmm_mask) == zmm_mask) 332 | { 333 | if ((abcd[1] & bit_AVX512) == bit_AVX512) 334 | flags |= bit_AVX512; 335 | } 336 | } 337 | 338 | #endif 339 | 340 | return flags; 341 | } 342 | 343 | #endif /* cpuid */ 344 | 345 | #if defined(HAVE_AVX2) 346 | 347 | #include 348 | 349 | #if !defined(_MSC_VER) 350 | __attribute__ ((target ("avx2"))) 351 | #endif 352 | static inline void CSA256(__m256i* h, __m256i* l, __m256i a, __m256i b, __m256i c) 353 | { 354 | __m256i u = _mm256_xor_si256(a, b); 355 | *h = _mm256_or_si256(_mm256_and_si256(a, b), _mm256_and_si256(u, c)); 356 | *l = _mm256_xor_si256(u, c); 357 | } 358 | 359 | #if !defined(_MSC_VER) 360 | __attribute__ ((target ("avx2"))) 361 | #endif 362 | static inline __m256i popcnt256(__m256i v) 363 | { 364 | __m256i lookup1 = _mm256_setr_epi8( 365 | 4, 5, 5, 6, 5, 6, 6, 7, 366 | 5, 6, 6, 7, 6, 7, 7, 8, 367 | 4, 5, 5, 6, 5, 6, 6, 7, 368 | 5, 6, 6, 7, 6, 7, 7, 8 369 | ); 370 | 371 | __m256i lookup2 = _mm256_setr_epi8( 372 | 4, 3, 3, 2, 3, 2, 2, 1, 373 | 3, 2, 2, 1, 2, 1, 1, 0, 374 | 4, 3, 3, 2, 3, 2, 2, 1, 375 | 3, 2, 2, 1, 2, 1, 1, 0 376 | ); 377 | 378 | __m256i low_mask = _mm256_set1_epi8(0x0f); 379 | __m256i lo = _mm256_and_si256(v, low_mask); 380 | __m256i hi = _mm256_and_si256(_mm256_srli_epi16(v, 4), low_mask); 381 | __m256i popcnt1 = _mm256_shuffle_epi8(lookup1, lo); 382 | __m256i popcnt2 = _mm256_shuffle_epi8(lookup2, hi); 383 | 384 | return _mm256_sad_epu8(popcnt1, popcnt2); 385 | } 386 | 387 | /* 388 | * AVX2 Harley-Seal popcount (4th iteration). 389 | * The algorithm is based on the paper "Faster Population Counts 390 | * using AVX2 Instructions" by Daniel Lemire, Nathan Kurz and 391 | * Wojciech Mula (23 Nov 2016). 392 | * @see https://arxiv.org/abs/1611.07612 393 | */ 394 | #if !defined(_MSC_VER) 395 | __attribute__ ((target ("avx2"))) 396 | #endif 397 | static inline uint64_t popcnt_avx2(const __m256i* data, uint64_t size) 398 | { 399 | __m256i cnt = _mm256_setzero_si256(); 400 | __m256i ones = _mm256_setzero_si256(); 401 | __m256i twos = _mm256_setzero_si256(); 402 | __m256i fours = _mm256_setzero_si256(); 403 | __m256i eights = _mm256_setzero_si256(); 404 | __m256i sixteens = _mm256_setzero_si256(); 405 | __m256i twosA, twosB, foursA, foursB, eightsA, eightsB; 406 | 407 | uint64_t i = 0; 408 | uint64_t limit = size - size % 16; 409 | uint64_t* cnt64; 410 | 411 | for(; i < limit; i += 16) 412 | { 413 | CSA256(&twosA, &ones, ones, data[i+0], data[i+1]); 414 | CSA256(&twosB, &ones, ones, data[i+2], data[i+3]); 415 | CSA256(&foursA, &twos, twos, twosA, twosB); 416 | CSA256(&twosA, &ones, ones, data[i+4], data[i+5]); 417 | CSA256(&twosB, &ones, ones, data[i+6], data[i+7]); 418 | CSA256(&foursB, &twos, twos, twosA, twosB); 419 | CSA256(&eightsA, &fours, fours, foursA, foursB); 420 | CSA256(&twosA, &ones, ones, data[i+8], data[i+9]); 421 | CSA256(&twosB, &ones, ones, data[i+10], data[i+11]); 422 | CSA256(&foursA, &twos, twos, twosA, twosB); 423 | CSA256(&twosA, &ones, ones, data[i+12], data[i+13]); 424 | CSA256(&twosB, &ones, ones, data[i+14], data[i+15]); 425 | CSA256(&foursB, &twos, twos, twosA, twosB); 426 | CSA256(&eightsB, &fours, fours, foursA, foursB); 427 | CSA256(&sixteens, &eights, eights, eightsA, eightsB); 428 | 429 | cnt = _mm256_add_epi64(cnt, popcnt256(sixteens)); 430 | } 431 | 432 | cnt = _mm256_slli_epi64(cnt, 4); 433 | cnt = _mm256_add_epi64(cnt, _mm256_slli_epi64(popcnt256(eights), 3)); 434 | cnt = _mm256_add_epi64(cnt, _mm256_slli_epi64(popcnt256(fours), 2)); 435 | cnt = _mm256_add_epi64(cnt, _mm256_slli_epi64(popcnt256(twos), 1)); 436 | cnt = _mm256_add_epi64(cnt, popcnt256(ones)); 437 | 438 | for(; i < size; i++) 439 | cnt = _mm256_add_epi64(cnt, popcnt256(data[i])); 440 | 441 | cnt64 = (uint64_t*) &cnt; 442 | 443 | return cnt64[0] + 444 | cnt64[1] + 445 | cnt64[2] + 446 | cnt64[3]; 447 | } 448 | 449 | /* Align memory to 32 bytes boundary */ 450 | static inline void align_avx2(const uint8_t** p, uint64_t* size, uint64_t* cnt) 451 | { 452 | for (; (uintptr_t) *p % 8; (*p)++) 453 | { 454 | *cnt += popcnt64(**p); 455 | *size -= 1; 456 | } 457 | for (; (uintptr_t) *p % 32; (*p) += 8) 458 | { 459 | *cnt += popcnt64( 460 | *(const uint64_t*) *p); 461 | *size -= 8; 462 | } 463 | } 464 | 465 | #endif 466 | 467 | #if defined(HAVE_AVX512) 468 | 469 | #include 470 | 471 | #if !defined(_MSC_VER) 472 | __attribute__ ((target ("avx512bw"))) 473 | #endif 474 | static inline __m512i popcnt512(__m512i v) 475 | { 476 | __m512i m1 = _mm512_set1_epi8(0x55); 477 | __m512i m2 = _mm512_set1_epi8(0x33); 478 | __m512i m4 = _mm512_set1_epi8(0x0F); 479 | __m512i t1 = _mm512_sub_epi8(v, (_mm512_srli_epi16(v, 1) & m1)); 480 | __m512i t2 = _mm512_add_epi8(t1 & m2, (_mm512_srli_epi16(t1, 2) & m2)); 481 | __m512i t3 = _mm512_add_epi8(t2, _mm512_srli_epi16(t2, 4)) & m4; 482 | 483 | return _mm512_sad_epu8(t3, _mm512_setzero_si512()); 484 | } 485 | 486 | #if !defined(_MSC_VER) 487 | __attribute__ ((target ("avx512bw"))) 488 | #endif 489 | static inline void CSA512(__m512i* h, __m512i* l, __m512i a, __m512i b, __m512i c) 490 | { 491 | *l = _mm512_ternarylogic_epi32(c, b, a, 0x96); 492 | *h = _mm512_ternarylogic_epi32(c, b, a, 0xe8); 493 | } 494 | 495 | /* 496 | * AVX512 Harley-Seal popcount (4th iteration). 497 | * The algorithm is based on the paper "Faster Population Counts 498 | * using AVX2 Instructions" by Daniel Lemire, Nathan Kurz and 499 | * Wojciech Mula (23 Nov 2016). 500 | * @see https://arxiv.org/abs/1611.07612 501 | */ 502 | #if !defined(_MSC_VER) 503 | __attribute__ ((target ("avx512bw"))) 504 | #endif 505 | static inline uint64_t popcnt_avx512(const __m512i* data, const uint64_t size) 506 | { 507 | __m512i cnt = _mm512_setzero_si512(); 508 | __m512i ones = _mm512_setzero_si512(); 509 | __m512i twos = _mm512_setzero_si512(); 510 | __m512i fours = _mm512_setzero_si512(); 511 | __m512i eights = _mm512_setzero_si512(); 512 | __m512i sixteens = _mm512_setzero_si512(); 513 | __m512i twosA, twosB, foursA, foursB, eightsA, eightsB; 514 | 515 | uint64_t i = 0; 516 | uint64_t limit = size - size % 16; 517 | uint64_t* cnt64; 518 | 519 | for(; i < limit; i += 16) 520 | { 521 | CSA512(&twosA, &ones, ones, data[i+0], data[i+1]); 522 | CSA512(&twosB, &ones, ones, data[i+2], data[i+3]); 523 | CSA512(&foursA, &twos, twos, twosA, twosB); 524 | CSA512(&twosA, &ones, ones, data[i+4], data[i+5]); 525 | CSA512(&twosB, &ones, ones, data[i+6], data[i+7]); 526 | CSA512(&foursB, &twos, twos, twosA, twosB); 527 | CSA512(&eightsA, &fours, fours, foursA, foursB); 528 | CSA512(&twosA, &ones, ones, data[i+8], data[i+9]); 529 | CSA512(&twosB, &ones, ones, data[i+10], data[i+11]); 530 | CSA512(&foursA, &twos, twos, twosA, twosB); 531 | CSA512(&twosA, &ones, ones, data[i+12], data[i+13]); 532 | CSA512(&twosB, &ones, ones, data[i+14], data[i+15]); 533 | CSA512(&foursB, &twos, twos, twosA, twosB); 534 | CSA512(&eightsB, &fours, fours, foursA, foursB); 535 | CSA512(&sixteens, &eights, eights, eightsA, eightsB); 536 | 537 | cnt = _mm512_add_epi64(cnt, popcnt512(sixteens)); 538 | } 539 | 540 | cnt = _mm512_slli_epi64(cnt, 4); 541 | cnt = _mm512_add_epi64(cnt, _mm512_slli_epi64(popcnt512(eights), 3)); 542 | cnt = _mm512_add_epi64(cnt, _mm512_slli_epi64(popcnt512(fours), 2)); 543 | cnt = _mm512_add_epi64(cnt, _mm512_slli_epi64(popcnt512(twos), 1)); 544 | cnt = _mm512_add_epi64(cnt, popcnt512(ones)); 545 | 546 | for(; i < size; i++) 547 | cnt = _mm512_add_epi64(cnt, popcnt512(data[i])); 548 | 549 | cnt64 = (uint64_t*) &cnt; 550 | 551 | return cnt64[0] + 552 | cnt64[1] + 553 | cnt64[2] + 554 | cnt64[3] + 555 | cnt64[4] + 556 | cnt64[5] + 557 | cnt64[6] + 558 | cnt64[7]; 559 | } 560 | 561 | /* Align memory to 64 bytes boundary */ 562 | static inline void align_avx512(const uint8_t** p, uint64_t* size, uint64_t* cnt) 563 | { 564 | for (; (uintptr_t) *p % 8; (*p)++) 565 | { 566 | *cnt += popcnt64(**p); 567 | *size -= 1; 568 | } 569 | for (; (uintptr_t) *p % 64; (*p) += 8) 570 | { 571 | *cnt += popcnt64( 572 | *(const uint64_t*) *p); 573 | *size -= 8; 574 | } 575 | } 576 | 577 | #endif 578 | 579 | /* x86 CPUs */ 580 | #if defined(X86_OR_X64) 581 | 582 | /* Align memory to 8 bytes boundary */ 583 | static inline void align_8(const uint8_t** p, uint64_t* size, uint64_t* cnt) 584 | { 585 | for (; *size > 0 && (uintptr_t) *p % 8; (*p)++) 586 | { 587 | *cnt += popcount64(**p); 588 | *size -= 1; 589 | } 590 | } 591 | 592 | static inline uint64_t popcount64_unrolled(const uint64_t* data, uint64_t size) 593 | { 594 | uint64_t i = 0; 595 | uint64_t limit = size - size % 4; 596 | uint64_t cnt = 0; 597 | 598 | for (; i < limit; i += 4) 599 | { 600 | cnt += popcount64(data[i+0]); 601 | cnt += popcount64(data[i+1]); 602 | cnt += popcount64(data[i+2]); 603 | cnt += popcount64(data[i+3]); 604 | } 605 | 606 | for (; i < size; i++) 607 | cnt += popcount64(data[i]); 608 | 609 | return cnt; 610 | } 611 | 612 | /* 613 | * Count the number of 1 bits in the data array 614 | * @data: An array 615 | * @size: Size of data in bytes 616 | */ 617 | static inline uint64_t popcnt(const void* data, uint64_t size) 618 | { 619 | const uint8_t* ptr = (const uint8_t*) data; 620 | uint64_t cnt = 0; 621 | uint64_t i; 622 | 623 | #if defined(HAVE_CPUID) 624 | #if defined(__cplusplus) 625 | /* C++11 thread-safe singleton */ 626 | static const int cpuid = get_cpuid(); 627 | #else 628 | static int cpuid_ = -1; 629 | int cpuid = cpuid_; 630 | if (cpuid == -1) 631 | { 632 | cpuid = get_cpuid(); 633 | 634 | #if defined(_MSC_VER) 635 | _InterlockedCompareExchange(&cpuid_, cpuid, -1); 636 | #else 637 | __sync_val_compare_and_swap(&cpuid_, -1, cpuid); 638 | #endif 639 | } 640 | #endif 641 | #endif 642 | 643 | #if defined(HAVE_AVX512) 644 | 645 | /* AVX512 requires arrays >= 1024 bytes */ 646 | if ((cpuid & bit_AVX512) && 647 | size >= 1024) 648 | { 649 | align_avx512(&ptr, &size, &cnt); 650 | cnt += popcnt_avx512((const __m512i*) ptr, size / 64); 651 | ptr += size - size % 64; 652 | size = size % 64; 653 | } 654 | 655 | #endif 656 | 657 | #if defined(HAVE_AVX2) 658 | 659 | /* AVX2 requires arrays >= 512 bytes */ 660 | if ((cpuid & bit_AVX2) && 661 | size >= 512) 662 | { 663 | align_avx2(&ptr, &size, &cnt); 664 | cnt += popcnt_avx2((const __m256i*) ptr, size / 32); 665 | ptr += size - size % 32; 666 | size = size % 32; 667 | } 668 | 669 | #endif 670 | 671 | #if defined(HAVE_POPCNT) 672 | 673 | if (cpuid & bit_POPCNT) 674 | { 675 | cnt += popcnt64_unrolled((const uint64_t*) ptr, size / 8); 676 | ptr += size - size % 8; 677 | size = size % 8; 678 | for (i = 0; i < size; i++) 679 | cnt += popcnt64(ptr[i]); 680 | 681 | return cnt; 682 | } 683 | 684 | #endif 685 | 686 | /* pure integer popcount algorithm */ 687 | if (size >= 8) 688 | { 689 | align_8(&ptr, &size, &cnt); 690 | cnt += popcount64_unrolled((const uint64_t*) ptr, size / 8); 691 | ptr += size - size % 8; 692 | size = size % 8; 693 | } 694 | 695 | /* pure integer popcount algorithm */ 696 | for (i = 0; i < size; i++) 697 | cnt += popcount64(ptr[i]); 698 | 699 | return cnt; 700 | } 701 | 702 | #elif defined(__ARM_NEON) || \ 703 | defined(__aarch64__) 704 | 705 | #include 706 | 707 | static inline uint64x2_t vpadalq(uint64x2_t sum, uint8x16_t t) 708 | { 709 | return vpadalq_u32(sum, vpaddlq_u16(vpaddlq_u8(t))); 710 | } 711 | 712 | /* 713 | * Count the number of 1 bits in the data array 714 | * @data: An array 715 | * @size: Size of data in bytes 716 | */ 717 | static inline uint64_t popcnt(const void* data, uint64_t size) 718 | { 719 | const uint8_t* ptr = (const uint8_t*) data; 720 | uint64_t cnt = 0; 721 | uint64_t tmp[2]; 722 | uint64_t chunk_size = 64; 723 | uint64_t n = size / chunk_size; 724 | uint64_t is_sum = 30; 725 | uint64_t i; 726 | 727 | uint64x2_t sum = vcombine_u64(vcreate_u64(0), vcreate_u64(0)); 728 | uint8x16_t zero = vcombine_u8(vcreate_u8(0), vcreate_u8(0)); 729 | 730 | uint8x16_t t0 = zero; 731 | uint8x16_t t1 = zero; 732 | uint8x16_t t2 = zero; 733 | uint8x16_t t3 = zero; 734 | uint8x16x4_t input; 735 | 736 | for (i = 0; i < n; i++, ptr += chunk_size) 737 | { 738 | input = vld4q_u8(ptr); 739 | 740 | t0 = vaddq_u8(t0, vcntq_u8(input.val[0])); 741 | t1 = vaddq_u8(t1, vcntq_u8(input.val[1])); 742 | t2 = vaddq_u8(t2, vcntq_u8(input.val[2])); 743 | t3 = vaddq_u8(t3, vcntq_u8(input.val[3])); 744 | 745 | if (i == is_sum) 746 | { 747 | is_sum += 30; 748 | sum = vpadalq(sum, t0); 749 | sum = vpadalq(sum, t1); 750 | sum = vpadalq(sum, t2); 751 | sum = vpadalq(sum, t3); 752 | t0 = t1 = t2 = t3 = zero; 753 | } 754 | } 755 | 756 | sum = vpadalq(sum, t0); 757 | sum = vpadalq(sum, t1); 758 | sum = vpadalq(sum, t2); 759 | sum = vpadalq(sum, t3); 760 | 761 | vst1q_u64(tmp, sum); 762 | for (i = 0; i < 2; i++) 763 | cnt += tmp[i]; 764 | 765 | size %= chunk_size; 766 | cnt += popcnt64_unrolled((const uint64_t*) ptr, size / 8); 767 | ptr += size - size % 8; 768 | size = size % 8; 769 | for (i = 0; i < size; i++) 770 | cnt += popcnt64(ptr[i]); 771 | 772 | return cnt; 773 | } 774 | 775 | /* all other CPUs */ 776 | #else 777 | 778 | /* Align memory to 8 bytes boundary */ 779 | static inline void align_8(const uint8_t** p, uint64_t* size, uint64_t* cnt) 780 | { 781 | for (; *size > 0 && (uintptr_t) *p % 8; (*p)++) 782 | { 783 | *cnt += popcnt64(**p); 784 | *size -= 1; 785 | } 786 | } 787 | 788 | /* 789 | * Count the number of 1 bits in the data array 790 | * @data: An array 791 | * @size: Size of data in bytes 792 | */ 793 | static inline uint64_t popcnt(const void* data, uint64_t size) 794 | { 795 | const uint8_t* ptr = (const uint8_t*) data; 796 | uint64_t cnt = 0; 797 | uint64_t i; 798 | 799 | align_8(&ptr, &size, &cnt); 800 | cnt += popcnt64_unrolled((const uint64_t*) ptr, size / 8); 801 | ptr += size - size % 8; 802 | size = size % 8; 803 | for (i = 0; i < size; i++) 804 | cnt += popcnt64(ptr[i]); 805 | 806 | return cnt; 807 | } 808 | 809 | #endif 810 | 811 | #endif /* LIBPOPCNT_H */ 812 | -------------------------------------------------------------------------------- /csrc/binop/include/matmul.h: -------------------------------------------------------------------------------- 1 | #ifndef MATMUL_H 2 | #define MATMUL_H 3 | #include 4 | #include 5 | #include "libpopcnt.h" 6 | #define MC 256 7 | #define KC 64 8 | #define NC 256 9 | 10 | #define MR 4 11 | #define NR 4 12 | #define ENCODE_BIT 32 13 | #define MASK(a) ( (a) + ( -(a) & -((0)>(a)) ) ) 14 | const uint32_t UBIT = ~0; 15 | // 16 | // Local buffers for storing panels from A, B and C 17 | // 18 | static uint32_t _A[MC*KC]; 19 | static uint32_t _B[KC*NC]; 20 | 21 | static inline uint32_t popcnt32(uint32_t x) 22 | { 23 | __asm__ ("popcnt %1, %0" : "=r" (x) : "0" (x)); 24 | return x; 25 | } 26 | // 27 | // Packing complete panels from A (i.e. without padding) 28 | // 29 | static void 30 | pack_MRxk(int k, uint32_t *A, int incRowA, int incColA, uint32_t *buffer){ 31 | int i, j; 32 | 33 | for (j=0; j0) { 55 | for (j=0; j0) { 97 | for (i=0; i 2 | #include 3 | #include 4 | #include "matmul.h" 5 | 6 | inline uint32_t encode_val(float* array, int n) { 7 | uint32_t sign, r = 0; 8 | for(int i=0; i0; 10 | r |= (sign<0; 37 | rvalue |= (sign << k); 38 | } 39 | 40 | columns_binary[j + n * i] = rvalue; 41 | } 42 | } 43 | } 44 | 45 | void encode_rows_cpu(THFloatTensor* input, THIntTensor* output) { 46 | int m = input->size[0]; 47 | int n = input->size[1]; 48 | int l = 1+(n-1)/ENCODE_BIT; 49 | 50 | THIntTensor_resize2d(output, m, l); 51 | float* a = THFloatTensor_data(input); 52 | uint32_t* b = (uint32_t*)THIntTensor_data(output); 53 | 54 | encode_rows_cpu_kernel(a, b, m, n); 55 | } 56 | 57 | void encode_cols_cpu(THFloatTensor* input, THIntTensor* output) { 58 | int n = input->size[0]; 59 | int k = input->size[1]; 60 | int l = 1+(n-1)/ENCODE_BIT; 61 | 62 | THIntTensor_resize2d(output, l, k); 63 | float* a = THFloatTensor_data(input); 64 | uint32_t* b = (uint32_t*)THIntTensor_data(output); 65 | 66 | encode_cols_cpu_kernel(a, b, n, k); 67 | } 68 | 69 | void binary_gemm_cpu(THIntTensor* a, THIntTensor* b, THFloatTensor* c, int m, int nn, int k, int transb, int beta, int alpha, THFloatTensor* alphas){ 70 | if (c->nDimension != 2 || c->size[0]*c->size[1] < m*k) { 71 | THFloatTensor_resize2d(c, m, k); 72 | } 73 | uint32_t *A = (uint32_t*)THIntTensor_data(a); 74 | uint32_t *B = (uint32_t*)THIntTensor_data(b); 75 | float *C = THFloatTensor_data(c); 76 | float *D = THFloatTensor_data(alphas); 77 | int n = 1 + (nn-1) / ENCODE_BIT, brow = transb? 1:k, bcol = transb? n:1; 78 | dgemm_nn(m, k, nn, A, n, 1, B, brow, bcol, C, k, 1, beta, alpha, D); 79 | } 80 | 81 | void THNN_unfolded_copy( 82 | THFloatTensor *columns, 83 | THFloatTensor *input, 84 | int kW, int kH, 85 | int dW, int dH, 86 | int padW, int padH, 87 | int nInputPlane, 88 | int inputWidth, int inputHeight, 89 | int outputWidth, int outputHeight) 90 | { 91 | // This function assumes that 92 | // kH*kW does not overflow an int 93 | // nInputPlane*kH*kW does not overflow a int64_t 94 | // outputHeight*dH does not overflow a int64_t 95 | // outputWidth*dW does not overflow a int64_t 96 | 97 | int64_t k; 98 | float *input_data = THFloatTensor_data(input); 99 | float *columns_data = THFloatTensor_data(columns); 100 | 101 | #pragma omp parallel for private(k) 102 | for(k = 0; k < (int64_t)nInputPlane*kH*kW; k++) { 103 | int64_t nip = k / (kH*kW); 104 | int64_t rest = k % (kH*kW); 105 | int64_t kh = rest / kW; 106 | int64_t kw = rest % kW; 107 | int x, y; 108 | int64_t ix, iy; 109 | float *dst = columns_data + nip*((size_t)kH*kW*outputHeight*outputWidth) + kh*((size_t)kW*outputHeight*outputWidth) + kw*((size_t)outputHeight*outputWidth); 110 | float *src = input_data + nip*((size_t)inputHeight*inputWidth); 111 | if (padW > 0 || padH > 0) { 112 | int64_t lpad,rpad; 113 | for(y = 0; y < outputHeight; y++) { 114 | iy = (int64_t)y*dH - padH + kh; 115 | if (iy < 0 || iy >= inputHeight) { 116 | memset(dst+(size_t)y*outputWidth, 0, sizeof(float)*outputWidth); 117 | } else { 118 | if (dW==1){ 119 | ix = 0 - padW + kw; 120 | lpad = fmaxf(0,padW-kw); 121 | rpad = fmaxf(0,padW-(kW-kw-1)); 122 | if (outputWidth-rpad-lpad <= 0) { 123 | memset(dst+(size_t)y*outputWidth, 0, sizeof(float)*outputWidth); 124 | } else { 125 | if (lpad > 0) memset(dst+(size_t)y*outputWidth, 0, sizeof(float)*lpad); 126 | memcpy(dst+(size_t)y*outputWidth+lpad, src+(size_t)iy*inputWidth+ix+lpad, sizeof(float)*(outputWidth-rpad-lpad)); 127 | if (rpad > 0) memset(dst+(size_t)y*outputWidth + outputWidth - rpad, 0, sizeof(float)*rpad); 128 | } 129 | } 130 | else{ 131 | for (x=0; x= inputWidth) 134 | memset(dst+(size_t)y*outputWidth+x, 0, sizeof(float)*1); 135 | else 136 | memcpy(dst+(size_t)y*outputWidth+x, src+(size_t)iy*inputWidth+ix, sizeof(float)*(1)); 137 | } 138 | } 139 | } 140 | } 141 | } else { 142 | for(y = 0; y < outputHeight; y++) { 143 | iy = (int64_t)y*dH + kh; 144 | ix = 0 + kw; 145 | if (dW == 1) 146 | memcpy(dst+(size_t)y*outputWidth, src+(size_t)iy*inputWidth+ix, sizeof(float)*outputWidth); 147 | else{ 148 | for (x=0; xstorage, output->storageOffset, nOutputPlane, -1, outputHeight*outputWidth, -1); 174 | THFloatTensor_zero(output2d); 175 | 176 | binary_gemm_cpu(weight, bin_col, output2d, nOutputPlane, kW*kH*nInputPlane, outputHeight*outputWidth, 0, 1, 1, alphas); 177 | if (bias->nDimension) { 178 | THFloatTensor_addmm(output2d, 1, output2d, 1, bias, ones); 179 | } 180 | THFloatTensor_free(output2d); 181 | } 182 | 183 | void THNN_Bin_SpatialConvolutionMM_updateOutput( 184 | THFloatTensor *input, 185 | THFloatTensor *output, 186 | THIntTensor *weight, 187 | THFloatTensor *bias, 188 | THFloatTensor *columns, 189 | THFloatTensor *alphas, 190 | int kH, int kW, 191 | int dH, int dW, 192 | int padH, int padW) 193 | { 194 | THIntTensor *bin_col = THIntTensor_new(); 195 | THFloatTensor *ones = THFloatTensor_new(); 196 | input = THFloatTensor_newContiguous(input); 197 | int ndim = input->nDimension; 198 | int dimf = 0; 199 | int dimh = 1; 200 | int dimw = 2; 201 | 202 | if (ndim == 4) { 203 | dimf++; 204 | dimh++; 205 | dimw++; 206 | } 207 | 208 | int64_t nInputPlane = input->size[dimf]; 209 | int64_t inputHeight = input->size[dimh]; 210 | int64_t inputWidth = input->size[dimw]; 211 | int64_t nOutputPlane = weight->size[0]; 212 | int64_t outputHeight = (inputHeight + 2*padH - kH) / dH + 1; 213 | int64_t outputWidth = (inputWidth + 2*padW - kW) / dW + 1; 214 | 215 | if (bias->nDimension ==1) { 216 | THFloatTensor_resize2d(bias, bias->size[0], 1); 217 | } 218 | 219 | 220 | THFloatTensor_resize2d(ones, 1, outputHeight*outputWidth); 221 | THFloatTensor_fill(ones, 1); 222 | 223 | int64_t T = input->size[0]; 224 | int64_t t; 225 | 226 | THFloatTensor_resize4d(output, T, nOutputPlane, outputHeight, outputWidth); 227 | THFloatTensor_resize3d(columns, T, kW*kH*nInputPlane, outputHeight*outputWidth); 228 | THIntTensor_resize3d(bin_col, T, weight->size[0], outputHeight*outputWidth); 229 | #pragma omp parallel for private(t) 230 | for(t = 0; t < T; t++) 231 | { 232 | THFloatTensor *input_t = THFloatTensor_newSelect(input, 0, t); 233 | THFloatTensor *columns_t = THFloatTensor_newSelect(columns, 0, t); 234 | THIntTensor *bin_col_t = THIntTensor_newSelect(bin_col, 0, t); 235 | 236 | THNN_unfolded_copy( 237 | columns_t, input_t, kW, kH, dW, dH, padW, padH, 238 | nInputPlane, inputWidth, inputHeight, outputWidth, outputHeight 239 | ); 240 | encode_cols_cpu(columns_t, bin_col_t); 241 | 242 | THFloatTensor_free(input_t); 243 | THFloatTensor_free(columns_t); 244 | THIntTensor_free(bin_col_t); 245 | } 246 | 247 | for(t = 0; t < T; t++){ 248 | THFloatTensor *output_t = THFloatTensor_newSelect(output, 0, t); 249 | THIntTensor *bin_col_t = THIntTensor_newSelect(bin_col, 0, t); 250 | 251 | THNN_Bin_SpatialConvolutionMM_updateOutput_frame( 252 | output_t, weight, bias, ones, bin_col_t, alphas, kW, kH, dW, dH, padW, padH, 253 | nInputPlane, inputWidth, inputHeight, nOutputPlane, outputWidth, outputHeight 254 | ); 255 | 256 | THFloatTensor_free(output_t); 257 | THIntTensor_free(bin_col_t); 258 | } 259 | THFloatTensor_free(input); 260 | THFloatTensor_free(ones); 261 | THIntTensor_free(bin_col); 262 | } 263 | -------------------------------------------------------------------------------- /csrc/binop/src/binop_cuda.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "binop_cuda_kernel.h" 4 | 5 | extern THCState *state; 6 | 7 | void binary_gemm(THCudaIntTensor* a, THCudaIntTensor* b, THCudaTensor* c, int m, int nn, int k, int transb, int alpha, int beta, THCudaTensor *alphas){ 8 | if (c->nDimension != 2 || c->size[0]*c->size[1] < m*k) { 9 | THCudaTensor_resize2d(state, c, m, k); 10 | } 11 | uint32_t *A = (uint32_t*)THCudaIntTensor_data(state, a); 12 | uint32_t *B = (uint32_t*)THCudaIntTensor_data(state, b); 13 | float *C = THCudaTensor_data(state, c); 14 | float *D = alpha? THCudaTensor_data(state, alphas) : NULL; 15 | cudaStream_t stream = THCState_getCurrentStream(state); 16 | 17 | binary_gemm_cuda(A, B, C, m, nn, k, transb, alpha, beta, D, stream); 18 | } 19 | 20 | void im2col(THCudaTensor* data_im, int channels, 21 | int height, int width, 22 | int ksize_h, int ksize_w, int pad_h, 23 | int pad_w, int stride_h, int stride_w, 24 | int dilation_h, int dilation_w, THCudaTensor* data_col) { 25 | // We are going to launch channels * height_col * width_col kernels, each 26 | // kernel responsible for copying a single-channel grid. 27 | int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; 28 | int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; 29 | int num_kernels = channels * height_col * width_col; 30 | 31 | float* d_im = THCudaTensor_data(state, data_im); 32 | float* d_col = THCudaTensor_data(state, data_col); 33 | cudaStream_t stream = THCState_getCurrentStream(state); 34 | 35 | im2col_cuda( 36 | num_kernels, d_im, height, width, ksize_h, ksize_w, 37 | pad_h, pad_w, stride_h, stride_w, 38 | dilation_h, dilation_w, 39 | height_col, width_col, d_col, stream 40 | ); 41 | 42 | THCudaCheck(cudaGetLastError()); 43 | } 44 | 45 | void encode_rows(THCudaTensor* input, THCudaIntTensor* output) { 46 | //THCUNN_assertSameGPU(state, 2, input, output); 47 | 48 | int m = input->size[0]; 49 | int n = input->size[1]; 50 | int l = 1+(n-1)/ENCODE_BITS; 51 | 52 | THCudaIntTensor_resize2d(state, output, m, l); 53 | float* a = THCudaTensor_data(state, input); 54 | uint32_t* b = (uint32_t*)THCudaIntTensor_data(state, output); 55 | cudaStream_t stream = THCState_getCurrentStream(state); 56 | 57 | encode_rows_cuda(a, b, m, n, l, stream); 58 | } 59 | 60 | void encode_cols(THCudaTensor* input, THCudaIntTensor* output) { 61 | //THCUNN_assertSameGPU(state, 2, input, output); 62 | 63 | int n = input->size[0]; 64 | int k = input->size[1]; 65 | int l = 1+(n-1)/ENCODE_BITS; 66 | 67 | THCudaIntTensor_resize2d(state, output, l, k); 68 | float* a = THCudaTensor_data(state, input); 69 | uint32_t* b = (uint32_t*)THCudaIntTensor_data(state, output); 70 | cudaStream_t stream = THCState_getCurrentStream(state); 71 | 72 | encode_cols_cuda(a, b, n, k, stream); 73 | } 74 | 75 | 76 | // Based on the torch SpatialConvolutionMM_updateOutput 77 | void BinarySpatialConvolution_updateOutput( 78 | THCudaTensor *input, 79 | THCudaTensor *output, 80 | THCudaIntTensor *weight, 81 | THCudaTensor *columns, 82 | THCudaTensor *bias, 83 | THCudaTensor *alphas, 84 | int nInputPlane, 85 | int kH, int kW, 86 | int sH, int sW, 87 | int padH, int padW) { 88 | 89 | //THCUNN_assertSameGPU(state, 5, input, output, weight, columns, columns_binary); 90 | 91 | // Params: 92 | // int nInputPlane = weight->size[1]; 93 | int nOutputPlane = weight->size[0]; 94 | 95 | input = THCudaTensor_newContiguous(state, input); 96 | int batch = 1; 97 | if (input->nDimension == 3) { 98 | // Force batch 99 | batch = 0; 100 | THCudaTensor_resize4d(state, input, 1, input->size[0], input->size[1], input->size[2]); 101 | } 102 | 103 | int64_t inputWidth = input->size[3]; 104 | int64_t inputHeight = input->size[2]; 105 | int64_t outputWidth = (inputWidth + 2*padW - kW) / sW + 1; 106 | int64_t outputHeight = (inputHeight + 2*padH - kH) / sH + 1; 107 | 108 | // Batch size + input planes 109 | int64_t batchSize = input->size[0]; 110 | 111 | // Resize output 112 | THCudaTensor_resize4d(state, output, batchSize, nOutputPlane, outputHeight, outputWidth); 113 | 114 | // Resize temporary columns 115 | THCudaTensor_resize2d(state, columns, nInputPlane*kW*kH, outputHeight*outputWidth); 116 | 117 | // Define a buffer of ones, for bias accumulation 118 | // Note: this buffer can be shared with other modules, it only ever gets increased, 119 | // and always contains ones. 120 | THCudaTensor *ones = THCudaTensor_new(state); 121 | THCudaTensor_resize2d(state, ones, outputHeight, outputWidth); 122 | THCudaTensor_fill(state, ones, 1); 123 | 124 | THCudaIntTensor *columns_binary = THCudaIntTensor_new(state); 125 | THCudaIntTensor_resize2d(state, columns_binary, weight->size[1], outputHeight*outputWidth); 126 | 127 | // Helpers 128 | THCudaTensor *input_n = THCudaTensor_new(state); 129 | THCudaTensor *output_n = THCudaTensor_new(state); 130 | 131 | // For each elt in batch, do: 132 | for (int elt = 0; elt < batchSize; elt ++) { 133 | // Matrix mulitply per output: 134 | THCudaTensor_select(state, input_n, input, 0, elt); 135 | THCudaTensor_select(state, output_n, output, 0, elt); 136 | 137 | // Do Bias first: 138 | // M,N,K are dims of matrix A and B 139 | // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm) 140 | int64_t m_ = nOutputPlane; 141 | int64_t n_ = outputHeight * outputWidth; 142 | int64_t k_ = 1; 143 | 144 | // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices) 145 | if (bias->nDimension) { 146 | THCudaBlas_Sgemm( 147 | state, 148 | 't', 'n', 149 | n_, m_, k_, 150 | 1, 151 | THCudaTensor_data(state, ones), k_, 152 | THCudaTensor_data(state, bias), k_, 153 | 0, 154 | THCudaTensor_data(state, output_n), n_ 155 | ); 156 | } else { 157 | THCudaTensor_zero(state, output_n); 158 | } 159 | 160 | // Extract columns: 161 | im2col(input_n, nInputPlane, inputHeight, inputWidth, kH, kW, padH, padW, sH, sW, 1, 1, columns); 162 | 163 | // M,N,K are dims of matrix A and B 164 | // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm) 165 | // row-major to column-major change 166 | int m = weight->size[0]; 167 | //int n = weight->size[1]; 168 | int k = columns->size[1]; 169 | 170 | encode_cols(columns, columns_binary); 171 | binary_gemm(weight, columns_binary, output_n, m, nInputPlane*kW*kH, k, 0, 1, 1, alphas); 172 | } 173 | 174 | if (batch==0) { 175 | THCudaTensor_resize3d(state, output, nOutputPlane, outputHeight, outputWidth); 176 | THCudaTensor_resize3d(state, input, nInputPlane, inputHeight, inputWidth); 177 | } 178 | 179 | // Free 180 | THCudaTensor_free(state, input_n); 181 | THCudaTensor_free(state, output_n); 182 | THCudaTensor_free(state, ones); 183 | 184 | THCudaTensor_free(state, input); 185 | THCudaTensor_free(state, columns); 186 | THCudaIntTensor_free(state, columns_binary); 187 | } 188 | 189 | 190 | -------------------------------------------------------------------------------- /csrc/binop/src/binop_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "binop_cuda_kernel.h" 4 | 5 | int GET_BLOCKS(int N){ 6 | return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; 7 | } 8 | 9 | __global__ void binary_gemm_kernel(uint32_t* A, uint32_t* B, float* C, int m, int nn, int k, int transb, int alpha, int beta, float *alphas) { 10 | int blockRow = blockIdx.y; 11 | int blockCol = blockIdx.x; 12 | 13 | int row = threadIdx.y; 14 | int col = threadIdx.x; 15 | 16 | int n = 1 + (nn-1)/ENCODE_BITS; 17 | int startLocation = BLOCK_SIZE * k * blockRow + BLOCK_SIZE * blockCol; 18 | 19 | float* Csub = &C[BLOCK_SIZE * k * blockRow + BLOCK_SIZE * blockCol]; 20 | 21 | __shared__ uint32_t As[BLOCK_SIZE][BLOCK_SIZE]; 22 | __shared__ uint32_t Bs[BLOCK_SIZE][BLOCK_SIZE]; 23 | 24 | int Cvalue = 0; 25 | 26 | int c = blockIdx.x*blockDim.x + threadIdx.x; 27 | int r = blockIdx.y*blockDim.y + threadIdx.y; 28 | int lim = 1+( (n-1) / BLOCK_SIZE); 29 | for (int i = 0; i < lim; ++i) { 30 | 31 | // Get sub-matrix Asub of A 32 | uint32_t* Asub = &A[BLOCK_SIZE * blockRow * n + BLOCK_SIZE * i]; 33 | 34 | // Get sub-matrix Bsub of B 35 | uint32_t* Bsub = transb? &B[BLOCK_SIZE * blockCol * n + BLOCK_SIZE * i] : &B[BLOCK_SIZE * k * i + BLOCK_SIZE * blockCol]; 36 | 37 | if ((BLOCK_SIZE*i+col)= 0 && w >= 0 && h < height && w < width) ? 79 | data_im[i * dilation_h * width + j * dilation_w] : 0; 80 | data_col += height_col * width_col; 81 | } 82 | } 83 | } 84 | } 85 | 86 | __forceinline__ __device__ uint32_t encode_val(float* array, int n) { 87 | uint32_t r = 0; 88 | for(int i=0; i0)<0)<>>(A, B, C, m, n, k, transb, alpha, beta, alphas); 117 | } 118 | 119 | void im2col_cuda(int n, float* data_im, int height, int width, 120 | int ksize_h, int ksize_w, int pad_h, int pad_w, 121 | int stride_h, int stride_w, int dilation_h, int dilation_w, 122 | int height_col, int width_col, float* data_col, cudaStream_t stream){ 123 | im2col_kernel <<< GET_BLOCKS(n), CUDA_NUM_THREADS, 0, stream >>> ( 124 | n, data_im, height, width, ksize_h, ksize_w, 125 | pad_h, pad_w, stride_h, stride_w, 126 | dilation_h, dilation_w, 127 | height_col, width_col, data_col 128 | ); 129 | } 130 | 131 | void encode_rows_cuda(float* input, uint32_t* output, int m, int n, int l, cudaStream_t stream) { 132 | encode_rows_kernel <<< GET_BLOCKS(m*l), CUDA_NUM_THREADS, 0, stream >>>(input, output, m, n, l); 133 | } 134 | 135 | void encode_cols_cuda(float* input, uint32_t* output, int n, int k, cudaStream_t stream) { 136 | dim3 blockDim(ENCODE_BITS, ENCODE_BITS, 1); 137 | dim3 gridDim(k/ENCODE_BITS+1, n/ENCODE_BITS+1, 1); 138 | 139 | encode_cols_kernel <<< gridDim, blockDim, 0, stream >>>(input, output, n, k); 140 | } 141 | 142 | 143 | 144 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | from .util import * -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | import binop 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | from torch.autograd import Function 7 | 8 | 9 | def bin_save_state(args, model): 10 | print('==> Binarizing and Saving model ...') 11 | state = model.state_dict() 12 | weight_ = [] 13 | for key in state.keys(): 14 | if 'weight' in key and 'bn' not in key: 15 | weight_.append((key, state.get(key))) 16 | 17 | # except the first and last layer 18 | weight_.pop(0) 19 | weight_.pop() 20 | 21 | for key, weight in weight_: 22 | s = weight.size() 23 | if len(s) == 4: 24 | weight = weight.view(s[0], s[1] * s[2] * s[3]) 25 | 26 | if args.cuda: 27 | bin_weight = torch.cuda.IntTensor() 28 | binop.encode_rows(weight, bin_weight) 29 | else: 30 | bin_weight = torch.IntTensor() 31 | binop.encode_rows_cpu(weight, bin_weight) 32 | 33 | state[key] = bin_weight 34 | torch.save(state, 'models/' + args.arch + '.pth') 35 | 36 | 37 | def bin_conv2d(input, weight, bias, alpha, kernel_size, stride, padding): 38 | out_tensor = torch.FloatTensor() 39 | col_tensor = torch.FloatTensor() 40 | use_cuda = input.is_cuda 41 | if use_cuda: 42 | out_tensor = out_tensor.cuda() 43 | col_tensor = col_tensor.cuda() 44 | output = Variable(out_tensor, requires_grad=False) 45 | if bias is None: 46 | if use_cuda: 47 | bias = Variable(torch.cuda.FloatTensor(), requires_grad=False) 48 | else: 49 | bias = Variable(torch.FloatTensor(), requires_grad=False) 50 | if use_cuda: 51 | binop.BinarySpatialConvolution_updateOutput( 52 | input.data, output.data, weight.data, col_tensor, bias.data, alpha.data, 53 | input.data.shape[1], kernel_size[0], kernel_size[1], stride[0], stride[1], padding[0], padding[1] 54 | ) 55 | else: 56 | binop.THNN_Bin_SpatialConvolutionMM_updateOutput( 57 | input.data, output.data, weight.data, bias.data, col_tensor, alpha.data, 58 | kernel_size[0], kernel_size[1], stride[0], stride[1], padding[0], padding[1] 59 | ) 60 | return output 61 | 62 | 63 | def bin_linear(input, weight, bias, alpha): 64 | m = input.data.shape[0] 65 | n = input.data.shape[1] 66 | k = weight.data.shape[0] 67 | out_tensor = torch.FloatTensor() 68 | bin_input = torch.IntTensor() 69 | use_cuda = input.is_cuda 70 | 71 | if use_cuda: 72 | bin_input = bin_input.cuda() 73 | out_tensor = out_tensor.cuda() 74 | 75 | output = Variable(out_tensor, requires_grad=False) 76 | if use_cuda: 77 | binop.encode_rows(input.data, bin_input) 78 | binop.binary_gemm(bin_input, weight.data, output.data, m, n, k, 1, 0, 0, alpha.data) 79 | else: 80 | binop.encode_rows_cpu(input.data, bin_input) 81 | binop.binary_gemm_cpu(bin_input, weight.data, output.data, m, n, k, 1, 0, 0, alpha.data) 82 | output.data.mul_(alpha.data.t().expand(output.shape)) 83 | if bias is not None: 84 | output.data.add_(bias.data.expand(output.shape)) 85 | return output 86 | 87 | 88 | class BinActive(Function): 89 | @staticmethod 90 | def forward(self, input): 91 | self.save_for_backward(input) 92 | input = input.sign() 93 | return input 94 | 95 | @staticmethod 96 | def backward(self, grad_output): 97 | input, = self.saved_tensors 98 | grad_input = grad_output.clone() 99 | grad_input[input.ge(1)] = 0 100 | grad_input[input.le(-1)] = 0 101 | return grad_input 102 | 103 | 104 | class BinConv2d(nn.Conv2d): 105 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, istrain=True, drop=0): 106 | super().__init__(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias) 107 | 108 | self.alpha = nn.Parameter(torch.FloatTensor(out_channels, 1, 1, 1), requires_grad=False) 109 | self.istrain = istrain 110 | self.bn = nn.BatchNorm2d(in_channels) 111 | self.dropout_ratio = drop 112 | 113 | if drop != 0: 114 | self.drop = nn.Dropout(drop) 115 | if not istrain: 116 | self.weight = nn.Parameter(torch.IntTensor(out_channels, 1 + ( in_channels * self.kernel_size[0] * self.kernel_size[1] - 1) // 32)) 117 | 118 | def forward(self, input): 119 | input = self.bn(input) 120 | if self.istrain: 121 | input = BinActive.apply(input) 122 | if self.dropout_ratio != 0: 123 | input = self.drop(input) 124 | input = F.conv2d(input, weight=self.weight, bias=self.bias, stride=self.stride, padding=self.padding) 125 | else: 126 | input = bin_conv2d(input, self.weight, self.bias, self.alpha, self.kernel_size, self.stride, self.padding) 127 | return input 128 | 129 | 130 | class BinLinear(nn.Linear): 131 | def __init__(self, in_features, out_features, bias=True, istrain=True, drop=0): 132 | super().__init__(in_features, out_features, bias) 133 | 134 | self.alpha = nn.Parameter(torch.FloatTensor(out_features, 1), requires_grad=False) 135 | self.istrain = istrain 136 | self.bn = nn.BatchNorm1d(in_features) 137 | self.dropout_ratio = drop 138 | if drop != 0: 139 | self.drop = nn.Dropout(drop) 140 | if not istrain: 141 | self.weight = nn.Parameter(torch.IntTensor(out_features, 1 + (in_features - 1) // 32)) 142 | 143 | def forward(self, input): 144 | input = self.bn(input) 145 | if self.istrain: 146 | input = BinActive.apply(input) 147 | if self.dropout_ratio != 0: 148 | input = self.drop(input) 149 | input = F.linear(input, weight=self.weight, bias=self.bias) 150 | else: 151 | input = bin_linear(input, weight=self.weight, bias=self.bias, alpha=self.alpha) 152 | return input 153 | 154 | 155 | class binop_train: 156 | def __init__(self, model): 157 | self.alpha_to_save = [] 158 | self.saved_params = [] 159 | self.target_modules = [] 160 | for m in model.modules(): 161 | if type(m).__name__ in ['BinConv2d', 'BinLinear']: 162 | tmp = m.weight.data.clone() 163 | self.saved_params.append(tmp) 164 | self.target_modules.append(m.weight) 165 | self.alpha_to_save.append(m.alpha) 166 | self.num_of_params = len(self.target_modules) 167 | 168 | def binarization(self): 169 | for index in range(self.num_of_params): 170 | n = self.target_modules[index].data[0].nelement() 171 | s = self.target_modules[index].data.size() 172 | 173 | # meancenter 174 | negMean = self.target_modules[index].data.mean(1, keepdim=True).mul(-1).expand_as( 175 | self.target_modules[index].data) 176 | self.target_modules[index].data.add_(negMean) 177 | # clamp 178 | self.target_modules[index].data.clamp_(-1.0, 1.0) 179 | # save param 180 | self.saved_params[index].copy_(self.target_modules[index].data) 181 | 182 | # get alpha, binarize weight and mutiply alpha 183 | if len(s) == 4: 184 | self.alpha_to_save[index].data = \ 185 | self.target_modules[index].data.norm(1, 3, keepdim=True).sum(2, keepdim=True).sum(1, 186 | keepdim=True).div( 187 | n) 188 | elif len(s) == 2: 189 | self.alpha_to_save[index].data = \ 190 | self.target_modules[index].data.norm(1, 1, keepdim=True).div(n) 191 | self.target_modules[index].data.sign().mul( 192 | self.alpha_to_save[index].data.expand(s), out=self.target_modules[index].data) 193 | 194 | def restore(self): 195 | for index in range(self.num_of_params): 196 | self.target_modules[index].data.copy_(self.saved_params[index]) 197 | 198 | def updateBinaryGradWeight(self): 199 | for index in range(self.num_of_params): 200 | weight = self.target_modules[index].data 201 | alpha = self.alpha_to_save[index].data.clone() 202 | n = weight[0].nelement() 203 | s = weight.size() 204 | alpha = alpha.expand(s) 205 | alpha[weight.le(-1.0)] = 0 206 | alpha[weight.ge(1.0)] = 0 207 | alpha = alpha.mul(self.target_modules[index].grad.data) 208 | add = weight.sign().mul(self.target_modules[index].grad.data) 209 | if len(s) == 4: 210 | add = add.sum(3, keepdim=True).sum(2, keepdim=True).sum(1, keepdim=True).div(n).expand(s) 211 | elif len(s) == 2: 212 | add = add.sum(1, keepdim=True).div(n).expand(s) 213 | add = add.mul(weight.sign()) 214 | self.target_modules[index].grad.data = alpha.add(add).mul(1.0 - 1.0 / s[1]).mul(n) 215 | --------------------------------------------------------------------------------