├── L_GM_loss.py ├── README.md └── test_loss.py /L_GM_loss.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class LGMLoss(nn.Module): 7 | def __init__(self, num_classes, feat_dim, alpha=0.1, lambda_=0.01): 8 | super(LGMLoss, self).__init__() 9 | self.num_classes = num_classes 10 | self.alpha = alpha 11 | self.lambda_ = lambda_ 12 | self.means = nn.Parameter(torch.randn(num_classes, feat_dim)) 13 | nn.init.xavier_uniform_(self.means, gain=math.sqrt(2.0)) 14 | 15 | def forward(self, feat, labels=None): 16 | batch_size= feat.size()[0] 17 | 18 | XY = torch.matmul(feat, torch.transpose(self.means, 0, 1)) 19 | XX = torch.sum(feat ** 2, dim=1, keepdim=True) 20 | YY = torch.sum(torch.transpose(self.means, 0, 1)**2, dim=0, keepdim=True) 21 | neg_sqr_dist = -0.5 * (XX - 2.0 * XY + YY) 22 | 23 | if labels is None: 24 | psudo_labels = torch.argmax(neg_sqr_dist, dim=1) 25 | means_batch = torch.index_select(self.means, dim=0, index=psudo_labels) 26 | likelihood_reg_loss = self.lambda_ * (torch.sum((feat - means_batch)**2) / 2) * (1. / batch_size) 27 | return neg_sqr_dist, likelihood_reg_loss, self.means 28 | 29 | labels_reshped = labels.view(labels.size()[0], -1) 30 | 31 | if torch.cuda.is_available(): 32 | ALPHA = torch.zeros(batch_size, self.num_classes).cuda().scatter_(1, labels_reshped, self.alpha) 33 | K = ALPHA + torch.ones([batch_size, self.num_classes]).cuda() 34 | else: 35 | ALPHA = torch.zeros(batch_size, self.num_classes).scatter_(1, labels_reshped, self.alpha) 36 | K = ALPHA + torch.ones([batch_size, self.num_classes]) 37 | 38 | logits_with_margin = torch.mul(neg_sqr_dist, K) 39 | means_batch = torch.index_select(self.means, dim=0, index=labels) 40 | likelihood_reg_loss = self.lambda_ * (torch.sum((feat - means_batch)**2) / 2) * (1. / batch_size) 41 | return logits_with_margin, likelihood_reg_loss, self.means 42 | 43 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # L-GM_loss_pytorch 2 | Rethinking Feature Distribution for Loss Functions in Image Classification 3 | 4 | This implementation comes from author's tensorflow version. https://github.com/WeitaoVan/L-GM-loss/tree/master/tensorflow 5 | -------------------------------------------------------------------------------- /test_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import numpy as np 5 | from L_GM_loss import LGMLoss 6 | 7 | # demo funtion loss, comes from author's tf version: 8 | # https://github.com/WeitaoVan/L-GM-loss/tree/master/tensorflow 9 | 10 | def tc_lgm_logits(feat, num_classes, labels=None, alpha=0.1, lambda_=0.01, means=None): 11 | N= feat.size()[0] 12 | feat_len = feat.size()[1] 13 | 14 | XY = torch.matmul(feat, torch.transpose(means, 0, 1)) 15 | XX = torch.sum(feat ** 2, dim=1, keepdim=True) 16 | YY = torch.sum(torch.transpose(means, 0, 1)**2, dim=0, keepdim=True) 17 | neg_sqr_dist = -0.5 * (XX - 2.0 * XY + YY) 18 | 19 | if labels is None: 20 | psudo_labels = torch.argmax(neg_sqr_dist, dim=1) 21 | means_batch = torch.index_select(means, dim=0, index=psudo_labels) 22 | likelihood_reg_loss = lambda_ * (torch.sum((feat - means_batch)**2) / 2) * (1. / N) 23 | return neg_sqr_dist, likelihood_reg_loss, means 24 | 25 | label = labels.view(labels.size()[0], -1) 26 | ALPHA = torch.zeros(N, num_classes).scatter_(1, label, alpha) 27 | 28 | K = ALPHA + torch.ones([N, num_classes]) 29 | logits_with_margin = torch.mul(neg_sqr_dist, K) 30 | means_batch = torch.index_select(means, dim=0, index=labels) 31 | likelihood_reg_loss = lambda_ * (torch.sum((feat - means_batch)**2) / 2) * (1. / N) 32 | return neg_sqr_dist, likelihood_reg_loss, means 33 | 34 | if __name__ == '__main__': 35 | num_classes = 5 36 | num_data = 2 37 | feat_dim = 5 38 | num_classes = 3 39 | 40 | np_feat = np.random.randn(num_data, feat_dim).astype(np.float32) 41 | np_labels = np.random.randint(0, num_classes, size=num_data).astype(np.int64) 42 | 43 | 44 | lgmloss = LGMLoss(num_classes, feat_dim, alpha=1.0, lambda_=1.0).cuda() 45 | 46 | # this optim 'optimzer4lgm' use to update lgmloss param: 'means', 47 | # You still need to build a optim for model. 48 | # for example: 49 | # optimzer = optim.SGD(model.parameters(), lr=0.01) 50 | optimzer4lgm = optim.SGD(lgmloss.parameters(), lr=0.1) 51 | 52 | # Simulate two iterations 53 | for _ in range(2): 54 | tc_feat = torch.tensor(np_feat).cuda() 55 | tc_labels = torch.tensor(np_labels).cuda() 56 | 57 | _, loss, _ = lgmloss(tc_feat, tc_labels) 58 | print(loss) 59 | 60 | _, tc_loss, _ = tc_lgm_logits(tc_feat.cpu(), num_classes, 61 | labels=tc_labels.cpu(), 62 | alpha=1.0, 63 | lambda_=1.0, 64 | means=lgmloss.means.cpu()) 65 | print(loss) 66 | print(tc_loss) 67 | 68 | print('--'*10) 69 | 70 | #here also need optimzer for model: 71 | #for example: 72 | #optimzer.zero_grad() 73 | optimzer4lgm.zero_grad() 74 | loss.backward() 75 | 76 | # print(lgmloss.means.grad) 77 | print(lgmloss.means) 78 | print('--'*10) 79 | 80 | #here also need optimzer apply grad: 81 | #for example: 82 | #optimzer.step() 83 | optimzer4lgm.step() 84 | --------------------------------------------------------------------------------