├── .gitignore ├── LICENSE ├── README.md ├── cgoptimizer ├── __init__.py ├── optim.py ├── optim_eff.py └── priority_dict.py ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__* 2 | /build/ 3 | .idea/ 4 | 5 | # Compiled python modules. 6 | *.pyc 7 | 8 | # Setuptools distribution folder. 9 | /dist/ 10 | 11 | # Python egg metadata, regenerated from source files by setuptools. 12 | /*.egg-info -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Paul-Aymeric McRae, Prasanna Parthasarathi et al. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Critical Gradient Optimizers 2 | 3 | Critical Gradient Optimizers from the "Memory Augmented Optimizers for Deep Learning" project and [paper](https://arxiv.org/abs/2106.10708), reformatted as package and stripped down to just the necessary components to integrate the optimizers into your code. 4 | 5 | This code is compatible with the following versions: 6 | 7 | ``` 8 | python >= 3.6 9 | pytorch >= 1.7.1 10 | ``` 11 | 12 | ## Installation 13 | 14 | Install these optimizers using 15 | 16 | ``` 17 | pip install cgoptimizer 18 | ``` 19 | 20 | Alternatively, clone [this repository](https://github.com/chandar-lab/CGOptimizer) anywhere on your system. Once cloned, `cd` to the directory and install with: `pip install .` 21 | 22 | ## Colab 23 | 24 | The shared [colab notebook](https://colab.research.google.com/drive/1m8Edr7aAHlBIlAtV2PZRQgnKIcf0VQh5?usp=sharing) shows examples on using the Critical gradient optimizers on toy classification tasks constructed with scikit-learn package. 25 | 26 | ## Importing and Running 27 | 28 | You can import the optimizers as you would any PyTorch optimizer. There are no requirements to run them other than PyTorch and its dependencies. 29 | 30 | When installed, import the optimizers to your training script as needed: 31 | 32 | ``` 33 | from cgoptimizer import SGD_C, RMSprop_C, Adam_C, AdamW_C 34 | ``` 35 | 36 | You can then replace any PyTorch optimizer in your script with their `_C` counterpart. Note that currently only Critical-Gradient variants of Adam, RMSprop, and SGD (with optional momentum but NOT Nesterov) are implemented. 37 | 38 | Here is a sample replacement: 39 | 40 | ``` 41 | optimizer = Adam(model.parameters(), lr=0.001) 42 | ``` 43 | 44 | becomes 45 | 46 | ``` 47 | optimizer = Adam_C(model.parameters(), lr=0.001, **kwargs) 48 | ``` 49 | 50 | Similarly, for efficient GPU-based implementation: 51 | 52 | ``` 53 | from cgoptimizer.optim_eff import SGD_C_eff, RMSprop_C_eff, Adam_C_eff, AdamW_C_eff 54 | optimizer = Adam_C_eff(model.parameters(), lr=0.001, **kwargs) 55 | ``` 56 | 57 | ## Optimizer Usage and Tuning 58 | 59 | The Critical Gradient variants use all the same hyperparameters as their vanilla counterparts, so you may not need to perform any additional tuning. 60 | 61 | The `_C` (and `_C_eff`) optimizers have two additional hyperparameters compared to the vanilla version: `topC` which indicates how many critical gradients to keep and`decay` which indicates how much the norms of critical gradients are decayed each step. These are keyword arguments with default values which we observed to work well. For additional performance, these can be tuned. 62 | 63 | The `_C` (and `_C_eff`) variants perform best using either the same best learning rate as its vanilla counterpart, or 1/10 that learning rate. It is recommended you run both learning rates to compare. 64 | 65 | Hyperparameter `topC` determines how many critical gradients are stored and thus how much memory is used. Higher `topC` usually result in longer training times. Good `topC` values usually fall between 5 and 20. We recommended using values 5, 10, and 20. 66 | 67 | Hyperparameter `decay` indicates the level of decay in the buffer. This modifies how frequently the buffer is refreshed. The `decay` parameter must fall between 0 and 1. We recommended using values 0.7 and 0.9. 68 | 69 | ## Citation 70 | 71 | ``` 72 | @misc{mcrae2021memory, 73 | author = {McRae, Paul-Aymeric and Parthasarathi, Prasanna and Assran, Mahmoud and Chandar, Sarath}, 74 | title = {Memory Augmented Optimizers for Deep Learning}, 75 | year = {2022}, 76 | booktitle = {Proceedings of ICLR} 77 | } 78 | -------------------------------------------------------------------------------- /cgoptimizer/__init__.py: -------------------------------------------------------------------------------- 1 | from .optim import SGD_C, RMSprop_C, Adam_C -------------------------------------------------------------------------------- /cgoptimizer/optim.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementations for _C enhanced optimizers as well as their vanilla counterparts. 3 | Vanilla algorthims are sourced from PyTorch source, and _C iterations are largely 4 | based on those as well. 5 | 6 | https://github.com/pytorch/pytorch/tree/master/torch/optim 7 | """ 8 | 9 | import math 10 | import torch 11 | from torch.optim import Optimizer 12 | from .priority_dict import PriorityDict 13 | from copy import deepcopy 14 | 15 | 16 | def aggregate(d_p, crit_buf, func, kappa=1.0): 17 | """ 18 | Reusable aggregation function to join current iteration gradient and critical 19 | gradients 20 | 21 | :param d_p: Current-iteration gradient 22 | :param crit_buf: Buffer of Critical Gradients 23 | :param func: String name of aggregation. Should be "sum", "mid", or "mean" 24 | :param kappa: Multiplicative factor for CG buffer 25 | :return: Aggregated total gradient 26 | """ 27 | 28 | if "sum" == func: 29 | crit_buf_ = crit_buf.gradMean() 30 | crit_buf_.mul_(kappa) 31 | return torch.add(d_p, crit_buf_) 32 | elif "mid" == func: 33 | crit_buf_ = crit_buf.gradMean() 34 | crit_buf_.mul_(kappa) 35 | return torch.mul(torch.add(d_p, crit_buf_), 0.5) 36 | elif "mean" == func: 37 | crit_buf_ = crit_buf.gradSum() 38 | crit_buf_.mul_(kappa) 39 | return torch.div(torch.add(d_p, crit_buf_), len(crit_buf) + 1) 40 | else: 41 | raise ValueError("Invalid aggregation function") 42 | 43 | 44 | class SGD(Optimizer): 45 | r"""Implements stochastic gradient descent (optionally with momentum). 46 | Nesterov momentum is based on the formula from 47 | `On the importance of initialization and momentum in deep learning`__. 48 | Args: 49 | params (iterable): iterable of parameters to optimize or dicts defining 50 | parameter groups 51 | lr (float): learning rate 52 | momentum (float, optional): momentum factor (default: 0) 53 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 54 | dampening (float, optional): dampening for momentum (default: 0) 55 | nesterov (bool, optional): enables Nesterov momentum (default: False) 56 | Example: 57 | >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) 58 | >>> optimizer.zero_grad() 59 | >>> loss_fn(model(input), target).backward() 60 | >>> optimizer.step() 61 | __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf 62 | .. note:: 63 | The implementation of SGD with Momentum/Nesterov subtly differs from 64 | Sutskever et. al. and implementations in some other frameworks. 65 | Considering the specific case of Momentum, the update can be written as 66 | .. math:: 67 | v_{t+1} = \mu * v_{t} + g_{t+1} \\ 68 | p_{t+1} = p_{t} - lr * v_{t+1} 69 | where p, g, v and :math:`\mu` denote the parameters, gradient, 70 | velocity, and momentum respectively. 71 | This is in contrast to Sutskever et. al. and 72 | other frameworks which employ an update of the form 73 | .. math:: 74 | v_{t+1} = \mu * v_{t} + lr * g_{t+1} \\ 75 | p_{t+1} = p_{t} - v_{t+1} 76 | The Nesterov version is analogously modified. 77 | """ 78 | 79 | def __init__(self, params, lr=0.001, momentum=0, dampening=0, 80 | weight_decay=0, nesterov=False): 81 | if momentum < 0.0: 82 | raise ValueError("Invalid momentum value: {}".format(momentum)) 83 | if weight_decay < 0.0: 84 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 85 | 86 | defaults = dict(lr=lr, momentum=momentum, dampening=dampening, 87 | weight_decay=weight_decay, nesterov=nesterov) 88 | if nesterov and (momentum <= 0 or dampening != 0): 89 | raise ValueError("Nesterov momentum requires a momentum and zero dampening") 90 | super(SGD, self).__init__(params, defaults) 91 | self.resetOfflineStats() 92 | 93 | def __setstate__(self, state): 94 | super(SGD, self).__setstate__(state) 95 | for group in self.param_groups: 96 | group.setdefault('nesterov', False) 97 | 98 | def getOfflineStats(self): 99 | return self.offline_grad 100 | 101 | def resetOfflineStats(self): 102 | self.offline_grad = {'yes': 0, 'no': 0} 103 | 104 | def step(self, closure=None): 105 | """Performs a single optimization step. 106 | Arguments: 107 | closure (callable, optional): A closure that reevaluates the model 108 | and returns the loss. 109 | """ 110 | loss = None 111 | if closure is not None: 112 | with torch.enable_grad(): 113 | loss = closure() 114 | 115 | for group in self.param_groups: 116 | weight_decay = group['weight_decay'] 117 | momentum = group['momentum'] 118 | dampening = group['dampening'] 119 | nesterov = group['nesterov'] 120 | 121 | for p in group['params']: 122 | if p.grad is None: 123 | continue 124 | d_p = p.grad.data 125 | if weight_decay != 0: 126 | d_p = d_p.add(weight_decay, p.data) 127 | if momentum != 0: 128 | param_state = self.state[p] 129 | if 'momentum_buffer' not in param_state: 130 | buf = param_state['momentum_buffer'] = torch.clone(d_p).detach() 131 | else: 132 | buf = param_state['momentum_buffer'] 133 | buf.mul_(momentum).add_(d_p, alpha=1 - dampening) 134 | if nesterov: 135 | d_p = d_p.add(momentum, buf) 136 | else: 137 | d_p = buf 138 | 139 | p.data.add_(d_p, alpha=-group['lr']) 140 | 141 | return loss 142 | 143 | 144 | class SGD_C(Optimizer): 145 | """ 146 | Implementation of SGD (and optionally SGD with momentum) with critical gradients. 147 | Replaces current-iteration gradient in conventional PyTorch implementation with 148 | an aggregation of current gradient and critical gradients. 149 | 150 | Conventional SGD or SGD with momentum can be recovered by setting kappa=0. 151 | 152 | The critical-gradient-specific keyword parameters are tuned for good 153 | off-the-shelf performance, though additional tuning may be required for best results 154 | """ 155 | 156 | def __init__(self, params, lr=0.001, kappa=1.0, dampening=0., 157 | weight_decay=0, momentum=0., 158 | decay=0.7, topC=10, aggr='sum', synced=True): 159 | 160 | if momentum < 0.0: 161 | raise ValueError("Invalid momentum value: {}".format(momentum)) 162 | if weight_decay < 0.0: 163 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 164 | if not 0.0 <= decay and not 1.0 > decay: 165 | raise ValueError("Invalid decay value: {}".format(decay)) 166 | if not 0.0 <= topC: 167 | raise ValueError("Invalid topC value: {}".format(topC)) 168 | 169 | defaults = dict(lr=lr, kappa=kappa, dampening=dampening, 170 | weight_decay=weight_decay, momentum=momentum, 171 | aggr=aggr, decay=decay, topC=topC, synced=synced) 172 | 173 | super(SGD_C, self).__init__(params, defaults) 174 | 175 | def getOfflineStats(self): 176 | return self.offline_grad 177 | 178 | def resetOfflineStats(self): 179 | self.offline_grad = {'yes': 0, 'no': 0} 180 | 181 | def __setstate__(self, state): 182 | super(SGD_C, self).__setstate__(state) 183 | 184 | def step(self, closure=None): 185 | """Performs a single optimization step. 186 | Arguments: 187 | closure (callable, optional): A closure that reevaluates the model 188 | and returns the loss. 189 | """ 190 | loss = None 191 | if closure is not None: 192 | with torch.enable_grad(): 193 | loss = closure() 194 | 195 | for group in self.param_groups: 196 | weight_decay = group['weight_decay'] 197 | kappa = group['kappa'] 198 | dampening = group['dampening'] 199 | decay = group['decay'] 200 | momentum = group['momentum'] 201 | topc = group['topC'] 202 | aggr = group['aggr'] 203 | synced = group['synced'] 204 | 205 | d_p_norm = 0.0 206 | 207 | if synced: 208 | for p in group['params']: 209 | if p.grad is None: 210 | continue 211 | d_p = p.grad.data 212 | d_p_norm += torch.sqrt(torch.sum(torch.square(d_p))) 213 | 214 | for p in group['params']: 215 | if p.grad is None: 216 | continue 217 | d_p = p.grad.data 218 | if not synced: 219 | d_p_norm = d_p.norm() 220 | if weight_decay != 0: 221 | d_p = d_p.add(weight_decay, p.data) 222 | if kappa != 0: 223 | param_state = self.state[p] 224 | if 'critical gradients' not in param_state: 225 | crit_buf = param_state['critical gradients'] = PriorityDict() 226 | crit_buf.setHyper(decay_rate=decay, K=topc) 227 | crit_buf[d_p_norm] = deepcopy(d_p) 228 | else: 229 | crit_buf = param_state['critical gradients'] 230 | aggr_grad = aggregate(d_p, crit_buf, aggr, kappa) 231 | if crit_buf.isFull(): 232 | if d_p_norm > crit_buf.pokeSmallest(): 233 | self.offline_grad['yes'] += 1 234 | crit_buf[d_p_norm] = deepcopy(d_p) 235 | else: 236 | self.offline_grad['no'] += 1 237 | else: 238 | crit_buf[d_p_norm] = deepcopy(d_p) 239 | d_p = aggr_grad 240 | 241 | crit_buf.decay() 242 | 243 | if momentum != 0: 244 | param_state = self.state[p] 245 | if 'momentum_buffer' not in param_state: 246 | buf = param_state['momentum_buffer'] = torch.clone( 247 | d_p).detach() 248 | else: 249 | buf = param_state['momentum_buffer'] 250 | buf.mul_(momentum).add_(d_p, alpha=1 - dampening) 251 | d_p = buf 252 | 253 | p.data.add_(d_p, alpha=-group['lr']) 254 | 255 | return loss 256 | 257 | 258 | class Adam(Optimizer): 259 | r"""Implements Adam algorithm. 260 | 261 | It has been proposed in `Adam: A Method for Stochastic Optimization`_. 262 | 263 | Arguments: 264 | params (iterable): iterable of parameters to optimize or dicts defining 265 | parameter groups 266 | lr (float, optional): learning rate (default: 1e-3) 267 | betas (Tuple[float, float], optional): coefficients used for computing 268 | running averages of gradient and its square (default: (0.9, 0.999)) 269 | eps (float, optional): term added to the denominator to improve 270 | numerical stability (default: 1e-8) 271 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 272 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 273 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 274 | (default: False) 275 | 276 | .. _Adam\: A Method for Stochastic Optimization: 277 | https://arxiv.org/abs/1412.6980 278 | .. _On the Convergence of Adam and Beyond: 279 | https://openreview.net/forum?id=ryQu7f-RZ 280 | """ 281 | 282 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 283 | weight_decay=0, amsgrad=False): 284 | if not 0.0 <= lr: 285 | raise ValueError("Invalid learning rate: {}".format(lr)) 286 | if not 0.0 <= eps: 287 | raise ValueError("Invalid epsilon value: {}".format(eps)) 288 | if not 0.0 <= betas[0] < 1.0: 289 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 290 | if not 0.0 <= betas[1] < 1.0: 291 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 292 | if not 0.0 <= weight_decay: 293 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 294 | defaults = dict(lr=lr, betas=betas, eps=eps, 295 | weight_decay=weight_decay, amsgrad=amsgrad) 296 | super(Adam, self).__init__(params, defaults) 297 | self.resetOfflineStats() 298 | 299 | def __setstate__(self, state): 300 | super(Adam, self).__setstate__(state) 301 | for group in self.param_groups: 302 | group.setdefault('amsgrad', False) 303 | 304 | def getOfflineStats(self): 305 | return self.offline_grad 306 | 307 | def resetOfflineStats(self): 308 | self.offline_grad = {'yes': 0, 'no': 0} 309 | 310 | @torch.no_grad() 311 | def step(self, closure=None): 312 | """Performs a single optimization step. 313 | 314 | Arguments: 315 | closure (callable, optional): A closure that reevaluates the model 316 | and returns the loss. 317 | """ 318 | loss = None 319 | if closure is not None: 320 | with torch.enable_grad(): 321 | loss = closure() 322 | 323 | for group in self.param_groups: 324 | for p in group['params']: 325 | if p.grad is None: 326 | continue 327 | grad = p.grad 328 | if grad.is_sparse: 329 | raise RuntimeError( 330 | 'Adam does not support sparse gradients, please consider SparseAdam instead') 331 | amsgrad = group['amsgrad'] 332 | 333 | state = self.state[p] 334 | 335 | # State initialization 336 | if len(state) == 0: 337 | state['step'] = 0 338 | # Exponential moving average of gradient values 339 | state['exp_avg'] = torch.zeros_like(p) 340 | # Exponential moving average of squared gradient values 341 | state['exp_avg_sq'] = torch.zeros_like(p) 342 | if amsgrad: 343 | # Maintains max of all exp. moving avg. of sq. grad. values 344 | state['max_exp_avg_sq'] = torch.zeros_like(p) 345 | 346 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 347 | if amsgrad: 348 | max_exp_avg_sq = state['max_exp_avg_sq'] 349 | beta1, beta2 = group['betas'] 350 | 351 | state['step'] += 1 352 | bias_correction1 = 1 - beta1 ** state['step'] 353 | bias_correction2 = 1 - beta2 ** state['step'] 354 | 355 | if group['weight_decay'] != 0: 356 | grad = grad.add(p, alpha=group['weight_decay']) 357 | 358 | # Decay the first and second moment running average coefficient 359 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 360 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 361 | if amsgrad: 362 | # Maintains the maximum of all 2nd moment running avg. till now 363 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 364 | # Use the max. for normalizing running avg. of gradient 365 | denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_( 366 | group['eps']) 367 | else: 368 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_( 369 | group['eps']) 370 | 371 | step_size = group['lr'] / bias_correction1 372 | 373 | p.addcdiv_(exp_avg, denom, value=-step_size) 374 | 375 | return loss 376 | 377 | 378 | class AdamW(Optimizer): 379 | r"""Implements AdamW algorithm. 380 | 381 | Arguments: 382 | params (iterable): iterable of parameters to optimize or dicts defining 383 | parameter groups 384 | lr (float, optional): learning rate (default: 1e-3) 385 | betas (Tuple[float, float], optional): coefficients used for computing 386 | running averages of gradient and its square (default: (0.9, 0.999)) 387 | eps (float, optional): term added to the denominator to improve 388 | numerical stability (default: 1e-8) 389 | weight_decay (float, optional): weight decay (default: 0.01) 390 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 391 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 392 | (default: False) 393 | 394 | .. _AdamW\: A Method for Stochastic Optimization: 395 | """ 396 | 397 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 398 | weight_decay=0.01, amsgrad=False): 399 | if not 0.0 <= lr: 400 | raise ValueError("Invalid learning rate: {}".format(lr)) 401 | if not 0.0 <= eps: 402 | raise ValueError("Invalid epsilon value: {}".format(eps)) 403 | if not 0.0 <= betas[0] < 1.0: 404 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 405 | if not 0.0 <= betas[1] < 1.0: 406 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 407 | if not 0.0 <= weight_decay: 408 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 409 | defaults = dict(lr=lr, betas=betas, eps=eps, 410 | weight_decay=weight_decay, amsgrad=amsgrad) 411 | super(AdamW, self).__init__(params, defaults) 412 | self.resetOfflineStats() 413 | 414 | def __setstate__(self, state): 415 | super(AdamW, self).__setstate__(state) 416 | for group in self.param_groups: 417 | group.setdefault('amsgrad', False) 418 | 419 | def getOfflineStats(self): 420 | return self.offline_grad 421 | 422 | def resetOfflineStats(self): 423 | self.offline_grad = {'yes': 0, 'no': 0} 424 | 425 | @torch.no_grad() 426 | def step(self, closure=None): 427 | """Performs a single optimization step. 428 | 429 | Arguments: 430 | closure (callable, optional): A closure that reevaluates the model 431 | and returns the loss. 432 | """ 433 | loss = None 434 | if closure is not None: 435 | with torch.enable_grad(): 436 | loss = closure() 437 | 438 | for group in self.param_groups: 439 | for p in group['params']: 440 | if p.grad is None: 441 | continue 442 | p.mul_(1 - group['lr'] * group['weight_decay']) 443 | grad = p.grad 444 | if grad.is_sparse: 445 | raise RuntimeError( 446 | 'Adam does not support sparse gradients, please consider SparseAdam instead') 447 | amsgrad = group['amsgrad'] 448 | 449 | state = self.state[p] 450 | 451 | # State initialization 452 | if len(state) == 0: 453 | state['step'] = 0 454 | # Exponential moving average of gradient values 455 | state['exp_avg'] = torch.zeros_like(p) 456 | # Exponential moving average of squared gradient values 457 | state['exp_avg_sq'] = torch.zeros_like(p) 458 | if amsgrad: 459 | # Maintains max of all exp. moving avg. of sq. grad. values 460 | state['max_exp_avg_sq'] = torch.zeros_like(p) 461 | 462 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 463 | if amsgrad: 464 | max_exp_avg_sq = state['max_exp_avg_sq'] 465 | beta1, beta2 = group['betas'] 466 | 467 | state['step'] += 1 468 | bias_correction1 = 1 - beta1 ** state['step'] 469 | bias_correction2 = 1 - beta2 ** state['step'] 470 | 471 | # if group['weight_decay'] != 0: 472 | # grad = grad.add(p, alpha=group['weight_decay']) 473 | 474 | # Decay the first and second moment running average coefficient 475 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 476 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 477 | if amsgrad: 478 | # Maintains the maximum of all 2nd moment running avg. till now 479 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 480 | # Use the max. for normalizing running avg. of gradient 481 | denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_( 482 | group['eps']) 483 | else: 484 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_( 485 | group['eps']) 486 | 487 | step_size = group['lr'] / bias_correction1 488 | 489 | p.addcdiv_(exp_avg, denom, value=-step_size) 490 | 491 | return loss 492 | 493 | 494 | class Adam_C(Optimizer): 495 | """ 496 | Implementation of Adam with critical gradients. 497 | Replaces current-iteration gradient in conventional PyTorch implementation with 498 | an aggregation of current gradient and critical gradients. 499 | 500 | Conventional Adam can be recovered by setting kappa=0. 501 | 502 | The critical-gradient-specific keyword parameters are tuned for good 503 | off-the-shelf performance, though additional tuning may be required for best results 504 | """ 505 | 506 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 507 | decay=0.7, kappa=1.0, topC=10, 508 | weight_decay=0, amsgrad=False, aggr='mean', synced=True): 509 | if not 0.0 <= lr: 510 | raise ValueError("Invalid learning rate: {}".format(lr)) 511 | if not 0.0 <= eps: 512 | raise ValueError("Invalid epsilon value: {}".format(eps)) 513 | if not 0.0 <= betas[0] < 1.0: 514 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 515 | if not 0.0 <= betas[1] < 1.0: 516 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 517 | if not 0.0 <= decay and not 1.0 > decay: 518 | raise ValueError("Invalid decay value: {}".format(decay)) 519 | if not 0.0 <= topC: 520 | raise ValueError("Invalid topC value: {}".format(topC)) 521 | defaults = dict(lr=lr, betas=betas, eps=eps, 522 | weight_decay=weight_decay, aggr=aggr, amsgrad=amsgrad, 523 | kappa=kappa, topC=topC, decay=decay, synced=synced) 524 | 525 | super(Adam_C, self).__init__(params, defaults) 526 | 527 | def getOfflineStats(self): 528 | return self.offline_grad 529 | 530 | def resetOfflineStats(self): 531 | self.offline_grad = {'yes': 0, 'no': 0} 532 | 533 | def __setstate__(self, state): 534 | super(Adam_C, self).__setstate__(state) 535 | for group in self.param_groups: 536 | group.setdefault('amsgrad', False) 537 | 538 | @torch.no_grad() 539 | def step(self, closure=None): 540 | """Performs a single optimization step. 541 | Arguments: 542 | closure (callable, optional): A closure that reevaluates the model 543 | and returns the loss. 544 | """ 545 | loss = None 546 | if closure is not None: 547 | with torch.enable_grad(): 548 | loss = closure() 549 | 550 | for group in self.param_groups: 551 | 552 | grad_norm = 0.0 553 | 554 | if group['synced']: 555 | for p in group['params']: 556 | if p.grad is None: 557 | continue 558 | d_p = p.grad.data 559 | grad_norm += torch.sqrt(torch.sum(torch.square(d_p))) 560 | 561 | for p in group['params']: 562 | if p.grad is None: 563 | continue 564 | grad = p.grad.data 565 | if not group['synced']: 566 | grad_norm = grad.norm() 567 | if grad.is_sparse: 568 | raise RuntimeError( 569 | 'Adam does not support sparse gradients, please consider SparseAdam instead') 570 | amsgrad = group['amsgrad'] 571 | kappa = group['kappa'] 572 | decay = group['decay'] 573 | topc = group['topC'] 574 | aggr = group['aggr'] 575 | 576 | state = self.state[p] 577 | 578 | # State initialization 579 | if len(state) == 0: 580 | state['step'] = 0 581 | # Exponential moving average of gradient values 582 | state['exp_avg'] = torch.zeros_like(p.data) 583 | # Exponential moving average of squared gradient values 584 | state['exp_avg_sq'] = torch.zeros_like(p.data) 585 | if kappa > 0.: 586 | state['critical gradients'] = PriorityDict() 587 | state['critical gradients'].setHyper(decay_rate=decay, K=topc) 588 | state['critical gradients'][grad_norm] = deepcopy(grad) 589 | if amsgrad: 590 | # Maintains max of all exp. moving avg. of sq. grad. values 591 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 592 | else: 593 | if kappa > 0.: 594 | aggr_grad = aggregate(grad, state['critical gradients'], aggr) 595 | if state['critical gradients'].isFull(): 596 | if grad_norm > state['critical gradients'].pokeSmallest(): 597 | self.offline_grad['yes'] += 1 598 | state['critical gradients'][grad_norm] = deepcopy(grad) 599 | else: 600 | self.offline_grad['no'] += 1 601 | else: 602 | state['critical gradients'][grad_norm] = deepcopy(grad) 603 | grad = aggr_grad 604 | 605 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 606 | if amsgrad: 607 | max_exp_avg_sq = state['max_exp_avg_sq'] 608 | beta1, beta2 = group['betas'] 609 | 610 | state['step'] += 1 611 | bias_correction1 = 1 - beta1 ** state['step'] 612 | bias_correction2 = 1 - beta2 ** state['step'] 613 | 614 | if group['weight_decay'] != 0: 615 | grad = grad.add(group['weight_decay'], p.data) 616 | 617 | # Decay the first and second moment running average coefficient 618 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) # m_t 619 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # v_t 620 | if amsgrad: 621 | # Maintains the maximum of all 2nd moment running avg. till now 622 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 623 | # Use the max. for normalizing running avg. of gradient 624 | denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_( 625 | group['eps']) 626 | else: 627 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_( 628 | group['eps']) 629 | 630 | step_size = group['lr'] / bias_correction1 631 | 632 | state['critical gradients'].decay() 633 | 634 | p.addcdiv_(exp_avg, denom, value=-step_size) 635 | 636 | return loss 637 | 638 | class AdamW_C(Optimizer): 639 | """ 640 | Implementation of AdamW with critical gradients. 641 | Replaces current-iteration gradient in conventional PyTorch implementation with 642 | an aggregation of current gradient and critical gradients. 643 | 644 | Conventional AdamW can be recovered by setting kappa=0. 645 | 646 | The critical-gradient-specific keyword parameters are tuned for good 647 | off-the-shelf performance, though additional tuning may be required for best results 648 | """ 649 | 650 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 651 | decay=0.7, kappa=1.0, topC=10, 652 | weight_decay=0.01, amsgrad=False, aggr='mean', synced=True): 653 | if not 0.0 <= lr: 654 | raise ValueError("Invalid learning rate: {}".format(lr)) 655 | if not 0.0 <= eps: 656 | raise ValueError("Invalid epsilon value: {}".format(eps)) 657 | if not 0.0 <= betas[0] < 1.0: 658 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 659 | if not 0.0 <= betas[1] < 1.0: 660 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 661 | if not 0.0 <= decay and not 1.0 > decay: 662 | raise ValueError("Invalid decay value: {}".format(decay)) 663 | if not 0.0 <= topC: 664 | raise ValueError("Invalid topC value: {}".format(topC)) 665 | defaults = dict(lr=lr, betas=betas, eps=eps, 666 | weight_decay=weight_decay, aggr=aggr, amsgrad=amsgrad, 667 | kappa=kappa, topC=topC, decay=decay, synced=synced) 668 | 669 | super(AdamW_C, self).__init__(params, defaults) 670 | self.resetOfflineStats() 671 | self.resetAnalysis() 672 | 673 | def getOfflineStats(self): 674 | return self.offline_grad 675 | 676 | def resetOfflineStats(self): 677 | self.offline_grad = {'yes': 0, 'no': 0} 678 | 679 | def __setstate__(self, state): 680 | super(AdamW_C, self).__setstate__(state) 681 | for group in self.param_groups: 682 | group.setdefault('amsgrad', False) 683 | 684 | def getAnalysis(self): 685 | return self.g_analysis 686 | 687 | def resetAnalysis(self): 688 | self.g_analysis = {'gt': 0., 'gc': 0., 'count': 0, 'gc_aggr': 0} 689 | 690 | @torch.no_grad() 691 | def step(self, closure=None): 692 | """Performs a single optimization step. 693 | Arguments: 694 | closure (callable, optional): A closure that reevaluates the model 695 | and returns the loss. 696 | """ 697 | loss = None 698 | if closure is not None: 699 | with torch.enable_grad(): 700 | loss = closure() 701 | 702 | for group in self.param_groups: 703 | 704 | grad_norm = 0.0 705 | 706 | if group['synced']: 707 | for p in group['params']: 708 | if p.grad is None: 709 | continue 710 | d_p = p.grad.data 711 | grad_norm += torch.sqrt(torch.sum(torch.square(d_p))) 712 | 713 | for p in group['params']: 714 | if p.grad is None: 715 | continue 716 | p.mul_(1 - group['lr'] * group['weight_decay']) 717 | grad = p.grad.data 718 | if not group['synced']: 719 | grad_norm = grad.norm() 720 | if grad.is_sparse: 721 | raise RuntimeError( 722 | 'AdamW does not support sparse gradients, please consider SparseAdamW instead') 723 | amsgrad = group['amsgrad'] 724 | kappa = group['kappa'] 725 | decay = group['decay'] 726 | topc = group['topC'] 727 | aggr = group['aggr'] 728 | 729 | state = self.state[p] 730 | 731 | # State initialization 732 | if len(state) == 0: 733 | state['step'] = 0 734 | # Exponential moving average of gradient values 735 | state['exp_avg'] = torch.zeros_like(p.data) 736 | # Exponential moving average of squared gradient values 737 | state['exp_avg_sq'] = torch.zeros_like(p.data) 738 | if kappa > 0.: 739 | state['critical gradients'] = PriorityDict() 740 | state['critical gradients'].setHyper(decay_rate=decay, K=topc) 741 | state['critical gradients'][grad_norm] = deepcopy(grad) 742 | if amsgrad: 743 | # Maintains max of all exp. moving avg. of sq. grad. values 744 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 745 | else: 746 | if kappa > 0.: 747 | aggr_grad = aggregate(grad, state['critical gradients'], aggr) 748 | if state['critical gradients'].isFull(): 749 | if grad_norm > state['critical gradients'].pokeSmallest(): 750 | self.offline_grad['yes'] += 1 751 | state['critical gradients'][grad_norm] = deepcopy(grad) 752 | else: 753 | self.offline_grad['no'] += 1 754 | else: 755 | state['critical gradients'][grad_norm] = deepcopy(grad) 756 | grad = aggr_grad 757 | 758 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 759 | if amsgrad: 760 | max_exp_avg_sq = state['max_exp_avg_sq'] 761 | beta1, beta2 = group['betas'] 762 | 763 | state['step'] += 1 764 | bias_correction1 = 1 - beta1 ** state['step'] 765 | bias_correction2 = 1 - beta2 ** state['step'] 766 | 767 | # if group['weight_decay'] != 0: 768 | # grad = grad.add(group['weight_decay'], p.data) 769 | 770 | # Decay the first and second moment running average coefficient 771 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) # m_t 772 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # v_t 773 | if amsgrad: 774 | # Maintains the maximum of all 2nd moment running avg. till now 775 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 776 | # Use the max. for normalizing running avg. of gradient 777 | denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_( 778 | group['eps']) 779 | else: 780 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_( 781 | group['eps']) 782 | 783 | step_size = group['lr'] / bias_correction1 784 | 785 | state['critical gradients'].decay() 786 | 787 | p.addcdiv_(exp_avg, denom, value=-step_size) 788 | 789 | return loss 790 | 791 | class RMSprop(Optimizer): 792 | r"""Implements RMSprop algorithm. 793 | Proposed by G. Hinton in his 794 | `course `. 795 | The centered version first appears in `Generating Sequences 796 | With Recurrent Neural Networks `_. 797 | The implementation here takes the square root of the gradient average before 798 | adding epsilon (note that TensorFlow interchanges these two operations). The 799 | effective learning rate is thus :math:`\alpha/(\sqrt{v} + \epsilon)` where 800 | :math:`\alpha` is the scheduled learning rate and :math:`v` is the weighted 801 | moving average of the squared gradient. 802 | Args: 803 | params (iterable): iterable of parameters to optimize or dicts defining 804 | parameter groups 805 | lr (float, optional): learning rate (default: 1e-2) 806 | momentum (float, optional): momentum factor (default: 0) 807 | alpha (float, optional): smoothing constant (default: 0.99) 808 | eps (float, optional): term added to the denominator to improve 809 | numerical stability (default: 1e-8) 810 | centered (bool, optional) : if ``True``, compute the centered RMSProp, 811 | the gradient is normalized by an estimation of its variance 812 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 813 | """ 814 | 815 | def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, 816 | momentum=0, centered=False): 817 | if not 0.0 <= lr: 818 | raise ValueError("Invalid learning rate: {}".format(lr)) 819 | if not 0.0 <= eps: 820 | raise ValueError("Invalid epsilon value: {}".format(eps)) 821 | if not 0.0 <= momentum: 822 | raise ValueError("Invalid momentum value: {}".format(momentum)) 823 | if not 0.0 <= weight_decay: 824 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 825 | if not 0.0 <= alpha: 826 | raise ValueError("Invalid alpha value: {}".format(alpha)) 827 | 828 | defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, 829 | centered=centered, weight_decay=weight_decay) 830 | super(RMSprop, self).__init__(params, defaults) 831 | self.resetOfflineStats() 832 | 833 | def __setstate__(self, state): 834 | super(RMSprop, self).__setstate__(state) 835 | for group in self.param_groups: 836 | group.setdefault('momentum', 0) 837 | group.setdefault('centered', False) 838 | 839 | @torch.no_grad() 840 | def step(self, closure=None): 841 | """Performs a single optimization step. 842 | Args: 843 | closure (callable, optional): A closure that reevaluates the model 844 | and returns the loss. 845 | """ 846 | loss = None 847 | if closure is not None: 848 | with torch.enable_grad(): 849 | loss = closure() 850 | 851 | for group in self.param_groups: 852 | for p in group['params']: 853 | if p.grad is None: 854 | continue 855 | grad = p.grad 856 | if grad.is_sparse: 857 | raise RuntimeError('RMSprop does not support sparse gradients') 858 | state = self.state[p] 859 | 860 | # State initialization 861 | if len(state) == 0: 862 | state['step'] = 0 863 | state['square_avg'] = \ 864 | torch.zeros_like(p, memory_format=torch.preserve_format) 865 | if group['momentum'] > 0: 866 | state['momentum_buffer'] = \ 867 | torch.zeros_like(p, memory_format=torch.preserve_format) 868 | if group['centered']: 869 | state['grad_avg'] = \ 870 | torch.zeros_like(p, memory_format=torch.preserve_format) 871 | 872 | square_avg = state['square_avg'] 873 | alpha = group['alpha'] 874 | 875 | state['step'] += 1 876 | 877 | if group['weight_decay'] != 0: 878 | grad = grad.add(p, alpha=group['weight_decay']) 879 | 880 | square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha) 881 | 882 | if group['centered']: 883 | grad_avg = state['grad_avg'] 884 | grad_avg.mul_(alpha).add_(grad, alpha=1 - alpha) 885 | avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).sqrt_().add_( 886 | group['eps']) 887 | else: 888 | avg = square_avg.sqrt().add_(group['eps']) 889 | 890 | if group['momentum'] > 0: 891 | buf = state['momentum_buffer'] 892 | buf.mul_(group['momentum']).addcdiv_(grad, avg) 893 | p.add_(buf, alpha=-group['lr']) 894 | else: 895 | p.addcdiv_(grad, avg, value=-group['lr']) 896 | 897 | return loss 898 | 899 | def getOfflineStats(self): 900 | return self.offline_grad 901 | 902 | def resetOfflineStats(self): 903 | self.offline_grad = {'yes': 0, 'no': 0} 904 | 905 | 906 | class RMSprop_C(Optimizer): 907 | """ 908 | Implementation of RMSprop with critical gradients. 909 | Replaces current-iteration gradient in conventional PyTorch implementation with 910 | an aggregation of current gradient and critical gradients. 911 | 912 | Conventional RMSprop can be recovered by setting kappa=0. 913 | 914 | The critical-gradient-specific keyword parameters are tuned for good 915 | off-the-shelf performance, though additional tuning may be required for best results 916 | """ 917 | 918 | def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, 919 | momentum=0, centered=False, decay=0.7, kappa=1.0, 920 | topC=10, aggr='mean', synced=True): 921 | if not 0.0 <= lr: 922 | raise ValueError("Invalid learning rate: {}".format(lr)) 923 | if not 0.0 <= eps: 924 | raise ValueError("Invalid epsilon value: {}".format(eps)) 925 | if not 0.0 <= momentum: 926 | raise ValueError("Invalid momentum value: {}".format(momentum)) 927 | if not 0.0 <= weight_decay: 928 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 929 | if not 0.0 <= decay and not 1.0 > decay: 930 | raise ValueError("Invalid decay value: {}".format(decay)) 931 | if not 0.0 <= topC: 932 | raise ValueError("Invalid topC value: {}".format(topC)) 933 | 934 | defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, 935 | centered=centered, weight_decay=weight_decay, 936 | aggr=aggr, kappa=kappa, topC=topC, decay=decay, synced=synced) 937 | super(RMSprop_C, self).__init__(params, defaults) 938 | self.resetOfflineStats() 939 | 940 | def __setstate__(self, state): 941 | super(RMSprop_C, self).__setstate__(state) 942 | for group in self.param_groups: 943 | group.setdefault('momentum', 0) 944 | group.setdefault('centered', False) 945 | 946 | @torch.no_grad() 947 | def step(self, closure=None): 948 | """Performs a single optimization step. 949 | Args: 950 | closure (callable, optional): A closure that reevaluates the model 951 | and returns the loss. 952 | """ 953 | loss = None 954 | if closure is not None: 955 | with torch.enable_grad(): 956 | loss = closure() 957 | 958 | for group in self.param_groups: 959 | 960 | grad_norm = 0.0 961 | 962 | if group['synced']: 963 | for p in group['params']: 964 | if p.grad is None: 965 | continue 966 | d_p = p.grad.data 967 | grad_norm += torch.sqrt(torch.sum(torch.square(d_p))) 968 | 969 | for p in group['params']: 970 | if p.grad is None: 971 | continue 972 | grad = p.grad.data 973 | if not group['synced']: 974 | grad_norm = grad.norm() 975 | if grad.is_sparse: 976 | raise RuntimeError('RMSprop does not support sparse gradients') 977 | kappa = group['kappa'] 978 | decay = group['decay'] 979 | topc = group['topC'] 980 | aggr = group['aggr'] 981 | state = self.state[p] 982 | 983 | # State initialization 984 | if len(state) == 0: 985 | state['step'] = 0 986 | state['square_avg'] = \ 987 | torch.zeros_like(p, memory_format=torch.preserve_format) 988 | if group['momentum'] > 0: 989 | state['momentum_buffer'] = \ 990 | torch.zeros_like(p, memory_format=torch.preserve_format) 991 | if group['centered']: 992 | state['grad_avg'] = \ 993 | torch.zeros_like(p, memory_format=torch.preserve_format) 994 | if kappa > 0.: 995 | state['critical gradients'] = PriorityDict() 996 | state['critical gradients'].setHyper(decay_rate=decay, K=topc) 997 | state['critical gradients'][grad_norm] = deepcopy(grad) 998 | else: 999 | aggr_grad = aggregate(grad, state['critical gradients'], aggr) 1000 | if kappa > 0.: 1001 | if state['critical gradients'].isFull(): 1002 | if grad_norm > state['critical gradients'].pokeSmallest(): 1003 | self.offline_grad['yes'] += 1 1004 | state['critical gradients'][grad_norm] = deepcopy(grad) 1005 | else: 1006 | self.offline_grad['no'] += 1 1007 | else: 1008 | state['critical gradients'][grad_norm] = deepcopy(grad) 1009 | grad = aggr_grad 1010 | 1011 | square_avg = state['square_avg'] 1012 | alpha = group['alpha'] 1013 | 1014 | state['step'] += 1 1015 | 1016 | if group['weight_decay'] != 0: 1017 | grad = grad.add(p, alpha=group['weight_decay']) 1018 | 1019 | square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha) 1020 | 1021 | if group['centered']: 1022 | grad_avg = state['grad_avg'] 1023 | grad_avg.mul_(alpha).add_(grad, alpha=1 - alpha) 1024 | avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).sqrt_().add_( 1025 | group['eps']) 1026 | else: 1027 | avg = square_avg.sqrt().add_(group['eps']) 1028 | 1029 | state['critical gradients'].decay() 1030 | 1031 | if group['momentum'] > 0: 1032 | buf = state['momentum_buffer'] 1033 | buf.mul_(group['momentum']).addcdiv_(grad, avg) 1034 | p.add_(buf, alpha=-group['lr']) 1035 | else: 1036 | p.addcdiv_(grad, avg, value=-group['lr']) 1037 | 1038 | return loss 1039 | 1040 | def getOfflineStats(self): 1041 | return self.offline_grad 1042 | 1043 | def resetOfflineStats(self): 1044 | self.offline_grad = {'yes': 0, 'no': 0} 1045 | -------------------------------------------------------------------------------- /cgoptimizer/optim_eff.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementations for _C_eff enhanced optimizers as well as their vanilla counterparts. 3 | Vanilla algorthims are sourced from PyTorch source, and _C_eff iterations are largely 4 | based on those as well. 5 | 6 | https://github.com/pytorch/pytorch/tree/master/torch/optim 7 | """ 8 | 9 | import math 10 | from copy import deepcopy 11 | 12 | import torch 13 | from torch.optim import Optimizer 14 | 15 | from .priority_dict import TensorList 16 | 17 | 18 | class SGD_C_eff(Optimizer): 19 | """ 20 | Efficient GPU-based implementation of SGD (and optionally SGD with momentum) with critical gradients. 21 | Replaces current-iteration gradient in conventional PyTorch implementation with 22 | an aggregation of current gradient and critical gradients. 23 | 24 | Conventional SGD or SGD with momentum can be recovered by setting kappa=0. 25 | 26 | The critical-gradient-specific keyword parameters are tuned for good 27 | off-the-shelf performance, though additional tuning may be required for best results 28 | """ 29 | 30 | def __init__(self, params, lr=0.001, kappa=1.0, dampening=0., 31 | weight_decay=0, momentum=0., 32 | decay=0.99, topC=20, aggr='sum', 33 | synced=True, buffer_dtype=None): 34 | 35 | if momentum < 0.0: 36 | raise ValueError("Invalid momentum value: {}".format(momentum)) 37 | if weight_decay < 0.0: 38 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 39 | if not 0.0 <= decay and not 1.0 > decay: 40 | raise ValueError("Invalid decay value: {}".format(decay)) 41 | if not 0.0 <= topC: 42 | raise ValueError("Invalid topC value: {}".format(topC)) 43 | 44 | defaults = dict(lr=lr, kappa=kappa, dampening=dampening, 45 | weight_decay=weight_decay, momentum=momentum, aggr=aggr, 46 | decay=decay, gradHist={}, topC=topC, 47 | synced=synced, buffer_dtype=buffer_dtype) 48 | 49 | super(SGD_C_eff, self).__init__(params, defaults) 50 | self.resetOfflineStats() 51 | self.resetAnalysis() 52 | 53 | def getOfflineStats(self): 54 | return self.offline_grad 55 | 56 | def getAnalysis(self): 57 | return self.g_analysis 58 | 59 | def resetAnalysis(self): 60 | self.g_analysis = {'gt': 0., 'gc': 0., 'count': 0, 'gc_aggr': 0} 61 | 62 | def resetOfflineStats(self): 63 | self.offline_grad = {'yes': 0, 'no': 0} 64 | 65 | def __setstate__(self, state): 66 | super(SGD_C_eff, self).__setstate__(state) 67 | 68 | def step(self, closure=None): 69 | """Performs a single optimization step. 70 | Arguments: 71 | closure (callable, optional): A closure that reevaluates the model 72 | and returns the loss. 73 | """ 74 | loss = None 75 | if closure is not None: 76 | with torch.enable_grad(): 77 | loss = closure() 78 | 79 | for group in self.param_groups: 80 | weight_decay = group['weight_decay'] 81 | kappa = group['kappa'] 82 | dampening = group['dampening'] 83 | decay = group['decay'] 84 | momentum = group['momentum'] 85 | topc = group['topC'] 86 | aggr = group['aggr'] 87 | 88 | synced = group['synced'] 89 | buffer_dtype = group['buffer_dtype'] 90 | 91 | d_p_norm = 0.0 92 | 93 | if synced: 94 | for p in group['params']: 95 | if p.grad is None: 96 | continue 97 | d_p = p.grad.data 98 | d_p_norm += torch.sqrt(torch.sum(torch.square(d_p))) 99 | 100 | for p in group['params']: 101 | if p.grad is None: 102 | continue 103 | d_p = p.grad.data 104 | if not synced: 105 | d_p_norm = d_p.norm() 106 | if weight_decay != 0: 107 | d_p = d_p.add(weight_decay, p.data) 108 | if kappa != 0: 109 | param_state = self.state[p] 110 | if 'critical gradients' not in param_state or len(param_state['critical gradients'])==0: 111 | crit_buf = param_state['critical gradients'] = TensorList() 112 | crit_buf.setHyper(decay_rate=decay, K=topc, dtype=buffer_dtype) 113 | crit_buf.addItem(d_p_norm, deepcopy(d_p)) 114 | else: 115 | crit_buf = param_state['critical gradients'] 116 | aggr_mean = crit_buf.aggr_sum.div(crit_buf.size()) 117 | aggr_grad = torch.add(d_p, aggr_mean) 118 | if crit_buf.isFull(): 119 | if d_p_norm > crit_buf.pokeSmallest(): 120 | self.offline_grad['yes'] += 1 121 | crit_buf.addItem(d_p_norm, deepcopy(d_p)) 122 | else: 123 | self.offline_grad['no'] += 1 124 | else: 125 | crit_buf.addItem(d_p_norm, deepcopy(d_p)) 126 | d_p = aggr_grad 127 | 128 | crit_buf.decay() 129 | 130 | if momentum != 0: 131 | param_state = self.state[p] 132 | if 'momentum_buffer' not in param_state: 133 | buf = param_state['momentum_buffer'] = torch.clone( 134 | d_p).detach() 135 | else: 136 | buf = param_state['momentum_buffer'] 137 | buf.mul_(momentum).add_(d_p, alpha=1 - dampening) 138 | d_p = buf 139 | 140 | p.data.add_(d_p, alpha=-group['lr']) 141 | 142 | return loss 143 | 144 | 145 | class Adam_C_eff(Optimizer): 146 | """ 147 | Efficient GPU-based implementation of Adam with critical gradients. 148 | Replaces current-iteration gradient in conventional PyTorch implementation with 149 | an aggregation of current gradient and critical gradients. 150 | 151 | Conventional Adam can be recovered by setting kappa=0. 152 | 153 | The critical-gradient-specific keyword parameters are tuned for good 154 | off-the-shelf performance, though additional tuning may be required for best results 155 | """ 156 | 157 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 158 | decay=0.99, kappa=1.0, topC=20, 159 | weight_decay=0, amsgrad=False, aggr='mean', synced=True, buffer_dtype=None): 160 | if not 0.0 <= lr: 161 | raise ValueError("Invalid learning rate: {}".format(lr)) 162 | if not 0.0 <= eps: 163 | raise ValueError("Invalid epsilon value: {}".format(eps)) 164 | if not 0.0 <= betas[0] < 1.0: 165 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 166 | if not 0.0 <= betas[1] < 1.0: 167 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 168 | if not 0.0 <= decay and not 1.0 > decay: 169 | raise ValueError("Invalid decay value: {}".format(decay)) 170 | if not 0.0 <= topC: 171 | raise ValueError("Invalid topC value: {}".format(topC)) 172 | defaults = dict(lr=lr, betas=betas, eps=eps, 173 | weight_decay=weight_decay, aggr=aggr, amsgrad=amsgrad, 174 | kappa=kappa, topC=topC, decay=decay, 175 | synced=synced, buffer_dtype=buffer_dtype) 176 | 177 | super(Adam_C_eff, self).__init__(params, defaults) 178 | self.resetOfflineStats() 179 | self.resetAnalysis() 180 | 181 | def getOfflineStats(self): 182 | return self.offline_grad 183 | 184 | def resetOfflineStats(self): 185 | self.offline_grad = {'yes': 0, 'no': 0} 186 | 187 | def __setstate__(self, state): 188 | super(Adam_C_eff, self).__setstate__(state) 189 | for group in self.param_groups: 190 | group.setdefault('amsgrad', False) 191 | 192 | def getAnalysis(self): 193 | return self.g_analysis 194 | 195 | def resetAnalysis(self): 196 | self.g_analysis = {'gt': 0., 'gc': 0., 'count': 0, 'gc_aggr': 0} 197 | 198 | @torch.no_grad() 199 | def step(self, closure=None): 200 | """Performs a single optimization step. 201 | Arguments: 202 | closure (callable, optional): A closure that reevaluates the model 203 | and returns the loss. 204 | """ 205 | loss = None 206 | if closure is not None: 207 | with torch.enable_grad(): 208 | loss = closure() 209 | 210 | for group in self.param_groups: 211 | 212 | grad_norm = 0.0 213 | 214 | if group['synced']: 215 | for p in group['params']: 216 | if p.grad is None: 217 | continue 218 | d_p = p.grad.data 219 | grad_norm += torch.sqrt(torch.sum(torch.square(d_p))) 220 | 221 | for p in group['params']: 222 | if p.grad is None: 223 | continue 224 | grad = p.grad.data 225 | if not group['synced']: 226 | grad_norm = grad.norm() 227 | if grad.is_sparse: 228 | raise RuntimeError( 229 | 'Adam does not support sparse gradients, please consider SparseAdam instead') 230 | amsgrad = group['amsgrad'] 231 | kappa = group['kappa'] 232 | decay = group['decay'] 233 | topc = group['topC'] 234 | aggr = group['aggr'] 235 | buffer_dtype = group['buffer_dtype'] 236 | 237 | state = self.state[p] 238 | 239 | # State initialization 240 | if len(state) == 0 or ('critical gradients' not in state or len(state['critical gradients'])==0): 241 | state['step'] = 0 242 | # Exponential moving average of gradient values 243 | state['exp_avg'] = torch.zeros_like(p.data) 244 | # Exponential moving average of squared gradient values 245 | state['exp_avg_sq'] = torch.zeros_like(p.data) 246 | if kappa > 0.: 247 | state['critical gradients'] = TensorList() 248 | state['critical gradients'].setHyper(decay_rate=decay, K=topc, dtype=buffer_dtype) 249 | state['critical gradients'].addItem(grad_norm, deepcopy(grad)) 250 | if amsgrad: 251 | # Maintains max of all exp. moving avg. of sq. grad. values 252 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 253 | else: 254 | if kappa > 0.: 255 | aggr_mean = state['critical gradients'].aggr_sum.div(state['critical gradients'].size()) 256 | aggr_grad = torch.add(grad, aggr_mean) 257 | if state['critical gradients'].isFull(): 258 | if grad_norm > state['critical gradients'].pokeSmallest(): 259 | self.offline_grad['yes'] += 1 260 | state['critical gradients'].addItem(grad_norm, deepcopy(grad)) 261 | else: 262 | self.offline_grad['no'] += 1 263 | else: 264 | state['critical gradients'].addItem(grad_norm, deepcopy(grad)) 265 | grad = aggr_grad 266 | 267 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 268 | if amsgrad: 269 | max_exp_avg_sq = state['max_exp_avg_sq'] 270 | beta1, beta2 = group['betas'] 271 | 272 | state['step'] += 1 273 | bias_correction1 = 1 - beta1 ** state['step'] 274 | bias_correction2 = 1 - beta2 ** state['step'] 275 | 276 | if group['weight_decay'] != 0: 277 | grad = grad.add(group['weight_decay'], p.data) 278 | 279 | # Decay the first and second moment running average coefficient 280 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) # m_t 281 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # v_t 282 | if amsgrad: 283 | # Maintains the maximum of all 2nd moment running avg. till now 284 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 285 | # Use the max. for normalizing running avg. of gradient 286 | denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 287 | else: 288 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 289 | 290 | step_size = group['lr'] / bias_correction1 291 | 292 | state['critical gradients'].decay() 293 | 294 | p.addcdiv_(exp_avg, denom, value=-step_size) 295 | 296 | return loss 297 | 298 | 299 | class AdamW_C_eff(Optimizer): 300 | """ 301 | Efficient GPU-based implementation of Adam with critical gradients. 302 | Replaces current-iteration gradient in conventional PyTorch implementation with 303 | an aggregation of current gradient and critical gradients. 304 | 305 | Conventional Adam can be recovered by setting kappa=0. 306 | 307 | The critical-gradient-specific keyword parameters are tuned for good 308 | off-the-shelf performance, though additional tuning may be required for best results 309 | """ 310 | 311 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 312 | decay=0.99, kappa=1.0, topC=20, 313 | weight_decay=0.01, amsgrad=False, aggr='mean', synced=True, buffer_dtype=None): 314 | if not 0.0 <= lr: 315 | raise ValueError("Invalid learning rate: {}".format(lr)) 316 | if not 0.0 <= eps: 317 | raise ValueError("Invalid epsilon value: {}".format(eps)) 318 | if not 0.0 <= betas[0] < 1.0: 319 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 320 | if not 0.0 <= betas[1] < 1.0: 321 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 322 | if not 0.0 <= decay and not 1.0 > decay: 323 | raise ValueError("Invalid decay value: {}".format(decay)) 324 | if not 0.0 <= topC: 325 | raise ValueError("Invalid topC value: {}".format(topC)) 326 | defaults = dict(lr=lr, betas=betas, eps=eps, 327 | weight_decay=weight_decay, aggr=aggr, amsgrad=amsgrad, 328 | kappa=kappa, topC=topC, decay=decay, 329 | synced=synced, buffer_dtype=buffer_dtype) 330 | 331 | super(AdamW_C_eff, self).__init__(params, defaults) 332 | self.resetOfflineStats() 333 | self.resetAnalysis() 334 | 335 | def getOfflineStats(self): 336 | return self.offline_grad 337 | 338 | def resetOfflineStats(self): 339 | self.offline_grad = {'yes': 0, 'no': 0} 340 | 341 | def __setstate__(self, state): 342 | super(AdamW_C_eff, self).__setstate__(state) 343 | for group in self.param_groups: 344 | group.setdefault('amsgrad', False) 345 | 346 | def getAnalysis(self): 347 | return self.g_analysis 348 | 349 | def resetAnalysis(self): 350 | self.g_analysis = {'gt': 0., 'gc': 0., 'count': 0, 'gc_aggr': 0} 351 | 352 | @torch.no_grad() 353 | def step(self, closure=None): 354 | """Performs a single optimization step. 355 | Arguments: 356 | closure (callable, optional): A closure that reevaluates the model 357 | and returns the loss. 358 | """ 359 | loss = None 360 | if closure is not None: 361 | with torch.enable_grad(): 362 | loss = closure() 363 | 364 | for group in self.param_groups: 365 | 366 | grad_norm = 0.0 367 | 368 | if group['synced']: 369 | for p in group['params']: 370 | if p.grad is None: 371 | continue 372 | d_p = p.grad.data 373 | grad_norm += torch.sqrt(torch.sum(torch.square(d_p))) 374 | 375 | for p in group['params']: 376 | if p.grad is None: 377 | continue 378 | p.mul_(1 - group['lr'] * group['weight_decay']) 379 | grad = p.grad.data 380 | if not group['synced']: 381 | grad_norm = grad.norm() 382 | if grad.is_sparse: 383 | raise RuntimeError( 384 | 'Adam does not support sparse gradients, please consider SparseAdam instead') 385 | amsgrad = group['amsgrad'] 386 | kappa = group['kappa'] 387 | decay = group['decay'] 388 | topc = group['topC'] 389 | aggr = group['aggr'] 390 | buffer_dtype = group['buffer_dtype'] 391 | 392 | state = self.state[p] 393 | 394 | # State initialization 395 | if len(state) == 0 or ('critical gradients' not in state or len(state['critical gradients'])==0): 396 | state['step'] = 0 397 | # Exponential moving average of gradient values 398 | state['exp_avg'] = torch.zeros_like(p.data) 399 | # Exponential moving average of squared gradient values 400 | state['exp_avg_sq'] = torch.zeros_like(p.data) 401 | if kappa > 0.: 402 | state['critical gradients'] = TensorList() 403 | state['critical gradients'].setHyper(decay_rate=decay, K=topc, dtype=buffer_dtype) 404 | state['critical gradients'].addItem(grad_norm, deepcopy(grad)) 405 | if amsgrad: 406 | # Maintains max of all exp. moving avg. of sq. grad. values 407 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 408 | else: 409 | if kappa > 0.: 410 | aggr_mean = state['critical gradients'].aggr_sum.div(state['critical gradients'].size()) 411 | aggr_grad = torch.add(grad, aggr_mean) 412 | if state['critical gradients'].isFull(): 413 | if grad_norm > state['critical gradients'].pokeSmallest(): 414 | self.offline_grad['yes'] += 1 415 | state['critical gradients'].addItem(grad_norm, deepcopy(grad)) 416 | else: 417 | self.offline_grad['no'] += 1 418 | else: 419 | state['critical gradients'].addItem(grad_norm, deepcopy(grad)) 420 | grad = aggr_grad 421 | 422 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 423 | if amsgrad: 424 | max_exp_avg_sq = state['max_exp_avg_sq'] 425 | beta1, beta2 = group['betas'] 426 | 427 | state['step'] += 1 428 | bias_correction1 = 1 - beta1 ** state['step'] 429 | bias_correction2 = 1 - beta2 ** state['step'] 430 | 431 | # Decay the first and second moment running average coefficient 432 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) # m_t 433 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # v_t 434 | if amsgrad: 435 | # Maintains the maximum of all 2nd moment running avg. till now 436 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 437 | # Use the max. for normalizing running avg. of gradient 438 | denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 439 | else: 440 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 441 | 442 | step_size = group['lr'] / bias_correction1 443 | 444 | state['critical gradients'].decay() 445 | 446 | p.addcdiv_(exp_avg, denom, value=-step_size) 447 | 448 | return loss 449 | 450 | 451 | class RMSprop_C_eff(Optimizer): 452 | """ 453 | Efficient GPU-based implementation of RMSprop with critical gradients. 454 | Replaces current-iteration gradient in conventional PyTorch implementation with 455 | an aggregation of current gradient and critical gradients. 456 | 457 | Conventional RMSprop can be recovered by setting kappa=0. 458 | 459 | The critical-gradient-specific keyword parameters are tuned for good 460 | off-the-shelf performance, though additional tuning may be required for best results 461 | """ 462 | 463 | def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, 464 | momentum=0, centered=False, decay=0.99, kappa=1.0, 465 | topC=20, aggr='mean', synced=True, buffer_dtype=None): 466 | if not 0.0 <= lr: 467 | raise ValueError("Invalid learning rate: {}".format(lr)) 468 | if not 0.0 <= eps: 469 | raise ValueError("Invalid epsilon value: {}".format(eps)) 470 | if not 0.0 <= momentum: 471 | raise ValueError("Invalid momentum value: {}".format(momentum)) 472 | if not 0.0 <= weight_decay: 473 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 474 | if not 0.0 <= decay and not 1.0 > decay: 475 | raise ValueError("Invalid decay value: {}".format(decay)) 476 | if not 0.0 <= topC: 477 | raise ValueError("Invalid topC value: {}".format(topC)) 478 | 479 | defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, 480 | centered=centered, weight_decay=weight_decay, 481 | aggr=aggr, kappa=kappa, topC=topC, decay=decay, 482 | synced=synced, buffer_dtype=buffer_dtype) 483 | super(RMSprop_C_eff, self).__init__(params, defaults) 484 | self.resetOfflineStats() 485 | 486 | def __setstate__(self, state): 487 | super(RMSprop_C_eff, self).__setstate__(state) 488 | for group in self.param_groups: 489 | group.setdefault('momentum', 0) 490 | group.setdefault('centered', False) 491 | 492 | @torch.no_grad() 493 | def step(self, closure=None): 494 | """Performs a single optimization step. 495 | Args: 496 | closure (callable, optional): A closure that reevaluates the model 497 | and returns the loss. 498 | """ 499 | loss = None 500 | if closure is not None: 501 | with torch.enable_grad(): 502 | loss = closure() 503 | 504 | for group in self.param_groups: 505 | grad_norm = 0.0 506 | 507 | if group['synced']: 508 | for p in group['params']: 509 | if p.grad is None: 510 | continue 511 | d_p = p.grad.data 512 | grad_norm += torch.sqrt(torch.sum(torch.square(d_p))) 513 | 514 | for p in group['params']: 515 | if p.grad is None: 516 | continue 517 | grad = p.grad.data 518 | if not group['synced']: 519 | grad_norm = grad.norm() 520 | if grad.is_sparse: 521 | raise RuntimeError('RMSprop does not support sparse gradients') 522 | kappa = group['kappa'] 523 | decay = group['decay'] 524 | topc = group['topC'] 525 | aggr = group['aggr'] 526 | buffer_dtype = group['buffer_dtype'] 527 | 528 | state = self.state[p] 529 | 530 | # State initialization 531 | if len(state) == 0 or ('critical gradients' not in state or len(state['critical gradients'])==0): 532 | state['step'] = 0 533 | state['square_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) 534 | if group['momentum'] > 0: 535 | state['momentum_buffer'] = torch.zeros_like(p, memory_format=torch.preserve_format) 536 | if group['centered']: 537 | state['grad_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) 538 | if kappa > 0.: 539 | state['critical gradients'] = TensorList() 540 | state['critical gradients'].setHyper(decay_rate=decay, K=topc, dtype=buffer_dtype) 541 | state['critical gradients'].addItem(grad_norm, deepcopy(grad)) 542 | else: 543 | aggr_mean = state['critical gradients'].aggr_sum.div(state['critical gradients'].size()) 544 | aggr_grad = torch.add(grad, aggr_mean) 545 | if kappa > 0.: 546 | if state['critical gradients'].isFull(): 547 | if grad_norm > state['critical gradients'].pokeSmallest(): 548 | self.offline_grad['yes'] += 1 549 | state['critical gradients'].addItem(grad_norm, deepcopy(grad)) 550 | else: 551 | self.offline_grad['no'] += 1 552 | else: 553 | state['critical gradients'].addItem(grad_norm, deepcopy(grad)) 554 | grad = aggr_grad 555 | 556 | square_avg = state['square_avg'] 557 | alpha = group['alpha'] 558 | 559 | state['step'] += 1 560 | 561 | if group['weight_decay'] != 0: 562 | grad = grad.add(p, alpha=group['weight_decay']) 563 | 564 | square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha) 565 | 566 | if group['centered']: 567 | grad_avg = state['grad_avg'] 568 | grad_avg.mul_(alpha).add_(grad, alpha=1 - alpha) 569 | avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).sqrt_().add_(group['eps']) 570 | else: 571 | avg = square_avg.sqrt().add_(group['eps']) 572 | 573 | state['critical gradients'].decay() 574 | 575 | if group['momentum'] > 0: 576 | buf = state['momentum_buffer'] 577 | buf.mul_(group['momentum']).addcdiv_(grad, avg) 578 | p.add_(buf, alpha=-group['lr']) 579 | else: 580 | p.addcdiv_(grad, avg, value=-group['lr']) 581 | 582 | return loss 583 | 584 | def getOfflineStats(self): 585 | return self.offline_grad 586 | 587 | def resetOfflineStats(self): 588 | self.offline_grad = {'yes': 0, 'no': 0} 589 | -------------------------------------------------------------------------------- /cgoptimizer/priority_dict.py: -------------------------------------------------------------------------------- 1 | from heapq import heapify 2 | import torch 3 | from copy import deepcopy 4 | 5 | 6 | class HeapItem: 7 | def __init__(self, p, t): 8 | self.p = p 9 | self.t = t 10 | 11 | def __lt__(self, other): 12 | return self.p < other.p 13 | 14 | 15 | class TensorList(dict): 16 | """List that can be used as a priority queue. 17 | 18 | The 'smallest' method can be used to return the object with lowest 19 | priority 20 | 21 | """ 22 | 23 | def __init__(self, *args, **kwargs): 24 | super(TensorList, self).__init__(*args, **kwargs) 25 | self.aggr_sum = None 26 | # self.aggr_sq_sum = None 27 | self.smallest = 0 28 | 29 | def getNorms(self): 30 | return self._heap_key 31 | 32 | def size(self): 33 | return self.curr_k 34 | 35 | def setHyper(self, decay_rate=0.5, K=5, dtype=None): 36 | self.k = K 37 | self.curr_k = 0 38 | self.decay_rate = decay_rate 39 | self.dtype = dtype 40 | 41 | def addItem(self, key, val): 42 | if self.dtype is not None: 43 | val = val.to(dtype=self.dtype) 44 | if self.isFull(): 45 | self.aggr_sum.add_(-self._heap[self.smallest]) 46 | self._heap_key[self.smallest] = key 47 | self._heap[self.smallest] = val 48 | else: 49 | if self.curr_k == 0: 50 | self._heap_key = torch.zeros(self.k, device='cuda' if torch.cuda.is_available() else 'cpu', dtype=val.dtype) 51 | self._heap = torch.zeros(self.k, *val.shape, device='cuda' if torch.cuda.is_available() else 'cpu', dtype=val.dtype) 52 | self._heap_key[self.curr_k] = key 53 | self._heap[self.curr_k] = val 54 | self.curr_k += 1 55 | 56 | if self.aggr_sum is None: 57 | self.aggr_sum = torch.zeros_like(self._heap[0], device='cuda' if torch.cuda.is_available() else 'cpu') 58 | # self.aggr_sq_sum = torch.zeros_like(val) 59 | self.aggr_sum.add_(val) 60 | # self.aggr_sq_sum.addcmul_(val, val) 61 | 62 | def pokeSmallest(self): 63 | """Return the lowest priority. 64 | Raises IndexError if the object is empty. 65 | """ 66 | self.smallest = torch.argmin(self._heap_key) 67 | return self._heap_key[self.smallest] 68 | 69 | def isEmpty(self): 70 | return self.curr_k == 0 71 | 72 | def decay(self): 73 | self._heap_key = torch.mul(self._heap_key, self.decay_rate) 74 | 75 | def isFull(self): 76 | return self.curr_k == self.k # len(self._heap) >= self.k 77 | 78 | def averageTopC(self): 79 | ave = 0. 80 | if self.curr_k > 0: 81 | ave = torch.sum([it.norm() for it in self._heap]) / float(self.curr_k) 82 | return ave 83 | 84 | def getMin(self): 85 | """ 86 | Get smallest gradient 87 | :return: The smallest gradient 88 | """ 89 | return self._heap[self.smallest] 90 | 91 | def getMax(self): 92 | "Returns the largest gradient" 93 | return self._heap[torch.argmax(self._heap_key)] 94 | 95 | def __getitem__(self, key): 96 | return self._heap[self._heap_key == key] 97 | 98 | def __len__(self): 99 | return self.curr_k 100 | 101 | def setdefault(self, key, val): 102 | if key not in self: 103 | self[key] = val 104 | return val 105 | return self[key] 106 | 107 | def step(self): 108 | for item in self._heap: item.step() 109 | 110 | def epoch(self): 111 | ages = [] 112 | for item in self._heap: 113 | ages.append(item.epoch_age) 114 | item.resetEpoch() 115 | return ages 116 | 117 | 118 | 119 | 120 | class PriorityDict(dict): 121 | """Dictionary that can be used as a priority queue. 122 | 123 | Keys of the dictionary are items to be put into the queue, and values 124 | are their respective priorities. All dictionary methods work as expected. 125 | The advantage over a standard heapq-based priority queue is 126 | that priorities of items can be efficiently updated (amortized O(1)) 127 | using code as 'thedict[item] = new_priority.' 128 | 129 | The 'smallest' method can be used to return the object with lowest 130 | priority, and 'pop_smallest' also removes it. 131 | 132 | The 'sorted_iter' method provides a destructive sorted iterator. 133 | """ 134 | 135 | def __init__(self, *args, **kwargs): 136 | super(PriorityDict, self).__init__(*args, **kwargs) 137 | self._heap = [HeapItem(k, v) for k, v in self.items()] 138 | self._rebuild_heap() 139 | 140 | def getNorms(self): 141 | return [item.p for item in self._heap] 142 | 143 | def size(self): 144 | return len(self._heap) 145 | 146 | def setHyper(self, decay_rate=0.5, K=5): 147 | self.k = K 148 | self.decay_rate = decay_rate 149 | 150 | def _reorder(self): 151 | self._heap = deepcopy(self._heap[-self.k:]) 152 | in_heap = [it.p for it in self._heap] 153 | del_ = [k for k in self.keys() if k not in in_heap] 154 | for k in del_: 155 | del self[k] 156 | 157 | def _rebuild_heap(self): 158 | # >= used as fix for errors in some data 159 | self._heap = [it for it in self._heap if it.p >= 0.0] 160 | if len(self._heap) > 0: 161 | heapify(self._heap) 162 | if not self.isEmpty() and self.isFull(): 163 | self._reorder() 164 | 165 | def isEmpty(self): 166 | if len(self._heap) == 0: 167 | return True 168 | return False 169 | 170 | def decay(self): 171 | self._heap = [HeapItem(self.decay_rate * it.p, it.t) for it in self._heap] 172 | 173 | def isFull(self): 174 | if len(self._heap) < self.k: 175 | return False 176 | return True 177 | 178 | def averageTopC(self): 179 | ave = 0. 180 | if len(self._heap) > 0: 181 | ave = sum([it.t.norm() for it in self._heap]) / float(len(self._heap)) 182 | return ave 183 | 184 | def pokeSmallest(self): 185 | """Return the lowest priority. 186 | 187 | Raises IndexError if the object is empty. 188 | """ 189 | 190 | it = self._heap[0] 191 | return it.p 192 | 193 | def gradMean(self): 194 | """Return the sum of top k gradients 195 | """ 196 | 197 | mean = torch.clone(self._heap[0].t) 198 | cnt = 1. 199 | for it in self._heap[1:]: 200 | mean.add_(it.t) 201 | cnt += 1. 202 | return mean.div_(cnt) 203 | 204 | def gradSum(self): 205 | """Return the sum of top k gradients 206 | """ 207 | 208 | sum = torch.clone(self._heap[0].t) 209 | for it in self._heap[1:]: 210 | sum.add_(it.t) 211 | return sum 212 | 213 | def __getitem__(self, key): 214 | return dict(self._heap) 215 | 216 | def __len__(self): 217 | return len(self._heap) 218 | 219 | def __setitem__(self, key, val): 220 | # We are not going to remove the previous value from the heap, 221 | # since this would have a cost O(n). 222 | 223 | self._heap.append(HeapItem(key, val)) 224 | self._rebuild_heap() 225 | 226 | def setdefault(self, key, val): 227 | if key not in self: 228 | self[key] = val 229 | return val 230 | return self[key] 231 | 232 | def update(self, *args, **kwargs): 233 | # Reimplementing dict.update is tricky -- see e.g. 234 | # http://mail.python.org/pipermail/python-ideas/2007-May/000744.html 235 | # We just rebuild the heap from scratch after passing to super. 236 | 237 | super(PriorityDict, self).update(*args, **kwargs) 238 | self._rebuild_heap() 239 | 240 | def sorted_iter(self): 241 | """Sorted iterator of the priority dictionary items. 242 | 243 | Beware: this will destroy elements as they are returned. 244 | """ 245 | 246 | while self: 247 | yield self.popSmallest() 248 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | dataclasses==0.8 2 | numpy==1.19.1 3 | torch==1.7.1 4 | typing-extensions==3.7.4.3 5 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | from setuptools import setup 3 | 4 | HERE = pathlib.Path(__file__).parent 5 | 6 | setup(name='cgoptimizer', 7 | version='1.0.0', 8 | description='Critical Gradient Optimizers', 9 | url='https://github.com/chandar-lab/CGOptimizer', 10 | author='Paul-Aymeric McRae, Prasanna Parthasarathi', 11 | author_email='paul-aymeric.mcrae@mail.mcgill.ca', 12 | license='MIT', 13 | install_requires=[ 14 | 'torch' 15 | ], 16 | packages=['cgoptimizer'], 17 | zip_safe=False) 18 | --------------------------------------------------------------------------------