├── .gitignore ├── .pylintrc ├── LICENSE ├── README.md ├── images ├── histogram.png └── loss.png ├── pytorch_lamb ├── __init__.py └── lamb.py ├── setup.py └── test_lamb.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | runs 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don’t work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # celery beat schedule file 96 | celerybeat-schedule 97 | 98 | # SageMath parsed files 99 | *.sage.py 100 | 101 | # Environments 102 | .env 103 | .venv 104 | env/ 105 | venv/ 106 | ENV/ 107 | env.bak/ 108 | venv.bak/ 109 | 110 | # Spyder project settings 111 | .spyderproject 112 | .spyproject 113 | 114 | # Rope project settings 115 | .ropeproject 116 | 117 | # mkdocs documentation 118 | /site 119 | 120 | # mypy 121 | .mypy_cache/ 122 | .dmypy.json 123 | dmypy.json 124 | 125 | # Pyre type checker 126 | .pyre/ -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [TYPECHECK] 2 | 3 | # List of members which are set dynamically and missed by Pylint inference 4 | # system, and so shouldn't trigger E1101 when accessed. 5 | generated-members=numpy.*, torch.* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 cybertronai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Implementation of https://arxiv.org/abs/1904.00962 for large batch, large learning rate training. 2 | 3 | The paper doesn't specify clamp values for ϕ, so I use 10. 4 | 5 | Bonus: TensorboardX logging (example below). 6 | 7 | ## Try the sample 8 | ``` 9 | git clone git@github.com:cybertronai/pytorch-lamb.git 10 | cd pytorch-lamb 11 | pip install -e . 12 | python test_lamb.py 13 | tensorboard --logdir=runs 14 | ``` 15 | 16 | ## Sample results 17 | At `--lr=.02`, the Adam optimizer is unable to train. 18 | 19 | Red: `python test_lamb.py --batch-size=512 --lr=.02 --wd=.01 --log-interval=30 --optimizer=adam` 20 | 21 | Blue: `python test_lamb.py --batch-size=512 --lr=.02 --wd=.01 --log-interval=30 --optimizer=lamb` 22 | ![](images/loss.png) 23 | 24 | ![](images/histogram.png) -------------------------------------------------------------------------------- /images/histogram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Smerity/pytorch-lamb/704f733c83c18fc5f3c01f085b5beb38043b38af/images/histogram.png -------------------------------------------------------------------------------- /images/loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Smerity/pytorch-lamb/704f733c83c18fc5f3c01f085b5beb38043b38af/images/loss.png -------------------------------------------------------------------------------- /pytorch_lamb/__init__.py: -------------------------------------------------------------------------------- 1 | from .lamb import Lamb, log_lamb_rs -------------------------------------------------------------------------------- /pytorch_lamb/lamb.py: -------------------------------------------------------------------------------- 1 | """Lamb optimizer.""" 2 | 3 | import collections 4 | import math 5 | 6 | import torch 7 | from tensorboardX import SummaryWriter 8 | from torch.optim import Optimizer 9 | 10 | 11 | def log_lamb_rs(optimizer: Optimizer, event_writer: SummaryWriter, token_count: int): 12 | """Log a histogram of trust ratio scalars in across layers.""" 13 | results = collections.defaultdict(list) 14 | for group in optimizer.param_groups: 15 | for p in group['params']: 16 | state = optimizer.state[p] 17 | for i in ('weight_norm', 'adam_norm', 'trust_ratio'): 18 | if i in state: 19 | results[i].append(state[i]) 20 | 21 | for k, v in results.items(): 22 | event_writer.add_histogram(f'lamb/{k}', torch.tensor(v), token_count) 23 | 24 | class Lamb(Optimizer): 25 | r"""Implements Lamb algorithm. 26 | 27 | It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. 28 | 29 | Arguments: 30 | params (iterable): iterable of parameters to optimize or dicts defining 31 | parameter groups 32 | lr (float, optional): learning rate (default: 1e-3) 33 | betas (Tuple[float, float], optional): coefficients used for computing 34 | running averages of gradient and its square (default: (0.9, 0.999)) 35 | eps (float, optional): term added to the denominator to improve 36 | numerical stability (default: 1e-8) 37 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 38 | adam (bool, optional): always use trust ratio = 1, which turns this into 39 | Adam. Useful for comparison purposes. 40 | 41 | .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes: 42 | https://arxiv.org/abs/1904.00962 43 | """ 44 | 45 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, 46 | weight_decay=0, adam=False, min_trust=None): 47 | if not 0.0 <= lr: 48 | raise ValueError("Invalid learning rate: {}".format(lr)) 49 | if not 0.0 <= eps: 50 | raise ValueError("Invalid epsilon value: {}".format(eps)) 51 | if not 0.0 <= betas[0] < 1.0: 52 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 53 | if not 0.0 <= betas[1] < 1.0: 54 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 55 | if min_trust and not 0.0 <= min_trust < 1.0: 56 | raise ValueError("Minimum trust range from 0 to 1: {}".format(min_trust)) 57 | defaults = dict(lr=lr, betas=betas, eps=eps, 58 | weight_decay=weight_decay) 59 | self.adam = adam 60 | self.min_trust = min_trust 61 | super(Lamb, self).__init__(params, defaults) 62 | 63 | def step(self, closure=None): 64 | """Performs a single optimization step. 65 | 66 | Arguments: 67 | closure (callable, optional): A closure that reevaluates the model 68 | and returns the loss. 69 | """ 70 | loss = None 71 | if closure is not None: 72 | loss = closure() 73 | 74 | for group in self.param_groups: 75 | for p in group['params']: 76 | if p.grad is None: 77 | continue 78 | grad = p.grad.data 79 | if grad.is_sparse: 80 | raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.') 81 | 82 | state = self.state[p] 83 | 84 | # State initialization 85 | if len(state) == 0: 86 | state['step'] = 0 87 | # Exponential moving average of gradient values 88 | state['exp_avg'] = torch.zeros_like(p.data) 89 | # Exponential moving average of squared gradient values 90 | state['exp_avg_sq'] = torch.zeros_like(p.data) 91 | 92 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 93 | beta1, beta2 = group['betas'] 94 | 95 | state['step'] += 1 96 | 97 | # Decay the first and second moment running average coefficient 98 | # m_t 99 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 100 | # v_t 101 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 102 | 103 | # Paper v3 does not use debiasing. 104 | # bias_correction1 = 1 - beta1 ** state['step'] 105 | # bias_correction2 = 1 - beta2 ** state['step'] 106 | # Apply bias to lr to avoid broadcast. 107 | step_size = group['lr'] # * math.sqrt(bias_correction2) / bias_correction1 108 | 109 | weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10) 110 | 111 | adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps']) 112 | if group['weight_decay'] != 0: 113 | adam_step.add_(group['weight_decay'], p.data) 114 | 115 | adam_norm = adam_step.pow(2).sum().sqrt() 116 | if weight_norm == 0 or adam_norm == 0: 117 | trust_ratio = 1 118 | else: 119 | trust_ratio = weight_norm / adam_norm 120 | if self.min_trust: 121 | trust_ratio = max(trust_ratio, self.min_trust) 122 | state['weight_norm'] = weight_norm 123 | state['adam_norm'] = adam_norm 124 | state['trust_ratio'] = trust_ratio 125 | if self.adam: 126 | trust_ratio = 1 127 | 128 | p.data.add_(-step_size * trust_ratio, adam_step) 129 | 130 | return loss 131 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name="pytorch_lamb", 5 | author="Ben Mann", 6 | version='1.0.0', 7 | author_email="me@benjmann.net", 8 | packages=find_packages(exclude=["*.tests", "*.tests.*", 9 | "tests.*", "tests"]), 10 | long_description=open("README.md", "r", encoding='utf-8').read(), 11 | long_description_content_type="text/markdown", 12 | license='MIT', 13 | url="https://github.com/cybertronai/pytorch-lamb", 14 | install_requires=[ 15 | 'torch>=0.4.1', 16 | 'tqdm', 17 | 'tensorboardX', 18 | 'torchvision', 19 | ], 20 | ) 21 | -------------------------------------------------------------------------------- /test_lamb.py: -------------------------------------------------------------------------------- 1 | """MNIST example. 2 | 3 | Based on https://github.com/pytorch/examples/blob/master/mnist/main.py 4 | """ 5 | 6 | from __future__ import print_function 7 | import argparse 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import tqdm 12 | from tensorboardX import SummaryWriter 13 | from torchvision import datasets, transforms 14 | from pytorch_lamb import Lamb, log_lamb_rs 15 | 16 | 17 | class Net(nn.Module): 18 | def __init__(self): 19 | super(Net, self).__init__() 20 | self.conv1 = nn.Conv2d(1, 20, 5, 1) 21 | self.conv2 = nn.Conv2d(20, 50, 5, 1) 22 | self.fc1 = nn.Linear(4*4*50, 500) 23 | self.fc2 = nn.Linear(500, 10) 24 | 25 | def forward(self, x): 26 | x = F.relu(self.conv1(x)) 27 | x = F.max_pool2d(x, 2, 2) 28 | x = F.relu(self.conv2(x)) 29 | x = F.max_pool2d(x, 2, 2) 30 | x = x.view(-1, 4*4*50) 31 | x = F.relu(self.fc1(x)) 32 | x = self.fc2(x) 33 | return F.log_softmax(x, dim=1) 34 | 35 | def train(args, model, device, train_loader, optimizer, epoch, event_writer): 36 | model.train() 37 | tqdm_bar = tqdm.tqdm(train_loader) 38 | for batch_idx, (data, target) in enumerate(tqdm_bar): 39 | data, target = data.to(device), target.to(device) 40 | optimizer.zero_grad() 41 | output = model(data) 42 | loss = F.nll_loss(output, target) 43 | loss.backward() 44 | optimizer.step() 45 | if batch_idx % args.log_interval == 0: 46 | step = batch_idx * len(data) + (epoch-1) * len(train_loader.dataset) 47 | log_lamb_rs(optimizer, event_writer, step) 48 | event_writer.add_scalar('loss', loss.item(), step) 49 | tqdm_bar.set_description( 50 | f'Train epoch {epoch} Loss: {loss.item():.6f}') 51 | 52 | def test(args, model, device, test_loader, event_writer:SummaryWriter, epoch): 53 | model.eval() 54 | test_loss = 0 55 | correct = 0 56 | with torch.no_grad(): 57 | for data, target in test_loader: 58 | data, target = data.to(device), target.to(device) 59 | output = model(data) 60 | test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss 61 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability 62 | correct += pred.eq(target.view_as(pred)).sum().item() 63 | 64 | test_loss /= len(test_loader.dataset) 65 | acc = correct / len(test_loader.dataset) 66 | event_writer.add_scalar('loss/test_loss', test_loss, epoch - 1) 67 | event_writer.add_scalar('loss/test_acc', acc, epoch - 1) 68 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 69 | test_loss, correct, len(test_loader.dataset), 70 | 100. * acc)) 71 | 72 | def main(): 73 | # Training settings 74 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 75 | parser.add_argument('--batch-size', type=int, default=64, metavar='N', 76 | help='input batch size for training (default: 64)') 77 | parser.add_argument('--optimizer', type=str, default='lamb', choices=['lamb', 'adam'], 78 | help='which optimizer to use') 79 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', 80 | help='input batch size for testing (default: 1000)') 81 | parser.add_argument('--epochs', type=int, default=6, metavar='N', 82 | help='number of epochs to train (default: 10)') 83 | parser.add_argument('--lr', type=float, default=0.0025, metavar='LR', 84 | help='learning rate (default: 0.0025)') 85 | parser.add_argument('--wd', type=float, default=0.01, metavar='WD', 86 | help='weight decay (default: 0.01)') 87 | parser.add_argument('--seed', type=int, default=1, metavar='S', 88 | help='random seed (default: 1)') 89 | parser.add_argument('--log-interval', type=int, default=10, metavar='N', 90 | help='how many batches to wait before logging training status') 91 | 92 | args = parser.parse_args() 93 | use_cuda = torch.cuda.is_available() 94 | 95 | torch.manual_seed(args.seed) 96 | 97 | device = torch.device("cuda" if use_cuda else "cpu") 98 | 99 | kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} 100 | train_loader = torch.utils.data.DataLoader( 101 | datasets.MNIST('../data', train=True, download=True, 102 | transform=transforms.Compose([ 103 | transforms.ToTensor(), 104 | transforms.Normalize((0.1307,), (0.3081,)) 105 | ])), 106 | batch_size=args.batch_size, shuffle=True, **kwargs) 107 | test_loader = torch.utils.data.DataLoader( 108 | datasets.MNIST('../data', train=False, transform=transforms.Compose([ 109 | transforms.ToTensor(), 110 | transforms.Normalize((0.1307,), (0.3081,)) 111 | ])), 112 | batch_size=args.test_batch_size, shuffle=True, **kwargs) 113 | 114 | 115 | model = Net().to(device) 116 | optimizer = Lamb(model.parameters(), lr=args.lr, weight_decay=args.wd, betas=(.9, .999), adam=(args.optimizer == 'adam')) 117 | writer = SummaryWriter() 118 | for epoch in range(1, args.epochs + 1): 119 | train(args, model, device, train_loader, optimizer, epoch, writer) 120 | test(args, model, device, test_loader, writer, epoch) 121 | 122 | 123 | if __name__ == '__main__': 124 | main() 125 | --------------------------------------------------------------------------------