├── README.md ├── __init__.py ├── l4.py └── mnist_example.py /README.md: -------------------------------------------------------------------------------- 1 | # l4-pytorch 2 | This is PyTorch implementation of 3 | ["L4: Practical loss-based stepsize adaptation for deep learning"](https://arxiv.org/abs/1802.05074) By [Michal Rolínek](https://scholar.google.de/citations?user=DVdSTFQAAAAJ&hl=en), [Georg Martius](http://georg.playfulmachines.com/). 4 | 5 | To install put ```l4.py``` to working directory 6 | 7 | 8 | ```python 9 | from l4 import L4 10 | 11 | #... 12 | 13 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) | create mode 100644 l4.py 14 | # wrap original optimizer with L4 15 | l4opt = L4(optimizer) 16 | 17 | #... 18 | 19 | loss = F.nll_loss(output, target) 20 | loss.backward() 21 | 22 | # Comment out original optimizer step 23 | # optimizer.step() 24 | 25 | # make step with L4 optimizer, dont forget to pass loss value 26 | l4opt.step(loss) 27 | ``` 28 | 29 | Tensorflow implementation can be found [here](https://github.com/martius-lab/l4-optimizer) 30 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iovdin/l4-pytorch/dcf71bb256ff78652c5701a790e5f38fff8fb179/__init__.py -------------------------------------------------------------------------------- /l4.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import math 4 | 5 | class L4(): 6 | """Implements L4: Practical loss-based stepsize adaptation for deep learning 7 | 8 | Proposed by Michal Rolinek & Georg Martius in 9 | `paper `_. 10 | 11 | Arguments: 12 | params (iterable): iterable of parameters to optimize or dicts defining 13 | parameter groups 14 | optimizer: an optimizer to wrap with L4 15 | alpha (float, optional): scale the step size, recommended value is in range (0.1, 0.3) (default: 0.15) 16 | gamma (float, optional): scale min Loss (default: 0.9) 17 | tau (float, optional): min Loss forget rate (default: 1e-3) 18 | eps (float, optional): term added to the denominator to improve 19 | numerical stability (default: 1e-12) 20 | 21 | """ 22 | 23 | def __init__(self, optimizer, alpha=0.15, gamma=0.9, tau=1e-3, eps=1e-12): 24 | #TODO: save and load, state 25 | self.optimizer = optimizer 26 | self.state = dict(alpha=alpha, gamma=gamma, tau=tau, eps=eps) 27 | 28 | def zero_grad(self): 29 | self.optimizer.zero_grad() 30 | 31 | def step(self, loss): 32 | if loss is None: 33 | raise RuntimeError('L4: loss is required to step') 34 | 35 | if loss.data.item() < 0: 36 | raise RuntimeError('L4: loss must be non negative') 37 | 38 | if math.isnan(loss.data.item()): 39 | return 40 | 41 | # copy original data for parameters 42 | originals = {} 43 | # grad estimate decay 44 | decay = 0.9 45 | 46 | state = self.state 47 | if 'step' not in state: 48 | state['step'] = 0 49 | 50 | state['step'] += 1 51 | #correction_term = 1 - math.exp(state['step'] * math.log(decay)) 52 | correction_term = 1 - decay ** state['step'] 53 | 54 | 55 | for group in self.optimizer.param_groups: 56 | for p in group['params']: 57 | if p.grad is None: 58 | continue 59 | 60 | if p not in state: 61 | state[p] = torch.zeros_like(p.grad.data) 62 | 63 | # grad running average momentum 64 | state[p].mul_(decay).add_(1 - decay, p.grad.data) 65 | 66 | if p not in originals: 67 | originals[p] = torch.zeros_like(p.data) 68 | originals[p].copy_(p.data) 69 | p.data.zero_() 70 | 71 | 72 | if 'lmin' not in state: 73 | state['lmin'] = loss.data.item() * 0.75 74 | 75 | lmin = min(state['lmin'], loss.data.item()) 76 | 77 | gamma = state['gamma'] 78 | tau = state['tau'] 79 | alpha = state['alpha'] 80 | eps = state['eps'] 81 | 82 | self.optimizer.step() 83 | 84 | inner_prod = 0 85 | 86 | for group in self.optimizer.param_groups: 87 | for p in group['params']: 88 | if p.grad is None: 89 | continue 90 | grad = state[p].div(correction_term) 91 | v = -p.data.clone() 92 | inner_prod += torch.dot(grad.view(-1), -p.data.view(-1)) 93 | 94 | lr = alpha * (loss.data.item() - lmin * gamma) / (inner_prod + eps) 95 | state['lr'] = lr 96 | for group in self.optimizer.param_groups: 97 | for p in group['params']: 98 | if p.grad is None: 99 | continue 100 | grad = state[p].div(correction_term) 101 | v = -p.data.clone() 102 | p.data.copy_(originals[p]) 103 | p.data.add_(-lr, v) 104 | 105 | state['lmin'] = (1 + tau) * lmin 106 | 107 | return loss 108 | -------------------------------------------------------------------------------- /mnist_example.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | from torchvision import datasets, transforms 8 | from torch.autograd import Variable 9 | from l4 import L4 10 | 11 | # Training settings 12 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 13 | parser.add_argument('--batch-size', type=int, default=64, metavar='N', 14 | help='input batch size for training (default: 64)') 15 | parser.add_argument('--test-batch-size', type=int, default=64, metavar='N', 16 | help='input batch size for testing (default: 1000)') 17 | parser.add_argument('--epochs', type=int, default=10, metavar='N', 18 | help='number of epochs to train (default: 10)') 19 | parser.add_argument('--lr', type=float, default=0.01, metavar='LR', 20 | help='learning rate (default: 0.01)') 21 | parser.add_argument('--momentum', type=float, default=0.5, metavar='M', 22 | help='SGD momentum (default: 0.5)') 23 | parser.add_argument('--no-cuda', action='store_true', default=False, 24 | help='disables CUDA training') 25 | parser.add_argument('--seed', type=int, default=1, metavar='S', 26 | help='random seed (default: 1)') 27 | parser.add_argument('--log-interval', type=int, default=10, metavar='N', 28 | help='how many batches to wait before logging training status') 29 | args = parser.parse_args() 30 | args.cuda = not args.no_cuda and torch.cuda.is_available() 31 | 32 | torch.manual_seed(args.seed) 33 | if args.cuda: 34 | torch.cuda.manual_seed(args.seed) 35 | 36 | 37 | 38 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} 39 | train_loader = torch.utils.data.DataLoader( 40 | datasets.MNIST('../data', train=True, download=True, 41 | transform=transforms.Compose([ 42 | transforms.ToTensor(), 43 | transforms.Normalize((0.1307,), (0.3081,)) 44 | ])), 45 | batch_size=args.batch_size, shuffle=True, **kwargs) 46 | test_loader = torch.utils.data.DataLoader( 47 | datasets.MNIST('../data', train=False, transform=transforms.Compose([ 48 | transforms.ToTensor(), 49 | transforms.Normalize((0.1307,), (0.3081,)) 50 | ])), 51 | batch_size=args.test_batch_size, shuffle=True, **kwargs) 52 | 53 | 54 | class Net(nn.Module): 55 | def __init__(self): 56 | super(Net, self).__init__() 57 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 58 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 59 | self.conv2_drop = nn.Dropout2d() 60 | self.fc1 = nn.Linear(320, 50) 61 | self.fc2 = nn.Linear(50, 10) 62 | 63 | def forward(self, x): 64 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 65 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 66 | x = x.view(-1, 320) 67 | x = F.relu(self.fc1(x)) 68 | x = F.dropout(x, training=self.training) 69 | x = self.fc2(x) 70 | return F.log_softmax(x) 71 | 72 | model = Net() 73 | if args.cuda: 74 | model.cuda() 75 | 76 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) 77 | l4opt = L4(optimizer) 78 | 79 | iteration = 0 80 | def train(epoch): 81 | global iteration 82 | model.train() 83 | for batch_idx, (data, target) in enumerate(train_loader): 84 | if args.cuda: 85 | data, target = data.cuda(), target.cuda() 86 | data, target = Variable(data), Variable(target) 87 | optimizer.zero_grad() 88 | output = model(data) 89 | loss = F.nll_loss(output, target) 90 | loss.backward() 91 | iteration += 1 92 | #optimizer.step() 93 | l4opt.step(loss) 94 | if batch_idx % args.log_interval == 0: 95 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 96 | epoch, batch_idx * len(data), len(train_loader.dataset), 97 | 100. * batch_idx / len(train_loader), loss.data[0])) 98 | 99 | def test(): 100 | model.eval() 101 | test_loss = 0 102 | correct = 0 103 | for data, target in test_loader: 104 | if args.cuda: 105 | data, target = data.cuda(), target.cuda() 106 | data, target = Variable(data, volatile=True), Variable(target) 107 | output = model(data) 108 | loss = F.nll_loss(output, target, size_average=False) 109 | test_loss += loss.data[0] # sum up batch loss 110 | pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability 111 | correct += pred.eq(target.data.view_as(pred)).cpu().sum() 112 | 113 | test_loss /= len(test_loader.dataset) 114 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 115 | test_loss, correct, len(test_loader.dataset), 116 | 100. * correct / len(test_loader.dataset))) 117 | 118 | 119 | for epoch in range(1, args.epochs + 1): 120 | train(epoch) 121 | test() 122 | --------------------------------------------------------------------------------