├── LICENSE ├── README.md ├── config ├── CIFAR10 │ ├── pyramidnet.yaml │ ├── shake-shake.yaml │ └── wrn-28-10.yaml ├── CIFAR100 │ ├── pyramidnet.yaml │ ├── shake-shake.yaml │ └── wrn-28-10.yaml └── ImageNet │ └── resnet50.yaml ├── lib ├── __init__.py ├── augmentation │ ├── __init__.py │ ├── augmentation_container.py │ ├── cutout.py │ ├── imagenet_augmentation.py │ ├── nn_aug.py │ └── replay_buffer.py ├── losses │ ├── __init__.py │ └── non_saturating_loss.py ├── models │ ├── __init__.py │ ├── pyramidnet.py │ ├── resnet.py │ ├── shakedrop.py │ ├── shakeshake │ │ ├── __init__.py │ │ ├── shake_resnet.py │ │ ├── shake_resnext.py │ │ └── shakeshake.py │ └── wide_resnet.py ├── teachaugment.py └── utils │ ├── __init__.py │ ├── lr_scheduler.py │ └── utils.py └── main.py /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (C) 2022 Denso IT Laboratory, Inc. 2 | All Rights Reserved 3 | 4 | Denso IT Laboratory, Inc. retains sole and exclusive ownership of all 5 | intellectual property rights including copyrights and patents related to this 6 | Software. 7 | 8 | Permission is hereby granted, free of charge, to any person obtaining a copy 9 | of the Software and accompanying documentation to use, copy, modify, merge, 10 | publish, or distribute the Software or software derived from it for 11 | non-commercial purposes, such as academic study, education and personal use, 12 | subject to the following conditions: 13 | 14 | 1. Redistributions of source code must retain the above copyright notice, 15 | this list of conditions and the following disclaimer. 16 | 17 | 2. Redistributions in binary form must reproduce the above copyright notice, 18 | this list of conditions and the following disclaimer in the documentation 19 | and/or other materials provided with the distribution. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 24 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 25 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 26 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 27 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 28 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 29 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 30 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 31 | POSSIBILITY OF SUCH DAMAGE. 32 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TeachAugment: Data Augmentation Optimization Using Teacher Knowledge (CVPR2022, Oral) 2 | Official Implementation of TeachAugment in PyTorch. 3 | arXiv: https://arxiv.org/abs/2202.12513 4 | 5 | ## Requirements 6 | - PyTorch >= 1.9 7 | - Torchvision >= 0.10 8 | 9 | ## Run 10 | Training with single GPU 11 | ``` 12 | python main.py --yaml ./config/$DATASET_NAME/$MODEL 13 | ``` 14 | 15 | Training with single node multi-GPU 16 | ``` 17 | python -m torch.distributed.launch --nproc_per_node=$N_GPUS main.py \ 18 | --yaml ./config/$DATASET_NAME/$MODEL --dist 19 | ``` 20 | 21 | Examples 22 | ``` 23 | # Training WRN-28-10 on CIFAR-100 24 | python main.py --yaml ./config/CIFAR100/wrn-28-10.yaml 25 | # Training ResNet-50 on ImageNet with 4 GPUs 26 | python -m torch.distributed.launch --nproc_per_node=4 main.py \ 27 | --yaml ./config/ImageNet/resnet50.yaml --dist 28 | ``` 29 | If the computational resources are limited, please try `--save_memory` option. 30 | 31 | 32 | ## Citation 33 | If you find our project useful in your research, please cite it as follows: 34 | ``` 35 | @InProceedings{Suzuki_2022_CVPR, 36 | author = {Suzuki, Teppei}, 37 | title = {TeachAugment: Data Augmentation Optimization Using Teacher Knowledge}, 38 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 39 | month = {June}, 40 | year = {2022}, 41 | pages = {10904-10914} 42 | } 43 | ``` 44 | 45 | ## Acknowledgement 46 | The files in ```./lib/models``` and the code in ```./lib/augmentation/imagenet_augmentation.py``` are based on the implementation of [Fast AutoAugment](https://github.com/kakaobrain/fast-autoaugment). 47 | 48 | -------------------------------------------------------------------------------- /config/CIFAR10/pyramidnet.yaml: -------------------------------------------------------------------------------- 1 | model: pyramidnet 2 | dataset: CIFAR10 3 | 4 | optim: 5 | lr: 0.05 6 | batch_size: 64 7 | weight_decay: 0.00005 8 | n_epochs: 1800 9 | 10 | label_smooth: 11 | epsilon: 0.5 12 | 13 | rb_decay: 0.99 -------------------------------------------------------------------------------- /config/CIFAR10/shake-shake.yaml: -------------------------------------------------------------------------------- 1 | model: ss96 2 | dataset: CIFAR10 3 | 4 | optim: 5 | lr: 0.01 6 | weight_decay: 0.001 7 | n_epochs: 1800 8 | 9 | label_smooth: 10 | epsilon: 0.5 11 | 12 | rb_decay: 0.99 -------------------------------------------------------------------------------- /config/CIFAR10/wrn-28-10.yaml: -------------------------------------------------------------------------------- 1 | model: wrn-28-10 2 | dataset: CIFAR10 3 | 4 | optim: 5 | lr: 0.1 6 | weight_decay: 0.0005 7 | n_epochs: 200 8 | 9 | label_smooth: 10 | epsilon: 0.2 11 | -------------------------------------------------------------------------------- /config/CIFAR100/pyramidnet.yaml: -------------------------------------------------------------------------------- 1 | model: pyramidnet 2 | dataset: CIFAR100 3 | 4 | optim: 5 | lr: 0.05 6 | batch_size: 64 7 | weight_decay: 0.00005 8 | n_epochs: 1800 9 | 10 | label_smooth: 11 | epsilon: 0.1 12 | 13 | rb_decay: 0.99 -------------------------------------------------------------------------------- /config/CIFAR100/shake-shake.yaml: -------------------------------------------------------------------------------- 1 | model: ss96 2 | dataset: CIFAR100 3 | 4 | optim: 5 | lr: 0.01 6 | weight_decay: 0.001 7 | n_epochs: 1800 8 | 9 | label_smooth: 10 | epsilon: 0.2 11 | 12 | rb_decay: 0.99 -------------------------------------------------------------------------------- /config/CIFAR100/wrn-28-10.yaml: -------------------------------------------------------------------------------- 1 | model: wrn-28-10 2 | dataset: CIFAR100 3 | 4 | optim: 5 | lr: 0.1 6 | weight_decay: 0.0005 7 | n_epochs: 200 8 | 9 | label_smooth: 10 | epsilon: 0.1 11 | -------------------------------------------------------------------------------- /config/ImageNet/resnet50.yaml: -------------------------------------------------------------------------------- 1 | model: resnet50 2 | dataset: ImageNet 3 | 4 | optim: 5 | lr: 0.05 6 | weight_decay: 0.0001 7 | n_epochs: 270 8 | 9 | label_smooth: 10 | epsilon: 0.1 11 | 12 | replay_buffer: 13 | sampling_freq: 1 14 | -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torchvision.datasets as D 4 | import torchvision.transforms as T 5 | 6 | 7 | def build_dataset(dataset_name, root, train_transform=None, val_transform=None): 8 | if train_transform is None: 9 | train_transform = T.Compose([T.ToTensor()]) 10 | if val_transform is None: 11 | val_transform = T.Compose([T.ToTensor()]) 12 | if dataset_name != 'ImageNet': 13 | train_data = D.__dict__[dataset_name](root, download=True, transform=train_transform) 14 | test_data = D.__dict__[dataset_name](root, train=False, transform=val_transform) 15 | n_classes = 10 if dataset_name == 'CIFAR10' else 100 16 | else: 17 | train_data = D.ImageFolder(os.path.join(root, 'train'), train_transform) 18 | test_data = D.ImageFolder(os.path.join(root, 'val'), val_transform) 19 | n_classes = 1000 20 | 21 | return train_data, test_data, n_classes 22 | -------------------------------------------------------------------------------- /lib/augmentation/__init__.py: -------------------------------------------------------------------------------- 1 | from . import augmentation_container 2 | from . import cutout 3 | from . import imagenet_augmentation 4 | from . import nn_aug 5 | from . import replay_buffer 6 | 7 | 8 | def get_transforms(dataset): 9 | import torchvision.transforms as T 10 | if dataset == 'ImageNet': 11 | normalizer = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 12 | train_transform = T.Compose([imagenet_augmentation.EfficientNetRandomCrop(224), 13 | T.Resize((224, 224), interpolation=T.InterpolationMode.BICUBIC), 14 | T.ToTensor()]) 15 | base_aug = [T.RandomHorizontalFlip(), 16 | T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 17 | imagenet_augmentation.Lighting(), 18 | normalizer] 19 | val_transform = T.Compose([imagenet_augmentation.EfficientNetCenterCrop(224), 20 | T.Resize((224, 224), interpolation=T.InterpolationMode.BICUBIC), 21 | T.ToTensor(), 22 | normalizer]) 23 | else: 24 | normalizer = T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 25 | train_transform = T.Compose([T.RandomCrop(32, padding=4), 26 | T.RandomHorizontalFlip(), 27 | T.ToTensor()]) 28 | base_aug = [normalizer, 29 | cutout.Cutout()] 30 | val_transform = T.Compose([T.ToTensor(), normalizer]) 31 | return base_aug, train_transform, val_transform, normalizer 32 | 33 | 34 | def build_augmentation(n_classes, g_scale, c_scale, c_reg_coef=0, normalizer=None, replay_buffer=None, n_chunk=16, with_context=True): 35 | g_aug = nn_aug.GeometricAugmentation(n_classes, g_scale, with_context=with_context) 36 | c_aug = nn_aug.ColorAugmentation(n_classes, c_scale, with_context=with_context) 37 | augmentation = augmentation_container.AugmentationContainer(c_aug, g_aug, c_reg_coef, normalizer, replay_buffer, n_chunk) 38 | return augmentation 39 | -------------------------------------------------------------------------------- /lib/augmentation/augmentation_container.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def slicd_Wasserstein_distance(x1, x2, n_projection=128): 6 | x1 = x1.flatten(-2).transpose(1, 2).contiguous() # (b, 3, h, w) -> (b, n, 3) 7 | x2 = x2.flatten(-2).transpose(1, 2).contiguous() 8 | rand_proj = torch.randn(3, n_projection, device=x1.device) 9 | rand_proj = rand_proj / (rand_proj.norm(2, dim=0, keepdim=True) + 1e-12) 10 | sorted_proj_x1 = torch.matmul(x1, rand_proj).sort(0)[0] 11 | sorted_proj_x2 = torch.matmul(x2, rand_proj).sort(0)[0] 12 | return (sorted_proj_x1 - sorted_proj_x2).pow(2).mean() 13 | 14 | 15 | class AugmentationContainer(nn.Module): 16 | def __init__( 17 | self, c_aug, g_aug, c_reg_coef=0, 18 | normalizer=None, replay_buffer=None, n_chunk=16): 19 | super().__init__() 20 | self.c_aug = c_aug 21 | self.g_aug = g_aug 22 | self.c_reg_coef = c_reg_coef 23 | self.normalizer = normalizer 24 | self.replay_buffer = replay_buffer 25 | self.n_chunk = n_chunk 26 | 27 | def get_params(self, x, c, c_aug, g_aug): 28 | # sample noise vector from unit gauss 29 | noise = x.new(x.shape[0], self.g_aug.n_dim).normal_() 30 | target = self.normalizer(x) if self.normalizer is not None else x 31 | # sample augmentation parameters 32 | grid = g_aug(target, noise, c) 33 | scale, shift = c_aug(target, noise, c) 34 | return (scale, shift), grid 35 | 36 | def augmentation(self, x, c, c_aug, g_aug, update=False): 37 | c_param, g_param = self.get_params(x, c, c_aug, g_aug) 38 | # color augmentation 39 | aug_x = c_aug.transform(x, *c_param) 40 | # color regularization 41 | if update and self.c_reg_coef > 0: 42 | if self.normalizer is not None: 43 | swd = self.c_reg_coef * slicd_Wasserstein_distance(self.normalizer(x), self.normalizer(aug_x)) 44 | else: 45 | swd = self.c_reg_coef * slicd_Wasserstein_distance(x, aug_x) 46 | else: 47 | swd = torch.zeros(1, device=x.device) 48 | # geometric augmentation 49 | aug_x = g_aug.transform(aug_x, g_param) 50 | return aug_x, swd 51 | 52 | def forward(self, x, c, update=False): 53 | if update or self.replay_buffer is None or len(self.replay_buffer) == 0: 54 | x, swd = self.augmentation(x, c, self.c_aug, self.g_aug, update) 55 | else: 56 | policies = self.replay_buffer.sampling(self.n_chunk, self.get_augmentation_model()) 57 | if c is not None: 58 | x = torch.cat([self.augmentation(_x, _c, *policy, update)[0] 59 | for _x, _c, policy in zip(x.chunk(self.n_chunk), c.chunk(self.n_chunk), policies)]) 60 | else: 61 | x = torch.cat([self.augmentation(_x, None, *policy, update)[0] 62 | for _x, policy in zip(x.chunk(self.n_chunk), policies)]) 63 | 64 | swd = torch.zeros(1, device=x.device) 65 | return x, swd 66 | 67 | def get_augmentation_model(self): 68 | return nn.ModuleList([self.c_aug, self.g_aug]) 69 | 70 | def reset(self): 71 | # initialize parameters 72 | self.c_aug.reset() 73 | self.g_aug.reset() 74 | -------------------------------------------------------------------------------- /lib/augmentation/cutout.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch 4 | 5 | 6 | def _gen_cutout_coord(height, width, size): 7 | height_loc = random.randint(0, height - 1) 8 | width_loc = random.randint(0, width - 1) 9 | 10 | upper_coord = (max(0, height_loc - size // 2), 11 | max(0, width_loc - size // 2)) 12 | lower_coord = (min(height, height_loc + size // 2), 13 | min(width, width_loc + size // 2)) 14 | 15 | return upper_coord, lower_coord 16 | 17 | 18 | class Cutout(torch.nn.Module): 19 | def __init__(self, size=16): 20 | super().__init__() 21 | self.size = size 22 | 23 | def forward(self, img): 24 | h, w = img.shape[-2:] 25 | upper_coord, lower_coord = _gen_cutout_coord(h, w, self.size) 26 | 27 | mask_height = lower_coord[0] - upper_coord[0] 28 | mask_width = lower_coord[1] - upper_coord[1] 29 | assert mask_height > 0 30 | assert mask_width > 0 31 | 32 | mask = torch.ones_like(img) 33 | mask[..., upper_coord[0]:lower_coord[0], upper_coord[1]:lower_coord[1]] = 0 34 | return img * mask 35 | -------------------------------------------------------------------------------- /lib/augmentation/imagenet_augmentation.py: -------------------------------------------------------------------------------- 1 | # this code is taken from https://github.com/kakaobrain/fast-autoaugment 2 | import math 3 | import random 4 | 5 | import torch 6 | 7 | 8 | class EfficientNetRandomCrop: 9 | def __init__(self, imgsize, min_covered=0.1, aspect_ratio_range=(3./4, 4./3), area_range=(0.08, 1.0), max_attempts=10): 10 | assert 0.0 < min_covered 11 | assert 0 < aspect_ratio_range[0] <= aspect_ratio_range[1] 12 | assert 0 < area_range[0] <= area_range[1] 13 | assert 1 <= max_attempts 14 | 15 | self.min_covered = min_covered 16 | self.aspect_ratio_range = aspect_ratio_range 17 | self.area_range = area_range 18 | self.max_attempts = max_attempts 19 | self._fallback = EfficientNetCenterCrop(imgsize) 20 | 21 | def __call__(self, img): 22 | # https://github.com/tensorflow/tensorflow/blob/9274bcebb31322370139467039034f8ff852b004/tensorflow/core/kernels/sample_distorted_bounding_box_op.cc#L111 23 | original_width, original_height = img.size 24 | min_area = self.area_range[0] * (original_width * original_height) 25 | max_area = self.area_range[1] * (original_width * original_height) 26 | 27 | for _ in range(self.max_attempts): 28 | aspect_ratio = random.uniform(*self.aspect_ratio_range) 29 | height = int(round(math.sqrt(min_area / aspect_ratio))) 30 | max_height = int(round(math.sqrt(max_area / aspect_ratio))) 31 | 32 | if max_height * aspect_ratio > original_width: 33 | max_height = (original_width + 0.5 - 1e-7) / aspect_ratio 34 | max_height = int(max_height) 35 | if max_height * aspect_ratio > original_width: 36 | max_height -= 1 37 | 38 | if max_height > original_height: 39 | max_height = original_height 40 | 41 | if height >= max_height: 42 | height = max_height 43 | 44 | height = int(round(random.uniform(height, max_height))) 45 | width = int(round(height * aspect_ratio)) 46 | area = width * height 47 | 48 | if area < min_area or area > max_area: 49 | continue 50 | if width > original_width or height > original_height: 51 | continue 52 | if area < self.min_covered * (original_width * original_height): 53 | continue 54 | if width == original_width and height == original_height: 55 | return self._fallback(img) # https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/preprocessing.py#L102 56 | 57 | x = random.randint(0, original_width - width) 58 | y = random.randint(0, original_height - height) 59 | return img.crop((x, y, x + width, y + height)) 60 | 61 | return self._fallback(img) 62 | 63 | 64 | class EfficientNetCenterCrop: 65 | def __init__(self, imgsize): 66 | self.imgsize = imgsize 67 | 68 | def __call__(self, img): 69 | """Crop the given PIL Image and resize it to desired size. 70 | 71 | Args: 72 | img (PIL Image): Image to be cropped. (0,0) denotes the top left corner of the image. 73 | output_size (sequence or int): (height, width) of the crop box. If int, 74 | it is used for both directions 75 | Returns: 76 | PIL Image: Cropped image. 77 | """ 78 | image_width, image_height = img.size 79 | image_short = min(image_width, image_height) 80 | 81 | crop_size = float(self.imgsize) / (self.imgsize + 32) * image_short 82 | 83 | crop_height, crop_width = crop_size, crop_size 84 | crop_top = int(round((image_height - crop_height) / 2.)) 85 | crop_left = int(round((image_width - crop_width) / 2.)) 86 | return img.crop((crop_left, crop_top, crop_left + crop_width, crop_top + crop_height)) 87 | 88 | 89 | class Lighting(torch.nn.Module): 90 | """Lighting noise(AlexNet - style PCA - based noise)""" 91 | def __init__(self, alphastd=0.1): 92 | super().__init__() 93 | self.alphastd = alphastd 94 | self.register_buffer('eigval', torch.Tensor([0.2175, 0.0188, 0.0045])) 95 | self.register_buffer('eigvec', torch.Tensor([ 96 | [-0.5675, 0.7192, 0.4009], 97 | [-0.5808, -0.0045, -0.8140], 98 | [-0.5836, -0.6948, 0.4203], 99 | ])) 100 | 101 | def forward(self, img): 102 | if self.alphastd == 0: 103 | return img 104 | 105 | alpha = img.new().resize_(3).normal_(0, self.alphastd) 106 | rgb = self.eigvec.type_as(img).clone() \ 107 | .mul(alpha.view(1, 3).expand(3, 3)) \ 108 | .mul(self.eigval.view(1, 3).expand(3, 3)) \ 109 | .sum(1).squeeze() 110 | 111 | return img.add(rgb.view(3, 1, 1).expand_as(img)) 112 | 113 | -------------------------------------------------------------------------------- /lib/augmentation/nn_aug.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | def relaxed_bernoulli(logits, temp=0.05, device='cpu'): 9 | u = torch.rand_like(logits, device=device) 10 | l = torch.log(u) - torch.log(1 - u) 11 | return ((l + logits)/temp).sigmoid() 12 | 13 | 14 | class TriangleWave(torch.autograd.Function): 15 | @staticmethod 16 | def forward(self, x): 17 | o = torch.acos(torch.cos(x * math.pi)) / math.pi 18 | self.save_for_backward(x) 19 | return o 20 | 21 | @staticmethod 22 | def backward(self, grad): 23 | o = self.saved_tensors[0] 24 | # avoid nan gradient at the peak by replacing it with the right derivative 25 | o = torch.floor(o) % 2 26 | grad[o == 1] *= -1 27 | return grad 28 | 29 | 30 | class ColorAugmentation(nn.Module): 31 | def __init__(self, n_classes=10, scale=1, hidden=128, n_dim=128, dropout_ratio=0.8, with_context=True): 32 | super().__init__() 33 | 34 | n_hidden = 4 * n_dim 35 | conv = lambda ic, io, k : nn.Conv2d(ic, io, k, padding=k//2, bias=False) 36 | linear = lambda ic, io : nn.Linear(ic, io, False) 37 | bn2d = lambda c : nn.BatchNorm2d(c, track_running_stats=False) 38 | bn1d = lambda c : nn.BatchNorm1d(c, track_running_stats=False) 39 | 40 | # embedding layer for context vector 41 | if with_context: 42 | self.context_layer = conv(n_classes, hidden, 1) 43 | else: 44 | self.context_layer = None 45 | # embedding layer for RGB 46 | self.color_enc1 = conv(3, hidden, 1) 47 | # body for RGB 48 | self.color_enc_body = nn.Sequential( 49 | bn2d(hidden), 50 | nn.LeakyReLU(0.2, True), 51 | nn.Dropout2d(dropout_ratio) if dropout_ratio > 0 else nn.Sequential(), 52 | conv(hidden, hidden, 1), 53 | bn2d(hidden), 54 | nn.LeakyReLU(0.2, True), 55 | nn.Dropout2d(dropout_ratio) if dropout_ratio > 0 else nn.Sequential() 56 | ) 57 | # output layer for RGB 58 | self.c_regress = conv(hidden, 6, 1) 59 | # body for noise vector 60 | self.noise_enc = nn.Sequential( 61 | linear(n_dim + n_classes if with_context else n_dim, n_hidden), 62 | bn1d(n_hidden), 63 | nn.LeakyReLU(0.2, True), 64 | nn.Dropout(dropout_ratio) if dropout_ratio > 0 else nn.Sequential(), 65 | linear(n_hidden, n_hidden), 66 | bn1d(n_hidden), 67 | nn.LeakyReLU(0.2, True), 68 | nn.Dropout(dropout_ratio) if dropout_ratio > 0 else nn.Sequential(), 69 | ) 70 | # output layer for noise vector 71 | self.n_regress = linear(n_hidden, 2) 72 | 73 | if with_context: 74 | self.register_parameter('logits', nn.Parameter(torch.zeros(n_classes))) 75 | else: 76 | self.register_parameter('logits', nn.Parameter(torch.zeros(1))) 77 | # initialize parameters 78 | self.reset() 79 | 80 | self.with_context = with_context 81 | self.n_classes = n_classes 82 | self.n_dim = n_dim 83 | self.scale = scale 84 | self.relax = True 85 | self.stochastic = True 86 | 87 | def sampling(self, scale, shift, y, temp=0.05): 88 | if self.stochastic: # random apply 89 | if self.with_context: 90 | logits = self.logits[y].reshape(-1, 1, 1, 1) 91 | else: 92 | logits = self.logits.repeat(scale.shape[0]).reshape(-1, 1, 1, 1) 93 | prob = relaxed_bernoulli(logits, temp, device=scale.device) 94 | if not self.relax: # hard sampling 95 | prob = (prob > 0.5).float() 96 | scale = 1 - prob + prob * scale 97 | shift = prob * shift # omit "+ (1 - prob) * 0" 98 | return scale, shift 99 | 100 | def forward(self, x, noise, c=None): 101 | if self.with_context: 102 | # integer to onehot vector 103 | onehot_c = nn.functional.one_hot(c, self.n_classes).float() 104 | noise = torch.cat([onehot_c, noise], 1) 105 | # vector to 2d image 106 | onehot_c = onehot_c.reshape(*onehot_c.shape, 1, 1) 107 | # global scale and shift 108 | gfactor = self.noise_enc(noise) 109 | gfactor = self.n_regress(gfactor).reshape(-1, 2, 1, 1) 110 | # per-pixel scale and shift 111 | feature = self.color_enc1(x) 112 | # add context information 113 | if self.with_context: 114 | feature = self.context_layer(onehot_c) + feature 115 | feature = self.color_enc_body(feature) 116 | factor = self.c_regress(feature) 117 | # add up parameters 118 | scale, shift = factor.chunk(2, dim=1) 119 | g_scale, g_shift = gfactor.chunk(2, dim=1) 120 | scale = (g_scale + scale).sigmoid() 121 | shift = (g_shift + shift).sigmoid() 122 | # scaling 123 | scale = self.scale * (scale - 0.5) + 1 124 | shift = shift - 0.5 125 | # random apply 126 | scale, shift = self.sampling(scale, shift, c) 127 | 128 | return scale, shift 129 | 130 | def reset(self): 131 | for m in self.modules(): 132 | if isinstance(m, (nn.Conv2d, nn.Linear)): 133 | nn.init.kaiming_normal_(m.weight, 0.2, 'fan_out') 134 | if m.bias is not None: 135 | nn.init.constant_(m.bias, 0) 136 | # zero initialization 137 | nn.init.constant_(self.c_regress.weight, 0) 138 | nn.init.constant_(self.n_regress.weight, 0) 139 | nn.init.constant_(self.logits, 0) 140 | 141 | def transform(self, x, scale, shift): 142 | # ignore zero padding region 143 | with torch.no_grad(): 144 | h, w = x.shape[-2:] 145 | mask = (x.sum(1, keepdim=True) == 0).float() # mask pixels having (0, 0, 0) color 146 | mask = torch.logical_and(mask.sum(-1, keepdim=True) < w, 147 | mask.sum(-2, keepdim=True) < h) # mask zero padding region 148 | 149 | x = (scale * x + shift) * mask 150 | return TriangleWave.apply(x) 151 | 152 | 153 | class GeometricAugmentation(nn.Module): 154 | def __init__(self, n_classes=10, scale=0.5, n_dim=128, dropout_ratio=0.8, with_context=True): 155 | super().__init__() 156 | 157 | hidden = 4 * n_dim 158 | linear = lambda ic, io : nn.Linear(ic, io, False) 159 | bn1d = lambda c : nn.BatchNorm1d(c, track_running_stats=False) 160 | 161 | self.body = nn.Sequential( 162 | linear(n_dim + n_classes if with_context else n_dim, hidden), 163 | bn1d(hidden), 164 | nn.LeakyReLU(0.2, True), 165 | nn.Dropout(dropout_ratio) if dropout_ratio > 0 else nn.Sequential(), 166 | linear(hidden, hidden), 167 | bn1d(hidden), 168 | nn.LeakyReLU(0.2, True), 169 | nn.Dropout(dropout_ratio) if dropout_ratio > 0 else nn.Sequential(), 170 | ) 171 | 172 | self.regressor = linear(hidden, 6) 173 | # identity matrix 174 | self.register_buffer('i_matrix', torch.Tensor([[1, 0, 0], [0, 1, 0]]).reshape(1, 2, 3)) 175 | 176 | if with_context: 177 | self.register_parameter('logits', nn.Parameter(torch.zeros(n_classes))) 178 | else: 179 | self.register_parameter('logits', nn.Parameter(torch.zeros(1))) 180 | # initialize parameters 181 | self.reset() 182 | 183 | self.with_context = with_context 184 | self.n_classes = n_classes 185 | self.n_dim = n_dim 186 | self.scale = scale 187 | 188 | self.relax = True 189 | self.stochastic = True 190 | 191 | def sampling(self, A, y=None, temp=0.05): 192 | if self.stochastic: # random apply 193 | if self.with_context: 194 | logits = self.logits[y].reshape(-1, 1, 1) 195 | else: 196 | logits = self.logits.repeat(A.shape[0]).reshape(-1, 1, 1) 197 | prob = relaxed_bernoulli(logits, temp, device=logits.device) 198 | if not self.relax: # hard sampling 199 | prob = (prob > 0.5).float() 200 | return ((1 - prob) * self.i_matrix + prob * A) 201 | else: 202 | return A 203 | 204 | def forward(self, x, noise, c=None): 205 | if self.with_context: 206 | with torch.no_grad(): 207 | # integer to onehot vector 208 | onehot_c = nn.functional.one_hot(c, self.n_classes).float() 209 | noise = torch.cat([onehot_c, noise], 1) 210 | features = self.body(noise) 211 | A = self.regressor(features).reshape(-1, 2, 3) 212 | # scaling 213 | A = self.scale * (A.sigmoid() - 0.5) + self.i_matrix 214 | # random apply 215 | A = self.sampling(A, c) 216 | # matrix to grid representation 217 | grid = nn.functional.affine_grid(A, x.shape) 218 | return grid 219 | 220 | def reset(self): 221 | for m in self.modules(): 222 | if isinstance(m, nn.Linear): 223 | nn.init.kaiming_normal_(m.weight, 0.2, 'fan_out') 224 | if m.bias is not None: 225 | nn.init.constant_(m.bias, 0) 226 | # zero initialization 227 | nn.init.constant_(self.regressor.weight, 0) 228 | nn.init.constant_(self.logits, 0) 229 | 230 | def transform(self, x, grid): 231 | x = F.grid_sample(x, grid, mode='bilinear') 232 | return x 233 | -------------------------------------------------------------------------------- /lib/augmentation/replay_buffer.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import random 3 | 4 | import torch.nn as nn 5 | 6 | 7 | class ReplayBuffer(nn.Module): 8 | def __init__(self, decay_rate=0.9, buffer_size=-1): 9 | super().__init__() 10 | self.decay_rate = decay_rate 11 | self.buffer = nn.ModuleList([]) 12 | self.priority = [] 13 | self.buffer_size = buffer_size 14 | 15 | def store(self, augmentation): 16 | self.buffer.append(copy.deepcopy(augmentation)) 17 | self.priority.append(1) 18 | self.priority = list(map(lambda x: self.decay_rate * x, self.priority)) # decay 19 | if self.buffer_size > 0 and len(self.priority) > self.buffer_size: 20 | self.buffer = self.buffer[-self.buffer_size:] 21 | self.priority = self.priority[-self.buffer_size:] 22 | 23 | def sampling(self, n_samples, latest_aug=None): 24 | if latest_aug is not None: 25 | buffer = list(self.buffer._modules.values()) + [latest_aug] 26 | priority = self.priority + [1] 27 | else: 28 | buffer = self.buffer 29 | priority = self.priority 30 | return random.choices(buffer, priority, k=n_samples) 31 | 32 | def __len__(self): 33 | return len(self.buffer) 34 | 35 | def initialize(self, length, module): 36 | # This function must be called before the "load_state_dict" function. 37 | # placeholder to load state dict 38 | self.buffer = nn.ModuleList([copy.deepcopy(module) for _ in range(length)]) 39 | self.priority = [self.decay_rate**(i+1) for i in reversed(range(length))] 40 | -------------------------------------------------------------------------------- /lib/losses/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DensoITLab/TeachAugment/66ec099a0afab99e18531c5437182cfe17dc30c8/lib/losses/__init__.py -------------------------------------------------------------------------------- /lib/losses/non_saturating_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def non_saturating_loss(logits, targets): 6 | probs = logits.softmax(1) 7 | log_prob = torch.log(1 - probs + 1e-12) 8 | if targets.ndim == 2: 9 | return - (targets * log_prob).sum(1).mean() 10 | else: 11 | return F.nll_loss(log_prob, targets) 12 | 13 | 14 | class NonSaturatingLoss(torch.nn.Module): 15 | def __init__(self, epsilon=0): 16 | super().__init__() 17 | self.epsilon = epsilon 18 | 19 | def forward(self, logits, targets): 20 | if self.epsilon > 0: # label smoothing 21 | n_classes = logits.shape[1] 22 | onehot_targets = F.one_hot(targets, n_classes).float() 23 | targets = (1 - self.epsilon) * onehot_targets + self.epsilon / n_classes 24 | return non_saturating_loss(logits, targets) 25 | -------------------------------------------------------------------------------- /lib/models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import resnet 2 | from . import pyramidnet 3 | from . import wide_resnet 4 | from .shakeshake import shake_resnet 5 | 6 | 7 | def build_model(model_name, num_classes=10): 8 | if model_name in ['wideresnet-28-10', 'wrn-28-10']: 9 | model = wide_resnet.WideResNet(28, 10, 0, num_classes) 10 | 11 | elif model_name in ['wideresnet-40-2', 'wrn-40-2']: 12 | model = wide_resnet.WideResNet(40, 2, 0, num_classes) 13 | 14 | elif model_name in ['shakeshake26_2x32d', 'ss32']: 15 | model = shake_resnet.ShakeResNet(26, 32, num_classes) 16 | 17 | elif model_name in ['shakeshake26_2x96d', 'ss96']: 18 | model = shake_resnet.ShakeResNet(26, 96, num_classes) 19 | 20 | elif model_name in ['shakeshake26_2x112d', 'ss112']: 21 | model = shake_resnet.ShakeResNet(26, 112, num_classes) 22 | 23 | elif model_name == 'pyramidnet': 24 | model = pyramidnet.PyramidNet('cifar10', depth=272, alpha=200, num_classes=num_classes, bottleneck=True) 25 | 26 | elif model_name == 'resnet200': 27 | model = resnet.ResNet('imagenet', 200, num_classes, True) 28 | 29 | elif model_name == 'resnet50': 30 | model = resnet.ResNet('imagenet', 50, num_classes, True) 31 | 32 | return model 33 | -------------------------------------------------------------------------------- /lib/models/pyramidnet.py: -------------------------------------------------------------------------------- 1 | # the code is taken from https://github.com/kakaobrain/fast-autoaugment 2 | import torch 3 | import torch.nn as nn 4 | import math 5 | 6 | from .shakedrop import ShakeDrop 7 | 8 | 9 | def conv3x3(in_planes, out_planes, stride=1): 10 | """ 11 | 3x3 convolution with padding 12 | """ 13 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 14 | 15 | 16 | class BasicBlock(nn.Module): 17 | outchannel_ratio = 1 18 | 19 | def __init__(self, inplanes, planes, stride=1, downsample=None, p_shakedrop=1.0): 20 | super(BasicBlock, self).__init__() 21 | self.bn1 = nn.BatchNorm2d(inplanes) 22 | self.conv1 = conv3x3(inplanes, planes, stride) 23 | self.bn2 = nn.BatchNorm2d(planes) 24 | self.conv2 = conv3x3(planes, planes) 25 | self.bn3 = nn.BatchNorm2d(planes) 26 | self.relu = nn.ReLU(inplace=True) 27 | self.downsample = downsample 28 | self.stride = stride 29 | self.shake_drop = ShakeDrop(p_shakedrop) 30 | 31 | def forward(self, x): 32 | 33 | out = self.bn1(x) 34 | out = self.conv1(out) 35 | out = self.bn2(out) 36 | out = self.relu(out) 37 | out = self.conv2(out) 38 | out = self.bn3(out) 39 | 40 | if self.training: 41 | out = self.shake_drop(out) 42 | else: 43 | out = (1 - self.shake_drop.p_drop) * x 44 | 45 | if self.downsample is not None: 46 | shortcut = self.downsample(x) 47 | featuremap_size = shortcut.size()[2:4] 48 | else: 49 | shortcut = x 50 | featuremap_size = out.size()[2:4] 51 | 52 | batch_size = out.size()[0] 53 | residual_channel = out.size()[1] 54 | shortcut_channel = shortcut.size()[1] 55 | 56 | if residual_channel != shortcut_channel: 57 | padding = torch.autograd.Variable( 58 | torch.cuda.FloatTensor(batch_size, residual_channel - shortcut_channel, featuremap_size[0], 59 | featuremap_size[1]).fill_(0)) 60 | out = out + torch.cat((shortcut, padding), 1) 61 | else: 62 | out = out + shortcut 63 | 64 | return out 65 | 66 | 67 | class Bottleneck(nn.Module): 68 | outchannel_ratio = 4 69 | 70 | def __init__(self, inplanes, planes, stride=1, downsample=None, p_shakedrop=1.0): 71 | super(Bottleneck, self).__init__() 72 | self.bn1 = nn.BatchNorm2d(inplanes) 73 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 74 | self.bn2 = nn.BatchNorm2d(planes) 75 | self.conv2 = nn.Conv2d(planes, (planes * 1), kernel_size=3, stride=stride, 76 | padding=1, bias=False) 77 | self.bn3 = nn.BatchNorm2d((planes * 1)) 78 | self.conv3 = nn.Conv2d((planes * 1), planes * Bottleneck.outchannel_ratio, kernel_size=1, bias=False) 79 | self.bn4 = nn.BatchNorm2d(planes * Bottleneck.outchannel_ratio) 80 | self.relu = nn.ReLU(inplace=True) 81 | self.downsample = downsample 82 | self.stride = stride 83 | self.shake_drop = ShakeDrop(p_shakedrop) 84 | 85 | def forward(self, x): 86 | 87 | out = self.bn1(x) 88 | out = self.conv1(out) 89 | 90 | out = self.bn2(out) 91 | out = self.relu(out) 92 | out = self.conv2(out) 93 | 94 | out = self.bn3(out) 95 | out = self.relu(out) 96 | out = self.conv3(out) 97 | 98 | out = self.bn4(out) 99 | 100 | if self.training: 101 | out = self.shake_drop(out) 102 | else: 103 | out = (1 - self.shake_drop.p_drop) * out 104 | 105 | 106 | if self.downsample is not None: 107 | shortcut = self.downsample(x) 108 | featuremap_size = shortcut.size()[2:4] 109 | else: 110 | shortcut = x 111 | featuremap_size = out.size()[2:4] 112 | 113 | batch_size = out.size()[0] 114 | residual_channel = out.size()[1] 115 | shortcut_channel = shortcut.size()[1] 116 | 117 | if residual_channel != shortcut_channel: 118 | padding = torch.autograd.Variable( 119 | torch.cuda.FloatTensor(batch_size, residual_channel - shortcut_channel, featuremap_size[0], 120 | featuremap_size[1]).fill_(0)) 121 | out = out + torch.cat((shortcut, padding), 1) 122 | else: 123 | out = out + shortcut 124 | 125 | return out 126 | 127 | 128 | class PyramidNet(nn.Module): 129 | 130 | def __init__(self, dataset, depth, alpha, num_classes, bottleneck=True): 131 | super(PyramidNet, self).__init__() 132 | self.dataset = dataset 133 | if self.dataset.startswith('cifar'): 134 | self.inplanes = 16 135 | if bottleneck: 136 | n = int((depth - 2) / 9) 137 | block = Bottleneck 138 | else: 139 | n = int((depth - 2) / 6) 140 | block = BasicBlock 141 | 142 | self.addrate = alpha / (3 * n * 1.0) 143 | self.ps_shakedrop = [1. - (1.0 - (0.5 / (3 * n)) * (i + 1)) for i in range(3 * n)] 144 | 145 | self.input_featuremap_dim = self.inplanes 146 | self.conv1 = nn.Conv2d(3, self.input_featuremap_dim, kernel_size=3, stride=1, padding=1, bias=False) 147 | self.bn1 = nn.BatchNorm2d(self.input_featuremap_dim) 148 | 149 | self.featuremap_dim = self.input_featuremap_dim 150 | self.layer1 = self.pyramidal_make_layer(block, n) 151 | self.layer2 = self.pyramidal_make_layer(block, n, stride=2) 152 | self.layer3 = self.pyramidal_make_layer(block, n, stride=2) 153 | 154 | self.final_featuremap_dim = self.input_featuremap_dim 155 | self.bn_final = nn.BatchNorm2d(self.final_featuremap_dim) 156 | self.relu_final = nn.ReLU(inplace=True) 157 | self.avgpool = nn.AvgPool2d(8) 158 | self.fc = nn.Linear(self.final_featuremap_dim, num_classes) 159 | 160 | elif dataset == 'imagenet': 161 | blocks = {18: BasicBlock, 34: BasicBlock, 50: Bottleneck, 101: Bottleneck, 152: Bottleneck, 200: Bottleneck} 162 | layers = {18: [2, 2, 2, 2], 34: [3, 4, 6, 3], 50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], 163 | 200: [3, 24, 36, 3]} 164 | 165 | if layers.get(depth) is None: 166 | if bottleneck == True: 167 | blocks[depth] = Bottleneck 168 | temp_cfg = int((depth - 2) / 12) 169 | else: 170 | blocks[depth] = BasicBlock 171 | temp_cfg = int((depth - 2) / 8) 172 | 173 | layers[depth] = [temp_cfg, temp_cfg, temp_cfg, temp_cfg] 174 | print('=> the layer configuration for each stage is set to', layers[depth]) 175 | 176 | self.inplanes = 64 177 | self.addrate = alpha / (sum(layers[depth]) * 1.0) 178 | 179 | self.input_featuremap_dim = self.inplanes 180 | self.conv1 = nn.Conv2d(3, self.input_featuremap_dim, kernel_size=7, stride=2, padding=3, bias=False) 181 | self.bn1 = nn.BatchNorm2d(self.input_featuremap_dim) 182 | self.relu = nn.ReLU(inplace=True) 183 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 184 | 185 | self.featuremap_dim = self.input_featuremap_dim 186 | self.layer1 = self.pyramidal_make_layer(blocks[depth], layers[depth][0]) 187 | self.layer2 = self.pyramidal_make_layer(blocks[depth], layers[depth][1], stride=2) 188 | self.layer3 = self.pyramidal_make_layer(blocks[depth], layers[depth][2], stride=2) 189 | self.layer4 = self.pyramidal_make_layer(blocks[depth], layers[depth][3], stride=2) 190 | 191 | self.final_featuremap_dim = self.input_featuremap_dim 192 | self.bn_final = nn.BatchNorm2d(self.final_featuremap_dim) 193 | self.relu_final = nn.ReLU(inplace=True) 194 | self.avgpool = nn.AvgPool2d(7) 195 | self.fc = nn.Linear(self.final_featuremap_dim, num_classes) 196 | 197 | for m in self.modules(): 198 | if isinstance(m, nn.Conv2d): 199 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 200 | m.weight.data.normal_(0, math.sqrt(2. / n)) 201 | elif isinstance(m, nn.BatchNorm2d): 202 | m.weight.data.fill_(1) 203 | m.bias.data.zero_() 204 | 205 | assert len(self.ps_shakedrop) == 0, self.ps_shakedrop 206 | 207 | def pyramidal_make_layer(self, block, block_depth, stride=1): 208 | downsample = None 209 | if stride != 1: # or self.inplanes != int(round(featuremap_dim_1st)) * block.outchannel_ratio: 210 | downsample = nn.AvgPool2d((2, 2), stride=(2, 2), ceil_mode=True) 211 | 212 | layers = [] 213 | self.featuremap_dim = self.featuremap_dim + self.addrate 214 | layers.append(block(self.input_featuremap_dim, int(round(self.featuremap_dim)), stride, downsample, p_shakedrop=self.ps_shakedrop.pop(0))) 215 | for i in range(1, block_depth): 216 | temp_featuremap_dim = self.featuremap_dim + self.addrate 217 | layers.append( 218 | block(int(round(self.featuremap_dim)) * block.outchannel_ratio, int(round(temp_featuremap_dim)), 1, p_shakedrop=self.ps_shakedrop.pop(0))) 219 | self.featuremap_dim = temp_featuremap_dim 220 | self.input_featuremap_dim = int(round(self.featuremap_dim)) * block.outchannel_ratio 221 | 222 | return nn.Sequential(*layers) 223 | 224 | def forward(self, x): 225 | if self.dataset == 'cifar10' or self.dataset == 'cifar100': 226 | x = self.conv1(x) 227 | x = self.bn1(x) 228 | 229 | x = self.layer1(x) 230 | x = self.layer2(x) 231 | x = self.layer3(x) 232 | 233 | x = self.bn_final(x) 234 | x = self.relu_final(x) 235 | x = self.avgpool(x) 236 | x = x.view(x.size(0), -1) 237 | x = self.fc(x) 238 | 239 | elif self.dataset == 'imagenet': 240 | x = self.conv1(x) 241 | x = self.bn1(x) 242 | x = self.relu(x) 243 | x = self.maxpool(x) 244 | 245 | x = self.layer1(x) 246 | x = self.layer2(x) 247 | x = self.layer3(x) 248 | x = self.layer4(x) 249 | 250 | x = self.bn_final(x) 251 | x = self.relu_final(x) 252 | x = self.avgpool(x) 253 | x = x.view(x.size(0), -1) 254 | x = self.fc(x) 255 | 256 | return x 257 | -------------------------------------------------------------------------------- /lib/models/resnet.py: -------------------------------------------------------------------------------- 1 | # the code is taken from https://github.com/kakaobrain/fast-autoaugment 2 | 3 | # Original code: https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 4 | 5 | import torch.nn as nn 6 | import math 7 | 8 | 9 | def conv3x3(in_planes, out_planes, stride=1): 10 | "3x3 convolution with padding" 11 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 12 | padding=1, bias=False) 13 | 14 | 15 | class BasicBlock(nn.Module): 16 | expansion = 1 17 | 18 | def __init__(self, inplanes, planes, stride=1, downsample=None): 19 | super(BasicBlock, self).__init__() 20 | self.conv1 = conv3x3(inplanes, planes, stride) 21 | self.bn1 = nn.BatchNorm2d(planes) 22 | self.conv2 = conv3x3(planes, planes) 23 | self.bn2 = nn.BatchNorm2d(planes) 24 | self.relu = nn.ReLU(inplace=True) 25 | 26 | self.downsample = downsample 27 | self.stride = stride 28 | 29 | def forward(self, x): 30 | residual = x 31 | 32 | out = self.conv1(x) 33 | out = self.bn1(out) 34 | out = self.relu(out) 35 | 36 | out = self.conv2(out) 37 | out = self.bn2(out) 38 | 39 | if self.downsample is not None: 40 | residual = self.downsample(x) 41 | 42 | out += residual 43 | out = self.relu(out) 44 | 45 | return out 46 | 47 | 48 | class Bottleneck(nn.Module): 49 | expansion = 4 50 | 51 | def __init__(self, inplanes, planes, stride=1, downsample=None): 52 | super(Bottleneck, self).__init__() 53 | 54 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 55 | self.bn1 = nn.BatchNorm2d(planes) 56 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 57 | self.bn2 = nn.BatchNorm2d(planes) 58 | self.conv3 = nn.Conv2d(planes, planes * Bottleneck.expansion, kernel_size=1, bias=False) 59 | self.bn3 = nn.BatchNorm2d(planes * Bottleneck.expansion) 60 | self.relu = nn.ReLU(inplace=True) 61 | 62 | self.downsample = downsample 63 | self.stride = stride 64 | 65 | def forward(self, x): 66 | residual = x 67 | 68 | out = self.conv1(x) 69 | out = self.bn1(out) 70 | out = self.relu(out) 71 | 72 | out = self.conv2(out) 73 | out = self.bn2(out) 74 | out = self.relu(out) 75 | 76 | out = self.conv3(out) 77 | out = self.bn3(out) 78 | if self.downsample is not None: 79 | residual = self.downsample(x) 80 | 81 | out += residual 82 | out = self.relu(out) 83 | 84 | return out 85 | 86 | class ResNet(nn.Module): 87 | def __init__(self, dataset, depth, num_classes, bottleneck=False): 88 | super(ResNet, self).__init__() 89 | self.dataset = dataset 90 | if self.dataset.startswith('cifar'): 91 | self.inplanes = 16 92 | print(bottleneck) 93 | if bottleneck == True: 94 | n = int((depth - 2) / 9) 95 | block = Bottleneck 96 | else: 97 | n = int((depth - 2) / 6) 98 | block = BasicBlock 99 | 100 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) 101 | self.bn1 = nn.BatchNorm2d(self.inplanes) 102 | self.relu = nn.ReLU(inplace=True) 103 | self.layer1 = self._make_layer(block, 16, n) 104 | self.layer2 = self._make_layer(block, 32, n, stride=2) 105 | self.layer3 = self._make_layer(block, 64, n, stride=2) 106 | # self.avgpool = nn.AvgPool2d(8) 107 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 108 | self.fc = nn.Linear(64 * block.expansion, num_classes) 109 | 110 | elif dataset == 'imagenet': 111 | blocks ={18: BasicBlock, 34: BasicBlock, 50: Bottleneck, 101: Bottleneck, 152: Bottleneck, 200: Bottleneck} 112 | layers ={18: [2, 2, 2, 2], 34: [3, 4, 6, 3], 50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], 200: [3, 24, 36, 3]} 113 | assert layers[depth], 'invalid detph for ResNet (depth should be one of 18, 34, 50, 101, 152, and 200)' 114 | 115 | self.inplanes = 64 116 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) 117 | self.bn1 = nn.BatchNorm2d(64) 118 | self.relu = nn.ReLU(inplace=True) 119 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 120 | self.layer1 = self._make_layer(blocks[depth], 64, layers[depth][0]) 121 | self.layer2 = self._make_layer(blocks[depth], 128, layers[depth][1], stride=2) 122 | self.layer3 = self._make_layer(blocks[depth], 256, layers[depth][2], stride=2) 123 | self.layer4 = self._make_layer(blocks[depth], 512, layers[depth][3], stride=2) 124 | # self.avgpool = nn.AvgPool2d(7) 125 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 126 | self.fc = nn.Linear(512 * blocks[depth].expansion, num_classes) 127 | 128 | for m in self.modules(): 129 | if isinstance(m, nn.Conv2d): 130 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 131 | m.weight.data.normal_(0, math.sqrt(2. / n)) 132 | elif isinstance(m, nn.BatchNorm2d): 133 | m.weight.data.fill_(1) 134 | m.bias.data.zero_() 135 | 136 | def _make_layer(self, block, planes, blocks, stride=1): 137 | downsample = None 138 | if stride != 1 or self.inplanes != planes * block.expansion: 139 | downsample = nn.Sequential( 140 | nn.Conv2d(self.inplanes, planes * block.expansion, 141 | kernel_size=1, stride=stride, bias=False), 142 | nn.BatchNorm2d(planes * block.expansion), 143 | ) 144 | 145 | layers = [] 146 | layers.append(block(self.inplanes, planes, stride, downsample)) 147 | self.inplanes = planes * block.expansion 148 | for i in range(1, blocks): 149 | layers.append(block(self.inplanes, planes)) 150 | 151 | return nn.Sequential(*layers) 152 | 153 | def forward(self, x): 154 | if self.dataset == 'cifar10' or self.dataset == 'cifar100': 155 | x = self.conv1(x) 156 | x = self.bn1(x) 157 | x = self.relu(x) 158 | 159 | x = self.layer1(x) 160 | x = self.layer2(x) 161 | x = self.layer3(x) 162 | 163 | x = self.avgpool(x) 164 | x = x.view(x.size(0), -1) 165 | x = self.fc(x) 166 | 167 | elif self.dataset == 'imagenet': 168 | x = self.conv1(x) 169 | x = self.bn1(x) 170 | x = self.relu(x) 171 | x = self.maxpool(x) 172 | 173 | x = self.layer1(x) 174 | x = self.layer2(x) 175 | x = self.layer3(x) 176 | x = self.layer4(x) 177 | 178 | x = self.avgpool(x) 179 | x = x.view(x.size(0), -1) 180 | x = self.fc(x) 181 | 182 | return x 183 | -------------------------------------------------------------------------------- /lib/models/shakedrop.py: -------------------------------------------------------------------------------- 1 | # the code is taken from https://github.com/kakaobrain/fast-autoaugment 2 | # -*- coding: utf-8 -*- 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | 9 | 10 | class ShakeDropFunction(torch.autograd.Function): 11 | 12 | @staticmethod 13 | def forward(ctx, x, training=True, p_drop=0.5, alpha_range=[-1, 1]): 14 | if training: 15 | gate = torch.cuda.FloatTensor([0]).bernoulli_(1 - p_drop) 16 | ctx.save_for_backward(gate) 17 | if gate.item() == 0: 18 | alpha = torch.cuda.FloatTensor(x.size(0)).uniform_(*alpha_range) 19 | alpha = alpha.view(alpha.size(0), 1, 1, 1).expand_as(x) 20 | return alpha * x 21 | else: 22 | return x 23 | else: 24 | return (1 - p_drop) * x 25 | 26 | @staticmethod 27 | def backward(ctx, grad_output): 28 | gate = ctx.saved_tensors[0] 29 | if gate.item() == 0: 30 | beta = torch.cuda.FloatTensor(grad_output.size(0)).uniform_(0, 1) 31 | beta = beta.view(beta.size(0), 1, 1, 1).expand_as(grad_output) 32 | beta = Variable(beta) 33 | return beta * grad_output, None, None, None 34 | else: 35 | return grad_output, None, None, None 36 | 37 | 38 | class ShakeDrop(nn.Module): 39 | 40 | def __init__(self, p_drop=0.5, alpha_range=[-1, 1]): 41 | super(ShakeDrop, self).__init__() 42 | self.p_drop = p_drop 43 | self.alpha_range = alpha_range 44 | 45 | def forward(self, x): 46 | return ShakeDropFunction.apply(x, self.training, self.p_drop, self.alpha_range) 47 | -------------------------------------------------------------------------------- /lib/models/shakeshake/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DensoITLab/TeachAugment/66ec099a0afab99e18531c5437182cfe17dc30c8/lib/models/shakeshake/__init__.py -------------------------------------------------------------------------------- /lib/models/shakeshake/shake_resnet.py: -------------------------------------------------------------------------------- 1 | # the code is taken from https://github.com/kakaobrain/fast-autoaugment 2 | # -*- coding: utf-8 -*- 3 | 4 | import math 5 | 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from .shakeshake import ShakeShake 10 | from .shakeshake import Shortcut 11 | 12 | 13 | class ShakeBlock(nn.Module): 14 | 15 | def __init__(self, in_ch, out_ch, stride=1): 16 | super(ShakeBlock, self).__init__() 17 | self.equal_io = in_ch == out_ch 18 | self.shortcut = self.equal_io and None or Shortcut(in_ch, out_ch, stride=stride) 19 | 20 | self.branch1 = self._make_branch(in_ch, out_ch, stride) 21 | self.branch2 = self._make_branch(in_ch, out_ch, stride) 22 | 23 | def forward(self, x): 24 | h1 = self.branch1(x) 25 | h2 = self.branch2(x) 26 | h = ShakeShake.apply(h1, h2, self.training) 27 | h0 = x if self.equal_io else self.shortcut(x) 28 | return h + h0 29 | 30 | def _make_branch(self, in_ch, out_ch, stride=1): 31 | return nn.Sequential( 32 | nn.ReLU(inplace=False), 33 | nn.Conv2d(in_ch, out_ch, 3, padding=1, stride=stride, bias=False), 34 | nn.BatchNorm2d(out_ch), 35 | nn.ReLU(inplace=False), 36 | nn.Conv2d(out_ch, out_ch, 3, padding=1, stride=1, bias=False), 37 | nn.BatchNorm2d(out_ch)) 38 | 39 | 40 | class ShakeResNet(nn.Module): 41 | 42 | def __init__(self, depth, w_base, label): 43 | super(ShakeResNet, self).__init__() 44 | n_units = (depth - 2) / 6 45 | 46 | in_chs = [16, w_base, w_base * 2, w_base * 4] 47 | self.in_chs = in_chs 48 | 49 | self.c_in = nn.Conv2d(3, in_chs[0], 3, padding=1) 50 | self.layer1 = self._make_layer(n_units, in_chs[0], in_chs[1]) 51 | self.layer2 = self._make_layer(n_units, in_chs[1], in_chs[2], 2) 52 | self.layer3 = self._make_layer(n_units, in_chs[2], in_chs[3], 2) 53 | self.fc = nn.Linear(in_chs[3], label) 54 | 55 | # Initialize paramters 56 | for m in self.modules(): 57 | if isinstance(m, nn.Conv2d): 58 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 59 | m.weight.data.normal_(0, math.sqrt(2. / n)) 60 | elif isinstance(m, nn.BatchNorm2d): 61 | m.weight.data.fill_(1) 62 | m.bias.data.zero_() 63 | elif isinstance(m, nn.Linear): 64 | m.bias.data.zero_() 65 | 66 | def forward(self, x): 67 | h = self.c_in(x) 68 | h = self.layer1(h) 69 | h = self.layer2(h) 70 | h = self.layer3(h) 71 | h = F.relu(h) 72 | h = F.avg_pool2d(h, 8) 73 | h = h.view(-1, self.in_chs[3]) 74 | h = self.fc(h) 75 | return h 76 | 77 | def _make_layer(self, n_units, in_ch, out_ch, stride=1): 78 | layers = [] 79 | for i in range(int(n_units)): 80 | layers.append(ShakeBlock(in_ch, out_ch, stride=stride)) 81 | in_ch, stride = out_ch, 1 82 | return nn.Sequential(*layers) 83 | -------------------------------------------------------------------------------- /lib/models/shakeshake/shake_resnext.py: -------------------------------------------------------------------------------- 1 | # the code is taken from https://github.com/kakaobrain/fast-autoaugment 2 | # -*- coding: utf-8 -*- 3 | 4 | import math 5 | 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from .shakeshake import ShakeShake 10 | from .shakeshake import Shortcut 11 | 12 | 13 | class ShakeBottleNeck(nn.Module): 14 | 15 | def __init__(self, in_ch, mid_ch, out_ch, cardinary, stride=1): 16 | super(ShakeBottleNeck, self).__init__() 17 | self.equal_io = in_ch == out_ch 18 | self.shortcut = None if self.equal_io else Shortcut(in_ch, out_ch, stride=stride) 19 | 20 | self.branch1 = self._make_branch(in_ch, mid_ch, out_ch, cardinary, stride) 21 | self.branch2 = self._make_branch(in_ch, mid_ch, out_ch, cardinary, stride) 22 | 23 | def forward(self, x): 24 | h1 = self.branch1(x) 25 | h2 = self.branch2(x) 26 | if self.training: 27 | h = ShakeShake.apply(h1, h2, self.training) 28 | else: 29 | h = 0.5 * (h1 + h2) # avoiding stochastic gradient when updating augmentor 30 | h0 = x if self.equal_io else self.shortcut(x) 31 | return h + h0 32 | 33 | def _make_branch(self, in_ch, mid_ch, out_ch, cardinary, stride=1): 34 | return nn.Sequential( 35 | nn.Conv2d(in_ch, mid_ch, 1, padding=0, bias=False), 36 | nn.BatchNorm2d(mid_ch), 37 | nn.ReLU(inplace=False), 38 | nn.Conv2d(mid_ch, mid_ch, 3, padding=1, stride=stride, groups=cardinary, bias=False), 39 | nn.BatchNorm2d(mid_ch), 40 | nn.ReLU(inplace=False), 41 | nn.Conv2d(mid_ch, out_ch, 1, padding=0, bias=False), 42 | nn.BatchNorm2d(out_ch)) 43 | 44 | 45 | class ShakeResNeXt(nn.Module): 46 | 47 | def __init__(self, depth, w_base, cardinary, label): 48 | super(ShakeResNeXt, self).__init__() 49 | n_units = (depth - 2) // 9 50 | n_chs = [64, 128, 256, 1024] 51 | self.n_chs = n_chs 52 | self.in_ch = n_chs[0] 53 | 54 | self.c_in = nn.Conv2d(3, n_chs[0], 3, padding=1) 55 | self.layer1 = self._make_layer(n_units, n_chs[0], w_base, cardinary) 56 | self.layer2 = self._make_layer(n_units, n_chs[1], w_base, cardinary, 2) 57 | self.layer3 = self._make_layer(n_units, n_chs[2], w_base, cardinary, 2) 58 | self.fc_out = nn.Linear(n_chs[3], label) 59 | 60 | # Initialize paramters 61 | for m in self.modules(): 62 | if isinstance(m, nn.Conv2d): 63 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 64 | m.weight.data.normal_(0, math.sqrt(2. / n)) 65 | elif isinstance(m, nn.BatchNorm2d): 66 | m.weight.data.fill_(1) 67 | m.bias.data.zero_() 68 | elif isinstance(m, nn.Linear): 69 | m.bias.data.zero_() 70 | 71 | def forward(self, x): 72 | h = self.c_in(x) 73 | h = self.layer1(h) 74 | h = self.layer2(h) 75 | h = self.layer3(h) 76 | h = F.relu(h) 77 | h = F.avg_pool2d(h, 8) 78 | h = h.view(-1, self.n_chs[3]) 79 | h = self.fc_out(h) 80 | return h 81 | 82 | def _make_layer(self, n_units, n_ch, w_base, cardinary, stride=1): 83 | layers = [] 84 | mid_ch, out_ch = n_ch * (w_base // 64) * cardinary, n_ch * 4 85 | for i in range(n_units): 86 | layers.append(ShakeBottleNeck(self.in_ch, mid_ch, out_ch, cardinary, stride=stride)) 87 | self.in_ch, stride = out_ch, 1 88 | return nn.Sequential(*layers) 89 | -------------------------------------------------------------------------------- /lib/models/shakeshake/shakeshake.py: -------------------------------------------------------------------------------- 1 | # the code is taken from https://github.com/kakaobrain/fast-autoaugment 2 | # -*- coding: utf-8 -*- 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | 9 | 10 | class ShakeShake(torch.autograd.Function): 11 | 12 | @staticmethod 13 | def forward(ctx, x1, x2, training=True): 14 | if training: 15 | alpha = torch.cuda.FloatTensor(x1.size(0)).uniform_() 16 | alpha = alpha.view(alpha.size(0), 1, 1, 1).expand_as(x1) 17 | else: 18 | alpha = 0.5 19 | return alpha * x1 + (1 - alpha) * x2 20 | 21 | @staticmethod 22 | def backward(ctx, grad_output): 23 | beta = torch.cuda.FloatTensor(grad_output.size(0)).uniform_() 24 | beta = beta.view(beta.size(0), 1, 1, 1).expand_as(grad_output) 25 | beta = Variable(beta) 26 | 27 | return beta * grad_output, (1 - beta) * grad_output, None 28 | 29 | 30 | class Shortcut(nn.Module): 31 | 32 | def __init__(self, in_ch, out_ch, stride): 33 | super(Shortcut, self).__init__() 34 | self.stride = stride 35 | self.conv1 = nn.Conv2d(in_ch, out_ch // 2, 1, stride=1, padding=0, bias=False) 36 | self.conv2 = nn.Conv2d(in_ch, out_ch // 2, 1, stride=1, padding=0, bias=False) 37 | self.bn = nn.BatchNorm2d(out_ch) 38 | 39 | def forward(self, x): 40 | h = F.relu(x) 41 | 42 | h1 = F.avg_pool2d(h, 1, self.stride) 43 | h1 = self.conv1(h1) 44 | 45 | h2 = F.avg_pool2d(F.pad(h, (-1, 1, -1, 1)), 1, self.stride) 46 | h2 = self.conv2(h2) 47 | 48 | h = torch.cat((h1, h2), 1) 49 | return self.bn(h) 50 | -------------------------------------------------------------------------------- /lib/models/wide_resnet.py: -------------------------------------------------------------------------------- 1 | # the code is taken from https://github.com/kakaobrain/fast-autoaugment 2 | import math 3 | 4 | import torch.nn as nn 5 | import torch.nn.init as init 6 | import torch.nn.functional as F 7 | 8 | 9 | def conv3x3(in_planes, out_planes, stride=1): 10 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True) 11 | 12 | 13 | def conv_init(m): 14 | classname = m.__class__.__name__ 15 | if classname.find('Conv') != -1: 16 | init.xavier_uniform_(m.weight, gain=math.sqrt(2)) 17 | init.constant_(m.bias, 0) 18 | elif classname.find('BatchNorm') != -1: 19 | init.constant_(m.weight, 1) 20 | init.constant_(m.bias, 0) 21 | 22 | 23 | class WideBasic(nn.Module): 24 | def __init__(self, in_planes, planes, dropout_rate, stride=1): 25 | super(WideBasic, self).__init__() 26 | self.bn1 = nn.BatchNorm2d(in_planes, momentum=0.9) 27 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True) 28 | self.dropout = nn.Dropout(p=dropout_rate) 29 | self.bn2 = nn.BatchNorm2d(planes, momentum=0.9) 30 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True) 31 | 32 | self.shortcut = nn.Sequential() 33 | if stride != 1 or in_planes != planes: 34 | self.shortcut = nn.Sequential( 35 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True), 36 | ) 37 | 38 | def forward(self, x): 39 | out = self.dropout(self.conv1(F.relu(self.bn1(x)))) 40 | out = self.conv2(F.relu(self.bn2(out))) 41 | out += self.shortcut(x) 42 | 43 | return out 44 | 45 | 46 | class WideResNet(nn.Module): 47 | def __init__(self, depth, widen_factor, dropout_rate, num_classes): 48 | super(WideResNet, self).__init__() 49 | self.in_planes = 16 50 | 51 | assert ((depth - 4) % 6 == 0), 'Wide-resnet depth should be 6n+4' 52 | n = int((depth - 4) / 6) 53 | k = widen_factor 54 | 55 | nStages = [16, 16*k, 32*k, 64*k] 56 | 57 | self.conv1 = conv3x3(3, nStages[0]) 58 | self.layer1 = self._wide_layer(WideBasic, nStages[1], n, dropout_rate, stride=1) 59 | self.layer2 = self._wide_layer(WideBasic, nStages[2], n, dropout_rate, stride=2) 60 | self.layer3 = self._wide_layer(WideBasic, nStages[3], n, dropout_rate, stride=2) 61 | self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9) 62 | self.linear = nn.Linear(nStages[3], num_classes) 63 | 64 | # self.apply(conv_init) 65 | 66 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride): 67 | strides = [stride] + [1]*(num_blocks-1) 68 | layers = [] 69 | 70 | for stride in strides: 71 | layers.append(block(self.in_planes, planes, dropout_rate, stride)) 72 | self.in_planes = planes 73 | 74 | return nn.Sequential(*layers) 75 | 76 | def forward(self, x): 77 | out = self.conv1(x) 78 | out = self.layer1(out) 79 | out = self.layer2(out) 80 | out = self.layer3(out) 81 | out = F.relu(self.bn1(out)) 82 | # out = F.avg_pool2d(out, 8) 83 | out = F.adaptive_avg_pool2d(out, (1, 1)) 84 | out = out.view(out.size(0), -1) 85 | out = self.linear(out) 86 | 87 | return out 88 | -------------------------------------------------------------------------------- /lib/teachaugment.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class TeachAugment(nn.Module): 7 | """ 8 | Args: 9 | model: nn.Module 10 | the target model 11 | ema_model: nn.Module 12 | exponential moving average of the target model 13 | trainable_aug: nn.Module 14 | augmentation model 15 | adv_criterion: 16 | criterion for the adversarial loss 17 | weight_decay: float 18 | coefficient for weight decay 19 | base_aug: nn.Module 20 | baseline augmentation 21 | """ 22 | def __init__( 23 | self, model, ema_model, 24 | trainable_aug, adv_criterion, 25 | weight_decay=0, base_aug=None, 26 | normalizer=None, save_memory=False): 27 | super().__init__() 28 | # model 29 | self.model = model 30 | self.ema_model = ema_model 31 | # loss 32 | self.adv_criterion = adv_criterion 33 | self.weight_decay = weight_decay 34 | # augmentation 35 | self.trainable_aug = trainable_aug 36 | self.base_aug = base_aug 37 | self.normalizer = normalizer 38 | # misc 39 | self.save_memory = save_memory 40 | 41 | def manual_weight_decay(self, model, coef): 42 | return coef * (1. / 2.) * sum([params.pow(2).sum() 43 | for name, params in model.named_parameters() 44 | if not ('_bn' in name or '.bn' in name)]) 45 | 46 | def forward(self, x, y, c=None, loss='cls'): 47 | """ 48 | Args: 49 | x: torch.Tensor 50 | images 51 | y: torch.Tensor 52 | labels 53 | c: torch.Tensor 54 | context vector (optional) 55 | loss: str 56 | loss type 57 | - cls 58 | loss for the target model 59 | - aug 60 | loss for the augmentation model (sum of following loss_adv and loss_tea) 61 | - loss_adv 62 | adversarial loss 63 | - loss_tea 64 | loss for the teacher model 65 | """ 66 | if loss == 'cls': 67 | return self.loss_classifier(x, y, c) 68 | elif loss == 'aug': 69 | return self.loss_augmentation(x, y, c) 70 | elif loss == 'loss_adv': 71 | return self.loss_adversarial(x, y, c) 72 | elif loss == 'loss_tea': 73 | return self.loss_teacher(x, y, c) 74 | else: 75 | raise NotImplementedError 76 | 77 | def loss_classifier(self, x, y, c=None): 78 | # augmentation 79 | with torch.no_grad(): 80 | aug_x, _ = self.trainable_aug(x, c) 81 | if self.base_aug is not None: 82 | inputs = torch.stack([self.base_aug(_x) for _x in aug_x]) 83 | else: 84 | inputs = aug_x 85 | # calculate loss 86 | pred = self.model(inputs) 87 | loss = F.cross_entropy(pred, y) 88 | res = {'loss cls.': loss.item()} 89 | if self.weight_decay > 0: 90 | loss += self.manual_weight_decay(self.model, self.weight_decay) 91 | return loss, res, aug_x 92 | 93 | def loss_augmentation(self, x, y, c=None): 94 | # avoid updating bn running stats because the stats has been updated in loss_classifier. 95 | self.stop_bn_track_running_stats(self.model) 96 | # augmentation 97 | x, c_reg = self.trainable_aug(x, c, update=True) 98 | if self.normalizer is not None: 99 | x = self.normalizer(x) 100 | # calculate loss 101 | tar_pred = self.model(x) 102 | loss_adv = self.adv_criterion(tar_pred, y) 103 | # compute gradient to release the memory for the computational graph 104 | # NOTE: save_memory does NOT work for DDP. 105 | # Under DDP, computing loss_tea and loss_adv independently using loss_teacher and loss_adversarial 106 | # see main.py l130-l138 for more details 107 | if self.save_memory: 108 | grad = torch.autograd.grad(loss_adv, x)[0] 109 | x.backward(grad, retain_graph=True) 110 | tea_pred = self.ema_model(x) 111 | loss_tea = F.cross_entropy(tea_pred, y) 112 | # accuracy 113 | with torch.no_grad(): 114 | teacher_acc = (tea_pred.argmax(1) == y).float().mean() 115 | target_acc = (tar_pred.argmax(1) == y).float().mean() 116 | 117 | res = {'loss adv.': loss_adv.item(), 118 | 'loss teacher': loss_tea.item(), 119 | 'color reg.': c_reg.item(), 120 | 'acc.': target_acc.item(), 121 | 'acc. teacher': teacher_acc.item()} 122 | 123 | self.activate_bn_track_running_stats(self.model) 124 | 125 | if self.save_memory: 126 | return loss_tea + c_reg, res 127 | 128 | return loss_adv + loss_tea + c_reg, res 129 | 130 | def loss_adversarial(self, x, y, c=None): 131 | # avoid updating bn running stats twice with the same samples 132 | self.stop_bn_track_running_stats(self.model) 133 | # augmentation 134 | x, c_reg = self.trainable_aug(x, c, update=True) 135 | if self.normalizer is not None: 136 | x = self.normalizer(x) 137 | # calculate loss 138 | tar_pred = self.model(x) 139 | loss_adv = self.adv_criterion(tar_pred, y) 140 | # accuracy 141 | with torch.no_grad(): 142 | acc = (tar_pred.argmax(1) == y).float().mean() 143 | 144 | self.activate_bn_track_running_stats(self.model) 145 | 146 | return loss_adv, c_reg, acc 147 | 148 | def loss_teacher(self, x, y, c=None): 149 | # augmentation 150 | x, c_reg = self.trainable_aug(x, c, update=True) 151 | if self.normalizer is not None: 152 | x = self.normalizer(x) 153 | # calculate loss 154 | tea_pred = self.ema_model(x) 155 | loss_tea = F.cross_entropy(tea_pred, y) 156 | # accuracy 157 | with torch.no_grad(): 158 | acc = (tea_pred.argmax(1) == y).float().mean() 159 | 160 | return loss_tea, c_reg, acc 161 | 162 | def stop_bn_track_running_stats(self, model): 163 | for m in model.modules(): 164 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)): 165 | m.track_running_stats = False 166 | 167 | def activate_bn_track_running_stats(self, model): 168 | for m in model.modules(): 169 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)): 170 | m.track_running_stats = True 171 | -------------------------------------------------------------------------------- /lib/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DensoITLab/TeachAugment/66ec099a0afab99e18531c5437182cfe17dc30c8/lib/utils/__init__.py -------------------------------------------------------------------------------- /lib/utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from collections import Counter 4 | from torch.optim.lr_scheduler import _LRScheduler 5 | 6 | 7 | class CosineAnnealingWithLinearWarmup(_LRScheduler): 8 | def __init__( 9 | self, optimizer, warmup_epoch, 10 | T_max, adjust_epoch=False, eta_min=0, 11 | last_epoch=-1, verbose=False): 12 | self.warmup_epoch = warmup_epoch 13 | self.T_max = T_max 14 | self.eta_min = eta_min 15 | self.adjust_epoch = adjust_epoch 16 | super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose) 17 | 18 | def get_lr(self): 19 | if self.last_epoch <= self.warmup_epoch: # linear warmup 20 | rate = self.last_epoch / self.warmup_epoch 21 | return [lr * rate for lr in self.base_lrs] 22 | else: # cosine annealing 23 | cur_epoch = self.last_epoch - self.warmup_epoch 24 | if self.adjust_epoch: 25 | max_epoch = self.T_max - self.warmup_epoch 26 | else: 27 | max_epoch = self.T_max 28 | rate = (1 + math.cos(cur_epoch / max_epoch * math.pi)) 29 | return [self.eta_min + 0.5 * (lr - self.eta_min) * rate for lr in self.base_lrs] 30 | 31 | 32 | class MultiStepLRWithLinearWarmup(_LRScheduler): 33 | def __init__(self, optimizer, warmup_epoch, milestones, gamma, adjust_epoch=False, last_epoch=-1, verbose=False): 34 | self.warmup_epoch = warmup_epoch 35 | self.milestones = Counter(milestones) 36 | self.gamma = gamma 37 | self.adjust_epoch = adjust_epoch 38 | super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose) 39 | 40 | def get_lr(self): 41 | if self.last_epoch <= self.warmup_epoch: # linear warmup 42 | rate = self.last_epoch / self.warmup_epoch 43 | return [lr * rate for lr in self.base_lrs] 44 | else: # multi step lr decay 45 | if self.adjust_epoch: 46 | cur_epoch = self.last_epoch - self.warmup_epoch 47 | else: 48 | cur_epoch = self.last_epoch 49 | if cur_epoch not in self.milestones: 50 | return [group['lr'] for group in self.optimizer.param_groups] 51 | return [group['lr'] * self.gamma ** self.milestones[cur_epoch] 52 | for group in self.optimizer.param_groups] 53 | -------------------------------------------------------------------------------- /lib/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import yaml 4 | import json 5 | import logging 6 | 7 | import torch 8 | 9 | 10 | def set_seed(seed): 11 | import random 12 | import numpy.random 13 | random.seed(seed) 14 | numpy.random.seed(seed) 15 | torch.manual_seed(seed) 16 | 17 | 18 | def setup_ddp(args): 19 | torch.cuda.set_device(args.local_rank) 20 | if getattr(args, 'port', None) is not None: 21 | torch.distributed.init_process_group( 22 | backend='nccl', 23 | init_method=f'tcp://127.0.0.1:{args.port}', 24 | world_size=args.world_size, 25 | rank=args.local_rank, 26 | ) 27 | else: 28 | torch.distributed.init_process_group(backend='nccl', init_method='env://') 29 | args.world_size = int(os.environ['WORLD_SIZE']) 30 | args.num_workers //= args.world_size 31 | args.lr *= args.world_size 32 | return args 33 | 34 | 35 | def accuracy(output, target, topk=(1,)): 36 | """Computes the accuracy over the k top predictions for the specified values of k""" 37 | with torch.no_grad(): 38 | maxk = max(topk) 39 | batch_size = target.size(0) 40 | 41 | _, pred = output.topk(maxk, 1, True, True) 42 | pred = pred.t() 43 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 44 | 45 | res = [] 46 | for k in topk: 47 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 48 | res.append(correct_k.mul_(100.0 / batch_size).item()) 49 | return res 50 | 51 | 52 | def setup_logger(log_dir=None, resume=False): 53 | plain_formatter = logging.Formatter( 54 | '[%(asctime)s] %(name)s %(levelname)s: %(message)s', datefmt='%m/%d %H:%M:%S' 55 | ) 56 | logger = logging.getLogger() # root logger 57 | logger.setLevel(logging.INFO) 58 | s_handler = logging.StreamHandler(stream=sys.stdout) 59 | s_handler.setFormatter(plain_formatter) 60 | s_handler.setLevel(logging.INFO) 61 | logger.addHandler(s_handler) 62 | if log_dir is not None: 63 | os.makedirs(log_dir, exist_ok=True) 64 | if not resume and os.path.exists(os.path.join(log_dir, 'console.log')): 65 | os.remove(os.path.join(log_dir, 'console.log')) 66 | f_handler = logging.FileHandler(os.path.join(log_dir, 'console.log')) 67 | f_handler.setFormatter(plain_formatter) 68 | f_handler.setLevel(logging.INFO) 69 | logger.addHandler(f_handler) 70 | 71 | 72 | class AvgMeter: 73 | def __init__(self, ema_coef=0.9): 74 | self.ema_coef = ema_coef 75 | self.ema_params = {} 76 | self.sum_params = {} 77 | self.counter = {} 78 | 79 | def add(self, params:dict, ignores:list = []): 80 | for k, v in params.items(): 81 | if k in ignores: 82 | continue 83 | if not k in self.ema_params.keys(): 84 | self.ema_params[k] = v 85 | self.counter[k] = 1 86 | else: 87 | self.ema_params[k] -= (1 - self.ema_coef) * (self.ema_params[k] - v) 88 | self.counter[k] += 1 89 | if not k in self.sum_params.keys(): 90 | self.sum_params[k] = v 91 | else: 92 | self.sum_params[k] += v 93 | 94 | def state(self, header="", footer="", ignore_keys=None): 95 | if ignore_keys is None: 96 | ignore_keys = set() 97 | state = header 98 | for k, v in self.ema_params.items(): 99 | if k in ignore_keys: 100 | continue 101 | state += f" {k} {v:.6g} |" 102 | return state + " " + footer 103 | 104 | def mean_state(self, header="", footer=""): 105 | state = header 106 | for k, v in self.sum_params.items(): 107 | state += f" {k} {v/self.counter[k]:.6g} |" 108 | self.counter[k] = 0 109 | state += footer 110 | 111 | self.sum_params = {} 112 | 113 | return state 114 | 115 | def reset(self): 116 | self.ema_params = {} 117 | self.sum_params = {} 118 | self.counter = {} 119 | 120 | 121 | def override_config(args, dict_param): 122 | for k, v in dict_param.items(): 123 | if isinstance(v, dict): 124 | args = override_config(args, v) 125 | else: 126 | setattr(args, k, v) 127 | return args 128 | 129 | 130 | def load_yaml(path): 131 | with open(path, 'r') as f: 132 | d = yaml.safe_load(f) 133 | return d 134 | 135 | 136 | def load_json(path): 137 | with open(path, 'r') as f: 138 | d = json.load(f) 139 | return d 140 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import warnings 4 | 5 | logger = logging.getLogger(__name__) 6 | warnings.simplefilter('ignore', UserWarning) 7 | 8 | import torch 9 | import torch.optim as optim 10 | 11 | from torch.utils.data import DataLoader 12 | from torchvision.utils import save_image 13 | 14 | from lib import augmentation, build_dataset, teachaugment 15 | from lib.utils import utils, lr_scheduler 16 | from lib.models import build_model 17 | from lib.losses import non_saturating_loss 18 | 19 | 20 | def main(args): 21 | main_process = args.local_rank == 0 22 | if main_process: 23 | logger.info(args) 24 | # Setup GPU 25 | if torch.cuda.is_available(): 26 | device = 'cuda' 27 | if args.disable_cudnn: 28 | # torch.nn.functional.grid_sample, which is used for geometric augmentation, is non-deterministic 29 | # so, reproducibility is not ensured even though following option is True 30 | torch.backends.cudnn.deterministic = True 31 | else: 32 | torch.backends.cudnn.benchmark = True 33 | else: 34 | raise NotImplementedError('CUDA is unavailable.') 35 | # Dataset 36 | base_aug, train_trans, val_trans, normalizer = augmentation.get_transforms(args.dataset) 37 | train_data, eval_data, n_classes = build_dataset(args.dataset, args.root, train_trans, val_trans) 38 | sampler = torch.utils.data.DistributedSampler(train_data, num_replicas=args.world_size, rank=args.local_rank) if args.dist else None 39 | train_loader = DataLoader(train_data, args.batch_size, not args.dist, sampler, 40 | num_workers=args.num_workers, pin_memory=True, 41 | drop_last=True) 42 | eval_loader = DataLoader(eval_data, 1) 43 | # Model 44 | model = build_model(args.model, n_classes).to(device) 45 | model.train() 46 | # EMA Teacher 47 | avg_fn = lambda averaged_model_parameter, model_parameter, num_averaged: \ 48 | args.ema_rate * averaged_model_parameter + (1 - args.ema_rate) * model_parameter 49 | ema_model = optim.swa_utils.AveragedModel(model, avg_fn=avg_fn) 50 | for ema_p in ema_model.parameters(): 51 | ema_p.requires_grad_(False) 52 | ema_model.train() 53 | # Trainable Augmentation 54 | rbuffer = augmentation.replay_buffer.ReplayBuffer(args.rb_decay) 55 | trainable_aug = augmentation.build_augmentation(n_classes, args.g_scale, args.c_scale, 56 | args.c_reg_coef, normalizer, rbuffer, 57 | args.batch_size // args.group_size, 58 | not args.wo_context).to(device) 59 | # Baseline augmentation 60 | base_aug = torch.nn.Sequential(*base_aug).to(device) 61 | if main_process: 62 | logger.info('augmentation') 63 | logger.info(trainable_aug) 64 | logger.info(base_aug) 65 | # Optimizer 66 | optim_cls = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, nesterov=True, weight_decay=0) 67 | optim_aug = optim.AdamW(trainable_aug.parameters(), lr=args.aug_lr, weight_decay=args.aug_weight_decay) 68 | if args.dataset == 'ImageNet': 69 | scheduler = lr_scheduler.MultiStepLRWithLinearWarmup(optim_cls, 5, [90, 180, 240], 0.1) 70 | else: 71 | scheduler = lr_scheduler.CosineAnnealingWithLinearWarmup(optim_cls, 5, args.n_epochs) 72 | 73 | # Following Fast AutoAugment (https://github.com/kakaobrain/fast-autoaugment), 74 | # pytorch-gradual-warmup-lr (https://github.com/ildoonet/pytorch-gradual-warmup-lr) was used for the paper experiments. 75 | # The implementation of our "*WithLinearWarmup" is slightly different from GradualWarmupScheduler. 76 | # Thus, to reproduce experimental results strictly, please use following scheduler, instead of above scheduler. 77 | 78 | # Don't forget to install pytorch-gradual-warmup-lr 79 | # pip install git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git 80 | 81 | # from warmup_scheduler import GradualWarmupScheduler 82 | # if args.dataset == 'ImageNet': 83 | # base_scheduler = optim.lr_scheduler.MultiStepLR(optim_cls, [90, 180, 240], 0.1) 84 | # scheduler = GradualWarmupScheduler(optim_cls, 1, 5, base_scheduler) 85 | # else: 86 | # base_scheduler = optim.lr_scheduler.CosineAnnealingLR(optim_cls, args.n_epochs) 87 | # scheduler = GradualWarmupScheduler(optim_cls, 1, 5, base_scheduler) 88 | 89 | # Objective function 90 | adv_criterion = non_saturating_loss.NonSaturatingLoss(args.epsilon) 91 | objective = teachaugment.TeachAugment(model, ema_model, trainable_aug, 92 | adv_criterion, args.weight_decay, 93 | base_aug, normalizer, not args.dist and args.save_memory).to(device) 94 | # DDP 95 | if args.dist: 96 | objective = torch.nn.parallel.DistributedDataParallel(objective, device_ids=[args.local_rank], 97 | output_device=args.local_rank, find_unused_parameters=True) 98 | # Resume 99 | st_epoch = 1 100 | if args.resume: 101 | checkpoint = torch.load(os.path.join(args.log_dir, 'checkpoint.pth')) 102 | st_epoch += checkpoint['epoch'] 103 | if main_process: 104 | logger.info(f'resume from epoch {st_epoch}') 105 | buffer_length = checkpoint['epoch'] // args.sampling_freq 106 | rbuffer.initialize(buffer_length, trainable_aug.get_augmentation_model()) # define placeholder for load_state_dict 107 | objective.load_state_dict(checkpoint['objective']) # including model, ema teacher, trainable_aug, and replay buffer 108 | optim_cls.load_state_dict(checkpoint['optim_cls']) 109 | optim_aug.load_state_dict(checkpoint['optim_aug']) 110 | scheduler.load_state_dict(checkpoint['scheduler']) 111 | # Training loop 112 | if main_process: 113 | logger.info('training') 114 | meter = utils.AvgMeter() 115 | for epoch in range(st_epoch, args.n_epochs + 1): 116 | if args.dist: 117 | train_loader.sampler.set_epoch(epoch) 118 | for i, data in enumerate(train_loader): 119 | torch.cuda.synchronize() 120 | inputs, targets = data 121 | inputs, targets = inputs.to(device), targets.to(device) 122 | if args.wo_context: 123 | context = None 124 | else: 125 | context = targets 126 | # update teacher model 127 | ema_model.update_parameters(model) 128 | # Update augmentation 129 | if i % args.n_inner == 0: 130 | optim_aug.zero_grad() 131 | if args.dist and args.save_memory: # computating gradient independently for saving memory 132 | loss_adv, c_reg, acc_tar = objective(inputs, targets, context, 'loss_adv') 133 | (loss_adv + 0.5 * c_reg).backward() 134 | loss_tea, c_reg, acc_tea = objective(inputs, targets, context, 'loss_tea') 135 | (loss_tea + 0.5 * c_reg).backward() 136 | res = {'loss adv.': loss_adv.item(), 137 | 'loss teacher': loss_tea.item(), 138 | 'color reg.': c_reg.item(), 139 | 'acc.': acc_tar.item(), 140 | 'acc. teacher': acc_tea.item()} 141 | else: 142 | loss_aug, res = objective(inputs, targets, context, 'aug') 143 | loss_aug.backward() 144 | optim_aug.step() 145 | meter.add(res) 146 | # Update target model 147 | optim_cls.zero_grad() 148 | loss_cls, res, aug_img = objective(inputs, targets, context, 'cls') 149 | loss_cls.backward() 150 | if args.dataset != 'ImageNet': 151 | torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0) 152 | optim_cls.step() 153 | # Adjust learning rate 154 | scheduler.step(epoch - 1. + (i + 1.) / len(train_loader)) 155 | # Print losses and accuracy 156 | meter.add(res) 157 | if main_process and (i + 1) % args.print_freq == 0: 158 | logger.info(meter.state(f'epoch {epoch} [{i+1}/{len(train_loader)}]', 159 | f'lr {optim_cls.param_groups[0]["lr"]:.4e}')) 160 | # Store augmentation in buffer 161 | if args.sampling_freq > 0 and epoch % args.sampling_freq == 0: 162 | rbuffer.store(trainable_aug.get_augmentation_model()) 163 | if main_process: 164 | logger.info(f'store augmentation (buffer length: {len(rbuffer)})') 165 | # Save checkpoint 166 | if main_process: 167 | logger.info(meter.mean_state(f'epoch [{epoch}/{args.n_epochs}]', 168 | f'lr {optim_cls.param_groups[0]["lr"]:.4e}')) 169 | checkpoint = {'model': model.state_dict(), 170 | 'objective': objective.state_dict(), # including ema model and replay buffer 171 | 'optim_cls': optim_cls.state_dict(), 172 | 'optim_aug': optim_aug.state_dict(), 173 | 'scheduler': scheduler.state_dict(), 174 | 'epoch': epoch} 175 | torch.save(checkpoint, os.path.join(args.log_dir, 'checkpoint.pth')) 176 | # Save augmented images 177 | if args.vis: 178 | save_image(aug_img, os.path.join(args.log_dir, f'{epoch}epoch_aug_img.png')) 179 | save_image(inputs, os.path.join(args.log_dir, f'{epoch}epoch_img.png')) 180 | # Evaluation 181 | if main_process: 182 | logger.info('evaluation') 183 | acc1, acc5 = 0, 0 184 | model.eval() 185 | n_samples = len(eval_loader) 186 | with torch.no_grad(): 187 | for data in eval_loader: 188 | input, target = data 189 | output = model(input.to(device)) 190 | accs = utils.accuracy(output, target.to(device), (1, 5)) 191 | acc1 += accs[0] 192 | acc5 += accs[1] 193 | logger.info(f'{args.dataset} error rate (%) | Top1 {100 - acc1/n_samples} | Top5 {100 - acc5/n_samples}') 194 | 195 | 196 | if __name__ == '__main__': 197 | import argparse 198 | 199 | parser = argparse.ArgumentParser() 200 | # Dataset 201 | parser.add_argument('--dataset', default='CIFAR10', choices=['CIFAR10', 'CIFAR100', 'ImageNet']) 202 | parser.add_argument('--root', default='./data', type=str, 203 | help='/path/to/dataset') 204 | # Model 205 | parser.add_argument('--model', default='wrn-28-10', type=str) 206 | # Optimization 207 | parser.add_argument('--lr', default=0.1, type=float, 208 | help='learning rate') 209 | parser.add_argument('--weight_decay', '-wd', default=5e-4, type=float) 210 | parser.add_argument('--n_epochs', default=200, type=int) 211 | parser.add_argument('--batch_size', '-bs', default=128, type=int) 212 | parser.add_argument('--aug_lr', default=1e-3, type=float, 213 | help='learning rate for augmentation model') 214 | parser.add_argument('--aug_weight_decay', '-awd', default=1e-2, type=float, 215 | help='weight decay for augmentation model') 216 | # Augmentation 217 | parser.add_argument('--g_scale', default=0.5, type=float, 218 | help='the search range of the magnitude of geometric augmantation') 219 | parser.add_argument('--c_scale', default=0.8, type=float, 220 | help='the search range of the magnitude of color augmantation') 221 | parser.add_argument('--group_size', default=8, type=int) 222 | parser.add_argument('--wo_context', action='store_true', 223 | help='without context vector as input') 224 | # TeachAugment 225 | parser.add_argument('--n_inner', default=5, type=int, 226 | help='the number of iterations for inner loop (i.e., updating classifier)') 227 | parser.add_argument('--ema_rate', default=0.999, type=float, 228 | help='decay rate for the ema teacher') 229 | # Improvement techniques 230 | parser.add_argument('--c_reg_coef', default=10, type=float, 231 | help='coefficient of the color regularization') 232 | parser.add_argument('--rb_decay', default=0.9, type=float, 233 | help='decay rate for replay buffer') 234 | parser.add_argument('--sampling_freq', default=10, type=int, 235 | help='sampling augmentation frequency') 236 | parser.add_argument('--epsilon', default=0.1, type=float, 237 | help='epsilon for the label smoothing') 238 | # Distributed data parallel 239 | parser.add_argument('--dist', action='store_true', 240 | help='use distributed data parallel') 241 | parser.add_argument('--local_rank', default=0, type=int) 242 | parser.add_argument('--world_size', '-ws', default=1, type=int) 243 | parser.add_argument('--port', default=None, type=str) 244 | # Misc 245 | parser.add_argument('--seed', default=0, type=int) 246 | parser.add_argument('--print_freq', default=100, type=int) 247 | parser.add_argument('--log_dir', default='./log', type=str) 248 | parser.add_argument('--disable_cudnn', action='store_true', 249 | help='disable cudnn for reproducibility') 250 | parser.add_argument('--resume', action='store_true', 251 | help='resume training') 252 | parser.add_argument('--num_workers', '-j', default=8, type=int, 253 | help='the number of data loading workers') 254 | parser.add_argument('--vis', action='store_true', 255 | help='visualize augmented images') 256 | parser.add_argument('--save_memory', action='store_true', 257 | help='independently calculate adversarial loss \ 258 | and teacher loss for saving memory') 259 | parser.add_argument('--yaml', default=None, type=str, 260 | help='given path to .json, parse from .yaml') 261 | parser.add_argument('--json', default=None, type=str, 262 | help='given path to .json, parse from .json') 263 | 264 | args = parser.parse_args() 265 | 266 | # override args 267 | if args.yaml is not None: 268 | yaml_cfg = utils.load_yaml(args.yaml) 269 | args = utils.override_config(args, yaml_cfg) 270 | if args.json is not None: 271 | json_cfg = utils.load_json(args.json) 272 | args = utils.override_config(args, json_cfg) 273 | 274 | utils.set_seed(args.seed) 275 | if args.local_rank == 0: 276 | utils.setup_logger(args.log_dir, args.resume) 277 | if args.dist: 278 | utils.setup_ddp(args) 279 | 280 | main(args) 281 | --------------------------------------------------------------------------------