├── README.md ├── data.py ├── main_binary.py ├── main_binary_hinge.py ├── main_mnist.py ├── models ├── __init__.py ├── alexnet.py ├── alexnet_binary.py ├── binarized_modules.py ├── resnet.py ├── resnet_binary.py ├── vgg_cifar10.py └── vgg_cifar10_binary.py ├── preprocess.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # BNN.pytorch 2 | Binarized Neural Network (BNN) for pytorch 3 | This is the pytorch version for the BNN code, fro VGG and resnet models 4 | Link to the paper: https://papers.nips.cc/paper/6573-binarized-neural-networks 5 | 6 | The code is based on https://github.com/eladhoffer/convNet.pytorch 7 | Please install torch and torchvision by following the instructions at: http://pytorch.org/ 8 | To run resnet18 for cifar10 dataset use: python main_binary.py --model resnet_binary --save resnet18_binary --dataset cifar10 9 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torchvision.datasets as datasets 3 | import torchvision.transforms as transforms 4 | 5 | _DATASETS_MAIN_PATH = '/home/Datasets' 6 | _dataset_path = { 7 | 'cifar10': os.path.join(_DATASETS_MAIN_PATH, 'CIFAR10'), 8 | 'cifar100': os.path.join(_DATASETS_MAIN_PATH, 'CIFAR100'), 9 | 'stl10': os.path.join(_DATASETS_MAIN_PATH, 'STL10'), 10 | 'mnist': os.path.join(_DATASETS_MAIN_PATH, 'MNIST'), 11 | 'imagenet': { 12 | 'train': os.path.join(_DATASETS_MAIN_PATH, 'ImageNet/train'), 13 | 'val': os.path.join(_DATASETS_MAIN_PATH, 'ImageNet/val') 14 | } 15 | } 16 | 17 | 18 | def get_dataset(name, split='train', transform=None, 19 | target_transform=None, download=True): 20 | train = (split == 'train') 21 | if name == 'cifar10': 22 | return datasets.CIFAR10(root=_dataset_path['cifar10'], 23 | train=train, 24 | transform=transform, 25 | target_transform=target_transform, 26 | download=download) 27 | elif name == 'cifar100': 28 | return datasets.CIFAR100(root=_dataset_path['cifar100'], 29 | train=train, 30 | transform=transform, 31 | target_transform=target_transform, 32 | download=download) 33 | elif name == 'imagenet': 34 | path = _dataset_path[name][split] 35 | return datasets.ImageFolder(root=path, 36 | transform=transform, 37 | target_transform=target_transform) 38 | -------------------------------------------------------------------------------- /main_binary.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import logging 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.parallel 8 | import torch.backends.cudnn as cudnn 9 | import torch.optim 10 | import torch.utils.data 11 | import models 12 | from torch.autograd import Variable 13 | from data import get_dataset 14 | from preprocess import get_transform 15 | from utils import * 16 | from datetime import datetime 17 | from ast import literal_eval 18 | from torchvision.utils import save_image 19 | 20 | 21 | model_names = sorted(name for name in models.__dict__ 22 | if name.islower() and not name.startswith("__") 23 | and callable(models.__dict__[name])) 24 | 25 | parser = argparse.ArgumentParser(description='PyTorch ConvNet Training') 26 | 27 | parser.add_argument('--results_dir', metavar='RESULTS_DIR', default='./results', 28 | help='results dir') 29 | parser.add_argument('--save', metavar='SAVE', default='', 30 | help='saved folder') 31 | parser.add_argument('--dataset', metavar='DATASET', default='imagenet', 32 | help='dataset name or folder') 33 | parser.add_argument('--model', '-a', metavar='MODEL', default='alexnet', 34 | choices=model_names, 35 | help='model architecture: ' + 36 | ' | '.join(model_names) + 37 | ' (default: alexnet)') 38 | parser.add_argument('--input_size', type=int, default=None, 39 | help='image input size') 40 | parser.add_argument('--model_config', default='', 41 | help='additional architecture configuration') 42 | parser.add_argument('--type', default='torch.cuda.FloatTensor', 43 | help='type of tensor - e.g torch.cuda.HalfTensor') 44 | parser.add_argument('--gpus', default='0', 45 | help='gpus used for training - e.g 0,1,3') 46 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', 47 | help='number of data loading workers (default: 8)') 48 | parser.add_argument('--epochs', default=2500, type=int, metavar='N', 49 | help='number of total epochs to run') 50 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 51 | help='manual epoch number (useful on restarts)') 52 | parser.add_argument('-b', '--batch-size', default=256, type=int, 53 | metavar='N', help='mini-batch size (default: 256)') 54 | parser.add_argument('--optimizer', default='SGD', type=str, metavar='OPT', 55 | help='optimizer function used') 56 | parser.add_argument('--lr', '--learning_rate', default=0.1, type=float, 57 | metavar='LR', help='initial learning rate') 58 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 59 | help='momentum') 60 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 61 | metavar='W', help='weight decay (default: 1e-4)') 62 | parser.add_argument('--print-freq', '-p', default=10, type=int, 63 | metavar='N', help='print frequency (default: 10)') 64 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 65 | help='path to latest checkpoint (default: none)') 66 | parser.add_argument('-e', '--evaluate', type=str, metavar='FILE', 67 | help='evaluate model FILE on validation set') 68 | 69 | 70 | def main(): 71 | global args, best_prec1 72 | best_prec1 = 0 73 | args = parser.parse_args() 74 | 75 | if args.evaluate: 76 | args.results_dir = '/tmp' 77 | if args.save is '': 78 | args.save = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') 79 | save_path = os.path.join(args.results_dir, args.save) 80 | if not os.path.exists(save_path): 81 | os.makedirs(save_path) 82 | 83 | setup_logging(os.path.join(save_path, 'log.txt')) 84 | results_file = os.path.join(save_path, 'results.%s') 85 | results = ResultsLog(results_file % 'csv', results_file % 'html') 86 | 87 | logging.info("saving to %s", save_path) 88 | logging.debug("run arguments: %s", args) 89 | 90 | if 'cuda' in args.type: 91 | args.gpus = [int(i) for i in args.gpus.split(',')] 92 | torch.cuda.set_device(args.gpus[0]) 93 | cudnn.benchmark = True 94 | else: 95 | args.gpus = None 96 | 97 | # create model 98 | logging.info("creating model %s", args.model) 99 | model = models.__dict__[args.model] 100 | model_config = {'input_size': args.input_size, 'dataset': args.dataset} 101 | 102 | if args.model_config is not '': 103 | model_config = dict(model_config, **literal_eval(args.model_config)) 104 | 105 | model = model(**model_config) 106 | logging.info("created model with configuration: %s", model_config) 107 | 108 | # optionally resume from a checkpoint 109 | if args.evaluate: 110 | if not os.path.isfile(args.evaluate): 111 | parser.error('invalid checkpoint: {}'.format(args.evaluate)) 112 | checkpoint = torch.load(args.evaluate) 113 | model.load_state_dict(checkpoint['state_dict']) 114 | logging.info("loaded checkpoint '%s' (epoch %s)", 115 | args.evaluate, checkpoint['epoch']) 116 | elif args.resume: 117 | checkpoint_file = args.resume 118 | if os.path.isdir(checkpoint_file): 119 | results.load(os.path.join(checkpoint_file, 'results.csv')) 120 | checkpoint_file = os.path.join( 121 | checkpoint_file, 'model_best.pth.tar') 122 | if os.path.isfile(checkpoint_file): 123 | logging.info("loading checkpoint '%s'", args.resume) 124 | checkpoint = torch.load(checkpoint_file) 125 | args.start_epoch = checkpoint['epoch'] - 1 126 | best_prec1 = checkpoint['best_prec1'] 127 | model.load_state_dict(checkpoint['state_dict']) 128 | logging.info("loaded checkpoint '%s' (epoch %s)", 129 | checkpoint_file, checkpoint['epoch']) 130 | else: 131 | logging.error("no checkpoint found at '%s'", args.resume) 132 | 133 | num_parameters = sum([l.nelement() for l in model.parameters()]) 134 | logging.info("number of parameters: %d", num_parameters) 135 | 136 | # Data loading code 137 | default_transform = { 138 | 'train': get_transform(args.dataset, 139 | input_size=args.input_size, augment=True), 140 | 'eval': get_transform(args.dataset, 141 | input_size=args.input_size, augment=False) 142 | } 143 | transform = getattr(model, 'input_transform', default_transform) 144 | regime = getattr(model, 'regime', {0: {'optimizer': args.optimizer, 145 | 'lr': args.lr, 146 | 'momentum': args.momentum, 147 | 'weight_decay': args.weight_decay}}) 148 | # define loss function (criterion) and optimizer 149 | criterion = getattr(model, 'criterion', nn.CrossEntropyLoss)() 150 | criterion.type(args.type) 151 | model.type(args.type) 152 | 153 | val_data = get_dataset(args.dataset, 'val', transform['eval']) 154 | val_loader = torch.utils.data.DataLoader( 155 | val_data, 156 | batch_size=args.batch_size, shuffle=False, 157 | num_workers=args.workers, pin_memory=True) 158 | 159 | if args.evaluate: 160 | validate(val_loader, model, criterion, 0) 161 | return 162 | 163 | train_data = get_dataset(args.dataset, 'train', transform['train']) 164 | train_loader = torch.utils.data.DataLoader( 165 | train_data, 166 | batch_size=args.batch_size, shuffle=True, 167 | num_workers=args.workers, pin_memory=True) 168 | 169 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr) 170 | logging.info('training regime: %s', regime) 171 | 172 | 173 | for epoch in range(args.start_epoch, args.epochs): 174 | optimizer = adjust_optimizer(optimizer, epoch, regime) 175 | 176 | # train for one epoch 177 | train_loss, train_prec1, train_prec5 = train( 178 | train_loader, model, criterion, epoch, optimizer) 179 | 180 | # evaluate on validation set 181 | val_loss, val_prec1, val_prec5 = validate( 182 | val_loader, model, criterion, epoch) 183 | 184 | # remember best prec@1 and save checkpoint 185 | is_best = val_prec1 > best_prec1 186 | best_prec1 = max(val_prec1, best_prec1) 187 | 188 | save_checkpoint({ 189 | 'epoch': epoch + 1, 190 | 'model': args.model, 191 | 'config': args.model_config, 192 | 'state_dict': model.state_dict(), 193 | 'best_prec1': best_prec1, 194 | 'regime': regime 195 | }, is_best, path=save_path) 196 | logging.info('\n Epoch: {0}\t' 197 | 'Training Loss {train_loss:.4f} \t' 198 | 'Training Prec@1 {train_prec1:.3f} \t' 199 | 'Training Prec@5 {train_prec5:.3f} \t' 200 | 'Validation Loss {val_loss:.4f} \t' 201 | 'Validation Prec@1 {val_prec1:.3f} \t' 202 | 'Validation Prec@5 {val_prec5:.3f} \n' 203 | .format(epoch + 1, train_loss=train_loss, val_loss=val_loss, 204 | train_prec1=train_prec1, val_prec1=val_prec1, 205 | train_prec5=train_prec5, val_prec5=val_prec5)) 206 | 207 | results.add(epoch=epoch + 1, train_loss=train_loss, val_loss=val_loss, 208 | train_error1=100 - train_prec1, val_error1=100 - val_prec1, 209 | train_error5=100 - train_prec5, val_error5=100 - val_prec5) 210 | #results.plot(x='epoch', y=['train_loss', 'val_loss'], 211 | # title='Loss', ylabel='loss') 212 | #results.plot(x='epoch', y=['train_error1', 'val_error1'], 213 | # title='Error@1', ylabel='error %') 214 | #results.plot(x='epoch', y=['train_error5', 'val_error5'], 215 | # title='Error@5', ylabel='error %') 216 | results.save() 217 | 218 | 219 | def forward(data_loader, model, criterion, epoch=0, training=True, optimizer=None): 220 | if args.gpus and len(args.gpus) > 1: 221 | model = torch.nn.DataParallel(model, args.gpus) 222 | batch_time = AverageMeter() 223 | data_time = AverageMeter() 224 | losses = AverageMeter() 225 | top1 = AverageMeter() 226 | top5 = AverageMeter() 227 | 228 | end = time.time() 229 | for i, (inputs, target) in enumerate(data_loader): 230 | # measure data loading time 231 | data_time.update(time.time() - end) 232 | if args.gpus is not None: 233 | target = target.cuda() 234 | 235 | if not training: 236 | with torch.no_grad(): 237 | input_var = Variable(inputs.type(args.type), volatile=not training) 238 | target_var = Variable(target) 239 | # compute output 240 | output = model(input_var) 241 | else: 242 | input_var = Variable(inputs.type(args.type), volatile=not training) 243 | target_var = Variable(target) 244 | # compute output 245 | output = model(input_var) 246 | 247 | 248 | loss = criterion(output, target_var) 249 | if type(output) is list: 250 | output = output[0] 251 | 252 | # measure accuracy and record loss 253 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 254 | losses.update(loss.item(), inputs.size(0)) 255 | top1.update(prec1.item(), inputs.size(0)) 256 | top5.update(prec5.item(), inputs.size(0)) 257 | 258 | if training: 259 | # compute gradient and do SGD step 260 | optimizer.zero_grad() 261 | loss.backward() 262 | optimizer.step() 263 | 264 | 265 | # measure elapsed time 266 | batch_time.update(time.time() - end) 267 | end = time.time() 268 | 269 | if i % args.print_freq == 0: 270 | logging.info('{phase} - Epoch: [{0}][{1}/{2}]\t' 271 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 272 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 273 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 274 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 275 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 276 | epoch, i, len(data_loader), 277 | phase='TRAINING' if training else 'EVALUATING', 278 | batch_time=batch_time, 279 | data_time=data_time, loss=losses, top1=top1, top5=top5)) 280 | 281 | return losses.avg, top1.avg, top5.avg 282 | 283 | 284 | def train(data_loader, model, criterion, epoch, optimizer): 285 | # switch to train mode 286 | model.train() 287 | return forward(data_loader, model, criterion, epoch, 288 | training=True, optimizer=optimizer) 289 | 290 | 291 | def validate(data_loader, model, criterion, epoch): 292 | # switch to evaluate mode 293 | model.eval() 294 | return forward(data_loader, model, criterion, epoch, 295 | training=False, optimizer=None) 296 | 297 | 298 | if __name__ == '__main__': 299 | main() 300 | -------------------------------------------------------------------------------- /main_binary_hinge.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import logging 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.parallel 8 | import torch.backends.cudnn as cudnn 9 | import torch.optim 10 | import torch.utils.data 11 | import models 12 | from torch.autograd import Variable 13 | from data import get_dataset 14 | from preprocess import get_transform 15 | from utils import * 16 | from datetime import datetime 17 | from ast import literal_eval 18 | from torchvision.utils import save_image 19 | from models.binarized_modules import HingeLoss 20 | 21 | model_names = sorted(name for name in models.__dict__ 22 | if name.islower() and not name.startswith("__") 23 | and callable(models.__dict__[name])) 24 | 25 | parser = argparse.ArgumentParser(description='PyTorch ConvNet Training') 26 | 27 | parser.add_argument('--results_dir', metavar='RESULTS_DIR', default='/media/hdd/ihubara/BinaryNet.pytorch/results', 28 | help='results dir') 29 | parser.add_argument('--save', metavar='SAVE', default='', 30 | help='saved folder') 31 | parser.add_argument('--dataset', metavar='DATASET', default='imagenet', 32 | help='dataset name or folder') 33 | parser.add_argument('--model', '-a', metavar='MODEL', default='alexnet', 34 | choices=model_names, 35 | help='model architecture: ' + 36 | ' | '.join(model_names) + 37 | ' (default: alexnet)') 38 | parser.add_argument('--input_size', type=int, default=None, 39 | help='image input size') 40 | parser.add_argument('--model_config', default='', 41 | help='additional architecture configuration') 42 | parser.add_argument('--type', default='torch.cuda.FloatTensor', 43 | help='type of tensor - e.g torch.cuda.HalfTensor') 44 | parser.add_argument('--gpus', default='0', 45 | help='gpus used for training - e.g 0,1,3') 46 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', 47 | help='number of data loading workers (default: 8)') 48 | parser.add_argument('--epochs', default=900, type=int, metavar='N', 49 | help='number of total epochs to run') 50 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 51 | help='manual epoch number (useful on restarts)') 52 | parser.add_argument('-b', '--batch-size', default=256, type=int, 53 | metavar='N', help='mini-batch size (default: 256)') 54 | parser.add_argument('--optimizer', default='SGD', type=str, metavar='OPT', 55 | help='optimizer function used') 56 | parser.add_argument('--lr', '--learning_rate', default=0.1, type=float, 57 | metavar='LR', help='initial learning rate') 58 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 59 | help='momentum') 60 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 61 | metavar='W', help='weight decay (default: 1e-4)') 62 | parser.add_argument('--print-freq', '-p', default=10, type=int, 63 | metavar='N', help='print frequency (default: 10)') 64 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 65 | help='path to latest checkpoint (default: none)') 66 | parser.add_argument('-e', '--evaluate', type=str, metavar='FILE', 67 | help='evaluate model FILE on validation set') 68 | 69 | torch.cuda.random.manual_seed_all(10) 70 | 71 | output_dim = 0 72 | 73 | 74 | def main(): 75 | global args, best_prec1, output_dim 76 | best_prec1 = 0 77 | args = parser.parse_args() 78 | output_dim = {'cifar10': 10, 'cifar100':100, 'imagenet': 1000}[args.dataset] 79 | #import pdb; pdb.set_trace() 80 | #torch.save(args.batch_size/(len(args.gpus)/2+1),'multi_gpu_batch_size') 81 | if args.evaluate: 82 | args.results_dir = '/tmp' 83 | if args.save is '': 84 | args.save = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') 85 | save_path = os.path.join(args.results_dir, args.save) 86 | if not os.path.exists(save_path): 87 | os.makedirs(save_path) 88 | 89 | setup_logging(os.path.join(save_path, 'log.txt')) 90 | results_file = os.path.join(save_path, 'results.%s') 91 | results = ResultsLog(results_file % 'csv', results_file % 'html') 92 | 93 | logging.info("saving to %s", save_path) 94 | logging.debug("run arguments: %s", args) 95 | 96 | if 'cuda' in args.type: 97 | args.gpus = [int(i) for i in args.gpus.split(',')] 98 | torch.cuda.set_device(args.gpus[0]) 99 | cudnn.benchmark = True 100 | else: 101 | args.gpus = None 102 | 103 | # create model 104 | logging.info("creating model %s", args.model) 105 | model = models.__dict__[args.model] 106 | 107 | 108 | model_config = {'input_size': args.input_size, 'dataset': args.dataset, 'num_classes': output_dim} 109 | 110 | if args.model_config is not '': 111 | model_config = dict(model_config, **literal_eval(args.model_config)) 112 | model = model(**model_config) 113 | logging.info("created model with configuration: %s", model_config) 114 | 115 | # optionally resume from a checkpoint 116 | if args.evaluate: 117 | if not os.path.isfile(args.evaluate): 118 | parser.error('invalid checkpoint: {}'.format(args.evaluate)) 119 | checkpoint = torch.load(args.evaluate) 120 | model.load_state_dict(checkpoint['state_dict']) 121 | logging.info("loaded checkpoint '%s' (epoch %s)", 122 | args.evaluate, checkpoint['epoch']) 123 | elif args.resume: 124 | checkpoint_file = args.resume 125 | if os.path.isdir(checkpoint_file): 126 | results.load(os.path.join(checkpoint_file, 'results.csv')) 127 | checkpoint_file = os.path.join( 128 | checkpoint_file, 'model_best.pth.tar') 129 | if os.path.isfile(checkpoint_file): 130 | logging.info("loading checkpoint '%s'", args.resume) 131 | checkpoint = torch.load(checkpoint_file) 132 | args.start_epoch = checkpoint['epoch'] - 1 133 | best_prec1 = checkpoint['best_prec1'] 134 | model.load_state_dict(checkpoint['state_dict']) 135 | logging.info("loaded checkpoint '%s' (epoch %s)", 136 | checkpoint_file, checkpoint['epoch']) 137 | else: 138 | logging.error("no checkpoint found at '%s'", args.resume) 139 | 140 | num_parameters = sum([l.nelement() for l in model.parameters()]) 141 | logging.info("number of parameters: %d", num_parameters) 142 | 143 | # Data loading code 144 | default_transform = { 145 | 'train': get_transform(args.dataset, 146 | input_size=args.input_size, augment=True), 147 | 'eval': get_transform(args.dataset, 148 | input_size=args.input_size, augment=False) 149 | } 150 | transform = getattr(model, 'input_transform', default_transform) 151 | regime = getattr(model, 'regime', {0: {'optimizer': args.optimizer, 152 | 'lr': args.lr, 153 | 'momentum': args.momentum, 154 | 'weight_decay': args.weight_decay}}) 155 | # define loss function (criterion) and optimizer 156 | #criterion = getattr(model, 'criterion', nn.NLLLoss)() 157 | criterion = getattr(model, 'criterion', HingeLoss)() 158 | #criterion.type(args.type) 159 | model.type(args.type) 160 | 161 | val_data = get_dataset(args.dataset, 'val', transform['eval']) 162 | val_loader = torch.utils.data.DataLoader( 163 | val_data, 164 | batch_size=args.batch_size, shuffle=False, 165 | num_workers=args.workers, pin_memory=True) 166 | 167 | if args.evaluate: 168 | validate(val_loader, model, criterion, 0) 169 | return 170 | 171 | train_data = get_dataset(args.dataset, 'train', transform['train']) 172 | train_loader = torch.utils.data.DataLoader( 173 | train_data, 174 | batch_size=args.batch_size, shuffle=True, 175 | num_workers=args.workers, pin_memory=True) 176 | 177 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr) 178 | logging.info('training regime: %s', regime) 179 | #import pdb; pdb.set_trace() 180 | #search_binarized_modules(model) 181 | 182 | for epoch in range(args.start_epoch, args.epochs): 183 | optimizer = adjust_optimizer(optimizer, epoch, regime) 184 | 185 | # train for one epoch 186 | train_loss, train_prec1, train_prec5 = train( 187 | train_loader, model, criterion, epoch, optimizer) 188 | 189 | # evaluate on validation set 190 | val_loss, val_prec1, val_prec5 = validate( 191 | val_loader, model, criterion, epoch) 192 | 193 | # remember best prec@1 and save checkpoint 194 | is_best = val_prec1 > best_prec1 195 | best_prec1 = max(val_prec1, best_prec1) 196 | save_checkpoint({ 197 | 'epoch': epoch + 1, 198 | 'model': args.model, 199 | 'config': args.model_config, 200 | 'state_dict': model.state_dict(), 201 | 'best_prec1': best_prec1, 202 | 'regime': regime 203 | }, is_best, path=save_path) 204 | logging.info('\n Epoch: {0}\t' 205 | 'Training Loss {train_loss:.4f} \t' 206 | 'Training Prec@1 {train_prec1:.3f} \t' 207 | 'Training Prec@5 {train_prec5:.3f} \t' 208 | 'Validation Loss {val_loss:.4f} \t' 209 | 'Validation Prec@1 {val_prec1:.3f} \t' 210 | 'Validation Prec@5 {val_prec5:.3f} \n' 211 | .format(epoch + 1, train_loss=train_loss, val_loss=val_loss, 212 | train_prec1=train_prec1, val_prec1=val_prec1, 213 | train_prec5=train_prec5, val_prec5=val_prec5)) 214 | 215 | results.add(epoch=epoch + 1, train_loss=train_loss, val_loss=val_loss, 216 | train_error1=100 - train_prec1, val_error1=100 - val_prec1, 217 | train_error5=100 - train_prec5, val_error5=100 - val_prec5) 218 | results.plot(x='epoch', y=['train_loss', 'val_loss'], 219 | title='Loss', ylabel='loss') 220 | results.plot(x='epoch', y=['train_error1', 'val_error1'], 221 | title='Error@1', ylabel='error %') 222 | results.plot(x='epoch', y=['train_error5', 'val_error5'], 223 | title='Error@5', ylabel='error %') 224 | results.save() 225 | 226 | def forward(data_loader, model, criterion, epoch=0, training=True, optimizer=None): 227 | if args.gpus and len(args.gpus) > 1: 228 | model = torch.nn.DataParallel(model, args.gpus) 229 | batch_time = AverageMeter() 230 | data_time = AverageMeter() 231 | losses = AverageMeter() 232 | top1 = AverageMeter() 233 | top5 = AverageMeter() 234 | 235 | end = time.time() 236 | for i, (inputs, target) in enumerate(data_loader): 237 | # measure data loading time 238 | data_time.update(time.time() - end) 239 | if args.gpus is not None: 240 | target = target.cuda() 241 | #import pdb; pdb.set_trace() 242 | if criterion.__class__.__name__=='HingeLoss': 243 | target=target.unsqueeze(1) 244 | target_onehot = torch.cuda.FloatTensor(target.size(0), output_dim) 245 | target_onehot.fill_(-1) 246 | target_onehot.scatter_(1, target, 1) 247 | target=target.squeeze() 248 | if not training: 249 | with torch.no_grad(): 250 | input_var = Variable(inputs.type(args.type)) 251 | target_var = Variable(target_onehot) 252 | 253 | # compute output 254 | output = model(input_var) 255 | else: 256 | input_var = Variable(inputs.type(args.type)) 257 | target_var = Variable(target_onehot) 258 | 259 | # compute output 260 | output = model(input_var) 261 | 262 | #import pdb; pdb.set_trace() 263 | loss = criterion(output, target_onehot) 264 | #import pdb; pdb.set_trace() 265 | if type(output) is list: 266 | output = output[0] 267 | 268 | # measure accuracy and record loss 269 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 270 | losses.update(loss.item(), inputs.size(0)) 271 | top1.update(prec1.item(), inputs.size(0)) 272 | top5.update(prec5.item(), inputs.size(0)) 273 | #import pdb; pdb.set_trace() 274 | #if not training and top1.avg<15: 275 | # import pdb; pdb.set_trace() 276 | if training: 277 | # compute gradient and do SGD step 278 | optimizer.zero_grad() 279 | #add backwoed hook 280 | loss.backward() 281 | for p in list(model.parameters()): 282 | #import pdb; pdb.set_trace() 283 | if hasattr(p,'org'): 284 | #print('before:', p[0][0]) 285 | #gm=max(p.grad.data.max(),-p.grad.data.min()) 286 | #p.grad=p.grad.div(gm+1) 287 | p.data.copy_(p.org) 288 | #print('after:', p[0][0]) 289 | optimizer.step() 290 | for p in list(model.parameters()): 291 | #import pdb; pdb.set_trace() 292 | if hasattr(p,'org'): 293 | #print('before:', p[0][0]) 294 | p.org.copy_(p.data.clamp_(-1,1)) 295 | #if epoch>30: 296 | # import pdb; pdb.set_trace() 297 | 298 | # measure elapsed time 299 | batch_time.update(time.time() - end) 300 | end = time.time() 301 | 302 | if i % args.print_freq == 0: 303 | logging.info('{phase} - Epoch: [{0}][{1}/{2}]\t' 304 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 305 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 306 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 307 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 308 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 309 | epoch, i, len(data_loader), 310 | phase='TRAINING' if training else 'EVALUATING', 311 | batch_time=batch_time, 312 | data_time=data_time, loss=losses, top1=top1, top5=top5)) 313 | 314 | return losses.avg, top1.avg, top5.avg 315 | 316 | 317 | def train(data_loader, model, criterion, epoch, optimizer): 318 | # switch to train mode 319 | model.train() 320 | return forward(data_loader, model, criterion, epoch, 321 | training=True, optimizer=optimizer) 322 | 323 | 324 | def validate(data_loader, model, criterion, epoch): 325 | # switch to evaluate mode 326 | model.eval() 327 | return forward(data_loader, model, criterion, epoch, 328 | training=False, optimizer=None) 329 | 330 | 331 | if __name__ == '__main__': 332 | main() 333 | -------------------------------------------------------------------------------- /main_mnist.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import print_function 3 | import argparse 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | from torchvision import datasets, transforms 9 | from torch.autograd import Variable 10 | from models.binarized_modules import BinarizeLinear,BinarizeConv2d 11 | from models.binarized_modules import Binarize,HingeLoss 12 | # Training settings 13 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 14 | parser.add_argument('--batch-size', type=int, default=64, metavar='N', 15 | help='input batch size for training (default: 256)') 16 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', 17 | help='input batch size for testing (default: 1000)') 18 | parser.add_argument('--epochs', type=int, default=100, metavar='N', 19 | help='number of epochs to train (default: 10)') 20 | parser.add_argument('--lr', type=float, default=0.01, metavar='LR', 21 | help='learning rate (default: 0.001)') 22 | parser.add_argument('--momentum', type=float, default=0.5, metavar='M', 23 | help='SGD momentum (default: 0.5)') 24 | parser.add_argument('--no-cuda', action='store_true', default=False, 25 | help='disables CUDA training') 26 | parser.add_argument('--seed', type=int, default=1, metavar='S', 27 | help='random seed (default: 1)') 28 | parser.add_argument('--gpus', default=3, 29 | help='gpus used for training - e.g 0,1,3') 30 | parser.add_argument('--log-interval', type=int, default=10, metavar='N', 31 | help='how many batches to wait before logging training status') 32 | args = parser.parse_args() 33 | args.cuda = not args.no_cuda and torch.cuda.is_available() 34 | 35 | torch.manual_seed(args.seed) 36 | if args.cuda: 37 | torch.cuda.manual_seed(args.seed) 38 | 39 | 40 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} 41 | train_loader = torch.utils.data.DataLoader( 42 | datasets.MNIST('../data', train=True, download=True, 43 | transform=transforms.Compose([ 44 | transforms.ToTensor(), 45 | transforms.Normalize((0.1307,), (0.3081,)) 46 | ])), 47 | batch_size=args.batch_size, shuffle=True, **kwargs) 48 | test_loader = torch.utils.data.DataLoader( 49 | datasets.MNIST('../data', train=False, transform=transforms.Compose([ 50 | transforms.ToTensor(), 51 | transforms.Normalize((0.1307,), (0.3081,)) 52 | ])), 53 | batch_size=args.test_batch_size, shuffle=True, **kwargs) 54 | 55 | 56 | class Net(nn.Module): 57 | def __init__(self): 58 | super(Net, self).__init__() 59 | self.infl_ratio=3 60 | self.fc1 = BinarizeLinear(784, 2048*self.infl_ratio) 61 | self.htanh1 = nn.Hardtanh() 62 | self.bn1 = nn.BatchNorm1d(2048*self.infl_ratio) 63 | self.fc2 = BinarizeLinear(2048*self.infl_ratio, 2048*self.infl_ratio) 64 | self.htanh2 = nn.Hardtanh() 65 | self.bn2 = nn.BatchNorm1d(2048*self.infl_ratio) 66 | self.fc3 = BinarizeLinear(2048*self.infl_ratio, 2048*self.infl_ratio) 67 | self.htanh3 = nn.Hardtanh() 68 | self.bn3 = nn.BatchNorm1d(2048*self.infl_ratio) 69 | self.fc4 = nn.Linear(2048*self.infl_ratio, 10) 70 | self.logsoftmax=nn.LogSoftmax() 71 | self.drop=nn.Dropout(0.5) 72 | 73 | def forward(self, x): 74 | x = x.view(-1, 28*28) 75 | x = self.fc1(x) 76 | x = self.bn1(x) 77 | x = self.htanh1(x) 78 | x = self.fc2(x) 79 | x = self.bn2(x) 80 | x = self.htanh2(x) 81 | x = self.fc3(x) 82 | x = self.drop(x) 83 | x = self.bn3(x) 84 | x = self.htanh3(x) 85 | x = self.fc4(x) 86 | return self.logsoftmax(x) 87 | 88 | model = Net() 89 | if args.cuda: 90 | torch.cuda.set_device(3) 91 | model.cuda() 92 | 93 | 94 | criterion = nn.CrossEntropyLoss() 95 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 96 | 97 | 98 | def train(epoch): 99 | model.train() 100 | for batch_idx, (data, target) in enumerate(train_loader): 101 | if args.cuda: 102 | data, target = data.cuda(), target.cuda() 103 | data, target = Variable(data), Variable(target) 104 | optimizer.zero_grad() 105 | output = model(data) 106 | loss = criterion(output, target) 107 | 108 | if epoch%40==0: 109 | optimizer.param_groups[0]['lr']=optimizer.param_groups[0]['lr']*0.1 110 | 111 | optimizer.zero_grad() 112 | loss.backward() 113 | for p in list(model.parameters()): 114 | if hasattr(p,'org'): 115 | p.data.copy_(p.org) 116 | optimizer.step() 117 | for p in list(model.parameters()): 118 | if hasattr(p,'org'): 119 | p.org.copy_(p.data.clamp_(-1,1)) 120 | 121 | if batch_idx % args.log_interval == 0: 122 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 123 | epoch, batch_idx * len(data), len(train_loader.dataset), 124 | 100. * batch_idx / len(train_loader), loss.item())) 125 | 126 | def test(): 127 | model.eval() 128 | test_loss = 0 129 | correct = 0 130 | with torch.no_grad(): 131 | for data, target in test_loader: 132 | if args.cuda: 133 | data, target = data.cuda(), target.cuda() 134 | data, target = Variable(data), Variable(target) 135 | output = model(data) 136 | test_loss += criterion(output, target).item() # sum up batch loss 137 | pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability 138 | correct += pred.eq(target.data.view_as(pred)).cpu().sum() 139 | 140 | test_loss /= len(test_loader.dataset) 141 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 142 | test_loss, correct, len(test_loader.dataset), 143 | 100. * correct / len(test_loader.dataset))) 144 | 145 | 146 | for epoch in range(1, args.epochs + 1): 147 | train(epoch) 148 | test() 149 | if epoch%40==0: 150 | optimizer.param_groups[0]['lr']=optimizer.param_groups[0]['lr']*0.1 151 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .alexnet import * 3 | from .alexnet_binary import * 4 | from .resnet import * 5 | from .resnet_binary import * 6 | from .vgg_cifar10_binary import * 7 | -------------------------------------------------------------------------------- /models/alexnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torchvision.transforms as transforms 3 | 4 | __all__ = ['alexnet'] 5 | 6 | class AlexNetOWT_BN(nn.Module): 7 | 8 | def __init__(self, num_classes=1000): 9 | super(AlexNetOWT_BN, self).__init__() 10 | self.features = nn.Sequential( 11 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2, 12 | bias=False), 13 | nn.MaxPool2d(kernel_size=3, stride=2), 14 | nn.BatchNorm2d(64), 15 | nn.ReLU(inplace=True), 16 | nn.Conv2d(64, 192, kernel_size=5, padding=2, bias=False), 17 | nn.MaxPool2d(kernel_size=3, stride=2), 18 | nn.ReLU(inplace=True), 19 | nn.BatchNorm2d(192), 20 | nn.Conv2d(192, 384, kernel_size=3, padding=1, bias=False), 21 | nn.ReLU(inplace=True), 22 | nn.BatchNorm2d(384), 23 | nn.Conv2d(384, 256, kernel_size=3, padding=1, bias=False), 24 | nn.ReLU(inplace=True), 25 | nn.BatchNorm2d(256), 26 | nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), 27 | nn.MaxPool2d(kernel_size=3, stride=2), 28 | nn.ReLU(inplace=True), 29 | nn.BatchNorm2d(256) 30 | ) 31 | self.classifier = nn.Sequential( 32 | nn.Linear(256 * 6 * 6, 4096, bias=False), 33 | nn.BatchNorm1d(4096), 34 | nn.ReLU(inplace=True), 35 | nn.Dropout(0.5), 36 | nn.Linear(4096, 4096, bias=False), 37 | nn.BatchNorm1d(4096), 38 | nn.ReLU(inplace=True), 39 | nn.Dropout(0.5), 40 | nn.Linear(4096, num_classes) 41 | ) 42 | 43 | self.regime = { 44 | 0: {'optimizer': 'SGD', 'lr': 1e-2, 45 | 'weight_decay': 5e-4, 'momentum': 0.9}, 46 | 10: {'lr': 5e-3}, 47 | 15: {'lr': 1e-3, 'weight_decay': 0}, 48 | 20: {'lr': 5e-4}, 49 | 25: {'lr': 1e-4} 50 | } 51 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 52 | std=[0.229, 0.224, 0.225]) 53 | self.input_transform = { 54 | 'train': transforms.Compose([ 55 | transforms.Scale(256), 56 | transforms.RandomCrop(224), 57 | transforms.RandomHorizontalFlip(), 58 | transforms.ToTensor(), 59 | normalize 60 | ]), 61 | 'eval': transforms.Compose([ 62 | transforms.Scale(256), 63 | transforms.CenterCrop(224), 64 | transforms.ToTensor(), 65 | normalize 66 | ]) 67 | } 68 | 69 | def forward(self, x): 70 | x = self.features(x) 71 | x = x.view(-1, 256 * 6 * 6) 72 | x = self.classifier(x) 73 | return x 74 | 75 | 76 | def alexnet(**kwargs): 77 | num_classes = kwargs.get( 'num_classes', 1000) 78 | return AlexNetOWT_BN(num_classes) 79 | -------------------------------------------------------------------------------- /models/alexnet_binary.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torchvision.transforms as transforms 3 | from .binarized_modules import BinarizeLinear,BinarizeConv2d 4 | 5 | __all__ = ['alexnet_binary'] 6 | 7 | class AlexNetOWT_BN(nn.Module): 8 | 9 | def __init__(self, num_classes=1000): 10 | super(AlexNetOWT_BN, self).__init__() 11 | self.ratioInfl=3 12 | self.features = nn.Sequential( 13 | BinarizeConv2d(3, int(64*self.ratioInfl), kernel_size=11, stride=4, padding=2), 14 | nn.MaxPool2d(kernel_size=3, stride=2), 15 | nn.BatchNorm2d(int(64*self.ratioInfl)), 16 | nn.Hardtanh(inplace=True), 17 | BinarizeConv2d(int(64*self.ratioInfl), int(192*self.ratioInfl), kernel_size=5, padding=2), 18 | nn.MaxPool2d(kernel_size=3, stride=2), 19 | nn.BatchNorm2d(int(192*self.ratioInfl)), 20 | nn.Hardtanh(inplace=True), 21 | 22 | BinarizeConv2d(int(192*self.ratioInfl), int(384*self.ratioInfl), kernel_size=3, padding=1), 23 | nn.BatchNorm2d(int(384*self.ratioInfl)), 24 | nn.Hardtanh(inplace=True), 25 | 26 | BinarizeConv2d(int(384*self.ratioInfl), int(256*self.ratioInfl), kernel_size=3, padding=1), 27 | nn.BatchNorm2d(int(256*self.ratioInfl)), 28 | nn.Hardtanh(inplace=True), 29 | 30 | BinarizeConv2d(int(256*self.ratioInfl), 256, kernel_size=3, padding=1), 31 | nn.MaxPool2d(kernel_size=3, stride=2), 32 | nn.BatchNorm2d(256), 33 | nn.Hardtanh(inplace=True) 34 | 35 | ) 36 | self.classifier = nn.Sequential( 37 | BinarizeLinear(256 * 6 * 6, 4096), 38 | nn.BatchNorm1d(4096), 39 | nn.Hardtanh(inplace=True), 40 | #nn.Dropout(0.5), 41 | BinarizeLinear(4096, 4096), 42 | nn.BatchNorm1d(4096), 43 | nn.Hardtanh(inplace=True), 44 | #nn.Dropout(0.5), 45 | BinarizeLinear(4096, num_classes), 46 | nn.BatchNorm1d(1000), 47 | nn.LogSoftmax() 48 | ) 49 | 50 | #self.regime = { 51 | # 0: {'optimizer': 'SGD', 'lr': 1e-2, 52 | # 'weight_decay': 5e-4, 'momentum': 0.9}, 53 | # 10: {'lr': 5e-3}, 54 | # 15: {'lr': 1e-3, 'weight_decay': 0}, 55 | # 20: {'lr': 5e-4}, 56 | # 25: {'lr': 1e-4} 57 | #} 58 | self.regime = { 59 | 0: {'optimizer': 'Adam', 'lr': 5e-3}, 60 | 20: {'lr': 1e-3}, 61 | 30: {'lr': 5e-4}, 62 | 35: {'lr': 1e-4}, 63 | 40: {'lr': 1e-5} 64 | } 65 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 66 | std=[0.229, 0.224, 0.225]) 67 | self.input_transform = { 68 | 'train': transforms.Compose([ 69 | transforms.Scale(256), 70 | transforms.RandomCrop(224), 71 | transforms.RandomHorizontalFlip(), 72 | transforms.ToTensor(), 73 | normalize 74 | ]), 75 | 'eval': transforms.Compose([ 76 | transforms.Scale(256), 77 | transforms.CenterCrop(224), 78 | transforms.ToTensor(), 79 | normalize 80 | ]) 81 | } 82 | 83 | def forward(self, x): 84 | x = self.features(x) 85 | x = x.view(-1, 256 * 6 * 6) 86 | x = self.classifier(x) 87 | return x 88 | 89 | 90 | def alexnet_binary(**kwargs): 91 | num_classes = kwargs.get( 'num_classes', 1000) 92 | return AlexNetOWT_BN(num_classes) 93 | -------------------------------------------------------------------------------- /models/binarized_modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pdb 3 | import torch.nn as nn 4 | import math 5 | from torch.autograd import Variable 6 | from torch.autograd.function import Function, InplaceFunction 7 | 8 | import numpy as np 9 | 10 | 11 | 12 | 13 | class Binarize(InplaceFunction): 14 | 15 | def forward(ctx,input,quant_mode='det',allow_scale=False,inplace=False): 16 | ctx.inplace = inplace 17 | if ctx.inplace: 18 | ctx.mark_dirty(input) 19 | output = input 20 | else: 21 | output = input.clone() 22 | 23 | scale= output.abs().max() if allow_scale else 1 24 | 25 | if quant_mode=='det': 26 | return output.div(scale).sign().mul(scale) 27 | else: 28 | return output.div(scale).add_(1).div_(2).add_(torch.rand(output.size()).add(-0.5)).clamp_(0,1).round().mul_(2).add_(-1).mul(scale) 29 | 30 | def backward(ctx,grad_output): 31 | #STE 32 | grad_input=grad_output 33 | return grad_input,None,None,None 34 | 35 | 36 | class Quantize(InplaceFunction): 37 | def forward(ctx,input,quant_mode='det',numBits=4,inplace=False): 38 | ctx.inplace = inplace 39 | if ctx.inplace: 40 | ctx.mark_dirty(input) 41 | output = input 42 | else: 43 | output = input.clone() 44 | scale=(2**numBits-1)/(output.max()-output.min()) 45 | output = output.mul(scale).clamp(-2**(numBits-1)+1,2**(numBits-1)) 46 | if quant_mode=='det': 47 | output=output.round().div(scale) 48 | else: 49 | output=output.round().add(torch.rand(output.size()).add(-0.5)).div(scale) 50 | return output 51 | 52 | def backward(grad_output): 53 | #STE 54 | grad_input=grad_output 55 | return grad_input,None,None 56 | 57 | def binarized(input,quant_mode='det'): 58 | return Binarize.apply(input,quant_mode) 59 | 60 | def quantize(input,quant_mode,numBits): 61 | return Quantize.apply(input,quant_mode,numBits) 62 | 63 | class HingeLoss(nn.Module): 64 | def __init__(self): 65 | super(HingeLoss,self).__init__() 66 | self.margin=1.0 67 | 68 | def hinge_loss(self,input,target): 69 | #import pdb; pdb.set_trace() 70 | output=self.margin-input.mul(target) 71 | output[output.le(0)]=0 72 | return output.mean() 73 | 74 | def forward(self, input, target): 75 | return self.hinge_loss(input,target) 76 | 77 | class SqrtHingeLossFunction(Function): 78 | def __init__(self): 79 | super(SqrtHingeLossFunction,self).__init__() 80 | self.margin=1.0 81 | 82 | def forward(self, input, target): 83 | output=self.margin-input.mul(target) 84 | output[output.le(0)]=0 85 | self.save_for_backward(input, target) 86 | loss=output.mul(output).sum(0).sum(1).div(target.numel()) 87 | return loss 88 | 89 | def backward(self,grad_output): 90 | input, target = self.saved_tensors 91 | output=self.margin-input.mul(target) 92 | output[output.le(0)]=0 93 | grad_output.resize_as_(input).copy_(target).mul_(-2).mul_(output) 94 | grad_output.mul_(output.ne(0).float()) 95 | grad_output.div_(input.numel()) 96 | return grad_output,grad_output 97 | 98 | 99 | 100 | class BinarizeLinear(nn.Linear): 101 | 102 | def __init__(self, *kargs, **kwargs): 103 | super(BinarizeLinear, self).__init__(*kargs, **kwargs) 104 | 105 | def forward(self, input): 106 | 107 | if input.size(1) != 784: 108 | input_b=binarized(input) 109 | weight_b=binarized(self.weight) 110 | out = nn.functional.linear(input_b,weight_b) 111 | if not self.bias is None: 112 | self.bias.org=self.bias.data.clone() 113 | out += self.bias.view(1, -1).expand_as(out) 114 | 115 | return out 116 | 117 | class BinarizeConv2d(nn.Conv2d): 118 | 119 | def __init__(self, *kargs, **kwargs): 120 | super(BinarizeConv2d, self).__init__(*kargs, **kwargs) 121 | 122 | 123 | def forward(self, input): 124 | if input.size(1) != 3: 125 | input_b = binarized(input) 126 | else: 127 | input_b=input 128 | weight_b=binarized(self.weight) 129 | 130 | out = nn.functional.conv2d(input_b, weight_b, None, self.stride, 131 | self.padding, self.dilation, self.groups) 132 | 133 | if not self.bias is None: 134 | self.bias.org=self.bias.data.clone() 135 | out += self.bias.view(1, -1, 1, 1).expand_as(out) 136 | 137 | return out 138 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torchvision.transforms as transforms 3 | import math 4 | 5 | __all__ = ['resnet'] 6 | 7 | def conv3x3(in_planes, out_planes, stride=1): 8 | "3x3 convolution with padding" 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 10 | padding=1, bias=False) 11 | 12 | 13 | def init_model(model): 14 | for m in model.modules(): 15 | if isinstance(m, nn.Conv2d): 16 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 17 | m.weight.data.normal_(0, math.sqrt(2. / n)) 18 | elif isinstance(m, nn.BatchNorm2d): 19 | m.weight.data.fill_(1) 20 | m.bias.data.zero_() 21 | 22 | 23 | class BasicBlock(nn.Module): 24 | expansion = 1 25 | 26 | def __init__(self, inplanes, planes, stride=1, downsample=None): 27 | super(BasicBlock, self).__init__() 28 | self.conv1 = conv3x3(inplanes, planes, stride) 29 | self.bn1 = nn.BatchNorm2d(planes) 30 | self.relu = nn.ReLU(inplace=True) 31 | self.conv2 = conv3x3(planes, planes) 32 | self.bn2 = nn.BatchNorm2d(planes) 33 | self.downsample = downsample 34 | self.stride = stride 35 | 36 | def forward(self, x): 37 | residual = x 38 | 39 | out = self.conv1(x) 40 | out = self.bn1(out) 41 | out = self.relu(out) 42 | 43 | out = self.conv2(out) 44 | out = self.bn2(out) 45 | 46 | if self.downsample is not None: 47 | residual = self.downsample(x) 48 | 49 | out += residual 50 | out = self.relu(out) 51 | 52 | return out 53 | 54 | 55 | class Bottleneck(nn.Module): 56 | expansion = 4 57 | 58 | def __init__(self, inplanes, planes, stride=1, downsample=None): 59 | super(Bottleneck, self).__init__() 60 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 61 | self.bn1 = nn.BatchNorm2d(planes) 62 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 63 | padding=1, bias=False) 64 | self.bn2 = nn.BatchNorm2d(planes) 65 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 66 | self.bn3 = nn.BatchNorm2d(planes * 4) 67 | self.relu = nn.ReLU(inplace=True) 68 | self.downsample = downsample 69 | self.stride = stride 70 | 71 | def forward(self, x): 72 | residual = x 73 | 74 | out = self.conv1(x) 75 | out = self.bn1(out) 76 | out = self.relu(out) 77 | 78 | out = self.conv2(out) 79 | out = self.bn2(out) 80 | out = self.relu(out) 81 | 82 | out = self.conv3(out) 83 | out = self.bn3(out) 84 | 85 | if self.downsample is not None: 86 | residual = self.downsample(x) 87 | 88 | out += residual 89 | out = self.relu(out) 90 | 91 | return out 92 | 93 | 94 | class ResNet(nn.Module): 95 | 96 | def __init__(self): 97 | super(ResNet, self).__init__() 98 | 99 | def _make_layer(self, block, planes, blocks, stride=1): 100 | downsample = None 101 | if stride != 1 or self.inplanes != planes * block.expansion: 102 | downsample = nn.Sequential( 103 | nn.Conv2d(self.inplanes, planes * block.expansion, 104 | kernel_size=1, stride=stride, bias=False), 105 | nn.BatchNorm2d(planes * block.expansion), 106 | ) 107 | 108 | layers = [] 109 | layers.append(block(self.inplanes, planes, stride, downsample)) 110 | self.inplanes = planes * block.expansion 111 | for i in range(1, blocks): 112 | layers.append(block(self.inplanes, planes)) 113 | 114 | return nn.Sequential(*layers) 115 | 116 | def forward(self, x): 117 | x = self.conv1(x) 118 | x = self.bn1(x) 119 | x = self.relu(x) 120 | x = self.maxpool(x) 121 | 122 | x = self.layer1(x) 123 | x = self.layer2(x) 124 | x = self.layer3(x) 125 | x = self.layer4(x) 126 | 127 | x = self.avgpool(x) 128 | x = x.view(x.size(0), -1) 129 | x = self.fc(x) 130 | 131 | return x 132 | 133 | 134 | class ResNet_imagenet(ResNet): 135 | 136 | def __init__(self, num_classes=1000, 137 | block=Bottleneck, layers=[3, 4, 23, 3]): 138 | super(ResNet_imagenet, self).__init__() 139 | self.inplanes = 64 140 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 141 | bias=False) 142 | self.bn1 = nn.BatchNorm2d(64) 143 | self.relu = nn.ReLU(inplace=True) 144 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 145 | self.layer1 = self._make_layer(block, 64, layers[0]) 146 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 147 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 148 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 149 | self.avgpool = nn.AvgPool2d(7) 150 | self.fc = nn.Linear(512 * block.expansion, num_classes) 151 | 152 | init_model(self) 153 | self.regime = { 154 | 0: {'optimizer': 'SGD', 'lr': 1e-1, 155 | 'weight_decay': 1e-4, 'momentum': 0.9}, 156 | 30: {'lr': 1e-2}, 157 | 60: {'lr': 1e-3, 'weight_decay': 0}, 158 | 90: {'lr': 1e-4} 159 | } 160 | 161 | 162 | class ResNet_cifar10(ResNet): 163 | 164 | def __init__(self, num_classes=10, 165 | block=BasicBlock, depth=18): 166 | super(ResNet_cifar10, self).__init__() 167 | self.inplanes = 16 168 | n = int((depth - 2) / 6) 169 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, 170 | bias=False) 171 | self.bn1 = nn.BatchNorm2d(16) 172 | self.relu = nn.ReLU(inplace=True) 173 | self.maxpool = lambda x: x 174 | self.layer1 = self._make_layer(block, 16, n) 175 | self.layer2 = self._make_layer(block, 32, n, stride=2) 176 | self.layer3 = self._make_layer(block, 64, n, stride=2) 177 | self.layer4 = lambda x: x 178 | self.avgpool = nn.AvgPool2d(8) 179 | self.fc = nn.Linear(64, num_classes) 180 | 181 | init_model(self) 182 | self.regime = { 183 | 0: {'optimizer': 'SGD', 'lr': 1e-1, 184 | 'weight_decay': 1e-4, 'momentum': 0.9}, 185 | 81: {'lr': 1e-2}, 186 | 122: {'lr': 1e-3, 'weight_decay': 0}, 187 | 164: {'lr': 1e-4} 188 | } 189 | 190 | 191 | def resnet(**kwargs): 192 | num_classes, depth, dataset = map( 193 | kwargs.get, ['num_classes', 'depth', 'dataset']) 194 | if dataset == 'imagenet': 195 | num_classes = num_classes or 1000 196 | depth = depth or 50 197 | if depth == 18: 198 | return ResNet_imagenet(num_classes=num_classes, 199 | block=BasicBlock, layers=[2, 2, 2, 2]) 200 | if depth == 34: 201 | return ResNet_imagenet(num_classes=num_classes, 202 | block=BasicBlock, layers=[3, 4, 6, 3]) 203 | if depth == 50: 204 | return ResNet_imagenet(num_classes=num_classes, 205 | block=Bottleneck, layers=[3, 4, 6, 3]) 206 | if depth == 101: 207 | return ResNet_imagenet(num_classes=num_classes, 208 | block=Bottleneck, layers=[3, 4, 23, 3]) 209 | if depth == 152: 210 | return ResNet_imagenet(num_classes=num_classes, 211 | block=Bottleneck, layers=[3, 8, 36, 3]) 212 | 213 | elif dataset == 'cifar10': 214 | num_classes = num_classes or 10 215 | depth = depth or 18 #56 216 | return ResNet_cifar10(num_classes=num_classes, 217 | block=BasicBlock, depth=depth) 218 | -------------------------------------------------------------------------------- /models/resnet_binary.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torchvision.transforms as transforms 3 | import math 4 | from .binarized_modules import BinarizeLinear,BinarizeConv2d 5 | 6 | __all__ = ['resnet_binary'] 7 | 8 | def Binaryconv3x3(in_planes, out_planes, stride=1): 9 | "3x3 convolution with padding" 10 | return BinarizeConv2d(in_planes, out_planes, kernel_size=3, stride=stride, 11 | padding=1, bias=False) 12 | 13 | def conv3x3(in_planes, out_planes, stride=1): 14 | "3x3 convolution with padding" 15 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 16 | padding=1, bias=False) 17 | 18 | def init_model(model): 19 | for m in model.modules(): 20 | if isinstance(m, BinarizeConv2d): 21 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 22 | m.weight.data.normal_(0, math.sqrt(2. / n)) 23 | elif isinstance(m, nn.BatchNorm2d): 24 | m.weight.data.fill_(1) 25 | m.bias.data.zero_() 26 | 27 | 28 | class BasicBlock(nn.Module): 29 | expansion = 1 30 | 31 | def __init__(self, inplanes, planes, stride=1, downsample=None,do_bntan=True): 32 | super(BasicBlock, self).__init__() 33 | 34 | self.conv1 = Binaryconv3x3(inplanes, planes, stride) 35 | self.bn1 = nn.BatchNorm2d(planes) 36 | self.tanh1 = nn.Hardtanh(inplace=True) 37 | self.conv2 = Binaryconv3x3(planes, planes) 38 | self.tanh2 = nn.Hardtanh(inplace=True) 39 | self.bn2 = nn.BatchNorm2d(planes) 40 | 41 | self.downsample = downsample 42 | self.do_bntan=do_bntan; 43 | self.stride = stride 44 | 45 | def forward(self, x): 46 | 47 | residual = x.clone() 48 | 49 | out = self.conv1(x) 50 | out = self.bn1(out) 51 | out = self.tanh1(out) 52 | 53 | out = self.conv2(out) 54 | 55 | 56 | if self.downsample is not None: 57 | if residual.data.max()>1: 58 | import pdb; pdb.set_trace() 59 | residual = self.downsample(residual) 60 | 61 | out += residual 62 | if self.do_bntan: 63 | out = self.bn2(out) 64 | out = self.tanh2(out) 65 | 66 | return out 67 | 68 | 69 | class Bottleneck(nn.Module): 70 | expansion = 4 71 | 72 | def __init__(self, inplanes, planes, stride=1, downsample=None): 73 | super(Bottleneck, self).__init__() 74 | self.conv1 = BinarizeConv2d(inplanes, planes, kernel_size=1, bias=False) 75 | self.bn1 = nn.BatchNorm2d(planes) 76 | self.conv2 = BinarizeConv2d(planes, planes, kernel_size=3, stride=stride, 77 | padding=1, bias=False) 78 | self.bn2 = nn.BatchNorm2d(planes) 79 | self.conv3 = BinarizeConv2d(planes, planes * 4, kernel_size=1, bias=False) 80 | self.bn3 = nn.BatchNorm2d(planes * 4) 81 | self.tanh = nn.Hardtanh(inplace=True) 82 | self.downsample = downsample 83 | self.stride = stride 84 | 85 | def forward(self, x): 86 | residual = x 87 | import pdb; pdb.set_trace() 88 | out = self.conv1(x) 89 | out = self.bn1(out) 90 | out = self.tanh(out) 91 | 92 | out = self.conv2(out) 93 | out = self.bn2(out) 94 | out = self.tanh(out) 95 | 96 | out = self.conv3(out) 97 | out = self.bn3(out) 98 | 99 | if self.downsample is not None: 100 | residual = self.downsample(x) 101 | 102 | out += residual 103 | if self.do_bntan: 104 | out = self.bn2(out) 105 | out = self.tanh2(out) 106 | 107 | return out 108 | 109 | 110 | class ResNet(nn.Module): 111 | 112 | def __init__(self): 113 | super(ResNet, self).__init__() 114 | 115 | def _make_layer(self, block, planes, blocks, stride=1,do_bntan=True): 116 | downsample = None 117 | if stride != 1 or self.inplanes != planes * block.expansion: 118 | downsample = nn.Sequential( 119 | BinarizeConv2d(self.inplanes, planes * block.expansion, 120 | kernel_size=1, stride=stride, bias=False), 121 | nn.BatchNorm2d(planes * block.expansion), 122 | ) 123 | 124 | layers = [] 125 | layers.append(block(self.inplanes, planes, stride, downsample)) 126 | self.inplanes = planes * block.expansion 127 | for i in range(1, blocks-1): 128 | layers.append(block(self.inplanes, planes)) 129 | layers.append(block(self.inplanes, planes,do_bntan=do_bntan)) 130 | return nn.Sequential(*layers) 131 | 132 | def forward(self, x): 133 | x = self.conv1(x) 134 | x = self.maxpool(x) 135 | x = self.bn1(x) 136 | x = self.tanh1(x) 137 | x = self.layer1(x) 138 | x = self.layer2(x) 139 | x = self.layer3(x) 140 | x = self.layer4(x) 141 | 142 | x = self.avgpool(x) 143 | x = x.view(x.size(0), -1) 144 | x = self.bn2(x) 145 | x = self.tanh2(x) 146 | x = self.fc(x) 147 | x = self.bn3(x) 148 | x = self.logsoftmax(x) 149 | 150 | return x 151 | 152 | 153 | class ResNet_imagenet(ResNet): 154 | 155 | def __init__(self, num_classes=1000, 156 | block=Bottleneck, layers=[3, 4, 23, 3]): 157 | super(ResNet_imagenet, self).__init__() 158 | self.inplanes = 64 159 | self.conv1 = BinarizeConv2d(3, 64, kernel_size=7, stride=2, padding=3, 160 | bias=False) 161 | self.bn1 = nn.BatchNorm2d(64) 162 | self.tanh = nn.Hardtanh(inplace=True) 163 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 164 | self.layer1 = self._make_layer(block, 64, layers[0]) 165 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 166 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 167 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 168 | self.avgpool = nn.AvgPool2d(7) 169 | self.fc = BinarizeLinear(512 * block.expansion, num_classes) 170 | 171 | init_model(self) 172 | self.regime = { 173 | 0: {'optimizer': 'SGD', 'lr': 1e-1, 174 | 'weight_decay': 1e-4, 'momentum': 0.9}, 175 | 30: {'lr': 1e-2}, 176 | 60: {'lr': 1e-3, 'weight_decay': 0}, 177 | 90: {'lr': 1e-4} 178 | } 179 | 180 | 181 | class ResNet_cifar10(ResNet): 182 | 183 | def __init__(self, num_classes=10, 184 | block=BasicBlock, depth=18): 185 | super(ResNet_cifar10, self).__init__() 186 | self.inflate = 5 187 | self.inplanes = 16*self.inflate 188 | n = int((depth - 2) / 6) 189 | self.conv1 = BinarizeConv2d(3, 16*self.inflate, kernel_size=3, stride=1, padding=1, 190 | bias=False) 191 | self.maxpool = lambda x: x 192 | self.bn1 = nn.BatchNorm2d(16*self.inflate) 193 | self.tanh1 = nn.Hardtanh(inplace=True) 194 | self.tanh2 = nn.Hardtanh(inplace=True) 195 | self.layer1 = self._make_layer(block, 16*self.inflate, n) 196 | self.layer2 = self._make_layer(block, 32*self.inflate, n, stride=2) 197 | self.layer3 = self._make_layer(block, 64*self.inflate, n, stride=2,do_bntan=False) 198 | self.layer4 = lambda x: x 199 | self.avgpool = nn.AvgPool2d(8) 200 | self.bn2 = nn.BatchNorm1d(64*self.inflate) 201 | self.bn3 = nn.BatchNorm1d(10) 202 | self.logsoftmax = nn.LogSoftmax() 203 | self.fc = BinarizeLinear(64*self.inflate, num_classes) 204 | 205 | init_model(self) 206 | #self.regime = { 207 | # 0: {'optimizer': 'SGD', 'lr': 1e-1, 208 | # 'weight_decay': 1e-4, 'momentum': 0.9}, 209 | # 81: {'lr': 1e-4}, 210 | # 122: {'lr': 1e-5, 'weight_decay': 0}, 211 | # 164: {'lr': 1e-6} 212 | #} 213 | self.regime = { 214 | 0: {'optimizer': 'Adam', 'lr': 5e-3}, 215 | 101: {'lr': 1e-3}, 216 | 142: {'lr': 5e-4}, 217 | 184: {'lr': 1e-4}, 218 | 220: {'lr': 1e-5} 219 | } 220 | 221 | 222 | def resnet_binary(**kwargs): 223 | num_classes, depth, dataset = map( 224 | kwargs.get, ['num_classes', 'depth', 'dataset']) 225 | if dataset == 'imagenet': 226 | num_classes = num_classes or 1000 227 | depth = depth or 50 228 | if depth == 18: 229 | return ResNet_imagenet(num_classes=num_classes, 230 | block=BasicBlock, layers=[2, 2, 2, 2]) 231 | if depth == 34: 232 | return ResNet_imagenet(num_classes=num_classes, 233 | block=BasicBlock, layers=[3, 4, 6, 3]) 234 | if depth == 50: 235 | return ResNet_imagenet(num_classes=num_classes, 236 | block=Bottleneck, layers=[3, 4, 6, 3]) 237 | if depth == 101: 238 | return ResNet_imagenet(num_classes=num_classes, 239 | block=Bottleneck, layers=[3, 4, 23, 3]) 240 | if depth == 152: 241 | return ResNet_imagenet(num_classes=num_classes, 242 | block=Bottleneck, layers=[3, 8, 36, 3]) 243 | 244 | elif dataset == 'cifar10': 245 | num_classes = num_classes or 10 246 | depth = depth or 18 247 | return ResNet_cifar10(num_classes=num_classes, 248 | block=BasicBlock, depth=depth) 249 | -------------------------------------------------------------------------------- /models/vgg_cifar10.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torchvision.transforms as transforms 3 | 4 | 5 | class AlexNetOWT_BN(nn.Module): 6 | 7 | def __init__(self, num_classes=1000): 8 | super(AlexNetOWT_BN, self).__init__() 9 | self.features = nn.Sequential( 10 | nn.Conv2d(3, 128, kernel_size=3, stride=1, padding=1, 11 | bias=False), 12 | nn.BatchNorm2d(128), 13 | nn.ReLU(inplace=True), 14 | 15 | nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False), 16 | nn.MaxPool2d(kernel_size=2, stride=2), 17 | nn.ReLU(inplace=True), 18 | nn.BatchNorm2d(128), 19 | 20 | nn.Conv2d(128, 256, kernel_size=3, padding=1, bias=False), 21 | nn.ReLU(inplace=True), 22 | nn.BatchNorm2d(256), 23 | 24 | nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), 25 | nn.MaxPool2d(kernel_size=2, stride=2), 26 | nn.ReLU(inplace=True), 27 | nn.BatchNorm2d(256), 28 | 29 | nn.Conv2d(256, 512, kernel_size=3, padding=1, bias=False), 30 | nn.ReLU(inplace=True), 31 | nn.BatchNorm2d(512), 32 | 33 | nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=False), 34 | nn.MaxPool2d(kernel_size=2, stride=2), 35 | nn.ReLU(inplace=True), 36 | nn.BatchNorm2d(512), 37 | ) 38 | self.classifier = nn.Sequential( 39 | nn.Linear(512 * 4 * 4, 1024, bias=False), 40 | nn.BatchNorm1d(1024), 41 | nn.ReLU(inplace=True), 42 | nn.Dropout(0.5), 43 | nn.Linear(1024, 1024, bias=False), 44 | nn.BatchNorm1d(1024), 45 | nn.ReLU(inplace=True), 46 | nn.Dropout(0.5), 47 | nn.Linear(1024, num_classes) 48 | nn.LogSoftMax() 49 | ) 50 | 51 | self.regime = { 52 | 0: {'optimizer': 'SGD', 'lr': 1e-2, 53 | 'weight_decay': 5e-4, 'momentum': 0.9}, 54 | 10: {'lr': 5e-3}, 55 | 15: {'lr': 1e-3, 'weight_decay': 0}, 56 | 20: {'lr': 5e-4}, 57 | 25: {'lr': 1e-4} 58 | } 59 | 60 | def forward(self, x): 61 | x = self.features(x) 62 | x = x.view(-1, 512 * 4 * 4) 63 | x = self.classifier(x) 64 | return x 65 | 66 | 67 | def model(**kwargs): 68 | num_classes = kwargs.get( 'num_classes', 1000) 69 | return AlexNetOWT_BN(num_classes) 70 | -------------------------------------------------------------------------------- /models/vgg_cifar10_binary.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.transforms as transforms 4 | from torch.autograd import Function 5 | from .binarized_modules import BinarizeLinear,BinarizeConv2d 6 | 7 | 8 | 9 | class VGG_Cifar10(nn.Module): 10 | 11 | def __init__(self, num_classes=1000): 12 | super(VGG_Cifar10, self).__init__() 13 | self.infl_ratio=3; 14 | self.features = nn.Sequential( 15 | BinarizeConv2d(3, 128*self.infl_ratio, kernel_size=3, stride=1, padding=1, 16 | bias=True), 17 | nn.BatchNorm2d(128*self.infl_ratio), 18 | nn.Hardtanh(inplace=True), 19 | 20 | BinarizeConv2d(128*self.infl_ratio, 128*self.infl_ratio, kernel_size=3, padding=1, bias=True), 21 | nn.MaxPool2d(kernel_size=2, stride=2), 22 | nn.BatchNorm2d(128*self.infl_ratio), 23 | nn.Hardtanh(inplace=True), 24 | 25 | 26 | BinarizeConv2d(128*self.infl_ratio, 256*self.infl_ratio, kernel_size=3, padding=1, bias=True), 27 | nn.BatchNorm2d(256*self.infl_ratio), 28 | nn.Hardtanh(inplace=True), 29 | 30 | 31 | BinarizeConv2d(256*self.infl_ratio, 256*self.infl_ratio, kernel_size=3, padding=1, bias=True), 32 | nn.MaxPool2d(kernel_size=2, stride=2), 33 | nn.BatchNorm2d(256*self.infl_ratio), 34 | nn.Hardtanh(inplace=True), 35 | 36 | 37 | BinarizeConv2d(256*self.infl_ratio, 512*self.infl_ratio, kernel_size=3, padding=1, bias=True), 38 | nn.BatchNorm2d(512*self.infl_ratio), 39 | nn.Hardtanh(inplace=True), 40 | 41 | 42 | BinarizeConv2d(512*self.infl_ratio, 512, kernel_size=3, padding=1, bias=True), 43 | nn.MaxPool2d(kernel_size=2, stride=2), 44 | nn.BatchNorm2d(512), 45 | nn.Hardtanh(inplace=True) 46 | 47 | ) 48 | self.classifier = nn.Sequential( 49 | BinarizeLinear(512 * 4 * 4, 1024, bias=True), 50 | nn.BatchNorm1d(1024), 51 | nn.Hardtanh(inplace=True), 52 | #nn.Dropout(0.5), 53 | BinarizeLinear(1024, 1024, bias=True), 54 | nn.BatchNorm1d(1024), 55 | nn.Hardtanh(inplace=True), 56 | #nn.Dropout(0.5), 57 | BinarizeLinear(1024, num_classes, bias=True), 58 | nn.BatchNorm1d(num_classes, affine=False), 59 | nn.LogSoftmax() 60 | ) 61 | 62 | self.regime = { 63 | 0: {'optimizer': 'Adam', 'betas': (0.9, 0.999),'lr': 5e-3}, 64 | 40: {'lr': 1e-3}, 65 | 80: {'lr': 5e-4}, 66 | 100: {'lr': 1e-4}, 67 | 120: {'lr': 5e-5}, 68 | 140: {'lr': 1e-5} 69 | } 70 | 71 | def forward(self, x): 72 | x = self.features(x) 73 | x = x.view(-1, 512 * 4 * 4) 74 | x = self.classifier(x) 75 | return x 76 | 77 | 78 | def vgg_cifar10_binary(**kwargs): 79 | num_classes = kwargs.get( 'num_classes', 10) 80 | return VGG_Cifar10(num_classes) 81 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | import random 4 | 5 | __imagenet_stats = {'mean': [0.485, 0.456, 0.406], 6 | 'std': [0.229, 0.224, 0.225]} 7 | 8 | __imagenet_pca = { 9 | 'eigval': torch.Tensor([0.2175, 0.0188, 0.0045]), 10 | 'eigvec': torch.Tensor([ 11 | [-0.5675, 0.7192, 0.4009], 12 | [-0.5808, -0.0045, -0.8140], 13 | [-0.5836, -0.6948, 0.4203], 14 | ]) 15 | } 16 | 17 | 18 | def scale_crop(input_size, scale_size=None, normalize=__imagenet_stats): 19 | t_list = [ 20 | transforms.CenterCrop(input_size), 21 | transforms.ToTensor(), 22 | transforms.Normalize(**normalize), 23 | ] 24 | if scale_size != input_size: 25 | t_list = [transforms.Scale(scale_size)] + t_list 26 | 27 | return transforms.Compose(t_list) 28 | 29 | 30 | def scale_random_crop(input_size, scale_size=None, normalize=__imagenet_stats): 31 | t_list = [ 32 | transforms.RandomCrop(input_size), 33 | transforms.ToTensor(), 34 | transforms.Normalize(**normalize), 35 | ] 36 | if scale_size != input_size: 37 | t_list = [transforms.Scale(scale_size)] + t_list 38 | 39 | transforms.Compose(t_list) 40 | 41 | 42 | def pad_random_crop(input_size, scale_size=None, normalize=__imagenet_stats): 43 | padding = int((scale_size - input_size) / 2) 44 | return transforms.Compose([ 45 | transforms.RandomCrop(input_size, padding=padding), 46 | transforms.RandomHorizontalFlip(), 47 | transforms.ToTensor(), 48 | transforms.Normalize(**normalize), 49 | ]) 50 | 51 | 52 | def inception_preproccess(input_size, normalize=__imagenet_stats): 53 | return transforms.Compose([ 54 | transforms.RandomSizedCrop(input_size), 55 | transforms.RandomHorizontalFlip(), 56 | transforms.ToTensor(), 57 | transforms.Normalize(**normalize) 58 | ]) 59 | def inception_color_preproccess(input_size, normalize=__imagenet_stats): 60 | return transforms.Compose([ 61 | transforms.RandomSizedCrop(input_size), 62 | transforms.RandomHorizontalFlip(), 63 | transforms.ToTensor(), 64 | ColorJitter( 65 | brightness=0.4, 66 | contrast=0.4, 67 | saturation=0.4, 68 | ), 69 | Lighting(0.1, __imagenet_pca['eigval'], __imagenet_pca['eigvec']), 70 | transforms.Normalize(**normalize) 71 | ]) 72 | 73 | 74 | def get_transform(name='imagenet', input_size=None, 75 | scale_size=None, normalize=None, augment=True): 76 | normalize = normalize or __imagenet_stats 77 | if name == 'imagenet': 78 | scale_size = scale_size or 256 79 | input_size = input_size or 224 80 | if augment: 81 | return inception_preproccess(input_size, normalize=normalize) 82 | else: 83 | return scale_crop(input_size=input_size, 84 | scale_size=scale_size, normalize=normalize) 85 | elif 'cifar' in name: 86 | input_size = input_size or 32 87 | if augment: 88 | scale_size = scale_size or 40 89 | return pad_random_crop(input_size, scale_size=scale_size, 90 | normalize=normalize) 91 | else: 92 | scale_size = scale_size or 32 93 | return scale_crop(input_size=input_size, 94 | scale_size=scale_size, normalize=normalize) 95 | elif name == 'mnist': 96 | normalize = {'mean': [0.5], 'std': [0.5]} 97 | input_size = input_size or 28 98 | if augment: 99 | scale_size = scale_size or 32 100 | return pad_random_crop(input_size, scale_size=scale_size, 101 | normalize=normalize) 102 | else: 103 | scale_size = scale_size or 32 104 | return scale_crop(input_size=input_size, 105 | scale_size=scale_size, normalize=normalize) 106 | 107 | 108 | class Lighting(object): 109 | """Lighting noise(AlexNet - style PCA - based noise)""" 110 | 111 | def __init__(self, alphastd, eigval, eigvec): 112 | self.alphastd = alphastd 113 | self.eigval = eigval 114 | self.eigvec = eigvec 115 | 116 | def __call__(self, img): 117 | if self.alphastd == 0: 118 | return img 119 | 120 | alpha = img.new().resize_(3).normal_(0, self.alphastd) 121 | rgb = self.eigvec.type_as(img).clone()\ 122 | .mul(alpha.view(1, 3).expand(3, 3))\ 123 | .mul(self.eigval.view(1, 3).expand(3, 3))\ 124 | .sum(1).squeeze() 125 | 126 | return img.add(rgb.view(3, 1, 1).expand_as(img)) 127 | 128 | 129 | class Grayscale(object): 130 | 131 | def __call__(self, img): 132 | gs = img.clone() 133 | gs[0].mul_(0.299).add_(0.587, gs[1]).add_(0.114, gs[2]) 134 | gs[1].copy_(gs[0]) 135 | gs[2].copy_(gs[0]) 136 | return gs 137 | 138 | 139 | class Saturation(object): 140 | 141 | def __init__(self, var): 142 | self.var = var 143 | 144 | def __call__(self, img): 145 | gs = Grayscale()(img) 146 | alpha = random.uniform(0, self.var) 147 | return img.lerp(gs, alpha) 148 | 149 | 150 | class Brightness(object): 151 | 152 | def __init__(self, var): 153 | self.var = var 154 | 155 | def __call__(self, img): 156 | gs = img.new().resize_as_(img).zero_() 157 | alpha = random.uniform(0, self.var) 158 | return img.lerp(gs, alpha) 159 | 160 | 161 | class Contrast(object): 162 | 163 | def __init__(self, var): 164 | self.var = var 165 | 166 | def __call__(self, img): 167 | gs = Grayscale()(img) 168 | gs.fill_(gs.mean()) 169 | alpha = random.uniform(0, self.var) 170 | return img.lerp(gs, alpha) 171 | 172 | 173 | class RandomOrder(object): 174 | """ Composes several transforms together in random order. 175 | """ 176 | 177 | def __init__(self, transforms): 178 | self.transforms = transforms 179 | 180 | def __call__(self, img): 181 | if self.transforms is None: 182 | return img 183 | order = torch.randperm(len(self.transforms)) 184 | for i in order: 185 | img = self.transforms[i](img) 186 | return img 187 | 188 | 189 | class ColorJitter(RandomOrder): 190 | 191 | def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4): 192 | self.transforms = [] 193 | if brightness != 0: 194 | self.transforms.append(Brightness(brightness)) 195 | if contrast != 0: 196 | self.transforms.append(Contrast(contrast)) 197 | if saturation != 0: 198 | self.transforms.append(Saturation(saturation)) 199 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import logging.config 4 | import shutil 5 | import pandas as pd 6 | from bokeh.io import output_file, save, show 7 | from bokeh.plotting import figure 8 | from bokeh.layouts import column 9 | #from bokeh.charts import Line, defaults 10 | # 11 | #defaults.width = 800 12 | #defaults.height = 400 13 | #defaults.tools = 'pan,box_zoom,wheel_zoom,box_select,hover,resize,reset,save' 14 | 15 | 16 | def setup_logging(log_file='log.txt'): 17 | """Setup logging configuration 18 | """ 19 | logging.basicConfig(level=logging.DEBUG, 20 | format="%(asctime)s - %(levelname)s - %(message)s", 21 | datefmt="%Y-%m-%d %H:%M:%S", 22 | filename=log_file, 23 | filemode='w') 24 | console = logging.StreamHandler() 25 | console.setLevel(logging.INFO) 26 | formatter = logging.Formatter('%(message)s') 27 | console.setFormatter(formatter) 28 | logging.getLogger('').addHandler(console) 29 | 30 | 31 | class ResultsLog(object): 32 | 33 | def __init__(self, path='results.csv', plot_path=None): 34 | self.path = path 35 | self.plot_path = plot_path or (self.path + '.html') 36 | self.figures = [] 37 | self.results = None 38 | 39 | def add(self, **kwargs): 40 | df = pd.DataFrame([kwargs.values()], columns=kwargs.keys()) 41 | if self.results is None: 42 | self.results = df 43 | else: 44 | self.results = self.results.append(df, ignore_index=True) 45 | 46 | def save(self, title='Training Results'): 47 | if len(self.figures) > 0: 48 | if os.path.isfile(self.plot_path): 49 | os.remove(self.plot_path) 50 | output_file(self.plot_path, title=title) 51 | plot = column(*self.figures) 52 | save(plot) 53 | self.figures = [] 54 | self.results.to_csv(self.path, index=False, index_label=False) 55 | 56 | def load(self, path=None): 57 | path = path or self.path 58 | if os.path.isfile(path): 59 | self.results.read_csv(path) 60 | 61 | def show(self): 62 | if len(self.figures) > 0: 63 | plot = column(*self.figures) 64 | show(plot) 65 | 66 | #def plot(self, *kargs, **kwargs): 67 | # line = Line(data=self.results, *kargs, **kwargs) 68 | # self.figures.append(line) 69 | 70 | def image(self, *kargs, **kwargs): 71 | fig = figure() 72 | fig.image(*kargs, **kwargs) 73 | self.figures.append(fig) 74 | 75 | 76 | def save_checkpoint(state, is_best, path='.', filename='checkpoint.pth.tar', save_all=False): 77 | filename = os.path.join(path, filename) 78 | torch.save(state, filename) 79 | if is_best: 80 | shutil.copyfile(filename, os.path.join(path, 'model_best.pth.tar')) 81 | if save_all: 82 | shutil.copyfile(filename, os.path.join( 83 | path, 'checkpoint_epoch_%s.pth.tar' % state['epoch'])) 84 | 85 | 86 | class AverageMeter(object): 87 | """Computes and stores the average and current value""" 88 | 89 | def __init__(self): 90 | self.reset() 91 | 92 | def reset(self): 93 | self.val = 0 94 | self.avg = 0 95 | self.sum = 0 96 | self.count = 0 97 | 98 | def update(self, val, n=1): 99 | self.val = val 100 | self.sum += val * n 101 | self.count += n 102 | self.avg = self.sum / self.count 103 | 104 | __optimizers = { 105 | 'SGD': torch.optim.SGD, 106 | 'ASGD': torch.optim.ASGD, 107 | 'Adam': torch.optim.Adam, 108 | 'Adamax': torch.optim.Adamax, 109 | 'Adagrad': torch.optim.Adagrad, 110 | 'Adadelta': torch.optim.Adadelta, 111 | 'Rprop': torch.optim.Rprop, 112 | 'RMSprop': torch.optim.RMSprop 113 | } 114 | 115 | 116 | def adjust_optimizer(optimizer, epoch, config): 117 | """Reconfigures the optimizer according to epoch and config dict""" 118 | def modify_optimizer(optimizer, setting): 119 | if 'optimizer' in setting: 120 | optimizer = __optimizers[setting['optimizer']]( 121 | optimizer.param_groups) 122 | logging.debug('OPTIMIZER - setting method = %s' % 123 | setting['optimizer']) 124 | for param_group in optimizer.param_groups: 125 | for key in param_group.keys(): 126 | if key in setting: 127 | logging.debug('OPTIMIZER - setting %s = %s' % 128 | (key, setting[key])) 129 | param_group[key] = setting[key] 130 | return optimizer 131 | 132 | if callable(config): 133 | optimizer = modify_optimizer(optimizer, config(epoch)) 134 | else: 135 | for e in range(epoch + 1): # run over all epochs - sticky setting 136 | if e in config: 137 | optimizer = modify_optimizer(optimizer, config[e]) 138 | 139 | return optimizer 140 | 141 | 142 | def accuracy(output, target, topk=(1,)): 143 | """Computes the precision@k for the specified values of k""" 144 | maxk = max(topk) 145 | batch_size = target.size(0) 146 | 147 | _, pred = output.float().topk(maxk, 1, True, True) 148 | pred = pred.t() 149 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 150 | 151 | res = [] 152 | for k in topk: 153 | correct_k = correct[:k].reshape(-1).float().sum(0) 154 | res.append(correct_k.mul_(100.0 / batch_size)) 155 | return res 156 | 157 | # kernel_img = model.features[0][0].kernel.data.clone() 158 | # kernel_img.add_(-kernel_img.min()) 159 | # kernel_img.mul_(255 / kernel_img.max()) 160 | # save_image(kernel_img, 'kernel%s.jpg' % epoch) 161 | --------------------------------------------------------------------------------