├── README.md └── CenterLoss.py /README.md: -------------------------------------------------------------------------------- 1 | # center_loss_pytorch 2 | 3 | ## Introduction 4 | 5 | This is an Pytorch implementation of center loss. Some codes are from the repository [MNIST_center_loss_pytorch](https://github.com/jxgu1016/MNIST_center_loss_pytorch). 6 | 7 | [Here](https://fanjingbo.com/post/center_loss_pytorch/) is an article about the code. 8 | 9 | ## Usage 10 | 11 | You should use centerloss like this in your training file. 12 | 13 | ```python 14 | # Creat an instance of CenterLoss 15 | centerloss = CenterLoss(10, 48, 0.1) 16 | # Get the loss and centers params 17 | loss_center, params_grad = centerloss(targets, features) 18 | # Calculate all gradients 19 | loss_center.backward() 20 | # Reset gradients(generated by autograd) in center params 21 | centerloss.zero_grad() 22 | # Manually assign centers gradients other than using autograd 23 | centerloss.centers.backward(params_grad) 24 | ``` 25 | -------------------------------------------------------------------------------- /CenterLoss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable, Function 4 | 5 | class CenterLoss(nn.Module): 6 | def __init__(self, num_classes, feat_dim, loss_weight = 0.01): 7 | super(CenterLoss, self).__init__() 8 | self.num_classes = num_classes 9 | self.feat_dim = feat_dim 10 | self.loss_weight = loss_weight 11 | self.centers = nn.Parameter(torch.randn(num_classes, feat_dim)) 12 | 13 | def forward(self, y, feat): 14 | hist = Variable(torch.histc(y.cpu().data.float(), bins = self.num_classes, min = 0, max = self.num_classes)).cuda() 15 | feat = feat.view(feat.size()[0], -1) 16 | centers_pred = self.centers.index_select(0, y.long()) 17 | diff = feat - centers_pred 18 | feat_mean = torch.Tensor().cuda() 19 | for i in range(self.num_classes): 20 | if i not in y.data: 21 | feat_mean = torch.cat((feat_mean, torch.zeros(1, self.feat_dim).cuda()), 0) 22 | else: 23 | feat_mean = torch.cat((feat_mean, (feat.index_select(0, Variable((y.data==i).nonzero().squeeze_(1)))).mean(0).data.unsqueeze_(0)), 0) 24 | centers_grad = Variable((hist / (1 + hist)).data.unsqueeze_(1)) * (self.centers - Variable(feat_mean)) 25 | loss = self.loss_weight * 1 / 2.0 * diff.pow(2).sum(1).sum() 26 | return loss, centers_grad 27 | --------------------------------------------------------------------------------