├── .gitignore ├── README.md ├── curveball.py └── examples ├── cifar.py ├── mnist.py ├── models ├── __init__.py ├── basic.py ├── densenet.py ├── dpn.py ├── googlenet.py ├── lenet.py ├── mobilenet.py ├── mobilenetv2.py ├── pnasnet.py ├── preact_resnet.py ├── resnet.py ├── resnext.py ├── senet.py ├── shufflenet.py ├── shufflenetv2.py └── vgg.py └── test_fmad.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .vscode 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch CurveBall - A second-order optimizer for deep networks 2 | 3 | This is a PyTorch re-implementation of the CurveBall algorithm, presented in: 4 | 5 | > João F. Henriques, Sebastien Ehrhardt, Samuel Albanie, Andrea Vedaldi, "Small Steps and Giant Leaps: Minimal Newton Solvers for Deep Learning", ICCV 2019 ([arXiv](https://arxiv.org/abs/1805.08095)) 6 | 7 | It follows closely the [original](https://github.com/jotaf98/curveball) Matlab code, although it does not implement all the experiments in that paper. 8 | 9 | ### Warning: 10 | 11 | Unfortunately, the PyTorch operations used for forward-mode automatic differentiation (FMAD) are somewhat slow (refer to [this issue](https://github.com/pytorch/pytorch/issues/22577)). 12 | 13 | For this reason, it is not as fast as the original Matlab implementation or this [TensorFlow](https://github.com/hyenal/curveball-tf) port. 14 | 15 | You can find an experimental version in the `interleave` branch that achieves much higher speed despite this problem (by interleaving the CurveBall steps with SGD). Other suggested fixes are very welcome. 16 | 17 | 18 | ## Requirements 19 | 20 | Although it may work with older versions, this has mainly been tested with: 21 | 22 | - PyTorch 1.3 23 | - Python 3.7 24 | 25 | Plots are produced with [OverBoard](https://pypi.org/project/overboard/) (optional). 26 | 27 | 28 | ## Usage 29 | 30 | The `curveball.py` file contains the full code of the optimizer and that's all you need for it to work. Everything else is just example code. 31 | 32 | The optimizer does not need any hyper-parameters: 33 | 34 | ``` 35 | from curveball import CurveBall 36 | net = ... # your network goes here 37 | optimizer = CurveBall(net.parameters() 38 | ``` 39 | 40 | Inside the training loop, CurveBall needs to know what loss you're using (or losses, as a sum). Provide them as a closure (function), for example (given some `labels`): 41 | 42 | ``` 43 | loss_fn = lambda pred: F.cross_entropy(pred, labels) 44 | ``` 45 | 46 | Similarly, provide a closure for the forward operation of the model (given a model `net` and input `images`): 47 | 48 | ``` 49 | model_fn = lambda: net(images) 50 | ``` 51 | 52 | Now the optimizer step is just: 53 | 54 | ``` 55 | (loss, predictions) = optimizer.step(model_fn, loss_fn) 56 | ``` 57 | 58 | Note that, unlike gradient-based optimizers like SGD, there's no need to run the model forward, call `backward()`, zero-gradients, or any other operations -- the optimizer's step will do all those things (by calling the closures you defined), and update the network's parameters. You can define more complex loss functions or models by using full functions (`def f(): ...`) instead of lambda functions. 59 | 60 | 61 | # Full example 62 | 63 | See `examples/cifar.py` or `examples/mnist.py`. 64 | 65 | 66 | # Author 67 | 68 | [João F. Henriques](http://www.robots.ox.ac.uk/~joao/) 69 | 70 | -------------------------------------------------------------------------------- /curveball.py: -------------------------------------------------------------------------------- 1 | 2 | import torch as t 3 | from torch.optim.optimizer import Optimizer 4 | from torch.autograd import grad 5 | 6 | 7 | class CurveBall(Optimizer): 8 | """CurveBall optimizer""" 9 | 10 | def __init__(self, params, lr=None, momentum=None, auto_lambda=True, lambd=10.0, 11 | lambda_factor=0.999, lambda_low=0.5, lambda_high=1.5, lambda_interval=5): 12 | 13 | defaults = dict(lr=lr, momentum=momentum, auto_lambda=auto_lambda, 14 | lambd=lambd, lambda_factor=lambda_factor, lambda_low=lambda_low, 15 | lambda_high=lambda_high, lambda_interval=lambda_interval) 16 | super().__init__(params, defaults) 17 | 18 | 19 | def step(self, model_fn, loss_fn): 20 | """Performs a single optimization step""" 21 | 22 | # only support one parameter group 23 | if len(self.param_groups) != 1: 24 | raise ValueError('Since the hyper-parameters are set automatically, only one parameter group (with the same hyper-parameters) is supported.') 25 | group = self.param_groups[0] 26 | parameters = group['params'] 27 | 28 | # initialize state to 0 if needed 29 | state = self.state 30 | for p in parameters: 31 | if p not in state: 32 | state[p] = {'z': t.zeros_like(p)} 33 | 34 | # linear list of state tensors z 35 | zs = [state[p]['z'] for p in parameters] 36 | 37 | # store global state (step count, lambda estimate) with first parameter 38 | global_state = state[parameters[0]] 39 | global_state.setdefault('count', 0) 40 | 41 | # get lambda estimate, or initial lambda (user hyper-parameter) if it's not set 42 | lambd = global_state.get('lambd', group['lambd']) 43 | 44 | 45 | # 46 | # compute CurveBall step (delta_zs) 47 | # 48 | 49 | # run forward pass, cutting off gradient propagation between model and loss function for efficiency 50 | predictions = model_fn() 51 | predictions_d = predictions.detach().requires_grad_(True) 52 | loss = loss_fn(predictions_d) 53 | 54 | # compute J^T * z using FMAD (where z are the state variables) 55 | (Jz,) = fmad(predictions, parameters, zs) # equivalent but slower 56 | 57 | # compute loss gradient Jl, retaining the graph to allow 2nd-order gradients 58 | (Jl,) = grad(loss, predictions_d, create_graph=True) 59 | Jl_d = Jl.detach() # detached version, without requiring gradients 60 | 61 | # compute loss Hessian (projected by Jz) using 2nd-order gradients 62 | (Hl_Jz,) = grad(Jl, predictions_d, grad_outputs=Jz, retain_graph=True) 63 | 64 | # compute J * (Hl_Jz + Jl) using RMAD (back-propagation). 65 | # note this is still missing the lambda * z term. 66 | delta_zs = grad(predictions, parameters, Hl_Jz + Jl_d, retain_graph=True) 67 | 68 | # add lambda * z term to the result, obtaining the final steps delta_zs 69 | for (z, dz) in zip(zs, delta_zs): 70 | dz.data.add_(lambd, z) 71 | 72 | 73 | # 74 | # automatic hyper-parameters: momentum (rho) and learning rate (beta) 75 | # 76 | 77 | lr = group['lr'] 78 | momentum = group['momentum'] 79 | 80 | if momentum < 0 or lr < 0 or group['auto_lambda']: # required by auto-lambda 81 | # compute J^T * delta_zs 82 | (Jdeltaz,) = fmad(predictions, parameters, delta_zs) # equivalent but slower 83 | 84 | # project result by loss hessian (using 2nd-order gradients) 85 | (Hl_Jdeltaz,) = grad(Jl, predictions_d, grad_outputs=Jdeltaz) 86 | 87 | # solve 2x2 linear system: [rho, -beta]^T = [a11, a12; a12, a22]^-1 [b1, b2]^T. 88 | # accumulate components of dot-product from all parameters, by first aggregating them into a vector. 89 | z_vec = t.cat([z.flatten() for z in zs]) 90 | dz_vec = t.cat([dz.flatten() for dz in delta_zs]) 91 | 92 | a11 = lambd * (dz_vec * dz_vec).sum() + (Jdeltaz * Hl_Jdeltaz).sum() 93 | a12 = lambd * (dz_vec * z_vec).sum() + (Jz * Hl_Jdeltaz).sum() 94 | a22 = lambd * (z_vec * z_vec).sum() + (Jz * Hl_Jz).sum() 95 | 96 | b1 = (Jl_d * Jdeltaz).sum() 97 | b2 = (Jl_d * Jz).sum() 98 | 99 | # item() implicitly moves to the CPU 100 | A = t.tensor([[a11.item(), a12.item()], [a12.item(), a22.item()]]) 101 | b = t.tensor([[b1.item()], [b2.item()]]) 102 | auto_params = A.pinverse() @ b 103 | 104 | lr = auto_params[0].item() 105 | momentum = -auto_params[1].item() 106 | 107 | 108 | # 109 | # update parameters and state in-place: z = momentum * z + lr * delta_z; p = p + z 110 | # 111 | 112 | for (p, z, dz) in zip(parameters, zs, delta_zs): 113 | z.data.mul_(momentum).add_(-lr, dz) # update state 114 | p.data.add_(z) # update parameter 115 | 116 | 117 | # 118 | # automatic lambda hyper-parameter (trust region adaptation) 119 | # 120 | 121 | if group['auto_lambda']: 122 | # only adapt once every few batches 123 | if global_state['count'] % group['lambda_interval'] == 0: 124 | with t.no_grad(): 125 | # evaluate the loss with the updated parameters 126 | new_loss = loss_fn(model_fn()) 127 | 128 | # objective function change predicted by quadratic fit 129 | quadratic_change = -0.5 * (auto_params * b).sum() 130 | 131 | # ratio between predicted and actual change 132 | ratio = (new_loss - loss) / quadratic_change 133 | 134 | # increase or decrease lambda based on ratio 135 | factor = group['lambda_factor'] ** group['lambda_interval'] 136 | 137 | if ratio < group['lambda_low']: lambd /= factor 138 | if ratio > group['lambda_high']: lambd *= factor 139 | 140 | global_state['lambd'] = lambd 141 | 142 | global_state['count'] += 1 143 | return (loss, predictions) 144 | 145 | 146 | def fmad(ys, xs, dxs): 147 | """Forward-mode automatic differentiation.""" 148 | v = t.zeros_like(ys, requires_grad=True) 149 | g = grad(ys, xs, grad_outputs=v, create_graph=True) 150 | return grad(g, v, grad_outputs=dxs) 151 | 152 | -------------------------------------------------------------------------------- /examples/cifar.py: -------------------------------------------------------------------------------- 1 | '''Train CIFAR10 with PyTorch.''' 2 | from __future__ import print_function 3 | 4 | import torch 5 | import torch.optim as optim 6 | import torch.nn.functional as F 7 | 8 | import torchvision 9 | import torchvision.transforms as transforms 10 | 11 | import os, argparse, shutil, sys 12 | from time import time 13 | 14 | sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/..') # import from parent directory 15 | from curveball import CurveBall 16 | 17 | import models 18 | 19 | try: 20 | from overboard import Logger 21 | except ImportError: 22 | print('Warning: OverBoard not installed, no logging/plotting will be performed. See https://pypi.org/project/overboard/') 23 | Logger = None 24 | 25 | 26 | def train(args, net, device, train_loader, optimizer, epoch, logger): 27 | net.train() 28 | for batch_idx, (data, target) in enumerate(train_loader): 29 | start = time() 30 | data, target = data.to(device), target.to(device) 31 | 32 | # create closures to compute the forward pass, and the loss 33 | model_fn = lambda: net(data) 34 | loss_fn = lambda pred: F.cross_entropy(pred, target) 35 | 36 | if isinstance(optimizer, CurveBall): 37 | (loss, predictions) = optimizer.step(model_fn, loss_fn) 38 | else: 39 | # standard optimizer 40 | optimizer.zero_grad() 41 | predictions = model_fn() 42 | loss = loss_fn(predictions) 43 | loss.backward() 44 | optimizer.step() 45 | 46 | pred = predictions.max(1, keepdim=True)[1] # get the index of the max log-probability 47 | accuracy = pred.eq(target.view_as(pred)).double().mean() 48 | 49 | # log the loss and accuracy 50 | stats = {'train.loss': loss.item(), 'train.accuracy': accuracy.item()} 51 | if logger: 52 | logger.update_average(stats) 53 | if logger.avg_count['train.loss'] > 3: # skip first 3 iterations (warm-up time) 54 | logger.update_average({'train.time': time() - start}) 55 | logger.print(line_prefix='ep %i ' % epoch, prefix='train') 56 | else: 57 | print(stats) 58 | 59 | 60 | def test(args, net, device, test_loader, logger): 61 | net.eval() 62 | with torch.no_grad(): 63 | for data, target in test_loader: 64 | start = time() 65 | data, target = data.to(device), target.to(device) 66 | predictions = net(data) 67 | 68 | loss = F.cross_entropy(predictions, target) 69 | 70 | pred = predictions.max(1, keepdim=True)[1] # get the index of the max log-probability 71 | accuracy = pred.eq(target.view_as(pred)).double().mean() 72 | 73 | # log the loss and accuracy 74 | stats = {'val.loss': loss.item(), 'val.accuracy': accuracy.item()} 75 | if logger: 76 | logger.update_average(stats) 77 | if logger.avg_count['val.loss'] > 3: # skip first 3 iterations (warm-up time) 78 | logger.update_average({'val.time': time() - start}) 79 | logger.print(prefix='val') 80 | else: 81 | print(stats) 82 | 83 | 84 | def main(): 85 | all_models = [name for name in dir(models) if callable(getattr(models, name))] 86 | 87 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') 88 | parser.add_argument("experiment", nargs='?', default="test") 89 | parser.add_argument('-model', choices=all_models, default='BasicNetBN') #ResNet18 90 | parser.add_argument('-optimizer', choices=['sgd', 'adam', 'curveball'], default='curveball') 91 | parser.add_argument('-lr', default=-1, type=float, help='learning rate') 92 | parser.add_argument('-momentum', type=float, default=-1, metavar='M') 93 | parser.add_argument('-lambda', type=float, default=1.0) 94 | parser.add_argument('--no-auto-lambda', action='store_true', default=False, help='disables automatic lambda estimation') 95 | parser.add_argument('-batch-size', default=128, type=int) 96 | parser.add_argument('-epochs', default=200, type=int) 97 | parser.add_argument('-save-interval', default=10, type=int) 98 | parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint') 99 | parser.add_argument('-outputdir', default='data/cifar-experiments', type=str) 100 | parser.add_argument('-datadir', default='data/cifar', type=str) 101 | parser.add_argument('-device', default='cuda', type=str) 102 | parser.add_argument('--parallel', action='store_true', default=False) 103 | args = parser.parse_args() 104 | 105 | args.outputdir += ('/' + args.model + '/' + args.optimizer + '/' + args.experiment) 106 | 107 | if os.path.isdir(args.outputdir): 108 | input('Directory already exists. Press Enter to overwrite or Ctrl+C to cancel.') 109 | 110 | if not torch.cuda.is_available(): args.device = 'cpu' 111 | best_acc = 0 # best test accuracy 112 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 113 | 114 | # data 115 | transform_train = transforms.Compose([ 116 | transforms.RandomCrop(32, padding=2, fill=(128, 128, 128)), 117 | transforms.RandomHorizontalFlip(), 118 | transforms.ToTensor(), 119 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 120 | ]) 121 | 122 | transform_test = transforms.Compose([ 123 | transforms.ToTensor(), 124 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 125 | ]) 126 | 127 | train_set = torchvision.datasets.CIFAR10(root=args.datadir, train=True, download=True, transform=transform_train) 128 | 129 | train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, num_workers=2, shuffle=True) 130 | 131 | test_set = torchvision.datasets.CIFAR10(root=args.datadir, train=False, download=True, transform=transform_test) 132 | 133 | test_loader = torch.utils.data.DataLoader(test_set, batch_size=args.batch_size, num_workers=2, shuffle=False) 134 | 135 | # model 136 | net = getattr(models, args.model)() 137 | 138 | net = net.to(args.device) 139 | if args.device != 'cpu' and args.parallel: 140 | net = torch.nn.DataParallel(net) 141 | torch.backends.cudnn.benchmark = True # slightly faster for fixed batch/input sizes 142 | 143 | if args.resume: 144 | # load checkpoint 145 | print('Resuming from checkpoint..') 146 | assert os.path.isdir(args.outputdir), 'Error: no checkpoint directory found!' 147 | checkpoint = torch.load(args.outputdir + '/last.t7') 148 | net.load_state_dict(checkpoint['net']) 149 | best_acc = checkpoint['acc'] 150 | start_epoch = checkpoint['epoch'] 151 | 152 | # optimizer 153 | if args.optimizer == 'sgd': 154 | if args.lr < 0: args.lr = 0.1 155 | optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9) 156 | 157 | elif args.optimizer == 'adam': 158 | if args.lr < 0: args.lr = 0.001 159 | optimizer = optim.Adam(net.parameters(), lr=args.lr) 160 | 161 | elif args.optimizer == 'curveball': 162 | #if args.lr < 0: args.lr = 0.01 163 | #if args.momentum < 0: args.momentum = 0.9 164 | lambd = getattr(args, 'lambda') 165 | 166 | optimizer = CurveBall(net.parameters(), lr=args.lr, momentum=args.momentum, lambd=lambd, auto_lambda=not args.no_auto_lambda) 167 | 168 | logger = None 169 | if Logger: logger = Logger(args.outputdir, meta=args, resume=args.resume) 170 | 171 | for epoch in range(start_epoch, args.epochs): 172 | train(args, net, args.device, train_loader, optimizer, epoch, logger) 173 | test(args, net, args.device, test_loader, logger) 174 | 175 | if logger: 176 | acc = logger.average()['val.accuracy'] 177 | logger.append() 178 | 179 | # save checkpoint 180 | if epoch % args.save_interval == 0: 181 | print('Saving..') 182 | state = {'net': net.state_dict(), 'optimizer': optimizer.state_dict(), 'acc': acc, 'epoch': epoch} 183 | if not os.path.isdir(args.outputdir): 184 | os.mkdir(args.outputdir) 185 | torch.save(state, args.outputdir + '/last.t7') 186 | if logger and acc > best_acc: 187 | shutil.copyfile(args.outputdir + '/last.t7', args.outputdir + '/best.t7') 188 | best_acc = acc 189 | 190 | if __name__ == '__main__': 191 | main() 192 | -------------------------------------------------------------------------------- /examples/mnist.py: -------------------------------------------------------------------------------- 1 | 2 | # Modified version of the PyTorch MNIST example to log outputs for OverBoard 3 | 4 | from __future__ import print_function 5 | import argparse, os, sys 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.optim as optim 10 | from torchvision import datasets, transforms 11 | from time import time 12 | 13 | from models.basic import insert_bnorm 14 | 15 | sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/..') # import from parent directory 16 | from curveball import CurveBall 17 | 18 | try: 19 | from overboard import Logger 20 | except ImportError: 21 | print('Warning: OverBoard not installed, no logging/plotting will be performed. See https://pypi.org/project/overboard/') 22 | Logger = None 23 | 24 | 25 | class Flatten(nn.Module): 26 | def forward(self, input): 27 | return input.view(input.size(0), -1) 28 | 29 | def onehot(target, like): 30 | """Transforms numeric labels into one-hot regression targets.""" 31 | out = torch.zeros_like(like) 32 | out.scatter_(1, target.unsqueeze(1), 1.0) 33 | return out 34 | 35 | 36 | def train(args, model, device, train_loader, optimizer, epoch, logger): 37 | model.train() 38 | for batch_idx, (data, target) in enumerate(train_loader): 39 | start = time() 40 | data, target = data.to(device), target.to(device) 41 | 42 | # create closures to compute the forward pass, and the loss 43 | model_fn = lambda: model(data) 44 | loss_fn = lambda pred: F.cross_entropy(pred, target) 45 | 46 | if isinstance(optimizer, CurveBall): 47 | (loss, predictions) = optimizer.step(model_fn, loss_fn) 48 | else: 49 | # standard optimizer 50 | optimizer.zero_grad() 51 | predictions = model_fn() 52 | loss = loss_fn(predictions) 53 | loss.backward() 54 | optimizer.step() 55 | 56 | pred = predictions.max(1, keepdim=True)[1] # get the index of the max log-probability 57 | accuracy = pred.eq(target.view_as(pred)).double().mean() 58 | 59 | # log the loss and accuracy 60 | stats = {'train.loss': loss.item(), 'train.accuracy': accuracy.item()} 61 | if logger: 62 | logger.update_average(stats) 63 | if logger.avg_count['train.loss'] > 3: # skip first 3 iterations (warm-up time) 64 | logger.update_average({'train.time': time() - start}) 65 | logger.print(line_prefix='ep %i ' % epoch, prefix='train') 66 | else: 67 | print(stats) 68 | 69 | 70 | def test(args, model, device, test_loader, logger): 71 | model.eval() 72 | with torch.no_grad(): 73 | for data, target in test_loader: 74 | start = time() 75 | data, target = data.to(device), target.to(device) 76 | predictions = model(data) 77 | 78 | loss = F.cross_entropy(predictions, target) 79 | 80 | pred = predictions.max(1, keepdim=True)[1] # get the index of the max log-probability 81 | accuracy = pred.eq(target.view_as(pred)).double().mean() 82 | 83 | # log the loss and accuracy 84 | stats = {'val.loss': loss.item(), 'val.accuracy': accuracy.item()} 85 | if logger: 86 | logger.update_average(stats) 87 | if logger.avg_count['val.loss'] > 3: # skip first 3 iterations (warm-up time) 88 | logger.update_average({'val.time': time() - start}) 89 | else: 90 | print(stats) 91 | 92 | # display final values in console 93 | if logger: logger.print(prefix='val') 94 | 95 | 96 | def main(): 97 | # Training settings 98 | parser = argparse.ArgumentParser() 99 | parser.add_argument("experiment", nargs='?', default="test") 100 | parser.add_argument('-batch-size', type=int, default=64, metavar='N', 101 | help='input batch size for training (default: 64)') 102 | parser.add_argument('-test-batch-size', type=int, default=1000, 103 | help='input batch size for testing (default: 1000)') 104 | parser.add_argument('-epochs', type=int, default=10, 105 | help='number of epochs to train (default: 10)') 106 | parser.add_argument('-optimizer', choices=['sgd', 'adam', 'curveball'], default='curveball', 107 | help='optimizer (sgd, adam, or curveball)') 108 | parser.add_argument('-lr', type=float, default=-1, metavar='LR', 109 | help='learning rate (default: 0.01 for SGD, 0.001 for Adam, 1 for CurveBall)') 110 | parser.add_argument('-momentum', type=float, default=-1, metavar='M', 111 | help='momentum (default: 0.5)') 112 | parser.add_argument('-lambda', type=float, default=1.0, 113 | help='lambda') 114 | parser.add_argument('--no-auto-lambda', action='store_true', default=False, 115 | help='disables automatic lambda estimation') 116 | parser.add_argument('--no-batch-norm', action='store_true', default=False) 117 | parser.add_argument('--no-cuda', action='store_true', default=False, 118 | help='disables CUDA training') 119 | parser.add_argument('-seed', type=int, default=1, metavar='S', 120 | help='random seed (default: 1)') 121 | parser.add_argument('-datadir', type=str, default='data/mnist', 122 | help='MNIST data directory') 123 | parser.add_argument('-outputdir', type=str, default='data/mnist-experiments', 124 | help='output directory') 125 | args = parser.parse_args() 126 | use_cuda = not args.no_cuda and torch.cuda.is_available() 127 | args.outputdir += '/' + args.optimizer + '/' + args.experiment 128 | 129 | if os.path.isdir(args.outputdir): 130 | input('Directory already exists. Press Enter to overwrite or Ctrl+C to cancel.') 131 | 132 | torch.manual_seed(args.seed) 133 | 134 | device = torch.device("cuda" if use_cuda else "cpu") 135 | 136 | kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} 137 | train_loader = torch.utils.data.DataLoader( 138 | datasets.MNIST(args.datadir, train=True, download=True, 139 | transform=transforms.Compose([ 140 | transforms.ToTensor(), 141 | transforms.Normalize((0.1307,), (0.3081,)) 142 | ])), 143 | batch_size=args.batch_size, shuffle=True, **kwargs) 144 | test_loader = torch.utils.data.DataLoader( 145 | datasets.MNIST(args.datadir, train=False, transform=transforms.Compose([ 146 | transforms.ToTensor(), 147 | transforms.Normalize((0.1307,), (0.3081,)) 148 | ])), 149 | batch_size=args.test_batch_size, shuffle=True, **kwargs) 150 | 151 | # same network as in the tutorial, in sequential form, and with optional batch-norm 152 | layers = [ 153 | nn.Conv2d(1, 10, kernel_size=5), 154 | nn.MaxPool2d(kernel_size=2), 155 | nn.ReLU(), 156 | nn.Conv2d(10, 20, kernel_size=5), 157 | nn.MaxPool2d(kernel_size=2), 158 | nn.ReLU(), 159 | nn.Dropout2d(), 160 | Flatten(), 161 | nn.Linear(320, 50), 162 | nn.ReLU(), 163 | nn.Dropout(), 164 | nn.Linear(50, 10) 165 | ] 166 | 167 | # insert batch norm layers 168 | if not args.no_batch_norm: insert_bnorm(layers) 169 | 170 | model = nn.Sequential(*layers) 171 | model.to(device) 172 | 173 | if args.optimizer == 'sgd': 174 | if args.lr < 0: args.lr = 0.01 175 | if args.momentum < 0: args.momentum = 0.5 176 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) 177 | 178 | elif args.optimizer == 'adam': 179 | if args.lr < 0: args.lr = 0.001 180 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 181 | 182 | elif args.optimizer == 'curveball': 183 | #if args.lr < 0: args.lr = 0.01 184 | #if args.momentum < 0: args.momentum = 0.9 185 | lambd = getattr(args, 'lambda') 186 | 187 | optimizer = CurveBall(model.parameters(), lr=args.lr, momentum=args.momentum, lambd=lambd, auto_lambda=not args.no_auto_lambda) 188 | 189 | # open logging stream 190 | with Logger(args.outputdir, meta=args) as logger: 191 | # do training 192 | for epoch in range(1, args.epochs + 1): 193 | train(args, model, device, train_loader, optimizer, epoch, logger) 194 | test(args, model, device, test_loader, logger) 195 | 196 | # record average statistics collected over this epoch (with logger.update_average) 197 | logger.append() 198 | 199 | 200 | if __name__ == '__main__': 201 | main() 202 | 203 | -------------------------------------------------------------------------------- /examples/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .basic import * 2 | from .vgg import * 3 | from .dpn import * 4 | from .lenet import * 5 | from .senet import * 6 | from .pnasnet import * 7 | from .densenet import * 8 | from .googlenet import * 9 | from .shufflenet import * 10 | from .shufflenetv2 import * 11 | from .resnet import * 12 | from .resnext import * 13 | from .preact_resnet import * 14 | from .mobilenet import * 15 | from .mobilenetv2 import * 16 | -------------------------------------------------------------------------------- /examples/models/basic.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.nn as nn 3 | 4 | 5 | class Flatten(nn.Module): 6 | def forward(self, input): 7 | return input.view(input.size(0), -1) 8 | 9 | def BasicNetBN(): 10 | return BasicNet(batch_norm=True) 11 | 12 | def BasicNet(batch_norm=False): 13 | """Basic network for CIFAR.""" 14 | layers = [ 15 | nn.Conv2d(3, 32, kernel_size=5, padding=2), 16 | nn.ReLU(), 17 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1), 18 | 19 | nn.Conv2d(32, 32, kernel_size=5, padding=2), 20 | nn.ReLU(), 21 | nn.AvgPool2d(kernel_size=3, stride=2, padding=1), 22 | 23 | nn.Conv2d(32, 64, kernel_size=5, padding=2), 24 | nn.ReLU(), 25 | nn.AvgPool2d(kernel_size=3, stride=2, padding=1), 26 | 27 | Flatten(), 28 | 29 | nn.Linear(4 * 4 * 64, 64), 30 | nn.ReLU(), 31 | 32 | nn.Linear(64, 10) 33 | ] 34 | 35 | # insert batch norm layers 36 | if batch_norm: 37 | insert_bnorm(layers, init_gain=True, eps=1e-4) 38 | 39 | return nn.Sequential(*layers) 40 | 41 | 42 | def insert_bnorm(layers, init_gain=False, eps=1e-5, ignore_last_layer=True): 43 | """Inserts batch-norm layers after each convolution/linear layer in a list of layers.""" 44 | last = True 45 | for (idx, layer) in reversed(list(enumerate(layers))): 46 | if isinstance(layer, (nn.Conv2d, nn.Linear)): 47 | if ignore_last_layer and last: 48 | last = False # do not insert batch-norm after last linear/conv layer 49 | else: 50 | if isinstance(layer, nn.Conv2d): 51 | bnorm = nn.BatchNorm2d(layer.out_channels, eps=eps) 52 | elif isinstance(layer, nn.Linear): 53 | bnorm = nn.BatchNorm1d(layer.out_features, eps=eps) 54 | 55 | if init_gain: 56 | bnorm.weight.data[:] = 1.0 # instead of uniform sampling 57 | 58 | layers.insert(idx + 1, bnorm) 59 | return layers 60 | -------------------------------------------------------------------------------- /examples/models/densenet.py: -------------------------------------------------------------------------------- 1 | '''DenseNet in PyTorch.''' 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class Bottleneck(nn.Module): 10 | def __init__(self, in_planes, growth_rate): 11 | super(Bottleneck, self).__init__() 12 | self.bn1 = nn.BatchNorm2d(in_planes) 13 | self.conv1 = nn.Conv2d(in_planes, 4*growth_rate, kernel_size=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(4*growth_rate) 15 | self.conv2 = nn.Conv2d(4*growth_rate, growth_rate, kernel_size=3, padding=1, bias=False) 16 | 17 | def forward(self, x): 18 | out = self.conv1(F.relu(self.bn1(x))) 19 | out = self.conv2(F.relu(self.bn2(out))) 20 | out = torch.cat([out,x], 1) 21 | return out 22 | 23 | 24 | class Transition(nn.Module): 25 | def __init__(self, in_planes, out_planes): 26 | super(Transition, self).__init__() 27 | self.bn = nn.BatchNorm2d(in_planes) 28 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False) 29 | 30 | def forward(self, x): 31 | out = self.conv(F.relu(self.bn(x))) 32 | out = F.avg_pool2d(out, 2) 33 | return out 34 | 35 | 36 | class DenseNet(nn.Module): 37 | def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_classes=10): 38 | super(DenseNet, self).__init__() 39 | self.growth_rate = growth_rate 40 | 41 | num_planes = 2*growth_rate 42 | self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, padding=1, bias=False) 43 | 44 | self.dense1 = self._make_dense_layers(block, num_planes, nblocks[0]) 45 | num_planes += nblocks[0]*growth_rate 46 | out_planes = int(math.floor(num_planes*reduction)) 47 | self.trans1 = Transition(num_planes, out_planes) 48 | num_planes = out_planes 49 | 50 | self.dense2 = self._make_dense_layers(block, num_planes, nblocks[1]) 51 | num_planes += nblocks[1]*growth_rate 52 | out_planes = int(math.floor(num_planes*reduction)) 53 | self.trans2 = Transition(num_planes, out_planes) 54 | num_planes = out_planes 55 | 56 | self.dense3 = self._make_dense_layers(block, num_planes, nblocks[2]) 57 | num_planes += nblocks[2]*growth_rate 58 | out_planes = int(math.floor(num_planes*reduction)) 59 | self.trans3 = Transition(num_planes, out_planes) 60 | num_planes = out_planes 61 | 62 | self.dense4 = self._make_dense_layers(block, num_planes, nblocks[3]) 63 | num_planes += nblocks[3]*growth_rate 64 | 65 | self.bn = nn.BatchNorm2d(num_planes) 66 | self.linear = nn.Linear(num_planes, num_classes) 67 | 68 | def _make_dense_layers(self, block, in_planes, nblock): 69 | layers = [] 70 | for i in range(nblock): 71 | layers.append(block(in_planes, self.growth_rate)) 72 | in_planes += self.growth_rate 73 | return nn.Sequential(*layers) 74 | 75 | def forward(self, x): 76 | out = self.conv1(x) 77 | out = self.trans1(self.dense1(out)) 78 | out = self.trans2(self.dense2(out)) 79 | out = self.trans3(self.dense3(out)) 80 | out = self.dense4(out) 81 | out = F.avg_pool2d(F.relu(self.bn(out)), 4) 82 | out = out.view(out.size(0), -1) 83 | out = self.linear(out) 84 | return out 85 | 86 | def DenseNet121(): 87 | return DenseNet(Bottleneck, [6,12,24,16], growth_rate=32) 88 | 89 | def DenseNet169(): 90 | return DenseNet(Bottleneck, [6,12,32,32], growth_rate=32) 91 | 92 | def DenseNet201(): 93 | return DenseNet(Bottleneck, [6,12,48,32], growth_rate=32) 94 | 95 | def DenseNet161(): 96 | return DenseNet(Bottleneck, [6,12,36,24], growth_rate=48) 97 | 98 | def densenet_cifar(): 99 | return DenseNet(Bottleneck, [6,12,24,16], growth_rate=12) 100 | 101 | def test(): 102 | net = densenet_cifar() 103 | x = torch.randn(1,3,32,32) 104 | y = net(x) 105 | print(y) 106 | 107 | # test() 108 | -------------------------------------------------------------------------------- /examples/models/dpn.py: -------------------------------------------------------------------------------- 1 | '''Dual Path Networks in PyTorch.''' 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class Bottleneck(nn.Module): 8 | def __init__(self, last_planes, in_planes, out_planes, dense_depth, stride, first_layer): 9 | super(Bottleneck, self).__init__() 10 | self.out_planes = out_planes 11 | self.dense_depth = dense_depth 12 | 13 | self.conv1 = nn.Conv2d(last_planes, in_planes, kernel_size=1, bias=False) 14 | self.bn1 = nn.BatchNorm2d(in_planes) 15 | self.conv2 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=32, bias=False) 16 | self.bn2 = nn.BatchNorm2d(in_planes) 17 | self.conv3 = nn.Conv2d(in_planes, out_planes+dense_depth, kernel_size=1, bias=False) 18 | self.bn3 = nn.BatchNorm2d(out_planes+dense_depth) 19 | 20 | self.shortcut = nn.Sequential() 21 | if first_layer: 22 | self.shortcut = nn.Sequential( 23 | nn.Conv2d(last_planes, out_planes+dense_depth, kernel_size=1, stride=stride, bias=False), 24 | nn.BatchNorm2d(out_planes+dense_depth) 25 | ) 26 | 27 | def forward(self, x): 28 | out = F.relu(self.bn1(self.conv1(x))) 29 | out = F.relu(self.bn2(self.conv2(out))) 30 | out = self.bn3(self.conv3(out)) 31 | x = self.shortcut(x) 32 | d = self.out_planes 33 | out = torch.cat([x[:,:d,:,:]+out[:,:d,:,:], x[:,d:,:,:], out[:,d:,:,:]], 1) 34 | out = F.relu(out) 35 | return out 36 | 37 | 38 | class DPN(nn.Module): 39 | def __init__(self, cfg): 40 | super(DPN, self).__init__() 41 | in_planes, out_planes = cfg['in_planes'], cfg['out_planes'] 42 | num_blocks, dense_depth = cfg['num_blocks'], cfg['dense_depth'] 43 | 44 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 45 | self.bn1 = nn.BatchNorm2d(64) 46 | self.last_planes = 64 47 | self.layer1 = self._make_layer(in_planes[0], out_planes[0], num_blocks[0], dense_depth[0], stride=1) 48 | self.layer2 = self._make_layer(in_planes[1], out_planes[1], num_blocks[1], dense_depth[1], stride=2) 49 | self.layer3 = self._make_layer(in_planes[2], out_planes[2], num_blocks[2], dense_depth[2], stride=2) 50 | self.layer4 = self._make_layer(in_planes[3], out_planes[3], num_blocks[3], dense_depth[3], stride=2) 51 | self.linear = nn.Linear(out_planes[3]+(num_blocks[3]+1)*dense_depth[3], 10) 52 | 53 | def _make_layer(self, in_planes, out_planes, num_blocks, dense_depth, stride): 54 | strides = [stride] + [1]*(num_blocks-1) 55 | layers = [] 56 | for i,stride in enumerate(strides): 57 | layers.append(Bottleneck(self.last_planes, in_planes, out_planes, dense_depth, stride, i==0)) 58 | self.last_planes = out_planes + (i+2) * dense_depth 59 | return nn.Sequential(*layers) 60 | 61 | def forward(self, x): 62 | out = F.relu(self.bn1(self.conv1(x))) 63 | out = self.layer1(out) 64 | out = self.layer2(out) 65 | out = self.layer3(out) 66 | out = self.layer4(out) 67 | out = F.avg_pool2d(out, 4) 68 | out = out.view(out.size(0), -1) 69 | out = self.linear(out) 70 | return out 71 | 72 | 73 | def DPN26(): 74 | cfg = { 75 | 'in_planes': (96,192,384,768), 76 | 'out_planes': (256,512,1024,2048), 77 | 'num_blocks': (2,2,2,2), 78 | 'dense_depth': (16,32,24,128) 79 | } 80 | return DPN(cfg) 81 | 82 | def DPN92(): 83 | cfg = { 84 | 'in_planes': (96,192,384,768), 85 | 'out_planes': (256,512,1024,2048), 86 | 'num_blocks': (3,4,20,3), 87 | 'dense_depth': (16,32,24,128) 88 | } 89 | return DPN(cfg) 90 | 91 | 92 | def test(): 93 | net = DPN92() 94 | x = torch.randn(1,3,32,32) 95 | y = net(x) 96 | print(y) 97 | 98 | # test() 99 | -------------------------------------------------------------------------------- /examples/models/googlenet.py: -------------------------------------------------------------------------------- 1 | '''GoogLeNet with PyTorch.''' 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class Inception(nn.Module): 8 | def __init__(self, in_planes, n1x1, n3x3red, n3x3, n5x5red, n5x5, pool_planes): 9 | super(Inception, self).__init__() 10 | # 1x1 conv branch 11 | self.b1 = nn.Sequential( 12 | nn.Conv2d(in_planes, n1x1, kernel_size=1), 13 | nn.BatchNorm2d(n1x1), 14 | nn.ReLU(True), 15 | ) 16 | 17 | # 1x1 conv -> 3x3 conv branch 18 | self.b2 = nn.Sequential( 19 | nn.Conv2d(in_planes, n3x3red, kernel_size=1), 20 | nn.BatchNorm2d(n3x3red), 21 | nn.ReLU(True), 22 | nn.Conv2d(n3x3red, n3x3, kernel_size=3, padding=1), 23 | nn.BatchNorm2d(n3x3), 24 | nn.ReLU(True), 25 | ) 26 | 27 | # 1x1 conv -> 5x5 conv branch 28 | self.b3 = nn.Sequential( 29 | nn.Conv2d(in_planes, n5x5red, kernel_size=1), 30 | nn.BatchNorm2d(n5x5red), 31 | nn.ReLU(True), 32 | nn.Conv2d(n5x5red, n5x5, kernel_size=3, padding=1), 33 | nn.BatchNorm2d(n5x5), 34 | nn.ReLU(True), 35 | nn.Conv2d(n5x5, n5x5, kernel_size=3, padding=1), 36 | nn.BatchNorm2d(n5x5), 37 | nn.ReLU(True), 38 | ) 39 | 40 | # 3x3 pool -> 1x1 conv branch 41 | self.b4 = nn.Sequential( 42 | nn.MaxPool2d(3, stride=1, padding=1), 43 | nn.Conv2d(in_planes, pool_planes, kernel_size=1), 44 | nn.BatchNorm2d(pool_planes), 45 | nn.ReLU(True), 46 | ) 47 | 48 | def forward(self, x): 49 | y1 = self.b1(x) 50 | y2 = self.b2(x) 51 | y3 = self.b3(x) 52 | y4 = self.b4(x) 53 | return torch.cat([y1,y2,y3,y4], 1) 54 | 55 | 56 | class GoogLeNet(nn.Module): 57 | def __init__(self): 58 | super(GoogLeNet, self).__init__() 59 | self.pre_layers = nn.Sequential( 60 | nn.Conv2d(3, 192, kernel_size=3, padding=1), 61 | nn.BatchNorm2d(192), 62 | nn.ReLU(True), 63 | ) 64 | 65 | self.a3 = Inception(192, 64, 96, 128, 16, 32, 32) 66 | self.b3 = Inception(256, 128, 128, 192, 32, 96, 64) 67 | 68 | self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) 69 | 70 | self.a4 = Inception(480, 192, 96, 208, 16, 48, 64) 71 | self.b4 = Inception(512, 160, 112, 224, 24, 64, 64) 72 | self.c4 = Inception(512, 128, 128, 256, 24, 64, 64) 73 | self.d4 = Inception(512, 112, 144, 288, 32, 64, 64) 74 | self.e4 = Inception(528, 256, 160, 320, 32, 128, 128) 75 | 76 | self.a5 = Inception(832, 256, 160, 320, 32, 128, 128) 77 | self.b5 = Inception(832, 384, 192, 384, 48, 128, 128) 78 | 79 | self.avgpool = nn.AvgPool2d(8, stride=1) 80 | self.linear = nn.Linear(1024, 10) 81 | 82 | def forward(self, x): 83 | out = self.pre_layers(x) 84 | out = self.a3(out) 85 | out = self.b3(out) 86 | out = self.maxpool(out) 87 | out = self.a4(out) 88 | out = self.b4(out) 89 | out = self.c4(out) 90 | out = self.d4(out) 91 | out = self.e4(out) 92 | out = self.maxpool(out) 93 | out = self.a5(out) 94 | out = self.b5(out) 95 | out = self.avgpool(out) 96 | out = out.view(out.size(0), -1) 97 | out = self.linear(out) 98 | return out 99 | 100 | 101 | def test(): 102 | net = GoogLeNet() 103 | x = torch.randn(1,3,32,32) 104 | y = net(x) 105 | print(y.size()) 106 | 107 | # test() 108 | -------------------------------------------------------------------------------- /examples/models/lenet.py: -------------------------------------------------------------------------------- 1 | '''LeNet in PyTorch.''' 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class LeNet(nn.Module): 6 | def __init__(self): 7 | super(LeNet, self).__init__() 8 | self.conv1 = nn.Conv2d(3, 6, 5) 9 | self.conv2 = nn.Conv2d(6, 16, 5) 10 | self.fc1 = nn.Linear(16*5*5, 120) 11 | self.fc2 = nn.Linear(120, 84) 12 | self.fc3 = nn.Linear(84, 10) 13 | 14 | def forward(self, x): 15 | out = F.relu(self.conv1(x)) 16 | out = F.max_pool2d(out, 2) 17 | out = F.relu(self.conv2(out)) 18 | out = F.max_pool2d(out, 2) 19 | out = out.view(out.size(0), -1) 20 | out = F.relu(self.fc1(out)) 21 | out = F.relu(self.fc2(out)) 22 | out = self.fc3(out) 23 | return out 24 | -------------------------------------------------------------------------------- /examples/models/mobilenet.py: -------------------------------------------------------------------------------- 1 | '''MobileNet in PyTorch. 2 | 3 | See the paper "MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications" 4 | for more details. 5 | ''' 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class Block(nn.Module): 12 | '''Depthwise conv + Pointwise conv''' 13 | def __init__(self, in_planes, out_planes, stride=1): 14 | super(Block, self).__init__() 15 | self.conv1 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=in_planes, bias=False) 16 | self.bn1 = nn.BatchNorm2d(in_planes) 17 | self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 18 | self.bn2 = nn.BatchNorm2d(out_planes) 19 | 20 | def forward(self, x): 21 | out = F.relu(self.bn1(self.conv1(x))) 22 | out = F.relu(self.bn2(self.conv2(out))) 23 | return out 24 | 25 | 26 | class MobileNet(nn.Module): 27 | # (128,2) means conv planes=128, conv stride=2, by default conv stride=1 28 | cfg = [64, (128,2), 128, (256,2), 256, (512,2), 512, 512, 512, 512, 512, (1024,2), 1024] 29 | 30 | def __init__(self, num_classes=10): 31 | super(MobileNet, self).__init__() 32 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False) 33 | self.bn1 = nn.BatchNorm2d(32) 34 | self.layers = self._make_layers(in_planes=32) 35 | self.linear = nn.Linear(1024, num_classes) 36 | 37 | def _make_layers(self, in_planes): 38 | layers = [] 39 | for x in self.cfg: 40 | out_planes = x if isinstance(x, int) else x[0] 41 | stride = 1 if isinstance(x, int) else x[1] 42 | layers.append(Block(in_planes, out_planes, stride)) 43 | in_planes = out_planes 44 | return nn.Sequential(*layers) 45 | 46 | def forward(self, x): 47 | out = F.relu(self.bn1(self.conv1(x))) 48 | out = self.layers(out) 49 | out = F.avg_pool2d(out, 2) 50 | out = out.view(out.size(0), -1) 51 | out = self.linear(out) 52 | return out 53 | 54 | 55 | def test(): 56 | net = MobileNet() 57 | x = torch.randn(1,3,32,32) 58 | y = net(x) 59 | print(y.size()) 60 | 61 | # test() 62 | -------------------------------------------------------------------------------- /examples/models/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | '''MobileNetV2 in PyTorch. 2 | 3 | See the paper "Inverted Residuals and Linear Bottlenecks: 4 | Mobile Networks for Classification, Detection and Segmentation" for more details. 5 | ''' 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class Block(nn.Module): 12 | '''expand + depthwise + pointwise''' 13 | def __init__(self, in_planes, out_planes, expansion, stride): 14 | super(Block, self).__init__() 15 | self.stride = stride 16 | 17 | planes = expansion * in_planes 18 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, stride=1, padding=0, bias=False) 19 | self.bn1 = nn.BatchNorm2d(planes) 20 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, groups=planes, bias=False) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | self.conv3 = nn.Conv2d(planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 23 | self.bn3 = nn.BatchNorm2d(out_planes) 24 | 25 | self.shortcut = nn.Sequential() 26 | if stride == 1 and in_planes != out_planes: 27 | self.shortcut = nn.Sequential( 28 | nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False), 29 | nn.BatchNorm2d(out_planes), 30 | ) 31 | 32 | def forward(self, x): 33 | out = F.relu(self.bn1(self.conv1(x))) 34 | out = F.relu(self.bn2(self.conv2(out))) 35 | out = self.bn3(self.conv3(out)) 36 | out = out + self.shortcut(x) if self.stride==1 else out 37 | return out 38 | 39 | 40 | class MobileNetV2(nn.Module): 41 | # (expansion, out_planes, num_blocks, stride) 42 | cfg = [(1, 16, 1, 1), 43 | (6, 24, 2, 1), # NOTE: change stride 2 -> 1 for CIFAR10 44 | (6, 32, 3, 2), 45 | (6, 64, 4, 2), 46 | (6, 96, 3, 1), 47 | (6, 160, 3, 2), 48 | (6, 320, 1, 1)] 49 | 50 | def __init__(self, num_classes=10): 51 | super(MobileNetV2, self).__init__() 52 | # NOTE: change conv1 stride 2 -> 1 for CIFAR10 53 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False) 54 | self.bn1 = nn.BatchNorm2d(32) 55 | self.layers = self._make_layers(in_planes=32) 56 | self.conv2 = nn.Conv2d(320, 1280, kernel_size=1, stride=1, padding=0, bias=False) 57 | self.bn2 = nn.BatchNorm2d(1280) 58 | self.linear = nn.Linear(1280, num_classes) 59 | 60 | def _make_layers(self, in_planes): 61 | layers = [] 62 | for expansion, out_planes, num_blocks, stride in self.cfg: 63 | strides = [stride] + [1]*(num_blocks-1) 64 | for stride in strides: 65 | layers.append(Block(in_planes, out_planes, expansion, stride)) 66 | in_planes = out_planes 67 | return nn.Sequential(*layers) 68 | 69 | def forward(self, x): 70 | out = F.relu(self.bn1(self.conv1(x))) 71 | out = self.layers(out) 72 | out = F.relu(self.bn2(self.conv2(out))) 73 | # NOTE: change pooling kernel_size 7 -> 4 for CIFAR10 74 | out = F.avg_pool2d(out, 4) 75 | out = out.view(out.size(0), -1) 76 | out = self.linear(out) 77 | return out 78 | 79 | 80 | def test(): 81 | net = MobileNetV2() 82 | x = torch.randn(2,3,32,32) 83 | y = net(x) 84 | print(y.size()) 85 | 86 | # test() 87 | -------------------------------------------------------------------------------- /examples/models/pnasnet.py: -------------------------------------------------------------------------------- 1 | '''PNASNet in PyTorch. 2 | 3 | Paper: Progressive Neural Architecture Search 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class SepConv(nn.Module): 11 | '''Separable Convolution.''' 12 | def __init__(self, in_planes, out_planes, kernel_size, stride): 13 | super(SepConv, self).__init__() 14 | self.conv1 = nn.Conv2d(in_planes, out_planes, 15 | kernel_size, stride, 16 | padding=(kernel_size-1)//2, 17 | bias=False, groups=in_planes) 18 | self.bn1 = nn.BatchNorm2d(out_planes) 19 | 20 | def forward(self, x): 21 | return self.bn1(self.conv1(x)) 22 | 23 | 24 | class CellA(nn.Module): 25 | def __init__(self, in_planes, out_planes, stride=1): 26 | super(CellA, self).__init__() 27 | self.stride = stride 28 | self.sep_conv1 = SepConv(in_planes, out_planes, kernel_size=7, stride=stride) 29 | if stride==2: 30 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 31 | self.bn1 = nn.BatchNorm2d(out_planes) 32 | 33 | def forward(self, x): 34 | y1 = self.sep_conv1(x) 35 | y2 = F.max_pool2d(x, kernel_size=3, stride=self.stride, padding=1) 36 | if self.stride==2: 37 | y2 = self.bn1(self.conv1(y2)) 38 | return F.relu(y1+y2) 39 | 40 | class CellB(nn.Module): 41 | def __init__(self, in_planes, out_planes, stride=1): 42 | super(CellB, self).__init__() 43 | self.stride = stride 44 | # Left branch 45 | self.sep_conv1 = SepConv(in_planes, out_planes, kernel_size=7, stride=stride) 46 | self.sep_conv2 = SepConv(in_planes, out_planes, kernel_size=3, stride=stride) 47 | # Right branch 48 | self.sep_conv3 = SepConv(in_planes, out_planes, kernel_size=5, stride=stride) 49 | if stride==2: 50 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 51 | self.bn1 = nn.BatchNorm2d(out_planes) 52 | # Reduce channels 53 | self.conv2 = nn.Conv2d(2*out_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 54 | self.bn2 = nn.BatchNorm2d(out_planes) 55 | 56 | def forward(self, x): 57 | # Left branch 58 | y1 = self.sep_conv1(x) 59 | y2 = self.sep_conv2(x) 60 | # Right branch 61 | y3 = F.max_pool2d(x, kernel_size=3, stride=self.stride, padding=1) 62 | if self.stride==2: 63 | y3 = self.bn1(self.conv1(y3)) 64 | y4 = self.sep_conv3(x) 65 | # Concat & reduce channels 66 | b1 = F.relu(y1+y2) 67 | b2 = F.relu(y3+y4) 68 | y = torch.cat([b1,b2], 1) 69 | return F.relu(self.bn2(self.conv2(y))) 70 | 71 | class PNASNet(nn.Module): 72 | def __init__(self, cell_type, num_cells, num_planes): 73 | super(PNASNet, self).__init__() 74 | self.in_planes = num_planes 75 | self.cell_type = cell_type 76 | 77 | self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, stride=1, padding=1, bias=False) 78 | self.bn1 = nn.BatchNorm2d(num_planes) 79 | 80 | self.layer1 = self._make_layer(num_planes, num_cells=6) 81 | self.layer2 = self._downsample(num_planes*2) 82 | self.layer3 = self._make_layer(num_planes*2, num_cells=6) 83 | self.layer4 = self._downsample(num_planes*4) 84 | self.layer5 = self._make_layer(num_planes*4, num_cells=6) 85 | 86 | self.linear = nn.Linear(num_planes*4, 10) 87 | 88 | def _make_layer(self, planes, num_cells): 89 | layers = [] 90 | for _ in range(num_cells): 91 | layers.append(self.cell_type(self.in_planes, planes, stride=1)) 92 | self.in_planes = planes 93 | return nn.Sequential(*layers) 94 | 95 | def _downsample(self, planes): 96 | layer = self.cell_type(self.in_planes, planes, stride=2) 97 | self.in_planes = planes 98 | return layer 99 | 100 | def forward(self, x): 101 | out = F.relu(self.bn1(self.conv1(x))) 102 | out = self.layer1(out) 103 | out = self.layer2(out) 104 | out = self.layer3(out) 105 | out = self.layer4(out) 106 | out = self.layer5(out) 107 | out = F.avg_pool2d(out, 8) 108 | out = self.linear(out.view(out.size(0), -1)) 109 | return out 110 | 111 | 112 | def PNASNetA(): 113 | return PNASNet(CellA, num_cells=6, num_planes=44) 114 | 115 | def PNASNetB(): 116 | return PNASNet(CellB, num_cells=6, num_planes=32) 117 | 118 | 119 | def test(): 120 | net = PNASNetB() 121 | x = torch.randn(1,3,32,32) 122 | y = net(x) 123 | print(y) 124 | 125 | # test() 126 | -------------------------------------------------------------------------------- /examples/models/preact_resnet.py: -------------------------------------------------------------------------------- 1 | '''Pre-activation ResNet in PyTorch. 2 | 3 | Reference: 4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 5 | Identity Mappings in Deep Residual Networks. arXiv:1603.05027 6 | ''' 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class PreActBlock(nn.Module): 13 | '''Pre-activation version of the BasicBlock.''' 14 | expansion = 1 15 | 16 | def __init__(self, in_planes, planes, stride=1): 17 | super(PreActBlock, self).__init__() 18 | self.bn1 = nn.BatchNorm2d(in_planes) 19 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 22 | 23 | if stride != 1 or in_planes != self.expansion*planes: 24 | self.shortcut = nn.Sequential( 25 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 26 | ) 27 | 28 | def forward(self, x): 29 | out = F.relu(self.bn1(x)) 30 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 31 | out = self.conv1(out) 32 | out = self.conv2(F.relu(self.bn2(out))) 33 | out += shortcut 34 | return out 35 | 36 | 37 | class PreActBottleneck(nn.Module): 38 | '''Pre-activation version of the original Bottleneck module.''' 39 | expansion = 4 40 | 41 | def __init__(self, in_planes, planes, stride=1): 42 | super(PreActBottleneck, self).__init__() 43 | self.bn1 = nn.BatchNorm2d(in_planes) 44 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 45 | self.bn2 = nn.BatchNorm2d(planes) 46 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 47 | self.bn3 = nn.BatchNorm2d(planes) 48 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 49 | 50 | if stride != 1 or in_planes != self.expansion*planes: 51 | self.shortcut = nn.Sequential( 52 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 53 | ) 54 | 55 | def forward(self, x): 56 | out = F.relu(self.bn1(x)) 57 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 58 | out = self.conv1(out) 59 | out = self.conv2(F.relu(self.bn2(out))) 60 | out = self.conv3(F.relu(self.bn3(out))) 61 | out += shortcut 62 | return out 63 | 64 | 65 | class PreActResNet(nn.Module): 66 | def __init__(self, block, num_blocks, num_classes=10): 67 | super(PreActResNet, self).__init__() 68 | self.in_planes = 64 69 | 70 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 71 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 72 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 73 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 74 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 75 | self.linear = nn.Linear(512*block.expansion, num_classes) 76 | 77 | def _make_layer(self, block, planes, num_blocks, stride): 78 | strides = [stride] + [1]*(num_blocks-1) 79 | layers = [] 80 | for stride in strides: 81 | layers.append(block(self.in_planes, planes, stride)) 82 | self.in_planes = planes * block.expansion 83 | return nn.Sequential(*layers) 84 | 85 | def forward(self, x): 86 | out = self.conv1(x) 87 | out = self.layer1(out) 88 | out = self.layer2(out) 89 | out = self.layer3(out) 90 | out = self.layer4(out) 91 | out = F.avg_pool2d(out, 4) 92 | out = out.view(out.size(0), -1) 93 | out = self.linear(out) 94 | return out 95 | 96 | 97 | def PreActResNet18(): 98 | return PreActResNet(PreActBlock, [2,2,2,2]) 99 | 100 | def PreActResNet34(): 101 | return PreActResNet(PreActBlock, [3,4,6,3]) 102 | 103 | def PreActResNet50(): 104 | return PreActResNet(PreActBottleneck, [3,4,6,3]) 105 | 106 | def PreActResNet101(): 107 | return PreActResNet(PreActBottleneck, [3,4,23,3]) 108 | 109 | def PreActResNet152(): 110 | return PreActResNet(PreActBottleneck, [3,8,36,3]) 111 | 112 | 113 | def test(): 114 | net = PreActResNet18() 115 | y = net((torch.randn(1,3,32,32))) 116 | print(y.size()) 117 | 118 | # test() 119 | -------------------------------------------------------------------------------- /examples/models/resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | 3 | For Pre-activation ResNet, see 'preact_resnet.py'. 4 | 5 | Reference: 6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 8 | ''' 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class BasicBlock(nn.Module): 15 | expansion = 1 16 | 17 | def __init__(self, in_planes, planes, stride=1): 18 | super(BasicBlock, self).__init__() 19 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 20 | self.bn1 = nn.BatchNorm2d(planes) 21 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | 24 | self.shortcut = nn.Sequential() 25 | if stride != 1 or in_planes != self.expansion*planes: 26 | self.shortcut = nn.Sequential( 27 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 28 | nn.BatchNorm2d(self.expansion*planes) 29 | ) 30 | 31 | def forward(self, x): 32 | out = F.relu(self.bn1(self.conv1(x))) 33 | out = self.bn2(self.conv2(out)) 34 | out += self.shortcut(x) 35 | out = F.relu(out) 36 | return out 37 | 38 | 39 | class Bottleneck(nn.Module): 40 | expansion = 4 41 | 42 | def __init__(self, in_planes, planes, stride=1): 43 | super(Bottleneck, self).__init__() 44 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 45 | self.bn1 = nn.BatchNorm2d(planes) 46 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 47 | self.bn2 = nn.BatchNorm2d(planes) 48 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 49 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 50 | 51 | self.shortcut = nn.Sequential() 52 | if stride != 1 or in_planes != self.expansion*planes: 53 | self.shortcut = nn.Sequential( 54 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 55 | nn.BatchNorm2d(self.expansion*planes) 56 | ) 57 | 58 | def forward(self, x): 59 | out = F.relu(self.bn1(self.conv1(x))) 60 | out = F.relu(self.bn2(self.conv2(out))) 61 | out = self.bn3(self.conv3(out)) 62 | out += self.shortcut(x) 63 | out = F.relu(out) 64 | return out 65 | 66 | 67 | class ResNet(nn.Module): 68 | def __init__(self, block, num_blocks, num_classes=10): 69 | super(ResNet, self).__init__() 70 | self.in_planes = 64 71 | 72 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 73 | self.bn1 = nn.BatchNorm2d(64) 74 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 75 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 76 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 77 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 78 | self.linear = nn.Linear(512*block.expansion, num_classes) 79 | 80 | def _make_layer(self, block, planes, num_blocks, stride): 81 | strides = [stride] + [1]*(num_blocks-1) 82 | layers = [] 83 | for stride in strides: 84 | layers.append(block(self.in_planes, planes, stride)) 85 | self.in_planes = planes * block.expansion 86 | return nn.Sequential(*layers) 87 | 88 | def forward(self, x): 89 | out = F.relu(self.bn1(self.conv1(x))) 90 | out = self.layer1(out) 91 | out = self.layer2(out) 92 | out = self.layer3(out) 93 | out = self.layer4(out) 94 | out = F.avg_pool2d(out, 4) 95 | out = out.view(out.size(0), -1) 96 | out = self.linear(out) 97 | return out 98 | 99 | 100 | def ResNet18(): 101 | return ResNet(BasicBlock, [2,2,2,2]) 102 | 103 | def ResNet34(): 104 | return ResNet(BasicBlock, [3,4,6,3]) 105 | 106 | def ResNet50(): 107 | return ResNet(Bottleneck, [3,4,6,3]) 108 | 109 | def ResNet101(): 110 | return ResNet(Bottleneck, [3,4,23,3]) 111 | 112 | def ResNet152(): 113 | return ResNet(Bottleneck, [3,8,36,3]) 114 | 115 | 116 | def test(): 117 | net = ResNet18() 118 | y = net(torch.randn(1,3,32,32)) 119 | print(y.size()) 120 | 121 | # test() 122 | -------------------------------------------------------------------------------- /examples/models/resnext.py: -------------------------------------------------------------------------------- 1 | '''ResNeXt in PyTorch. 2 | 3 | See the paper "Aggregated Residual Transformations for Deep Neural Networks" for more details. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class Block(nn.Module): 11 | '''Grouped convolution block.''' 12 | expansion = 2 13 | 14 | def __init__(self, in_planes, cardinality=32, bottleneck_width=4, stride=1): 15 | super(Block, self).__init__() 16 | group_width = cardinality * bottleneck_width 17 | self.conv1 = nn.Conv2d(in_planes, group_width, kernel_size=1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(group_width) 19 | self.conv2 = nn.Conv2d(group_width, group_width, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False) 20 | self.bn2 = nn.BatchNorm2d(group_width) 21 | self.conv3 = nn.Conv2d(group_width, self.expansion*group_width, kernel_size=1, bias=False) 22 | self.bn3 = nn.BatchNorm2d(self.expansion*group_width) 23 | 24 | self.shortcut = nn.Sequential() 25 | if stride != 1 or in_planes != self.expansion*group_width: 26 | self.shortcut = nn.Sequential( 27 | nn.Conv2d(in_planes, self.expansion*group_width, kernel_size=1, stride=stride, bias=False), 28 | nn.BatchNorm2d(self.expansion*group_width) 29 | ) 30 | 31 | def forward(self, x): 32 | out = F.relu(self.bn1(self.conv1(x))) 33 | out = F.relu(self.bn2(self.conv2(out))) 34 | out = self.bn3(self.conv3(out)) 35 | out += self.shortcut(x) 36 | out = F.relu(out) 37 | return out 38 | 39 | 40 | class ResNeXt(nn.Module): 41 | def __init__(self, num_blocks, cardinality, bottleneck_width, num_classes=10): 42 | super(ResNeXt, self).__init__() 43 | self.cardinality = cardinality 44 | self.bottleneck_width = bottleneck_width 45 | self.in_planes = 64 46 | 47 | self.conv1 = nn.Conv2d(3, 64, kernel_size=1, bias=False) 48 | self.bn1 = nn.BatchNorm2d(64) 49 | self.layer1 = self._make_layer(num_blocks[0], 1) 50 | self.layer2 = self._make_layer(num_blocks[1], 2) 51 | self.layer3 = self._make_layer(num_blocks[2], 2) 52 | # self.layer4 = self._make_layer(num_blocks[3], 2) 53 | self.linear = nn.Linear(cardinality*bottleneck_width*8, num_classes) 54 | 55 | def _make_layer(self, num_blocks, stride): 56 | strides = [stride] + [1]*(num_blocks-1) 57 | layers = [] 58 | for stride in strides: 59 | layers.append(Block(self.in_planes, self.cardinality, self.bottleneck_width, stride)) 60 | self.in_planes = Block.expansion * self.cardinality * self.bottleneck_width 61 | # Increase bottleneck_width by 2 after each stage. 62 | self.bottleneck_width *= 2 63 | return nn.Sequential(*layers) 64 | 65 | def forward(self, x): 66 | out = F.relu(self.bn1(self.conv1(x))) 67 | out = self.layer1(out) 68 | out = self.layer2(out) 69 | out = self.layer3(out) 70 | # out = self.layer4(out) 71 | out = F.avg_pool2d(out, 8) 72 | out = out.view(out.size(0), -1) 73 | out = self.linear(out) 74 | return out 75 | 76 | 77 | def ResNeXt29_2x64d(): 78 | return ResNeXt(num_blocks=[3,3,3], cardinality=2, bottleneck_width=64) 79 | 80 | def ResNeXt29_4x64d(): 81 | return ResNeXt(num_blocks=[3,3,3], cardinality=4, bottleneck_width=64) 82 | 83 | def ResNeXt29_8x64d(): 84 | return ResNeXt(num_blocks=[3,3,3], cardinality=8, bottleneck_width=64) 85 | 86 | def ResNeXt29_32x4d(): 87 | return ResNeXt(num_blocks=[3,3,3], cardinality=32, bottleneck_width=4) 88 | 89 | def test_resnext(): 90 | net = ResNeXt29_2x64d() 91 | x = torch.randn(1,3,32,32) 92 | y = net(x) 93 | print(y.size()) 94 | 95 | # test_resnext() 96 | -------------------------------------------------------------------------------- /examples/models/senet.py: -------------------------------------------------------------------------------- 1 | '''SENet in PyTorch. 2 | 3 | SENet is the winner of ImageNet-2017. The paper is not released yet. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class BasicBlock(nn.Module): 11 | def __init__(self, in_planes, planes, stride=1): 12 | super(BasicBlock, self).__init__() 13 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 14 | self.bn1 = nn.BatchNorm2d(planes) 15 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 16 | self.bn2 = nn.BatchNorm2d(planes) 17 | 18 | self.shortcut = nn.Sequential() 19 | if stride != 1 or in_planes != planes: 20 | self.shortcut = nn.Sequential( 21 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False), 22 | nn.BatchNorm2d(planes) 23 | ) 24 | 25 | # SE layers 26 | self.fc1 = nn.Conv2d(planes, planes//16, kernel_size=1) # Use nn.Conv2d instead of nn.Linear 27 | self.fc2 = nn.Conv2d(planes//16, planes, kernel_size=1) 28 | 29 | def forward(self, x): 30 | out = F.relu(self.bn1(self.conv1(x))) 31 | out = self.bn2(self.conv2(out)) 32 | 33 | # Squeeze 34 | w = F.avg_pool2d(out, out.size(2)) 35 | w = F.relu(self.fc1(w)) 36 | w = F.sigmoid(self.fc2(w)) 37 | # Excitation 38 | out = out * w # New broadcasting feature from v0.2! 39 | 40 | out += self.shortcut(x) 41 | out = F.relu(out) 42 | return out 43 | 44 | 45 | class PreActBlock(nn.Module): 46 | def __init__(self, in_planes, planes, stride=1): 47 | super(PreActBlock, self).__init__() 48 | self.bn1 = nn.BatchNorm2d(in_planes) 49 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 50 | self.bn2 = nn.BatchNorm2d(planes) 51 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 52 | 53 | if stride != 1 or in_planes != planes: 54 | self.shortcut = nn.Sequential( 55 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False) 56 | ) 57 | 58 | # SE layers 59 | self.fc1 = nn.Conv2d(planes, planes//16, kernel_size=1) 60 | self.fc2 = nn.Conv2d(planes//16, planes, kernel_size=1) 61 | 62 | def forward(self, x): 63 | out = F.relu(self.bn1(x)) 64 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 65 | out = self.conv1(out) 66 | out = self.conv2(F.relu(self.bn2(out))) 67 | 68 | # Squeeze 69 | w = F.avg_pool2d(out, out.size(2)) 70 | w = F.relu(self.fc1(w)) 71 | w = F.sigmoid(self.fc2(w)) 72 | # Excitation 73 | out = out * w 74 | 75 | out += shortcut 76 | return out 77 | 78 | 79 | class SENet(nn.Module): 80 | def __init__(self, block, num_blocks, num_classes=10): 81 | super(SENet, self).__init__() 82 | self.in_planes = 64 83 | 84 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 85 | self.bn1 = nn.BatchNorm2d(64) 86 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 87 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 88 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 89 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 90 | self.linear = nn.Linear(512, num_classes) 91 | 92 | def _make_layer(self, block, planes, num_blocks, stride): 93 | strides = [stride] + [1]*(num_blocks-1) 94 | layers = [] 95 | for stride in strides: 96 | layers.append(block(self.in_planes, planes, stride)) 97 | self.in_planes = planes 98 | return nn.Sequential(*layers) 99 | 100 | def forward(self, x): 101 | out = F.relu(self.bn1(self.conv1(x))) 102 | out = self.layer1(out) 103 | out = self.layer2(out) 104 | out = self.layer3(out) 105 | out = self.layer4(out) 106 | out = F.avg_pool2d(out, 4) 107 | out = out.view(out.size(0), -1) 108 | out = self.linear(out) 109 | return out 110 | 111 | 112 | def SENet18(): 113 | return SENet(PreActBlock, [2,2,2,2]) 114 | 115 | 116 | def test(): 117 | net = SENet18() 118 | y = net(torch.randn(1,3,32,32)) 119 | print(y.size()) 120 | 121 | # test() 122 | -------------------------------------------------------------------------------- /examples/models/shufflenet.py: -------------------------------------------------------------------------------- 1 | '''ShuffleNet in PyTorch. 2 | 3 | See the paper "ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices" for more details. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class ShuffleBlock(nn.Module): 11 | def __init__(self, groups): 12 | super(ShuffleBlock, self).__init__() 13 | self.groups = groups 14 | 15 | def forward(self, x): 16 | '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]''' 17 | N,C,H,W = x.size() 18 | g = self.groups 19 | return x.view(N,g,C/g,H,W).permute(0,2,1,3,4).contiguous().view(N,C,H,W) 20 | 21 | 22 | class Bottleneck(nn.Module): 23 | def __init__(self, in_planes, out_planes, stride, groups): 24 | super(Bottleneck, self).__init__() 25 | self.stride = stride 26 | 27 | mid_planes = out_planes/4 28 | g = 1 if in_planes==24 else groups 29 | self.conv1 = nn.Conv2d(in_planes, mid_planes, kernel_size=1, groups=g, bias=False) 30 | self.bn1 = nn.BatchNorm2d(mid_planes) 31 | self.shuffle1 = ShuffleBlock(groups=g) 32 | self.conv2 = nn.Conv2d(mid_planes, mid_planes, kernel_size=3, stride=stride, padding=1, groups=mid_planes, bias=False) 33 | self.bn2 = nn.BatchNorm2d(mid_planes) 34 | self.conv3 = nn.Conv2d(mid_planes, out_planes, kernel_size=1, groups=groups, bias=False) 35 | self.bn3 = nn.BatchNorm2d(out_planes) 36 | 37 | self.shortcut = nn.Sequential() 38 | if stride == 2: 39 | self.shortcut = nn.Sequential(nn.AvgPool2d(3, stride=2, padding=1)) 40 | 41 | def forward(self, x): 42 | out = F.relu(self.bn1(self.conv1(x))) 43 | out = self.shuffle1(out) 44 | out = F.relu(self.bn2(self.conv2(out))) 45 | out = self.bn3(self.conv3(out)) 46 | res = self.shortcut(x) 47 | out = F.relu(torch.cat([out,res], 1)) if self.stride==2 else F.relu(out+res) 48 | return out 49 | 50 | 51 | class ShuffleNet(nn.Module): 52 | def __init__(self, cfg): 53 | super(ShuffleNet, self).__init__() 54 | out_planes = cfg['out_planes'] 55 | num_blocks = cfg['num_blocks'] 56 | groups = cfg['groups'] 57 | 58 | self.conv1 = nn.Conv2d(3, 24, kernel_size=1, bias=False) 59 | self.bn1 = nn.BatchNorm2d(24) 60 | self.in_planes = 24 61 | self.layer1 = self._make_layer(out_planes[0], num_blocks[0], groups) 62 | self.layer2 = self._make_layer(out_planes[1], num_blocks[1], groups) 63 | self.layer3 = self._make_layer(out_planes[2], num_blocks[2], groups) 64 | self.linear = nn.Linear(out_planes[2], 10) 65 | 66 | def _make_layer(self, out_planes, num_blocks, groups): 67 | layers = [] 68 | for i in range(num_blocks): 69 | stride = 2 if i == 0 else 1 70 | cat_planes = self.in_planes if i == 0 else 0 71 | layers.append(Bottleneck(self.in_planes, out_planes-cat_planes, stride=stride, groups=groups)) 72 | self.in_planes = out_planes 73 | return nn.Sequential(*layers) 74 | 75 | def forward(self, x): 76 | out = F.relu(self.bn1(self.conv1(x))) 77 | out = self.layer1(out) 78 | out = self.layer2(out) 79 | out = self.layer3(out) 80 | out = F.avg_pool2d(out, 4) 81 | out = out.view(out.size(0), -1) 82 | out = self.linear(out) 83 | return out 84 | 85 | 86 | def ShuffleNetG2(): 87 | cfg = { 88 | 'out_planes': [200,400,800], 89 | 'num_blocks': [4,8,4], 90 | 'groups': 2 91 | } 92 | return ShuffleNet(cfg) 93 | 94 | def ShuffleNetG3(): 95 | cfg = { 96 | 'out_planes': [240,480,960], 97 | 'num_blocks': [4,8,4], 98 | 'groups': 3 99 | } 100 | return ShuffleNet(cfg) 101 | 102 | 103 | def test(): 104 | net = ShuffleNetG2() 105 | x = torch.randn(1,3,32,32) 106 | y = net(x) 107 | print(y) 108 | 109 | # test() 110 | -------------------------------------------------------------------------------- /examples/models/shufflenetv2.py: -------------------------------------------------------------------------------- 1 | '''ShuffleNetV2 in PyTorch. 2 | 3 | See the paper "ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" for more details. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class ShuffleBlock(nn.Module): 11 | def __init__(self, groups=2): 12 | super(ShuffleBlock, self).__init__() 13 | self.groups = groups 14 | 15 | def forward(self, x): 16 | '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]''' 17 | N, C, H, W = x.size() 18 | g = self.groups 19 | return x.view(N, g, C/g, H, W).permute(0, 2, 1, 3, 4).reshape(N, C, H, W) 20 | 21 | 22 | class SplitBlock(nn.Module): 23 | def __init__(self, ratio): 24 | super(SplitBlock, self).__init__() 25 | self.ratio = ratio 26 | 27 | def forward(self, x): 28 | c = int(x.size(1) * self.ratio) 29 | return x[:, :c, :, :], x[:, c:, :, :] 30 | 31 | 32 | class BasicBlock(nn.Module): 33 | def __init__(self, in_channels, split_ratio=0.5): 34 | super(BasicBlock, self).__init__() 35 | self.split = SplitBlock(split_ratio) 36 | in_channels = int(in_channels * split_ratio) 37 | self.conv1 = nn.Conv2d(in_channels, in_channels, 38 | kernel_size=1, bias=False) 39 | self.bn1 = nn.BatchNorm2d(in_channels) 40 | self.conv2 = nn.Conv2d(in_channels, in_channels, 41 | kernel_size=3, stride=1, padding=1, groups=in_channels, bias=False) 42 | self.bn2 = nn.BatchNorm2d(in_channels) 43 | self.conv3 = nn.Conv2d(in_channels, in_channels, 44 | kernel_size=1, bias=False) 45 | self.bn3 = nn.BatchNorm2d(in_channels) 46 | self.shuffle = ShuffleBlock() 47 | 48 | def forward(self, x): 49 | x1, x2 = self.split(x) 50 | out = F.relu(self.bn1(self.conv1(x2))) 51 | out = self.bn2(self.conv2(out)) 52 | out = F.relu(self.bn3(self.conv3(out))) 53 | out = torch.cat([x1, out], 1) 54 | out = self.shuffle(out) 55 | return out 56 | 57 | 58 | class DownBlock(nn.Module): 59 | def __init__(self, in_channels, out_channels): 60 | super(DownBlock, self).__init__() 61 | mid_channels = out_channels // 2 62 | # left 63 | self.conv1 = nn.Conv2d(in_channels, in_channels, 64 | kernel_size=3, stride=2, padding=1, groups=in_channels, bias=False) 65 | self.bn1 = nn.BatchNorm2d(in_channels) 66 | self.conv2 = nn.Conv2d(in_channels, mid_channels, 67 | kernel_size=1, bias=False) 68 | self.bn2 = nn.BatchNorm2d(mid_channels) 69 | # right 70 | self.conv3 = nn.Conv2d(in_channels, mid_channels, 71 | kernel_size=1, bias=False) 72 | self.bn3 = nn.BatchNorm2d(mid_channels) 73 | self.conv4 = nn.Conv2d(mid_channels, mid_channels, 74 | kernel_size=3, stride=2, padding=1, groups=mid_channels, bias=False) 75 | self.bn4 = nn.BatchNorm2d(mid_channels) 76 | self.conv5 = nn.Conv2d(mid_channels, mid_channels, 77 | kernel_size=1, bias=False) 78 | self.bn5 = nn.BatchNorm2d(mid_channels) 79 | 80 | self.shuffle = ShuffleBlock() 81 | 82 | def forward(self, x): 83 | # left 84 | out1 = self.bn1(self.conv1(x)) 85 | out1 = F.relu(self.bn2(self.conv2(out1))) 86 | # right 87 | out2 = F.relu(self.bn3(self.conv3(x))) 88 | out2 = self.bn4(self.conv4(out2)) 89 | out2 = F.relu(self.bn5(self.conv5(out2))) 90 | # concat 91 | out = torch.cat([out1, out2], 1) 92 | out = self.shuffle(out) 93 | return out 94 | 95 | 96 | class ShuffleNetV2(nn.Module): 97 | def __init__(self, net_size): 98 | super(ShuffleNetV2, self).__init__() 99 | out_channels = configs[net_size]['out_channels'] 100 | num_blocks = configs[net_size]['num_blocks'] 101 | 102 | self.conv1 = nn.Conv2d(3, 24, kernel_size=3, 103 | stride=1, padding=1, bias=False) 104 | self.bn1 = nn.BatchNorm2d(24) 105 | self.in_channels = 24 106 | self.layer1 = self._make_layer(out_channels[0], num_blocks[0]) 107 | self.layer2 = self._make_layer(out_channels[1], num_blocks[1]) 108 | self.layer3 = self._make_layer(out_channels[2], num_blocks[2]) 109 | self.conv2 = nn.Conv2d(out_channels[2], out_channels[3], 110 | kernel_size=1, stride=1, padding=0, bias=False) 111 | self.bn2 = nn.BatchNorm2d(out_channels[3]) 112 | self.linear = nn.Linear(out_channels[3], 10) 113 | 114 | def _make_layer(self, out_channels, num_blocks): 115 | layers = [DownBlock(self.in_channels, out_channels)] 116 | for i in range(num_blocks): 117 | layers.append(BasicBlock(out_channels)) 118 | self.in_channels = out_channels 119 | return nn.Sequential(*layers) 120 | 121 | def forward(self, x): 122 | out = F.relu(self.bn1(self.conv1(x))) 123 | # out = F.max_pool2d(out, 3, stride=2, padding=1) 124 | out = self.layer1(out) 125 | out = self.layer2(out) 126 | out = self.layer3(out) 127 | out = F.relu(self.bn2(self.conv2(out))) 128 | out = F.avg_pool2d(out, 4) 129 | out = out.view(out.size(0), -1) 130 | out = self.linear(out) 131 | return out 132 | 133 | 134 | configs = { 135 | 0.5: { 136 | 'out_channels': (48, 96, 192, 1024), 137 | 'num_blocks': (3, 7, 3) 138 | }, 139 | 140 | 1: { 141 | 'out_channels': (116, 232, 464, 1024), 142 | 'num_blocks': (3, 7, 3) 143 | }, 144 | 1.5: { 145 | 'out_channels': (176, 352, 704, 1024), 146 | 'num_blocks': (3, 7, 3) 147 | }, 148 | 2: { 149 | 'out_channels': (224, 488, 976, 2048), 150 | 'num_blocks': (3, 7, 3) 151 | } 152 | } 153 | 154 | 155 | def test(): 156 | net = ShuffleNetV2(net_size=0.5) 157 | x = torch.randn(3, 3, 32, 32) 158 | y = net(x) 159 | print(y.shape) 160 | 161 | 162 | # test() 163 | -------------------------------------------------------------------------------- /examples/models/vgg.py: -------------------------------------------------------------------------------- 1 | '''VGG11/13/16/19 in Pytorch.''' 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | cfg = { 7 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 8 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 9 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 10 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 11 | } 12 | 13 | 14 | def VGG19(): 15 | return VGG('VGG19') 16 | 17 | class VGG(nn.Module): 18 | def __init__(self, vgg_name): 19 | super(VGG, self).__init__() 20 | self.features = self._make_layers(cfg[vgg_name]) 21 | self.classifier = nn.Linear(512, 10) 22 | 23 | def forward(self, x): 24 | out = self.features(x) 25 | out = out.view(out.size(0), -1) 26 | out = self.classifier(out) 27 | return out 28 | 29 | def _make_layers(self, cfg): 30 | layers = [] 31 | in_channels = 3 32 | for x in cfg: 33 | if x == 'M': 34 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 35 | else: 36 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1), 37 | nn.BatchNorm2d(x), 38 | nn.ReLU(inplace=True)] 39 | in_channels = x 40 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)] 41 | return nn.Sequential(*layers) 42 | 43 | 44 | def test(): 45 | net = VGG('VGG11') 46 | x = torch.randn(2,3,32,32) 47 | y = net(x) 48 | print(y.size()) 49 | 50 | # test() 51 | -------------------------------------------------------------------------------- /examples/test_fmad.py: -------------------------------------------------------------------------------- 1 | 2 | import torch as t 3 | from torch.autograd import grad 4 | 5 | def fmad(ys, xs, dxs): 6 | # inspired by: https://github.com/renmengye/tensorflow-forward-ad/issues/2 7 | v = [t.zeros_like(y, requires_grad=True) for y in ys] 8 | g = grad(ys, xs, grad_outputs=v, create_graph=True) 9 | return grad(g, v, grad_outputs=dxs) 10 | 11 | 12 | # linear function test 13 | x = t.tensor([[0.1], [0.2]], requires_grad=True) 14 | A = t.tensor([[1, 2], [3, 4], [5, 6]], dtype=t.float) 15 | y = A @ x 16 | 17 | print('Linear function output:\n', y) 18 | 19 | bwd_der = grad([y], [x], [t.ones_like(y)], retain_graph=True) 20 | print('RMAD:\n', bwd_der) 21 | 22 | print('Closed-form backward gradient:\n', A.t() @ t.ones_like(y)) 23 | 24 | fwd_der = fmad([y], [x], [t.ones_like(x)]) 25 | print('FMAD:\n', fwd_der) 26 | 27 | print('Closed-form forward gradient:\n', A @ t.ones_like(x)) 28 | 29 | 30 | # MLP test 31 | net = t.nn.Sequential( 32 | t.nn.Linear(3, 4), 33 | t.nn.ReLU(), 34 | t.nn.Linear(4, 2) 35 | ) 36 | parameters = list(net.parameters()) 37 | output = net(t.ones((1, 3))) 38 | 39 | fwd_der = fmad([output], parameters, [t.ones_like(p) for p in parameters]) 40 | print('MLP FMAD:\n', fwd_der) 41 | 42 | # numerical check 43 | loss_der = t.randn_like(output) # arbitrary 44 | param_der = [t.zeros_like(p) for p in parameters] 45 | result_der = [t.zeros_like(p) for p in parameters] 46 | 47 | for (param_idx, param) in enumerate(parameters): 48 | for elem in range(param.numel()): 49 | # set one-hot tensor 50 | param_der[param_idx].flatten()[elem] = 1 51 | 52 | fmad_result = fmad([output], parameters, param_der) 53 | assert len(fmad_result) == 1 54 | 55 | result_der[param_idx].flatten()[elem] = (fmad_result[0] * loss_der).sum() 56 | 57 | param_der[param_idx].flatten()[elem] = 0 # reset it 58 | 59 | print('MLP backward gradient using FMAD (numerical):', result_der) 60 | 61 | print('MLP backward gradient using RMAD (standard):', grad([output], parameters, loss_der)) 62 | 63 | --------------------------------------------------------------------------------