├── README.md ├── __init__.py ├── autoaugment.py ├── focal_loss.py ├── labelsmoothing.py ├── mixup.py ├── randomerasing.py └── scheduler.py /README.md: -------------------------------------------------------------------------------- 1 | # Simple_Tool_Pytorch 2 | 3 | ## Getting Started 4 | ```Shell 5 | cd (yourprojectdir) 6 | git clone https://github.com/cjf8899/simple_tool_pytorch.git 7 | 8 | ``` 9 | ## Summary 10 | - [Auto-Augment](https://github.com/cjf8899/simple_tool_pytorch#Auto-Augment) 11 | - [Warmup-Cosine-Lr](https://github.com/cjf8899/simple_tool_pytorch#Warmup-Cosine-Lr) 12 | - [Mixup](https://github.com/cjf8899/simple_tool_pytorch#Mixup) 13 | - [Label-Smoothing](https://github.com/cjf8899/simple_tool_pytorch#Label-Smoothing) 14 | - [Random-erasing-augmentation](https://github.com/cjf8899/simple_tool_pytorch#Random-erasing-augmentation) 15 | - [Focal-Loss](https://github.com/cjf8899/simple_tool_pytorch#Focal-Loss) 16 | 17 | 18 | ## Auto-Augment 19 | 20 | ```python 21 | from simple_tool_pytorch import ImageNetPolicy, CIFAR10Policy, SVHNPolicy 22 | ... 23 | 24 | data = ImageFolder(rootdir, transform=transforms.Compose( 25 | [transforms.RandomResizedCrop(256), 26 | transforms.RandomHorizontalFlip(), 27 | ImageNetPolicy(), # CIFAR10Policy(), SVHNPolicy() 28 | transforms.ToTensor(), 29 | transforms.Normalize(...)])) 30 | loader = DataLoader(data, ...) 31 | ... 32 | ``` 33 | source : https://github.com/DeepVoltaire/AutoAugment 34 | 35 | ## Warmup-Cosine-Lr 36 | ```python 37 | from simple_tool_pytorch import GradualWarmupScheduler 38 | ... 39 | 40 | criterion = nn.CrossEntropyLoss() 41 | optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9) 42 | cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, eta_min=0, last_epoch=-1) 43 | scheduler = GradualWarmupScheduler(optim, multiplier=10, total_epoch=5, after_scheduler=cosine_scheduler) 44 | 45 | for i, (images, labels) in enumerate(train_data): 46 | ... 47 | 48 | scheduler.step(epoch) # Last position 49 | ``` 50 | source : https://github.com/seominseok0429/pytorch-warmup-cosine-lr 51 | 52 | ## Mixup 53 | 54 | ```python 55 | from simple_tool_pytorch import mixup_data, mixup_criterion 56 | ... 57 | 58 | alpha = 0.2 # set beta distributed parm, 0.2 is recommend. 59 | criterion = torch.nn.CrossEntropyLoss() 60 | optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9) 61 | 62 | for i, (images, labels) in enumerate(train_data): 63 | images = images.cuda() 64 | labels = labels.cuda() 65 | 66 | data, labels_a, labels_b, lam = mixup_data(images, labels, alpha) 67 | optimizer.zero_grad() 68 | outputs = model(images) 69 | loss = mixup_criterion(criterion, outputs, labels_a, labels_b, lam) 70 | 71 | loss.backward() 72 | optimizer.update() 73 | ... 74 | ``` 75 | 76 | ## Label-Smoothing 77 | 78 | ```python 79 | from simple_tool_pytorch import LabelSmoothingCrossEntropy 80 | ... 81 | 82 | criterion = LabelSmoothingCrossEntropy() 83 | ... 84 | 85 | for i, (images, labels) in enumerate(train_data): 86 | ... 87 | 88 | loss = criterion(outputs, targets) 89 | loss.backward() 90 | optimizer.step() 91 | ... 92 | ``` 93 | 94 | source : https://github.com/seominseok0429/label-smoothing-visualization-pytorch 95 | 96 | ## Random-erasing-augmentation 97 | 98 | ```python 99 | from simple_tool_pytorch import RandomErasing 100 | ... 101 | 102 | erasing_percent = 0.5 103 | data = ImageFolder(rootdir, transform=transforms.Compose( 104 | [transforms.RandomHorizontalFlip(), 105 | transforms.ToTensor(), 106 | transforms.Normalize((0.4914, 0.4822, 0.4465), (...)), 107 | RandomErasing(probability=erasing_percent, mean=[0.4914, 0.4822, 0.4465])])) 108 | loader = DataLoader(data, ...) 109 | 110 | ``` 111 | 112 | ## Focal-Loss 113 | ex) Multiboxloss for SSD 114 | 115 | ```python 116 | from simple_tool_pytorch import FocalLoss 117 | ... 118 | 119 | ... 120 | pos_idx = pos.unsqueeze(2).expand_as(conf_data) 121 | neg_idx = neg.unsqueeze(2).expand_as(conf_data) 122 | 123 | 124 | conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1, self.num_classes) 125 | targets_weighted = conf_t[(pos+neg).gt(0)] 126 | 127 | ###Focal loss 128 | compute_c_loss = FocalLoss(alpha=None, gamma=2, class_num=num_classes, size_average=False) 129 | loss_c = compute_c_loss(conf_p, targets_weighted) 130 | 131 | # Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N 132 | 133 | N = num_pos.data.sum() 134 | loss_l /= N 135 | loss_c /= N 136 | return loss_l, loss_c 137 | 138 | 139 | ``` 140 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .scheduler import * 2 | from .autoaugment import * 3 | from .mixup import * 4 | from .labelsmoothing import * 5 | from .randomerasing import * 6 | from .focal_loss import * 7 | -------------------------------------------------------------------------------- /autoaugment.py: -------------------------------------------------------------------------------- 1 | 2 | from PIL import Image, ImageEnhance, ImageOps 3 | import numpy as np 4 | import random 5 | 6 | 7 | class ImageNetPolicy(object): 8 | """ Randomly choose one of the best 24 Sub-policies on ImageNet. 9 | Example: 10 | >>> policy = ImageNetPolicy() 11 | >>> transformed = policy(image) 12 | Example as a PyTorch Transform: 13 | >>> transform=transforms.Compose([ 14 | >>> transforms.Resize(256), 15 | >>> ImageNetPolicy(), 16 | >>> transforms.ToTensor()]) 17 | """ 18 | def __init__(self, fillcolor=(128, 128, 128)): 19 | self.policies = [ 20 | SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor), 21 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 22 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor), 23 | SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor), 24 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 25 | 26 | SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor), 27 | SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor), 28 | SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor), 29 | SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor), 30 | SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor), 31 | 32 | SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor), 33 | SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor), 34 | SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor), 35 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 36 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), 37 | 38 | SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor), 39 | SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor), 40 | SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor), 41 | SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor), 42 | SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor), 43 | 44 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 45 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 46 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 47 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), 48 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor) 49 | ] 50 | 51 | 52 | def __call__(self, img): 53 | policy_idx = random.randint(0, len(self.policies) - 1) 54 | return self.policies[policy_idx](img) 55 | 56 | def __repr__(self): 57 | return "AutoAugment ImageNet Policy" 58 | 59 | 60 | class CIFAR10Policy(object): 61 | """ Randomly choose one of the best 25 Sub-policies on CIFAR10. 62 | Example: 63 | >>> policy = CIFAR10Policy() 64 | >>> transformed = policy(image) 65 | Example as a PyTorch Transform: 66 | >>> transform=transforms.Compose([ 67 | >>> transforms.Resize(256), 68 | >>> CIFAR10Policy(), 69 | >>> transforms.ToTensor()]) 70 | """ 71 | def __init__(self, fillcolor=(128, 128, 128)): 72 | self.policies = [ 73 | SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor), 74 | SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor), 75 | SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor), 76 | SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor), 77 | SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor), 78 | 79 | SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor), 80 | SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor), 81 | SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor), 82 | SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor), 83 | SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor), 84 | 85 | SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor), 86 | SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor), 87 | SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor), 88 | SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor), 89 | SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor), 90 | 91 | SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor), 92 | SubPolicy(0.2, "equalize", 8, 0.6, "equalize", 4, fillcolor), 93 | SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor), 94 | SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor), 95 | SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor), 96 | 97 | SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor), 98 | SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor), 99 | SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor), 100 | SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor), 101 | SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor) 102 | ] 103 | 104 | 105 | def __call__(self, img): 106 | policy_idx = random.randint(0, len(self.policies) - 1) 107 | return self.policies[policy_idx](img) 108 | 109 | def __repr__(self): 110 | return "AutoAugment CIFAR10 Policy" 111 | 112 | 113 | class SVHNPolicy(object): 114 | """ Randomly choose one of the best 25 Sub-policies on SVHN. 115 | Example: 116 | >>> policy = SVHNPolicy() 117 | >>> transformed = policy(image) 118 | Example as a PyTorch Transform: 119 | >>> transform=transforms.Compose([ 120 | >>> transforms.Resize(256), 121 | >>> SVHNPolicy(), 122 | >>> transforms.ToTensor()]) 123 | """ 124 | def __init__(self, fillcolor=(128, 128, 128)): 125 | self.policies = [ 126 | SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor), 127 | SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor), 128 | SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor), 129 | SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor), 130 | SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor), 131 | 132 | SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor), 133 | SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor), 134 | SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor), 135 | SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor), 136 | SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor), 137 | 138 | SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor), 139 | SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor), 140 | SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor), 141 | SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor), 142 | SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor), 143 | 144 | SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor), 145 | SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor), 146 | SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor), 147 | SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor), 148 | SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor), 149 | 150 | SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor), 151 | SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor), 152 | SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor), 153 | SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor), 154 | SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor) 155 | ] 156 | 157 | 158 | def __call__(self, img): 159 | policy_idx = random.randint(0, len(self.policies) - 1) 160 | return self.policies[policy_idx](img) 161 | 162 | def __repr__(self): 163 | return "AutoAugment SVHN Policy" 164 | 165 | 166 | class SubPolicy(object): 167 | def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)): 168 | ranges = { 169 | "shearX": np.linspace(0, 0.3, 10), 170 | "shearY": np.linspace(0, 0.3, 10), 171 | "translateX": np.linspace(0, 150 / 331, 10), 172 | "translateY": np.linspace(0, 150 / 331, 10), 173 | "rotate": np.linspace(0, 30, 10), 174 | "color": np.linspace(0.0, 0.9, 10), 175 | "posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int), 176 | "solarize": np.linspace(256, 0, 10), 177 | "contrast": np.linspace(0.0, 0.9, 10), 178 | "sharpness": np.linspace(0.0, 0.9, 10), 179 | "brightness": np.linspace(0.0, 0.9, 10), 180 | "autocontrast": [0] * 10, 181 | "equalize": [0] * 10, 182 | "invert": [0] * 10 183 | } 184 | 185 | # from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand 186 | def rotate_with_fill(img, magnitude): 187 | rot = img.convert("RGBA").rotate(magnitude) 188 | return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode) 189 | 190 | func = { 191 | "shearX": lambda img, magnitude: img.transform( 192 | img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0), 193 | Image.BICUBIC, fillcolor=fillcolor), 194 | "shearY": lambda img, magnitude: img.transform( 195 | img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0), 196 | Image.BICUBIC, fillcolor=fillcolor), 197 | "translateX": lambda img, magnitude: img.transform( 198 | img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0), 199 | fillcolor=fillcolor), 200 | "translateY": lambda img, magnitude: img.transform( 201 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])), 202 | fillcolor=fillcolor), 203 | "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude), 204 | "color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])), 205 | "posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude), 206 | "solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude), 207 | "contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance( 208 | 1 + magnitude * random.choice([-1, 1])), 209 | "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance( 210 | 1 + magnitude * random.choice([-1, 1])), 211 | "brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance( 212 | 1 + magnitude * random.choice([-1, 1])), 213 | "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img), 214 | "equalize": lambda img, magnitude: ImageOps.equalize(img), 215 | "invert": lambda img, magnitude: ImageOps.invert(img) 216 | } 217 | 218 | self.p1 = p1 219 | self.operation1 = func[operation1] 220 | self.magnitude1 = ranges[operation1][magnitude_idx1] 221 | self.p2 = p2 222 | self.operation2 = func[operation2] 223 | self.magnitude2 = ranges[operation2][magnitude_idx2] 224 | 225 | 226 | def __call__(self, img): 227 | if random.random() < self.p1: img = self.operation1(img, self.magnitude1) 228 | if random.random() < self.p2: img = self.operation2(img, self.magnitude2) 229 | return img 230 | -------------------------------------------------------------------------------- /focal_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | class FocalLoss(nn.Module): 7 | r""" 8 | This criterion is a implemenation of Focal Loss, which is proposed in 9 | Focal Loss for Dense Object Detection. 10 | Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class]) 11 | The losses are averaged across observations for each minibatch. 12 | Args: 13 | alpha(1D Tensor, Variable) : the scalar factor for this criterion 14 | gamma(float, double) : gamma > 0; reduces the relative loss for well-clasified examples (p > .5), 15 | putting more focus on hard, misclassified examples 16 | size_average(bool): By default, the losses are averaged over observations for each minibatch. 17 | However, if the field size_average is set to False, the losses are 18 | instead summed for each minibatch. 19 | """ 20 | def __init__(self, alpha, gamma=2, class_num=5,size_average=False): 21 | super(FocalLoss, self).__init__() 22 | if alpha is None: 23 | self.alpha = Variable(torch.ones(class_num, 1)) 24 | else: 25 | if isinstance(alpha, Variable): 26 | self.alpha = alpha 27 | else: 28 | self.alpha = Variable(alpha) 29 | 30 | self.gamma = gamma 31 | 32 | # self.class_num = class_num 33 | self.size_average = size_average 34 | 35 | def forward(self, inputs, targets): 36 | N = inputs.size(0) # batch_size 37 | C = inputs.size(1) # channels 38 | P = F.softmax(inputs, dim=1) 39 | 40 | class_mask = inputs.data.new(N, C).fill_(0) 41 | class_mask = Variable(class_mask) 42 | ids = targets.view(-1, 1) 43 | class_mask.scatter_(1, ids.data, 1.) 44 | # print(class_mask) 45 | 46 | if inputs.is_cuda and not self.alpha.is_cuda: 47 | self.alpha = self.alpha.cuda() 48 | alpha = self.alpha[ids.data.view(-1)] 49 | 50 | probs = (P*class_mask).sum(1).view(-1, 1) 51 | 52 | log_p = probs.log() 53 | # print('probs size= {}'.format(probs.size())) 54 | # print(probs) 55 | 56 | batch_loss = -alpha*(torch.pow((1-probs), self.gamma))*log_p 57 | 58 | # print('-----bacth_loss------') 59 | # print(batch_loss) 60 | 61 | if self.size_average: 62 | loss = batch_loss.mean() 63 | else: 64 | loss = batch_loss.sum() 65 | return loss 66 | -------------------------------------------------------------------------------- /labelsmoothing.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | import time 5 | import math 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.init as init 9 | import torch.nn.functional as F 10 | from pathlib import Path 11 | 12 | class LabelSmoothingCrossEntropy(nn.Module): 13 | def __init__(self): 14 | super(LabelSmoothingCrossEntropy, self).__init__() 15 | def forward(self, x, target, smoothing=0.1): 16 | confidence = 1. - smoothing 17 | logprobs = F.log_softmax(x, dim=-1) 18 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 19 | nll_loss = nll_loss.squeeze(1) 20 | smooth_loss = -logprobs.mean(dim=-1) 21 | loss = confidence * nll_loss + smoothing * smooth_loss 22 | return loss.mean() 23 | 24 | def get_mean_and_std(dataset): 25 | '''Compute the mean and std value of dataset.''' 26 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 27 | mean = torch.zeros(3) 28 | std = torch.zeros(3) 29 | print('==> Computing mean and std..') 30 | for inputs, targets in dataloader: 31 | for i in range(3): 32 | mean[i] += inputs[:,i,:,:].mean() 33 | std[i] += inputs[:,i,:,:].std() 34 | mean.div_(len(dataset)) 35 | std.div_(len(dataset)) 36 | return mean, std 37 | 38 | def init_params(net): 39 | '''Init layer parameters.''' 40 | for m in net.modules(): 41 | if isinstance(m, nn.Conv2d): 42 | init.kaiming_normal(m.weight, mode='fan_out') 43 | if m.bias: 44 | init.constant(m.bias, 0) 45 | elif isinstance(m, nn.BatchNorm2d): 46 | init.constant(m.weight, 1) 47 | init.constant(m.bias, 0) 48 | elif isinstance(m, nn.Linear): 49 | init.normal(m.weight, std=1e-3) 50 | if m.bias: 51 | init.constant(m.bias, 0) 52 | 53 | 54 | _, term_width = os.popen('stty size', 'r').read().split() 55 | term_width = int(term_width) 56 | 57 | TOTAL_BAR_LENGTH = 65. 58 | last_time = time.time() 59 | begin_time = last_time 60 | def progress_bar(current, total, msg=None): 61 | global last_time, begin_time 62 | if current == 0: 63 | begin_time = time.time() # Reset for new bar. 64 | 65 | cur_len = int(TOTAL_BAR_LENGTH*current/total) 66 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 67 | 68 | sys.stdout.write(' [') 69 | for i in range(cur_len): 70 | sys.stdout.write('=') 71 | sys.stdout.write('>') 72 | for i in range(rest_len): 73 | sys.stdout.write('.') 74 | sys.stdout.write(']') 75 | 76 | cur_time = time.time() 77 | step_time = cur_time - last_time 78 | last_time = cur_time 79 | tot_time = cur_time - begin_time 80 | 81 | L = [] 82 | L.append(' Step: %s' % format_time(step_time)) 83 | L.append(' | Tot: %s' % format_time(tot_time)) 84 | if msg: 85 | L.append(' | ' + msg) 86 | 87 | msg = ''.join(L) 88 | sys.stdout.write(msg) 89 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 90 | sys.stdout.write(' ') 91 | 92 | # Go back to the center of the bar. 93 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): 94 | sys.stdout.write('\b') 95 | sys.stdout.write(' %d/%d ' % (current+1, total)) 96 | 97 | if current < total-1: 98 | sys.stdout.write('\r') 99 | else: 100 | sys.stdout.write('\n') 101 | sys.stdout.flush() 102 | 103 | def format_time(seconds): 104 | days = int(seconds / 3600/24) 105 | seconds = seconds - days*3600*24 106 | hours = int(seconds / 3600) 107 | seconds = seconds - hours*3600 108 | minutes = int(seconds / 60) 109 | seconds = seconds - minutes*60 110 | secondsf = int(seconds) 111 | seconds = seconds - secondsf 112 | millis = int(seconds*1000) 113 | 114 | f = '' 115 | i = 1 116 | if days > 0: 117 | f += str(days) + 'D' 118 | i += 1 119 | if hours > 0 and i <= 2: 120 | f += str(hours) + 'h' 121 | i += 1 122 | if minutes > 0 and i <= 2: 123 | f += str(minutes) + 'm' 124 | i += 1 125 | if secondsf > 0 and i <= 2: 126 | f += str(secondsf) + 's' 127 | i += 1 128 | if millis > 0 and i <= 2: 129 | f += str(millis) + 'ms' 130 | i += 1 131 | if f == '': 132 | f = '0ms' 133 | return f 134 | 135 | def save_model(model, model_path): 136 | if isinstance(model_path, Path): 137 | model_path = str(model_path) 138 | if isinstance(model, nn.DataParallel): 139 | model = model.module 140 | state_dict = model.state_dict() 141 | for key in state_dict: 142 | state_dict[key] = state_dict[key].cpu() 143 | torch.save(state_dict, model_path) 144 | -------------------------------------------------------------------------------- /mixup.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | 5 | __all__ = ['mixup_data', 'mixup_criterion'] 6 | 7 | @torch.no_grad() 8 | def mixup_data(x, y, alpha=0.2): 9 | """Returns mixed inputs, pairs of targets, and lambda 10 | """ 11 | if alpha > 0: 12 | lam = np.random.beta(alpha, alpha) 13 | else: 14 | lam = 1 15 | 16 | mixed_x = lam * x + (1 - lam) * x.flip(dims=(0,)) 17 | y_a, y_b = y, y.flip(dims=(0,)) 18 | return mixed_x, y_a, y_b, lam 19 | 20 | 21 | def mixup_criterion(criterion, pred, y_a, y_b, lam): 22 | return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b) 23 | -------------------------------------------------------------------------------- /randomerasing.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | import random 4 | 5 | 6 | class RandomErasing(object): 7 | """ Randomly selects a rectangle region in an image and erases its pixels. 8 | 'Random Erasing Data Augmentation' by Zhong et al. 9 | See https://arxiv.org/pdf/1708.04896.pdf 10 | Args: 11 | probability: The probability that the Random Erasing operation will be performed. 12 | sl: Minimum proportion of erased area against input image. 13 | sh: Maximum proportion of erased area against input image. 14 | r1: Minimum aspect ratio of erased area. 15 | mean: Erasing value. 16 | """ 17 | 18 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)): 19 | self.probability = probability 20 | self.mean = mean 21 | self.sl = sl 22 | self.sh = sh 23 | self.r1 = r1 24 | 25 | def __call__(self, img): 26 | 27 | if random.uniform(0, 1) >= self.probability: 28 | return img 29 | 30 | for attempt in range(100): 31 | area = img.size()[1] * img.size()[2] 32 | 33 | target_area = random.uniform(self.sl, self.sh) * area 34 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) 35 | 36 | h = int(round(math.sqrt(target_area * aspect_ratio))) 37 | w = int(round(math.sqrt(target_area / aspect_ratio))) 38 | 39 | if w < img.size()[2] and h < img.size()[1]: 40 | x1 = random.randint(0, img.size()[1] - h) 41 | y1 = random.randint(0, img.size()[2] - w) 42 | if img.size()[0] == 3: 43 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 44 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] 45 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] 46 | else: 47 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 48 | return img 49 | 50 | return img 51 | -------------------------------------------------------------------------------- /scheduler.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.optim.lr_scheduler import _LRScheduler 3 | from torch.optim.lr_scheduler import ReduceLROnPlateau 4 | import torch 5 | import matplotlib.pyplot as plt 6 | 7 | class GradualWarmupScheduler(_LRScheduler): 8 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): 9 | self.multiplier = multiplier 10 | self.total_epoch = total_epoch 11 | self.after_scheduler = after_scheduler 12 | self.finished = False 13 | super().__init__(optimizer) 14 | 15 | def get_lr(self): 16 | if self.last_epoch > self.total_epoch: 17 | if self.after_scheduler: 18 | if not self.finished: 19 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] 20 | self.finished = True 21 | return self.after_scheduler.get_lr() 22 | return [base_lr * self.multiplier for base_lr in self.base_lrs] 23 | 24 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 25 | 26 | 27 | def step(self, epoch=None, metrics=None): 28 | if self.finished and self.after_scheduler: 29 | if epoch is None: 30 | self.after_scheduler.step(None) 31 | else: 32 | self.after_scheduler.step(epoch - self.total_epoch) 33 | else: 34 | return super(GradualWarmupScheduler, self).step(epoch) 35 | 36 | if __name__ == '__main__': 37 | v = torch.zeros(10) 38 | optim = torch.optim.SGD([v], lr=0.01) 39 | cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, 100, eta_min=0, last_epoch=-1) 40 | scheduler = GradualWarmupScheduler(optim, multiplier=8, total_epoch=5, after_scheduler=cosine_scheduler) 41 | a = [] 42 | b = [] 43 | for epoch in range(1, 100): 44 | scheduler.step(epoch) 45 | a.append(epoch) 46 | b.append(optim.param_groups[0]['lr']) 47 | print(epoch, optim.param_groups[0]['lr']) 48 | 49 | plt.plot(a,b) 50 | plt.show() 51 | --------------------------------------------------------------------------------