├── .gitignore
├── CE.npy
├── CE_tar.npy
├── LS.npy
├── LS_tar.npy
├── README.md
├── TSEN.py
├── assets
├── TSNE_CrossEntropy.png
└── TSNE_LabelSmoothing.png
├── checkpoint
├── CrossEntropy.bin
└── LabelSmoothing.bin
├── main.py
├── resnet.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | data/
3 |
--------------------------------------------------------------------------------
/CE.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seominseok0429/label-smoothing-visualization-pytorch/fa18717c2c723f61dc7b38155c520a113588a0b8/CE.npy
--------------------------------------------------------------------------------
/CE_tar.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seominseok0429/label-smoothing-visualization-pytorch/fa18717c2c723f61dc7b38155c520a113588a0b8/CE_tar.npy
--------------------------------------------------------------------------------
/LS.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seominseok0429/label-smoothing-visualization-pytorch/fa18717c2c723f61dc7b38155c520a113588a0b8/LS.npy
--------------------------------------------------------------------------------
/LS_tar.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seominseok0429/label-smoothing-visualization-pytorch/fa18717c2c723f61dc7b38155c520a113588a0b8/LS_tar.npy
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # When Does Label Smoothing Help??? pytorch implementation
2 |
3 | paper : https://arxiv.org/abs/1906.02629
4 |
5 |
6 |
7 | Cross Entropy : **python main.py --ce** -> **python TSNE.py --ce**
8 |
9 | Label Smoothing : **python main.py** -> **python TSNE.py**
10 |
11 |
12 |
13 | simple Label Smoothing implementation code.
14 |
15 | ```python
16 |
17 | class LabelSmoothingCrossEntropy(nn.Module):
18 | def __init__(self):
19 | super(LabelSmoothingCrossEntropy, self).__init__()
20 | def forward(self, x, target, smoothing=0.1):
21 | confidence = 1. - smoothing
22 | logprobs = F.log_softmax(x, dim=-1)
23 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
24 | nll_loss = nll_loss.squeeze(1)
25 | smooth_loss = -logprobs.mean(dim=-1)
26 | loss = confidence * nll_loss + smoothing * smooth_loss
27 | return loss.mean()
28 | ```
29 | ```python
30 | from utils import LabelSmoothingCrossEntropy
31 |
32 | criterion = LabelSmoothingCrossEntropy()
33 | loss = criterion(outputs, targets)
34 | loss.backward()
35 | optimizer.step()
36 | ```
37 |
38 |
39 |
40 | Visualized using TSNE algorithm with CIFAR10 Dataset. "When Does Label Smoothing Help ???" As mentioned, you can use label smoothing to classify classes more clearly.
41 |
42 |
43 |

44 |

