├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── Understanding_QNNs (25).pdf ├── data.py ├── main.py ├── models ├── __init__.py ├── inception_resnet_v2.py ├── inception_v2.py ├── mnist.py ├── mobilenet.py ├── mobilenet_quantized.py ├── modules │ ├── __init__.py │ ├── bwn.py │ ├── quantize.py │ └── rnlu.py ├── resnet.py ├── resnet_bwn.py ├── resnet_quantized.py ├── resnet_quantized_float_bn.py └── resnext.py ├── preprocess.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | results/ 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | env/ 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *,cover 48 | .hypothesis/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | 58 | # Flask stuff: 59 | instance/ 60 | .webassets-cache 61 | 62 | # Scrapy stuff: 63 | .scrapy 64 | 65 | # Sphinx documentation 66 | docs/_build/ 67 | 68 | # PyBuilder 69 | target/ 70 | 71 | # IPython Notebook 72 | .ipynb_checkpoints 73 | 74 | # pyenv 75 | .python-version 76 | 77 | # celery beat schedule file 78 | celerybeat-schedule 79 | 80 | # dotenv 81 | .env 82 | 83 | # virtualenv 84 | venv/ 85 | ENV/ 86 | 87 | # Spyder project settings 88 | .spyderproject 89 | 90 | # Rope project settings 91 | .ropeproject 92 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "utils"] 2 | path = utils 3 | url = https://github.com/eladhoffer/utils.pytorch 4 | branch = master 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Elad Hoffer 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Quantized Convolutional networks using PyTorch 2 | **See https://github.com/eladhoffer/convNet.pytorch for updated version of this code** 3 | 4 | Code to replicate results in [Scalable Methods for 8-bit Training of Neural Networks](https://arxiv.org/abs/1805.11046) 5 | 6 | e.g: running an 8-bit quantized resnet18 from the paper on ImageNet 7 | 8 | ``` 9 | python main.py --model resnet_quantized --model_config "{'depth': 18}" --save quantized_resnet18 --dataset imagenet --b 128 10 | ``` 11 | 12 | ## Dependencies 13 | 14 | - [pytorch]() 15 | - [torchvision]() to load the datasets, perform image transforms 16 | - [pandas]() for logging to csv 17 | - [bokeh]() for training visualization 18 | 19 | 20 | ## Data 21 | - Configure your dataset path at **data.py**. 22 | - To get the ILSVRC data, you should register on their site for access: 23 | 24 | 25 | ## Model configuration 26 | 27 | Network model is defined by writing a .py file in models folder, and selecting it using the model flag. Model function must be registered in models/\_\_init\_\_.py 28 | The model function must return a trainable network. It can also specify additional training options such optimization regime (either a dictionary or a function), and input transform modifications. 29 | 30 | e.g for a model definition: 31 | 32 | ```python 33 | class Model(nn.Module): 34 | 35 | def __init__(self, num_classes=1000): 36 | super(Model, self).__init__() 37 | self.model = nn.Sequential(...) 38 | 39 | self.regime = [ 40 | {'epoch': 0, 'optimizer': 'SGD', 'lr': 1e-2, 41 | 'weight_decay': 5e-4, 'momentum': 0.9}, 42 | {'epoch': 15, 'lr': 1e-3, 'weight_decay': 0} 43 | ] 44 | 45 | self.input_transform = { 46 | 'train': transforms.Compose([...]), 47 | 'eval': transforms.Compose([...]) 48 | } 49 | def forward(self, inputs): 50 | return self.model(inputs) 51 | 52 | def model(**kwargs): 53 | return Model() 54 | ``` 55 | -------------------------------------------------------------------------------- /Understanding_QNNs (25).pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eladhoffer/quantized.pytorch/e09c447a50a6a4c7dabf6176f20c931422aefd67/Understanding_QNNs (25).pdf -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torchvision.datasets as datasets 3 | 4 | __DATASETS_DEFAULT_PATH = '/media/ssd/Datasets/' 5 | 6 | 7 | def get_dataset(name, split='train', transform=None, 8 | target_transform=None, download=True, datasets_path=__DATASETS_DEFAULT_PATH): 9 | train = (split == 'train') 10 | root = os.path.join(datasets_path, name) 11 | if name == 'cifar10': 12 | return datasets.CIFAR10(root=root, 13 | train=train, 14 | transform=transform, 15 | target_transform=target_transform, 16 | download=download) 17 | elif name == 'cifar100': 18 | return datasets.CIFAR100(root=root, 19 | train=train, 20 | transform=transform, 21 | target_transform=target_transform, 22 | download=download) 23 | elif name == 'mnist': 24 | return datasets.MNIST(root=root, 25 | train=train, 26 | transform=transform, 27 | target_transform=target_transform, 28 | download=download) 29 | elif name == 'stl10': 30 | return datasets.STL10(root=root, 31 | split=split, 32 | transform=transform, 33 | target_transform=target_transform, 34 | download=download) 35 | elif name == 'imagenet': 36 | if train: 37 | root = os.path.join(root, 'train') 38 | else: 39 | root = os.path.join(root, 'val') 40 | return datasets.ImageFolder(root=root, 41 | transform=transform, 42 | target_transform=target_transform) 43 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import logging 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.parallel 8 | import torch.backends.cudnn as cudnn 9 | import torch.optim 10 | import torch.utils.data 11 | import models 12 | from data import get_dataset 13 | from preprocess import get_transform 14 | from utils.log import setup_logging, ResultsLog, save_checkpoint 15 | from utils.meters import AverageMeter, accuracy 16 | from utils.optim import OptimRegime 17 | from utils.misc import torch_dtypes 18 | from datetime import datetime 19 | from ast import literal_eval 20 | 21 | model_names = sorted(name for name in models.__dict__ 22 | if name.islower() and not name.startswith("__") 23 | and callable(models.__dict__[name])) 24 | 25 | parser = argparse.ArgumentParser(description='PyTorch ConvNet Training') 26 | 27 | parser.add_argument('--results_dir', metavar='RESULTS_DIR', default='./results', 28 | help='results dir') 29 | parser.add_argument('--save', metavar='SAVE', default='', 30 | help='saved folder') 31 | parser.add_argument('--dataset', metavar='DATASET', default='imagenet', 32 | help='dataset name or folder') 33 | parser.add_argument('--model', '-a', metavar='MODEL', default='alexnet', 34 | choices=model_names, 35 | help='model architecture: ' + 36 | ' | '.join(model_names) + 37 | ' (default: alexnet)') 38 | parser.add_argument('--input_size', type=int, default=None, 39 | help='image input size') 40 | parser.add_argument('--model_config', default='', 41 | help='additional architecture configuration') 42 | parser.add_argument('--dtype', default='float', 43 | help='type of tensor: ' + 44 | ' | '.join(torch_dtypes.keys()) + 45 | ' (default: half)') 46 | parser.add_argument('--device', default='cuda', 47 | help='device assignment ("cpu" or "cuda")') 48 | parser.add_argument('--device_ids', default=[0], type=int, nargs='+', 49 | help='device ids assignment (e.g 0 1 2 3') 50 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', 51 | help='number of data loading workers (default: 8)') 52 | parser.add_argument('--epochs', default=90, type=int, metavar='N', 53 | help='number of total epochs to run') 54 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 55 | help='manual epoch number (useful on restarts)') 56 | parser.add_argument('-b', '--batch-size', default=256, type=int, 57 | metavar='N', help='mini-batch size (default: 256)') 58 | parser.add_argument('--optimizer', default='SGD', type=str, metavar='OPT', 59 | help='optimizer function used') 60 | parser.add_argument('--lr', '--learning_rate', default=0.1, type=float, 61 | metavar='LR', help='initial learning rate') 62 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 63 | help='momentum') 64 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 65 | metavar='W', help='weight decay (default: 1e-4)') 66 | parser.add_argument('--print-freq', '-p', default=10, type=int, 67 | metavar='N', help='print frequency (default: 10)') 68 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 69 | help='path to latest checkpoint (default: none)') 70 | parser.add_argument('-e', '--evaluate', type=str, metavar='FILE', 71 | help='evaluate model FILE on validation set') 72 | parser.add_argument('--seed', default=123, type=int, 73 | help='random seed (default: 123)') 74 | 75 | 76 | def main(): 77 | global args, best_prec1, dtype 78 | best_prec1 = 0 79 | args = parser.parse_args() 80 | dtype = torch_dtypes.get(args.dtype) 81 | torch.manual_seed(args.seed) 82 | time_stamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') 83 | if args.evaluate: 84 | args.results_dir = '/tmp' 85 | if args.save is '': 86 | args.save = time_stamp 87 | save_path = os.path.join(args.results_dir, args.save) 88 | if not os.path.exists(save_path): 89 | os.makedirs(save_path) 90 | 91 | setup_logging(os.path.join(save_path, 'log.txt'), 92 | resume=args.resume is not '') 93 | results_path = os.path.join(save_path, 'results') 94 | results = ResultsLog( 95 | results_path, title='Training Results - %s' % args.save) 96 | 97 | logging.info("saving to %s", save_path) 98 | logging.debug("run arguments: %s", args) 99 | 100 | if 'cuda' in args.device and torch.cuda.is_available(): 101 | torch.cuda.manual_seed_all(args.seed) 102 | torch.cuda.set_device(args.device_ids[0]) 103 | cudnn.benchmark = True 104 | else: 105 | args.device_ids = None 106 | 107 | # create model 108 | logging.info("creating model %s", args.model) 109 | model = models.__dict__[args.model] 110 | model_config = {'input_size': args.input_size, 'dataset': args.dataset} 111 | 112 | if args.model_config is not '': 113 | model_config = dict(model_config, **literal_eval(args.model_config)) 114 | 115 | model = model(**model_config) 116 | logging.info("created model with configuration: %s", model_config) 117 | 118 | # optionally resume from a checkpoint 119 | if args.evaluate: 120 | if not os.path.isfile(args.evaluate): 121 | parser.error('invalid checkpoint: {}'.format(args.evaluate)) 122 | checkpoint = torch.load(args.evaluate) 123 | model.load_state_dict(checkpoint['state_dict']) 124 | logging.info("loaded checkpoint '%s' (epoch %s)", 125 | args.evaluate, checkpoint['epoch']) 126 | elif args.resume: 127 | checkpoint_file = args.resume 128 | if os.path.isdir(checkpoint_file): 129 | results.load(os.path.join(checkpoint_file, 'results.csv')) 130 | checkpoint_file = os.path.join( 131 | checkpoint_file, 'model_best.pth.tar') 132 | if os.path.isfile(checkpoint_file): 133 | logging.info("loading checkpoint '%s'", args.resume) 134 | checkpoint = torch.load(checkpoint_file) 135 | args.start_epoch = checkpoint['epoch'] - 1 136 | best_prec1 = checkpoint['best_prec1'] 137 | model.load_state_dict(checkpoint['state_dict']) 138 | logging.info("loaded checkpoint '%s' (epoch %s)", 139 | checkpoint_file, checkpoint['epoch']) 140 | else: 141 | logging.error("no checkpoint found at '%s'", args.resume) 142 | 143 | num_parameters = sum([l.nelement() for l in model.parameters()]) 144 | logging.info("number of parameters: %d", num_parameters) 145 | 146 | # Data loading code 147 | default_transform = { 148 | 'train': get_transform(args.dataset, 149 | input_size=args.input_size, augment=True), 150 | 'eval': get_transform(args.dataset, 151 | input_size=args.input_size, augment=False) 152 | } 153 | transform = getattr(model, 'input_transform', default_transform) 154 | regime = getattr(model, 'regime', [{'epoch': 0, 155 | 'optimizer': args.optimizer, 156 | 'lr': args.lr, 157 | 'momentum': args.momentum, 158 | 'weight_decay': args.weight_decay}]) 159 | 160 | # define loss function (criterion) and optimizer 161 | criterion = getattr(model, 'criterion', nn.CrossEntropyLoss)() 162 | criterion.to(args.device, dtype) 163 | model.to(args.device, dtype) 164 | 165 | val_data = get_dataset(args.dataset, 'val', transform['eval']) 166 | val_loader = torch.utils.data.DataLoader( 167 | val_data, 168 | batch_size=args.batch_size, shuffle=False, 169 | num_workers=args.workers, pin_memory=True) 170 | 171 | if args.evaluate: 172 | validate(val_loader, model, criterion, 0) 173 | return 174 | 175 | train_data = get_dataset(args.dataset, 'train', transform['train']) 176 | train_loader = torch.utils.data.DataLoader( 177 | train_data, 178 | batch_size=args.batch_size, shuffle=True, 179 | num_workers=args.workers, pin_memory=True) 180 | 181 | optimizer = OptimRegime(model.parameters(), regime) 182 | logging.info('training regime: %s', regime) 183 | 184 | for epoch in range(args.start_epoch, args.epochs): 185 | # train for one epoch 186 | train_loss, train_prec1, train_prec5 = train( 187 | train_loader, model, criterion, epoch, optimizer) 188 | 189 | # evaluate on validation set 190 | val_loss, val_prec1, val_prec5 = validate( 191 | val_loader, model, criterion, epoch) 192 | 193 | # remember best prec@1 and save checkpoint 194 | is_best = val_prec1 > best_prec1 195 | best_prec1 = max(val_prec1, best_prec1) 196 | save_checkpoint({ 197 | 'epoch': epoch + 1, 198 | 'model': args.model, 199 | 'config': args.model_config, 200 | 'state_dict': model.state_dict(), 201 | 'best_prec1': best_prec1, 202 | 'regime': regime 203 | }, is_best, path=save_path) 204 | logging.info('\n Epoch: {0}\t' 205 | 'Training Loss {train_loss:.4f} \t' 206 | 'Training Prec@1 {train_prec1:.3f} \t' 207 | 'Training Prec@5 {train_prec5:.3f} \t' 208 | 'Validation Loss {val_loss:.4f} \t' 209 | 'Validation Prec@1 {val_prec1:.3f} \t' 210 | 'Validation Prec@5 {val_prec5:.3f} \n' 211 | .format(epoch + 1, train_loss=train_loss, val_loss=val_loss, 212 | train_prec1=train_prec1, val_prec1=val_prec1, 213 | train_prec5=train_prec5, val_prec5=val_prec5)) 214 | 215 | results.add(epoch=epoch + 1, train_loss=train_loss, val_loss=val_loss, 216 | train_error1=100 - train_prec1, val_error1=100 - val_prec1, 217 | train_error5=100 - train_prec5, val_error5=100 - val_prec5) 218 | results.plot(x='epoch', y=['train_loss', 'val_loss'], 219 | legend=['training', 'validation'], 220 | title='Loss', ylabel='loss') 221 | results.plot(x='epoch', y=['train_error1', 'val_error1'], 222 | legend=['training', 'validation'], 223 | title='Error@1', ylabel='error %') 224 | results.plot(x='epoch', y=['train_error5', 'val_error5'], 225 | legend=['training', 'validation'], 226 | title='Error@5', ylabel='error %') 227 | results.save() 228 | 229 | 230 | def forward(data_loader, model, criterion, epoch=0, training=True, optimizer=None): 231 | regularizer = getattr(model, 'regularization', None) 232 | if args.device_ids and len(args.device_ids) > 1: 233 | model = torch.nn.DataParallel(model, args.device_ids) 234 | 235 | batch_time = AverageMeter() 236 | data_time = AverageMeter() 237 | losses = AverageMeter() 238 | top1 = AverageMeter() 239 | top5 = AverageMeter() 240 | 241 | end = time.time() 242 | for i, (inputs, target) in enumerate(data_loader): 243 | # measure data loading time 244 | data_time.update(time.time() - end) 245 | target = target.to(args.device) 246 | inputs = inputs.to(args.device, dtype=dtype) 247 | 248 | # compute output 249 | output = model(inputs) 250 | loss = criterion(output, target) 251 | if regularizer is not None: 252 | loss += regularizer(model) 253 | 254 | if type(output) is list: 255 | output = output[0] 256 | 257 | # measure accuracy and record loss 258 | prec1, prec5 = accuracy(output.detach(), target, topk=(1, 5)) 259 | losses.update(float(loss), inputs.size(0)) 260 | top1.update(float(prec1), inputs.size(0)) 261 | top5.update(float(prec5), inputs.size(0)) 262 | 263 | if training: 264 | optimizer.update(epoch, epoch * len(data_loader) + i) 265 | # compute gradient and do SGD step 266 | optimizer.zero_grad() 267 | loss.backward() 268 | optimizer.step() 269 | 270 | # measure elapsed time 271 | batch_time.update(time.time() - end) 272 | end = time.time() 273 | 274 | if i % args.print_freq == 0: 275 | logging.info('{phase} - Epoch: [{0}][{1}/{2}]\t' 276 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 277 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 278 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 279 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 280 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 281 | epoch, i, len(data_loader), 282 | phase='TRAINING' if training else 'EVALUATING', 283 | batch_time=batch_time, 284 | data_time=data_time, loss=losses, top1=top1, top5=top5)) 285 | 286 | return losses.avg, top1.avg, top5.avg 287 | 288 | 289 | def train(data_loader, model, criterion, epoch, optimizer): 290 | # switch to train mode 291 | model.train() 292 | return forward(data_loader, model, criterion, epoch, 293 | training=True, optimizer=optimizer) 294 | 295 | 296 | def validate(data_loader, model, criterion, epoch): 297 | # switch to evaluate mode 298 | model.eval() 299 | with torch.no_grad(): 300 | return forward(data_loader, model, criterion, epoch, 301 | training=False, optimizer=None) 302 | 303 | 304 | if __name__ == '__main__': 305 | main() 306 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import * 2 | from .resnet_bwn import * 3 | from .resnet_quantized import * 4 | from .resnet_quantized_float_bn import * 5 | from .resnext import * 6 | from .inception_resnet_v2 import * 7 | from .inception_v2 import * 8 | from .mobilenet import * 9 | from .mobilenet_quantized import * -------------------------------------------------------------------------------- /models/inception_resnet_v2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from collections import OrderedDict 4 | 5 | __all__ = ['inception_resnet_v2'] 6 | 7 | """ inception_resnet_v2. 8 | References: 9 | Inception-v4, Inception-ResNet and the Impact of Residual Connections 10 | on Learning 11 | Christian Szegedy, Sergey Ioffe, Vincent Vanhoucke, Alex Alemi. 12 | 13 | Links: 14 | http://arxiv.org/abs/1602.07261 15 | 16 | """ 17 | 18 | 19 | def conv_bn(in_planes, out_planes, kernel_size, stride=1, padding=0, bias=False): 20 | "convolution with batchnorm, relu" 21 | return nn.Sequential( 22 | nn.Conv2d(in_planes, out_planes, kernel_size, stride=stride, 23 | padding=padding, bias=False), 24 | nn.BatchNorm2d(out_planes, eps=1e-3), 25 | nn.ReLU() 26 | ) 27 | 28 | 29 | class Concat(nn.Sequential): 30 | 31 | def __init__(self, *kargs, **kwargs): 32 | super(Concat, self).__init__(*kargs, **kwargs) 33 | 34 | def forward(self, inputs): 35 | return torch.cat([m(inputs) for m in self._modules.values()], 1) 36 | 37 | 38 | class block(nn.Module): 39 | 40 | def __init__(self, in_planes, scale=1.0, activation=nn.ReLU(True)): 41 | super(block, self).__init__() 42 | self.scale = scale 43 | self.activation = activation or (lambda x: x) 44 | 45 | def forward(self, inputs): 46 | branch0 = self.Branch_0(inputs) 47 | branch1 = self.Branch_1(inputs) 48 | if hasattr(self, 'Branch_2'): 49 | branch2 = self.Branch_2(inputs) 50 | tower_mixed = torch.cat([branch0, branch1, branch2], 1) 51 | else: 52 | tower_mixed = torch.cat([branch0, branch1], 1) 53 | tower_out = self.Conv2d_1x1(tower_mixed) 54 | output = self.activation(self.scale * tower_out + inputs) 55 | return output 56 | 57 | 58 | class block35(block): 59 | 60 | def __init__(self, in_planes, scale=1.0, activation=nn.ReLU(True)): 61 | super(block35, self).__init__(in_planes, scale, activation) 62 | self.Branch_0 = nn.Sequential(OrderedDict([ 63 | ('Conv2d_1x1', conv_bn(in_planes, 32, 1)) 64 | ])) 65 | self.Branch_1 = nn.Sequential(OrderedDict([ 66 | ('Conv2d_0a_1x1', conv_bn(in_planes, 32, 1)), 67 | ('Conv2d_0b_3x3', conv_bn(32, 32, 3, padding=1)) 68 | ])) 69 | self.Branch_2 = nn.Sequential(OrderedDict([ 70 | ('Conv2d_0a_1x1', conv_bn(in_planes, 32, 1)), 71 | ('Conv2d_0b_3x3', conv_bn(32, 48, 3, padding=1)), 72 | ('Conv2d_0c_3x3', conv_bn(48, 64, 3, padding=1)) 73 | ])) 74 | self.Conv2d_1x1 = conv_bn(128, in_planes, 1) 75 | 76 | 77 | class block17(block): 78 | 79 | def __init__(self, in_planes, scale=1.0, activation=nn.ReLU(True)): 80 | super(block17, self).__init__(in_planes, scale, activation) 81 | 82 | self.Branch_0 = nn.Sequential(OrderedDict([ 83 | ('Conv2d_1x1', conv_bn(in_planes, 192, 1)) 84 | ])) 85 | self.Branch_1 = nn.Sequential(OrderedDict([ 86 | ('Conv2d_0a_1x1', conv_bn(in_planes, 128, 1)), 87 | ('Conv2d_0b_1x7', conv_bn(128, 160, (1, 7), padding=(0, 3))), 88 | ('Conv2d_0c_7x1', conv_bn(160, 192, (7, 1), padding=(3, 0))) 89 | ])) 90 | self.Conv2d_1x1 = conv_bn(384, in_planes, 1) 91 | 92 | 93 | class block8(block): 94 | 95 | def __init__(self, in_planes, scale=1.0, activation=nn.ReLU(True)): 96 | super(block8, self).__init__(in_planes, scale, activation) 97 | 98 | self.Branch_0 = nn.Sequential(OrderedDict([ 99 | ('Conv2d_1x1', conv_bn(in_planes, 192, 1)) 100 | ])) 101 | self.Branch_1 = nn.Sequential(OrderedDict([ 102 | ('Conv2d_0a_1x1', conv_bn(in_planes, 192, 1)), 103 | ('Conv2d_0b_1x7', conv_bn(192, 224, (1, 3), padding=(0, 1))), 104 | ('Conv2d_0c_7x1', conv_bn(224, 256, (3, 1), padding=(1, 0))) 105 | ])) 106 | self.Conv2d_1x1 = conv_bn(448, in_planes, 1) 107 | 108 | 109 | class InceptionResnetV2(nn.Module): 110 | 111 | def __init__(self, num_classes=1000): 112 | super(InceptionResnetV2, self).__init__() 113 | self.end_points = {} 114 | self.num_classes = num_classes 115 | 116 | self.stem = nn.Sequential(OrderedDict([ 117 | ('Conv2d_1a_3x3', conv_bn(3, 32, 3, stride=2, padding=1)), 118 | ('Conv2d_2a_3x3', conv_bn(32, 32, 3, padding=1)), 119 | ('Conv2d_2b_3x3', conv_bn(32, 64, 3)), 120 | ('MaxPool_3a_3x3', nn.MaxPool2d(3, 2)), 121 | ('Conv2d_3b_1x1', conv_bn(64, 80, 1)), 122 | ('Conv2d_4a_3x3', conv_bn(80, 192, 3)), 123 | ('MaxPool_5a_3x3', nn.MaxPool2d(3, 2)) 124 | ])) 125 | 126 | tower_conv = nn.Sequential(OrderedDict([ 127 | ('Conv2d_5b_b0_1x1', conv_bn(192, 96, 1)) 128 | ])) 129 | tower_conv1 = nn.Sequential(OrderedDict([ 130 | ('Conv2d_5b_b1_0a_1x1', conv_bn(192, 48, 1)), 131 | ('Conv2d_5b_b1_0b_5x5', conv_bn(48, 64, 5, padding=2)) 132 | ])) 133 | tower_conv2 = nn.Sequential(OrderedDict([ 134 | ('Conv2d_5b_b2_0a_1x1', conv_bn(192, 64, 1)), 135 | ('Conv2d_5b_b2_0b_3x3', conv_bn(64, 96, 3, padding=1)), 136 | ('Conv2d_5b_b2_0c_3x3', conv_bn(96, 96, 3, padding=1)) 137 | ])) 138 | tower_pool3 = nn.Sequential(OrderedDict([ 139 | ('AvgPool_5b_b3_0a_3x3', nn.AvgPool2d(3, stride=1, padding=1)), 140 | ('Conv2d_5b_b3_0b_1x1', conv_bn(192, 64, 1)) 141 | ])) 142 | 143 | self.mixed_5b = Concat(OrderedDict([ 144 | ('Branch_0', tower_conv), 145 | ('Branch_1', tower_conv1), 146 | ('Branch_2', tower_conv2), 147 | ('Branch_3', tower_pool3) 148 | ])) 149 | 150 | self.blocks35 = nn.Sequential() 151 | for i in range(10): 152 | self.blocks35.add_module('Block35.%s' % 153 | i, block35(320, scale=0.17)) 154 | 155 | tower_conv = nn.Sequential(OrderedDict([ 156 | ('Conv2d_6a_b0_0a_3x3', conv_bn(320, 384, 3, stride=2)) 157 | ])) 158 | tower_conv1 = nn.Sequential(OrderedDict([ 159 | ('Conv2d_6a_b1_0a_1x1', conv_bn(320, 256, 1)), 160 | ('Conv2d_6a_b1_0b_3x3', conv_bn(256, 256, 3, padding=1)), 161 | ('Conv2d_6a_b1_0c_3x3', conv_bn(256, 384, 3, stride=2)) 162 | ])) 163 | tower_pool = nn.Sequential(OrderedDict([ 164 | ('MaxPool_1a_3x3', nn.MaxPool2d(3, stride=2)) 165 | ])) 166 | 167 | self.mixed_6a = Concat(OrderedDict([ 168 | ('Branch_0', tower_conv), 169 | ('Branch_1', tower_conv1), 170 | ('Branch_2', tower_pool) 171 | ])) 172 | 173 | self.blocks17 = nn.Sequential() 174 | for i in range(20): 175 | self.blocks17.add_module('Block17.%s' % 176 | i, block17(1088, scale=0.1)) 177 | 178 | tower_conv = nn.Sequential(OrderedDict([ 179 | ('Conv2d_0a_1x1', conv_bn(1088, 256, 1)), 180 | ('Conv2d_1a_3x3', conv_bn(256, 384, 3, stride=2)), 181 | ])) 182 | tower_conv1 = nn.Sequential(OrderedDict([ 183 | ('Conv2d_0a_1x1', conv_bn(1088, 256, 1)), 184 | ('Conv2d_1a_3x3', conv_bn(256, 64, 3, stride=2)) 185 | ])) 186 | tower_conv2 = nn.Sequential(OrderedDict([ 187 | ('Conv2d_0a_1x1', conv_bn(1088, 256, 1)), 188 | ('Conv2d_0b_3x3', conv_bn(256, 288, 3, padding=1)), 189 | ('Conv2d_1a_3x3', conv_bn(288, 320, 3, stride=2)) 190 | ])) 191 | tower_pool3 = nn.Sequential(OrderedDict([ 192 | ('MaxPool_1a_3x3', nn.MaxPool2d(3, stride=2)) 193 | ])) 194 | 195 | self.mixed_7a = Concat(OrderedDict([ 196 | ('Branch_0', tower_conv), 197 | ('Branch_1', tower_conv1), 198 | ('Branch_2', tower_conv2), 199 | ('Branch_3', tower_pool3) 200 | ])) 201 | 202 | self.blocks8 = nn.Sequential() 203 | for i in range(9): 204 | self.blocks8.add_module('Block8.%s' % 205 | i, block8(1856, scale=0.2)) 206 | self.blocks8.add_module('Block8.9', block8( 207 | 1856, scale=0.2, activation=None)) 208 | 209 | self.conv_pool = nn.Sequential(OrderedDict([ 210 | ('Conv2d_7b_1x1', conv_bn(1856, 1536, 1)), 211 | ('AvgPool_1a_8x8', nn.AvgPool2d(8, 1)), 212 | ('Dropout', nn.Dropout(0.2)) 213 | ])) 214 | self.classifier = nn.Linear(1536, num_classes) 215 | 216 | self.aux_classifier = nn.Sequential(OrderedDict([ 217 | ('Conv2d_1a_3x3', nn.AvgPool2d(5, 3)), 218 | ('Conv2d_1b_1x1', conv_bn(1088, 128, 1)), 219 | ('Conv2d_2a_5x5', conv_bn(128, 768, 5)), 220 | ('Dropout', nn.Dropout(0.2)), 221 | ('Logits', conv_bn(768, num_classes, 1)) 222 | ])) 223 | 224 | class aux_loss(nn.Module): 225 | def __init__(self): 226 | super(aux_loss,self).__init__() 227 | self.loss = nn.CrossEntropyLoss() 228 | 229 | def forward(self, outputs, target): 230 | return self.loss(outputs[0], target) +\ 231 | 0.4 * (self.loss(outputs[1], target)) 232 | self.criterion = aux_loss 233 | self.regime = [ 234 | {'epoch': 0, 'optimizer': 'SGD', 'lr': 1e-1, 235 | 'weight_decay': 1e-4, 'momentum': 0.9}, 236 | {'epoch': 30, 'lr': 1e-2}, 237 | {'epoch': 60, 'lr': 1e-3, 'weight_decay': 0}, 238 | {'epoch': 90, 'lr': 1e-4} 239 | ] 240 | 241 | def forward(self, x): 242 | x = self.stem(x) # (B, 192, 35, 35) 243 | x = self.mixed_5b(x) # (B, 320, 35, 35) 244 | x = self.blocks35(x) # (B, 320, 35, 35) 245 | x = self.mixed_6a(x) # (B, 1088, 17, 17) 246 | branch1 = self.blocks17(x) # (B, 1088, 17, 17) 247 | x = self.mixed_7a(branch1) # (B, 1856, 8, 8) 248 | x = self.blocks8(x) # (B, 1856, 8, 8) 249 | x = self.conv_pool(x) # (B, 1536, 1, 1) 250 | x = x.view(-1, 1536) # (B, 1536) 251 | output = self.classifier(x) # (B, num_classes) 252 | if hasattr(self, 'aux_classifier'): 253 | branch1 = self.aux_classifier(branch1).view(-1, self.num_classes) 254 | output = [output, branch1] 255 | return output 256 | 257 | def inception_resnet_v2(**kwargs): 258 | num_classes = getattr(kwargs, 'num_classes', 1000) 259 | return InceptionResnetV2(num_classes=num_classes) 260 | -------------------------------------------------------------------------------- /models/inception_v2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.transforms as transforms 4 | import math 5 | 6 | __all__ = ['inception_v2'] 7 | 8 | def conv_bn(in_planes, out_planes, kernel_size, stride=1, padding=0): 9 | "convolution with batchnorm, relu" 10 | return nn.Sequential( 11 | nn.Conv2d(in_planes, out_planes, kernel_size, stride=stride, 12 | padding=padding, bias=False), 13 | nn.BatchNorm2d(out_planes), 14 | nn.ReLU() 15 | ) 16 | 17 | 18 | class InceptionModule(nn.Module): 19 | 20 | def __init__(self, in_channels, n1x1_channels, n3x3r_channels, 21 | n3x3_channels, dn3x3r_channels, dn3x3_channels, 22 | pool_proj_channels=None, type_pool='avg', stride=1): 23 | super(InceptionModule, self).__init__() 24 | self.in_channels = in_channels 25 | self.n1x1_channels = n1x1_channels or 0 26 | pool_proj_channels = pool_proj_channels or 0 27 | self.stride = stride 28 | 29 | if n1x1_channels > 0: 30 | self.conv_1x1 = conv_bn(in_channels, n1x1_channels, 1, stride) 31 | else: 32 | self.conv_1x1 = None 33 | 34 | self.conv_3x3 = nn.Sequential( 35 | conv_bn(in_channels, n3x3r_channels, 1), 36 | conv_bn(n3x3r_channels, n3x3_channels, 3, stride, padding=1) 37 | ) 38 | self.conv_d3x3 = nn.Sequential( 39 | conv_bn(in_channels, dn3x3r_channels, 1), 40 | conv_bn(dn3x3r_channels, dn3x3_channels, 3, padding=1), 41 | conv_bn(dn3x3_channels, dn3x3_channels, 3, stride, padding=1) 42 | ) 43 | 44 | if type_pool == 'avg': 45 | self.pool = nn.AvgPool2d(3, stride, padding=1) 46 | elif type_pool == 'max': 47 | self.pool = nn.MaxPool2d(3, stride, padding=1) 48 | 49 | if pool_proj_channels > 0: # Add pool projection 50 | self.pool = nn.Sequential( 51 | self.pool, 52 | conv_bn(in_channels, pool_proj_channels, 1)) 53 | 54 | def forward(self, inputs): 55 | layer_outputs = [] 56 | 57 | if self.conv_1x1 is not None: 58 | layer_outputs.append(self.conv_1x1(inputs)) 59 | 60 | layer_outputs.append(self.conv_3x3(inputs)) 61 | layer_outputs.append(self.conv_d3x3(inputs)) 62 | layer_outputs.append(self.pool(inputs)) 63 | output = torch.cat(layer_outputs, 1) 64 | 65 | return output 66 | 67 | 68 | class Inception_v2(nn.Module): 69 | 70 | def __init__(self, num_classes=1000, aux_classifiers=True): 71 | super(inception_v2, self).__init__() 72 | self.num_classes = num_classes 73 | self.part1 = nn.Sequential( 74 | nn.Conv2d(3, 64, 7, 2, 3, bias=False), 75 | nn.MaxPool2d(3, 2), 76 | nn.BatchNorm2d(64), 77 | nn.ReLU(), 78 | nn.Conv2d(64, 192, 3, 1, 1, bias=False), 79 | nn.MaxPool2d(3, 2), 80 | nn.BatchNorm2d(192), 81 | nn.ReLU(), 82 | InceptionModule(192, 64, 64, 64, 64, 96, 32, 'avg'), 83 | InceptionModule(256, 64, 64, 96, 64, 96, 64, 'avg'), 84 | InceptionModule(320, 0, 128, 160, 64, 96, 0, 'max', 2) 85 | ) 86 | 87 | self.part2 = nn.Sequential( 88 | InceptionModule(576, 224, 64, 96, 96, 128, 128, 'avg'), 89 | InceptionModule(576, 192, 96, 128, 96, 128, 128, 'avg'), 90 | InceptionModule(576, 160, 128, 160, 128, 160, 96, 'avg') 91 | ) 92 | self.part3 = nn.Sequential( 93 | InceptionModule(576, 96, 128, 192, 160, 192, 96, 'avg'), 94 | InceptionModule(576, 0, 128, 192, 192, 256, 0, 'max', 2), 95 | InceptionModule(1024, 352, 192, 320, 160, 224, 128, 'avg'), 96 | InceptionModule(1024, 352, 192, 320, 192, 224, 128, 'max') 97 | ) 98 | 99 | self.main_classifier = nn.Sequential( 100 | nn.AvgPool2d(7, 1), 101 | nn.Dropout(0.2), 102 | nn.Conv2d(1024, self.num_classes, 1) 103 | ) 104 | if aux_classifiers: 105 | self.aux_classifier1 = nn.Sequential( 106 | nn.AvgPool2d(5, 3), 107 | conv_bn(576, 128, 1), 108 | conv_bn(128, 768, 4), 109 | nn.Dropout(0.2), 110 | nn.Conv2d(768, self.num_classes, 1), 111 | ) 112 | self.aux_classifier2 = nn.Sequential( 113 | nn.AvgPool2d(5, 3), 114 | conv_bn(576, 128, 1), 115 | conv_bn(128, 768, 4), 116 | nn.Dropout(0.2), 117 | nn.Conv2d(768, self.num_classes, 1), 118 | ) 119 | 120 | self.regime = [ 121 | {'epoch': 0, 'optimizer': 'SGD', 'lr': 1e-1, 122 | 'weight_decay': 1e-4, 'momentum': 0.9}, 123 | {'epoch': 30, 'lr': 1e-2}, 124 | {'epoch': 60, 'lr': 1e-3, 'weight_decay': 0}, 125 | {'epoch': 90, 'lr': 1e-4} 126 | ] 127 | 128 | class aux_loss(nn.Module): 129 | def __init__(self): 130 | super(aux_loss,self).__init__() 131 | self.loss = nn.CrossEntropyLoss() 132 | 133 | def forward(self, outputs, target): 134 | return self.loss(outputs[0], target) +\ 135 | 0.4 * (self.loss(outputs[1], target) + self.loss(outputs[2], target)) 136 | self.criterion = aux_loss 137 | 138 | def forward(self, inputs): 139 | branch1 = self.part1(inputs) 140 | branch2 = self.part2(branch1) 141 | branch3 = self.part3(branch1) 142 | 143 | output = self.main_classifier(branch3).view(-1, self.num_classes) 144 | if hasattr(self, 'aux_classifier1'): 145 | branch1 = self.aux_classifier1(branch1).view(-1, self.num_classes) 146 | branch2 = self.aux_classifier2(branch2).view(-1, self.num_classes) 147 | output = [output, branch1, branch2] 148 | return output 149 | 150 | 151 | def inception_v2(**kwargs): 152 | num_classes = getattr(kwargs, 'num_classes', 1000) 153 | return Inception_v2(num_classes=num_classes) 154 | -------------------------------------------------------------------------------- /models/mnist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class mnist_model(nn.Module): 6 | 7 | def __init__(self): 8 | super(mnist_model, self).__init__() 9 | self.feats = nn.Sequential( 10 | nn.Conv2d(1, 32, 5, 1, 1), 11 | nn.MaxPool2d(2, 2), 12 | nn.ReLU(True), 13 | nn.BatchNorm2d(32), 14 | 15 | nn.Conv2d(32, 64, 3, 1, 1), 16 | nn.ReLU(True), 17 | nn.BatchNorm2d(64), 18 | 19 | nn.Conv2d(64, 64, 3, 1, 1), 20 | nn.MaxPool2d(2, 2), 21 | nn.ReLU(True), 22 | nn.BatchNorm2d(64), 23 | 24 | nn.Conv2d(64, 128, 3, 1, 1), 25 | nn.ReLU(True), 26 | nn.BatchNorm2d(128) 27 | ) 28 | 29 | self.classifier = nn.Conv2d(128, 10, 1) 30 | self.avgpool = nn.AvgPool2d(6, 6) 31 | self.dropout = nn.Dropout(0.5) 32 | 33 | def forward(self, inputs): 34 | out = self.feats(inputs) 35 | out = self.dropout(out) 36 | out = self.classifier(out) 37 | out = self.avgpool(out) 38 | out = out.view(-1, 10) 39 | return out 40 | 41 | 42 | def model(**kwargs): 43 | return mnist_model() 44 | -------------------------------------------------------------------------------- /models/mobilenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.modules.utils import _single, _pair, _triple 4 | import math 5 | import torch.nn.functional as F 6 | from torch.nn.modules.utils import _pair 7 | import torchvision.transforms as transforms 8 | __all__ = ['mobilenet'] 9 | 10 | def nearby_int(n): 11 | return int(round(n)) 12 | 13 | 14 | def init_model(model): 15 | for m in model.modules(): 16 | if isinstance(m, nn.Conv2d): 17 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 18 | m.weight.data.normal_(0, math.sqrt(2. / n)) 19 | elif isinstance(m, nn.BatchNorm2d): 20 | m.weight.data.fill_(1) 21 | m.bias.data.zero_() 22 | 23 | 24 | class DepthwiseSeparableFusedConv2d(nn.Module): 25 | 26 | def __init__(self, in_channels, out_channels, kernel_size, 27 | stride=1, padding=0): 28 | super(DepthwiseSeparableFusedConv2d, self).__init__() 29 | self.components = nn.Sequential( 30 | nn.Conv2d(in_channels, in_channels, kernel_size, 31 | stride=stride, padding=padding, groups=in_channels), 32 | nn.BatchNorm2d(in_channels), 33 | nn.ReLU(), 34 | 35 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 36 | nn.BatchNorm2d(out_channels), 37 | nn.ReLU() 38 | ) 39 | 40 | def forward(self, x): 41 | return self.components(x) 42 | 43 | 44 | class MobileNet(nn.Module): 45 | 46 | def __init__(self, width=1., shallow=False, num_classes=1000): 47 | super(MobileNet, self).__init__() 48 | num_classes = num_classes or 1000 49 | width = width or 1. 50 | layers = [ 51 | nn.Conv2d(3, nearby_int(width * 32), 52 | kernel_size=3, stride=2, padding=1, bias=False), 53 | nn.BatchNorm2d(nearby_int(width * 32)), 54 | nn.ReLU(inplace=True), 55 | 56 | DepthwiseSeparableFusedConv2d( 57 | nearby_int(width * 32), nearby_int(width * 64), 58 | kernel_size=3, padding=1), 59 | DepthwiseSeparableFusedConv2d( 60 | nearby_int(width * 64), nearby_int(width * 128), 61 | kernel_size=3, stride=2, padding=1), 62 | DepthwiseSeparableFusedConv2d( 63 | nearby_int(width * 128), nearby_int(width * 128), 64 | kernel_size=3, padding=1), 65 | DepthwiseSeparableFusedConv2d( 66 | nearby_int(width * 128), nearby_int(width * 256), 67 | kernel_size=3, stride=2, padding=1), 68 | DepthwiseSeparableFusedConv2d( 69 | nearby_int(width * 256), nearby_int(width * 256), 70 | kernel_size=3, padding=1), 71 | DepthwiseSeparableFusedConv2d( 72 | nearby_int(width * 256), nearby_int(width * 512), 73 | kernel_size=3, stride=2, padding=1) 74 | ] 75 | if not shallow: 76 | # 5x 512->512 DW-separable convolutions 77 | layers += [ 78 | DepthwiseSeparableFusedConv2d( 79 | nearby_int(width * 512), nearby_int(width * 512), 80 | kernel_size=3, padding=1), 81 | DepthwiseSeparableFusedConv2d( 82 | nearby_int(width * 512), nearby_int(width * 512), 83 | kernel_size=3, padding=1), 84 | DepthwiseSeparableFusedConv2d( 85 | nearby_int(width * 512), nearby_int(width * 512), 86 | kernel_size=3, padding=1), 87 | DepthwiseSeparableFusedConv2d( 88 | nearby_int(width * 512), nearby_int(width * 512), 89 | kernel_size=3, padding=1), 90 | DepthwiseSeparableFusedConv2d( 91 | nearby_int(width * 512), nearby_int(width * 512), 92 | kernel_size=3, padding=1), 93 | ] 94 | layers += [ 95 | DepthwiseSeparableFusedConv2d( 96 | nearby_int(width * 512), nearby_int(width * 1024), 97 | kernel_size=3, stride=2, padding=1), 98 | # Paper specifies stride-2, but unchanged size. 99 | # Assume its a typo and use stride-1 convolution 100 | DepthwiseSeparableFusedConv2d( 101 | nearby_int(width * 1024), nearby_int(width * 1024), 102 | kernel_size=3, stride=1, padding=1) 103 | ] 104 | self.features = nn.Sequential(*layers) 105 | self.avg_pool = nn.AvgPool2d(7) 106 | self.fc = nn.Linear(nearby_int(width * 1024), num_classes) 107 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 108 | std=[0.229, 0.224, 0.225]) 109 | self.input_transform = { 110 | 'train': transforms.Compose([ 111 | transforms.RandomResizedCrop(224, scale=(0.3, 1.0)), 112 | transforms.RandomHorizontalFlip(), 113 | transforms.ToTensor(), 114 | normalize 115 | ]), 116 | 'eval': transforms.Compose([ 117 | transforms.Resize(256), 118 | transforms.CenterCrop(224), 119 | transforms.ToTensor(), 120 | normalize 121 | ]) 122 | } 123 | self.regime = [ 124 | {'epoch': 0, 'optimizer': 'SGD', 'lr': 1e-1, 'momentum': 0.9}, 125 | {'epoch': 30, 'lr': 1e-2}, 126 | {'epoch': 60, 'lr': 1e-3}, 127 | {'epoch': 80, 'lr': 1e-4} 128 | ] 129 | 130 | 131 | @staticmethod 132 | def regularization(model, weight_decay=4e-5): 133 | l2_params = 0 134 | for m in model.modules(): 135 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 136 | l2_params += m.weight.pow(2).sum() 137 | if m.bias is not None: 138 | l2_params += m.bias.pow(2).sum() 139 | return weight_decay * 0.5 * l2_params 140 | 141 | def forward(self, x): 142 | x = self.features(x) 143 | x = self.avg_pool(x) 144 | x = x.view(x.size(0), -1) 145 | x = self.fc(x) 146 | return x 147 | 148 | 149 | def mobilenet(**kwargs): 150 | r"""MobileNet model architecture from the `"MobileNets: 151 | Efficient Convolutional Neural Networks for Mobile Vision Applications" 152 | `_ paper. 153 | 154 | Args: 155 | pretrained (bool): If True, returns a model pre-trained on ImageNet 156 | """ 157 | num_classes, width, alpha, shallow = map( 158 | kwargs.get, ['num_classes', 'width', 'alpha', 'shallow']) 159 | return MobileNet(width=width, shallow=shallow, num_classes=num_classes) 160 | -------------------------------------------------------------------------------- /models/mobilenet_quantized.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.modules.utils import _single, _pair, _triple 4 | import math 5 | import torch.nn.functional as F 6 | from torch.nn.modules.utils import _pair 7 | import torchvision.transforms as transforms 8 | from .modules.quantize import quantize, quantize_grad, QConv2d, QLinear, RangeBN 9 | __all__ = ['mobilenet_quantized'] 10 | 11 | NUM_BITS = 8 12 | NUM_BITS_WEIGHT = 8 13 | NUM_BITS_GRAD = 8 14 | BIPRECISION = True 15 | 16 | 17 | def nearby_int(n): 18 | return int(round(n)) 19 | 20 | 21 | def init_model(model): 22 | for m in model.modules(): 23 | if isinstance(m, QConv2d): 24 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 25 | m.weight.data.normal_(0, math.sqrt(2. / n)) 26 | elif isinstance(m, RangeBN): 27 | m.weight.data.fill_(1) 28 | m.bias.data.zero_() 29 | model.fc.weight.data.normal_(0, 0.01) 30 | model.fc.bias.data.zero_() 31 | 32 | 33 | class DepthwiseSeparableFusedConv2d(nn.Module): 34 | 35 | def __init__(self, in_channels, out_channels, kernel_size, 36 | stride=1, padding=0): 37 | super(DepthwiseSeparableFusedConv2d, self).__init__() 38 | self.components = nn.Sequential( 39 | QConv2d(in_channels, in_channels, kernel_size, 40 | stride=stride, padding=padding, groups=in_channels, num_bits=NUM_BITS, num_bits_weight=NUM_BITS_WEIGHT, num_bits_grad=NUM_BITS_GRAD, biprecision=BIPRECISION), 41 | RangeBN(in_channels, num_bits=NUM_BITS, 42 | num_bits_grad=NUM_BITS_GRAD), 43 | nn.ReLU(), 44 | 45 | QConv2d(in_channels, out_channels, 1, bias=False, num_bits=NUM_BITS, 46 | num_bits_weight=NUM_BITS_WEIGHT, num_bits_grad=NUM_BITS_GRAD, biprecision=BIPRECISION), 47 | RangeBN(out_channels, num_bits=NUM_BITS, 48 | num_bits_grad=NUM_BITS_GRAD), 49 | nn.ReLU() 50 | ) 51 | 52 | def forward(self, x): 53 | return self.components(x) 54 | 55 | 56 | class MobileNet(nn.Module): 57 | 58 | def __init__(self, width=1., shallow=False, num_classes=1000): 59 | super(MobileNet, self).__init__() 60 | num_classes = num_classes or 1000 61 | width = width or 1. 62 | layers = [ 63 | QConv2d(3, nearby_int(width * 32), 64 | kernel_size=3, stride=2, padding=1, bias=False, num_bits=NUM_BITS, 65 | num_bits_weight=NUM_BITS_WEIGHT, num_bits_grad=NUM_BITS_GRAD, biprecision=BIPRECISION), 66 | RangeBN(nearby_int(width * 32), num_bits=NUM_BITS, 67 | num_bits_grad=NUM_BITS_GRAD), 68 | nn.ReLU(inplace=True), 69 | 70 | DepthwiseSeparableFusedConv2d( 71 | nearby_int(width * 32), nearby_int(width * 64), 72 | kernel_size=3, padding=1), 73 | DepthwiseSeparableFusedConv2d( 74 | nearby_int(width * 64), nearby_int(width * 128), 75 | kernel_size=3, stride=2, padding=1), 76 | DepthwiseSeparableFusedConv2d( 77 | nearby_int(width * 128), nearby_int(width * 128), 78 | kernel_size=3, padding=1), 79 | DepthwiseSeparableFusedConv2d( 80 | nearby_int(width * 128), nearby_int(width * 256), 81 | kernel_size=3, stride=2, padding=1), 82 | DepthwiseSeparableFusedConv2d( 83 | nearby_int(width * 256), nearby_int(width * 256), 84 | kernel_size=3, padding=1), 85 | DepthwiseSeparableFusedConv2d( 86 | nearby_int(width * 256), nearby_int(width * 512), 87 | kernel_size=3, stride=2, padding=1) 88 | ] 89 | if not shallow: 90 | # 5x 512->512 DW-separable convolutions 91 | layers += [ 92 | DepthwiseSeparableFusedConv2d( 93 | nearby_int(width * 512), nearby_int(width * 512), 94 | kernel_size=3, padding=1), 95 | DepthwiseSeparableFusedConv2d( 96 | nearby_int(width * 512), nearby_int(width * 512), 97 | kernel_size=3, padding=1), 98 | DepthwiseSeparableFusedConv2d( 99 | nearby_int(width * 512), nearby_int(width * 512), 100 | kernel_size=3, padding=1), 101 | DepthwiseSeparableFusedConv2d( 102 | nearby_int(width * 512), nearby_int(width * 512), 103 | kernel_size=3, padding=1), 104 | DepthwiseSeparableFusedConv2d( 105 | nearby_int(width * 512), nearby_int(width * 512), 106 | kernel_size=3, padding=1), 107 | ] 108 | layers += [ 109 | DepthwiseSeparableFusedConv2d( 110 | nearby_int(width * 512), nearby_int(width * 1024), 111 | kernel_size=3, stride=2, padding=1), 112 | # Paper specifies stride-2, but unchanged size. 113 | # Assume its a typo and use stride-1 convolution 114 | DepthwiseSeparableFusedConv2d( 115 | nearby_int(width * 1024), nearby_int(width * 1024), 116 | kernel_size=3, stride=1, padding=1) 117 | ] 118 | self.features = nn.Sequential(*layers) 119 | self.avg_pool = nn.AvgPool2d(7) 120 | self.fc = QLinear(nearby_int(width * 1024), num_classes, num_bits=NUM_BITS, 121 | num_bits_weight=NUM_BITS_WEIGHT, num_bits_grad=NUM_BITS_GRAD, biprecision=BIPRECISION) 122 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 123 | std=[0.229, 0.224, 0.225]) 124 | self.input_transform = { 125 | 'train': transforms.Compose([ 126 | transforms.RandomResizedCrop(224, scale=(0.3, 1.0)), 127 | transforms.RandomHorizontalFlip(), 128 | transforms.ToTensor(), 129 | normalize 130 | ]), 131 | 'eval': transforms.Compose([ 132 | transforms.Resize(256), 133 | transforms.CenterCrop(224), 134 | transforms.ToTensor(), 135 | normalize 136 | ]) 137 | } 138 | self.regime = [ 139 | {'epoch': 0, 'optimizer': 'SGD', 'lr': 1e-1, 'momentum': 0.9}, 140 | {'epoch': 30, 'lr': 1e-2}, 141 | {'epoch': 60, 'lr': 1e-3}, 142 | {'epoch': 80, 'lr': 1e-4} 143 | ] 144 | 145 | @staticmethod 146 | def regularization(model, weight_decay=4e-5): 147 | l2_params = 0 148 | for m in model.modules(): 149 | if isinstance(m, QConv2d) or isinstance(m, nn.Linear): 150 | l2_params += m.weight.pow(2).sum() 151 | if m.bias is not None: 152 | l2_params += m.bias.pow(2).sum() 153 | return weight_decay * 0.5 * l2_params 154 | 155 | def forward(self, x): 156 | x = self.features(x) 157 | x = self.avg_pool(x) 158 | x = x.view(x.size(0), -1) 159 | x = self.fc(x) 160 | return x 161 | 162 | 163 | def mobilenet_quantized(**kwargs): 164 | r"""MobileNet model architecture from the `"MobileNets: 165 | Efficient Convolutional Neural Networks for Mobile Vision Applications" 166 | `_ paper. 167 | 168 | Args: 169 | pretrained (bool): If True, returns a model pre-trained on ImageNet 170 | """ 171 | num_classes, width, alpha, shallow = map( 172 | kwargs.get, ['num_classes', 'width', 'alpha', 'shallow']) 173 | return MobileNet(width=width, shallow=shallow, num_classes=num_classes) 174 | -------------------------------------------------------------------------------- /models/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eladhoffer/quantized.pytorch/e09c447a50a6a4c7dabf6176f20c931422aefd67/models/modules/__init__.py -------------------------------------------------------------------------------- /models/modules/bwn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Bounded weight norm 3 | Weight Normalization from https://arxiv.org/abs/1602.07868 4 | taken and adapted from https://github.com/pytorch/pytorch/blob/master/torch/nn/utils/weight_norm.py 5 | """ 6 | import torch 7 | from torch.nn.parameter import Parameter 8 | from torch.autograd import Variable, Function 9 | import torch.nn as nn 10 | 11 | 12 | def gather_params(self, memo=None, param_func=lambda s: s._parameters.values()): 13 | if memo is None: 14 | memo = set() 15 | for p in param_func(self): 16 | if p is not None and p not in memo: 17 | memo.add(p) 18 | yield p 19 | for m in self.children(): 20 | for p in gather_params(m, memo, param_func): 21 | yield p 22 | 23 | nn.Module.gather_params = gather_params 24 | 25 | 26 | def _norm(x, dim, p=2): 27 | """Computes the norm over all dimensions except dim""" 28 | if p == float('inf'): # infinity norm 29 | func = lambda x, dim: x.abs().max(dim=dim)[0] 30 | else: 31 | func = lambda x, dim: torch.norm(x, dim=dim, p=p) 32 | if dim is None: 33 | return x.norm(p=p) 34 | elif dim == 0: 35 | output_size = (x.size(0),) + (1,) * (x.dim() - 1) 36 | return func(x.contiguous().view(x.size(0), -1), 1).view(*output_size) 37 | elif dim == x.dim() - 1: 38 | output_size = (1,) * (x.dim() - 1) + (x.size(-1),) 39 | return func(x.contiguous().view(-1, x.size(-1)), 0).view(*output_size) 40 | else: 41 | return _norm(x.transpose(0, dim), 0).transpose(0, dim) 42 | 43 | 44 | def _mean(p, dim): 45 | """Computes the mean over all dimensions except dim""" 46 | if dim is None: 47 | return p.mean() 48 | elif dim == 0: 49 | output_size = (p.size(0),) + (1,) * (p.dim() - 1) 50 | return p.contiguous().view(p.size(0), -1).mean(dim=1).view(*output_size) 51 | elif dim == p.dim() - 1: 52 | output_size = (1,) * (p.dim() - 1) + (p.size(-1),) 53 | return p.contiguous().view(-1, p.size(-1)).mean(dim=0).view(*output_size) 54 | else: 55 | return _mean(p.transpose(0, dim), 0).transpose(0, dim) 56 | 57 | 58 | class BoundedWeighNorm(object): 59 | 60 | def __init__(self, name, dim, p): 61 | self.name = name 62 | self.dim = dim 63 | self.p = p 64 | 65 | def compute_weight(self, module): 66 | v = getattr(module, self.name + '_v') 67 | pre_norm = getattr(module, self.name + '_prenorm') 68 | return v * (pre_norm / _norm(v, self.dim, p=self.p)) 69 | 70 | @staticmethod 71 | def apply(module, name, dim, p): 72 | fn = BoundedWeighNorm(name, dim, p) 73 | 74 | weight = getattr(module, name) 75 | 76 | # remove w from parameter list 77 | del module._parameters[name] 78 | 79 | prenorm = _norm(weight, dim, p=p).mean() 80 | module.register_buffer(name + '_prenorm', prenorm.detach()) 81 | pre_norm = getattr(module, name + '_prenorm') 82 | print(pre_norm) 83 | module.register_parameter(name + '_v', Parameter(weight.data)) 84 | setattr(module, name, fn.compute_weight(module)) 85 | 86 | # recompute weight before every forward() 87 | module.register_forward_pre_hook(fn) 88 | 89 | def gather_normed_params(self, memo=None, param_func=lambda s: fn.compute_weight(s)): 90 | return gather_params(self, memo, param_func) 91 | module.gather_params = gather_normed_params 92 | return fn 93 | 94 | def remove(self, module): 95 | weight = self.compute_weight(module) 96 | delattr(module, self.name) 97 | del module._parameters[self.name + '_prenorm'] 98 | del module._parameters[self.name + '_v'] 99 | module.register_parameter(self.name, Parameter(weight.data)) 100 | 101 | def __call__(self, module, inputs): 102 | setattr(module, self.name, self.compute_weight(module)) 103 | 104 | 105 | def weight_norm(module, name='weight', dim=0, p=2): 106 | r"""Applies weight normalization to a parameter in the given module. 107 | 108 | .. math:: 109 | \mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|} 110 | 111 | Weight normalization is a reparameterization that decouples the magnitude 112 | of a weight tensor from its direction. This replaces the parameter specified 113 | by `name` (e.g. "weight") with two parameters: one specifying the magnitude 114 | (e.g. "weight_g") and one specifying the direction (e.g. "weight_v"). 115 | Weight normalization is implemented via a hook that recomputes the weight 116 | tensor from the magnitude and direction before every :meth:`~Module.forward` 117 | call. 118 | 119 | By default, with `dim=0`, the norm is computed independently per output 120 | channel/plane. To compute a norm over the entire weight tensor, use 121 | `dim=None`. 122 | 123 | See https://arxiv.org/abs/1602.07868 124 | 125 | Args: 126 | module (nn.Module): containing module 127 | name (str, optional): name of weight parameter 128 | dim (int, optional): dimension over which to compute the norm 129 | 130 | Returns: 131 | The original module with the weight norm hook 132 | 133 | Example:: 134 | 135 | >>> m = weight_norm(nn.Linear(20, 40), name='weight') 136 | Linear (20 -> 40) 137 | >>> m.weight_g.size() 138 | torch.Size([40, 1]) 139 | >>> m.weight_v.size() 140 | torch.Size([40, 20]) 141 | 142 | """ 143 | BoundedWeighNorm.apply(module, name, dim, p) 144 | return module 145 | 146 | 147 | def remove_weight_norm(module, name='weight'): 148 | r"""Removes the weight normalization reparameterization from a module. 149 | 150 | Args: 151 | module (nn.Module): containing module 152 | name (str, optional): name of weight parameter 153 | 154 | Example: 155 | >>> m = weight_norm(nn.Linear(20, 40)) 156 | >>> remove_weight_norm(m) 157 | """ 158 | for k, hook in module._forward_pre_hooks.items(): 159 | if isinstance(hook, BoundedWeighNorm) and hook.name == name: 160 | hook.remove(module) 161 | del module._forward_pre_hooks[k] 162 | return module 163 | 164 | raise ValueError("weight_norm of '{}' not found in {}" 165 | .format(name, module)) 166 | -------------------------------------------------------------------------------- /models/modules/quantize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd.function import InplaceFunction, Function 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import math 6 | 7 | 8 | def _mean(p, dim): 9 | """Computes the mean over all dimensions except dim""" 10 | if dim is None: 11 | return p.mean() 12 | elif dim == 0: 13 | output_size = (p.size(0),) + (1,) * (p.dim() - 1) 14 | return p.contiguous().view(p.size(0), -1).mean(dim=1).view(*output_size) 15 | elif dim == p.dim() - 1: 16 | output_size = (1,) * (p.dim() - 1) + (p.size(-1),) 17 | return p.contiguous().view(-1, p.size(-1)).mean(dim=0).view(*output_size) 18 | else: 19 | return _mean(p.transpose(0, dim), 0).transpose(0, dim) 20 | 21 | 22 | class UniformQuantize(InplaceFunction): 23 | 24 | @classmethod 25 | def forward(cls, ctx, input, num_bits=8, min_value=None, max_value=None, 26 | stochastic=False, inplace=False, enforce_true_zero=False, num_chunks=None, out_half=False): 27 | 28 | num_chunks = num_chunks = input.shape[ 29 | 0] if num_chunks is None else num_chunks 30 | if min_value is None or max_value is None: 31 | B = input.shape[0] 32 | y = input.view(B // num_chunks, -1) 33 | if min_value is None: 34 | min_value = y.min(-1)[0].mean(-1) # C 35 | #min_value = float(input.view(input.size(0), -1).min(-1)[0].mean()) 36 | if max_value is None: 37 | #max_value = float(input.view(input.size(0), -1).max(-1)[0].mean()) 38 | max_value = y.max(-1)[0].mean(-1) # C 39 | ctx.inplace = inplace 40 | ctx.num_bits = num_bits 41 | ctx.min_value = min_value 42 | ctx.max_value = max_value 43 | ctx.stochastic = stochastic 44 | 45 | if ctx.inplace: 46 | ctx.mark_dirty(input) 47 | output = input 48 | else: 49 | output = input.clone() 50 | 51 | qmin = 0. 52 | qmax = 2.**num_bits - 1. 53 | #import pdb; pdb.set_trace() 54 | scale = (max_value - min_value) / (qmax - qmin) 55 | 56 | scale = max(scale, 1e-8) 57 | 58 | if enforce_true_zero: 59 | initial_zero_point = qmin - min_value / scale 60 | zero_point = 0. 61 | # make zero exactly represented 62 | if initial_zero_point < qmin: 63 | zero_point = qmin 64 | elif initial_zero_point > qmax: 65 | zero_point = qmax 66 | else: 67 | zero_point = initial_zero_point 68 | zero_point = int(zero_point) 69 | output.div_(scale).add_(zero_point) 70 | else: 71 | output.add_(-min_value).div_(scale).add_(qmin) 72 | 73 | if ctx.stochastic: 74 | noise = output.new(output.shape).uniform_(-0.5, 0.5) 75 | output.add_(noise) 76 | output.clamp_(qmin, qmax).round_() # quantize 77 | 78 | if enforce_true_zero: 79 | output.add_(-zero_point).mul_(scale) # dequantize 80 | else: 81 | output.add_(-qmin).mul_(scale).add_(min_value) # dequantize 82 | if out_half and num_bits <= 16: 83 | output = output.half() 84 | return output 85 | 86 | @staticmethod 87 | def backward(ctx, grad_output): 88 | # straight-through estimator 89 | grad_input = grad_output 90 | return grad_input, None, None, None, None, None, None 91 | 92 | 93 | class UniformQuantizeGrad(InplaceFunction): 94 | 95 | @classmethod 96 | def forward(cls, ctx, input, num_bits=8, min_value=None, max_value=None, stochastic=True, inplace=False): 97 | ctx.inplace = inplace 98 | ctx.num_bits = num_bits 99 | ctx.min_value = min_value 100 | ctx.max_value = max_value 101 | ctx.stochastic = stochastic 102 | return input 103 | 104 | @staticmethod 105 | def backward(ctx, grad_output): 106 | if ctx.min_value is None: 107 | min_value = float(grad_output.min()) 108 | # min_value = float(grad_output.view( 109 | # grad_output.size(0), -1).min(-1)[0].mean()) 110 | else: 111 | min_value = ctx.min_value 112 | if ctx.max_value is None: 113 | max_value = float(grad_output.max()) 114 | # max_value = float(grad_output.view( 115 | # grad_output.size(0), -1).max(-1)[0].mean()) 116 | else: 117 | max_value = ctx.max_value 118 | grad_input = UniformQuantize().apply(grad_output, ctx.num_bits, 119 | min_value, max_value, ctx.stochastic, ctx.inplace) 120 | return grad_input, None, None, None, None, None 121 | 122 | 123 | def conv2d_biprec(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, num_bits_grad=None): 124 | out1 = F.conv2d(input.detach(), weight, bias, 125 | stride, padding, dilation, groups) 126 | out2 = F.conv2d(input, weight.detach(), bias.detach() if bias is not None else None, 127 | stride, padding, dilation, groups) 128 | out2 = quantize_grad(out2, num_bits=num_bits_grad) 129 | return out1 + out2 - out1.detach() 130 | 131 | 132 | def linear_biprec(input, weight, bias=None, num_bits_grad=None): 133 | out1 = F.linear(input.detach(), weight, bias) 134 | out2 = F.linear(input, weight.detach(), bias.detach() 135 | if bias is not None else None) 136 | out2 = quantize_grad(out2, num_bits=num_bits_grad) 137 | return out1 + out2 - out1.detach() 138 | 139 | 140 | def quantize(x, num_bits=8, min_value=None, max_value=None, num_chunks=None, stochastic=False, inplace=False): 141 | return UniformQuantize().apply(x, num_bits, min_value, max_value, num_chunks, stochastic, inplace) 142 | 143 | 144 | def quantize_grad(x, num_bits=8, min_value=None, max_value=None, stochastic=True, inplace=False): 145 | return UniformQuantizeGrad().apply(x, num_bits, min_value, max_value, stochastic, inplace) 146 | 147 | 148 | class QuantMeasure(nn.Module): 149 | """docstring for QuantMeasure.""" 150 | 151 | def __init__(self, num_bits=8, momentum=0.1): 152 | super(QuantMeasure, self).__init__() 153 | self.register_buffer('running_min', torch.zeros(1)) 154 | self.register_buffer('running_max', torch.zeros(1)) 155 | self.momentum = momentum 156 | self.num_bits = num_bits 157 | 158 | def forward(self, input): 159 | if self.training: 160 | min_value = input.detach().view( 161 | input.size(0), -1).min(-1)[0].mean() 162 | max_value = input.detach().view( 163 | input.size(0), -1).max(-1)[0].mean() 164 | self.running_min.mul_(self.momentum).add_( 165 | min_value * (1 - self.momentum)) 166 | self.running_max.mul_(self.momentum).add_( 167 | max_value * (1 - self.momentum)) 168 | else: 169 | min_value = self.running_min 170 | max_value = self.running_max 171 | return quantize(input, self.num_bits, min_value=float(min_value), max_value=float(max_value), num_chunks=16) 172 | 173 | 174 | class QConv2d(nn.Conv2d): 175 | """docstring for QConv2d.""" 176 | 177 | def __init__(self, in_channels, out_channels, kernel_size, 178 | stride=1, padding=0, dilation=1, groups=1, bias=True, num_bits=8, num_bits_weight=None, num_bits_grad=None, biprecision=False): 179 | super(QConv2d, self).__init__(in_channels, out_channels, kernel_size, 180 | stride, padding, dilation, groups, bias) 181 | self.num_bits = num_bits 182 | self.num_bits_weight = num_bits_weight or num_bits 183 | self.num_bits_grad = num_bits_grad 184 | self.quantize_input = QuantMeasure(self.num_bits) 185 | self.biprecision = biprecision 186 | 187 | def forward(self, input): 188 | qinput = self.quantize_input(input) 189 | qweight = quantize(self.weight, num_bits=self.num_bits_weight, 190 | min_value=float(self.weight.min()), 191 | max_value=float(self.weight.max())) 192 | if self.bias is not None: 193 | qbias = quantize(self.bias, num_bits=self.num_bits_weight) 194 | else: 195 | qbias = None 196 | if not self.biprecision or self.num_bits_grad is None: 197 | output = F.conv2d(qinput, qweight, qbias, self.stride, 198 | self.padding, self.dilation, self.groups) 199 | if self.num_bits_grad is not None: 200 | output = quantize_grad(output, num_bits=self.num_bits_grad) 201 | else: 202 | output = conv2d_biprec(qinput, qweight, qbias, self.stride, 203 | self.padding, self.dilation, self.groups, num_bits_grad=self.num_bits_grad) 204 | 205 | return output 206 | 207 | 208 | class QLinear(nn.Linear): 209 | """docstring for QConv2d.""" 210 | 211 | def __init__(self, in_features, out_features, bias=True, num_bits=8, num_bits_weight=None, num_bits_grad=None, biprecision=False): 212 | super(QLinear, self).__init__(in_features, out_features, bias) 213 | self.num_bits = num_bits 214 | self.num_bits_weight = num_bits_weight or num_bits 215 | self.num_bits_grad = num_bits_grad 216 | self.biprecision = biprecision 217 | self.quantize_input = QuantMeasure(self.num_bits) 218 | 219 | def forward(self, input): 220 | qinput = self.quantize_input(input) 221 | qweight = quantize(self.weight, num_bits=self.num_bits_weight, 222 | min_value=float(self.weight.min()), 223 | max_value=float(self.weight.max())) 224 | if self.bias is not None: 225 | qbias = quantize(self.bias, num_bits=self.num_bits_weight) 226 | else: 227 | qbias = None 228 | 229 | if not self.biprecision or self.num_bits_grad is None: 230 | output = F.linear(qinput, qweight, qbias) 231 | if self.num_bits_grad is not None: 232 | output = quantize_grad(output, num_bits=self.num_bits_grad) 233 | else: 234 | output = linear_biprec(qinput, qweight, qbias, self.num_bits_grad) 235 | return output 236 | 237 | 238 | class RangeBN(nn.Module): 239 | # this is normalized RangeBN 240 | 241 | def __init__(self, num_features, dim=1, momentum=0.1, affine=True, num_chunks=16, eps=1e-5, num_bits=8, num_bits_grad=8): 242 | super(RangeBN, self).__init__() 243 | self.register_buffer('running_mean', torch.zeros(num_features)) 244 | self.register_buffer('running_var', torch.zeros(num_features)) 245 | 246 | self.momentum = momentum 247 | self.dim = dim 248 | if affine: 249 | self.bias = nn.Parameter(torch.Tensor(num_features)) 250 | self.weight = nn.Parameter(torch.Tensor(num_features)) 251 | self.num_bits = num_bits 252 | self.num_bits_grad = num_bits_grad 253 | self.quantize_input = QuantMeasure(self.num_bits) 254 | self.eps = eps 255 | self.num_chunks = num_chunks 256 | self.reset_params() 257 | 258 | def reset_params(self): 259 | if self.weight is not None: 260 | self.weight.data.uniform_() 261 | if self.bias is not None: 262 | self.bias.data.zero_() 263 | 264 | def forward(self, x): 265 | x = self.quantize_input(x) 266 | if x.dim() == 2: # 1d 267 | x = x.unsqueeze(-1,).unsqueeze(-1) 268 | 269 | if self.training: 270 | B, C, H, W = x.shape 271 | y = x.transpose(0, 1).contiguous() # C x B x H x W 272 | y = y.view(C, self.num_chunks, B * H * W // self.num_chunks) 273 | mean_max = y.max(-1)[0].mean(-1) # C 274 | mean_min = y.min(-1)[0].mean(-1) # C 275 | mean = y.view(C, -1).mean(-1) # C 276 | scale_fix = (0.5 * 0.35) * (1 + (math.pi * math.log(4)) ** 277 | 0.5) / ((2 * math.log(y.size(-1))) ** 0.5) 278 | 279 | scale = 1 / ((mean_max - mean_min) * scale_fix + self.eps) 280 | 281 | self.running_mean.detach().mul_(self.momentum).add_( 282 | mean * (1 - self.momentum)) 283 | 284 | self.running_var.detach().mul_(self.momentum).add_( 285 | scale * (1 - self.momentum)) 286 | else: 287 | mean = self.running_mean 288 | scale = self.running_var 289 | scale = quantize(scale, num_bits=self.num_bits, min_value=float( 290 | scale.min()), max_value=float(scale.max())) 291 | out = (x - mean.view(1, mean.size(0), 1, 1)) * \ 292 | scale.view(1, scale.size(0), 1, 1) 293 | 294 | if self.weight is not None: 295 | qweight = quantize(self.weight, num_bits=self.num_bits, 296 | min_value=float(self.weight.min()), 297 | max_value=float(self.weight.max())) 298 | out = out * qweight.view(1, qweight.size(0), 1, 1) 299 | 300 | if self.bias is not None: 301 | qbias = quantize(self.bias, num_bits=self.num_bits) 302 | out = out + qbias.view(1, qbias.size(0), 1, 1) 303 | if self.num_bits_grad is not None: 304 | out = quantize_grad(out, num_bits=self.num_bits_grad) 305 | 306 | if out.size(3) == 1 and out.size(2) == 1: 307 | out = out.squeeze(-1).squeeze(-1) 308 | return out 309 | -------------------------------------------------------------------------------- /models/modules/rnlu.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch.autograd.function import InplaceFunction 4 | from torch.autograd import Variable 5 | import torch.nn as nn 6 | import math 7 | 8 | 9 | class BiReLUFunction(InplaceFunction): 10 | 11 | @classmethod 12 | def forward(cls, ctx, input, inplace=False): 13 | if input.size(1) % 2 != 0: 14 | raise RuntimeError("dimension 1 of input must be multiple of 2, " 15 | "but got {}".format(input.size(1))) 16 | ctx.inplace = inplace 17 | 18 | if ctx.inplace: 19 | ctx.mark_dirty(input) 20 | output = input 21 | else: 22 | output = input.clone() 23 | 24 | pos, neg = output.chunk(2, dim=1) 25 | pos.clamp_(min=0) 26 | neg.clamp_(max=0) 27 | # scale = (pos - neg).view(pos.size(0), -1).mean(1).div_(2) 28 | # output. 29 | ctx.save_for_backward(output) 30 | return output 31 | 32 | @staticmethod 33 | def backward(ctx, grad_output): 34 | output, = ctx.saved_variables 35 | grad_input = grad_output.masked_fill(output.eq(0), 0) 36 | return grad_input, None 37 | 38 | 39 | def birelu(x, inplace=False): 40 | return BiReLUFunction().apply(x, inplace) 41 | 42 | 43 | class BiReLU(nn.Module): 44 | """docstring for BiReLU.""" 45 | 46 | def __init__(self, inplace=False): 47 | super(BiReLU, self).__init__() 48 | self.inplace = inplace 49 | 50 | def forward(self, inputs): 51 | return birelu(inputs, inplace=self.inplace) 52 | 53 | 54 | def binorm(x, shift=0, scale_fix=(2 / math.pi) ** 0.5): 55 | pos, neg = (x + shift).split(2, dim=1) 56 | scale = (pos - neg).view(pos.size(0), -1).mean(1).div_(2) * scale_fix 57 | return x / scale 58 | 59 | 60 | def _mean(p, dim): 61 | """Computes the mean over all dimensions except dim""" 62 | if dim is None: 63 | return p.mean() 64 | elif dim == 0: 65 | output_size = (p.size(0),) + (1,) * (p.dim() - 1) 66 | return p.contiguous().view(p.size(0), -1).mean(dim=1).view(*output_size) 67 | elif dim == p.dim() - 1: 68 | output_size = (1,) * (p.dim() - 1) + (p.size(-1),) 69 | return p.contiguous().view(-1, p.size(-1)).mean(dim=0).view(*output_size) 70 | else: 71 | return _mean(p.transpose(0, dim), 0).transpose(0, dim) 72 | 73 | 74 | def rnlu(x, inplace=False, shift=0, scale_fix=(math.pi / 2) ** 0.5): 75 | x = birelu(x, inplace=inplace) 76 | pos, neg = (x + shift).chunk(2, dim=1) 77 | # scale = torch.cat((_mean(pos, 1), -_mean(neg, 1)), 1) * scale_fix + 1e-5 78 | scale = (pos - neg).view(pos.size(0), -1).mean(1) * scale_fix + 1e-8 79 | return x / scale.view(scale.size(0), *([1] * (x.dim() - 1))) 80 | 81 | 82 | class RnLU(nn.Module): 83 | """docstring for RnLU.""" 84 | 85 | def __init__(self, inplace=False): 86 | super(RnLU, self).__init__() 87 | self.inplace = inplace 88 | 89 | def forward(self, x): 90 | return rnlu(x, inplace=self.inplace) 91 | 92 | # output. 93 | if __name__ == "__main__": 94 | x = Variable(torch.randn(2, 16, 5, 5).cuda(), requires_grad=True) 95 | output = rnlu(x) 96 | 97 | output.sum().backward() 98 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torchvision.transforms as transforms 3 | import math 4 | 5 | __all__ = ['resnet'] 6 | 7 | 8 | def conv3x3(in_planes, out_planes, stride=1): 9 | "3x3 convolution with padding" 10 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 11 | padding=1, bias=False) 12 | 13 | 14 | def init_model(model): 15 | for m in model.modules(): 16 | if isinstance(m, nn.Conv2d): 17 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 18 | m.weight.data.normal_(0, math.sqrt(2. / n)) 19 | elif isinstance(m, nn.BatchNorm2d): 20 | m.weight.data.fill_(1) 21 | m.bias.data.zero_() 22 | 23 | 24 | class BasicBlock(nn.Module): 25 | expansion = 1 26 | 27 | def __init__(self, inplanes, planes, stride=1, downsample=None): 28 | super(BasicBlock, self).__init__() 29 | self.conv1 = conv3x3(inplanes, planes, stride) 30 | self.bn1 = nn.BatchNorm2d(planes) 31 | self.relu = nn.ReLU(inplace=True) 32 | self.conv2 = conv3x3(planes, planes) 33 | self.bn2 = nn.BatchNorm2d(planes) 34 | self.downsample = downsample 35 | self.stride = stride 36 | 37 | def forward(self, x): 38 | residual = x 39 | 40 | out = self.conv1(x) 41 | out = self.bn1(out) 42 | out = self.relu(out) 43 | 44 | out = self.conv2(out) 45 | out = self.bn2(out) 46 | 47 | if self.downsample is not None: 48 | residual = self.downsample(x) 49 | 50 | out += residual 51 | out = self.relu(out) 52 | 53 | return out 54 | 55 | 56 | class Bottleneck(nn.Module): 57 | expansion = 4 58 | 59 | def __init__(self, inplanes, planes, stride=1, downsample=None): 60 | super(Bottleneck, self).__init__() 61 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 62 | self.bn1 = nn.BatchNorm2d(planes) 63 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 64 | padding=1, bias=False) 65 | self.bn2 = nn.BatchNorm2d(planes) 66 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 67 | self.bn3 = nn.BatchNorm2d(planes * 4) 68 | self.relu = nn.ReLU(inplace=True) 69 | self.downsample = downsample 70 | self.stride = stride 71 | 72 | def forward(self, x): 73 | residual = x 74 | 75 | out = self.conv1(x) 76 | out = self.bn1(out) 77 | out = self.relu(out) 78 | 79 | out = self.conv2(out) 80 | out = self.bn2(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv3(out) 84 | out = self.bn3(out) 85 | 86 | if self.downsample is not None: 87 | residual = self.downsample(x) 88 | 89 | out += residual 90 | out = self.relu(out) 91 | 92 | return out 93 | 94 | 95 | class ResNet(nn.Module): 96 | 97 | def __init__(self): 98 | super(ResNet, self).__init__() 99 | 100 | def _make_layer(self, block, planes, blocks, stride=1): 101 | downsample = None 102 | if stride != 1 or self.inplanes != planes * block.expansion: 103 | downsample = nn.Sequential( 104 | nn.Conv2d(self.inplanes, planes * block.expansion, 105 | kernel_size=1, stride=stride, bias=False), 106 | nn.BatchNorm2d(planes * block.expansion), 107 | ) 108 | 109 | layers = [] 110 | layers.append(block(self.inplanes, planes, stride, downsample)) 111 | self.inplanes = planes * block.expansion 112 | for i in range(1, blocks): 113 | layers.append(block(self.inplanes, planes)) 114 | 115 | return nn.Sequential(*layers) 116 | 117 | def forward(self, x): 118 | x = self.conv1(x) 119 | x = self.bn1(x) 120 | x = self.relu(x) 121 | x = self.maxpool(x) 122 | 123 | x = self.layer1(x) 124 | x = self.layer2(x) 125 | x = self.layer3(x) 126 | x = self.layer4(x) 127 | 128 | x = self.avgpool(x) 129 | x = x.view(x.size(0), -1) 130 | x = self.fc(x) 131 | 132 | return x 133 | 134 | 135 | class ResNet_imagenet(ResNet): 136 | 137 | def __init__(self, num_classes=1000, 138 | block=Bottleneck, layers=[3, 4, 23, 3]): 139 | super(ResNet_imagenet, self).__init__() 140 | self.inplanes = 64 141 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 142 | bias=False) 143 | self.bn1 = nn.BatchNorm2d(64) 144 | self.relu = nn.ReLU(inplace=True) 145 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 146 | self.layer1 = self._make_layer(block, 64, layers[0]) 147 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 148 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 149 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 150 | self.avgpool = nn.AvgPool2d(7) 151 | self.fc = nn.Linear(512 * block.expansion, num_classes) 152 | 153 | init_model(self) 154 | self.regime = [ 155 | {'epoch': 0, 'optimizer': 'SGD', 'lr': 1e-1, 156 | 'weight_decay': 1e-4, 'momentum': 0.9}, 157 | {'epoch': 30, 'lr': 1e-2}, 158 | {'epoch': 60, 'lr': 1e-3, 'weight_decay': 0}, 159 | {'epoch': 90, 'lr': 1e-4} 160 | ] 161 | 162 | 163 | class ResNet_cifar10(ResNet): 164 | 165 | def __init__(self, num_classes=10, 166 | block=BasicBlock, depth=18): 167 | super(ResNet_cifar10, self).__init__() 168 | self.inplanes = 16 169 | n = int((depth - 2) / 6) 170 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, 171 | bias=False) 172 | self.bn1 = nn.BatchNorm2d(16) 173 | self.relu = nn.ReLU(inplace=True) 174 | self.maxpool = lambda x: x 175 | self.layer1 = self._make_layer(block, 16, n) 176 | self.layer2 = self._make_layer(block, 32, n, stride=2) 177 | self.layer3 = self._make_layer(block, 64, n, stride=2) 178 | self.layer4 = lambda x: x 179 | self.avgpool = nn.AvgPool2d(8) 180 | self.fc = nn.Linear(64, num_classes) 181 | 182 | init_model(self) 183 | self.regime = [ 184 | {'epoch': 0, 'optimizer': 'SGD', 'lr': 1e-1, 185 | 'weight_decay': 1e-4, 'momentum': 0.9}, 186 | {'epoch': 81, 'lr': 1e-2}, 187 | {'epoch': 122, 'lr': 1e-3, 'weight_decay': 0}, 188 | {'epoch': 164, 'lr': 1e-4} 189 | ] 190 | 191 | 192 | def resnet(**kwargs): 193 | num_classes, depth, dataset = map( 194 | kwargs.get, ['num_classes', 'depth', 'dataset']) 195 | if dataset == 'imagenet': 196 | num_classes = num_classes or 1000 197 | depth = depth or 50 198 | if depth == 18: 199 | return ResNet_imagenet(num_classes=num_classes, 200 | block=BasicBlock, layers=[2, 2, 2, 2]) 201 | if depth == 34: 202 | return ResNet_imagenet(num_classes=num_classes, 203 | block=BasicBlock, layers=[3, 4, 6, 3]) 204 | if depth == 50: 205 | return ResNet_imagenet(num_classes=num_classes, 206 | block=Bottleneck, layers=[3, 4, 6, 3]) 207 | if depth == 101: 208 | return ResNet_imagenet(num_classes=num_classes, 209 | block=Bottleneck, layers=[3, 4, 23, 3]) 210 | if depth == 152: 211 | return ResNet_imagenet(num_classes=num_classes, 212 | block=Bottleneck, layers=[3, 8, 36, 3]) 213 | 214 | elif dataset == 'cifar10': 215 | num_classes = num_classes or 10 216 | depth = depth or 56 217 | return ResNet_cifar10(num_classes=num_classes, 218 | block=BasicBlock, depth=depth) 219 | -------------------------------------------------------------------------------- /models/resnet_bwn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.transforms as transforms 4 | from .modules.rnlu import BiReLU 5 | from .modules.bwn import weight_norm as wn 6 | import math 7 | 8 | __all__ = ['resnet_bwn'] 9 | 10 | 11 | def conv3x3(in_planes, out_planes, stride=1): 12 | "3x3 convolution with padding" 13 | return wn(nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 14 | padding=1, bias=True)) 15 | 16 | 17 | def init_model(model): 18 | for m in model.modules(): 19 | if isinstance(m, nn.Conv2d): 20 | nn.init.constant_(m.bias, 0) 21 | nn.init.kaiming_normal_( 22 | m.weight, mode='fan_out', nonlinearity='relu') 23 | 24 | # if model.fc.weight.size(0) == 1000: 25 | nn.init.normal_(model.fc.weight, std=0.01) 26 | 27 | 28 | class BasicBlock(nn.Module): 29 | expansion = 1 30 | 31 | def __init__(self, inplanes, planes, stride=1, downsample=None): 32 | super(BasicBlock, self).__init__() 33 | self.conv1 = conv3x3(inplanes, planes, stride) 34 | self.relu = BiReLU(inplace=True) 35 | self.conv2 = conv3x3(planes, planes) 36 | self.downsample = downsample 37 | self.stride = stride 38 | 39 | def forward(self, x): 40 | residual = x 41 | 42 | out = self.conv1(x) 43 | out = self.relu(out) 44 | 45 | out = self.conv2(out) 46 | 47 | if self.downsample is not None: 48 | residual = self.downsample(x) 49 | 50 | out += residual 51 | out = self.relu(out) 52 | 53 | return out 54 | 55 | 56 | class Bottleneck(nn.Module): 57 | expansion = 4 58 | 59 | def __init__(self, inplanes, planes, stride=1, downsample=None): 60 | super(Bottleneck, self).__init__() 61 | self.conv1 = wn(nn.Conv2d(inplanes, planes, kernel_size=1, bias=True)) 62 | self.conv2 = wn(nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 63 | padding=1, bias=True)) 64 | self.conv3 = wn(nn.Conv2d(planes, planes * 4, kernel_size=1, bias=True)) 65 | self.relu = BiReLU(inplace=True) 66 | self.downsample = downsample 67 | self.stride = stride 68 | 69 | def forward(self, x): 70 | residual = x 71 | 72 | out = self.conv1(x) 73 | out = self.relu(out) 74 | 75 | out = self.conv2(out) 76 | out = self.relu(out) 77 | 78 | out = self.conv3(out) 79 | 80 | if self.downsample is not None: 81 | residual = self.downsample(x) 82 | 83 | out += residual 84 | out = self.relu(out) 85 | 86 | return out 87 | 88 | 89 | class ResNet(nn.Module): 90 | 91 | def __init__(self): 92 | super(ResNet, self).__init__() 93 | 94 | def _make_layer(self, block, planes, blocks, stride=1): 95 | downsample = None 96 | if stride != 1 or self.inplanes != planes * block.expansion: 97 | downsample = nn.Sequential( 98 | wn(nn.Conv2d(self.inplanes, planes * block.expansion, 99 | kernel_size=1, stride=stride, bias=True)), 100 | ) 101 | 102 | layers = [] 103 | layers.append(block(self.inplanes, planes, stride, downsample)) 104 | self.inplanes = planes * block.expansion 105 | for i in range(1, blocks): 106 | layers.append(block(self.inplanes, planes)) 107 | 108 | return nn.Sequential(*layers) 109 | 110 | def forward(self, x): 111 | x = self.conv1(x) 112 | x = self.relu(x) 113 | x = self.maxpool(x) 114 | 115 | x = self.layer1(x) 116 | x = self.layer2(x) 117 | x = self.layer3(x) 118 | x = self.layer4(x) 119 | 120 | x = self.avgpool(x) 121 | x = x.view(x.size(0), -1) 122 | x = self.fc(x) 123 | 124 | return x 125 | 126 | 127 | class ResNet_imagenet(ResNet): 128 | 129 | def __init__(self, num_classes=1000, 130 | block=Bottleneck, layers=[3, 4, 23, 3]): 131 | super(ResNet_imagenet, self).__init__() 132 | self.inplanes = 64 133 | self.conv1 = wn(nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 134 | bias=True)) 135 | self.relu = BiReLU(inplace=True) 136 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 137 | self.layer1 = self._make_layer(block, 64, layers[0]) 138 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 139 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 140 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 141 | self.avgpool = nn.AvgPool2d(7) 142 | self.fc = nn.Linear(512 * block.expansion, num_classes) 143 | 144 | init_model(self) 145 | self.regime = [ 146 | {'epoch': 0, 'optimizer': 'SGD', 'lr': 1e-1, 147 | 'weight_decay': 1e-4, 'momentum': 0.9}, 148 | {'epoch': 30, 'lr': 1e-2}, 149 | {'epoch': 60, 'lr': 1e-3, 'weight_decay': 0}, 150 | {'epoch': 90, 'lr': 1e-4} 151 | ] 152 | 153 | 154 | class ResNet_cifar10(ResNet): 155 | 156 | def __init__(self, num_classes=10, 157 | block=BasicBlock, depth=18): 158 | super(ResNet_cifar10, self).__init__() 159 | self.inplanes = 16 160 | n = int((depth - 2) / 6) 161 | self.conv1 = wn(nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, 162 | bias=True)) 163 | self.relu = BiReLU(inplace=True) 164 | self.maxpool = lambda x: x 165 | self.layer1 = self._make_layer(block, 16, n) 166 | self.layer2 = self._make_layer(block, 32, n, stride=2) 167 | self.layer3 = self._make_layer(block, 64, n, stride=2) 168 | self.layer4 = lambda x: x 169 | self.avgpool = nn.AvgPool2d(8) 170 | self.fc = nn.Linear(64, num_classes) 171 | 172 | init_model(self) 173 | self.regime = [ 174 | {'epoch': 0, 'optimizer': 'SGD', 'lr': 1e-1, 175 | 'weight_decay': 1e-4, 'momentum': 0.9}, 176 | {'epoch': 81, 'lr': 1e-2}, 177 | {'epoch': 122, 'lr': 1e-3, 'weight_decay': 0}, 178 | {'epoch': 164, 'lr': 1e-4} 179 | ] 180 | 181 | 182 | def resnet_bwn(**kwargs): 183 | num_classes, depth, dataset = map( 184 | kwargs.get, ['num_classes', 'depth', 'dataset']) 185 | if dataset == 'imagenet': 186 | num_classes = num_classes or 1000 187 | depth = depth or 50 188 | if depth == 18: 189 | return ResNet_imagenet(num_classes=num_classes, 190 | block=BasicBlock, layers=[2, 2, 2, 2]) 191 | if depth == 34: 192 | return ResNet_imagenet(num_classes=num_classes, 193 | block=BasicBlock, layers=[3, 4, 6, 3]) 194 | if depth == 50: 195 | return ResNet_imagenet(num_classes=num_classes, 196 | block=Bottleneck, layers=[3, 4, 6, 3]) 197 | if depth == 101: 198 | return ResNet_imagenet(num_classes=num_classes, 199 | block=Bottleneck, layers=[3, 4, 23, 3]) 200 | if depth == 152: 201 | return ResNet_imagenet(num_classes=num_classes, 202 | block=Bottleneck, layers=[3, 8, 36, 3]) 203 | 204 | elif dataset == 'cifar10': 205 | num_classes = num_classes or 10 206 | depth = depth or 56 207 | return ResNet_cifar10(num_classes=num_classes, 208 | block=BasicBlock, depth=depth) 209 | -------------------------------------------------------------------------------- /models/resnet_quantized.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torchvision.transforms as transforms 3 | import math 4 | from .modules.quantize import quantize, quantize_grad, QConv2d, QLinear, RangeBN 5 | __all__ = ['resnet_quantized'] 6 | 7 | NUM_BITS = 8 8 | NUM_BITS_WEIGHT = 8 9 | NUM_BITS_GRAD = 8 10 | BIPRECISION = True 11 | 12 | 13 | def conv3x3(in_planes, out_planes, stride=1): 14 | "3x3 convolution with padding" 15 | return QConv2d(in_planes, out_planes, kernel_size=3, stride=stride, 16 | padding=1, bias=False, num_bits=NUM_BITS, num_bits_weight=NUM_BITS_WEIGHT, num_bits_grad=NUM_BITS_GRAD, biprecision=BIPRECISION) 17 | 18 | 19 | def init_model(model): 20 | for m in model.modules(): 21 | if isinstance(m, QConv2d): 22 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 23 | m.weight.data.normal_(0, math.sqrt(2. / n)) 24 | elif isinstance(m, RangeBN): 25 | m.weight.data.fill_(1) 26 | m.bias.data.zero_() 27 | for m in model.modules(): 28 | if isinstance(m, Bottleneck): 29 | nn.init.constant_(m.bn3.weight, 0) 30 | elif isinstance(m, BasicBlock): 31 | nn.init.constant_(m.bn2.weight, 0) 32 | 33 | model.fc.weight.data.normal_(0, 0.01) 34 | model.fc.bias.data.zero_() 35 | 36 | 37 | class BasicBlock(nn.Module): 38 | expansion = 1 39 | 40 | def __init__(self, inplanes, planes, stride=1, downsample=None): 41 | super(BasicBlock, self).__init__() 42 | self.conv1 = conv3x3(inplanes, planes, stride) 43 | self.bn1 = RangeBN(planes, num_bits=NUM_BITS, 44 | num_bits_grad=NUM_BITS_GRAD) 45 | self.relu = nn.ReLU(inplace=True) 46 | self.conv2 = conv3x3(planes, planes) 47 | self.bn2 = RangeBN(planes, num_bits=NUM_BITS, 48 | num_bits_grad=NUM_BITS_GRAD) 49 | self.downsample = downsample 50 | self.stride = stride 51 | 52 | def forward(self, x): 53 | residual = x 54 | 55 | out = self.conv1(x) 56 | out = self.bn1(out) 57 | out = self.relu(out) 58 | 59 | out = self.conv2(out) 60 | out = self.bn2(out) 61 | 62 | if self.downsample is not None: 63 | residual = self.downsample(x) 64 | 65 | out += residual 66 | out = self.relu(out) 67 | 68 | return out 69 | 70 | 71 | class Bottleneck(nn.Module): 72 | expansion = 4 73 | 74 | def __init__(self, inplanes, planes, stride=1, downsample=None): 75 | super(Bottleneck, self).__init__() 76 | self.conv1 = QConv2d(inplanes, planes, kernel_size=1, bias=False, 77 | num_bits=NUM_BITS, num_bits_weight=NUM_BITS_WEIGHT, num_bits_grad=NUM_BITS_GRAD, biprecision=BIPRECISION) 78 | self.bn1 = RangeBN(planes, num_bits=NUM_BITS, 79 | num_bits_grad=NUM_BITS_GRAD) 80 | self.conv2 = QConv2d(planes, planes, kernel_size=3, stride=stride, 81 | padding=1, bias=False, num_bits=NUM_BITS, 82 | num_bits_weight=NUM_BITS_WEIGHT, num_bits_grad=NUM_BITS_GRAD, biprecision=BIPRECISION) 83 | self.bn2 = RangeBN(planes, num_bits=NUM_BITS, 84 | num_bits_grad=NUM_BITS_GRAD) 85 | self.conv3 = QConv2d(planes, planes * 4, kernel_size=1, bias=False, 86 | num_bits=NUM_BITS, num_bits_weight=NUM_BITS_WEIGHT, num_bits_grad=NUM_BITS_GRAD, biprecision=BIPRECISION) 87 | self.bn3 = RangeBN(planes * 4, num_bits=NUM_BITS, 88 | num_bits_grad=NUM_BITS_GRAD) 89 | self.relu = nn.ReLU(inplace=True) 90 | self.downsample = downsample 91 | self.stride = stride 92 | 93 | def forward(self, x): 94 | residual = x 95 | 96 | out = self.conv1(x) 97 | out = self.bn1(out) 98 | out = self.relu(out) 99 | 100 | out = self.conv2(out) 101 | out = self.bn2(out) 102 | out = self.relu(out) 103 | 104 | out = self.conv3(out) 105 | out = self.bn3(out) 106 | 107 | if self.downsample is not None: 108 | residual = self.downsample(x) 109 | 110 | out += residual 111 | out = self.relu(out) 112 | 113 | return out 114 | 115 | 116 | class ResNet(nn.Module): 117 | 118 | def __init__(self): 119 | super(ResNet, self).__init__() 120 | 121 | def _make_layer(self, block, planes, blocks, stride=1): 122 | downsample = None 123 | if stride != 1 or self.inplanes != planes * block.expansion: 124 | downsample = nn.Sequential( 125 | QConv2d(self.inplanes, planes * block.expansion, 126 | kernel_size=1, stride=stride, bias=False, 127 | num_bits=NUM_BITS, num_bits_weight=NUM_BITS_WEIGHT, num_bits_grad=NUM_BITS_GRAD, biprecision=BIPRECISION), 128 | RangeBN(planes * block.expansion, num_bits=NUM_BITS, 129 | num_bits_grad=NUM_BITS_GRAD) 130 | ) 131 | 132 | layers = [] 133 | layers.append(block(self.inplanes, planes, stride, downsample)) 134 | self.inplanes = planes * block.expansion 135 | for i in range(1, blocks): 136 | layers.append(block(self.inplanes, planes)) 137 | 138 | return nn.Sequential(*layers) 139 | 140 | def forward(self, x): 141 | x = self.conv1(x) 142 | x = self.bn1(x) 143 | x = self.relu(x) 144 | x = self.maxpool(x) 145 | 146 | x = self.layer1(x) 147 | x = self.layer2(x) 148 | x = self.layer3(x) 149 | x = self.layer4(x) 150 | 151 | x = self.avgpool(x) 152 | x = x.view(x.size(0), -1) 153 | x = self.fc(x) 154 | 155 | return x 156 | 157 | @staticmethod 158 | def regularization(model, weight_decay=1e-4): 159 | l2_params = 0 160 | for m in model.modules(): 161 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 162 | l2_params += m.weight.pow(2).sum() 163 | if m.bias is not None: 164 | l2_params += m.bias.pow(2).sum() 165 | return weight_decay * 0.5 * l2_params 166 | 167 | 168 | class ResNet_imagenet(ResNet): 169 | 170 | def __init__(self, num_classes=1000, 171 | block=Bottleneck, layers=[3, 4, 23, 3]): 172 | super(ResNet_imagenet, self).__init__() 173 | self.inplanes = 64 174 | self.conv1 = QConv2d(3, 64, kernel_size=7, stride=2, padding=3, 175 | bias=False, num_bits=NUM_BITS, num_bits_weight=NUM_BITS_WEIGHT, num_bits_grad=NUM_BITS_GRAD, biprecision=BIPRECISION) 176 | self.bn1 = RangeBN(64, num_bits=NUM_BITS, num_bits_grad=NUM_BITS_GRAD) 177 | self.relu = nn.ReLU(inplace=True) 178 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 179 | self.layer1 = self._make_layer(block, 64, layers[0]) 180 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 181 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 182 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 183 | self.avgpool = nn.AvgPool2d(7) 184 | self.fc = QLinear(512 * block.expansion, num_classes, num_bits=NUM_BITS, 185 | num_bits_weight=NUM_BITS_WEIGHT, num_bits_grad=NUM_BITS_GRAD, biprecision=BIPRECISION) 186 | 187 | init_model(self) 188 | batch_size = 256. 189 | 190 | scale = batch_size / 256. 191 | 192 | def ramp_up_lr(lr0, lrT, T): 193 | rate = (lrT - lr0) / T 194 | return "lambda t: {'lr': %s + t * %s}" % (lr0, rate) 195 | self.regime = [ 196 | {'epoch': 0, 'optimizer': 'SGD', 'momentum': 0.9, 197 | 'step_lambda': ramp_up_lr(0, 0.1 * scale, 5004 * 5 / scale)}, 198 | {'epoch': 5, 'lr': scale * 1e-1}, 199 | {'epoch': 30, 'lr': scale * 1e-2}, 200 | {'epoch': 60, 'lr': scale * 1e-3}, 201 | {'epoch': 80, 'lr': scale * 1e-4} 202 | ] 203 | 204 | 205 | class ResNet_cifar10(ResNet): 206 | 207 | def __init__(self, num_classes=10, 208 | block=BasicBlock, depth=18): 209 | super(ResNet_cifar10, self).__init__() 210 | self.inplanes = 16 211 | n = int((depth - 2) / 6) 212 | self.conv1 = QConv2d(3, 16, kernel_size=3, stride=1, padding=1, 213 | bias=False, num_bits=NUM_BITS, num_bits_weight=NUM_BITS_WEIGHT, num_bits_grad=NUM_BITS_GRAD, biprecision=BIPRECISION) 214 | self.bn1 = RangeBN(16, num_bits=NUM_BITS, num_bits_grad=NUM_BITS_GRAD) 215 | self.relu = nn.ReLU(inplace=True) 216 | self.maxpool = lambda x: x 217 | self.layer1 = self._make_layer(block, 16, n) 218 | self.layer2 = self._make_layer(block, 32, n, stride=2) 219 | self.layer3 = self._make_layer(block, 64, n, stride=2) 220 | self.layer4 = lambda x: x 221 | self.avgpool = nn.AvgPool2d(8) 222 | self.fc = QLinear(64, num_classes, num_bits=NUM_BITS, 223 | num_bits_weight=NUM_BITS_WEIGHT, num_bits_grad=NUM_BITS_GRAD, biprecision=BIPRECISION) 224 | 225 | init_model(self) 226 | self.regime = [ 227 | {'epoch': 0, 'optimizer': 'SGD', 'lr': 1e-1, 228 | 'weight_decay': 1e-4, 'momentum': 0.9}, 229 | {'epoch': 81, 'lr': 1e-2}, 230 | {'epoch': 122, 'lr': 1e-3, 'weight_decay': 0}, 231 | {'epoch': 164, 'lr': 1e-4} 232 | ] 233 | 234 | 235 | def resnet_quantized(**kwargs): 236 | num_classes, depth, dataset = map( 237 | kwargs.get, ['num_classes', 'depth', 'dataset']) 238 | if dataset == 'imagenet': 239 | num_classes = num_classes or 1000 240 | depth = depth or 50 241 | if depth == 18: 242 | return ResNet_imagenet(num_classes=num_classes, 243 | block=BasicBlock, layers=[2, 2, 2, 2]) 244 | if depth == 34: 245 | return ResNet_imagenet(num_classes=num_classes, 246 | block=BasicBlock, layers=[3, 4, 6, 3]) 247 | if depth == 50: 248 | return ResNet_imagenet(num_classes=num_classes, 249 | block=Bottleneck, layers=[3, 4, 6, 3]) 250 | if depth == 101: 251 | return ResNet_imagenet(num_classes=num_classes, 252 | block=Bottleneck, layers=[3, 4, 23, 3]) 253 | if depth == 152: 254 | return ResNet_imagenet(num_classes=num_classes, 255 | block=Bottleneck, layers=[3, 8, 36, 3]) 256 | 257 | elif dataset == 'cifar10': 258 | num_classes = num_classes or 10 259 | depth = depth or 56 260 | return ResNet_cifar10(num_classes=num_classes, 261 | block=BasicBlock, depth=depth) 262 | -------------------------------------------------------------------------------- /models/resnet_quantized_float_bn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torchvision.transforms as transforms 3 | import math 4 | from .modules.quantize import quantize, quantize_grad, QConv2d, QLinear, RangeBN 5 | __all__ = ['resnet_quantized_float_bn'] 6 | 7 | NUM_BITS = 8 8 | NUM_BITS_WEIGHT = 8 9 | NUM_BITS_GRAD = 8 10 | 11 | 12 | def conv3x3(in_planes, out_planes, stride=1): 13 | "3x3 convolution with padding" 14 | return QConv2d(in_planes, out_planes, kernel_size=3, stride=stride, 15 | padding=1, bias=False, num_bits=NUM_BITS, num_bits_weight=NUM_BITS_WEIGHT, num_bits_grad=NUM_BITS_GRAD) 16 | 17 | 18 | def init_model(model): 19 | for m in model.modules(): 20 | if isinstance(m, QConv2d): 21 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 22 | m.weight.data.normal_(0, math.sqrt(2. / n)) 23 | elif isinstance(m, nn.BatchNorm2d): 24 | m.weight.data.fill_(1) 25 | m.bias.data.zero_() 26 | 27 | 28 | class BasicBlock(nn.Module): 29 | expansion = 1 30 | 31 | def __init__(self, inplanes, planes, stride=1, downsample=None): 32 | super(BasicBlock, self).__init__() 33 | self.conv1 = conv3x3(inplanes, planes, stride) 34 | self.bn1 = nn.BatchNorm2d(planes) 35 | self.relu = nn.ReLU(inplace=True) 36 | self.conv2 = conv3x3(planes, planes) 37 | self.bn2 = nn.BatchNorm2d(planes) 38 | self.downsample = downsample 39 | self.stride = stride 40 | 41 | def forward(self, x): 42 | residual = x 43 | 44 | out = self.conv1(x) 45 | out = self.bn1(out) 46 | out = self.relu(out) 47 | 48 | out = self.conv2(out) 49 | out = self.bn2(out) 50 | 51 | if self.downsample is not None: 52 | residual = self.downsample(x) 53 | 54 | out += residual 55 | out = self.relu(out) 56 | 57 | return out 58 | 59 | 60 | class Bottleneck(nn.Module): 61 | expansion = 4 62 | 63 | def __init__(self, inplanes, planes, stride=1, downsample=None): 64 | super(Bottleneck, self).__init__() 65 | self.conv1 = QConv2d(inplanes, planes, kernel_size=1, bias=False, 66 | num_bits=NUM_BITS, num_bits_weight=NUM_BITS_WEIGHT, num_bits_grad=NUM_BITS_GRAD) 67 | self.bn1 = nn.BatchNorm2d(planes) 68 | self.conv2 = QConv2d(planes, planes, kernel_size=3, stride=stride, 69 | padding=1, bias=False, num_bits=NUM_BITS, 70 | num_bits_weight=NUM_BITS_WEIGHT, num_bits_grad=NUM_BITS_GRAD) 71 | self.bn2 = nn.BatchNorm2d(planes) 72 | self.conv3 = QConv2d(planes, planes * 4, kernel_size=1, bias=False, 73 | num_bits=NUM_BITS, num_bits_weight=NUM_BITS_WEIGHT, num_bits_grad=NUM_BITS_GRAD) 74 | self.bn3 = nn.BatchNorm2d(planes * 4) 75 | self.relu = nn.ReLU(inplace=True) 76 | self.downsample = downsample 77 | self.stride = stride 78 | 79 | def forward(self, x): 80 | residual = x 81 | 82 | out = self.conv1(x) 83 | out = self.bn1(out) 84 | out = self.relu(out) 85 | 86 | out = self.conv2(out) 87 | out = self.bn2(out) 88 | out = self.relu(out) 89 | 90 | out = self.conv3(out) 91 | out = self.bn3(out) 92 | 93 | if self.downsample is not None: 94 | residual = self.downsample(x) 95 | 96 | out += residual 97 | out = self.relu(out) 98 | 99 | return out 100 | 101 | 102 | class ResNet(nn.Module): 103 | 104 | def __init__(self): 105 | super(ResNet, self).__init__() 106 | 107 | def _make_layer(self, block, planes, blocks, stride=1): 108 | downsample = None 109 | if stride != 1 or self.inplanes != planes * block.expansion: 110 | downsample = nn.Sequential( 111 | QConv2d(self.inplanes, planes * block.expansion, 112 | kernel_size=1, stride=stride, bias=False, 113 | num_bits=NUM_BITS, num_bits_weight=NUM_BITS_WEIGHT, num_bits_grad=NUM_BITS_GRAD), 114 | nn.BatchNorm2d(planes * block.expansion), 115 | ) 116 | 117 | layers = [] 118 | layers.append(block(self.inplanes, planes, stride, downsample)) 119 | self.inplanes = planes * block.expansion 120 | for i in range(1, blocks): 121 | layers.append(block(self.inplanes, planes)) 122 | 123 | return nn.Sequential(*layers) 124 | 125 | def forward(self, x): 126 | x = self.conv1(x) 127 | x = self.bn1(x) 128 | x = self.relu(x) 129 | x = self.maxpool(x) 130 | 131 | x = self.layer1(x) 132 | x = self.layer2(x) 133 | x = self.layer3(x) 134 | x = self.layer4(x) 135 | 136 | x = self.avgpool(x) 137 | x = x.view(x.size(0), -1) 138 | x = self.fc(x) 139 | 140 | return x 141 | 142 | 143 | class ResNet_imagenet(ResNet): 144 | 145 | def __init__(self, num_classes=1000, 146 | block=Bottleneck, layers=[3, 4, 23, 3]): 147 | super(ResNet_imagenet, self).__init__() 148 | self.inplanes = 64 149 | self.conv1 = QConv2d(3, 64, kernel_size=7, stride=2, padding=3, 150 | bias=False, num_bits=NUM_BITS, num_bits_weight=NUM_BITS_WEIGHT, num_bits_grad=NUM_BITS_GRAD) 151 | self.bn1 = nn.BatchNorm2d(64) 152 | self.relu = nn.ReLU(inplace=True) 153 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 154 | self.layer1 = self._make_layer(block, 64, layers[0]) 155 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 156 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 157 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 158 | self.avgpool = nn.AvgPool2d(7) 159 | self.fc = QLinear(512 * block.expansion, num_classes, num_bits=NUM_BITS, num_bits_weight=NUM_BITS_WEIGHT, num_bits_grad=NUM_BITS_GRAD) 160 | 161 | init_model(self) 162 | self.regime = [ 163 | {'epoch': 0, 'optimizer': 'SGD', 'lr': 1e-1, 164 | 'weight_decay': 1e-4, 'momentum': 0.9}, 165 | {'epoch': 30, 'lr': 1e-2}, 166 | {'epoch': 60, 'lr': 1e-3, 'weight_decay': 0}, 167 | {'epoch': 90, 'lr': 1e-4} 168 | ] 169 | 170 | 171 | class ResNet_cifar10(ResNet): 172 | 173 | def __init__(self, num_classes=10, 174 | block=BasicBlock, depth=18): 175 | super(ResNet_cifar10, self).__init__() 176 | self.inplanes = 16 177 | n = int((depth - 2) / 6) 178 | self.conv1 = QConv2d(3, 16, kernel_size=3, stride=1, padding=1, 179 | bias=False, num_bits=NUM_BITS, num_bits_weight=NUM_BITS_WEIGHT, num_bits_grad=NUM_BITS_GRAD) 180 | self.bn1 = nn.BatchNorm2d(16) 181 | self.relu = nn.ReLU(inplace=True) 182 | self.maxpool = lambda x: x 183 | self.layer1 = self._make_layer(block, 16, n) 184 | self.layer2 = self._make_layer(block, 32, n, stride=2) 185 | self.layer3 = self._make_layer(block, 64, n, stride=2) 186 | self.layer4 = lambda x: x 187 | self.avgpool = nn.AvgPool2d(8) 188 | self.fc = QLinear(64, num_classes, num_bits=NUM_BITS, num_bits_weight=NUM_BITS_WEIGHT, num_bits_grad=NUM_BITS_GRAD) 189 | 190 | init_model(self) 191 | self.regime = [ 192 | {'epoch': 0, 'optimizer': 'SGD', 'lr': 1e-1, 193 | 'weight_decay': 1e-4, 'momentum': 0.9}, 194 | {'epoch': 81, 'lr': 1e-2}, 195 | {'epoch': 122, 'lr': 1e-3, 'weight_decay': 0}, 196 | {'epoch': 164, 'lr': 1e-4} 197 | ] 198 | 199 | 200 | def resnet_quantized_float_bn(**kwargs): 201 | num_classes, depth, dataset = map( 202 | kwargs.get, ['num_classes', 'depth', 'dataset']) 203 | if dataset == 'imagenet': 204 | num_classes = num_classes or 1000 205 | depth = depth or 50 206 | if depth == 18: 207 | return ResNet_imagenet(num_classes=num_classes, 208 | block=BasicBlock, layers=[2, 2, 2, 2]) 209 | if depth == 34: 210 | return ResNet_imagenet(num_classes=num_classes, 211 | block=BasicBlock, layers=[3, 4, 6, 3]) 212 | if depth == 50: 213 | return ResNet_imagenet(num_classes=num_classes, 214 | block=Bottleneck, layers=[3, 4, 6, 3]) 215 | if depth == 101: 216 | return ResNet_imagenet(num_classes=num_classes, 217 | block=Bottleneck, layers=[3, 4, 23, 3]) 218 | if depth == 152: 219 | return ResNet_imagenet(num_classes=num_classes, 220 | block=Bottleneck, layers=[3, 8, 36, 3]) 221 | 222 | elif dataset == 'cifar10': 223 | num_classes = num_classes or 10 224 | depth = depth or 56 225 | return ResNet_cifar10(num_classes=num_classes, 226 | block=BasicBlock, depth=depth) 227 | -------------------------------------------------------------------------------- /models/resnext.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | from torch.autograd import Variable 5 | 6 | __all__ = ['resnext'] 7 | 8 | def depBatchNorm2d(exists, *kargs, **kwargs): 9 | if exists: 10 | return nn.BatchNorm2d(*kargs, **kwargs) 11 | else: 12 | return lambda x: x 13 | 14 | 15 | def conv3x3(in_planes, out_planes, stride=1, bias=False): 16 | "3x3 convolution with padding" 17 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 18 | padding=1, bias=bias) 19 | 20 | 21 | def init_model(model): 22 | for m in model.modules(): 23 | if isinstance(m, nn.Conv2d): 24 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 25 | m.weight.data.normal_(0, math.sqrt(2. / n)) 26 | elif isinstance(m, nn.BatchNorm2d): 27 | m.weight.data.fill_(1) 28 | m.bias.data.zero_() 29 | 30 | 31 | class BasicBlock(nn.Module): 32 | expansion = 1 33 | 34 | def __init__(self, inplanes, planes, stride=1, downsample=None, 35 | batch_norm=True): 36 | super(BasicBlock, self).__init__() 37 | self.conv1 = conv3x3(inplanes, planes, stride, bias=not batch_norm) 38 | self.bn1 = depBatchNorm2d(batch_norm, planes) 39 | self.relu = nn.ReLU(inplace=True) 40 | self.conv2 = conv3x3(planes, planes, bias=not batch_norm) 41 | self.bn2 = depBatchNorm2d(batch_norm, planes) 42 | self.downsample = downsample 43 | self.stride = stride 44 | 45 | def forward(self, x): 46 | residual = x 47 | 48 | out = self.conv1(x) 49 | out = self.bn1(out) 50 | out = self.relu(out) 51 | 52 | out = self.conv2(out) 53 | out = self.bn2(out) 54 | 55 | if self.downsample is not None: 56 | residual = self.downsample(x) 57 | 58 | out += residual 59 | out = self.relu(out) 60 | 61 | return out 62 | 63 | 64 | class Bottleneck(nn.Module): 65 | expansion = 2 66 | 67 | def __init__(self, inplanes, planes, stride=1, downsample=None, batch_norm=True): 68 | super(Bottleneck, self).__init__() 69 | self.conv1 = nn.Conv2d( 70 | inplanes, planes, kernel_size=1, bias=not batch_norm) 71 | self.bn1 = depBatchNorm2d(batch_norm, planes) 72 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 73 | padding=1, bias=not batch_norm, groups=32) 74 | self.bn2 = depBatchNorm2d(batch_norm, planes) 75 | self.conv3 = nn.Conv2d( 76 | planes, planes * 2, kernel_size=1, bias=not batch_norm) 77 | self.bn3 = depBatchNorm2d(batch_norm, planes * 2) 78 | self.relu = nn.ReLU(inplace=True) 79 | self.downsample = downsample 80 | self.stride = stride 81 | 82 | def forward(self, x): 83 | residual = x 84 | 85 | out = self.conv1(x) 86 | out = self.bn1(out) 87 | out = self.relu(out) 88 | 89 | out = self.conv2(out) 90 | out = self.bn2(out) 91 | out = self.relu(out) 92 | 93 | out = self.conv3(out) 94 | out = self.bn3(out) 95 | 96 | if self.downsample is not None: 97 | residual = self.downsample(x) 98 | 99 | out += residual 100 | out = self.relu(out) 101 | 102 | return out 103 | 104 | 105 | class PlainDownSample(nn.Module): 106 | 107 | def __init__(self, input_dims, output_dims, stride): 108 | super(PlainDownSample, self).__init__() 109 | self.input_dims = input_dims 110 | self.output_dims = output_dims 111 | self.stride = stride 112 | self.downsample = nn.AvgPool2d(stride) 113 | self.zero = Variable(torch.Tensor(1,1,1,1).cuda(), requires_grad=False) 114 | 115 | def forward(self, inputs): 116 | ds = self.downsample(inputs) 117 | zeros_size = [ds.size(0), self.output_dims - 118 | ds.size(1), ds.size(2), ds.size(3)] 119 | return torch.cat([ds, self.zero.expand(*zeros_size)], 1) 120 | 121 | 122 | class ResNeXt(nn.Module): 123 | 124 | def __init__(self, shortcut='B'): 125 | super(ResNeXt, self).__init__() 126 | self.shortcut = shortcut 127 | 128 | def _make_layer(self, block, planes, blocks, stride=1, 129 | batch_norm=True): 130 | downsample = None 131 | if self.shortcut == 'C' or \ 132 | self.shortcut == 'B' and \ 133 | (stride != 1 or self.inplanes != planes * block.expansion): 134 | downsample = [nn.Conv2d(self.inplanes, planes * block.expansion, 135 | kernel_size=1, stride=stride, bias=not batch_norm)] 136 | if batch_norm: 137 | downsample.append(nn.BatchNorm2d(planes * block.expansion)) 138 | downsample = nn.Sequential(*downsample) 139 | else: 140 | downsample = PlainDownSample( 141 | self.inplanes, planes * block.expansion, stride) 142 | 143 | layers = [] 144 | layers.append(block(self.inplanes, planes, 145 | stride, downsample, batch_norm)) 146 | self.inplanes = planes * block.expansion 147 | for i in range(1, blocks): 148 | layers.append(block(self.inplanes, planes, batch_norm=batch_norm)) 149 | 150 | return nn.Sequential(*layers) 151 | 152 | def forward(self, x): 153 | x = self.conv1(x) 154 | x = self.bn1(x) 155 | x = self.relu(x) 156 | x = self.maxpool(x) 157 | 158 | x = self.layer1(x) 159 | x = self.layer2(x) 160 | x = self.layer3(x) 161 | x = self.layer4(x) 162 | 163 | x = self.avgpool(x) 164 | x = x.view(x.size(0), -1) 165 | x = self.fc(x) 166 | 167 | return x 168 | 169 | 170 | class ResNeXt_imagenet(ResNeXt): 171 | 172 | def __init__(self, num_classes=1000, 173 | block=Bottleneck, layers=[3, 4, 23, 3], batch_norm=True, shortcut='B'): 174 | super(ResNeXt_imagenet, self).__init__(shortcut=shortcut) 175 | self.inplanes = 64 176 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 177 | bias=not batch_norm) 178 | self.bn1 = depBatchNorm2d(batch_norm, 64) 179 | self.relu = nn.ReLU(inplace=True) 180 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 181 | self.layer1 = self._make_layer(block, 128, layers[0], 182 | batch_norm=batch_norm) 183 | self.layer2 = self._make_layer(block, 256, layers[1], stride=2, 184 | batch_norm=batch_norm) 185 | self.layer3 = self._make_layer(block, 512, layers[2], stride=2, 186 | batch_norm=batch_norm) 187 | self.layer4 = self._make_layer(block, 1024, layers[3], stride=2, 188 | batch_norm=batch_norm) 189 | self.avgpool = nn.AvgPool2d(7) 190 | self.fc = nn.Linear(1024 * block.expansion, num_classes) 191 | 192 | init_model(self) 193 | self.regime = [ 194 | {'epoch': 0, 'optimizer': 'SGD', 'lr': 1e-1, 195 | 'weight_decay': 1e-4, 'momentum': 0.9}, 196 | {'epoch': 30, 'lr': 1e-2}, 197 | {'epoch': 60, 'lr': 1e-3, 'weight_decay': 0}, 198 | {'epoch': 90, 'lr': 1e-4} 199 | ] 200 | 201 | 202 | class ResNeXt_cifar10(ResNeXt): 203 | 204 | def __init__(self, num_classes=10, 205 | block=BasicBlock, depth=18, batch_norm=True): 206 | super(ResNeXt_cifar10, self).__init__() 207 | self.inplanes = 16 208 | n = int((depth - 2) / 6) 209 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, 210 | bias=not batch_norm) 211 | self.bn1 = depBatchNorm2d(batch_norm, 16) 212 | self.relu = nn.ReLU(inplace=True) 213 | self.maxpool = lambda x: x 214 | self.layer1 = self._make_layer(block, 16, n, 215 | batch_norm=not batch_norm) 216 | self.layer2 = self._make_layer(block, 32, n, stride=2, 217 | batch_norm=not batch_norm) 218 | self.layer3 = self._make_layer(block, 64, n, stride=2, 219 | batch_norm=not batch_norm) 220 | self.layer4 = lambda x: x 221 | self.avgpool = nn.AvgPool2d(8) 222 | self.fc = nn.Linear(64, num_classes) 223 | 224 | init_model(self) 225 | self.regime = [ 226 | {'epoch': 0, 'optimizer': 'SGD', 'lr': 1e-1, 227 | 'weight_decay': 1e-4, 'momentum': 0.9}, 228 | {'epoch': 81, 'lr': 1e-2}, 229 | {'epoch': 122, 'lr': 1e-3, 'weight_decay': 0}, 230 | {'epoch': 164, 'lr': 1e-4} 231 | ] 232 | 233 | 234 | def resnext(**kwargs): 235 | num_classes, depth, dataset, batch_norm, shortcut = map( 236 | kwargs.get, ['num_classes', 'depth', 'dataset', 'batch_norm', 'shortcut']) 237 | dataset = dataset or 'imagenet' 238 | shortcut = shortcut or 'B' 239 | if batch_norm is None: 240 | batch_norm = True 241 | if dataset == 'imagenet': 242 | num_classes = num_classes or 1000 243 | depth = depth or 50 244 | if depth == 18: 245 | return ResNeXt_imagenet(num_classes=num_classes, 246 | block=BasicBlock, layers=[2, 2, 2, 2], 247 | batch_norm=batch_norm, shortcut=shortcut) 248 | if depth == 34: 249 | return ResNeXt_imagenet(num_classes=num_classes, 250 | block=BasicBlock, layers=[3, 4, 6, 3], 251 | batch_norm=batch_norm, shortcut=shortcut) 252 | if depth == 50: 253 | return ResNeXt_imagenet(num_classes=num_classes, 254 | block=Bottleneck, layers=[3, 4, 6, 3], 255 | batch_norm=batch_norm, shortcut=shortcut) 256 | if depth == 101: 257 | return ResNeXt_imagenet(num_classes=num_classes, 258 | block=Bottleneck, layers=[3, 4, 23, 3], 259 | batch_norm=batch_norm, shortcut=shortcut) 260 | if depth == 152: 261 | return ResNeXt_imagenet(num_classes=num_classes, 262 | block=Bottleneck, layers=[3, 8, 36, 3], 263 | batch_norm=batch_norm, shortcut=shortcut) 264 | 265 | elif dataset == 'cifar10': 266 | num_classes = num_classes or 10 267 | depth = depth or 56 268 | return ResNeXt_cifar10(num_classes=num_classes, 269 | block=BasicBlock, depth=depth, batch_norm=batch_norm) 270 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | import random 4 | 5 | __imagenet_stats = {'mean': [0.485, 0.456, 0.406], 6 | 'std': [0.229, 0.224, 0.225]} 7 | 8 | __imagenet_pca = { 9 | 'eigval': torch.Tensor([0.2175, 0.0188, 0.0045]), 10 | 'eigvec': torch.Tensor([ 11 | [-0.5675, 0.7192, 0.4009], 12 | [-0.5808, -0.0045, -0.8140], 13 | [-0.5836, -0.6948, 0.4203], 14 | ]) 15 | } 16 | 17 | 18 | def scale_crop(input_size, scale_size=None, normalize=__imagenet_stats): 19 | t_list = [ 20 | transforms.CenterCrop(input_size), 21 | transforms.ToTensor(), 22 | transforms.Normalize(**normalize), 23 | ] 24 | if scale_size != input_size: 25 | t_list = [transforms.Resize(scale_size)] + t_list 26 | 27 | return transforms.Compose(t_list) 28 | 29 | 30 | def scale_random_crop(input_size, scale_size=None, normalize=__imagenet_stats): 31 | t_list = [ 32 | transforms.RandomCrop(input_size), 33 | transforms.ToTensor(), 34 | transforms.Normalize(**normalize), 35 | ] 36 | if scale_size != input_size: 37 | t_list = [transforms.Resize(scale_size)] + t_list 38 | 39 | transforms.Compose(t_list) 40 | 41 | 42 | def pad_random_crop(input_size, scale_size=None, normalize=__imagenet_stats): 43 | padding = int((scale_size - input_size) / 2) 44 | return transforms.Compose([ 45 | transforms.RandomCrop(input_size, padding=padding), 46 | transforms.RandomHorizontalFlip(), 47 | transforms.ToTensor(), 48 | transforms.Normalize(**normalize), 49 | ]) 50 | 51 | 52 | def inception_preproccess(input_size, normalize=__imagenet_stats): 53 | return transforms.Compose([ 54 | transforms.RandomResizedCrop(input_size), 55 | transforms.RandomHorizontalFlip(), 56 | transforms.ToTensor(), 57 | transforms.Normalize(**normalize) 58 | ]) 59 | def inception_color_preproccess(input_size, normalize=__imagenet_stats): 60 | return transforms.Compose([ 61 | transforms.RandomResizedCrop(input_size), 62 | transforms.RandomHorizontalFlip(), 63 | transforms.ToTensor(), 64 | ColorJitter( 65 | brightness=0.4, 66 | contrast=0.4, 67 | saturation=0.4, 68 | ), 69 | Lighting(0.1, __imagenet_pca['eigval'], __imagenet_pca['eigvec']), 70 | transforms.Normalize(**normalize) 71 | ]) 72 | 73 | 74 | def get_transform(name='imagenet', input_size=None, 75 | scale_size=None, normalize=None, augment=True): 76 | normalize = normalize or __imagenet_stats 77 | if name == 'imagenet': 78 | scale_size = scale_size or 256 79 | input_size = input_size or 224 80 | if augment: 81 | return inception_preproccess(input_size, normalize=normalize) 82 | else: 83 | return scale_crop(input_size=input_size, 84 | scale_size=scale_size, normalize=normalize) 85 | elif 'cifar' in name: 86 | input_size = input_size or 32 87 | if augment: 88 | scale_size = scale_size or 40 89 | return pad_random_crop(input_size, scale_size=scale_size, 90 | normalize=normalize) 91 | else: 92 | scale_size = scale_size or 32 93 | return scale_crop(input_size=input_size, 94 | scale_size=scale_size, normalize=normalize) 95 | elif name == 'mnist': 96 | normalize = {'mean': [0.5], 'std': [0.5]} 97 | input_size = input_size or 28 98 | if augment: 99 | scale_size = scale_size or 32 100 | return pad_random_crop(input_size, scale_size=scale_size, 101 | normalize=normalize) 102 | else: 103 | scale_size = scale_size or 32 104 | return scale_crop(input_size=input_size, 105 | scale_size=scale_size, normalize=normalize) 106 | 107 | 108 | class Lighting(object): 109 | """Lighting noise(AlexNet - style PCA - based noise)""" 110 | 111 | def __init__(self, alphastd, eigval, eigvec): 112 | self.alphastd = alphastd 113 | self.eigval = eigval 114 | self.eigvec = eigvec 115 | 116 | def __call__(self, img): 117 | if self.alphastd == 0: 118 | return img 119 | 120 | alpha = img.new().resize_(3).normal_(0, self.alphastd) 121 | rgb = self.eigvec.type_as(img).clone()\ 122 | .mul(alpha.view(1, 3).expand(3, 3))\ 123 | .mul(self.eigval.view(1, 3).expand(3, 3))\ 124 | .sum(1).squeeze() 125 | 126 | return img.add(rgb.view(3, 1, 1).expand_as(img)) 127 | 128 | 129 | class Grayscale(object): 130 | 131 | def __call__(self, img): 132 | gs = img.clone() 133 | gs[0].mul_(0.299).add_(0.587, gs[1]).add_(0.114, gs[2]) 134 | gs[1].copy_(gs[0]) 135 | gs[2].copy_(gs[0]) 136 | return gs 137 | 138 | 139 | class Saturation(object): 140 | 141 | def __init__(self, var): 142 | self.var = var 143 | 144 | def __call__(self, img): 145 | gs = Grayscale()(img) 146 | alpha = random.uniform(0, self.var) 147 | return img.lerp(gs, alpha) 148 | 149 | 150 | class Brightness(object): 151 | 152 | def __init__(self, var): 153 | self.var = var 154 | 155 | def __call__(self, img): 156 | gs = img.new().resize_as_(img).zero_() 157 | alpha = random.uniform(0, self.var) 158 | return img.lerp(gs, alpha) 159 | 160 | 161 | class Contrast(object): 162 | 163 | def __init__(self, var): 164 | self.var = var 165 | 166 | def __call__(self, img): 167 | gs = Grayscale()(img) 168 | gs.fill_(gs.mean()) 169 | alpha = random.uniform(0, self.var) 170 | return img.lerp(gs, alpha) 171 | 172 | 173 | class RandomOrder(object): 174 | """ Composes several transforms together in random order. 175 | """ 176 | 177 | def __init__(self, transforms): 178 | self.transforms = transforms 179 | 180 | def __call__(self, img): 181 | if self.transforms is None: 182 | return img 183 | order = torch.randperm(len(self.transforms)) 184 | for i in order: 185 | img = self.transforms[i](img) 186 | return img 187 | 188 | 189 | class ColorJitter(RandomOrder): 190 | 191 | def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4): 192 | self.transforms = [] 193 | if brightness != 0: 194 | self.transforms.append(Brightness(brightness)) 195 | if contrast != 0: 196 | self.transforms.append(Contrast(contrast)) 197 | if saturation != 0: 198 | self.transforms.append(Saturation(saturation)) 199 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | --------------------------------------------------------------------------------