├── .gitignore ├── LICENSE ├── README.md ├── anyprecision.py └── fadam.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Less Wright 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FAdam_PyTorch 2 | an implementation of FAdam (Fisher Adam) in PyTorch 3 | 4 | Please see the official Arxiv paper: 5 | [FAdam: Adam is a natural gradient optimizer using 6 | diagonal empirical Fisher information](https://arxiv.org/abs/2405.12807) 7 | 8 | Schedule: 9 | 1 - impl in eager PyTorch --> **Complete and working.** 10 | 2 - add adaptive epsilon --> **Complete and working.** 11 | 12 | 3 - (if torch.compile not performant) - update to fused Cuda kernel (cpp extension) 13 | 14 | -------------------------------------------------------------------------------- /anyprecision.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # AnyPrecisionAdamW: a flexible precision AdamW optimizer 8 | # with optional Kahan summation for high precision weight updates. 9 | # Allows direct control over momentum, variance and auxiliary compensation 10 | # buffer dtypes. 11 | # Optional Kahan summation is used to offset precision reduction for 12 | # the weight updates. This allows full training in BFloat16 (equal or 13 | # better than FP32 results in many cases) due to high precision weight upates. 14 | 15 | import torch 16 | from torch.optim.optimizer import Optimizer 17 | 18 | 19 | class AnyPrecisionAdamW(Optimizer): 20 | def __init__( 21 | self, 22 | params, 23 | lr=1e-3, 24 | betas=(0.9, 0.999), 25 | eps=1e-8, 26 | eps2=1e-5, 27 | weight_decay=0.0, 28 | use_kahan_summation=False, 29 | use_numerical_guarantee: bool = True, 30 | momentum_dtype=torch.float16, 31 | variance_dtype=torch.float16, 32 | compensation_buffer_dtype=torch.bfloat16, 33 | ): 34 | """ 35 | Args: 36 | params (iterable): iterable of parameters to optimize or dicts defining 37 | parameter groups 38 | lr (float, optional): learning rate (default: 1e-3) 39 | betas (Tuple[float, float], optional): coefficients used for computing 40 | running averages of gradient and its square (default: (0.9, 0.999)) 41 | eps (float, optional): term added to the denominator to improve 42 | numerical stability (default: 1e-8) 43 | weight_decay (float, optional): weight decay coefficient (default: 1e-2) 44 | 45 | # Any Precision specific 46 | use_kahan_summation = creates auxiliary buffer to ensure high precision 47 | model param updates (default: False) 48 | momentum_dtype = dtype for momentum (default: BFloat32) 49 | variance_dtype = dtype for uncentered variance (default: BFloat16) 50 | compensation_buffer_dtype = dtype for Kahan summation 51 | buffer (default: BFloat16). Only used if 52 | ``use_kahan_summation=True``. 53 | 54 | # Usage 55 | This optimizer implements optimizer states, and Kahan summation 56 | for high precision updates, all in user controlled dtypes. 57 | Defaults are variance in BF16, Momentum in FP32. 58 | This can be run in FSDP mixed precision, amp, or full precision, 59 | depending on what training pipeline you wish to work with. 60 | 61 | Setting to use_kahan_summation = False, and changing momentum and 62 | variance dtypes to FP32, reverts this to a standard AdamW optimizer. 63 | """ 64 | defaults = dict( 65 | lr=lr, 66 | betas=betas, 67 | eps=eps, 68 | weight_decay=weight_decay, 69 | use_kahan_summation=use_kahan_summation, 70 | momentum_dtype=momentum_dtype, 71 | variance_dtype=variance_dtype, 72 | compensation_buffer_dtype=compensation_buffer_dtype, 73 | use_numerical_guarantee = use_numerical_guarantee, 74 | ) 75 | 76 | super().__init__(params, defaults) 77 | 78 | @torch.no_grad() 79 | def step(self, closure=None): 80 | """Performs a single optimization step. 81 | Args: 82 | closure (callable, optional): A closure that reevaluates the model 83 | and returns the loss. 84 | """ 85 | 86 | if closure is not None: 87 | with torch.enable_grad(): 88 | # to fix linter, we do not keep the returned loss for use atm. 89 | closure() 90 | 91 | for group in self.param_groups: 92 | beta1, beta2 = group["betas"] 93 | lr = group["lr"] 94 | weight_decay = group["weight_decay"] 95 | eps = group["eps"] 96 | use_kahan_summation = group["use_kahan_summation"] 97 | 98 | momentum_dtype = group["momentum_dtype"] 99 | variance_dtype = group["variance_dtype"] 100 | compensation_buffer_dtype = group["compensation_buffer_dtype"] 101 | use_numerical_guarantee = group['use_numerical_guarantee'] 102 | 103 | 104 | for p in group["params"]: 105 | if p.grad is None: 106 | continue 107 | 108 | if p.grad.is_sparse: 109 | raise RuntimeError( 110 | "AnyPrecisionAdamW does not support sparse gradients" 111 | ) 112 | 113 | state = self.state[p] 114 | 115 | # State initialization 116 | if len(state) == 0: 117 | state["step"] = torch.tensor(0.0) 118 | 119 | # momentum - EMA of gradient values 120 | state["exp_avg"] = torch.zeros_like( 121 | p, 122 | dtype=momentum_dtype, 123 | ) 124 | 125 | # variance uncentered - EMA of squared gradient values 126 | state["exp_avg_sq"] = torch.zeros_like( 127 | p, 128 | dtype=variance_dtype, 129 | ) 130 | 131 | # optional Kahan summation - accumulated error tracker 132 | if use_kahan_summation: 133 | state["compensation"] = torch.zeros_like( 134 | p, 135 | dtype=compensation_buffer_dtype, 136 | ) 137 | 138 | # main processing ------------------------- 139 | 140 | # update the steps for each param group update 141 | state["step"] += 1 142 | step = state["step"] 143 | 144 | exp_avg = state["exp_avg"] 145 | exp_avg_sq = state["exp_avg_sq"] 146 | 147 | grad = p.grad 148 | 149 | # weight decay, AdamW style 150 | if weight_decay: 151 | p.data.mul_(1 - lr * weight_decay) 152 | 153 | # update momentum 154 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 155 | 156 | # update uncentered variance 157 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 158 | 159 | # adjust using bias1 160 | bias_correction1 = 1 - beta1**step 161 | 162 | step_size = lr / bias_correction1 163 | 164 | # adjust using bias2 165 | denom_correction = (1 - beta2**step) ** 0.5 # avoids math import 166 | 167 | #if not use_numerical_guarantee: 168 | centered_variance = exp_avg_sq.sqrt() / denom_correction.add_( 169 | 1e-12, alpha=1 170 | ) 171 | 172 | if use_numerical_guarantee: 173 | # denom_max = max(denom_correction, eps) 174 | # centered_variance = exp_avg_sq.sqrt() / denom_max 175 | 176 | safe_variance = torch.clamp(centered_variance, min=1e-7) 177 | safe_variance = safe_variance.sqrt() 178 | 179 | # lr update to compensation 180 | if use_kahan_summation: 181 | compensation = state["compensation"] 182 | compensation.addcdiv_(exp_avg, centered_variance, value=-step_size) 183 | # update weights with compensation (Kahan summation) 184 | # save error back to compensation for next iteration 185 | temp_buffer = p.detach().clone() 186 | p.data.add_(compensation) 187 | compensation.add_(temp_buffer.sub_(p.data)) 188 | 189 | elif use_numerical_guarantee: 190 | p.data.addcdiv_(exp_avg, safe_variance, value=-step_size) 191 | else: 192 | p.data.addcdiv_(exp_avg, centered_variance, value=-step_size) 193 | -------------------------------------------------------------------------------- /fadam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the MIT-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # FAdam (Fisher Adam): an implentation in PyTorch of the paper: 8 | # "FAdam: Adam is a natural gradient optimizer using diagonal empirical Fisher information" 9 | # https://www.arxiv.org/abs/2405.12807 10 | 11 | 12 | import torch 13 | from torch.optim.optimizer import Optimizer 14 | from typing import Tuple, Optional 15 | from torchtitan.logging_utils import logger 16 | 17 | class FAdam(Optimizer): 18 | def __init__( 19 | self, 20 | params, 21 | lr: float = 1e-3, 22 | weight_decay: float = 0.1, 23 | betas: Tuple[float, float] = (0.9, 0.999), 24 | clip: float = 1.0, 25 | p: float = 0.5, 26 | eps: float = 1e-8, 27 | momentum_dtype: torch.dtype = torch.float32, 28 | fim_dtype: torch.dtype = torch.float32, 29 | ): 30 | """ 31 | Args: 32 | params (iterable): iterable of parameters to optimize or dicts defining 33 | parameter groups 34 | lr (float, optional): learning rate (default: 1e-3) 35 | betas (Tuple[float, float], optional): coefficients used for computing 36 | running averages of gradient and its square (default: (0.9, 0.999)) 37 | eps (float, optional): term added to the denominator to improve 38 | numerical stability (default: 1e-15) 39 | clip (float, optional): maximum norm of the gradient (default: 1.0) 40 | TODO - explain p 41 | 42 | # Usage 43 | TODO 44 | """ 45 | defaults = dict( 46 | lr=lr, 47 | betas=betas, 48 | weight_decay=weight_decay, 49 | eps=eps, 50 | momentum_dtype=momentum_dtype, 51 | fim_dtype=fim_dtype, 52 | clip=clip, 53 | p=p, 54 | ) 55 | 56 | super().__init__(params, defaults) 57 | 58 | @torch.no_grad() 59 | def step(self, closure: Optional[callable] = None) -> Optional[float]: 60 | """Performs a single optimization step. 61 | Args: 62 | closure (callable, optional): A closure that reevaluates the model 63 | and returns the loss. 64 | """ 65 | loss = None 66 | if closure is not None: 67 | with torch.enable_grad(): 68 | # to fix linter, we do not keep the returned loss for use atm. 69 | loss = closure() 70 | 71 | for group in self.param_groups: 72 | beta1, beta2 = group["betas"] 73 | lr = group["lr"] 74 | eps = group["eps"] 75 | clip = group["clip"] 76 | pval = group["p"] 77 | momentum_dtype = group["momentum_dtype"] 78 | fim_dtype = group["fim_dtype"] 79 | weight_decay = group["weight_decay"] 80 | 81 | for p in group["params"]: 82 | if p.grad is None: 83 | continue 84 | 85 | if p.grad.is_sparse: 86 | raise RuntimeError("FAdam does not support sparse gradients") 87 | 88 | state = self.state[p] 89 | 90 | # State initialization 91 | if len(state) == 0: 92 | state.setdefault("step", torch.tensor(0.0)) 93 | state.setdefault( 94 | "momentum", torch.zeros_like(p, dtype=momentum_dtype) 95 | ) 96 | state.setdefault("fim", torch.ones_like(p, dtype=fim_dtype)) 97 | 98 | 99 | # main processing ------------------------- 100 | 101 | # update the steps for each param group update 102 | state["step"] += 1 103 | step = state["step"] 104 | 105 | momentum = state["momentum"] 106 | fim = state["fim"] 107 | grad = p.grad 108 | 109 | # begin FAdam algo ------------------------- 110 | # 6 - beta2 bias correction per Section 3.4.4 111 | curr_beta2 = beta2 * (1 - beta2 ** (step - 1)) / (1 - beta2**step) 112 | 113 | # 7 - update fim 114 | fim.mul_(curr_beta2).add_(grad * grad, alpha=1 - curr_beta2) 115 | 116 | # 8 - adaptive epsilon 117 | rms_grad = torch.sqrt(torch.mean((grad * grad))) 118 | curr_eps = eps * min(1, rms_grad) 119 | 120 | # 9 - compute natural gradient 121 | fim_base = fim**pval + curr_eps # **(2*pval) 122 | 123 | grad_nat = grad / fim_base 124 | 125 | # 10 - clip the natural gradient 126 | rms = torch.sqrt(torch.mean(grad_nat**2)) 127 | divisor = max(1, rms) 128 | divisor = divisor / clip 129 | grad_nat = grad_nat / divisor 130 | 131 | # 11 - update momentum 132 | momentum.mul_(beta1).add_(grad_nat, alpha=1 - beta1) 133 | 134 | # 12 - weight decay 135 | grad_weights = p / fim_base 136 | 137 | # 13 - clip weight decay 138 | rms = torch.sqrt(torch.mean(grad_weights**2)) 139 | divisor = max(1, rms) 140 | divisor /= clip 141 | grad_weights = grad_weights / divisor 142 | 143 | # 14 - compute update 144 | full_step = momentum + (weight_decay * grad_weights) 145 | lr_step = lr * full_step 146 | 147 | # 15 - update weights 148 | p.sub_(lr_step) 149 | 150 | return loss 151 | --------------------------------------------------------------------------------