45 |
46 |
47 |
--------------------------------------------------------------------------------
/TSEN.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from torch.utils.data import DataLoader
5 | from torchvision import datasets, transforms
6 | import resnet as RN
7 | import torchvision
8 | import torchvision.transforms as transforms
9 | import matplotlib.pyplot as plt
10 | from sklearn.manifold import TSNE
11 | import argparse
12 |
13 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
14 | parser.add_argument('--ce', action='store_true', help='Cross Entropy use')
15 | args = parser.parse_args()
16 |
17 | model = RN.ResNet18()
18 | if args.ce == True:
19 | path = './checkpoint/CrossEntropy.bin'
20 | npy_path = './CE.npy'
21 | npy_target = './CE_tar.npy'
22 | title = 'TSNE_CrossEntropy'
23 | states = torch.load(path)
24 | else:
25 | path = './checkpoint/LabelSmoothing.bin'
26 | npy_path = './LS.npy'
27 | npy_target = './LS_tar.npy'
28 | title = 'TSNE_LabelSmoothing'
29 | states = torch.load(path)
30 |
31 | model.load_state_dict(states)
32 | model.linear = nn.Flatten()
33 |
34 | transform_test = transforms.Compose([
35 | transforms.ToTensor(),
36 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
37 | ])
38 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
39 | testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)
40 |
41 | extract = model
42 | extract.cuda()
43 | extract.eval()
44 |
45 | out_target = []
46 | out_output = []
47 |
48 | for batch_idx, (inputs, targets) in enumerate(testloader):
49 | inputs, targets = inputs.cuda(), targets.cuda()
50 | outputs = extract(inputs)
51 | output_np = outputs.data.cpu().numpy()
52 | target_np = targets.data.cpu().numpy()
53 | out_output.append(output_np)
54 | out_target.append(target_np[:,np.newaxis])
55 |
56 | output_array = np.concatenate(out_output, axis=0)
57 | target_array = np.concatenate(out_target, axis=0)
58 | np.save(npy_path, output_array, allow_pickle=False)
59 | np.save(npy_target, target_array, allow_pickle=False)
60 |
61 | #feature = np.load('./label_smooth1.npy').astype(np.float64)
62 | #target = np.load('./label_smooth_target1.npy')
63 |
64 | print('Pred shape :',output_array.shape)
65 | print('Target shape :',target_array.shape)
66 |
67 | tsne = TSNE(n_components=2, init='pca', random_state=0)
68 | output_array = tsne.fit_transform(output_array)
69 | plt.rcParams['figure.figsize'] = 10,10
70 | plt.scatter(output_array[:, 0], output_array[:, 1], c= target_array[:,0])
71 | plt.title(title)
72 | plt.savefig('./'+title+'.png', bbox_inches='tight')
73 |
74 |
75 |
--------------------------------------------------------------------------------
/assets/TSNE_CrossEntropy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seominseok0429/label-smoothing-visualization-pytorch/fa18717c2c723f61dc7b38155c520a113588a0b8/assets/TSNE_CrossEntropy.png
--------------------------------------------------------------------------------
/assets/TSNE_LabelSmoothing.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seominseok0429/label-smoothing-visualization-pytorch/fa18717c2c723f61dc7b38155c520a113588a0b8/assets/TSNE_LabelSmoothing.png
--------------------------------------------------------------------------------
/checkpoint/CrossEntropy.bin:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seominseok0429/label-smoothing-visualization-pytorch/fa18717c2c723f61dc7b38155c520a113588a0b8/checkpoint/CrossEntropy.bin
--------------------------------------------------------------------------------
/checkpoint/LabelSmoothing.bin:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seominseok0429/label-smoothing-visualization-pytorch/fa18717c2c723f61dc7b38155c520a113588a0b8/checkpoint/LabelSmoothing.bin
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | '''Train CIFAR10 with PyTorch.'''
2 | import torch
3 | import torch.nn as nn
4 | import torch.optim as optim
5 | import torch.nn.functional as F
6 | import torch.backends.cudnn as cudnn
7 |
8 | import torchvision
9 | import torchvision.transforms as transforms
10 |
11 | import os
12 | import argparse
13 | import resnet as RN
14 | from utils import progress_bar, LabelSmoothingCrossEntropy, save_model
15 |
16 |
17 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
18 | parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
19 | parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
20 | parser.add_argument('--ce', action='store_true', help='Cross entropy use')
21 | args = parser.parse_args()
22 |
23 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
24 | best_acc = 0 # best test accuracy
25 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch
26 |
27 | # Data
28 | print('==> Preparing data..')
29 | transform_train = transforms.Compose([
30 | transforms.RandomCrop(32, padding=4),
31 | transforms.RandomHorizontalFlip(),
32 | transforms.ToTensor(),
33 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
34 | ])
35 |
36 | transform_test = transforms.Compose([
37 | transforms.ToTensor(),
38 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
39 | ])
40 |
41 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
42 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=256, shuffle=True, num_workers=30)
43 |
44 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
45 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
46 |
47 | classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
48 |
49 | # Model
50 | print('==> Building model..')
51 | net = RN.ResNet18()
52 | net = net.to(device)
53 |
54 | if device == 'cuda':
55 | net = torch.nn.DataParallel(net)
56 | cudnn.benchmark = True
57 |
58 | if args.ce == True:
59 | criterion = nn.CrossEntropyLoss()
60 | save_path = './checkpoint/CrossEntropy.bin'
61 | print("Use CrossEntropy")
62 | else:
63 | criterion = LabelSmoothingCrossEntropy()
64 | save_path = './checkpoint/LabelSmoothing.bin'
65 | print("Use Label Smooting")
66 |
67 | optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=0.0001, nesterov= True)
68 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30,60,90])
69 |
70 | # Training
71 | def train(epoch):
72 | print('\nEpoch: %d' % epoch)
73 | net.train()
74 | train_loss = 0
75 | correct = 0
76 | total = 0
77 | for batch_idx, (inputs, targets) in enumerate(trainloader):
78 | inputs, targets = inputs.to(device), targets.to(device)
79 | optimizer.zero_grad()
80 | outputs = net(inputs)
81 | loss = criterion(outputs, targets)
82 | loss.backward()
83 | optimizer.step()
84 |
85 | train_loss += loss.item()
86 | _, predicted = outputs.max(1)
87 | total += targets.size(0)
88 | correct += predicted.eq(targets).sum().item()
89 |
90 | progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
91 | % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
92 | scheduler.step()
93 |
94 | def test(epoch):
95 | global best_acc
96 | net.eval()
97 | test_loss = 0
98 | correct = 0
99 | total = 0
100 | with torch.no_grad():
101 | for batch_idx, (inputs, targets) in enumerate(testloader):
102 | inputs, targets = inputs.to(device), targets.to(device)
103 | outputs = net(inputs)
104 | loss = criterion(outputs, targets)
105 |
106 | test_loss += loss.item()
107 | _, predicted = outputs.max(1)
108 | total += targets.size(0)
109 | correct += predicted.eq(targets).sum().item()
110 |
111 | progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
112 | % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
113 |
114 | # Save checkpoint.
115 | acc = 100.*correct/total
116 | if acc > best_acc:
117 | print('Saving..')
118 | if not os.path.isdir('checkpoint'):
119 | os.mkdir('checkpoint')
120 | save_model(net, save_path)
121 | best_acc = acc
122 |
123 |
124 | for epoch in range(start_epoch, start_epoch+120):
125 | train(epoch)
126 | test(epoch)
127 |
--------------------------------------------------------------------------------
/resnet.py:
--------------------------------------------------------------------------------
1 | '''ResNet in PyTorch.
2 |
3 | For Pre-activation ResNet, see 'preact_resnet.py'.
4 |
5 | Reference:
6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
8 | '''
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 |
14 | class BasicBlock(nn.Module):
15 | expansion = 1
16 |
17 | def __init__(self, in_planes, planes, stride=1):
18 | super(BasicBlock, self).__init__()
19 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
20 | self.bn1 = nn.BatchNorm2d(planes)
21 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
22 | self.bn2 = nn.BatchNorm2d(planes)
23 |
24 | self.shortcut = nn.Sequential()
25 | if stride != 1 or in_planes != self.expansion*planes:
26 | self.shortcut = nn.Sequential(
27 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
28 | nn.BatchNorm2d(self.expansion*planes)
29 | )
30 |
31 | def forward(self, x):
32 | out = F.relu(self.bn1(self.conv1(x)))
33 | out = self.bn2(self.conv2(out))
34 | out += self.shortcut(x)
35 | out = F.relu(out)
36 | return out
37 |
38 | class ResNet(nn.Module):
39 | def __init__(self, block, num_blocks, num_classes=10):
40 | super(ResNet, self).__init__()
41 | self.in_planes = 64
42 |
43 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
44 | self.bn1 = nn.BatchNorm2d(64)
45 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
46 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
47 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
48 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
49 | self.avgpool = nn.AdaptiveAvgPool2d((1,1))
50 | self.linear = nn.Linear(512*block.expansion, num_classes)
51 |
52 | def _make_layer(self, block, planes, num_blocks, stride):
53 | strides = [stride] + [1]*(num_blocks-1)
54 | layers = []
55 | for stride in strides:
56 | layers.append(block(self.in_planes, planes, stride))
57 | self.in_planes = planes * block.expansion
58 | return nn.Sequential(*layers)
59 |
60 | def forward(self, x):
61 | out = F.relu(self.bn1(self.conv1(x)))
62 | out = self.layer1(out)
63 | out = self.layer2(out)
64 | out = self.layer3(out)
65 | out = self.layer4(out)
66 | out = self.avgpool(out)
67 | out = torch.flatten(out, 1)
68 | out = self.linear(out)
69 | return out
70 |
71 |
72 | def ResNet18():
73 | return ResNet(BasicBlock, [2,2,2,2])
74 |
75 | def test():
76 | net = ResNet18()
77 | y = net(torch.randn(1,3,32,32))
78 | print(net)
79 |
80 | #test()
81 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | '''Some helper functions for PyTorch, including:
2 | - get_mean_and_std: calculate the mean and std value of dataset.
3 | - msr_init: net parameter initialization.
4 | - progress_bar: progress bar mimic xlua.progress.
5 | '''
6 | import os
7 | import sys
8 | import time
9 | import math
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.init as init
13 | import torch.nn.functional as F
14 | from pathlib import Path
15 |
16 | class LabelSmoothingCrossEntropy(nn.Module):
17 | def __init__(self):
18 | super(LabelSmoothingCrossEntropy, self).__init__()
19 | def forward(self, x, target, smoothing=0.1):
20 | confidence = 1. - smoothing
21 | logprobs = F.log_softmax(x, dim=-1)
22 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
23 | nll_loss = nll_loss.squeeze(1)
24 | smooth_loss = -logprobs.mean(dim=-1)
25 | loss = confidence * nll_loss + smoothing * smooth_loss
26 | return loss.mean()
27 |
28 | def get_mean_and_std(dataset):
29 | '''Compute the mean and std value of dataset.'''
30 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
31 | mean = torch.zeros(3)
32 | std = torch.zeros(3)
33 | print('==> Computing mean and std..')
34 | for inputs, targets in dataloader:
35 | for i in range(3):
36 | mean[i] += inputs[:,i,:,:].mean()
37 | std[i] += inputs[:,i,:,:].std()
38 | mean.div_(len(dataset))
39 | std.div_(len(dataset))
40 | return mean, std
41 |
42 | def init_params(net):
43 | '''Init layer parameters.'''
44 | for m in net.modules():
45 | if isinstance(m, nn.Conv2d):
46 | init.kaiming_normal(m.weight, mode='fan_out')
47 | if m.bias:
48 | init.constant(m.bias, 0)
49 | elif isinstance(m, nn.BatchNorm2d):
50 | init.constant(m.weight, 1)
51 | init.constant(m.bias, 0)
52 | elif isinstance(m, nn.Linear):
53 | init.normal(m.weight, std=1e-3)
54 | if m.bias:
55 | init.constant(m.bias, 0)
56 |
57 |
58 | _, term_width = os.popen('stty size', 'r').read().split()
59 | term_width = int(term_width)
60 |
61 | TOTAL_BAR_LENGTH = 65.
62 | last_time = time.time()
63 | begin_time = last_time
64 | def progress_bar(current, total, msg=None):
65 | global last_time, begin_time
66 | if current == 0:
67 | begin_time = time.time() # Reset for new bar.
68 |
69 | cur_len = int(TOTAL_BAR_LENGTH*current/total)
70 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
71 |
72 | sys.stdout.write(' [')
73 | for i in range(cur_len):
74 | sys.stdout.write('=')
75 | sys.stdout.write('>')
76 | for i in range(rest_len):
77 | sys.stdout.write('.')
78 | sys.stdout.write(']')
79 |
80 | cur_time = time.time()
81 | step_time = cur_time - last_time
82 | last_time = cur_time
83 | tot_time = cur_time - begin_time
84 |
85 | L = []
86 | L.append(' Step: %s' % format_time(step_time))
87 | L.append(' | Tot: %s' % format_time(tot_time))
88 | if msg:
89 | L.append(' | ' + msg)
90 |
91 | msg = ''.join(L)
92 | sys.stdout.write(msg)
93 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
94 | sys.stdout.write(' ')
95 |
96 | # Go back to the center of the bar.
97 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
98 | sys.stdout.write('\b')
99 | sys.stdout.write(' %d/%d ' % (current+1, total))
100 |
101 | if current < total-1:
102 | sys.stdout.write('\r')
103 | else:
104 | sys.stdout.write('\n')
105 | sys.stdout.flush()
106 |
107 | def format_time(seconds):
108 | days = int(seconds / 3600/24)
109 | seconds = seconds - days*3600*24
110 | hours = int(seconds / 3600)
111 | seconds = seconds - hours*3600
112 | minutes = int(seconds / 60)
113 | seconds = seconds - minutes*60
114 | secondsf = int(seconds)
115 | seconds = seconds - secondsf
116 | millis = int(seconds*1000)
117 |
118 | f = ''
119 | i = 1
120 | if days > 0:
121 | f += str(days) + 'D'
122 | i += 1
123 | if hours > 0 and i <= 2:
124 | f += str(hours) + 'h'
125 | i += 1
126 | if minutes > 0 and i <= 2:
127 | f += str(minutes) + 'm'
128 | i += 1
129 | if secondsf > 0 and i <= 2:
130 | f += str(secondsf) + 's'
131 | i += 1
132 | if millis > 0 and i <= 2:
133 | f += str(millis) + 'ms'
134 | i += 1
135 | if f == '':
136 | f = '0ms'
137 | return f
138 |
139 | def save_model(model, model_path):
140 | if isinstance(model_path, Path):
141 | model_path = str(model_path)
142 | if isinstance(model, nn.DataParallel):
143 | model = model.module
144 | state_dict = model.state_dict()
145 | for key in state_dict:
146 | state_dict[key] = state_dict[key].cpu()
147 | torch.save(state_dict, model_path)
148 |
--------------------------------------------------------------------------------