├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── ema_pytorch ├── __init__.py ├── ema_pytorch.py └── post_hoc_ema.py └── setup.py /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | 2 | 3 | # This workflow will upload a Python Package using Twine when a release is created 4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 5 | 6 | # This workflow uses actions that are not certified by GitHub. 7 | # They are provided by a third-party and are governed by 8 | # separate terms of service, privacy policy, and support 9 | # documentation. 10 | 11 | name: Upload Python Package 12 | 13 | on: 14 | release: 15 | types: [published] 16 | 17 | jobs: 18 | deploy: 19 | 20 | runs-on: ubuntu-latest 21 | 22 | steps: 23 | - uses: actions/checkout@v2 24 | - name: Set up Python 25 | uses: actions/setup-python@v2 26 | with: 27 | python-version: '3.x' 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install build 32 | - name: Build package 33 | run: python -m build 34 | - name: Publish package 35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 36 | with: 37 | user: __token__ 38 | password: ${{ secrets.PYPI_API_TOKEN }} 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## EMA - Pytorch 2 | 3 | A simple way to keep track of an Exponential Moving Average (EMA) version of your pytorch model 4 | 5 | ## Install 6 | 7 | ```bash 8 | $ pip install ema-pytorch 9 | ``` 10 | 11 | ## Usage 12 | 13 | ```python 14 | import torch 15 | from ema_pytorch import EMA 16 | 17 | # your neural network as a pytorch module 18 | 19 | net = torch.nn.Linear(512, 512) 20 | 21 | # wrap your neural network, specify the decay (beta) 22 | 23 | ema = EMA( 24 | net, 25 | beta = 0.9999, # exponential moving average factor 26 | update_after_step = 100, # only after this number of .update() calls will it start updating 27 | update_every = 10, # how often to actually update, to save on compute (updates every 10th .update() call) 28 | ) 29 | 30 | # mutate your network, with SGD or otherwise 31 | 32 | with torch.no_grad(): 33 | net.weight.copy_(torch.randn_like(net.weight)) 34 | net.bias.copy_(torch.randn_like(net.bias)) 35 | 36 | # you will call the update function on your moving average wrapper 37 | 38 | ema.update() 39 | 40 | # then, later on, you can invoke the EMA model the same way as your network 41 | 42 | data = torch.randn(1, 512) 43 | 44 | output = net(data) 45 | ema_output = ema(data) 46 | 47 | # if you want to save your ema model, it is recommended you save the entire wrapper 48 | # as it contains the number of steps taken (there is a warmup logic in there, recommended by @crowsonkb, validated for a number of projects now) 49 | # however, if you wish to access the copy of your model with EMA, then it will live at ema.ema_model 50 | ``` 51 | 52 | In order to use the post-hoc synthesized EMA, proposed by Karras et al. in a recent paper, follow the example below 53 | 54 | ```python 55 | import torch 56 | from ema_pytorch import PostHocEMA 57 | 58 | # your neural network as a pytorch module 59 | 60 | net = torch.nn.Linear(512, 512) 61 | 62 | # wrap your neural network, specify the sigma_rels or gammas 63 | 64 | emas = PostHocEMA( 65 | net, 66 | sigma_rels = (0.05, 0.28), # a tuple with the hyperparameter for the multiple EMAs. you need at least 2 here to synthesize a new one 67 | update_every = 10, # how often to actually update, to save on compute (updates every 10th .update() call) 68 | checkpoint_every_num_steps = 10, 69 | checkpoint_folder = './post-hoc-ema-checkpoints' # the folder of saved checkpoints for each sigma_rel (gamma) across timesteps with the hparam above, used to synthesizing a new EMA model after training 70 | ) 71 | 72 | net.train() 73 | 74 | for _ in range(1000): 75 | # mutate your network, with SGD or otherwise 76 | 77 | with torch.no_grad(): 78 | net.weight.copy_(torch.randn_like(net.weight)) 79 | net.bias.copy_(torch.randn_like(net.bias)) 80 | 81 | # you will call the update function on your moving average wrapper 82 | 83 | emas.update() 84 | 85 | # now that you have a few checkpoints 86 | # you can synthesize an EMA model with a different sigma_rel (say 0.15) 87 | 88 | synthesized_ema = emas.synthesize_ema_model(sigma_rel = 0.15) 89 | 90 | # output with synthesized EMA 91 | 92 | data = torch.randn(1, 512) 93 | 94 | synthesized_ema_output = synthesized_ema(data) 95 | 96 | ``` 97 | 98 | For testing out the claims of a free lunch from the `Switch EMA` paper, just set `update_model_with_ema_every` as so 99 | 100 | ```python 101 | 102 | ema = EMA( 103 | net, 104 | ..., 105 | update_model_with_ema_every = 10000 # say 10k steps is 1 epoch 106 | ) 107 | 108 | # or you can do it manually at the end of each epoch 109 | 110 | ema.update_model_with_ema() 111 | 112 | ``` 113 | 114 | ## Citations 115 | 116 | ```bibtex 117 | @article{Karras2023AnalyzingAI, 118 | title = {Analyzing and Improving the Training Dynamics of Diffusion Models}, 119 | author = {Tero Karras and Miika Aittala and Jaakko Lehtinen and Janne Hellsten and Timo Aila and Samuli Laine}, 120 | journal = {ArXiv}, 121 | year = {2023}, 122 | volume = {abs/2312.02696}, 123 | url = {https://api.semanticscholar.org/CorpusID:265659032} 124 | } 125 | ``` 126 | 127 | ```bibtex 128 | @article{Lee2024SlowAS, 129 | title = {Slow and Steady Wins the Race: Maintaining Plasticity with Hare and Tortoise Networks}, 130 | author = {Hojoon Lee and Hyeonseo Cho and Hyunseung Kim and Donghu Kim and Dugki Min and Jaegul Choo and Clare Lyle}, 131 | journal = {ArXiv}, 132 | year = {2024}, 133 | volume = {abs/2406.02596}, 134 | url = {https://api.semanticscholar.org/CorpusID:270258586} 135 | } 136 | ``` 137 | 138 | ```bibtex 139 | @article{Li2024SwitchEA, 140 | title = {Switch EMA: A Free Lunch for Better Flatness and Sharpness}, 141 | author = {Siyuan Li and Zicheng Liu and Juanxi Tian and Ge Wang and Zedong Wang and Weiyang Jin and Di Wu and Cheng Tan and Tao Lin and Yang Liu and Baigui Sun and Stan Z. Li}, 142 | journal = {ArXiv}, 143 | year = {2024}, 144 | volume = {abs/2402.09240}, 145 | url = {https://api.semanticscholar.org/CorpusID:267657558} 146 | } 147 | ``` 148 | -------------------------------------------------------------------------------- /ema_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from ema_pytorch.ema_pytorch import EMA 2 | 3 | from ema_pytorch.post_hoc_ema import ( 4 | KarrasEMA, 5 | PostHocEMA 6 | ) 7 | -------------------------------------------------------------------------------- /ema_pytorch/ema_pytorch.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Callable 3 | 4 | from copy import deepcopy 5 | from functools import partial 6 | 7 | import torch 8 | from torch import nn, Tensor 9 | from torch.nn import Module 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | def divisible_by(num, den): 15 | return (num % den) == 0 16 | 17 | def get_module_device(m: Module): 18 | return next(m.parameters()).device 19 | 20 | def maybe_coerce_dtype(t, dtype): 21 | if t.dtype == dtype: 22 | return t 23 | 24 | return t.to(dtype) 25 | 26 | def inplace_copy(tgt: Tensor, src: Tensor, *, auto_move_device = False, coerce_dtype = False): 27 | if auto_move_device: 28 | src = src.to(tgt.device) 29 | 30 | if coerce_dtype: 31 | src = maybe_coerce_dtype(src, tgt.dtype) 32 | 33 | tgt.copy_(src) 34 | 35 | def inplace_lerp(tgt: Tensor, src: Tensor, weight, *, auto_move_device = False, coerce_dtype = False): 36 | if auto_move_device: 37 | src = src.to(tgt.device) 38 | 39 | if coerce_dtype: 40 | src = maybe_coerce_dtype(src, tgt.dtype) 41 | 42 | tgt.lerp_(src, weight) 43 | 44 | class EMA(Module): 45 | """ 46 | Implements exponential moving average shadowing for your model. 47 | 48 | Utilizes an inverse decay schedule to manage longer term training runs. 49 | By adjusting the power, you can control how fast EMA will ramp up to your specified beta. 50 | 51 | @crowsonkb's notes on EMA Warmup: 52 | 53 | If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are 54 | good values for models you plan to train for a million or more steps (reaches decay 55 | factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models 56 | you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at 57 | 215.4k steps). 58 | 59 | Args: 60 | inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. 61 | power (float): Exponential factor of EMA warmup. Default: 2/3. 62 | min_value (float): The minimum EMA decay rate. Default: 0. 63 | """ 64 | 65 | def __init__( 66 | self, 67 | model: Module, 68 | ema_model: Module | Callable[[], Module] | None = None, # if your model has lazylinears or other types of non-deepcopyable modules, you can pass in your own ema model 69 | beta = 0.9999, 70 | update_after_step = 100, 71 | update_every = 10, 72 | inv_gamma = 1.0, 73 | power = 2 / 3, 74 | min_value = 0.0, 75 | param_or_buffer_names_no_ema: set[str] = set(), 76 | ignore_names: set[str] = set(), 77 | ignore_startswith_names: set[str] = set(), 78 | include_online_model = True, # set this to False if you do not wish for the online model to be saved along with the ema model (managed externally) 79 | allow_different_devices = False, # if the EMA model is on a different device (say CPU), automatically move the tensor 80 | use_foreach = False, 81 | update_model_with_ema_every = None, # update the model with EMA model weights every number of steps, for better continual learning https://arxiv.org/abs/2406.02596 82 | update_model_with_ema_beta = 0., # amount of model weight to keep when updating to EMA (hare to tortoise) 83 | forward_method_names: tuple[str, ...] = (), 84 | move_ema_to_online_device = False, 85 | coerce_dtype = False, 86 | lazy_init_ema = False, 87 | ): 88 | super().__init__() 89 | self.beta = beta 90 | 91 | self.is_frozen = beta == 1. 92 | 93 | # whether to include the online model within the module tree, so that state_dict also saves it 94 | 95 | self.include_online_model = include_online_model 96 | 97 | if include_online_model: 98 | self.online_model = model 99 | else: 100 | self.online_model = [model] # hack 101 | 102 | # handle callable returning ema module 103 | 104 | if not isinstance(ema_model, Module) and callable(ema_model): 105 | ema_model = ema_model() 106 | 107 | # ema model 108 | 109 | self.ema_model = None 110 | self.forward_method_names = forward_method_names 111 | 112 | if not lazy_init_ema: 113 | self.init_ema(ema_model) 114 | else: 115 | assert not exists(ema_model) 116 | 117 | # tensor update functions 118 | 119 | self.inplace_copy = partial(inplace_copy, auto_move_device = allow_different_devices, coerce_dtype = coerce_dtype) 120 | self.inplace_lerp = partial(inplace_lerp, auto_move_device = allow_different_devices, coerce_dtype = coerce_dtype) 121 | 122 | # updating hyperparameters 123 | 124 | self.update_every = update_every 125 | self.update_after_step = update_after_step 126 | 127 | self.inv_gamma = inv_gamma 128 | self.power = power 129 | self.min_value = min_value 130 | 131 | assert isinstance(param_or_buffer_names_no_ema, (set, list)) 132 | self.param_or_buffer_names_no_ema = param_or_buffer_names_no_ema # parameter or buffer 133 | 134 | self.ignore_names = ignore_names 135 | self.ignore_startswith_names = ignore_startswith_names 136 | 137 | # continual learning related 138 | 139 | self.update_model_with_ema_every = update_model_with_ema_every 140 | self.update_model_with_ema_beta = update_model_with_ema_beta 141 | 142 | # whether to manage if EMA model is kept on a different device 143 | 144 | self.allow_different_devices = allow_different_devices 145 | 146 | # whether to coerce dtype when copy or lerp from online to EMA model 147 | 148 | self.coerce_dtype = coerce_dtype 149 | 150 | # whether to move EMA model to online model device automatically 151 | 152 | self.move_ema_to_online_device = move_ema_to_online_device 153 | 154 | # whether to use foreach 155 | 156 | if use_foreach: 157 | assert hasattr(torch, '_foreach_lerp_') and hasattr(torch, '_foreach_copy_'), 'your version of torch does not have the prerequisite foreach functions' 158 | 159 | self.use_foreach = use_foreach 160 | 161 | # init and step states 162 | 163 | self.register_buffer('initted', torch.tensor(False)) 164 | self.register_buffer('step', torch.tensor(0)) 165 | 166 | def init_ema( 167 | self, 168 | ema_model: Module | None = None 169 | ): 170 | self.ema_model = ema_model 171 | 172 | if not exists(self.ema_model): 173 | try: 174 | self.ema_model = deepcopy(self.model) 175 | except Exception as e: 176 | print(f'Error: While trying to deepcopy model: {e}') 177 | print('Your model was not copyable. Please make sure you are not using any LazyLinear') 178 | exit() 179 | 180 | for p in self.ema_model.parameters(): 181 | p.detach_() 182 | 183 | # forwarding methods 184 | 185 | for forward_method_name in self.forward_method_names: 186 | fn = getattr(self.ema_model, forward_method_name) 187 | setattr(self, forward_method_name, fn) 188 | 189 | # parameter and buffer names 190 | 191 | self.parameter_names = {name for name, param in self.ema_model.named_parameters() if torch.is_floating_point(param) or torch.is_complex(param)} 192 | self.buffer_names = {name for name, buffer in self.ema_model.named_buffers() if torch.is_floating_point(buffer) or torch.is_complex(buffer)} 193 | 194 | def add_to_optimizer_post_step_hook(self, optimizer): 195 | assert hasattr(optimizer, 'register_step_post_hook') 196 | 197 | def hook(*_): 198 | self.update() 199 | 200 | return optimizer.register_step_post_hook(hook) 201 | 202 | @property 203 | def model(self): 204 | return self.online_model if self.include_online_model else self.online_model[0] 205 | 206 | def eval(self): 207 | return self.ema_model.eval() 208 | 209 | @torch.no_grad() 210 | def forward_eval(self, *args, **kwargs): 211 | # handy function for invoking ema model with no grad + eval 212 | training = self.ema_model.training 213 | out = self.ema_model(*args, **kwargs) 214 | self.ema_model.train(training) 215 | return out 216 | 217 | def restore_ema_model_device(self): 218 | device = self.initted.device 219 | self.ema_model.to(device) 220 | 221 | def get_params_iter(self, model): 222 | for name, param in model.named_parameters(): 223 | if name not in self.parameter_names: 224 | continue 225 | yield name, param 226 | 227 | def get_buffers_iter(self, model): 228 | for name, buffer in model.named_buffers(): 229 | if name not in self.buffer_names: 230 | continue 231 | yield name, buffer 232 | 233 | def copy_params_from_model_to_ema(self): 234 | copy = self.inplace_copy 235 | 236 | for (_, ma_params), (_, current_params) in zip(self.get_params_iter(self.ema_model), self.get_params_iter(self.model)): 237 | copy(ma_params.data, current_params.data) 238 | 239 | for (_, ma_buffers), (_, current_buffers) in zip(self.get_buffers_iter(self.ema_model), self.get_buffers_iter(self.model)): 240 | copy(ma_buffers.data, current_buffers.data) 241 | 242 | def copy_params_from_ema_to_model(self): 243 | copy = self.inplace_copy 244 | 245 | for (_, ma_params), (_, current_params) in zip(self.get_params_iter(self.ema_model), self.get_params_iter(self.model)): 246 | copy(current_params.data, ma_params.data) 247 | 248 | for (_, ma_buffers), (_, current_buffers) in zip(self.get_buffers_iter(self.ema_model), self.get_buffers_iter(self.model)): 249 | copy(current_buffers.data, ma_buffers.data) 250 | 251 | def update_model_with_ema(self, decay = None): 252 | if not exists(decay): 253 | decay = self.update_model_with_ema_beta 254 | 255 | if decay == 0.: 256 | return self.copy_params_from_ema_to_model() 257 | 258 | self.update_moving_average(self.model, self.ema_model, decay) 259 | 260 | def get_current_decay(self): 261 | epoch = (self.step - self.update_after_step - 1).clamp(min = 0.) 262 | value = 1 - (1 + epoch / self.inv_gamma) ** - self.power 263 | 264 | if epoch.item() <= 0: 265 | return 0. 266 | 267 | return value.clamp(min = self.min_value, max = self.beta).item() 268 | 269 | def update(self): 270 | step = self.step.item() 271 | self.step += 1 272 | 273 | if not self.initted.item(): 274 | if not exists(self.ema_model): 275 | self.init_ema() 276 | 277 | self.copy_params_from_model_to_ema() 278 | self.initted.data.copy_(torch.tensor(True)) 279 | return 280 | 281 | should_update = divisible_by(step, self.update_every) 282 | 283 | if should_update and step <= self.update_after_step: 284 | self.copy_params_from_model_to_ema() 285 | return 286 | 287 | if should_update: 288 | self.update_moving_average(self.ema_model, self.model) 289 | 290 | if exists(self.update_model_with_ema_every) and divisible_by(step, self.update_model_with_ema_every): 291 | self.update_model_with_ema() 292 | 293 | @torch.no_grad() 294 | def update_moving_average(self, ma_model, current_model, current_decay = None): 295 | if self.is_frozen: 296 | return 297 | 298 | # move ema model to online model device if not same and needed 299 | 300 | if self.move_ema_to_online_device and get_module_device(ma_model) != get_module_device(current_model): 301 | ma_model.to(get_module_device(current_model)) 302 | 303 | # get current decay 304 | 305 | if not exists(current_decay): 306 | current_decay = self.get_current_decay() 307 | 308 | # store all source and target tensors to copy or lerp 309 | 310 | tensors_to_copy = [] 311 | tensors_to_lerp = [] 312 | 313 | # loop through parameters 314 | 315 | for (name, current_params), (_, ma_params) in zip(self.get_params_iter(current_model), self.get_params_iter(ma_model)): 316 | if name in self.ignore_names: 317 | continue 318 | 319 | if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]): 320 | continue 321 | 322 | if name in self.param_or_buffer_names_no_ema: 323 | tensors_to_copy.append((ma_params.data, current_params.data)) 324 | continue 325 | 326 | tensors_to_lerp.append((ma_params.data, current_params.data)) 327 | 328 | # loop through buffers 329 | 330 | for (name, current_buffer), (_, ma_buffer) in zip(self.get_buffers_iter(current_model), self.get_buffers_iter(ma_model)): 331 | if name in self.ignore_names: 332 | continue 333 | 334 | if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]): 335 | continue 336 | 337 | if name in self.param_or_buffer_names_no_ema: 338 | tensors_to_copy.append((ma_buffer.data, current_buffer.data)) 339 | continue 340 | 341 | tensors_to_lerp.append((ma_buffer.data, current_buffer.data)) 342 | 343 | # execute inplace copy or lerp 344 | 345 | if not self.use_foreach: 346 | 347 | for tgt, src in tensors_to_copy: 348 | self.inplace_copy(tgt, src) 349 | 350 | for tgt, src in tensors_to_lerp: 351 | self.inplace_lerp(tgt, src, 1. - current_decay) 352 | 353 | else: 354 | # use foreach if available and specified 355 | 356 | if self.allow_different_devices: 357 | tensors_to_copy = [(tgt, src.to(tgt.device)) for tgt, src in tensors_to_copy] 358 | tensors_to_lerp = [(tgt, src.to(tgt.device)) for tgt, src in tensors_to_lerp] 359 | 360 | if self.coerce_dtype: 361 | tensors_to_copy = [(tgt, maybe_coerce_dtype(src, tgt.dtype)) for tgt, src in tensors_to_copy] 362 | tensors_to_lerp = [(tgt, maybe_coerce_dtype(src, tgt.dtype)) for tgt, src in tensors_to_lerp] 363 | 364 | if len(tensors_to_copy) > 0: 365 | tgt_copy, src_copy = zip(*tensors_to_copy) 366 | torch._foreach_copy_(tgt_copy, src_copy) 367 | 368 | if len(tensors_to_lerp) > 0: 369 | tgt_lerp, src_lerp = zip(*tensors_to_lerp) 370 | torch._foreach_lerp_(tgt_lerp, src_lerp, 1. - current_decay) 371 | 372 | def __call__(self, *args, **kwargs): 373 | return self.ema_model(*args, **kwargs) 374 | -------------------------------------------------------------------------------- /ema_pytorch/post_hoc_ema.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Callable, Literal 3 | 4 | from pathlib import Path 5 | from copy import deepcopy 6 | from functools import partial 7 | 8 | import torch 9 | from torch import nn, Tensor 10 | from torch.nn import Module, ModuleList 11 | 12 | import numpy as np 13 | 14 | def exists(val): 15 | return val is not None 16 | 17 | def default(val, d): 18 | return val if exists(val) else d 19 | 20 | def first(arr): 21 | return arr[0] 22 | 23 | def divisible_by(num, den): 24 | return (num % den) == 0 25 | 26 | def get_module_device(m: Module): 27 | return next(m.parameters()).device 28 | 29 | def inplace_copy(tgt: Tensor, src: Tensor, *, auto_move_device = False): 30 | if auto_move_device: 31 | src = src.to(tgt.device) 32 | 33 | tgt.copy_(src) 34 | 35 | def inplace_lerp(tgt: Tensor, src: Tensor, weight, *, auto_move_device = False): 36 | if auto_move_device: 37 | src = src.to(tgt.device) 38 | 39 | tgt.lerp_(src, weight) 40 | 41 | # algorithm 2 in https://arxiv.org/abs/2312.02696 42 | 43 | def sigma_rel_to_gamma(sigma_rel): 44 | t = sigma_rel ** -2 45 | return np.roots([1, 7, 16 - t, 12 - t]).real.max().item() 46 | 47 | class KarrasEMA(Module): 48 | """ 49 | exponential moving average module that uses hyperparameters from the paper https://arxiv.org/abs/2312.02696 50 | can either use gamma or sigma_rel from paper 51 | """ 52 | 53 | def __init__( 54 | self, 55 | model: Module, 56 | sigma_rel: float | None = None, 57 | gamma: float | None = None, 58 | ema_model: Module | Callable[[], Module] | None = None, # if your model has lazylinears or other types of non-deepcopyable modules, you can pass in your own ema model 59 | update_every: int = 100, 60 | frozen: bool = False, 61 | param_or_buffer_names_no_ema: set[str] = set(), 62 | ignore_names: set[str] = set(), 63 | ignore_startswith_names: set[str] = set(), 64 | allow_different_devices = False, # if the EMA model is on a different device (say CPU), automatically move the tensor 65 | move_ema_to_online_device = False # will move entire EMA model to the same device as online model, if different 66 | ): 67 | super().__init__() 68 | 69 | assert exists(sigma_rel) ^ exists(gamma), 'either sigma_rel or gamma is given. gamma is derived from sigma_rel as in the paper, then beta is dervied from gamma' 70 | 71 | if exists(sigma_rel): 72 | gamma = sigma_rel_to_gamma(sigma_rel) 73 | 74 | self.gamma = gamma 75 | self.frozen = frozen 76 | 77 | self.online_model = [model] 78 | 79 | # handle callable returning ema module 80 | 81 | if not isinstance(ema_model, Module) and callable(ema_model): 82 | ema_model = ema_model() 83 | 84 | # ema model 85 | 86 | self.ema_model = ema_model 87 | 88 | if not exists(self.ema_model): 89 | try: 90 | self.ema_model = deepcopy(model) 91 | except Exception as e: 92 | print(f'Error: While trying to deepcopy model: {e}') 93 | print('Your model was not copyable. Please make sure you are not using any LazyLinear') 94 | exit() 95 | 96 | for p in self.ema_model.parameters(): 97 | p.detach_() 98 | 99 | # parameter and buffer names 100 | 101 | self.parameter_names = {name for name, param in self.ema_model.named_parameters() if torch.is_floating_point(param) or torch.is_complex(param)} 102 | self.buffer_names = {name for name, buffer in self.ema_model.named_buffers() if torch.is_floating_point(buffer) or torch.is_complex(buffer)} 103 | 104 | # tensor update functions 105 | 106 | self.inplace_copy = partial(inplace_copy, auto_move_device = allow_different_devices) 107 | self.inplace_lerp = partial(inplace_lerp, auto_move_device = allow_different_devices) 108 | 109 | # updating hyperparameters 110 | 111 | self.update_every = update_every 112 | 113 | assert isinstance(param_or_buffer_names_no_ema, (set, list)) 114 | self.param_or_buffer_names_no_ema = param_or_buffer_names_no_ema # parameter or buffer 115 | 116 | self.ignore_names = ignore_names 117 | self.ignore_startswith_names = ignore_startswith_names 118 | 119 | # whether to manage if EMA model is kept on a different device 120 | 121 | self.allow_different_devices = allow_different_devices 122 | 123 | # whether to move EMA model to online model device automatically 124 | 125 | self.move_ema_to_online_device = move_ema_to_online_device 126 | 127 | # init and step states 128 | 129 | self.register_buffer('initted', torch.tensor(False)) 130 | self.register_buffer('step', torch.tensor(0)) 131 | 132 | @property 133 | def model(self): 134 | return first(self.online_model) 135 | 136 | @property 137 | def beta(self): 138 | return (1. - 1. / (self.step.item() + 1.)) ** (1. + self.gamma) 139 | 140 | def eval(self): 141 | return self.ema_model.eval() 142 | 143 | def restore_ema_model_device(self): 144 | device = self.initted.device 145 | self.ema_model.to(device) 146 | 147 | def get_params_iter(self, model): 148 | for name, param in model.named_parameters(): 149 | if name not in self.parameter_names: 150 | continue 151 | yield name, param 152 | 153 | def get_buffers_iter(self, model): 154 | for name, buffer in model.named_buffers(): 155 | if name not in self.buffer_names: 156 | continue 157 | yield name, buffer 158 | 159 | def copy_params_from_model_to_ema(self): 160 | copy = self.inplace_copy 161 | 162 | for (_, ma_params), (_, current_params) in zip(self.get_params_iter(self.ema_model), self.get_params_iter(self.model)): 163 | copy(ma_params.data, current_params.data) 164 | 165 | for (_, ma_buffers), (_, current_buffers) in zip(self.get_buffers_iter(self.ema_model), self.get_buffers_iter(self.model)): 166 | copy(ma_buffers.data, current_buffers.data) 167 | 168 | def copy_params_from_ema_to_model(self): 169 | copy = self.inplace_copy 170 | 171 | for (_, ma_params), (_, current_params) in zip(self.get_params_iter(self.ema_model), self.get_params_iter(self.model)): 172 | copy(current_params.data, ma_params.data) 173 | 174 | for (_, ma_buffers), (_, current_buffers) in zip(self.get_buffers_iter(self.ema_model), self.get_buffers_iter(self.model)): 175 | copy(current_buffers.data, ma_buffers.data) 176 | 177 | def update(self): 178 | step = self.step.item() 179 | self.step += 1 180 | 181 | if (step % self.update_every) != 0: 182 | return 183 | 184 | if not self.initted.item(): 185 | self.copy_params_from_model_to_ema() 186 | self.initted.data.copy_(torch.tensor(True)) 187 | 188 | self.update_moving_average(self.ema_model, self.model) 189 | 190 | def iter_all_ema_params_and_buffers(self): 191 | for name, ma_params in self.get_params_iter(self.ema_model): 192 | if name in self.ignore_names: 193 | continue 194 | 195 | if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]): 196 | continue 197 | 198 | if name in self.param_or_buffer_names_no_ema: 199 | continue 200 | 201 | yield ma_params 202 | 203 | for name, ma_buffer in self.get_buffers_iter(self.ema_model): 204 | if name in self.ignore_names: 205 | continue 206 | 207 | if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]): 208 | continue 209 | 210 | if name in self.param_or_buffer_names_no_ema: 211 | continue 212 | 213 | yield ma_buffer 214 | 215 | @torch.no_grad() 216 | def update_moving_average(self, ma_model, current_model): 217 | if self.frozen: 218 | return 219 | 220 | # move ema model to online model device if not same and needed 221 | 222 | if self.move_ema_to_online_device and get_module_device(ma_model) != get_module_device(current_model): 223 | ma_model.to(get_module_device(current_model)) 224 | 225 | # get some functions and current decay 226 | 227 | copy, lerp = self.inplace_copy, self.inplace_lerp 228 | current_decay = self.beta 229 | 230 | for (name, current_params), (_, ma_params) in zip(self.get_params_iter(current_model), self.get_params_iter(ma_model)): 231 | if name in self.ignore_names: 232 | continue 233 | 234 | if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]): 235 | continue 236 | 237 | if name in self.param_or_buffer_names_no_ema: 238 | copy(ma_params.data, current_params.data) 239 | continue 240 | 241 | lerp(ma_params.data, current_params.data, 1. - current_decay) 242 | 243 | for (name, current_buffer), (_, ma_buffer) in zip(self.get_buffers_iter(current_model), self.get_buffers_iter(ma_model)): 244 | if name in self.ignore_names: 245 | continue 246 | 247 | if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]): 248 | continue 249 | 250 | if name in self.param_or_buffer_names_no_ema: 251 | copy(ma_buffer.data, current_buffer.data) 252 | continue 253 | 254 | lerp(ma_buffer.data, current_buffer.data, 1. - current_decay) 255 | 256 | def __call__(self, *args, **kwargs): 257 | return self.ema_model(*args, **kwargs) 258 | 259 | # post hoc ema wrapper 260 | 261 | # solving of the weights for combining all checkpoints into a newly synthesized EMA at desired gamma 262 | # Algorithm 3 copied from paper, redone in torch 263 | 264 | def p_dot_p(t_a, gamma_a, t_b, gamma_b): 265 | t_ratio = t_a / t_b 266 | t_exp = torch.where(t_a < t_b , gamma_b , -gamma_a) 267 | t_max = torch.maximum(t_a , t_b) 268 | num = (gamma_a + 1) * (gamma_b + 1) * t_ratio ** t_exp 269 | den = (gamma_a + gamma_b + 1) * t_max 270 | return num / den 271 | 272 | def solve_weights(t_i, gamma_i, t_r, gamma_r): 273 | rv = lambda x: x.double().reshape(-1, 1) 274 | cv = lambda x: x.double().reshape(1, -1) 275 | A = p_dot_p(rv(t_i), rv(gamma_i), cv(t_i), cv(gamma_i)) 276 | b = p_dot_p(rv(t_i), rv(gamma_i), cv(t_r), cv(gamma_r)) 277 | return torch.linalg.solve(A, b) 278 | 279 | class PostHocEMA(Module): 280 | 281 | def __init__( 282 | self, 283 | model: Module, 284 | ema_model: Callable[[], Module] | None = None, 285 | sigma_rels: tuple[float, ...] | None = None, 286 | gammas: tuple[float, ...] | None = None, 287 | checkpoint_every_num_steps: int | Literal['manual'] = 1000, 288 | checkpoint_folder: str = './post-hoc-ema-checkpoints', 289 | checkpoint_dtype: torch.dtype = torch.float16, 290 | **kwargs 291 | ): 292 | super().__init__() 293 | assert exists(sigma_rels) ^ exists(gammas) 294 | 295 | if exists(sigma_rels): 296 | gammas = tuple(map(sigma_rel_to_gamma, sigma_rels)) 297 | 298 | assert len(gammas) > 1, 'at least 2 ema models with different gammas in order to synthesize new ema models of a different gamma' 299 | assert len(set(gammas)) == len(gammas), 'calculated gammas must be all unique' 300 | 301 | self.maybe_ema_model = ema_model 302 | 303 | self.gammas = gammas 304 | self.num_ema_models = len(gammas) 305 | 306 | self._model = [model] 307 | self.ema_models = ModuleList([KarrasEMA(model, ema_model = ema_model, gamma = gamma, **kwargs) for gamma in gammas]) 308 | 309 | self.checkpoint_folder = Path(checkpoint_folder) 310 | self.checkpoint_folder.mkdir(exist_ok = True, parents = True) 311 | assert self.checkpoint_folder.is_dir() 312 | 313 | self.checkpoint_every_num_steps = checkpoint_every_num_steps 314 | self.checkpoint_dtype = checkpoint_dtype 315 | self.ema_kwargs = kwargs 316 | 317 | @property 318 | def model(self): 319 | return first(self._model) 320 | 321 | @property 322 | def step(self): 323 | return first(self.ema_models).step 324 | 325 | @property 326 | def device(self): 327 | return self.step.device 328 | 329 | def copy_params_from_model_to_ema(self): 330 | for ema_model in self.ema_models: 331 | ema_model.copy_params_from_model_to_ema() 332 | 333 | def copy_params_from_ema_to_model(self): 334 | for ema_model in self.ema_models: 335 | ema_model.copy_params_from_ema_to_model() 336 | 337 | def update(self): 338 | for ema_model in self.ema_models: 339 | ema_model.update() 340 | 341 | if self.checkpoint_every_num_steps == 'manual': 342 | return 343 | 344 | if divisible_by(self.step.item(), self.checkpoint_every_num_steps): 345 | self.checkpoint() 346 | 347 | def checkpoint(self): 348 | step = self.step.item() 349 | 350 | for ind, ema_model in enumerate(self.ema_models): 351 | filename = f'{ind}.{step}.pt' 352 | path = self.checkpoint_folder / filename 353 | 354 | pkg = { 355 | k: v.to(device = 'cpu', dtype = self.checkpoint_dtype, copy = True) 356 | for k, v in ema_model.state_dict().items() 357 | } 358 | 359 | torch.save(pkg, str(path)) 360 | 361 | def synthesize_ema_model( 362 | self, 363 | gamma: float | None = None, 364 | sigma_rel: float | None = None, 365 | step: int | None = None, 366 | ) -> KarrasEMA: 367 | assert exists(gamma) ^ exists(sigma_rel) 368 | device = self.device 369 | 370 | if exists(sigma_rel): 371 | gamma = sigma_rel_to_gamma(sigma_rel) 372 | 373 | synthesized_ema_model = KarrasEMA( 374 | model = self.model, 375 | ema_model = self.maybe_ema_model, 376 | gamma = gamma, 377 | **self.ema_kwargs 378 | ) 379 | 380 | synthesized_ema_model 381 | 382 | # get all checkpoints 383 | 384 | gammas = [] 385 | timesteps = [] 386 | checkpoints = [*self.checkpoint_folder.glob('*.pt')] 387 | 388 | for file in checkpoints: 389 | gamma_ind, timestep = map(int, file.stem.split('.')) 390 | gammas.append(self.gammas[gamma_ind]) 391 | timesteps.append(timestep) 392 | 393 | step = default(step, max(timesteps)) 394 | assert step <= max(timesteps), f'you can only synthesize for a timestep that is less than the max timestep {max(timesteps)}' 395 | 396 | # line up with Algorithm 3 397 | 398 | gamma_i = torch.tensor(gammas, device = device) 399 | t_i = torch.tensor(timesteps, device = device) 400 | 401 | gamma_r = torch.tensor([gamma], device = device) 402 | t_r = torch.tensor([step], device = device) 403 | 404 | # solve for weights for combining all checkpoints into synthesized, using least squares as in paper 405 | 406 | weights = solve_weights(t_i, gamma_i, t_r, gamma_r) 407 | weights = weights.squeeze(-1) 408 | 409 | # now sum up all the checkpoints using the weights one by one 410 | 411 | tmp_ema_model = KarrasEMA( 412 | model = self.model, 413 | ema_model = self.maybe_ema_model, 414 | gamma = gamma, 415 | **self.ema_kwargs 416 | ) 417 | 418 | for ind, (checkpoint, weight) in enumerate(zip(checkpoints, weights.tolist())): 419 | is_first = ind == 0 420 | 421 | # load checkpoint into a temporary ema model 422 | 423 | ckpt_state_dict = torch.load(str(checkpoint), weights_only=True) 424 | tmp_ema_model.load_state_dict(ckpt_state_dict) 425 | 426 | # add weighted checkpoint to synthesized 427 | 428 | for ckpt_tensor, synth_tensor in zip(tmp_ema_model.iter_all_ema_params_and_buffers(), synthesized_ema_model.iter_all_ema_params_and_buffers()): 429 | if is_first: 430 | synth_tensor.zero_() 431 | 432 | synth_tensor.add_(ckpt_tensor * weight) 433 | 434 | # return the synthesized model 435 | 436 | return synthesized_ema_model 437 | 438 | def __call__(self, *args, **kwargs): 439 | return tuple(ema_model(*args, **kwargs) for ema_model in self.ema_models) 440 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'ema-pytorch', 5 | packages = find_packages(exclude=[]), 6 | version = '0.7.7', 7 | license='MIT', 8 | description = 'Easy way to keep track of exponential moving average version of your pytorch module', 9 | author = 'Phil Wang', 10 | author_email = 'lucidrains@gmail.com', 11 | long_description_content_type = 'text/markdown', 12 | url = 'https://github.com/lucidrains/ema-pytorch', 13 | keywords = [ 14 | 'artificial intelligence', 15 | 'deep learning', 16 | 'exponential moving average' 17 | ], 18 | install_requires=[ 19 | 'torch>=2.0', 20 | ], 21 | classifiers=[ 22 | 'Development Status :: 4 - Beta', 23 | 'Intended Audience :: Developers', 24 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 25 | 'License :: OSI Approved :: MIT License', 26 | 'Programming Language :: Python :: 3.6', 27 | ], 28 | ) 29 | --------------------------------------------------------------------------------