├── hypergrad
├── __init__.py
├── adam_hd.py
└── sgd_hd.py
├── plots
├── mlp.pdf
├── mlp.png
├── vgg.pdf
├── vgg.png
├── logreg.pdf
└── logreg.png
├── poster
├── iclr_2018_poster.pdf
└── iclr_2018_poster.png
├── setup.py
├── LICENSE
├── .gitignore
├── run.sh
├── vgg.py
├── plot.py
├── README.md
└── train.py
/hypergrad/__init__.py:
--------------------------------------------------------------------------------
1 | from .adam_hd import AdamHD
2 | from .sgd_hd import SGDHD
3 |
--------------------------------------------------------------------------------
/plots/mlp.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gbaydin/hypergradient-descent/HEAD/plots/mlp.pdf
--------------------------------------------------------------------------------
/plots/mlp.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gbaydin/hypergradient-descent/HEAD/plots/mlp.png
--------------------------------------------------------------------------------
/plots/vgg.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gbaydin/hypergradient-descent/HEAD/plots/vgg.pdf
--------------------------------------------------------------------------------
/plots/vgg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gbaydin/hypergradient-descent/HEAD/plots/vgg.png
--------------------------------------------------------------------------------
/plots/logreg.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gbaydin/hypergradient-descent/HEAD/plots/logreg.pdf
--------------------------------------------------------------------------------
/plots/logreg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gbaydin/hypergradient-descent/HEAD/plots/logreg.png
--------------------------------------------------------------------------------
/poster/iclr_2018_poster.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gbaydin/hypergradient-descent/HEAD/poster/iclr_2018_poster.pdf
--------------------------------------------------------------------------------
/poster/iclr_2018_poster.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gbaydin/hypergradient-descent/HEAD/poster/iclr_2018_poster.png
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 |
3 | with open("README.md", "r") as f:
4 | long_description = f.read()
5 |
6 | setuptools.setup(
7 | name="hypergrad",
8 | version="0.1",
9 | author="Atılım Güneş Baydin",
10 | author_email="",
11 | description="Hypergradient descent",
12 | long_description=long_description,
13 | long_description_content_type="text/markdown",
14 | url="https://github.com/gbaydin/hypergradient-descent",
15 | packages=setuptools.find_packages(),
16 | classifiers=[
17 | "Programming Language :: Python :: 3",
18 | "License :: OSI Approved :: MIT License",
19 | "Operating System :: OS Independent",
20 | ],
21 | )
22 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Atılım Güneş Baydin
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Training data and results
2 | data/
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | env/
15 | build/
16 | develop-eggs/
17 | dist/
18 | downloads/
19 | eggs/
20 | .eggs/
21 | lib/
22 | lib64/
23 | parts/
24 | sdist/
25 | var/
26 | wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 |
31 | # PyInstaller
32 | # Usually these files are written by a python script from a template
33 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
34 | *.manifest
35 | *.spec
36 |
37 | # Installer logs
38 | pip-log.txt
39 | pip-delete-this-directory.txt
40 |
41 | # Unit test / coverage reports
42 | htmlcov/
43 | .tox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | .hypothesis/
51 |
52 | # Translations
53 | *.mo
54 | *.pot
55 |
56 | # Django stuff:
57 | *.log
58 | local_settings.py
59 |
60 | # Flask stuff:
61 | instance/
62 | .webassets-cache
63 |
64 | # Scrapy stuff:
65 | .scrapy
66 |
67 | # Sphinx documentation
68 | docs/_build/
69 |
70 | # PyBuilder
71 | target/
72 |
73 | # Jupyter Notebook
74 | .ipynb_checkpoints
75 |
76 | # pyenv
77 | .python-version
78 |
79 | # celery beat schedule file
80 | celerybeat-schedule
81 |
82 | # SageMath parsed files
83 | *.sage.py
84 |
85 | # dotenv
86 | .env
87 |
88 | # virtualenv
89 | .venv
90 | venv/
91 | ENV/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | python train.py --cuda --model logreg --method sgd --save --epochs 10 --alpha_0 0.001 --beta 0.001
4 | python train.py --cuda --model logreg --method sgd_hd --save --epochs 10 --alpha_0 0.001 --beta 0.001
5 | python train.py --cuda --model logreg --method sgdn --save --epochs 10 --alpha_0 0.001 --beta 0.001
6 | python train.py --cuda --model logreg --method sgdn_hd --save --epochs 10 --alpha_0 0.001 --beta 0.001
7 | python train.py --cuda --model logreg --method adam --save --epochs 10 --alpha_0 0.001 --beta 1e-7
8 | python train.py --cuda --model logreg --method adam_hd --save --epochs 10 --alpha_0 0.001 --beta 1e-7
9 |
10 | python train.py --cuda --model mlp --method sgd --save --epochs 100 --alpha_0 0.001 --beta 0.001
11 | python train.py --cuda --model mlp --method sgd_hd --save --epochs 100 --alpha_0 0.001 --beta 0.001
12 | python train.py --cuda --model mlp --method sgdn --save --epochs 100 --alpha_0 0.001 --beta 0.001
13 | python train.py --cuda --model mlp --method sgdn_hd --save --epochs 100 --alpha_0 0.001 --beta 0.001
14 | python train.py --cuda --model mlp --method adam --save --epochs 100 --alpha_0 0.001 --beta 1e-7
15 | python train.py --cuda --model mlp --method adam_hd --save --epochs 100 --alpha_0 0.001 --beta 1e-7
16 |
17 | python train.py --cuda --model vgg --method sgd --save --epochs 100 --alpha_0 0.001 --beta 0.001
18 | python train.py --cuda --model vgg --method sgd_hd --save --epochs 100 --alpha_0 0.001 --beta 0.001
19 | python train.py --cuda --model vgg --method sgdn --save --epochs 100 --alpha_0 0.001 --beta 0.001
20 | python train.py --cuda --model vgg --method sgdn_hd --save --epochs 100 --alpha_0 0.001 --beta 0.001
21 | python train.py --cuda --model vgg --method adam --save --epochs 100 --alpha_0 0.001 --beta 1e-8
22 | python train.py --cuda --model vgg --method adam_hd --save --epochs 100 --alpha_0 0.001 --beta 1e-8
23 |
--------------------------------------------------------------------------------
/vgg.py:
--------------------------------------------------------------------------------
1 | '''
2 | Modified from https://github.com/pytorch/vision.git
3 | '''
4 | import math
5 |
6 | import torch.nn as nn
7 | import torch.nn.init as init
8 |
9 | __all__ = [
10 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
11 | 'vgg19_bn', 'vgg19',
12 | ]
13 |
14 |
15 | class VGG(nn.Module):
16 | '''
17 | VGG model
18 | '''
19 | def __init__(self, features):
20 | super(VGG, self).__init__()
21 | self.features = features
22 | self.classifier = nn.Sequential(
23 | nn.Dropout(),
24 | nn.Linear(512, 512),
25 | nn.ReLU(True),
26 | nn.Dropout(),
27 | nn.Linear(512, 512),
28 | nn.ReLU(True),
29 | nn.Linear(512, 10),
30 | )
31 | # Initialize weights
32 | for m in self.modules():
33 | if isinstance(m, nn.Conv2d):
34 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
35 | m.weight.data.normal_(0, math.sqrt(2. / n))
36 | m.bias.data.zero_()
37 |
38 |
39 | def forward(self, x):
40 | x = self.features(x)
41 | x = x.view(x.size(0), -1)
42 | x = self.classifier(x)
43 | return x
44 |
45 |
46 | def make_layers(cfg, batch_norm=False):
47 | layers = []
48 | in_channels = 3
49 | for v in cfg:
50 | if v == 'M':
51 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
52 | else:
53 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
54 | if batch_norm:
55 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
56 | else:
57 | layers += [conv2d, nn.ReLU(inplace=True)]
58 | in_channels = v
59 | return nn.Sequential(*layers)
60 |
61 |
62 | cfg = {
63 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
64 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
65 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
66 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M',
67 | 512, 512, 512, 512, 'M'],
68 | }
69 |
70 |
71 | def vgg11():
72 | """VGG 11-layer model (configuration "A")"""
73 | return VGG(make_layers(cfg['A']))
74 |
75 |
76 | def vgg11_bn():
77 | """VGG 11-layer model (configuration "A") with batch normalization"""
78 | return VGG(make_layers(cfg['A'], batch_norm=True))
79 |
80 |
81 | def vgg13():
82 | """VGG 13-layer model (configuration "B")"""
83 | return VGG(make_layers(cfg['B']))
84 |
85 |
86 | def vgg13_bn():
87 | """VGG 13-layer model (configuration "B") with batch normalization"""
88 | return VGG(make_layers(cfg['B'], batch_norm=True))
89 |
90 |
91 | def vgg16():
92 | """VGG 16-layer model (configuration "D")"""
93 | return VGG(make_layers(cfg['D']))
94 |
95 |
96 | def vgg16_bn():
97 | """VGG 16-layer model (configuration "D") with batch normalization"""
98 | return VGG(make_layers(cfg['D'], batch_norm=True))
99 |
100 |
101 | def vgg19():
102 | """VGG 19-layer model (configuration "E")"""
103 | return VGG(make_layers(cfg['E']))
104 |
105 |
106 | def vgg19_bn():
107 | """VGG 19-layer model (configuration 'E') with batch normalization"""
108 | return VGG(make_layers(cfg['E'], batch_norm=True))
109 |
--------------------------------------------------------------------------------
/hypergrad/adam_hd.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch.optim.optimizer import Optimizer
4 |
5 |
6 | class AdamHD(Optimizer):
7 | """Implements Adam algorithm.
8 |
9 | It has been proposed in `Adam: A Method for Stochastic Optimization`_.
10 |
11 | Arguments:
12 | params (iterable): iterable of parameters to optimize or dicts defining
13 | parameter groups
14 | lr (float, optional): learning rate (default: 1e-3)
15 | betas (Tuple[float, float], optional): coefficients used for computing
16 | running averages of gradient and its square (default: (0.9, 0.999))
17 | eps (float, optional): term added to the denominator to improve
18 | numerical stability (default: 1e-8)
19 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
20 | hypergrad_lr (float, optional): hypergradient learning rate for the online
21 | tuning of the learning rate, introduced in the paper
22 | `Online Learning Rate Adaptation with Hypergradient Descent`_
23 |
24 | .. _Adam\: A Method for Stochastic Optimization:
25 | https://arxiv.org/abs/1412.6980
26 | .. _Online Learning Rate Adaptation with Hypergradient Descent:
27 | https://openreview.net/forum?id=BkrsAzWAb
28 | """
29 |
30 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
31 | weight_decay=0, hypergrad_lr=1e-8):
32 | defaults = dict(lr=lr, betas=betas, eps=eps,
33 | weight_decay=weight_decay, hypergrad_lr=hypergrad_lr)
34 | super(AdamHD, self).__init__(params, defaults)
35 |
36 | def step(self, closure=None):
37 | """Performs a single optimization step.
38 |
39 | Arguments:
40 | closure (callable, optional): A closure that reevaluates the model
41 | and returns the loss.
42 | """
43 | loss = None
44 | if closure is not None:
45 | loss = closure()
46 |
47 | for group in self.param_groups:
48 | for p in group['params']:
49 | if p.grad is None:
50 | continue
51 | grad = p.grad.data
52 | if grad.is_sparse:
53 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
54 |
55 | state = self.state[p]
56 |
57 | # State initialization
58 | if len(state) == 0:
59 | state['step'] = 0
60 | # Exponential moving average of gradient values
61 | state['exp_avg'] = torch.zeros_like(p.data)
62 | # Exponential moving average of squared gradient values
63 | state['exp_avg_sq'] = torch.zeros_like(p.data)
64 |
65 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
66 | beta1, beta2 = group['betas']
67 |
68 | state['step'] += 1
69 |
70 | if group['weight_decay'] != 0:
71 | grad = grad.add(group['weight_decay'], p.data)
72 |
73 | if state['step'] > 1:
74 | prev_bias_correction1 = 1 - beta1 ** (state['step'] - 1)
75 | prev_bias_correction2 = 1 - beta2 ** (state['step'] - 1)
76 | # Hypergradient for Adam:
77 | h = torch.dot(grad.view(-1), torch.div(exp_avg, exp_avg_sq.sqrt().add_(group['eps'])).view(-1)) * math.sqrt(prev_bias_correction2) / prev_bias_correction1
78 | # Hypergradient descent of the learning rate:
79 | group['lr'] += group['hypergrad_lr'] * h
80 |
81 | # Decay the first and second moment running average coefficient
82 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
83 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
84 | denom = exp_avg_sq.sqrt().add_(group['eps'])
85 |
86 | bias_correction1 = 1 - beta1 ** state['step']
87 | bias_correction2 = 1 - beta2 ** state['step']
88 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
89 |
90 | p.data.addcdiv_(-step_size, exp_avg, denom)
91 |
92 | return loss
93 |
--------------------------------------------------------------------------------
/plot.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import argparse
4 | import csv
5 | import os
6 | import glob
7 | import matplotlib
8 | # Force matplotlib to not use any Xwindows backend.
9 | matplotlib.use('Agg')
10 | import matplotlib.pyplot as plt
11 | from matplotlib.ticker import MaxNLocator
12 | from mpl_toolkits.axes_grid.inset_locator import inset_axes
13 |
14 | colorblindbright = [(252,145,100),(188,56,119),(114,27,166)]
15 | colorblinddim = [(213,167,103),(163,85,114),(104,59,130)]
16 | for i in range(len(colorblindbright)):
17 | r, g, b = colorblindbright[i]
18 | colorblindbright[i] = (r / 255., g / 255., b / 255.)
19 | for i in range(len(colorblinddim)):
20 | r, g, b = colorblinddim[i]
21 | colorblinddim[i] = (r / 255., g / 255., b / 255.)
22 |
23 | colors = {'sgd':colorblinddim[0], 'sgdn':colorblinddim[1],'adam':colorblinddim[2], \
24 | 'sgd_hd':colorblindbright[0], 'sgdn_hd':colorblindbright[1],'adam_hd':colorblindbright[2]}
25 | names = {'sgd':'SGD','sgdn':'SGDN','adam':'Adam','sgd_hd':'SGD-HD','sgdn_hd':'SGDN-HD','adam_hd':'Adam-HD'}
26 | linestyles = {'sgd':'--','sgdn':'--','adam':'--','sgd_hd':'-','sgdn_hd':'-','adam_hd':'-'}
27 | linedashes = {'sgd':[3,3],'sgdn':[3,3],'adam':[3,3],'sgd_hd':[10,1e-9],'sgdn_hd':[10,1e-9],'adam_hd':[10,1e-9]}
28 |
29 | parser = argparse.ArgumentParser(description='Plotting for hypergradient descent PyTorch tests', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
30 | parser.add_argument('--dir', help='directory to read the csv files written by train.py', default='results', type=str)
31 | parser.add_argument('--plotDir', help='directory to save the plots', default='plots', type=str)
32 | opt = parser.parse_args()
33 |
34 | os.makedirs(opt.plotDir, exist_ok=True)
35 |
36 | model_titles = {'logreg': 'Logistic regression (on MNIST)', 'mlp': 'Multi-layer neural network (on MNIST)', 'vgg': 'VGG Net (on CIFAR-10)'}
37 | for model in next(os.walk(opt.dir))[1]:
38 | data = {}
39 | data_epoch = {}
40 | selected = []
41 | for fname in glob.glob('{}/{}/**/*.csv'.format(opt.dir, model), recursive=True):
42 | name = os.path.splitext(os.path.basename(fname))[0]
43 | data[name] = pd.read_csv(fname)
44 | data_epoch[name] = data[name][pd.notna(data[name].LossEpoch)]
45 | selected.append(name)
46 |
47 | plt.figure(figsize=(5,12))
48 |
49 | fig = plt.figure(figsize=(5, 12))
50 | ax = fig.add_subplot(311)
51 | for name in selected:
52 | plt.plot(data_epoch[name].Epoch,data_epoch[name].AlphaEpoch,label=names[name],color=colors[name],linestyle=linestyles[name],dashes=linedashes[name])
53 | ax.xaxis.set_major_locator(MaxNLocator(integer=True))
54 | plt.ylabel('Learning rate')
55 | plt.tick_params(labeltop=False, labelbottom=False, bottom=False, top=False, labelright=False)
56 | plt.grid()
57 | plt.title(model_titles[model])
58 | inset_axes(ax, width="50%", height="35%", loc=1)
59 | for name in selected:
60 | plt.plot(data[name].Iteration, data[name].Alpha,label=names[name],color=colors[name],linestyle=linestyles[name],dashes=linedashes[name])
61 | plt.yticks(np.arange(-0.01, 0.051, 0.01))
62 | plt.xlabel('Iteration')
63 | plt.ylabel('Learning rate')
64 | plt.xscale('log')
65 | plt.xlim([0,9000])
66 | plt.grid()
67 |
68 | ax = fig.add_subplot(312)
69 | for name in selected:
70 | plt.plot(data_epoch[name].Epoch, data_epoch[name].LossEpoch,label=names[name],color=colors[name],linestyle=linestyles[name],dashes=linedashes[name])
71 | ax.xaxis.set_major_locator(MaxNLocator(integer=True))
72 | plt.ylabel('Training loss')
73 | plt.yscale('log')
74 | plt.tick_params(labeltop=False, labelbottom=False, bottom=False, top=False, labelright=False)
75 | plt.grid()
76 | inset_axes(ax, width="50%", height="35%", loc=1)
77 | for name in selected:
78 | plt.plot(data[name].Iteration, data[name].Loss,label=names[name],color=colors[name],linestyle=linestyles[name],dashes=linedashes[name])
79 | plt.yticks(np.arange(0, 2.01, 0.5))
80 | plt.xlabel('Iteration')
81 | plt.ylabel('Training loss')
82 | plt.xscale('log')
83 | plt.xlim([0,9000])
84 | plt.grid()
85 |
86 | ax = fig.add_subplot(313)
87 | for name in selected:
88 | plt.plot(data_epoch[name].Epoch, data_epoch[name].ValidLossEpoch,label=names[name],color=colors[name],linestyle=linestyles[name],dashes=linedashes[name])
89 | plt.xlabel('Epoch')
90 | ax.xaxis.set_major_locator(MaxNLocator(integer=True))
91 | plt.ylabel('Validation loss')
92 | plt.yscale('log')
93 | handles, labels = plt.gca().get_legend_handles_labels()
94 | labels, handles = zip(*sorted(zip(labels, handles), key=lambda t: t[0]))
95 | plt.legend(handles,labels,loc='upper right',frameon=1,framealpha=1,edgecolor='black',fancybox=False)
96 | plt.grid()
97 |
98 | plt.tight_layout()
99 | plt.savefig('{}/{}.pdf'.format(opt.plotDir, model), bbox_inches='tight')
100 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # hypergradient-descent
2 | This is the [PyTorch](http://pytorch.org/) code for the paper [_Online Learning Rate Adaptation with Hypergradient Descent_](https://openreview.net/forum?id=BkrsAzWAb) at ICLR 2018.
3 |
4 | A [TensorFlow](https://www.tensorflow.org/) version is also planned and should appear in this repo at a later time.
5 |
6 |
7 |
8 | ## What is a "hypergradient"?
9 |
10 | In gradient-based optimization, one optimizes an objective function by using its derivatives (gradient) with respect to model parameters. In addition to this basic gradient, a _hypergradient_ is the derivative of the same objective function with respect to the optimization procedure's hyperparameters (such as the learning rate, momentum, or regularization parameters). There can be many types of hypergradients, and in this work we're interested in the hypergradient with respect to a scalar learning rate.
11 |
12 | ## Installation
13 |
14 | ```bash
15 | pip install git+https://github.com/gbaydin/hypergradient-descent.git
16 | ```
17 |
18 | ## How can I use it for my work?
19 |
20 | We are providing ready-to-use implementations of the hypergradient versions of SGD (with or without momentum) and Adam optimizers for PyTorch. These comply with the `torch.optim` API and can be used as drop-in replacements in your code. Just take the `sgd_hd.py` and `adam_hd.py` files from this repo and import them like
21 |
22 | ```python
23 | from hypergrad import SGDHD, AdamHD
24 | ...
25 |
26 | optimizer = optim.AdamHD(model.parameters(), lr=args.lr, hypergrad_lr=1e-8)
27 | ...
28 | ```
29 |
30 | The optimizers introduce an extra argument `hypergrad_lr` which determines the hypergradient learning rate, that is, the learning rate used to optimize the regular learning rate `lr`. The value you give to `lr` sets the initial value of the regular learning rate, from which it will be adapted in an online fashion by the hypergradient descent procedure. Lower values for `hypergrad_lr` are safer because (1) the resulting updates to `lr` are smaller; and (2) one recovers the non-hypergradient version of the algorithm as `hypergrad_lr` approaches zero.
31 |
32 | Don't be worried that, instead of having to tune just one learning rate (`lr`), now you have to tune two (`lr` and `hypergrad_lr`); just see the next section.
33 |
34 | ## What is the advantage?
35 | Hypergradient algorithms are much less sensitive to the choice of the initial learning rate (`lr`), unlike the non-hypergradient version of the same algorithm. A hypergradient algorithm requires significantly less tuning to give performance better than, or in the worst case the same as, a non-hypergradient baseline, given a small `hypergrad_lr`, which can either be left as the recommended default or tuned. Please see the paper for guideline values of `hypergrad_lr`.
36 |
37 | In practice, you might be surprised to see that **even starting with a zero learning rate works** and the learning rate is quickly raised to a useful non-zero level as needed, and then decayed towards zero as optimization converges:
38 | ```python
39 | optimizer = optim.AdamHD(model.parameters(), lr=0, hypergrad_lr=1e-8)
40 | ```
41 |
42 | If you would like to monitor the evolution of the learning rate during optimization, you can monitor it with code that looks like
43 |
44 | ```python
45 | lr = optimizer.param_groups[0]['lr']
46 | print(lr)
47 | ```
48 |
49 | ## Notes about the code in this repository
50 | * The results in the paper were produced by code written in [(Lua)torch](http://torch.ch/). The code in this repo is a reimplementation in PyTorch, which produces results that are not exactly the same but qualitatively identical in the behavior of the learning rate, training and validation losses, and the relative performance of the algorithms.
51 | * In the .csv result files, `nan` is used as a placeholder for empty entries in epoch losses, which are only computed once per epoch.
52 | * The implementation in this repository doesn't include any heuristics to ensure that your gradients and hypergradients don't "explode". In practice, you might need to apply gradient clipping or safeguards for the updates to the learning rate to prevent bad behavior. If this happens in your models, we would be interested in hearing about it: please let us know via email or a GitHub issue.
53 |
54 | ## Some other implementations
55 |
56 | * (Lua)Torch optim implementation (not the main implementation for the paper and not well-tested): https://github.com/gbaydin/optim
57 | * C++ implementation in the Livermore Big Artificial Neural Network Toolkit, Lawrence Livermore National Laboratory: https://github.com/LLNL/lbann/blob/a778d2b764ba209042555aac26328cbfb8063802/src/optimizers/hypergradient_adam.cpp
58 | * TensorFlow implementation by Andrii Zadaianchuk, University of Tübingen: https://github.com/zadaianchuk/HyperGradientDescent
59 | * Java implementation in Apache Hivemall: https://github.com/apache/incubator-hivemall/blob/31932fd7c63f9bb21eba8959944d03f280b6deb9/core/src/main/java/hivemall/optimizer/Optimizer.java#L640
60 |
61 | ## Paper
62 | Atılım Güneş Baydin, Robert Cornish, David Martı́nez Rubio, Mark Schmidt, and Frank Wood. Online learning rate adaptation with hypergradient descent. In _Sixth International
63 | Conference on Learning Representations (ICLR), Vancouver, Canada, April 30 – May 3, 2018,_ 2018.
64 |
65 | https://openreview.net/forum?id=BkrsAzWAb
66 |
67 | ```
68 | @inproceedings{baydin2018hypergradient,
69 | title = {Online Learning Rate Adaptation with Hypergradient Descent},
70 | author = {Baydin, Atılım Güneş and Cornish, Robert and Rubio, David Martínez and Schmidt, Mark and Wood, Frank},
71 | booktitle = {Sixth International Conference on Learning Representations (ICLR), Vancouver, Canada, April 30 -- May 3, 2018},
72 | year = {2018}
73 | }
74 | ```
75 |
76 | ## Poster
77 |
78 | [The ICLR 2018 poster](https://github.com/gbaydin/hypergradient-descent/raw/master/poster/iclr_2018_poster.pdf)
79 |
80 |
81 |
82 |
--------------------------------------------------------------------------------
/hypergrad/sgd_hd.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from functools import reduce
3 | from torch.optim.optimizer import Optimizer, required
4 |
5 |
6 | class SGDHD(Optimizer):
7 | r"""Implements stochastic gradient descent (optionally with momentum).
8 |
9 | Nesterov momentum is based on the formula from
10 | `On the importance of initialization and momentum in deep learning`__.
11 |
12 | Args:
13 | params (iterable): iterable of parameters to optimize or dicts defining
14 | parameter groups
15 | lr (float): learning rate
16 | momentum (float, optional): momentum factor (default: 0)
17 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
18 | dampening (float, optional): dampening for momentum (default: 0)
19 | nesterov (bool, optional): enables Nesterov momentum (default: False)
20 | hypergrad_lr (float, optional): hypergradient learning rate for the online
21 | tuning of the learning rate, introduced in the paper
22 | `Online Learning Rate Adaptation with Hypergradient Descent`_
23 |
24 | Example:
25 | >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
26 | >>> optimizer.zero_grad()
27 | >>> loss_fn(model(input), target).backward()
28 | >>> optimizer.step()
29 |
30 | __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf
31 | .. _Online Learning Rate Adaptation with Hypergradient Descent:
32 | https://openreview.net/forum?id=BkrsAzWAb
33 | .. note::
34 | The implementation of SGD with Momentum/Nesterov subtly differs from
35 | Sutskever et. al. and implementations in some other frameworks.
36 |
37 | Considering the specific case of Momentum, the update can be written as
38 |
39 | .. math::
40 | v = \rho * v + g \\
41 | p = p - lr * v
42 |
43 | where p, g, v and :math:`\rho` denote the parameters, gradient,
44 | velocity, and momentum respectively.
45 |
46 | This is in contrast to Sutskever et. al. and
47 | other frameworks which employ an update of the form
48 |
49 | .. math::
50 | v = \rho * v + lr * g \\
51 | p = p - v
52 |
53 | The Nesterov version is analogously modified.
54 | """
55 |
56 | def __init__(self, params, lr=required, momentum=0, dampening=0,
57 | weight_decay=0, nesterov=False, hypergrad_lr=1e-6):
58 | defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
59 | weight_decay=weight_decay, nesterov=nesterov, hypergrad_lr=hypergrad_lr)
60 | if nesterov and (momentum <= 0 or dampening != 0):
61 | raise ValueError("Nesterov momentum requires a momentum and zero dampening")
62 | super(SGDHD, self).__init__(params, defaults)
63 |
64 | if len(self.param_groups) != 1:
65 | raise ValueError("SGDHD doesn't support per-parameter options (parameter groups)")
66 |
67 | self._params = self.param_groups[0]['params']
68 | self._params_numel = reduce(lambda total, p: total + p.numel(), self._params, 0)
69 |
70 | def _gather_flat_grad_with_weight_decay(self, weight_decay=0):
71 | views = []
72 | for p in self._params:
73 | if p.grad is None:
74 | view = torch.zeros_like(p.data)
75 | elif p.grad.data.is_sparse:
76 | view = p.grad.data.to_dense().view(-1)
77 | else:
78 | view = p.grad.data.view(-1)
79 | if weight_decay != 0:
80 | view.add_(weight_decay, p.data.view(-1))
81 | views.append(view)
82 | return torch.cat(views, 0)
83 |
84 | def _add_grad(self, step_size, update):
85 | offset = 0
86 | for p in self._params:
87 | numel = p.numel()
88 | # view as to avoid deprecated pointwise semantics
89 | p.data.add_(step_size, update[offset:offset + numel].view_as(p.data))
90 | offset += numel
91 | assert offset == self._params_numel
92 |
93 | def step(self, closure=None):
94 | """Performs a single optimization step.
95 |
96 | Arguments:
97 | closure (callable, optional): A closure that reevaluates the model
98 | and returns the loss.
99 | """
100 | assert len(self.param_groups) == 1
101 |
102 | loss = None
103 | if closure is not None:
104 | loss = closure()
105 |
106 | group = self.param_groups[0]
107 | weight_decay = group['weight_decay']
108 | momentum = group['momentum']
109 | dampening = group['dampening']
110 | nesterov = group['nesterov']
111 |
112 | grad = self._gather_flat_grad_with_weight_decay(weight_decay)
113 |
114 | # NOTE: SGDHD has only global state, but we register it as state for
115 | # the first param, because this helps with casting in load_state_dict
116 | state = self.state[self._params[0]]
117 | # State initialization
118 | if len(state) == 0:
119 | state['grad_prev'] = torch.zeros_like(grad)
120 |
121 | grad_prev = state['grad_prev']
122 | # Hypergradient for SGD
123 | h = torch.dot(grad, grad_prev)
124 | # Hypergradient descent of the learning rate:
125 | group['lr'] += group['hypergrad_lr'] * h
126 |
127 | if momentum != 0:
128 | if 'momentum_buffer' not in state:
129 | buf = state['momentum_buffer'] = torch.zeros_like(grad)
130 | buf.mul_(momentum).add_(grad)
131 | else:
132 | buf = state['momentum_buffer']
133 | buf.mul_(momentum).add_(1 - dampening, grad)
134 | if nesterov:
135 | grad.add_(momentum, buf)
136 | else:
137 | grad = buf
138 |
139 | state['grad_prev'] = grad
140 |
141 | self._add_grad(-group['lr'], grad)
142 |
143 | return loss
144 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import traceback
2 | import argparse
3 | import sys
4 | import os
5 | import csv
6 | import time
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | from torch.autograd import Variable
12 | from torchvision import datasets, transforms
13 | import vgg
14 | from torch.utils.data import DataLoader
15 | from torch.optim import SGD, Adam
16 | from hypergrad import SGDHD, AdamHD
17 |
18 |
19 | class LogReg(nn.Module):
20 | def __init__(self, input_dim, output_dim):
21 | super(LogReg, self).__init__()
22 | self._input_dim = input_dim
23 | self.lin1 = nn.Linear(input_dim, output_dim)
24 |
25 | def forward(self, x):
26 | x = x.view(-1, self._input_dim)
27 | x = self.lin1(x)
28 | return x
29 |
30 |
31 | class MLP(nn.Module):
32 | def __init__(self, input_dim, hidden_dim, output_dim):
33 | super(MLP, self).__init__()
34 | self._input_dim = input_dim
35 | self.lin1 = nn.Linear(input_dim, hidden_dim)
36 | self.lin2 = nn.Linear(hidden_dim, hidden_dim)
37 | self.lin3 = nn.Linear(hidden_dim, output_dim)
38 |
39 | def forward(self, x):
40 | x = x.view(-1, self._input_dim)
41 | x = F.relu(self.lin1(x))
42 | x = F.relu(self.lin2(x))
43 | x = self.lin3(x)
44 | return x
45 |
46 |
47 | def train(opt, log_func=None):
48 | torch.manual_seed(opt.seed)
49 | if opt.cuda:
50 | torch.cuda.set_device(opt.device)
51 | torch.cuda.manual_seed(opt.seed)
52 | torch.backends.cudnn.enabled = True
53 |
54 | if opt.model == 'logreg':
55 | model = LogReg(28 * 28, 10)
56 | elif opt.model == 'mlp':
57 | model = MLP(28 * 28, 1000, 10)
58 | elif opt.model == 'vgg':
59 | model = vgg.vgg16_bn()
60 | if opt.parallel:
61 | model.features = torch.nn.DataParallel(model.features)
62 | else:
63 | raise Exception('Unknown model: {}'.format(opt.model))
64 |
65 | if opt.cuda:
66 | model = model.cuda()
67 |
68 | if opt.model == 'logreg' or opt.model == 'mlp':
69 | task = 'MNIST'
70 | train_loader = DataLoader(
71 | datasets.MNIST('./data', train=True, download=True,
72 | transform=transforms.Compose([
73 | transforms.ToTensor(),
74 | transforms.Normalize((0.1307,), (0.3081,))
75 | ])),
76 | batch_size=opt.batchSize, shuffle=True)
77 | valid_loader = DataLoader(
78 | datasets.MNIST('./data', train=False, transform=transforms.Compose([
79 | transforms.ToTensor(),
80 | transforms.Normalize((0.1307,), (0.3081,))
81 | ])),
82 | batch_size=opt.batchSize, shuffle=False)
83 | elif opt.model == 'vgg':
84 | task = 'CIFAR10'
85 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
86 | std=[0.229, 0.224, 0.225])
87 | train_loader = torch.utils.data.DataLoader(
88 | datasets.CIFAR10(root='./data', train=True, transform=transforms.Compose([
89 | transforms.RandomHorizontalFlip(),
90 | transforms.RandomCrop(32, 4),
91 | transforms.ToTensor(),
92 | normalize,
93 | ]), download=True),
94 | batch_size=opt.batchSize, shuffle=True,
95 | num_workers=opt.workers, pin_memory=True)
96 |
97 | valid_loader = torch.utils.data.DataLoader(
98 | datasets.CIFAR10(root='./data', train=False, transform=transforms.Compose([
99 | transforms.ToTensor(),
100 | normalize,
101 | ])),
102 | batch_size=opt.batchSize, shuffle=False,
103 | num_workers=opt.workers, pin_memory=True)
104 | else:
105 | raise Exception('Unknown model: {}'.format(opt.model))
106 |
107 | if opt.method == 'sgd':
108 | optimizer = SGD(model.parameters(), lr=opt.alpha_0, weight_decay=opt.weightDecay)
109 | elif opt.method == 'sgd_hd':
110 | optimizer = SGDHD(model.parameters(), lr=opt.alpha_0, weight_decay=opt.weightDecay, hypergrad_lr=opt.beta)
111 | elif opt.method == 'sgdn':
112 | optimizer = SGD(model.parameters(), lr=opt.alpha_0, weight_decay=opt.weightDecay, momentum=opt.mu, nesterov=True)
113 | elif opt.method == 'sgdn_hd':
114 | optimizer = SGDHD(model.parameters(), lr=opt.alpha_0, weight_decay=opt.weightDecay, momentum=opt.mu, nesterov=True, hypergrad_lr=opt.beta)
115 | elif opt.method == 'adam':
116 | optimizer = Adam(model.parameters(), lr=opt.alpha_0, weight_decay=opt.weightDecay)
117 | elif opt.method == 'adam_hd':
118 | optimizer = AdamHD(model.parameters(), lr=opt.alpha_0, weight_decay=opt.weightDecay, hypergrad_lr=opt.beta)
119 | else:
120 | raise Exception('Unknown method: {}'.format(opt.method))
121 |
122 | if not opt.silent:
123 | print('Task: {}, Model: {}, Method: {}'.format(task, opt.model, opt.method))
124 |
125 | model.eval()
126 | for batch_id, (data, target) in enumerate(train_loader):
127 | data, target = Variable(data), Variable(target)
128 | if opt.cuda:
129 | data, target = data.cuda(), target.cuda()
130 | output = model(data)
131 | loss = F.cross_entropy(output, target)
132 | loss = loss.data[0]
133 | break
134 | valid_loss = 0
135 | for data, target in valid_loader:
136 | data, target = Variable(data, volatile=True), Variable(target)
137 | if opt.cuda:
138 | data, target = data.cuda(), target.cuda()
139 | output = model(data)
140 | valid_loss += F.cross_entropy(output, target, size_average=False).data[0]
141 | valid_loss /= len(valid_loader.dataset)
142 | if log_func is not None:
143 | log_func(0, 0, 0, loss, loss, valid_loss, opt.alpha_0, opt.alpha_0, opt.beta)
144 |
145 | time_start = time.time()
146 | iteration = 1
147 | epoch = 1
148 | done = False
149 | while not done:
150 | model.train()
151 | loss_epoch = 0
152 | alpha_epoch = 0
153 | for batch_id, (data, target) in enumerate(train_loader):
154 | data, target = Variable(data), Variable(target)
155 | if opt.cuda:
156 | data, target = data.cuda(), target.cuda()
157 | optimizer.zero_grad()
158 | output = model(data)
159 | loss = F.cross_entropy(output, target)
160 | loss.backward()
161 | optimizer.step()
162 | loss = loss.data[0]
163 | loss_epoch += loss
164 | alpha = optimizer.param_groups[0]['lr']
165 | alpha_epoch += alpha
166 | iteration += 1
167 | if opt.iterations != 0:
168 | if iteration > opt.iterations:
169 | print('Early stopping: iteration > {}'.format(opt.iterations))
170 | done = True
171 | break
172 | if opt.lossThreshold >= 0:
173 | if loss <= opt.lossThreshold:
174 | print('Early stopping: loss <= {}'.format(opt.lossThreshold))
175 | done = True
176 | break
177 |
178 | if batch_id + 1 >= len(train_loader):
179 | loss_epoch /= len(train_loader)
180 | alpha_epoch /= len(train_loader)
181 | model.eval()
182 | valid_loss = 0
183 | for data, target in valid_loader:
184 | data, target = Variable(data, volatile=True), Variable(target)
185 | if opt.cuda:
186 | data, target = data.cuda(), target.cuda()
187 | output = model(data)
188 | valid_loss += F.cross_entropy(output, target, size_average=False).data[0]
189 | valid_loss /= len(valid_loader.dataset)
190 | if log_func is not None:
191 | log_func(epoch, iteration, time.time() - time_start, loss, loss_epoch, valid_loss, alpha, alpha_epoch, opt.beta)
192 | else:
193 | if log_func is not None:
194 | log_func(epoch, iteration, time.time() - time_start, loss, float('nan'), float('nan'), alpha, float('nan'), opt.beta)
195 |
196 | epoch += 1
197 | if opt.epochs != 0:
198 | if epoch > opt.epochs:
199 | print('Early stopping: epoch > {}'.format(opt.epochs))
200 | done = True
201 | return loss, iteration
202 |
203 |
204 | def main():
205 | try:
206 | parser = argparse.ArgumentParser(description='Hypergradient descent PyTorch tests', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
207 | parser.add_argument('--cuda', help='use CUDA', action='store_true')
208 | parser.add_argument('--device', help='selected CUDA device', default=0, type=int)
209 | parser.add_argument('--seed', help='random seed', default=1, type=int)
210 | parser.add_argument('--dir', help='directory to write the output files', default='results', type=str)
211 | parser.add_argument('--model', help='model (logreg, mlp, vgg)', default='logreg', type=str)
212 | parser.add_argument('--method', help='method (sgd, sgd_hd, sgdn, sgdn_hd, adam, adam_hd)', default='adam', type=str)
213 | parser.add_argument('--alpha_0', help='initial learning rate', default=0.001, type=float)
214 | parser.add_argument('--beta', help='learning learning rate', default=0.000001, type=float)
215 | parser.add_argument('--mu', help='momentum', default=0.9, type=float)
216 | parser.add_argument('--weightDecay', help='regularization', default=0.0001, type=float)
217 | parser.add_argument('--batchSize', help='minibatch size', default=128, type=int)
218 | parser.add_argument('--epochs', help='stop after this many epochs (0: disregard)', default=2, type=int)
219 | parser.add_argument('--iterations', help='stop after this many iterations (0: disregard)', default=0, type=int)
220 | parser.add_argument('--lossThreshold', help='stop after reaching this loss (0: disregard)', default=0, type=float)
221 | parser.add_argument('--silent', help='do not print output', action='store_true')
222 | parser.add_argument('--workers', help='number of data loading workers', default=4, type=int)
223 | parser.add_argument('--parallel', help='parallelize', action='store_true')
224 | parser.add_argument('--save', help='do not save output to file', action='store_true')
225 | opt = parser.parse_args()
226 |
227 | torch.manual_seed(opt.seed)
228 | if opt.cuda:
229 | torch.cuda.set_device(opt.device)
230 | torch.cuda.manual_seed(opt.seed)
231 | torch.backends.cudnn.enabled = True
232 |
233 | file_name = '{}/{}/{:+.0e}_{:+.0e}/{}.csv'.format(opt.dir, opt.model, opt.alpha_0, opt.beta, opt.method)
234 | os.makedirs(os.path.dirname(file_name), exist_ok=True)
235 | if not opt.silent:
236 | print('Output file: {}'.format(file_name))
237 | # if os.path.isfile(file_name):
238 | # print('File with previous results exists, skipping...')
239 | # else:
240 | if not opt.save:
241 | def log_func(epoch, iteration, time_spent, loss, loss_epoch, valid_loss, alpha, alpha_epoch, beta):
242 | if not opt.silent:
243 | print('{} | {} | Epoch: {} | Iter: {} | Time: {:+.3e} | Loss: {:+.3e} | Valid. loss: {:+.3e} | Alpha: {:+.3e} | Beta: {:+.3e}'.format(opt.model, opt.method, epoch, iteration, time_spent, loss, valid_loss, alpha, beta))
244 | train(opt, log_func)
245 | else:
246 | with open(file_name, 'w') as f:
247 | writer = csv.writer(f)
248 | writer.writerow(['Epoch', 'Iteration', 'Time', 'Loss', 'LossEpoch', 'ValidLossEpoch', 'Alpha', 'AlphaEpoch', 'Beta'])
249 | def log_func(epoch, iteration, time_spent, loss, loss_epoch, valid_loss, alpha, alpha_epoch, beta):
250 | writer.writerow([epoch, iteration, time_spent, loss, loss_epoch, valid_loss, alpha, alpha_epoch, beta])
251 | if not opt.silent:
252 | print('{} | {} | Epoch: {} | Iter: {} | Time: {:+.3e} | Loss: {:+.3e} | Valid. loss: {:+.3e} | Alpha: {:+.3e} | Beta: {:+.3e}'.format(opt.model, opt.method, epoch, iteration, time_spent, loss, valid_loss, alpha, beta))
253 | train(opt, log_func)
254 |
255 | except KeyboardInterrupt:
256 | print('Stopped')
257 | except Exception:
258 | traceback.print_exc(file=sys.stdout)
259 | sys.exit(0)
260 |
261 |
262 | if __name__ == "__main__":
263 | main()
264 |
--------------------------------------------------------------------------------