├── hypergrad ├── __init__.py ├── adam_hd.py └── sgd_hd.py ├── plots ├── mlp.pdf ├── mlp.png ├── vgg.pdf ├── vgg.png ├── logreg.pdf └── logreg.png ├── poster ├── iclr_2018_poster.pdf └── iclr_2018_poster.png ├── setup.py ├── LICENSE ├── .gitignore ├── run.sh ├── vgg.py ├── plot.py ├── README.md └── train.py /hypergrad/__init__.py: -------------------------------------------------------------------------------- 1 | from .adam_hd import AdamHD 2 | from .sgd_hd import SGDHD 3 | -------------------------------------------------------------------------------- /plots/mlp.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gbaydin/hypergradient-descent/HEAD/plots/mlp.pdf -------------------------------------------------------------------------------- /plots/mlp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gbaydin/hypergradient-descent/HEAD/plots/mlp.png -------------------------------------------------------------------------------- /plots/vgg.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gbaydin/hypergradient-descent/HEAD/plots/vgg.pdf -------------------------------------------------------------------------------- /plots/vgg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gbaydin/hypergradient-descent/HEAD/plots/vgg.png -------------------------------------------------------------------------------- /plots/logreg.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gbaydin/hypergradient-descent/HEAD/plots/logreg.pdf -------------------------------------------------------------------------------- /plots/logreg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gbaydin/hypergradient-descent/HEAD/plots/logreg.png -------------------------------------------------------------------------------- /poster/iclr_2018_poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gbaydin/hypergradient-descent/HEAD/poster/iclr_2018_poster.pdf -------------------------------------------------------------------------------- /poster/iclr_2018_poster.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gbaydin/hypergradient-descent/HEAD/poster/iclr_2018_poster.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r") as f: 4 | long_description = f.read() 5 | 6 | setuptools.setup( 7 | name="hypergrad", 8 | version="0.1", 9 | author="Atılım Güneş Baydin", 10 | author_email="", 11 | description="Hypergradient descent", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | url="https://github.com/gbaydin/hypergradient-descent", 15 | packages=setuptools.find_packages(), 16 | classifiers=[ 17 | "Programming Language :: Python :: 3", 18 | "License :: OSI Approved :: MIT License", 19 | "Operating System :: OS Independent", 20 | ], 21 | ) 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Atılım Güneş Baydin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Training data and results 2 | data/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | env/ 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # SageMath parsed files 83 | *.sage.py 84 | 85 | # dotenv 86 | .env 87 | 88 | # virtualenv 89 | .venv 90 | venv/ 91 | ENV/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python train.py --cuda --model logreg --method sgd --save --epochs 10 --alpha_0 0.001 --beta 0.001 4 | python train.py --cuda --model logreg --method sgd_hd --save --epochs 10 --alpha_0 0.001 --beta 0.001 5 | python train.py --cuda --model logreg --method sgdn --save --epochs 10 --alpha_0 0.001 --beta 0.001 6 | python train.py --cuda --model logreg --method sgdn_hd --save --epochs 10 --alpha_0 0.001 --beta 0.001 7 | python train.py --cuda --model logreg --method adam --save --epochs 10 --alpha_0 0.001 --beta 1e-7 8 | python train.py --cuda --model logreg --method adam_hd --save --epochs 10 --alpha_0 0.001 --beta 1e-7 9 | 10 | python train.py --cuda --model mlp --method sgd --save --epochs 100 --alpha_0 0.001 --beta 0.001 11 | python train.py --cuda --model mlp --method sgd_hd --save --epochs 100 --alpha_0 0.001 --beta 0.001 12 | python train.py --cuda --model mlp --method sgdn --save --epochs 100 --alpha_0 0.001 --beta 0.001 13 | python train.py --cuda --model mlp --method sgdn_hd --save --epochs 100 --alpha_0 0.001 --beta 0.001 14 | python train.py --cuda --model mlp --method adam --save --epochs 100 --alpha_0 0.001 --beta 1e-7 15 | python train.py --cuda --model mlp --method adam_hd --save --epochs 100 --alpha_0 0.001 --beta 1e-7 16 | 17 | python train.py --cuda --model vgg --method sgd --save --epochs 100 --alpha_0 0.001 --beta 0.001 18 | python train.py --cuda --model vgg --method sgd_hd --save --epochs 100 --alpha_0 0.001 --beta 0.001 19 | python train.py --cuda --model vgg --method sgdn --save --epochs 100 --alpha_0 0.001 --beta 0.001 20 | python train.py --cuda --model vgg --method sgdn_hd --save --epochs 100 --alpha_0 0.001 --beta 0.001 21 | python train.py --cuda --model vgg --method adam --save --epochs 100 --alpha_0 0.001 --beta 1e-8 22 | python train.py --cuda --model vgg --method adam_hd --save --epochs 100 --alpha_0 0.001 --beta 1e-8 23 | -------------------------------------------------------------------------------- /vgg.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Modified from https://github.com/pytorch/vision.git 3 | ''' 4 | import math 5 | 6 | import torch.nn as nn 7 | import torch.nn.init as init 8 | 9 | __all__ = [ 10 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 11 | 'vgg19_bn', 'vgg19', 12 | ] 13 | 14 | 15 | class VGG(nn.Module): 16 | ''' 17 | VGG model 18 | ''' 19 | def __init__(self, features): 20 | super(VGG, self).__init__() 21 | self.features = features 22 | self.classifier = nn.Sequential( 23 | nn.Dropout(), 24 | nn.Linear(512, 512), 25 | nn.ReLU(True), 26 | nn.Dropout(), 27 | nn.Linear(512, 512), 28 | nn.ReLU(True), 29 | nn.Linear(512, 10), 30 | ) 31 | # Initialize weights 32 | for m in self.modules(): 33 | if isinstance(m, nn.Conv2d): 34 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 35 | m.weight.data.normal_(0, math.sqrt(2. / n)) 36 | m.bias.data.zero_() 37 | 38 | 39 | def forward(self, x): 40 | x = self.features(x) 41 | x = x.view(x.size(0), -1) 42 | x = self.classifier(x) 43 | return x 44 | 45 | 46 | def make_layers(cfg, batch_norm=False): 47 | layers = [] 48 | in_channels = 3 49 | for v in cfg: 50 | if v == 'M': 51 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 52 | else: 53 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 54 | if batch_norm: 55 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 56 | else: 57 | layers += [conv2d, nn.ReLU(inplace=True)] 58 | in_channels = v 59 | return nn.Sequential(*layers) 60 | 61 | 62 | cfg = { 63 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 64 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 65 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 66 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 67 | 512, 512, 512, 512, 'M'], 68 | } 69 | 70 | 71 | def vgg11(): 72 | """VGG 11-layer model (configuration "A")""" 73 | return VGG(make_layers(cfg['A'])) 74 | 75 | 76 | def vgg11_bn(): 77 | """VGG 11-layer model (configuration "A") with batch normalization""" 78 | return VGG(make_layers(cfg['A'], batch_norm=True)) 79 | 80 | 81 | def vgg13(): 82 | """VGG 13-layer model (configuration "B")""" 83 | return VGG(make_layers(cfg['B'])) 84 | 85 | 86 | def vgg13_bn(): 87 | """VGG 13-layer model (configuration "B") with batch normalization""" 88 | return VGG(make_layers(cfg['B'], batch_norm=True)) 89 | 90 | 91 | def vgg16(): 92 | """VGG 16-layer model (configuration "D")""" 93 | return VGG(make_layers(cfg['D'])) 94 | 95 | 96 | def vgg16_bn(): 97 | """VGG 16-layer model (configuration "D") with batch normalization""" 98 | return VGG(make_layers(cfg['D'], batch_norm=True)) 99 | 100 | 101 | def vgg19(): 102 | """VGG 19-layer model (configuration "E")""" 103 | return VGG(make_layers(cfg['E'])) 104 | 105 | 106 | def vgg19_bn(): 107 | """VGG 19-layer model (configuration 'E') with batch normalization""" 108 | return VGG(make_layers(cfg['E'], batch_norm=True)) 109 | -------------------------------------------------------------------------------- /hypergrad/adam_hd.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim.optimizer import Optimizer 4 | 5 | 6 | class AdamHD(Optimizer): 7 | """Implements Adam algorithm. 8 | 9 | It has been proposed in `Adam: A Method for Stochastic Optimization`_. 10 | 11 | Arguments: 12 | params (iterable): iterable of parameters to optimize or dicts defining 13 | parameter groups 14 | lr (float, optional): learning rate (default: 1e-3) 15 | betas (Tuple[float, float], optional): coefficients used for computing 16 | running averages of gradient and its square (default: (0.9, 0.999)) 17 | eps (float, optional): term added to the denominator to improve 18 | numerical stability (default: 1e-8) 19 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 20 | hypergrad_lr (float, optional): hypergradient learning rate for the online 21 | tuning of the learning rate, introduced in the paper 22 | `Online Learning Rate Adaptation with Hypergradient Descent`_ 23 | 24 | .. _Adam\: A Method for Stochastic Optimization: 25 | https://arxiv.org/abs/1412.6980 26 | .. _Online Learning Rate Adaptation with Hypergradient Descent: 27 | https://openreview.net/forum?id=BkrsAzWAb 28 | """ 29 | 30 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 31 | weight_decay=0, hypergrad_lr=1e-8): 32 | defaults = dict(lr=lr, betas=betas, eps=eps, 33 | weight_decay=weight_decay, hypergrad_lr=hypergrad_lr) 34 | super(AdamHD, self).__init__(params, defaults) 35 | 36 | def step(self, closure=None): 37 | """Performs a single optimization step. 38 | 39 | Arguments: 40 | closure (callable, optional): A closure that reevaluates the model 41 | and returns the loss. 42 | """ 43 | loss = None 44 | if closure is not None: 45 | loss = closure() 46 | 47 | for group in self.param_groups: 48 | for p in group['params']: 49 | if p.grad is None: 50 | continue 51 | grad = p.grad.data 52 | if grad.is_sparse: 53 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 54 | 55 | state = self.state[p] 56 | 57 | # State initialization 58 | if len(state) == 0: 59 | state['step'] = 0 60 | # Exponential moving average of gradient values 61 | state['exp_avg'] = torch.zeros_like(p.data) 62 | # Exponential moving average of squared gradient values 63 | state['exp_avg_sq'] = torch.zeros_like(p.data) 64 | 65 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 66 | beta1, beta2 = group['betas'] 67 | 68 | state['step'] += 1 69 | 70 | if group['weight_decay'] != 0: 71 | grad = grad.add(group['weight_decay'], p.data) 72 | 73 | if state['step'] > 1: 74 | prev_bias_correction1 = 1 - beta1 ** (state['step'] - 1) 75 | prev_bias_correction2 = 1 - beta2 ** (state['step'] - 1) 76 | # Hypergradient for Adam: 77 | h = torch.dot(grad.view(-1), torch.div(exp_avg, exp_avg_sq.sqrt().add_(group['eps'])).view(-1)) * math.sqrt(prev_bias_correction2) / prev_bias_correction1 78 | # Hypergradient descent of the learning rate: 79 | group['lr'] += group['hypergrad_lr'] * h 80 | 81 | # Decay the first and second moment running average coefficient 82 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 83 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 84 | denom = exp_avg_sq.sqrt().add_(group['eps']) 85 | 86 | bias_correction1 = 1 - beta1 ** state['step'] 87 | bias_correction2 = 1 - beta2 ** state['step'] 88 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 89 | 90 | p.data.addcdiv_(-step_size, exp_avg, denom) 91 | 92 | return loss 93 | -------------------------------------------------------------------------------- /plot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import argparse 4 | import csv 5 | import os 6 | import glob 7 | import matplotlib 8 | # Force matplotlib to not use any Xwindows backend. 9 | matplotlib.use('Agg') 10 | import matplotlib.pyplot as plt 11 | from matplotlib.ticker import MaxNLocator 12 | from mpl_toolkits.axes_grid.inset_locator import inset_axes 13 | 14 | colorblindbright = [(252,145,100),(188,56,119),(114,27,166)] 15 | colorblinddim = [(213,167,103),(163,85,114),(104,59,130)] 16 | for i in range(len(colorblindbright)): 17 | r, g, b = colorblindbright[i] 18 | colorblindbright[i] = (r / 255., g / 255., b / 255.) 19 | for i in range(len(colorblinddim)): 20 | r, g, b = colorblinddim[i] 21 | colorblinddim[i] = (r / 255., g / 255., b / 255.) 22 | 23 | colors = {'sgd':colorblinddim[0], 'sgdn':colorblinddim[1],'adam':colorblinddim[2], \ 24 | 'sgd_hd':colorblindbright[0], 'sgdn_hd':colorblindbright[1],'adam_hd':colorblindbright[2]} 25 | names = {'sgd':'SGD','sgdn':'SGDN','adam':'Adam','sgd_hd':'SGD-HD','sgdn_hd':'SGDN-HD','adam_hd':'Adam-HD'} 26 | linestyles = {'sgd':'--','sgdn':'--','adam':'--','sgd_hd':'-','sgdn_hd':'-','adam_hd':'-'} 27 | linedashes = {'sgd':[3,3],'sgdn':[3,3],'adam':[3,3],'sgd_hd':[10,1e-9],'sgdn_hd':[10,1e-9],'adam_hd':[10,1e-9]} 28 | 29 | parser = argparse.ArgumentParser(description='Plotting for hypergradient descent PyTorch tests', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 30 | parser.add_argument('--dir', help='directory to read the csv files written by train.py', default='results', type=str) 31 | parser.add_argument('--plotDir', help='directory to save the plots', default='plots', type=str) 32 | opt = parser.parse_args() 33 | 34 | os.makedirs(opt.plotDir, exist_ok=True) 35 | 36 | model_titles = {'logreg': 'Logistic regression (on MNIST)', 'mlp': 'Multi-layer neural network (on MNIST)', 'vgg': 'VGG Net (on CIFAR-10)'} 37 | for model in next(os.walk(opt.dir))[1]: 38 | data = {} 39 | data_epoch = {} 40 | selected = [] 41 | for fname in glob.glob('{}/{}/**/*.csv'.format(opt.dir, model), recursive=True): 42 | name = os.path.splitext(os.path.basename(fname))[0] 43 | data[name] = pd.read_csv(fname) 44 | data_epoch[name] = data[name][pd.notna(data[name].LossEpoch)] 45 | selected.append(name) 46 | 47 | plt.figure(figsize=(5,12)) 48 | 49 | fig = plt.figure(figsize=(5, 12)) 50 | ax = fig.add_subplot(311) 51 | for name in selected: 52 | plt.plot(data_epoch[name].Epoch,data_epoch[name].AlphaEpoch,label=names[name],color=colors[name],linestyle=linestyles[name],dashes=linedashes[name]) 53 | ax.xaxis.set_major_locator(MaxNLocator(integer=True)) 54 | plt.ylabel('Learning rate') 55 | plt.tick_params(labeltop=False, labelbottom=False, bottom=False, top=False, labelright=False) 56 | plt.grid() 57 | plt.title(model_titles[model]) 58 | inset_axes(ax, width="50%", height="35%", loc=1) 59 | for name in selected: 60 | plt.plot(data[name].Iteration, data[name].Alpha,label=names[name],color=colors[name],linestyle=linestyles[name],dashes=linedashes[name]) 61 | plt.yticks(np.arange(-0.01, 0.051, 0.01)) 62 | plt.xlabel('Iteration') 63 | plt.ylabel('Learning rate') 64 | plt.xscale('log') 65 | plt.xlim([0,9000]) 66 | plt.grid() 67 | 68 | ax = fig.add_subplot(312) 69 | for name in selected: 70 | plt.plot(data_epoch[name].Epoch, data_epoch[name].LossEpoch,label=names[name],color=colors[name],linestyle=linestyles[name],dashes=linedashes[name]) 71 | ax.xaxis.set_major_locator(MaxNLocator(integer=True)) 72 | plt.ylabel('Training loss') 73 | plt.yscale('log') 74 | plt.tick_params(labeltop=False, labelbottom=False, bottom=False, top=False, labelright=False) 75 | plt.grid() 76 | inset_axes(ax, width="50%", height="35%", loc=1) 77 | for name in selected: 78 | plt.plot(data[name].Iteration, data[name].Loss,label=names[name],color=colors[name],linestyle=linestyles[name],dashes=linedashes[name]) 79 | plt.yticks(np.arange(0, 2.01, 0.5)) 80 | plt.xlabel('Iteration') 81 | plt.ylabel('Training loss') 82 | plt.xscale('log') 83 | plt.xlim([0,9000]) 84 | plt.grid() 85 | 86 | ax = fig.add_subplot(313) 87 | for name in selected: 88 | plt.plot(data_epoch[name].Epoch, data_epoch[name].ValidLossEpoch,label=names[name],color=colors[name],linestyle=linestyles[name],dashes=linedashes[name]) 89 | plt.xlabel('Epoch') 90 | ax.xaxis.set_major_locator(MaxNLocator(integer=True)) 91 | plt.ylabel('Validation loss') 92 | plt.yscale('log') 93 | handles, labels = plt.gca().get_legend_handles_labels() 94 | labels, handles = zip(*sorted(zip(labels, handles), key=lambda t: t[0])) 95 | plt.legend(handles,labels,loc='upper right',frameon=1,framealpha=1,edgecolor='black',fancybox=False) 96 | plt.grid() 97 | 98 | plt.tight_layout() 99 | plt.savefig('{}/{}.pdf'.format(opt.plotDir, model), bbox_inches='tight') 100 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # hypergradient-descent 2 | This is the [PyTorch](http://pytorch.org/) code for the paper [_Online Learning Rate Adaptation with Hypergradient Descent_](https://openreview.net/forum?id=BkrsAzWAb) at ICLR 2018. 3 | 4 | A [TensorFlow](https://www.tensorflow.org/) version is also planned and should appear in this repo at a later time. 5 | 6 | 7 | 8 | ## What is a "hypergradient"? 9 | 10 | In gradient-based optimization, one optimizes an objective function by using its derivatives (gradient) with respect to model parameters. In addition to this basic gradient, a _hypergradient_ is the derivative of the same objective function with respect to the optimization procedure's hyperparameters (such as the learning rate, momentum, or regularization parameters). There can be many types of hypergradients, and in this work we're interested in the hypergradient with respect to a scalar learning rate. 11 | 12 | ## Installation 13 | 14 | ```bash 15 | pip install git+https://github.com/gbaydin/hypergradient-descent.git 16 | ``` 17 | 18 | ## How can I use it for my work? 19 | 20 | We are providing ready-to-use implementations of the hypergradient versions of SGD (with or without momentum) and Adam optimizers for PyTorch. These comply with the `torch.optim` API and can be used as drop-in replacements in your code. Just take the `sgd_hd.py` and `adam_hd.py` files from this repo and import them like 21 | 22 | ```python 23 | from hypergrad import SGDHD, AdamHD 24 | ... 25 | 26 | optimizer = optim.AdamHD(model.parameters(), lr=args.lr, hypergrad_lr=1e-8) 27 | ... 28 | ``` 29 | 30 | The optimizers introduce an extra argument `hypergrad_lr` which determines the hypergradient learning rate, that is, the learning rate used to optimize the regular learning rate `lr`. The value you give to `lr` sets the initial value of the regular learning rate, from which it will be adapted in an online fashion by the hypergradient descent procedure. Lower values for `hypergrad_lr` are safer because (1) the resulting updates to `lr` are smaller; and (2) one recovers the non-hypergradient version of the algorithm as `hypergrad_lr` approaches zero. 31 | 32 | Don't be worried that, instead of having to tune just one learning rate (`lr`), now you have to tune two (`lr` and `hypergrad_lr`); just see the next section. 33 | 34 | ## What is the advantage? 35 | Hypergradient algorithms are much less sensitive to the choice of the initial learning rate (`lr`), unlike the non-hypergradient version of the same algorithm. A hypergradient algorithm requires significantly less tuning to give performance better than, or in the worst case the same as, a non-hypergradient baseline, given a small `hypergrad_lr`, which can either be left as the recommended default or tuned. Please see the paper for guideline values of `hypergrad_lr`. 36 | 37 | In practice, you might be surprised to see that **even starting with a zero learning rate works** and the learning rate is quickly raised to a useful non-zero level as needed, and then decayed towards zero as optimization converges: 38 | ```python 39 | optimizer = optim.AdamHD(model.parameters(), lr=0, hypergrad_lr=1e-8) 40 | ``` 41 | 42 | If you would like to monitor the evolution of the learning rate during optimization, you can monitor it with code that looks like 43 | 44 | ```python 45 | lr = optimizer.param_groups[0]['lr'] 46 | print(lr) 47 | ``` 48 | 49 | ## Notes about the code in this repository 50 | * The results in the paper were produced by code written in [(Lua)torch](http://torch.ch/). The code in this repo is a reimplementation in PyTorch, which produces results that are not exactly the same but qualitatively identical in the behavior of the learning rate, training and validation losses, and the relative performance of the algorithms. 51 | * In the .csv result files, `nan` is used as a placeholder for empty entries in epoch losses, which are only computed once per epoch. 52 | * The implementation in this repository doesn't include any heuristics to ensure that your gradients and hypergradients don't "explode". In practice, you might need to apply gradient clipping or safeguards for the updates to the learning rate to prevent bad behavior. If this happens in your models, we would be interested in hearing about it: please let us know via email or a GitHub issue. 53 | 54 | ## Some other implementations 55 | 56 | * (Lua)Torch optim implementation (not the main implementation for the paper and not well-tested): https://github.com/gbaydin/optim 57 | * C++ implementation in the Livermore Big Artificial Neural Network Toolkit, Lawrence Livermore National Laboratory: https://github.com/LLNL/lbann/blob/a778d2b764ba209042555aac26328cbfb8063802/src/optimizers/hypergradient_adam.cpp 58 | * TensorFlow implementation by Andrii Zadaianchuk, University of Tübingen: https://github.com/zadaianchuk/HyperGradientDescent 59 | * Java implementation in Apache Hivemall: https://github.com/apache/incubator-hivemall/blob/31932fd7c63f9bb21eba8959944d03f280b6deb9/core/src/main/java/hivemall/optimizer/Optimizer.java#L640 60 | 61 | ## Paper 62 | Atılım Güneş Baydin, Robert Cornish, David Martı́nez Rubio, Mark Schmidt, and Frank Wood. Online learning rate adaptation with hypergradient descent. In _Sixth International 63 | Conference on Learning Representations (ICLR), Vancouver, Canada, April 30 – May 3, 2018,_ 2018. 64 | 65 | https://openreview.net/forum?id=BkrsAzWAb 66 | 67 | ``` 68 | @inproceedings{baydin2018hypergradient, 69 | title = {Online Learning Rate Adaptation with Hypergradient Descent}, 70 | author = {Baydin, Atılım Güneş and Cornish, Robert and Rubio, David Martínez and Schmidt, Mark and Wood, Frank}, 71 | booktitle = {Sixth International Conference on Learning Representations (ICLR), Vancouver, Canada, April 30 -- May 3, 2018}, 72 | year = {2018} 73 | } 74 | ``` 75 | 76 | ## Poster 77 | 78 | [The ICLR 2018 poster](https://github.com/gbaydin/hypergradient-descent/raw/master/poster/iclr_2018_poster.pdf) 79 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /hypergrad/sgd_hd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from functools import reduce 3 | from torch.optim.optimizer import Optimizer, required 4 | 5 | 6 | class SGDHD(Optimizer): 7 | r"""Implements stochastic gradient descent (optionally with momentum). 8 | 9 | Nesterov momentum is based on the formula from 10 | `On the importance of initialization and momentum in deep learning`__. 11 | 12 | Args: 13 | params (iterable): iterable of parameters to optimize or dicts defining 14 | parameter groups 15 | lr (float): learning rate 16 | momentum (float, optional): momentum factor (default: 0) 17 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 18 | dampening (float, optional): dampening for momentum (default: 0) 19 | nesterov (bool, optional): enables Nesterov momentum (default: False) 20 | hypergrad_lr (float, optional): hypergradient learning rate for the online 21 | tuning of the learning rate, introduced in the paper 22 | `Online Learning Rate Adaptation with Hypergradient Descent`_ 23 | 24 | Example: 25 | >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) 26 | >>> optimizer.zero_grad() 27 | >>> loss_fn(model(input), target).backward() 28 | >>> optimizer.step() 29 | 30 | __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf 31 | .. _Online Learning Rate Adaptation with Hypergradient Descent: 32 | https://openreview.net/forum?id=BkrsAzWAb 33 | .. note:: 34 | The implementation of SGD with Momentum/Nesterov subtly differs from 35 | Sutskever et. al. and implementations in some other frameworks. 36 | 37 | Considering the specific case of Momentum, the update can be written as 38 | 39 | .. math:: 40 | v = \rho * v + g \\ 41 | p = p - lr * v 42 | 43 | where p, g, v and :math:`\rho` denote the parameters, gradient, 44 | velocity, and momentum respectively. 45 | 46 | This is in contrast to Sutskever et. al. and 47 | other frameworks which employ an update of the form 48 | 49 | .. math:: 50 | v = \rho * v + lr * g \\ 51 | p = p - v 52 | 53 | The Nesterov version is analogously modified. 54 | """ 55 | 56 | def __init__(self, params, lr=required, momentum=0, dampening=0, 57 | weight_decay=0, nesterov=False, hypergrad_lr=1e-6): 58 | defaults = dict(lr=lr, momentum=momentum, dampening=dampening, 59 | weight_decay=weight_decay, nesterov=nesterov, hypergrad_lr=hypergrad_lr) 60 | if nesterov and (momentum <= 0 or dampening != 0): 61 | raise ValueError("Nesterov momentum requires a momentum and zero dampening") 62 | super(SGDHD, self).__init__(params, defaults) 63 | 64 | if len(self.param_groups) != 1: 65 | raise ValueError("SGDHD doesn't support per-parameter options (parameter groups)") 66 | 67 | self._params = self.param_groups[0]['params'] 68 | self._params_numel = reduce(lambda total, p: total + p.numel(), self._params, 0) 69 | 70 | def _gather_flat_grad_with_weight_decay(self, weight_decay=0): 71 | views = [] 72 | for p in self._params: 73 | if p.grad is None: 74 | view = torch.zeros_like(p.data) 75 | elif p.grad.data.is_sparse: 76 | view = p.grad.data.to_dense().view(-1) 77 | else: 78 | view = p.grad.data.view(-1) 79 | if weight_decay != 0: 80 | view.add_(weight_decay, p.data.view(-1)) 81 | views.append(view) 82 | return torch.cat(views, 0) 83 | 84 | def _add_grad(self, step_size, update): 85 | offset = 0 86 | for p in self._params: 87 | numel = p.numel() 88 | # view as to avoid deprecated pointwise semantics 89 | p.data.add_(step_size, update[offset:offset + numel].view_as(p.data)) 90 | offset += numel 91 | assert offset == self._params_numel 92 | 93 | def step(self, closure=None): 94 | """Performs a single optimization step. 95 | 96 | Arguments: 97 | closure (callable, optional): A closure that reevaluates the model 98 | and returns the loss. 99 | """ 100 | assert len(self.param_groups) == 1 101 | 102 | loss = None 103 | if closure is not None: 104 | loss = closure() 105 | 106 | group = self.param_groups[0] 107 | weight_decay = group['weight_decay'] 108 | momentum = group['momentum'] 109 | dampening = group['dampening'] 110 | nesterov = group['nesterov'] 111 | 112 | grad = self._gather_flat_grad_with_weight_decay(weight_decay) 113 | 114 | # NOTE: SGDHD has only global state, but we register it as state for 115 | # the first param, because this helps with casting in load_state_dict 116 | state = self.state[self._params[0]] 117 | # State initialization 118 | if len(state) == 0: 119 | state['grad_prev'] = torch.zeros_like(grad) 120 | 121 | grad_prev = state['grad_prev'] 122 | # Hypergradient for SGD 123 | h = torch.dot(grad, grad_prev) 124 | # Hypergradient descent of the learning rate: 125 | group['lr'] += group['hypergrad_lr'] * h 126 | 127 | if momentum != 0: 128 | if 'momentum_buffer' not in state: 129 | buf = state['momentum_buffer'] = torch.zeros_like(grad) 130 | buf.mul_(momentum).add_(grad) 131 | else: 132 | buf = state['momentum_buffer'] 133 | buf.mul_(momentum).add_(1 - dampening, grad) 134 | if nesterov: 135 | grad.add_(momentum, buf) 136 | else: 137 | grad = buf 138 | 139 | state['grad_prev'] = grad 140 | 141 | self._add_grad(-group['lr'], grad) 142 | 143 | return loss 144 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import traceback 2 | import argparse 3 | import sys 4 | import os 5 | import csv 6 | import time 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.autograd import Variable 12 | from torchvision import datasets, transforms 13 | import vgg 14 | from torch.utils.data import DataLoader 15 | from torch.optim import SGD, Adam 16 | from hypergrad import SGDHD, AdamHD 17 | 18 | 19 | class LogReg(nn.Module): 20 | def __init__(self, input_dim, output_dim): 21 | super(LogReg, self).__init__() 22 | self._input_dim = input_dim 23 | self.lin1 = nn.Linear(input_dim, output_dim) 24 | 25 | def forward(self, x): 26 | x = x.view(-1, self._input_dim) 27 | x = self.lin1(x) 28 | return x 29 | 30 | 31 | class MLP(nn.Module): 32 | def __init__(self, input_dim, hidden_dim, output_dim): 33 | super(MLP, self).__init__() 34 | self._input_dim = input_dim 35 | self.lin1 = nn.Linear(input_dim, hidden_dim) 36 | self.lin2 = nn.Linear(hidden_dim, hidden_dim) 37 | self.lin3 = nn.Linear(hidden_dim, output_dim) 38 | 39 | def forward(self, x): 40 | x = x.view(-1, self._input_dim) 41 | x = F.relu(self.lin1(x)) 42 | x = F.relu(self.lin2(x)) 43 | x = self.lin3(x) 44 | return x 45 | 46 | 47 | def train(opt, log_func=None): 48 | torch.manual_seed(opt.seed) 49 | if opt.cuda: 50 | torch.cuda.set_device(opt.device) 51 | torch.cuda.manual_seed(opt.seed) 52 | torch.backends.cudnn.enabled = True 53 | 54 | if opt.model == 'logreg': 55 | model = LogReg(28 * 28, 10) 56 | elif opt.model == 'mlp': 57 | model = MLP(28 * 28, 1000, 10) 58 | elif opt.model == 'vgg': 59 | model = vgg.vgg16_bn() 60 | if opt.parallel: 61 | model.features = torch.nn.DataParallel(model.features) 62 | else: 63 | raise Exception('Unknown model: {}'.format(opt.model)) 64 | 65 | if opt.cuda: 66 | model = model.cuda() 67 | 68 | if opt.model == 'logreg' or opt.model == 'mlp': 69 | task = 'MNIST' 70 | train_loader = DataLoader( 71 | datasets.MNIST('./data', train=True, download=True, 72 | transform=transforms.Compose([ 73 | transforms.ToTensor(), 74 | transforms.Normalize((0.1307,), (0.3081,)) 75 | ])), 76 | batch_size=opt.batchSize, shuffle=True) 77 | valid_loader = DataLoader( 78 | datasets.MNIST('./data', train=False, transform=transforms.Compose([ 79 | transforms.ToTensor(), 80 | transforms.Normalize((0.1307,), (0.3081,)) 81 | ])), 82 | batch_size=opt.batchSize, shuffle=False) 83 | elif opt.model == 'vgg': 84 | task = 'CIFAR10' 85 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 86 | std=[0.229, 0.224, 0.225]) 87 | train_loader = torch.utils.data.DataLoader( 88 | datasets.CIFAR10(root='./data', train=True, transform=transforms.Compose([ 89 | transforms.RandomHorizontalFlip(), 90 | transforms.RandomCrop(32, 4), 91 | transforms.ToTensor(), 92 | normalize, 93 | ]), download=True), 94 | batch_size=opt.batchSize, shuffle=True, 95 | num_workers=opt.workers, pin_memory=True) 96 | 97 | valid_loader = torch.utils.data.DataLoader( 98 | datasets.CIFAR10(root='./data', train=False, transform=transforms.Compose([ 99 | transforms.ToTensor(), 100 | normalize, 101 | ])), 102 | batch_size=opt.batchSize, shuffle=False, 103 | num_workers=opt.workers, pin_memory=True) 104 | else: 105 | raise Exception('Unknown model: {}'.format(opt.model)) 106 | 107 | if opt.method == 'sgd': 108 | optimizer = SGD(model.parameters(), lr=opt.alpha_0, weight_decay=opt.weightDecay) 109 | elif opt.method == 'sgd_hd': 110 | optimizer = SGDHD(model.parameters(), lr=opt.alpha_0, weight_decay=opt.weightDecay, hypergrad_lr=opt.beta) 111 | elif opt.method == 'sgdn': 112 | optimizer = SGD(model.parameters(), lr=opt.alpha_0, weight_decay=opt.weightDecay, momentum=opt.mu, nesterov=True) 113 | elif opt.method == 'sgdn_hd': 114 | optimizer = SGDHD(model.parameters(), lr=opt.alpha_0, weight_decay=opt.weightDecay, momentum=opt.mu, nesterov=True, hypergrad_lr=opt.beta) 115 | elif opt.method == 'adam': 116 | optimizer = Adam(model.parameters(), lr=opt.alpha_0, weight_decay=opt.weightDecay) 117 | elif opt.method == 'adam_hd': 118 | optimizer = AdamHD(model.parameters(), lr=opt.alpha_0, weight_decay=opt.weightDecay, hypergrad_lr=opt.beta) 119 | else: 120 | raise Exception('Unknown method: {}'.format(opt.method)) 121 | 122 | if not opt.silent: 123 | print('Task: {}, Model: {}, Method: {}'.format(task, opt.model, opt.method)) 124 | 125 | model.eval() 126 | for batch_id, (data, target) in enumerate(train_loader): 127 | data, target = Variable(data), Variable(target) 128 | if opt.cuda: 129 | data, target = data.cuda(), target.cuda() 130 | output = model(data) 131 | loss = F.cross_entropy(output, target) 132 | loss = loss.data[0] 133 | break 134 | valid_loss = 0 135 | for data, target in valid_loader: 136 | data, target = Variable(data, volatile=True), Variable(target) 137 | if opt.cuda: 138 | data, target = data.cuda(), target.cuda() 139 | output = model(data) 140 | valid_loss += F.cross_entropy(output, target, size_average=False).data[0] 141 | valid_loss /= len(valid_loader.dataset) 142 | if log_func is not None: 143 | log_func(0, 0, 0, loss, loss, valid_loss, opt.alpha_0, opt.alpha_0, opt.beta) 144 | 145 | time_start = time.time() 146 | iteration = 1 147 | epoch = 1 148 | done = False 149 | while not done: 150 | model.train() 151 | loss_epoch = 0 152 | alpha_epoch = 0 153 | for batch_id, (data, target) in enumerate(train_loader): 154 | data, target = Variable(data), Variable(target) 155 | if opt.cuda: 156 | data, target = data.cuda(), target.cuda() 157 | optimizer.zero_grad() 158 | output = model(data) 159 | loss = F.cross_entropy(output, target) 160 | loss.backward() 161 | optimizer.step() 162 | loss = loss.data[0] 163 | loss_epoch += loss 164 | alpha = optimizer.param_groups[0]['lr'] 165 | alpha_epoch += alpha 166 | iteration += 1 167 | if opt.iterations != 0: 168 | if iteration > opt.iterations: 169 | print('Early stopping: iteration > {}'.format(opt.iterations)) 170 | done = True 171 | break 172 | if opt.lossThreshold >= 0: 173 | if loss <= opt.lossThreshold: 174 | print('Early stopping: loss <= {}'.format(opt.lossThreshold)) 175 | done = True 176 | break 177 | 178 | if batch_id + 1 >= len(train_loader): 179 | loss_epoch /= len(train_loader) 180 | alpha_epoch /= len(train_loader) 181 | model.eval() 182 | valid_loss = 0 183 | for data, target in valid_loader: 184 | data, target = Variable(data, volatile=True), Variable(target) 185 | if opt.cuda: 186 | data, target = data.cuda(), target.cuda() 187 | output = model(data) 188 | valid_loss += F.cross_entropy(output, target, size_average=False).data[0] 189 | valid_loss /= len(valid_loader.dataset) 190 | if log_func is not None: 191 | log_func(epoch, iteration, time.time() - time_start, loss, loss_epoch, valid_loss, alpha, alpha_epoch, opt.beta) 192 | else: 193 | if log_func is not None: 194 | log_func(epoch, iteration, time.time() - time_start, loss, float('nan'), float('nan'), alpha, float('nan'), opt.beta) 195 | 196 | epoch += 1 197 | if opt.epochs != 0: 198 | if epoch > opt.epochs: 199 | print('Early stopping: epoch > {}'.format(opt.epochs)) 200 | done = True 201 | return loss, iteration 202 | 203 | 204 | def main(): 205 | try: 206 | parser = argparse.ArgumentParser(description='Hypergradient descent PyTorch tests', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 207 | parser.add_argument('--cuda', help='use CUDA', action='store_true') 208 | parser.add_argument('--device', help='selected CUDA device', default=0, type=int) 209 | parser.add_argument('--seed', help='random seed', default=1, type=int) 210 | parser.add_argument('--dir', help='directory to write the output files', default='results', type=str) 211 | parser.add_argument('--model', help='model (logreg, mlp, vgg)', default='logreg', type=str) 212 | parser.add_argument('--method', help='method (sgd, sgd_hd, sgdn, sgdn_hd, adam, adam_hd)', default='adam', type=str) 213 | parser.add_argument('--alpha_0', help='initial learning rate', default=0.001, type=float) 214 | parser.add_argument('--beta', help='learning learning rate', default=0.000001, type=float) 215 | parser.add_argument('--mu', help='momentum', default=0.9, type=float) 216 | parser.add_argument('--weightDecay', help='regularization', default=0.0001, type=float) 217 | parser.add_argument('--batchSize', help='minibatch size', default=128, type=int) 218 | parser.add_argument('--epochs', help='stop after this many epochs (0: disregard)', default=2, type=int) 219 | parser.add_argument('--iterations', help='stop after this many iterations (0: disregard)', default=0, type=int) 220 | parser.add_argument('--lossThreshold', help='stop after reaching this loss (0: disregard)', default=0, type=float) 221 | parser.add_argument('--silent', help='do not print output', action='store_true') 222 | parser.add_argument('--workers', help='number of data loading workers', default=4, type=int) 223 | parser.add_argument('--parallel', help='parallelize', action='store_true') 224 | parser.add_argument('--save', help='do not save output to file', action='store_true') 225 | opt = parser.parse_args() 226 | 227 | torch.manual_seed(opt.seed) 228 | if opt.cuda: 229 | torch.cuda.set_device(opt.device) 230 | torch.cuda.manual_seed(opt.seed) 231 | torch.backends.cudnn.enabled = True 232 | 233 | file_name = '{}/{}/{:+.0e}_{:+.0e}/{}.csv'.format(opt.dir, opt.model, opt.alpha_0, opt.beta, opt.method) 234 | os.makedirs(os.path.dirname(file_name), exist_ok=True) 235 | if not opt.silent: 236 | print('Output file: {}'.format(file_name)) 237 | # if os.path.isfile(file_name): 238 | # print('File with previous results exists, skipping...') 239 | # else: 240 | if not opt.save: 241 | def log_func(epoch, iteration, time_spent, loss, loss_epoch, valid_loss, alpha, alpha_epoch, beta): 242 | if not opt.silent: 243 | print('{} | {} | Epoch: {} | Iter: {} | Time: {:+.3e} | Loss: {:+.3e} | Valid. loss: {:+.3e} | Alpha: {:+.3e} | Beta: {:+.3e}'.format(opt.model, opt.method, epoch, iteration, time_spent, loss, valid_loss, alpha, beta)) 244 | train(opt, log_func) 245 | else: 246 | with open(file_name, 'w') as f: 247 | writer = csv.writer(f) 248 | writer.writerow(['Epoch', 'Iteration', 'Time', 'Loss', 'LossEpoch', 'ValidLossEpoch', 'Alpha', 'AlphaEpoch', 'Beta']) 249 | def log_func(epoch, iteration, time_spent, loss, loss_epoch, valid_loss, alpha, alpha_epoch, beta): 250 | writer.writerow([epoch, iteration, time_spent, loss, loss_epoch, valid_loss, alpha, alpha_epoch, beta]) 251 | if not opt.silent: 252 | print('{} | {} | Epoch: {} | Iter: {} | Time: {:+.3e} | Loss: {:+.3e} | Valid. loss: {:+.3e} | Alpha: {:+.3e} | Beta: {:+.3e}'.format(opt.model, opt.method, epoch, iteration, time_spent, loss, valid_loss, alpha, beta)) 253 | train(opt, log_func) 254 | 255 | except KeyboardInterrupt: 256 | print('Stopped') 257 | except Exception: 258 | traceback.print_exc(file=sys.stdout) 259 | sys.exit(0) 260 | 261 | 262 | if __name__ == "__main__": 263 | main() 264 | --------------------------------------------------------------------------------