├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── dadaptation ├── __init__.py ├── dadapt_adagrad.py ├── dadapt_adam.py ├── dadapt_adan.py ├── dadapt_lion.py ├── dadapt_sgd.py └── experimental │ ├── __init__.py │ ├── dadapt_adam_preprint.py │ └── dadapt_adan_ip.py ├── figures ├── dadapt_cifar.png ├── dadapt_cifar100.png ├── dadapt_convex.png ├── dadapt_detectron.png ├── dadapt_dlrm.png ├── dadapt_fastmri.png ├── dadapt_gpt.png ├── dadapt_imagenet.png ├── dadapt_lstm.png ├── dadapt_roberta.png └── dadapt_vit.png ├── pyproject.toml ├── requirements.txt └── setup.py /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to D-Adaptation 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Meta's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to D-Adaptation, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Meta Platforms, Inc. and its affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # D-Adaptation 2 | [![Downloads](https://static.pepy.tech/badge/dadaptation)](https://pepy.tech/project/dadaptation) [![Downloads](https://static.pepy.tech/badge/dadaptation/month)](https://pepy.tech/project/dadaptation) 3 | 4 | Learning rate free learning for SGD, AdaGrad and Adam! 5 | 6 | *by Aaron Defazio and Konstantin Mishchenko [(Arxiv)](https://arxiv.org/abs/2301.07733)* 7 | 8 | ``` pip install dadaptation ``` 9 | 10 | **NEW V3.0 release uses an improved algorithm that may give different results from past versions. The old version is still availiable under experimental/d_adapt_adam_preprint.** 11 | 12 | ## NEW: Prodigy 13 | We have recently released the [Prodigy](https://github.com/konstmish/prodigy) method, which grows the adapted learning rate faster than D-Adaptation in theory and practice. Try it out if D-Adaptation is under-estimating the learning rate. 14 | 15 | ## How To Cite 16 | If you use D-Adaptation in a publication, please cite our work as 17 | ``` 18 | @ARTICLE{defazio2023dadapt, 19 | author = {Aaron Defazio and Konstantin Mishchenko}, 20 | title = {Learning-Rate-Free Learning by D-Adaptation}, 21 | journal = {The 40th International Conference on Machine Learning (ICML 2023)}, 22 | year = {2023} 23 | } 24 | ``` 25 | 26 | ## Details 27 | 28 | The provided Pytorch Optimizer classes are drop-in replacements, either copy into your project or use via pip with dadaptation.DAdaptSGD, dadaptation.DAdaptAdam or dadaptation.DAdaptAdaGrad. 29 | 30 | - **Set the LR parameter to 1.0**. This parameter is not ignored. Setting it larger to smaller will directly scale up or down the D-Adapted learning rate estimate. 31 | - Different per-layer learning rates can be achieved by setting the layer_scale value in each parameter-group. It defaults to 1.0, and scales each layer's learning rate relative to the other layers. 32 | - **Use the same learning rate scheduler you would normally use on the problem.** 33 | - The Adam variant supports AdamW style weight decay, just set decouple=True. It is not turned on by default, so if you are replacing your adam implementation, make sure you use decoupled if necessary. 34 | - It may be necessary to use larger weight decay than you would normally use, try a factor of 2 or 4 bigger if you see overfitting. D-Adaptation uses larger learning rates than people typically hand-choose, in some cases that requires more decay. 35 | - Use the log_every setting to see the learning rate being used (d*lr) and the current D bound. 36 | - Only the AdaGrad version supports sparse gradients. It does not adapt as efficiently as the other variants and should be considered experimental. 37 | 38 | ## Change Log 39 | 40 | ### Version 3.2 41 | - Added support for layer-wise scaling to DAdaptAdam. 42 | 43 | ### Version 3.0 44 | - Major improvements to DAdaptAdam, improving the performance particularly on Transformer models. This variant may behave differently in practice. The old version is availiable under experimental/d_adapt_adam_preprint if you wish to continue to use it. 45 | - The IP variant is now the main variant of the method. 46 | - Added Lion. This is highly experimental. Feedback on it's performance is welcome. 47 | 48 | ### Version 2.0 49 | - Added Adan - should still be considered experimental. 50 | - Added support for PyTorch's Fully Sharded Data Parallel. 51 | - Improved support of edge cases such as learning rate zero. 52 | - Improved logging - uses Python logging rather than print statements 53 | 54 | # Experimental results 55 | 56 | ![vision](figures/dadapt_cifar.png) 57 | ![vision](figures/dadapt_cifar100.png) 58 | ![vision](figures/dadapt_imagenet.png) 59 | ![vision](figures/dadapt_vit.png) 60 | ![vision](figures/dadapt_lstm.png) 61 | ![vision](figures/dadapt_roberta.png) 62 | ![vision](figures/dadapt_gpt.png) 63 | ![vision](figures/dadapt_fastmri.png) 64 | ![vision](figures/dadapt_detectron.png) 65 | ![vision](figures/dadapt_dlrm.png) 66 | 67 | # License 68 | See the [License file](/LICENSE). 69 | -------------------------------------------------------------------------------- /dadaptation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .dadapt_adagrad import DAdaptAdaGrad 8 | from .dadapt_adam import DAdaptAdam 9 | from .dadapt_sgd import DAdaptSGD 10 | from .dadapt_adan import DAdaptAdan 11 | from .dadapt_lion import DAdaptLion 12 | -------------------------------------------------------------------------------- /dadaptation/dadapt_adagrad.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | from typing import TYPE_CHECKING, Any, Callable, Optional 9 | 10 | import torch 11 | import torch.optim 12 | import pdb 13 | import logging 14 | 15 | if TYPE_CHECKING: 16 | from torch.optim.optimizer import _params_t 17 | else: 18 | _params_t = Any 19 | 20 | 21 | class DAdaptAdaGrad(torch.optim.Optimizer): 22 | """ 23 | Adagrad with D-Adaptation. We recommend Adam or SGD be used instead in most situations, 24 | as D-Adaptation on top of AdaGrad does not adapt the learning rate as quickly in 25 | practice as the other variants. 26 | 27 | Leave LR set to 1 unless you encounter instability. 28 | 29 | Arguments: 30 | params (iterable): 31 | Iterable of parameters to optimize or dicts defining parameter groups. 32 | lr (float): 33 | Learning rate adjustment parameter. Increases or decreases the D-adapted learning rate. 34 | log_every (int): 35 | Log using print every k steps, default 0 (no logging). 36 | weight_decay (float): 37 | Weight decay, i.e. a L2 penalty (default: 0). 38 | eps (float): 39 | Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-6). 40 | d0 (float): 41 | Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing. 42 | growth_rate (float): 43 | prevent the D estimate from growing faster than this multiplicative rate. 44 | Default is inf, for unrestricted. 45 | """ 46 | 47 | def __init__( 48 | self, params: _params_t, 49 | lr: float = 1.0, 50 | momentum: float = 0, 51 | log_every: int = 0, 52 | weight_decay: float = 0.0, 53 | eps: float = 1e-6, 54 | d0 = 1e-6, growth_rate=float('inf') 55 | ): 56 | if d0 <= 0: 57 | raise ValueError("Invalid d0 value: {}".format(d0)) 58 | if lr <= 0: 59 | raise ValueError(f"Learning rate {lr} must be positive") 60 | if momentum < 0: 61 | raise ValueError(f"Momentum {momentum} must be non-negative") 62 | if eps <= 0: 63 | raise ValueError("Invalid epsilon value: {}".format(eps)) 64 | 65 | defaults = dict(lr=lr, 66 | momentum=momentum, 67 | eps=eps, 68 | weight_decay=weight_decay, 69 | gsq_weighted=0.0, 70 | log_every=log_every, 71 | d=d0, 72 | growth_rate=growth_rate, 73 | k = 0, 74 | sksq_weighted=0.0, 75 | skl1=0.0) 76 | self.d0 = d0 77 | super().__init__(params, defaults) 78 | 79 | def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: 80 | """Performs a single optimization step. 81 | 82 | Arguments: 83 | closure (callable, optional): A closure that reevaluates the model 84 | and returns the loss. 85 | """ 86 | loss = None 87 | if closure is not None: 88 | loss = closure() 89 | 90 | group = self.param_groups[0] 91 | lr = group["lr"] 92 | momentum = group['momentum'] 93 | ck = 1 - momentum 94 | 95 | log_every = group['log_every'] 96 | growth_rate = group['growth_rate'] 97 | 98 | gsq_weighted = group['gsq_weighted'] 99 | sksq_weighted = group['sksq_weighted'] 100 | skl1 = group['skl1'] 101 | d = group['d'] 102 | 103 | dlr = d*lr 104 | 105 | g_sq = 0.0 106 | sksq_weighted_change = 0.0 107 | skl1_change = 0.0 108 | 109 | for group in self.param_groups: 110 | eps = group["eps"] 111 | k = group['k'] 112 | decay = group['weight_decay'] 113 | 114 | ###### 115 | for p in group['params']: 116 | if p.grad is None: 117 | continue 118 | if hasattr(p, "_fsdp_flattened"): 119 | raise RuntimeError("D-Adapt AdaGrad doesn't currently support fully-sharded data parallel. Use D-Adapt Adam instead") 120 | grad = p.grad.data 121 | 122 | state = self.state[p] 123 | 124 | if "alphak" not in state: 125 | state["alphak"] = torch.full_like(p.data, fill_value=1e-6).detach() 126 | state['sk'] = torch.zeros_like(p.data).detach() 127 | state["x0"] = torch.clone(p.data).detach() 128 | 129 | if grad.is_sparse: 130 | state['weighted_sk'] = torch.zeros_like(p.data).detach() 131 | 132 | sk = state['sk'] 133 | alphak = state['alphak'] 134 | 135 | grad_sq = 0.0 136 | if grad.is_sparse: 137 | weighted_sk = state['weighted_sk'] 138 | 139 | grad = grad.coalesce() 140 | grad_vals = grad._values() 141 | vk_vals = grad_vals*grad_vals 142 | 143 | sk_vals = sk.sparse_mask(grad).coalesce()._values() 144 | 145 | old_skl1_vals = sk_vals.abs().sum().item() 146 | 147 | sk.data.add_(grad, alpha=dlr) 148 | 149 | sk_vals = sk.sparse_mask(grad).coalesce()._values() 150 | alphak_vals = alphak.sparse_mask(grad).coalesce()._values() 151 | weighted_sk_vals = weighted_sk.sparse_mask(grad).coalesce()._values() 152 | 153 | ### Update alpha before step 154 | alphak_vals = alphak.sparse_mask(grad).coalesce()._values() 155 | alphakp1_vals = alphak_vals + vk_vals 156 | 157 | alphak_delta_vals = alphakp1_vals - alphak_vals 158 | alphak_delta = torch.sparse_coo_tensor(grad.indices(), alphak_delta_vals, grad.shape) 159 | alphak.add_(alphak_delta) 160 | 161 | #### 162 | denominator = torch.sqrt(alphakp1_vals + eps) 163 | 164 | grad_sq = (grad_vals * grad_vals).div(denominator).sum().item() 165 | g_sq += grad_sq 166 | 167 | ### Update weighted sk sq tracking 168 | weighted_skp1_vals = (sk_vals * sk_vals).div(denominator) 169 | 170 | sksq_weighted_change += weighted_skp1_vals.sum().item() - weighted_sk_vals.sum().item() 171 | 172 | weighted_skp1_delta_vals = weighted_skp1_vals - weighted_sk_vals 173 | weighted_skp1_delta = torch.sparse_coo_tensor(grad.indices(), weighted_skp1_delta_vals, grad.shape) 174 | weighted_sk.add_(weighted_skp1_delta) 175 | 176 | skl1_vals = sk_vals.abs().sum().item() 177 | 178 | skl1_change += skl1_vals - old_skl1_vals 179 | 180 | else: 181 | if decay != 0: 182 | grad.add_(p.data, alpha=decay) 183 | 184 | old_sksq_weighted_param = (sk * sk).div(torch.sqrt(alphak) + eps).sum().item() 185 | old_skl1_param = sk.abs().sum().item() 186 | 187 | alphak.data.add_(grad * grad) 188 | grad_sq = (grad * grad).div(torch.sqrt(alphak) + eps).sum().item() 189 | g_sq += grad_sq 190 | 191 | sk.data.add_(grad, alpha=dlr) 192 | 193 | sksq_weighted_param = (sk * sk).div(torch.sqrt(alphak) + eps).sum().item() 194 | skl1_param = sk.abs().sum().item() 195 | 196 | sksq_weighted_change += sksq_weighted_param - old_sksq_weighted_param 197 | skl1_change += skl1_param - old_skl1_param 198 | ###### 199 | 200 | sksq_weighted = sksq_weighted + sksq_weighted_change 201 | skl1 = skl1 + skl1_change 202 | 203 | # if we have not done any progres, return 204 | # if we have any gradients available, will have skl1 > 0 (unless \|g\|=0) 205 | if skl1 == 0: 206 | return loss 207 | 208 | gsq_weighted = gsq_weighted + dlr*dlr*g_sq 209 | d_hat = d 210 | 211 | if lr > 0.0: 212 | d_hat = (sksq_weighted - gsq_weighted)/skl1 213 | d = group['d'] = max(d, min(d_hat, d*growth_rate)) 214 | 215 | if log_every > 0 and k % log_every == 0: 216 | logging.info(f"d_hat: {d_hat}, d: {d}. sksq_weighted={sksq_weighted:1.1e} skl1={skl1:1.1e} gsq_weighted={gsq_weighted:1.1e} lr={lr}") 217 | 218 | for group in self.param_groups: 219 | group['gsq_weighted'] = gsq_weighted 220 | group['skl1'] = skl1 221 | group['sksq_weighted'] = sksq_weighted 222 | 223 | group['d'] = d 224 | 225 | decay = group['weight_decay'] 226 | k = group['k'] 227 | eps = group['eps'] 228 | 229 | for p in group["params"]: 230 | if p.grad is None: 231 | continue 232 | grad = p.grad.data 233 | state = self.state[p] 234 | 235 | alphak = state["alphak"] 236 | sk = state["sk"] 237 | x0 = state["x0"] 238 | 239 | if grad.is_sparse: 240 | grad = grad.coalesce() 241 | grad_vals = grad._values() 242 | 243 | sk_vals = sk.sparse_mask(grad).coalesce()._values() 244 | alphak_vals = alphak.sparse_mask(grad).coalesce()._values() 245 | x0_vals = x0.sparse_mask(grad).coalesce()._values() 246 | p_vals = p.data.sparse_mask(grad).coalesce()._values() 247 | 248 | loc_vals = x0_vals - sk_vals.div(torch.sqrt(alphak_vals + eps)) 249 | 250 | loc_delta_vals = loc_vals - p_vals 251 | loc_delta = torch.sparse_coo_tensor(grad.indices(), loc_delta_vals, grad.shape) 252 | p.data.add_(loc_delta) 253 | 254 | else: 255 | z = x0 - sk.div(torch.sqrt(alphak) + eps) 256 | 257 | if momentum != 0: 258 | p.data.mul_(1-ck).add_(z, alpha=ck) 259 | else: 260 | p.data.copy_(z) 261 | group['k'] = k + 1 262 | return loss 263 | -------------------------------------------------------------------------------- /dadaptation/dadapt_adam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | from typing import TYPE_CHECKING, Any, Callable, Optional 9 | 10 | import torch 11 | import torch.optim 12 | import pdb 13 | import logging 14 | import os 15 | import torch.distributed as dist 16 | 17 | if TYPE_CHECKING: 18 | from torch.optim.optimizer import _params_t 19 | else: 20 | _params_t = Any 21 | 22 | class DAdaptAdam(torch.optim.Optimizer): 23 | r""" 24 | Implements Adam with D-Adaptation automatic step-sizes. 25 | Leave LR set to 1 unless you encounter instability. 26 | 27 | To scale the learning rate differently for each layer, set the 'layer_scale' 28 | for each parameter group. Increase (or decrease) from its default value of 1.0 29 | to increase (or decrease) the learning rate for that layer relative to the 30 | other layers. 31 | 32 | Arguments: 33 | params (iterable): 34 | Iterable of parameters to optimize or dicts defining parameter groups. 35 | lr (float): 36 | Learning rate adjustment parameter. Increases or decreases the D-adapted learning rate. 37 | betas (Tuple[float, float], optional): coefficients used for computing 38 | running averages of gradient and its square (default: (0.9, 0.999)) 39 | eps (float): 40 | Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-8). 41 | weight_decay (float): 42 | Weight decay, i.e. a L2 penalty (default: 0). 43 | log_every (int): 44 | Log using print every k steps, default 0 (no logging). 45 | decouple (boolean): 46 | Use AdamW style decoupled weight decay 47 | use_bias_correction (boolean): 48 | Turn on Adam's bias correction. Off by default. 49 | d0 (float): 50 | Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing. 51 | growth_rate (float): 52 | prevent the D estimate from growing faster than this multiplicative rate. 53 | Default is inf, for unrestricted. Values like 1.02 give a kind of learning 54 | rate warmup effect. 55 | fsdp_in_use (bool): 56 | If you're using sharded parameters, this should be set to True. The optimizer 57 | will attempt to auto-detect this, but if you're using an implementation other 58 | than PyTorch's builtin version, the auto-detection won't work. 59 | """ 60 | def __init__(self, params, lr=1.0, 61 | betas=(0.9, 0.999), eps=1e-8, 62 | weight_decay=0, log_every=0, 63 | decouple=False, 64 | use_bias_correction=False, 65 | d0=1e-6, growth_rate=float('inf'), 66 | fsdp_in_use=False): 67 | if not 0.0 < d0: 68 | raise ValueError("Invalid d0 value: {}".format(d0)) 69 | if not 0.0 < lr: 70 | raise ValueError("Invalid learning rate: {}".format(lr)) 71 | if not 0.0 < eps: 72 | raise ValueError("Invalid epsilon value: {}".format(eps)) 73 | if not 0.0 <= betas[0] < 1.0: 74 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 75 | if not 0.0 <= betas[1] < 1.0: 76 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 77 | 78 | if decouple: 79 | print(f"Using decoupled weight decay") 80 | 81 | 82 | defaults = dict(lr=lr, betas=betas, eps=eps, 83 | weight_decay=weight_decay, 84 | d = d0, 85 | k=0, 86 | layer_scale=1.0, 87 | numerator_weighted=0.0, 88 | log_every=log_every, 89 | growth_rate=growth_rate, 90 | use_bias_correction=use_bias_correction, 91 | decouple=decouple, 92 | fsdp_in_use=fsdp_in_use) 93 | self.d0 = d0 94 | super().__init__(params, defaults) 95 | 96 | @property 97 | def supports_memory_efficient_fp16(self): 98 | return False 99 | 100 | @property 101 | def supports_flat_params(self): 102 | return True 103 | 104 | def step(self, closure=None): 105 | """Performs a single optimization step. 106 | 107 | Arguments: 108 | closure (callable, optional): A closure that reevaluates the model 109 | and returns the loss. 110 | """ 111 | loss = None 112 | if closure is not None: 113 | loss = closure() 114 | 115 | sk_l1 = 0.0 116 | 117 | group = self.param_groups[0] 118 | use_bias_correction = group['use_bias_correction'] 119 | numerator_weighted = group['numerator_weighted'] 120 | beta1, beta2 = group['betas'] 121 | k = group['k'] 122 | 123 | d = group['d'] 124 | lr = max(group['lr'] for group in self.param_groups) 125 | 126 | if use_bias_correction: 127 | bias_correction = ((1-beta2**(k+1))**0.5)/(1-beta1**(k+1)) 128 | else: 129 | bias_correction = 1 130 | 131 | dlr = d*lr*bias_correction 132 | 133 | growth_rate = group['growth_rate'] 134 | decouple = group['decouple'] 135 | log_every = group['log_every'] 136 | fsdp_in_use = group['fsdp_in_use'] 137 | 138 | 139 | sqrt_beta2 = beta2**(0.5) 140 | 141 | numerator_acum = 0.0 142 | 143 | for group in self.param_groups: 144 | decay = group['weight_decay'] 145 | k = group['k'] 146 | eps = group['eps'] 147 | group_lr = group['lr'] 148 | r = group['layer_scale'] 149 | 150 | if group_lr not in [lr, 0.0]: 151 | raise RuntimeError(f"Setting different lr values in different parameter groups " 152 | "is only supported for values of 0. To scale the learning " 153 | "rate differently for each layer, set the 'layer_scale' value instead.") 154 | 155 | for p in group['params']: 156 | if p.grad is None: 157 | continue 158 | if hasattr(p, "_fsdp_flattened"): 159 | fsdp_in_use = True 160 | 161 | grad = p.grad.data 162 | 163 | # Apply weight decay (coupled variant) 164 | if decay != 0 and not decouple: 165 | grad.add_(p.data, alpha=decay) 166 | 167 | state = self.state[p] 168 | 169 | # State initialization 170 | if 'step' not in state: 171 | state['step'] = 0 172 | state['s'] = torch.zeros_like(p.data).detach() 173 | # Exponential moving average of gradient values 174 | state['exp_avg'] = torch.zeros_like(p.data).detach() 175 | # Exponential moving average of squared gradient values 176 | state['exp_avg_sq'] = torch.zeros_like(p.data).detach() 177 | 178 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 179 | 180 | s = state['s'] 181 | 182 | if group_lr > 0.0: 183 | denom = exp_avg_sq.sqrt().add_(eps) 184 | numerator_acum += r * dlr * torch.dot(grad.flatten(), s.div(denom).flatten()).item() 185 | 186 | # Adam EMA updates 187 | exp_avg.mul_(beta1).add_(grad, alpha=r*dlr*(1-beta1)) 188 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1-beta2) 189 | 190 | s.mul_(sqrt_beta2).add_(grad, alpha=dlr*(1-sqrt_beta2)) 191 | sk_l1 += r * s.abs().sum().item() 192 | 193 | ###### 194 | 195 | d_hat = d 196 | 197 | # if we have not done any progres, return 198 | # if we have any gradients available, will have sk_l1 > 0 (unless \|g\|=0) 199 | if sk_l1 == 0: 200 | return loss 201 | 202 | if fsdp_in_use: 203 | dist_tensor = torch.zeros(2).cuda() 204 | dist_tensor[0] = numerator_acum 205 | dist_tensor[1] = sk_l1 206 | dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM) 207 | global_numerator_weighted = sqrt_beta2*numerator_weighted + (1-sqrt_beta2)*dist_tensor[0] 208 | global_sk_l1 = dist_tensor[1] 209 | else: 210 | global_numerator_weighted = sqrt_beta2*numerator_weighted + (1-sqrt_beta2)*numerator_acum 211 | global_sk_l1 = sk_l1 212 | 213 | if lr > 0.0: 214 | d_hat = global_numerator_weighted/((1-sqrt_beta2)*global_sk_l1) 215 | d = max(d, min(d_hat, d*growth_rate)) 216 | 217 | if log_every > 0 and k % log_every == 0: 218 | logging.info(f"lr: {lr} dlr: {dlr} d_hat: {d_hat}, d: {d}. sk_l1={global_sk_l1:1.1e} numerator_weighted={global_numerator_weighted:1.1e}") 219 | 220 | for group in self.param_groups: 221 | group['numerator_weighted'] = global_numerator_weighted 222 | group['d'] = d 223 | 224 | decay = group['weight_decay'] 225 | k = group['k'] 226 | eps = group['eps'] 227 | 228 | for p in group['params']: 229 | if p.grad is None: 230 | continue 231 | grad = p.grad.data 232 | 233 | state = self.state[p] 234 | 235 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 236 | 237 | state['step'] += 1 238 | 239 | denom = exp_avg_sq.sqrt().add_(eps) 240 | 241 | # Apply weight decay (decoupled variant) 242 | if decay != 0 and decouple: 243 | p.data.add_(p.data, alpha=-decay * dlr) 244 | 245 | 246 | ### Take step 247 | p.data.addcdiv_(exp_avg, denom, value=-1) 248 | 249 | group['k'] = k + 1 250 | 251 | return loss 252 | -------------------------------------------------------------------------------- /dadaptation/dadapt_adan.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import TYPE_CHECKING, Any 8 | 9 | import torch 10 | import torch.optim 11 | 12 | if TYPE_CHECKING: 13 | from torch.optim.optimizer import _params_t 14 | else: 15 | _params_t = Any 16 | 17 | 18 | def to_real(x): 19 | if torch.is_complex(x): 20 | return x.real 21 | else: 22 | return x 23 | 24 | 25 | class DAdaptAdan(torch.optim.Optimizer): 26 | r""" 27 | Implements Adan with D-Adaptation automatic step-sizes. 28 | Has not been as heavily tested as DAdaptAdam and should be considered experimental. 29 | 30 | Leave LR set to 1 unless you encounter instability. 31 | Adan was proposed in 32 | Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models[J]. arXiv preprint arXiv:2208.06677, 2022. 33 | https://arxiv.org/abs/2208.06677 34 | 35 | Arguments: 36 | params (iterable): 37 | Iterable of parameters to optimize or dicts defining parameter groups. 38 | lr (float): 39 | Learning rate adjustment parameter. Increases or decreases the D-adapted learning rate. 40 | betas (Tuple[float, float, flot], optional): coefficients used for computing 41 | running averages of gradient and its norm. (default: (0.98, 0.92, 0.99)) 42 | eps (float): 43 | Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-8). 44 | weight_decay (float): 45 | Weight decay, i.e. a L2 penalty (default: 0.02). 46 | no_prox (boolean): 47 | how to perform the decoupled weight decay (default: False) 48 | log_every (int): 49 | Log using print every k steps, default 0 (no logging). 50 | d0 (float): 51 | Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing. 52 | growth_rate (float): 53 | prevent the D estimate from growing faster than this multiplicative rate. 54 | Default is inf, for unrestricted. Values like 1.02 give a kind of learning 55 | rate warmup effect. 56 | """ 57 | def __init__(self, params, lr=1.0, 58 | betas=(0.98, 0.92, 0.99), 59 | eps=1e-8, weight_decay=0.02, 60 | no_prox=False, 61 | log_every=0, d0=1e-6, 62 | growth_rate=float('inf')): 63 | if not 0.0 < d0: 64 | raise ValueError("Invalid d0 value: {}".format(d0)) 65 | if not 0.0 < lr: 66 | raise ValueError("Invalid learning rate: {}".format(lr)) 67 | if not 0.0 < eps: 68 | raise ValueError("Invalid epsilon value: {}".format(eps)) 69 | if not 0.0 <= betas[0] < 1.0: 70 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 71 | if not 0.0 <= betas[1] < 1.0: 72 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 73 | if not 0.0 <= betas[2] < 1.0: 74 | raise ValueError("Invalid beta parameter at index 2: {}".format(betas[2])) 75 | 76 | defaults = dict(lr=lr, betas=betas, eps=eps, 77 | weight_decay=weight_decay, 78 | no_prox=no_prox, 79 | d = d0, 80 | k=0, 81 | gsq_weighted=0.0, 82 | log_every=log_every, 83 | growth_rate=growth_rate) 84 | super().__init__(params, defaults) 85 | 86 | @property 87 | def supports_memory_efficient_fp16(self): 88 | return False 89 | 90 | @property 91 | def supports_flat_params(self): 92 | return True 93 | 94 | # Experimental implementation of Adan's restart strategy 95 | @torch.no_grad() 96 | def restart_opt(self): 97 | for group in self.param_groups: 98 | group['gsq_weighted'] = 0.0 99 | for p in group['params']: 100 | if p.requires_grad: 101 | state = self.state[p] 102 | # State initialization 103 | 104 | state['step'] = 0 105 | state['s'] = torch.zeros_like(p.data, memory_format=torch.preserve_format).detach() 106 | # Exponential moving average of gradient values 107 | state['exp_avg'] = torch.zeros_like(p.data, memory_format=torch.preserve_format).detach() 108 | # Exponential moving average of gradient difference 109 | state['exp_avg_diff'] = torch.zeros_like(to_real(p.data), memory_format=torch.preserve_format).detach() 110 | # Exponential moving average of squared gradient values 111 | state['exp_avg_sq'] = torch.zeros_like(p.data, memory_format=torch.preserve_format).detach() 112 | 113 | @torch.no_grad() 114 | def step(self, closure=None): 115 | """Performs a single optimization step. 116 | Arguments: 117 | closure (callable, optional): A closure that reevaluates the model 118 | and returns the loss. 119 | """ 120 | loss = None 121 | if closure is not None: 122 | loss = closure() 123 | 124 | 125 | g_sq = 0.0 126 | sksq_weighted = 0.0 127 | sk_l1 = 0.0 128 | 129 | ngroups = len(self.param_groups) 130 | 131 | group = self.param_groups[0] 132 | gsq_weighted = group['gsq_weighted'] 133 | d = group['d'] 134 | lr = group['lr'] 135 | dlr = d*lr 136 | 137 | no_prox = group['no_prox'] 138 | growth_rate = group['growth_rate'] 139 | log_every = group['log_every'] 140 | 141 | beta1, beta2, beta3 = group['betas'] 142 | 143 | for group in self.param_groups: 144 | decay = group['weight_decay'] 145 | k = group['k'] 146 | eps = group['eps'] 147 | 148 | for p in group['params']: 149 | if p.grad is None: 150 | continue 151 | grad = p.grad.data 152 | 153 | state = self.state[p] 154 | 155 | # State initialization 156 | if 'step' not in state: 157 | state['step'] = 0 158 | state['s'] = torch.zeros_like(p.data, memory_format=torch.preserve_format).detach() 159 | # Exponential moving average of gradient values 160 | state['exp_avg'] = torch.zeros_like(p.data, memory_format=torch.preserve_format).detach() 161 | # Exponential moving average of gradient difference 162 | state['exp_avg_diff'] = torch.zeros_like(p.data, memory_format=torch.preserve_format).detach() 163 | # Exponential moving average of squared gradient values 164 | state['exp_avg_sq'] = torch.zeros_like(to_real(p.data), memory_format=torch.preserve_format).detach() 165 | 166 | if state['step'] == 0: 167 | # Previous gradient values 168 | state['pre_grad'] = grad.clone() 169 | 170 | exp_avg, exp_avg_diff, exp_avg_sq = state['exp_avg'], state['exp_avg_diff'], state['exp_avg_sq'] 171 | grad_diff = grad - state['pre_grad'] 172 | 173 | grad_grad = to_real(grad * grad.conj()) 174 | update = grad + beta2 * grad_diff 175 | update_update = to_real(update * update.conj()) 176 | 177 | exp_avg.mul_(beta1).add_(grad, alpha=dlr*(1. - beta1)) 178 | exp_avg_diff.mul_(beta2).add_(grad_diff, alpha=dlr*(1. - beta2)) 179 | exp_avg_sq.mul_(beta3).add_(update_update, alpha=1. - beta3) 180 | 181 | denom = exp_avg_sq.sqrt().add_(eps) 182 | 183 | g_sq += grad_grad.div_(denom).sum().item() 184 | 185 | s = state['s'] 186 | s.mul_(beta3).add_(grad, alpha=dlr*(1. - beta3)) 187 | sksq_weighted += to_real(s * s.conj()).div_(denom).sum().item() 188 | sk_l1 += s.abs().sum().item() 189 | 190 | ###### 191 | 192 | gsq_weighted = beta3*gsq_weighted + g_sq*(dlr**2)*(1-beta3) 193 | d_hat = d 194 | 195 | # if we have not done any progres, return 196 | # if we have any gradients available, will have sk_l1 > 0 (unless \|g\|=0) 197 | if sk_l1 == 0: 198 | return loss 199 | 200 | if lr > 0.0: 201 | d_hat = (sksq_weighted/(1-beta3) - gsq_weighted)/sk_l1 202 | d = max(d, min(d_hat, d*growth_rate)) 203 | 204 | if log_every > 0 and k % log_every == 0: 205 | print(f"ng: {ngroups} lr: {lr} dlr: {dlr} d_hat: {d_hat}, d: {d}. sksq_weighted={sksq_weighted:1.1e} sk_l1={sk_l1:1.1e} gsq_weighted={gsq_weighted:1.1e}") 206 | 207 | for group in self.param_groups: 208 | group['gsq_weighted'] = gsq_weighted 209 | group['d'] = d 210 | 211 | decay = group['weight_decay'] 212 | k = group['k'] 213 | eps = group['eps'] 214 | 215 | for p in group['params']: 216 | if p.grad is None: 217 | continue 218 | grad = p.grad.data 219 | 220 | state = self.state[p] 221 | 222 | exp_avg, exp_avg_diff, exp_avg_sq = state['exp_avg'], state['exp_avg_diff'], state['exp_avg_sq'] 223 | 224 | state['step'] += 1 225 | 226 | denom = exp_avg_sq.sqrt().add_(eps) 227 | denom = denom.type(p.type()) 228 | 229 | update = (exp_avg + beta2 * exp_avg_diff).div_(denom) 230 | 231 | ### Take step 232 | if no_prox: 233 | p.data.mul_(1 - dlr * decay) 234 | p.add_(update, alpha=-1) 235 | else: 236 | p.add_(update, alpha=-1) 237 | p.data.div_(1 + dlr * decay) 238 | 239 | state['pre_grad'].copy_(grad) 240 | 241 | group['k'] = k + 1 242 | 243 | return loss 244 | -------------------------------------------------------------------------------- /dadaptation/dadapt_lion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Tuple, Optional, Callable 8 | 9 | import torch 10 | from torch.optim.optimizer import Optimizer 11 | import torch.distributed as dist 12 | import logging 13 | import pdb 14 | 15 | class DAdaptLion(Optimizer): 16 | r""" 17 | Implements Lion with D-Adaptation automatic step-sizes. 18 | Has not been as heavily tested as DAdaptAdam and should be considered experimental. 19 | 20 | 21 | Leave LR set to 1 unless you encounter instability. 22 | Arguments: 23 | params (iterable): 24 | Iterable of parameters to optimize or dicts defining parameter groups. 25 | lr (float): 26 | Learning rate adjustment parameter. Increases or decreases the D-adapted learning rate. 27 | betas (Tuple[float, float], optional): coefficients used for computing 28 | running averages of gradient and its square (default: (0.9, 0.999)) 29 | weight_decay (float): 30 | Weight decay, i.e. a L2 penalty (default: 0). 31 | log_every (int): 32 | Log using print every k steps, default 0 (no logging). 33 | d0 (float): 34 | Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing. 35 | fsdp_in_use (bool): 36 | If you're using sharded parameters, this should be set to True. The optimizer 37 | will attempt to auto-detect this, but if you're using an implementation other 38 | than PyTorch's builtin version, the auto-detection won't work. 39 | """ 40 | def __init__( 41 | self, 42 | params, 43 | lr: float = 1.0, 44 | betas: Tuple[float, float] = (0.9, 0.99), 45 | weight_decay: float = 0.0, log_every=0, 46 | d0=1e-6, fsdp_in_use=False): 47 | 48 | if not 0.0 <= lr: 49 | raise ValueError("Invalid learning rate: {}".format(lr)) 50 | if not 0.0 <= betas[0] < 1.0: 51 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 52 | if not 0.0 <= betas[1] < 1.0: 53 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 54 | 55 | defaults = dict( 56 | lr=lr, 57 | betas=betas, 58 | weight_decay=weight_decay, 59 | d=d0, k=0, 60 | log_every=log_every, 61 | numerator_weighted=0.0, 62 | fsdp_in_use=fsdp_in_use) 63 | 64 | super().__init__(params, defaults) 65 | 66 | def step(self, closure: Optional[Callable] = None): 67 | 68 | loss = None 69 | if closure is not None: 70 | with torch.enable_grad(): 71 | loss = closure() 72 | 73 | group = self.param_groups[0] 74 | numerator_weighted = group['numerator_weighted'] 75 | d = group['d'] 76 | lr = max(group['lr'] for group in self.param_groups) 77 | 78 | dlr = d*lr 79 | 80 | log_every = group['log_every'] 81 | fsdp_in_use = group['fsdp_in_use'] 82 | 83 | beta1, beta2 = group['betas'] 84 | sqrt_beta2 = beta2**0.5 85 | 86 | numerator_acum = 0.0 87 | sk_l1 = 0.0 88 | 89 | for group in self.param_groups: 90 | k = group['k'] 91 | group_lr = group['lr'] 92 | wd = group['weight_decay'] 93 | 94 | if group_lr not in [lr, 0.0]: 95 | raise RuntimeError(f"Setting different lr values in different parameter groups is only supported for values of 0") 96 | 97 | for p in group['params']: 98 | if p.grad is None: 99 | continue 100 | 101 | grad = p.grad 102 | state = self.state[p] 103 | 104 | if 'exp_avg' not in state: 105 | state['exp_avg'] = torch.zeros_like(p).detach() 106 | state['s'] = torch.zeros_like(p).detach() 107 | 108 | exp_avg = state['exp_avg'] 109 | s = state['s'] 110 | 111 | #AdamW style weight decay 112 | p.data.mul_(1-dlr*wd) 113 | 114 | update = exp_avg.clone().mul_(beta1).add_(grad, alpha=(1-beta1)).sign_() 115 | 116 | p.data.add_(update, alpha=-dlr) 117 | 118 | exp_avg.mul_(beta2).add_(grad, alpha=(1-beta2)*dlr) 119 | 120 | numerator_acum += dlr * torch.dot(update.flatten(), s.flatten()).item() 121 | 122 | s.mul_(sqrt_beta2).add_(update, alpha=(1-sqrt_beta2)*dlr) 123 | 124 | sk_l1 += s.abs().sum().item() 125 | 126 | d_hat = d 127 | 128 | # if we have not done any progres, return 129 | # if we have any gradients available, will have sk_l1 > 0 (unless \|g\|=0) 130 | if sk_l1 == 0: 131 | return loss 132 | 133 | if fsdp_in_use: 134 | dist_tensor = torch.zeros(2).cuda() 135 | dist_tensor[0] = numerator_acum 136 | dist_tensor[1] = sk_l1 137 | dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM) 138 | global_numerator_weighted = sqrt_beta2*numerator_weighted + (1-sqrt_beta2)*dist_tensor[0] 139 | global_sk_l1 = dist_tensor[1] 140 | else: 141 | global_numerator_weighted = sqrt_beta2*numerator_weighted + (1-sqrt_beta2)*numerator_acum 142 | global_sk_l1 = sk_l1 143 | 144 | if lr > 0.0: 145 | d_hat = global_numerator_weighted/((1-sqrt_beta2)*global_sk_l1) 146 | d = max(d, d_hat) 147 | 148 | if log_every > 0 and k % log_every == 0: 149 | logging.info(f"lr: {lr} dlr: {dlr} d_hat: {d_hat}, d: {d}. sk_l1={global_sk_l1:1.1e} numerator_weighted={global_numerator_weighted:1.1e}") 150 | 151 | for group in self.param_groups: 152 | group['numerator_weighted'] = global_numerator_weighted 153 | group['d'] = d 154 | group['k'] = group['k'] + 1 155 | 156 | return loss -------------------------------------------------------------------------------- /dadaptation/dadapt_sgd.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.optim 9 | import pdb 10 | import math 11 | import logging 12 | import torch.distributed as dist 13 | 14 | class DAdaptSGD(torch.optim.Optimizer): 15 | r""" 16 | Implements SGD with D-Adaptation automatic step-sizes. Leave LR set to 1 unless you encounter instability. 17 | 18 | Arguments: 19 | params (iterable): 20 | Iterable of parameters to optimize or dicts defining parameter groups. 21 | lr (float): 22 | Learning rate adjustment parameter. Increases or decreases the D-adapted learning rate. 23 | momentum (float): 24 | Momentum value in the range [0,1) (default: 0). 25 | weight_decay (float): 26 | Weight decay, i.e. a L2 penalty (default: 0). 27 | log_every (int): 28 | Log using print every k steps, default 0 (no logging). 29 | d0 (float): 30 | Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing. 31 | growth_rate (float): 32 | prevent the D estimate from growing faster than this multiplicative rate. 33 | Default is inf, for unrestricted. More conservative values like 1.02 may 34 | help if training is unstable. 35 | fsdp_in_use (bool): 36 | If you're using sharded parameters, this should be set to True. The optimizer 37 | will attempt to auto-detect this, but if you're using an implementation other 38 | than PyTorch's builtin version, the auto-detection won't work. 39 | """ 40 | def __init__(self, params, 41 | lr=1.0, 42 | momentum=0.0, 43 | weight_decay=0, 44 | log_every=0, 45 | d0=1e-6, growth_rate=float('inf'), 46 | fsdp_in_use=False): 47 | 48 | if not 0.0 < d0: 49 | raise ValueError("Invalid d0 value: {}".format(d0)) 50 | if not 0.0 < lr: 51 | raise ValueError("Invalid learning rate: {}".format(lr)) 52 | 53 | defaults = dict(lr=lr, 54 | momentum=momentum, 55 | weight_decay=weight_decay, k=0, 56 | log_every=log_every, 57 | numerator_weighted=0.0, 58 | d=d0, 59 | growth_rate=growth_rate, 60 | fsdp_in_use=fsdp_in_use) 61 | self.loggables = {} 62 | 63 | try: 64 | self.rank = torch.distributed.get_rank() 65 | except: 66 | self.rank = 0 67 | 68 | super().__init__(params, defaults) 69 | 70 | def step(self, closure=None): 71 | loss = None 72 | if closure is not None: 73 | loss = closure() 74 | 75 | group = self.param_groups[0] 76 | lr = max(group['lr'] for group in self.param_groups) 77 | 78 | decay = group['weight_decay'] 79 | momentum = group['momentum'] 80 | log_every = group['log_every'] 81 | ck = 1 - momentum 82 | k = group['k'] 83 | 84 | numerator_weighted = group['numerator_weighted'] 85 | growth_rate = group['growth_rate'] 86 | d = group['d'] 87 | fsdp_in_use = group['fsdp_in_use'] 88 | 89 | group = self.param_groups[0] 90 | 91 | sk_sq = 0.0 92 | delta_numerator_weighted = 0.0 93 | 94 | if k == 0: 95 | g_sq = 0.0 96 | for group in self.param_groups: 97 | group_lr = group['lr'] 98 | for p in group['params']: 99 | if p.grad is None: 100 | continue 101 | if hasattr(p, "_fsdp_flattened"): 102 | fsdp_in_use = True 103 | grad = p.grad.data 104 | 105 | # Apply weight decay 106 | if decay != 0: 107 | grad.add(p.data, alpha=decay) 108 | 109 | state = self.state[p] 110 | 111 | if group_lr > 0.0: 112 | g_sq += (grad * grad).sum().item() 113 | 114 | if fsdp_in_use: 115 | dist_tensor = torch.zeros(1).cuda() 116 | dist_tensor[0] = g_sq 117 | dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM) 118 | global_gsq = dist_tensor[0] 119 | else: 120 | global_gsq = g_sq 121 | group['g0_norm'] = g0_norm = math.sqrt(global_gsq) 122 | 123 | g0_norm = group['g0_norm'] 124 | 125 | dlr = d*lr/g0_norm 126 | 127 | for group in self.param_groups: 128 | group_lr = group['lr'] 129 | if group_lr not in [lr, 0.0]: 130 | raise RuntimeError(f"Setting different lr values in different parameter groups is only supported for values of 0") 131 | 132 | for p in group['params']: 133 | if p.grad is None: 134 | continue 135 | grad = p.grad.data 136 | state = self.state[p] 137 | 138 | if 'z' not in state: 139 | z = state['z'] = torch.clone(p.data).detach() 140 | s = state['s'] = torch.zeros_like(p.data).detach() 141 | x0 = state['x0'] = torch.clone(p.data).detach() 142 | 143 | # Apply weight decay 144 | if decay != 0: 145 | grad.add_(p.data, alpha=decay) 146 | 147 | s = state['s'] 148 | 149 | if group_lr > 0.0: 150 | delta_numerator_weighted += dlr * torch.dot(grad.flatten(), s.flatten()).item() 151 | 152 | s.data.add_(grad, alpha=dlr) 153 | sk_sq += (s * s).sum().item() 154 | ###### 155 | 156 | d_hat = d 157 | 158 | if fsdp_in_use: 159 | dist_tensor = torch.zeros(2).cuda() 160 | dist_tensor[0] = sk_sq 161 | dist_tensor[1] = delta_numerator_weighted 162 | dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM) 163 | global_sk_sq = dist_tensor[0] 164 | global_numerator_weighted = numerator_weighted + dist_tensor[1] 165 | else: 166 | global_sk_sq = sk_sq 167 | global_numerator_weighted = numerator_weighted + delta_numerator_weighted 168 | 169 | if lr > 0.0: 170 | d_hat = 2*global_numerator_weighted/math.sqrt(global_sk_sq) 171 | d = max(d, min(d_hat, d*growth_rate)) 172 | 173 | 174 | # if we have not done any updates 175 | # if we have any gradients available, will have sk_sq > 0 (unless \|g\|=0) 176 | if global_sk_sq == 0: 177 | return loss 178 | 179 | if log_every > 0 and k % log_every == 0: 180 | logging.info(f"(r={self.rank},k={k}) dlr: {dlr} d_hat: {d_hat}, d: {d}. sk_norm={math.sqrt(global_sk_sq)} numerator_weighted={global_numerator_weighted} g0_norm={g0_norm}") 181 | 182 | for group in self.param_groups: 183 | group['numerator_weighted'] = global_numerator_weighted 184 | group['d'] = d 185 | group['g0_norm'] = g0_norm 186 | ###################################### 187 | for p in group['params']: 188 | if p.grad is None: 189 | continue 190 | grad = p.grad.data 191 | state = self.state[p] 192 | 193 | s = state['s'] 194 | x0 = state['x0'] 195 | z = state['z'] 196 | 197 | # z step 198 | z.data.copy_(x0 - s) 199 | 200 | # x step 201 | p.data.mul_(1-ck).add_(z, alpha=ck) 202 | 203 | group['k'] = k + 1 204 | 205 | return loss -------------------------------------------------------------------------------- /dadaptation/experimental/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .dadapt_adan_ip import DAdaptAdanIP 8 | from .dadapt_adam_preprint import DAdaptAdamPreprint 9 | -------------------------------------------------------------------------------- /dadaptation/experimental/dadapt_adam_preprint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | from typing import TYPE_CHECKING, Any, Callable, Optional 9 | 10 | import torch 11 | import torch.optim 12 | import pdb 13 | import logging 14 | import os 15 | import torch.distributed as dist 16 | 17 | if TYPE_CHECKING: 18 | from torch.optim.optimizer import _params_t 19 | else: 20 | _params_t = Any 21 | 22 | def to_real(x): 23 | if torch.is_complex(x): 24 | return x.real 25 | else: 26 | return x 27 | 28 | class DAdaptAdamPreprint(torch.optim.Optimizer): 29 | r""" 30 | 31 | This is an earlier variant of D-Adapt Adam used in early preprints of the paper, and source 32 | code releases V1 and V2. Use this if you encounter performance regressions after the latest update. 33 | 34 | Arguments: 35 | params (iterable): 36 | Iterable of parameters to optimize or dicts defining parameter groups. 37 | lr (float): 38 | Learning rate adjustment parameter. Increases or decreases the D-adapted learning rate. 39 | betas (Tuple[float, float], optional): coefficients used for computing 40 | running averages of gradient and its square (default: (0.9, 0.999)) 41 | momentum (float): 42 | Momentum value in the range [0,1) (default: 0.9). 43 | eps (float): 44 | Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-8). 45 | weight_decay (float): 46 | Weight decay, i.e. a L2 penalty (default: 0). 47 | log_every (int): 48 | Log using print every k steps, default 0 (no logging). 49 | decouple (boolean): 50 | Use AdamW style decoupled weight decay 51 | d0 (float): 52 | Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing. 53 | growth_rate (float): 54 | prevent the D estimate from growing faster than this multiplicative rate. 55 | Default is inf, for unrestricted. Values like 1.02 give a kind of learning 56 | rate warmup effect. 57 | fsdp_in_use (bool): 58 | If you're using sharded parameters, this should be set to True. The optimizer 59 | will attempt to auto-detect this, but if you're using an implementation other 60 | than PyTorch's builtin version, the auto-detection won't work. 61 | """ 62 | def __init__(self, params, lr=1.0, 63 | betas=(0.9, 0.999), 64 | eps=1e-8, 65 | weight_decay=0, 66 | log_every=0, 67 | decouple=False, 68 | d0=1e-6, 69 | growth_rate=float('inf'), 70 | fsdp_in_use=False): 71 | if not 0.0 < d0: 72 | raise ValueError("Invalid d0 value: {}".format(d0)) 73 | if not 0.0 < lr: 74 | raise ValueError("Invalid learning rate: {}".format(lr)) 75 | if not 0.0 < eps: 76 | raise ValueError("Invalid epsilon value: {}".format(eps)) 77 | if not 0.0 <= betas[0] < 1.0: 78 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 79 | if not 0.0 <= betas[1] < 1.0: 80 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 81 | 82 | if decouple: 83 | print(f"Using decoupled weight decay") 84 | 85 | defaults = dict(lr=lr, betas=betas, eps=eps, 86 | weight_decay=weight_decay, 87 | d = d0, 88 | k=0, 89 | gsq_weighted=0.0, 90 | log_every=log_every, 91 | decouple=decouple, 92 | growth_rate=growth_rate, 93 | fsdp_in_use=fsdp_in_use) 94 | 95 | super().__init__(params, defaults) 96 | 97 | @property 98 | def supports_memory_efficient_fp16(self): 99 | return False 100 | 101 | @property 102 | def supports_flat_params(self): 103 | return True 104 | 105 | def step(self, closure=None): 106 | """Performs a single optimization step. 107 | 108 | Arguments: 109 | closure (callable, optional): A closure that reevaluates the model 110 | and returns the loss. 111 | """ 112 | loss = None 113 | if closure is not None: 114 | loss = closure() 115 | 116 | 117 | g_sq = 0.0 118 | sksq_weighted = 0.0 119 | sk_l1 = 0.0 120 | 121 | lr = max(group['lr'] for group in self.param_groups) 122 | 123 | group = self.param_groups[0] 124 | gsq_weighted = group['gsq_weighted'] 125 | d = group['d'] 126 | dlr = d*lr 127 | 128 | growth_rate = group['growth_rate'] 129 | decouple = group['decouple'] 130 | fsdp_in_use = group['fsdp_in_use'] 131 | log_every = group['log_every'] 132 | 133 | beta1, beta2 = group['betas'] 134 | 135 | for group in self.param_groups: 136 | group_lr = group['lr'] 137 | decay = group['weight_decay'] 138 | k = group['k'] 139 | eps = group['eps'] 140 | 141 | if group_lr not in [lr, 0.0]: 142 | raise RuntimeError(f"Setting different lr values in different parameter groups is only supported for values of 0") 143 | 144 | for p in group['params']: 145 | if p.grad is None: 146 | continue 147 | if hasattr(p, "_fsdp_flattened"): 148 | fsdp_in_use = True 149 | grad = p.grad.data 150 | 151 | # Apply weight decay (coupled variant) 152 | if decay != 0 and not decouple: 153 | grad.add_(p.data, alpha=decay) 154 | 155 | state = self.state[p] 156 | 157 | # State initialization 158 | if 'step' not in state: 159 | state['step'] = 0 160 | state['s'] = torch.zeros_like(p.data, memory_format=torch.preserve_format).detach() 161 | # Exponential moving average of gradient values 162 | state['exp_avg'] = torch.zeros_like(p.data, memory_format=torch.preserve_format).detach() 163 | # Exponential moving average of squared gradient values 164 | state['exp_avg_sq'] = torch.zeros_like(to_real(p.data), memory_format=torch.preserve_format).detach() 165 | 166 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 167 | 168 | grad_grad = to_real(grad * grad.conj()) 169 | 170 | # Adam EMA updates 171 | if group_lr > 0: 172 | exp_avg.mul_(beta1).add_(grad, alpha=dlr*(1-beta1)) 173 | exp_avg_sq.mul_(beta2).add_(grad_grad, alpha=1-beta2) 174 | 175 | denom = exp_avg_sq.sqrt().add_(eps) 176 | 177 | g_sq += grad_grad.div_(denom).sum().item() 178 | 179 | s = state['s'] 180 | s.mul_(beta2).add_(grad, alpha=dlr*(1-beta2)) 181 | sksq_weighted += to_real(s * s.conj()).div_(denom).sum().item() 182 | sk_l1 += s.abs().sum().item() 183 | 184 | ###### 185 | 186 | gsq_weighted = beta2*gsq_weighted + g_sq*(dlr**2)*(1-beta2) 187 | d_hat = d 188 | 189 | # if we have not done any progres, return 190 | # if we have any gradients available, will have sk_l1 > 0 (unless \|g\|=0) 191 | if sk_l1 == 0: 192 | return loss 193 | 194 | if lr > 0.0: 195 | if fsdp_in_use: 196 | dist_tensor = torch.zeros(3).cuda() 197 | dist_tensor[0] = sksq_weighted 198 | dist_tensor[1] = gsq_weighted 199 | dist_tensor[2] = sk_l1 200 | dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM) 201 | global_sksq_weighted = dist_tensor[0] 202 | global_gsq_weighted = dist_tensor[1] 203 | global_sk_l1 = dist_tensor[2] 204 | else: 205 | global_sksq_weighted = sksq_weighted 206 | global_gsq_weighted = gsq_weighted 207 | global_sk_l1 = sk_l1 208 | 209 | d_hat = (global_sksq_weighted/(1-beta2) - global_gsq_weighted)/global_sk_l1 210 | d = max(d, min(d_hat, d*growth_rate)) 211 | 212 | if log_every > 0 and k % log_every == 0: 213 | logging.info( 214 | f"(k={k}) dlr: {dlr:1.1e} d_hat: {d_hat:1.1e}, d: {d:1.8}. " 215 | f"sksq_weighted={global_sksq_weighted:1.1e} gsq_weighted={global_gsq_weighted:1.1e} " 216 | f"sk_l1={global_sk_l1:1.1e}{' (FSDP)' if fsdp_in_use else ''}") 217 | 218 | for group in self.param_groups: 219 | group['gsq_weighted'] = gsq_weighted 220 | group['d'] = d 221 | 222 | group_lr = group['lr'] 223 | decay = group['weight_decay'] 224 | k = group['k'] 225 | eps = group['eps'] 226 | 227 | for p in group['params']: 228 | if p.grad is None: 229 | continue 230 | grad = p.grad.data 231 | 232 | state = self.state[p] 233 | 234 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 235 | 236 | state['step'] += 1 237 | 238 | denom = exp_avg_sq.sqrt().add_(eps) 239 | denom = denom.type(p.type()) 240 | 241 | # Apply weight decay (decoupled variant) 242 | if decay != 0 and decouple and group_lr > 0: 243 | p.data.add_(p.data, alpha=-decay * dlr) 244 | 245 | 246 | ### Take step 247 | p.data.addcdiv_(exp_avg, denom, value=-1) 248 | 249 | group['k'] = k + 1 250 | 251 | return loss -------------------------------------------------------------------------------- /dadaptation/experimental/dadapt_adan_ip.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | from typing import TYPE_CHECKING, Any, Callable, Optional 9 | 10 | import torch 11 | import torch.optim 12 | import pdb 13 | import logging 14 | import os 15 | 16 | if TYPE_CHECKING: 17 | from torch.optim.optimizer import _params_t 18 | else: 19 | _params_t = Any 20 | 21 | def to_real(x): 22 | if torch.is_complex(x): 23 | return x.real 24 | else: 25 | return x 26 | 27 | class DAdaptAdanIP(torch.optim.Optimizer): 28 | r""" 29 | Implements Adan with D-Adaptation automatic step-sizes. Leave LR set to 1 unless you encounter instability. 30 | Adan was proposed in 31 | Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models[J]. arXiv preprint arXiv:2208.06677, 2022. 32 | https://arxiv.org/abs/2208.06677 33 | 34 | This IP variant uses a tighter bound than the non-IP version, 35 | and so will typically choose larger step sizes. It has not 36 | been as extensively tested. 37 | 38 | Arguments: 39 | params (iterable): 40 | Iterable of parameters to optimize or dicts defining parameter groups. 41 | lr (float): 42 | Learning rate adjustment parameter. Increases or decreases the D-adapted learning rate. 43 | betas (Tuple[float, float, flot], optional): coefficients used for computing 44 | running averages of gradient and its norm. (default: (0.98, 0.92, 0.99)) 45 | eps (float): 46 | Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-8). 47 | weight_decay (float): 48 | Weight decay, i.e. a L2 penalty (default: 0.02). 49 | no_prox (boolean): 50 | how to perform the decoupled weight decay (default: False) 51 | log_every (int): 52 | Log using print every k steps, default 0 (no logging). 53 | d0 (float): 54 | Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing. 55 | growth_rate (float): 56 | prevent the D estimate from growing faster than this multiplicative rate. 57 | Default is inf, for unrestricted. Values like 1.02 give a kind of learning 58 | rate warmup effect. 59 | """ 60 | def __init__(self, params, lr=1.0, 61 | betas=(0.98, 0.92, 0.99), 62 | eps=1e-8, weight_decay=0.02, 63 | no_prox=False, 64 | log_every=0, d0=1e-6, 65 | growth_rate=float('inf')): 66 | if not 0.0 < d0: 67 | raise ValueError("Invalid d0 value: {}".format(d0)) 68 | if not 0.0 < lr: 69 | raise ValueError("Invalid learning rate: {}".format(lr)) 70 | if not 0.0 < eps: 71 | raise ValueError("Invalid epsilon value: {}".format(eps)) 72 | if not 0.0 <= betas[0] < 1.0: 73 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 74 | if not 0.0 <= betas[1] < 1.0: 75 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 76 | if not 0.0 <= betas[2] < 1.0: 77 | raise ValueError("Invalid beta parameter at index 2: {}".format(betas[2])) 78 | 79 | defaults = dict(lr=lr, betas=betas, eps=eps, 80 | weight_decay=weight_decay, 81 | no_prox=no_prox, 82 | d = d0, 83 | k=0, 84 | numerator_weighted=0.0, 85 | log_every=log_every, 86 | growth_rate=growth_rate) 87 | self.d0 = d0 88 | super().__init__(params, defaults) 89 | 90 | @property 91 | def supports_memory_efficient_fp16(self): 92 | return False 93 | 94 | @property 95 | def supports_flat_params(self): 96 | return True 97 | 98 | # Experimental implementation of Adan's restart strategy 99 | @torch.no_grad() 100 | def restart_opt(self): 101 | for group in self.param_groups: 102 | group['numerator_weighted'] = 0.0 103 | for p in group['params']: 104 | if p.requires_grad: 105 | state = self.state[p] 106 | # State initialization 107 | 108 | state['step'] = 0 109 | state['s'] = torch.zeros_like(p.data, memory_format=torch.preserve_format).detach() 110 | # Exponential moving average of gradient values 111 | state['exp_avg'] = torch.zeros_like(p.data, memory_format=torch.preserve_format).detach() 112 | # Exponential moving average of gradient difference 113 | state['exp_avg_diff'] = torch.zeros_like(to_real(p.data), memory_format=torch.preserve_format).detach() 114 | # Exponential moving average of squared gradient values 115 | state['exp_avg_sq'] = torch.zeros_like(p.data, memory_format=torch.preserve_format).detach() 116 | 117 | @torch.no_grad() 118 | def step(self, closure=None): 119 | """Performs a single optimization step. 120 | Arguments: 121 | closure (callable, optional): A closure that reevaluates the model 122 | and returns the loss. 123 | """ 124 | loss = None 125 | if closure is not None: 126 | loss = closure() 127 | 128 | 129 | g_sq = 0.0 130 | sksq_weighted = 0.0 131 | sk_l1 = 0.0 132 | 133 | ngroups = len(self.param_groups) 134 | 135 | group = self.param_groups[0] 136 | numerator_weighted = group['numerator_weighted'] 137 | d = group['d'] 138 | lr = group['lr'] 139 | dlr = d*lr 140 | 141 | no_prox = group['no_prox'] 142 | growth_rate = group['growth_rate'] 143 | log_every = group['log_every'] 144 | 145 | beta1, beta2, beta3 = group['betas'] 146 | 147 | numerator_acum = 0.0 148 | 149 | for group in self.param_groups: 150 | decay = group['weight_decay'] 151 | k = group['k'] 152 | eps = group['eps'] 153 | 154 | for p in group['params']: 155 | if p.grad is None: 156 | continue 157 | grad = p.grad.data 158 | 159 | state = self.state[p] 160 | 161 | # State initialization 162 | if 'step' not in state: 163 | state['step'] = 0 164 | state['s'] = torch.zeros_like(p.data, memory_format=torch.preserve_format).detach() 165 | # Exponential moving average of gradient values 166 | state['exp_avg'] = torch.zeros_like(p.data, memory_format=torch.preserve_format).detach() 167 | # Exponential moving average of gradient difference 168 | state['exp_avg_diff'] = torch.zeros_like(p.data, memory_format=torch.preserve_format).detach() 169 | # Exponential moving average of squared gradient values 170 | state['exp_avg_sq'] = torch.zeros_like(to_real(p.data), memory_format=torch.preserve_format).detach() 171 | 172 | if state['step'] == 0: 173 | # Previous gradient values 174 | state['pre_grad'] = grad.clone() 175 | 176 | exp_avg, exp_avg_sq, exp_avg_diff = state['exp_avg'], state['exp_avg_diff'], state['exp_avg_sq'] 177 | grad_diff = grad - state['pre_grad'] 178 | 179 | update = grad + beta2 * grad_diff 180 | update_update = to_real(update * update.conj()) 181 | 182 | s = state['s'] 183 | 184 | denom = exp_avg_sq.sqrt().add_(eps) 185 | numerator_acum += dlr * torch.dot(grad.flatten(), s.div(denom).flatten()) 186 | 187 | exp_avg.mul_(beta1).add_(grad, alpha=dlr*(1. - beta1)) 188 | exp_avg_diff.mul_(beta2).add_(grad_diff, alpha=dlr*(1. - beta2)) 189 | exp_avg_sq.mul_(beta3).add_(update_update, alpha=1. - beta3) 190 | 191 | s.mul_(beta3).add_(grad, alpha=dlr*(1. - beta3)) 192 | sk_l1 += s.abs().sum().item() 193 | 194 | ###### 195 | 196 | numerator_weighted = beta3*numerator_weighted + (1-beta3)*numerator_acum 197 | d_hat = d 198 | 199 | # if we have not done any progres, return 200 | # if we have any gradients available, will have sk_l1 > 0 (unless \|g\|=0) 201 | if sk_l1 == 0: 202 | return loss 203 | 204 | if lr > 0.0: 205 | d_hat = 2*(beta3/(1-beta3))*numerator_weighted/sk_l1 206 | d = max(d, min(d_hat, d*growth_rate)) 207 | 208 | if log_every > 0 and k % log_every == 0: 209 | print(f"ng: {ngroups} lr: {lr} dlr: {dlr} d_hat: {d_hat}, d: {d}. sk_l1={sk_l1:1.1e} numerator_weighted={numerator_weighted:1.1e}") 210 | 211 | for group in self.param_groups: 212 | group['numerator_weighted'] = numerator_weighted 213 | group['d'] = d 214 | 215 | decay = group['weight_decay'] 216 | k = group['k'] 217 | eps = group['eps'] 218 | 219 | for p in group['params']: 220 | if p.grad is None: 221 | continue 222 | grad = p.grad.data 223 | 224 | state = self.state[p] 225 | 226 | exp_avg, exp_avg_sq, exp_avg_diff = state['exp_avg'], state['exp_avg_diff'], state['exp_avg_sq'] 227 | 228 | state['step'] += 1 229 | 230 | denom = exp_avg_sq.sqrt().add_(eps) 231 | denom = denom.type(p.type()) 232 | 233 | update = (exp_avg + beta2 * exp_avg_diff).div_(denom) 234 | 235 | ### Take step 236 | if no_prox: 237 | p.data.mul_(1 - dlr * decay) 238 | p.add_(update, alpha=-1) 239 | else: 240 | p.add_(update, alpha=-1) 241 | p.data.div_(1 + dlr * decay) 242 | 243 | state['pre_grad'].copy_(grad) 244 | 245 | group['k'] = k + 1 246 | 247 | return loss 248 | -------------------------------------------------------------------------------- /figures/dadapt_cifar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/dadaptation/c984980213a989ff7cac1e3d68c19a2257bcd571/figures/dadapt_cifar.png -------------------------------------------------------------------------------- /figures/dadapt_cifar100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/dadaptation/c984980213a989ff7cac1e3d68c19a2257bcd571/figures/dadapt_cifar100.png -------------------------------------------------------------------------------- /figures/dadapt_convex.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/dadaptation/c984980213a989ff7cac1e3d68c19a2257bcd571/figures/dadapt_convex.png -------------------------------------------------------------------------------- /figures/dadapt_detectron.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/dadaptation/c984980213a989ff7cac1e3d68c19a2257bcd571/figures/dadapt_detectron.png -------------------------------------------------------------------------------- /figures/dadapt_dlrm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/dadaptation/c984980213a989ff7cac1e3d68c19a2257bcd571/figures/dadapt_dlrm.png -------------------------------------------------------------------------------- /figures/dadapt_fastmri.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/dadaptation/c984980213a989ff7cac1e3d68c19a2257bcd571/figures/dadapt_fastmri.png -------------------------------------------------------------------------------- /figures/dadapt_gpt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/dadaptation/c984980213a989ff7cac1e3d68c19a2257bcd571/figures/dadapt_gpt.png -------------------------------------------------------------------------------- /figures/dadapt_imagenet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/dadaptation/c984980213a989ff7cac1e3d68c19a2257bcd571/figures/dadapt_imagenet.png -------------------------------------------------------------------------------- /figures/dadapt_lstm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/dadaptation/c984980213a989ff7cac1e3d68c19a2257bcd571/figures/dadapt_lstm.png -------------------------------------------------------------------------------- /figures/dadapt_roberta.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/dadaptation/c984980213a989ff7cac1e3d68c19a2257bcd571/figures/dadapt_roberta.png -------------------------------------------------------------------------------- /figures/dadapt_vit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/dadaptation/c984980213a989ff7cac1e3d68c19a2257bcd571/figures/dadapt_vit.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | [build-system] 8 | requires = [ 9 | "setuptools>=42", 10 | "wheel" 11 | ] 12 | build-backend = "setuptools.build_meta" 13 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch >= 1.5.1 2 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import setuptools 8 | 9 | with open("README.md", "r", encoding="utf-8") as fh: 10 | long_description = fh.read() 11 | 12 | setuptools.setup( 13 | name="dadaptation", 14 | version="3.2", 15 | author="Aaron Defazio", 16 | author_email="adefazio@meta.com", 17 | description="Learning Rate Free Learning for Adam, SGD and AdaGrad", 18 | long_description=long_description, 19 | long_description_content_type="text/markdown", 20 | url="https://github.com/facebookresearch/dadaptation", 21 | packages=setuptools.find_packages(), 22 | classifiers=[ 23 | "Programming Language :: Python :: 3", 24 | "License :: OSI Approved :: MIT License", 25 | "Operating System :: OS Independent", 26 | ], 27 | python_requires='>=3.6', 28 | ) 29 | --------------------------------------------------------------------------------