├── .gitignore ├── LICENSE ├── README.md ├── pyproject.toml └── src └── gradient_descent_the_ultimate_optimizer ├── __init__.py └── gdtuo.py /.gitignore: -------------------------------------------------------------------------------- 1 | /data 2 | /__pycache__ 3 | /dist 4 | /src/gradient_descent_the_ultimate_optimizer.egg-info -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2022 Kartik Chandra, Audrey Xie. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Gradient Descent: The Ultimate Optimizer 2 | 3 | ![gdtuo_turtles](https://user-images.githubusercontent.com/31300675/193727211-bff82331-998c-4d44-b675-03d1fd222e0e.png) 4 | # Abstract 5 | Working with any gradient-based machine learning algorithm involves the tedious task of tuning the optimizer's hyperparameters, such as the step size. Recent work has shown how the step size can itself be "learned" on-line by gradient descent, by manually deriving expressions for "hypergradients" ahead of time. 6 | 7 | We show how to *automatically* compute hypergradients with a simple and elegant modification to backpropagation. This allows us to apply the method to other hyperparameters besides the step size, such as the momentum coefficient. We can even recursively apply the method to its own *hyper*-hyperparameters, and so on *ad infinitum*. As these towers of optimizers grow taller, they become less sensitive to the initial choice of hyperparameters. We present experiments validating this for MLPs, CNNs, and RNNs. 8 | 9 | *This repository contains an implementation of the algorithm in our paper.* 10 | 11 | # Citation 12 | ``` 13 | @article{chandra2022gradient, 14 | title = {Gradient Descent: The Ultimate Optimizer}, 15 | author = {Chandra, Kartik and Xie, Audrey and Ragan-Kelley, Jonathan and Meijer, Erik}, 16 | journal = {NeurIPS}, 17 | year = {2022}, 18 | url = {https://arxiv.org/abs/1909.13371} 19 | } 20 | ``` 21 | 22 | # Install 23 | ``` 24 | # install pytorch for your specific machine 25 | ... 26 | 27 | # install our package 28 | pip install gradient-descent-the-ultimate-optimizer 29 | ``` 30 | # Example 31 | First, build the MLP and initialize data loaders as you would normally in PyTorch. 32 | ```python 33 | import math 34 | import torch 35 | import torchvision 36 | import torch.nn as nn 37 | import torch.nn.functional as F 38 | 39 | class MNIST_FullyConnected(nn.Module): 40 | """ 41 | A fully-connected NN for the MNIST task. This is Optimizable but not itself 42 | an optimizer. 43 | """ 44 | def __init__(self, num_inp, num_hid, num_out): 45 | super(MNIST_FullyConnected, self).__init__() 46 | self.layer1 = nn.Linear(num_inp, num_hid) 47 | self.layer2 = nn.Linear(num_hid, num_out) 48 | 49 | def initialize(self): 50 | nn.init.kaiming_uniform_(self.layer1.weight, a=math.sqrt(5)) 51 | nn.init.kaiming_uniform_(self.layer2.weight, a=math.sqrt(5)) 52 | 53 | def forward(self, x): 54 | """Compute a prediction.""" 55 | x = self.layer1(x) 56 | x = torch.tanh(x) 57 | x = self.layer2(x) 58 | x = torch.tanh(x) 59 | x = F.log_softmax(x, dim=1) 60 | return x 61 | 62 | BATCH_SIZE = 256 63 | EPOCHS = 5 64 | DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' 65 | 66 | mnist_train = torchvision.datasets.MNIST('./data', train=True, download=True, transform=torchvision.transforms.ToTensor()) 67 | mnist_test = torchvision.datasets.MNIST('./data', train=False, download=True, transform=torchvision.transforms.ToTensor()) 68 | dl_train = torch.utils.data.DataLoader(mnist_train, batch_size=BATCH_SIZE, shuffle=True) 69 | dl_test = torch.utils.data.DataLoader(mnist_test, batch_size=10000, shuffle=False) 70 | 71 | model = MNIST_FullyConnected(28 * 28, 128, 10).to(DEVICE) 72 | ``` 73 | Next, import our package and initialize a stack of hyperoptimizers. This example uses the stack `Adam/SGD`. 74 | ```python 75 | from gradient_descent_the_ultimate_optimizer import gdtuo 76 | 77 | optim = gdtuo.Adam(optimizer=gdtuo.SGD(1e-5)) 78 | ``` 79 | `gdtuo.ModuleWrapper` allows any `nn.Module` to be optimized by hyperoptimizers. 80 | ```python 81 | mw = gdtuo.ModuleWrapper(model, optimizer=optim) 82 | mw.initialize() 83 | ``` 84 | Lastly, use `mw` instead of a PyTorch optimizer to optimize the model. The train loop is nearly identical to what you would typically implement in PyTorch (differences are marked by comments). 85 | ```python 86 | for i in range(1, EPOCHS+1): 87 | running_loss = 0.0 88 | for j, (features_, labels_) in enumerate(dl_train): 89 | mw.begin() # call this before each step, enables gradient tracking on desired params 90 | features, labels = torch.reshape(features_, (-1, 28 * 28)).to(DEVICE), labels_.to(DEVICE) 91 | pred = mw.forward(features) 92 | loss = F.nll_loss(pred, labels) 93 | mw.zero_grad() 94 | loss.backward(create_graph=True) # important! use create_graph=True 95 | mw.step() 96 | running_loss += loss.item() * features_.size(0) 97 | train_loss = running_loss / len(dl_train.dataset) 98 | print("EPOCH: {}, TRAIN LOSS: {}".format(i, train_loss)) 99 | ``` 100 | Note that on the first step of the train loop PyTorch will return the following warning: 101 | ``` 102 | UserWarning: Using backward() with create_graph=True will create a reference cycle between the parameter and its gradient which can cause a memory leak. We recommend using autograd.grad when creating the graph to avoid this. If you have to use this function, make sure to reset the .grad fields of your parameters to None after use to break the cycle and avoid the leak. 103 | ``` 104 | This is normal and to be expected. -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "gradient_descent_the_ultimate_optimizer" 7 | version = "1.0" 8 | authors = [ 9 | { name="Kartik Chandra" }, 10 | { name="Audrey Xie" } 11 | ] 12 | description = "Code for paper, Gradient Descent: The Ultimate Optimizer" 13 | readme = "README.md" 14 | requires-python = ">=3.7" 15 | classifiers = [ 16 | "Programming Language :: Python :: 3", 17 | "License :: OSI Approved :: MIT License", 18 | "Operating System :: OS Independent", 19 | ] 20 | 21 | [project.urls] 22 | "Homepage" = "https://github.com/kach/gradient-descent-the-ultimate-optimizer" -------------------------------------------------------------------------------- /src/gradient_descent_the_ultimate_optimizer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kach/gradient-descent-the-ultimate-optimizer/b3b047e02ca6d32e0e61e34a0ca6e0bc57e55bdf/src/gradient_descent_the_ultimate_optimizer/__init__.py -------------------------------------------------------------------------------- /src/gradient_descent_the_ultimate_optimizer/gdtuo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class Optimizable: 4 | ''' 5 | This is the interface for anything that has parameters that need to be 6 | optimized, somewhat like torch.nn.Model but with the right plumbing for 7 | hyperoptimizability. (Specifically, torch.nn.Model uses the Parameter 8 | interface which does not give us enough control about the detachments.) 9 | Nominal operation of an Optimizable at the lowest level is as follows: 10 | o = MyOptimizable(...) 11 | o.initialize() 12 | loop { 13 | o.begin() 14 | o.zero_grad() 15 | loss = --compute loss function from parameters-- 16 | loss.backward() 17 | o.step() 18 | } 19 | Optimizables recursively handle updates to their optimiz*ers*. 20 | ''' 21 | def __init__(self, parameters, optimizer): 22 | self.parameters = parameters # a dict mapping names to tensors 23 | self.optimizer = optimizer # which must itself be Optimizable! 24 | self.all_params_with_gradients = [] 25 | 26 | def initialize(self): 27 | ''' Initialize parameters, e.g. with a Kaiming initializer. ''' 28 | pass 29 | 30 | def begin(self): 31 | ''' Enable gradient tracking on current parameters. ''' 32 | for param in self.all_params_with_gradients: 33 | param.grad = None 34 | self.all_params_with_gradients.clear() 35 | for name, param in self.parameters.items(): 36 | param.requires_grad_() # keep gradient information... 37 | param.retain_grad() # even if not a leaf... 38 | self.all_params_with_gradients.append(param) 39 | self.optimizer.begin() 40 | 41 | def zero_grad(self): 42 | ''' Set all gradients to zero. ''' 43 | for param in self.all_params_with_gradients: 44 | param.grad = torch.zeros_like(param) 45 | self.optimizer.zero_grad() 46 | 47 | ''' Note: at this point you would probably call .backwards() on the loss 48 | function. ''' 49 | 50 | def step(self): 51 | ''' Update parameters ''' 52 | pass 53 | 54 | class NoOpOptimizer(Optimizable): 55 | ''' 56 | NoOpOptimizer sits on top of a stack, and does not affect what lies below. 57 | ''' 58 | def __init__(self): 59 | pass 60 | 61 | def initialize(self): 62 | pass 63 | 64 | def begin(self): 65 | pass 66 | 67 | def zero_grad(self): 68 | pass 69 | 70 | def step(self, params): 71 | pass 72 | 73 | def __str__(self): 74 | return '' 75 | 76 | class SGD(Optimizable): 77 | ''' 78 | A hyperoptimizable SGD. 79 | ''' 80 | def __init__(self, alpha=0.01, mu=0.0, optimizer=NoOpOptimizer()): 81 | self.mu = mu 82 | self.state = {} 83 | parameters = { 84 | 'alpha': torch.tensor(alpha), 85 | 'mu': torch.tensor(mu) 86 | } 87 | super().__init__(parameters, optimizer) 88 | 89 | def step(self, params): 90 | self.optimizer.step(self.parameters) 91 | for name, param in params.items(): 92 | g = param.grad.detach() 93 | p = param.detach() 94 | if self.mu != 0.0: 95 | if name not in self.state: 96 | buf = self.state[name] = g 97 | else: 98 | buf = self.state[name].detach() 99 | buf = buf * self.parameters['mu'] + g 100 | g = self.state[name] = buf 101 | params[name] = p - g * self.parameters['alpha'] 102 | 103 | def __str__(self): 104 | return 'sgd / '+ str(self.optimizer) 105 | 106 | class SGDPerParam(Optimizable): 107 | ''' 108 | Optimizes parameters individually with SGD. 109 | ''' 110 | def __init__(self, params, optimizer=NoOpOptimizer()): 111 | parameters = {k + '_alpha' : torch.tensor(v) for k, v in params} 112 | super().__init__(parameters, optimizer) 113 | 114 | def step(self, params): 115 | self.optimizer.step(self.parameters) 116 | for name, param in params.items(): 117 | g = param.grad.detach() 118 | p = param.detach() 119 | if name + '_alpha' not in self.parameters: params[name] = p 120 | else: params[name] = p - g * self.parameters[name + '_alpha'] 121 | 122 | def __str__(self): 123 | return 'sgdPerParam / ' + str(self.optimizer) 124 | 125 | class AdaGrad(Optimizable): 126 | ''' 127 | A hyperoptimizable AdaGrad. 128 | ''' 129 | def __init__(self, alpha=0.01, optimizer=NoOpOptimizer()): 130 | self.eps = 1e-10 131 | self.cache = {} 132 | parameters = { 133 | 'alpha': torch.tensor(alpha) 134 | } 135 | super().__init__(parameters, optimizer) 136 | 137 | def step(self, params): 138 | self.optimizer.step(self.parameters) 139 | for name, param in params.items(): 140 | if name not in self.cache: 141 | self.cache[name] = { 142 | 'G': torch.zeros_like(param) + 1e-1 143 | } 144 | g = param.grad.detach() 145 | self.cache[name]['G'] = G = self.cache[name]['G'].detach() + torch.square(g) 146 | params[name] = param.detach() - self.parameters['alpha'] * g / torch.sqrt(G + self.eps).detach() 147 | 148 | def __str__(self): 149 | return 'adagrad / ' + str(self.optimizer) 150 | 151 | class RMSProp(Optimizable): 152 | ''' 153 | A hyperoptimizable RMSProp. 154 | ''' 155 | def clamp(x): 156 | return (x.tanh() + 1.) / 2. 157 | 158 | def unclamp(y): 159 | z = y * 2. - 1. 160 | return ((1. + z) / (1. - z)).log() / 2. 161 | 162 | def __init__(self, alpha=0.01, gamma=0.99, optimizer=NoOpOptimizer()): 163 | self.eps = 1e-8 164 | parameters = { 165 | 'alpha': torch.sqrt(torch.tensor(alpha)), 166 | 'gamma': RMSProp.unclamp(torch.tensor(gamma)) 167 | } 168 | super().__init__(parameters, optimizer) 169 | self.cache = {} 170 | 171 | def step(self, params): 172 | self.optimizer.step(self.parameters) 173 | gamma = RMSProp.clamp(self.parameters['gamma']) 174 | alpha = torch.square(self.parameters['alpha']) 175 | for name, param in params.items(): 176 | if name not in self.cache: 177 | self.cache[name] = { 178 | 's': torch.zeros_like(param) 179 | } 180 | g = param.grad.detach() 181 | self.cache[name]['s'] = s = gamma * self.cache[name]['s'].detach() + (1. - gamma) * torch.square(g) 182 | self.all_params_with_gradients.append(s) 183 | params[name] = param.detach() - alpha * g / torch.sqrt(s + self.eps) 184 | 185 | def __str__(self): 186 | return 'rmsprop / ' + str(self.optimizer) 187 | 188 | class RMSPropAlpha(Optimizable): 189 | ''' 190 | A hyperoptimizable RMSProp for only alpha. 191 | ''' 192 | def __init__(self, alpha=0.01, gamma=0.99, optimizer=NoOpOptimizer()): 193 | self.eps = 1e-8 194 | self.gamma = gamma 195 | parameters = { 196 | 'alpha': torch.sqrt(torch.tensor(alpha)), 197 | } 198 | super().__init__(parameters, optimizer) 199 | self.cache = {} 200 | 201 | def step(self, params): 202 | self.optimizer.step(self.parameters) 203 | alpha = torch.square(self.parameters['alpha']) 204 | for name, param in params.items(): 205 | if name not in self.cache: 206 | self.cache[name] = { 207 | 's': torch.zeros_like(param) 208 | } 209 | g = param.grad.detach() 210 | self.cache[name]['s'] = s = self.gamma * self.cache[name]['s'].detach() + (1. - self.gamma) * torch.square(g) 211 | self.all_params_with_gradients.append(s) 212 | params[name] = param.detach() - alpha * g / torch.sqrt(s + self.eps) 213 | 214 | def __str__(self): 215 | return 'rmspropAlpha / ' + str(self.optimizer) 216 | 217 | class Adam(Optimizable): 218 | ''' 219 | A hyperoptimizable Adam optimizer. 220 | ''' 221 | def clamp(x): 222 | return (x.tanh() + 1.) / 2. 223 | 224 | def unclamp(y): 225 | z = y * 2. - 1. 226 | return ((1. + z) / (1. - z)).log() / 2. 227 | 228 | def __init__(self, alpha=0.001, beta1=0.9, beta2=0.999, log_eps=-8., optimizer=NoOpOptimizer()): 229 | self.eps = 10. ** log_eps 230 | parameters = { 231 | 'alpha': torch.tensor(alpha), 232 | 'beta1': Adam.unclamp(torch.tensor(beta1)), 233 | 'beta2': Adam.unclamp(torch.tensor(beta2)), 234 | } 235 | super().__init__(parameters, optimizer) 236 | self.num_stepments = 0 237 | self.cache = {} 238 | 239 | def step(self, params): 240 | self.num_stepments += 1 241 | self.optimizer.step(self.parameters) 242 | t = self.num_stepments 243 | beta1 = Adam.clamp(self.parameters['beta1']) 244 | beta2 = Adam.clamp(self.parameters['beta2']) 245 | for name, param in params.items(): 246 | if name not in self.cache: 247 | self.cache[name] = { 248 | 'm': torch.zeros_like(param), 249 | 'v': torch.zeros_like(param) +\ 250 | self.eps 251 | # NOTE that we add a little `fudge factor' here because sqrt is not 252 | # differentiable at exactly zero 253 | } 254 | g = param.grad.detach() 255 | self.cache[name]['m'] = m =\ 256 | beta1 * self.cache[name]['m'].detach() + (1. - beta1) * g 257 | self.cache[name]['v'] = v =\ 258 | beta2 * self.cache[name]['v'].detach() + (1. - beta2) * g * g 259 | self.all_params_with_gradients.append(m) 260 | self.all_params_with_gradients.append(v) 261 | 262 | m_hat = m / (1. - beta1 ** float(t)) 263 | v_hat = v / (1. - beta2 ** float(t)) 264 | 265 | dparam = m_hat / (v_hat ** 0.5 + self.eps) 266 | params[name] = param.detach() - self.parameters['alpha'] * dparam 267 | 268 | def __str__(self): 269 | return 'adam / ' + str(self.optimizer) 270 | 271 | class AdamBaydin(Optimizable): 272 | ''' Same as above, but only optimizes the learning rate, treating the 273 | remaining hyperparameters as constants. ''' 274 | 275 | def __init__( 276 | self, 277 | alpha=0.001, beta1=0.9, beta2=0.999, log_eps=-8., 278 | optimizer=NoOpOptimizer() 279 | ): 280 | parameters = { 281 | 'alpha': torch.tensor(alpha), 282 | } 283 | self.alpha = alpha 284 | self.beta1 = beta1 285 | self.beta2 = beta2 286 | self.log_eps = log_eps 287 | super().__init__(parameters, optimizer) 288 | self.num_stepments = 0 289 | self.cache = {} 290 | 291 | def step(self, params): 292 | self.num_stepments += 1 293 | self.optimizer.step(self.parameters) 294 | t = self.num_stepments 295 | beta1 = self.beta1 296 | beta2 = self.beta2 297 | for name, param in params.items(): 298 | if name not in self.cache: 299 | self.cache[name] = { 300 | 'm': torch.zeros_like(param), 301 | 'v': torch.zeros_like(param) +\ 302 | 10.**self.log_eps 303 | # NOTE that we add a little `fudge factor' here because sqrt is not 304 | # differentiable at exactly zero 305 | } 306 | 307 | g = param.grad.detach() 308 | self.cache[name]['m'] = m =\ 309 | beta1 * self.cache[name]['m'].detach() + (1. - beta1) * g 310 | self.cache[name]['v'] = v =\ 311 | beta2 * self.cache[name]['v'].detach() + (1. - beta2) * g * g 312 | 313 | self.all_params_with_gradients.append(m) 314 | self.all_params_with_gradients.append(v) 315 | 316 | m_hat = m / (1. - beta1 ** float(t)) 317 | v_hat = v / (1. - beta2 ** float(t)) 318 | 319 | dparam = m_hat / (v_hat ** 0.5 + 10. ** self.log_eps) 320 | params[name] = param.detach() - self.parameters['alpha'] * dparam 321 | 322 | def __str__(self): 323 | return 'adamBaydin / ' + str(self.optimizer) 324 | 325 | 326 | class ModuleWrapper(Optimizable): 327 | ''' 328 | This class tries to convert a torch.nn.Module to an Optimizable, handling 329 | the internal plumbing needed to update parameters correctly. 330 | ''' 331 | def __init__(self, module, optimizer=NoOpOptimizer()): 332 | self.module = module 333 | parameters = {k:v for k, v in module.named_parameters(recurse=True)} 334 | super().__init__(parameters, optimizer) 335 | 336 | def initialize(self): 337 | self.optimizer.initialize() 338 | 339 | def zero_grad(self): 340 | """ Set all gradients to zero. """ 341 | self.module.zero_grad() 342 | for param in self.all_params_with_gradients: 343 | param.grad = torch.zeros_like(param) 344 | self.optimizer.zero_grad() 345 | 346 | def forward(self, *xyz): 347 | return self.module(*xyz) 348 | 349 | def train(self): 350 | self.module.train() 351 | 352 | def eval(self): 353 | self.module.eval() 354 | 355 | def step(self): 356 | self.optimizer.step(self.parameters) 357 | def set_param(m, k, v): 358 | kk = k 359 | while '.' in k: 360 | sm = k[:k.index('.')] 361 | k = k[k.index('.') + 1:] 362 | m = m._modules[sm] 363 | 364 | m._parameters[k] = None 365 | m._parameters[k] = self.parameters[kk] 366 | 367 | for k, v in self.module.named_parameters(recurse=True): 368 | set_param(self.module, k, v) --------------------------------------------------------------------------------