├── src └── dog │ ├── __init__.py │ ├── averager.py │ └── dog.py ├── pyproject.toml ├── LICENSE ├── README.md └── example.py /src/dog/__init__.py: -------------------------------------------------------------------------------- 1 | from .dog import DoG, LDoG 2 | from .averager import PolynomialDecayAverager -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "dog-optimizer" 7 | version = "1.0.3" 8 | authors = [ 9 | { name="Maor Ivgi" }, 10 | { name="Oliver Hinder" }, 11 | { name="Yair Carmon" } 12 | ] 13 | description = "implementation of the algorithms in the paper DoG is SGD's Best Friend: A Parameter-Free Dynamic Step Size Schedule" 14 | readme = "README.md" 15 | requires-python = ">=3.7" 16 | classifiers = [ 17 | "Programming Language :: Python :: 3", 18 | "License :: OSI Approved :: MIT License", 19 | "Operating System :: OS Independent", 20 | ] 21 | 22 | [project.urls] 23 | "Homepage" = "https://github.com/formll/dog" 24 | "Paper" = "https://arxiv.org/abs/2302.12022" 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Foundations of Robust Machine Learning Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/dog/averager.py: -------------------------------------------------------------------------------- 1 | """ 2 | PyTorch implementation of polynomial decay averaging (Shamir & Zhang, 2013) 3 | """ 4 | import logging 5 | 6 | import torch 7 | from torch import nn 8 | 9 | from copy import deepcopy 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class PolynomialDecayAverager: 15 | """ 16 | Averaging model weights using a polynomial decay, as described in Shamir & Zhang, 2013. 17 | 18 | Given parameters x_t at iteration t, the averaged parameters are updated as follows: 19 | .. math:: 20 | \begin{aligned} 21 | \bar{x}_t = (1 - \frac{1+\gamma}{t+\gamma}) \cdot \bar{x}_{t-1} + \frac{1+\gamma}{t+\gamma} \cdot x_t 22 | \end{aligned} 23 | """ 24 | 25 | def __init__(self, model: nn.Module, gamma: float = 8.): 26 | self.t = 1 27 | self.model = model 28 | self.av_model = deepcopy(model) 29 | self.gamma = gamma 30 | self._matched_devices = False 31 | 32 | def step(self): 33 | self._match_models_devices() 34 | 35 | t = self.t 36 | model_sd = self.model.state_dict() 37 | av_sd = self.av_model.state_dict() 38 | 39 | for k in model_sd.keys(): 40 | if isinstance(av_sd[k], (torch.LongTensor, torch.cuda.LongTensor)): 41 | # these are buffers that store how many batches batch norm has seen so far 42 | av_sd[k].copy_(model_sd[k]) 43 | continue 44 | av_sd[k].mul_(1 - ((self.gamma + 1) / (self.gamma + t))).add_( 45 | model_sd[k], alpha=(self.gamma + 1) / (self.gamma + t) 46 | ) 47 | 48 | self.t += 1 49 | 50 | def _match_models_devices(self): 51 | if self._matched_devices: 52 | return 53 | 54 | # nn.Module does not always have a device attribute, so we check if the model has one and use it to match 55 | # the device where the av_model is stored 56 | if hasattr(self.model, 'device'): 57 | if self.model.device != self.av_model.device: 58 | self.av_model = self.av_model.to(self.model.device) 59 | else: 60 | # This could be a problem if the model is split across multiple devices in a distributed manner 61 | model_device, av_device = next(self.model.parameters()).device, next(self.av_model.parameters()).device 62 | if model_device != av_device: 63 | self.av_model = self.av_model.to(model_device) 64 | 65 | self._matched_devices = True 66 | 67 | def reset(self): 68 | self.t = 1 69 | 70 | @property 71 | def averaged_model(self): 72 | """ 73 | @return: returns the averaged model (the polynomial decay averaged model) 74 | """ 75 | return self.av_model 76 | 77 | @property 78 | def base_model(self): 79 | """ 80 | @return: returns the base model (the one that is being trainer) 81 | """ 82 | return self.model 83 | 84 | def state_dict(self): 85 | """ 86 | @return: returns the state dict of the averager. 87 | Note that if you wish to save the averaged model itself, as a loadable weights checkpoint, 88 | you should use averager.averaged_model.state_dict(). 89 | """ 90 | return { 91 | 't': self.t, 92 | 'av_model': self.av_model.state_dict() 93 | } 94 | 95 | def load_state_dict(self, state_dict): 96 | """ 97 | Loads the state dict of the averager. 98 | @param state_dict: A state dict as returned by averager.state_dict() 99 | """ 100 | self.t = state_dict['t'] 101 | self.av_model.load_state_dict(state_dict['av_model']) 102 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DoG Optimizer 2 | 3 | This repository contains the implementation of the algorithms in the paper 4 | [DoG is SGD's Best Friend: A Parameter-Free Dynamic Step Size Schedule](https://arxiv.org/abs/2302.12022) 5 | by Maor Ivgi, Oliver Hinder and Yair Carmon. 6 | 7 | **IMPORTANT:** For best performance (and for fair comparison to other methods) **DoG/L-DoG must be combined with iterate averaging!** This package includes an easy-to-use [averager class](<#iterate-averaging>) - its default configuration should work well out of the box. 8 | 9 | ## Algorithm 10 | DoG ("Distance over Gradients") is a parameter-free stochastic optimizer. 11 | DoG updates parameters $x_t$ with stochastic gradients $g_t$ according to: 12 | ```math 13 | \begin{aligned} 14 | \eta_t & = \frac{ \bar{r}_t }{ \sqrt{\sum_{i \le t }{\lVert g_i\rVert ^2 + \epsilon}} } \\ 15 | x_{t+1} & = x_{t} - \eta_t \cdot g_t 16 | \end{aligned} 17 | ``` 18 | where 19 | ```math 20 | \begin{equation*} 21 | \bar{r}_t = \begin{cases} 22 | \text{max}_{i \le t}{\lVert x_i - x_0 \rVert} & t \ge 1 \\ 23 | r_{\epsilon} & t=0. 24 | \end{cases} 25 | \end{equation*} 26 | ``` 27 | The initial movement parameter $r_{\epsilon}$ should be chosen small relative to the distance between $x_0$ and the nearest optimum $x^\star$ (see additional discussion below). 28 | 29 | LDoG (layerwise DoG) is a variant of DoG that applies the above update rule separately to every element in the list of parameters provided to the optimizer object. 30 | 31 | ## Installation 32 | To install the package, simply run `pip install dog-optimizer`. 33 | 34 | ## Usage 35 | DoG and LDoG are implemented using the standard pytorch optimizer interface. After installing the pacakge with `pip install dog-optimizer`, 36 | All you need to do is replace the line that creates your optimizer with 37 | ```python 38 | from dog import DoG 39 | optimizer = DoG(optimizer args) 40 | ``` 41 | for DoG, or 42 | ```python 43 | from dog import LDoG 44 | optimizer = LDoG(optimizer args) 45 | ``` 46 | for LDoG, 47 | where `optimizer args` follows the standard pytorch optimizer syntex. 48 | To see the list of all available parameters, run `help(DoG)` or `help(LDoG)`. 49 | 50 | ### Iterate averaging 51 | We provide an implementation of the polynomial decay averaging used throughout our experimentes. TO use it simply create a `PolynomialDecayAverager` with 52 | ```python 53 | from dog import PolynomialDecayAverager 54 | averager = PolynomialDecayAverager(model) 55 | ``` 56 | then, after each `optimizer.step()`, call `averager.step()` as well. 57 | You can then get both the current model and the averaged model with `averager.base_model` and `averager.averaged_model` respectively. 58 | 59 | ### Example script 60 | An example of how to use the above to train a simple CNN on MNIST can be found in `examples/mnist.py` 61 | (based on this [pytorch example](https://github.com/pytorch/examples/blob/main/mnist/main.py)). 62 | 63 | ### Choosing `reps_rel` 64 | DoG is parameter-free by design, so there is no need to tune a learning rate parameter. 65 | However, as discussed in the paper, DoG has an initial step movement parameter 66 | $r_{\epsilon}$ that must be small enough to avoid destructively updates that cause divergence, 67 | but an extremely small value of $r_{\epsilon}$ would slow down training. 68 | We recommend choosing $r_{\epsilon}$ relative to the norm of the initial weights $x_0$. In particular, we set 69 | $r_{\epsilon}$ to be `reps_rel` $\times (1+\rVert x_0 \lVert)$, where `reps_rel` is a configurable parameter of the optimizer. The default value 70 | of `reps_rel` is 1e-6, and we have found it to work well most of the time. However, in our experiments we did encounter 71 | some situations that required different values of `reps_rel`: 72 | - If optimization diverges early, it is likely that `reps_rel` (and hence $r_{\epsilon}$) is too large: 73 | try decreasing it by factors 100 until divergence no longer occurs. This happened when applying LDoG to fine-tune T5, 74 | which had large pre-trained weights; setting `reps_rel` to 1e-8 eliminated the divergence. 75 | - If the DoG step size (`eta`) does not substantially increase from its initial value for a few hundred steps, it could be that `reps_rel` is too small: 76 | try increasing it by factors of 100 until you see `eta` starting to increase in the first few steps. 77 | This happened when training models with batch normalization; setting `reps_rel` to 1e-4 eliminated the problem. 78 | 79 | 80 | ## Citation 81 | ``` 82 | @article{ivgi2023dog, 83 | title={{D}o{G} is {SGD}'s Best Friend: A Parameter-Free Dynamic Step Size Schedule}, 84 | author={Maor Ivgi and Oliver Hinder and Yair Carmon}, 85 | journal={arXiv:2302.12022}, 86 | year={2023}, 87 | } 88 | ``` 89 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | """ 2 | Minimal example to show how to use the DoG optimizer. 3 | Based on https://github.com/pytorch/examples/blob/main/mnist/main.py 4 | """ 5 | 6 | from __future__ import print_function 7 | import argparse 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torchvision import datasets, transforms 12 | from dog import DoG, LDoG, PolynomialDecayAverager 13 | 14 | 15 | class Net(nn.Module): 16 | def __init__(self): 17 | super(Net, self).__init__() 18 | self.conv1 = nn.Conv2d(1, 32, 3, 1) 19 | self.conv2 = nn.Conv2d(32, 64, 3, 1) 20 | self.dropout1 = nn.Dropout(0.25) 21 | self.dropout2 = nn.Dropout(0.5) 22 | self.fc1 = nn.Linear(9216, 128) 23 | self.fc2 = nn.Linear(128, 10) 24 | 25 | def forward(self, x): 26 | x = self.conv1(x) 27 | x = F.relu(x) 28 | x = self.conv2(x) 29 | x = F.relu(x) 30 | x = F.max_pool2d(x, 2) 31 | x = self.dropout1(x) 32 | x = torch.flatten(x, 1) 33 | x = self.fc1(x) 34 | x = F.relu(x) 35 | x = self.dropout2(x) 36 | x = self.fc2(x) 37 | output = F.log_softmax(x, dim=1) 38 | return output 39 | 40 | 41 | def format_pgroup_dog_state(dog_param_group_state): 42 | """ 43 | A helper function to format the state of a DoG parameter group into a loggable string, 44 | describing the distance from initial point, the sum of gradient squared norms, and the step size. 45 | 46 | Note: for LDoG, those value are the mean across layers 47 | @param dog_param_group_state: A state_dict of a single param group of a DoG optimizer 48 | @return: A printable string 49 | """ 50 | # among other things, the state dict contains the following keys: 51 | # 'rbar' is a tensor holding the distance (maximum distance observed so far) from the initial point 52 | # For DoG this is a single value, while for LDoG this is a vector of size equal to the number of layers 53 | # 'G' is a tensor holding the sum gradient squared norms 54 | # For DoG this is a single value, while for LDoG this is a vector of size equal to the number of layers 55 | # 'eta' is a list of scalar tensors holding the step size for each layer (i.e., pytorch Parameter) 56 | 57 | rbar = torch.mean(dog_param_group_state['rbar'].detach()).item() 58 | G = torch.mean(dog_param_group_state['G'].detach()).item() 59 | # in DoG, eta has the same value for all layers 60 | eta = torch.mean(torch.stack(dog_param_group_state['eta'])).detach().item() 61 | return f'rbar={rbar:E}, G={G:E}, eta={eta:E}' 62 | 63 | 64 | def train(args, model, averager, device, train_loader, optimizer, epoch): 65 | model.train() 66 | for batch_idx, (data, target) in enumerate(train_loader): 67 | data, target = data.to(device), target.to(device) 68 | optimizer.zero_grad() 69 | output = model(data) 70 | loss = F.nll_loss(output, target) 71 | loss.backward() 72 | optimizer.step() 73 | averager.step() 74 | if batch_idx % args.log_interval == 0: 75 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 76 | epoch, batch_idx * len(data), len(train_loader.dataset), 77 | 100. * batch_idx / len(train_loader), loss.item())) 78 | if args.log_state and isinstance(optimizer, DoG): 79 | opt_state = optimizer.state_dict() 80 | for i, p in enumerate(opt_state['param_groups']): 81 | prefix = f"DoG's state for param group {i}" if not args.ldog else \ 82 | f"LDoG's state for param group {i} (mean values across layers)" 83 | print(f'\t - {prefix}: {format_pgroup_dog_state(p)}') 84 | 85 | 86 | def test(model, device, test_loader, model_name): 87 | model.eval() 88 | test_loss = 0 89 | correct = 0 90 | with torch.no_grad(): 91 | for data, target in test_loader: 92 | data, target = data.to(device), target.to(device) 93 | output = model(data) 94 | test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss 95 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability 96 | correct += pred.eq(target.view_as(pred)).sum().item() 97 | 98 | test_loss /= len(test_loader.dataset) 99 | 100 | print('\nTest set ({}): Loss = {:.4f}, Accuracy = {:.2f}% ({}/{})\n'.format( 101 | model_name, test_loss, 100. * correct / len(test_loader.dataset), 102 | correct, len(test_loader.dataset), 103 | )) 104 | 105 | 106 | def main(): 107 | # Training settings 108 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 109 | parser.add_argument('--data-root', type=str, default='../data', metavar='N', 110 | help='data root (default: "../data")') 111 | parser.add_argument('--batch-size', type=int, default=64, metavar='N', 112 | help='input batch size for training (default: 64)') 113 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', 114 | help='input batch size for testing (default: 1000)') 115 | parser.add_argument('--epochs', type=int, default=14, metavar='N', 116 | help='number of epochs to train (default: 14)') 117 | parser.add_argument('--ldog', action='store_true', default=False, 118 | help='If set to true, will use LDoG rather than DoG') 119 | parser.add_argument('--lr', type=float, default=1.0, metavar='LR', 120 | help='base learning rate (default: 1.0) - should not be changed!') 121 | parser.add_argument('--reps_rel', type=float, default=1e-6, metavar='M', 122 | help='normalized version of the r_epsilon parameter (default: 1e-6)') 123 | parser.add_argument('--init_eta', type=float, default=0, metavar='M', 124 | help='if above 0, will use this value as the initial eta instead of the result of ' 125 | 'reps_rel (default: 0)') 126 | parser.add_argument('--avg_gamma', type=float, default=8, metavar='M', 127 | help='Polynomial decay averager gamma (default: 8)') 128 | parser.add_argument('--weight_decay', type=float, default=0, metavar='M', 129 | help='weight decay coefficient (default: 0)') 130 | parser.add_argument('--seed', type=int, default=1, metavar='S', 131 | help='random seed (default: 1)') 132 | parser.add_argument('--log-interval', type=int, default=10, metavar='N', 133 | help='how many batches to wait before logging training status') 134 | parser.add_argument('--save-model', action='store_true', default=False, 135 | help='Save the current Model') 136 | parser.add_argument('--no-log-state', action='store_false', default=True, dest='log_state', 137 | help='Suppress logging the state of the optimizer at each log_interval') 138 | args = parser.parse_args() 139 | use_cuda = torch.cuda.is_available() 140 | 141 | torch.manual_seed(args.seed) 142 | 143 | if use_cuda: 144 | device = torch.device("cuda") 145 | else: 146 | device = torch.device("cpu") 147 | 148 | train_kwargs = {'batch_size': args.batch_size} 149 | test_kwargs = {'batch_size': args.test_batch_size} 150 | if use_cuda: 151 | cuda_kwargs = {'num_workers': 1, 152 | 'pin_memory': True, 153 | 'shuffle': True} 154 | train_kwargs.update(cuda_kwargs) 155 | test_kwargs.update(cuda_kwargs) 156 | 157 | transform=transforms.Compose([ 158 | transforms.ToTensor(), 159 | transforms.Normalize((0.1307,), (0.3081,)) 160 | ]) 161 | dataset1 = datasets.MNIST(args.data_root, train=True, download=True, transform=transform) 162 | dataset2 = datasets.MNIST(args.data_root, train=False, transform=transform) 163 | train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs) 164 | test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs) 165 | 166 | model = Net().to(device) 167 | 168 | opt_class = LDoG if args.ldog else DoG 169 | # Creating the optimizer 170 | optimizer = opt_class(model.parameters(), reps_rel=args.reps_rel, lr=args.lr, 171 | init_eta=(args.init_eta if args.init_eta > 0 else None), weight_decay=args.weight_decay) 172 | averager = PolynomialDecayAverager(model, gamma=args.avg_gamma) # Creating the averager 173 | # Note - there is no lr scheduler 174 | 175 | for epoch in range(1, args.epochs + 1): 176 | train(args, model, averager, device, train_loader, optimizer, epoch) 177 | test(model, device, test_loader, 'base model') # get test results for the base model 178 | test(averager.averaged_model, device, test_loader, 'averaged model') # get test results for the averaged model 179 | 180 | if args.save_model: 181 | torch.save(model.state_dict(), "mnist_cnn.pt") 182 | torch.save(averager.averaged_model.state_dict(), "mnist_cnn_averaged.pt") 183 | 184 | 185 | if __name__ == '__main__': 186 | main() 187 | -------------------------------------------------------------------------------- /src/dog/dog.py: -------------------------------------------------------------------------------- 1 | """ 2 | PyTorch implementation of the DoG/LDoG optimizers (Ivgi et al., 2023) 3 | """ 4 | import logging 5 | from typing import Optional 6 | 7 | import torch 8 | from torch.optim import Optimizer 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class DoG(Optimizer): 14 | """ 15 | DoG (Distance over Gradients) is a parameter-free adaptive optimizer, proposed in 16 | `DoG is SGD's Best Friend: A Parameter-Free Dynamic Step Size Schedule` (Ivgi et al., 2023). 17 | IMPORTANT: for best performance, DoG must be combined with iterate averaging. 18 | """ 19 | 20 | __version__ = '1.0.3' 21 | 22 | def __init__(self, params, reps_rel: float = 1e-6, lr: float = 1.0, 23 | weight_decay: float = 0.0, eps: float = 1e-8, init_eta: Optional[float] = None): 24 | r"""Distance over Gradients - an adaptive stochastic optimizer. 25 | 26 | DoG updates parameters x_t with stochastic gradients g_t according to: 27 | .. math:: 28 | \begin{aligned} 29 | eta_t & = \frac{ max_{i \le t}{\|x_i - x_0\|} }{ \sqrt{\sum_{i \le t }{\|g_i\|^2 + eps}} }, \\ 30 | x_{t+1} & = x_{t} - eta_t * g_t, 31 | \end{aligned} 32 | 33 | IMPORTANT: Since we do not employ a step-size decay scheme, ITERATE AVERAGING IS CRUCIAL to obtain 34 | the best performance. This package provides an implementation of the polynomial decay averaging 35 | that is effective and does not require tuning. 36 | 37 | Args: 38 | params (iterable): iterable of parameters to optimize or dicts defining parameter groups 39 | reps_rel (float): value to use to compute the initial distance (r_epsilon in the paper). 40 | Namely, the first step size is given by: 41 | (reps_rel * (1+\|x_0\|)) / (\|g_0\|^2 + eps)^{1/2} where x_0 are the initial 42 | weights of the model (or the parameter group), and g_0 is the gradient of the 43 | first step. 44 | As discussed in the paper, this value should be small enough to ensure that the 45 | first update step will be small enough to not cause the model to diverge. 46 | 47 | Suggested value is 1e-6, unless the model uses batch-normalization, 48 | in which case the suggested value is 1e-4. (default: 1e-6) 49 | 50 | lr (float, optional): learning rate (referred to as c in the paper). The default value is 1.0 and changing 51 | it is not recommended. 52 | weight_decay (float, optional): weight decay (L2 penalty). weight_decay * x_t is added directly 53 | to the gradient (default: 0) 54 | eps (float, optional): epsilon used for numerical stability - added to the sum of gradients (default: 1e-8) 55 | init_eta (floar, optional): if specified, this value will be used the the initial eta (i.e. 56 | first step size), and will override the value of reps_rel (default: None) 57 | 58 | Example: 59 | >>> optimizer = DoG(model.parameters(), reps_rel=1e-6) 60 | >>> optimizer.zero_grad() 61 | >>> loss_fn(model(input), target).backward() 62 | >>> optimizer.step() 63 | 64 | __ https://arxiv.org/pdf/2302.12022.pdf 65 | """ 66 | 67 | if lr <= 0.0: 68 | raise ValueError(f'Invalid learning rate ({lr}). Suggested value is 1.') 69 | if lr != 1.0: 70 | logger.warning(f'We do not recommend changing the lr parameter from its default value of 1') 71 | if init_eta is not None: 72 | if init_eta <= 0: 73 | raise ValueError(f'Invalid value for init_eta ({init_eta})') 74 | logger.info(f'Ignoring reps_rel since will be explicitly set init_eta to be {init_eta} (first step size)') 75 | reps_rel = 0 76 | else: 77 | if reps_rel <= 0.0: 78 | raise ValueError(f'Invalid reps_rel value ({reps_rel}). Suggested value is 1e-6 ' 79 | '(unless the model uses batch-normalization, in which case suggested value is 1e-4)') 80 | 81 | if weight_decay < 0.0: 82 | raise ValueError(f'Invalid weight_decay value: {weight_decay}') 83 | 84 | self._first_step = True 85 | 86 | defaults = dict(reps_rel=reps_rel, lr=lr, weight_decay=weight_decay, eps=eps, init_eta=init_eta) 87 | super(DoG, self).__init__(params, defaults) 88 | 89 | def __setstate__(self, state): 90 | super(DoG, self).__setstate__(state) 91 | 92 | def state_dict(self) -> dict: 93 | state_dict = super(DoG, self).state_dict() 94 | logger.info('retrieving DoG state dict') 95 | state_dict['state']['_first_step'] = self._first_step 96 | return state_dict 97 | 98 | def load_state_dict(self, state_dict: dict) -> None: 99 | super(DoG, self).load_state_dict(state_dict) 100 | self._first_step = state_dict['state']['_first_step'] 101 | logger.info(f'loaded DoG state dict') 102 | cuda = self.param_groups[0]['params'][0].device 103 | for group in self.param_groups: 104 | cuda_buffers = {'init_buffer'} 105 | for tgroup in group.keys(): 106 | # this can cast all the tensors to the device. However, as it turns out, 107 | # we need ONLY the init_buffer to be on the params' device 108 | if tgroup != 'params': 109 | device = cuda if tgroup in cuda_buffers else 'cpu' 110 | if isinstance(group[tgroup], list) and len(group[tgroup]) > 0 and \ 111 | isinstance(group[tgroup][0], torch.Tensor): 112 | group[tgroup] = [i.to(device) for i in group[tgroup]] 113 | elif isinstance(group[tgroup], torch.Tensor): 114 | group[tgroup] = group[tgroup].to(device) 115 | 116 | @torch.no_grad() 117 | def step(self, closure=None): 118 | """ 119 | Performs a single optimization step. 120 | 121 | Arguments: 122 | closure (callable, optional): A closure that reevaluates the model 123 | and returns the loss. 124 | """ 125 | loss = None 126 | if closure is not None: 127 | with torch.enable_grad(): 128 | loss = closure() 129 | first_step = self._first_step 130 | 131 | for group in self.param_groups: 132 | weight_decay = group['weight_decay'] 133 | 134 | if first_step: 135 | init = group['init_buffer'] = [torch.clone(p).detach() for p in group['params']] 136 | else: 137 | init = group['init_buffer'] 138 | 139 | if weight_decay > 0: 140 | for p in group['params']: 141 | p.grad.add_(p, alpha=weight_decay) 142 | 143 | self._update_group_state(group, init) 144 | self._override_init_eta_if_needed(group) 145 | 146 | for p, eta in zip(group['params'], group['eta']): 147 | if p.grad is None: 148 | continue 149 | else: 150 | p.add_(p.grad, alpha=-eta) 151 | 152 | self._first_step = False 153 | 154 | return loss 155 | 156 | def _update_group_state(self, group, init): 157 | # treat all layers as one long vector 158 | if self._first_step: 159 | group['rbar'] = group['reps_rel'] * (1 + torch.stack([p.norm() for p in group['params']]).norm()) 160 | group['G'] = torch.stack([(p.grad.detach() ** 2).sum() for p in group['params']]).sum() + group['eps'] 161 | else: 162 | curr_d = torch.stack([torch.norm(p.detach() - pi) for p, pi in zip(group['params'], init)]).norm() 163 | group['rbar'] = torch.maximum(group['rbar'], curr_d) 164 | group['G'] += torch.stack([(p.grad.detach() ** 2).sum() for p in group['params']]).sum() 165 | assert group['G'] > 0, \ 166 | f'DoG cannot work when G is not strictly positive. got: {group["G"]}' 167 | group['eta'] = [group['lr'] * group['rbar'] / torch.sqrt(group['G'])] * len(group['params']) 168 | 169 | def _override_init_eta_if_needed(self, group): 170 | # Override init_eta if needed 171 | if self._first_step and group['init_eta'] is not None: 172 | init_eta = group['init_eta'] 173 | logger.info(f'Explicitly setting init_eta value to {init_eta}') 174 | group['eta'] = [eta * 0 + init_eta for eta in group['eta']] 175 | 176 | 177 | class LDoG(DoG): 178 | """ 179 | Layer-wise DoG, as described in: 180 | `DoG is SGD's Best Friend: A Parameter-Free Dynamic Step Size Schedule` (Ivgi et al., 2023). 181 | LDoG applies the DoG formula defined in the DoG class, but for each layer separately. 182 | IMPORTANT: for best performance, L-DoG must be combined with iterate averaging. 183 | """ 184 | def _update_group_state(self, group, init): 185 | # treat each layer in the group as a separate block 186 | if self._first_step: 187 | group['rbar'] = group['reps_rel'] * (1 + torch.stack([p.norm() for p in group['params']])) 188 | group['G'] = torch.stack([(p.grad ** 2).sum() for p in group['params']]) + group['eps'] 189 | else: 190 | curr_d = torch.stack([torch.norm(p - pi) for p, pi in zip(group['params'], init)]) 191 | group['rbar'] = torch.maximum(group['rbar'], curr_d) 192 | group['G'] += torch.stack([(p.grad ** 2).sum() for p in group['params']]) 193 | assert torch.all(group['G'] > 0).item(), \ 194 | f'DoG cannot work when g2 is not strictly positive. got: {group["G"]}' 195 | group['eta'] = list(group['lr'] * group['rbar'] / torch.sqrt(group['G'])) 196 | 197 | 198 | --------------------------------------------------------------------------------