├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── ema_pytorch
├── __init__.py
├── ema_pytorch.py
└── post_hoc_ema.py
└── setup.py
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 |
2 |
3 | # This workflow will upload a Python Package using Twine when a release is created
4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
5 |
6 | # This workflow uses actions that are not certified by GitHub.
7 | # They are provided by a third-party and are governed by
8 | # separate terms of service, privacy policy, and support
9 | # documentation.
10 |
11 | name: Upload Python Package
12 |
13 | on:
14 | release:
15 | types: [published]
16 |
17 | jobs:
18 | deploy:
19 |
20 | runs-on: ubuntu-latest
21 |
22 | steps:
23 | - uses: actions/checkout@v2
24 | - name: Set up Python
25 | uses: actions/setup-python@v2
26 | with:
27 | python-version: '3.x'
28 | - name: Install dependencies
29 | run: |
30 | python -m pip install --upgrade pip
31 | pip install build
32 | - name: Build package
33 | run: python -m build
34 | - name: Publish package
35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
36 | with:
37 | user: __token__
38 | password: ${{ secrets.PYPI_API_TOKEN }}
39 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## EMA - Pytorch
2 |
3 | A simple way to keep track of an Exponential Moving Average (EMA) version of your pytorch model
4 |
5 | ## Install
6 |
7 | ```bash
8 | $ pip install ema-pytorch
9 | ```
10 |
11 | ## Usage
12 |
13 | ```python
14 | import torch
15 | from ema_pytorch import EMA
16 |
17 | # your neural network as a pytorch module
18 |
19 | net = torch.nn.Linear(512, 512)
20 |
21 | # wrap your neural network, specify the decay (beta)
22 |
23 | ema = EMA(
24 | net,
25 | beta = 0.9999, # exponential moving average factor
26 | update_after_step = 100, # only after this number of .update() calls will it start updating
27 | update_every = 10, # how often to actually update, to save on compute (updates every 10th .update() call)
28 | )
29 |
30 | # mutate your network, with SGD or otherwise
31 |
32 | with torch.no_grad():
33 | net.weight.copy_(torch.randn_like(net.weight))
34 | net.bias.copy_(torch.randn_like(net.bias))
35 |
36 | # you will call the update function on your moving average wrapper
37 |
38 | ema.update()
39 |
40 | # then, later on, you can invoke the EMA model the same way as your network
41 |
42 | data = torch.randn(1, 512)
43 |
44 | output = net(data)
45 | ema_output = ema(data)
46 |
47 | # if you want to save your ema model, it is recommended you save the entire wrapper
48 | # as it contains the number of steps taken (there is a warmup logic in there, recommended by @crowsonkb, validated for a number of projects now)
49 | # however, if you wish to access the copy of your model with EMA, then it will live at ema.ema_model
50 | ```
51 |
52 | In order to use the post-hoc synthesized EMA, proposed by Karras et al. in a recent paper, follow the example below
53 |
54 | ```python
55 | import torch
56 | from ema_pytorch import PostHocEMA
57 |
58 | # your neural network as a pytorch module
59 |
60 | net = torch.nn.Linear(512, 512)
61 |
62 | # wrap your neural network, specify the sigma_rels or gammas
63 |
64 | emas = PostHocEMA(
65 | net,
66 | sigma_rels = (0.05, 0.28), # a tuple with the hyperparameter for the multiple EMAs. you need at least 2 here to synthesize a new one
67 | update_every = 10, # how often to actually update, to save on compute (updates every 10th .update() call)
68 | checkpoint_every_num_steps = 10,
69 | checkpoint_folder = './post-hoc-ema-checkpoints' # the folder of saved checkpoints for each sigma_rel (gamma) across timesteps with the hparam above, used to synthesizing a new EMA model after training
70 | )
71 |
72 | net.train()
73 |
74 | for _ in range(1000):
75 | # mutate your network, with SGD or otherwise
76 |
77 | with torch.no_grad():
78 | net.weight.copy_(torch.randn_like(net.weight))
79 | net.bias.copy_(torch.randn_like(net.bias))
80 |
81 | # you will call the update function on your moving average wrapper
82 |
83 | emas.update()
84 |
85 | # now that you have a few checkpoints
86 | # you can synthesize an EMA model with a different sigma_rel (say 0.15)
87 |
88 | synthesized_ema = emas.synthesize_ema_model(sigma_rel = 0.15)
89 |
90 | # output with synthesized EMA
91 |
92 | data = torch.randn(1, 512)
93 |
94 | synthesized_ema_output = synthesized_ema(data)
95 |
96 | ```
97 |
98 | For testing out the claims of a free lunch from the `Switch EMA` paper, just set `update_model_with_ema_every` as so
99 |
100 | ```python
101 |
102 | ema = EMA(
103 | net,
104 | ...,
105 | update_model_with_ema_every = 10000 # say 10k steps is 1 epoch
106 | )
107 |
108 | # or you can do it manually at the end of each epoch
109 |
110 | ema.update_model_with_ema()
111 |
112 | ```
113 |
114 | ## Citations
115 |
116 | ```bibtex
117 | @article{Karras2023AnalyzingAI,
118 | title = {Analyzing and Improving the Training Dynamics of Diffusion Models},
119 | author = {Tero Karras and Miika Aittala and Jaakko Lehtinen and Janne Hellsten and Timo Aila and Samuli Laine},
120 | journal = {ArXiv},
121 | year = {2023},
122 | volume = {abs/2312.02696},
123 | url = {https://api.semanticscholar.org/CorpusID:265659032}
124 | }
125 | ```
126 |
127 | ```bibtex
128 | @article{Lee2024SlowAS,
129 | title = {Slow and Steady Wins the Race: Maintaining Plasticity with Hare and Tortoise Networks},
130 | author = {Hojoon Lee and Hyeonseo Cho and Hyunseung Kim and Donghu Kim and Dugki Min and Jaegul Choo and Clare Lyle},
131 | journal = {ArXiv},
132 | year = {2024},
133 | volume = {abs/2406.02596},
134 | url = {https://api.semanticscholar.org/CorpusID:270258586}
135 | }
136 | ```
137 |
138 | ```bibtex
139 | @article{Li2024SwitchEA,
140 | title = {Switch EMA: A Free Lunch for Better Flatness and Sharpness},
141 | author = {Siyuan Li and Zicheng Liu and Juanxi Tian and Ge Wang and Zedong Wang and Weiyang Jin and Di Wu and Cheng Tan and Tao Lin and Yang Liu and Baigui Sun and Stan Z. Li},
142 | journal = {ArXiv},
143 | year = {2024},
144 | volume = {abs/2402.09240},
145 | url = {https://api.semanticscholar.org/CorpusID:267657558}
146 | }
147 | ```
148 |
--------------------------------------------------------------------------------
/ema_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from ema_pytorch.ema_pytorch import EMA
2 |
3 | from ema_pytorch.post_hoc_ema import (
4 | KarrasEMA,
5 | PostHocEMA
6 | )
7 |
--------------------------------------------------------------------------------
/ema_pytorch/ema_pytorch.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Callable
3 |
4 | from copy import deepcopy
5 | from functools import partial
6 |
7 | import torch
8 | from torch import nn, Tensor
9 | from torch.nn import Module
10 |
11 | def exists(val):
12 | return val is not None
13 |
14 | def divisible_by(num, den):
15 | return (num % den) == 0
16 |
17 | def get_module_device(m: Module):
18 | return next(m.parameters()).device
19 |
20 | def maybe_coerce_dtype(t, dtype):
21 | if t.dtype == dtype:
22 | return t
23 |
24 | return t.to(dtype)
25 |
26 | def inplace_copy(tgt: Tensor, src: Tensor, *, auto_move_device = False, coerce_dtype = False):
27 | if auto_move_device:
28 | src = src.to(tgt.device)
29 |
30 | if coerce_dtype:
31 | src = maybe_coerce_dtype(src, tgt.dtype)
32 |
33 | tgt.copy_(src)
34 |
35 | def inplace_lerp(tgt: Tensor, src: Tensor, weight, *, auto_move_device = False, coerce_dtype = False):
36 | if auto_move_device:
37 | src = src.to(tgt.device)
38 |
39 | if coerce_dtype:
40 | src = maybe_coerce_dtype(src, tgt.dtype)
41 |
42 | tgt.lerp_(src, weight)
43 |
44 | class EMA(Module):
45 | """
46 | Implements exponential moving average shadowing for your model.
47 |
48 | Utilizes an inverse decay schedule to manage longer term training runs.
49 | By adjusting the power, you can control how fast EMA will ramp up to your specified beta.
50 |
51 | @crowsonkb's notes on EMA Warmup:
52 |
53 | If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are
54 | good values for models you plan to train for a million or more steps (reaches decay
55 | factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models
56 | you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
57 | 215.4k steps).
58 |
59 | Args:
60 | inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
61 | power (float): Exponential factor of EMA warmup. Default: 2/3.
62 | min_value (float): The minimum EMA decay rate. Default: 0.
63 | """
64 |
65 | def __init__(
66 | self,
67 | model: Module,
68 | ema_model: Module | Callable[[], Module] | None = None, # if your model has lazylinears or other types of non-deepcopyable modules, you can pass in your own ema model
69 | beta = 0.9999,
70 | update_after_step = 100,
71 | update_every = 10,
72 | inv_gamma = 1.0,
73 | power = 2 / 3,
74 | min_value = 0.0,
75 | param_or_buffer_names_no_ema: set[str] = set(),
76 | ignore_names: set[str] = set(),
77 | ignore_startswith_names: set[str] = set(),
78 | include_online_model = True, # set this to False if you do not wish for the online model to be saved along with the ema model (managed externally)
79 | allow_different_devices = False, # if the EMA model is on a different device (say CPU), automatically move the tensor
80 | use_foreach = False,
81 | update_model_with_ema_every = None, # update the model with EMA model weights every number of steps, for better continual learning https://arxiv.org/abs/2406.02596
82 | update_model_with_ema_beta = 0., # amount of model weight to keep when updating to EMA (hare to tortoise)
83 | forward_method_names: tuple[str, ...] = (),
84 | move_ema_to_online_device = False,
85 | coerce_dtype = False,
86 | lazy_init_ema = False,
87 | ):
88 | super().__init__()
89 | self.beta = beta
90 |
91 | self.is_frozen = beta == 1.
92 |
93 | # whether to include the online model within the module tree, so that state_dict also saves it
94 |
95 | self.include_online_model = include_online_model
96 |
97 | if include_online_model:
98 | self.online_model = model
99 | else:
100 | self.online_model = [model] # hack
101 |
102 | # handle callable returning ema module
103 |
104 | if not isinstance(ema_model, Module) and callable(ema_model):
105 | ema_model = ema_model()
106 |
107 | # ema model
108 |
109 | self.ema_model = None
110 | self.forward_method_names = forward_method_names
111 |
112 | if not lazy_init_ema:
113 | self.init_ema(ema_model)
114 | else:
115 | assert not exists(ema_model)
116 |
117 | # tensor update functions
118 |
119 | self.inplace_copy = partial(inplace_copy, auto_move_device = allow_different_devices, coerce_dtype = coerce_dtype)
120 | self.inplace_lerp = partial(inplace_lerp, auto_move_device = allow_different_devices, coerce_dtype = coerce_dtype)
121 |
122 | # updating hyperparameters
123 |
124 | self.update_every = update_every
125 | self.update_after_step = update_after_step
126 |
127 | self.inv_gamma = inv_gamma
128 | self.power = power
129 | self.min_value = min_value
130 |
131 | assert isinstance(param_or_buffer_names_no_ema, (set, list))
132 | self.param_or_buffer_names_no_ema = param_or_buffer_names_no_ema # parameter or buffer
133 |
134 | self.ignore_names = ignore_names
135 | self.ignore_startswith_names = ignore_startswith_names
136 |
137 | # continual learning related
138 |
139 | self.update_model_with_ema_every = update_model_with_ema_every
140 | self.update_model_with_ema_beta = update_model_with_ema_beta
141 |
142 | # whether to manage if EMA model is kept on a different device
143 |
144 | self.allow_different_devices = allow_different_devices
145 |
146 | # whether to coerce dtype when copy or lerp from online to EMA model
147 |
148 | self.coerce_dtype = coerce_dtype
149 |
150 | # whether to move EMA model to online model device automatically
151 |
152 | self.move_ema_to_online_device = move_ema_to_online_device
153 |
154 | # whether to use foreach
155 |
156 | if use_foreach:
157 | assert hasattr(torch, '_foreach_lerp_') and hasattr(torch, '_foreach_copy_'), 'your version of torch does not have the prerequisite foreach functions'
158 |
159 | self.use_foreach = use_foreach
160 |
161 | # init and step states
162 |
163 | self.register_buffer('initted', torch.tensor(False))
164 | self.register_buffer('step', torch.tensor(0))
165 |
166 | def init_ema(
167 | self,
168 | ema_model: Module | None = None
169 | ):
170 | self.ema_model = ema_model
171 |
172 | if not exists(self.ema_model):
173 | try:
174 | self.ema_model = deepcopy(self.model)
175 | except Exception as e:
176 | print(f'Error: While trying to deepcopy model: {e}')
177 | print('Your model was not copyable. Please make sure you are not using any LazyLinear')
178 | exit()
179 |
180 | for p in self.ema_model.parameters():
181 | p.detach_()
182 |
183 | # forwarding methods
184 |
185 | for forward_method_name in self.forward_method_names:
186 | fn = getattr(self.ema_model, forward_method_name)
187 | setattr(self, forward_method_name, fn)
188 |
189 | # parameter and buffer names
190 |
191 | self.parameter_names = {name for name, param in self.ema_model.named_parameters() if torch.is_floating_point(param) or torch.is_complex(param)}
192 | self.buffer_names = {name for name, buffer in self.ema_model.named_buffers() if torch.is_floating_point(buffer) or torch.is_complex(buffer)}
193 |
194 | def add_to_optimizer_post_step_hook(self, optimizer):
195 | assert hasattr(optimizer, 'register_step_post_hook')
196 |
197 | def hook(*_):
198 | self.update()
199 |
200 | return optimizer.register_step_post_hook(hook)
201 |
202 | @property
203 | def model(self):
204 | return self.online_model if self.include_online_model else self.online_model[0]
205 |
206 | def eval(self):
207 | return self.ema_model.eval()
208 |
209 | @torch.no_grad()
210 | def forward_eval(self, *args, **kwargs):
211 | # handy function for invoking ema model with no grad + eval
212 | training = self.ema_model.training
213 | out = self.ema_model(*args, **kwargs)
214 | self.ema_model.train(training)
215 | return out
216 |
217 | def restore_ema_model_device(self):
218 | device = self.initted.device
219 | self.ema_model.to(device)
220 |
221 | def get_params_iter(self, model):
222 | for name, param in model.named_parameters():
223 | if name not in self.parameter_names:
224 | continue
225 | yield name, param
226 |
227 | def get_buffers_iter(self, model):
228 | for name, buffer in model.named_buffers():
229 | if name not in self.buffer_names:
230 | continue
231 | yield name, buffer
232 |
233 | def copy_params_from_model_to_ema(self):
234 | copy = self.inplace_copy
235 |
236 | for (_, ma_params), (_, current_params) in zip(self.get_params_iter(self.ema_model), self.get_params_iter(self.model)):
237 | copy(ma_params.data, current_params.data)
238 |
239 | for (_, ma_buffers), (_, current_buffers) in zip(self.get_buffers_iter(self.ema_model), self.get_buffers_iter(self.model)):
240 | copy(ma_buffers.data, current_buffers.data)
241 |
242 | def copy_params_from_ema_to_model(self):
243 | copy = self.inplace_copy
244 |
245 | for (_, ma_params), (_, current_params) in zip(self.get_params_iter(self.ema_model), self.get_params_iter(self.model)):
246 | copy(current_params.data, ma_params.data)
247 |
248 | for (_, ma_buffers), (_, current_buffers) in zip(self.get_buffers_iter(self.ema_model), self.get_buffers_iter(self.model)):
249 | copy(current_buffers.data, ma_buffers.data)
250 |
251 | def update_model_with_ema(self, decay = None):
252 | if not exists(decay):
253 | decay = self.update_model_with_ema_beta
254 |
255 | if decay == 0.:
256 | return self.copy_params_from_ema_to_model()
257 |
258 | self.update_moving_average(self.model, self.ema_model, decay)
259 |
260 | def get_current_decay(self):
261 | epoch = (self.step - self.update_after_step - 1).clamp(min = 0.)
262 | value = 1 - (1 + epoch / self.inv_gamma) ** - self.power
263 |
264 | if epoch.item() <= 0:
265 | return 0.
266 |
267 | return value.clamp(min = self.min_value, max = self.beta).item()
268 |
269 | def update(self):
270 | step = self.step.item()
271 | self.step += 1
272 |
273 | if not self.initted.item():
274 | if not exists(self.ema_model):
275 | self.init_ema()
276 |
277 | self.copy_params_from_model_to_ema()
278 | self.initted.data.copy_(torch.tensor(True))
279 | return
280 |
281 | should_update = divisible_by(step, self.update_every)
282 |
283 | if should_update and step <= self.update_after_step:
284 | self.copy_params_from_model_to_ema()
285 | return
286 |
287 | if should_update:
288 | self.update_moving_average(self.ema_model, self.model)
289 |
290 | if exists(self.update_model_with_ema_every) and divisible_by(step, self.update_model_with_ema_every):
291 | self.update_model_with_ema()
292 |
293 | @torch.no_grad()
294 | def update_moving_average(self, ma_model, current_model, current_decay = None):
295 | if self.is_frozen:
296 | return
297 |
298 | # move ema model to online model device if not same and needed
299 |
300 | if self.move_ema_to_online_device and get_module_device(ma_model) != get_module_device(current_model):
301 | ma_model.to(get_module_device(current_model))
302 |
303 | # get current decay
304 |
305 | if not exists(current_decay):
306 | current_decay = self.get_current_decay()
307 |
308 | # store all source and target tensors to copy or lerp
309 |
310 | tensors_to_copy = []
311 | tensors_to_lerp = []
312 |
313 | # loop through parameters
314 |
315 | for (name, current_params), (_, ma_params) in zip(self.get_params_iter(current_model), self.get_params_iter(ma_model)):
316 | if name in self.ignore_names:
317 | continue
318 |
319 | if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]):
320 | continue
321 |
322 | if name in self.param_or_buffer_names_no_ema:
323 | tensors_to_copy.append((ma_params.data, current_params.data))
324 | continue
325 |
326 | tensors_to_lerp.append((ma_params.data, current_params.data))
327 |
328 | # loop through buffers
329 |
330 | for (name, current_buffer), (_, ma_buffer) in zip(self.get_buffers_iter(current_model), self.get_buffers_iter(ma_model)):
331 | if name in self.ignore_names:
332 | continue
333 |
334 | if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]):
335 | continue
336 |
337 | if name in self.param_or_buffer_names_no_ema:
338 | tensors_to_copy.append((ma_buffer.data, current_buffer.data))
339 | continue
340 |
341 | tensors_to_lerp.append((ma_buffer.data, current_buffer.data))
342 |
343 | # execute inplace copy or lerp
344 |
345 | if not self.use_foreach:
346 |
347 | for tgt, src in tensors_to_copy:
348 | self.inplace_copy(tgt, src)
349 |
350 | for tgt, src in tensors_to_lerp:
351 | self.inplace_lerp(tgt, src, 1. - current_decay)
352 |
353 | else:
354 | # use foreach if available and specified
355 |
356 | if self.allow_different_devices:
357 | tensors_to_copy = [(tgt, src.to(tgt.device)) for tgt, src in tensors_to_copy]
358 | tensors_to_lerp = [(tgt, src.to(tgt.device)) for tgt, src in tensors_to_lerp]
359 |
360 | if self.coerce_dtype:
361 | tensors_to_copy = [(tgt, maybe_coerce_dtype(src, tgt.dtype)) for tgt, src in tensors_to_copy]
362 | tensors_to_lerp = [(tgt, maybe_coerce_dtype(src, tgt.dtype)) for tgt, src in tensors_to_lerp]
363 |
364 | if len(tensors_to_copy) > 0:
365 | tgt_copy, src_copy = zip(*tensors_to_copy)
366 | torch._foreach_copy_(tgt_copy, src_copy)
367 |
368 | if len(tensors_to_lerp) > 0:
369 | tgt_lerp, src_lerp = zip(*tensors_to_lerp)
370 | torch._foreach_lerp_(tgt_lerp, src_lerp, 1. - current_decay)
371 |
372 | def __call__(self, *args, **kwargs):
373 | return self.ema_model(*args, **kwargs)
374 |
--------------------------------------------------------------------------------
/ema_pytorch/post_hoc_ema.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Callable, Literal
3 |
4 | from pathlib import Path
5 | from copy import deepcopy
6 | from functools import partial
7 |
8 | import torch
9 | from torch import nn, Tensor
10 | from torch.nn import Module, ModuleList
11 |
12 | import numpy as np
13 |
14 | def exists(val):
15 | return val is not None
16 |
17 | def default(val, d):
18 | return val if exists(val) else d
19 |
20 | def first(arr):
21 | return arr[0]
22 |
23 | def divisible_by(num, den):
24 | return (num % den) == 0
25 |
26 | def get_module_device(m: Module):
27 | return next(m.parameters()).device
28 |
29 | def inplace_copy(tgt: Tensor, src: Tensor, *, auto_move_device = False):
30 | if auto_move_device:
31 | src = src.to(tgt.device)
32 |
33 | tgt.copy_(src)
34 |
35 | def inplace_lerp(tgt: Tensor, src: Tensor, weight, *, auto_move_device = False):
36 | if auto_move_device:
37 | src = src.to(tgt.device)
38 |
39 | tgt.lerp_(src, weight)
40 |
41 | # algorithm 2 in https://arxiv.org/abs/2312.02696
42 |
43 | def sigma_rel_to_gamma(sigma_rel):
44 | t = sigma_rel ** -2
45 | return np.roots([1, 7, 16 - t, 12 - t]).real.max().item()
46 |
47 | class KarrasEMA(Module):
48 | """
49 | exponential moving average module that uses hyperparameters from the paper https://arxiv.org/abs/2312.02696
50 | can either use gamma or sigma_rel from paper
51 | """
52 |
53 | def __init__(
54 | self,
55 | model: Module,
56 | sigma_rel: float | None = None,
57 | gamma: float | None = None,
58 | ema_model: Module | Callable[[], Module] | None = None, # if your model has lazylinears or other types of non-deepcopyable modules, you can pass in your own ema model
59 | update_every: int = 100,
60 | frozen: bool = False,
61 | param_or_buffer_names_no_ema: set[str] = set(),
62 | ignore_names: set[str] = set(),
63 | ignore_startswith_names: set[str] = set(),
64 | allow_different_devices = False, # if the EMA model is on a different device (say CPU), automatically move the tensor
65 | move_ema_to_online_device = False # will move entire EMA model to the same device as online model, if different
66 | ):
67 | super().__init__()
68 |
69 | assert exists(sigma_rel) ^ exists(gamma), 'either sigma_rel or gamma is given. gamma is derived from sigma_rel as in the paper, then beta is dervied from gamma'
70 |
71 | if exists(sigma_rel):
72 | gamma = sigma_rel_to_gamma(sigma_rel)
73 |
74 | self.gamma = gamma
75 | self.frozen = frozen
76 |
77 | self.online_model = [model]
78 |
79 | # handle callable returning ema module
80 |
81 | if not isinstance(ema_model, Module) and callable(ema_model):
82 | ema_model = ema_model()
83 |
84 | # ema model
85 |
86 | self.ema_model = ema_model
87 |
88 | if not exists(self.ema_model):
89 | try:
90 | self.ema_model = deepcopy(model)
91 | except Exception as e:
92 | print(f'Error: While trying to deepcopy model: {e}')
93 | print('Your model was not copyable. Please make sure you are not using any LazyLinear')
94 | exit()
95 |
96 | for p in self.ema_model.parameters():
97 | p.detach_()
98 |
99 | # parameter and buffer names
100 |
101 | self.parameter_names = {name for name, param in self.ema_model.named_parameters() if torch.is_floating_point(param) or torch.is_complex(param)}
102 | self.buffer_names = {name for name, buffer in self.ema_model.named_buffers() if torch.is_floating_point(buffer) or torch.is_complex(buffer)}
103 |
104 | # tensor update functions
105 |
106 | self.inplace_copy = partial(inplace_copy, auto_move_device = allow_different_devices)
107 | self.inplace_lerp = partial(inplace_lerp, auto_move_device = allow_different_devices)
108 |
109 | # updating hyperparameters
110 |
111 | self.update_every = update_every
112 |
113 | assert isinstance(param_or_buffer_names_no_ema, (set, list))
114 | self.param_or_buffer_names_no_ema = param_or_buffer_names_no_ema # parameter or buffer
115 |
116 | self.ignore_names = ignore_names
117 | self.ignore_startswith_names = ignore_startswith_names
118 |
119 | # whether to manage if EMA model is kept on a different device
120 |
121 | self.allow_different_devices = allow_different_devices
122 |
123 | # whether to move EMA model to online model device automatically
124 |
125 | self.move_ema_to_online_device = move_ema_to_online_device
126 |
127 | # init and step states
128 |
129 | self.register_buffer('initted', torch.tensor(False))
130 | self.register_buffer('step', torch.tensor(0))
131 |
132 | @property
133 | def model(self):
134 | return first(self.online_model)
135 |
136 | @property
137 | def beta(self):
138 | return (1. - 1. / (self.step.item() + 1.)) ** (1. + self.gamma)
139 |
140 | def eval(self):
141 | return self.ema_model.eval()
142 |
143 | def restore_ema_model_device(self):
144 | device = self.initted.device
145 | self.ema_model.to(device)
146 |
147 | def get_params_iter(self, model):
148 | for name, param in model.named_parameters():
149 | if name not in self.parameter_names:
150 | continue
151 | yield name, param
152 |
153 | def get_buffers_iter(self, model):
154 | for name, buffer in model.named_buffers():
155 | if name not in self.buffer_names:
156 | continue
157 | yield name, buffer
158 |
159 | def copy_params_from_model_to_ema(self):
160 | copy = self.inplace_copy
161 |
162 | for (_, ma_params), (_, current_params) in zip(self.get_params_iter(self.ema_model), self.get_params_iter(self.model)):
163 | copy(ma_params.data, current_params.data)
164 |
165 | for (_, ma_buffers), (_, current_buffers) in zip(self.get_buffers_iter(self.ema_model), self.get_buffers_iter(self.model)):
166 | copy(ma_buffers.data, current_buffers.data)
167 |
168 | def copy_params_from_ema_to_model(self):
169 | copy = self.inplace_copy
170 |
171 | for (_, ma_params), (_, current_params) in zip(self.get_params_iter(self.ema_model), self.get_params_iter(self.model)):
172 | copy(current_params.data, ma_params.data)
173 |
174 | for (_, ma_buffers), (_, current_buffers) in zip(self.get_buffers_iter(self.ema_model), self.get_buffers_iter(self.model)):
175 | copy(current_buffers.data, ma_buffers.data)
176 |
177 | def update(self):
178 | step = self.step.item()
179 | self.step += 1
180 |
181 | if (step % self.update_every) != 0:
182 | return
183 |
184 | if not self.initted.item():
185 | self.copy_params_from_model_to_ema()
186 | self.initted.data.copy_(torch.tensor(True))
187 |
188 | self.update_moving_average(self.ema_model, self.model)
189 |
190 | def iter_all_ema_params_and_buffers(self):
191 | for name, ma_params in self.get_params_iter(self.ema_model):
192 | if name in self.ignore_names:
193 | continue
194 |
195 | if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]):
196 | continue
197 |
198 | if name in self.param_or_buffer_names_no_ema:
199 | continue
200 |
201 | yield ma_params
202 |
203 | for name, ma_buffer in self.get_buffers_iter(self.ema_model):
204 | if name in self.ignore_names:
205 | continue
206 |
207 | if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]):
208 | continue
209 |
210 | if name in self.param_or_buffer_names_no_ema:
211 | continue
212 |
213 | yield ma_buffer
214 |
215 | @torch.no_grad()
216 | def update_moving_average(self, ma_model, current_model):
217 | if self.frozen:
218 | return
219 |
220 | # move ema model to online model device if not same and needed
221 |
222 | if self.move_ema_to_online_device and get_module_device(ma_model) != get_module_device(current_model):
223 | ma_model.to(get_module_device(current_model))
224 |
225 | # get some functions and current decay
226 |
227 | copy, lerp = self.inplace_copy, self.inplace_lerp
228 | current_decay = self.beta
229 |
230 | for (name, current_params), (_, ma_params) in zip(self.get_params_iter(current_model), self.get_params_iter(ma_model)):
231 | if name in self.ignore_names:
232 | continue
233 |
234 | if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]):
235 | continue
236 |
237 | if name in self.param_or_buffer_names_no_ema:
238 | copy(ma_params.data, current_params.data)
239 | continue
240 |
241 | lerp(ma_params.data, current_params.data, 1. - current_decay)
242 |
243 | for (name, current_buffer), (_, ma_buffer) in zip(self.get_buffers_iter(current_model), self.get_buffers_iter(ma_model)):
244 | if name in self.ignore_names:
245 | continue
246 |
247 | if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]):
248 | continue
249 |
250 | if name in self.param_or_buffer_names_no_ema:
251 | copy(ma_buffer.data, current_buffer.data)
252 | continue
253 |
254 | lerp(ma_buffer.data, current_buffer.data, 1. - current_decay)
255 |
256 | def __call__(self, *args, **kwargs):
257 | return self.ema_model(*args, **kwargs)
258 |
259 | # post hoc ema wrapper
260 |
261 | # solving of the weights for combining all checkpoints into a newly synthesized EMA at desired gamma
262 | # Algorithm 3 copied from paper, redone in torch
263 |
264 | def p_dot_p(t_a, gamma_a, t_b, gamma_b):
265 | t_ratio = t_a / t_b
266 | t_exp = torch.where(t_a < t_b , gamma_b , -gamma_a)
267 | t_max = torch.maximum(t_a , t_b)
268 | num = (gamma_a + 1) * (gamma_b + 1) * t_ratio ** t_exp
269 | den = (gamma_a + gamma_b + 1) * t_max
270 | return num / den
271 |
272 | def solve_weights(t_i, gamma_i, t_r, gamma_r):
273 | rv = lambda x: x.double().reshape(-1, 1)
274 | cv = lambda x: x.double().reshape(1, -1)
275 | A = p_dot_p(rv(t_i), rv(gamma_i), cv(t_i), cv(gamma_i))
276 | b = p_dot_p(rv(t_i), rv(gamma_i), cv(t_r), cv(gamma_r))
277 | return torch.linalg.solve(A, b)
278 |
279 | class PostHocEMA(Module):
280 |
281 | def __init__(
282 | self,
283 | model: Module,
284 | ema_model: Callable[[], Module] | None = None,
285 | sigma_rels: tuple[float, ...] | None = None,
286 | gammas: tuple[float, ...] | None = None,
287 | checkpoint_every_num_steps: int | Literal['manual'] = 1000,
288 | checkpoint_folder: str = './post-hoc-ema-checkpoints',
289 | checkpoint_dtype: torch.dtype = torch.float16,
290 | **kwargs
291 | ):
292 | super().__init__()
293 | assert exists(sigma_rels) ^ exists(gammas)
294 |
295 | if exists(sigma_rels):
296 | gammas = tuple(map(sigma_rel_to_gamma, sigma_rels))
297 |
298 | assert len(gammas) > 1, 'at least 2 ema models with different gammas in order to synthesize new ema models of a different gamma'
299 | assert len(set(gammas)) == len(gammas), 'calculated gammas must be all unique'
300 |
301 | self.maybe_ema_model = ema_model
302 |
303 | self.gammas = gammas
304 | self.num_ema_models = len(gammas)
305 |
306 | self._model = [model]
307 | self.ema_models = ModuleList([KarrasEMA(model, ema_model = ema_model, gamma = gamma, **kwargs) for gamma in gammas])
308 |
309 | self.checkpoint_folder = Path(checkpoint_folder)
310 | self.checkpoint_folder.mkdir(exist_ok = True, parents = True)
311 | assert self.checkpoint_folder.is_dir()
312 |
313 | self.checkpoint_every_num_steps = checkpoint_every_num_steps
314 | self.checkpoint_dtype = checkpoint_dtype
315 | self.ema_kwargs = kwargs
316 |
317 | @property
318 | def model(self):
319 | return first(self._model)
320 |
321 | @property
322 | def step(self):
323 | return first(self.ema_models).step
324 |
325 | @property
326 | def device(self):
327 | return self.step.device
328 |
329 | def copy_params_from_model_to_ema(self):
330 | for ema_model in self.ema_models:
331 | ema_model.copy_params_from_model_to_ema()
332 |
333 | def copy_params_from_ema_to_model(self):
334 | for ema_model in self.ema_models:
335 | ema_model.copy_params_from_ema_to_model()
336 |
337 | def update(self):
338 | for ema_model in self.ema_models:
339 | ema_model.update()
340 |
341 | if self.checkpoint_every_num_steps == 'manual':
342 | return
343 |
344 | if divisible_by(self.step.item(), self.checkpoint_every_num_steps):
345 | self.checkpoint()
346 |
347 | def checkpoint(self):
348 | step = self.step.item()
349 |
350 | for ind, ema_model in enumerate(self.ema_models):
351 | filename = f'{ind}.{step}.pt'
352 | path = self.checkpoint_folder / filename
353 |
354 | pkg = {
355 | k: v.to(device = 'cpu', dtype = self.checkpoint_dtype, copy = True)
356 | for k, v in ema_model.state_dict().items()
357 | }
358 |
359 | torch.save(pkg, str(path))
360 |
361 | def synthesize_ema_model(
362 | self,
363 | gamma: float | None = None,
364 | sigma_rel: float | None = None,
365 | step: int | None = None,
366 | ) -> KarrasEMA:
367 | assert exists(gamma) ^ exists(sigma_rel)
368 | device = self.device
369 |
370 | if exists(sigma_rel):
371 | gamma = sigma_rel_to_gamma(sigma_rel)
372 |
373 | synthesized_ema_model = KarrasEMA(
374 | model = self.model,
375 | ema_model = self.maybe_ema_model,
376 | gamma = gamma,
377 | **self.ema_kwargs
378 | )
379 |
380 | synthesized_ema_model
381 |
382 | # get all checkpoints
383 |
384 | gammas = []
385 | timesteps = []
386 | checkpoints = [*self.checkpoint_folder.glob('*.pt')]
387 |
388 | for file in checkpoints:
389 | gamma_ind, timestep = map(int, file.stem.split('.'))
390 | gammas.append(self.gammas[gamma_ind])
391 | timesteps.append(timestep)
392 |
393 | step = default(step, max(timesteps))
394 | assert step <= max(timesteps), f'you can only synthesize for a timestep that is less than the max timestep {max(timesteps)}'
395 |
396 | # line up with Algorithm 3
397 |
398 | gamma_i = torch.tensor(gammas, device = device)
399 | t_i = torch.tensor(timesteps, device = device)
400 |
401 | gamma_r = torch.tensor([gamma], device = device)
402 | t_r = torch.tensor([step], device = device)
403 |
404 | # solve for weights for combining all checkpoints into synthesized, using least squares as in paper
405 |
406 | weights = solve_weights(t_i, gamma_i, t_r, gamma_r)
407 | weights = weights.squeeze(-1)
408 |
409 | # now sum up all the checkpoints using the weights one by one
410 |
411 | tmp_ema_model = KarrasEMA(
412 | model = self.model,
413 | ema_model = self.maybe_ema_model,
414 | gamma = gamma,
415 | **self.ema_kwargs
416 | )
417 |
418 | for ind, (checkpoint, weight) in enumerate(zip(checkpoints, weights.tolist())):
419 | is_first = ind == 0
420 |
421 | # load checkpoint into a temporary ema model
422 |
423 | ckpt_state_dict = torch.load(str(checkpoint), weights_only=True)
424 | tmp_ema_model.load_state_dict(ckpt_state_dict)
425 |
426 | # add weighted checkpoint to synthesized
427 |
428 | for ckpt_tensor, synth_tensor in zip(tmp_ema_model.iter_all_ema_params_and_buffers(), synthesized_ema_model.iter_all_ema_params_and_buffers()):
429 | if is_first:
430 | synth_tensor.zero_()
431 |
432 | synth_tensor.add_(ckpt_tensor * weight)
433 |
434 | # return the synthesized model
435 |
436 | return synthesized_ema_model
437 |
438 | def __call__(self, *args, **kwargs):
439 | return tuple(ema_model(*args, **kwargs) for ema_model in self.ema_models)
440 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'ema-pytorch',
5 | packages = find_packages(exclude=[]),
6 | version = '0.7.7',
7 | license='MIT',
8 | description = 'Easy way to keep track of exponential moving average version of your pytorch module',
9 | author = 'Phil Wang',
10 | author_email = 'lucidrains@gmail.com',
11 | long_description_content_type = 'text/markdown',
12 | url = 'https://github.com/lucidrains/ema-pytorch',
13 | keywords = [
14 | 'artificial intelligence',
15 | 'deep learning',
16 | 'exponential moving average'
17 | ],
18 | install_requires=[
19 | 'torch>=2.0',
20 | ],
21 | classifiers=[
22 | 'Development Status :: 4 - Beta',
23 | 'Intended Audience :: Developers',
24 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
25 | 'License :: OSI Approved :: MIT License',
26 | 'Programming Language :: Python :: 3.6',
27 | ],
28 | )
29 |
--------------------------------------------------------------------------------