├── LICENSE ├── README.md └── soap.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Nikhil Vyas 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SOAP 2 | 3 | This is the official (preliminary) implementation of the SOAP optimizer from [SOAP: Improving and Stabilizing Shampoo using Adam](https://arxiv.org/abs/2409.11321). To use, copy the soap.py file to your codebase and use SOAP optimizer in the following fashion: 4 | 5 | ``` 6 | from soap import SOAP 7 | 8 | optim = SOAP(lr = 3e-3, betas=(.95, .95), weight_decay=.01, precondition_frequency=10) 9 | ``` 10 | 11 | We recommend trying it with as large batch size as possible, as expected from second order optimizers, the benefits are larger at larger batch sizes. 12 | 13 | While in the paper our experiments are restricted to Transformers which only have 2D layers, the code supports nD layers. If you are using the optimizer for (n > 2) nD layers please see additional hyperparameters in soap.py. 14 | 15 | 16 | We will release an improved version of the optimizer with support for lower precision and distributed training. 17 | 18 | 19 | Haydn Jones has implemented a JAX version at https://github.com/haydn-jones/SOAP_JAX, though we have not yet verified the implementation. 20 | -------------------------------------------------------------------------------- /soap.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | 5 | from itertools import chain 6 | 7 | # Parts of the code are modifications of Pytorch's AdamW optimizer 8 | # Parts of the code are modifications of code from https://github.com/jiaweizzhao/GaLore/blob/master/galore_torch/galore_projector.py 9 | 10 | 11 | class SOAP(optim.Optimizer): 12 | """ 13 | Implements SOAP algorithm (https://arxiv.org/abs/2409.11321). 14 | 15 | Parameters: 16 | params (`Iterable[nn.parameter.Parameter]`): 17 | Iterable of parameters to optimize or dictionaries defining parameter groups. 18 | lr (`float`, *optional*, defaults to 0.003): 19 | The learning rate to use. 20 | betas (`Tuple[float,float]`, *optional*, defaults to `(0.95, 0.95)`): 21 | Adam's betas parameters (b1, b2). 22 | shampoo_beta (`float`, *optional*, defaults to -1): 23 | If >= 0, use this beta for the preconditioner (L and R in paper, state['GG'] below) moving average instead of betas[1]. 24 | eps (`float`, *optional*, defaults to 1e-08): 25 | Adam's epsilon for numerical stability. 26 | weight_decay (`float`, *optional*, defaults to 0.01): weight decay coefficient. 27 | precondition_frequency (`int`, *optional*, defaults to 10): 28 | How often to update the preconditioner. 29 | max_precond_dim (`int`, *optional*, defaults to 10000): 30 | Maximum dimension of the preconditioner. 31 | Set to 10000, so that we exclude most common vocab sizes while including layers. 32 | merge_dims (`bool`, *optional*, defaults to `False`): 33 | Whether or not to merge dimensions of the preconditioner. 34 | precondition_1d (`bool`, *optional*, defaults to `False`): 35 | Whether or not to precondition 1D gradients. 36 | normalize_grads (`bool`, *optional*, defaults to `False`): 37 | Whether or not to normalize gradients per layer. 38 | Helps at large precondition_frequency (~100 in our experiments), 39 | but hurts performance at small precondition_frequency (~10 in our experiments). 40 | data_format (`str`, *optional*, defaults to `channels_first`): 41 | Data format of the input for convolutional layers. 42 | Should be "channels_last" for data_format of NHWC and "channels_first" for NCHW. 43 | correct_bias (`bool`, *optional*, defaults to `True`): 44 | Whether or not to use bias correction in Adam. 45 | """ 46 | 47 | def __init__( 48 | self, 49 | params, 50 | lr: float = 3e-3, 51 | betas=(0.95, 0.95), 52 | shampoo_beta: float= -1, 53 | eps: float = 1e-8, 54 | weight_decay: float = 0.01, 55 | precondition_frequency: int=10, 56 | max_precond_dim: int=10000, # 57 | merge_dims: bool = False, # Merge dimensions till the product of the dimensions is less than or equal to max_precond_dim. 58 | precondition_1d: bool = False, 59 | normalize_grads: bool = False, 60 | data_format: str = "channels_first", 61 | correct_bias: bool = True, 62 | ): 63 | defaults = { 64 | "lr": lr, 65 | "betas": betas, 66 | "shampoo_beta": shampoo_beta, 67 | "eps": eps, 68 | "weight_decay": weight_decay, 69 | "precondition_frequency": precondition_frequency, 70 | "max_precond_dim": max_precond_dim, 71 | "merge_dims": merge_dims, 72 | "precondition_1d": precondition_1d, 73 | "normalize_grads": normalize_grads, 74 | "correct_bias": correct_bias, 75 | } 76 | super().__init__(params, defaults) 77 | self._data_format = data_format 78 | 79 | def merge_dims(self, grad, max_precond_dim): 80 | """ 81 | Merges dimensions of the gradient tensor till the product of the dimensions is less than or equal to max_precond_dim. 82 | """ 83 | assert self._data_format in ["channels_first", "channels_last"] 84 | if self._data_format == "channels_last" and grad.dim() == 4: 85 | grad = grad.permute(0, 3, 1, 2) 86 | shape = grad.shape 87 | new_shape = [] 88 | 89 | curr_shape = 1 90 | for sh in shape: 91 | temp_shape = curr_shape * sh 92 | if temp_shape > max_precond_dim: 93 | if curr_shape > 1: 94 | new_shape.append(curr_shape) 95 | curr_shape = sh 96 | else: 97 | new_shape.append(sh) 98 | curr_shape = 1 99 | else: 100 | curr_shape = temp_shape 101 | 102 | if curr_shape > 1 or len(new_shape)==0: 103 | new_shape.append(curr_shape) 104 | 105 | new_grad = grad.reshape(new_shape) 106 | return new_grad 107 | 108 | @torch.no_grad() 109 | def step(self, closure = None): 110 | """ 111 | Performs a single optimization step. 112 | 113 | Arguments: 114 | closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss. 115 | """ 116 | if closure is None: 117 | loss = None 118 | else: 119 | loss = closure() 120 | 121 | for group in self.param_groups: 122 | for p in group["params"]: 123 | if p.grad is None: 124 | continue 125 | grad = p.grad 126 | 127 | state = self.state[p] 128 | 129 | if "step" not in state: 130 | state["step"] = 0 131 | 132 | # State initialization 133 | if "exp_avg" not in state: 134 | # Exponential moving average of gradient values 135 | state["exp_avg"] = torch.zeros_like(grad) 136 | # Exponential moving average of squared gradient values 137 | state["exp_avg_sq"] = torch.zeros_like(grad) 138 | 139 | if 'Q' not in state: 140 | self.init_preconditioner( 141 | grad, 142 | state, 143 | precondition_frequency=group['precondition_frequency'], 144 | precondition_1d=group['precondition_1d'], 145 | shampoo_beta=(group['shampoo_beta'] if group['shampoo_beta'] >= 0 else group["betas"][1]), 146 | max_precond_dim=group['max_precond_dim'], 147 | merge_dims=group["merge_dims"], 148 | ) 149 | self.update_preconditioner(grad, state, 150 | max_precond_dim=group['max_precond_dim'], 151 | merge_dims=group["merge_dims"], 152 | precondition_1d=group["precondition_1d"]) 153 | continue # first step is skipped so that we never use the current gradients in the projection. 154 | 155 | # Projecting gradients to the eigenbases of Shampoo's preconditioner 156 | # i.e. projecting to the eigenbases of matrices in state['GG'] 157 | grad_projected = self.project(grad, state, merge_dims=group["merge_dims"], 158 | max_precond_dim=group['max_precond_dim']) 159 | 160 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] 161 | beta1, beta2 = group["betas"] 162 | 163 | state["step"] += 1 164 | 165 | # Decay the first and second moment running average coefficient 166 | # In-place operations to update the averages at the same time 167 | exp_avg.mul_(beta1).add_(grad_projected, alpha=(1.0 - beta1)) 168 | exp_avg_sq.mul_(beta2).add_(grad_projected.square(), alpha=(1.0 - beta2)) 169 | 170 | denom = exp_avg_sq.sqrt().add_(group["eps"]) 171 | 172 | # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner 173 | # i.e. projecting to the eigenbases of matrices in state['GG'] 174 | # exp_avg_projected = self.project(exp_avg, state, merge_dims=group["merge_dims"], 175 | # max_precond_dim=group['max_precond_dim']) 176 | exp_avg_projected = exp_avg 177 | 178 | step_size = group["lr"] 179 | if group["correct_bias"]: 180 | bias_correction1 = 1.0 - beta1 ** (state["step"]) 181 | bias_correction2 = 1.0 - beta2 ** (state["step"]) 182 | step_size = step_size * (bias_correction2 ** .5) / bias_correction1 183 | 184 | # Projecting back the preconditioned (by Adam) exponential moving average of gradients 185 | # to the original space 186 | norm_grad = self.project_back(exp_avg_projected / denom, state, merge_dims=group["merge_dims"], 187 | max_precond_dim=group['max_precond_dim']) 188 | 189 | if group["normalize_grads"]: 190 | norm_grad = norm_grad / (1e-30+torch.mean(norm_grad**2)**0.5) 191 | 192 | p.add_(norm_grad, alpha=-step_size) 193 | 194 | 195 | # From AdamW code: Just adding the square of the weights to the loss function is *not* 196 | # the correct way of using L2 regularization/weight decay with Adam, 197 | # since that will interact with the m and v parameters in strange ways. 198 | # 199 | # Instead we want to decay the weights in a manner that doesn't interact 200 | # with the m/v parameters. This is equivalent to adding the square 201 | # of the weights to the loss with plain (non-momentum) SGD. 202 | # Add weight decay at the end (fixed version) 203 | if group["weight_decay"] > 0.0: 204 | p.add_(p, alpha=(-group["lr"] * group["weight_decay"])) 205 | 206 | # Update is done after the gradient step to avoid using current gradients in the projection. 207 | self.update_preconditioner(grad, state, 208 | max_precond_dim=group['max_precond_dim'], 209 | merge_dims=group["merge_dims"], 210 | precondition_1d=group["precondition_1d"]) 211 | 212 | return loss 213 | 214 | def init_preconditioner(self, grad, state, precondition_frequency=10, 215 | shampoo_beta=0.95, max_precond_dim=10000, precondition_1d=False, 216 | merge_dims=False): 217 | """ 218 | Initializes the preconditioner matrices (L and R in the paper). 219 | """ 220 | state['GG'] = [] # Will hold all the preconditioner matrices (L and R in the paper). 221 | if grad.dim() == 1: 222 | if not precondition_1d or grad.shape[0] > max_precond_dim: 223 | state['GG'].append([]) 224 | else: 225 | state['GG'].append(torch.zeros(grad.shape[0], grad.shape[0], device=grad.device)) 226 | else: 227 | if merge_dims: 228 | grad = self.merge_dims(grad, max_precond_dim) 229 | 230 | for sh in grad.shape: 231 | if sh > max_precond_dim: 232 | state['GG'].append([]) 233 | else: 234 | state['GG'].append(torch.zeros(sh, sh, device=grad.device)) 235 | 236 | state['Q'] = None # Will hold all the eigenbases of the preconditioner. 237 | state['precondition_frequency'] = precondition_frequency 238 | state['shampoo_beta'] = shampoo_beta 239 | 240 | def project(self, grad, state, merge_dims=False, max_precond_dim=10000): 241 | """ 242 | Projects the gradient to the eigenbases of the preconditioner. 243 | """ 244 | original_shape = grad.shape 245 | if merge_dims: 246 | if grad.dim() == 4 and self._data_format == 'channels_last': 247 | permuted_shape = grad.permute(0, 3, 1, 2).shape 248 | grad = self.merge_dims(grad, max_precond_dim) 249 | 250 | for mat in state['Q']: 251 | if len(mat) > 0: 252 | grad = torch.tensordot( 253 | grad, 254 | mat, 255 | dims=[[0], [0]], 256 | ) 257 | else: 258 | permute_order = list(range(1, len(grad.shape))) + [0] 259 | grad = grad.permute(permute_order) 260 | 261 | if merge_dims: 262 | if self._data_format == 'channels_last' and len(original_shape) == 4: 263 | grad = grad.reshape(permuted_shape).permute(0, 2, 3, 1) 264 | else: 265 | grad = grad.reshape(original_shape) 266 | return grad 267 | 268 | def update_preconditioner(self, grad, state, 269 | max_precond_dim=10000, merge_dims=False, precondition_1d=False): 270 | """ 271 | Updates the preconditioner matrices and the eigenbases (L, R, Q_L, Q_R in the paper). 272 | """ 273 | if state["Q"] is not None: 274 | state["exp_avg"] = self.project_back(state["exp_avg"], state, merge_dims=merge_dims, max_precond_dim=max_precond_dim) 275 | if grad.dim() == 1: 276 | if precondition_1d and grad.shape[0] <= max_precond_dim: 277 | state['GG'][0].lerp_(grad.unsqueeze(1) @ grad.unsqueeze(0), 1-state['shampoo_beta']) 278 | else: 279 | if merge_dims: 280 | new_grad = self.merge_dims(grad, max_precond_dim) 281 | for idx, sh in enumerate(new_grad.shape): 282 | if sh <= max_precond_dim: 283 | outer_product = torch.tensordot( 284 | new_grad, 285 | new_grad, 286 | dims=[[*chain(range(idx), range(idx + 1, len(new_grad.shape)))]] * 2, 287 | ) 288 | state['GG'][idx].lerp_(outer_product, 1-state['shampoo_beta']) 289 | else: 290 | for idx, sh in enumerate(grad.shape): 291 | if sh <= max_precond_dim: 292 | outer_product = torch.tensordot( 293 | grad, 294 | grad, 295 | # Contracts across all dimensions except for k. 296 | dims=[[*chain(range(idx), range(idx + 1, len(grad.shape)))]] * 2, 297 | ) 298 | state['GG'][idx].lerp_(outer_product, 1-state['shampoo_beta']) 299 | 300 | if state['Q'] is None: 301 | state['Q'] = self.get_orthogonal_matrix(state['GG']) 302 | if state['step'] > 0 and state['step'] % state['precondition_frequency'] == 0: 303 | state['Q'] = self.get_orthogonal_matrix_QR(state, max_precond_dim, merge_dims) 304 | # state['Q'] = self.get_fast_QR(state, max_precond_dim, merge_dims) 305 | 306 | if state["step"] > 0: 307 | state["exp_avg"] = self.project(state["exp_avg"], state, merge_dims=merge_dims, max_precond_dim=max_precond_dim) 308 | 309 | def project_back(self, grad, state, merge_dims=False, max_precond_dim=10000): 310 | """ 311 | Projects the gradient back to the original space. 312 | """ 313 | original_shape = grad.shape 314 | if merge_dims: 315 | if self._data_format == 'channels_last' and grad.dim() == 4: 316 | permuted_shape = grad.permute(0, 3, 1, 2).shape 317 | grad = self.merge_dims(grad, max_precond_dim) 318 | for mat in state['Q']: 319 | if len(mat) > 0: 320 | grad = torch.tensordot( 321 | grad, 322 | mat, 323 | dims=[[0], [1]], 324 | ) 325 | else: 326 | permute_order = list(range(1, len(grad.shape))) + [0] 327 | grad = grad.permute(permute_order) 328 | 329 | if merge_dims: 330 | if self._data_format == 'channels_last' and len(original_shape) == 4: 331 | grad = grad.reshape(permuted_shape).permute(0, 2, 3, 1) 332 | else: 333 | grad = grad.reshape(original_shape) 334 | return grad 335 | 336 | 337 | def get_orthogonal_matrix(self, mat): 338 | """ 339 | Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition. 340 | """ 341 | matrix = [] 342 | for m in mat: 343 | if len(m) == 0: 344 | matrix.append([]) 345 | continue 346 | if m.data.dtype != torch.float: 347 | float_data = False 348 | original_type = m.data.dtype 349 | original_device = m.data.device 350 | matrix.append(m.data.float()) 351 | else: 352 | float_data = True 353 | matrix.append(m.data) 354 | 355 | final = [] 356 | for m in matrix: 357 | if len(m) == 0: 358 | final.append([]) 359 | continue 360 | try: 361 | _, Q = torch.linalg.eigh(m+1e-30*torch.eye(m.shape[0], device=m.device)) 362 | except: 363 | _, Q = torch.linalg.eigh(m.to(torch.float64)+1e-30*torch.eye(m.shape[0], device=m.device)) 364 | Q = Q.to(m.dtype) 365 | Q = torch.flip(Q, [1]) 366 | 367 | if not float_data: 368 | Q = Q.to(original_device).type(original_type) 369 | final.append(Q) 370 | return final 371 | 372 | 373 | def get_orthogonal_matrix_QR(self, state, max_precond_dim=10000, merge_dims=False): 374 | """ 375 | Computes the eigenbases of the preconditioner using one round of power iteration 376 | followed by torch.linalg.qr decomposition. 377 | """ 378 | precond_list = state['GG'] 379 | orth_list = state['Q'] 380 | 381 | matrix = [] 382 | orth_matrix = [] 383 | for m,o in zip(precond_list, orth_list): 384 | if len(m) == 0: 385 | matrix.append([]) 386 | orth_matrix.append([]) 387 | continue 388 | if m.data.dtype != torch.float: 389 | float_data = False 390 | original_type = m.data.dtype 391 | original_device = m.data.device 392 | matrix.append(m.data.float()) 393 | orth_matrix.append(o.data.float()) 394 | else: 395 | float_data = True 396 | matrix.append(m.data.float()) 397 | orth_matrix.append(o.data.float()) 398 | 399 | orig_shape = state['exp_avg_sq'].shape 400 | if self._data_format == 'channels_last' and len(orig_shape) == 4: 401 | permuted_shape = state['exp_avg_sq'].permute(0, 3, 1, 2).shape 402 | if merge_dims: 403 | exp_avg_sq = self.merge_dims(state['exp_avg_sq'], max_precond_dim) 404 | else: 405 | exp_avg_sq = state['exp_avg_sq'] 406 | 407 | final = [] 408 | for ind, (m,o) in enumerate(zip(matrix, orth_matrix)): 409 | if len(m)==0: 410 | final.append([]) 411 | continue 412 | est_eig = torch.diag(o.T @ m @ o) 413 | sort_idx = torch.argsort(est_eig, descending=True) 414 | exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx) 415 | o = o[:,sort_idx] 416 | power_iter = m @ o 417 | Q, _ = torch.linalg.qr(power_iter) 418 | 419 | if not float_data: 420 | Q = Q.to(original_device).type(original_type) 421 | final.append(Q) 422 | 423 | if merge_dims: 424 | if self._data_format == 'channels_last' and len(orig_shape) == 4: 425 | exp_avg_sq = exp_avg_sq.reshape(permuted_shape).permute(0, 2, 3, 1) 426 | else: 427 | exp_avg_sq = exp_avg_sq.reshape(orig_shape) 428 | 429 | state['exp_avg_sq'] = exp_avg_sq 430 | return final 431 | 432 | --------------------------------------------------------------------------------