├── .gitignore ├── README.md ├── data.py ├── lsoftmax.py ├── lsoftmax_test.py ├── model.py └── train_mnist.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | data 3 | *.pyc 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Large-Margin Softmax (L-Softmax) Loss 2 | PyTorch implementation of L-Softmax loss. 3 | 4 | ## Caution 5 | This quick and dirty implementation is getting some stars. 6 | Please note that I've not heavily tested the code, so if you want to use this in your research, business, or whatever, please re-check the correctness of the implementation. 7 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.sampler import SequentialSampler 2 | from torchvision import datasets, transforms 3 | -------------------------------------------------------------------------------- /lsoftmax.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Variable 6 | 7 | from scipy.special import binom 8 | 9 | 10 | class LSoftmaxLinear(nn.Module): 11 | 12 | def __init__(self, input_dim, output_dim, margin): 13 | super().__init__() 14 | self.input_dim = input_dim 15 | self.output_dim = output_dim 16 | self.margin = margin 17 | 18 | self.weight = nn.Parameter(torch.FloatTensor(input_dim, output_dim)) 19 | 20 | self.divisor = math.pi / self.margin 21 | self.coeffs = binom(margin, range(0, margin + 1, 2)) 22 | self.cos_exps = range(self.margin, -1, -2) 23 | self.sin_sq_exps = range(len(self.cos_exps)) 24 | self.signs = [1] 25 | for i in range(1, len(self.sin_sq_exps)): 26 | self.signs.append(self.signs[-1] * -1) 27 | 28 | def reset_parameters(self): 29 | nn.init.kaiming_normal(self.weight.data.t()) 30 | 31 | def find_k(self, cos): 32 | acos = cos.acos() 33 | k = (acos / self.divisor).floor().detach() 34 | return k 35 | 36 | def forward(self, input, target=None): 37 | if self.training: 38 | assert target is not None 39 | logit = input.matmul(self.weight) 40 | batch_size = logit.size(0) 41 | logit_target = logit[range(batch_size), target] 42 | weight_target_norm = self.weight[:, target].norm(p=2, dim=0) 43 | input_norm = input.norm(p=2, dim=1) 44 | # norm_target_prod: (batch_size,) 45 | norm_target_prod = weight_target_norm * input_norm 46 | # cos_target: (batch_size,) 47 | cos_target = logit_target / (norm_target_prod + 1e-10) 48 | sin_sq_target = 1 - cos_target**2 49 | 50 | num_ns = self.margin//2 + 1 51 | # coeffs, cos_powers, sin_sq_powers, signs: (num_ns,) 52 | coeffs = Variable(input.data.new(self.coeffs)) 53 | cos_exps = Variable(input.data.new(self.cos_exps)) 54 | sin_sq_exps = Variable(input.data.new(self.sin_sq_exps)) 55 | signs = Variable(input.data.new(self.signs)) 56 | 57 | cos_terms = cos_target.unsqueeze(1) ** cos_exps.unsqueeze(0) 58 | sin_sq_terms = (sin_sq_target.unsqueeze(1) 59 | ** sin_sq_exps.unsqueeze(0)) 60 | 61 | cosm_terms = (signs.unsqueeze(0) * coeffs.unsqueeze(0) 62 | * cos_terms * sin_sq_terms) 63 | cosm = cosm_terms.sum(1) 64 | k = self.find_k(cos_target) 65 | 66 | ls_target = norm_target_prod * (((-1)**k * cosm) - 2*k) 67 | logit[range(batch_size), target] = ls_target 68 | return logit 69 | else: 70 | assert target is None 71 | return input.matmul(self.weight) 72 | -------------------------------------------------------------------------------- /lsoftmax_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import torch 5 | from torch.autograd import Variable 6 | 7 | from lsoftmax import LSoftmaxLinear 8 | 9 | 10 | class LSoftmaxTestCase(unittest.TestCase): 11 | 12 | def test_trivial(self): 13 | m = LSoftmaxLinear(input_dim=2, output_dim=4, margin=1) 14 | x = Variable(torch.arange(6).view(3, 2)) 15 | target = Variable(torch.LongTensor([3, 0, 1])) 16 | 17 | m.eval() 18 | eval_out = m.forward(input=x).data.tolist() 19 | m.train() 20 | train_out = m.forward(input=x, target=target).data.tolist() 21 | np.testing.assert_array_almost_equal(eval_out, train_out) 22 | 23 | def test_odd(self): 24 | m = LSoftmaxLinear(input_dim=2, output_dim=4, margin=3) 25 | m.weight.data.copy_(torch.arange(-4, 4).view(2, 4)) 26 | x = Variable(torch.arange(6).view(3, 2)) 27 | target = Variable(torch.LongTensor([3, 0, 1])) 28 | 29 | m.eval() 30 | eval_out = m.forward(input=x).data.tolist() 31 | eval_gold = [[0, 1, 2, 3], [-8, -3, 2, 7], [-16, -7, 2, 11]] 32 | np.testing.assert_array_almost_equal(eval_out, eval_gold, decimal=5) 33 | 34 | m.train() 35 | train_out = m.forward(input=x, target=target).data.tolist() 36 | train_gold = [[0, 1, 2, 1.7999999999999996], 37 | [-43.53497425357768, -3, 2, 7], 38 | [-16, -58.150571999218542, 2, 11]] 39 | np.testing.assert_array_almost_equal(train_out, train_gold, decimal=5) 40 | 41 | def test_even(self): 42 | m = LSoftmaxLinear(input_dim=2, output_dim=4, margin=4) 43 | m.weight.data.copy_(torch.arange(-4, 4).view(2, 4)) 44 | x = Variable(torch.arange(6).view(3, 2)) 45 | target = Variable(torch.LongTensor([3, 0, 1])) 46 | 47 | m.eval() 48 | eval_out = m.forward(input=x).data.tolist() 49 | eval_gold = [[0, 1, 2, 3], [-8, -3, 2, 7], [-16, -7, 2, 11]] 50 | np.testing.assert_array_almost_equal(eval_out, eval_gold, decimal=5) 51 | 52 | m.train() 53 | train_out = m.forward(input=x, target=target).data.tolist() 54 | train_gold = [[0, 1, 2, 0.88543774484714499], 55 | [-67.844100922931872, -3, 2, 7], 56 | [-16, -77.791173935544478, 2, 11]] 57 | np.testing.assert_array_almost_equal(train_out, train_gold, decimal=5) 58 | 59 | 60 | if __name__ == '__main__': 61 | unittest.main() 62 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | from torch import nn 4 | from torch.nn import init 5 | 6 | from lsoftmax import LSoftmaxLinear 7 | 8 | 9 | class MNISTModel(nn.Module): 10 | 11 | def __init__(self, margin): 12 | super().__init__() 13 | self.margin = margin 14 | 15 | cnn_layers = OrderedDict() 16 | cnn_layers['pre_bn'] = nn.BatchNorm2d(1) 17 | cnn_layers['conv0_0'] = nn.Conv2d(in_channels=1, out_channels=64, 18 | kernel_size=(3, 3), padding=1) 19 | cnn_layers['prelu0_0'] = nn.PReLU(64) 20 | cnn_layers['bn0_0'] = nn.BatchNorm2d(64) 21 | # conv1.x 22 | for x in range(3): 23 | cnn_layers[f'conv1_{x}'] = nn.Conv2d( 24 | in_channels=64, out_channels=64, kernel_size=(3, 3), 25 | padding=1) 26 | cnn_layers[f'prelu1_{x}'] = nn.PReLU(64) 27 | cnn_layers[f'bn1_{x}'] = nn.BatchNorm2d(64) 28 | cnn_layers['pool1'] = nn.MaxPool2d(kernel_size=(2, 2), stride=2) 29 | # conv2.x 30 | for x in range(4): 31 | cnn_layers[f'conv2_{x}'] = nn.Conv2d( 32 | in_channels=64, out_channels=64, kernel_size=(3, 3), 33 | padding=1) 34 | cnn_layers[f'prelu2_{x}'] = nn.PReLU(64) 35 | cnn_layers[f'bn2_{x}'] = nn.BatchNorm2d(64) 36 | cnn_layers['pool2'] = nn.MaxPool2d(kernel_size=(2, 2), stride=2) 37 | # conv3.x 38 | for x in range(4): 39 | cnn_layers[f'conv3_{x}'] = nn.Conv2d( 40 | in_channels=64, out_channels=64, kernel_size=(3, 3), 41 | padding=1) 42 | cnn_layers[f'prelu3_{x}'] = nn.PReLU(64) 43 | cnn_layers[f'bn3_{x}'] = nn.BatchNorm2d(64) 44 | cnn_layers['pool3'] = nn.MaxPool2d(kernel_size=(2, 2), stride=2) 45 | self.net = nn.Sequential(cnn_layers) 46 | self.fc = nn.Sequential(OrderedDict([ 47 | ('fc0', nn.Linear(in_features=576, out_features=256)), 48 | # ('fc1', nn.Linear(in_features=256, out_features=10)) 49 | ('fc0_bn', nn.BatchNorm1d(256)) 50 | ])) 51 | 52 | self.lsoftmax_linear = LSoftmaxLinear( 53 | input_dim=256, output_dim=10, margin=margin) 54 | self.reset_parameters() 55 | 56 | def reset_parameters(self): 57 | def init_kaiming(layer): 58 | init.kaiming_normal(layer.weight.data) 59 | init.constant(layer.bias.data, val=0) 60 | 61 | init_kaiming(self.net.conv0_0) 62 | for x in range(3): 63 | init_kaiming(getattr(self.net, f'conv1_{x}')) 64 | for x in range(4): 65 | init_kaiming(getattr(self.net, f'conv2_{x}')) 66 | for x in range(4): 67 | init_kaiming(getattr(self.net, f'conv3_{x}')) 68 | init_kaiming(self.fc.fc0) 69 | self.lsoftmax_linear.reset_parameters() 70 | # init_kaiming(self.fc.fc1) 71 | 72 | def forward(self, input, target=None): 73 | """ 74 | Args: 75 | input: A variable of size (N, 1, 28, 28). 76 | target: A long variable of size (N,). 77 | 78 | Returns: 79 | logit: A variable of size (N, 10). 80 | """ 81 | 82 | conv_output = self.net(input) 83 | batch_size = conv_output.size(0) 84 | fc_input = conv_output.view(batch_size, -1) 85 | fc_output = self.fc(fc_input) 86 | logit = self.lsoftmax_linear(input=fc_output, target=target) 87 | return logit 88 | -------------------------------------------------------------------------------- /train_mnist.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | from tensorboardX import SummaryWriter 6 | from torch import nn, optim 7 | from torch.autograd import Variable 8 | from torch.nn.utils import clip_grad_norm 9 | from torch.utils.data import DataLoader, sampler 10 | from torchvision import datasets, transforms 11 | 12 | from model import MNISTModel 13 | 14 | 15 | def train(args): 16 | transform = transforms.Compose([ 17 | transforms.ToTensor(), 18 | transforms.Normalize(mean=(0.1307,), std=(0.3081,)) 19 | ]) 20 | train_dataset = datasets.MNIST( 21 | root='data', train=True, transform=transform, download=True) 22 | test_dataset = datasets.MNIST( 23 | root='data', train=False, transform=transform, download=True) 24 | train_loader = DataLoader( 25 | dataset=train_dataset, batch_size=256, 26 | sampler=sampler.SubsetRandomSampler(list(range(0, 55000)))) 27 | valid_loader = DataLoader( 28 | dataset=train_dataset, batch_size=256, 29 | sampler=sampler.SubsetRandomSampler(list(range(55000, 60000)))) 30 | test_loader = DataLoader(dataset=test_dataset, batch_size=256) 31 | 32 | model = MNISTModel(margin=args.margin) 33 | if args.gpu > -1: 34 | model.cuda(args.gpu) 35 | criterion = nn.CrossEntropyLoss() 36 | 37 | if args.optimizer == 'sgd': 38 | optimizer = optim.SGD(params=model.parameters(), lr=0.1, momentum=0.9, 39 | weight_decay=0.0005) 40 | min_lr = 0.001 41 | elif args.optimizer == 'adam': 42 | optimizer = optim.Adam(model.parameters(), weight_decay=0.0005) 43 | min_lr = 0.00001 44 | else: 45 | raise ValueError('Unknown optimizer') 46 | 47 | lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau( 48 | optimizer=optimizer, mode='max', factor=0.1, patience=5, verbose=True, 49 | min_lr=min_lr) 50 | 51 | summary_writer = SummaryWriter(os.path.join(args.save_dir, 'log')) 52 | 53 | def var(tensor, volatile=False): 54 | if args.gpu > -1: 55 | tensor = tensor.cuda(args.gpu) 56 | return Variable(tensor, volatile=volatile) 57 | 58 | global_step = 0 59 | 60 | def train_epoch(): 61 | nonlocal global_step 62 | model.train() 63 | for train_batch in train_loader: 64 | train_x, train_y = var(train_batch[0]), var(train_batch[1]) 65 | logit = model(input=train_x, target=train_y) 66 | loss = criterion(input=logit, target=train_y) 67 | optimizer.zero_grad() 68 | loss.backward() 69 | clip_grad_norm(model.parameters(), max_norm=10) 70 | optimizer.step() 71 | global_step += 1 72 | summary_writer.add_scalar( 73 | tag='train_loss', scalar_value=loss.data[0], 74 | global_step=global_step) 75 | 76 | def validate(): 77 | model.eval() 78 | loss_sum = num_correct = denom = 0 79 | for valid_batch in valid_loader: 80 | valid_x, valid_y = (var(valid_batch[0], volatile=True), 81 | var(valid_batch[1], volatile=True)) 82 | logit = model(valid_x) 83 | y_pred = logit.max(1)[1] 84 | loss = criterion(input=logit, target=valid_y) 85 | loss_sum += loss.data[0] * valid_x.size(0) 86 | num_correct += y_pred.eq(valid_y).long().sum().data[0] 87 | denom += valid_x.size(0) 88 | loss = loss_sum / denom 89 | accuracy = num_correct / denom 90 | summary_writer.add_scalar(tag='valid_loss', scalar_value=loss, 91 | global_step=global_step) 92 | summary_writer.add_scalar(tag='valid_accuracy', scalar_value=accuracy, 93 | global_step=global_step) 94 | lr_scheduler.step(accuracy) 95 | return loss, accuracy 96 | 97 | def test(): 98 | model.eval() 99 | num_correct = denom = 0 100 | for test_batch in test_loader: 101 | test_x, test_y = (var(test_batch[0], volatile=True), 102 | var(test_batch[1], volatile=True)) 103 | logit = model(test_x) 104 | y_pred = logit.max(1)[1] 105 | num_correct += y_pred.eq(test_y).long().sum().data[0] 106 | denom += test_x.size(0) 107 | accuracy = num_correct / denom 108 | summary_writer.add_scalar(tag='test_accuracy', scalar_value=accuracy, 109 | global_step=global_step) 110 | return accuracy 111 | 112 | best_valid_accuracy = 0 113 | for epoch in range(1, args.max_epoch + 1): 114 | train_epoch() 115 | valid_loss, valid_accuracy = validate() 116 | print(f'Epoch {epoch}: Valid loss = {valid_loss:.5f}') 117 | print(f'Epoch {epoch}: Valid accuracy = {valid_accuracy:.5f}') 118 | test_accuracy = test() 119 | print(f'Epoch {epoch}: Test accuracy = {test_accuracy:.5f}') 120 | if valid_accuracy > best_valid_accuracy: 121 | model_filename = (f'{epoch:02d}' 122 | f'-{valid_loss:.5f}' 123 | f'-{valid_accuracy:.5f}' 124 | f'-{test_accuracy:.5f}.pt') 125 | model_path = os.path.join(args.save_dir, model_filename) 126 | torch.save(model.state_dict(), model_path) 127 | print(f'Epoch {epoch}: Saved the new best model to: {model_path}') 128 | best_valid_accuracy = valid_accuracy 129 | 130 | 131 | def main(): 132 | parser = argparse.ArgumentParser() 133 | parser.add_argument('--margin', default=1, type=int) 134 | parser.add_argument('--optimizer', default='sgd') 135 | parser.add_argument('--max-epoch', default=50, type=int) 136 | parser.add_argument('--gpu', default=0, type=int) 137 | parser.add_argument('--save-dir', required=True) 138 | args = parser.parse_args() 139 | train(args) 140 | 141 | 142 | if __name__ == '__main__': 143 | main() 144 | --------------------------------------------------------------------------------