├── .gitignore ├── README.md ├── main.py ├── models ├── __init__.py ├── cifar │ ├── __init__.py │ ├── alexnet.py │ ├── densenet.py │ ├── preresnet.py │ ├── resnet.py │ ├── resnext.py │ ├── vgg.py │ └── wrn.py └── imagenet │ ├── __init__.py │ └── resnext.py ├── optimizers ├── __init__.py ├── ekfac.py └── kfac.py ├── trainer.py └── utils ├── data_utils.py ├── kfac_utils.py └── network_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | led / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | # data 6 | data.cifar10/ 7 | data.cifar100/ 8 | *.gz 9 | shells/ 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | checkpoint/ 16 | env/ 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | .hypothesis/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | #*.log 60 | local_settings.py 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # pyenv 79 | .python-version 80 | 81 | # celery beat schedule file 82 | celerybeat-schedule 83 | 84 | # SageMath parsed files 85 | *.sage.py 86 | 87 | # dotenv 88 | .env 89 | *.tar 90 | 91 | # virtualenv 92 | .venv 93 | venv/ 94 | ENV/ 95 | 96 | # Spyder project settings 97 | .spyderproject 98 | .spyproject 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # mkdocs documentation 104 | /site 105 | 106 | # mypy 107 | .mypy_cache/ 108 | 109 | tmp 110 | runs 111 | run 112 | 113 | # PyCharm 114 | .idea/ 115 | 116 | # macOS metadata 117 | .DS_Store 118 | ._.DS_Store 119 | ._* 120 | 121 | # 122 | data/ 123 | log/ 124 | summary/ 125 | data/kernel_toy/*.pth 126 | data/AS/gp-structure-search 127 | #*.data 128 | data/mnist_data 129 | *.npz 130 | *.txt 131 | #*.png 132 | #*.pdf 133 | *.jpeg 134 | *.jpg 135 | #results/ 136 | *.pyc 137 | *__pycache__ 138 | 139 | checkpoint/ 140 | runs/ 141 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # K-FAC_pytorch 2 | Pytorch implementation of [K-FAC](https://arxiv.org/abs/1503.05671) and [E-KFAC](https://arxiv.org/abs/1806.03884). (Only support single-GPU training, need modifications for multi-GPU.) 3 | ## Requiresments 4 | ``` 5 | pytorch 0.4.0 6 | torchvision 7 | python 3.6.0 8 | tqdm 9 | tensorboardX 10 | tensorflow 11 | ``` 12 | ## How to run 13 | ``` 14 | python main.py --dataset cifar10 --optimizer kfac --network vgg16_bn --epoch 100 --milestone 40,80 --learning_rate 0.01 --damping 0.03 --weight_decay 0.003 15 | ``` 16 | 17 | 18 | ## Performance 19 | #### Note: for better hyparameters of K-FAC, please refer to [weight_decay](https://github.com/gd-zhang/Weight-Decay/tree/master/configs) repo. (The hyparameters below are not good enough! Especially the weight decay is too small!) 20 | For K-FAC and E-KFAC, the search range of learning rates, weight decay and dampings are:
21 | (1) learning rate = [3e-2, 1e-2, 3e-3]
22 | (2) weight decay = [1e-2, 3e-3, 1e-3, 3e-4, 1e-4]
23 | (3) damping = [3e-2, 1e-3, 3e-3] 24 | 25 | For SGD:
26 | (1) learning rate = [3e-1, 1e-1, 3e-2]
27 | (2) weight decay = [1e-2, 3e-3, 1e-3, 3e-4, 1e-4] 28 | 29 | #### CIFAR10 30 | 31 | | Optimizer | Model | Acc. | learning rate | weight decay | damping | 32 | |---------- | ---------------------------------- | ----------- | ------------- | -------------| ----------- | 33 | | KFAC | [VGG16_BN](https://arxiv.org/abs/1409.1556) | 93.86% | 0.01 | 0.003 | 0.03 | 34 | | E-KFAC | [VGG16_BN](https://arxiv.org/abs/1409.1556) | 94.00% | 0.003 | 0.01 | 0.03 | 35 | | SGD | [VGG16_BN](https://arxiv.org/abs/1409.1556) | 94.03% | 0.03 | 0.001 | - | 36 | | KFAC | [ResNet110](https://arxiv.org/abs/1512.03385)| 93.59% | 0.01 | 0.003 | 0.03 | 37 | | E-KFAC | [ResNet110](https://arxiv.org/abs/1512.03385)| 93.37% | 0.003 | 0.01 | 0.03 | 38 | | SGD | [ResNet110](https://arxiv.org/abs/1512.03385)| 94.14% | 0.03 | 0.001 | - | 39 | 40 | 41 | 42 | #### CIFAR100 43 | 44 | | Optimizer | Model | Acc. | learning rate | weight decay | damping | 45 | |---------- | ---------------------------------- | ----------- | ------------- | -------------| ----------- | 46 | | KFAC | [VGG16_BN](https://arxiv.org/abs/1409.1556) | 74.09% | 0.003 | 0.01 | 0.03 | 47 | | E-KFAC | [VGG16_BN](https://arxiv.org/abs/1409.1556) | 73.20% | 0.01 | 0.01 | 0.03 | 48 | | SGD | [VGG16_BN](https://arxiv.org/abs/1409.1556) | 74.56% | 0.03 | 0.003 | - | 49 | | KFAC | [ResNet110](https://arxiv.org/abs/1512.03385)| 72.71% | 0.003 | 0.01 | 0.003 | 50 | | E-KFAC | [ResNet110](https://arxiv.org/abs/1512.03385)| 72.32% | 0.03 | 0.001 | 0.03 | 51 | | SGD | [ResNet110](https://arxiv.org/abs/1512.03385)| 72.60% | 0.1 | 0.0003 | - | 52 | 53 | ## Others 54 | Please consider cite the following papers for K-FAC: 55 | ``` 56 | @inproceedings{martens2015optimizing, 57 | title={Optimizing neural networks with kronecker-factored approximate curvature}, 58 | author={Martens, James and Grosse, Roger}, 59 | booktitle={International conference on machine learning}, 60 | pages={2408--2417}, 61 | year={2015} 62 | } 63 | 64 | @inproceedings{grosse2016kronecker, 65 | title={A kronecker-factored approximate fisher matrix for convolution layers}, 66 | author={Grosse, Roger and Martens, James}, 67 | booktitle={International Conference on Machine Learning}, 68 | pages={573--582}, 69 | year={2016} 70 | } 71 | ``` 72 | 73 | and for E-KFAC: 74 | ``` 75 | @inproceedings{george2018fast, 76 | title={Fast Approximate Natural Gradient Descent in a Kronecker Factored Eigenbasis}, 77 | author={George, Thomas and Laurent, C{\'e}sar and Bouthillier, Xavier and Ballas, Nicolas and Vincent, Pascal}, 78 | booktitle={Advances in Neural Information Processing Systems}, 79 | pages={9550--9560}, 80 | year={2018} 81 | } 82 | ``` 83 | 84 | If you have any questions or suggestions, please feel free to contact me via alecwangcq at gmail , com! 85 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | '''Train CIFAR10/CIFAR100 with PyTorch.''' 2 | import argparse 3 | import os 4 | from optimizers import (KFACOptimizer, EKFACOptimizer) 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | from torch.optim.lr_scheduler import MultiStepLR 9 | 10 | from tqdm import tqdm 11 | from tensorboardX import SummaryWriter 12 | from utils.network_utils import get_network 13 | from utils.data_utils import get_dataloader 14 | 15 | 16 | # fetch args 17 | parser = argparse.ArgumentParser() 18 | 19 | 20 | parser.add_argument('--network', default='vgg16_bn', type=str) 21 | parser.add_argument('--depth', default=19, type=int) 22 | parser.add_argument('--dataset', default='cifar10', type=str) 23 | 24 | # densenet 25 | parser.add_argument('--growthRate', default=12, type=int) 26 | parser.add_argument('--compressionRate', default=2, type=int) 27 | 28 | # wrn, densenet 29 | parser.add_argument('--widen_factor', default=1, type=int) 30 | parser.add_argument('--dropRate', default=0.0, type=float) 31 | 32 | 33 | parser.add_argument('--device', default='cuda', type=str) 34 | parser.add_argument('--resume', '-r', action='store_true') 35 | parser.add_argument('--load_path', default='', type=str) 36 | parser.add_argument('--log_dir', default='runs/pretrain', type=str) 37 | 38 | 39 | parser.add_argument('--optimizer', default='kfac', type=str) 40 | parser.add_argument('--batch_size', default=64, type=float) 41 | parser.add_argument('--epoch', default=100, type=int) 42 | parser.add_argument('--milestone', default=None, type=str) 43 | parser.add_argument('--learning_rate', default=0.01, type=float) 44 | parser.add_argument('--momentum', default=0.9, type=float) 45 | parser.add_argument('--stat_decay', default=0.95, type=float) 46 | parser.add_argument('--damping', default=1e-3, type=float) 47 | parser.add_argument('--kl_clip', default=1e-2, type=float) 48 | parser.add_argument('--weight_decay', default=3e-3, type=float) 49 | parser.add_argument('--TCov', default=10, type=int) 50 | parser.add_argument('--TScal', default=10, type=int) 51 | parser.add_argument('--TInv', default=100, type=int) 52 | 53 | 54 | parser.add_argument('--prefix', default=None, type=str) 55 | args = parser.parse_args() 56 | 57 | # init model 58 | nc = { 59 | 'cifar10': 10, 60 | 'cifar100': 100 61 | } 62 | num_classes = nc[args.dataset] 63 | net = get_network(args.network, 64 | depth=args.depth, 65 | num_classes=num_classes, 66 | growthRate=args.growthRate, 67 | compressionRate=args.compressionRate, 68 | widen_factor=args.widen_factor, 69 | dropRate=args.dropRate) 70 | net = net.to(args.device) 71 | 72 | # init dataloader 73 | trainloader, testloader = get_dataloader(dataset=args.dataset, 74 | train_batch_size=args.batch_size, 75 | test_batch_size=256) 76 | 77 | # init optimizer and lr scheduler 78 | optim_name = args.optimizer.lower() 79 | tag = optim_name 80 | if optim_name == 'sgd': 81 | optimizer = optim.SGD(net.parameters(), 82 | lr=args.learning_rate, 83 | momentum=args.momentum, 84 | weight_decay=args.weight_decay) 85 | elif optim_name == 'kfac': 86 | optimizer = KFACOptimizer(net, 87 | lr=args.learning_rate, 88 | momentum=args.momentum, 89 | stat_decay=args.stat_decay, 90 | damping=args.damping, 91 | kl_clip=args.kl_clip, 92 | weight_decay=args.weight_decay, 93 | TCov=args.TCov, 94 | TInv=args.TInv) 95 | elif optim_name == 'ekfac': 96 | optimizer = EKFACOptimizer(net, 97 | lr=args.learning_rate, 98 | momentum=args.momentum, 99 | stat_decay=args.stat_decay, 100 | damping=args.damping, 101 | kl_clip=args.kl_clip, 102 | weight_decay=args.weight_decay, 103 | TCov=args.TCov, 104 | TScal=args.TScal, 105 | TInv=args.TInv) 106 | else: 107 | raise NotImplementedError 108 | 109 | if args.milestone is None: 110 | lr_scheduler = MultiStepLR(optimizer, milestones=[int(args.epoch*0.5), int(args.epoch*0.75)], gamma=0.1) 111 | else: 112 | milestone = [int(_) for _ in args.milestone.split(',')] 113 | lr_scheduler = MultiStepLR(optimizer, milestones=milestone, gamma=0.1) 114 | 115 | # init criterion 116 | criterion = nn.CrossEntropyLoss() 117 | 118 | start_epoch = 0 119 | best_acc = 0 120 | if args.resume: 121 | print('==> Resuming from checkpoint..') 122 | assert os.path.isfile(args.load_path), 'Error: no checkpoint directory found!' 123 | checkpoint = torch.load(args.load_path) 124 | net.load_state_dict(checkpoint['net']) 125 | best_acc = checkpoint['acc'] 126 | start_epoch = checkpoint['epoch'] 127 | print('==> Loaded checkpoint at epoch: %d, acc: %.2f%%' % (start_epoch, best_acc)) 128 | 129 | # init summary writter 130 | 131 | log_dir = os.path.join(args.log_dir, args.dataset, args.network, args.optimizer, 132 | 'lr%.3f_wd%.4f_damping%.4f' % 133 | (args.learning_rate, args.weight_decay, args.damping)) 134 | if not os.path.isdir(log_dir): 135 | os.makedirs(log_dir) 136 | writer = SummaryWriter(log_dir) 137 | 138 | 139 | def train(epoch): 140 | print('\nEpoch: %d' % epoch) 141 | net.train() 142 | train_loss = 0 143 | correct = 0 144 | total = 0 145 | 146 | lr_scheduler.step() 147 | desc = ('[%s][LR=%s] Loss: %.3f | Acc: %.3f%% (%d/%d)' % 148 | (tag, lr_scheduler.get_lr()[0], 0, 0, correct, total)) 149 | 150 | writer.add_scalar('train/lr', lr_scheduler.get_lr()[0], epoch) 151 | 152 | prog_bar = tqdm(enumerate(trainloader), total=len(trainloader), desc=desc, leave=True) 153 | for batch_idx, (inputs, targets) in prog_bar: 154 | inputs, targets = inputs.to(args.device), targets.to(args.device) 155 | optimizer.zero_grad() 156 | outputs = net(inputs) 157 | loss = criterion(outputs, targets) 158 | if optim_name in ['kfac', 'ekfac'] and optimizer.steps % optimizer.TCov == 0: 159 | # compute true fisher 160 | optimizer.acc_stats = True 161 | with torch.no_grad(): 162 | sampled_y = torch.multinomial(torch.nn.functional.softmax(outputs.cpu().data, dim=1), 163 | 1).squeeze().cuda() 164 | loss_sample = criterion(outputs, sampled_y) 165 | loss_sample.backward(retain_graph=True) 166 | optimizer.acc_stats = False 167 | optimizer.zero_grad() # clear the gradient for computing true-fisher. 168 | loss.backward() 169 | optimizer.step() 170 | 171 | train_loss += loss.item() 172 | _, predicted = outputs.max(1) 173 | total += targets.size(0) 174 | correct += predicted.eq(targets).sum().item() 175 | 176 | desc = ('[%s][LR=%s] Loss: %.3f | Acc: %.3f%% (%d/%d)' % 177 | (tag, lr_scheduler.get_lr()[0], train_loss / (batch_idx + 1), 100. * correct / total, correct, total)) 178 | prog_bar.set_description(desc, refresh=True) 179 | 180 | writer.add_scalar('train/loss', train_loss/(batch_idx + 1), epoch) 181 | writer.add_scalar('train/acc', 100. * correct / total, epoch) 182 | 183 | 184 | def test(epoch): 185 | global best_acc 186 | net.eval() 187 | test_loss = 0 188 | correct = 0 189 | total = 0 190 | desc = ('[%s][LR=%s] Loss: %.3f | Acc: %.3f%% (%d/%d)' 191 | % (tag,lr_scheduler.get_lr()[0], test_loss/(0+1), 0, correct, total)) 192 | 193 | prog_bar = tqdm(enumerate(testloader), total=len(testloader), desc=desc, leave=True) 194 | with torch.no_grad(): 195 | for batch_idx, (inputs, targets) in prog_bar: 196 | inputs, targets = inputs.to(args.device), targets.to(args.device) 197 | outputs = net(inputs) 198 | loss = criterion(outputs, targets) 199 | 200 | test_loss += loss.item() 201 | _, predicted = outputs.max(1) 202 | total += targets.size(0) 203 | correct += predicted.eq(targets).sum().item() 204 | 205 | desc = ('[%s][LR=%s] Loss: %.3f | Acc: %.3f%% (%d/%d)' 206 | % (tag, lr_scheduler.get_lr()[0], test_loss / (batch_idx + 1), 100. * correct / total, correct, total)) 207 | prog_bar.set_description(desc, refresh=True) 208 | 209 | # Save checkpoint. 210 | acc = 100.*correct/total 211 | 212 | writer.add_scalar('test/loss', test_loss / (batch_idx + 1), epoch) 213 | writer.add_scalar('test/acc', 100. * correct / total, epoch) 214 | 215 | if acc > best_acc: 216 | print('Saving..') 217 | state = { 218 | 'net': net.state_dict(), 219 | 'acc': acc, 220 | 'epoch': epoch, 221 | 'loss': test_loss, 222 | 'args': args 223 | } 224 | 225 | torch.save(state, '%s/%s_%s_%s%s_best.t7' % (log_dir, 226 | args.optimizer, 227 | args.dataset, 228 | args.network, 229 | args.depth)) 230 | best_acc = acc 231 | 232 | 233 | def main(): 234 | for epoch in range(start_epoch, args.epoch): 235 | train(epoch) 236 | test(epoch) 237 | return best_acc 238 | 239 | 240 | if __name__ == '__main__': 241 | main() 242 | 243 | 244 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alecwangcq/KFAC-Pytorch/25e6dbe14752348d4f6030697b4b7f553ead2e92/models/__init__.py -------------------------------------------------------------------------------- /models/cifar/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | """The models subpackage contains definitions for the following model for CIFAR10/CIFAR100 4 | architectures: 5 | 6 | - `AlexNet`_ 7 | - `VGG`_ 8 | - `ResNet`_ 9 | - `SqueezeNet`_ 10 | - `DenseNet`_ 11 | 12 | You can construct a model with random weights by calling its constructor: 13 | 14 | .. code:: python 15 | 16 | import torchvision.models as models 17 | resnet18 = models.resnet18() 18 | alexnet = models.alexnet() 19 | squeezenet = models.squeezenet1_0() 20 | densenet = models.densenet_161() 21 | 22 | We provide pre-trained models for the ResNet variants and AlexNet, using the 23 | PyTorch :mod:`torch.utils.model_zoo`. These can constructed by passing 24 | ``pretrained=True``: 25 | 26 | .. code:: python 27 | 28 | import torchvision.models as models 29 | resnet18 = models.resnet18(pretrained=True) 30 | alexnet = models.alexnet(pretrained=True) 31 | 32 | ImageNet 1-crop error rates (224x224) 33 | 34 | ======================== ============= ============= 35 | Network Top-1 error Top-5 error 36 | ======================== ============= ============= 37 | ResNet-18 30.24 10.92 38 | ResNet-34 26.70 8.58 39 | ResNet-50 23.85 7.13 40 | ResNet-101 22.63 6.44 41 | ResNet-152 21.69 5.94 42 | Inception v3 22.55 6.44 43 | AlexNet 43.45 20.91 44 | VGG-11 30.98 11.37 45 | VGG-13 30.07 10.75 46 | VGG-16 28.41 9.62 47 | VGG-19 27.62 9.12 48 | SqueezeNet 1.0 41.90 19.58 49 | SqueezeNet 1.1 41.81 19.38 50 | Densenet-121 25.35 7.83 51 | Densenet-169 24.00 7.00 52 | Densenet-201 22.80 6.43 53 | Densenet-161 22.35 6.20 54 | ======================== ============= ============= 55 | 56 | 57 | .. _AlexNet: https://arxiv.org/abs/1404.5997 58 | .. _VGG: https://arxiv.org/abs/1409.1556 59 | .. _ResNet: https://arxiv.org/abs/1512.03385 60 | .. _SqueezeNet: https://arxiv.org/abs/1602.07360 61 | .. _DenseNet: https://arxiv.org/abs/1608.06993 62 | """ 63 | 64 | from .alexnet import * 65 | from .vgg import * 66 | from .resnet import * 67 | from .resnext import * 68 | from .wrn import * 69 | from .preresnet import * 70 | from .densenet import * 71 | -------------------------------------------------------------------------------- /models/cifar/alexnet.py: -------------------------------------------------------------------------------- 1 | '''AlexNet for CIFAR10. FC layers are removed. Paddings are adjusted. 2 | Without BN, the start learning rate should be 0.01 3 | (c) YANG, Wei 4 | ''' 5 | import torch.nn as nn 6 | 7 | 8 | __all__ = ['alexnet'] 9 | 10 | 11 | class AlexNet(nn.Module): 12 | 13 | def __init__(self, num_classes=10, **kwargs): 14 | super(AlexNet, self).__init__() 15 | self.features = nn.Sequential( 16 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=5), 17 | nn.ReLU(inplace=True), 18 | nn.MaxPool2d(kernel_size=2, stride=2), 19 | nn.Conv2d(64, 192, kernel_size=5, padding=2), 20 | nn.ReLU(inplace=True), 21 | nn.MaxPool2d(kernel_size=2, stride=2), 22 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 23 | nn.ReLU(inplace=True), 24 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 25 | nn.ReLU(inplace=True), 26 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 27 | nn.ReLU(inplace=True), 28 | nn.MaxPool2d(kernel_size=2, stride=2), 29 | ) 30 | self.classifier = nn.Linear(256, num_classes) 31 | 32 | def forward(self, x): 33 | x = self.features(x) 34 | x = x.view(x.size(0), -1) 35 | x = self.classifier(x) 36 | return x 37 | 38 | 39 | def alexnet(**kwargs): 40 | r"""AlexNet model architecture from the 41 | `"One weird trick..." `_ paper. 42 | """ 43 | model = AlexNet(**kwargs) 44 | return model 45 | -------------------------------------------------------------------------------- /models/cifar/densenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | 7 | __all__ = ['densenet'] 8 | 9 | 10 | from torch.autograd import Variable 11 | 12 | class Bottleneck(nn.Module): 13 | def __init__(self, inplanes, expansion=4, growthRate=12, dropRate=0): 14 | super(Bottleneck, self).__init__() 15 | planes = expansion * growthRate 16 | self.bn1 = nn.BatchNorm2d(inplanes) 17 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 18 | self.bn2 = nn.BatchNorm2d(planes) 19 | self.conv2 = nn.Conv2d(planes, growthRate, kernel_size=3, 20 | padding=1, bias=False) 21 | self.relu = nn.ReLU(inplace=True) 22 | self.dropRate = dropRate 23 | 24 | def forward(self, x): 25 | out = self.bn1(x) 26 | out = self.relu(out) 27 | out = self.conv1(out) 28 | out = self.bn2(out) 29 | out = self.relu(out) 30 | out = self.conv2(out) 31 | if self.dropRate > 0: 32 | out = F.dropout(out, p=self.dropRate, training=self.training) 33 | 34 | out = torch.cat((x, out), 1) 35 | 36 | return out 37 | 38 | 39 | class BasicBlock(nn.Module): 40 | def __init__(self, inplanes, expansion=1, growthRate=12, dropRate=0): 41 | super(BasicBlock, self).__init__() 42 | planes = expansion * growthRate 43 | self.bn1 = nn.BatchNorm2d(inplanes) 44 | self.conv1 = nn.Conv2d(inplanes, growthRate, kernel_size=3, 45 | padding=1, bias=False) 46 | self.relu = nn.ReLU(inplace=True) 47 | self.dropRate = dropRate 48 | 49 | def forward(self, x): 50 | out = self.bn1(x) 51 | out = self.relu(out) 52 | out = self.conv1(out) 53 | if self.dropRate > 0: 54 | out = F.dropout(out, p=self.dropRate, training=self.training) 55 | 56 | out = torch.cat((x, out), 1) 57 | 58 | return out 59 | 60 | 61 | class Transition(nn.Module): 62 | def __init__(self, inplanes, outplanes): 63 | super(Transition, self).__init__() 64 | self.bn1 = nn.BatchNorm2d(inplanes) 65 | self.conv1 = nn.Conv2d(inplanes, outplanes, kernel_size=1, 66 | bias=False) 67 | self.relu = nn.ReLU(inplace=True) 68 | 69 | def forward(self, x): 70 | out = self.bn1(x) 71 | out = self.relu(out) 72 | out = self.conv1(out) 73 | out = F.avg_pool2d(out, 2) 74 | return out 75 | 76 | 77 | class DenseNet(nn.Module): 78 | 79 | def __init__(self, depth=22, block=Bottleneck, 80 | dropRate=0, num_classes=10, growthRate=12, compressionRate=2, **kwargs): 81 | super(DenseNet, self).__init__() 82 | 83 | assert (depth - 4) % 3 == 0, 'depth should be 3n+4' 84 | n = (depth - 4) / 3 if block == BasicBlock else (depth - 4) // 6 85 | 86 | self.growthRate = growthRate 87 | self.dropRate = dropRate 88 | 89 | # self.inplanes is a global variable used across multiple 90 | # helper functions 91 | self.inplanes = growthRate * 2 92 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, padding=1, 93 | bias=False) 94 | self.dense1 = self._make_denseblock(block, n) 95 | self.trans1 = self._make_transition(compressionRate) 96 | self.dense2 = self._make_denseblock(block, n) 97 | self.trans2 = self._make_transition(compressionRate) 98 | self.dense3 = self._make_denseblock(block, n) 99 | self.bn = nn.BatchNorm2d(self.inplanes) 100 | self.relu = nn.ReLU(inplace=True) 101 | self.avgpool = nn.AvgPool2d(8) 102 | self.fc = nn.Linear(self.inplanes, num_classes) 103 | 104 | # Weight initialization 105 | for m in self.modules(): 106 | if isinstance(m, nn.Conv2d): 107 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 108 | m.weight.data.normal_(0, math.sqrt(2. / n)) 109 | elif isinstance(m, nn.BatchNorm2d): 110 | m.weight.data.fill_(1) 111 | m.bias.data.zero_() 112 | 113 | def _make_denseblock(self, block, blocks): 114 | layers = [] 115 | for i in range(blocks): 116 | # Currently we fix the expansion ratio as the default value 117 | layers.append(block(self.inplanes, growthRate=self.growthRate, dropRate=self.dropRate)) 118 | self.inplanes += self.growthRate 119 | 120 | return nn.Sequential(*layers) 121 | 122 | def _make_transition(self, compressionRate): 123 | inplanes = self.inplanes 124 | outplanes = int(math.floor(self.inplanes // compressionRate)) 125 | self.inplanes = outplanes 126 | return Transition(inplanes, outplanes) 127 | 128 | 129 | def forward(self, x): 130 | x = self.conv1(x) 131 | 132 | x = self.trans1(self.dense1(x)) 133 | x = self.trans2(self.dense2(x)) 134 | x = self.dense3(x) 135 | x = self.bn(x) 136 | x = self.relu(x) 137 | 138 | x = self.avgpool(x) 139 | x = x.view(x.size(0), -1) 140 | x = self.fc(x) 141 | 142 | return x 143 | 144 | 145 | def densenet(**kwargs): 146 | """ 147 | Constructs a ResNet model. 148 | """ 149 | return DenseNet(**kwargs) -------------------------------------------------------------------------------- /models/cifar/preresnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | '''Resnet for cifar dataset. 4 | Ported form 5 | https://github.com/facebook/fb.resnet.torch 6 | and 7 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 8 | (c) YANG, Wei 9 | ''' 10 | import torch.nn as nn 11 | import math 12 | 13 | 14 | __all__ = ['preresnet'] 15 | 16 | def conv3x3(in_planes, out_planes, stride=1): 17 | "3x3 convolution with padding" 18 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 19 | padding=1, bias=False) 20 | 21 | 22 | class BasicBlock(nn.Module): 23 | expansion = 1 24 | 25 | def __init__(self, inplanes, planes, stride=1, downsample=None): 26 | super(BasicBlock, self).__init__() 27 | self.bn1 = nn.BatchNorm2d(inplanes) 28 | self.relu = nn.ReLU(inplace=True) 29 | self.conv1 = conv3x3(inplanes, planes, stride) 30 | self.bn2 = nn.BatchNorm2d(planes) 31 | self.conv2 = conv3x3(planes, planes) 32 | self.downsample = downsample 33 | self.stride = stride 34 | 35 | def forward(self, x): 36 | residual = x 37 | 38 | out = self.bn1(x) 39 | out = self.relu(out) 40 | out = self.conv1(out) 41 | 42 | out = self.bn2(out) 43 | out = self.relu(out) 44 | out = self.conv2(out) 45 | 46 | if self.downsample is not None: 47 | residual = self.downsample(x) 48 | 49 | out += residual 50 | 51 | return out 52 | 53 | 54 | class Bottleneck(nn.Module): 55 | expansion = 4 56 | 57 | def __init__(self, inplanes, planes, stride=1, downsample=None): 58 | super(Bottleneck, self).__init__() 59 | self.bn1 = nn.BatchNorm2d(inplanes) 60 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 61 | self.bn2 = nn.BatchNorm2d(planes) 62 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 63 | padding=1, bias=False) 64 | self.bn3 = nn.BatchNorm2d(planes) 65 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 66 | self.relu = nn.ReLU(inplace=True) 67 | self.downsample = downsample 68 | self.stride = stride 69 | 70 | def forward(self, x): 71 | residual = x 72 | 73 | out = self.bn1(x) 74 | out = self.relu(out) 75 | out = self.conv1(out) 76 | 77 | out = self.bn2(out) 78 | out = self.relu(out) 79 | out = self.conv2(out) 80 | 81 | out = self.bn3(out) 82 | out = self.relu(out) 83 | out = self.conv3(out) 84 | 85 | if self.downsample is not None: 86 | residual = self.downsample(x) 87 | 88 | out += residual 89 | 90 | return out 91 | 92 | 93 | class PreResNet(nn.Module): 94 | 95 | def __init__(self, depth, num_classes=1000, block_name='BasicBlock'): 96 | super(PreResNet, self).__init__() 97 | # Model type specifies number of layers for CIFAR-10 model 98 | if block_name.lower() == 'basicblock': 99 | assert (depth - 2) % 6 == 0, 'When use basicblock, depth should be 6n+2, e.g. 20, 32, 44, 56, 110, 1202' 100 | n = (depth - 2) // 6 101 | block = BasicBlock 102 | elif block_name.lower() == 'bottleneck': 103 | assert (depth - 2) % 9 == 0, 'When use bottleneck, depth should be 9n+2, e.g. 20, 29, 47, 56, 110, 1199' 104 | n = (depth - 2) // 9 105 | block = Bottleneck 106 | else: 107 | raise ValueError('block_name shoule be Basicblock or Bottleneck') 108 | 109 | self.inplanes = 16 110 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, 111 | bias=False) 112 | self.layer1 = self._make_layer(block, 16, n) 113 | self.layer2 = self._make_layer(block, 32, n, stride=2) 114 | self.layer3 = self._make_layer(block, 64, n, stride=2) 115 | self.bn = nn.BatchNorm2d(64 * block.expansion) 116 | self.relu = nn.ReLU(inplace=True) 117 | self.avgpool = nn.AvgPool2d(8) 118 | self.fc = nn.Linear(64 * block.expansion, num_classes) 119 | 120 | for m in self.modules(): 121 | if isinstance(m, nn.Conv2d): 122 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 123 | m.weight.data.normal_(0, math.sqrt(2. / n)) 124 | elif isinstance(m, nn.BatchNorm2d): 125 | m.weight.data.fill_(1) 126 | m.bias.data.zero_() 127 | 128 | def _make_layer(self, block, planes, blocks, stride=1): 129 | downsample = None 130 | if stride != 1 or self.inplanes != planes * block.expansion: 131 | downsample = nn.Sequential( 132 | nn.Conv2d(self.inplanes, planes * block.expansion, 133 | kernel_size=1, stride=stride, bias=False), 134 | ) 135 | 136 | layers = [] 137 | layers.append(block(self.inplanes, planes, stride, downsample)) 138 | self.inplanes = planes * block.expansion 139 | for i in range(1, blocks): 140 | layers.append(block(self.inplanes, planes)) 141 | 142 | return nn.Sequential(*layers) 143 | 144 | def forward(self, x): 145 | x = self.conv1(x) 146 | 147 | x = self.layer1(x) # 32x32 148 | x = self.layer2(x) # 16x16 149 | x = self.layer3(x) # 8x8 150 | x = self.bn(x) 151 | x = self.relu(x) 152 | 153 | x = self.avgpool(x) 154 | x = x.view(x.size(0), -1) 155 | x = self.fc(x) 156 | 157 | return x 158 | 159 | 160 | def preresnet(**kwargs): 161 | """ 162 | Constructs a ResNet model. 163 | """ 164 | return PreResNet(**kwargs) 165 | -------------------------------------------------------------------------------- /models/cifar/resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | '''Resnet for cifar dataset. 4 | Ported form 5 | https://github.com/facebook/fb.resnet.torch 6 | and 7 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 8 | (c) YANG, Wei 9 | ''' 10 | import torch.nn as nn 11 | import math 12 | 13 | 14 | __all__ = ['resnet'] 15 | 16 | def conv3x3(in_planes, out_planes, stride=1): 17 | "3x3 convolution with padding" 18 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 19 | padding=1, bias=False) 20 | 21 | 22 | class BasicBlock(nn.Module): 23 | expansion = 1 24 | 25 | def __init__(self, inplanes, planes, stride=1, downsample=None): 26 | super(BasicBlock, self).__init__() 27 | self.conv1 = conv3x3(inplanes, planes, stride) 28 | self.bn1 = nn.BatchNorm2d(planes) 29 | self.relu = nn.ReLU(inplace=True) 30 | self.conv2 = conv3x3(planes, planes) 31 | self.bn2 = nn.BatchNorm2d(planes) 32 | self.downsample = downsample 33 | self.stride = stride 34 | 35 | def forward(self, x): 36 | residual = x 37 | 38 | out = self.conv1(x) 39 | out = self.bn1(out) 40 | out = self.relu(out) 41 | 42 | out = self.conv2(out) 43 | out = self.bn2(out) 44 | 45 | if self.downsample is not None: 46 | residual = self.downsample(x) 47 | 48 | out += residual 49 | out = self.relu(out) 50 | 51 | return out 52 | 53 | 54 | class Bottleneck(nn.Module): 55 | expansion = 4 56 | 57 | def __init__(self, inplanes, planes, stride=1, downsample=None): 58 | super(Bottleneck, self).__init__() 59 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 60 | self.bn1 = nn.BatchNorm2d(planes) 61 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 62 | padding=1, bias=False) 63 | self.bn2 = nn.BatchNorm2d(planes) 64 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 65 | self.bn3 = nn.BatchNorm2d(planes * 4) 66 | self.relu = nn.ReLU(inplace=True) 67 | self.downsample = downsample 68 | self.stride = stride 69 | 70 | def forward(self, x): 71 | residual = x 72 | 73 | out = self.conv1(x) 74 | out = self.bn1(out) 75 | out = self.relu(out) 76 | 77 | out = self.conv2(out) 78 | out = self.bn2(out) 79 | out = self.relu(out) 80 | 81 | out = self.conv3(out) 82 | out = self.bn3(out) 83 | 84 | if self.downsample is not None: 85 | residual = self.downsample(x) 86 | 87 | out += residual 88 | out = self.relu(out) 89 | 90 | return out 91 | 92 | 93 | class ResNet(nn.Module): 94 | 95 | def __init__(self, depth, num_classes=1000, block_name='BasicBlock', **kwargs): 96 | super(ResNet, self).__init__() 97 | # Model type specifies number of layers for CIFAR-10 model 98 | if block_name.lower() == 'basicblock': 99 | assert (depth - 2) % 6 == 0, 'When use basicblock, depth should be 6n+2, e.g. 20, 32, 44, 56, 110, 1202' 100 | n = (depth - 2) // 6 101 | block = BasicBlock 102 | elif block_name.lower() == 'bottleneck': 103 | assert (depth - 2) % 9 == 0, 'When use bottleneck, depth should be 9n+2, e.g. 20, 29, 47, 56, 110, 1199' 104 | n = (depth - 2) // 9 105 | block = Bottleneck 106 | else: 107 | raise ValueError('block_name shoule be Basicblock or Bottleneck') 108 | 109 | 110 | self.inplanes = 16 111 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, 112 | bias=False) 113 | self.bn1 = nn.BatchNorm2d(16) 114 | self.relu = nn.ReLU(inplace=True) 115 | self.layer1 = self._make_layer(block, 16, n) 116 | self.layer2 = self._make_layer(block, 32, n, stride=2) 117 | self.layer3 = self._make_layer(block, 64, n, stride=2) 118 | self.avgpool = nn.AvgPool2d(8) 119 | self.fc = nn.Linear(64 * block.expansion, num_classes) 120 | 121 | for m in self.modules(): 122 | if isinstance(m, nn.Conv2d): 123 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 124 | m.weight.data.normal_(0, math.sqrt(2. / n)) 125 | elif isinstance(m, nn.BatchNorm2d): 126 | m.weight.data.fill_(1) 127 | m.bias.data.zero_() 128 | 129 | def _make_layer(self, block, planes, blocks, stride=1): 130 | downsample = None 131 | if stride != 1 or self.inplanes != planes * block.expansion: 132 | downsample = nn.Sequential( 133 | nn.Conv2d(self.inplanes, planes * block.expansion, 134 | kernel_size=1, stride=stride, bias=False), 135 | nn.BatchNorm2d(planes * block.expansion), 136 | ) 137 | 138 | layers = [] 139 | layers.append(block(self.inplanes, planes, stride, downsample)) 140 | self.inplanes = planes * block.expansion 141 | for i in range(1, blocks): 142 | layers.append(block(self.inplanes, planes)) 143 | 144 | return nn.Sequential(*layers) 145 | 146 | def forward(self, x): 147 | x = self.conv1(x) 148 | x = self.bn1(x) 149 | x = self.relu(x) # 32x32 150 | 151 | x = self.layer1(x) # 32x32 152 | x = self.layer2(x) # 16x16 153 | x = self.layer3(x) # 8x8 154 | 155 | x = self.avgpool(x) 156 | x = x.view(x.size(0), -1) 157 | x = self.fc(x) 158 | 159 | return x 160 | 161 | 162 | def resnet(**kwargs): 163 | """ 164 | Constructs a ResNet model. 165 | """ 166 | return ResNet(**kwargs) 167 | -------------------------------------------------------------------------------- /models/cifar/resnext.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | """ 3 | Creates a ResNeXt Model as defined in: 4 | Xie, S., Girshick, R., Dollar, P., Tu, Z., & He, K. (2016). 5 | Aggregated residual transformations for deep neural networks. 6 | arXiv preprint arXiv:1611.05431. 7 | import from https://github.com/prlz77/ResNeXt.pytorch/blob/master/models/model.py 8 | """ 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.nn import init 12 | 13 | __all__ = ['resnext'] 14 | 15 | class ResNeXtBottleneck(nn.Module): 16 | """ 17 | RexNeXt bottleneck type C (https://github.com/facebookresearch/ResNeXt/blob/master/models/resnext.lua) 18 | """ 19 | def __init__(self, in_channels, out_channels, stride, cardinality, widen_factor): 20 | """ Constructor 21 | Args: 22 | in_channels: input channel dimensionality 23 | out_channels: output channel dimensionality 24 | stride: conv stride. Replaces pooling layer. 25 | cardinality: num of convolution groups. 26 | widen_factor: factor to reduce the input dimensionality before convolution. 27 | """ 28 | super(ResNeXtBottleneck, self).__init__() 29 | D = cardinality * out_channels // widen_factor 30 | self.conv_reduce = nn.Conv2d(in_channels, D, kernel_size=1, stride=1, padding=0, bias=False) 31 | self.bn_reduce = nn.BatchNorm2d(D) 32 | self.conv_conv = nn.Conv2d(D, D, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False) 33 | self.bn = nn.BatchNorm2d(D) 34 | self.conv_expand = nn.Conv2d(D, out_channels, kernel_size=1, stride=1, padding=0, bias=False) 35 | self.bn_expand = nn.BatchNorm2d(out_channels) 36 | 37 | self.shortcut = nn.Sequential() 38 | if in_channels != out_channels: 39 | self.shortcut.add_module('shortcut_conv', nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=False)) 40 | self.shortcut.add_module('shortcut_bn', nn.BatchNorm2d(out_channels)) 41 | 42 | def forward(self, x): 43 | bottleneck = self.conv_reduce.forward(x) 44 | bottleneck = F.relu(self.bn_reduce.forward(bottleneck), inplace=True) 45 | bottleneck = self.conv_conv.forward(bottleneck) 46 | bottleneck = F.relu(self.bn.forward(bottleneck), inplace=True) 47 | bottleneck = self.conv_expand.forward(bottleneck) 48 | bottleneck = self.bn_expand.forward(bottleneck) 49 | residual = self.shortcut.forward(x) 50 | return F.relu(residual + bottleneck, inplace=True) 51 | 52 | 53 | class CifarResNeXt(nn.Module): 54 | """ 55 | ResNext optimized for the Cifar dataset, as specified in 56 | https://arxiv.org/pdf/1611.05431.pdf 57 | """ 58 | def __init__(self, cardinality, depth, num_classes, widen_factor=4, dropRate=0): 59 | """ Constructor 60 | Args: 61 | cardinality: number of convolution groups. 62 | depth: number of layers. 63 | num_classes: number of classes 64 | widen_factor: factor to adjust the channel dimensionality 65 | """ 66 | super(CifarResNeXt, self).__init__() 67 | self.cardinality = cardinality 68 | self.depth = depth 69 | self.block_depth = (self.depth - 2) // 9 70 | self.widen_factor = widen_factor 71 | self.num_classes = num_classes 72 | self.output_size = 64 73 | self.stages = [64, 64 * self.widen_factor, 128 * self.widen_factor, 256 * self.widen_factor] 74 | 75 | self.conv_1_3x3 = nn.Conv2d(3, 64, 3, 1, 1, bias=False) 76 | self.bn_1 = nn.BatchNorm2d(64) 77 | self.stage_1 = self.block('stage_1', self.stages[0], self.stages[1], 1) 78 | self.stage_2 = self.block('stage_2', self.stages[1], self.stages[2], 2) 79 | self.stage_3 = self.block('stage_3', self.stages[2], self.stages[3], 2) 80 | self.classifier = nn.Linear(1024, num_classes) 81 | init.kaiming_normal(self.classifier.weight) 82 | 83 | for key in self.state_dict(): 84 | if key.split('.')[-1] == 'weight': 85 | if 'conv' in key: 86 | init.kaiming_normal(self.state_dict()[key], mode='fan_out') 87 | if 'bn' in key: 88 | self.state_dict()[key][...] = 1 89 | elif key.split('.')[-1] == 'bias': 90 | self.state_dict()[key][...] = 0 91 | 92 | def block(self, name, in_channels, out_channels, pool_stride=2): 93 | """ Stack n bottleneck modules where n is inferred from the depth of the network. 94 | Args: 95 | name: string name of the current block. 96 | in_channels: number of input channels 97 | out_channels: number of output channels 98 | pool_stride: factor to reduce the spatial dimensionality in the first bottleneck of the block. 99 | Returns: a Module consisting of n sequential bottlenecks. 100 | """ 101 | block = nn.Sequential() 102 | for bottleneck in range(self.block_depth): 103 | name_ = '%s_bottleneck_%d' % (name, bottleneck) 104 | if bottleneck == 0: 105 | block.add_module(name_, ResNeXtBottleneck(in_channels, out_channels, pool_stride, self.cardinality, 106 | self.widen_factor)) 107 | else: 108 | block.add_module(name_, 109 | ResNeXtBottleneck(out_channels, out_channels, 1, self.cardinality, self.widen_factor)) 110 | return block 111 | 112 | def forward(self, x): 113 | x = self.conv_1_3x3.forward(x) 114 | x = F.relu(self.bn_1.forward(x), inplace=True) 115 | x = self.stage_1.forward(x) 116 | x = self.stage_2.forward(x) 117 | x = self.stage_3.forward(x) 118 | x = F.avg_pool2d(x, 8, 1) 119 | x = x.view(-1, 1024) 120 | return self.classifier(x) 121 | 122 | def resnext(**kwargs): 123 | """Constructs a ResNeXt. 124 | """ 125 | model = CifarResNeXt(**kwargs) 126 | return model -------------------------------------------------------------------------------- /models/cifar/vgg.py: -------------------------------------------------------------------------------- 1 | '''VGG for CIFAR10. FC layers are removed. 2 | (c) YANG, Wei 3 | ''' 4 | import torch.nn as nn 5 | import torch.utils.model_zoo as model_zoo 6 | import math 7 | 8 | 9 | __all__ = [ 10 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 11 | 'vgg19_bn', 'vgg19', 12 | ] 13 | 14 | 15 | model_urls = { 16 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 17 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 18 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 19 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 20 | } 21 | 22 | 23 | class VGG(nn.Module): 24 | 25 | def __init__(self, features, num_classes=1000, **kwargs): 26 | super(VGG, self).__init__() 27 | self.features = features 28 | self.classifier = nn.Linear(512, num_classes) 29 | self._initialize_weights() 30 | 31 | def forward(self, x): 32 | x = self.features(x) 33 | x = x.view(x.size(0), -1) 34 | x = self.classifier(x) 35 | return x 36 | 37 | def _initialize_weights(self): 38 | for m in self.modules(): 39 | if isinstance(m, nn.Conv2d): 40 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 41 | m.weight.data.normal_(0, math.sqrt(2. / n)) 42 | if m.bias is not None: 43 | m.bias.data.zero_() 44 | elif isinstance(m, nn.BatchNorm2d): 45 | m.weight.data.fill_(1) 46 | m.bias.data.zero_() 47 | elif isinstance(m, nn.Linear): 48 | n = m.weight.size(1) 49 | m.weight.data.normal_(0, 0.01) 50 | m.bias.data.zero_() 51 | 52 | 53 | def make_layers(cfg, batch_norm=False): 54 | layers = [] 55 | in_channels = 3 56 | for v in cfg: 57 | if v == 'M': 58 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 59 | else: 60 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 61 | if batch_norm: 62 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 63 | else: 64 | layers += [conv2d, nn.ReLU(inplace=True)] 65 | in_channels = v 66 | return nn.Sequential(*layers) 67 | 68 | 69 | cfg = { 70 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 71 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 72 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 73 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 74 | } 75 | 76 | 77 | def vgg11(**kwargs): 78 | """VGG 11-layer model (configuration "A") 79 | 80 | Args: 81 | pretrained (bool): If True, returns a model pre-trained on ImageNet 82 | """ 83 | model = VGG(make_layers(cfg['A']), **kwargs) 84 | return model 85 | 86 | 87 | def vgg11_bn(**kwargs): 88 | """VGG 11-layer model (configuration "A") with batch normalization""" 89 | model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs) 90 | return model 91 | 92 | 93 | def vgg13(**kwargs): 94 | """VGG 13-layer model (configuration "B") 95 | 96 | Args: 97 | pretrained (bool): If True, returns a model pre-trained on ImageNet 98 | """ 99 | model = VGG(make_layers(cfg['B']), **kwargs) 100 | return model 101 | 102 | 103 | def vgg13_bn(**kwargs): 104 | """VGG 13-layer model (configuration "B") with batch normalization""" 105 | model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs) 106 | return model 107 | 108 | 109 | def vgg16(**kwargs): 110 | """VGG 16-layer model (configuration "D") 111 | 112 | Args: 113 | pretrained (bool): If True, returns a model pre-trained on ImageNet 114 | """ 115 | model = VGG(make_layers(cfg['D']), **kwargs) 116 | return model 117 | 118 | 119 | def vgg16_bn(**kwargs): 120 | """VGG 16-layer model (configuration "D") with batch normalization""" 121 | model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs) 122 | return model 123 | 124 | 125 | def vgg19(**kwargs): 126 | """VGG 19-layer model (configuration "E") 127 | 128 | Args: 129 | pretrained (bool): If True, returns a model pre-trained on ImageNet 130 | """ 131 | model = VGG(make_layers(cfg['E']), **kwargs) 132 | return model 133 | 134 | 135 | def vgg19_bn(**kwargs): 136 | """VGG 19-layer model (configuration 'E') with batch normalization""" 137 | model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs) 138 | return model 139 | -------------------------------------------------------------------------------- /models/cifar/wrn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | __all__ = ['wrn'] 7 | 8 | class BasicBlock(nn.Module): 9 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 10 | super(BasicBlock, self).__init__() 11 | self.bn1 = nn.BatchNorm2d(in_planes) 12 | self.relu1 = nn.ReLU(inplace=True) 13 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 14 | padding=1, bias=False) 15 | self.bn2 = nn.BatchNorm2d(out_planes) 16 | self.relu2 = nn.ReLU(inplace=True) 17 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 18 | padding=1, bias=False) 19 | self.droprate = dropRate 20 | self.equalInOut = (in_planes == out_planes) 21 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 22 | padding=0, bias=False) or None 23 | def forward(self, x): 24 | if not self.equalInOut: 25 | x = self.relu1(self.bn1(x)) 26 | else: 27 | out = self.relu1(self.bn1(x)) 28 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 29 | if self.droprate > 0: 30 | out = F.dropout(out, p=self.droprate, training=self.training) 31 | out = self.conv2(out) 32 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 33 | 34 | class NetworkBlock(nn.Module): 35 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 36 | super(NetworkBlock, self).__init__() 37 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 38 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 39 | layers = [] 40 | for i in range(nb_layers): 41 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 42 | return nn.Sequential(*layers) 43 | def forward(self, x): 44 | return self.layer(x) 45 | 46 | class WideResNet(nn.Module): 47 | def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0, **kwargs): 48 | super(WideResNet, self).__init__() 49 | nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] 50 | assert (depth - 4) % 6 == 0, 'depth should be 6n+4' 51 | n = (depth - 4) // 6 52 | block = BasicBlock 53 | # 1st conv before any network block 54 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 55 | padding=1, bias=False) 56 | # 1st block 57 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 58 | # 2nd block 59 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 60 | # 3rd block 61 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 62 | # global average pooling and classifier 63 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 64 | self.relu = nn.ReLU(inplace=True) 65 | self.fc = nn.Linear(nChannels[3], num_classes) 66 | self.nChannels = nChannels[3] 67 | 68 | for m in self.modules(): 69 | if isinstance(m, nn.Conv2d): 70 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 71 | m.weight.data.normal_(0, math.sqrt(2. / n)) 72 | elif isinstance(m, nn.BatchNorm2d): 73 | m.weight.data.fill_(1) 74 | m.bias.data.zero_() 75 | elif isinstance(m, nn.Linear): 76 | m.bias.data.zero_() 77 | 78 | def forward(self, x): 79 | out = self.conv1(x) 80 | out = self.block1(out) 81 | out = self.block2(out) 82 | out = self.block3(out) 83 | out = self.relu(self.bn1(out)) 84 | out = F.avg_pool2d(out, 8) 85 | out = out.view(-1, self.nChannels) 86 | return self.fc(out) 87 | 88 | def wrn(**kwargs): 89 | """ 90 | Constructs a Wide Residual Networks. 91 | """ 92 | model = WideResNet(**kwargs) 93 | return model 94 | -------------------------------------------------------------------------------- /models/imagenet/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .resnext import * 4 | -------------------------------------------------------------------------------- /models/imagenet/resnext.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | """ 3 | Creates a ResNeXt Model as defined in: 4 | Xie, S., Girshick, R., Dollar, P., Tu, Z., & He, K. (2016). 5 | Aggregated residual transformations for deep neural networks. 6 | arXiv preprint arXiv:1611.05431. 7 | import from https://github.com/facebookresearch/ResNeXt/blob/master/models/resnext.lua 8 | """ 9 | import math 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torch.nn import init 13 | import torch 14 | 15 | __all__ = ['resnext50', 'resnext101', 'resnext152'] 16 | 17 | class Bottleneck(nn.Module): 18 | """ 19 | RexNeXt bottleneck type C 20 | """ 21 | expansion = 4 22 | 23 | def __init__(self, inplanes, planes, baseWidth, cardinality, stride=1, downsample=None): 24 | """ Constructor 25 | Args: 26 | inplanes: input channel dimensionality 27 | planes: output channel dimensionality 28 | baseWidth: base width. 29 | cardinality: num of convolution groups. 30 | stride: conv stride. Replaces pooling layer. 31 | """ 32 | super(Bottleneck, self).__init__() 33 | 34 | D = int(math.floor(planes * (baseWidth / 64))) 35 | C = cardinality 36 | 37 | self.conv1 = nn.Conv2d(inplanes, D*C, kernel_size=1, stride=1, padding=0, bias=False) 38 | self.bn1 = nn.BatchNorm2d(D*C) 39 | self.conv2 = nn.Conv2d(D*C, D*C, kernel_size=3, stride=stride, padding=1, groups=C, bias=False) 40 | self.bn2 = nn.BatchNorm2d(D*C) 41 | self.conv3 = nn.Conv2d(D*C, planes * 4, kernel_size=1, stride=1, padding=0, bias=False) 42 | self.bn3 = nn.BatchNorm2d(planes * 4) 43 | self.relu = nn.ReLU(inplace=True) 44 | 45 | self.downsample = downsample 46 | 47 | def forward(self, x): 48 | residual = x 49 | 50 | out = self.conv1(x) 51 | out = self.bn1(out) 52 | out = self.relu(out) 53 | 54 | out = self.conv2(out) 55 | out = self.bn2(out) 56 | out = self.relu(out) 57 | 58 | out = self.conv3(out) 59 | out = self.bn3(out) 60 | 61 | if self.downsample is not None: 62 | residual = self.downsample(x) 63 | 64 | out += residual 65 | out = self.relu(out) 66 | 67 | return out 68 | 69 | 70 | class ResNeXt(nn.Module): 71 | """ 72 | ResNext optimized for the ImageNet dataset, as specified in 73 | https://arxiv.org/pdf/1611.05431.pdf 74 | """ 75 | def __init__(self, baseWidth, cardinality, layers, num_classes): 76 | """ Constructor 77 | Args: 78 | baseWidth: baseWidth for ResNeXt. 79 | cardinality: number of convolution groups. 80 | layers: config of layers, e.g., [3, 4, 6, 3] 81 | num_classes: number of classes 82 | """ 83 | super(ResNeXt, self).__init__() 84 | block = Bottleneck 85 | 86 | self.cardinality = cardinality 87 | self.baseWidth = baseWidth 88 | self.num_classes = num_classes 89 | self.inplanes = 64 90 | self.output_size = 64 91 | 92 | self.conv1 = nn.Conv2d(3, 64, 7, 2, 3, bias=False) 93 | self.bn1 = nn.BatchNorm2d(64) 94 | self.relu = nn.ReLU(inplace=True) 95 | self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 96 | self.layer1 = self._make_layer(block, 64, layers[0]) 97 | self.layer2 = self._make_layer(block, 128, layers[1], 2) 98 | self.layer3 = self._make_layer(block, 256, layers[2], 2) 99 | self.layer4 = self._make_layer(block, 512, layers[3], 2) 100 | self.avgpool = nn.AvgPool2d(7) 101 | self.fc = nn.Linear(512 * block.expansion, num_classes) 102 | 103 | for m in self.modules(): 104 | if isinstance(m, nn.Conv2d): 105 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 106 | m.weight.data.normal_(0, math.sqrt(2. / n)) 107 | elif isinstance(m, nn.BatchNorm2d): 108 | m.weight.data.fill_(1) 109 | m.bias.data.zero_() 110 | 111 | def _make_layer(self, block, planes, blocks, stride=1): 112 | """ Stack n bottleneck modules where n is inferred from the depth of the network. 113 | Args: 114 | block: block type used to construct ResNext 115 | planes: number of output channels (need to multiply by block.expansion) 116 | blocks: number of blocks to be built 117 | stride: factor to reduce the spatial dimensionality in the first bottleneck of the block. 118 | Returns: a Module consisting of n sequential bottlenecks. 119 | """ 120 | downsample = None 121 | if stride != 1 or self.inplanes != planes * block.expansion: 122 | downsample = nn.Sequential( 123 | nn.Conv2d(self.inplanes, planes * block.expansion, 124 | kernel_size=1, stride=stride, bias=False), 125 | nn.BatchNorm2d(planes * block.expansion), 126 | ) 127 | 128 | layers = [] 129 | layers.append(block(self.inplanes, planes, self.baseWidth, self.cardinality, stride, downsample)) 130 | self.inplanes = planes * block.expansion 131 | for i in range(1, blocks): 132 | layers.append(block(self.inplanes, planes, self.baseWidth, self.cardinality)) 133 | 134 | return nn.Sequential(*layers) 135 | 136 | def forward(self, x): 137 | x = self.conv1(x) 138 | x = self.bn1(x) 139 | x = self.relu(x) 140 | x = self.maxpool1(x) 141 | x = self.layer1(x) 142 | x = self.layer2(x) 143 | x = self.layer3(x) 144 | x = self.layer4(x) 145 | x = self.avgpool(x) 146 | x = x.view(x.size(0), -1) 147 | x = self.fc(x) 148 | 149 | return x 150 | 151 | 152 | def resnext50(baseWidth, cardinality): 153 | """ 154 | Construct ResNeXt-50. 155 | """ 156 | model = ResNeXt(baseWidth, cardinality, [3, 4, 6, 3], 1000) 157 | return model 158 | 159 | 160 | def resnext101(baseWidth, cardinality): 161 | """ 162 | Construct ResNeXt-101. 163 | """ 164 | model = ResNeXt(baseWidth, cardinality, [3, 4, 23, 3], 1000) 165 | return model 166 | 167 | 168 | def resnext152(baseWidth, cardinality): 169 | """ 170 | Construct ResNeXt-152. 171 | """ 172 | model = ResNeXt(baseWidth, cardinality, [3, 8, 36, 3], 1000) 173 | return model 174 | -------------------------------------------------------------------------------- /optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .kfac import KFACOptimizer 2 | from .ekfac import EKFACOptimizer 3 | 4 | 5 | def get_optimizer(name): 6 | if name == 'kfac': 7 | return KFACOptimizer 8 | elif name == 'ekfac': 9 | return EKFACOptimizer 10 | else: 11 | raise NotImplementedError -------------------------------------------------------------------------------- /optimizers/ekfac.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.optim as optim 5 | 6 | from utils.kfac_utils import (ComputeCovA, ComputeCovG, ComputeMatGrad) 7 | from utils.kfac_utils import update_running_stat 8 | 9 | 10 | class EKFACOptimizer(optim.Optimizer): 11 | def __init__(self, 12 | model, 13 | lr=0.001, 14 | momentum=0.9, 15 | stat_decay=0.95, 16 | damping=0.001, 17 | kl_clip=0.001, 18 | weight_decay=0, 19 | TCov=10, 20 | TScal=10, 21 | TInv=100, 22 | batch_averaged=True): 23 | if lr < 0.0: 24 | raise ValueError("Invalid learning rate: {}".format(lr)) 25 | if momentum < 0.0: 26 | raise ValueError("Invalid momentum value: {}".format(momentum)) 27 | if weight_decay < 0.0: 28 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 29 | defaults = dict(lr=lr, momentum=momentum, damping=damping, 30 | weight_decay=weight_decay) 31 | # TODO (CW): EKFAC optimizer now only support model as input 32 | super(EKFACOptimizer, self).__init__(model.parameters(), defaults) 33 | self.CovAHandler = ComputeCovA() 34 | self.CovGHandler = ComputeCovG() 35 | self.MatGradHandler = ComputeMatGrad() 36 | self.batch_averaged = batch_averaged 37 | 38 | self.known_modules = {'Linear', 'Conv2d'} 39 | 40 | self.modules = [] 41 | self.grad_outputs = {} 42 | 43 | self.model = model 44 | self._prepare_model() 45 | 46 | self.steps = 0 47 | 48 | self.m_aa, self.m_gg = {}, {} 49 | self.Q_a, self.Q_g = {}, {} 50 | self.d_a, self.d_g = {}, {} 51 | self.S_l = {} 52 | self.A, self.DS = {}, {} 53 | self.stat_decay = stat_decay 54 | 55 | self.kl_clip = kl_clip 56 | self.TCov = TCov 57 | self.TScal = TScal 58 | self.TInv = TInv 59 | 60 | def _save_input(self, module, input): 61 | if torch.is_grad_enabled() and self.steps % self.TCov == 0: 62 | aa = self.CovAHandler(input[0].data, module) 63 | # Initialize buffers 64 | if self.steps == 0: 65 | self.m_aa[module] = torch.diag(aa.new(aa.size(0)).fill_(1)) 66 | update_running_stat(aa, self.m_aa[module], self.stat_decay) 67 | if torch.is_grad_enabled() and self.steps % self.TScal == 0 and self.steps > 0: 68 | self.A[module] = input[0].data 69 | 70 | def _save_grad_output(self, module, grad_input, grad_output): 71 | # Accumulate statistics for Fisher matrices 72 | if self.acc_stats and self.steps % self.TCov == 0: 73 | gg = self.CovGHandler(grad_output[0].data, module, self.batch_averaged) 74 | # Initialize buffers 75 | if self.steps == 0: 76 | self.m_gg[module] = torch.diag(gg.new(gg.size(0)).fill_(1)) 77 | update_running_stat(gg, self.m_gg[module], self.stat_decay) 78 | 79 | # if self.steps % self.TInv == 0: 80 | # self._update_inv(module) 81 | 82 | if self.acc_stats and self.steps % self.TScal == 0 and self.steps > 0: 83 | self.DS[module] = grad_output[0].data 84 | # self._update_scale(module) 85 | 86 | def _prepare_model(self): 87 | count = 0 88 | print(self.model) 89 | print("=> We keep following layers in EKFAC. ") 90 | for module in self.model.modules(): 91 | classname = module.__class__.__name__ 92 | if classname in self.known_modules: 93 | self.modules.append(module) 94 | module.register_forward_pre_hook(self._save_input) 95 | module.register_backward_hook(self._save_grad_output) 96 | print('(%s): %s' % (count, module)) 97 | count += 1 98 | 99 | def _update_inv(self, m): 100 | """Do eigen decomposition for computing inverse of the ~ fisher. 101 | :param m: The layer 102 | :return: no returns. 103 | """ 104 | eps = 1e-10 # for numerical stability 105 | self.d_a[m], self.Q_a[m] = torch.symeig( 106 | self.m_aa[m], eigenvectors=True) 107 | self.d_g[m], self.Q_g[m] = torch.symeig( 108 | self.m_gg[m], eigenvectors=True) 109 | 110 | self.d_a[m].mul_((self.d_a[m] > eps).float()) 111 | self.d_g[m].mul_((self.d_g[m] > eps).float()) 112 | # if self.steps != 0: 113 | self.S_l[m] = self.d_g[m].unsqueeze(1) @ self.d_a[m].unsqueeze(0) 114 | 115 | @staticmethod 116 | def _get_matrix_form_grad(m, classname): 117 | """ 118 | :param m: the layer 119 | :param classname: the class name of the layer 120 | :return: a matrix form of the gradient. it should be a [output_dim, input_dim] matrix. 121 | """ 122 | if classname == 'Conv2d': 123 | p_grad_mat = m.weight.grad.data.view(m.weight.grad.data.size(0), -1) # n_filters * (in_c * kw * kh) 124 | else: 125 | p_grad_mat = m.weight.grad.data 126 | if m.bias is not None: 127 | p_grad_mat = torch.cat([p_grad_mat, m.bias.grad.data.view(-1, 1)], 1) 128 | return p_grad_mat 129 | 130 | def _get_natural_grad(self, m, p_grad_mat, damping): 131 | """ 132 | :param m: the layer 133 | :param p_grad_mat: the gradients in matrix form 134 | :return: a list of gradients w.r.t to the parameters in `m` 135 | """ 136 | # p_grad_mat is of output_dim * input_dim 137 | # inv((ss')) p_grad_mat inv(aa') = [ Q_g (1/R_g) Q_g^T ] @ p_grad_mat @ [Q_a (1/R_a) Q_a^T] 138 | v1 = self.Q_g[m].t() @ p_grad_mat @ self.Q_a[m] 139 | v2 = v1 / (self.S_l[m] + damping) 140 | v = self.Q_g[m] @ v2 @ self.Q_a[m].t() 141 | if m.bias is not None: 142 | # we always put gradient w.r.t weight in [0] 143 | # and w.r.t bias in [1] 144 | v = [v[:, :-1], v[:, -1:]] 145 | v[0] = v[0].view(m.weight.grad.data.size()) 146 | v[1] = v[1].view(m.bias.grad.data.size()) 147 | else: 148 | v = [v.view(m.weight.grad.data.size())] 149 | 150 | return v 151 | 152 | def _kl_clip_and_update_grad(self, updates, lr): 153 | # do kl clip 154 | vg_sum = 0 155 | for m in self.modules: 156 | v = updates[m] 157 | vg_sum += (v[0] * m.weight.grad.data * lr ** 2).sum().item() 158 | if m.bias is not None: 159 | vg_sum += (v[1] * m.bias.grad.data * lr ** 2).sum().item() 160 | nu = min(1.0, math.sqrt(self.kl_clip / vg_sum)) 161 | 162 | for m in self.modules: 163 | v = updates[m] 164 | m.weight.grad.data.copy_(v[0]) 165 | m.weight.grad.data.mul_(nu) 166 | if m.bias is not None: 167 | m.bias.grad.data.copy_(v[1]) 168 | m.bias.grad.data.mul_(nu) 169 | 170 | def _step(self, closure): 171 | # FIXME (CW): Modified based on SGD (removed nestrov and dampening in momentum.) 172 | # FIXME (CW): 1. no nesterov, 2. buf.mul_(momentum).add_(1 - dampening , d_p) 173 | for group in self.param_groups: 174 | weight_decay = group['weight_decay'] 175 | momentum = group['momentum'] 176 | 177 | for p in group['params']: 178 | if p.grad is None: 179 | continue 180 | d_p = p.grad.data 181 | if weight_decay != 0 and self.steps >= 20 * self.TCov: 182 | d_p.add_(weight_decay, p.data) 183 | if momentum != 0: 184 | param_state = self.state[p] 185 | if 'momentum_buffer' not in param_state: 186 | buf = param_state['momentum_buffer'] = torch.zeros_like(p.data) 187 | buf.mul_(momentum).add_(d_p) 188 | else: 189 | buf = param_state['momentum_buffer'] 190 | buf.mul_(momentum).add_(1, d_p) 191 | d_p = buf 192 | 193 | p.data.add_(-group['lr'], d_p) 194 | 195 | def _update_scale(self, m): 196 | with torch.no_grad(): 197 | A, S = self.A[m], self.DS[m] 198 | grad_mat = self.MatGradHandler(A, S, m) # batch_size * out_dim * in_dim 199 | if self.batch_averaged: 200 | grad_mat *= S.size(0) 201 | 202 | s_l = (self.Q_g[m] @ grad_mat @ self.Q_a[m].t()) ** 2 # <- this consumes too much memory! 203 | s_l = s_l.mean(dim=0) 204 | if self.steps == 0: 205 | self.S_l[m] = s_l.new(s_l.size()).fill_(1) 206 | # s_ls = self.Q_g[m] @ grad_s 207 | # s_la = in_a @ self.Q_a[m].t() 208 | # s_l = 0 209 | # for i in range(0, s_ls.size(0), S.size(0)): # tradeoff between time and memory 210 | # start = i 211 | # end = min(s_ls.size(0), i + S.size(0)) 212 | # s_l += (torch.bmm(s_ls[start:end,:], s_la[start:end,:]) ** 2).sum(0) 213 | # s_l /= s_ls.size(0) 214 | # if self.steps == 0: 215 | # self.S_l[m] = s_l.new(s_l.size()).fill_(1) 216 | update_running_stat(s_l, self.S_l[m], self.stat_decay) 217 | # remove reference for reducing memory cost. 218 | self.A[m] = None 219 | self.DS[m] = None 220 | 221 | def step(self, closure=None): 222 | # FIXME(CW): temporal fix for compatibility with Official LR scheduler. 223 | group = self.param_groups[0] 224 | lr = group['lr'] 225 | damping = group['damping'] 226 | updates = {} 227 | for m in self.modules: 228 | classname = m.__class__.__name__ 229 | if self.steps % self.TInv == 0: 230 | self._update_inv(m) 231 | 232 | if self.steps % self.TScal == 0 and self.steps > 0: 233 | self._update_scale(m) 234 | 235 | p_grad_mat = self._get_matrix_form_grad(m, classname) 236 | v = self._get_natural_grad(m, p_grad_mat, damping) 237 | updates[m] = v 238 | self._kl_clip_and_update_grad(updates, lr) 239 | 240 | self._step(closure) 241 | self.steps += 1 242 | -------------------------------------------------------------------------------- /optimizers/kfac.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.optim as optim 5 | 6 | from utils.kfac_utils import (ComputeCovA, ComputeCovG) 7 | from utils.kfac_utils import update_running_stat 8 | 9 | 10 | class KFACOptimizer(optim.Optimizer): 11 | def __init__(self, 12 | model, 13 | lr=0.001, 14 | momentum=0.9, 15 | stat_decay=0.95, 16 | damping=0.001, 17 | kl_clip=0.001, 18 | weight_decay=0, 19 | TCov=10, 20 | TInv=100, 21 | batch_averaged=True): 22 | if lr < 0.0: 23 | raise ValueError("Invalid learning rate: {}".format(lr)) 24 | if momentum < 0.0: 25 | raise ValueError("Invalid momentum value: {}".format(momentum)) 26 | if weight_decay < 0.0: 27 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 28 | defaults = dict(lr=lr, momentum=momentum, damping=damping, 29 | weight_decay=weight_decay) 30 | # TODO (CW): KFAC optimizer now only support model as input 31 | super(KFACOptimizer, self).__init__(model.parameters(), defaults) 32 | self.CovAHandler = ComputeCovA() 33 | self.CovGHandler = ComputeCovG() 34 | self.batch_averaged = batch_averaged 35 | 36 | self.known_modules = {'Linear', 'Conv2d'} 37 | 38 | self.modules = [] 39 | self.grad_outputs = {} 40 | 41 | self.model = model 42 | self._prepare_model() 43 | 44 | self.steps = 0 45 | 46 | self.m_aa, self.m_gg = {}, {} 47 | self.Q_a, self.Q_g = {}, {} 48 | self.d_a, self.d_g = {}, {} 49 | self.stat_decay = stat_decay 50 | 51 | self.kl_clip = kl_clip 52 | self.TCov = TCov 53 | self.TInv = TInv 54 | 55 | def _save_input(self, module, input): 56 | if torch.is_grad_enabled() and self.steps % self.TCov == 0: 57 | aa = self.CovAHandler(input[0].data, module) 58 | # Initialize buffers 59 | if self.steps == 0: 60 | self.m_aa[module] = torch.diag(aa.new(aa.size(0)).fill_(1)) 61 | update_running_stat(aa, self.m_aa[module], self.stat_decay) 62 | 63 | def _save_grad_output(self, module, grad_input, grad_output): 64 | # Accumulate statistics for Fisher matrices 65 | if self.acc_stats and self.steps % self.TCov == 0: 66 | gg = self.CovGHandler(grad_output[0].data, module, self.batch_averaged) 67 | # Initialize buffers 68 | if self.steps == 0: 69 | self.m_gg[module] = torch.diag(gg.new(gg.size(0)).fill_(1)) 70 | update_running_stat(gg, self.m_gg[module], self.stat_decay) 71 | 72 | def _prepare_model(self): 73 | count = 0 74 | print(self.model) 75 | print("=> We keep following layers in KFAC. ") 76 | for module in self.model.modules(): 77 | classname = module.__class__.__name__ 78 | # print('=> We keep following layers in KFAC. <=') 79 | if classname in self.known_modules: 80 | self.modules.append(module) 81 | module.register_forward_pre_hook(self._save_input) 82 | module.register_backward_hook(self._save_grad_output) 83 | print('(%s): %s' % (count, module)) 84 | count += 1 85 | 86 | def _update_inv(self, m): 87 | """Do eigen decomposition for computing inverse of the ~ fisher. 88 | :param m: The layer 89 | :return: no returns. 90 | """ 91 | eps = 1e-10 # for numerical stability 92 | self.d_a[m], self.Q_a[m] = torch.symeig( 93 | self.m_aa[m], eigenvectors=True) 94 | self.d_g[m], self.Q_g[m] = torch.symeig( 95 | self.m_gg[m], eigenvectors=True) 96 | 97 | self.d_a[m].mul_((self.d_a[m] > eps).float()) 98 | self.d_g[m].mul_((self.d_g[m] > eps).float()) 99 | 100 | @staticmethod 101 | def _get_matrix_form_grad(m, classname): 102 | """ 103 | :param m: the layer 104 | :param classname: the class name of the layer 105 | :return: a matrix form of the gradient. it should be a [output_dim, input_dim] matrix. 106 | """ 107 | if classname == 'Conv2d': 108 | p_grad_mat = m.weight.grad.data.view(m.weight.grad.data.size(0), -1) # n_filters * (in_c * kw * kh) 109 | else: 110 | p_grad_mat = m.weight.grad.data 111 | if m.bias is not None: 112 | p_grad_mat = torch.cat([p_grad_mat, m.bias.grad.data.view(-1, 1)], 1) 113 | return p_grad_mat 114 | 115 | def _get_natural_grad(self, m, p_grad_mat, damping): 116 | """ 117 | :param m: the layer 118 | :param p_grad_mat: the gradients in matrix form 119 | :return: a list of gradients w.r.t to the parameters in `m` 120 | """ 121 | # p_grad_mat is of output_dim * input_dim 122 | # inv((ss')) p_grad_mat inv(aa') = [ Q_g (1/R_g) Q_g^T ] @ p_grad_mat @ [Q_a (1/R_a) Q_a^T] 123 | v1 = self.Q_g[m].t() @ p_grad_mat @ self.Q_a[m] 124 | v2 = v1 / (self.d_g[m].unsqueeze(1) * self.d_a[m].unsqueeze(0) + damping) 125 | v = self.Q_g[m] @ v2 @ self.Q_a[m].t() 126 | if m.bias is not None: 127 | # we always put gradient w.r.t weight in [0] 128 | # and w.r.t bias in [1] 129 | v = [v[:, :-1], v[:, -1:]] 130 | v[0] = v[0].view(m.weight.grad.data.size()) 131 | v[1] = v[1].view(m.bias.grad.data.size()) 132 | else: 133 | v = [v.view(m.weight.grad.data.size())] 134 | 135 | return v 136 | 137 | def _kl_clip_and_update_grad(self, updates, lr): 138 | # do kl clip 139 | vg_sum = 0 140 | for m in self.modules: 141 | v = updates[m] 142 | vg_sum += (v[0] * m.weight.grad.data * lr ** 2).sum().item() 143 | if m.bias is not None: 144 | vg_sum += (v[1] * m.bias.grad.data * lr ** 2).sum().item() 145 | nu = min(1.0, math.sqrt(self.kl_clip / vg_sum)) 146 | 147 | for m in self.modules: 148 | v = updates[m] 149 | m.weight.grad.data.copy_(v[0]) 150 | m.weight.grad.data.mul_(nu) 151 | if m.bias is not None: 152 | m.bias.grad.data.copy_(v[1]) 153 | m.bias.grad.data.mul_(nu) 154 | 155 | def _step(self, closure): 156 | # FIXME (CW): Modified based on SGD (removed nestrov and dampening in momentum.) 157 | # FIXME (CW): 1. no nesterov, 2. buf.mul_(momentum).add_(1 - dampening , d_p) 158 | for group in self.param_groups: 159 | weight_decay = group['weight_decay'] 160 | momentum = group['momentum'] 161 | 162 | for p in group['params']: 163 | if p.grad is None: 164 | continue 165 | d_p = p.grad.data 166 | if weight_decay != 0 and self.steps >= 20 * self.TCov: 167 | d_p.add_(weight_decay, p.data) 168 | if momentum != 0: 169 | param_state = self.state[p] 170 | if 'momentum_buffer' not in param_state: 171 | buf = param_state['momentum_buffer'] = torch.zeros_like(p.data) 172 | buf.mul_(momentum).add_(d_p) 173 | else: 174 | buf = param_state['momentum_buffer'] 175 | buf.mul_(momentum).add_(1, d_p) 176 | d_p = buf 177 | 178 | p.data.add_(-group['lr'], d_p) 179 | 180 | def step(self, closure=None): 181 | # FIXME(CW): temporal fix for compatibility with Official LR scheduler. 182 | group = self.param_groups[0] 183 | lr = group['lr'] 184 | damping = group['damping'] 185 | updates = {} 186 | for m in self.modules: 187 | classname = m.__class__.__name__ 188 | if self.steps % self.TInv == 0: 189 | self._update_inv(m) 190 | p_grad_mat = self._get_matrix_form_grad(m, classname) 191 | v = self._get_natural_grad(m, p_grad_mat, damping) 192 | updates[m] = v 193 | self._kl_clip_and_update_grad(updates, lr) 194 | 195 | self._step(closure) 196 | self.steps += 1 197 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('--network', type=str, default='vgg16_bn') 7 | parser.add_argument('--dataset', type=str, default='cifar10') 8 | parser.add_argument('--optimizer', type=str, default='kfac') 9 | parser.add_argument('--machine', type=int, default=10) 10 | 11 | args = parser.parse_args() 12 | 13 | vgg16_bn = '' 14 | vgg19_bn = '' 15 | resnet = '--depth 110' 16 | wrn = '--depth 28 --widen_factor 10 --dropRate 0.3' 17 | densenet = '--depth 100 --growthRate 12' 18 | 19 | apps = { 20 | 'vgg16_bn': vgg16_bn, 21 | 'vgg19_bn': vgg19_bn, 22 | 'resnet': resnet, 23 | 'wrn': wrn, 24 | 'densenet': densenet 25 | } 26 | 27 | 28 | def grid_search(args): 29 | scripts = [] 30 | if args.optimizer in ['kfac', 'ekfac']: 31 | template = 'python main.py ' \ 32 | '--dataset %s ' \ 33 | '--optimizer %s ' \ 34 | '--network %s ' \ 35 | ' --epoch 100 ' \ 36 | '--milestone 40,80 ' \ 37 | '--learning_rate %f ' \ 38 | '--damping %f ' \ 39 | '--weight_decay %f %s' 40 | 41 | lrs = [3e-2, 1e-2, 3e-3] 42 | dampings = [3e-2, 1e-3, 3e-3] 43 | wds = [1e-2, 3e-3, 1e-3, 3e-4, 1e-4] 44 | app = apps[args.network] 45 | for lr in lrs: 46 | for dmp in dampings: 47 | for wd in wds: 48 | scripts.append(template % (args.dataset, args.optimizer, args.network, lr, dmp, wd, app)) 49 | elif args.optimizer == 'sgd': 50 | template = 'python main.py ' \ 51 | '--dataset %s ' \ 52 | '--optimizer %s ' \ 53 | '--network %s ' \ 54 | ' --epoch 200 ' \ 55 | '--milestone 60,120,180 ' \ 56 | '--learning_rate %f ' \ 57 | '--weight_decay %f %s' 58 | app = apps[args.network] 59 | lrs = [3e-1, 1e-1, 3e-2] 60 | wds = [1e-2, 3e-3, 1e-3, 3e-4, 1e-4] 61 | 62 | for lr in lrs: 63 | for wd in wds: 64 | scripts.append(template % (args.dataset, args.optimizer, args.network, lr, wd, app)) 65 | 66 | return scripts 67 | 68 | 69 | def gen_script(scripts, machine, args): 70 | with open('run_%s_%s_%s.sh' % (args.dataset, args.optimizer, args.network), 'w') as f: 71 | for s in scripts: 72 | f.write('srun --gres=gpu:1 -c 6 -w guppy%d --mem=16G -p gpu \"%s\" &\n' % (machine, s)) 73 | 74 | 75 | if __name__ == '__main__': 76 | scripts = grid_search(args) 77 | gen_script(scripts, args.machine, args) 78 | 79 | -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torchvision.transforms as transforms 4 | 5 | 6 | def get_transforms(dataset): 7 | transform_train = None 8 | transform_test = None 9 | if dataset == 'cifar10': 10 | transform_train = transforms.Compose([ 11 | transforms.RandomCrop(32, padding=4), 12 | transforms.RandomHorizontalFlip(), 13 | transforms.ToTensor(), 14 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 15 | ]) 16 | 17 | transform_test = transforms.Compose([ 18 | transforms.ToTensor(), 19 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 20 | ]) 21 | 22 | if dataset == 'cifar100': 23 | transform_train = transforms.Compose([ 24 | transforms.RandomCrop(32, padding=4), 25 | transforms.RandomHorizontalFlip(), 26 | transforms.ToTensor(), 27 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), 28 | ]) 29 | 30 | transform_test = transforms.Compose([ 31 | transforms.ToTensor(), 32 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), 33 | ]) 34 | 35 | assert transform_test is not None and transform_train is not None, 'Error, no dataset %s' % dataset 36 | return transform_train, transform_test 37 | 38 | 39 | def get_dataloader(dataset, train_batch_size, test_batch_size, num_workers=2, root='../data'): 40 | transform_train, transform_test = get_transforms(dataset) 41 | trainset, testset = None, None 42 | if dataset == 'cifar10': 43 | trainset = torchvision.datasets.CIFAR10(root=root, train=True, download=True, transform=transform_train) 44 | testset = torchvision.datasets.CIFAR10(root=root, train=False, download=True, transform=transform_test) 45 | 46 | if dataset == 'cifar100': 47 | trainset = torchvision.datasets.CIFAR100(root=root, train=True, download=True, transform=transform_train) 48 | testset = torchvision.datasets.CIFAR100(root=root, train=False, download=True, transform=transform_test) 49 | 50 | 51 | assert trainset is not None and testset is not None, 'Error, no dataset %s' % dataset 52 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=train_batch_size, shuffle=True, 53 | num_workers=num_workers) 54 | testloader = torch.utils.data.DataLoader(testset, batch_size=test_batch_size, shuffle=False, 55 | num_workers=num_workers) 56 | 57 | return trainloader, testloader -------------------------------------------------------------------------------- /utils/kfac_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def try_contiguous(x): 7 | if not x.is_contiguous(): 8 | x = x.contiguous() 9 | 10 | return x 11 | 12 | 13 | def _extract_patches(x, kernel_size, stride, padding): 14 | """ 15 | :param x: The input feature maps. (batch_size, in_c, h, w) 16 | :param kernel_size: the kernel size of the conv filter (tuple of two elements) 17 | :param stride: the stride of conv operation (tuple of two elements) 18 | :param padding: number of paddings. be a tuple of two elements 19 | :return: (batch_size, out_h, out_w, in_c*kh*kw) 20 | """ 21 | if padding[0] + padding[1] > 0: 22 | x = F.pad(x, (padding[1], padding[1], padding[0], 23 | padding[0])).data # Actually check dims 24 | x = x.unfold(2, kernel_size[0], stride[0]) 25 | x = x.unfold(3, kernel_size[1], stride[1]) 26 | x = x.transpose_(1, 2).transpose_(2, 3).contiguous() 27 | x = x.view( 28 | x.size(0), x.size(1), x.size(2), 29 | x.size(3) * x.size(4) * x.size(5)) 30 | return x 31 | 32 | 33 | def update_running_stat(aa, m_aa, stat_decay): 34 | # using inplace operation to save memory! 35 | m_aa *= stat_decay / (1 - stat_decay) 36 | m_aa += aa 37 | m_aa *= (1 - stat_decay) 38 | 39 | 40 | class ComputeMatGrad: 41 | 42 | @classmethod 43 | def __call__(cls, input, grad_output, layer): 44 | if isinstance(layer, nn.Linear): 45 | grad = cls.linear(input, grad_output, layer) 46 | elif isinstance(layer, nn.Conv2d): 47 | grad = cls.conv2d(input, grad_output, layer) 48 | else: 49 | raise NotImplementedError 50 | return grad 51 | 52 | @staticmethod 53 | def linear(input, grad_output, layer): 54 | """ 55 | :param input: batch_size * input_dim 56 | :param grad_output: batch_size * output_dim 57 | :param layer: [nn.module] output_dim * input_dim 58 | :return: batch_size * output_dim * (input_dim + [1 if with bias]) 59 | """ 60 | with torch.no_grad(): 61 | if layer.bias is not None: 62 | input = torch.cat([input, input.new(input.size(0), 1).fill_(1)], 1) 63 | input = input.unsqueeze(1) 64 | grad_output = grad_output.unsqueeze(2) 65 | grad = torch.bmm(grad_output, input) 66 | return grad 67 | 68 | @staticmethod 69 | def conv2d(input, grad_output, layer): 70 | """ 71 | :param input: batch_size * in_c * in_h * in_w 72 | :param grad_output: batch_size * out_c * h * w 73 | :param layer: nn.module batch_size * out_c * (in_c*k_h*k_w + [1 if with bias]) 74 | :return: 75 | """ 76 | with torch.no_grad(): 77 | input = _extract_patches(input, layer.kernel_size, layer.stride, layer.padding) 78 | input = input.view(-1, input.size(-1)) # b * hw * in_c*kh*kw 79 | grad_output = grad_output.transpose(1, 2).transpose(2, 3) 80 | grad_output = try_contiguous(grad_output).view(grad_output.size(0), -1, grad_output.size(-1)) 81 | # b * hw * out_c 82 | if layer.bias is not None: 83 | input = torch.cat([input, input.new(input.size(0), 1).fill_(1)], 1) 84 | input = input.view(grad_output.size(0), -1, input.size(-1)) # b * hw * in_c*kh*kw 85 | grad = torch.einsum('abm,abn->amn', (grad_output, input)) 86 | return grad 87 | 88 | 89 | class ComputeCovA: 90 | 91 | @classmethod 92 | def compute_cov_a(cls, a, layer): 93 | return cls.__call__(a, layer) 94 | 95 | @classmethod 96 | def __call__(cls, a, layer): 97 | if isinstance(layer, nn.Linear): 98 | cov_a = cls.linear(a, layer) 99 | elif isinstance(layer, nn.Conv2d): 100 | cov_a = cls.conv2d(a, layer) 101 | else: 102 | # FIXME(CW): for extension to other layers. 103 | # raise NotImplementedError 104 | cov_a = None 105 | 106 | return cov_a 107 | 108 | @staticmethod 109 | def conv2d(a, layer): 110 | batch_size = a.size(0) 111 | a = _extract_patches(a, layer.kernel_size, layer.stride, layer.padding) 112 | spatial_size = a.size(1) * a.size(2) 113 | a = a.view(-1, a.size(-1)) 114 | if layer.bias is not None: 115 | a = torch.cat([a, a.new(a.size(0), 1).fill_(1)], 1) 116 | a = a/spatial_size 117 | # FIXME(CW): do we need to divide the output feature map's size? 118 | return a.t() @ (a / batch_size) 119 | 120 | @staticmethod 121 | def linear(a, layer): 122 | # a: batch_size * in_dim 123 | batch_size = a.size(0) 124 | if layer.bias is not None: 125 | a = torch.cat([a, a.new(a.size(0), 1).fill_(1)], 1) 126 | return a.t() @ (a / batch_size) 127 | 128 | 129 | class ComputeCovG: 130 | 131 | @classmethod 132 | def compute_cov_g(cls, g, layer, batch_averaged=False): 133 | """ 134 | :param g: gradient 135 | :param layer: the corresponding layer 136 | :param batch_averaged: if the gradient is already averaged with the batch size? 137 | :return: 138 | """ 139 | # batch_size = g.size(0) 140 | return cls.__call__(g, layer, batch_averaged) 141 | 142 | @classmethod 143 | def __call__(cls, g, layer, batch_averaged): 144 | if isinstance(layer, nn.Conv2d): 145 | cov_g = cls.conv2d(g, layer, batch_averaged) 146 | elif isinstance(layer, nn.Linear): 147 | cov_g = cls.linear(g, layer, batch_averaged) 148 | else: 149 | cov_g = None 150 | 151 | return cov_g 152 | 153 | @staticmethod 154 | def conv2d(g, layer, batch_averaged): 155 | # g: batch_size * n_filters * out_h * out_w 156 | # n_filters is actually the output dimension (analogous to Linear layer) 157 | spatial_size = g.size(2) * g.size(3) 158 | batch_size = g.shape[0] 159 | g = g.transpose(1, 2).transpose(2, 3) 160 | g = try_contiguous(g) 161 | g = g.view(-1, g.size(-1)) 162 | 163 | if batch_averaged: 164 | g = g * batch_size 165 | g = g * spatial_size 166 | cov_g = g.t() @ (g / g.size(0)) 167 | 168 | return cov_g 169 | 170 | @staticmethod 171 | def linear(g, layer, batch_averaged): 172 | # g: batch_size * out_dim 173 | batch_size = g.size(0) 174 | 175 | if batch_averaged: 176 | cov_g = g.t() @ (g * batch_size) 177 | else: 178 | cov_g = g.t() @ (g / batch_size) 179 | return cov_g 180 | 181 | 182 | 183 | if __name__ == '__main__': 184 | def test_ComputeCovA(): 185 | pass 186 | 187 | def test_ComputeCovG(): 188 | pass 189 | 190 | 191 | 192 | 193 | 194 | 195 | -------------------------------------------------------------------------------- /utils/network_utils.py: -------------------------------------------------------------------------------- 1 | from models.cifar import (alexnet, densenet, resnet, 2 | vgg16_bn, vgg19_bn, 3 | wrn) 4 | 5 | 6 | def get_network(network, **kwargs): 7 | networks = { 8 | 'alexnet': alexnet, 9 | 'densenet': densenet, 10 | 'resnet': resnet, 11 | 'vgg16_bn': vgg16_bn, 12 | 'vgg19_bn': vgg19_bn, 13 | 'wrn': wrn 14 | 15 | } 16 | 17 | return networks[network](**kwargs) 18 | 19 | --------------------------------------------------------------------------------