├── .gitignore ├── CE.npy ├── CE_tar.npy ├── LS.npy ├── LS_tar.npy ├── README.md ├── TSEN.py ├── assets ├── TSNE_CrossEntropy.png └── TSNE_LabelSmoothing.png ├── checkpoint ├── CrossEntropy.bin └── LabelSmoothing.bin ├── main.py ├── resnet.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | data/ 3 | -------------------------------------------------------------------------------- /CE.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seominseok0429/label-smoothing-visualization-pytorch/fa18717c2c723f61dc7b38155c520a113588a0b8/CE.npy -------------------------------------------------------------------------------- /CE_tar.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seominseok0429/label-smoothing-visualization-pytorch/fa18717c2c723f61dc7b38155c520a113588a0b8/CE_tar.npy -------------------------------------------------------------------------------- /LS.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seominseok0429/label-smoothing-visualization-pytorch/fa18717c2c723f61dc7b38155c520a113588a0b8/LS.npy -------------------------------------------------------------------------------- /LS_tar.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seominseok0429/label-smoothing-visualization-pytorch/fa18717c2c723f61dc7b38155c520a113588a0b8/LS_tar.npy -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # When Does Label Smoothing Help??? pytorch implementation 2 | 3 | paper : https://arxiv.org/abs/1906.02629 4 | 5 |
6 | 7 | Cross Entropy : **python main.py --ce** -> **python TSNE.py --ce** 8 | 9 | Label Smoothing : **python main.py** -> **python TSNE.py** 10 | 11 |
12 | 13 | simple Label Smoothing implementation code. 14 | 15 | ```python 16 | 17 | class LabelSmoothingCrossEntropy(nn.Module): 18 | def __init__(self): 19 | super(LabelSmoothingCrossEntropy, self).__init__() 20 | def forward(self, x, target, smoothing=0.1): 21 | confidence = 1. - smoothing 22 | logprobs = F.log_softmax(x, dim=-1) 23 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 24 | nll_loss = nll_loss.squeeze(1) 25 | smooth_loss = -logprobs.mean(dim=-1) 26 | loss = confidence * nll_loss + smoothing * smooth_loss 27 | return loss.mean() 28 | ``` 29 | ```python 30 | from utils import LabelSmoothingCrossEntropy 31 | 32 | criterion = LabelSmoothingCrossEntropy() 33 | loss = criterion(outputs, targets) 34 | loss.backward() 35 | optimizer.step() 36 | ``` 37 |
38 | 39 | 40 | Visualized using TSNE algorithm with CIFAR10 Dataset. "When Does Label Smoothing Help ???" As mentioned, you can use label smoothing to classify classes more clearly. 41 | 42 |
43 | 44 | 45 |
46 | 47 | -------------------------------------------------------------------------------- /TSEN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from torch.utils.data import DataLoader 5 | from torchvision import datasets, transforms 6 | import resnet as RN 7 | import torchvision 8 | import torchvision.transforms as transforms 9 | import matplotlib.pyplot as plt 10 | from sklearn.manifold import TSNE 11 | import argparse 12 | 13 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') 14 | parser.add_argument('--ce', action='store_true', help='Cross Entropy use') 15 | args = parser.parse_args() 16 | 17 | model = RN.ResNet18() 18 | if args.ce == True: 19 | path = './checkpoint/CrossEntropy.bin' 20 | npy_path = './CE.npy' 21 | npy_target = './CE_tar.npy' 22 | title = 'TSNE_CrossEntropy' 23 | states = torch.load(path) 24 | else: 25 | path = './checkpoint/LabelSmoothing.bin' 26 | npy_path = './LS.npy' 27 | npy_target = './LS_tar.npy' 28 | title = 'TSNE_LabelSmoothing' 29 | states = torch.load(path) 30 | 31 | model.load_state_dict(states) 32 | model.linear = nn.Flatten() 33 | 34 | transform_test = transforms.Compose([ 35 | transforms.ToTensor(), 36 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 37 | ]) 38 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) 39 | testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2) 40 | 41 | extract = model 42 | extract.cuda() 43 | extract.eval() 44 | 45 | out_target = [] 46 | out_output = [] 47 | 48 | for batch_idx, (inputs, targets) in enumerate(testloader): 49 | inputs, targets = inputs.cuda(), targets.cuda() 50 | outputs = extract(inputs) 51 | output_np = outputs.data.cpu().numpy() 52 | target_np = targets.data.cpu().numpy() 53 | out_output.append(output_np) 54 | out_target.append(target_np[:,np.newaxis]) 55 | 56 | output_array = np.concatenate(out_output, axis=0) 57 | target_array = np.concatenate(out_target, axis=0) 58 | np.save(npy_path, output_array, allow_pickle=False) 59 | np.save(npy_target, target_array, allow_pickle=False) 60 | 61 | #feature = np.load('./label_smooth1.npy').astype(np.float64) 62 | #target = np.load('./label_smooth_target1.npy') 63 | 64 | print('Pred shape :',output_array.shape) 65 | print('Target shape :',target_array.shape) 66 | 67 | tsne = TSNE(n_components=2, init='pca', random_state=0) 68 | output_array = tsne.fit_transform(output_array) 69 | plt.rcParams['figure.figsize'] = 10,10 70 | plt.scatter(output_array[:, 0], output_array[:, 1], c= target_array[:,0]) 71 | plt.title(title) 72 | plt.savefig('./'+title+'.png', bbox_inches='tight') 73 | 74 | 75 | -------------------------------------------------------------------------------- /assets/TSNE_CrossEntropy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seominseok0429/label-smoothing-visualization-pytorch/fa18717c2c723f61dc7b38155c520a113588a0b8/assets/TSNE_CrossEntropy.png -------------------------------------------------------------------------------- /assets/TSNE_LabelSmoothing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seominseok0429/label-smoothing-visualization-pytorch/fa18717c2c723f61dc7b38155c520a113588a0b8/assets/TSNE_LabelSmoothing.png -------------------------------------------------------------------------------- /checkpoint/CrossEntropy.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seominseok0429/label-smoothing-visualization-pytorch/fa18717c2c723f61dc7b38155c520a113588a0b8/checkpoint/CrossEntropy.bin -------------------------------------------------------------------------------- /checkpoint/LabelSmoothing.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seominseok0429/label-smoothing-visualization-pytorch/fa18717c2c723f61dc7b38155c520a113588a0b8/checkpoint/LabelSmoothing.bin -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | '''Train CIFAR10 with PyTorch.''' 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import torch.nn.functional as F 6 | import torch.backends.cudnn as cudnn 7 | 8 | import torchvision 9 | import torchvision.transforms as transforms 10 | 11 | import os 12 | import argparse 13 | import resnet as RN 14 | from utils import progress_bar, LabelSmoothingCrossEntropy, save_model 15 | 16 | 17 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') 18 | parser.add_argument('--lr', default=0.1, type=float, help='learning rate') 19 | parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint') 20 | parser.add_argument('--ce', action='store_true', help='Cross entropy use') 21 | args = parser.parse_args() 22 | 23 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 24 | best_acc = 0 # best test accuracy 25 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 26 | 27 | # Data 28 | print('==> Preparing data..') 29 | transform_train = transforms.Compose([ 30 | transforms.RandomCrop(32, padding=4), 31 | transforms.RandomHorizontalFlip(), 32 | transforms.ToTensor(), 33 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 34 | ]) 35 | 36 | transform_test = transforms.Compose([ 37 | transforms.ToTensor(), 38 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 39 | ]) 40 | 41 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) 42 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=256, shuffle=True, num_workers=30) 43 | 44 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) 45 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2) 46 | 47 | classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') 48 | 49 | # Model 50 | print('==> Building model..') 51 | net = RN.ResNet18() 52 | net = net.to(device) 53 | 54 | if device == 'cuda': 55 | net = torch.nn.DataParallel(net) 56 | cudnn.benchmark = True 57 | 58 | if args.ce == True: 59 | criterion = nn.CrossEntropyLoss() 60 | save_path = './checkpoint/CrossEntropy.bin' 61 | print("Use CrossEntropy") 62 | else: 63 | criterion = LabelSmoothingCrossEntropy() 64 | save_path = './checkpoint/LabelSmoothing.bin' 65 | print("Use Label Smooting") 66 | 67 | optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=0.0001, nesterov= True) 68 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30,60,90]) 69 | 70 | # Training 71 | def train(epoch): 72 | print('\nEpoch: %d' % epoch) 73 | net.train() 74 | train_loss = 0 75 | correct = 0 76 | total = 0 77 | for batch_idx, (inputs, targets) in enumerate(trainloader): 78 | inputs, targets = inputs.to(device), targets.to(device) 79 | optimizer.zero_grad() 80 | outputs = net(inputs) 81 | loss = criterion(outputs, targets) 82 | loss.backward() 83 | optimizer.step() 84 | 85 | train_loss += loss.item() 86 | _, predicted = outputs.max(1) 87 | total += targets.size(0) 88 | correct += predicted.eq(targets).sum().item() 89 | 90 | progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 91 | % (train_loss/(batch_idx+1), 100.*correct/total, correct, total)) 92 | scheduler.step() 93 | 94 | def test(epoch): 95 | global best_acc 96 | net.eval() 97 | test_loss = 0 98 | correct = 0 99 | total = 0 100 | with torch.no_grad(): 101 | for batch_idx, (inputs, targets) in enumerate(testloader): 102 | inputs, targets = inputs.to(device), targets.to(device) 103 | outputs = net(inputs) 104 | loss = criterion(outputs, targets) 105 | 106 | test_loss += loss.item() 107 | _, predicted = outputs.max(1) 108 | total += targets.size(0) 109 | correct += predicted.eq(targets).sum().item() 110 | 111 | progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 112 | % (test_loss/(batch_idx+1), 100.*correct/total, correct, total)) 113 | 114 | # Save checkpoint. 115 | acc = 100.*correct/total 116 | if acc > best_acc: 117 | print('Saving..') 118 | if not os.path.isdir('checkpoint'): 119 | os.mkdir('checkpoint') 120 | save_model(net, save_path) 121 | best_acc = acc 122 | 123 | 124 | for epoch in range(start_epoch, start_epoch+120): 125 | train(epoch) 126 | test(epoch) 127 | -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | 3 | For Pre-activation ResNet, see 'preact_resnet.py'. 4 | 5 | Reference: 6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 8 | ''' 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class BasicBlock(nn.Module): 15 | expansion = 1 16 | 17 | def __init__(self, in_planes, planes, stride=1): 18 | super(BasicBlock, self).__init__() 19 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 20 | self.bn1 = nn.BatchNorm2d(planes) 21 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | 24 | self.shortcut = nn.Sequential() 25 | if stride != 1 or in_planes != self.expansion*planes: 26 | self.shortcut = nn.Sequential( 27 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 28 | nn.BatchNorm2d(self.expansion*planes) 29 | ) 30 | 31 | def forward(self, x): 32 | out = F.relu(self.bn1(self.conv1(x))) 33 | out = self.bn2(self.conv2(out)) 34 | out += self.shortcut(x) 35 | out = F.relu(out) 36 | return out 37 | 38 | class ResNet(nn.Module): 39 | def __init__(self, block, num_blocks, num_classes=10): 40 | super(ResNet, self).__init__() 41 | self.in_planes = 64 42 | 43 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 44 | self.bn1 = nn.BatchNorm2d(64) 45 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 46 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 47 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 48 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 49 | self.avgpool = nn.AdaptiveAvgPool2d((1,1)) 50 | self.linear = nn.Linear(512*block.expansion, num_classes) 51 | 52 | def _make_layer(self, block, planes, num_blocks, stride): 53 | strides = [stride] + [1]*(num_blocks-1) 54 | layers = [] 55 | for stride in strides: 56 | layers.append(block(self.in_planes, planes, stride)) 57 | self.in_planes = planes * block.expansion 58 | return nn.Sequential(*layers) 59 | 60 | def forward(self, x): 61 | out = F.relu(self.bn1(self.conv1(x))) 62 | out = self.layer1(out) 63 | out = self.layer2(out) 64 | out = self.layer3(out) 65 | out = self.layer4(out) 66 | out = self.avgpool(out) 67 | out = torch.flatten(out, 1) 68 | out = self.linear(out) 69 | return out 70 | 71 | 72 | def ResNet18(): 73 | return ResNet(BasicBlock, [2,2,2,2]) 74 | 75 | def test(): 76 | net = ResNet18() 77 | y = net(torch.randn(1,3,32,32)) 78 | print(net) 79 | 80 | #test() 81 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | - msr_init: net parameter initialization. 4 | - progress_bar: progress bar mimic xlua.progress. 5 | ''' 6 | import os 7 | import sys 8 | import time 9 | import math 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.init as init 13 | import torch.nn.functional as F 14 | from pathlib import Path 15 | 16 | class LabelSmoothingCrossEntropy(nn.Module): 17 | def __init__(self): 18 | super(LabelSmoothingCrossEntropy, self).__init__() 19 | def forward(self, x, target, smoothing=0.1): 20 | confidence = 1. - smoothing 21 | logprobs = F.log_softmax(x, dim=-1) 22 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 23 | nll_loss = nll_loss.squeeze(1) 24 | smooth_loss = -logprobs.mean(dim=-1) 25 | loss = confidence * nll_loss + smoothing * smooth_loss 26 | return loss.mean() 27 | 28 | def get_mean_and_std(dataset): 29 | '''Compute the mean and std value of dataset.''' 30 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 31 | mean = torch.zeros(3) 32 | std = torch.zeros(3) 33 | print('==> Computing mean and std..') 34 | for inputs, targets in dataloader: 35 | for i in range(3): 36 | mean[i] += inputs[:,i,:,:].mean() 37 | std[i] += inputs[:,i,:,:].std() 38 | mean.div_(len(dataset)) 39 | std.div_(len(dataset)) 40 | return mean, std 41 | 42 | def init_params(net): 43 | '''Init layer parameters.''' 44 | for m in net.modules(): 45 | if isinstance(m, nn.Conv2d): 46 | init.kaiming_normal(m.weight, mode='fan_out') 47 | if m.bias: 48 | init.constant(m.bias, 0) 49 | elif isinstance(m, nn.BatchNorm2d): 50 | init.constant(m.weight, 1) 51 | init.constant(m.bias, 0) 52 | elif isinstance(m, nn.Linear): 53 | init.normal(m.weight, std=1e-3) 54 | if m.bias: 55 | init.constant(m.bias, 0) 56 | 57 | 58 | _, term_width = os.popen('stty size', 'r').read().split() 59 | term_width = int(term_width) 60 | 61 | TOTAL_BAR_LENGTH = 65. 62 | last_time = time.time() 63 | begin_time = last_time 64 | def progress_bar(current, total, msg=None): 65 | global last_time, begin_time 66 | if current == 0: 67 | begin_time = time.time() # Reset for new bar. 68 | 69 | cur_len = int(TOTAL_BAR_LENGTH*current/total) 70 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 71 | 72 | sys.stdout.write(' [') 73 | for i in range(cur_len): 74 | sys.stdout.write('=') 75 | sys.stdout.write('>') 76 | for i in range(rest_len): 77 | sys.stdout.write('.') 78 | sys.stdout.write(']') 79 | 80 | cur_time = time.time() 81 | step_time = cur_time - last_time 82 | last_time = cur_time 83 | tot_time = cur_time - begin_time 84 | 85 | L = [] 86 | L.append(' Step: %s' % format_time(step_time)) 87 | L.append(' | Tot: %s' % format_time(tot_time)) 88 | if msg: 89 | L.append(' | ' + msg) 90 | 91 | msg = ''.join(L) 92 | sys.stdout.write(msg) 93 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 94 | sys.stdout.write(' ') 95 | 96 | # Go back to the center of the bar. 97 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): 98 | sys.stdout.write('\b') 99 | sys.stdout.write(' %d/%d ' % (current+1, total)) 100 | 101 | if current < total-1: 102 | sys.stdout.write('\r') 103 | else: 104 | sys.stdout.write('\n') 105 | sys.stdout.flush() 106 | 107 | def format_time(seconds): 108 | days = int(seconds / 3600/24) 109 | seconds = seconds - days*3600*24 110 | hours = int(seconds / 3600) 111 | seconds = seconds - hours*3600 112 | minutes = int(seconds / 60) 113 | seconds = seconds - minutes*60 114 | secondsf = int(seconds) 115 | seconds = seconds - secondsf 116 | millis = int(seconds*1000) 117 | 118 | f = '' 119 | i = 1 120 | if days > 0: 121 | f += str(days) + 'D' 122 | i += 1 123 | if hours > 0 and i <= 2: 124 | f += str(hours) + 'h' 125 | i += 1 126 | if minutes > 0 and i <= 2: 127 | f += str(minutes) + 'm' 128 | i += 1 129 | if secondsf > 0 and i <= 2: 130 | f += str(secondsf) + 's' 131 | i += 1 132 | if millis > 0 and i <= 2: 133 | f += str(millis) + 'ms' 134 | i += 1 135 | if f == '': 136 | f = '0ms' 137 | return f 138 | 139 | def save_model(model, model_path): 140 | if isinstance(model_path, Path): 141 | model_path = str(model_path) 142 | if isinstance(model, nn.DataParallel): 143 | model = model.module 144 | state_dict = model.state_dict() 145 | for key in state_dict: 146 | state_dict[key] = state_dict[key].cpu() 147 | torch.save(state_dict, model_path) 148 | --------------------------------------------------------------------------------