├── __init__.py ├── UniformAugment ├── networks │ ├── shakeshake │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── shake_resnet.cpython-36.pyc │ │ │ └── shakeshake.cpython-36.pyc │ │ ├── shakeshake.py │ │ ├── shake_resnet.py │ │ └── shake_resnext.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── pyramidnet.cpython-36.pyc │ │ ├── shakedrop.cpython-36.pyc │ │ └── wideresnet.cpython-36.pyc │ ├── efficientnet_pytorch │ │ ├── __init__.py │ │ ├── condconv.py │ │ ├── model.py │ │ └── utils.py │ ├── shakedrop.py │ ├── wideresnet.py │ ├── __init__.py │ ├── resnet.py │ └── pyramidnet.py ├── __init__.py ├── __pycache__ │ ├── data.cpython-36.pyc │ ├── common.cpython-36.pyc │ ├── metrics.cpython-36.pyc │ ├── __init__.cpython-36.pyc │ ├── lr_scheduler.cpython-36.pyc │ └── augmentations.cpython-36.pyc ├── lr_scheduler.py ├── aug_mixup.py ├── common.py ├── metrics.py ├── augmentations.py ├── imagenet.py ├── data.py └── train.py ├── confs ├── wresnet28x10_svhn.yaml ├── wresnet28x10_cifar.yaml ├── wresnet40x2_cifar.yaml ├── resnet200.yaml ├── resnet50.yaml ├── shake26_2x112d_cifar.yaml ├── shake26_2x32d_cifar.yaml ├── shake26_2x96d_cifar.yaml ├── resnet50_mixup.yaml ├── pyramid272_cifar.yaml ├── efficientnet_b0.yaml ├── efficientnet_b1.yaml ├── efficientnet_b2.yaml ├── efficientnet_b3.yaml ├── efficientnet_b4.yaml └── efficientnet_b0_condconv.yaml ├── requirements.txt ├── setup.py ├── LICENSE └── README.md /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /UniformAugment/networks/shakeshake/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /UniformAugment/__init__.py: -------------------------------------------------------------------------------- 1 | from UniformAugment.augmentations import UniformAugment 2 | -------------------------------------------------------------------------------- /UniformAugment/__pycache__/data.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tgilewicz/uniformaugment/HEAD/UniformAugment/__pycache__/data.cpython-36.pyc -------------------------------------------------------------------------------- /UniformAugment/__pycache__/common.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tgilewicz/uniformaugment/HEAD/UniformAugment/__pycache__/common.cpython-36.pyc -------------------------------------------------------------------------------- /UniformAugment/__pycache__/metrics.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tgilewicz/uniformaugment/HEAD/UniformAugment/__pycache__/metrics.cpython-36.pyc -------------------------------------------------------------------------------- /UniformAugment/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tgilewicz/uniformaugment/HEAD/UniformAugment/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /UniformAugment/__pycache__/lr_scheduler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tgilewicz/uniformaugment/HEAD/UniformAugment/__pycache__/lr_scheduler.cpython-36.pyc -------------------------------------------------------------------------------- /UniformAugment/__pycache__/augmentations.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tgilewicz/uniformaugment/HEAD/UniformAugment/__pycache__/augmentations.cpython-36.pyc -------------------------------------------------------------------------------- /UniformAugment/networks/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tgilewicz/uniformaugment/HEAD/UniformAugment/networks/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /UniformAugment/networks/__pycache__/pyramidnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tgilewicz/uniformaugment/HEAD/UniformAugment/networks/__pycache__/pyramidnet.cpython-36.pyc -------------------------------------------------------------------------------- /UniformAugment/networks/__pycache__/shakedrop.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tgilewicz/uniformaugment/HEAD/UniformAugment/networks/__pycache__/shakedrop.cpython-36.pyc -------------------------------------------------------------------------------- /UniformAugment/networks/__pycache__/wideresnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tgilewicz/uniformaugment/HEAD/UniformAugment/networks/__pycache__/wideresnet.cpython-36.pyc -------------------------------------------------------------------------------- /UniformAugment/networks/shakeshake/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tgilewicz/uniformaugment/HEAD/UniformAugment/networks/shakeshake/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /UniformAugment/networks/shakeshake/__pycache__/shake_resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tgilewicz/uniformaugment/HEAD/UniformAugment/networks/shakeshake/__pycache__/shake_resnet.cpython-36.pyc -------------------------------------------------------------------------------- /UniformAugment/networks/shakeshake/__pycache__/shakeshake.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tgilewicz/uniformaugment/HEAD/UniformAugment/networks/shakeshake/__pycache__/shakeshake.cpython-36.pyc -------------------------------------------------------------------------------- /UniformAugment/networks/efficientnet_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.5.1" 2 | from .model import EfficientNet, RoutingFn 3 | from .utils import ( 4 | GlobalParams, 5 | BlockArgs, 6 | BlockDecoder, 7 | efficientnet, 8 | get_model_params, 9 | ) -------------------------------------------------------------------------------- /confs/wresnet28x10_svhn.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: wresnet28_10 3 | dataset: svhn 4 | aug: uniformaugment 5 | cutout: 20 6 | batch: 128 7 | epoch: 200 8 | lr: 0.01 9 | lr_schedule: 10 | type: 'cosine' 11 | warmup: 12 | multiplier: 1 13 | epoch: 5 14 | optimizer: 15 | type: sgd 16 | nesterov: True 17 | decay: 0.0005 18 | ema: 0 -------------------------------------------------------------------------------- /confs/wresnet28x10_cifar.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: wresnet28_10 3 | dataset: cifar10 4 | aug: uniformaugment 5 | cutout: 16 6 | batch: 128 7 | epoch: 200 8 | lr: 0.1 9 | lr_schedule: 10 | type: 'cosine' 11 | warmup: 12 | multiplier: 1 13 | epoch: 5 14 | optimizer: 15 | type: sgd 16 | nesterov: True 17 | decay: 0.0005 18 | ema: 0 -------------------------------------------------------------------------------- /confs/wresnet40x2_cifar.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: wresnet40_2 3 | dataset: cifar10 4 | aug: uniformaugment 5 | cutout: 16 6 | batch: 128 7 | epoch: 200 8 | lr: 0.1 9 | lr_schedule: 10 | type: 'cosine' 11 | warmup: 12 | multiplier: 1 13 | epoch: 5 14 | optimizer: 15 | type: sgd 16 | nesterov: True 17 | decay: 0.0002 18 | ema: 0 -------------------------------------------------------------------------------- /confs/resnet200.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: resnet200 3 | dataset: imagenet 4 | aug: uniformaugment 5 | cutout: 0 6 | batch: 64 7 | epoch: 270 8 | lr: 0.025 9 | lr_schedule: 10 | type: 'resnet' 11 | warmup: 12 | multiplier: 1 13 | epoch: 5 14 | optimizer: 15 | type: sgd 16 | nesterov: True 17 | decay: 0.0001 18 | clip: 0 19 | ema: 0 20 | -------------------------------------------------------------------------------- /confs/resnet50.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: resnet50 3 | dataset: imagenet 4 | aug: uniformaugment 5 | cutout: 0 6 | batch: 128 7 | epoch: 270 8 | lr: 0.05 9 | lr_schedule: 10 | type: 'resnet' 11 | warmup: 12 | multiplier: 1 13 | epoch: 5 14 | optimizer: 15 | type: sgd 16 | nesterov: True 17 | decay: 0.0001 18 | clip: 0 19 | ema: 0 20 | -------------------------------------------------------------------------------- /confs/shake26_2x112d_cifar.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: shakeshake26_2x112d 3 | dataset: cifar10 4 | aug: uniformaugment 5 | cutout: 16 6 | batch: 128 7 | epoch: 1800 8 | lr: 0.01 9 | lr_schedule: 10 | type: 'cosine' 11 | warmup: 12 | multiplier: 1 13 | epoch: 5 14 | optimizer: 15 | type: sgd 16 | nesterov: True 17 | decay: 0.002 18 | ema: 0 19 | -------------------------------------------------------------------------------- /confs/shake26_2x32d_cifar.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: shakeshake26_2x32d 3 | dataset: cifar10 4 | aug: uniformaugment 5 | cutout: 16 6 | batch: 128 7 | epoch: 1800 8 | lr: 0.01 9 | lr_schedule: 10 | type: 'cosine' 11 | warmup: 12 | multiplier: 1 13 | epoch: 5 14 | optimizer: 15 | type: sgd 16 | nesterov: True 17 | decay: 0.001 18 | ema: 0 19 | -------------------------------------------------------------------------------- /confs/shake26_2x96d_cifar.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: shakeshake26_2x96d 3 | dataset: cifar10 4 | aug: uniformaugment 5 | cutout: 16 6 | batch: 128 7 | epoch: 1800 8 | lr: 0.01 9 | lr_schedule: 10 | type: 'cosine' 11 | warmup: 12 | multiplier: 1 13 | epoch: 5 14 | optimizer: 15 | type: sgd 16 | nesterov: True 17 | decay: 0.001 18 | ema: 0 19 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | git+https://github.com/wbaek/theconf@de32022f8c0651a043dc812d17194cdfd62066e8 2 | git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git@08f7d5e 3 | git+https://github.com/ildoonet/pystopwatch2.git 4 | git+https://github.com/hyperopt/hyperopt.git 5 | 6 | pretrainedmodels 7 | tqdm 8 | tensorboardx 9 | sklearn 10 | ray 11 | matplotlib 12 | psutil 13 | requests -------------------------------------------------------------------------------- /confs/resnet50_mixup.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: resnet50 3 | dataset: imagenet 4 | aug: uniformaugment 5 | cutout: 0 6 | batch: 128 7 | epoch: 270 8 | lr: 0.05 9 | lr_schedule: 10 | type: 'resnet' 11 | warmup: 12 | multiplier: 1 13 | epoch: 5 14 | optimizer: 15 | type: sgd 16 | nesterov: True 17 | decay: 0.0001 18 | clip: 0 19 | ema: 0 20 | #lb_smooth: 0.1 21 | mixup: 0.2 22 | -------------------------------------------------------------------------------- /confs/pyramid272_cifar.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: pyramid 3 | depth: 272 4 | alpha: 200 5 | bottleneck: True 6 | dataset: cifar10 7 | aug: uniformaugment 8 | cutout: 16 9 | batch: 64 10 | epoch: 1800 11 | lr: 0.05 12 | lr_schedule: 13 | type: 'cosine' 14 | warmup: 15 | multiplier: 1 16 | epoch: 5 17 | optimizer: 18 | type: sgd 19 | nesterov: True 20 | decay: 0.00005 21 | ema: 0 22 | -------------------------------------------------------------------------------- /confs/efficientnet_b0.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: efficientnet-b0 3 | condconv_num_expert: 1 # if this is greater than 1(eg. 4), it activates condconv. 4 | dataset: imagenet 5 | aug: uniformaugment 6 | cutout: 0 7 | batch: 128 # per gpu 8 | epoch: 350 9 | lr: 0.008 # 0.256 for 4096 batch 10 | lr_schedule: 11 | type: 'efficientnet' 12 | warmup: 13 | multiplier: 1 14 | epoch: 5 15 | optimizer: 16 | type: rmsprop 17 | decay: 0.00001 18 | clip: 0 19 | ema: 0.9999 20 | ema_interval: -1 21 | lb_smooth: 0.1 -------------------------------------------------------------------------------- /confs/efficientnet_b1.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: efficientnet-b1 3 | condconv_num_expert: 1 # if this is greater than 1(eg. 4), it activates condconv. 4 | dataset: imagenet 5 | aug: uniformaugment 6 | cutout: 0 7 | batch: 128 # per gpu 8 | epoch: 350 9 | lr: 0.008 # 0.256 for 4096 batch 10 | lr_schedule: 11 | type: 'efficientnet' 12 | warmup: 13 | multiplier: 1 14 | epoch: 5 15 | optimizer: 16 | type: rmsprop 17 | decay: 0.00001 18 | clip: 0 19 | ema: 0.9999 20 | ema_interval: -1 21 | lb_smooth: 0.1 22 | -------------------------------------------------------------------------------- /confs/efficientnet_b2.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: efficientnet-b2 3 | condconv_num_expert: 1 # if this is greater than 1(eg. 4), it activates condconv. 4 | dataset: imagenet 5 | aug: uniformaugment 6 | cutout: 0 7 | batch: 128 # per gpu 8 | epoch: 350 9 | lr: 0.008 # 0.256 for 4096 batch 10 | lr_schedule: 11 | type: 'efficientnet' 12 | warmup: 13 | multiplier: 1 14 | epoch: 5 15 | optimizer: 16 | type: rmsprop 17 | decay: 0.00001 18 | clip: 0 19 | ema: 0.9999 20 | ema_interval: -1 21 | lb_smooth: 0.1 22 | -------------------------------------------------------------------------------- /confs/efficientnet_b3.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: efficientnet-b3 3 | condconv_num_expert: 1 # if this is greater than 1(eg. 4), it activates condconv. 4 | dataset: imagenet 5 | aug: uniformaugment 6 | cutout: 0 7 | batch: 64 # per gpu 8 | epoch: 350 9 | lr: 0.004 # 0.256 for 4096 batch 10 | lr_schedule: 11 | type: 'efficientnet' 12 | warmup: 13 | multiplier: 1 14 | epoch: 5 15 | optimizer: 16 | type: rmsprop 17 | decay: 0.00001 18 | clip: 0 19 | ema: 0.9999 20 | ema_interval: -1 21 | lb_smooth: 0.1 22 | -------------------------------------------------------------------------------- /confs/efficientnet_b4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: efficientnet-b4 3 | condconv_num_expert: 1 # if this is greater than 1(eg. 4), it activates condconv. 4 | dataset: imagenet 5 | aug: uniformaugment 6 | cutout: 0 7 | batch: 32 # per gpu 8 | epoch: 350 9 | lr: 0.002 # 0.256 for 4096 batch 10 | lr_schedule: 11 | type: 'efficientnet' 12 | warmup: 13 | multiplier: 1 14 | epoch: 5 15 | optimizer: 16 | type: rmsprop 17 | decay: 0.00001 18 | clip: 0 19 | ema: 0.9999 20 | ema_interval: -1 21 | lb_smooth: 0.1 22 | -------------------------------------------------------------------------------- /confs/efficientnet_b0_condconv.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: efficientnet-b0 3 | condconv_num_expert: 8 # if this is greater than 1(eg. 4), it activates condconv. 4 | dataset: imagenet 5 | aug: uniformaugment 6 | cutout: 0 7 | batch: 128 # per gpu 8 | epoch: 350 9 | lr: 0.008 # 0.256 for 4096 batch 10 | lr_schedule: 11 | type: 'efficientnet' 12 | warmup: 13 | multiplier: 1 14 | epoch: 5 15 | optimizer: 16 | type: rmsprop 17 | decay: 0.00001 18 | clip: 0 19 | ema: 0.9999 20 | ema_interval: -1 21 | lb_smooth: 0.1 22 | mixup: 0.2 23 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import setuptools 6 | 7 | _VERSION = '0.1' 8 | REQUIRED_PACKAGES = [] 9 | DEPENDENCY_LINKS = [] 10 | 11 | setuptools.setup( 12 | name='UniformAugment', 13 | version=_VERSION, 14 | description='Unofficial PyTorch Reimplementation of UniformAugment', 15 | install_requires=REQUIRED_PACKAGES, 16 | dependency_links=DEPENDENCY_LINKS, 17 | url='https://github.com/tgilewicz/uniformaugment', 18 | license='MIT License', 19 | package_dir={}, 20 | packages=setuptools.find_packages(exclude=['tests']), 21 | ) 22 | -------------------------------------------------------------------------------- /UniformAugment/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.lr_scheduler import MultiStepLR 3 | from theconf import Config as C 4 | 5 | 6 | def adjust_learning_rate_resnet(optimizer): 7 | """ 8 | Sets the learning rate to the initial LR decayed by 10 on every predefined epochs 9 | Ref: AutoAugment 10 | """ 11 | 12 | if C.get()['epoch'] == 90: 13 | return MultiStepLR_HotFix(optimizer, [30, 60, 80]) 14 | elif C.get()['epoch'] == 270: # autoaugment 15 | return MultiStepLR_HotFix(optimizer, [90, 180, 240]) 16 | else: 17 | raise ValueError('invalid epoch=%d for resnet scheduler' % C.get()['epoch']) 18 | 19 | 20 | class MultiStepLR_HotFix(MultiStepLR): 21 | def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1): 22 | super(MultiStepLR_HotFix, self).__init__(optimizer, milestones, gamma, last_epoch) 23 | self.milestones = list(milestones) 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Ildoo Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /UniformAugment/aug_mixup.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reference : 3 | - https://github.com/hysts/pytorch_image_classification/blob/master/augmentations/mixup.py 4 | - https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/imagenet_input.py#L120 5 | """ 6 | 7 | import numpy as np 8 | import torch 9 | 10 | from UniformAugment.metrics import CrossEntropyLabelSmooth 11 | 12 | 13 | def mixup(data, targets, alpha): 14 | indices = torch.randperm(data.size(0)) 15 | shuffled_data = data[indices] 16 | shuffled_targets = targets[indices] 17 | 18 | lam = np.random.beta(alpha, alpha) 19 | lam = max(lam, 1. - lam) 20 | assert 0.0 <= lam <= 1.0, lam 21 | data = data * lam + shuffled_data * (1 - lam) 22 | 23 | return data, targets, shuffled_targets, lam 24 | 25 | 26 | class CrossEntropyMixUpLabelSmooth(torch.nn.Module): 27 | def __init__(self, num_classes, epsilon, reduction='mean'): 28 | super(CrossEntropyMixUpLabelSmooth, self).__init__() 29 | self.ce = CrossEntropyLabelSmooth(num_classes, epsilon, reduction=reduction) 30 | 31 | def forward(self, input, target1, target2, lam): # pylint: disable=redefined-builtin 32 | return lam * self.ce(input, target1) + (1 - lam) * self.ce(input, target2) 33 | -------------------------------------------------------------------------------- /UniformAugment/networks/shakeshake/shakeshake.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | 8 | 9 | class ShakeShake(torch.autograd.Function): 10 | 11 | @staticmethod 12 | def forward(ctx, x1, x2, training=True): 13 | if training: 14 | alpha = torch.cuda.FloatTensor(x1.size(0)).uniform_() 15 | alpha = alpha.view(alpha.size(0), 1, 1, 1).expand_as(x1) 16 | else: 17 | alpha = 0.5 18 | return alpha * x1 + (1 - alpha) * x2 19 | 20 | @staticmethod 21 | def backward(ctx, grad_output): 22 | beta = torch.cuda.FloatTensor(grad_output.size(0)).uniform_() 23 | beta = beta.view(beta.size(0), 1, 1, 1).expand_as(grad_output) 24 | beta = Variable(beta) 25 | 26 | return beta * grad_output, (1 - beta) * grad_output, None 27 | 28 | 29 | class Shortcut(nn.Module): 30 | 31 | def __init__(self, in_ch, out_ch, stride): 32 | super(Shortcut, self).__init__() 33 | self.stride = stride 34 | self.conv1 = nn.Conv2d(in_ch, out_ch // 2, 1, stride=1, padding=0, bias=False) 35 | self.conv2 = nn.Conv2d(in_ch, out_ch // 2, 1, stride=1, padding=0, bias=False) 36 | self.bn = nn.BatchNorm2d(out_ch) 37 | 38 | def forward(self, x): 39 | h = F.relu(x) 40 | 41 | h1 = F.avg_pool2d(h, 1, self.stride) 42 | h1 = self.conv1(h1) 43 | 44 | h2 = F.avg_pool2d(F.pad(h, (-1, 1, -1, 1)), 1, self.stride) 45 | h2 = self.conv2(h2) 46 | 47 | h = torch.cat((h1, h2), 1) 48 | return self.bn(h) 49 | -------------------------------------------------------------------------------- /UniformAugment/networks/shakedrop.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | 8 | 9 | class ShakeDropFunction(torch.autograd.Function): 10 | 11 | @staticmethod 12 | def forward(ctx, x, training=True, p_drop=0.5, alpha_range=[-1, 1]): 13 | if training: 14 | gate = torch.cuda.FloatTensor([0]).bernoulli_(1 - p_drop) 15 | ctx.save_for_backward(gate) 16 | if gate.item() == 0: 17 | alpha = torch.cuda.FloatTensor(x.size(0)).uniform_(*alpha_range) 18 | alpha = alpha.view(alpha.size(0), 1, 1, 1).expand_as(x) 19 | return alpha * x 20 | else: 21 | return x 22 | else: 23 | return (1 - p_drop) * x 24 | 25 | @staticmethod 26 | def backward(ctx, grad_output): 27 | gate = ctx.saved_tensors[0] 28 | if gate.item() == 0: 29 | beta = torch.cuda.FloatTensor(grad_output.size(0)).uniform_(0, 1) 30 | beta = beta.view(beta.size(0), 1, 1, 1).expand_as(grad_output) 31 | beta = Variable(beta) 32 | return beta * grad_output, None, None, None 33 | else: 34 | return grad_output, None, None, None 35 | 36 | 37 | class ShakeDrop(nn.Module): 38 | 39 | def __init__(self, p_drop=0.5, alpha_range=[-1, 1]): 40 | super(ShakeDrop, self).__init__() 41 | self.p_drop = p_drop 42 | self.alpha_range = alpha_range 43 | 44 | def forward(self, x): 45 | return ShakeDropFunction.apply(x, self.training, self.p_drop, self.alpha_range) 46 | -------------------------------------------------------------------------------- /UniformAugment/common.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import warnings 4 | 5 | formatter = logging.Formatter('[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s') 6 | warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning) 7 | warnings.filterwarnings("ignore", "DeprecationWarning: 'saved_variables' is deprecated", UserWarning) 8 | 9 | 10 | def get_logger(name, level=logging.DEBUG): 11 | logger = logging.getLogger(name) 12 | logger.handlers.clear() 13 | logger.setLevel(level) 14 | ch = logging.StreamHandler() 15 | ch.setLevel(level) 16 | ch.setFormatter(formatter) 17 | logger.addHandler(ch) 18 | return logger 19 | 20 | 21 | def add_filehandler(logger, filepath, level=logging.DEBUG): 22 | fh = logging.FileHandler(filepath) 23 | fh.setLevel(level) 24 | fh.setFormatter(formatter) 25 | logger.addHandler(fh) 26 | 27 | 28 | class EMA: 29 | def __init__(self, mu): 30 | self.mu = mu 31 | self.shadow = {} 32 | 33 | def state_dict(self): 34 | return copy.deepcopy(self.shadow) 35 | 36 | def __len__(self): 37 | return len(self.shadow) 38 | 39 | def __call__(self, module, step=None): 40 | if step is None: 41 | mu = self.mu 42 | else: 43 | # see : https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/train/ExponentialMovingAverage?hl=PL 44 | mu = min(self.mu, (1. + step) / (10 + step)) 45 | 46 | for name, x in module.state_dict().items(): 47 | if name in self.shadow: 48 | new_average = (1.0 - mu) * x + mu * self.shadow[name] 49 | self.shadow[name] = new_average.clone() 50 | else: 51 | self.shadow[name] = x.clone() 52 | -------------------------------------------------------------------------------- /UniformAugment/metrics.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch 4 | import numpy as np 5 | from collections import defaultdict 6 | 7 | from torch import nn 8 | 9 | 10 | def accuracy(output, target, topk=(1,)): 11 | """Computes the precision@k for the specified values of k""" 12 | maxk = max(topk) 13 | batch_size = target.size(0) 14 | 15 | _, pred = output.topk(maxk, 1, True, True) 16 | pred = pred.t() 17 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 18 | 19 | res = [] 20 | for k in topk: 21 | correct_k = correct[:k].view(-1).float().sum(0) 22 | res.append(correct_k.mul_(1. / batch_size)) 23 | return res 24 | 25 | 26 | class CrossEntropyLabelSmooth(torch.nn.Module): 27 | def __init__(self, num_classes, epsilon, reduction='mean'): 28 | super(CrossEntropyLabelSmooth, self).__init__() 29 | self.num_classes = num_classes 30 | self.epsilon = epsilon 31 | self.reduction = reduction 32 | self.logsoftmax = torch.nn.LogSoftmax(dim=1) 33 | 34 | def forward(self, input, target): # pylint: disable=redefined-builtin 35 | log_probs = self.logsoftmax(input) 36 | targets = torch.zeros_like(log_probs).scatter_(1, target.unsqueeze(1), 1) 37 | if self.epsilon > 0.0: 38 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 39 | targets = targets.detach() 40 | loss = (-targets * log_probs) 41 | 42 | if self.reduction in ['avg', 'mean']: 43 | loss = torch.mean(torch.sum(loss, dim=1)) 44 | elif self.reduction == 'sum': 45 | loss = loss.sum() 46 | return loss 47 | 48 | 49 | class Accumulator: 50 | def __init__(self): 51 | self.metrics = defaultdict(lambda: 0.) 52 | 53 | def add(self, key, value): 54 | self.metrics[key] += value 55 | 56 | def add_dict(self, dict): 57 | for key, value in dict.items(): 58 | self.add(key, value) 59 | 60 | def __getitem__(self, item): 61 | return self.metrics[item] 62 | 63 | def __setitem__(self, key, value): 64 | self.metrics[key] = value 65 | 66 | def get_dict(self): 67 | return copy.deepcopy(dict(self.metrics)) 68 | 69 | def items(self): 70 | return self.metrics.items() 71 | 72 | def __str__(self): 73 | return str(dict(self.metrics)) 74 | 75 | def __truediv__(self, other): 76 | newone = Accumulator() 77 | for key, value in self.items(): 78 | if isinstance(other, str): 79 | if other != key: 80 | newone[key] = value / self[other] 81 | else: 82 | newone[key] = value 83 | else: 84 | newone[key] = value / other 85 | return newone 86 | 87 | 88 | class SummaryWriterDummy: 89 | def __init__(self, log_dir): 90 | pass 91 | 92 | def add_scalar(self, *args, **kwargs): 93 | pass 94 | -------------------------------------------------------------------------------- /UniformAugment/networks/shakeshake/shake_resnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import math 4 | 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from UniformAugment.networks.shakeshake.shakeshake import ShakeShake 9 | from UniformAugment.networks.shakeshake.shakeshake import Shortcut 10 | 11 | 12 | class ShakeBlock(nn.Module): 13 | 14 | def __init__(self, in_ch, out_ch, stride=1): 15 | super(ShakeBlock, self).__init__() 16 | self.equal_io = in_ch == out_ch 17 | self.shortcut = self.equal_io and None or Shortcut(in_ch, out_ch, stride=stride) 18 | 19 | self.branch1 = self._make_branch(in_ch, out_ch, stride) 20 | self.branch2 = self._make_branch(in_ch, out_ch, stride) 21 | 22 | def forward(self, x): 23 | h1 = self.branch1(x) 24 | h2 = self.branch2(x) 25 | h = ShakeShake.apply(h1, h2, self.training) 26 | h0 = x if self.equal_io else self.shortcut(x) 27 | return h + h0 28 | 29 | def _make_branch(self, in_ch, out_ch, stride=1): 30 | return nn.Sequential( 31 | nn.ReLU(inplace=False), 32 | nn.Conv2d(in_ch, out_ch, 3, padding=1, stride=stride, bias=False), 33 | nn.BatchNorm2d(out_ch), 34 | nn.ReLU(inplace=False), 35 | nn.Conv2d(out_ch, out_ch, 3, padding=1, stride=1, bias=False), 36 | nn.BatchNorm2d(out_ch)) 37 | 38 | 39 | class ShakeResNet(nn.Module): 40 | 41 | def __init__(self, depth, w_base, label): 42 | super(ShakeResNet, self).__init__() 43 | n_units = (depth - 2) / 6 44 | 45 | in_chs = [16, w_base, w_base * 2, w_base * 4] 46 | self.in_chs = in_chs 47 | 48 | self.c_in = nn.Conv2d(3, in_chs[0], 3, padding=1) 49 | self.layer1 = self._make_layer(n_units, in_chs[0], in_chs[1]) 50 | self.layer2 = self._make_layer(n_units, in_chs[1], in_chs[2], 2) 51 | self.layer3 = self._make_layer(n_units, in_chs[2], in_chs[3], 2) 52 | self.fc_out = nn.Linear(in_chs[3], label) 53 | 54 | # Initialize paramters 55 | for m in self.modules(): 56 | if isinstance(m, nn.Conv2d): 57 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 58 | m.weight.data.normal_(0, math.sqrt(2. / n)) 59 | elif isinstance(m, nn.BatchNorm2d): 60 | m.weight.data.fill_(1) 61 | m.bias.data.zero_() 62 | elif isinstance(m, nn.Linear): 63 | m.bias.data.zero_() 64 | 65 | def forward(self, x): 66 | h = self.c_in(x) 67 | h = self.layer1(h) 68 | h = self.layer2(h) 69 | h = self.layer3(h) 70 | h = F.relu(h) 71 | h = F.avg_pool2d(h, 8) 72 | h = h.view(-1, self.in_chs[3]) 73 | h = self.fc_out(h) 74 | return h 75 | 76 | def _make_layer(self, n_units, in_ch, out_ch, stride=1): 77 | layers = [] 78 | for i in range(int(n_units)): 79 | layers.append(ShakeBlock(in_ch, out_ch, stride=stride)) 80 | in_ch, stride = out_ch, 1 81 | return nn.Sequential(*layers) 82 | -------------------------------------------------------------------------------- /UniformAugment/networks/wideresnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.init as init 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | def conv3x3(in_planes, out_planes, stride=1): 8 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True) 9 | 10 | 11 | def conv_init(m): 12 | classname = m.__class__.__name__ 13 | if classname.find('Conv') != -1: 14 | init.xavier_uniform_(m.weight, gain=np.sqrt(2)) 15 | init.constant_(m.bias, 0) 16 | elif classname.find('BatchNorm') != -1: 17 | init.constant_(m.weight, 1) 18 | init.constant_(m.bias, 0) 19 | 20 | 21 | class WideBasic(nn.Module): 22 | def __init__(self, in_planes, planes, dropout_rate, stride=1): 23 | super(WideBasic, self).__init__() 24 | self.bn1 = nn.BatchNorm2d(in_planes, momentum=0.9) 25 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True) 26 | self.dropout = nn.Dropout(p=dropout_rate) 27 | self.bn2 = nn.BatchNorm2d(planes, momentum=0.9) 28 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True) 29 | 30 | self.shortcut = nn.Sequential() 31 | if stride != 1 or in_planes != planes: 32 | self.shortcut = nn.Sequential( 33 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True), 34 | ) 35 | 36 | def forward(self, x): 37 | out = self.dropout(self.conv1(F.relu(self.bn1(x)))) 38 | out = self.conv2(F.relu(self.bn2(out))) 39 | out += self.shortcut(x) 40 | 41 | return out 42 | 43 | 44 | class WideResNet(nn.Module): 45 | def __init__(self, depth, widen_factor, dropout_rate, num_classes): 46 | super(WideResNet, self).__init__() 47 | self.in_planes = 16 48 | 49 | assert ((depth - 4) % 6 == 0), 'Wide-resnet depth should be 6n+4' 50 | n = int((depth - 4) / 6) 51 | k = widen_factor 52 | 53 | nStages = [16, 16*k, 32*k, 64*k] 54 | 55 | self.conv1 = conv3x3(3, nStages[0]) 56 | self.layer1 = self._wide_layer(WideBasic, nStages[1], n, dropout_rate, stride=1) 57 | self.layer2 = self._wide_layer(WideBasic, nStages[2], n, dropout_rate, stride=2) 58 | self.layer3 = self._wide_layer(WideBasic, nStages[3], n, dropout_rate, stride=2) 59 | self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9) 60 | self.linear = nn.Linear(nStages[3], num_classes) 61 | 62 | # self.apply(conv_init) 63 | 64 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride): 65 | strides = [stride] + [1]*(num_blocks-1) 66 | layers = [] 67 | 68 | for stride in strides: 69 | layers.append(block(self.in_planes, planes, dropout_rate, stride)) 70 | self.in_planes = planes 71 | 72 | return nn.Sequential(*layers) 73 | 74 | def forward(self, x): 75 | out = self.conv1(x) 76 | out = self.layer1(out) 77 | out = self.layer2(out) 78 | out = self.layer3(out) 79 | out = F.relu(self.bn1(out)) 80 | # out = F.avg_pool2d(out, 8) 81 | out = F.adaptive_avg_pool2d(out, (1, 1)) 82 | out = out.view(out.size(0), -1) 83 | out = self.linear(out) 84 | 85 | return out 86 | -------------------------------------------------------------------------------- /UniformAugment/networks/shakeshake/shake_resnext.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import math 4 | 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from UniformAugment.networks.shakeshake.shakeshake import ShakeShake 9 | from UniformAugment.networks.shakeshake.shakeshake import Shortcut 10 | 11 | 12 | class ShakeBottleNeck(nn.Module): 13 | 14 | def __init__(self, in_ch, mid_ch, out_ch, cardinary, stride=1): 15 | super(ShakeBottleNeck, self).__init__() 16 | self.equal_io = in_ch == out_ch 17 | self.shortcut = None if self.equal_io else Shortcut(in_ch, out_ch, stride=stride) 18 | 19 | self.branch1 = self._make_branch(in_ch, mid_ch, out_ch, cardinary, stride) 20 | self.branch2 = self._make_branch(in_ch, mid_ch, out_ch, cardinary, stride) 21 | 22 | def forward(self, x): 23 | h1 = self.branch1(x) 24 | h2 = self.branch2(x) 25 | h = ShakeShake.apply(h1, h2, self.training) 26 | h0 = x if self.equal_io else self.shortcut(x) 27 | return h + h0 28 | 29 | def _make_branch(self, in_ch, mid_ch, out_ch, cardinary, stride=1): 30 | return nn.Sequential( 31 | nn.Conv2d(in_ch, mid_ch, 1, padding=0, bias=False), 32 | nn.BatchNorm2d(mid_ch), 33 | nn.ReLU(inplace=False), 34 | nn.Conv2d(mid_ch, mid_ch, 3, padding=1, stride=stride, groups=cardinary, bias=False), 35 | nn.BatchNorm2d(mid_ch), 36 | nn.ReLU(inplace=False), 37 | nn.Conv2d(mid_ch, out_ch, 1, padding=0, bias=False), 38 | nn.BatchNorm2d(out_ch)) 39 | 40 | 41 | class ShakeResNeXt(nn.Module): 42 | 43 | def __init__(self, depth, w_base, cardinary, label): 44 | super(ShakeResNeXt, self).__init__() 45 | n_units = (depth - 2) // 9 46 | n_chs = [64, 128, 256, 1024] 47 | self.n_chs = n_chs 48 | self.in_ch = n_chs[0] 49 | 50 | self.c_in = nn.Conv2d(3, n_chs[0], 3, padding=1) 51 | self.layer1 = self._make_layer(n_units, n_chs[0], w_base, cardinary) 52 | self.layer2 = self._make_layer(n_units, n_chs[1], w_base, cardinary, 2) 53 | self.layer3 = self._make_layer(n_units, n_chs[2], w_base, cardinary, 2) 54 | self.fc_out = nn.Linear(n_chs[3], label) 55 | 56 | # Initialize paramters 57 | for m in self.modules(): 58 | if isinstance(m, nn.Conv2d): 59 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 60 | m.weight.data.normal_(0, math.sqrt(2. / n)) 61 | elif isinstance(m, nn.BatchNorm2d): 62 | m.weight.data.fill_(1) 63 | m.bias.data.zero_() 64 | elif isinstance(m, nn.Linear): 65 | m.bias.data.zero_() 66 | 67 | def forward(self, x): 68 | h = self.c_in(x) 69 | h = self.layer1(h) 70 | h = self.layer2(h) 71 | h = self.layer3(h) 72 | h = F.relu(h) 73 | h = F.avg_pool2d(h, 8) 74 | h = h.view(-1, self.n_chs[3]) 75 | h = self.fc_out(h) 76 | return h 77 | 78 | def _make_layer(self, n_units, n_ch, w_base, cardinary, stride=1): 79 | layers = [] 80 | mid_ch, out_ch = n_ch * (w_base // 64) * cardinary, n_ch * 4 81 | for i in range(n_units): 82 | layers.append(ShakeBottleNeck(self.in_ch, mid_ch, out_ch, cardinary, stride=stride)) 83 | self.in_ch, stride = out_ch, 1 84 | return nn.Sequential(*layers) 85 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # UniformAugment 2 | 3 | Unofficial PyTorch Reimplementation of [UniformAugment](https://arxiv.org/abs/2003.14348). Most of codes are from [Fast AutoAugment](https://github.com/kakaobrain/fast-autoaugment) and [PyTorch RandAugment](https://github.com/ildoonet/pytorch-randaugment). 4 | 5 | ## Introduction 6 | UniformAugment is an automated data augmentation approach that completely avoids a search phase. UniformAugment’s effectiveness is comparable to the known methods, while still being highly efficient by virtue of not requiring any search. 7 | 8 | ## Install 9 | ``` 10 | pip install git+https://github.com/tgilewicz/uniformaugment/ 11 | ``` 12 | 13 | ## Usage 14 | 15 | ```python 16 | from torchvision.transforms import transforms 17 | from UniformAugment import UniformAugment 18 | 19 | transform_train = transforms.Compose([ 20 | transforms.RandomCrop(32, padding=4), 21 | transforms.RandomHorizontalFlip(), 22 | transforms.ToTensor(), 23 | transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD), 24 | ]) 25 | # Add UniformAugment with num_ops hyperparameter (num_ops=2 is optimal) 26 | transform_train.transforms.insert(0, UniformAugment()) 27 | ``` 28 | 29 | ## Experiment 30 | 31 | The details of the experiment were consulted with the authors of the UniformAugment paper. 32 | 33 | You can run an example experiment with, 34 | 35 | ```bash 36 | $ python UniformAugment/train.py -c confs/wresnet28x10_cifar.yaml --dataset cifar10 \ 37 | --save cifar10_wres28x10.pth --dataroot ~/data --tag v1 38 | ``` 39 | 40 | ### CIFAR-10 Classification, TOP1 Accuracy 41 | 42 | | Model | Paper's Result | Run1 | Run2 | Run3 | Run4 | Avg (Ours) | 43 | |-------------------|---------------:|-------------:|-------------:|-------------:|-------------:|-------------:| 44 | | Wide-ResNet 28x10 | **97.33** | 97.26 | 97.31 | 97.33 | 97.42 | **97.33** | 45 | | Wide-ResNet 40x2 | **96.25** | 96.27 | 96.36 | 96.5 | 96.54 | **96.41** | 46 | 47 | ### CIFAR-100 Classification, TOP1 Accuracy 48 | 49 | | Model | Paper's Result | Run1 | Run2 | Run3 | Run4 | Avg (Ours) | 50 | |-------------------|---------------:|-------------:|-------------:|-------------:|-------------:|-------------:| 51 | | Wide-ResNet 28x10 | **82.82** | 83.55 | 82.56 | 82.66 | 82.72 | **82.87** | 52 | | Wide-ResNet 40x2 | **79.01** | 79.06 | 79.08 | 79.09 | 78.77 | **79.00** | 53 | 54 | 55 | 56 | ### ImageNet Classification 57 | 58 | | Model | Paper's Result | Ours | 59 | |-------------------|---------------:|-------------:| 60 | | ResNet-50 | **77.63** | **77.80** | 61 | | ResNet-200 | **80.4** | Stay tuned | 62 | 63 | 64 | ## Core class 65 | ```python 66 | class UniformAugment: 67 | def __init__(self, ops_num=2): 68 | self._augment_list = augment_list(for_autoaug=False) 69 | self._ops_num = ops_num 70 | 71 | def __call__(self, img): 72 | # Selecting unique num_ops transforms for each image would help the 73 | # training procedure. 74 | ops = random.choices(self._augment_list, k=self._ops_num) 75 | 76 | for op in ops: 77 | augment_fn, low, high = op 78 | probability = random.random() 79 | if random.random() < probability: 80 | img = augment_fn(img.copy(), random.uniform(low, high)) 81 | 82 | return img 83 | ``` 84 | 85 | 86 | ## References 87 | 88 | - UniformAugment : [Paper](https://arxiv.org/abs/2003.14348) 89 | - Fast AutoAugment : [Code](https://github.com/kakaobrain/fast-autoaugment) [Paper](https://arxiv.org/abs/1905.00397) 90 | - Pytorch RandAugment: [Code](https://github.com/ildoonet/pytorch-randaugment) 91 | -------------------------------------------------------------------------------- /UniformAugment/networks/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.backends.cudnn as cudnn 4 | from UniformAugment.networks.efficientnet_pytorch import EfficientNet, RoutingFn 5 | from UniformAugment.networks.pyramidnet import PyramidNet 6 | from UniformAugment.networks.resnet import ResNet 7 | from UniformAugment.networks.shakeshake.shake_resnet import ShakeResNet 8 | from UniformAugment.networks.shakeshake.shake_resnext import ShakeResNeXt 9 | from UniformAugment.networks.wideresnet import WideResNet 10 | from torch import nn 11 | from torch.nn.parallel import DistributedDataParallel 12 | 13 | 14 | def get_model(conf, num_class=10, local_rank=-1): 15 | name = conf['type'] 16 | 17 | if name == 'resnet50': 18 | model = ResNet(dataset='imagenet', depth=50, num_classes=num_class, bottleneck=True) 19 | elif name == 'resnet200': 20 | model = ResNet(dataset='imagenet', depth=200, num_classes=num_class, bottleneck=True) 21 | elif name == 'wresnet40_2': 22 | model = WideResNet(40, 2, dropout_rate=0.0, num_classes=num_class) 23 | elif name == 'wresnet28_10': 24 | model = WideResNet(28, 10, dropout_rate=0.0, num_classes=num_class) 25 | 26 | elif name == 'shakeshake26_2x32d': 27 | model = ShakeResNet(26, 32, num_class) 28 | elif name == 'shakeshake26_2x64d': 29 | model = ShakeResNet(26, 64, num_class) 30 | elif name == 'shakeshake26_2x96d': 31 | model = ShakeResNet(26, 96, num_class) 32 | elif name == 'shakeshake26_2x112d': 33 | model = ShakeResNet(26, 112, num_class) 34 | 35 | elif name == 'shakeshake26_2x96d_next': 36 | model = ShakeResNeXt(26, 96, 4, num_class) 37 | 38 | elif name == 'pyramid': 39 | model = PyramidNet('cifar10', depth=conf['depth'], alpha=conf['alpha'], num_classes=num_class, bottleneck=conf['bottleneck']) 40 | 41 | elif 'efficientnet' in name: 42 | model = EfficientNet.from_name(name, condconv_num_expert=conf['condconv_num_expert'], norm_layer=None) # TpuBatchNormalization 43 | if local_rank >= 0: 44 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model) 45 | def kernel_initializer(module): 46 | def get_fan_in_out(module): 47 | num_input_fmaps = module.weight.size(1) 48 | num_output_fmaps = module.weight.size(0) 49 | receptive_field_size = 1 50 | if module.weight.dim() > 2: 51 | receptive_field_size = module.weight[0][0].numel() 52 | fan_in = num_input_fmaps * receptive_field_size 53 | fan_out = num_output_fmaps * receptive_field_size 54 | return fan_in, fan_out 55 | 56 | if isinstance(module, torch.nn.Conv2d): 57 | # https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py#L58 58 | fan_in, fan_out = get_fan_in_out(module) 59 | torch.nn.init.normal_(module.weight, mean=0.0, std=np.sqrt(2.0 / fan_out)) 60 | if module.bias is not None: 61 | torch.nn.init.constant_(module.bias, val=0.) 62 | elif isinstance(module, RoutingFn): 63 | torch.nn.init.xavier_uniform_(module.weight) 64 | torch.nn.init.constant_(module.bias, val=0.) 65 | elif isinstance(module, torch.nn.Linear): 66 | # https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py#L82 67 | fan_in, fan_out = get_fan_in_out(module) 68 | delta = 1.0 / np.sqrt(fan_out) 69 | torch.nn.init.uniform_(module.weight, a=-delta, b=delta) 70 | if module.bias is not None: 71 | torch.nn.init.constant_(module.bias, val=0.) 72 | model.apply(kernel_initializer) 73 | else: 74 | raise NameError('no model named, %s' % name) 75 | 76 | if local_rank >= 0: 77 | device = torch.device('cuda', local_rank) 78 | model = model.to(device) 79 | model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank) 80 | else: 81 | model = model.cuda() 82 | # model = DataParallel(model) 83 | 84 | cudnn.benchmark = True 85 | return model 86 | 87 | 88 | def num_class(dataset): 89 | return { 90 | 'cifar10': 10, 91 | 'reduced_cifar10': 10, 92 | 'cifar10.1': 10, 93 | 'cifar100': 100, 94 | 'svhn': 10, 95 | 'reduced_svhn': 10, 96 | 'imagenet': 1000, 97 | 'reduced_imagenet': 120, 98 | }[dataset] 99 | -------------------------------------------------------------------------------- /UniformAugment/networks/resnet.py: -------------------------------------------------------------------------------- 1 | # Original code: https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 2 | 3 | import torch.nn as nn 4 | import math 5 | 6 | 7 | def conv3x3(in_planes, out_planes, stride=1): 8 | "3x3 convolution with padding" 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 10 | padding=1, bias=False) 11 | 12 | 13 | class BasicBlock(nn.Module): 14 | expansion = 1 15 | 16 | def __init__(self, inplanes, planes, stride=1, downsample=None): 17 | super(BasicBlock, self).__init__() 18 | self.conv1 = conv3x3(inplanes, planes, stride) 19 | self.bn1 = nn.BatchNorm2d(planes) 20 | self.conv2 = conv3x3(planes, planes) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | self.relu = nn.ReLU(inplace=True) 23 | 24 | self.downsample = downsample 25 | self.stride = stride 26 | 27 | def forward(self, x): 28 | residual = x 29 | 30 | out = self.conv1(x) 31 | out = self.bn1(out) 32 | out = self.relu(out) 33 | 34 | out = self.conv2(out) 35 | out = self.bn2(out) 36 | 37 | if self.downsample is not None: 38 | residual = self.downsample(x) 39 | 40 | out += residual 41 | out = self.relu(out) 42 | 43 | return out 44 | 45 | 46 | class Bottleneck(nn.Module): 47 | expansion = 4 48 | 49 | def __init__(self, inplanes, planes, stride=1, downsample=None): 50 | super(Bottleneck, self).__init__() 51 | 52 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 53 | self.bn1 = nn.BatchNorm2d(planes) 54 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 55 | self.bn2 = nn.BatchNorm2d(planes) 56 | self.conv3 = nn.Conv2d(planes, planes * Bottleneck.expansion, kernel_size=1, bias=False) 57 | self.bn3 = nn.BatchNorm2d(planes * Bottleneck.expansion) 58 | self.relu = nn.ReLU(inplace=True) 59 | 60 | self.downsample = downsample 61 | self.stride = stride 62 | 63 | def forward(self, x): 64 | residual = x 65 | 66 | out = self.conv1(x) 67 | out = self.bn1(out) 68 | out = self.relu(out) 69 | 70 | out = self.conv2(out) 71 | out = self.bn2(out) 72 | out = self.relu(out) 73 | 74 | out = self.conv3(out) 75 | out = self.bn3(out) 76 | if self.downsample is not None: 77 | residual = self.downsample(x) 78 | 79 | out += residual 80 | out = self.relu(out) 81 | 82 | return out 83 | 84 | class ResNet(nn.Module): 85 | def __init__(self, dataset, depth, num_classes, bottleneck=False): 86 | super(ResNet, self).__init__() 87 | self.dataset = dataset 88 | if self.dataset.startswith('cifar'): 89 | self.inplanes = 16 90 | print(bottleneck) 91 | if bottleneck == True: 92 | n = int((depth - 2) / 9) 93 | block = Bottleneck 94 | else: 95 | n = int((depth - 2) / 6) 96 | block = BasicBlock 97 | 98 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) 99 | self.bn1 = nn.BatchNorm2d(self.inplanes) 100 | self.relu = nn.ReLU(inplace=True) 101 | self.layer1 = self._make_layer(block, 16, n) 102 | self.layer2 = self._make_layer(block, 32, n, stride=2) 103 | self.layer3 = self._make_layer(block, 64, n, stride=2) 104 | # self.avgpool = nn.AvgPool2d(8) 105 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 106 | self.fc = nn.Linear(64 * block.expansion, num_classes) 107 | 108 | elif dataset == 'imagenet': 109 | blocks ={18: BasicBlock, 34: BasicBlock, 50: Bottleneck, 101: Bottleneck, 152: Bottleneck, 200: Bottleneck} 110 | layers ={18: [2, 2, 2, 2], 34: [3, 4, 6, 3], 50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], 200: [3, 24, 36, 3]} 111 | assert layers[depth], 'invalid detph for ResNet (depth should be one of 18, 34, 50, 101, 152, and 200)' 112 | 113 | self.inplanes = 64 114 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) 115 | self.bn1 = nn.BatchNorm2d(64) 116 | self.relu = nn.ReLU(inplace=True) 117 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 118 | self.layer1 = self._make_layer(blocks[depth], 64, layers[depth][0]) 119 | self.layer2 = self._make_layer(blocks[depth], 128, layers[depth][1], stride=2) 120 | self.layer3 = self._make_layer(blocks[depth], 256, layers[depth][2], stride=2) 121 | self.layer4 = self._make_layer(blocks[depth], 512, layers[depth][3], stride=2) 122 | # self.avgpool = nn.AvgPool2d(7) 123 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 124 | self.fc = nn.Linear(512 * blocks[depth].expansion, num_classes) 125 | 126 | for m in self.modules(): 127 | if isinstance(m, nn.Conv2d): 128 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 129 | m.weight.data.normal_(0, math.sqrt(2. / n)) 130 | elif isinstance(m, nn.BatchNorm2d): 131 | m.weight.data.fill_(1) 132 | m.bias.data.zero_() 133 | 134 | def _make_layer(self, block, planes, blocks, stride=1): 135 | downsample = None 136 | if stride != 1 or self.inplanes != planes * block.expansion: 137 | downsample = nn.Sequential( 138 | nn.Conv2d(self.inplanes, planes * block.expansion, 139 | kernel_size=1, stride=stride, bias=False), 140 | nn.BatchNorm2d(planes * block.expansion), 141 | ) 142 | 143 | layers = [] 144 | layers.append(block(self.inplanes, planes, stride, downsample)) 145 | self.inplanes = planes * block.expansion 146 | for i in range(1, blocks): 147 | layers.append(block(self.inplanes, planes)) 148 | 149 | return nn.Sequential(*layers) 150 | 151 | def forward(self, x): 152 | if self.dataset == 'cifar10' or self.dataset == 'cifar100': 153 | x = self.conv1(x) 154 | x = self.bn1(x) 155 | x = self.relu(x) 156 | 157 | x = self.layer1(x) 158 | x = self.layer2(x) 159 | x = self.layer3(x) 160 | 161 | x = self.avgpool(x) 162 | x = x.view(x.size(0), -1) 163 | x = self.fc(x) 164 | 165 | elif self.dataset == 'imagenet': 166 | x = self.conv1(x) 167 | x = self.bn1(x) 168 | x = self.relu(x) 169 | x = self.maxpool(x) 170 | 171 | x = self.layer1(x) 172 | x = self.layer2(x) 173 | x = self.layer3(x) 174 | x = self.layer4(x) 175 | 176 | x = self.avgpool(x) 177 | x = x.view(x.size(0), -1) 178 | x = self.fc(x) 179 | 180 | return x 181 | -------------------------------------------------------------------------------- /UniformAugment/augmentations.py: -------------------------------------------------------------------------------- 1 | # code in this file is adpated from rpmcruz/autoaugment 2 | # https://github.com/rpmcruz/autoaugment/blob/master/transformations.py 3 | import random 4 | 5 | import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw 6 | import numpy as np 7 | import torch 8 | from torchvision.transforms.transforms import Compose 9 | 10 | random_mirror = True 11 | 12 | 13 | def ShearX(img, v): # [-0.3, 0.3] 14 | assert -0.3 <= v <= 0.3 15 | if random_mirror and random.random() > 0.5: 16 | v = -v 17 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) 18 | 19 | 20 | def ShearY(img, v): # [-0.3, 0.3] 21 | assert -0.3 <= v <= 0.3 22 | if random_mirror and random.random() > 0.5: 23 | v = -v 24 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 25 | 26 | 27 | def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 28 | assert -0.45 <= v <= 0.45 29 | if random_mirror and random.random() > 0.5: 30 | v = -v 31 | v = v * img.size[0] 32 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 33 | 34 | 35 | def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 36 | assert -0.45 <= v <= 0.45 37 | if random_mirror and random.random() > 0.5: 38 | v = -v 39 | v = v * img.size[1] 40 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 41 | 42 | 43 | def TranslateXAbs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 44 | assert 0 <= v <= 10 45 | if random.random() > 0.5: 46 | v = -v 47 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 48 | 49 | 50 | def TranslateYAbs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 51 | assert 0 <= v <= 10 52 | if random.random() > 0.5: 53 | v = -v 54 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 55 | 56 | 57 | def Rotate(img, v): # [-30, 30] 58 | assert -30 <= v <= 30 59 | if random_mirror and random.random() > 0.5: 60 | v = -v 61 | return img.rotate(v) 62 | 63 | 64 | def AutoContrast(img, _): 65 | return PIL.ImageOps.autocontrast(img) 66 | 67 | 68 | def Invert(img, _): 69 | return PIL.ImageOps.invert(img) 70 | 71 | 72 | def Equalize(img, _): 73 | return PIL.ImageOps.equalize(img) 74 | 75 | 76 | def Flip(img, _): # not from the paper 77 | return PIL.ImageOps.mirror(img) 78 | 79 | 80 | def Solarize(img, v): # [0, 256] 81 | assert 0 <= v <= 256 82 | return PIL.ImageOps.solarize(img, v) 83 | 84 | 85 | def Posterize(img, v): # [4, 8] 86 | assert 4 <= v <= 8 87 | v = int(v) 88 | return PIL.ImageOps.posterize(img, v) 89 | 90 | 91 | def Posterize2(img, v): # [0, 4] 92 | assert 0 <= v <= 4 93 | v = int(v) 94 | return PIL.ImageOps.posterize(img, v) 95 | 96 | 97 | def Contrast(img, v): # [0.1,1.9] 98 | assert 0.1 <= v <= 1.9 99 | return PIL.ImageEnhance.Contrast(img).enhance(v) 100 | 101 | 102 | def Color(img, v): # [0.1,1.9] 103 | assert 0.1 <= v <= 1.9 104 | return PIL.ImageEnhance.Color(img).enhance(v) 105 | 106 | 107 | def Brightness(img, v): # [0.1,1.9] 108 | assert 0.1 <= v <= 1.9 109 | return PIL.ImageEnhance.Brightness(img).enhance(v) 110 | 111 | 112 | def Sharpness(img, v): # [0.1,1.9] 113 | assert 0.1 <= v <= 1.9 114 | return PIL.ImageEnhance.Sharpness(img).enhance(v) 115 | 116 | 117 | def Cutout(img, v): # [0, 60] => percentage: [0, 0.2] 118 | assert 0.0 <= v <= 0.2 119 | if v <= 0.: 120 | return img 121 | 122 | v = v * img.size[0] 123 | return CutoutAbs(img, v) 124 | 125 | 126 | def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2] 127 | # assert 0 <= v <= 20 128 | if v < 0: 129 | return img 130 | w, h = img.size 131 | x0 = np.random.uniform(w) 132 | y0 = np.random.uniform(h) 133 | 134 | x0 = int(max(0, x0 - v / 2.)) 135 | y0 = int(max(0, y0 - v / 2.)) 136 | x1 = min(w, x0 + v) 137 | y1 = min(h, y0 + v) 138 | 139 | xy = (x0, y0, x1, y1) 140 | color = (125, 123, 114) 141 | # color = (0, 0, 0) 142 | img = img.copy() 143 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 144 | return img 145 | 146 | 147 | def SamplePairing(imgs): # [0, 0.4] 148 | def f(img1, v): 149 | i = np.random.choice(len(imgs)) 150 | img2 = PIL.Image.fromarray(imgs[i]) 151 | return PIL.Image.blend(img1, img2, v) 152 | 153 | return f 154 | 155 | 156 | def augment_list(for_autoaug=True): # 16 oeprations and their ranges 157 | l = [ 158 | (ShearX, -0.3, 0.3), # 0 159 | (ShearY, -0.3, 0.3), # 1 160 | (TranslateX, -0.45, 0.45), # 2 161 | (TranslateY, -0.45, 0.45), # 3 162 | (Rotate, -30, 30), # 4 163 | (AutoContrast, 0, 1), # 5 164 | (Invert, 0, 1), # 6 165 | (Equalize, 0, 1), # 7 166 | (Solarize, 0, 256), # 8 167 | (Posterize, 4, 8), # 9 168 | (Contrast, 0.1, 1.9), # 10 169 | (Color, 0.1, 1.9), # 11 170 | (Brightness, 0.1, 1.9), # 12 171 | (Sharpness, 0.1, 1.9), # 13 172 | (Cutout, 0, 0.2), # 14 173 | # (SamplePairing(imgs), 0, 0.4), # 15 174 | ] 175 | if for_autoaug: 176 | l += [ 177 | (CutoutAbs, 0, 20), # compatible with auto-augment 178 | (Posterize2, 0, 4), # 9 179 | (TranslateXAbs, 0, 10), # 9 180 | (TranslateYAbs, 0, 10), # 9 181 | ] 182 | return l 183 | 184 | 185 | augment_dict = {fn.__name__: (fn, v1, v2) for fn, v1, v2 in augment_list()} 186 | 187 | 188 | def get_augment(name): 189 | return augment_dict[name] 190 | 191 | 192 | def apply_augment(img, name, level): 193 | augment_fn, low, high = get_augment(name) 194 | return augment_fn(img.copy(), level * (high - low) + low) 195 | 196 | 197 | class Lighting(object): 198 | """Lighting noise(AlexNet - style PCA - based noise)""" 199 | 200 | def __init__(self, alphastd, eigval, eigvec): 201 | self.alphastd = alphastd 202 | self.eigval = torch.Tensor(eigval) 203 | self.eigvec = torch.Tensor(eigvec) 204 | 205 | def __call__(self, img): 206 | if self.alphastd == 0: 207 | return img 208 | 209 | alpha = img.new().resize_(3).normal_(0, self.alphastd) 210 | rgb = self.eigvec.type_as(img).clone() \ 211 | .mul(alpha.view(1, 3).expand(3, 3)) \ 212 | .mul(self.eigval.view(1, 3).expand(3, 3)) \ 213 | .sum(1).squeeze() 214 | 215 | return img.add(rgb.view(3, 1, 1).expand_as(img)) 216 | 217 | 218 | class UniformAugment: 219 | def __init__(self, ops_num=2): 220 | self._augment_list = augment_list(for_autoaug=False) 221 | self._ops_num = ops_num 222 | 223 | def __call__(self, img): 224 | # Selecting unique num_ops transforms for each image would help the 225 | # training procedure. 226 | ops = random.choices(self._augment_list, k=self._ops_num) 227 | 228 | for op in ops: 229 | augment_fn, low, high = op 230 | probability = random.random() 231 | if random.random() < probability: 232 | img = augment_fn(img.copy(), random.uniform(low, high)) 233 | 234 | return img 235 | -------------------------------------------------------------------------------- /UniformAugment/networks/efficientnet_pytorch/condconv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch._six import container_abcs 5 | 6 | from itertools import repeat 7 | from functools import partial 8 | from typing import Union, List, Tuple, Optional, Callable 9 | import numpy as np 10 | import math 11 | 12 | 13 | def _ntuple(n): 14 | def parse(x): 15 | if isinstance(x, container_abcs.Iterable): 16 | return x 17 | return tuple(repeat(x, n)) 18 | return parse 19 | 20 | 21 | _single = _ntuple(1) 22 | _pair = _ntuple(2) 23 | _triple = _ntuple(3) 24 | _quadruple = _ntuple(4) 25 | 26 | 27 | def _is_static_pad(kernel_size, stride=1, dilation=1, **_): 28 | return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 29 | 30 | 31 | def _get_padding(kernel_size, stride=1, dilation=1, **_): 32 | padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 33 | return padding 34 | 35 | 36 | def _calc_same_pad(i: int, k: int, s: int, d: int): 37 | return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0) 38 | 39 | 40 | def conv2d_same( 41 | x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1), 42 | padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1): 43 | ih, iw = x.size()[-2:] 44 | kh, kw = weight.size()[-2:] 45 | pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0]) 46 | pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1]) 47 | if pad_h > 0 or pad_w > 0: 48 | x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) 49 | return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups) 50 | 51 | 52 | def get_padding_value(padding, kernel_size, **kwargs): 53 | dynamic = False 54 | if isinstance(padding, str): 55 | # for any string padding, the padding will be calculated for you, one of three ways 56 | padding = padding.lower() 57 | if padding == 'same': 58 | # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact 59 | if _is_static_pad(kernel_size, **kwargs): 60 | # static case, no extra overhead 61 | padding = _get_padding(kernel_size, **kwargs) 62 | else: 63 | # dynamic padding 64 | padding = 0 65 | dynamic = True 66 | elif padding == 'valid': 67 | # 'VALID' padding, same as padding=0 68 | padding = 0 69 | else: 70 | # Default to PyTorch style 'same'-ish symmetric padding 71 | padding = _get_padding(kernel_size, **kwargs) 72 | return padding, dynamic 73 | 74 | 75 | def get_condconv_initializer(initializer, num_experts, expert_shape): 76 | def condconv_initializer(weight): 77 | """CondConv initializer function.""" 78 | num_params = np.prod(expert_shape) 79 | if (len(weight.shape) != 2 or weight.shape[0] != num_experts or weight.shape[1] != num_params): 80 | raise (ValueError('CondConv variables must have shape [num_experts, num_params]')) 81 | for i in range(num_experts): 82 | initializer(weight[i].view(expert_shape)) 83 | return condconv_initializer 84 | 85 | 86 | class CondConv2d(nn.Module): 87 | """ Conditional Convolution 88 | Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py 89 | Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion: 90 | https://github.com/pytorch/pytorch/issues/17983 91 | """ 92 | __constants__ = ['bias', 'in_channels', 'out_channels', 'dynamic_padding'] 93 | 94 | def __init__(self, in_channels, out_channels, kernel_size=3, 95 | stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4): 96 | super(CondConv2d, self).__init__() 97 | assert num_experts > 1 98 | 99 | if isinstance(stride, container_abcs.Iterable) and len(stride) == 1: 100 | stride = stride[0] 101 | # print('CondConv', num_experts) 102 | 103 | self.in_channels = in_channels 104 | self.out_channels = out_channels 105 | self.kernel_size = _pair(kernel_size) 106 | self.stride = _pair(stride) 107 | padding_val, is_padding_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation) 108 | self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript 109 | self.padding = _pair(padding_val) 110 | self.dilation = _pair(dilation) 111 | self.groups = groups 112 | self.num_experts = num_experts 113 | 114 | self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size 115 | weight_num_param = 1 116 | for wd in self.weight_shape: 117 | weight_num_param *= wd 118 | self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param)) 119 | 120 | if bias: 121 | self.bias_shape = (self.out_channels,) 122 | self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels)) 123 | else: 124 | self.register_parameter('bias', None) 125 | 126 | self.reset_parameters() 127 | 128 | def reset_parameters(self): 129 | num_input_fmaps = self.weight.size(1) 130 | num_output_fmaps = self.weight.size(0) 131 | receptive_field_size = 1 132 | if self.weight.dim() > 2: 133 | receptive_field_size = self.weight[0][0].numel() 134 | fan_in = num_input_fmaps * receptive_field_size 135 | fan_out = num_output_fmaps * receptive_field_size 136 | 137 | init_weight = get_condconv_initializer(partial(nn.init.normal_, mean=0.0, std=np.sqrt(2.0 / fan_out)), self.num_experts, self.weight_shape) 138 | init_weight(self.weight) 139 | if self.bias is not None: 140 | # fan_in = np.prod(self.weight_shape[1:]) 141 | # bound = 1 / math.sqrt(fan_in) 142 | init_bias = get_condconv_initializer(partial(nn.init.constant_, val=0), self.num_experts, self.bias_shape) 143 | init_bias(self.bias) 144 | 145 | def forward(self, x, routing_weights): 146 | x_orig = x 147 | B, C, H, W = x.shape 148 | weight = torch.matmul(routing_weights, self.weight) # (Expert x out x in x 3x3) --> (B x out x in x 3x3) 149 | new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size 150 | weight = weight.view(new_weight_shape) # (B*out x in x 3 x 3) 151 | bias = None 152 | if self.bias is not None: 153 | bias = torch.matmul(routing_weights, self.bias) 154 | bias = bias.view(B * self.out_channels) 155 | # move batch elements with channels so each batch element can be efficiently convolved with separate kernel 156 | x = x.view(1, B * C, H, W) 157 | if self.dynamic_padding: 158 | out = conv2d_same( 159 | x, weight, bias, stride=self.stride, padding=self.padding, 160 | dilation=self.dilation, groups=self.groups * B) 161 | else: 162 | out = F.conv2d( 163 | x, weight, bias, stride=self.stride, padding=self.padding, 164 | dilation=self.dilation, groups=self.groups * B) 165 | 166 | # out : (1 x B*out x ...) 167 | out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1]) 168 | 169 | # out2 = self.forward_legacy(x_orig, routing_weights) 170 | # lt = torch.lt(torch.abs(torch.add(out, -out2)), 1e-8) 171 | # assert torch.all(lt), torch.abs(torch.add(out, -out2))[lt] 172 | # print('checked') 173 | return out 174 | 175 | def forward_legacy(self, x, routing_weights): 176 | # Literal port (from TF definition) 177 | B, C, H, W = x.shape 178 | weight = torch.matmul(routing_weights, self.weight) # (Expert x out x in x 3x3) --> (B x out x in x 3x3) 179 | x = torch.split(x, 1, 0) 180 | weight = torch.split(weight, 1, 0) 181 | if self.bias is not None: 182 | bias = torch.matmul(routing_weights, self.bias) 183 | bias = torch.split(bias, 1, 0) 184 | else: 185 | bias = [None] * B 186 | out = [] 187 | if self.dynamic_padding: 188 | conv_fn = conv2d_same 189 | else: 190 | conv_fn = F.conv2d 191 | for xi, wi, bi in zip(x, weight, bias): 192 | wi = wi.view(*self.weight_shape) 193 | if bi is not None: 194 | bi = bi.view(*self.bias_shape) 195 | out.append(conv_fn( 196 | xi, wi, bi, stride=self.stride, padding=self.padding, 197 | dilation=self.dilation, groups=self.groups)) 198 | out = torch.cat(out, 0) 199 | return out 200 | -------------------------------------------------------------------------------- /UniformAugment/imagenet.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import shutil 4 | import torch 5 | 6 | ARCHIVE_DICT = { 7 | 'train': { 8 | 'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_train.tar', 9 | 'md5': '1d675b47d978889d74fa0da5fadfb00e', 10 | }, 11 | 'val': { 12 | 'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_val.tar', 13 | 'md5': '29b22e2961454d5413ddabcf34fc5622', 14 | }, 15 | 'devkit': { 16 | 'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_devkit_t12.tar.gz', 17 | 'md5': 'fa75699e90414af021442c21a62c3abf', 18 | } 19 | } 20 | 21 | 22 | import torchvision 23 | from torchvision.datasets.utils import check_integrity, download_url 24 | 25 | 26 | # copy ILSVRC/ImageSets/CLS-LOC/train_cls.txt to ./root/ 27 | # to skip os walk (it's too slow) using ILSVRC/ImageSets/CLS-LOC/train_cls.txt file 28 | class ImageNet(torchvision.datasets.ImageFolder): 29 | """`ImageNet `_ 2012 Classification Dataset. 30 | 31 | Args: 32 | root (string): Root directory of the ImageNet Dataset. 33 | split (string, optional): The dataset split, supports ``train``, or ``val``. 34 | download (bool, optional): If true, downloads the dataset from the internet and 35 | puts it in root directory. If dataset is already downloaded, it is not 36 | downloaded again. 37 | transform (callable, optional): A function/transform that takes in an PIL image 38 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 39 | target_transform (callable, optional): A function/transform that takes in the 40 | target and transforms it. 41 | loader (callable, optional): A function to load an image given its path. 42 | 43 | Attributes: 44 | classes (list): List of the class names. 45 | class_to_idx (dict): Dict with items (class_name, class_index). 46 | wnids (list): List of the WordNet IDs. 47 | wnid_to_idx (dict): Dict with items (wordnet_id, class_index). 48 | imgs (list): List of (image path, class_index) tuples 49 | targets (list): The class_index value for each image in the dataset 50 | """ 51 | 52 | def __init__(self, root, split='train', download=False, **kwargs): 53 | root = self.root = os.path.expanduser(root) 54 | self.split = self._verify_split(split) 55 | 56 | if download: 57 | self.download() 58 | wnid_to_classes = self._load_meta_file()[0] 59 | 60 | # to skip os walk (it's too slow) using ILSVRC/ImageSets/CLS-LOC/train_cls.txt file 61 | listfile = os.path.join(root, 'train_cls.txt') 62 | if split == 'train' and os.path.exists(listfile): 63 | torchvision.datasets.VisionDataset.__init__(self, root, **kwargs) 64 | with open(listfile, 'r') as f: 65 | datalist = [ 66 | line.strip().split(' ')[0] 67 | for line in f.readlines() 68 | if line.strip() 69 | ] 70 | 71 | classes = list(set([line.split('/')[0] for line in datalist])) 72 | classes.sort() 73 | class_to_idx = {classes[i]: i for i in range(len(classes))} 74 | 75 | samples = [ 76 | (os.path.join(self.split_folder, line + '.JPEG'), class_to_idx[line.split('/')[0]]) 77 | for line in datalist 78 | ] 79 | 80 | self.loader = torchvision.datasets.folder.default_loader 81 | self.extensions = torchvision.datasets.folder.IMG_EXTENSIONS 82 | 83 | self.classes = classes 84 | self.class_to_idx = class_to_idx 85 | self.samples = samples 86 | self.targets = [s[1] for s in samples] 87 | 88 | self.imgs = self.samples 89 | else: 90 | super(ImageNet, self).__init__(self.split_folder, **kwargs) 91 | 92 | self.root = root 93 | 94 | idcs = [idx for _, idx in self.imgs] 95 | self.wnids = self.classes 96 | self.wnid_to_idx = {wnid: idx for idx, wnid in zip(idcs, self.wnids)} 97 | self.classes = [wnid_to_classes[wnid] for wnid in self.wnids] 98 | self.class_to_idx = {cls: idx 99 | for clss, idx in zip(self.classes, idcs) 100 | for cls in clss} 101 | 102 | def download(self): 103 | if not check_integrity(self.meta_file): 104 | tmpdir = os.path.join(self.root, 'tmp') 105 | 106 | archive_dict = ARCHIVE_DICT['devkit'] 107 | download_and_extract_tar(archive_dict['url'], self.root, 108 | extract_root=tmpdir, 109 | md5=archive_dict['md5']) 110 | devkit_folder = _splitexts(os.path.basename(archive_dict['url']))[0] 111 | meta = parse_devkit(os.path.join(tmpdir, devkit_folder)) 112 | self._save_meta_file(*meta) 113 | 114 | shutil.rmtree(tmpdir) 115 | 116 | if not os.path.isdir(self.split_folder): 117 | archive_dict = ARCHIVE_DICT[self.split] 118 | download_and_extract_tar(archive_dict['url'], self.root, 119 | extract_root=self.split_folder, 120 | md5=archive_dict['md5']) 121 | 122 | if self.split == 'train': 123 | prepare_train_folder(self.split_folder) 124 | elif self.split == 'val': 125 | val_wnids = self._load_meta_file()[1] 126 | prepare_val_folder(self.split_folder, val_wnids) 127 | else: 128 | msg = ("You set download=True, but a folder '{}' already exist in " 129 | "the root directory. If you want to re-download or re-extract the " 130 | "archive, delete the folder.") 131 | print(msg.format(self.split)) 132 | 133 | @property 134 | def meta_file(self): 135 | return os.path.join(self.root, 'meta.bin') 136 | 137 | def _load_meta_file(self): 138 | if check_integrity(self.meta_file): 139 | return torch.load(self.meta_file) 140 | raise RuntimeError("Meta file not found or corrupted.", 141 | "You can use download=True to create it.") 142 | 143 | def _save_meta_file(self, wnid_to_class, val_wnids): 144 | torch.save((wnid_to_class, val_wnids), self.meta_file) 145 | 146 | def _verify_split(self, split): 147 | if split not in self.valid_splits: 148 | msg = "Unknown split {} .".format(split) 149 | msg += "Valid splits are {{}}.".format(", ".join(self.valid_splits)) 150 | raise ValueError(msg) 151 | return split 152 | 153 | @property 154 | def valid_splits(self): 155 | return 'train', 'val' 156 | 157 | @property 158 | def split_folder(self): 159 | return os.path.join(self.root, self.split) 160 | 161 | def extra_repr(self): 162 | return "Split: {split}".format(**self.__dict__) 163 | 164 | 165 | def extract_tar(src, dest=None, gzip=None, delete=False): 166 | import tarfile 167 | 168 | if dest is None: 169 | dest = os.path.dirname(src) 170 | if gzip is None: 171 | gzip = src.lower().endswith('.gz') 172 | 173 | mode = 'r:gz' if gzip else 'r' 174 | with tarfile.open(src, mode) as tarfh: 175 | tarfh.extractall(path=dest) 176 | 177 | if delete: 178 | os.remove(src) 179 | 180 | 181 | def download_and_extract_tar(url, download_root, extract_root=None, filename=None, 182 | md5=None, **kwargs): 183 | download_root = os.path.expanduser(download_root) 184 | if extract_root is None: 185 | extract_root = download_root 186 | if filename is None: 187 | filename = os.path.basename(url) 188 | 189 | if not check_integrity(os.path.join(download_root, filename), md5): 190 | download_url(url, download_root, filename=filename, md5=md5) 191 | 192 | extract_tar(os.path.join(download_root, filename), extract_root, **kwargs) 193 | 194 | 195 | def parse_devkit(root): 196 | idx_to_wnid, wnid_to_classes = parse_meta(root) 197 | val_idcs = parse_val_groundtruth(root) 198 | val_wnids = [idx_to_wnid[idx] for idx in val_idcs] 199 | return wnid_to_classes, val_wnids 200 | 201 | 202 | def parse_meta(devkit_root, path='data', filename='meta.mat'): 203 | import scipy.io as sio 204 | 205 | metafile = os.path.join(devkit_root, path, filename) 206 | meta = sio.loadmat(metafile, squeeze_me=True)['synsets'] 207 | nums_children = list(zip(*meta))[4] 208 | meta = [meta[idx] for idx, num_children in enumerate(nums_children) 209 | if num_children == 0] 210 | idcs, wnids, classes = list(zip(*meta))[:3] 211 | classes = [tuple(clss.split(', ')) for clss in classes] 212 | idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)} 213 | wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)} 214 | return idx_to_wnid, wnid_to_classes 215 | 216 | 217 | def parse_val_groundtruth(devkit_root, path='data', 218 | filename='ILSVRC2012_validation_ground_truth.txt'): 219 | with open(os.path.join(devkit_root, path, filename), 'r') as txtfh: 220 | val_idcs = txtfh.readlines() 221 | return [int(val_idx) for val_idx in val_idcs] 222 | 223 | 224 | def prepare_train_folder(folder): 225 | for archive in [os.path.join(folder, archive) for archive in os.listdir(folder)]: 226 | extract_tar(archive, os.path.splitext(archive)[0], delete=True) 227 | 228 | 229 | def prepare_val_folder(folder, wnids): 230 | img_files = sorted([os.path.join(folder, file) for file in os.listdir(folder)]) 231 | 232 | for wnid in set(wnids): 233 | os.mkdir(os.path.join(folder, wnid)) 234 | 235 | for wnid, img_file in zip(wnids, img_files): 236 | shutil.move(img_file, os.path.join(folder, wnid, os.path.basename(img_file))) 237 | 238 | 239 | def _splitexts(root): 240 | exts = [] 241 | ext = '.' 242 | while ext: 243 | root, ext = os.path.splitext(root) 244 | exts.append(ext) 245 | return root, ''.join(reversed(exts)) 246 | -------------------------------------------------------------------------------- /UniformAugment/networks/pyramidnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | from UniformAugment.networks.shakedrop import ShakeDrop 6 | 7 | 8 | def conv3x3(in_planes, out_planes, stride=1): 9 | """ 10 | 3x3 convolution with padding 11 | """ 12 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 13 | 14 | 15 | class BasicBlock(nn.Module): 16 | outchannel_ratio = 1 17 | 18 | def __init__(self, inplanes, planes, stride=1, downsample=None, p_shakedrop=1.0): 19 | super(BasicBlock, self).__init__() 20 | self.bn1 = nn.BatchNorm2d(inplanes) 21 | self.conv1 = conv3x3(inplanes, planes, stride) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.conv2 = conv3x3(planes, planes) 24 | self.bn3 = nn.BatchNorm2d(planes) 25 | self.relu = nn.ReLU(inplace=True) 26 | self.downsample = downsample 27 | self.stride = stride 28 | self.shake_drop = ShakeDrop(p_shakedrop) 29 | 30 | def forward(self, x): 31 | 32 | out = self.bn1(x) 33 | out = self.conv1(out) 34 | out = self.bn2(out) 35 | out = self.relu(out) 36 | out = self.conv2(out) 37 | out = self.bn3(out) 38 | 39 | out = self.shake_drop(out) 40 | 41 | if self.downsample is not None: 42 | shortcut = self.downsample(x) 43 | featuremap_size = shortcut.size()[2:4] 44 | else: 45 | shortcut = x 46 | featuremap_size = out.size()[2:4] 47 | 48 | batch_size = out.size()[0] 49 | residual_channel = out.size()[1] 50 | shortcut_channel = shortcut.size()[1] 51 | 52 | if residual_channel != shortcut_channel: 53 | padding = torch.autograd.Variable( 54 | torch.cuda.FloatTensor(batch_size, residual_channel - shortcut_channel, featuremap_size[0], 55 | featuremap_size[1]).fill_(0)) 56 | out += torch.cat((shortcut, padding), 1) 57 | else: 58 | out += shortcut 59 | 60 | return out 61 | 62 | 63 | class Bottleneck(nn.Module): 64 | outchannel_ratio = 4 65 | 66 | def __init__(self, inplanes, planes, stride=1, downsample=None, p_shakedrop=1.0): 67 | super(Bottleneck, self).__init__() 68 | self.bn1 = nn.BatchNorm2d(inplanes) 69 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 70 | self.bn2 = nn.BatchNorm2d(planes) 71 | self.conv2 = nn.Conv2d(planes, (planes * 1), kernel_size=3, stride=stride, 72 | padding=1, bias=False) 73 | self.bn3 = nn.BatchNorm2d((planes * 1)) 74 | self.conv3 = nn.Conv2d((planes * 1), planes * Bottleneck.outchannel_ratio, kernel_size=1, bias=False) 75 | self.bn4 = nn.BatchNorm2d(planes * Bottleneck.outchannel_ratio) 76 | self.relu = nn.ReLU(inplace=True) 77 | self.downsample = downsample 78 | self.stride = stride 79 | self.shake_drop = ShakeDrop(p_shakedrop) 80 | 81 | def forward(self, x): 82 | 83 | out = self.bn1(x) 84 | out = self.conv1(out) 85 | 86 | out = self.bn2(out) 87 | out = self.relu(out) 88 | out = self.conv2(out) 89 | 90 | out = self.bn3(out) 91 | out = self.relu(out) 92 | out = self.conv3(out) 93 | 94 | out = self.bn4(out) 95 | 96 | out = self.shake_drop(out) 97 | 98 | if self.downsample is not None: 99 | shortcut = self.downsample(x) 100 | featuremap_size = shortcut.size()[2:4] 101 | else: 102 | shortcut = x 103 | featuremap_size = out.size()[2:4] 104 | 105 | batch_size = out.size()[0] 106 | residual_channel = out.size()[1] 107 | shortcut_channel = shortcut.size()[1] 108 | 109 | if residual_channel != shortcut_channel: 110 | padding = torch.autograd.Variable( 111 | torch.cuda.FloatTensor(batch_size, residual_channel - shortcut_channel, featuremap_size[0], 112 | featuremap_size[1]).fill_(0)) 113 | out += torch.cat((shortcut, padding), 1) 114 | else: 115 | out += shortcut 116 | 117 | return out 118 | 119 | 120 | class PyramidNet(nn.Module): 121 | 122 | def __init__(self, dataset, depth, alpha, num_classes, bottleneck=True): 123 | super(PyramidNet, self).__init__() 124 | self.dataset = dataset 125 | if self.dataset.startswith('cifar'): 126 | self.inplanes = 16 127 | if bottleneck: 128 | n = int((depth - 2) / 9) 129 | block = Bottleneck 130 | else: 131 | n = int((depth - 2) / 6) 132 | block = BasicBlock 133 | 134 | self.addrate = alpha / (3 * n * 1.0) 135 | self.ps_shakedrop = [1. - (1.0 - (0.5 / (3 * n)) * (i + 1)) for i in range(3 * n)] 136 | 137 | self.input_featuremap_dim = self.inplanes 138 | self.conv1 = nn.Conv2d(3, self.input_featuremap_dim, kernel_size=3, stride=1, padding=1, bias=False) 139 | self.bn1 = nn.BatchNorm2d(self.input_featuremap_dim) 140 | 141 | self.featuremap_dim = self.input_featuremap_dim 142 | self.layer1 = self.pyramidal_make_layer(block, n) 143 | self.layer2 = self.pyramidal_make_layer(block, n, stride=2) 144 | self.layer3 = self.pyramidal_make_layer(block, n, stride=2) 145 | 146 | self.final_featuremap_dim = self.input_featuremap_dim 147 | self.bn_final = nn.BatchNorm2d(self.final_featuremap_dim) 148 | self.relu_final = nn.ReLU(inplace=True) 149 | self.avgpool = nn.AvgPool2d(8) 150 | self.fc = nn.Linear(self.final_featuremap_dim, num_classes) 151 | 152 | elif dataset == 'imagenet': 153 | blocks = {18: BasicBlock, 34: BasicBlock, 50: Bottleneck, 101: Bottleneck, 152: Bottleneck, 200: Bottleneck} 154 | layers = {18: [2, 2, 2, 2], 34: [3, 4, 6, 3], 50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], 155 | 200: [3, 24, 36, 3]} 156 | 157 | if layers.get(depth) is None: 158 | if bottleneck == True: 159 | blocks[depth] = Bottleneck 160 | temp_cfg = int((depth - 2) / 12) 161 | else: 162 | blocks[depth] = BasicBlock 163 | temp_cfg = int((depth - 2) / 8) 164 | 165 | layers[depth] = [temp_cfg, temp_cfg, temp_cfg, temp_cfg] 166 | print('=> the layer configuration for each stage is set to', layers[depth]) 167 | 168 | self.inplanes = 64 169 | self.addrate = alpha / (sum(layers[depth]) * 1.0) 170 | 171 | self.input_featuremap_dim = self.inplanes 172 | self.conv1 = nn.Conv2d(3, self.input_featuremap_dim, kernel_size=7, stride=2, padding=3, bias=False) 173 | self.bn1 = nn.BatchNorm2d(self.input_featuremap_dim) 174 | self.relu = nn.ReLU(inplace=True) 175 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 176 | 177 | self.featuremap_dim = self.input_featuremap_dim 178 | self.layer1 = self.pyramidal_make_layer(blocks[depth], layers[depth][0]) 179 | self.layer2 = self.pyramidal_make_layer(blocks[depth], layers[depth][1], stride=2) 180 | self.layer3 = self.pyramidal_make_layer(blocks[depth], layers[depth][2], stride=2) 181 | self.layer4 = self.pyramidal_make_layer(blocks[depth], layers[depth][3], stride=2) 182 | 183 | self.final_featuremap_dim = self.input_featuremap_dim 184 | self.bn_final = nn.BatchNorm2d(self.final_featuremap_dim) 185 | self.relu_final = nn.ReLU(inplace=True) 186 | self.avgpool = nn.AvgPool2d(7) 187 | self.fc = nn.Linear(self.final_featuremap_dim, num_classes) 188 | 189 | for m in self.modules(): 190 | if isinstance(m, nn.Conv2d): 191 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 192 | m.weight.data.normal_(0, math.sqrt(2. / n)) 193 | elif isinstance(m, nn.BatchNorm2d): 194 | m.weight.data.fill_(1) 195 | m.bias.data.zero_() 196 | 197 | assert len(self.ps_shakedrop) == 0, self.ps_shakedrop 198 | 199 | def pyramidal_make_layer(self, block, block_depth, stride=1): 200 | downsample = None 201 | if stride != 1: # or self.inplanes != int(round(featuremap_dim_1st)) * block.outchannel_ratio: 202 | downsample = nn.AvgPool2d((2, 2), stride=(2, 2), ceil_mode=True) 203 | 204 | layers = [] 205 | self.featuremap_dim = self.featuremap_dim + self.addrate 206 | layers.append(block(self.input_featuremap_dim, int(round(self.featuremap_dim)), stride, downsample, p_shakedrop=self.ps_shakedrop.pop(0))) 207 | for i in range(1, block_depth): 208 | temp_featuremap_dim = self.featuremap_dim + self.addrate 209 | layers.append( 210 | block(int(round(self.featuremap_dim)) * block.outchannel_ratio, int(round(temp_featuremap_dim)), 1, p_shakedrop=self.ps_shakedrop.pop(0))) 211 | self.featuremap_dim = temp_featuremap_dim 212 | self.input_featuremap_dim = int(round(self.featuremap_dim)) * block.outchannel_ratio 213 | 214 | return nn.Sequential(*layers) 215 | 216 | def forward(self, x): 217 | if self.dataset == 'cifar10' or self.dataset == 'cifar100': 218 | x = self.conv1(x) 219 | x = self.bn1(x) 220 | 221 | x = self.layer1(x) 222 | x = self.layer2(x) 223 | x = self.layer3(x) 224 | 225 | x = self.bn_final(x) 226 | x = self.relu_final(x) 227 | x = self.avgpool(x) 228 | x = x.view(x.size(0), -1) 229 | x = self.fc(x) 230 | 231 | elif self.dataset == 'imagenet': 232 | x = self.conv1(x) 233 | x = self.bn1(x) 234 | x = self.relu(x) 235 | x = self.maxpool(x) 236 | 237 | x = self.layer1(x) 238 | x = self.layer2(x) 239 | x = self.layer3(x) 240 | x = self.layer4(x) 241 | 242 | x = self.bn_final(x) 243 | x = self.relu_final(x) 244 | x = self.avgpool(x) 245 | x = x.view(x.size(0), -1) 246 | x = self.fc(x) 247 | 248 | return x 249 | -------------------------------------------------------------------------------- /UniformAugment/networks/efficientnet_pytorch/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | from functools import partial 6 | from .utils import ( 7 | round_filters, 8 | round_repeats, 9 | drop_connect, 10 | get_same_padding_conv2d, 11 | get_model_params, 12 | efficientnet_params, 13 | load_pretrained_weights, 14 | MemoryEfficientSwish, 15 | ) 16 | 17 | 18 | class RoutingFn(nn.Linear): 19 | pass 20 | 21 | 22 | class MBConvBlock(nn.Module): 23 | """ 24 | Mobile Inverted Residual Bottleneck Block 25 | 26 | Args: 27 | block_args (namedtuple): BlockArgs, see above 28 | global_params (namedtuple): GlobalParam, see above 29 | 30 | Attributes: 31 | has_se (bool): Whether the block contains a Squeeze and Excitation layer. 32 | """ 33 | 34 | def __init__(self, block_args, global_params, norm_layer=None): 35 | super().__init__() 36 | self._block_args = block_args 37 | self._bn_mom = 1 - global_params.batch_norm_momentum 38 | self._bn_eps = global_params.batch_norm_epsilon 39 | self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1) 40 | self.id_skip = block_args.id_skip # skip connection and drop connect 41 | if norm_layer is None: 42 | norm_layer = nn.BatchNorm2d 43 | 44 | self.condconv_num_expert = block_args.condconv_num_expert 45 | if self._is_condconv(): 46 | self.routing_fn = RoutingFn(self._block_args.input_filters, self.condconv_num_expert) 47 | 48 | # Get static or dynamic convolution depending on image size 49 | Conv2d = get_same_padding_conv2d(image_size=global_params.image_size, condconv_num_expert=block_args.condconv_num_expert) 50 | Conv2dse = get_same_padding_conv2d(image_size=global_params.image_size) 51 | 52 | # Expansion phase 53 | inp = self._block_args.input_filters # number of input channels 54 | oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels 55 | if self._block_args.expand_ratio != 1: 56 | self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False) 57 | self._bn0 = norm_layer(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) 58 | 59 | # Depthwise convolution phase 60 | k = self._block_args.kernel_size 61 | s = self._block_args.stride 62 | self._depthwise_conv = Conv2d( 63 | in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise 64 | kernel_size=k, stride=s, bias=False) 65 | self._bn1 = norm_layer(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) 66 | 67 | # Squeeze and Excitation layer, if desired 68 | if self.has_se: 69 | num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio)) 70 | self._se_reduce = Conv2dse(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1) 71 | self._se_expand = Conv2dse(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1) 72 | 73 | # Output phase 74 | final_oup = self._block_args.output_filters 75 | self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False) 76 | self._bn2 = norm_layer(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps) 77 | self._swish = MemoryEfficientSwish() 78 | 79 | def _is_condconv(self): 80 | return self.condconv_num_expert > 1 81 | 82 | def forward(self, inputs, drop_connect_rate=None): 83 | """ 84 | :param inputs: input tensor 85 | :param drop_connect_rate: drop connect rate (float, between 0 and 1) 86 | :return: output of block 87 | """ 88 | 89 | if self._is_condconv(): 90 | feat = F.adaptive_avg_pool2d(inputs, 1).flatten(1) 91 | routing_w = torch.sigmoid(self.routing_fn(feat)) 92 | 93 | if self._block_args.expand_ratio != 1: 94 | _expand_conv = partial(self._expand_conv, routing_weights=routing_w) 95 | _depthwise_conv = partial(self._depthwise_conv, routing_weights=routing_w) 96 | _project_conv = partial(self._project_conv, routing_weights=routing_w) 97 | else: 98 | if self._block_args.expand_ratio != 1: 99 | _expand_conv = self._expand_conv 100 | _depthwise_conv, _project_conv = self._depthwise_conv, self._project_conv 101 | 102 | # Expansion and Depthwise Convolution 103 | x = inputs 104 | if self._block_args.expand_ratio != 1: 105 | x = self._swish(self._bn0(_expand_conv(inputs))) 106 | x = self._swish(self._bn1(_depthwise_conv(x))) 107 | 108 | # Squeeze and Excitation 109 | if self.has_se: 110 | x_squeezed = F.adaptive_avg_pool2d(x, 1) 111 | x_squeezed = self._se_expand(self._swish(self._se_reduce(x_squeezed))) 112 | x = torch.sigmoid(x_squeezed) * x 113 | 114 | x = self._bn2(_project_conv(x)) 115 | 116 | # Skip connection and drop connect 117 | input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters 118 | if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters: 119 | if drop_connect_rate: 120 | x = drop_connect(x, drop_p=drop_connect_rate, training=self.training) 121 | x = x + inputs # skip connection 122 | return x 123 | 124 | def set_swish(self): 125 | """Sets swish function as memory efficient (for training) or standard (for export)""" 126 | self._swish = MemoryEfficientSwish() 127 | 128 | 129 | class EfficientNet(nn.Module): 130 | """ 131 | An EfficientNet model. Most easily loaded with the .from_name or .from_pretrained methods 132 | 133 | Args: 134 | blocks_args (list): A list of BlockArgs to construct blocks 135 | global_params (namedtuple): A set of GlobalParams shared between blocks 136 | 137 | Example: 138 | model = EfficientNet.from_pretrained('efficientnet-b0') 139 | 140 | """ 141 | 142 | def __init__(self, blocks_args=None, global_params=None, norm_layer=None): 143 | super().__init__() 144 | assert isinstance(blocks_args, list), 'blocks_args should be a list' 145 | assert len(blocks_args) > 0, 'block args must be greater than 0' 146 | self._global_params = global_params 147 | self._blocks_args = blocks_args 148 | if norm_layer is None: 149 | norm_layer = nn.BatchNorm2d 150 | 151 | # Get static or dynamic convolution depending on image size 152 | Conv2d = get_same_padding_conv2d(image_size=global_params.image_size) 153 | 154 | # Batch norm parameters 155 | bn_mom = 1 - self._global_params.batch_norm_momentum 156 | bn_eps = self._global_params.batch_norm_epsilon 157 | 158 | # Stem 159 | in_channels = 3 # rgb 160 | out_channels = round_filters(32, self._global_params) # number of output channels 161 | self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) 162 | self._bn0 = norm_layer(num_features=out_channels, momentum=bn_mom, eps=bn_eps) 163 | 164 | # Build blocks 165 | self._blocks = nn.ModuleList([]) 166 | for idx, block_args in enumerate(self._blocks_args): 167 | # Update block input and output filters based on depth multiplier. 168 | block_args = block_args._replace( 169 | input_filters=round_filters(block_args.input_filters, self._global_params), 170 | output_filters=round_filters(block_args.output_filters, self._global_params), 171 | num_repeat=round_repeats(block_args.num_repeat, self._global_params) 172 | ) 173 | 174 | # The first block needs to take care of stride and filter size increase. 175 | self._blocks.append(MBConvBlock(block_args, self._global_params, norm_layer=norm_layer)) 176 | if block_args.num_repeat > 1: 177 | block_args = block_args._replace(input_filters=block_args.output_filters, stride=1) 178 | for _ in range(block_args.num_repeat - 1): 179 | self._blocks.append(MBConvBlock(block_args, self._global_params, norm_layer=norm_layer)) 180 | 181 | # Head 182 | in_channels = block_args.output_filters # output of final block 183 | out_channels = round_filters(1280, self._global_params) 184 | self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False) 185 | self._bn1 = norm_layer(num_features=out_channels, momentum=bn_mom, eps=bn_eps) 186 | 187 | # Final linear layer 188 | self._avg_pooling = nn.AdaptiveAvgPool2d(1) 189 | self._dropout = nn.Dropout(self._global_params.dropout_rate) 190 | self._fc = nn.Linear(out_channels, self._global_params.num_classes) 191 | self._swish = MemoryEfficientSwish() 192 | 193 | def set_swish(self): 194 | """Sets swish function as memory efficient (for training) or standard (for export)""" 195 | self._swish = MemoryEfficientSwish() 196 | for block in self._blocks: 197 | block.set_swish() 198 | 199 | def extract_features(self, inputs): 200 | """ Returns output of the final convolution layer """ 201 | 202 | # Stem 203 | x = self._swish(self._bn0(self._conv_stem(inputs))) 204 | 205 | # Blocks 206 | for idx, block in enumerate(self._blocks): 207 | drop_connect_rate = self._global_params.drop_connect_rate 208 | if drop_connect_rate: 209 | drop_connect_rate *= float(idx) / len(self._blocks) 210 | x = block(x, drop_connect_rate=drop_connect_rate) 211 | 212 | # Head 213 | x = self._swish(self._bn1(self._conv_head(x))) 214 | 215 | return x 216 | 217 | def forward(self, inputs): 218 | """ Calls extract_features to extract features, applies final linear layer, and returns logits. """ 219 | bs = inputs.size(0) 220 | # Convolution layers 221 | x = self.extract_features(inputs) 222 | 223 | # Pooling and final linear layer 224 | x = self._avg_pooling(x) 225 | x = x.view(bs, -1) 226 | x = self._dropout(x) 227 | x = self._fc(x) 228 | return x 229 | 230 | @classmethod 231 | def from_name(cls, model_name, override_params=None, norm_layer=None, condconv_num_expert=1): 232 | cls._check_model_name_is_valid(model_name) 233 | blocks_args, global_params = get_model_params(model_name, override_params, condconv_num_expert=condconv_num_expert) 234 | return cls(blocks_args, global_params, norm_layer=norm_layer) 235 | 236 | @classmethod 237 | def from_pretrained(cls, model_name, num_classes=1000): 238 | model = cls.from_name(model_name, override_params={'num_classes': num_classes}) 239 | load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000)) 240 | 241 | return model 242 | 243 | @classmethod 244 | def get_image_size(cls, model_name): 245 | cls._check_model_name_is_valid(model_name) 246 | _, _, res, _ = efficientnet_params(model_name) 247 | return res 248 | 249 | @classmethod 250 | def _check_model_name_is_valid(cls, model_name, also_need_pretrained_weights=False): 251 | """ Validates model name. None that pretrained weights are only available for 252 | the first four models (efficientnet-b{i} for i in 0,1,2,3) at the moment. """ 253 | num_models = 4 if also_need_pretrained_weights else 8 254 | valid_models = ['efficientnet-b'+str(i) for i in range(num_models)] 255 | if model_name not in valid_models: 256 | raise ValueError(f'model_name={model_name} should be one of: ' + ', '.join(valid_models)) 257 | -------------------------------------------------------------------------------- /UniformAugment/networks/efficientnet_pytorch/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains helper functions for building the model and for loading model parameters. 3 | These helper functions are built to mirror those in the official TensorFlow implementation. 4 | """ 5 | 6 | import re 7 | import math 8 | import collections 9 | from functools import partial 10 | import torch 11 | from torch import nn 12 | from torch.nn import functional as F 13 | from torch.utils import model_zoo 14 | 15 | ######################################################################## 16 | ############### HELPERS FUNCTIONS FOR MODEL ARCHITECTURE ############### 17 | ######################################################################## 18 | 19 | 20 | # Parameters for the entire model (stem, all blocks, and head) 21 | from UniformAugment.networks.efficientnet_pytorch.condconv import CondConv2d 22 | 23 | GlobalParams = collections.namedtuple('GlobalParams', [ 24 | 'batch_norm_momentum', 'batch_norm_epsilon', 'dropout_rate', 25 | 'num_classes', 'width_coefficient', 'depth_coefficient', 26 | 'depth_divisor', 'min_depth', 'drop_connect_rate', 'image_size']) 27 | 28 | # Parameters for an individual model block 29 | BlockArgs = collections.namedtuple('BlockArgs', [ 30 | 'kernel_size', 'num_repeat', 'input_filters', 'output_filters', 31 | 'expand_ratio', 'id_skip', 'stride', 'se_ratio', 'condconv_num_expert']) 32 | 33 | # Change namedtuple defaults 34 | GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields) 35 | BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields) 36 | 37 | 38 | class SwishImplementation(torch.autograd.Function): 39 | @staticmethod 40 | def forward(ctx, i): 41 | result = i * torch.sigmoid(i) 42 | ctx.save_for_backward(i) 43 | return result 44 | 45 | @staticmethod 46 | def backward(ctx, grad_output): 47 | i = ctx.saved_tensors[0] 48 | sigmoid_i = torch.sigmoid(i) 49 | return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) 50 | 51 | 52 | class MemoryEfficientSwish(nn.Module): 53 | def forward(self, x): 54 | return SwishImplementation.apply(x) 55 | 56 | 57 | def round_filters(filters, global_params): 58 | """ Calculate and round number of filters based on depth multiplier. """ 59 | multiplier = global_params.width_coefficient 60 | if not multiplier: 61 | return filters 62 | divisor = global_params.depth_divisor 63 | min_depth = global_params.min_depth 64 | filters *= multiplier 65 | min_depth = min_depth or divisor 66 | new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor) 67 | if new_filters < 0.9 * filters: # prevent rounding by more than 10% 68 | new_filters += divisor 69 | return int(new_filters) 70 | 71 | 72 | def round_repeats(repeats, global_params): 73 | """ Round number of filters based on depth multiplier. """ 74 | multiplier = global_params.depth_coefficient 75 | if not multiplier: 76 | return repeats 77 | return int(math.ceil(multiplier * repeats)) 78 | 79 | 80 | def drop_connect(inputs, drop_p, training): 81 | """ Drop connect. """ 82 | if not training: 83 | return inputs * (1. - drop_p) 84 | batch_size = inputs.shape[0] 85 | random_tensor = torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device) 86 | binary_tensor = random_tensor > drop_p 87 | output = inputs * binary_tensor.float() 88 | # output = inputs / (1. - drop_p) * binary_tensor.float() 89 | return output 90 | 91 | # if not training: return inputs 92 | # batch_size = inputs.shape[0] 93 | # keep_prob = 1 - drop_p 94 | # random_tensor = keep_prob 95 | # random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device) 96 | # binary_tensor = torch.floor(random_tensor) 97 | # output = inputs / keep_prob * binary_tensor 98 | # return output 99 | 100 | 101 | def get_same_padding_conv2d(image_size=None, condconv_num_expert=1): 102 | """ Chooses static padding if you have specified an image size, and dynamic padding otherwise. 103 | Static padding is necessary for ONNX exporting of models. """ 104 | if condconv_num_expert > 1: 105 | return partial(CondConv2d, num_experts=condconv_num_expert) 106 | elif image_size is None: 107 | return Conv2dDynamicSamePadding 108 | else: 109 | return partial(Conv2dStaticSamePadding, image_size=image_size) 110 | 111 | 112 | class Conv2dDynamicSamePadding(nn.Conv2d): 113 | """ 2D Convolutions like TensorFlow, for a dynamic image size """ 114 | 115 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True): 116 | super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) 117 | self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 118 | 119 | def forward(self, x): 120 | ih, iw = x.size()[-2:] 121 | kh, kw = self.weight.size()[-2:] 122 | sh, sw = self.stride 123 | oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) 124 | pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) 125 | pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) 126 | if pad_h > 0 or pad_w > 0: 127 | x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) 128 | return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 129 | 130 | 131 | class Conv2dStaticSamePadding(nn.Conv2d): 132 | """ 2D Convolutions like TensorFlow, for a fixed image size""" 133 | 134 | def __init__(self, in_channels, out_channels, kernel_size, image_size=None, **kwargs): 135 | super().__init__(in_channels, out_channels, kernel_size, **kwargs) 136 | self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 137 | 138 | # Calculate padding based on image size and save it 139 | assert image_size is not None 140 | ih, iw = image_size if type(image_size) == list else [image_size, image_size] 141 | kh, kw = self.weight.size()[-2:] 142 | sh, sw = self.stride 143 | oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) 144 | pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) 145 | pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) 146 | if pad_h > 0 or pad_w > 0: 147 | self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)) 148 | else: 149 | self.static_padding = Identity() 150 | 151 | def forward(self, x): 152 | x = self.static_padding(x) 153 | x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 154 | return x 155 | 156 | 157 | class Identity(nn.Module): 158 | def __init__(self, ): 159 | super(Identity, self).__init__() 160 | 161 | def forward(self, input): 162 | return input 163 | 164 | 165 | ######################################################################## 166 | ############## HELPERS FUNCTIONS FOR LOADING MODEL PARAMS ############## 167 | ######################################################################## 168 | 169 | 170 | def efficientnet_params(model_name): 171 | """ Map EfficientNet model name to parameter coefficients. """ 172 | params_dict = { 173 | # Coefficients: width,depth,res,dropout 174 | 'efficientnet-b0': (1.0, 1.0, 224, 0.2), 175 | 'efficientnet-b1': (1.0, 1.1, 240, 0.2), 176 | 'efficientnet-b2': (1.1, 1.2, 260, 0.3), 177 | 'efficientnet-b3': (1.2, 1.4, 300, 0.3), 178 | 'efficientnet-b4': (1.4, 1.8, 380, 0.4), 179 | 'efficientnet-b5': (1.6, 2.2, 456, 0.4), 180 | 'efficientnet-b6': (1.8, 2.6, 528, 0.5), 181 | 'efficientnet-b7': (2.0, 3.1, 600, 0.5), 182 | } 183 | return params_dict[model_name] 184 | 185 | 186 | class BlockDecoder(object): 187 | """ Block Decoder for readability, straight from the official TensorFlow repository """ 188 | 189 | @staticmethod 190 | def _decode_block_string(block_string): 191 | """ Gets a block through a string notation of arguments. """ 192 | assert isinstance(block_string, str) 193 | 194 | ops = block_string.split('_') 195 | options = {} 196 | for op in ops: 197 | splits = re.split(r'(\d.*)', op) 198 | if len(splits) >= 2: 199 | key, value = splits[:2] 200 | options[key] = value 201 | 202 | # Check stride 203 | assert (('s' in options and len(options['s']) == 1) or 204 | (len(options['s']) == 2 and options['s'][0] == options['s'][1])) 205 | 206 | return BlockArgs( 207 | kernel_size=int(options['k']), 208 | num_repeat=int(options['r']), 209 | input_filters=int(options['i']), 210 | output_filters=int(options['o']), 211 | expand_ratio=int(options['e']), 212 | id_skip=('noskip' not in block_string), 213 | se_ratio=float(options['se']) if 'se' in options else None, 214 | stride=[int(options['s'][0])], 215 | condconv_num_expert=0 216 | ) 217 | 218 | @staticmethod 219 | def _encode_block_string(block): 220 | """Encodes a block to a string.""" 221 | args = [ 222 | 'r%d' % block.num_repeat, 223 | 'k%d' % block.kernel_size, 224 | 's%d%d' % (block.strides[0], block.strides[1]), 225 | 'e%s' % block.expand_ratio, 226 | 'i%d' % block.input_filters, 227 | 'o%d' % block.output_filters 228 | ] 229 | if 0 < block.se_ratio <= 1: 230 | args.append('se%s' % block.se_ratio) 231 | if block.id_skip is False: 232 | args.append('noskip') 233 | return '_'.join(args) 234 | 235 | @staticmethod 236 | def decode(string_list): 237 | """ 238 | Decodes a list of string notations to specify blocks inside the network. 239 | 240 | :param string_list: a list of strings, each string is a notation of block 241 | :return: a list of BlockArgs namedtuples of block args 242 | """ 243 | assert isinstance(string_list, list) 244 | blocks_args = [] 245 | for block_string in string_list: 246 | blocks_args.append(BlockDecoder._decode_block_string(block_string)) 247 | return blocks_args 248 | 249 | @staticmethod 250 | def encode(blocks_args): 251 | """ 252 | Encodes a list of BlockArgs to a list of strings. 253 | 254 | :param blocks_args: a list of BlockArgs namedtuples of block args 255 | :return: a list of strings, each string is a notation of block 256 | """ 257 | block_strings = [] 258 | for block in blocks_args: 259 | block_strings.append(BlockDecoder._encode_block_string(block)) 260 | return block_strings 261 | 262 | 263 | def efficientnet(width_coefficient=None, depth_coefficient=None, dropout_rate=0.2, 264 | drop_connect_rate=0.2, image_size=None, num_classes=1000, condconv_num_expert=1): 265 | """ Creates a efficientnet model. """ 266 | 267 | blocks_args = [ 268 | 'r1_k3_s11_e1_i32_o16_se0.25', 'r2_k3_s22_e6_i16_o24_se0.25', 269 | 'r2_k5_s22_e6_i24_o40_se0.25', 'r3_k3_s22_e6_i40_o80_se0.25', 270 | 'r3_k5_s11_e6_i80_o112_se0.25', 'r4_k5_s22_e6_i112_o192_se0.25', 271 | 'r1_k3_s11_e6_i192_o320_se0.25', 272 | ] 273 | blocks_args = BlockDecoder.decode(blocks_args) 274 | 275 | blocks_args_new = blocks_args[:-3] 276 | for blocks_arg in blocks_args[-3:]: 277 | blocks_arg = blocks_arg._replace(condconv_num_expert=condconv_num_expert) 278 | blocks_args_new.append(blocks_arg) 279 | blocks_args = blocks_args_new 280 | 281 | global_params = GlobalParams( 282 | batch_norm_momentum=0.99, 283 | batch_norm_epsilon=1e-3, 284 | dropout_rate=dropout_rate, 285 | drop_connect_rate=drop_connect_rate, 286 | # data_format='channels_last', # removed, this is always true in PyTorch 287 | num_classes=num_classes, 288 | width_coefficient=width_coefficient, 289 | depth_coefficient=depth_coefficient, 290 | depth_divisor=8, 291 | min_depth=None, 292 | image_size=image_size, 293 | ) 294 | 295 | return blocks_args, global_params 296 | 297 | 298 | def get_model_params(model_name, override_params, condconv_num_expert=1): 299 | """ Get the block args and global params for a given model """ 300 | if model_name.startswith('efficientnet'): 301 | w, d, s, p = efficientnet_params(model_name) 302 | # note: all models have drop connect rate = 0.2 303 | blocks_args, global_params = efficientnet( 304 | width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s, condconv_num_expert=condconv_num_expert) 305 | else: 306 | raise NotImplementedError('model name is not pre-defined: %s' % model_name) 307 | if override_params: 308 | # ValueError will be raised here if override_params has fields not included in global_params. 309 | global_params = global_params._replace(**override_params) 310 | return blocks_args, global_params 311 | 312 | 313 | url_map = { 314 | 'efficientnet-b0': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b0-355c32eb.pth', 315 | 'efficientnet-b1': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b1-f1951068.pth', 316 | 'efficientnet-b2': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b2-8bb594d6.pth', 317 | 'efficientnet-b3': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b3-5fb5a3c3.pth', 318 | 'efficientnet-b4': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b4-6ed6700e.pth', 319 | 'efficientnet-b5': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b5-b6417697.pth', 320 | 'efficientnet-b6': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b6-c76e70fd.pth', 321 | 'efficientnet-b7': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b7-dcc49843.pth', 322 | } 323 | 324 | 325 | def load_pretrained_weights(model, model_name, load_fc=True): 326 | """ Loads pretrained weights, and downloads if loading for the first time. """ 327 | state_dict = model_zoo.load_url(url_map[model_name]) 328 | if load_fc: 329 | model.load_state_dict(state_dict) 330 | else: 331 | state_dict.pop('_fc.weight') 332 | state_dict.pop('_fc.bias') 333 | res = model.load_state_dict(state_dict, strict=False) 334 | assert set(res.missing_keys) == set(['_fc.weight', '_fc.bias']), 'issue loading pretrained weights' 335 | print('Loaded pretrained weights for {}'.format(model_name)) 336 | -------------------------------------------------------------------------------- /UniformAugment/data.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import os 4 | 5 | import torch 6 | import torch.distributed as dist 7 | import torchvision 8 | from PIL import Image 9 | from UniformAugment.augmentations import * 10 | from UniformAugment.augmentations import UniformAugment 11 | from UniformAugment.common import get_logger 12 | from UniformAugment.imagenet import ImageNet 13 | from UniformAugment.networks.efficientnet_pytorch.model import EfficientNet 14 | from sklearn.model_selection import StratifiedShuffleSplit 15 | from theconf import Config as C 16 | from torch.utils.data import SubsetRandomSampler, Sampler, Subset, ConcatDataset 17 | from torchvision.transforms import transforms 18 | 19 | logger = get_logger('Fast AutoAugment') 20 | logger.setLevel(logging.INFO) 21 | _IMAGENET_PCA = { 22 | 'eigval': [0.2175, 0.0188, 0.0045], 23 | 'eigvec': [ 24 | [-0.5675, 0.7192, 0.4009], 25 | [-0.5808, -0.0045, -0.8140], 26 | [-0.5836, -0.6948, 0.4203], 27 | ] 28 | } 29 | _CIFAR_MEAN, _CIFAR_STD = (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) 30 | 31 | 32 | def get_dataloaders(dataset, batch, dataroot, split=0.15, split_idx=0, multinode=False, target_lb=-1): 33 | if 'cifar' in dataset or 'svhn' in dataset: 34 | transform_train = transforms.Compose([ 35 | transforms.RandomCrop(32, padding=4), 36 | transforms.RandomHorizontalFlip(), 37 | transforms.ToTensor(), 38 | transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD), 39 | ]) 40 | transform_test = transforms.Compose([ 41 | transforms.ToTensor(), 42 | transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD), 43 | ]) 44 | elif 'imagenet' in dataset: 45 | input_size = 224 46 | sized_size = 256 47 | 48 | if 'efficientnet' in C.get()['model']['type']: 49 | input_size = EfficientNet.get_image_size(C.get()['model']['type']) 50 | sized_size = input_size + 32 # TODO 51 | # sized_size = int(round(input_size / 224. * 256)) 52 | # sized_size = input_size 53 | logger.info('size changed to %d/%d.' % (input_size, sized_size)) 54 | 55 | transform_train = transforms.Compose([ 56 | EfficientNetRandomCrop(input_size), 57 | transforms.Resize((input_size, input_size), interpolation=Image.BICUBIC), 58 | # transforms.RandomResizedCrop(input_size, scale=(0.1, 1.0), interpolation=Image.BICUBIC), 59 | transforms.RandomHorizontalFlip(), 60 | transforms.ColorJitter( 61 | brightness=0.4, 62 | contrast=0.4, 63 | saturation=0.4, 64 | ), 65 | transforms.ToTensor(), 66 | Lighting(0.1, _IMAGENET_PCA['eigval'], _IMAGENET_PCA['eigvec']), 67 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 68 | ]) 69 | 70 | transform_test = transforms.Compose([ 71 | EfficientNetCenterCrop(input_size), 72 | transforms.Resize((input_size, input_size), interpolation=Image.BICUBIC), 73 | transforms.ToTensor(), 74 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 75 | ]) 76 | 77 | else: 78 | raise ValueError('dataset=%s' % dataset) 79 | 80 | total_aug = augs = None 81 | if isinstance(C.get()['aug'], list): 82 | logger.debug('augmentation provided.') 83 | transform_train.transforms.insert(0, Augmentation(C.get()['aug'])) 84 | else: 85 | logger.debug('augmentation: %s' % C.get()['aug']) 86 | if C.get()['aug'] == 'uniformaugment': 87 | transform_train.transforms.insert(0, UniformAugment()) 88 | elif C.get()['aug'] in ['default']: 89 | pass 90 | else: 91 | raise ValueError('not found augmentations. %s' % C.get()['aug']) 92 | 93 | if C.get()['cutout'] > 0: 94 | transform_train.transforms.append(CutoutDefault(C.get()['cutout'])) 95 | 96 | if dataset == 'cifar10': 97 | total_trainset = torchvision.datasets.CIFAR10(root=dataroot, train=True, download=True, transform=transform_train) 98 | testset = torchvision.datasets.CIFAR10(root=dataroot, train=False, download=True, transform=transform_test) 99 | elif dataset == 'reduced_cifar10': 100 | total_trainset = torchvision.datasets.CIFAR10(root=dataroot, train=True, download=True, transform=transform_train) 101 | sss = StratifiedShuffleSplit(n_splits=1, test_size=46000, random_state=0) # 4000 trainset 102 | sss = sss.split(list(range(len(total_trainset))), total_trainset.targets) 103 | train_idx, valid_idx = next(sss) 104 | targets = [total_trainset.targets[idx] for idx in train_idx] 105 | total_trainset = Subset(total_trainset, train_idx) 106 | total_trainset.targets = targets 107 | 108 | testset = torchvision.datasets.CIFAR10(root=dataroot, train=False, download=True, transform=transform_test) 109 | elif dataset == 'cifar100': 110 | total_trainset = torchvision.datasets.CIFAR100(root=dataroot, train=True, download=True, transform=transform_train) 111 | testset = torchvision.datasets.CIFAR100(root=dataroot, train=False, download=True, transform=transform_test) 112 | elif dataset == 'svhn': 113 | trainset = torchvision.datasets.SVHN(root=dataroot, split='train', download=True, transform=transform_train) 114 | extraset = torchvision.datasets.SVHN(root=dataroot, split='extra', download=True, transform=transform_train) 115 | total_trainset = ConcatDataset([trainset, extraset]) 116 | testset = torchvision.datasets.SVHN(root=dataroot, split='test', download=True, transform=transform_test) 117 | elif dataset == 'reduced_svhn': 118 | total_trainset = torchvision.datasets.SVHN(root=dataroot, split='train', download=True, transform=transform_train) 119 | sss = StratifiedShuffleSplit(n_splits=1, test_size=73257-1000, random_state=0) # 1000 trainset 120 | sss = sss.split(list(range(len(total_trainset))), total_trainset.targets) 121 | train_idx, valid_idx = next(sss) 122 | targets = [total_trainset.targets[idx] for idx in train_idx] 123 | total_trainset = Subset(total_trainset, train_idx) 124 | total_trainset.targets = targets 125 | 126 | testset = torchvision.datasets.SVHN(root=dataroot, split='test', download=True, transform=transform_test) 127 | elif dataset == 'imagenet': 128 | total_trainset = ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), transform=transform_train) 129 | testset = ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), split='val', transform=transform_test) 130 | 131 | # compatibility 132 | total_trainset.targets = [lb for _, lb in total_trainset.samples] 133 | elif dataset == 'reduced_imagenet': 134 | # randomly chosen indices 135 | # idx120 = sorted(random.sample(list(range(1000)), k=120)) 136 | idx120 = [16, 23, 52, 57, 76, 93, 95, 96, 99, 121, 122, 128, 148, 172, 181, 189, 202, 210, 232, 238, 257, 258, 259, 277, 283, 289, 295, 304, 307, 318, 322, 331, 337, 338, 345, 350, 361, 375, 376, 381, 388, 399, 401, 408, 424, 431, 432, 440, 447, 462, 464, 472, 483, 497, 506, 512, 530, 541, 553, 554, 557, 564, 570, 584, 612, 614, 619, 626, 631, 632, 650, 657, 658, 660, 674, 675, 680, 682, 691, 695, 699, 711, 734, 736, 741, 754, 757, 764, 769, 770, 780, 781, 787, 797, 799, 811, 822, 829, 830, 835, 837, 842, 843, 845, 873, 883, 897, 900, 902, 905, 913, 920, 925, 937, 938, 940, 941, 944, 949, 959] 137 | total_trainset = ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), transform=transform_train) 138 | testset = ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), split='val', transform=transform_test) 139 | 140 | # compatibility 141 | total_trainset.targets = [lb for _, lb in total_trainset.samples] 142 | 143 | sss = StratifiedShuffleSplit(n_splits=1, test_size=len(total_trainset) - 50000, random_state=0) # 4000 trainset 144 | sss = sss.split(list(range(len(total_trainset))), total_trainset.targets) 145 | train_idx, valid_idx = next(sss) 146 | 147 | # filter out 148 | train_idx = list(filter(lambda x: total_trainset.labels[x] in idx120, train_idx)) 149 | valid_idx = list(filter(lambda x: total_trainset.labels[x] in idx120, valid_idx)) 150 | test_idx = list(filter(lambda x: testset.samples[x][1] in idx120, range(len(testset)))) 151 | 152 | targets = [idx120.index(total_trainset.targets[idx]) for idx in train_idx] 153 | for idx in range(len(total_trainset.samples)): 154 | if total_trainset.samples[idx][1] not in idx120: 155 | continue 156 | total_trainset.samples[idx] = (total_trainset.samples[idx][0], idx120.index(total_trainset.samples[idx][1])) 157 | total_trainset = Subset(total_trainset, train_idx) 158 | total_trainset.targets = targets 159 | 160 | for idx in range(len(testset.samples)): 161 | if testset.samples[idx][1] not in idx120: 162 | continue 163 | testset.samples[idx] = (testset.samples[idx][0], idx120.index(testset.samples[idx][1])) 164 | testset = Subset(testset, test_idx) 165 | print('reduced_imagenet train=', len(total_trainset)) 166 | else: 167 | raise ValueError('invalid dataset name=%s' % dataset) 168 | 169 | if total_aug is not None and augs is not None: 170 | total_trainset.set_preaug(augs, total_aug) 171 | print('set_preaug-') 172 | 173 | train_sampler = None 174 | if split > 0.0: 175 | sss = StratifiedShuffleSplit(n_splits=5, test_size=split, random_state=0) 176 | sss = sss.split(list(range(len(total_trainset))), total_trainset.targets) 177 | for _ in range(split_idx + 1): 178 | train_idx, valid_idx = next(sss) 179 | 180 | if target_lb >= 0: 181 | train_idx = [i for i in train_idx if total_trainset.targets[i] == target_lb] 182 | valid_idx = [i for i in valid_idx if total_trainset.targets[i] == target_lb] 183 | 184 | train_sampler = SubsetRandomSampler(train_idx) 185 | valid_sampler = SubsetSampler(valid_idx) 186 | 187 | if multinode: 188 | train_sampler = torch.utils.data.distributed.DistributedSampler(Subset(total_trainset, train_idx), num_replicas=dist.get_world_size(), rank=dist.get_rank()) 189 | else: 190 | valid_sampler = SubsetSampler([]) 191 | 192 | if multinode: 193 | train_sampler = torch.utils.data.distributed.DistributedSampler(total_trainset, num_replicas=dist.get_world_size(), rank=dist.get_rank()) 194 | logger.info(f'----- dataset with DistributedSampler {dist.get_rank()}/{dist.get_world_size()}') 195 | 196 | trainloader = torch.utils.data.DataLoader( 197 | total_trainset, batch_size=batch, shuffle=True if train_sampler is None else False, num_workers=8, pin_memory=True, 198 | sampler=train_sampler, drop_last=True) 199 | validloader = torch.utils.data.DataLoader( 200 | total_trainset, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True, 201 | sampler=valid_sampler, drop_last=False) 202 | 203 | testloader = torch.utils.data.DataLoader( 204 | testset, batch_size=batch, shuffle=False, num_workers=8, pin_memory=True, 205 | drop_last=False 206 | ) 207 | return train_sampler, trainloader, validloader, testloader 208 | 209 | 210 | class CutoutDefault(object): 211 | """ 212 | Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py 213 | """ 214 | def __init__(self, length): 215 | self.length = length 216 | 217 | def __call__(self, img): 218 | h, w = img.size(1), img.size(2) 219 | mask = np.ones((h, w), np.float32) 220 | y = np.random.randint(h) 221 | x = np.random.randint(w) 222 | 223 | y1 = np.clip(y - self.length // 2, 0, h) 224 | y2 = np.clip(y + self.length // 2, 0, h) 225 | x1 = np.clip(x - self.length // 2, 0, w) 226 | x2 = np.clip(x + self.length // 2, 0, w) 227 | 228 | mask[y1: y2, x1: x2] = 0. 229 | mask = torch.from_numpy(mask) 230 | mask = mask.expand_as(img) 231 | img *= mask 232 | return img 233 | 234 | 235 | class Augmentation(object): 236 | def __init__(self, policies): 237 | self.policies = policies 238 | 239 | def __call__(self, img): 240 | for _ in range(1): 241 | policy = random.choice(self.policies) 242 | for name, pr, level in policy: 243 | if random.random() > pr: 244 | continue 245 | img = apply_augment(img, name, level) 246 | return img 247 | 248 | 249 | class EfficientNetRandomCrop: 250 | def __init__(self, imgsize, min_covered=0.1, aspect_ratio_range=(3./4, 4./3), area_range=(0.08, 1.0), max_attempts=10): 251 | assert 0.0 < min_covered 252 | assert 0 < aspect_ratio_range[0] <= aspect_ratio_range[1] 253 | assert 0 < area_range[0] <= area_range[1] 254 | assert 1 <= max_attempts 255 | 256 | self.min_covered = min_covered 257 | self.aspect_ratio_range = aspect_ratio_range 258 | self.area_range = area_range 259 | self.max_attempts = max_attempts 260 | self._fallback = EfficientNetCenterCrop(imgsize) 261 | 262 | def __call__(self, img): 263 | # https://github.com/tensorflow/tensorflow/blob/9274bcebb31322370139467039034f8ff852b004/tensorflow/core/kernels/sample_distorted_bounding_box_op.cc#L111 264 | original_width, original_height = img.size 265 | min_area = self.area_range[0] * (original_width * original_height) 266 | max_area = self.area_range[1] * (original_width * original_height) 267 | 268 | for _ in range(self.max_attempts): 269 | aspect_ratio = random.uniform(*self.aspect_ratio_range) 270 | height = int(round(math.sqrt(min_area / aspect_ratio))) 271 | max_height = int(round(math.sqrt(max_area / aspect_ratio))) 272 | 273 | if max_height * aspect_ratio > original_width: 274 | max_height = (original_width + 0.5 - 1e-7) / aspect_ratio 275 | max_height = int(max_height) 276 | if max_height * aspect_ratio > original_width: 277 | max_height -= 1 278 | 279 | if max_height > original_height: 280 | max_height = original_height 281 | 282 | if height >= max_height: 283 | height = max_height 284 | 285 | height = int(round(random.uniform(height, max_height))) 286 | width = int(round(height * aspect_ratio)) 287 | area = width * height 288 | 289 | if area < min_area or area > max_area: 290 | continue 291 | if width > original_width or height > original_height: 292 | continue 293 | if area < self.min_covered * (original_width * original_height): 294 | continue 295 | if width == original_width and height == original_height: 296 | return self._fallback(img) # https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/preprocessing.py#L102 297 | 298 | x = random.randint(0, original_width - width) 299 | y = random.randint(0, original_height - height) 300 | return img.crop((x, y, x + width, y + height)) 301 | 302 | return self._fallback(img) 303 | 304 | 305 | class EfficientNetCenterCrop: 306 | def __init__(self, imgsize): 307 | self.imgsize = imgsize 308 | 309 | def __call__(self, img): 310 | """Crop the given PIL Image and resize it to desired size. 311 | 312 | Args: 313 | img (PIL Image): Image to be cropped. (0,0) denotes the top left corner of the image. 314 | output_size (sequence or int): (height, width) of the crop box. If int, 315 | it is used for both directions 316 | Returns: 317 | PIL Image: Cropped image. 318 | """ 319 | image_width, image_height = img.size 320 | image_short = min(image_width, image_height) 321 | 322 | crop_size = float(self.imgsize) / (self.imgsize + 32) * image_short 323 | 324 | crop_height, crop_width = crop_size, crop_size 325 | crop_top = int(round((image_height - crop_height) / 2.)) 326 | crop_left = int(round((image_width - crop_width) / 2.)) 327 | return img.crop((crop_left, crop_top, crop_left + crop_width, crop_top + crop_height)) 328 | 329 | 330 | class SubsetSampler(Sampler): 331 | r"""Samples elements from a given list of indices, without replacement. 332 | 333 | Arguments: 334 | indices (sequence): a sequence of indices 335 | """ 336 | 337 | def __init__(self, indices): 338 | self.indices = indices 339 | 340 | def __iter__(self): 341 | return (i for i in self.indices) 342 | 343 | def __len__(self): 344 | return len(self.indices) 345 | -------------------------------------------------------------------------------- /UniformAugment/train.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import sys 3 | 4 | sys.path.append(str(pathlib.Path(__file__).parent.parent.absolute())) 5 | 6 | import itertools 7 | import json 8 | import logging 9 | import math 10 | import os 11 | from collections import OrderedDict 12 | 13 | import torch 14 | from torch import nn, optim 15 | from torch.nn.parallel.data_parallel import DataParallel 16 | from torch.nn.parallel import DistributedDataParallel 17 | import torch.distributed as dist 18 | 19 | from tqdm import tqdm 20 | from theconf import Config as C, ConfigArgumentParser 21 | 22 | from UniformAugment.common import get_logger, EMA, add_filehandler 23 | from UniformAugment.data import get_dataloaders 24 | from UniformAugment.lr_scheduler import adjust_learning_rate_resnet 25 | from UniformAugment.metrics import accuracy, Accumulator, CrossEntropyLabelSmooth 26 | from UniformAugment.networks import get_model, num_class 27 | from UniformAugment.aug_mixup import CrossEntropyMixUpLabelSmooth, mixup 28 | from warmup_scheduler import GradualWarmupScheduler 29 | 30 | logger = get_logger('Fast AutoAugment') 31 | logger.setLevel(logging.INFO) 32 | 33 | 34 | def run_epoch(model, loader, loss_fn, optimizer, desc_default='', epoch=0, writer=None, verbose=1, scheduler=None, is_master=True, ema=None, wd=0.0, tqdm_disabled=False): 35 | if verbose: 36 | loader = tqdm(loader, disable=tqdm_disabled) 37 | loader.set_description('[%s %04d/%04d]' % (desc_default, epoch, C.get()['epoch'])) 38 | 39 | params_without_bn = [params for name, params in model.named_parameters() if not ('_bn' in name or '.bn' in name)] 40 | 41 | loss_ema = None 42 | metrics = Accumulator() 43 | cnt = 0 44 | total_steps = len(loader) 45 | steps = 0 46 | for data, label in loader: 47 | steps += 1 48 | data, label = data.cuda(), label.cuda() 49 | 50 | if C.get().conf.get('mixup', 0.0) <= 0.0 or optimizer is None: 51 | preds = model(data) 52 | loss = loss_fn(preds, label) 53 | else: # mixup 54 | data, targets, shuffled_targets, lam = mixup(data, label, C.get()['mixup']) 55 | preds = model(data) 56 | loss = loss_fn(preds, targets, shuffled_targets, lam) 57 | del shuffled_targets, lam 58 | 59 | if optimizer: 60 | loss += wd * (1. / 2.) * sum([torch.sum(p ** 2) for p in params_without_bn]) 61 | loss.backward() 62 | grad_clip = C.get()['optimizer'].get('clip', 5.0) 63 | if grad_clip > 0: 64 | nn.utils.clip_grad_norm_(model.parameters(), grad_clip) 65 | optimizer.step() 66 | optimizer.zero_grad() 67 | 68 | if ema is not None: 69 | ema(model, (epoch - 1) * total_steps + steps) 70 | 71 | top1, top5 = accuracy(preds, label, (1, 5)) 72 | metrics.add_dict({ 73 | 'loss': loss.item() * len(data), 74 | 'top1': top1.item() * len(data), 75 | 'top5': top5.item() * len(data), 76 | }) 77 | cnt += len(data) 78 | if loss_ema: 79 | loss_ema = loss_ema * 0.9 + loss.item() * 0.1 80 | else: 81 | loss_ema = loss.item() 82 | if verbose: 83 | postfix = metrics / cnt 84 | if optimizer: 85 | postfix['lr'] = optimizer.param_groups[0]['lr'] 86 | postfix['loss_ema'] = loss_ema 87 | loader.set_postfix(postfix) 88 | 89 | if scheduler is not None: 90 | scheduler.step(epoch - 1 + float(steps) / total_steps) 91 | 92 | del preds, loss, top1, top5, data, label 93 | 94 | if tqdm_disabled and verbose: 95 | if optimizer: 96 | logger.info('[%s %03d/%03d] %s lr=%.6f', desc_default, epoch, C.get()['epoch'], metrics / cnt, optimizer.param_groups[0]['lr']) 97 | else: 98 | logger.info('[%s %03d/%03d] %s', desc_default, epoch, C.get()['epoch'], metrics / cnt) 99 | 100 | metrics /= cnt 101 | if optimizer: 102 | metrics.metrics['lr'] = optimizer.param_groups[0]['lr'] 103 | if verbose: 104 | for key, value in metrics.items(): 105 | writer.add_scalar(key, value, epoch) 106 | return metrics 107 | 108 | 109 | def train_and_eval(tag, dataroot, test_ratio=0.0, cv_fold=0, reporter=None, metric='last', save_path=None, only_eval=False, local_rank=-1, evaluation_interval=5): 110 | total_batch = C.get()["batch"] 111 | if local_rank >= 0: 112 | dist.init_process_group(backend='nccl', init_method='env://', world_size=int(os.environ['WORLD_SIZE'])) 113 | device = torch.device('cuda', local_rank) 114 | torch.cuda.set_device(device) 115 | 116 | C.get()['lr'] *= dist.get_world_size() 117 | logger.info(f'local batch={C.get()["batch"]} world_size={dist.get_world_size()} ----> total batch={C.get()["batch"] * dist.get_world_size()}') 118 | total_batch = C.get()["batch"] * dist.get_world_size() 119 | 120 | is_master = local_rank < 0 or dist.get_rank() == 0 121 | if is_master: 122 | add_filehandler(logger, args.save + '.log') 123 | 124 | if not reporter: 125 | reporter = lambda **kwargs: 0 126 | 127 | max_epoch = C.get()['epoch'] 128 | trainsampler, trainloader, validloader, testloader_ = get_dataloaders(C.get()['dataset'], C.get()['batch'], dataroot, test_ratio, split_idx=cv_fold, multinode=(local_rank >= 0)) 129 | 130 | # create a model & an optimizer 131 | model = get_model(C.get()['model'], num_class(C.get()['dataset']), local_rank=local_rank) 132 | model_ema = get_model(C.get()['model'], num_class(C.get()['dataset']), local_rank=-1) 133 | model_ema.eval() 134 | 135 | criterion_ce = criterion = CrossEntropyLabelSmooth(num_class(C.get()['dataset']), C.get().conf.get('lb_smooth', 0)) 136 | if C.get().conf.get('mixup', 0.0) > 0.0: 137 | criterion = CrossEntropyMixUpLabelSmooth(num_class(C.get()['dataset']), C.get().conf.get('lb_smooth', 0)) 138 | if C.get()['optimizer']['type'] == 'sgd': 139 | optimizer = optim.SGD( 140 | model.parameters(), 141 | lr=C.get()['lr'], 142 | momentum=C.get()['optimizer'].get('momentum', 0.9), 143 | weight_decay=0.0, 144 | nesterov=C.get()['optimizer'].get('nesterov', True) 145 | ) 146 | else: 147 | raise ValueError('invalid optimizer type=%s' % C.get()['optimizer']['type']) 148 | 149 | lr_scheduler_type = C.get()['lr_schedule'].get('type', 'cosine') 150 | if lr_scheduler_type == 'cosine': 151 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=C.get()['epoch'], eta_min=0.) 152 | elif lr_scheduler_type == 'resnet': 153 | scheduler = adjust_learning_rate_resnet(optimizer) 154 | elif lr_scheduler_type == 'efficientnet': 155 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: 0.97 ** int((x + C.get()['lr_schedule']['warmup']['epoch']) / 2.4)) 156 | else: 157 | raise ValueError('invalid lr_schduler=%s' % lr_scheduler_type) 158 | 159 | if C.get()['lr_schedule'].get('warmup', None) and C.get()['lr_schedule']['warmup']['epoch'] > 0: 160 | scheduler = GradualWarmupScheduler( 161 | optimizer, 162 | multiplier=C.get()['lr_schedule']['warmup']['multiplier'], 163 | total_epoch=C.get()['lr_schedule']['warmup']['epoch'], 164 | after_scheduler=scheduler 165 | ) 166 | 167 | if not tag or not is_master: 168 | from UniformAugment.metrics import SummaryWriterDummy as SummaryWriter 169 | logger.warning('tag not provided, no tensorboard log.') 170 | else: 171 | from tensorboardX import SummaryWriter 172 | writers = [SummaryWriter(log_dir='./logs/%s/%s' % (tag, x)) for x in ['train', 'valid', 'test']] 173 | 174 | if C.get()['optimizer']['ema'] > 0.0 and is_master: 175 | # https://discuss.pytorch.org/t/how-to-apply-exponential-moving-average-decay-for-variables/10856/4?u=ildoonet 176 | ema = EMA(C.get()['optimizer']['ema']) 177 | else: 178 | ema = None 179 | 180 | result = OrderedDict() 181 | epoch_start = 1 182 | if save_path != 'test.pth': # and is_master: --> should load all data(not able to be broadcasted) 183 | if save_path and os.path.exists(save_path): 184 | logger.info('%s file found. loading...' % save_path) 185 | data = torch.load(save_path) 186 | key = 'model' if 'model' in data else 'state_dict' 187 | 188 | if 'epoch' not in data: 189 | model.load_state_dict(data) 190 | else: 191 | logger.info('checkpoint epoch@%d' % data['epoch']) 192 | if not isinstance(model, (DataParallel, DistributedDataParallel)): 193 | model.load_state_dict({k.replace('module.', ''): v for k, v in data[key].items()}) 194 | else: 195 | model.load_state_dict({k if 'module.' in k else 'module.'+k: v for k, v in data[key].items()}) 196 | logger.info('optimizer.load_state_dict+') 197 | optimizer.load_state_dict(data['optimizer']) 198 | if data['epoch'] < C.get()['epoch']: 199 | epoch_start = data['epoch'] 200 | else: 201 | only_eval = True 202 | if ema is not None: 203 | ema.shadow = data.get('ema', {}) if isinstance(data.get('ema', {}), dict) else data['ema'].state_dict() 204 | del data 205 | else: 206 | logger.info('"%s" file not found. skip to pretrain weights...' % save_path) 207 | if only_eval: 208 | logger.warning('model checkpoint not found. only-evaluation mode is off.') 209 | only_eval = False 210 | 211 | if local_rank >= 0: 212 | for name, x in model.state_dict().items(): 213 | dist.broadcast(x, 0) 214 | logger.info(f'multinode init. local_rank={dist.get_rank()} is_master={is_master}') 215 | torch.cuda.synchronize() 216 | 217 | tqdm_disabled = bool(os.environ.get('TASK_NAME', '')) and local_rank != 0 # KakaoBrain Environment 218 | 219 | if only_eval: 220 | logger.info('evaluation only+') 221 | model.eval() 222 | rs = dict() 223 | rs['train'] = run_epoch(model, trainloader, criterion, None, desc_default='train', epoch=0, writer=writers[0], is_master=is_master) 224 | 225 | with torch.no_grad(): 226 | rs['valid'] = run_epoch(model, validloader, criterion, None, desc_default='valid', epoch=0, writer=writers[1], is_master=is_master) 227 | rs['test'] = run_epoch(model, testloader_, criterion, None, desc_default='*test', epoch=0, writer=writers[2], is_master=is_master) 228 | if ema is not None and len(ema) > 0: 229 | model_ema.load_state_dict({k.replace('module.', ''): v for k, v in ema.state_dict().items()}) 230 | rs['valid'] = run_epoch(model_ema, validloader, criterion_ce, None, desc_default='valid(EMA)', epoch=0, writer=writers[1], verbose=is_master, tqdm_disabled=tqdm_disabled) 231 | rs['test'] = run_epoch(model_ema, testloader_, criterion_ce, None, desc_default='*test(EMA)', epoch=0, writer=writers[2], verbose=is_master, tqdm_disabled=tqdm_disabled) 232 | for key, setname in itertools.product(['loss', 'top1', 'top5'], ['train', 'valid', 'test']): 233 | if setname not in rs: 234 | continue 235 | result['%s_%s' % (key, setname)] = rs[setname][key] 236 | result['epoch'] = 0 237 | return result 238 | 239 | # train loop 240 | best_top1 = 0 241 | for epoch in range(epoch_start, max_epoch + 1): 242 | if local_rank >= 0: 243 | trainsampler.set_epoch(epoch) 244 | 245 | model.train() 246 | rs = dict() 247 | rs['train'] = run_epoch(model, trainloader, criterion, optimizer, desc_default='train', epoch=epoch, writer=writers[0], verbose=(is_master and local_rank <= 0), scheduler=scheduler, ema=ema, wd=C.get()['optimizer']['decay'], tqdm_disabled=tqdm_disabled) 248 | model.eval() 249 | 250 | if math.isnan(rs['train']['loss']): 251 | raise Exception('train loss is NaN.') 252 | 253 | if ema is not None and C.get()['optimizer']['ema_interval'] > 0 and epoch % C.get()['optimizer']['ema_interval'] == 0: 254 | logger.info(f'ema synced+ rank={dist.get_rank()}') 255 | if ema is not None: 256 | model.load_state_dict(ema.state_dict()) 257 | for name, x in model.state_dict().items(): 258 | # print(name) 259 | dist.broadcast(x, 0) 260 | torch.cuda.synchronize() 261 | logger.info(f'ema synced- rank={dist.get_rank()}') 262 | 263 | if is_master and (epoch % evaluation_interval == 0 or epoch == max_epoch): 264 | with torch.no_grad(): 265 | rs['valid'] = run_epoch(model, validloader, criterion_ce, None, desc_default='valid', epoch=epoch, writer=writers[1], verbose=is_master, tqdm_disabled=tqdm_disabled) 266 | rs['test'] = run_epoch(model, testloader_, criterion_ce, None, desc_default='*test', epoch=epoch, writer=writers[2], verbose=is_master, tqdm_disabled=tqdm_disabled) 267 | 268 | if ema is not None: 269 | model_ema.load_state_dict({k.replace('module.', ''): v for k, v in ema.state_dict().items()}) 270 | rs['valid'] = run_epoch(model_ema, validloader, criterion_ce, None, desc_default='valid(EMA)', epoch=epoch, writer=writers[1], verbose=is_master, tqdm_disabled=tqdm_disabled) 271 | rs['test'] = run_epoch(model_ema, testloader_, criterion_ce, None, desc_default='*test(EMA)', epoch=epoch, writer=writers[2], verbose=is_master, tqdm_disabled=tqdm_disabled) 272 | 273 | logger.info( 274 | f'epoch={epoch} ' 275 | f'[train] loss={rs["train"]["loss"]:.4f} top1={rs["train"]["top1"]:.4f} ' 276 | f'[valid] loss={rs["valid"]["loss"]:.4f} top1={rs["valid"]["top1"]:.4f} ' 277 | f'[test] loss={rs["test"]["loss"]:.4f} top1={rs["test"]["top1"]:.4f} ' 278 | ) 279 | 280 | if metric == 'last' or rs[metric]['top1'] > best_top1: 281 | if metric != 'last': 282 | best_top1 = rs[metric]['top1'] 283 | for key, setname in itertools.product(['loss', 'top1', 'top5'], ['train', 'valid', 'test']): 284 | result['%s_%s' % (key, setname)] = rs[setname][key] 285 | result['epoch'] = epoch 286 | 287 | writers[1].add_scalar('valid_top1/best', rs['valid']['top1'], epoch) 288 | writers[2].add_scalar('test_top1/best', rs['test']['top1'], epoch) 289 | 290 | reporter( 291 | loss_valid=rs['valid']['loss'], top1_valid=rs['valid']['top1'], 292 | loss_test=rs['test']['loss'], top1_test=rs['test']['top1'] 293 | ) 294 | 295 | # save checkpoint 296 | if is_master and save_path: 297 | logger.info('save model@%d to %s, err=%.4f' % (epoch, save_path, 1 - best_top1)) 298 | torch.save({ 299 | 'epoch': epoch, 300 | 'log': { 301 | 'train': rs['train'].get_dict(), 302 | 'valid': rs['valid'].get_dict(), 303 | 'test': rs['test'].get_dict(), 304 | }, 305 | 'optimizer': optimizer.state_dict(), 306 | 'model': model.state_dict(), 307 | 'ema': ema.state_dict() if ema is not None else None, 308 | }, save_path) 309 | 310 | del model 311 | 312 | result['top1_test'] = best_top1 313 | return result 314 | 315 | 316 | if __name__ == '__main__': 317 | parser = ConfigArgumentParser(conflict_handler='resolve') 318 | parser.add_argument('--tag', type=str, default='') 319 | parser.add_argument('--dataroot', type=str, default='/data/private/pretrainedmodels', help='torchvision data folder') 320 | parser.add_argument('--save', type=str, default='test.pth') 321 | parser.add_argument('--cv-ratio', type=float, default=0.0) 322 | parser.add_argument('--cv', type=int, default=0) 323 | parser.add_argument('--local_rank', type=int, default=-1) 324 | parser.add_argument('--evaluation-interval', type=int, default=5) 325 | parser.add_argument('--only-eval', action='store_true') 326 | args = parser.parse_args() 327 | 328 | assert (args.only_eval and args.save) or not args.only_eval, 'checkpoint path not provided in evaluation mode.' 329 | 330 | if not args.only_eval: 331 | if args.save: 332 | logger.info('checkpoint will be saved at %s' % args.save) 333 | else: 334 | logger.warning('Provide --save argument to save the checkpoint. Without it, training result will not be saved!') 335 | 336 | import time 337 | t = time.time() 338 | result = train_and_eval(args.tag, args.dataroot, test_ratio=args.cv_ratio, cv_fold=args.cv, save_path=args.save, only_eval=args.only_eval, local_rank=args.local_rank, metric='test', evaluation_interval=args.evaluation_interval) 339 | elapsed = time.time() - t 340 | 341 | logger.info('done.') 342 | logger.info('model: %s' % C.get()['model']) 343 | logger.info('augmentation: %s' % C.get()['aug']) 344 | logger.info('\n' + json.dumps(result, indent=4)) 345 | logger.info('elapsed time: %.3f Hours' % (elapsed / 3600.)) 346 | logger.info('top1 error in testset: %.4f' % (1. - result['top1_test'])) 347 | logger.info(args.save) 348 | --------------------------------------------------------------------------------