├── __init__.py ├── Training-process-with-TensorBoard.jpg ├── .idea ├── inspectionProfiles │ └── profiles_settings.xml ├── misc.xml ├── modules.xml ├── lars-imagenet-pytorch.iml ├── deployment.xml └── workspace.xml ├── LICENSE ├── utils.py ├── README.md ├── lars.py ├── lamb.py ├── pytorch_imagenet_resnet.py └── pytorch_imagenet_resnet_dali.py /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Training-process-with-TensorBoard.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NUS-HPC-AI-Lab/LARS-ImageNet-PyTorch/HEAD/Training-process-with-TensorBoard.jpg -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/lars-imagenet-pytorch.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 binmakeswell 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 10 | 11 | 12 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 1614257743553 28 | 33 | 34 | 35 | 36 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import horovod.torch as hvd 4 | from torch.optim.lr_scheduler import _LRScheduler 5 | 6 | 7 | def accuracy(output, target): 8 | # get the index of the max log-probability 9 | pred = output.max(1, keepdim=True)[1] 10 | return pred.eq(target.view_as(pred)).cpu().float().mean() 11 | 12 | def save_checkpoint(model, optimizer, checkpoint_format, epoch): 13 | if hvd.rank() == 0: 14 | filepath = checkpoint_format.format(epoch=epoch + 1) 15 | state = { 16 | 'model': model.state_dict(), 17 | 'optimizer': optimizer.state_dict(), 18 | } 19 | torch.save(state, filepath) 20 | 21 | class LabelSmoothLoss(torch.nn.Module): 22 | 23 | def __init__(self, smoothing=0.0): 24 | super(LabelSmoothLoss, self).__init__() 25 | self.smoothing = smoothing 26 | 27 | def forward(self, input, target): 28 | log_prob = F.log_softmax(input, dim=-1) 29 | weight = input.new_ones(input.size()) * \ 30 | self.smoothing / (input.size(-1) - 1.) 31 | weight.scatter_(-1, target.unsqueeze(-1), (1. - self.smoothing)) 32 | loss = (-weight * log_prob).sum(dim=-1).mean() 33 | return loss 34 | 35 | def metric_average(val_tensor): 36 | avg_tensor = hvd.allreduce(val_tensor) 37 | return avg_tensor.item() 38 | 39 | # Horovod: average metrics from distributed training. 40 | class Metric(object): 41 | def __init__(self, name): 42 | self.name = name 43 | self.sum = torch.tensor(0.) 44 | self.n = torch.tensor(0.) 45 | 46 | def update(self, val, n=1): 47 | self.sum += float(hvd.allreduce(val.detach().cpu(), name=self.name)) 48 | self.n += n 49 | 50 | @property 51 | def avg(self): 52 | return self.sum / self.n 53 | 54 | def create_lr_schedule(workers, warmup_epochs, decay_schedule, alpha=0.1): 55 | def lr_schedule(epoch): 56 | lr_adj = 1. 57 | if epoch < warmup_epochs: 58 | lr_adj = 1. / workers * (epoch * (workers - 1) / warmup_epochs + 1) 59 | else: 60 | decay_schedule.sort(reverse=True) 61 | for e in decay_schedule: 62 | if epoch >= e: 63 | lr_adj *= alpha 64 | return lr_adj 65 | return lr_schedule 66 | 67 | 68 | 69 | class PolynomialDecay(_LRScheduler): 70 | def __init__(self, optimizer, decay_steps, end_lr=0.0001, power=1.0, last_epoch=-1): 71 | self.decay_steps = decay_steps 72 | self.end_lr = end_lr 73 | self.power = power 74 | super().__init__(optimizer, last_epoch) 75 | 76 | def get_lr(self): 77 | return self._get_closed_form_lr() 78 | 79 | def _get_closed_form_lr(self): 80 | return [ 81 | (base_lr - self.end_lr) * ((1 - min(self.last_epoch, self.decay_steps) / 82 | self.decay_steps) ** self.power) + self.end_lr 83 | for base_lr in self.base_lrs 84 | ] 85 | 86 | class WarmupScheduler(_LRScheduler): 87 | def __init__(self, optimizer, warmup_epochs, after_scheduler, last_epoch=-1): 88 | self.warmup_epochs = warmup_epochs 89 | self.after_scheduler = after_scheduler 90 | self.finished = False 91 | super().__init__(optimizer, last_epoch) 92 | 93 | def get_lr(self): 94 | if self.last_epoch >= self.warmup_epochs: 95 | if not self.finished: 96 | self.after_scheduler.base_lrs = self.base_lrs 97 | self.finished = True 98 | return self.after_scheduler.get_lr() 99 | 100 | return [self.last_epoch / self.warmup_epochs * lr for lr in self.base_lrs] 101 | 102 | def step(self, epoch=None): 103 | if self.finished: 104 | if epoch is None: 105 | self.after_scheduler.step(None) 106 | else: 107 | self.after_scheduler.step(epoch - self.warmup_epochs) 108 | else: 109 | return super().step(epoch) 110 | 111 | 112 | class PolynomialWarmup(WarmupScheduler): 113 | def __init__(self, optimizer, decay_steps, warmup_steps=0, end_lr=0.0001, power=1.0, last_epoch=-1): 114 | base_scheduler = PolynomialDecay( 115 | optimizer, decay_steps - warmup_steps, end_lr=end_lr, power=power, last_epoch=last_epoch) 116 | super().__init__(optimizer, warmup_steps, base_scheduler, last_epoch=last_epoch) 117 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Implementation of LARS for ImageNet with PyTorch 2 | 3 | This is the code for the paper "[Large Batch Training of Convolutional Networks](https://arxiv.org/abs/1708.03888)", which implements a large batch deep learning optimizer called LARS using PyTorch. Although the optimizer has been released for some time and has an official TensorFlow version implementation, as far as we know, there is no reliable PyTorch version implementation, so we try to complete this work. We use [Horovod](https://github.com/horovod/horovod) to implement distributed data parallel training and provide accumulated gradient and NVIDIA DALI dataloader as options. 4 | 5 | ## Requirements 6 | 7 | This code is validated to run with Python 3.6.10, PyTorch 1.5.0, Horovod 0.21.1, CUDA 10.0/1, CUDNN 7.6.4, and NCCL 2.4.7. 8 | 9 | ## Performance on ImageNet 10 | 11 | We verified the implementation on the complete ImageNet-1K (ILSVRC2012) data set. The parameters and performance as follows. 12 | 13 | | Effective Batchsize | Batchsize | Base LR | Warmup Epochs | Epsilon | Val Accuracy | TensorBoard Color | 14 | | :-----------------: | :-------: | :-------------: | :--------------: | :-----: | :----------: | :---------------: | 15 | | 512 | 64 | 22 | 10/26 | 1e-5 | **77.02%** | Light blue | 16 | | 1024 | 128 | 22.5 | 10/25 | 1e-5 | **76.96%** | Brown | 17 | | 4096 | 128 | 23.5 | 10/23 | 1e-5 | **77.38%** | Orange | 18 | | 8192 | 128 | 24 | 10/22 | 1e-5 | **77.14%** | Deep Blue | 19 | | 16384 | 128 | 24.5 | 5 | 1e-5 | **76.96%** | Pink | 20 | | 32768 | 64 | 25 | 14 | 0.0 | **76.75%** | Green | 21 | 22 | Training process with TensorBoard 23 | 24 | ![Training process with TensorBoard](Training-process-with-TensorBoard.jpg) 25 | 26 | We set epochs = 90, weight decay = 0.0001, model = resnet50 and use NVIDIA Tesla V100/P100 GPU for all experiments. We do not finetune the hyperparameters, maybe you can get better performance using others. 27 | 28 | Thanks for computing resources from National Supercomputing Centre Singapore (NSCC), Texas Advanced Computing Center (TACC) and Swiss National Supercomputing Centre (CSCS). 29 | 30 | ## Usage 31 | 32 | ``` 33 | from lars import * 34 | ... 35 | optimizer = create_optimizer_lars(model=model, lr=args.base_lr, epsilon=args.epsilon, 36 | momentum=args.momentum, weight_decay=args.wd, 37 | bn_bias_separately=args.bn_bias_separately) 38 | ... 39 | lr_scheduler = PolynomialWarmup(optimizer, decay_steps=args.epochs * num_steps_per_epoch, 40 | warmup_steps=args.warmup_epochs * num_steps_per_epoch, 41 | end_lr=0.0, power=lr_power, last_epoch=-1) 42 | ... 43 | ``` 44 | 45 | Note that we recommend using create_optimizer_lars and setting bn_bias_separately=True, instead of using class Lars directly, which helps LARS skip parameters in BatchNormalization and bias, and has better performance in general. Polynomial Warmup learning rate decay is also helpful for better performance in general. 46 | 47 | ## Example Scripts 48 | 49 | Example scripts for training with 8 GPUs and 1024 effective batch size on ImageNet-1k are provided. 50 | 51 | ``` 52 | $ mpirun -np 8 \ 53 | python pytorch_imagenet_resnet.py \ 54 | --batch-size 128 \ 55 | --warmup-epochs 0.3125 \ 56 | --train-dir=your path/ImageNet/train/ \ 57 | --val-dir=your path/ImageNet/val \ 58 | --base-lr 5.6568542494924 \ 59 | --base-op lars \ 60 | --bn-bias-separately \ 61 | --wd 0.0001 \ 62 | --lr-scaling keep 63 | ``` 64 | 65 | ## Additional Options 66 | 67 | **Accumulated gradient** When the GPUs is insufficient, the accumulated gradient technology can be used, which can simulate larger effective batch size using limited GPUs, although it maybe extend the running time to some extent. To use it, you just need add --batches-per-allreduce N in above command, where N is the scale factor. For example, set N = 4 here can simulate effective batch size 4096 using only 8 GPUs. 68 | 69 | **DALI dataloader** NVIDIA DALI can accelerate data loading and pre-processing using GPU rather than CPU, although with GPU memory tradeoff. It can also avoid some potential conflicts between MPI libraries and Horovod on some GPU clusters. To use it, please use 'pytorch_imagenet_resnet_dali.py' with '--data-dir' rather than 'train/val-dir'. For '--data-dir', it requires ImageNet-1k data in **TFRecord format** in the following structure: 70 | 71 | ``` 72 | train-recs 'path/train/*' 73 | val-recs 'path/validation/*' 74 | train-idx 'path/idx_files/train/*' 75 | val-idx 'path/idx_files/validation/*' 76 | ``` 77 | 78 | ## 79 | 80 | ## Reference 81 | 82 | [Large Batch Training of Convolutional Networks](https://arxiv.org/abs/1708.03888) 83 | 84 | [Large-Batch Training for LSTM and Beyond](https://arxiv.org/abs/1901.08256) 85 | 86 | https://www.comp.nus.edu.sg/~youy/lars_optimizer.py 87 | 88 | https://github.com/tensorflow/tpu/blob/5f71c12a020403f863434e96982a840578fdd127/models/official/efficientnet/lars_optimizer.py 89 | -------------------------------------------------------------------------------- /lars.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.optimizer import Optimizer 3 | from typing import Dict, Iterable, Optional, Callable, Tuple 4 | from torch import nn 5 | 6 | """ 7 | We recommend using create_optimizer_lars and setting bn_bias_separately=True 8 | instead of using class Lars directly, which helps LARS skip parameters 9 | in BatchNormalization and bias, and has better performance in general. 10 | Polynomial Warmup learning rate decay is also helpful for better performance in general. 11 | """ 12 | 13 | 14 | def create_optimizer_lars(model, lr, momentum, weight_decay, bn_bias_separately, epsilon): 15 | if bn_bias_separately: 16 | optimizer = Lars([ 17 | dict(params=get_common_parameters(model, exclude_func=get_norm_bias_parameters)), 18 | dict(params=get_norm_bias_parameters(model), weight_decay=0, lars=False)], 19 | lr=lr, 20 | momentum=momentum, 21 | weight_decay=weight_decay, 22 | epsilon=epsilon) 23 | else: 24 | optimizer = Lars(model.parameters(), 25 | lr=lr, 26 | momentum=momentum, 27 | weight_decay=weight_decay, 28 | epsilon=epsilon) 29 | return optimizer 30 | 31 | 32 | class Lars(Optimizer): 33 | r"""Implements the LARS optimizer from `"Large batch training of convolutional networks" 34 | `_. 35 | 36 | Args: 37 | params (iterable): iterable of parameters to optimize or dicts defining 38 | parameter groups 39 | lr (float, optional): learning rate 40 | momentum (float, optional): momentum factor (default: 0) 41 | eeta (float, optional): LARS coefficient as used in the paper (default: 1e-3) 42 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 43 | """ 44 | 45 | def __init__( 46 | self, 47 | params: Iterable[torch.nn.Parameter], 48 | lr=1e-3, 49 | momentum=0, 50 | eeta=1e-3, 51 | weight_decay=0, 52 | epsilon=0.0 53 | ) -> None: 54 | if not isinstance(lr, float) or lr < 0.0: 55 | raise ValueError("Invalid learning rate: {}".format(lr)) 56 | if momentum < 0.0: 57 | raise ValueError("Invalid momentum value: {}".format(momentum)) 58 | if weight_decay < 0.0: 59 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 60 | if eeta <= 0 or eeta > 1: 61 | raise ValueError("Invalid eeta value: {}".format(eeta)) 62 | if epsilon < 0: 63 | raise ValueError("Invalid epsilon value: {}".format(epsilon)) 64 | defaults = dict(lr=lr, momentum=momentum, 65 | weight_decay=weight_decay, eeta=eeta, epsilon=epsilon, lars=True) 66 | 67 | super().__init__(params, defaults) 68 | 69 | @torch.no_grad() 70 | def step(self, closure=None): 71 | """Performs a single optimization step. 72 | Arguments: 73 | closure (callable, optional): A closure that reevaluates the model 74 | and returns the loss. 75 | """ 76 | loss = None 77 | if closure is not None: 78 | with torch.enable_grad(): 79 | loss = closure() 80 | 81 | for group in self.param_groups: 82 | weight_decay = group['weight_decay'] 83 | momentum = group['momentum'] 84 | eeta = group['eeta'] 85 | lr = group['lr'] 86 | lars = group['lars'] 87 | eps = group['epsilon'] 88 | 89 | for p in group['params']: 90 | if p.grad is None: 91 | continue 92 | decayed_grad = p.grad 93 | scaled_lr = lr 94 | if lars: 95 | w_norm = torch.norm(p) 96 | g_norm = torch.norm(p.grad) 97 | trust_ratio = torch.where( 98 | w_norm > 0 and g_norm > 0, 99 | eeta * w_norm / (g_norm + weight_decay * w_norm + eps), 100 | torch.ones_like(w_norm) 101 | ) 102 | trust_ratio.clamp_(0.0, 50) 103 | scaled_lr *= trust_ratio.item() 104 | if weight_decay != 0: 105 | decayed_grad = decayed_grad.add(p, alpha=weight_decay) 106 | decayed_grad = torch.clamp(decayed_grad, -10.0, 10.0) 107 | 108 | if momentum != 0: 109 | param_state = self.state[p] 110 | if 'momentum_buffer' not in param_state: 111 | buf = param_state['momentum_buffer'] = torch.clone( 112 | decayed_grad).detach() 113 | else: 114 | buf = param_state['momentum_buffer'] 115 | buf.mul_(momentum).add_(decayed_grad) 116 | decayed_grad = buf 117 | 118 | p.add_(decayed_grad, alpha=-scaled_lr) 119 | 120 | return loss 121 | 122 | 123 | """ 124 | Functions which help to skip bias and BatchNorm 125 | """ 126 | BN_CLS = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d) 127 | 128 | 129 | def get_parameters_from_cls(module, cls_): 130 | def get_members_fn(m): 131 | if isinstance(m, cls_): 132 | return m._parameters.items() 133 | else: 134 | return dict() 135 | 136 | named_parameters = module._named_members(get_members_fn=get_members_fn) 137 | for name, param in named_parameters: 138 | yield param 139 | 140 | 141 | def get_norm_parameters(module): 142 | return get_parameters_from_cls(module, (nn.LayerNorm, *BN_CLS)) 143 | 144 | 145 | def get_bias_parameters(module, exclude_func=None): 146 | excluded_parameters = set() 147 | if exclude_func is not None: 148 | for param in exclude_func(module): 149 | excluded_parameters.add(param) 150 | for name, param in module.named_parameters(): 151 | if param not in excluded_parameters and 'bias' in name: 152 | yield param 153 | 154 | 155 | def get_norm_bias_parameters(module): 156 | for param in get_norm_parameters(module): 157 | yield param 158 | for param in get_bias_parameters(module, exclude_func=get_norm_parameters): 159 | yield param 160 | 161 | 162 | def get_common_parameters(module, exclude_func=None): 163 | excluded_parameters = set() 164 | if exclude_func is not None: 165 | for param in exclude_func(module): 166 | excluded_parameters.add(param) 167 | for name, param in module.named_parameters(): 168 | if param not in excluded_parameters: 169 | yield param 170 | -------------------------------------------------------------------------------- /lamb.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim import Optimizer 4 | from torch import nn 5 | 6 | 7 | class LAMB(Optimizer): 8 | r"""Implements Lamb algorithm. 9 | It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. 10 | Arguments: 11 | params (iterable): iterable of parameters to optimize or dicts defining 12 | parameter groups 13 | lr (float, optional): learning rate (default: 1e-3) 14 | betas (Tuple[float, float], optional): coefficients used for computing 15 | running averages of gradient and its square (default: (0.9, 0.999)) 16 | eps (float, optional): term added to the denominator to improve 17 | numerical stability (default: 1e-8) 18 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 19 | adam (bool, optional): always use trust ratio = 1, which turns this into 20 | Adam. Useful for comparison purposes. 21 | .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes: 22 | https://arxiv.org/abs/1904.00962 23 | """ 24 | 25 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, 26 | weight_decay=0): 27 | if not 0.0 <= lr: 28 | raise ValueError("Invalid learning rate: {}".format(lr)) 29 | if not 0.0 <= eps: 30 | raise ValueError("Invalid epsilon value: {}".format(eps)) 31 | if not 0.0 <= betas[0] < 1.0: 32 | raise ValueError( 33 | "Invalid beta parameter at index 0: {}".format(betas[0])) 34 | if not 0.0 <= betas[1] < 1.0: 35 | raise ValueError( 36 | "Invalid beta parameter at index 1: {}".format(betas[1])) 37 | defaults = dict(lr=lr, betas=betas, eps=eps, 38 | weight_decay=weight_decay) 39 | super().__init__(params, defaults) 40 | 41 | @torch.no_grad() 42 | def step(self, closure=None): 43 | """Performs a single optimization step. 44 | Arguments: 45 | closure (callable, optional): A closure that reevaluates the model 46 | and returns the loss. 47 | """ 48 | loss = None 49 | if closure is not None: 50 | loss = closure() 51 | 52 | torch.nn.utils.clip_grad_norm_( 53 | parameters=[ 54 | p for group in self.param_groups for p in group['params']], 55 | max_norm=1.0, 56 | norm_type=2 57 | ) 58 | 59 | for group in self.param_groups: 60 | for p in group['params']: 61 | if p.grad is None: 62 | continue 63 | grad = p.grad.data 64 | if grad.is_sparse: 65 | raise RuntimeError( 66 | 'Lamb does not support sparse gradients, consider SparseAdam instad.') 67 | 68 | state = self.state[p] 69 | 70 | # State initialization 71 | if len(state) == 0: 72 | state['step'] = 0 73 | # Exponential moving average of gradient values 74 | state['exp_avg'] = torch.zeros_like(p.data) 75 | # Exponential moving average of squared gradient values 76 | state['exp_avg_sq'] = torch.zeros_like(p.data) 77 | 78 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 79 | beta1, beta2 = group['betas'] 80 | 81 | state['step'] += 1 82 | 83 | # Decay the first and second moment running average coefficient 84 | # m_t 85 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 86 | # v_t 87 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 88 | 89 | # Paper v3 does not use debiasing. 90 | # bias_correction1 = 1 - beta1 ** state['step'] 91 | # bias_correction2 = 1 - beta2 ** state['step'] 92 | # Apply bias to lr to avoid broadcast. 93 | # * math.sqrt(bias_correction2) / bias_correction1 94 | scaled_lr = group['lr'] 95 | update = exp_avg / exp_avg_sq.sqrt().add(group['eps']) 96 | if group['weight_decay'] != 0: 97 | update.add_(p.data, alpha=group['weight_decay']) 98 | w_norm = torch.norm(p) 99 | g_norm = torch.norm(update) 100 | trust_ratio = torch.where( 101 | w_norm > 0 and g_norm > 0, 102 | w_norm / g_norm, 103 | torch.ones_like(w_norm) 104 | ) 105 | scaled_lr *= trust_ratio.item() 106 | 107 | p.data.add_(update, alpha=-scaled_lr) 108 | 109 | return loss 110 | 111 | 112 | def create_lamb_optimizer(model, lr, betas=(0.9, 0.999), eps=1e-6, 113 | weight_decay=0, exclude_layers=['bn', 'ln', 'bias']): 114 | # can only exclude BatchNorm, LayerNorm, bias layers 115 | # ['bn', 'ln'] will exclude BatchNorm, LayerNorm layers 116 | # ['bn', 'ln', 'bias'] will exclude BatchNorm, LayerNorm, bias layers 117 | # [] will not exclude any layers 118 | if 'bias' in exclude_layers: 119 | params = [ 120 | dict(params=get_common_parameters( 121 | model, exclude_func=get_norm_bias_parameters)), 122 | dict(params=get_norm_bias_parameters(model), weight_decay=0) 123 | ] 124 | elif len(exclude_layers) > 0: 125 | params = [ 126 | dict(params=get_common_parameters( 127 | model, exclude_func=get_norm_parameters)), 128 | dict(params=get_norm_parameters(model), weight_decay=0) 129 | ] 130 | else: 131 | params = model.parameters() 132 | optimizer = LAMB(params, lr, betas=betas, eps=eps, 133 | weight_decay=weight_decay) 134 | return optimizer 135 | 136 | 137 | BN_CLS = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d) 138 | 139 | 140 | def get_parameters_from_cls(module, cls_): 141 | def get_members_fn(m): 142 | if isinstance(m, cls_): 143 | return m._parameters.items() 144 | else: 145 | return dict() 146 | named_parameters = module._named_members(get_members_fn=get_members_fn) 147 | for name, param in named_parameters: 148 | yield param 149 | 150 | 151 | def get_bn_parameters(module): 152 | return get_parameters_from_cls(module, BN_CLS) 153 | 154 | 155 | def get_ln_parameters(module): 156 | return get_parameters_from_cls(module, nn.LayerNorm) 157 | 158 | 159 | def get_norm_parameters(module): 160 | return get_parameters_from_cls(module, (nn.LayerNorm, *BN_CLS)) 161 | 162 | 163 | def get_bias_parameters(module, exclude_func=None): 164 | excluded_parameters = set() 165 | if exclude_func is not None: 166 | for param in exclude_func(module): 167 | excluded_parameters.add(param) 168 | for name, param in module.named_parameters(): 169 | if param not in excluded_parameters and 'bias' in name: 170 | yield param 171 | 172 | 173 | def get_norm_bias_parameters(module): 174 | for param in get_norm_parameters(module): 175 | yield param 176 | for param in get_bias_parameters(module, exclude_func=get_norm_parameters): 177 | yield param 178 | 179 | 180 | def get_common_parameters(module, exclude_func=None): 181 | excluded_parameters = set() 182 | if exclude_func is not None: 183 | for param in exclude_func(module): 184 | excluded_parameters.add(param) 185 | for name, param in module.named_parameters(): 186 | if param not in excluded_parameters: 187 | yield param -------------------------------------------------------------------------------- /pytorch_imagenet_resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import time 4 | from datetime import datetime, timedelta 5 | import argparse 6 | import os 7 | import math 8 | import warnings 9 | import torch 10 | import torch.backends.cudnn as cudnn 11 | import torch.nn.functional as F 12 | import torch.optim as optim 13 | import torch.utils.data.distributed 14 | from torchvision import datasets, transforms, models 15 | import horovod.torch as hvd 16 | from tqdm import tqdm 17 | from distutils.version import LooseVersion 18 | 19 | from utils import * 20 | from lars import * 21 | from lamb import * 22 | 23 | warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning) 24 | 25 | 26 | def initialize(): 27 | # Training settings 28 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Example', 29 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 30 | parser.add_argument('--train-dir', default='/tmp/imagenet/ILSVRC2012_img_train/', 31 | help='path to training data') 32 | parser.add_argument('--val-dir', default='/tmp/imagenet/ILSVRC2012_img_val/', 33 | help='path to validation data') 34 | parser.add_argument('--log-dir', default='./logs/imagenet', 35 | help='tensorboard/checkpoint log directory') 36 | parser.add_argument('--checkpoint-format', default='checkpoint-{epoch}.pth.tar', 37 | help='checkpoint file format') 38 | parser.add_argument('--fp16-allreduce', action='store_true', default=False, 39 | help='use fp16 compression during allreduce') 40 | parser.add_argument('--batches-per-allreduce', type=int, default=1, 41 | help='number of batches processed locally before ' 42 | 'executing allreduce across workers; it multiplies ' 43 | 'total batch size.') 44 | 45 | # Default settings from https://arxiv.org/abs/1706.02677. 46 | parser.add_argument('--model', default='resnet50', 47 | help='Model (resnet35, resnet50, resnet101, resnet152, resnext50, resnext101)') 48 | parser.add_argument('--batch-size', type=int, default=32, 49 | help='input batch size for training') 50 | parser.add_argument('--val-batch-size', type=int, default=32, 51 | help='input batch size for validation') 52 | parser.add_argument('--epochs', type=int, default=90, 53 | help='number of epochs to train') 54 | parser.add_argument('--base-lr', type=float, default=0.0125, 55 | help='learning rate for a single GPU') 56 | parser.add_argument('--lr-decay', nargs='+', type=int, default=[30, 60, 80], 57 | help='epoch intervals to decay lr') 58 | parser.add_argument('--warmup-epochs', type=float, default=5, 59 | help='number of warmup epochs') 60 | parser.add_argument('--momentum', type=float, default=0.9, 61 | help='SGD momentum') 62 | parser.add_argument('--wd', type=float, default=0.00005, 63 | help='weight decay') 64 | parser.add_argument('--epsilon', type=float, default=1e-5, 65 | help='epsilon for optimizer') 66 | parser.add_argument('--label-smoothing', type=float, default=0.1, 67 | help='label smoothing (default 0.1)') 68 | parser.add_argument('--base-op', type=str, default='sgd', 69 | help='base optimizer name') 70 | parser.add_argument('--bn-bias-separately', action='store_true', default=False, 71 | help='skip bn and bias') 72 | parser.add_argument('--lr-scaling', type=str, default='keep', 73 | help='lr scaling method') 74 | 75 | parser.add_argument('--no-cuda', action='store_true', default=False, 76 | help='disables CUDA training') 77 | parser.add_argument('--single-threaded', action='store_true', default=False, 78 | help='disables multi-threaded dataloading') 79 | parser.add_argument('--seed', type=int, default=42, 80 | help='random seed') 81 | 82 | args = parser.parse_args() 83 | args.cuda = not args.no_cuda and torch.cuda.is_available() 84 | 85 | hvd.init() 86 | torch.manual_seed(args.seed) 87 | 88 | args.verbose = 1 if hvd.rank() == 0 else 0 89 | 90 | print('hvd.rank() ', hvd.rank()) 91 | 92 | if args.verbose: 93 | print(args) 94 | 95 | if args.cuda: 96 | torch.cuda.set_device(hvd.local_rank()) 97 | torch.cuda.manual_seed(args.seed) 98 | 99 | cudnn.benchmark = True 100 | 101 | if args.bn_bias_separately: 102 | skip = "bn_bias" 103 | elif args.bn_separately: 104 | skip = "bn" 105 | else: 106 | skip = "no" 107 | 108 | args.log_dir = os.path.join(args.log_dir, 109 | "imagenet_{}_gpu_{}_{}_ebs{}_blr_{}_skip_{}_{}".format( 110 | args.model, hvd.size(), args.base_op, 111 | args.batch_size * hvd.size() * args.batches_per_allreduce, args.base_lr, skip, 112 | datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))) 113 | args.checkpoint_format = os.path.join(args.log_dir, args.checkpoint_format) 114 | os.makedirs(args.log_dir, exist_ok=True) 115 | 116 | # If set > 0, will resume training from a given checkpoint. 117 | args.resume_from_epoch = 0 118 | for try_epoch in range(args.epochs, 0, -1): 119 | if os.path.exists(args.checkpoint_format.format(epoch=try_epoch)): 120 | args.resume_from_epoch = try_epoch 121 | break 122 | 123 | # Horovod: broadcast resume_from_epoch from rank 0 (which will have 124 | # checkpoints) to other ranks. 125 | args.resume_from_epoch = hvd.broadcast(torch.tensor(args.resume_from_epoch), 126 | root_rank=0, 127 | name='resume_from_epoch').item() 128 | 129 | # Horovod: write TensorBoard logs on first worker. 130 | try: 131 | if LooseVersion(torch.__version__) >= LooseVersion('1.2.0'): 132 | from torch.utils.tensorboard import SummaryWriter 133 | else: 134 | from tensorboardX import SummaryWriter 135 | args.log_writer = SummaryWriter(args.log_dir) if hvd.rank() == 0 else None 136 | except ImportError: 137 | args.log_writer = None 138 | 139 | return args 140 | 141 | 142 | def get_datasets(args): 143 | # Horovod: limit # of CPU threads to be used per worker. 144 | if args.single_threaded: 145 | torch.set_num_threads(4) 146 | kwargs = {'num_workers': 0, 'pin_memory': True} if args.cuda else {} 147 | else: 148 | torch.set_num_threads(4) 149 | num_workers_input = hvd.size() 150 | # num_workers_input = 4 151 | kwargs = {'num_workers': num_workers_input, 'pin_memory': True} if args.cuda else {} 152 | if args.verbose: 153 | print('actual num_workers ', num_workers_input) 154 | 155 | train_dataset = datasets.ImageFolder( 156 | args.train_dir, 157 | transform=transforms.Compose([ 158 | transforms.RandomResizedCrop(224), 159 | transforms.RandomHorizontalFlip(), 160 | transforms.ToTensor(), 161 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 162 | std=[0.229, 0.224, 0.225])])) 163 | val_dataset = datasets.ImageFolder( 164 | args.val_dir, 165 | transform=transforms.Compose([ 166 | transforms.Resize(256), 167 | transforms.CenterCrop(224), 168 | transforms.ToTensor(), 169 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 170 | std=[0.229, 0.224, 0.225])])) 171 | 172 | # Horovod: use DistributedSampler to partition data among workers. Manually specify 173 | # `num_replicas=hvd.size()` and `rank=hvd.rank()`. 174 | train_sampler = torch.utils.data.distributed.DistributedSampler( 175 | train_dataset, num_replicas=hvd.size(), rank=hvd.rank()) 176 | train_loader = torch.utils.data.DataLoader( 177 | train_dataset, batch_size=args.batch_size * args.batches_per_allreduce, 178 | sampler=train_sampler, **kwargs) 179 | val_sampler = torch.utils.data.distributed.DistributedSampler( 180 | val_dataset, num_replicas=hvd.size(), rank=hvd.rank()) 181 | val_loader = torch.utils.data.DataLoader( 182 | val_dataset, batch_size=args.val_batch_size, 183 | sampler=val_sampler, **kwargs) 184 | 185 | if args.verbose: 186 | print('actual batch_size ', args.batch_size * args.batches_per_allreduce * hvd.size()) 187 | 188 | return train_sampler, train_loader, val_sampler, val_loader 189 | 190 | 191 | def get_model(args, num_steps_per_epoch): 192 | if args.model.lower() == 'resnet50': 193 | # model = models_local.resnet50() 194 | model = models.resnet50() 195 | else: 196 | raise ValueError('Unknown model \'{}\''.format(args.model)) 197 | 198 | if args.cuda: 199 | model.cuda() 200 | 201 | # Horovod: scale learning rate by the number of GPUs. 202 | if args.lr_scaling.lower() == "linear": 203 | args.base_lr = args.base_lr * hvd.size() * args.batches_per_allreduce 204 | if args.lr_scaling.lower() == "sqrt": 205 | args.base_lr = math.sqrt(args.base_lr * hvd.size() * args.batches_per_allreduce) 206 | if args.lr_scaling.lower() == "keep": 207 | args.base_lr = args.base_lr 208 | if args.verbose: 209 | print('actual base_lr ', args.base_lr) 210 | 211 | if args.base_op.lower() == "lars": 212 | optimizer = create_optimizer_lars(model=model, lr=args.base_lr, epsilon=args.epsilon, 213 | momentum=args.momentum, weight_decay=args.wd, 214 | bn_bias_separately=args.bn_bias_separately) 215 | elif args.base_op.lower() == "lamb": 216 | optimizer = create_lamb_optimizer(model=model, lr=args.base_lr, 217 | weight_decay=args.wd) 218 | else: 219 | optimizer = optim.SGD(model.parameters(), lr=args.base_lr, 220 | momentum=args.momentum, weight_decay=args.wd) 221 | 222 | compression = hvd.Compression.fp16 if args.fp16_allreduce \ 223 | else hvd.Compression.none 224 | optimizer = hvd.DistributedOptimizer( 225 | optimizer, named_parameters=model.named_parameters(), 226 | compression=compression, op=hvd.Average, 227 | backward_passes_per_step=args.batches_per_allreduce) 228 | 229 | # Restore from a previous checkpoint, if initial_epoch is specified. 230 | # Horovod: restore on the first worker which will broadcast weights 231 | # to other workers. 232 | if args.resume_from_epoch > 0 and hvd.rank() == 0: 233 | filepath = args.checkpoint_format.format(epoch=args.resume_from_epoch) 234 | checkpoint = torch.load(filepath) 235 | model.load_state_dict(checkpoint['model']) 236 | optimizer.load_state_dict(checkpoint['optimizer']) 237 | 238 | # Horovod: broadcast parameters & optimizer state. 239 | hvd.broadcast_parameters(model.state_dict(), root_rank=0) 240 | hvd.broadcast_optimizer_state(optimizer, root_rank=0) 241 | 242 | # lrs = create_lr_schedule(hvd.size(), args.warmup_epochs, args.lr_decay) 243 | # lr_scheduler = [LambdaLR(optimizer, lrs)] 244 | if args.base_op.lower() == "lars": 245 | lr_power = 2.0 246 | else: 247 | lr_power = 1.0 248 | 249 | lr_scheduler = PolynomialWarmup(optimizer, decay_steps=args.epochs * num_steps_per_epoch, 250 | warmup_steps=args.warmup_epochs * num_steps_per_epoch, 251 | end_lr=0.0, power=lr_power, last_epoch=-1) 252 | 253 | loss_func = LabelSmoothLoss(args.label_smoothing) 254 | 255 | return model, optimizer, lr_scheduler, loss_func 256 | 257 | 258 | def train(epoch, model, optimizer, lr_schedules, 259 | loss_func, train_sampler, train_loader, args): 260 | model.train() 261 | train_sampler.set_epoch(epoch) 262 | train_loss = Metric('train_loss') 263 | train_accuracy = Metric('train_accuracy') 264 | 265 | with tqdm(total=len(train_loader), 266 | desc='Epoch {:3d}/{:3d}'.format(epoch + 1, args.epochs), 267 | disable=not args.verbose) as t: 268 | for batch_idx, (data, target) in enumerate(train_loader): 269 | if args.cuda: 270 | data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True) 271 | optimizer.zero_grad() 272 | 273 | for i in range(0, len(data), args.batch_size): 274 | data_batch = data[i:i + args.batch_size] 275 | target_batch = target[i:i + args.batch_size] 276 | output = model(data_batch) 277 | 278 | loss = loss_func(output, target_batch) 279 | 280 | with torch.no_grad(): 281 | train_loss.update(loss) 282 | train_accuracy.update(accuracy(output, target_batch)) 283 | 284 | loss.div_(math.ceil(float(len(data)) / args.batch_size)) 285 | loss.backward() 286 | 287 | optimizer.synchronize() 288 | 289 | with optimizer.skip_synchronize(): 290 | optimizer.step() 291 | 292 | t.set_postfix_str("loss: {:.4f}, acc: {:.2f}%".format( 293 | train_loss.avg.item(), 100 * train_accuracy.avg.item())) 294 | t.update(1) 295 | 296 | lr_schedules.step() 297 | 298 | if args.verbose: 299 | print('') 300 | print('epoch ', epoch + 1, '/', args.epochs) 301 | print('train/loss ', train_loss.avg) 302 | print('train/accuracy ', train_accuracy.avg, '%') 303 | print('train/lr ', optimizer.param_groups[0]['lr']) 304 | 305 | if args.log_writer is not None: 306 | args.log_writer.add_scalar('train/loss', train_loss.avg, epoch) 307 | args.log_writer.add_scalar('train/accuracy', train_accuracy.avg, epoch) 308 | args.log_writer.add_scalar('train/lr', optimizer.param_groups[0]['lr'], epoch) 309 | 310 | 311 | def validate(epoch, model, loss_func, val_loader, args): 312 | model.eval() 313 | val_loss = Metric('val_loss') 314 | val_accuracy = Metric('val_accuracy') 315 | 316 | with tqdm(total=len(val_loader), 317 | # bar_format='{l_bar}{bar}|{postfix}', 318 | desc=' '.format(epoch + 1, args.epochs), 319 | disable=not args.verbose) as t: 320 | with torch.no_grad(): 321 | for i, (data, target) in enumerate(val_loader): 322 | if args.cuda: 323 | data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True) 324 | output = model(data) 325 | val_loss.update(loss_func(output, target)) 326 | val_accuracy.update(accuracy(output, target)) 327 | 328 | t.update(1) 329 | if i + 1 == len(val_loader): 330 | t.set_postfix_str("\b\b val_loss: {:.4f}, val_acc: {:.2f}%".format( 331 | val_loss.avg.item(), 100 * val_accuracy.avg.item()), 332 | refresh=False) 333 | if args.verbose: 334 | print('') 335 | print('val/loss ', val_loss.avg) 336 | print('val/accuracy ', val_accuracy.avg, '%') 337 | 338 | if args.log_writer is not None: 339 | args.log_writer.add_scalar('val/loss', val_loss.avg, epoch) 340 | args.log_writer.add_scalar('val/accuracy', val_accuracy.avg, epoch) 341 | 342 | 343 | if __name__ == '__main__': 344 | torch.multiprocessing.set_start_method('spawn') 345 | 346 | args = initialize() 347 | 348 | train_sampler, train_loader, _, val_loader = get_datasets(args) 349 | 350 | num_steps_per_epoch = len(train_loader) 351 | 352 | model, opt, lr_schedules, loss_func = get_model(args, num_steps_per_epoch) 353 | 354 | if args.verbose: 355 | print("MODEL:", args.model) 356 | 357 | start = time.time() 358 | 359 | for epoch in range(args.resume_from_epoch, args.epochs): 360 | train(epoch, model, opt, lr_schedules, 361 | loss_func, train_sampler, train_loader, args) 362 | validate(epoch, model, loss_func, val_loader, args) 363 | save_checkpoint(model, opt, args.checkpoint_format, epoch) 364 | 365 | if args.verbose: 366 | print("\nTraining time:", str(timedelta(seconds=time.time() - start))) 367 | -------------------------------------------------------------------------------- /pytorch_imagenet_resnet_dali.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import time 4 | from datetime import datetime, timedelta 5 | import argparse 6 | import os 7 | import math 8 | import warnings 9 | import torch 10 | import torch.backends.cudnn as cudnn 11 | import torch.nn.functional as F 12 | import torch.optim as optim 13 | import torch.utils.data.distributed 14 | from torchvision import datasets, transforms, models 15 | import horovod.torch as hvd 16 | from tqdm import tqdm 17 | from distutils.version import LooseVersion 18 | 19 | from nvidia.dali.pipeline import Pipeline 20 | from nvidia.dali.plugin.pytorch import DALIClassificationIterator, LastBatchPolicy 21 | import nvidia.dali.fn as fn 22 | import nvidia.dali.types as types 23 | import nvidia.dali.tfrecord as tfrec 24 | import glob 25 | 26 | from utils import * 27 | from lars import * 28 | from lamb import * 29 | 30 | warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning) 31 | 32 | 33 | def initialize(): 34 | # Training settings 35 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Example', 36 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 37 | parser.add_argument('--train-dir', default='/tmp/imagenet/ILSVRC2012_img_train/', 38 | help='path to training data') 39 | parser.add_argument('--val-dir', default='/tmp/imagenet/ILSVRC2012_img_val/', 40 | help='path to validation data') 41 | parser.add_argument('--data-dir', default='/tmp/imagenet/', 42 | help='path to data data') 43 | parser.add_argument('--log-dir', default='./logs/imagenet', 44 | help='tensorboard/checkpoint log directory') 45 | parser.add_argument('--checkpoint-format', default='checkpoint-{epoch}.pth.tar', 46 | help='checkpoint file format') 47 | parser.add_argument('--fp16-allreduce', action='store_true', default=False, 48 | help='use fp16 compression during allreduce') 49 | parser.add_argument('--batches-per-allreduce', type=int, default=1, 50 | help='number of batches processed locally before ' 51 | 'executing allreduce across workers; it multiplies ' 52 | 'total batch size.') 53 | 54 | # Default settings from https://arxiv.org/abs/1706.02677. 55 | parser.add_argument('--model', default='resnet50', 56 | help='Model (resnet35, resnet50, resnet101, resnet152, resnext50, resnext101)') 57 | parser.add_argument('--batch-size', type=int, default=32, 58 | help='input batch size for training') 59 | parser.add_argument('--val-batch-size', type=int, default=32, 60 | help='input batch size for validation') 61 | parser.add_argument('--epochs', type=int, default=90, 62 | help='number of epochs to train') 63 | parser.add_argument('--base-lr', type=float, default=0.0125, 64 | help='learning rate for a single GPU') 65 | parser.add_argument('--lr-decay', nargs='+', type=int, default=[30, 60, 80], 66 | help='epoch intervals to decay lr') 67 | parser.add_argument('--warmup-epochs', type=float, default=5, 68 | help='number of warmup epochs') 69 | parser.add_argument('--momentum', type=float, default=0.9, 70 | help='SGD momentum') 71 | parser.add_argument('--wd', type=float, default=0.00005, 72 | help='weight decay') 73 | parser.add_argument('--epsilon', type=float, default=1e-5, 74 | help='epsilon for optimizer') 75 | parser.add_argument('--label-smoothing', type=float, default=0.1, 76 | help='label smoothing (default 0.1)') 77 | parser.add_argument('--base-op', type=str, default='sgd', 78 | help='base optimizer name') 79 | parser.add_argument('--bn-bias-separately', action='store_true', default=False, 80 | help='skip bn and bias') 81 | parser.add_argument('--lr-scaling', type=str, default='keep', 82 | help='lr scaling method') 83 | 84 | parser.add_argument('--no-cuda', action='store_true', default=False, 85 | help='disables CUDA training') 86 | parser.add_argument('--single-threaded', action='store_true', default=False, 87 | help='disables multi-threaded dataloading') 88 | parser.add_argument('--seed', type=int, default=42, 89 | help='random seed') 90 | 91 | args = parser.parse_args() 92 | args.cuda = not args.no_cuda and torch.cuda.is_available() 93 | 94 | hvd.init() 95 | torch.manual_seed(args.seed) 96 | 97 | args.verbose = 1 if hvd.rank() == 0 else 0 98 | 99 | print('hvd.rank() ', hvd.rank()) 100 | 101 | if args.verbose: 102 | print(args) 103 | 104 | if args.cuda: 105 | torch.cuda.set_device(hvd.local_rank()) 106 | torch.cuda.manual_seed(args.seed) 107 | cudnn.deterministic = True 108 | cudnn.benchmark = False 109 | 110 | if args.bn_bias_separately: 111 | skip = "bn_bias" 112 | elif args.bn_separately: 113 | skip = "bn" 114 | else: 115 | skip = "no" 116 | 117 | args.log_dir = os.path.join(args.log_dir, 118 | "imagenet_{}_gpu_{}_{}_ebs{}_blr_{}_skip_{}_{}".format( 119 | args.model, hvd.size(), args.base_op, 120 | args.batch_size * hvd.size() * args.batches_per_allreduce, args.base_lr, skip, 121 | datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))) 122 | args.checkpoint_format = os.path.join(args.log_dir, args.checkpoint_format) 123 | os.makedirs(args.log_dir, exist_ok=True) 124 | 125 | # If set > 0, will resume training from a given checkpoint. 126 | args.resume_from_epoch = 0 127 | for try_epoch in range(args.epochs, 0, -1): 128 | if os.path.exists(args.checkpoint_format.format(epoch=try_epoch)): 129 | args.resume_from_epoch = try_epoch 130 | break 131 | 132 | # Horovod: broadcast resume_from_epoch from rank 0 (which will have 133 | # checkpoints) to other ranks. 134 | args.resume_from_epoch = hvd.broadcast(torch.tensor(args.resume_from_epoch), 135 | root_rank=0, 136 | name='resume_from_epoch').item() 137 | 138 | # Horovod: write TensorBoard logs on first worker. 139 | try: 140 | if LooseVersion(torch.__version__) >= LooseVersion('1.2.0'): 141 | from torch.utils.tensorboard import SummaryWriter 142 | else: 143 | from tensorboardX import SummaryWriter 144 | args.log_writer = SummaryWriter(args.log_dir) if hvd.rank() == 0 else None 145 | except ImportError: 146 | args.log_writer = None 147 | 148 | return args 149 | 150 | 151 | def dali_dataloader( 152 | tfrec_filenames, 153 | tfrec_idx_filenames, 154 | shard_id=0, num_shards=1, 155 | batch_size=64, num_threads=2, 156 | image_size=224, num_workers=1, training=True): 157 | pipe = Pipeline(batch_size=batch_size, 158 | num_threads=num_threads, device_id=hvd.local_rank()) 159 | with pipe: 160 | inputs = fn.readers.tfrecord( 161 | path=tfrec_filenames, 162 | index_path=tfrec_idx_filenames, 163 | random_shuffle=training, 164 | shard_id=shard_id, 165 | num_shards=num_shards, 166 | initial_fill=10000, 167 | read_ahead=True, 168 | pad_last_batch=True, 169 | prefetch_queue_depth=num_workers, 170 | name='Reader', 171 | features={ 172 | 'image/encoded': tfrec.FixedLenFeature((), tfrec.string, ""), 173 | 'image/class/label': tfrec.FixedLenFeature([1], tfrec.int64, -1), 174 | }) 175 | jpegs = inputs["image/encoded"] 176 | if training: 177 | images = fn.decoders.image_random_crop( 178 | jpegs, 179 | device="mixed", 180 | output_type=types.RGB, 181 | random_aspect_ratio=[0.8, 1.25], 182 | random_area=[0.1, 1.0], 183 | num_attempts=100) 184 | images = fn.resize(images, 185 | device='gpu', 186 | resize_x=image_size, 187 | resize_y=image_size, 188 | interp_type=types.INTERP_TRIANGULAR) 189 | mirror = fn.random.coin_flip(probability=0.5) 190 | else: 191 | images = fn.decoders.image(jpegs, 192 | device='mixed', 193 | output_type=types.RGB) 194 | images = fn.resize(images, 195 | device='gpu', 196 | size=int(image_size / 0.875), 197 | mode="not_smaller", 198 | interp_type=types.INTERP_TRIANGULAR) 199 | mirror = False 200 | 201 | images = fn.crop_mirror_normalize( 202 | images.gpu(), 203 | dtype=types.FLOAT, 204 | crop=(image_size, image_size), 205 | mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], 206 | std=[0.229 * 255, 0.224 * 255, 0.225 * 255], 207 | mirror=mirror) 208 | label = inputs["image/class/label"] - 1 # 0-999 209 | label = fn.element_extract(label, element_map=0) # Flatten 210 | label = label.gpu() 211 | pipe.set_outputs(images, label) 212 | 213 | pipe.build() 214 | last_batch_policy = LastBatchPolicy.DROP if training else LastBatchPolicy.PARTIAL 215 | loader = DALIClassificationIterator( 216 | pipe, reader_name="Reader", auto_reset=True, last_batch_policy=last_batch_policy) 217 | return loader 218 | 219 | 220 | def get_datasets(args): 221 | num_shards = hvd.size() 222 | shard_id = hvd.rank() 223 | num_workers = 1 224 | num_threads = 2 225 | root = args.data_dir 226 | 227 | train_pat = os.path.join(root, 'train/*') 228 | train_idx_pat = os.path.join(root, 'idx_files/train/*') 229 | train_loader = dali_dataloader(sorted(glob.glob(train_pat)), 230 | sorted(glob.glob(train_idx_pat)), 231 | shard_id=shard_id, 232 | num_shards=num_shards, 233 | batch_size=args.batch_size * args.batches_per_allreduce, 234 | num_workers=num_workers, 235 | num_threads=num_threads, 236 | training=True) 237 | test_pat = os.path.join(root, 'validation/*') 238 | test_idx_pat = os.path.join(root, 'idx_files/validation/*') 239 | val_loader = dali_dataloader(sorted(glob.glob(test_pat)), 240 | sorted(glob.glob(test_idx_pat)), 241 | shard_id=shard_id, 242 | num_shards=num_shards, 243 | batch_size=args.val_batch_size, 244 | num_workers=num_workers, 245 | num_threads=num_threads, 246 | training=False) 247 | if args.verbose: 248 | print('actual batch_size ', args.batch_size * args.batches_per_allreduce * hvd.size()) 249 | 250 | return train_loader, val_loader 251 | 252 | 253 | def get_model(args, num_steps_per_epoch): 254 | if args.model.lower() == 'resnet50': 255 | # model = models_local.resnet50() 256 | model = models.resnet50() 257 | else: 258 | raise ValueError('Unknown model \'{}\''.format(args.model)) 259 | 260 | if args.cuda: 261 | model.cuda() 262 | 263 | # Horovod: scale learning rate by the number of GPUs. 264 | if args.lr_scaling.lower() == "linear": 265 | args.base_lr = args.base_lr * hvd.size() * args.batches_per_allreduce 266 | if args.lr_scaling.lower() == "sqrt": 267 | args.base_lr = math.sqrt(args.base_lr * hvd.size() * args.batches_per_allreduce) 268 | if args.lr_scaling.lower() == "keep": 269 | args.base_lr = args.base_lr 270 | if args.verbose: 271 | print('actual base_lr ', args.base_lr) 272 | 273 | if args.base_op.lower() == "lars": 274 | optimizer = create_optimizer_lars(model=model, lr=args.base_lr, epsilon=args.epsilon, 275 | momentum=args.momentum, weight_decay=args.wd, 276 | bn_bias_separately=args.bn_bias_separately) 277 | elif args.base_op.lower() == "lamb": 278 | optimizer = create_lamb_optimizer(model=model, lr=args.base_lr, 279 | weight_decay=args.wd) 280 | else: 281 | optimizer = optim.SGD(model.parameters(), lr=args.base_lr, 282 | momentum=args.momentum, weight_decay=args.wd) 283 | 284 | compression = hvd.Compression.fp16 if args.fp16_allreduce \ 285 | else hvd.Compression.none 286 | optimizer = hvd.DistributedOptimizer( 287 | optimizer, named_parameters=model.named_parameters(), 288 | compression=compression, op=hvd.Average, 289 | backward_passes_per_step=args.batches_per_allreduce) 290 | 291 | # Restore from a previous checkpoint, if initial_epoch is specified. 292 | # Horovod: restore on the first worker which will broadcast weights 293 | # to other workers. 294 | if args.resume_from_epoch > 0 and hvd.rank() == 0: 295 | filepath = args.checkpoint_format.format(epoch=args.resume_from_epoch) 296 | checkpoint = torch.load(filepath) 297 | model.load_state_dict(checkpoint['model']) 298 | optimizer.load_state_dict(checkpoint['optimizer']) 299 | 300 | # Horovod: broadcast parameters & optimizer state. 301 | hvd.broadcast_parameters(model.state_dict(), root_rank=0) 302 | hvd.broadcast_optimizer_state(optimizer, root_rank=0) 303 | 304 | # lrs = create_lr_schedule(hvd.size(), args.warmup_epochs, args.lr_decay) 305 | # lr_scheduler = [LambdaLR(optimizer, lrs)] 306 | if args.base_op.lower() == "lars": 307 | lr_power = 2.0 308 | else: 309 | lr_power = 1.0 310 | 311 | lr_scheduler = PolynomialWarmup(optimizer, decay_steps=args.epochs * num_steps_per_epoch, 312 | warmup_steps=args.warmup_epochs * num_steps_per_epoch, 313 | end_lr=0.0, power=lr_power, last_epoch=-1) 314 | 315 | loss_func = LabelSmoothLoss(args.label_smoothing) 316 | 317 | return model, optimizer, lr_scheduler, loss_func 318 | 319 | 320 | def train(epoch, model, optimizer, lr_schedules, 321 | loss_func, train_loader, args): 322 | model.train() 323 | # train_sampler.set_epoch(epoch) 324 | train_loss = Metric('train_loss') 325 | train_accuracy = Metric('train_accuracy') 326 | 327 | with tqdm(total=len(train_loader), 328 | desc='Epoch {:3d}/{:3d}'.format(epoch + 1, args.epochs), 329 | disable=not args.verbose) as t: 330 | for batch_idx, data in enumerate(train_loader): 331 | input, target = data[0]['data'], data[0]['label'] 332 | # if args.cuda: 333 | # data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True) 334 | optimizer.zero_grad() 335 | 336 | for i in range(0, len(input), args.batch_size): 337 | data_batch = input[i:i + args.batch_size] 338 | target_batch = target[i:i + args.batch_size] 339 | output = model(data_batch) 340 | 341 | loss = loss_func(output, target_batch) 342 | loss = loss / args.batches_per_allreduce 343 | 344 | with torch.no_grad(): 345 | train_loss.update(loss) 346 | train_accuracy.update(accuracy(output, target_batch)) 347 | 348 | loss.backward() 349 | 350 | optimizer.synchronize() 351 | 352 | with optimizer.skip_synchronize(): 353 | optimizer.step() 354 | 355 | t.set_postfix_str("loss: {:.4f}, acc: {:.2f}%, lr: {:.4f}".format( 356 | train_loss.avg.item(), 100 * train_accuracy.avg.item(), optimizer.param_groups[0]['lr'])) 357 | t.update(1) 358 | 359 | lr_schedules.step() 360 | 361 | if args.verbose: 362 | print('') 363 | print('epoch ', epoch + 1, '/', args.epochs) 364 | print('train/loss ', train_loss.avg) 365 | print('train/accuracy ', train_accuracy.avg, '%') 366 | print('train/lr ', optimizer.param_groups[0]['lr']) 367 | 368 | if args.log_writer is not None: 369 | args.log_writer.add_scalar('train/loss', train_loss.avg, epoch) 370 | args.log_writer.add_scalar('train/accuracy', train_accuracy.avg, epoch) 371 | args.log_writer.add_scalar('train/lr', optimizer.param_groups[0]['lr'], epoch) 372 | 373 | 374 | def validate(epoch, model, loss_func, val_loader, args): 375 | model.eval() 376 | val_loss = Metric('val_loss') 377 | val_accuracy = Metric('val_accuracy') 378 | 379 | with tqdm(total=len(val_loader), 380 | # bar_format='{l_bar}{bar}|{postfix}', 381 | desc=' '.format(epoch + 1, args.epochs), 382 | disable=not args.verbose) as t: 383 | with torch.no_grad(): 384 | for i, data in enumerate(val_loader): 385 | input, target = data[0]['data'], data[0]['label'] 386 | # if args.cuda: 387 | # data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True) 388 | output = model(input) 389 | val_loss.update(loss_func(output, target)) 390 | val_accuracy.update(accuracy(output, target)) 391 | 392 | t.update(1) 393 | if i + 1 == len(val_loader): 394 | t.set_postfix_str("\b\b val_loss: {:.4f}, val_acc: {:.2f}%".format( 395 | val_loss.avg.item(), 100 * val_accuracy.avg.item()), 396 | refresh=False) 397 | if args.verbose: 398 | print('') 399 | print('val/loss ', val_loss.avg) 400 | print('val/accuracy ', val_accuracy.avg, '%') 401 | 402 | if args.log_writer is not None: 403 | args.log_writer.add_scalar('val/loss', val_loss.avg, epoch) 404 | args.log_writer.add_scalar('val/accuracy', val_accuracy.avg, epoch) 405 | 406 | 407 | if __name__ == '__main__': 408 | 409 | args = initialize() 410 | 411 | if args.single_threaded: 412 | print('Not use torch.multiprocessing.set_start_method') 413 | else: 414 | # torch.multiprocessing.set_start_method('spawn') 415 | # torch.multiprocessing.set_start_method('forkserver') 416 | torch.multiprocessing.set_start_method('spawn') 417 | 418 | train_loader, val_loader = get_datasets(args=args) 419 | 420 | num_steps_per_epoch = len(train_loader) 421 | 422 | model, opt, lr_schedules, loss_func = get_model(args, num_steps_per_epoch) 423 | 424 | if args.verbose: 425 | print("MODEL:", args.model) 426 | 427 | start = time.time() 428 | 429 | for epoch in range(args.resume_from_epoch, args.epochs): 430 | train(epoch=epoch, model=model, optimizer=opt, lr_schedules=lr_schedules, 431 | loss_func=loss_func, train_loader=train_loader, args=args) 432 | validate(epoch, model, loss_func, val_loader, args) 433 | save_checkpoint(model, opt, args.checkpoint_format, epoch) 434 | 435 | if args.verbose: 436 | print("\nTraining time:", str(timedelta(seconds=time.time() - start))) 437 | --------------------------------------------------------------------------------