├── ImageNetLoad.py ├── README.md ├── models ├── FocalLoss.py ├── LabelSmoothing.py ├── OLS.py └── model.py ├── plot_result.py ├── train.py └── utils.py /ImageNetLoad.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import dataset 2 | from PIL import Image 3 | from torchvision import transforms 4 | import numpy as np 5 | import os 6 | import albumentations as A 7 | from albumentations.pytorch import ToTensorV2 8 | 9 | 10 | img_root = '/path/to/ImageNet/Image' 11 | devkit = '/path/to/ImageNet/devkit/caffe_ilsvrc12' 12 | 13 | 14 | # trans = { 15 | # 'train': 16 | # transforms.Compose([ 17 | # transforms.RandomResizedCrop(224), 18 | # transforms.RandomHorizontalFlip(), 19 | # transforms.ToTensor(), 20 | # transforms.Normalize(mean=[0.485, 0.456, 0.406], 21 | # std=[0.229, 0.224, 0.225]) 22 | # ]), 23 | # 'val': 24 | # transforms.Compose([ 25 | # transforms.Resize(256), 26 | # transforms.CenterCrop(224), 27 | # transforms.ToTensor(), 28 | # transforms.Normalize(mean=[0.485, 0.456, 0.406], 29 | # std=[0.229, 0.224, 0.225]) 30 | # ])} 31 | 32 | trans = { 33 | 'train': 34 | A.Compose([ 35 | A.RandomResizedCrop(height=224, width=224), 36 | A.HorizontalFlip(p=0.5), 37 | # A.ColorJitter (brightness=0.4, contrast=0.4, saturation=0.4, hue=0, always_apply=False, p=0.5), 38 | A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 39 | ToTensorV2() 40 | ]), 41 | 'val': 42 | A.Compose([ 43 | A.Resize(height=256, width=256), 44 | A.CenterCrop(height=224, width=224), 45 | A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 46 | ToTensorV2() 47 | ])} 48 | 49 | 50 | class ImageNet(dataset.Dataset): 51 | def __init__(self, mode): 52 | assert mode in ['train', 'val'] 53 | txt = os.path.join(devkit, '%s.txt' % mode) 54 | self.dataroot = os.path.join(img_root, mode, 'images') 55 | 56 | fpath = [] 57 | labels = [] 58 | with open(txt, 'r')as f: 59 | for i in f.readlines(): 60 | fp, label = i.strip().split(' ') 61 | fpath.append(os.path.join(self.dataroot, fp)) 62 | labels.append(int(label)) 63 | 64 | self.fpath = fpath 65 | self.labels = labels 66 | self.mode = mode 67 | self.trans = trans[mode] 68 | 69 | def __getitem__(self, index): 70 | fp = self.fpath[index] 71 | label = self.labels[index] 72 | 73 | img = Image.open(fp).convert('RGB') 74 | 75 | img = np.array(img) 76 | if self.trans is not None: 77 | img = self.trans(image=img)["image"] 78 | 79 | return img, label 80 | 81 | def __len__(self): 82 | return len(self.labels) 83 | 84 | 85 | if __name__ == '__main__': 86 | from torch.utils.data import DataLoader 87 | import warnings 88 | import piexif 89 | 90 | warnings.filterwarnings('error') 91 | 92 | dataset = ImageNet(mode='train') 93 | print(len(dataset)) 94 | 95 | loader = DataLoader(dataset=dataset, 96 | batch_size=256, 97 | shuffle=False, 98 | num_workers=10, 99 | pin_memory=True) 100 | 101 | for idx, (data, label) in enumerate(loader): 102 | print(idx) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OnlineLabelSmoothing 2 | 3 | This is a re-implementation of Online Label Smoothing. The code is written based on my understanding of the paper. If there's any bug in my code, please tell me in the **Issues** page. 4 | 5 | ## Usage 6 | 7 | ```python 8 | from OLS import OnlineLabelSmoothing 9 | 10 | ols_loss = OnlineLabelSmoothing(num_classes=1000, use_gpu=True) 11 | 12 | # Training 13 | for epoch in range(total_epoch): 14 | # train() 15 | # test() 16 | ols_loss.update() 17 | 18 | # Saving 19 | torch.save({'ols': ols_loss.matrix.cpu().data}, 'ols.pth') 20 | ``` 21 | 22 | ## Results 23 | 24 | #### Environment 25 | 26 | - Python 3.7 27 | - PyTorch 1.6.0 28 | - GPU: Tesla V100 32GB * 1 29 | 30 | #### Other Setting 31 | 32 | ```python 33 | num_classes: 1000 34 | optimizer: SGD 35 | init_lr: 0.1 36 | weight_decay: 0.0001 37 | momentum: 0.9 38 | lr_gamma: 0.1 39 | total_epoch: 250 40 | batch_size: 256 41 | num_workers: 20 42 | random_seed: 2020 43 | amp: True # automatic mixed-precision training, this function is offered by pytorch 44 | ``` 45 | 46 | #### Train 47 | 48 | - use single gpu 49 | 50 | ```shell 51 | python train.py --amp -s cos --loss ce ols --loss_w 0.5 0.5 52 | ``` 53 | 54 | - use multi gpus single node 55 | 56 | ```shell 57 | CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch\ 58 | --nproc_per_node=2 --master_addr 127.0.0.7 --master_port 23456\ 59 | train.py --multi-gpus 1 -nw 20 --amp -s multi --loss ce ols --loss_w 0.5 0.5 60 | ``` 61 | 62 | - use multi gpus multi nodes 63 | 64 | ```shell 65 | # Limited computing resources 66 | ``` 67 | 68 | #### Accuracy on Validation Set of ImageNet2012 69 | 70 | Although I used AMP(automatic mixed-precision) to speed up my training, it still took me nearly five days, so I didn't do any other experiments with ols. But there are other records of training ImageNet in my [blog](https://blog.csdn.net/u013347145/article/details/113175942). 71 | 72 | | Model | Loss | epoches | lr_schedule | Acc@1 | Acc@5 | 73 | | ---- | ---- | ---- | ---- | ---- | ---- | 74 | | ResNet50 | CE | 250 | Multi Step [75,150,225] | 76.32 | 93.06 | 75 | | ResNet50 | CE | 250 | COS with 5 epochs warmup | 76.95 | 93.27 | 76 | | ResNet50 | 0.5\*CE+0.5\*OLS | 250 | Multi Step [75,150,225] | 77.27 | 93.47 | 77 | | ResNet50 | 0.5\*CE+0.5\*OLS | 250 | COS with 5 epochs warmup | 77.79 | 93.79 | 78 | | ResNet50 | LS(e=0.1) | 250 | COS with 5 epochs warmup | 77.62 | 93.75 | 79 | | ResNet50 | LS(e=0.2) | 250 | COS with 5 epochs warmup | 77.89 | 93.74 | 80 | 81 | 82 | #### Reference 83 | 84 | - [Delving Deep into Label Smoothing](https://arxiv.org/pdf/2011.12562.pdf) 85 | - https://github.com/zhangchbin/OnlineLabelSmooth 86 | 87 | -------------------------------------------------------------------------------- /models/FocalLoss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class BinaryFocalLoss(nn.Module): 7 | def __init__(self, gamma=2., alpha=0.25, reduce=True, logits=True): 8 | """ 9 | :param gamma: 10 | :param alpha: class weights 11 | :param reduce: 12 | :param logits: if Sigmoid applied 13 | """ 14 | super().__init__() 15 | self.gamma = gamma 16 | self.alpha = alpha 17 | self.reduce = reduce 18 | self.logits = logits 19 | 20 | def forward(self, x, target): 21 | if self.logits: 22 | bce = F.binary_cross_entropy_with_logits(x, target, reduce=False) 23 | else: 24 | bce = F.binary_cross_entropy(x, target, reduce=False) 25 | pt = torch.exp(-bce) 26 | loss = self.alpha * (1 - pt) ** self.gamma * bce 27 | 28 | if self.reduce: 29 | return torch.mean(loss) 30 | else: 31 | return loss 32 | 33 | 34 | class FocalLoss(nn.Module): 35 | """ 36 | Code: https://github.com/ronghuaiyang/arcface-pytorch/blob/master/models/focal_loss.py 37 | """ 38 | def __init__(self, gamma=2., eps=1e-7): 39 | super(FocalLoss, self).__init__() 40 | self.gamma = gamma 41 | self.eps = eps 42 | self.ce = nn.CrossEntropyLoss(reduction='none') 43 | 44 | def forward(self, x, target): 45 | logp = self.ce(x, target) 46 | p = torch.exp(-logp) 47 | loss = (1 - p) ** self.gamma * logp 48 | 49 | return loss.mean() 50 | 51 | 52 | class FocalLossv2(nn.Module): 53 | '''Multi-class Focal loss implementation''' 54 | """ 55 | Code: https://github.com/ashawkey/FocalLoss.pytorch/blob/master/focalloss.py 56 | """ 57 | def __init__(self, gamma=2, weight=None): 58 | super().__init__() 59 | self.gamma = gamma 60 | self.weight = weight 61 | 62 | def forward(self, x, target): 63 | logpt = F.log_softmax(x, dim=1) 64 | pt = torch.exp(logpt) 65 | logpt = (1 - pt) ** self.gamma * logpt 66 | loss = F.nll_loss(logpt, target, self.weight) 67 | return loss 68 | 69 | 70 | class FocalLossv3(nn.Module): 71 | """ 72 | Code: https://github.com/clcarwin/focal_loss_pytorch/blob/master/focalloss.py 73 | """ 74 | def __init__(self, gamma=2., alpha=None, size_average=True): 75 | super().__init__() 76 | self.gamma = gamma 77 | self.alpha = alpha 78 | if isinstance(alpha,(float,int)): self.alpha = torch.Tensor([alpha,1-alpha]) 79 | if isinstance(alpha,list): self.alpha = torch.Tensor(alpha) 80 | self.size_average = size_average 81 | 82 | def forward(self, input, target): 83 | if input.dim()>2: 84 | input = input.view(input.size(0),input.size(1),-1) # N,C,H,W => N,C,H*W 85 | input = input.transpose(1,2) # N,C,H*W => N,H*W,C 86 | input = input.contiguous().view(-1,input.size(2)) # N,H*W,C => N*H*W,C 87 | target = target.view(-1,1) 88 | 89 | logpt = torch.log_softmax(input, dim=1) 90 | logpt = logpt.gather(1,target) 91 | logpt = logpt.view(-1) 92 | pt = logpt.data.exp() 93 | 94 | if self.alpha is not None: 95 | if self.alpha.type()!=input.data.type(): 96 | self.alpha = self.alpha.type_as(input.data) 97 | at = self.alpha.gather(0,target.data.view(-1)) 98 | logpt = logpt * at 99 | 100 | loss = -1 * (1-pt)**self.gamma * logpt 101 | 102 | if self.size_average: 103 | return loss.mean() 104 | else: 105 | return loss.sum() 106 | 107 | 108 | if __name__ == '__main__': 109 | x = torch.randn((4, 10)) 110 | y = torch.LongTensor([1, 2, 3, 0]) 111 | 112 | f1 = FocalLoss()(x, y) 113 | print('Focal1:', f1) 114 | f2 = FocalLossv2()(x, y) 115 | print('Focal2:', f2) 116 | f3 = FocalLossv3()(x, y) 117 | print('Focal2:', f3) 118 | 119 | -------------------------------------------------------------------------------- /models/LabelSmoothing.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reference: 3 | code: https://mp.weixin.qq.com/s/qKQekaktQAhrZDMwMLOXpA 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | class LabelSmoothing(nn.Module): 11 | """ 12 | NLL loss with label smoothing. 13 | """ 14 | def __init__(self, smoothing=0.0): 15 | """ 16 | Constructor for the LabelSmoothing module. 17 | :param smoothing: label smoothing factor 18 | """ 19 | super().__init__() 20 | self.confidence = 1.0 - smoothing 21 | self.smoothing = smoothing 22 | 23 | def forward(self, x, target): 24 | logprobs = torch.log_softmax(x, dim=-1) 25 | 26 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 27 | nll_loss = nll_loss.squeeze(1) 28 | smooth_loss = -logprobs.mean(dim=-1) 29 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 30 | 31 | return loss.mean() 32 | -------------------------------------------------------------------------------- /models/OLS.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reference: 3 | paper: Delving Deep into Label Smoothing. 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | from collections import Counter 8 | 9 | 10 | class OnlineLabelSmoothing(nn.Module): 11 | def __init__(self, num_classes=10, use_gpu=False): 12 | super().__init__() 13 | self.num_classes = num_classes 14 | self.matrix = torch.zeros((num_classes, num_classes)) 15 | self.grad = torch.zeros((num_classes, num_classes)) 16 | self.count = torch.zeros((num_classes, 1)) 17 | if use_gpu: 18 | self.matrix = self.matrix.cuda() 19 | self.grad = self.grad.cuda() 20 | self.count = self.count.cuda() 21 | 22 | def forward(self, x, target): 23 | target = target.view(-1,) 24 | logprobs = torch.log_softmax(x, dim=-1) 25 | 26 | softlabel = self.matrix[target] 27 | loss = (- softlabel * logprobs).sum(dim=-1) 28 | 29 | if self.training: 30 | # accumulate correct predictions 31 | p = torch.softmax(x.detach(), dim=1) 32 | _, pred = torch.max(p, 1) 33 | correct_index = pred.eq(target) 34 | correct_p = p[correct_index] 35 | correct_label = target[correct_index].tolist() 36 | 37 | self.grad[correct_label] += correct_p 38 | for k, v in Counter(correct_label).items(): 39 | self.count[k] += v 40 | 41 | return loss.mean() 42 | 43 | def update(self): 44 | index = torch.where(self.count > 0)[0] 45 | self.grad[index] = self.grad[index] / self.count[index] 46 | # reset matrix and update 47 | nn.init.constant_(self.matrix, 0.) 48 | norm = self.grad.sum(dim=1).view(-1, 1) 49 | index = torch.where(norm > 0)[0] 50 | self.matrix[index] = self.grad[index] / norm[index] 51 | # reset 52 | nn.init.constant_(self.grad, 0.) 53 | nn.init.constant_(self.count, 0.) 54 | 55 | 56 | if __name__ == '__main__': 57 | import random 58 | ols = OnlineLabelSmoothing(num_classes=6, use_gpu=False) 59 | x = torch.randn((10, 6)) 60 | y = torch.LongTensor(random.choices(range(6), k=10)) 61 | 62 | l = ols(x, y) 63 | print('ols:', l) 64 | ols.update() 65 | """ Compare the time performance, 100-batchs 66 | y = torch.LongTensor(random.choices(range(100), k=256)).cuda() 67 | count1 = torch.zeros((100, 1)).cuda() 68 | count2 = torch.zeros((100, 1)).cuda() 69 | t1 = time.time() 70 | for i in range(100): 71 | for k, v in Counter(y.tolist()).items(): 72 | count1[k] += v 73 | t2 = time.time() 74 | 75 | print(t2 - t1) 76 | # 0.19294953346252441 77 | 78 | t3 = time.time() 79 | for i in range(100): 80 | for k in y: 81 | count2[k] += 1 82 | t4 = time.time() 83 | print(t4 - t3) 84 | # 2.9499971866607666 85 | 86 | print((count1-count2).sum()) 87 | """ -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision import models 5 | 6 | 7 | class BaseModel(nn.Module): 8 | def __init__(self, model_name, num_classes=1000, pretrained=False, metric='linear'): 9 | super().__init__() 10 | self.model_name = model_name 11 | 12 | if model_name == 'res18': 13 | backbone = nn.Sequential(*list(models.resnet18(pretrained=pretrained).children())[:-2]) 14 | plane = 512 15 | elif model_name == 'res34': 16 | backbone = nn.Sequential(*list(models.resnet34(pretrained=pretrained).children())[:-2]) 17 | plane = 512 18 | elif model_name == 'res50': 19 | backbone = nn.Sequential(*list(models.resnet50(pretrained=pretrained).children())[:-2]) 20 | plane = 2048 21 | elif model_name == 'resx50': 22 | backbone = nn.Sequential(*list(models.resnext50_32x4d(pretrained=pretrained).children())[:-2]) 23 | plane = 2048 24 | else: 25 | raise ValueError('model - {} is not support'.format(model_name)) 26 | 27 | self.backbone = backbone 28 | 29 | self.pool = nn.AdaptiveAvgPool2d((1, 1)) 30 | 31 | if metric == 'linear': 32 | self.metric = nn.Linear(plane, num_classes) 33 | else: 34 | self.metric = None 35 | 36 | def forward(self, x): 37 | feat = self.backbone(x) 38 | feat_flat = self.pool(feat).view(feat.size(0), -1) 39 | out = self.metric(feat_flat) 40 | if self.training: 41 | return out, None 42 | else: 43 | return out 44 | 45 | 46 | if __name__ == '__main__': 47 | model = BaseModel(model_name='res50').eval() 48 | x = torch.randn((1, 3, 224, 224)) 49 | out = model(x) 50 | print(out.size()) 51 | print(model) -------------------------------------------------------------------------------- /plot_result.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import re 4 | 5 | 6 | def str2dict(line): 7 | d = {} 8 | sp = line.strip().split() 9 | for s in sp: 10 | k, v = s.split(':') 11 | d[k] = float(v) 12 | return d 13 | 14 | 15 | def dict_append(root, x): 16 | for k, v in x.items(): 17 | if k not in root: 18 | root[k] = [] 19 | root[k].append(v) 20 | return root 21 | 22 | 23 | def plot(res, savepath='.', name='loss', mode='train', best=None): 24 | if(isinstance(res, dict)): 25 | for k, v in res.items(): 26 | plt.plot(range(len(v)), v, label=k) 27 | plt.legend() 28 | if(isinstance(res, list)): 29 | pass 30 | plt.xlabel('Epoch') 31 | plt.ylabel(name.capitalize()) 32 | plt.grid() 33 | plt.savefig('{}/{}_{}'.format(savepath, mode, name), bbox_inches='tight') 34 | plt.close() 35 | 36 | 37 | def plot_result(txt, savepath='.'): 38 | epoch_pattern = r'epoch:(.*?) ' 39 | lr_pattern = r'lr:(.*?) ' 40 | loss_pattern = r'loss\[(.*?)\]' 41 | acc_pattern = r'acc\[(.*?)\]' 42 | epoch = [] 43 | lr = [] 44 | train_loss = {} 45 | train_acc = {} 46 | val_loss = {} 47 | val_acc = {} 48 | with open(txt, 'r')as f: 49 | for i in f.readlines(): 50 | if i[0] == '#': 51 | print(i) 52 | else: 53 | epoch.append(int(re.search(epoch_pattern, i).group(1))) 54 | lr.append(float(re.search(lr_pattern, i).group(1))) 55 | loss = re.findall(loss_pattern, i) 56 | acc = re.findall(acc_pattern, i) 57 | train_loss = dict_append(train_loss, str2dict(loss[0])) 58 | train_acc = dict_append(train_acc, str2dict(acc[0])) 59 | val_loss = dict_append(val_loss, str2dict(loss[1])) 60 | val_acc = dict_append(val_acc, str2dict(acc[1])) 61 | 62 | plot({'lr': lr}, savepath=savepath, name='lr') 63 | plot(train_loss, savepath=savepath) 64 | plot(train_acc, savepath=savepath, name='acc') 65 | plot(val_loss, savepath=savepath, mode='test') 66 | plot(val_acc, savepath=savepath, name='acc', mode='test') 67 | 68 | 69 | if __name__ == '__main__': 70 | txt = 'log.txt' 71 | plot_result(txt) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import random 4 | import argparse 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.optim import lr_scheduler 10 | from torch.backends import cudnn 11 | from torch.utils.data import DataLoader 12 | from torch.utils.data.distributed import DistributedSampler 13 | 14 | import torchvision 15 | 16 | from models.model import BaseModel 17 | from models.FocalLoss import FocalLoss 18 | from models.LabelSmoothing import LabelSmoothing 19 | from models.OLS import OnlineLabelSmoothing 20 | from ImageNetLoad import ImageNet 21 | from utils import MultiLossAverageMeter, AverageMeter, accuracy 22 | from plot_result import plot_result 23 | 24 | # os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" 25 | 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('--model_name', default='res50', type=str) 28 | parser.add_argument('--savepath', default='./Single_gpu', type=str) 29 | parser.add_argument('--loss', default=['ce'], nargs='+', type=str) 30 | parser.add_argument('--loss_w', default=[1.], nargs='+', type=float) 31 | parser.add_argument('--smoothing', default=0.1, type=float) 32 | parser.add_argument('-c', '--num_classes', default=1000, type=int) 33 | parser.add_argument('-p', '--pool_type', default='avg', type=str) 34 | parser.add_argument('--metric', default='linear', type=str) 35 | parser.add_argument('--down', default=0, type=int) 36 | parser.add_argument('--lr', default=0.1, type=float) 37 | parser.add_argument('--weight_decay', default=1e-4, type=float) 38 | parser.add_argument('--momentum', default=0.9, type=float) 39 | parser.add_argument('-s', '--scheduler', default='step', type=str) 40 | parser.add_argument('-r', '--resume', default=None, type=str) 41 | parser.add_argument('--lr_step', default=30, type=int) 42 | parser.add_argument('--warm', default=5, type=int) 43 | parser.add_argument('--print_step', default=500, type=int) 44 | parser.add_argument('--lr_gamma', default=0.1, type=float) 45 | parser.add_argument('--total_epoch', default=250, type=int) 46 | parser.add_argument('-bs', '--batch_size', default=256, type=int) 47 | parser.add_argument('-nw', '--num_workers', default=20, type=int) 48 | parser.add_argument('--multi-gpus', default=0, type=int) 49 | parser.add_argument('--seed', default=2020, type=int) 50 | parser.add_argument('--pretrained', default=0, type=int) 51 | parser.add_argument('--local_rank', default=0, type=int) 52 | parser.add_argument('--sync_bn', default=False, action='store_true') 53 | parser.add_argument('--amp', default=False, action='store_true') 54 | 55 | args = parser.parse_args() 56 | print('local_rank:', args.local_rank) 57 | 58 | ce_based_loss = ['ce', 'ls', 'fl', 'ols'] 59 | 60 | 61 | def loss_func(x, target, feat=None, training=False): 62 | loss_dict = {} 63 | loss_value = 0. 64 | for l, w in zip(args.loss, args.loss_w): 65 | if training: 66 | criterion[l].train() 67 | else: 68 | criterion[l].eval() 69 | 70 | loss = w * criterion[l](x, target) 71 | loss_value += loss 72 | loss_dict[l] = loss.detach().cpu().item() 73 | 74 | return loss_dict, loss_value 75 | 76 | 77 | def train(epoch): 78 | model.train() 79 | 80 | loss_meter = MultiLossAverageMeter(args.loss) 81 | top1 = AverageMeter('Acc@1', ':.2f') 82 | top5 = AverageMeter('Acc@5', ':.2f') 83 | t1 = time.time() 84 | s1 = time.time() 85 | for idx, (data, labels) in enumerate(trainloader): 86 | if multi_gpus: 87 | data, labels = data.cuda(non_blocking=True), labels.long().cuda(non_blocking=True) 88 | else: 89 | data, labels = data.to(device), labels.long().to(device) 90 | 91 | optimizer.zero_grad() 92 | 93 | # AMP 94 | if args.amp: 95 | with torch.cuda.amp.autocast(): 96 | out, feat = model(data) 97 | loss_dict, loss = loss_func(out, labels, feat, training=True) 98 | 99 | scaler.scale(loss).backward() 100 | scaler.step(optimizer) 101 | scaler.update() 102 | else: 103 | out, feat = model(data) 104 | loss_dict, loss = loss_func(out, labels, feat, training=True) 105 | 106 | loss.backward() 107 | optimizer.step() 108 | 109 | loss_meter.update(loss_dict, data.size(0)) 110 | acc1, acc5 = accuracy(out, labels, topk=(1, 5)) 111 | top1.update(acc1.item(), data.size(0)) 112 | top5.update(acc5.item(), data.size(0)) 113 | 114 | if idx % args.print_step == 0: 115 | s2 = time.time() 116 | print(f'rank:{args.local_rank} epoch[{epoch:>3}/{args.total_epoch}] idx[{idx:>3}/{len(trainloader)}] loss[{loss_meter}] acc[@1:{top1.avg:.4f} @5:{top5.avg:.4f}] time:{s2 - s1:.2f}s') 117 | s1 = time.time() 118 | 119 | if args.local_rank == 0: 120 | print('=' * 30) 121 | print(f'rank:{args.local_rank} train loss[{loss_meter}] acc[@1:{top1.avg:.4f} @5:{top5.avg:.4f}] time:{time.time() - t1:.2f}s') 122 | 123 | if args.local_rank == 0: 124 | with open(os.path.join(savepath, 'log.txt'), 'a+')as f: 125 | f.write('epoch:{} lr:{:.8f} loss[{}] acc[@1:{:.4f} @5:{:.4f}] '.format(epoch, optimizer.param_groups[0]['lr'], loss_meter, top1.avg, top5.avg)) 126 | 127 | 128 | def test(epoch): 129 | model.eval() 130 | 131 | loss_meter = MultiLossAverageMeter(args.loss) 132 | top1 = AverageMeter('Acc@1', ':.2f') 133 | top5 = AverageMeter('Acc@5', ':.2f') 134 | with torch.no_grad(): 135 | for idx, (data, labels) in enumerate(valloader): 136 | data, labels = data.to(device), labels.long().to(device) 137 | out = model(data) 138 | loss_dict, loss = loss_func(out, labels, training=False) 139 | 140 | loss_meter.update(loss_dict, data.size(0)) 141 | acc1, acc5 = accuracy(out, labels, topk=(1, 5)) 142 | top1.update(acc1.item(), data.size(0)) 143 | top5.update(acc5.item(), data.size(0)) 144 | 145 | print(f'rank:{args.local_rank} test loss[{loss_meter}] acc[@1:{top1.avg:.4f} @5:{top5.avg:.4f}]', end=' ') 146 | 147 | global best_acc, best_epoch 148 | 149 | if isinstance(model, nn.parallel.distributed.DistributedDataParallel): 150 | state = { 151 | 'net': model.module.state_dict(), 152 | 'acc': top1.avg, 153 | 'epoch': epoch, 154 | } 155 | else: 156 | state = { 157 | 'net': model.state_dict(), 158 | 'acc': top1.avg, 159 | 'epoch': epoch, 160 | } 161 | 162 | if 'ols' in args.loss: 163 | state['ols'] = criterion['ols'].matrix.cpu().data 164 | 165 | if top1.avg > best_acc: 166 | best_acc = top1.avg 167 | best_epoch = epoch 168 | torch.save(state, os.path.join(savepath, 'best.pth')) 169 | print('*') 170 | else: 171 | print() 172 | 173 | torch.save(state, os.path.join(savepath, 'last.pth')) 174 | 175 | with open(os.path.join(savepath, 'log.txt'), 'a+')as f: 176 | f.write('test loss[{}] acc[@1:{:.4f} @5:{:.4f}]\n'.format(loss_meter, top1.avg, top5.avg)) 177 | 178 | 179 | if __name__ == '__main__': 180 | best_epoch = 0 181 | best_acc = 0. 182 | use_gpu = False 183 | multi_gpus = False 184 | 185 | start_epoch = 0 186 | total = args.total_epoch 187 | 188 | if args.seed is not None: 189 | print('use random seed:', args.seed) 190 | torch.manual_seed(args.seed) 191 | torch.cuda.manual_seed(args.seed) 192 | torch.cuda.manual_seed_all(args.seed) 193 | np.random.seed(args.seed) 194 | random.seed(args.seed) 195 | cudnn.deterministic = False 196 | 197 | if torch.cuda.is_available(): 198 | use_gpu = True 199 | cudnn.benchmark = True 200 | 201 | if torch.cuda.device_count() > 1 and args.multi_gpus: 202 | torch.distributed.init_process_group(backend="nccl") 203 | torch.cuda.set_device(args.local_rank) 204 | multi_gpus = True 205 | 206 | # loss 207 | criterion = { 208 | 'ce': nn.CrossEntropyLoss(), 209 | 'fl': FocalLoss(), 210 | 'ls': LabelSmoothing(smoothing=args.smoothing), 211 | 'ols': OnlineLabelSmoothing(num_classes=args.num_classes, use_gpu=use_gpu) 212 | } 213 | 214 | # dataloader 215 | trainset = ImageNet(mode='train') 216 | valset = ImageNet(mode='val') 217 | 218 | # dataloader 219 | train_sampler = None 220 | if multi_gpus: 221 | train_sampler = torch.utils.data.distributed.DistributedSampler(trainset) 222 | 223 | trainloader = DataLoader(dataset=trainset, 224 | batch_size=args.batch_size, 225 | shuffle=(train_sampler is None), 226 | sampler=train_sampler, 227 | num_workers=args.num_workers, 228 | pin_memory=True) 229 | 230 | valloader = DataLoader(dataset=valset, 231 | batch_size=args.batch_size, 232 | shuffle=False, 233 | num_workers=args.num_workers, 234 | pin_memory=True) 235 | 236 | # model 237 | model = BaseModel(model_name=args.model_name, 238 | num_classes=args.num_classes, 239 | pretrained=args.pretrained) 240 | 241 | if args.resume: 242 | state = torch.load(args.resume) 243 | print('Resume from:{}'.format(args.resume)) 244 | model.load_state_dict(state['net'], strict=False) 245 | best_acc = state['acc'] 246 | start_epoch = state['epoch'] + 1 247 | if 'ols' in args.loss: 248 | criterion['ols'].matrix = state['ols'].cuda() 249 | 250 | # sync_bn 251 | if args.sync_bn and multi_gpus: 252 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 253 | print('Using SyncBatchNorm') 254 | 255 | if multi_gpus: 256 | device = torch.device("cuda", args.local_rank) 257 | model = model.to(device) 258 | model = nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank]) 259 | else: 260 | device = ('cuda:%d' % args.local_rank if torch.cuda.is_available() else 'cpu') 261 | model = model.to(device) 262 | print('Device:', device) 263 | 264 | # optim 265 | optimizer = torch.optim.SGD( 266 | [{'params': filter(lambda p: p.requires_grad, model.parameters()), 'lr': args.lr}], 267 | weight_decay=args.weight_decay, momentum=args.momentum) 268 | 269 | if args.scheduler == 'step': 270 | scheduler = lr_scheduler.StepLR(optimizer, step_size=args.lr_step, gamma=args.lr_gamma, last_epoch=-1) 271 | elif args.scheduler == 'multi': 272 | scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[75, 150, 225], gamma=args.lr_gamma, last_epoch=-1) 273 | elif args.scheduler == 'cos': 274 | warm_up_step = args.warm 275 | lambda_ = lambda epoch: (epoch + 1) / warm_up_step if epoch <= warm_up_step else 0.5 * ( 276 | np.cos((epoch - warm_up_step) / (args.total_epoch - warm_up_step) * np.pi) + 1) 277 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda_) 278 | else: 279 | raise ValueError('No such scheduler - {}'.format(args.scheduler)) 280 | 281 | # savepath 282 | loss_str = '_'.join(args.loss) 283 | if 'ls' in args.loss: 284 | loss_str += str(args.smoothing) 285 | savepath = os.path.join(args.savepath, '{}_{}_{}_{}_{}'.format(args.model_name, 286 | args.pool_type, 287 | args.metric, 288 | str(args.down), 289 | loss_str)) 290 | # AMP 291 | if args.amp: 292 | scaler = torch.cuda.amp.GradScaler() 293 | print('Using Mixing Accuracy.') 294 | savepath += '_amp' 295 | 296 | if args.sync_bn: 297 | savepath += '_syncbn' 298 | 299 | savepath += args.scheduler 300 | 301 | if args.local_rank == 0: 302 | print('Init_lr={}, Weight_decay={}, Momentum={}'.format(args.lr, args.weight_decay, args.momentum)) 303 | print('Loss:', args.loss) 304 | print('Loss_weight:', args.loss_w) 305 | print('Using {} scheduler'.format(args.scheduler)) 306 | print('Savepath:', savepath) 307 | 308 | os.makedirs(savepath, exist_ok=True) 309 | 310 | if args.local_rank == 0 and args.resume is None: 311 | with open(os.path.join(savepath, 'setting.txt'), 'w')as f: 312 | for k, v in vars(args).items(): 313 | f.write('{}:{}\n'.format(k, v)) 314 | 315 | f = open(os.path.join(savepath, 'log.txt'), 'w') 316 | f.close() 317 | 318 | start = time.time() 319 | 320 | for epoch in range(start_epoch, total): 321 | train(epoch) 322 | scheduler.step() 323 | if args.local_rank == 0: 324 | test(epoch) 325 | if 'ols' in args.loss: 326 | criterion['ols'].update() 327 | 328 | end = time.time() 329 | if args.local_rank == 0: 330 | print('total time:{}m{:.2f}s'.format((end - start) // 60, (end - start) % 60)) 331 | print('best_epoch:', best_epoch) 332 | print('best_acc:', best_acc) 333 | with open(os.path.join(savepath, 'log.txt'), 'a+')as f: 334 | f.write('# best_acc:{:.4f}, best_epoch:{}'.format(best_acc, best_epoch)) 335 | 336 | plot_result(txt=os.path.join(savepath, 'log.txt'), savepath=savepath) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, absolute_import 2 | import torch 3 | import matplotlib.pyplot as plt 4 | 5 | 6 | def plot_result(): 7 | pass 8 | 9 | 10 | def accuracy(out, target, topk=(1,)): 11 | with torch.no_grad(): 12 | maxk = max(topk) 13 | bs = target.size(0) 14 | 15 | _, pred = out.topk(maxk, 1, True, True) 16 | pred = pred.t() 17 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 18 | 19 | res = [] 20 | for k in topk: 21 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 22 | res.append(correct_k / bs) 23 | return res 24 | 25 | 26 | class AverageMeter(object): 27 | def __init__(self, name, fmt): 28 | self.name = name 29 | self.fmt = fmt 30 | self.reset() 31 | 32 | def reset(self): 33 | self.val = 0. 34 | self.avg = 0. 35 | self.sum = 0. 36 | self.count = 0. 37 | 38 | def update(self, val, bs): 39 | self.val = val 40 | self.sum += val * bs 41 | self.count += bs 42 | self.avg = self.sum / self.count 43 | 44 | def __str__(self): 45 | fmtstr = '{name} {val' + self.fmt + '} ({avg})' + self.fmt + '})' 46 | return fmtstr.format(**self.__dict__) 47 | 48 | 49 | class MultiLossAverageMeter(object): 50 | def __init__(self, loss): 51 | """ 52 | :param loss: list, names of losses 53 | """ 54 | self.dict = {} 55 | self.loss = loss 56 | self.reset() 57 | 58 | def reset(self): 59 | for l in self.loss: 60 | self.dict[l] = 0. 61 | self.all = 0. 62 | self.count = 0. 63 | 64 | def update(self, val, bs=1): 65 | """ 66 | :param val: dict, values of each loss 67 | :param bs: batch size 68 | """ 69 | count = self.count 70 | self.count += bs 71 | for k, v in self.dict.items(): 72 | self.dict[k] = (v * count + val[k]) / self.count 73 | 74 | def __repr__(self, avg=True): 75 | fmtstr = '' 76 | for k, v in self.dict.items(): 77 | if avg: 78 | fmtstr += '{}:{:.4f} '.format(k, v) 79 | else: 80 | fmtstr += '{}:{:.4f} '.format(k, v * self.count) 81 | 82 | return fmtstr.rstrip() --------------------------------------------------------------------------------