├── README.md ├── ema_eq.png ├── ema_eq.py └── image.png /README.md: -------------------------------------------------------------------------------- 1 | # Karras Power Function EMA (Post-training EMA synthesis) 2 | 3 | This tutorial-repo implements the Karras's Power function EMA, quite incredible trick introduced in the paper [Analyzing and Improving the Training Dynamics of Diffusion Models](https://arxiv.org/abs/2312.02696v1) by Tero Karras, Miika Aittala, Jaakko Lehtinen, Janne Hellsten, Timo Aila, Samuli Laine. 4 | 5 | 6 | 7 | 8 | # So What is Karras's Power function EMA? 9 | 10 | I recommend you to read the paper for full detail, but here is the big picture. 11 | 12 | Recall that EMA'ing checkpoint is about keeping track of smooth-version of model parameters, $\theta_\beta$, where $\theta_\beta(t) = \beta \theta_\beta(t-1) + (1-\beta) \theta(t)$ 13 | 14 | , where $\beta$ is the decaying factor close to 1. Using EMA typically makes the model more robust, and it is a common practice in training deep neural networks. 15 | 16 | 17 | You want to use EMA, but... 18 | 19 | 1. You *don't* want the ema to be too slow, because it will make random initialization's contribution to the final model too big. 20 | 2. You definitely want the decaying factor to be self-similar, because you should be able to *increase-time* of the training. 21 | 3. You want to set decaying factor *post-hoc*, because you don't want to retrain the model from scratch with different decaying factor. 22 | 23 | Karras's Power function EMA is the answer to all of these problems. He first uses power-function version of EMA where instead of keeping beta constant, he uses $\beta(t) = (1 - 1/t)^{1 + \gamma}$, where $\gamma$ is the hyperparameter. This makes the contribution of historical parameters self-similar, and you can increase the time of the training and it will not change how you expect the EMA to behave (i.e., if you want first 10% of the training to contribute x% of the final model, increasing/decreasing training time will not change that). 24 | 25 | # Overall Algorithm and Implementation 26 | 27 | So there is two main part of the algorithm. 28 | 29 | 1. Saving two copies of the EMA-model, each with different width. 30 | 2. Recovering arbitrary-width EMA 31 | 32 | Think of width as decaying factor. Larger width means it will be *smoother*. 33 | 34 | ![Alt text](image.png) 35 | 36 | ## First, save two copies of the EMA, with different width 37 | 38 | This is the easy part. You just need to save two copies of the EMA, each with different width (different $\gamma$). 39 | 40 | ```python 41 | gamma_1 = 5 42 | gamma_2 = 10 43 | model = Model() 44 | model_ema_1 = copy.deepcopy(model).cpu() 45 | model_ema_2 = copy.deepcopy(model).cpu() 46 | 47 | 48 | for i, batch in enumerate(data_loader): 49 | beta_1 = (1 - 1/(i+1)) ** (1 + gamma_1) 50 | beta_2 = (1 - 1/(i+1)) ** (1 + gamma_2) 51 | # train model 52 | loss.backward() 53 | optimizer.step() 54 | for p, p_ema_1, p_ema_2 in zip(model.parameters(), model_ema_1.parameters(), model_ema_2.parameters()): 55 | p_ema_1.data = p_ema_1.data * beta_1 + p.data * (1 - beta_1) 56 | p_ema_2.data = p_ema_2.data * beta_2 + p.data * (1 - beta_2) 57 | 58 | if i % save_freq == 0: 59 | torch.save(model_ema_1.state_dict(), f'./model_ema_1_{i}.pth') 60 | torch.save(model_ema_2.state_dict(), f'./model_ema_2_{i}.pth') 61 | 62 | ``` 63 | 64 | 65 | ## Second, recover arbitrary-decay EMA after training. 66 | 67 | Now what if you want to recover EMA with $\gamma_3$? Incredibly, you can do this with all the saved checkpoints. The math behind this in the paper is bit... *not straightforward* so here is my version of the explanation. 68 | 69 | EMA, by definition, can be considered as integral of trajectory of the model parameters. So if you have some weighting function $w(t)$, such that 70 | 71 | $$\theta_e(T) = \int_0^t w(t) \theta(t) dt$$ 72 | 73 | For a fixed training runs, $t \in [0, T]$, because we saved two copies of EMA for every, say, total of $n$ checkpoints for different $\gamma$ this means we know the integral value of the trajectory of the 74 | 75 | $$\theta_{i,j} = \int_0^T w_{i, j}(t) \theta(t) dt$$ 76 | 77 | for $i = 1, 2$ and $j = 1, 2, \cdots, n$. $i$ correponds to different width, $j$ corresponds to $j$ th checkpoint. Notice how 78 | 79 | $$ 80 | w_{i, j}(t) = \begin{cases} 81 | t^{\gamma_i} / g_{i,j} & \text{if } t < j \\ 82 | 0 & \text{otherwise} 83 | \end{cases} $$ 84 | 85 | where $g_{i,j}$ is simply the normalization constant to make $\int_0^T w_{i, j}(t) dt = 1$. 86 | 87 | Our goal is then to 88 | 89 | 1. find a approximate $\hat{w}_3(t)$ that will give us the EMA that corresponds with $\gamma_3$. 90 | 91 | 2. find the correpsonding $\theta_{3,T}$ 92 | 93 | See where this is going? Our goal is to approximate $w_3(t)$ as linear combination of $w_{1, j}(t)$ and $w_{2, j}(t)$, i.e., 94 | 95 | $$w_3(t) = \sum_{j=1}^n \alpha_j w_{1, j}(t) + \beta_j w_{2, j}(t)$$ 96 | 97 | where $\alpha_j$ and $\beta_j$ are the coefficients we need to find. This way, 98 | 99 | 100 | $$\theta_{3,T} = \int_0^T w_3(t) \theta(t) dt = \sum_{j=1}^n \alpha_j \theta_{1, j} + \beta_j \theta_{2, j}$$ 101 | 102 | Aha! Now we can find $\alpha_j$ and $\beta_j$ by solving the linear system of equations. Let's just take this one step further. 103 | 104 | ## Goal let us project $w_3(t)$ onto the subspace spanned by $w_{1, j}(t)$ and $w_{2, j}(t)$ 105 | 106 | We have $K$ functions $f_k(t)$, we have a target $g(t)$ and we want to find $k$ coefficients $\alpha_i$ such that 107 | 108 | $$\min \int_0^T \left( g(t) - \sum_{i=1}^K \alpha_i f_i(t) \right)^2 dt$$ 109 | 110 | How would you solve this? 111 | 112 | Define inner product as 113 | 114 | $$\langle f, g \rangle = \int_0^T f(t) g(t) dt$$ 115 | 116 | Then we can rewrite the problem as 117 | 118 | $$\min \left \| g - \sum_{i=1}^K \alpha_i f_i \right \|_2^2$$ 119 | 120 | if we define $\| f \|_2 = \sqrt{\langle f, f \rangle}$, expanding the norm, we get 121 | 122 | ```math 123 | \min {\left \| g \right \|_2}^2 - 2 \sum_{i=1}^K \alpha_i \langle g, f_i \rangle + \sum_{i=1}^K \sum_{j=1}^K \alpha_i \alpha_j \langle f_i, f_j \rangle 124 | ``` 125 | 126 | Ha, so substituting $A_{i,j} = \langle f_i, f_j \rangle$ and $b_i = \langle g, f_i \rangle$, we actually just had linear least square problem! 127 | 128 | $$\min \left \| g \right \|_2^2 - 2 \alpha^T b + \alpha^T A \alpha$$ 129 | 130 | where $\alpha = (\alpha_1, \cdots, \alpha_K)^T$. 131 | 132 | So the solution is simply 133 | 134 | $$\alpha = A^{+} b$$ 135 | 136 | where $A^{+}$ is the pseudo-inverse of $A$. We are left to just use the $\alpha$ to get $\theta_{3,T}$. 137 | 138 | > Note : Well if you ever studied functional analysis, you realize hey, *there exists unique solution* to this problem, via *Hilbert's Projection Theorem*. The above is simply finding the projection of $g$ onto the subspace spanned by $f_i$, in $L^2$ space. 139 | 140 | So thing you learned: 141 | 142 | 1. The level of approximation is determined by the number of checkpoints you saved. More checkpoints, better approximation. 143 | 2. This doesn't have to be power-function EMA. You can use any weighting function $w(t)$, as long as you can compute the integral of the trajectory of the model parameters. 144 | 145 | # I don't care about the math just give me the code? 146 | 147 | Ok, but reminder this is just for a power-function EMA. You can use this for any weighting function $w(t)$. 148 | 149 | In the above code, you saved two copies of EMA, each with different $\gamma$. Now you want to recover EMA with $\gamma_3$. Suppose you saved $n$ checkpoints, at iteration $i_1, i_2, \cdots, i_n$. Then you can do the following. 150 | 151 | ```python 152 | t_checkpoint = t[checkpoint_index] 153 | 154 | ts = np.concatenate((t_checkpoint, t_checkpoint)) 155 | gammas = np.concatenate( 156 | ( 157 | np.ones_like(checkpoint_index) * gamma_1, 158 | np.ones_like(checkpoint_index) * gamma_2, 159 | ) 160 | ) 161 | 162 | x = solve_weights(ts, gammas, last_index, gamma_3) 163 | emapoints = np.concatenate((y_t_ema1[checkpoint_index], y_t_ema2[checkpoint_index])) 164 | 165 | y_t_ema3 = np.dot(x, emapoints) 166 | ``` 167 | 168 | where `solve_weights` is the function that solves the linear least square problem. You can find the implementation in `ema_eq.py`. 169 | 170 | The result is the EMA with $\gamma_3$. 171 | 172 | ![Alt text](ema_eq.png) -------------------------------------------------------------------------------- /ema_eq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cloneofsimo/karras-power-ema-tutorial/d850b390096893cd1bd050c2d253bd94ef5611a0/ema_eq.png -------------------------------------------------------------------------------- /ema_eq.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import matplotlib.pyplot as plt 4 | 5 | 6 | def p_dot_p(t_a, gamma_a, t_b, gamma_b): 7 | t_ratio = t_a / t_b 8 | t_exp = np.where(t_a < t_b, gamma_b, -gamma_a) 9 | t_max = np.maximum(t_a, t_b) 10 | num = (gamma_a + 1) * (gamma_b + 1) * t_ratio**t_exp 11 | den = (gamma_a + gamma_b + 1) * t_max 12 | return num / den 13 | 14 | 15 | def solve_weights(t_i, gamma_i, t_r, gamma_r): 16 | rv = lambda x: np.float64(x).reshape(-1, 1) 17 | cv = lambda x: np.float64(x).reshape(1, -1) 18 | A = p_dot_p(rv(t_i), rv(gamma_i), cv(t_i), cv(gamma_i)) 19 | B = p_dot_p(rv(t_i), rv(gamma_i), cv(t_r), cv(gamma_r)) 20 | X = np.linalg.solve(A, B) 21 | return X 22 | 23 | 24 | def power_ema(y, gamma, t): 25 | ema_y = np.zeros_like(y) 26 | ema_y[0] = y[0] 27 | for i in range(1, len(y)): 28 | beta_t = (1 - 1 / t[i]) ** (gamma + 1) 29 | ema_y[i] = beta_t * ema_y[i - 1] + (1 - beta_t) * y[i] 30 | return ema_y 31 | 32 | 33 | if __name__ == "__main__": 34 | N = 1000 35 | t = np.arange(1, N + 1) 36 | 37 | checkpoint_freq = 100 38 | checkpoint_index = np.arange(checkpoint_freq - 1, N, checkpoint_freq) 39 | # [99, 199, 299, 399, 499, 599, 699, 799, 899, 999] 40 | 41 | print(checkpoint_index) 42 | 43 | y_t_1d = ( 44 | 10 + math.sqrt(N) * np.sin(t / 300) + np.cumsum(np.random.normal(0, 0.5, N)) 45 | ) 46 | 47 | gamma_1 = 3 48 | gamma_2 = 16 49 | gamma_3 = 8 50 | 51 | y_t_ema1 = power_ema(y_t_1d, gamma_1, t) 52 | y_t_ema2 = power_ema(y_t_1d, gamma_2, t) 53 | 54 | y_t_ema3_ground_truth = power_ema(y_t_1d, gamma_3, t) 55 | 56 | ema3_last_ground_truth = y_t_ema3_ground_truth[-1] 57 | last_index = t[-1] 58 | 59 | t_checkpoint = t[checkpoint_index] 60 | 61 | ts = np.concatenate((t_checkpoint, t_checkpoint)) 62 | gammas = np.concatenate( 63 | ( 64 | np.ones_like(checkpoint_index) * gamma_1, 65 | np.ones_like(checkpoint_index) * gamma_2, 66 | ) 67 | ) 68 | 69 | x = solve_weights(ts, gammas, last_index, gamma_3) 70 | emapoints = np.concatenate((y_t_ema1[checkpoint_index], y_t_ema2[checkpoint_index])) 71 | print(x) 72 | 73 | ema3_last_approximated = np.dot(x.reshape(-1), emapoints.reshape(-1)) 74 | 75 | print(f"EMA3 Last Ground Truth: {ema3_last_ground_truth}") 76 | print(f"EMA3 Last Approximated: {ema3_last_approximated}") 77 | 78 | # Plotting 79 | plt.figure(figsize=(12, 6)) 80 | 81 | plt.plot(t, y_t_1d, label="Original Data", color="gray", alpha=0.7) 82 | plt.plot(t, y_t_ema1, label=f"EMA Gamma={gamma_1}", color="blue") 83 | plt.plot(t, y_t_ema2, label=f"EMA Gamma={gamma_2}", color="green") 84 | plt.plot( 85 | t, 86 | y_t_ema3_ground_truth, 87 | label=f"EMA Gamma={gamma_3} (Ground Truth)", 88 | color="red", 89 | ) 90 | 91 | plt.scatter( 92 | t[checkpoint_index], 93 | y_t_1d[checkpoint_index], 94 | color="black", 95 | marker="x", 96 | label="Checkpoints", 97 | ) 98 | 99 | plt.scatter( 100 | last_index, 101 | ema3_last_ground_truth, 102 | color="red", 103 | marker="x", 104 | label="EMA3 Last Ground Truth", 105 | ) 106 | plt.scatter( 107 | last_index, 108 | ema3_last_approximated, 109 | color="orange", 110 | marker="x", 111 | label="EMA3 Last Approximated", 112 | ) 113 | 114 | plt.title("Power Exponential Moving Average (EMA) Comparison per Gamma and its approximation") 115 | plt.xlabel("Time") 116 | plt.ylabel("Values") 117 | plt.legend() 118 | plt.grid(True) 119 | plt.show() 120 | 121 | save_path = "ema_eq.png" 122 | plt.savefig(save_path, dpi=300) 123 | -------------------------------------------------------------------------------- /image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cloneofsimo/karras-power-ema-tutorial/d850b390096893cd1bd050c2d253bd94ef5611a0/image.png --------------------------------------------------------------------------------