├── .gitignore ├── __init__.py ├── LICENSE ├── README.md └── scheduler.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from scheduler import * 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Tim Esler 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # lr-momentum-scheduler 2 | 3 | This repo contains pytorch scheduler classes for implementing the following: 4 | 5 | * Arbitrary LR and momentum schedulers 6 | * Lambda function-based scheduler based on lr_scheduler.LambdaLR 7 | * List-based scheduler that accepts explicitly defined schedule lists for LR and momentum 8 | * Learning rate range finder for preparing the 1cycle policy 9 | * The 1cycle policy scheduler 10 | 11 | These classes inherit from, and and based on, the core learning rate schedulers included in Pytorch, and can be used in an identical manner, with the added ability to schedule momentum. 12 | 13 | ## Schedulers 14 | 15 | See detailed documentation and implementation by running: 16 | 17 | ```python 18 | import scheduler 19 | help(scheduler.LambdaScheduler) 20 | help(scheduler.ListScheduler) 21 | help(scheduler.RangeFinder) 22 | help(scheduler.OneCyclePolicy) 23 | ``` 24 | 25 | 1. `LambdaScheduler`: based on pytorch's `LambdaLR`, but can also (optionally) schedule momentum in the same way. Note that, like LambdaLR, individual schedules can be defined for each parameter group in the optimizer by passing a list of lambdas/functions/callables for LR and momentum. 26 | 1. `ListScheduler`: similar to the `LambdaScheduler`, but defines LR and momentum using passed lists. Per-parameter schedules are specified using lists of lists or 2D numpy arrays. 27 | 1. `RangeFinder`: a simple predefined schedule that varies LR from 1e-7 to 1 over a certain number of epochs. This is a preparatory step for the One Cycle Policy. 28 | 1. `OneCyclePolicy`: The One Cycle Policy scheduler for LR and momentum, see [References](#references). 29 | 30 | ## The One Cycle Policy 31 | 32 | 1. Import modules and define some test data: 33 | ```python 34 | import torch 35 | from torch import nn 36 | from torch import optim 37 | from scheduler import * 38 | 39 | epochs = 50 40 | x = torch.randn(100, 10) 41 | ``` 42 | 1. Instantiate model: 43 | ```python 44 | mdl = nn.Sequential( 45 | nn.Linear(10, 10), 46 | nn.ReLU(), 47 | nn.Linear(10, 1), 48 | nn.Sigmoid() 49 | ) 50 | ``` 51 | 1. Run range test to find suitable LR: 52 | ```python 53 | optimizer = optim.SGD(mdl.parameters(), lr=1.23e-4) # optimizer LR is ignored 54 | range_finder = RangeFinder(optimizer, epochs) 55 | 56 | losses = [] 57 | for epoch in range(epochs): 58 | # Print achieved schedule 59 | current_lr = [g['lr'] for g in optimizer.param_groups] 60 | current_mom = [g['momentum'] for g in optimizer.param_groups] 61 | print('LR: {}, Momentum: {}'.format(current_lr, current_mom)) 62 | 63 | loss = mdl(x).mean() 64 | loss.backward() 65 | optimizer.step() 66 | optimizer.zero_grad() 67 | range_finder.step() 68 | losses.append(loss.item()) 69 | ``` 70 | Based on results above, let's say the max LR is 1e-2 71 | 1. Re-instantiate model: 72 | ```python 73 | mdl = nn.Sequential( 74 | nn.Linear(10, 10), 75 | nn.ReLU(), 76 | nn.Linear(10, 1), 77 | nn.Sigmoid() 78 | ) 79 | ``` 80 | 1. Define 1cycle policy optimizer: 81 | ```python 82 | optimizer = optim.SGD(mdl.parameters(), lr=1.23e-4) # optimizer LR is ignored 83 | one_cycle = OneCyclePolicy(optimizer, 1e-2, epochs, momentum_rng=[0.85, 0.95]) 84 | ``` 85 | 1. Train model: 86 | ```python 87 | losses = [] 88 | for epoch in range(epochs): 89 | # Print achieved schedule 90 | current_lr = [g['lr'] for g in optimizer.param_groups] 91 | current_mom = [g['momentum'] for g in optimizer.param_groups] 92 | print('LR: {}, Momentum: {}'.format(current_lr, current_mom)) 93 | 94 | loss = mdl(x).mean() 95 | loss.backward() 96 | optimizer.step() 97 | optimizer.zero_grad() 98 | one_cycle.step() 99 | losses.append(loss.item()) 100 | ``` 101 | 102 | ## References 103 | 104 | * _A disciplined approach to neural network hyper-parameters: Part 1 -- learning rate, batch size, momentum, and weight decay_. Leslie N. Smith, 2018, arXiv:1803.09820. 105 | 106 | -------------------------------------------------------------------------------- /scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim import SGD, lr_scheduler 3 | import numpy as np 4 | 5 | 6 | class _LRMomentumScheduler(lr_scheduler._LRScheduler): 7 | def __init__(self, optimizer, last_epoch=-1): 8 | if last_epoch == -1: 9 | for group in optimizer.param_groups: 10 | group.setdefault('initial_momentum', group['momentum']) 11 | else: 12 | for i, group in enumerate(optimizer.param_groups): 13 | if 'initial_momentum' not in group: 14 | raise KeyError("param 'initial_momentum' is not specified " 15 | "in param_groups[{}] when resuming an optimizer".format(i)) 16 | self.base_momentums = list(map(lambda group: group['initial_momentum'], optimizer.param_groups)) 17 | super().__init__(optimizer, last_epoch) 18 | 19 | def get_lr(self): 20 | raise NotImplementedError 21 | 22 | def get_momentum(self): 23 | raise NotImplementedError 24 | 25 | def step(self, epoch=None): 26 | if epoch is None: 27 | epoch = self.last_epoch + 1 28 | self.last_epoch = epoch 29 | for param_group, lr, momentum in zip(self.optimizer.param_groups, self.get_lr(), self.get_momentum()): 30 | param_group['lr'] = lr 31 | param_group['momentum'] = momentum 32 | 33 | 34 | class ParameterUpdate(object): 35 | """A callable class used to define an arbitrary schedule defined by a list. 36 | This object is designed to be passed to the LambdaLR or LambdaScheduler scheduler to apply 37 | the given schedule. 38 | 39 | Arguments: 40 | params {list or numpy.array} -- List or numpy array defining parameter schedule. 41 | base_param {float} -- Parameter value used to initialize the optimizer. 42 | """ 43 | 44 | def __init__(self, params, base_param): 45 | self.params = np.hstack([params, 0]) 46 | self.base_param = base_param 47 | 48 | def __call__(self, epoch): 49 | return self.params[epoch] / self.base_param 50 | 51 | 52 | def apply_lambda(last_epoch, bases, lambdas): 53 | return [base * lmbda(last_epoch) for lmbda, base in zip(lambdas, bases)] 54 | 55 | 56 | class LambdaScheduler(_LRMomentumScheduler): 57 | """Sets the learning rate and momentum of each parameter group to the initial lr and momentum 58 | times a given function. When last_epoch=-1, sets initial lr and momentum to the optimizer 59 | values. 60 | 61 | Args: 62 | optimizer (Optimizer): Wrapped optimizer. 63 | lr_lambda (function or list): A function which computes a multiplicative 64 | factor given an integer parameter epoch, or a list of such 65 | functions, one for each group in optimizer.param_groups. 66 | Default: lambda x:x. 67 | momentum_lambda (function or list): As for lr_lambda but applied to momentum. 68 | Default: lambda x:x. 69 | last_epoch (int): The index of last epoch. Default: -1. 70 | 71 | Example: 72 | >>> # Assuming optimizer has two groups. 73 | >>> lr_lambda = [ 74 | ... lambda epoch: epoch // 30, 75 | ... lambda epoch: 0.95 ** epoch 76 | ... ] 77 | >>> mom_lambda = [ 78 | ... lambda epoch: max(0, (50 - epoch) // 50), 79 | ... lambda epoch: 0.99 ** epoch 80 | ... ] 81 | >>> scheduler = LambdaScheduler(optimizer, lr_lambda, mom_lambda) 82 | >>> for epoch in range(100): 83 | >>> train(...) 84 | >>> validate(...) 85 | >>> scheduler.step() 86 | """ 87 | 88 | def __init__(self, optimizer, lr_lambda=lambda x:x, momentum_lambda=lambda x:x, last_epoch=-1): 89 | self.optimizer = optimizer 90 | 91 | if not isinstance(lr_lambda, (list, tuple)): 92 | self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups) 93 | else: 94 | if len(lr_lambda) != len(optimizer.param_groups): 95 | raise ValueError("Expected {} lr_lambdas, but got {}".format( 96 | len(optimizer.param_groups), len(lr_lambda))) 97 | self.lr_lambdas = list(lr_lambda) 98 | 99 | if not isinstance(momentum_lambda, (list, tuple)): 100 | self.momentum_lambdas = [momentum_lambda] * len(optimizer.param_groups) 101 | else: 102 | if len(momentum_lambda) != len(optimizer.param_groups): 103 | raise ValueError("Expected {} momentum_lambdas, but got {}".format( 104 | len(optimizer.param_groups), len(momentum_lambda))) 105 | self.momentum_lambdas = list(momentum_lambda) 106 | 107 | self.last_epoch = last_epoch 108 | super().__init__(optimizer, last_epoch) 109 | 110 | def state_dict(self): 111 | """Returns the state of the scheduler as a :class:`dict`. 112 | 113 | It contains an entry for every variable in self.__dict__ which 114 | is not the optimizer. 115 | The learning rate and momentum lambda functions will only be saved if they are 116 | callable objects and not if they are functions or lambdas. 117 | """ 118 | state_dict = {key: value for key, value in self.__dict__.items() 119 | if key not in ('optimizer', 'lr_lambdas', 'momentum_lambdas')} 120 | state_dict['lr_lambdas'] = [None] * len(self.lr_lambdas) 121 | state_dict['momentum_lambdas'] = [None] * len(self.momentum_lambdas) 122 | 123 | for idx, (lr_fn, mom_fn) in enumerate(zip(self.lr_lambdas, self.momentum_lambdas)): 124 | if not isinstance(lr_fn, types.FunctionType): 125 | state_dict['lr_lambdas'][idx] = lr_fn.__dict__.copy() 126 | if not isinstance(mom_fn, types.FunctionType): 127 | state_dict['momentum_lambdas'][idx] = mom_fn.__dict__.copy() 128 | 129 | return state_dict 130 | 131 | def load_state_dict(self, state_dict): 132 | """Loads the schedulers state. 133 | 134 | Arguments: 135 | state_dict (dict): scheduler state. Should be an object returned 136 | from a call to :meth:`state_dict`. 137 | """ 138 | lr_lambdas = state_dict.pop('lr_lambdas') 139 | momentum_lambdas = state_dict.pop('momentum_lambdas') 140 | self.__dict__.update(state_dict) 141 | 142 | for idx, fn in enumerate(lr_lambdas): 143 | if fn is not None: 144 | self.lr_lambdas[idx].__dict__.update(fn) 145 | 146 | for idx, fn in enumerate(momentum_lambdas): 147 | if fn is not None: 148 | self.momentum_lambdas[idx].__dict__.update(fn) 149 | 150 | def get_lr(self): 151 | return apply_lambda(self.last_epoch, self.base_lrs, self.lr_lambdas) 152 | 153 | def get_momentum(self): 154 | return apply_lambda(self.last_epoch, self.base_momentums, self.momentum_lambdas) 155 | 156 | 157 | class ParameterUpdate(object): 158 | """A callable class used to define an arbitrary schedule defined by a list. 159 | This object is designed to be passed to the LambdaLR or LambdaScheduler scheduler to apply 160 | the given schedule. If a base_param is zero, no updates are applied. 161 | 162 | Arguments: 163 | params {list or numpy.array} -- List or numpy array defining parameter schedule. 164 | base_param {float} -- Parameter value used to initialize the optimizer. 165 | """ 166 | 167 | def __init__(self, params, base_param): 168 | self.params = np.hstack([params, 0]) 169 | self.base_param = base_param 170 | 171 | if base_param < 1e-12: 172 | self.base_param = 1 173 | self.params = self.params * 0.0 + 1.0 174 | 175 | def __call__(self, epoch): 176 | return self.params[epoch] / self.base_param 177 | 178 | 179 | class ListScheduler(LambdaScheduler): 180 | """Sets the learning rate and momentum of each parameter group to values defined by lists. 181 | When last_epoch=-1, sets initial lr and momentum to the optimizer values. One of both of lr 182 | and momentum schedules may be specified. 183 | 184 | Note that the parameters used to initialize the optimizer are overriden by those defined by 185 | this scheduler. 186 | 187 | Args: 188 | optimizer (Optimizer): Wrapped optimizer. 189 | lrs (list or numpy.ndarray): A list of learning rates, or a list of lists, one for each 190 | parameter group. One- or two-dimensional numpy arrays may also be passed. 191 | momentum (list or numpy.ndarray): A list of momentums, or a list of lists, one for each 192 | parameter group. One- or two-dimensional numpy arrays may also be passed. 193 | last_epoch (int): The index of last epoch. Default: -1. 194 | 195 | Example: 196 | >>> # Assuming optimizer has two groups. 197 | >>> lrs = [ 198 | ... np.linspace(0.01, 0.1, 100), 199 | ... np.logspace(-2, 0, 100) 200 | ... ] 201 | >>> momentums = [ 202 | ... np.linspace(0.85, 0.95, 100), 203 | ... np.linspace(0.8, 0.99, 100) 204 | ... ] 205 | >>> scheduler = ListScheduler(optimizer, lrs, momentums) 206 | >>> for epoch in range(100): 207 | >>> train(...) 208 | >>> validate(...) 209 | >>> scheduler.step() 210 | """ 211 | 212 | def __init__(self, optimizer, lrs=None, momentums=None, last_epoch=-1): 213 | groups = optimizer.param_groups 214 | if lrs is None: 215 | lr_lambda = lambda x: x 216 | else: 217 | lrs = np.array(lrs) if isinstance(lrs, (list, tuple)) else lrs 218 | if len(lrs.shape) == 1: 219 | lr_lambda = [ParameterUpdate(lrs, g['lr']) for g in groups] 220 | else: 221 | lr_lambda = [ParameterUpdate(l, g['lr']) for l, g in zip(lrs, groups)] 222 | 223 | if momentums is None: 224 | momentum_lambda = lambda x: x 225 | else: 226 | momentums = np.array(momentums) if isinstance(momentums, (list, tuple)) else momentums 227 | if len(momentums.shape) == 1: 228 | momentum_lambda = [ParameterUpdate(momentums, g['momentum']) for g in groups] 229 | else: 230 | momentum_lambda = [ParameterUpdate(l, g['momentum']) for l, g in zip(momentums, groups)] 231 | super().__init__(optimizer, lr_lambda, momentum_lambda) 232 | 233 | 234 | class RangeFinder(ListScheduler): 235 | """Scheduler class that implements the LR range search specified in: 236 | 237 | A disciplined approach to neural network hyper-parameters: Part 1 -- learning rate, batch 238 | size, momentum, and weight decay. Leslie N. Smith, 2018, arXiv:1803.09820. 239 | 240 | Logarithmically spaced learning rates from 1e-7 to 1 are searched. The number of increments in 241 | that range is determined by 'epochs'. 242 | 243 | Note that the parameters used to initialize the optimizer are overriden by those defined by 244 | this scheduler. 245 | 246 | Args: 247 | optimizer (Optimizer): Wrapped optimizer. 248 | epochs (int): Number of epochs over which to run test. 249 | 250 | Example: 251 | >>> scheduler = RangeFinder(optimizer, 100) 252 | >>> for epoch in range(100): 253 | >>> train(...) 254 | >>> validate(...) 255 | >>> scheduler.step() 256 | """ 257 | 258 | def __init__(self, optimizer, epochs): 259 | lrs = np.logspace(-7, 0, epochs) 260 | super().__init__(optimizer, lrs) 261 | 262 | 263 | class OneCyclePolicy(ListScheduler): 264 | """Scheduler class that implements the 1cycle policy search specified in: 265 | 266 | A disciplined approach to neural network hyper-parameters: Part 1 -- learning rate, batch 267 | size, momentum, and weight decay. Leslie N. Smith, 2018, arXiv:1803.09820. 268 | 269 | Args: 270 | optimizer (Optimizer): Wrapped optimizer. 271 | lr (float or list). Maximum learning rate in range. If a list of values is passed, they 272 | should correspond to parameter groups. 273 | epochs (int): The number of epochs to use during search. 274 | momentum_rng (list). Optional upper and lower momentum values (may be both equal). Set to 275 | None to run without momentum. Default: [0.85, 0.95]. If a list of lists is passed, they 276 | should correspond to parameter groups. 277 | phase_ratio (float): Fraction of epochs used for the increasing and decreasing phase of 278 | the schedule. For example, if phase_ratio=0.45 and epochs=100, the learning rate will 279 | increase from lr/10 to lr over 45 epochs, then decrease back to lr/10 over 45 epochs, 280 | then decrease to lr/100 over the remaining 10 epochs. Default: 0.45. 281 | """ 282 | 283 | def __init__(self, optimizer, lr, epochs, momentum_rng=[0.85, 0.95], phase_ratio=0.45): 284 | phase_epochs = int(phase_ratio * epochs) 285 | if isinstance(lr, (list, tuple)): 286 | lrs = [ 287 | np.hstack([ 288 | np.linspace(l * 1e-1, l, phase_epochs), 289 | np.linspace(l, l * 1e-1, phase_epochs), 290 | np.linspace(l * 1e-1, l * 1e-2, epochs - 2 * phase_epochs), 291 | ]) for l in lr 292 | ] 293 | else: 294 | lrs = np.hstack([ 295 | np.linspace(lr * 1e-1, lr, phase_epochs), 296 | np.linspace(lr, lr * 1e-1, phase_epochs), 297 | np.linspace(lr * 1e-1, lr * 1e-2, epochs - 2 * phase_epochs), 298 | ]) 299 | 300 | if momentum_rng is not None: 301 | momentum_rng = np.array(momentum_rng) 302 | if len(momentum_rng.shape) == 2: 303 | for i, g in enumerate(optimizer.param_groups): 304 | g['momentum'] = momentum_rng[i][1] 305 | momentums = [ 306 | np.hstack([ 307 | np.linspace(m[1], m[0], phase_epochs), 308 | np.linspace(m[0], m[1], phase_epochs), 309 | np.linspace(m[1], m[1], epochs - 2 * phase_epochs), 310 | ]) for m in momentum_rng 311 | ] 312 | else: 313 | for i, g in enumerate(optimizer.param_groups): 314 | g['momentum'] = momentum_rng[1] 315 | momentums = np.hstack([ 316 | np.linspace(momentum_rng[1], momentum_rng[0], phase_epochs), 317 | np.linspace(momentum_rng[0], momentum_rng[1], phase_epochs), 318 | np.linspace(momentum_rng[1], momentum_rng[1], epochs - 2 * phase_epochs), 319 | ]) 320 | else: 321 | momentums = None 322 | 323 | super().__init__(optimizer, lrs, momentums) 324 | --------------------------------------------------------------------------------