├── __init__.py
├── Training-process-with-TensorBoard.jpg
├── .idea
├── inspectionProfiles
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
├── lars-imagenet-pytorch.iml
├── deployment.xml
└── workspace.xml
├── LICENSE
├── utils.py
├── README.md
├── lars.py
├── lamb.py
├── pytorch_imagenet_resnet.py
└── pytorch_imagenet_resnet_dali.py
/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/Training-process-with-TensorBoard.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NUS-HPC-AI-Lab/LARS-ImageNet-PyTorch/HEAD/Training-process-with-TensorBoard.jpg
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/lars-imagenet-pytorch.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 binmakeswell
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 | 1614257743553
28 |
29 |
30 | 1614257743553
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import horovod.torch as hvd
4 | from torch.optim.lr_scheduler import _LRScheduler
5 |
6 |
7 | def accuracy(output, target):
8 | # get the index of the max log-probability
9 | pred = output.max(1, keepdim=True)[1]
10 | return pred.eq(target.view_as(pred)).cpu().float().mean()
11 |
12 | def save_checkpoint(model, optimizer, checkpoint_format, epoch):
13 | if hvd.rank() == 0:
14 | filepath = checkpoint_format.format(epoch=epoch + 1)
15 | state = {
16 | 'model': model.state_dict(),
17 | 'optimizer': optimizer.state_dict(),
18 | }
19 | torch.save(state, filepath)
20 |
21 | class LabelSmoothLoss(torch.nn.Module):
22 |
23 | def __init__(self, smoothing=0.0):
24 | super(LabelSmoothLoss, self).__init__()
25 | self.smoothing = smoothing
26 |
27 | def forward(self, input, target):
28 | log_prob = F.log_softmax(input, dim=-1)
29 | weight = input.new_ones(input.size()) * \
30 | self.smoothing / (input.size(-1) - 1.)
31 | weight.scatter_(-1, target.unsqueeze(-1), (1. - self.smoothing))
32 | loss = (-weight * log_prob).sum(dim=-1).mean()
33 | return loss
34 |
35 | def metric_average(val_tensor):
36 | avg_tensor = hvd.allreduce(val_tensor)
37 | return avg_tensor.item()
38 |
39 | # Horovod: average metrics from distributed training.
40 | class Metric(object):
41 | def __init__(self, name):
42 | self.name = name
43 | self.sum = torch.tensor(0.)
44 | self.n = torch.tensor(0.)
45 |
46 | def update(self, val, n=1):
47 | self.sum += float(hvd.allreduce(val.detach().cpu(), name=self.name))
48 | self.n += n
49 |
50 | @property
51 | def avg(self):
52 | return self.sum / self.n
53 |
54 | def create_lr_schedule(workers, warmup_epochs, decay_schedule, alpha=0.1):
55 | def lr_schedule(epoch):
56 | lr_adj = 1.
57 | if epoch < warmup_epochs:
58 | lr_adj = 1. / workers * (epoch * (workers - 1) / warmup_epochs + 1)
59 | else:
60 | decay_schedule.sort(reverse=True)
61 | for e in decay_schedule:
62 | if epoch >= e:
63 | lr_adj *= alpha
64 | return lr_adj
65 | return lr_schedule
66 |
67 |
68 |
69 | class PolynomialDecay(_LRScheduler):
70 | def __init__(self, optimizer, decay_steps, end_lr=0.0001, power=1.0, last_epoch=-1):
71 | self.decay_steps = decay_steps
72 | self.end_lr = end_lr
73 | self.power = power
74 | super().__init__(optimizer, last_epoch)
75 |
76 | def get_lr(self):
77 | return self._get_closed_form_lr()
78 |
79 | def _get_closed_form_lr(self):
80 | return [
81 | (base_lr - self.end_lr) * ((1 - min(self.last_epoch, self.decay_steps) /
82 | self.decay_steps) ** self.power) + self.end_lr
83 | for base_lr in self.base_lrs
84 | ]
85 |
86 | class WarmupScheduler(_LRScheduler):
87 | def __init__(self, optimizer, warmup_epochs, after_scheduler, last_epoch=-1):
88 | self.warmup_epochs = warmup_epochs
89 | self.after_scheduler = after_scheduler
90 | self.finished = False
91 | super().__init__(optimizer, last_epoch)
92 |
93 | def get_lr(self):
94 | if self.last_epoch >= self.warmup_epochs:
95 | if not self.finished:
96 | self.after_scheduler.base_lrs = self.base_lrs
97 | self.finished = True
98 | return self.after_scheduler.get_lr()
99 |
100 | return [self.last_epoch / self.warmup_epochs * lr for lr in self.base_lrs]
101 |
102 | def step(self, epoch=None):
103 | if self.finished:
104 | if epoch is None:
105 | self.after_scheduler.step(None)
106 | else:
107 | self.after_scheduler.step(epoch - self.warmup_epochs)
108 | else:
109 | return super().step(epoch)
110 |
111 |
112 | class PolynomialWarmup(WarmupScheduler):
113 | def __init__(self, optimizer, decay_steps, warmup_steps=0, end_lr=0.0001, power=1.0, last_epoch=-1):
114 | base_scheduler = PolynomialDecay(
115 | optimizer, decay_steps - warmup_steps, end_lr=end_lr, power=power, last_epoch=last_epoch)
116 | super().__init__(optimizer, warmup_steps, base_scheduler, last_epoch=last_epoch)
117 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Implementation of LARS for ImageNet with PyTorch
2 |
3 | This is the code for the paper "[Large Batch Training of Convolutional Networks](https://arxiv.org/abs/1708.03888)", which implements a large batch deep learning optimizer called LARS using PyTorch. Although the optimizer has been released for some time and has an official TensorFlow version implementation, as far as we know, there is no reliable PyTorch version implementation, so we try to complete this work. We use [Horovod](https://github.com/horovod/horovod) to implement distributed data parallel training and provide accumulated gradient and NVIDIA DALI dataloader as options.
4 |
5 | ## Requirements
6 |
7 | This code is validated to run with Python 3.6.10, PyTorch 1.5.0, Horovod 0.21.1, CUDA 10.0/1, CUDNN 7.6.4, and NCCL 2.4.7.
8 |
9 | ## Performance on ImageNet
10 |
11 | We verified the implementation on the complete ImageNet-1K (ILSVRC2012) data set. The parameters and performance as follows.
12 |
13 | | Effective Batchsize | Batchsize | Base LR | Warmup Epochs | Epsilon | Val Accuracy | TensorBoard Color |
14 | | :-----------------: | :-------: | :-------------: | :--------------: | :-----: | :----------: | :---------------: |
15 | | 512 | 64 | 22 | 10/26 | 1e-5 | **77.02%** | Light blue |
16 | | 1024 | 128 | 22.5 | 10/25 | 1e-5 | **76.96%** | Brown |
17 | | 4096 | 128 | 23.5 | 10/23 | 1e-5 | **77.38%** | Orange |
18 | | 8192 | 128 | 24 | 10/22 | 1e-5 | **77.14%** | Deep Blue |
19 | | 16384 | 128 | 24.5 | 5 | 1e-5 | **76.96%** | Pink |
20 | | 32768 | 64 | 25 | 14 | 0.0 | **76.75%** | Green |
21 |
22 | Training process with TensorBoard
23 |
24 | 
25 |
26 | We set epochs = 90, weight decay = 0.0001, model = resnet50 and use NVIDIA Tesla V100/P100 GPU for all experiments. We do not finetune the hyperparameters, maybe you can get better performance using others.
27 |
28 | Thanks for computing resources from National Supercomputing Centre Singapore (NSCC), Texas Advanced Computing Center (TACC) and Swiss National Supercomputing Centre (CSCS).
29 |
30 | ## Usage
31 |
32 | ```
33 | from lars import *
34 | ...
35 | optimizer = create_optimizer_lars(model=model, lr=args.base_lr, epsilon=args.epsilon,
36 | momentum=args.momentum, weight_decay=args.wd,
37 | bn_bias_separately=args.bn_bias_separately)
38 | ...
39 | lr_scheduler = PolynomialWarmup(optimizer, decay_steps=args.epochs * num_steps_per_epoch,
40 | warmup_steps=args.warmup_epochs * num_steps_per_epoch,
41 | end_lr=0.0, power=lr_power, last_epoch=-1)
42 | ...
43 | ```
44 |
45 | Note that we recommend using create_optimizer_lars and setting bn_bias_separately=True, instead of using class Lars directly, which helps LARS skip parameters in BatchNormalization and bias, and has better performance in general. Polynomial Warmup learning rate decay is also helpful for better performance in general.
46 |
47 | ## Example Scripts
48 |
49 | Example scripts for training with 8 GPUs and 1024 effective batch size on ImageNet-1k are provided.
50 |
51 | ```
52 | $ mpirun -np 8 \
53 | python pytorch_imagenet_resnet.py \
54 | --batch-size 128 \
55 | --warmup-epochs 0.3125 \
56 | --train-dir=your path/ImageNet/train/ \
57 | --val-dir=your path/ImageNet/val \
58 | --base-lr 5.6568542494924 \
59 | --base-op lars \
60 | --bn-bias-separately \
61 | --wd 0.0001 \
62 | --lr-scaling keep
63 | ```
64 |
65 | ## Additional Options
66 |
67 | **Accumulated gradient** When the GPUs is insufficient, the accumulated gradient technology can be used, which can simulate larger effective batch size using limited GPUs, although it maybe extend the running time to some extent. To use it, you just need add --batches-per-allreduce N in above command, where N is the scale factor. For example, set N = 4 here can simulate effective batch size 4096 using only 8 GPUs.
68 |
69 | **DALI dataloader** NVIDIA DALI can accelerate data loading and pre-processing using GPU rather than CPU, although with GPU memory tradeoff. It can also avoid some potential conflicts between MPI libraries and Horovod on some GPU clusters. To use it, please use 'pytorch_imagenet_resnet_dali.py' with '--data-dir' rather than 'train/val-dir'. For '--data-dir', it requires ImageNet-1k data in **TFRecord format** in the following structure:
70 |
71 | ```
72 | train-recs 'path/train/*'
73 | val-recs 'path/validation/*'
74 | train-idx 'path/idx_files/train/*'
75 | val-idx 'path/idx_files/validation/*'
76 | ```
77 |
78 | ##
79 |
80 | ## Reference
81 |
82 | [Large Batch Training of Convolutional Networks](https://arxiv.org/abs/1708.03888)
83 |
84 | [Large-Batch Training for LSTM and Beyond](https://arxiv.org/abs/1901.08256)
85 |
86 | https://www.comp.nus.edu.sg/~youy/lars_optimizer.py
87 |
88 | https://github.com/tensorflow/tpu/blob/5f71c12a020403f863434e96982a840578fdd127/models/official/efficientnet/lars_optimizer.py
89 |
--------------------------------------------------------------------------------
/lars.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.optim.optimizer import Optimizer
3 | from typing import Dict, Iterable, Optional, Callable, Tuple
4 | from torch import nn
5 |
6 | """
7 | We recommend using create_optimizer_lars and setting bn_bias_separately=True
8 | instead of using class Lars directly, which helps LARS skip parameters
9 | in BatchNormalization and bias, and has better performance in general.
10 | Polynomial Warmup learning rate decay is also helpful for better performance in general.
11 | """
12 |
13 |
14 | def create_optimizer_lars(model, lr, momentum, weight_decay, bn_bias_separately, epsilon):
15 | if bn_bias_separately:
16 | optimizer = Lars([
17 | dict(params=get_common_parameters(model, exclude_func=get_norm_bias_parameters)),
18 | dict(params=get_norm_bias_parameters(model), weight_decay=0, lars=False)],
19 | lr=lr,
20 | momentum=momentum,
21 | weight_decay=weight_decay,
22 | epsilon=epsilon)
23 | else:
24 | optimizer = Lars(model.parameters(),
25 | lr=lr,
26 | momentum=momentum,
27 | weight_decay=weight_decay,
28 | epsilon=epsilon)
29 | return optimizer
30 |
31 |
32 | class Lars(Optimizer):
33 | r"""Implements the LARS optimizer from `"Large batch training of convolutional networks"
34 | `_.
35 |
36 | Args:
37 | params (iterable): iterable of parameters to optimize or dicts defining
38 | parameter groups
39 | lr (float, optional): learning rate
40 | momentum (float, optional): momentum factor (default: 0)
41 | eeta (float, optional): LARS coefficient as used in the paper (default: 1e-3)
42 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
43 | """
44 |
45 | def __init__(
46 | self,
47 | params: Iterable[torch.nn.Parameter],
48 | lr=1e-3,
49 | momentum=0,
50 | eeta=1e-3,
51 | weight_decay=0,
52 | epsilon=0.0
53 | ) -> None:
54 | if not isinstance(lr, float) or lr < 0.0:
55 | raise ValueError("Invalid learning rate: {}".format(lr))
56 | if momentum < 0.0:
57 | raise ValueError("Invalid momentum value: {}".format(momentum))
58 | if weight_decay < 0.0:
59 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
60 | if eeta <= 0 or eeta > 1:
61 | raise ValueError("Invalid eeta value: {}".format(eeta))
62 | if epsilon < 0:
63 | raise ValueError("Invalid epsilon value: {}".format(epsilon))
64 | defaults = dict(lr=lr, momentum=momentum,
65 | weight_decay=weight_decay, eeta=eeta, epsilon=epsilon, lars=True)
66 |
67 | super().__init__(params, defaults)
68 |
69 | @torch.no_grad()
70 | def step(self, closure=None):
71 | """Performs a single optimization step.
72 | Arguments:
73 | closure (callable, optional): A closure that reevaluates the model
74 | and returns the loss.
75 | """
76 | loss = None
77 | if closure is not None:
78 | with torch.enable_grad():
79 | loss = closure()
80 |
81 | for group in self.param_groups:
82 | weight_decay = group['weight_decay']
83 | momentum = group['momentum']
84 | eeta = group['eeta']
85 | lr = group['lr']
86 | lars = group['lars']
87 | eps = group['epsilon']
88 |
89 | for p in group['params']:
90 | if p.grad is None:
91 | continue
92 | decayed_grad = p.grad
93 | scaled_lr = lr
94 | if lars:
95 | w_norm = torch.norm(p)
96 | g_norm = torch.norm(p.grad)
97 | trust_ratio = torch.where(
98 | w_norm > 0 and g_norm > 0,
99 | eeta * w_norm / (g_norm + weight_decay * w_norm + eps),
100 | torch.ones_like(w_norm)
101 | )
102 | trust_ratio.clamp_(0.0, 50)
103 | scaled_lr *= trust_ratio.item()
104 | if weight_decay != 0:
105 | decayed_grad = decayed_grad.add(p, alpha=weight_decay)
106 | decayed_grad = torch.clamp(decayed_grad, -10.0, 10.0)
107 |
108 | if momentum != 0:
109 | param_state = self.state[p]
110 | if 'momentum_buffer' not in param_state:
111 | buf = param_state['momentum_buffer'] = torch.clone(
112 | decayed_grad).detach()
113 | else:
114 | buf = param_state['momentum_buffer']
115 | buf.mul_(momentum).add_(decayed_grad)
116 | decayed_grad = buf
117 |
118 | p.add_(decayed_grad, alpha=-scaled_lr)
119 |
120 | return loss
121 |
122 |
123 | """
124 | Functions which help to skip bias and BatchNorm
125 | """
126 | BN_CLS = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)
127 |
128 |
129 | def get_parameters_from_cls(module, cls_):
130 | def get_members_fn(m):
131 | if isinstance(m, cls_):
132 | return m._parameters.items()
133 | else:
134 | return dict()
135 |
136 | named_parameters = module._named_members(get_members_fn=get_members_fn)
137 | for name, param in named_parameters:
138 | yield param
139 |
140 |
141 | def get_norm_parameters(module):
142 | return get_parameters_from_cls(module, (nn.LayerNorm, *BN_CLS))
143 |
144 |
145 | def get_bias_parameters(module, exclude_func=None):
146 | excluded_parameters = set()
147 | if exclude_func is not None:
148 | for param in exclude_func(module):
149 | excluded_parameters.add(param)
150 | for name, param in module.named_parameters():
151 | if param not in excluded_parameters and 'bias' in name:
152 | yield param
153 |
154 |
155 | def get_norm_bias_parameters(module):
156 | for param in get_norm_parameters(module):
157 | yield param
158 | for param in get_bias_parameters(module, exclude_func=get_norm_parameters):
159 | yield param
160 |
161 |
162 | def get_common_parameters(module, exclude_func=None):
163 | excluded_parameters = set()
164 | if exclude_func is not None:
165 | for param in exclude_func(module):
166 | excluded_parameters.add(param)
167 | for name, param in module.named_parameters():
168 | if param not in excluded_parameters:
169 | yield param
170 |
--------------------------------------------------------------------------------
/lamb.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch.optim import Optimizer
4 | from torch import nn
5 |
6 |
7 | class LAMB(Optimizer):
8 | r"""Implements Lamb algorithm.
9 | It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
10 | Arguments:
11 | params (iterable): iterable of parameters to optimize or dicts defining
12 | parameter groups
13 | lr (float, optional): learning rate (default: 1e-3)
14 | betas (Tuple[float, float], optional): coefficients used for computing
15 | running averages of gradient and its square (default: (0.9, 0.999))
16 | eps (float, optional): term added to the denominator to improve
17 | numerical stability (default: 1e-8)
18 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
19 | adam (bool, optional): always use trust ratio = 1, which turns this into
20 | Adam. Useful for comparison purposes.
21 | .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes:
22 | https://arxiv.org/abs/1904.00962
23 | """
24 |
25 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6,
26 | weight_decay=0):
27 | if not 0.0 <= lr:
28 | raise ValueError("Invalid learning rate: {}".format(lr))
29 | if not 0.0 <= eps:
30 | raise ValueError("Invalid epsilon value: {}".format(eps))
31 | if not 0.0 <= betas[0] < 1.0:
32 | raise ValueError(
33 | "Invalid beta parameter at index 0: {}".format(betas[0]))
34 | if not 0.0 <= betas[1] < 1.0:
35 | raise ValueError(
36 | "Invalid beta parameter at index 1: {}".format(betas[1]))
37 | defaults = dict(lr=lr, betas=betas, eps=eps,
38 | weight_decay=weight_decay)
39 | super().__init__(params, defaults)
40 |
41 | @torch.no_grad()
42 | def step(self, closure=None):
43 | """Performs a single optimization step.
44 | Arguments:
45 | closure (callable, optional): A closure that reevaluates the model
46 | and returns the loss.
47 | """
48 | loss = None
49 | if closure is not None:
50 | loss = closure()
51 |
52 | torch.nn.utils.clip_grad_norm_(
53 | parameters=[
54 | p for group in self.param_groups for p in group['params']],
55 | max_norm=1.0,
56 | norm_type=2
57 | )
58 |
59 | for group in self.param_groups:
60 | for p in group['params']:
61 | if p.grad is None:
62 | continue
63 | grad = p.grad.data
64 | if grad.is_sparse:
65 | raise RuntimeError(
66 | 'Lamb does not support sparse gradients, consider SparseAdam instad.')
67 |
68 | state = self.state[p]
69 |
70 | # State initialization
71 | if len(state) == 0:
72 | state['step'] = 0
73 | # Exponential moving average of gradient values
74 | state['exp_avg'] = torch.zeros_like(p.data)
75 | # Exponential moving average of squared gradient values
76 | state['exp_avg_sq'] = torch.zeros_like(p.data)
77 |
78 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
79 | beta1, beta2 = group['betas']
80 |
81 | state['step'] += 1
82 |
83 | # Decay the first and second moment running average coefficient
84 | # m_t
85 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
86 | # v_t
87 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
88 |
89 | # Paper v3 does not use debiasing.
90 | # bias_correction1 = 1 - beta1 ** state['step']
91 | # bias_correction2 = 1 - beta2 ** state['step']
92 | # Apply bias to lr to avoid broadcast.
93 | # * math.sqrt(bias_correction2) / bias_correction1
94 | scaled_lr = group['lr']
95 | update = exp_avg / exp_avg_sq.sqrt().add(group['eps'])
96 | if group['weight_decay'] != 0:
97 | update.add_(p.data, alpha=group['weight_decay'])
98 | w_norm = torch.norm(p)
99 | g_norm = torch.norm(update)
100 | trust_ratio = torch.where(
101 | w_norm > 0 and g_norm > 0,
102 | w_norm / g_norm,
103 | torch.ones_like(w_norm)
104 | )
105 | scaled_lr *= trust_ratio.item()
106 |
107 | p.data.add_(update, alpha=-scaled_lr)
108 |
109 | return loss
110 |
111 |
112 | def create_lamb_optimizer(model, lr, betas=(0.9, 0.999), eps=1e-6,
113 | weight_decay=0, exclude_layers=['bn', 'ln', 'bias']):
114 | # can only exclude BatchNorm, LayerNorm, bias layers
115 | # ['bn', 'ln'] will exclude BatchNorm, LayerNorm layers
116 | # ['bn', 'ln', 'bias'] will exclude BatchNorm, LayerNorm, bias layers
117 | # [] will not exclude any layers
118 | if 'bias' in exclude_layers:
119 | params = [
120 | dict(params=get_common_parameters(
121 | model, exclude_func=get_norm_bias_parameters)),
122 | dict(params=get_norm_bias_parameters(model), weight_decay=0)
123 | ]
124 | elif len(exclude_layers) > 0:
125 | params = [
126 | dict(params=get_common_parameters(
127 | model, exclude_func=get_norm_parameters)),
128 | dict(params=get_norm_parameters(model), weight_decay=0)
129 | ]
130 | else:
131 | params = model.parameters()
132 | optimizer = LAMB(params, lr, betas=betas, eps=eps,
133 | weight_decay=weight_decay)
134 | return optimizer
135 |
136 |
137 | BN_CLS = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)
138 |
139 |
140 | def get_parameters_from_cls(module, cls_):
141 | def get_members_fn(m):
142 | if isinstance(m, cls_):
143 | return m._parameters.items()
144 | else:
145 | return dict()
146 | named_parameters = module._named_members(get_members_fn=get_members_fn)
147 | for name, param in named_parameters:
148 | yield param
149 |
150 |
151 | def get_bn_parameters(module):
152 | return get_parameters_from_cls(module, BN_CLS)
153 |
154 |
155 | def get_ln_parameters(module):
156 | return get_parameters_from_cls(module, nn.LayerNorm)
157 |
158 |
159 | def get_norm_parameters(module):
160 | return get_parameters_from_cls(module, (nn.LayerNorm, *BN_CLS))
161 |
162 |
163 | def get_bias_parameters(module, exclude_func=None):
164 | excluded_parameters = set()
165 | if exclude_func is not None:
166 | for param in exclude_func(module):
167 | excluded_parameters.add(param)
168 | for name, param in module.named_parameters():
169 | if param not in excluded_parameters and 'bias' in name:
170 | yield param
171 |
172 |
173 | def get_norm_bias_parameters(module):
174 | for param in get_norm_parameters(module):
175 | yield param
176 | for param in get_bias_parameters(module, exclude_func=get_norm_parameters):
177 | yield param
178 |
179 |
180 | def get_common_parameters(module, exclude_func=None):
181 | excluded_parameters = set()
182 | if exclude_func is not None:
183 | for param in exclude_func(module):
184 | excluded_parameters.add(param)
185 | for name, param in module.named_parameters():
186 | if param not in excluded_parameters:
187 | yield param
--------------------------------------------------------------------------------
/pytorch_imagenet_resnet.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import time
4 | from datetime import datetime, timedelta
5 | import argparse
6 | import os
7 | import math
8 | import warnings
9 | import torch
10 | import torch.backends.cudnn as cudnn
11 | import torch.nn.functional as F
12 | import torch.optim as optim
13 | import torch.utils.data.distributed
14 | from torchvision import datasets, transforms, models
15 | import horovod.torch as hvd
16 | from tqdm import tqdm
17 | from distutils.version import LooseVersion
18 |
19 | from utils import *
20 | from lars import *
21 | from lamb import *
22 |
23 | warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
24 |
25 |
26 | def initialize():
27 | # Training settings
28 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Example',
29 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
30 | parser.add_argument('--train-dir', default='/tmp/imagenet/ILSVRC2012_img_train/',
31 | help='path to training data')
32 | parser.add_argument('--val-dir', default='/tmp/imagenet/ILSVRC2012_img_val/',
33 | help='path to validation data')
34 | parser.add_argument('--log-dir', default='./logs/imagenet',
35 | help='tensorboard/checkpoint log directory')
36 | parser.add_argument('--checkpoint-format', default='checkpoint-{epoch}.pth.tar',
37 | help='checkpoint file format')
38 | parser.add_argument('--fp16-allreduce', action='store_true', default=False,
39 | help='use fp16 compression during allreduce')
40 | parser.add_argument('--batches-per-allreduce', type=int, default=1,
41 | help='number of batches processed locally before '
42 | 'executing allreduce across workers; it multiplies '
43 | 'total batch size.')
44 |
45 | # Default settings from https://arxiv.org/abs/1706.02677.
46 | parser.add_argument('--model', default='resnet50',
47 | help='Model (resnet35, resnet50, resnet101, resnet152, resnext50, resnext101)')
48 | parser.add_argument('--batch-size', type=int, default=32,
49 | help='input batch size for training')
50 | parser.add_argument('--val-batch-size', type=int, default=32,
51 | help='input batch size for validation')
52 | parser.add_argument('--epochs', type=int, default=90,
53 | help='number of epochs to train')
54 | parser.add_argument('--base-lr', type=float, default=0.0125,
55 | help='learning rate for a single GPU')
56 | parser.add_argument('--lr-decay', nargs='+', type=int, default=[30, 60, 80],
57 | help='epoch intervals to decay lr')
58 | parser.add_argument('--warmup-epochs', type=float, default=5,
59 | help='number of warmup epochs')
60 | parser.add_argument('--momentum', type=float, default=0.9,
61 | help='SGD momentum')
62 | parser.add_argument('--wd', type=float, default=0.00005,
63 | help='weight decay')
64 | parser.add_argument('--epsilon', type=float, default=1e-5,
65 | help='epsilon for optimizer')
66 | parser.add_argument('--label-smoothing', type=float, default=0.1,
67 | help='label smoothing (default 0.1)')
68 | parser.add_argument('--base-op', type=str, default='sgd',
69 | help='base optimizer name')
70 | parser.add_argument('--bn-bias-separately', action='store_true', default=False,
71 | help='skip bn and bias')
72 | parser.add_argument('--lr-scaling', type=str, default='keep',
73 | help='lr scaling method')
74 |
75 | parser.add_argument('--no-cuda', action='store_true', default=False,
76 | help='disables CUDA training')
77 | parser.add_argument('--single-threaded', action='store_true', default=False,
78 | help='disables multi-threaded dataloading')
79 | parser.add_argument('--seed', type=int, default=42,
80 | help='random seed')
81 |
82 | args = parser.parse_args()
83 | args.cuda = not args.no_cuda and torch.cuda.is_available()
84 |
85 | hvd.init()
86 | torch.manual_seed(args.seed)
87 |
88 | args.verbose = 1 if hvd.rank() == 0 else 0
89 |
90 | print('hvd.rank() ', hvd.rank())
91 |
92 | if args.verbose:
93 | print(args)
94 |
95 | if args.cuda:
96 | torch.cuda.set_device(hvd.local_rank())
97 | torch.cuda.manual_seed(args.seed)
98 |
99 | cudnn.benchmark = True
100 |
101 | if args.bn_bias_separately:
102 | skip = "bn_bias"
103 | elif args.bn_separately:
104 | skip = "bn"
105 | else:
106 | skip = "no"
107 |
108 | args.log_dir = os.path.join(args.log_dir,
109 | "imagenet_{}_gpu_{}_{}_ebs{}_blr_{}_skip_{}_{}".format(
110 | args.model, hvd.size(), args.base_op,
111 | args.batch_size * hvd.size() * args.batches_per_allreduce, args.base_lr, skip,
112 | datetime.now().strftime('%Y-%m-%d_%H-%M-%S')))
113 | args.checkpoint_format = os.path.join(args.log_dir, args.checkpoint_format)
114 | os.makedirs(args.log_dir, exist_ok=True)
115 |
116 | # If set > 0, will resume training from a given checkpoint.
117 | args.resume_from_epoch = 0
118 | for try_epoch in range(args.epochs, 0, -1):
119 | if os.path.exists(args.checkpoint_format.format(epoch=try_epoch)):
120 | args.resume_from_epoch = try_epoch
121 | break
122 |
123 | # Horovod: broadcast resume_from_epoch from rank 0 (which will have
124 | # checkpoints) to other ranks.
125 | args.resume_from_epoch = hvd.broadcast(torch.tensor(args.resume_from_epoch),
126 | root_rank=0,
127 | name='resume_from_epoch').item()
128 |
129 | # Horovod: write TensorBoard logs on first worker.
130 | try:
131 | if LooseVersion(torch.__version__) >= LooseVersion('1.2.0'):
132 | from torch.utils.tensorboard import SummaryWriter
133 | else:
134 | from tensorboardX import SummaryWriter
135 | args.log_writer = SummaryWriter(args.log_dir) if hvd.rank() == 0 else None
136 | except ImportError:
137 | args.log_writer = None
138 |
139 | return args
140 |
141 |
142 | def get_datasets(args):
143 | # Horovod: limit # of CPU threads to be used per worker.
144 | if args.single_threaded:
145 | torch.set_num_threads(4)
146 | kwargs = {'num_workers': 0, 'pin_memory': True} if args.cuda else {}
147 | else:
148 | torch.set_num_threads(4)
149 | num_workers_input = hvd.size()
150 | # num_workers_input = 4
151 | kwargs = {'num_workers': num_workers_input, 'pin_memory': True} if args.cuda else {}
152 | if args.verbose:
153 | print('actual num_workers ', num_workers_input)
154 |
155 | train_dataset = datasets.ImageFolder(
156 | args.train_dir,
157 | transform=transforms.Compose([
158 | transforms.RandomResizedCrop(224),
159 | transforms.RandomHorizontalFlip(),
160 | transforms.ToTensor(),
161 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
162 | std=[0.229, 0.224, 0.225])]))
163 | val_dataset = datasets.ImageFolder(
164 | args.val_dir,
165 | transform=transforms.Compose([
166 | transforms.Resize(256),
167 | transforms.CenterCrop(224),
168 | transforms.ToTensor(),
169 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
170 | std=[0.229, 0.224, 0.225])]))
171 |
172 | # Horovod: use DistributedSampler to partition data among workers. Manually specify
173 | # `num_replicas=hvd.size()` and `rank=hvd.rank()`.
174 | train_sampler = torch.utils.data.distributed.DistributedSampler(
175 | train_dataset, num_replicas=hvd.size(), rank=hvd.rank())
176 | train_loader = torch.utils.data.DataLoader(
177 | train_dataset, batch_size=args.batch_size * args.batches_per_allreduce,
178 | sampler=train_sampler, **kwargs)
179 | val_sampler = torch.utils.data.distributed.DistributedSampler(
180 | val_dataset, num_replicas=hvd.size(), rank=hvd.rank())
181 | val_loader = torch.utils.data.DataLoader(
182 | val_dataset, batch_size=args.val_batch_size,
183 | sampler=val_sampler, **kwargs)
184 |
185 | if args.verbose:
186 | print('actual batch_size ', args.batch_size * args.batches_per_allreduce * hvd.size())
187 |
188 | return train_sampler, train_loader, val_sampler, val_loader
189 |
190 |
191 | def get_model(args, num_steps_per_epoch):
192 | if args.model.lower() == 'resnet50':
193 | # model = models_local.resnet50()
194 | model = models.resnet50()
195 | else:
196 | raise ValueError('Unknown model \'{}\''.format(args.model))
197 |
198 | if args.cuda:
199 | model.cuda()
200 |
201 | # Horovod: scale learning rate by the number of GPUs.
202 | if args.lr_scaling.lower() == "linear":
203 | args.base_lr = args.base_lr * hvd.size() * args.batches_per_allreduce
204 | if args.lr_scaling.lower() == "sqrt":
205 | args.base_lr = math.sqrt(args.base_lr * hvd.size() * args.batches_per_allreduce)
206 | if args.lr_scaling.lower() == "keep":
207 | args.base_lr = args.base_lr
208 | if args.verbose:
209 | print('actual base_lr ', args.base_lr)
210 |
211 | if args.base_op.lower() == "lars":
212 | optimizer = create_optimizer_lars(model=model, lr=args.base_lr, epsilon=args.epsilon,
213 | momentum=args.momentum, weight_decay=args.wd,
214 | bn_bias_separately=args.bn_bias_separately)
215 | elif args.base_op.lower() == "lamb":
216 | optimizer = create_lamb_optimizer(model=model, lr=args.base_lr,
217 | weight_decay=args.wd)
218 | else:
219 | optimizer = optim.SGD(model.parameters(), lr=args.base_lr,
220 | momentum=args.momentum, weight_decay=args.wd)
221 |
222 | compression = hvd.Compression.fp16 if args.fp16_allreduce \
223 | else hvd.Compression.none
224 | optimizer = hvd.DistributedOptimizer(
225 | optimizer, named_parameters=model.named_parameters(),
226 | compression=compression, op=hvd.Average,
227 | backward_passes_per_step=args.batches_per_allreduce)
228 |
229 | # Restore from a previous checkpoint, if initial_epoch is specified.
230 | # Horovod: restore on the first worker which will broadcast weights
231 | # to other workers.
232 | if args.resume_from_epoch > 0 and hvd.rank() == 0:
233 | filepath = args.checkpoint_format.format(epoch=args.resume_from_epoch)
234 | checkpoint = torch.load(filepath)
235 | model.load_state_dict(checkpoint['model'])
236 | optimizer.load_state_dict(checkpoint['optimizer'])
237 |
238 | # Horovod: broadcast parameters & optimizer state.
239 | hvd.broadcast_parameters(model.state_dict(), root_rank=0)
240 | hvd.broadcast_optimizer_state(optimizer, root_rank=0)
241 |
242 | # lrs = create_lr_schedule(hvd.size(), args.warmup_epochs, args.lr_decay)
243 | # lr_scheduler = [LambdaLR(optimizer, lrs)]
244 | if args.base_op.lower() == "lars":
245 | lr_power = 2.0
246 | else:
247 | lr_power = 1.0
248 |
249 | lr_scheduler = PolynomialWarmup(optimizer, decay_steps=args.epochs * num_steps_per_epoch,
250 | warmup_steps=args.warmup_epochs * num_steps_per_epoch,
251 | end_lr=0.0, power=lr_power, last_epoch=-1)
252 |
253 | loss_func = LabelSmoothLoss(args.label_smoothing)
254 |
255 | return model, optimizer, lr_scheduler, loss_func
256 |
257 |
258 | def train(epoch, model, optimizer, lr_schedules,
259 | loss_func, train_sampler, train_loader, args):
260 | model.train()
261 | train_sampler.set_epoch(epoch)
262 | train_loss = Metric('train_loss')
263 | train_accuracy = Metric('train_accuracy')
264 |
265 | with tqdm(total=len(train_loader),
266 | desc='Epoch {:3d}/{:3d}'.format(epoch + 1, args.epochs),
267 | disable=not args.verbose) as t:
268 | for batch_idx, (data, target) in enumerate(train_loader):
269 | if args.cuda:
270 | data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)
271 | optimizer.zero_grad()
272 |
273 | for i in range(0, len(data), args.batch_size):
274 | data_batch = data[i:i + args.batch_size]
275 | target_batch = target[i:i + args.batch_size]
276 | output = model(data_batch)
277 |
278 | loss = loss_func(output, target_batch)
279 |
280 | with torch.no_grad():
281 | train_loss.update(loss)
282 | train_accuracy.update(accuracy(output, target_batch))
283 |
284 | loss.div_(math.ceil(float(len(data)) / args.batch_size))
285 | loss.backward()
286 |
287 | optimizer.synchronize()
288 |
289 | with optimizer.skip_synchronize():
290 | optimizer.step()
291 |
292 | t.set_postfix_str("loss: {:.4f}, acc: {:.2f}%".format(
293 | train_loss.avg.item(), 100 * train_accuracy.avg.item()))
294 | t.update(1)
295 |
296 | lr_schedules.step()
297 |
298 | if args.verbose:
299 | print('')
300 | print('epoch ', epoch + 1, '/', args.epochs)
301 | print('train/loss ', train_loss.avg)
302 | print('train/accuracy ', train_accuracy.avg, '%')
303 | print('train/lr ', optimizer.param_groups[0]['lr'])
304 |
305 | if args.log_writer is not None:
306 | args.log_writer.add_scalar('train/loss', train_loss.avg, epoch)
307 | args.log_writer.add_scalar('train/accuracy', train_accuracy.avg, epoch)
308 | args.log_writer.add_scalar('train/lr', optimizer.param_groups[0]['lr'], epoch)
309 |
310 |
311 | def validate(epoch, model, loss_func, val_loader, args):
312 | model.eval()
313 | val_loss = Metric('val_loss')
314 | val_accuracy = Metric('val_accuracy')
315 |
316 | with tqdm(total=len(val_loader),
317 | # bar_format='{l_bar}{bar}|{postfix}',
318 | desc=' '.format(epoch + 1, args.epochs),
319 | disable=not args.verbose) as t:
320 | with torch.no_grad():
321 | for i, (data, target) in enumerate(val_loader):
322 | if args.cuda:
323 | data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)
324 | output = model(data)
325 | val_loss.update(loss_func(output, target))
326 | val_accuracy.update(accuracy(output, target))
327 |
328 | t.update(1)
329 | if i + 1 == len(val_loader):
330 | t.set_postfix_str("\b\b val_loss: {:.4f}, val_acc: {:.2f}%".format(
331 | val_loss.avg.item(), 100 * val_accuracy.avg.item()),
332 | refresh=False)
333 | if args.verbose:
334 | print('')
335 | print('val/loss ', val_loss.avg)
336 | print('val/accuracy ', val_accuracy.avg, '%')
337 |
338 | if args.log_writer is not None:
339 | args.log_writer.add_scalar('val/loss', val_loss.avg, epoch)
340 | args.log_writer.add_scalar('val/accuracy', val_accuracy.avg, epoch)
341 |
342 |
343 | if __name__ == '__main__':
344 | torch.multiprocessing.set_start_method('spawn')
345 |
346 | args = initialize()
347 |
348 | train_sampler, train_loader, _, val_loader = get_datasets(args)
349 |
350 | num_steps_per_epoch = len(train_loader)
351 |
352 | model, opt, lr_schedules, loss_func = get_model(args, num_steps_per_epoch)
353 |
354 | if args.verbose:
355 | print("MODEL:", args.model)
356 |
357 | start = time.time()
358 |
359 | for epoch in range(args.resume_from_epoch, args.epochs):
360 | train(epoch, model, opt, lr_schedules,
361 | loss_func, train_sampler, train_loader, args)
362 | validate(epoch, model, loss_func, val_loader, args)
363 | save_checkpoint(model, opt, args.checkpoint_format, epoch)
364 |
365 | if args.verbose:
366 | print("\nTraining time:", str(timedelta(seconds=time.time() - start)))
367 |
--------------------------------------------------------------------------------
/pytorch_imagenet_resnet_dali.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import time
4 | from datetime import datetime, timedelta
5 | import argparse
6 | import os
7 | import math
8 | import warnings
9 | import torch
10 | import torch.backends.cudnn as cudnn
11 | import torch.nn.functional as F
12 | import torch.optim as optim
13 | import torch.utils.data.distributed
14 | from torchvision import datasets, transforms, models
15 | import horovod.torch as hvd
16 | from tqdm import tqdm
17 | from distutils.version import LooseVersion
18 |
19 | from nvidia.dali.pipeline import Pipeline
20 | from nvidia.dali.plugin.pytorch import DALIClassificationIterator, LastBatchPolicy
21 | import nvidia.dali.fn as fn
22 | import nvidia.dali.types as types
23 | import nvidia.dali.tfrecord as tfrec
24 | import glob
25 |
26 | from utils import *
27 | from lars import *
28 | from lamb import *
29 |
30 | warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
31 |
32 |
33 | def initialize():
34 | # Training settings
35 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Example',
36 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
37 | parser.add_argument('--train-dir', default='/tmp/imagenet/ILSVRC2012_img_train/',
38 | help='path to training data')
39 | parser.add_argument('--val-dir', default='/tmp/imagenet/ILSVRC2012_img_val/',
40 | help='path to validation data')
41 | parser.add_argument('--data-dir', default='/tmp/imagenet/',
42 | help='path to data data')
43 | parser.add_argument('--log-dir', default='./logs/imagenet',
44 | help='tensorboard/checkpoint log directory')
45 | parser.add_argument('--checkpoint-format', default='checkpoint-{epoch}.pth.tar',
46 | help='checkpoint file format')
47 | parser.add_argument('--fp16-allreduce', action='store_true', default=False,
48 | help='use fp16 compression during allreduce')
49 | parser.add_argument('--batches-per-allreduce', type=int, default=1,
50 | help='number of batches processed locally before '
51 | 'executing allreduce across workers; it multiplies '
52 | 'total batch size.')
53 |
54 | # Default settings from https://arxiv.org/abs/1706.02677.
55 | parser.add_argument('--model', default='resnet50',
56 | help='Model (resnet35, resnet50, resnet101, resnet152, resnext50, resnext101)')
57 | parser.add_argument('--batch-size', type=int, default=32,
58 | help='input batch size for training')
59 | parser.add_argument('--val-batch-size', type=int, default=32,
60 | help='input batch size for validation')
61 | parser.add_argument('--epochs', type=int, default=90,
62 | help='number of epochs to train')
63 | parser.add_argument('--base-lr', type=float, default=0.0125,
64 | help='learning rate for a single GPU')
65 | parser.add_argument('--lr-decay', nargs='+', type=int, default=[30, 60, 80],
66 | help='epoch intervals to decay lr')
67 | parser.add_argument('--warmup-epochs', type=float, default=5,
68 | help='number of warmup epochs')
69 | parser.add_argument('--momentum', type=float, default=0.9,
70 | help='SGD momentum')
71 | parser.add_argument('--wd', type=float, default=0.00005,
72 | help='weight decay')
73 | parser.add_argument('--epsilon', type=float, default=1e-5,
74 | help='epsilon for optimizer')
75 | parser.add_argument('--label-smoothing', type=float, default=0.1,
76 | help='label smoothing (default 0.1)')
77 | parser.add_argument('--base-op', type=str, default='sgd',
78 | help='base optimizer name')
79 | parser.add_argument('--bn-bias-separately', action='store_true', default=False,
80 | help='skip bn and bias')
81 | parser.add_argument('--lr-scaling', type=str, default='keep',
82 | help='lr scaling method')
83 |
84 | parser.add_argument('--no-cuda', action='store_true', default=False,
85 | help='disables CUDA training')
86 | parser.add_argument('--single-threaded', action='store_true', default=False,
87 | help='disables multi-threaded dataloading')
88 | parser.add_argument('--seed', type=int, default=42,
89 | help='random seed')
90 |
91 | args = parser.parse_args()
92 | args.cuda = not args.no_cuda and torch.cuda.is_available()
93 |
94 | hvd.init()
95 | torch.manual_seed(args.seed)
96 |
97 | args.verbose = 1 if hvd.rank() == 0 else 0
98 |
99 | print('hvd.rank() ', hvd.rank())
100 |
101 | if args.verbose:
102 | print(args)
103 |
104 | if args.cuda:
105 | torch.cuda.set_device(hvd.local_rank())
106 | torch.cuda.manual_seed(args.seed)
107 | cudnn.deterministic = True
108 | cudnn.benchmark = False
109 |
110 | if args.bn_bias_separately:
111 | skip = "bn_bias"
112 | elif args.bn_separately:
113 | skip = "bn"
114 | else:
115 | skip = "no"
116 |
117 | args.log_dir = os.path.join(args.log_dir,
118 | "imagenet_{}_gpu_{}_{}_ebs{}_blr_{}_skip_{}_{}".format(
119 | args.model, hvd.size(), args.base_op,
120 | args.batch_size * hvd.size() * args.batches_per_allreduce, args.base_lr, skip,
121 | datetime.now().strftime('%Y-%m-%d_%H-%M-%S')))
122 | args.checkpoint_format = os.path.join(args.log_dir, args.checkpoint_format)
123 | os.makedirs(args.log_dir, exist_ok=True)
124 |
125 | # If set > 0, will resume training from a given checkpoint.
126 | args.resume_from_epoch = 0
127 | for try_epoch in range(args.epochs, 0, -1):
128 | if os.path.exists(args.checkpoint_format.format(epoch=try_epoch)):
129 | args.resume_from_epoch = try_epoch
130 | break
131 |
132 | # Horovod: broadcast resume_from_epoch from rank 0 (which will have
133 | # checkpoints) to other ranks.
134 | args.resume_from_epoch = hvd.broadcast(torch.tensor(args.resume_from_epoch),
135 | root_rank=0,
136 | name='resume_from_epoch').item()
137 |
138 | # Horovod: write TensorBoard logs on first worker.
139 | try:
140 | if LooseVersion(torch.__version__) >= LooseVersion('1.2.0'):
141 | from torch.utils.tensorboard import SummaryWriter
142 | else:
143 | from tensorboardX import SummaryWriter
144 | args.log_writer = SummaryWriter(args.log_dir) if hvd.rank() == 0 else None
145 | except ImportError:
146 | args.log_writer = None
147 |
148 | return args
149 |
150 |
151 | def dali_dataloader(
152 | tfrec_filenames,
153 | tfrec_idx_filenames,
154 | shard_id=0, num_shards=1,
155 | batch_size=64, num_threads=2,
156 | image_size=224, num_workers=1, training=True):
157 | pipe = Pipeline(batch_size=batch_size,
158 | num_threads=num_threads, device_id=hvd.local_rank())
159 | with pipe:
160 | inputs = fn.readers.tfrecord(
161 | path=tfrec_filenames,
162 | index_path=tfrec_idx_filenames,
163 | random_shuffle=training,
164 | shard_id=shard_id,
165 | num_shards=num_shards,
166 | initial_fill=10000,
167 | read_ahead=True,
168 | pad_last_batch=True,
169 | prefetch_queue_depth=num_workers,
170 | name='Reader',
171 | features={
172 | 'image/encoded': tfrec.FixedLenFeature((), tfrec.string, ""),
173 | 'image/class/label': tfrec.FixedLenFeature([1], tfrec.int64, -1),
174 | })
175 | jpegs = inputs["image/encoded"]
176 | if training:
177 | images = fn.decoders.image_random_crop(
178 | jpegs,
179 | device="mixed",
180 | output_type=types.RGB,
181 | random_aspect_ratio=[0.8, 1.25],
182 | random_area=[0.1, 1.0],
183 | num_attempts=100)
184 | images = fn.resize(images,
185 | device='gpu',
186 | resize_x=image_size,
187 | resize_y=image_size,
188 | interp_type=types.INTERP_TRIANGULAR)
189 | mirror = fn.random.coin_flip(probability=0.5)
190 | else:
191 | images = fn.decoders.image(jpegs,
192 | device='mixed',
193 | output_type=types.RGB)
194 | images = fn.resize(images,
195 | device='gpu',
196 | size=int(image_size / 0.875),
197 | mode="not_smaller",
198 | interp_type=types.INTERP_TRIANGULAR)
199 | mirror = False
200 |
201 | images = fn.crop_mirror_normalize(
202 | images.gpu(),
203 | dtype=types.FLOAT,
204 | crop=(image_size, image_size),
205 | mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
206 | std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
207 | mirror=mirror)
208 | label = inputs["image/class/label"] - 1 # 0-999
209 | label = fn.element_extract(label, element_map=0) # Flatten
210 | label = label.gpu()
211 | pipe.set_outputs(images, label)
212 |
213 | pipe.build()
214 | last_batch_policy = LastBatchPolicy.DROP if training else LastBatchPolicy.PARTIAL
215 | loader = DALIClassificationIterator(
216 | pipe, reader_name="Reader", auto_reset=True, last_batch_policy=last_batch_policy)
217 | return loader
218 |
219 |
220 | def get_datasets(args):
221 | num_shards = hvd.size()
222 | shard_id = hvd.rank()
223 | num_workers = 1
224 | num_threads = 2
225 | root = args.data_dir
226 |
227 | train_pat = os.path.join(root, 'train/*')
228 | train_idx_pat = os.path.join(root, 'idx_files/train/*')
229 | train_loader = dali_dataloader(sorted(glob.glob(train_pat)),
230 | sorted(glob.glob(train_idx_pat)),
231 | shard_id=shard_id,
232 | num_shards=num_shards,
233 | batch_size=args.batch_size * args.batches_per_allreduce,
234 | num_workers=num_workers,
235 | num_threads=num_threads,
236 | training=True)
237 | test_pat = os.path.join(root, 'validation/*')
238 | test_idx_pat = os.path.join(root, 'idx_files/validation/*')
239 | val_loader = dali_dataloader(sorted(glob.glob(test_pat)),
240 | sorted(glob.glob(test_idx_pat)),
241 | shard_id=shard_id,
242 | num_shards=num_shards,
243 | batch_size=args.val_batch_size,
244 | num_workers=num_workers,
245 | num_threads=num_threads,
246 | training=False)
247 | if args.verbose:
248 | print('actual batch_size ', args.batch_size * args.batches_per_allreduce * hvd.size())
249 |
250 | return train_loader, val_loader
251 |
252 |
253 | def get_model(args, num_steps_per_epoch):
254 | if args.model.lower() == 'resnet50':
255 | # model = models_local.resnet50()
256 | model = models.resnet50()
257 | else:
258 | raise ValueError('Unknown model \'{}\''.format(args.model))
259 |
260 | if args.cuda:
261 | model.cuda()
262 |
263 | # Horovod: scale learning rate by the number of GPUs.
264 | if args.lr_scaling.lower() == "linear":
265 | args.base_lr = args.base_lr * hvd.size() * args.batches_per_allreduce
266 | if args.lr_scaling.lower() == "sqrt":
267 | args.base_lr = math.sqrt(args.base_lr * hvd.size() * args.batches_per_allreduce)
268 | if args.lr_scaling.lower() == "keep":
269 | args.base_lr = args.base_lr
270 | if args.verbose:
271 | print('actual base_lr ', args.base_lr)
272 |
273 | if args.base_op.lower() == "lars":
274 | optimizer = create_optimizer_lars(model=model, lr=args.base_lr, epsilon=args.epsilon,
275 | momentum=args.momentum, weight_decay=args.wd,
276 | bn_bias_separately=args.bn_bias_separately)
277 | elif args.base_op.lower() == "lamb":
278 | optimizer = create_lamb_optimizer(model=model, lr=args.base_lr,
279 | weight_decay=args.wd)
280 | else:
281 | optimizer = optim.SGD(model.parameters(), lr=args.base_lr,
282 | momentum=args.momentum, weight_decay=args.wd)
283 |
284 | compression = hvd.Compression.fp16 if args.fp16_allreduce \
285 | else hvd.Compression.none
286 | optimizer = hvd.DistributedOptimizer(
287 | optimizer, named_parameters=model.named_parameters(),
288 | compression=compression, op=hvd.Average,
289 | backward_passes_per_step=args.batches_per_allreduce)
290 |
291 | # Restore from a previous checkpoint, if initial_epoch is specified.
292 | # Horovod: restore on the first worker which will broadcast weights
293 | # to other workers.
294 | if args.resume_from_epoch > 0 and hvd.rank() == 0:
295 | filepath = args.checkpoint_format.format(epoch=args.resume_from_epoch)
296 | checkpoint = torch.load(filepath)
297 | model.load_state_dict(checkpoint['model'])
298 | optimizer.load_state_dict(checkpoint['optimizer'])
299 |
300 | # Horovod: broadcast parameters & optimizer state.
301 | hvd.broadcast_parameters(model.state_dict(), root_rank=0)
302 | hvd.broadcast_optimizer_state(optimizer, root_rank=0)
303 |
304 | # lrs = create_lr_schedule(hvd.size(), args.warmup_epochs, args.lr_decay)
305 | # lr_scheduler = [LambdaLR(optimizer, lrs)]
306 | if args.base_op.lower() == "lars":
307 | lr_power = 2.0
308 | else:
309 | lr_power = 1.0
310 |
311 | lr_scheduler = PolynomialWarmup(optimizer, decay_steps=args.epochs * num_steps_per_epoch,
312 | warmup_steps=args.warmup_epochs * num_steps_per_epoch,
313 | end_lr=0.0, power=lr_power, last_epoch=-1)
314 |
315 | loss_func = LabelSmoothLoss(args.label_smoothing)
316 |
317 | return model, optimizer, lr_scheduler, loss_func
318 |
319 |
320 | def train(epoch, model, optimizer, lr_schedules,
321 | loss_func, train_loader, args):
322 | model.train()
323 | # train_sampler.set_epoch(epoch)
324 | train_loss = Metric('train_loss')
325 | train_accuracy = Metric('train_accuracy')
326 |
327 | with tqdm(total=len(train_loader),
328 | desc='Epoch {:3d}/{:3d}'.format(epoch + 1, args.epochs),
329 | disable=not args.verbose) as t:
330 | for batch_idx, data in enumerate(train_loader):
331 | input, target = data[0]['data'], data[0]['label']
332 | # if args.cuda:
333 | # data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)
334 | optimizer.zero_grad()
335 |
336 | for i in range(0, len(input), args.batch_size):
337 | data_batch = input[i:i + args.batch_size]
338 | target_batch = target[i:i + args.batch_size]
339 | output = model(data_batch)
340 |
341 | loss = loss_func(output, target_batch)
342 | loss = loss / args.batches_per_allreduce
343 |
344 | with torch.no_grad():
345 | train_loss.update(loss)
346 | train_accuracy.update(accuracy(output, target_batch))
347 |
348 | loss.backward()
349 |
350 | optimizer.synchronize()
351 |
352 | with optimizer.skip_synchronize():
353 | optimizer.step()
354 |
355 | t.set_postfix_str("loss: {:.4f}, acc: {:.2f}%, lr: {:.4f}".format(
356 | train_loss.avg.item(), 100 * train_accuracy.avg.item(), optimizer.param_groups[0]['lr']))
357 | t.update(1)
358 |
359 | lr_schedules.step()
360 |
361 | if args.verbose:
362 | print('')
363 | print('epoch ', epoch + 1, '/', args.epochs)
364 | print('train/loss ', train_loss.avg)
365 | print('train/accuracy ', train_accuracy.avg, '%')
366 | print('train/lr ', optimizer.param_groups[0]['lr'])
367 |
368 | if args.log_writer is not None:
369 | args.log_writer.add_scalar('train/loss', train_loss.avg, epoch)
370 | args.log_writer.add_scalar('train/accuracy', train_accuracy.avg, epoch)
371 | args.log_writer.add_scalar('train/lr', optimizer.param_groups[0]['lr'], epoch)
372 |
373 |
374 | def validate(epoch, model, loss_func, val_loader, args):
375 | model.eval()
376 | val_loss = Metric('val_loss')
377 | val_accuracy = Metric('val_accuracy')
378 |
379 | with tqdm(total=len(val_loader),
380 | # bar_format='{l_bar}{bar}|{postfix}',
381 | desc=' '.format(epoch + 1, args.epochs),
382 | disable=not args.verbose) as t:
383 | with torch.no_grad():
384 | for i, data in enumerate(val_loader):
385 | input, target = data[0]['data'], data[0]['label']
386 | # if args.cuda:
387 | # data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)
388 | output = model(input)
389 | val_loss.update(loss_func(output, target))
390 | val_accuracy.update(accuracy(output, target))
391 |
392 | t.update(1)
393 | if i + 1 == len(val_loader):
394 | t.set_postfix_str("\b\b val_loss: {:.4f}, val_acc: {:.2f}%".format(
395 | val_loss.avg.item(), 100 * val_accuracy.avg.item()),
396 | refresh=False)
397 | if args.verbose:
398 | print('')
399 | print('val/loss ', val_loss.avg)
400 | print('val/accuracy ', val_accuracy.avg, '%')
401 |
402 | if args.log_writer is not None:
403 | args.log_writer.add_scalar('val/loss', val_loss.avg, epoch)
404 | args.log_writer.add_scalar('val/accuracy', val_accuracy.avg, epoch)
405 |
406 |
407 | if __name__ == '__main__':
408 |
409 | args = initialize()
410 |
411 | if args.single_threaded:
412 | print('Not use torch.multiprocessing.set_start_method')
413 | else:
414 | # torch.multiprocessing.set_start_method('spawn')
415 | # torch.multiprocessing.set_start_method('forkserver')
416 | torch.multiprocessing.set_start_method('spawn')
417 |
418 | train_loader, val_loader = get_datasets(args=args)
419 |
420 | num_steps_per_epoch = len(train_loader)
421 |
422 | model, opt, lr_schedules, loss_func = get_model(args, num_steps_per_epoch)
423 |
424 | if args.verbose:
425 | print("MODEL:", args.model)
426 |
427 | start = time.time()
428 |
429 | for epoch in range(args.resume_from_epoch, args.epochs):
430 | train(epoch=epoch, model=model, optimizer=opt, lr_schedules=lr_schedules,
431 | loss_func=loss_func, train_loader=train_loader, args=args)
432 | validate(epoch, model, loss_func, val_loader, args)
433 | save_checkpoint(model, opt, args.checkpoint_format, epoch)
434 |
435 | if args.verbose:
436 | print("\nTraining time:", str(timedelta(seconds=time.time() - start)))
437 |
--------------------------------------------------------------------------------