├── .gitignore
├── README.md
├── main.py
├── models
├── __init__.py
├── cifar
│ ├── __init__.py
│ ├── alexnet.py
│ ├── densenet.py
│ ├── preresnet.py
│ ├── resnet.py
│ ├── resnext.py
│ ├── vgg.py
│ └── wrn.py
└── imagenet
│ ├── __init__.py
│ └── resnext.py
├── optimizers
├── __init__.py
├── ekfac.py
└── kfac.py
├── trainer.py
└── utils
├── data_utils.py
├── kfac_utils.py
└── network_utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | led / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | # data
6 | data.cifar10/
7 | data.cifar100/
8 | *.gz
9 | shells/
10 | # C extensions
11 | *.so
12 |
13 | # Distribution / packaging
14 | .Python
15 | checkpoint/
16 | env/
17 | build/
18 | develop-eggs/
19 | dist/
20 | downloads/
21 | eggs/
22 | .eggs/
23 | lib/
24 | lib64/
25 | parts/
26 | sdist/
27 | var/
28 | wheels/
29 | *.egg-info/
30 | .installed.cfg
31 | *.egg
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | .hypothesis/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | #*.log
60 | local_settings.py
61 |
62 | # Flask stuff:
63 | instance/
64 | .webassets-cache
65 |
66 | # Scrapy stuff:
67 | .scrapy
68 |
69 | # Sphinx documentation
70 | docs/_build/
71 |
72 | # PyBuilder
73 | target/
74 |
75 | # Jupyter Notebook
76 | .ipynb_checkpoints
77 |
78 | # pyenv
79 | .python-version
80 |
81 | # celery beat schedule file
82 | celerybeat-schedule
83 |
84 | # SageMath parsed files
85 | *.sage.py
86 |
87 | # dotenv
88 | .env
89 | *.tar
90 |
91 | # virtualenv
92 | .venv
93 | venv/
94 | ENV/
95 |
96 | # Spyder project settings
97 | .spyderproject
98 | .spyproject
99 |
100 | # Rope project settings
101 | .ropeproject
102 |
103 | # mkdocs documentation
104 | /site
105 |
106 | # mypy
107 | .mypy_cache/
108 |
109 | tmp
110 | runs
111 | run
112 |
113 | # PyCharm
114 | .idea/
115 |
116 | # macOS metadata
117 | .DS_Store
118 | ._.DS_Store
119 | ._*
120 |
121 | #
122 | data/
123 | log/
124 | summary/
125 | data/kernel_toy/*.pth
126 | data/AS/gp-structure-search
127 | #*.data
128 | data/mnist_data
129 | *.npz
130 | *.txt
131 | #*.png
132 | #*.pdf
133 | *.jpeg
134 | *.jpg
135 | #results/
136 | *.pyc
137 | *__pycache__
138 |
139 | checkpoint/
140 | runs/
141 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # K-FAC_pytorch
2 | Pytorch implementation of [K-FAC](https://arxiv.org/abs/1503.05671) and [E-KFAC](https://arxiv.org/abs/1806.03884). (Only support single-GPU training, need modifications for multi-GPU.)
3 | ## Requiresments
4 | ```
5 | pytorch 0.4.0
6 | torchvision
7 | python 3.6.0
8 | tqdm
9 | tensorboardX
10 | tensorflow
11 | ```
12 | ## How to run
13 | ```
14 | python main.py --dataset cifar10 --optimizer kfac --network vgg16_bn --epoch 100 --milestone 40,80 --learning_rate 0.01 --damping 0.03 --weight_decay 0.003
15 | ```
16 |
17 |
18 | ## Performance
19 | #### Note: for better hyparameters of K-FAC, please refer to [weight_decay](https://github.com/gd-zhang/Weight-Decay/tree/master/configs) repo. (The hyparameters below are not good enough! Especially the weight decay is too small!)
20 | For K-FAC and E-KFAC, the search range of learning rates, weight decay and dampings are:
21 | (1) learning rate = [3e-2, 1e-2, 3e-3]
22 | (2) weight decay = [1e-2, 3e-3, 1e-3, 3e-4, 1e-4]
23 | (3) damping = [3e-2, 1e-3, 3e-3]
24 |
25 | For SGD:
26 | (1) learning rate = [3e-1, 1e-1, 3e-2]
27 | (2) weight decay = [1e-2, 3e-3, 1e-3, 3e-4, 1e-4]
28 |
29 | #### CIFAR10
30 |
31 | | Optimizer | Model | Acc. | learning rate | weight decay | damping |
32 | |---------- | ---------------------------------- | ----------- | ------------- | -------------| ----------- |
33 | | KFAC | [VGG16_BN](https://arxiv.org/abs/1409.1556) | 93.86% | 0.01 | 0.003 | 0.03 |
34 | | E-KFAC | [VGG16_BN](https://arxiv.org/abs/1409.1556) | 94.00% | 0.003 | 0.01 | 0.03 |
35 | | SGD | [VGG16_BN](https://arxiv.org/abs/1409.1556) | 94.03% | 0.03 | 0.001 | - |
36 | | KFAC | [ResNet110](https://arxiv.org/abs/1512.03385)| 93.59% | 0.01 | 0.003 | 0.03 |
37 | | E-KFAC | [ResNet110](https://arxiv.org/abs/1512.03385)| 93.37% | 0.003 | 0.01 | 0.03 |
38 | | SGD | [ResNet110](https://arxiv.org/abs/1512.03385)| 94.14% | 0.03 | 0.001 | - |
39 |
40 |
41 |
42 | #### CIFAR100
43 |
44 | | Optimizer | Model | Acc. | learning rate | weight decay | damping |
45 | |---------- | ---------------------------------- | ----------- | ------------- | -------------| ----------- |
46 | | KFAC | [VGG16_BN](https://arxiv.org/abs/1409.1556) | 74.09% | 0.003 | 0.01 | 0.03 |
47 | | E-KFAC | [VGG16_BN](https://arxiv.org/abs/1409.1556) | 73.20% | 0.01 | 0.01 | 0.03 |
48 | | SGD | [VGG16_BN](https://arxiv.org/abs/1409.1556) | 74.56% | 0.03 | 0.003 | - |
49 | | KFAC | [ResNet110](https://arxiv.org/abs/1512.03385)| 72.71% | 0.003 | 0.01 | 0.003 |
50 | | E-KFAC | [ResNet110](https://arxiv.org/abs/1512.03385)| 72.32% | 0.03 | 0.001 | 0.03 |
51 | | SGD | [ResNet110](https://arxiv.org/abs/1512.03385)| 72.60% | 0.1 | 0.0003 | - |
52 |
53 | ## Others
54 | Please consider cite the following papers for K-FAC:
55 | ```
56 | @inproceedings{martens2015optimizing,
57 | title={Optimizing neural networks with kronecker-factored approximate curvature},
58 | author={Martens, James and Grosse, Roger},
59 | booktitle={International conference on machine learning},
60 | pages={2408--2417},
61 | year={2015}
62 | }
63 |
64 | @inproceedings{grosse2016kronecker,
65 | title={A kronecker-factored approximate fisher matrix for convolution layers},
66 | author={Grosse, Roger and Martens, James},
67 | booktitle={International Conference on Machine Learning},
68 | pages={573--582},
69 | year={2016}
70 | }
71 | ```
72 |
73 | and for E-KFAC:
74 | ```
75 | @inproceedings{george2018fast,
76 | title={Fast Approximate Natural Gradient Descent in a Kronecker Factored Eigenbasis},
77 | author={George, Thomas and Laurent, C{\'e}sar and Bouthillier, Xavier and Ballas, Nicolas and Vincent, Pascal},
78 | booktitle={Advances in Neural Information Processing Systems},
79 | pages={9550--9560},
80 | year={2018}
81 | }
82 | ```
83 |
84 | If you have any questions or suggestions, please feel free to contact me via alecwangcq at gmail , com!
85 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | '''Train CIFAR10/CIFAR100 with PyTorch.'''
2 | import argparse
3 | import os
4 | from optimizers import (KFACOptimizer, EKFACOptimizer)
5 | import torch
6 | import torch.nn as nn
7 | import torch.optim as optim
8 | from torch.optim.lr_scheduler import MultiStepLR
9 |
10 | from tqdm import tqdm
11 | from tensorboardX import SummaryWriter
12 | from utils.network_utils import get_network
13 | from utils.data_utils import get_dataloader
14 |
15 |
16 | # fetch args
17 | parser = argparse.ArgumentParser()
18 |
19 |
20 | parser.add_argument('--network', default='vgg16_bn', type=str)
21 | parser.add_argument('--depth', default=19, type=int)
22 | parser.add_argument('--dataset', default='cifar10', type=str)
23 |
24 | # densenet
25 | parser.add_argument('--growthRate', default=12, type=int)
26 | parser.add_argument('--compressionRate', default=2, type=int)
27 |
28 | # wrn, densenet
29 | parser.add_argument('--widen_factor', default=1, type=int)
30 | parser.add_argument('--dropRate', default=0.0, type=float)
31 |
32 |
33 | parser.add_argument('--device', default='cuda', type=str)
34 | parser.add_argument('--resume', '-r', action='store_true')
35 | parser.add_argument('--load_path', default='', type=str)
36 | parser.add_argument('--log_dir', default='runs/pretrain', type=str)
37 |
38 |
39 | parser.add_argument('--optimizer', default='kfac', type=str)
40 | parser.add_argument('--batch_size', default=64, type=float)
41 | parser.add_argument('--epoch', default=100, type=int)
42 | parser.add_argument('--milestone', default=None, type=str)
43 | parser.add_argument('--learning_rate', default=0.01, type=float)
44 | parser.add_argument('--momentum', default=0.9, type=float)
45 | parser.add_argument('--stat_decay', default=0.95, type=float)
46 | parser.add_argument('--damping', default=1e-3, type=float)
47 | parser.add_argument('--kl_clip', default=1e-2, type=float)
48 | parser.add_argument('--weight_decay', default=3e-3, type=float)
49 | parser.add_argument('--TCov', default=10, type=int)
50 | parser.add_argument('--TScal', default=10, type=int)
51 | parser.add_argument('--TInv', default=100, type=int)
52 |
53 |
54 | parser.add_argument('--prefix', default=None, type=str)
55 | args = parser.parse_args()
56 |
57 | # init model
58 | nc = {
59 | 'cifar10': 10,
60 | 'cifar100': 100
61 | }
62 | num_classes = nc[args.dataset]
63 | net = get_network(args.network,
64 | depth=args.depth,
65 | num_classes=num_classes,
66 | growthRate=args.growthRate,
67 | compressionRate=args.compressionRate,
68 | widen_factor=args.widen_factor,
69 | dropRate=args.dropRate)
70 | net = net.to(args.device)
71 |
72 | # init dataloader
73 | trainloader, testloader = get_dataloader(dataset=args.dataset,
74 | train_batch_size=args.batch_size,
75 | test_batch_size=256)
76 |
77 | # init optimizer and lr scheduler
78 | optim_name = args.optimizer.lower()
79 | tag = optim_name
80 | if optim_name == 'sgd':
81 | optimizer = optim.SGD(net.parameters(),
82 | lr=args.learning_rate,
83 | momentum=args.momentum,
84 | weight_decay=args.weight_decay)
85 | elif optim_name == 'kfac':
86 | optimizer = KFACOptimizer(net,
87 | lr=args.learning_rate,
88 | momentum=args.momentum,
89 | stat_decay=args.stat_decay,
90 | damping=args.damping,
91 | kl_clip=args.kl_clip,
92 | weight_decay=args.weight_decay,
93 | TCov=args.TCov,
94 | TInv=args.TInv)
95 | elif optim_name == 'ekfac':
96 | optimizer = EKFACOptimizer(net,
97 | lr=args.learning_rate,
98 | momentum=args.momentum,
99 | stat_decay=args.stat_decay,
100 | damping=args.damping,
101 | kl_clip=args.kl_clip,
102 | weight_decay=args.weight_decay,
103 | TCov=args.TCov,
104 | TScal=args.TScal,
105 | TInv=args.TInv)
106 | else:
107 | raise NotImplementedError
108 |
109 | if args.milestone is None:
110 | lr_scheduler = MultiStepLR(optimizer, milestones=[int(args.epoch*0.5), int(args.epoch*0.75)], gamma=0.1)
111 | else:
112 | milestone = [int(_) for _ in args.milestone.split(',')]
113 | lr_scheduler = MultiStepLR(optimizer, milestones=milestone, gamma=0.1)
114 |
115 | # init criterion
116 | criterion = nn.CrossEntropyLoss()
117 |
118 | start_epoch = 0
119 | best_acc = 0
120 | if args.resume:
121 | print('==> Resuming from checkpoint..')
122 | assert os.path.isfile(args.load_path), 'Error: no checkpoint directory found!'
123 | checkpoint = torch.load(args.load_path)
124 | net.load_state_dict(checkpoint['net'])
125 | best_acc = checkpoint['acc']
126 | start_epoch = checkpoint['epoch']
127 | print('==> Loaded checkpoint at epoch: %d, acc: %.2f%%' % (start_epoch, best_acc))
128 |
129 | # init summary writter
130 |
131 | log_dir = os.path.join(args.log_dir, args.dataset, args.network, args.optimizer,
132 | 'lr%.3f_wd%.4f_damping%.4f' %
133 | (args.learning_rate, args.weight_decay, args.damping))
134 | if not os.path.isdir(log_dir):
135 | os.makedirs(log_dir)
136 | writer = SummaryWriter(log_dir)
137 |
138 |
139 | def train(epoch):
140 | print('\nEpoch: %d' % epoch)
141 | net.train()
142 | train_loss = 0
143 | correct = 0
144 | total = 0
145 |
146 | lr_scheduler.step()
147 | desc = ('[%s][LR=%s] Loss: %.3f | Acc: %.3f%% (%d/%d)' %
148 | (tag, lr_scheduler.get_lr()[0], 0, 0, correct, total))
149 |
150 | writer.add_scalar('train/lr', lr_scheduler.get_lr()[0], epoch)
151 |
152 | prog_bar = tqdm(enumerate(trainloader), total=len(trainloader), desc=desc, leave=True)
153 | for batch_idx, (inputs, targets) in prog_bar:
154 | inputs, targets = inputs.to(args.device), targets.to(args.device)
155 | optimizer.zero_grad()
156 | outputs = net(inputs)
157 | loss = criterion(outputs, targets)
158 | if optim_name in ['kfac', 'ekfac'] and optimizer.steps % optimizer.TCov == 0:
159 | # compute true fisher
160 | optimizer.acc_stats = True
161 | with torch.no_grad():
162 | sampled_y = torch.multinomial(torch.nn.functional.softmax(outputs.cpu().data, dim=1),
163 | 1).squeeze().cuda()
164 | loss_sample = criterion(outputs, sampled_y)
165 | loss_sample.backward(retain_graph=True)
166 | optimizer.acc_stats = False
167 | optimizer.zero_grad() # clear the gradient for computing true-fisher.
168 | loss.backward()
169 | optimizer.step()
170 |
171 | train_loss += loss.item()
172 | _, predicted = outputs.max(1)
173 | total += targets.size(0)
174 | correct += predicted.eq(targets).sum().item()
175 |
176 | desc = ('[%s][LR=%s] Loss: %.3f | Acc: %.3f%% (%d/%d)' %
177 | (tag, lr_scheduler.get_lr()[0], train_loss / (batch_idx + 1), 100. * correct / total, correct, total))
178 | prog_bar.set_description(desc, refresh=True)
179 |
180 | writer.add_scalar('train/loss', train_loss/(batch_idx + 1), epoch)
181 | writer.add_scalar('train/acc', 100. * correct / total, epoch)
182 |
183 |
184 | def test(epoch):
185 | global best_acc
186 | net.eval()
187 | test_loss = 0
188 | correct = 0
189 | total = 0
190 | desc = ('[%s][LR=%s] Loss: %.3f | Acc: %.3f%% (%d/%d)'
191 | % (tag,lr_scheduler.get_lr()[0], test_loss/(0+1), 0, correct, total))
192 |
193 | prog_bar = tqdm(enumerate(testloader), total=len(testloader), desc=desc, leave=True)
194 | with torch.no_grad():
195 | for batch_idx, (inputs, targets) in prog_bar:
196 | inputs, targets = inputs.to(args.device), targets.to(args.device)
197 | outputs = net(inputs)
198 | loss = criterion(outputs, targets)
199 |
200 | test_loss += loss.item()
201 | _, predicted = outputs.max(1)
202 | total += targets.size(0)
203 | correct += predicted.eq(targets).sum().item()
204 |
205 | desc = ('[%s][LR=%s] Loss: %.3f | Acc: %.3f%% (%d/%d)'
206 | % (tag, lr_scheduler.get_lr()[0], test_loss / (batch_idx + 1), 100. * correct / total, correct, total))
207 | prog_bar.set_description(desc, refresh=True)
208 |
209 | # Save checkpoint.
210 | acc = 100.*correct/total
211 |
212 | writer.add_scalar('test/loss', test_loss / (batch_idx + 1), epoch)
213 | writer.add_scalar('test/acc', 100. * correct / total, epoch)
214 |
215 | if acc > best_acc:
216 | print('Saving..')
217 | state = {
218 | 'net': net.state_dict(),
219 | 'acc': acc,
220 | 'epoch': epoch,
221 | 'loss': test_loss,
222 | 'args': args
223 | }
224 |
225 | torch.save(state, '%s/%s_%s_%s%s_best.t7' % (log_dir,
226 | args.optimizer,
227 | args.dataset,
228 | args.network,
229 | args.depth))
230 | best_acc = acc
231 |
232 |
233 | def main():
234 | for epoch in range(start_epoch, args.epoch):
235 | train(epoch)
236 | test(epoch)
237 | return best_acc
238 |
239 |
240 | if __name__ == '__main__':
241 | main()
242 |
243 |
244 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alecwangcq/KFAC-Pytorch/25e6dbe14752348d4f6030697b4b7f553ead2e92/models/__init__.py
--------------------------------------------------------------------------------
/models/cifar/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | """The models subpackage contains definitions for the following model for CIFAR10/CIFAR100
4 | architectures:
5 |
6 | - `AlexNet`_
7 | - `VGG`_
8 | - `ResNet`_
9 | - `SqueezeNet`_
10 | - `DenseNet`_
11 |
12 | You can construct a model with random weights by calling its constructor:
13 |
14 | .. code:: python
15 |
16 | import torchvision.models as models
17 | resnet18 = models.resnet18()
18 | alexnet = models.alexnet()
19 | squeezenet = models.squeezenet1_0()
20 | densenet = models.densenet_161()
21 |
22 | We provide pre-trained models for the ResNet variants and AlexNet, using the
23 | PyTorch :mod:`torch.utils.model_zoo`. These can constructed by passing
24 | ``pretrained=True``:
25 |
26 | .. code:: python
27 |
28 | import torchvision.models as models
29 | resnet18 = models.resnet18(pretrained=True)
30 | alexnet = models.alexnet(pretrained=True)
31 |
32 | ImageNet 1-crop error rates (224x224)
33 |
34 | ======================== ============= =============
35 | Network Top-1 error Top-5 error
36 | ======================== ============= =============
37 | ResNet-18 30.24 10.92
38 | ResNet-34 26.70 8.58
39 | ResNet-50 23.85 7.13
40 | ResNet-101 22.63 6.44
41 | ResNet-152 21.69 5.94
42 | Inception v3 22.55 6.44
43 | AlexNet 43.45 20.91
44 | VGG-11 30.98 11.37
45 | VGG-13 30.07 10.75
46 | VGG-16 28.41 9.62
47 | VGG-19 27.62 9.12
48 | SqueezeNet 1.0 41.90 19.58
49 | SqueezeNet 1.1 41.81 19.38
50 | Densenet-121 25.35 7.83
51 | Densenet-169 24.00 7.00
52 | Densenet-201 22.80 6.43
53 | Densenet-161 22.35 6.20
54 | ======================== ============= =============
55 |
56 |
57 | .. _AlexNet: https://arxiv.org/abs/1404.5997
58 | .. _VGG: https://arxiv.org/abs/1409.1556
59 | .. _ResNet: https://arxiv.org/abs/1512.03385
60 | .. _SqueezeNet: https://arxiv.org/abs/1602.07360
61 | .. _DenseNet: https://arxiv.org/abs/1608.06993
62 | """
63 |
64 | from .alexnet import *
65 | from .vgg import *
66 | from .resnet import *
67 | from .resnext import *
68 | from .wrn import *
69 | from .preresnet import *
70 | from .densenet import *
71 |
--------------------------------------------------------------------------------
/models/cifar/alexnet.py:
--------------------------------------------------------------------------------
1 | '''AlexNet for CIFAR10. FC layers are removed. Paddings are adjusted.
2 | Without BN, the start learning rate should be 0.01
3 | (c) YANG, Wei
4 | '''
5 | import torch.nn as nn
6 |
7 |
8 | __all__ = ['alexnet']
9 |
10 |
11 | class AlexNet(nn.Module):
12 |
13 | def __init__(self, num_classes=10, **kwargs):
14 | super(AlexNet, self).__init__()
15 | self.features = nn.Sequential(
16 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=5),
17 | nn.ReLU(inplace=True),
18 | nn.MaxPool2d(kernel_size=2, stride=2),
19 | nn.Conv2d(64, 192, kernel_size=5, padding=2),
20 | nn.ReLU(inplace=True),
21 | nn.MaxPool2d(kernel_size=2, stride=2),
22 | nn.Conv2d(192, 384, kernel_size=3, padding=1),
23 | nn.ReLU(inplace=True),
24 | nn.Conv2d(384, 256, kernel_size=3, padding=1),
25 | nn.ReLU(inplace=True),
26 | nn.Conv2d(256, 256, kernel_size=3, padding=1),
27 | nn.ReLU(inplace=True),
28 | nn.MaxPool2d(kernel_size=2, stride=2),
29 | )
30 | self.classifier = nn.Linear(256, num_classes)
31 |
32 | def forward(self, x):
33 | x = self.features(x)
34 | x = x.view(x.size(0), -1)
35 | x = self.classifier(x)
36 | return x
37 |
38 |
39 | def alexnet(**kwargs):
40 | r"""AlexNet model architecture from the
41 | `"One weird trick..." `_ paper.
42 | """
43 | model = AlexNet(**kwargs)
44 | return model
45 |
--------------------------------------------------------------------------------
/models/cifar/densenet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import math
5 |
6 |
7 | __all__ = ['densenet']
8 |
9 |
10 | from torch.autograd import Variable
11 |
12 | class Bottleneck(nn.Module):
13 | def __init__(self, inplanes, expansion=4, growthRate=12, dropRate=0):
14 | super(Bottleneck, self).__init__()
15 | planes = expansion * growthRate
16 | self.bn1 = nn.BatchNorm2d(inplanes)
17 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
18 | self.bn2 = nn.BatchNorm2d(planes)
19 | self.conv2 = nn.Conv2d(planes, growthRate, kernel_size=3,
20 | padding=1, bias=False)
21 | self.relu = nn.ReLU(inplace=True)
22 | self.dropRate = dropRate
23 |
24 | def forward(self, x):
25 | out = self.bn1(x)
26 | out = self.relu(out)
27 | out = self.conv1(out)
28 | out = self.bn2(out)
29 | out = self.relu(out)
30 | out = self.conv2(out)
31 | if self.dropRate > 0:
32 | out = F.dropout(out, p=self.dropRate, training=self.training)
33 |
34 | out = torch.cat((x, out), 1)
35 |
36 | return out
37 |
38 |
39 | class BasicBlock(nn.Module):
40 | def __init__(self, inplanes, expansion=1, growthRate=12, dropRate=0):
41 | super(BasicBlock, self).__init__()
42 | planes = expansion * growthRate
43 | self.bn1 = nn.BatchNorm2d(inplanes)
44 | self.conv1 = nn.Conv2d(inplanes, growthRate, kernel_size=3,
45 | padding=1, bias=False)
46 | self.relu = nn.ReLU(inplace=True)
47 | self.dropRate = dropRate
48 |
49 | def forward(self, x):
50 | out = self.bn1(x)
51 | out = self.relu(out)
52 | out = self.conv1(out)
53 | if self.dropRate > 0:
54 | out = F.dropout(out, p=self.dropRate, training=self.training)
55 |
56 | out = torch.cat((x, out), 1)
57 |
58 | return out
59 |
60 |
61 | class Transition(nn.Module):
62 | def __init__(self, inplanes, outplanes):
63 | super(Transition, self).__init__()
64 | self.bn1 = nn.BatchNorm2d(inplanes)
65 | self.conv1 = nn.Conv2d(inplanes, outplanes, kernel_size=1,
66 | bias=False)
67 | self.relu = nn.ReLU(inplace=True)
68 |
69 | def forward(self, x):
70 | out = self.bn1(x)
71 | out = self.relu(out)
72 | out = self.conv1(out)
73 | out = F.avg_pool2d(out, 2)
74 | return out
75 |
76 |
77 | class DenseNet(nn.Module):
78 |
79 | def __init__(self, depth=22, block=Bottleneck,
80 | dropRate=0, num_classes=10, growthRate=12, compressionRate=2, **kwargs):
81 | super(DenseNet, self).__init__()
82 |
83 | assert (depth - 4) % 3 == 0, 'depth should be 3n+4'
84 | n = (depth - 4) / 3 if block == BasicBlock else (depth - 4) // 6
85 |
86 | self.growthRate = growthRate
87 | self.dropRate = dropRate
88 |
89 | # self.inplanes is a global variable used across multiple
90 | # helper functions
91 | self.inplanes = growthRate * 2
92 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, padding=1,
93 | bias=False)
94 | self.dense1 = self._make_denseblock(block, n)
95 | self.trans1 = self._make_transition(compressionRate)
96 | self.dense2 = self._make_denseblock(block, n)
97 | self.trans2 = self._make_transition(compressionRate)
98 | self.dense3 = self._make_denseblock(block, n)
99 | self.bn = nn.BatchNorm2d(self.inplanes)
100 | self.relu = nn.ReLU(inplace=True)
101 | self.avgpool = nn.AvgPool2d(8)
102 | self.fc = nn.Linear(self.inplanes, num_classes)
103 |
104 | # Weight initialization
105 | for m in self.modules():
106 | if isinstance(m, nn.Conv2d):
107 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
108 | m.weight.data.normal_(0, math.sqrt(2. / n))
109 | elif isinstance(m, nn.BatchNorm2d):
110 | m.weight.data.fill_(1)
111 | m.bias.data.zero_()
112 |
113 | def _make_denseblock(self, block, blocks):
114 | layers = []
115 | for i in range(blocks):
116 | # Currently we fix the expansion ratio as the default value
117 | layers.append(block(self.inplanes, growthRate=self.growthRate, dropRate=self.dropRate))
118 | self.inplanes += self.growthRate
119 |
120 | return nn.Sequential(*layers)
121 |
122 | def _make_transition(self, compressionRate):
123 | inplanes = self.inplanes
124 | outplanes = int(math.floor(self.inplanes // compressionRate))
125 | self.inplanes = outplanes
126 | return Transition(inplanes, outplanes)
127 |
128 |
129 | def forward(self, x):
130 | x = self.conv1(x)
131 |
132 | x = self.trans1(self.dense1(x))
133 | x = self.trans2(self.dense2(x))
134 | x = self.dense3(x)
135 | x = self.bn(x)
136 | x = self.relu(x)
137 |
138 | x = self.avgpool(x)
139 | x = x.view(x.size(0), -1)
140 | x = self.fc(x)
141 |
142 | return x
143 |
144 |
145 | def densenet(**kwargs):
146 | """
147 | Constructs a ResNet model.
148 | """
149 | return DenseNet(**kwargs)
--------------------------------------------------------------------------------
/models/cifar/preresnet.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | '''Resnet for cifar dataset.
4 | Ported form
5 | https://github.com/facebook/fb.resnet.torch
6 | and
7 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
8 | (c) YANG, Wei
9 | '''
10 | import torch.nn as nn
11 | import math
12 |
13 |
14 | __all__ = ['preresnet']
15 |
16 | def conv3x3(in_planes, out_planes, stride=1):
17 | "3x3 convolution with padding"
18 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
19 | padding=1, bias=False)
20 |
21 |
22 | class BasicBlock(nn.Module):
23 | expansion = 1
24 |
25 | def __init__(self, inplanes, planes, stride=1, downsample=None):
26 | super(BasicBlock, self).__init__()
27 | self.bn1 = nn.BatchNorm2d(inplanes)
28 | self.relu = nn.ReLU(inplace=True)
29 | self.conv1 = conv3x3(inplanes, planes, stride)
30 | self.bn2 = nn.BatchNorm2d(planes)
31 | self.conv2 = conv3x3(planes, planes)
32 | self.downsample = downsample
33 | self.stride = stride
34 |
35 | def forward(self, x):
36 | residual = x
37 |
38 | out = self.bn1(x)
39 | out = self.relu(out)
40 | out = self.conv1(out)
41 |
42 | out = self.bn2(out)
43 | out = self.relu(out)
44 | out = self.conv2(out)
45 |
46 | if self.downsample is not None:
47 | residual = self.downsample(x)
48 |
49 | out += residual
50 |
51 | return out
52 |
53 |
54 | class Bottleneck(nn.Module):
55 | expansion = 4
56 |
57 | def __init__(self, inplanes, planes, stride=1, downsample=None):
58 | super(Bottleneck, self).__init__()
59 | self.bn1 = nn.BatchNorm2d(inplanes)
60 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
61 | self.bn2 = nn.BatchNorm2d(planes)
62 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
63 | padding=1, bias=False)
64 | self.bn3 = nn.BatchNorm2d(planes)
65 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
66 | self.relu = nn.ReLU(inplace=True)
67 | self.downsample = downsample
68 | self.stride = stride
69 |
70 | def forward(self, x):
71 | residual = x
72 |
73 | out = self.bn1(x)
74 | out = self.relu(out)
75 | out = self.conv1(out)
76 |
77 | out = self.bn2(out)
78 | out = self.relu(out)
79 | out = self.conv2(out)
80 |
81 | out = self.bn3(out)
82 | out = self.relu(out)
83 | out = self.conv3(out)
84 |
85 | if self.downsample is not None:
86 | residual = self.downsample(x)
87 |
88 | out += residual
89 |
90 | return out
91 |
92 |
93 | class PreResNet(nn.Module):
94 |
95 | def __init__(self, depth, num_classes=1000, block_name='BasicBlock'):
96 | super(PreResNet, self).__init__()
97 | # Model type specifies number of layers for CIFAR-10 model
98 | if block_name.lower() == 'basicblock':
99 | assert (depth - 2) % 6 == 0, 'When use basicblock, depth should be 6n+2, e.g. 20, 32, 44, 56, 110, 1202'
100 | n = (depth - 2) // 6
101 | block = BasicBlock
102 | elif block_name.lower() == 'bottleneck':
103 | assert (depth - 2) % 9 == 0, 'When use bottleneck, depth should be 9n+2, e.g. 20, 29, 47, 56, 110, 1199'
104 | n = (depth - 2) // 9
105 | block = Bottleneck
106 | else:
107 | raise ValueError('block_name shoule be Basicblock or Bottleneck')
108 |
109 | self.inplanes = 16
110 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1,
111 | bias=False)
112 | self.layer1 = self._make_layer(block, 16, n)
113 | self.layer2 = self._make_layer(block, 32, n, stride=2)
114 | self.layer3 = self._make_layer(block, 64, n, stride=2)
115 | self.bn = nn.BatchNorm2d(64 * block.expansion)
116 | self.relu = nn.ReLU(inplace=True)
117 | self.avgpool = nn.AvgPool2d(8)
118 | self.fc = nn.Linear(64 * block.expansion, num_classes)
119 |
120 | for m in self.modules():
121 | if isinstance(m, nn.Conv2d):
122 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
123 | m.weight.data.normal_(0, math.sqrt(2. / n))
124 | elif isinstance(m, nn.BatchNorm2d):
125 | m.weight.data.fill_(1)
126 | m.bias.data.zero_()
127 |
128 | def _make_layer(self, block, planes, blocks, stride=1):
129 | downsample = None
130 | if stride != 1 or self.inplanes != planes * block.expansion:
131 | downsample = nn.Sequential(
132 | nn.Conv2d(self.inplanes, planes * block.expansion,
133 | kernel_size=1, stride=stride, bias=False),
134 | )
135 |
136 | layers = []
137 | layers.append(block(self.inplanes, planes, stride, downsample))
138 | self.inplanes = planes * block.expansion
139 | for i in range(1, blocks):
140 | layers.append(block(self.inplanes, planes))
141 |
142 | return nn.Sequential(*layers)
143 |
144 | def forward(self, x):
145 | x = self.conv1(x)
146 |
147 | x = self.layer1(x) # 32x32
148 | x = self.layer2(x) # 16x16
149 | x = self.layer3(x) # 8x8
150 | x = self.bn(x)
151 | x = self.relu(x)
152 |
153 | x = self.avgpool(x)
154 | x = x.view(x.size(0), -1)
155 | x = self.fc(x)
156 |
157 | return x
158 |
159 |
160 | def preresnet(**kwargs):
161 | """
162 | Constructs a ResNet model.
163 | """
164 | return PreResNet(**kwargs)
165 |
--------------------------------------------------------------------------------
/models/cifar/resnet.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | '''Resnet for cifar dataset.
4 | Ported form
5 | https://github.com/facebook/fb.resnet.torch
6 | and
7 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
8 | (c) YANG, Wei
9 | '''
10 | import torch.nn as nn
11 | import math
12 |
13 |
14 | __all__ = ['resnet']
15 |
16 | def conv3x3(in_planes, out_planes, stride=1):
17 | "3x3 convolution with padding"
18 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
19 | padding=1, bias=False)
20 |
21 |
22 | class BasicBlock(nn.Module):
23 | expansion = 1
24 |
25 | def __init__(self, inplanes, planes, stride=1, downsample=None):
26 | super(BasicBlock, self).__init__()
27 | self.conv1 = conv3x3(inplanes, planes, stride)
28 | self.bn1 = nn.BatchNorm2d(planes)
29 | self.relu = nn.ReLU(inplace=True)
30 | self.conv2 = conv3x3(planes, planes)
31 | self.bn2 = nn.BatchNorm2d(planes)
32 | self.downsample = downsample
33 | self.stride = stride
34 |
35 | def forward(self, x):
36 | residual = x
37 |
38 | out = self.conv1(x)
39 | out = self.bn1(out)
40 | out = self.relu(out)
41 |
42 | out = self.conv2(out)
43 | out = self.bn2(out)
44 |
45 | if self.downsample is not None:
46 | residual = self.downsample(x)
47 |
48 | out += residual
49 | out = self.relu(out)
50 |
51 | return out
52 |
53 |
54 | class Bottleneck(nn.Module):
55 | expansion = 4
56 |
57 | def __init__(self, inplanes, planes, stride=1, downsample=None):
58 | super(Bottleneck, self).__init__()
59 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
60 | self.bn1 = nn.BatchNorm2d(planes)
61 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
62 | padding=1, bias=False)
63 | self.bn2 = nn.BatchNorm2d(planes)
64 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
65 | self.bn3 = nn.BatchNorm2d(planes * 4)
66 | self.relu = nn.ReLU(inplace=True)
67 | self.downsample = downsample
68 | self.stride = stride
69 |
70 | def forward(self, x):
71 | residual = x
72 |
73 | out = self.conv1(x)
74 | out = self.bn1(out)
75 | out = self.relu(out)
76 |
77 | out = self.conv2(out)
78 | out = self.bn2(out)
79 | out = self.relu(out)
80 |
81 | out = self.conv3(out)
82 | out = self.bn3(out)
83 |
84 | if self.downsample is not None:
85 | residual = self.downsample(x)
86 |
87 | out += residual
88 | out = self.relu(out)
89 |
90 | return out
91 |
92 |
93 | class ResNet(nn.Module):
94 |
95 | def __init__(self, depth, num_classes=1000, block_name='BasicBlock', **kwargs):
96 | super(ResNet, self).__init__()
97 | # Model type specifies number of layers for CIFAR-10 model
98 | if block_name.lower() == 'basicblock':
99 | assert (depth - 2) % 6 == 0, 'When use basicblock, depth should be 6n+2, e.g. 20, 32, 44, 56, 110, 1202'
100 | n = (depth - 2) // 6
101 | block = BasicBlock
102 | elif block_name.lower() == 'bottleneck':
103 | assert (depth - 2) % 9 == 0, 'When use bottleneck, depth should be 9n+2, e.g. 20, 29, 47, 56, 110, 1199'
104 | n = (depth - 2) // 9
105 | block = Bottleneck
106 | else:
107 | raise ValueError('block_name shoule be Basicblock or Bottleneck')
108 |
109 |
110 | self.inplanes = 16
111 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1,
112 | bias=False)
113 | self.bn1 = nn.BatchNorm2d(16)
114 | self.relu = nn.ReLU(inplace=True)
115 | self.layer1 = self._make_layer(block, 16, n)
116 | self.layer2 = self._make_layer(block, 32, n, stride=2)
117 | self.layer3 = self._make_layer(block, 64, n, stride=2)
118 | self.avgpool = nn.AvgPool2d(8)
119 | self.fc = nn.Linear(64 * block.expansion, num_classes)
120 |
121 | for m in self.modules():
122 | if isinstance(m, nn.Conv2d):
123 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
124 | m.weight.data.normal_(0, math.sqrt(2. / n))
125 | elif isinstance(m, nn.BatchNorm2d):
126 | m.weight.data.fill_(1)
127 | m.bias.data.zero_()
128 |
129 | def _make_layer(self, block, planes, blocks, stride=1):
130 | downsample = None
131 | if stride != 1 or self.inplanes != planes * block.expansion:
132 | downsample = nn.Sequential(
133 | nn.Conv2d(self.inplanes, planes * block.expansion,
134 | kernel_size=1, stride=stride, bias=False),
135 | nn.BatchNorm2d(planes * block.expansion),
136 | )
137 |
138 | layers = []
139 | layers.append(block(self.inplanes, planes, stride, downsample))
140 | self.inplanes = planes * block.expansion
141 | for i in range(1, blocks):
142 | layers.append(block(self.inplanes, planes))
143 |
144 | return nn.Sequential(*layers)
145 |
146 | def forward(self, x):
147 | x = self.conv1(x)
148 | x = self.bn1(x)
149 | x = self.relu(x) # 32x32
150 |
151 | x = self.layer1(x) # 32x32
152 | x = self.layer2(x) # 16x16
153 | x = self.layer3(x) # 8x8
154 |
155 | x = self.avgpool(x)
156 | x = x.view(x.size(0), -1)
157 | x = self.fc(x)
158 |
159 | return x
160 |
161 |
162 | def resnet(**kwargs):
163 | """
164 | Constructs a ResNet model.
165 | """
166 | return ResNet(**kwargs)
167 |
--------------------------------------------------------------------------------
/models/cifar/resnext.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | """
3 | Creates a ResNeXt Model as defined in:
4 | Xie, S., Girshick, R., Dollar, P., Tu, Z., & He, K. (2016).
5 | Aggregated residual transformations for deep neural networks.
6 | arXiv preprint arXiv:1611.05431.
7 | import from https://github.com/prlz77/ResNeXt.pytorch/blob/master/models/model.py
8 | """
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | from torch.nn import init
12 |
13 | __all__ = ['resnext']
14 |
15 | class ResNeXtBottleneck(nn.Module):
16 | """
17 | RexNeXt bottleneck type C (https://github.com/facebookresearch/ResNeXt/blob/master/models/resnext.lua)
18 | """
19 | def __init__(self, in_channels, out_channels, stride, cardinality, widen_factor):
20 | """ Constructor
21 | Args:
22 | in_channels: input channel dimensionality
23 | out_channels: output channel dimensionality
24 | stride: conv stride. Replaces pooling layer.
25 | cardinality: num of convolution groups.
26 | widen_factor: factor to reduce the input dimensionality before convolution.
27 | """
28 | super(ResNeXtBottleneck, self).__init__()
29 | D = cardinality * out_channels // widen_factor
30 | self.conv_reduce = nn.Conv2d(in_channels, D, kernel_size=1, stride=1, padding=0, bias=False)
31 | self.bn_reduce = nn.BatchNorm2d(D)
32 | self.conv_conv = nn.Conv2d(D, D, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False)
33 | self.bn = nn.BatchNorm2d(D)
34 | self.conv_expand = nn.Conv2d(D, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
35 | self.bn_expand = nn.BatchNorm2d(out_channels)
36 |
37 | self.shortcut = nn.Sequential()
38 | if in_channels != out_channels:
39 | self.shortcut.add_module('shortcut_conv', nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=False))
40 | self.shortcut.add_module('shortcut_bn', nn.BatchNorm2d(out_channels))
41 |
42 | def forward(self, x):
43 | bottleneck = self.conv_reduce.forward(x)
44 | bottleneck = F.relu(self.bn_reduce.forward(bottleneck), inplace=True)
45 | bottleneck = self.conv_conv.forward(bottleneck)
46 | bottleneck = F.relu(self.bn.forward(bottleneck), inplace=True)
47 | bottleneck = self.conv_expand.forward(bottleneck)
48 | bottleneck = self.bn_expand.forward(bottleneck)
49 | residual = self.shortcut.forward(x)
50 | return F.relu(residual + bottleneck, inplace=True)
51 |
52 |
53 | class CifarResNeXt(nn.Module):
54 | """
55 | ResNext optimized for the Cifar dataset, as specified in
56 | https://arxiv.org/pdf/1611.05431.pdf
57 | """
58 | def __init__(self, cardinality, depth, num_classes, widen_factor=4, dropRate=0):
59 | """ Constructor
60 | Args:
61 | cardinality: number of convolution groups.
62 | depth: number of layers.
63 | num_classes: number of classes
64 | widen_factor: factor to adjust the channel dimensionality
65 | """
66 | super(CifarResNeXt, self).__init__()
67 | self.cardinality = cardinality
68 | self.depth = depth
69 | self.block_depth = (self.depth - 2) // 9
70 | self.widen_factor = widen_factor
71 | self.num_classes = num_classes
72 | self.output_size = 64
73 | self.stages = [64, 64 * self.widen_factor, 128 * self.widen_factor, 256 * self.widen_factor]
74 |
75 | self.conv_1_3x3 = nn.Conv2d(3, 64, 3, 1, 1, bias=False)
76 | self.bn_1 = nn.BatchNorm2d(64)
77 | self.stage_1 = self.block('stage_1', self.stages[0], self.stages[1], 1)
78 | self.stage_2 = self.block('stage_2', self.stages[1], self.stages[2], 2)
79 | self.stage_3 = self.block('stage_3', self.stages[2], self.stages[3], 2)
80 | self.classifier = nn.Linear(1024, num_classes)
81 | init.kaiming_normal(self.classifier.weight)
82 |
83 | for key in self.state_dict():
84 | if key.split('.')[-1] == 'weight':
85 | if 'conv' in key:
86 | init.kaiming_normal(self.state_dict()[key], mode='fan_out')
87 | if 'bn' in key:
88 | self.state_dict()[key][...] = 1
89 | elif key.split('.')[-1] == 'bias':
90 | self.state_dict()[key][...] = 0
91 |
92 | def block(self, name, in_channels, out_channels, pool_stride=2):
93 | """ Stack n bottleneck modules where n is inferred from the depth of the network.
94 | Args:
95 | name: string name of the current block.
96 | in_channels: number of input channels
97 | out_channels: number of output channels
98 | pool_stride: factor to reduce the spatial dimensionality in the first bottleneck of the block.
99 | Returns: a Module consisting of n sequential bottlenecks.
100 | """
101 | block = nn.Sequential()
102 | for bottleneck in range(self.block_depth):
103 | name_ = '%s_bottleneck_%d' % (name, bottleneck)
104 | if bottleneck == 0:
105 | block.add_module(name_, ResNeXtBottleneck(in_channels, out_channels, pool_stride, self.cardinality,
106 | self.widen_factor))
107 | else:
108 | block.add_module(name_,
109 | ResNeXtBottleneck(out_channels, out_channels, 1, self.cardinality, self.widen_factor))
110 | return block
111 |
112 | def forward(self, x):
113 | x = self.conv_1_3x3.forward(x)
114 | x = F.relu(self.bn_1.forward(x), inplace=True)
115 | x = self.stage_1.forward(x)
116 | x = self.stage_2.forward(x)
117 | x = self.stage_3.forward(x)
118 | x = F.avg_pool2d(x, 8, 1)
119 | x = x.view(-1, 1024)
120 | return self.classifier(x)
121 |
122 | def resnext(**kwargs):
123 | """Constructs a ResNeXt.
124 | """
125 | model = CifarResNeXt(**kwargs)
126 | return model
--------------------------------------------------------------------------------
/models/cifar/vgg.py:
--------------------------------------------------------------------------------
1 | '''VGG for CIFAR10. FC layers are removed.
2 | (c) YANG, Wei
3 | '''
4 | import torch.nn as nn
5 | import torch.utils.model_zoo as model_zoo
6 | import math
7 |
8 |
9 | __all__ = [
10 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
11 | 'vgg19_bn', 'vgg19',
12 | ]
13 |
14 |
15 | model_urls = {
16 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
17 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
18 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
19 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
20 | }
21 |
22 |
23 | class VGG(nn.Module):
24 |
25 | def __init__(self, features, num_classes=1000, **kwargs):
26 | super(VGG, self).__init__()
27 | self.features = features
28 | self.classifier = nn.Linear(512, num_classes)
29 | self._initialize_weights()
30 |
31 | def forward(self, x):
32 | x = self.features(x)
33 | x = x.view(x.size(0), -1)
34 | x = self.classifier(x)
35 | return x
36 |
37 | def _initialize_weights(self):
38 | for m in self.modules():
39 | if isinstance(m, nn.Conv2d):
40 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
41 | m.weight.data.normal_(0, math.sqrt(2. / n))
42 | if m.bias is not None:
43 | m.bias.data.zero_()
44 | elif isinstance(m, nn.BatchNorm2d):
45 | m.weight.data.fill_(1)
46 | m.bias.data.zero_()
47 | elif isinstance(m, nn.Linear):
48 | n = m.weight.size(1)
49 | m.weight.data.normal_(0, 0.01)
50 | m.bias.data.zero_()
51 |
52 |
53 | def make_layers(cfg, batch_norm=False):
54 | layers = []
55 | in_channels = 3
56 | for v in cfg:
57 | if v == 'M':
58 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
59 | else:
60 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
61 | if batch_norm:
62 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
63 | else:
64 | layers += [conv2d, nn.ReLU(inplace=True)]
65 | in_channels = v
66 | return nn.Sequential(*layers)
67 |
68 |
69 | cfg = {
70 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
71 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
72 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
73 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
74 | }
75 |
76 |
77 | def vgg11(**kwargs):
78 | """VGG 11-layer model (configuration "A")
79 |
80 | Args:
81 | pretrained (bool): If True, returns a model pre-trained on ImageNet
82 | """
83 | model = VGG(make_layers(cfg['A']), **kwargs)
84 | return model
85 |
86 |
87 | def vgg11_bn(**kwargs):
88 | """VGG 11-layer model (configuration "A") with batch normalization"""
89 | model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs)
90 | return model
91 |
92 |
93 | def vgg13(**kwargs):
94 | """VGG 13-layer model (configuration "B")
95 |
96 | Args:
97 | pretrained (bool): If True, returns a model pre-trained on ImageNet
98 | """
99 | model = VGG(make_layers(cfg['B']), **kwargs)
100 | return model
101 |
102 |
103 | def vgg13_bn(**kwargs):
104 | """VGG 13-layer model (configuration "B") with batch normalization"""
105 | model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs)
106 | return model
107 |
108 |
109 | def vgg16(**kwargs):
110 | """VGG 16-layer model (configuration "D")
111 |
112 | Args:
113 | pretrained (bool): If True, returns a model pre-trained on ImageNet
114 | """
115 | model = VGG(make_layers(cfg['D']), **kwargs)
116 | return model
117 |
118 |
119 | def vgg16_bn(**kwargs):
120 | """VGG 16-layer model (configuration "D") with batch normalization"""
121 | model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs)
122 | return model
123 |
124 |
125 | def vgg19(**kwargs):
126 | """VGG 19-layer model (configuration "E")
127 |
128 | Args:
129 | pretrained (bool): If True, returns a model pre-trained on ImageNet
130 | """
131 | model = VGG(make_layers(cfg['E']), **kwargs)
132 | return model
133 |
134 |
135 | def vgg19_bn(**kwargs):
136 | """VGG 19-layer model (configuration 'E') with batch normalization"""
137 | model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs)
138 | return model
139 |
--------------------------------------------------------------------------------
/models/cifar/wrn.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | __all__ = ['wrn']
7 |
8 | class BasicBlock(nn.Module):
9 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
10 | super(BasicBlock, self).__init__()
11 | self.bn1 = nn.BatchNorm2d(in_planes)
12 | self.relu1 = nn.ReLU(inplace=True)
13 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
14 | padding=1, bias=False)
15 | self.bn2 = nn.BatchNorm2d(out_planes)
16 | self.relu2 = nn.ReLU(inplace=True)
17 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
18 | padding=1, bias=False)
19 | self.droprate = dropRate
20 | self.equalInOut = (in_planes == out_planes)
21 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
22 | padding=0, bias=False) or None
23 | def forward(self, x):
24 | if not self.equalInOut:
25 | x = self.relu1(self.bn1(x))
26 | else:
27 | out = self.relu1(self.bn1(x))
28 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
29 | if self.droprate > 0:
30 | out = F.dropout(out, p=self.droprate, training=self.training)
31 | out = self.conv2(out)
32 | return torch.add(x if self.equalInOut else self.convShortcut(x), out)
33 |
34 | class NetworkBlock(nn.Module):
35 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0):
36 | super(NetworkBlock, self).__init__()
37 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate)
38 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate):
39 | layers = []
40 | for i in range(nb_layers):
41 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate))
42 | return nn.Sequential(*layers)
43 | def forward(self, x):
44 | return self.layer(x)
45 |
46 | class WideResNet(nn.Module):
47 | def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0, **kwargs):
48 | super(WideResNet, self).__init__()
49 | nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor]
50 | assert (depth - 4) % 6 == 0, 'depth should be 6n+4'
51 | n = (depth - 4) // 6
52 | block = BasicBlock
53 | # 1st conv before any network block
54 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
55 | padding=1, bias=False)
56 | # 1st block
57 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)
58 | # 2nd block
59 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
60 | # 3rd block
61 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
62 | # global average pooling and classifier
63 | self.bn1 = nn.BatchNorm2d(nChannels[3])
64 | self.relu = nn.ReLU(inplace=True)
65 | self.fc = nn.Linear(nChannels[3], num_classes)
66 | self.nChannels = nChannels[3]
67 |
68 | for m in self.modules():
69 | if isinstance(m, nn.Conv2d):
70 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
71 | m.weight.data.normal_(0, math.sqrt(2. / n))
72 | elif isinstance(m, nn.BatchNorm2d):
73 | m.weight.data.fill_(1)
74 | m.bias.data.zero_()
75 | elif isinstance(m, nn.Linear):
76 | m.bias.data.zero_()
77 |
78 | def forward(self, x):
79 | out = self.conv1(x)
80 | out = self.block1(out)
81 | out = self.block2(out)
82 | out = self.block3(out)
83 | out = self.relu(self.bn1(out))
84 | out = F.avg_pool2d(out, 8)
85 | out = out.view(-1, self.nChannels)
86 | return self.fc(out)
87 |
88 | def wrn(**kwargs):
89 | """
90 | Constructs a Wide Residual Networks.
91 | """
92 | model = WideResNet(**kwargs)
93 | return model
94 |
--------------------------------------------------------------------------------
/models/imagenet/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from .resnext import *
4 |
--------------------------------------------------------------------------------
/models/imagenet/resnext.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | """
3 | Creates a ResNeXt Model as defined in:
4 | Xie, S., Girshick, R., Dollar, P., Tu, Z., & He, K. (2016).
5 | Aggregated residual transformations for deep neural networks.
6 | arXiv preprint arXiv:1611.05431.
7 | import from https://github.com/facebookresearch/ResNeXt/blob/master/models/resnext.lua
8 | """
9 | import math
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from torch.nn import init
13 | import torch
14 |
15 | __all__ = ['resnext50', 'resnext101', 'resnext152']
16 |
17 | class Bottleneck(nn.Module):
18 | """
19 | RexNeXt bottleneck type C
20 | """
21 | expansion = 4
22 |
23 | def __init__(self, inplanes, planes, baseWidth, cardinality, stride=1, downsample=None):
24 | """ Constructor
25 | Args:
26 | inplanes: input channel dimensionality
27 | planes: output channel dimensionality
28 | baseWidth: base width.
29 | cardinality: num of convolution groups.
30 | stride: conv stride. Replaces pooling layer.
31 | """
32 | super(Bottleneck, self).__init__()
33 |
34 | D = int(math.floor(planes * (baseWidth / 64)))
35 | C = cardinality
36 |
37 | self.conv1 = nn.Conv2d(inplanes, D*C, kernel_size=1, stride=1, padding=0, bias=False)
38 | self.bn1 = nn.BatchNorm2d(D*C)
39 | self.conv2 = nn.Conv2d(D*C, D*C, kernel_size=3, stride=stride, padding=1, groups=C, bias=False)
40 | self.bn2 = nn.BatchNorm2d(D*C)
41 | self.conv3 = nn.Conv2d(D*C, planes * 4, kernel_size=1, stride=1, padding=0, bias=False)
42 | self.bn3 = nn.BatchNorm2d(planes * 4)
43 | self.relu = nn.ReLU(inplace=True)
44 |
45 | self.downsample = downsample
46 |
47 | def forward(self, x):
48 | residual = x
49 |
50 | out = self.conv1(x)
51 | out = self.bn1(out)
52 | out = self.relu(out)
53 |
54 | out = self.conv2(out)
55 | out = self.bn2(out)
56 | out = self.relu(out)
57 |
58 | out = self.conv3(out)
59 | out = self.bn3(out)
60 |
61 | if self.downsample is not None:
62 | residual = self.downsample(x)
63 |
64 | out += residual
65 | out = self.relu(out)
66 |
67 | return out
68 |
69 |
70 | class ResNeXt(nn.Module):
71 | """
72 | ResNext optimized for the ImageNet dataset, as specified in
73 | https://arxiv.org/pdf/1611.05431.pdf
74 | """
75 | def __init__(self, baseWidth, cardinality, layers, num_classes):
76 | """ Constructor
77 | Args:
78 | baseWidth: baseWidth for ResNeXt.
79 | cardinality: number of convolution groups.
80 | layers: config of layers, e.g., [3, 4, 6, 3]
81 | num_classes: number of classes
82 | """
83 | super(ResNeXt, self).__init__()
84 | block = Bottleneck
85 |
86 | self.cardinality = cardinality
87 | self.baseWidth = baseWidth
88 | self.num_classes = num_classes
89 | self.inplanes = 64
90 | self.output_size = 64
91 |
92 | self.conv1 = nn.Conv2d(3, 64, 7, 2, 3, bias=False)
93 | self.bn1 = nn.BatchNorm2d(64)
94 | self.relu = nn.ReLU(inplace=True)
95 | self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
96 | self.layer1 = self._make_layer(block, 64, layers[0])
97 | self.layer2 = self._make_layer(block, 128, layers[1], 2)
98 | self.layer3 = self._make_layer(block, 256, layers[2], 2)
99 | self.layer4 = self._make_layer(block, 512, layers[3], 2)
100 | self.avgpool = nn.AvgPool2d(7)
101 | self.fc = nn.Linear(512 * block.expansion, num_classes)
102 |
103 | for m in self.modules():
104 | if isinstance(m, nn.Conv2d):
105 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
106 | m.weight.data.normal_(0, math.sqrt(2. / n))
107 | elif isinstance(m, nn.BatchNorm2d):
108 | m.weight.data.fill_(1)
109 | m.bias.data.zero_()
110 |
111 | def _make_layer(self, block, planes, blocks, stride=1):
112 | """ Stack n bottleneck modules where n is inferred from the depth of the network.
113 | Args:
114 | block: block type used to construct ResNext
115 | planes: number of output channels (need to multiply by block.expansion)
116 | blocks: number of blocks to be built
117 | stride: factor to reduce the spatial dimensionality in the first bottleneck of the block.
118 | Returns: a Module consisting of n sequential bottlenecks.
119 | """
120 | downsample = None
121 | if stride != 1 or self.inplanes != planes * block.expansion:
122 | downsample = nn.Sequential(
123 | nn.Conv2d(self.inplanes, planes * block.expansion,
124 | kernel_size=1, stride=stride, bias=False),
125 | nn.BatchNorm2d(planes * block.expansion),
126 | )
127 |
128 | layers = []
129 | layers.append(block(self.inplanes, planes, self.baseWidth, self.cardinality, stride, downsample))
130 | self.inplanes = planes * block.expansion
131 | for i in range(1, blocks):
132 | layers.append(block(self.inplanes, planes, self.baseWidth, self.cardinality))
133 |
134 | return nn.Sequential(*layers)
135 |
136 | def forward(self, x):
137 | x = self.conv1(x)
138 | x = self.bn1(x)
139 | x = self.relu(x)
140 | x = self.maxpool1(x)
141 | x = self.layer1(x)
142 | x = self.layer2(x)
143 | x = self.layer3(x)
144 | x = self.layer4(x)
145 | x = self.avgpool(x)
146 | x = x.view(x.size(0), -1)
147 | x = self.fc(x)
148 |
149 | return x
150 |
151 |
152 | def resnext50(baseWidth, cardinality):
153 | """
154 | Construct ResNeXt-50.
155 | """
156 | model = ResNeXt(baseWidth, cardinality, [3, 4, 6, 3], 1000)
157 | return model
158 |
159 |
160 | def resnext101(baseWidth, cardinality):
161 | """
162 | Construct ResNeXt-101.
163 | """
164 | model = ResNeXt(baseWidth, cardinality, [3, 4, 23, 3], 1000)
165 | return model
166 |
167 |
168 | def resnext152(baseWidth, cardinality):
169 | """
170 | Construct ResNeXt-152.
171 | """
172 | model = ResNeXt(baseWidth, cardinality, [3, 8, 36, 3], 1000)
173 | return model
174 |
--------------------------------------------------------------------------------
/optimizers/__init__.py:
--------------------------------------------------------------------------------
1 | from .kfac import KFACOptimizer
2 | from .ekfac import EKFACOptimizer
3 |
4 |
5 | def get_optimizer(name):
6 | if name == 'kfac':
7 | return KFACOptimizer
8 | elif name == 'ekfac':
9 | return EKFACOptimizer
10 | else:
11 | raise NotImplementedError
--------------------------------------------------------------------------------
/optimizers/ekfac.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.optim as optim
5 |
6 | from utils.kfac_utils import (ComputeCovA, ComputeCovG, ComputeMatGrad)
7 | from utils.kfac_utils import update_running_stat
8 |
9 |
10 | class EKFACOptimizer(optim.Optimizer):
11 | def __init__(self,
12 | model,
13 | lr=0.001,
14 | momentum=0.9,
15 | stat_decay=0.95,
16 | damping=0.001,
17 | kl_clip=0.001,
18 | weight_decay=0,
19 | TCov=10,
20 | TScal=10,
21 | TInv=100,
22 | batch_averaged=True):
23 | if lr < 0.0:
24 | raise ValueError("Invalid learning rate: {}".format(lr))
25 | if momentum < 0.0:
26 | raise ValueError("Invalid momentum value: {}".format(momentum))
27 | if weight_decay < 0.0:
28 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
29 | defaults = dict(lr=lr, momentum=momentum, damping=damping,
30 | weight_decay=weight_decay)
31 | # TODO (CW): EKFAC optimizer now only support model as input
32 | super(EKFACOptimizer, self).__init__(model.parameters(), defaults)
33 | self.CovAHandler = ComputeCovA()
34 | self.CovGHandler = ComputeCovG()
35 | self.MatGradHandler = ComputeMatGrad()
36 | self.batch_averaged = batch_averaged
37 |
38 | self.known_modules = {'Linear', 'Conv2d'}
39 |
40 | self.modules = []
41 | self.grad_outputs = {}
42 |
43 | self.model = model
44 | self._prepare_model()
45 |
46 | self.steps = 0
47 |
48 | self.m_aa, self.m_gg = {}, {}
49 | self.Q_a, self.Q_g = {}, {}
50 | self.d_a, self.d_g = {}, {}
51 | self.S_l = {}
52 | self.A, self.DS = {}, {}
53 | self.stat_decay = stat_decay
54 |
55 | self.kl_clip = kl_clip
56 | self.TCov = TCov
57 | self.TScal = TScal
58 | self.TInv = TInv
59 |
60 | def _save_input(self, module, input):
61 | if torch.is_grad_enabled() and self.steps % self.TCov == 0:
62 | aa = self.CovAHandler(input[0].data, module)
63 | # Initialize buffers
64 | if self.steps == 0:
65 | self.m_aa[module] = torch.diag(aa.new(aa.size(0)).fill_(1))
66 | update_running_stat(aa, self.m_aa[module], self.stat_decay)
67 | if torch.is_grad_enabled() and self.steps % self.TScal == 0 and self.steps > 0:
68 | self.A[module] = input[0].data
69 |
70 | def _save_grad_output(self, module, grad_input, grad_output):
71 | # Accumulate statistics for Fisher matrices
72 | if self.acc_stats and self.steps % self.TCov == 0:
73 | gg = self.CovGHandler(grad_output[0].data, module, self.batch_averaged)
74 | # Initialize buffers
75 | if self.steps == 0:
76 | self.m_gg[module] = torch.diag(gg.new(gg.size(0)).fill_(1))
77 | update_running_stat(gg, self.m_gg[module], self.stat_decay)
78 |
79 | # if self.steps % self.TInv == 0:
80 | # self._update_inv(module)
81 |
82 | if self.acc_stats and self.steps % self.TScal == 0 and self.steps > 0:
83 | self.DS[module] = grad_output[0].data
84 | # self._update_scale(module)
85 |
86 | def _prepare_model(self):
87 | count = 0
88 | print(self.model)
89 | print("=> We keep following layers in EKFAC. ")
90 | for module in self.model.modules():
91 | classname = module.__class__.__name__
92 | if classname in self.known_modules:
93 | self.modules.append(module)
94 | module.register_forward_pre_hook(self._save_input)
95 | module.register_backward_hook(self._save_grad_output)
96 | print('(%s): %s' % (count, module))
97 | count += 1
98 |
99 | def _update_inv(self, m):
100 | """Do eigen decomposition for computing inverse of the ~ fisher.
101 | :param m: The layer
102 | :return: no returns.
103 | """
104 | eps = 1e-10 # for numerical stability
105 | self.d_a[m], self.Q_a[m] = torch.symeig(
106 | self.m_aa[m], eigenvectors=True)
107 | self.d_g[m], self.Q_g[m] = torch.symeig(
108 | self.m_gg[m], eigenvectors=True)
109 |
110 | self.d_a[m].mul_((self.d_a[m] > eps).float())
111 | self.d_g[m].mul_((self.d_g[m] > eps).float())
112 | # if self.steps != 0:
113 | self.S_l[m] = self.d_g[m].unsqueeze(1) @ self.d_a[m].unsqueeze(0)
114 |
115 | @staticmethod
116 | def _get_matrix_form_grad(m, classname):
117 | """
118 | :param m: the layer
119 | :param classname: the class name of the layer
120 | :return: a matrix form of the gradient. it should be a [output_dim, input_dim] matrix.
121 | """
122 | if classname == 'Conv2d':
123 | p_grad_mat = m.weight.grad.data.view(m.weight.grad.data.size(0), -1) # n_filters * (in_c * kw * kh)
124 | else:
125 | p_grad_mat = m.weight.grad.data
126 | if m.bias is not None:
127 | p_grad_mat = torch.cat([p_grad_mat, m.bias.grad.data.view(-1, 1)], 1)
128 | return p_grad_mat
129 |
130 | def _get_natural_grad(self, m, p_grad_mat, damping):
131 | """
132 | :param m: the layer
133 | :param p_grad_mat: the gradients in matrix form
134 | :return: a list of gradients w.r.t to the parameters in `m`
135 | """
136 | # p_grad_mat is of output_dim * input_dim
137 | # inv((ss')) p_grad_mat inv(aa') = [ Q_g (1/R_g) Q_g^T ] @ p_grad_mat @ [Q_a (1/R_a) Q_a^T]
138 | v1 = self.Q_g[m].t() @ p_grad_mat @ self.Q_a[m]
139 | v2 = v1 / (self.S_l[m] + damping)
140 | v = self.Q_g[m] @ v2 @ self.Q_a[m].t()
141 | if m.bias is not None:
142 | # we always put gradient w.r.t weight in [0]
143 | # and w.r.t bias in [1]
144 | v = [v[:, :-1], v[:, -1:]]
145 | v[0] = v[0].view(m.weight.grad.data.size())
146 | v[1] = v[1].view(m.bias.grad.data.size())
147 | else:
148 | v = [v.view(m.weight.grad.data.size())]
149 |
150 | return v
151 |
152 | def _kl_clip_and_update_grad(self, updates, lr):
153 | # do kl clip
154 | vg_sum = 0
155 | for m in self.modules:
156 | v = updates[m]
157 | vg_sum += (v[0] * m.weight.grad.data * lr ** 2).sum().item()
158 | if m.bias is not None:
159 | vg_sum += (v[1] * m.bias.grad.data * lr ** 2).sum().item()
160 | nu = min(1.0, math.sqrt(self.kl_clip / vg_sum))
161 |
162 | for m in self.modules:
163 | v = updates[m]
164 | m.weight.grad.data.copy_(v[0])
165 | m.weight.grad.data.mul_(nu)
166 | if m.bias is not None:
167 | m.bias.grad.data.copy_(v[1])
168 | m.bias.grad.data.mul_(nu)
169 |
170 | def _step(self, closure):
171 | # FIXME (CW): Modified based on SGD (removed nestrov and dampening in momentum.)
172 | # FIXME (CW): 1. no nesterov, 2. buf.mul_(momentum).add_(1 - dampening , d_p)
173 | for group in self.param_groups:
174 | weight_decay = group['weight_decay']
175 | momentum = group['momentum']
176 |
177 | for p in group['params']:
178 | if p.grad is None:
179 | continue
180 | d_p = p.grad.data
181 | if weight_decay != 0 and self.steps >= 20 * self.TCov:
182 | d_p.add_(weight_decay, p.data)
183 | if momentum != 0:
184 | param_state = self.state[p]
185 | if 'momentum_buffer' not in param_state:
186 | buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
187 | buf.mul_(momentum).add_(d_p)
188 | else:
189 | buf = param_state['momentum_buffer']
190 | buf.mul_(momentum).add_(1, d_p)
191 | d_p = buf
192 |
193 | p.data.add_(-group['lr'], d_p)
194 |
195 | def _update_scale(self, m):
196 | with torch.no_grad():
197 | A, S = self.A[m], self.DS[m]
198 | grad_mat = self.MatGradHandler(A, S, m) # batch_size * out_dim * in_dim
199 | if self.batch_averaged:
200 | grad_mat *= S.size(0)
201 |
202 | s_l = (self.Q_g[m] @ grad_mat @ self.Q_a[m].t()) ** 2 # <- this consumes too much memory!
203 | s_l = s_l.mean(dim=0)
204 | if self.steps == 0:
205 | self.S_l[m] = s_l.new(s_l.size()).fill_(1)
206 | # s_ls = self.Q_g[m] @ grad_s
207 | # s_la = in_a @ self.Q_a[m].t()
208 | # s_l = 0
209 | # for i in range(0, s_ls.size(0), S.size(0)): # tradeoff between time and memory
210 | # start = i
211 | # end = min(s_ls.size(0), i + S.size(0))
212 | # s_l += (torch.bmm(s_ls[start:end,:], s_la[start:end,:]) ** 2).sum(0)
213 | # s_l /= s_ls.size(0)
214 | # if self.steps == 0:
215 | # self.S_l[m] = s_l.new(s_l.size()).fill_(1)
216 | update_running_stat(s_l, self.S_l[m], self.stat_decay)
217 | # remove reference for reducing memory cost.
218 | self.A[m] = None
219 | self.DS[m] = None
220 |
221 | def step(self, closure=None):
222 | # FIXME(CW): temporal fix for compatibility with Official LR scheduler.
223 | group = self.param_groups[0]
224 | lr = group['lr']
225 | damping = group['damping']
226 | updates = {}
227 | for m in self.modules:
228 | classname = m.__class__.__name__
229 | if self.steps % self.TInv == 0:
230 | self._update_inv(m)
231 |
232 | if self.steps % self.TScal == 0 and self.steps > 0:
233 | self._update_scale(m)
234 |
235 | p_grad_mat = self._get_matrix_form_grad(m, classname)
236 | v = self._get_natural_grad(m, p_grad_mat, damping)
237 | updates[m] = v
238 | self._kl_clip_and_update_grad(updates, lr)
239 |
240 | self._step(closure)
241 | self.steps += 1
242 |
--------------------------------------------------------------------------------
/optimizers/kfac.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.optim as optim
5 |
6 | from utils.kfac_utils import (ComputeCovA, ComputeCovG)
7 | from utils.kfac_utils import update_running_stat
8 |
9 |
10 | class KFACOptimizer(optim.Optimizer):
11 | def __init__(self,
12 | model,
13 | lr=0.001,
14 | momentum=0.9,
15 | stat_decay=0.95,
16 | damping=0.001,
17 | kl_clip=0.001,
18 | weight_decay=0,
19 | TCov=10,
20 | TInv=100,
21 | batch_averaged=True):
22 | if lr < 0.0:
23 | raise ValueError("Invalid learning rate: {}".format(lr))
24 | if momentum < 0.0:
25 | raise ValueError("Invalid momentum value: {}".format(momentum))
26 | if weight_decay < 0.0:
27 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
28 | defaults = dict(lr=lr, momentum=momentum, damping=damping,
29 | weight_decay=weight_decay)
30 | # TODO (CW): KFAC optimizer now only support model as input
31 | super(KFACOptimizer, self).__init__(model.parameters(), defaults)
32 | self.CovAHandler = ComputeCovA()
33 | self.CovGHandler = ComputeCovG()
34 | self.batch_averaged = batch_averaged
35 |
36 | self.known_modules = {'Linear', 'Conv2d'}
37 |
38 | self.modules = []
39 | self.grad_outputs = {}
40 |
41 | self.model = model
42 | self._prepare_model()
43 |
44 | self.steps = 0
45 |
46 | self.m_aa, self.m_gg = {}, {}
47 | self.Q_a, self.Q_g = {}, {}
48 | self.d_a, self.d_g = {}, {}
49 | self.stat_decay = stat_decay
50 |
51 | self.kl_clip = kl_clip
52 | self.TCov = TCov
53 | self.TInv = TInv
54 |
55 | def _save_input(self, module, input):
56 | if torch.is_grad_enabled() and self.steps % self.TCov == 0:
57 | aa = self.CovAHandler(input[0].data, module)
58 | # Initialize buffers
59 | if self.steps == 0:
60 | self.m_aa[module] = torch.diag(aa.new(aa.size(0)).fill_(1))
61 | update_running_stat(aa, self.m_aa[module], self.stat_decay)
62 |
63 | def _save_grad_output(self, module, grad_input, grad_output):
64 | # Accumulate statistics for Fisher matrices
65 | if self.acc_stats and self.steps % self.TCov == 0:
66 | gg = self.CovGHandler(grad_output[0].data, module, self.batch_averaged)
67 | # Initialize buffers
68 | if self.steps == 0:
69 | self.m_gg[module] = torch.diag(gg.new(gg.size(0)).fill_(1))
70 | update_running_stat(gg, self.m_gg[module], self.stat_decay)
71 |
72 | def _prepare_model(self):
73 | count = 0
74 | print(self.model)
75 | print("=> We keep following layers in KFAC. ")
76 | for module in self.model.modules():
77 | classname = module.__class__.__name__
78 | # print('=> We keep following layers in KFAC. <=')
79 | if classname in self.known_modules:
80 | self.modules.append(module)
81 | module.register_forward_pre_hook(self._save_input)
82 | module.register_backward_hook(self._save_grad_output)
83 | print('(%s): %s' % (count, module))
84 | count += 1
85 |
86 | def _update_inv(self, m):
87 | """Do eigen decomposition for computing inverse of the ~ fisher.
88 | :param m: The layer
89 | :return: no returns.
90 | """
91 | eps = 1e-10 # for numerical stability
92 | self.d_a[m], self.Q_a[m] = torch.symeig(
93 | self.m_aa[m], eigenvectors=True)
94 | self.d_g[m], self.Q_g[m] = torch.symeig(
95 | self.m_gg[m], eigenvectors=True)
96 |
97 | self.d_a[m].mul_((self.d_a[m] > eps).float())
98 | self.d_g[m].mul_((self.d_g[m] > eps).float())
99 |
100 | @staticmethod
101 | def _get_matrix_form_grad(m, classname):
102 | """
103 | :param m: the layer
104 | :param classname: the class name of the layer
105 | :return: a matrix form of the gradient. it should be a [output_dim, input_dim] matrix.
106 | """
107 | if classname == 'Conv2d':
108 | p_grad_mat = m.weight.grad.data.view(m.weight.grad.data.size(0), -1) # n_filters * (in_c * kw * kh)
109 | else:
110 | p_grad_mat = m.weight.grad.data
111 | if m.bias is not None:
112 | p_grad_mat = torch.cat([p_grad_mat, m.bias.grad.data.view(-1, 1)], 1)
113 | return p_grad_mat
114 |
115 | def _get_natural_grad(self, m, p_grad_mat, damping):
116 | """
117 | :param m: the layer
118 | :param p_grad_mat: the gradients in matrix form
119 | :return: a list of gradients w.r.t to the parameters in `m`
120 | """
121 | # p_grad_mat is of output_dim * input_dim
122 | # inv((ss')) p_grad_mat inv(aa') = [ Q_g (1/R_g) Q_g^T ] @ p_grad_mat @ [Q_a (1/R_a) Q_a^T]
123 | v1 = self.Q_g[m].t() @ p_grad_mat @ self.Q_a[m]
124 | v2 = v1 / (self.d_g[m].unsqueeze(1) * self.d_a[m].unsqueeze(0) + damping)
125 | v = self.Q_g[m] @ v2 @ self.Q_a[m].t()
126 | if m.bias is not None:
127 | # we always put gradient w.r.t weight in [0]
128 | # and w.r.t bias in [1]
129 | v = [v[:, :-1], v[:, -1:]]
130 | v[0] = v[0].view(m.weight.grad.data.size())
131 | v[1] = v[1].view(m.bias.grad.data.size())
132 | else:
133 | v = [v.view(m.weight.grad.data.size())]
134 |
135 | return v
136 |
137 | def _kl_clip_and_update_grad(self, updates, lr):
138 | # do kl clip
139 | vg_sum = 0
140 | for m in self.modules:
141 | v = updates[m]
142 | vg_sum += (v[0] * m.weight.grad.data * lr ** 2).sum().item()
143 | if m.bias is not None:
144 | vg_sum += (v[1] * m.bias.grad.data * lr ** 2).sum().item()
145 | nu = min(1.0, math.sqrt(self.kl_clip / vg_sum))
146 |
147 | for m in self.modules:
148 | v = updates[m]
149 | m.weight.grad.data.copy_(v[0])
150 | m.weight.grad.data.mul_(nu)
151 | if m.bias is not None:
152 | m.bias.grad.data.copy_(v[1])
153 | m.bias.grad.data.mul_(nu)
154 |
155 | def _step(self, closure):
156 | # FIXME (CW): Modified based on SGD (removed nestrov and dampening in momentum.)
157 | # FIXME (CW): 1. no nesterov, 2. buf.mul_(momentum).add_(1 - dampening , d_p)
158 | for group in self.param_groups:
159 | weight_decay = group['weight_decay']
160 | momentum = group['momentum']
161 |
162 | for p in group['params']:
163 | if p.grad is None:
164 | continue
165 | d_p = p.grad.data
166 | if weight_decay != 0 and self.steps >= 20 * self.TCov:
167 | d_p.add_(weight_decay, p.data)
168 | if momentum != 0:
169 | param_state = self.state[p]
170 | if 'momentum_buffer' not in param_state:
171 | buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
172 | buf.mul_(momentum).add_(d_p)
173 | else:
174 | buf = param_state['momentum_buffer']
175 | buf.mul_(momentum).add_(1, d_p)
176 | d_p = buf
177 |
178 | p.data.add_(-group['lr'], d_p)
179 |
180 | def step(self, closure=None):
181 | # FIXME(CW): temporal fix for compatibility with Official LR scheduler.
182 | group = self.param_groups[0]
183 | lr = group['lr']
184 | damping = group['damping']
185 | updates = {}
186 | for m in self.modules:
187 | classname = m.__class__.__name__
188 | if self.steps % self.TInv == 0:
189 | self._update_inv(m)
190 | p_grad_mat = self._get_matrix_form_grad(m, classname)
191 | v = self._get_natural_grad(m, p_grad_mat, damping)
192 | updates[m] = v
193 | self._kl_clip_and_update_grad(updates, lr)
194 |
195 | self._step(closure)
196 | self.steps += 1
197 |
--------------------------------------------------------------------------------
/trainer.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 |
5 | parser = argparse.ArgumentParser()
6 | parser.add_argument('--network', type=str, default='vgg16_bn')
7 | parser.add_argument('--dataset', type=str, default='cifar10')
8 | parser.add_argument('--optimizer', type=str, default='kfac')
9 | parser.add_argument('--machine', type=int, default=10)
10 |
11 | args = parser.parse_args()
12 |
13 | vgg16_bn = ''
14 | vgg19_bn = ''
15 | resnet = '--depth 110'
16 | wrn = '--depth 28 --widen_factor 10 --dropRate 0.3'
17 | densenet = '--depth 100 --growthRate 12'
18 |
19 | apps = {
20 | 'vgg16_bn': vgg16_bn,
21 | 'vgg19_bn': vgg19_bn,
22 | 'resnet': resnet,
23 | 'wrn': wrn,
24 | 'densenet': densenet
25 | }
26 |
27 |
28 | def grid_search(args):
29 | scripts = []
30 | if args.optimizer in ['kfac', 'ekfac']:
31 | template = 'python main.py ' \
32 | '--dataset %s ' \
33 | '--optimizer %s ' \
34 | '--network %s ' \
35 | ' --epoch 100 ' \
36 | '--milestone 40,80 ' \
37 | '--learning_rate %f ' \
38 | '--damping %f ' \
39 | '--weight_decay %f %s'
40 |
41 | lrs = [3e-2, 1e-2, 3e-3]
42 | dampings = [3e-2, 1e-3, 3e-3]
43 | wds = [1e-2, 3e-3, 1e-3, 3e-4, 1e-4]
44 | app = apps[args.network]
45 | for lr in lrs:
46 | for dmp in dampings:
47 | for wd in wds:
48 | scripts.append(template % (args.dataset, args.optimizer, args.network, lr, dmp, wd, app))
49 | elif args.optimizer == 'sgd':
50 | template = 'python main.py ' \
51 | '--dataset %s ' \
52 | '--optimizer %s ' \
53 | '--network %s ' \
54 | ' --epoch 200 ' \
55 | '--milestone 60,120,180 ' \
56 | '--learning_rate %f ' \
57 | '--weight_decay %f %s'
58 | app = apps[args.network]
59 | lrs = [3e-1, 1e-1, 3e-2]
60 | wds = [1e-2, 3e-3, 1e-3, 3e-4, 1e-4]
61 |
62 | for lr in lrs:
63 | for wd in wds:
64 | scripts.append(template % (args.dataset, args.optimizer, args.network, lr, wd, app))
65 |
66 | return scripts
67 |
68 |
69 | def gen_script(scripts, machine, args):
70 | with open('run_%s_%s_%s.sh' % (args.dataset, args.optimizer, args.network), 'w') as f:
71 | for s in scripts:
72 | f.write('srun --gres=gpu:1 -c 6 -w guppy%d --mem=16G -p gpu \"%s\" &\n' % (machine, s))
73 |
74 |
75 | if __name__ == '__main__':
76 | scripts = grid_search(args)
77 | gen_script(scripts, args.machine, args)
78 |
79 |
--------------------------------------------------------------------------------
/utils/data_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | import torchvision.transforms as transforms
4 |
5 |
6 | def get_transforms(dataset):
7 | transform_train = None
8 | transform_test = None
9 | if dataset == 'cifar10':
10 | transform_train = transforms.Compose([
11 | transforms.RandomCrop(32, padding=4),
12 | transforms.RandomHorizontalFlip(),
13 | transforms.ToTensor(),
14 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
15 | ])
16 |
17 | transform_test = transforms.Compose([
18 | transforms.ToTensor(),
19 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
20 | ])
21 |
22 | if dataset == 'cifar100':
23 | transform_train = transforms.Compose([
24 | transforms.RandomCrop(32, padding=4),
25 | transforms.RandomHorizontalFlip(),
26 | transforms.ToTensor(),
27 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
28 | ])
29 |
30 | transform_test = transforms.Compose([
31 | transforms.ToTensor(),
32 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
33 | ])
34 |
35 | assert transform_test is not None and transform_train is not None, 'Error, no dataset %s' % dataset
36 | return transform_train, transform_test
37 |
38 |
39 | def get_dataloader(dataset, train_batch_size, test_batch_size, num_workers=2, root='../data'):
40 | transform_train, transform_test = get_transforms(dataset)
41 | trainset, testset = None, None
42 | if dataset == 'cifar10':
43 | trainset = torchvision.datasets.CIFAR10(root=root, train=True, download=True, transform=transform_train)
44 | testset = torchvision.datasets.CIFAR10(root=root, train=False, download=True, transform=transform_test)
45 |
46 | if dataset == 'cifar100':
47 | trainset = torchvision.datasets.CIFAR100(root=root, train=True, download=True, transform=transform_train)
48 | testset = torchvision.datasets.CIFAR100(root=root, train=False, download=True, transform=transform_test)
49 |
50 |
51 | assert trainset is not None and testset is not None, 'Error, no dataset %s' % dataset
52 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=train_batch_size, shuffle=True,
53 | num_workers=num_workers)
54 | testloader = torch.utils.data.DataLoader(testset, batch_size=test_batch_size, shuffle=False,
55 | num_workers=num_workers)
56 |
57 | return trainloader, testloader
--------------------------------------------------------------------------------
/utils/kfac_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | def try_contiguous(x):
7 | if not x.is_contiguous():
8 | x = x.contiguous()
9 |
10 | return x
11 |
12 |
13 | def _extract_patches(x, kernel_size, stride, padding):
14 | """
15 | :param x: The input feature maps. (batch_size, in_c, h, w)
16 | :param kernel_size: the kernel size of the conv filter (tuple of two elements)
17 | :param stride: the stride of conv operation (tuple of two elements)
18 | :param padding: number of paddings. be a tuple of two elements
19 | :return: (batch_size, out_h, out_w, in_c*kh*kw)
20 | """
21 | if padding[0] + padding[1] > 0:
22 | x = F.pad(x, (padding[1], padding[1], padding[0],
23 | padding[0])).data # Actually check dims
24 | x = x.unfold(2, kernel_size[0], stride[0])
25 | x = x.unfold(3, kernel_size[1], stride[1])
26 | x = x.transpose_(1, 2).transpose_(2, 3).contiguous()
27 | x = x.view(
28 | x.size(0), x.size(1), x.size(2),
29 | x.size(3) * x.size(4) * x.size(5))
30 | return x
31 |
32 |
33 | def update_running_stat(aa, m_aa, stat_decay):
34 | # using inplace operation to save memory!
35 | m_aa *= stat_decay / (1 - stat_decay)
36 | m_aa += aa
37 | m_aa *= (1 - stat_decay)
38 |
39 |
40 | class ComputeMatGrad:
41 |
42 | @classmethod
43 | def __call__(cls, input, grad_output, layer):
44 | if isinstance(layer, nn.Linear):
45 | grad = cls.linear(input, grad_output, layer)
46 | elif isinstance(layer, nn.Conv2d):
47 | grad = cls.conv2d(input, grad_output, layer)
48 | else:
49 | raise NotImplementedError
50 | return grad
51 |
52 | @staticmethod
53 | def linear(input, grad_output, layer):
54 | """
55 | :param input: batch_size * input_dim
56 | :param grad_output: batch_size * output_dim
57 | :param layer: [nn.module] output_dim * input_dim
58 | :return: batch_size * output_dim * (input_dim + [1 if with bias])
59 | """
60 | with torch.no_grad():
61 | if layer.bias is not None:
62 | input = torch.cat([input, input.new(input.size(0), 1).fill_(1)], 1)
63 | input = input.unsqueeze(1)
64 | grad_output = grad_output.unsqueeze(2)
65 | grad = torch.bmm(grad_output, input)
66 | return grad
67 |
68 | @staticmethod
69 | def conv2d(input, grad_output, layer):
70 | """
71 | :param input: batch_size * in_c * in_h * in_w
72 | :param grad_output: batch_size * out_c * h * w
73 | :param layer: nn.module batch_size * out_c * (in_c*k_h*k_w + [1 if with bias])
74 | :return:
75 | """
76 | with torch.no_grad():
77 | input = _extract_patches(input, layer.kernel_size, layer.stride, layer.padding)
78 | input = input.view(-1, input.size(-1)) # b * hw * in_c*kh*kw
79 | grad_output = grad_output.transpose(1, 2).transpose(2, 3)
80 | grad_output = try_contiguous(grad_output).view(grad_output.size(0), -1, grad_output.size(-1))
81 | # b * hw * out_c
82 | if layer.bias is not None:
83 | input = torch.cat([input, input.new(input.size(0), 1).fill_(1)], 1)
84 | input = input.view(grad_output.size(0), -1, input.size(-1)) # b * hw * in_c*kh*kw
85 | grad = torch.einsum('abm,abn->amn', (grad_output, input))
86 | return grad
87 |
88 |
89 | class ComputeCovA:
90 |
91 | @classmethod
92 | def compute_cov_a(cls, a, layer):
93 | return cls.__call__(a, layer)
94 |
95 | @classmethod
96 | def __call__(cls, a, layer):
97 | if isinstance(layer, nn.Linear):
98 | cov_a = cls.linear(a, layer)
99 | elif isinstance(layer, nn.Conv2d):
100 | cov_a = cls.conv2d(a, layer)
101 | else:
102 | # FIXME(CW): for extension to other layers.
103 | # raise NotImplementedError
104 | cov_a = None
105 |
106 | return cov_a
107 |
108 | @staticmethod
109 | def conv2d(a, layer):
110 | batch_size = a.size(0)
111 | a = _extract_patches(a, layer.kernel_size, layer.stride, layer.padding)
112 | spatial_size = a.size(1) * a.size(2)
113 | a = a.view(-1, a.size(-1))
114 | if layer.bias is not None:
115 | a = torch.cat([a, a.new(a.size(0), 1).fill_(1)], 1)
116 | a = a/spatial_size
117 | # FIXME(CW): do we need to divide the output feature map's size?
118 | return a.t() @ (a / batch_size)
119 |
120 | @staticmethod
121 | def linear(a, layer):
122 | # a: batch_size * in_dim
123 | batch_size = a.size(0)
124 | if layer.bias is not None:
125 | a = torch.cat([a, a.new(a.size(0), 1).fill_(1)], 1)
126 | return a.t() @ (a / batch_size)
127 |
128 |
129 | class ComputeCovG:
130 |
131 | @classmethod
132 | def compute_cov_g(cls, g, layer, batch_averaged=False):
133 | """
134 | :param g: gradient
135 | :param layer: the corresponding layer
136 | :param batch_averaged: if the gradient is already averaged with the batch size?
137 | :return:
138 | """
139 | # batch_size = g.size(0)
140 | return cls.__call__(g, layer, batch_averaged)
141 |
142 | @classmethod
143 | def __call__(cls, g, layer, batch_averaged):
144 | if isinstance(layer, nn.Conv2d):
145 | cov_g = cls.conv2d(g, layer, batch_averaged)
146 | elif isinstance(layer, nn.Linear):
147 | cov_g = cls.linear(g, layer, batch_averaged)
148 | else:
149 | cov_g = None
150 |
151 | return cov_g
152 |
153 | @staticmethod
154 | def conv2d(g, layer, batch_averaged):
155 | # g: batch_size * n_filters * out_h * out_w
156 | # n_filters is actually the output dimension (analogous to Linear layer)
157 | spatial_size = g.size(2) * g.size(3)
158 | batch_size = g.shape[0]
159 | g = g.transpose(1, 2).transpose(2, 3)
160 | g = try_contiguous(g)
161 | g = g.view(-1, g.size(-1))
162 |
163 | if batch_averaged:
164 | g = g * batch_size
165 | g = g * spatial_size
166 | cov_g = g.t() @ (g / g.size(0))
167 |
168 | return cov_g
169 |
170 | @staticmethod
171 | def linear(g, layer, batch_averaged):
172 | # g: batch_size * out_dim
173 | batch_size = g.size(0)
174 |
175 | if batch_averaged:
176 | cov_g = g.t() @ (g * batch_size)
177 | else:
178 | cov_g = g.t() @ (g / batch_size)
179 | return cov_g
180 |
181 |
182 |
183 | if __name__ == '__main__':
184 | def test_ComputeCovA():
185 | pass
186 |
187 | def test_ComputeCovG():
188 | pass
189 |
190 |
191 |
192 |
193 |
194 |
195 |
--------------------------------------------------------------------------------
/utils/network_utils.py:
--------------------------------------------------------------------------------
1 | from models.cifar import (alexnet, densenet, resnet,
2 | vgg16_bn, vgg19_bn,
3 | wrn)
4 |
5 |
6 | def get_network(network, **kwargs):
7 | networks = {
8 | 'alexnet': alexnet,
9 | 'densenet': densenet,
10 | 'resnet': resnet,
11 | 'vgg16_bn': vgg16_bn,
12 | 'vgg19_bn': vgg19_bn,
13 | 'wrn': wrn
14 |
15 | }
16 |
17 | return networks[network](**kwargs)
18 |
19 |
--------------------------------------------------------------------------------