├── LICENSE ├── README.md ├── nn.py ├── optimizer.py ├── png ├── loss.png ├── valid_acc.png └── valid_loss.png ├── progressbar.py ├── run.py ├── tools.py └── trainingmonitor.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 lonePatinet 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Lookahead Pytorch 2 | 3 | This repository contains a PyTorch implementation of the Lookahead Optimizer from the paper 4 | 5 | [Lookahead Optimizer: k steps forward, 1 step back](https://arxiv.org/abs/1907.08610) 6 | 7 | by Michael R. Zhang, James Lucas, Geoffrey Hinton and Jimmy Ba. 8 | 9 | ## Dependencies 10 | 11 | * PyTorch 12 | * torchvision 13 | * matplotlib 14 | 15 | ## Usage 16 | 17 | The code in this repository implements both Lookahead and Adam training, with examples on the CIFAR-10 datasets. 18 | 19 | To use Lookahead use the following command. 20 | 21 | ```python 22 | from optimizer import Lookahead 23 | optimizer = optim.Adam(model.parameters(), lr=0.001) 24 | optimizer = Lookahead(optimizer=optimizer,k=5,alpha=0.5) 25 | ``` 26 | 27 | We found that evaluation performance is typically better using the slow weights. This can be done in PyTorch with something like this in your eval loop: 28 | ```python 29 | if args.lookahead: 30 | optimizer._backup_and_load_cache() 31 | val_loss = eval_func(model) 32 | optimizer._clear_and_load_backup() 33 | ``` 34 | ## Example 35 | 36 | To produce th result,we use CIFAR-10 dataset for ResNet18. 37 | 38 | ```python 39 | # use adam 40 | python run.py --optimizer=adam 41 | 42 | # use lookahead 43 | python run.py --optimizer=lookahead 44 | ``` 45 | ## Results 46 | 47 | Train loss of adam and lookahead with ResNet18 on CIFAR-10. 48 | 49 | ![](./png/loss.png) 50 | 51 | Valid loss of adam and lookahead with ResNet18 on CIFAR-10. 52 | 53 | ![](./png/valid_loss.png) 54 | 55 | Valid accuracy of adam and lookahead with ResNet18 on CIFAR-10. 56 | 57 | ![](./png/valid_acc.png) 58 | -------------------------------------------------------------------------------- /nn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class ResidualBlock(nn.Module): 6 | def __init__(self, inchannel, outchannel, stride=1): 7 | super(ResidualBlock, self).__init__() 8 | self.left = nn.Sequential( 9 | nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=stride, padding=1, bias=False), 10 | nn.BatchNorm2d(outchannel), 11 | nn.ReLU(inplace=True), 12 | nn.Conv2d(outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False), 13 | nn.BatchNorm2d(outchannel) 14 | ) 15 | self.shortcut = nn.Sequential() 16 | if stride != 1 or inchannel != outchannel: 17 | self.shortcut = nn.Sequential( 18 | nn.Conv2d(inchannel, outchannel, kernel_size=1, stride=stride, bias=False), 19 | nn.BatchNorm2d(outchannel) 20 | ) 21 | 22 | def forward(self, x): 23 | out = self.left(x) 24 | out += self.shortcut(x) 25 | out = F.relu(out) 26 | return out 27 | 28 | class ResNet(nn.Module): 29 | def __init__(self, ResidualBlock, num_classes=10): 30 | super(ResNet, self).__init__() 31 | self.inchannel = 64 32 | self.conv1 = nn.Sequential( 33 | nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False), 34 | nn.BatchNorm2d(64), 35 | nn.ReLU(), 36 | ) 37 | self.layer1 = self.make_layer(ResidualBlock, 64, 2, stride=1) 38 | self.layer2 = self.make_layer(ResidualBlock, 128, 2, stride=2) 39 | self.layer3 = self.make_layer(ResidualBlock, 256, 2, stride=2) 40 | self.layer4 = self.make_layer(ResidualBlock, 512, 2, stride=2) 41 | self.fc = nn.Linear(512, num_classes) 42 | 43 | def make_layer(self, block, channels, num_blocks, stride): 44 | strides = [stride] + [1] * (num_blocks - 1) #strides=[1,1] 45 | layers = [] 46 | for stride in strides: 47 | layers.append(block(self.inchannel, channels, stride)) 48 | self.inchannel = channels 49 | return nn.Sequential(*layers) 50 | 51 | def forward(self, x): 52 | out = self.conv1(x) 53 | out = self.layer1(out) 54 | out = self.layer2(out) 55 | out = self.layer3(out) 56 | out = self.layer4(out) 57 | out = F.avg_pool2d(out, 4) 58 | out = out.view(out.size(0), -1) 59 | out = self.fc(out) 60 | return out 61 | 62 | 63 | def ResNet18(): 64 | 65 | return ResNet(ResidualBlock) -------------------------------------------------------------------------------- /optimizer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import itertools as it 4 | from torch.optim import Optimizer 5 | from collections import defaultdict 6 | 7 | class Lookahead(Optimizer): 8 | ''' 9 | PyTorch implementation of the lookahead wrapper. 10 | Lookahead Optimizer: https://arxiv.org/abs/1907.08610 11 | ''' 12 | def __init__(self, optimizer,alpha=0.5, k=6,pullback_momentum="none"): 13 | ''' 14 | :param optimizer:inner optimizer 15 | :param k (int): number of lookahead steps 16 | :param alpha(float): linear interpolation factor. 1.0 recovers the inner optimizer. 17 | :param pullback_momentum (str): change to inner optimizer momentum on interpolation update 18 | ''' 19 | if not 0.0 <= alpha <= 1.0: 20 | raise ValueError(f'Invalid slow update rate: {alpha}') 21 | if not 1 <= k: 22 | raise ValueError(f'Invalid lookahead steps: {k}') 23 | self.optimizer = optimizer 24 | self.param_groups = self.optimizer.param_groups 25 | self.alpha = alpha 26 | self.k = k 27 | self.step_counter = 0 28 | assert pullback_momentum in ["reset", "pullback", "none"] 29 | self.pullback_momentum = pullback_momentum 30 | self.state = defaultdict(dict) 31 | 32 | # Cache the current optimizer parameters 33 | for group in self.optimizer.param_groups: 34 | for p in group['params']: 35 | param_state = self.state[p] 36 | param_state['cached_params'] = torch.zeros_like(p.data) 37 | param_state['cached_params'].copy_(p.data) 38 | 39 | def __getstate__(self): 40 | return { 41 | 'state': self.state, 42 | 'optimizer': self.optimizer, 43 | 'alpha': self.alpha, 44 | 'step_counter': self.step_counter, 45 | 'k':self.k, 46 | 'pullback_momentum': self.pullback_momentum 47 | } 48 | 49 | def zero_grad(self): 50 | self.optimizer.zero_grad() 51 | 52 | def state_dict(self): 53 | return self.optimizer.state_dict() 54 | 55 | def load_state_dict(self, state_dict): 56 | self.optimizer.load_state_dict(state_dict) 57 | 58 | def _backup_and_load_cache(self): 59 | """Useful for performing evaluation on the slow weights (which typically generalize better) 60 | """ 61 | for group in self.optimizer.param_groups: 62 | for p in group['params']: 63 | param_state = self.state[p] 64 | param_state['backup_params'] = torch.zeros_like(p.data) 65 | param_state['backup_params'].copy_(p.data) 66 | p.data.copy_(param_state['cached_params']) 67 | 68 | def _clear_and_load_backup(self): 69 | for group in self.optimizer.param_groups: 70 | for p in group['params']: 71 | param_state = self.state[p] 72 | p.data.copy_(param_state['backup_params']) 73 | del param_state['backup_params'] 74 | 75 | def step(self, closure=None): 76 | """Performs a single Lookahead optimization step. 77 | Arguments: 78 | closure (callable, optional): A closure that reevaluates the model 79 | and returns the loss. 80 | """ 81 | loss = self.optimizer.step(closure) 82 | self.step_counter += 1 83 | 84 | if self.step_counter >= self.k: 85 | self.step_counter = 0 86 | # Lookahead and cache the current optimizer parameters 87 | for group in self.optimizer.param_groups: 88 | for p in group['params']: 89 | param_state = self.state[p] 90 | p.data.mul_(self.alpha).add_(1.0 - self.alpha, param_state['cached_params']) # crucial line 91 | param_state['cached_params'].copy_(p.data) 92 | if self.pullback_momentum == "pullback": 93 | internal_momentum = self.optimizer.state[p]["momentum_buffer"] 94 | self.optimizer.state[p]["momentum_buffer"] = internal_momentum.mul_(self.alpha).add_( 95 | 1.0 - self.alpha, param_state["cached_mom"]) 96 | param_state["cached_mom"] = self.optimizer.state[p]["momentum_buffer"] 97 | elif self.pullback_momentum == "reset": 98 | self.optimizer.state[p]["momentum_buffer"] = torch.zeros_like(p.data) 99 | 100 | return loss 101 | -------------------------------------------------------------------------------- /png/loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lonePatient/lookahead_pytorch/1055128057408fe8533ffa30654551a317f07f0a/png/loss.png -------------------------------------------------------------------------------- /png/valid_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lonePatient/lookahead_pytorch/1055128057408fe8533ffa30654551a317f07f0a/png/valid_acc.png -------------------------------------------------------------------------------- /png/valid_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lonePatient/lookahead_pytorch/1055128057408fe8533ffa30654551a317f07f0a/png/valid_loss.png -------------------------------------------------------------------------------- /progressbar.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | class ProgressBar(object): 4 | ''' 5 | custom progress bar 6 | Example: 7 | >>> pbar = ProgressBar(n_total=30,desc='training') 8 | >>> step = 2 9 | >>> pbar(step=step) 10 | ''' 11 | def __init__(self, n_total,width=30,desc = 'Training'): 12 | self.width = width 13 | self.n_total = n_total 14 | self.start_time = time.time() 15 | self.desc = desc 16 | 17 | def __call__(self, step, info={}): 18 | now = time.time() 19 | current = step + 1 20 | recv_per = current / self.n_total 21 | bar = f'[{self.desc}] {current}/{self.n_total} [' 22 | if recv_per >= 1: 23 | recv_per = 1 24 | prog_width = int(self.width * recv_per) 25 | if prog_width > 0: 26 | bar += '=' * (prog_width - 1) 27 | if current< self.n_total: 28 | bar += ">" 29 | else: 30 | bar += '=' 31 | bar += '.' * (self.width - prog_width) 32 | bar += ']' 33 | show_bar = f"\r{bar}" 34 | time_per_unit = (now - self.start_time) / current 35 | if current < self.n_total: 36 | eta = time_per_unit * (self.n_total - current) 37 | if eta > 3600: 38 | eta_format = ('%d:%02d:%02d' % 39 | (eta // 3600, (eta % 3600) // 60, eta % 60)) 40 | elif eta > 60: 41 | eta_format = '%d:%02d' % (eta // 60, eta % 60) 42 | else: 43 | eta_format = '%ds' % eta 44 | time_info = f' - ETA: {eta_format}' 45 | else: 46 | if time_per_unit >= 1: 47 | time_info = f' {time_per_unit:.1f}s/step' 48 | elif time_per_unit >= 1e-3: 49 | time_info = f' {time_per_unit * 1e3:.1f}ms/step' 50 | else: 51 | time_info = f' {time_per_unit * 1e6:.1f}us/step' 52 | 53 | show_bar += time_info 54 | if len(info) != 0: 55 | show_info = f'{show_bar} ' + \ 56 | "-".join([f' {key}: {value:.4f} ' for key, value in info.items()]) 57 | print(show_info, end='') 58 | else: 59 | print(show_bar, end='') 60 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import torch.nn as nn 4 | from nn import ResNet18 5 | from tools import AverageMeter 6 | from progressbar import ProgressBar 7 | from tools import seed_everything 8 | from torchvision import datasets, transforms 9 | from torch.utils.data import DataLoader 10 | import torch.optim as optim 11 | from trainingmonitor import TrainingMonitor 12 | from optimizer import Lookahead 13 | 14 | epochs = 30 15 | batch_size = 128 16 | seed = 42 17 | 18 | seed_everything(seed) 19 | model = ResNet18() 20 | loss_fn = nn.CrossEntropyLoss() 21 | device = torch.device("cuda:0") 22 | model.to(device) 23 | 24 | parser = argparse.ArgumentParser(description='CIFAR10') 25 | parser.add_argument("--model", type=str, default='ResNet18') 26 | parser.add_argument("--task", type=str, default='image') 27 | parser.add_argument("--optimizer", default='lookahead',type=str,choices=['lookahead','adam']) 28 | args = parser.parse_args() 29 | 30 | if args.optimizer == 'lookahead': 31 | arch = 'ResNet18_Lookahead_adam' 32 | optimizer = optim.Adam(model.parameters(), lr=0.001) 33 | optimizer = Lookahead(optimizer=optimizer,k=5,alpha=0.5) 34 | else: 35 | arch = 'ResNet18_Adam' 36 | optimizer = optim.Adam(model.parameters(), lr=0.001) 37 | 38 | train_monitor = TrainingMonitor(file_dir='./',arch = arch) 39 | 40 | def train(train_loader): 41 | pbar = ProgressBar(n_total=len(train_loader),desc='Training') 42 | train_loss = AverageMeter() 43 | model.train() 44 | for batch_idx, (data, target) in enumerate(train_loader): 45 | data, target = data.to(device), target.to(device) 46 | optimizer.zero_grad() 47 | output = model(data) 48 | loss = loss_fn(output, target) 49 | loss.backward() 50 | optimizer.step() 51 | pbar(step = batch_idx,info = {'loss':loss.item()}) 52 | train_loss.update(loss.item(),n =1) 53 | return {'loss':train_loss.avg} 54 | 55 | def test(test_loader): 56 | pbar = ProgressBar(n_total=len(test_loader),desc='Testing') 57 | valid_loss = AverageMeter() 58 | valid_acc = AverageMeter() 59 | model.eval() 60 | count = 0 61 | with torch.no_grad(): 62 | for batch_idx,(data, target) in enumerate(test_loader): 63 | data, target = data.to(device), target.to(device) 64 | output = model(data) 65 | loss = loss_fn(output, target).item() # sum up batch loss 66 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability 67 | correct = pred.eq(target.view_as(pred)).sum().item() 68 | valid_loss.update(loss,n = data.size(0)) 69 | valid_acc.update(correct, n=1) 70 | count += data.size(0) 71 | pbar(step=batch_idx) 72 | return {'valid_loss':valid_loss.avg, 73 | 'valid_acc':valid_acc.sum /count} 74 | 75 | data = { 76 | 'train': datasets.CIFAR10( 77 | root='./data', download=True, 78 | transform=transforms.Compose([ 79 | transforms.RandomCrop((32, 32), padding=4), 80 | transforms.RandomHorizontalFlip(), 81 | transforms.ToTensor(), 82 | transforms.Normalize((0.4914,0.4822,0.4465),(0.2023,0.1994,0.2010))] 83 | ) 84 | ), 85 | 'valid': datasets.CIFAR10( 86 | root='./data', train=False, download=True, 87 | transform=transforms.Compose([ 88 | transforms.ToTensor(), 89 | transforms.Normalize((0.4914,0.4822,0.4465),(0.2023,0.1994,0.2010))] 90 | ) 91 | ) 92 | } 93 | 94 | loaders = { 95 | 'train': DataLoader(data['train'], batch_size=128, shuffle=True, 96 | num_workers=10, pin_memory=True, 97 | drop_last=True), 98 | 'valid': DataLoader(data['valid'], batch_size=128, 99 | num_workers=10, pin_memory=True, 100 | drop_last=False) 101 | } 102 | 103 | for epoch in range(1, epochs + 1): 104 | train_log = train(loaders['train']) 105 | if args.optimizer == 'lookahead': 106 | optimizer._backup_and_load_cache() 107 | valid_log = test(loaders['valid']) 108 | optimizer._clear_and_load_backup() 109 | else: 110 | valid_log = test(loaders['valid']) 111 | logs = dict(train_log, **valid_log) 112 | show_info = f'\nEpoch: {epoch} - ' + "-".join([f' {key}: {value:.4f} ' for key, value in logs.items()]) 113 | print(show_info) 114 | train_monitor.epoch_step(logs) 115 | 116 | -------------------------------------------------------------------------------- /tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pathlib import Path 3 | import json 4 | import random 5 | import torch 6 | import os 7 | 8 | 9 | def save_json(data, file_path): 10 | ''' 11 | save json 12 | :param data: 13 | :param json_path: 14 | :param file_name: 15 | :return: 16 | ''' 17 | if not isinstance(file_path, Path): 18 | file_path = Path(file_path) 19 | # if isinstance(data,dict): 20 | # data = json.dumps(data) 21 | with open(str(file_path), 'w') as f: 22 | json.dump(data, f) 23 | 24 | 25 | def load_json(file_path): 26 | ''' 27 | load json 28 | :param json_path: 29 | :param file_name: 30 | :return: 31 | ''' 32 | if not isinstance(file_path, Path): 33 | file_path = Path(file_path) 34 | with open(str(file_path), 'r') as f: 35 | data = json.load(f) 36 | return data 37 | 38 | class AverageMeter(object): 39 | ''' 40 | # computes and stores the average and current value 41 | # Example: 42 | # >>> loss = AverageMeter() 43 | # >>> for step,batch in enumerate(train_data): 44 | # >>> pred = self.model(batch) 45 | # >>> raw_loss = self.metrics(pred,target) 46 | # >>> loss.update(raw_loss.item(),n = 1) 47 | # >>> cur_loss = loss.avg 48 | # ''' 49 | 50 | def __init__(self): 51 | self.reset() 52 | 53 | def reset(self): 54 | self.val = 0 55 | self.avg = 0 56 | self.sum = 0 57 | self.count = 0 58 | 59 | def update(self, val, n=1): 60 | self.val = val 61 | self.sum += val * n 62 | self.count += n 63 | self.avg = self.sum / self.count 64 | 65 | 66 | def seed_everything(seed=1029): 67 | ''' 68 | :param seed: 69 | :param device: 70 | :return: 71 | ''' 72 | random.seed(seed) 73 | os.environ['PYTHONHASHSEED'] = str(seed) 74 | np.random.seed(seed) 75 | torch.manual_seed(seed) 76 | torch.cuda.manual_seed(seed) 77 | torch.cuda.manual_seed_all(seed) 78 | # some cudnn methods can be random even after fixing the seed 79 | # unless you tell it to be deterministic 80 | torch.backends.cudnn.deterministic = True -------------------------------------------------------------------------------- /trainingmonitor.py: -------------------------------------------------------------------------------- 1 | # encoding:utf-8 2 | import numpy as np 3 | from pathlib import Path 4 | import matplotlib.pyplot as plt 5 | from tools import load_json 6 | from tools import save_json 7 | plt.switch_backend('agg') # 防止ssh上绘图问题 8 | 9 | class TrainingMonitor(): 10 | def __init__(self, file_dir, arch, add_test=False): 11 | ''' 12 | :param startAt: 重新开始训练的epoch点 13 | ''' 14 | if isinstance(file_dir, Path): 15 | pass 16 | else: 17 | file_dir = Path(file_dir) 18 | file_dir.mkdir(parents=True, exist_ok=True) 19 | 20 | self.arch = arch 21 | self.file_dir = file_dir 22 | self.H = {} 23 | self.add_test = add_test 24 | self.json_path = file_dir / (arch + "_training_monitor.json") 25 | 26 | def reset(self,start_at): 27 | if start_at > 0: 28 | if self.json_path is not None: 29 | if self.json_path.exists(): 30 | self.H = load_json(self.json_path) 31 | for k in self.H.keys(): 32 | self.H[k] = self.H[k][:start_at] 33 | 34 | def epoch_step(self, logs={}): 35 | for (k, v) in logs.items(): 36 | l = self.H.get(k, []) 37 | # np.float32会报错 38 | if not isinstance(v, np.float): 39 | v = round(float(v), 4) 40 | l.append(v) 41 | self.H[k] = l 42 | 43 | # 写入文件 44 | if self.json_path is not None: 45 | save_json(data = self.H,file_path=self.json_path) 46 | 47 | # 保存train图像 48 | if len(self.H["loss"]) == 1: 49 | self.paths = {key: self.file_dir / (self.arch + f'_{key.upper()}') for key in self.H.keys()} 50 | 51 | if len(self.H["loss"]) > 1: 52 | # 指标变化 53 | # 曲线 54 | # 需要成对出现 55 | keys = [key for key, _ in self.H.items() if '_' not in key] 56 | for key in keys: 57 | N = np.arange(0, len(self.H[key])) 58 | plt.style.use("ggplot") 59 | plt.figure() 60 | plt.plot(N, self.H[key], label=f"train_{key}") 61 | plt.plot(N, self.H[f"valid_{key}"], label=f"valid_{key}") 62 | if self.add_test: 63 | plt.plot(N, self.H[f"test_{key}"], label=f"test_{key}") 64 | plt.legend() 65 | plt.xlabel("Epoch #") 66 | plt.ylabel(key) 67 | plt.title(f"Training {key} [Epoch {len(self.H[key])}]") 68 | plt.savefig(str(self.paths[key])) 69 | plt.close() 70 | --------------------------------------------------------------------------------