├── .gitattributes ├── .gitignore ├── CenterLoss.py ├── LICENSE ├── MNIST_with_centerloss.py ├── README.md └── images ├── 0.gif ├── 0.jpg ├── 1.0-new.gif ├── 1.0-new.jpg ├── 1.0.gif └── 1.0.jpg /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *,cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # dotenv 80 | .env 81 | 82 | # virtualenv 83 | .venv/ 84 | venv/ 85 | ENV/ 86 | 87 | # Spyder project settings 88 | .spyderproject 89 | 90 | # Rope project settings 91 | .ropeproject 92 | 93 | */.DS_Store 94 | .DS_Store 95 | 96 | .idea 97 | #*.jpg 98 | ignored/* 99 | .vscode/* 100 | -------------------------------------------------------------------------------- /CenterLoss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd.function import Function 4 | 5 | class CenterLoss(nn.Module): 6 | def __init__(self, num_classes, feat_dim, size_average=True): 7 | super(CenterLoss, self).__init__() 8 | self.centers = nn.Parameter(torch.randn(num_classes, feat_dim)) 9 | self.centerlossfunc = CenterlossFunc.apply 10 | self.feat_dim = feat_dim 11 | self.size_average = size_average 12 | 13 | def forward(self, label, feat): 14 | batch_size = feat.size(0) 15 | feat = feat.view(batch_size, -1) 16 | # To check the dim of centers and features 17 | if feat.size(1) != self.feat_dim: 18 | raise ValueError("Center's dim: {0} should be equal to input feature's \ 19 | dim: {1}".format(self.feat_dim,feat.size(1))) 20 | batch_size_tensor = feat.new_empty(1).fill_(batch_size if self.size_average else 1) 21 | loss = self.centerlossfunc(feat, label, self.centers, batch_size_tensor) 22 | return loss 23 | 24 | 25 | class CenterlossFunc(Function): 26 | @staticmethod 27 | def forward(ctx, feature, label, centers, batch_size): 28 | ctx.save_for_backward(feature, label, centers, batch_size) 29 | centers_batch = centers.index_select(0, label.long()) 30 | return (feature - centers_batch).pow(2).sum() / 2.0 / batch_size 31 | 32 | @staticmethod 33 | def backward(ctx, grad_output): 34 | feature, label, centers, batch_size = ctx.saved_tensors 35 | centers_batch = centers.index_select(0, label.long()) 36 | diff = centers_batch - feature 37 | # init every iteration 38 | counts = centers.new_ones(centers.size(0)) 39 | ones = centers.new_ones(label.size(0)) 40 | grad_centers = centers.new_zeros(centers.size()) 41 | 42 | counts = counts.scatter_add_(0, label.long(), ones) 43 | grad_centers.scatter_add_(0, label.unsqueeze(1).expand(feature.size()).long(), diff) 44 | grad_centers = grad_centers/counts.view(-1, 1) 45 | return - grad_output * diff / batch_size, None, grad_centers / batch_size, None 46 | 47 | 48 | def main(test_cuda=False): 49 | print('-'*80) 50 | device = torch.device("cuda" if test_cuda else "cpu") 51 | ct = CenterLoss(10,2,size_average=True).to(device) 52 | y = torch.Tensor([0,0,2,1]).to(device) 53 | feat = torch.zeros(4,2).to(device).requires_grad_() 54 | print (list(ct.parameters())) 55 | print (ct.centers.grad) 56 | out = ct(y,feat) 57 | print(out.item()) 58 | out.backward() 59 | print(ct.centers.grad) 60 | print(feat.grad) 61 | 62 | if __name__ == '__main__': 63 | torch.manual_seed(999) 64 | main(test_cuda=False) 65 | if torch.cuda.is_available(): 66 | main(test_cuda=True) 67 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 jxgu1016 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /MNIST_with_centerloss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | from torchvision import datasets, transforms 6 | from torch.utils.data import DataLoader 7 | import torch.optim.lr_scheduler as lr_scheduler 8 | from CenterLoss import CenterLoss 9 | import matplotlib.pyplot as plt 10 | 11 | 12 | class Net(nn.Module): 13 | def __init__(self): 14 | super(Net, self).__init__() 15 | self.conv1_1 = nn.Conv2d(1, 32, kernel_size=5, padding=2) 16 | self.prelu1_1 = nn.PReLU() 17 | self.conv1_2 = nn.Conv2d(32, 32, kernel_size=5, padding=2) 18 | self.prelu1_2 = nn.PReLU() 19 | self.conv2_1 = nn.Conv2d(32, 64, kernel_size=5, padding=2) 20 | self.prelu2_1 = nn.PReLU() 21 | self.conv2_2 = nn.Conv2d(64, 64, kernel_size=5, padding=2) 22 | self.prelu2_2 = nn.PReLU() 23 | self.conv3_1 = nn.Conv2d(64, 128, kernel_size=5, padding=2) 24 | self.prelu3_1 = nn.PReLU() 25 | self.conv3_2 = nn.Conv2d(128, 128, kernel_size=5, padding=2) 26 | self.prelu3_2 = nn.PReLU() 27 | self.preluip1 = nn.PReLU() 28 | self.ip1 = nn.Linear(128*3*3, 2) 29 | self.ip2 = nn.Linear(2, 10, bias=False) 30 | 31 | def forward(self, x): 32 | x = self.prelu1_1(self.conv1_1(x)) 33 | x = self.prelu1_2(self.conv1_2(x)) 34 | x = F.max_pool2d(x,2) 35 | x = self.prelu2_1(self.conv2_1(x)) 36 | x = self.prelu2_2(self.conv2_2(x)) 37 | x = F.max_pool2d(x,2) 38 | x = self.prelu3_1(self.conv3_1(x)) 39 | x = self.prelu3_2(self.conv3_2(x)) 40 | x = F.max_pool2d(x,2) 41 | x = x.view(-1, 128*3*3) 42 | ip1 = self.preluip1(self.ip1(x)) 43 | ip2 = self.ip2(ip1) 44 | return ip1, F.log_softmax(ip2, dim=1) 45 | 46 | def visualize(feat, labels, epoch): 47 | plt.ion() 48 | c = ['#ff0000', '#ffff00', '#00ff00', '#00ffff', '#0000ff', 49 | '#ff00ff', '#990000', '#999900', '#009900', '#009999'] 50 | plt.clf() 51 | for i in range(10): 52 | plt.plot(feat[labels == i, 0], feat[labels == i, 1], '.', c=c[i]) 53 | plt.legend(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'], loc = 'upper right') 54 | plt.xlim(xmin=-8,xmax=8) 55 | plt.ylim(ymin=-8,ymax=8) 56 | plt.text(-7.8,7.3,"epoch=%d" % epoch) 57 | plt.savefig('./images/epoch=%d.jpg' % epoch) 58 | plt.draw() 59 | plt.pause(0.001) 60 | 61 | 62 | def train(epoch): 63 | print "Training... Epoch = %d" % epoch 64 | ip1_loader = [] 65 | idx_loader = [] 66 | for i,(data, target) in enumerate(train_loader): 67 | data, target = data.to(device), target.to(device) 68 | 69 | ip1, pred = model(data) 70 | loss = nllloss(pred, target) + loss_weight * centerloss(target, ip1) 71 | 72 | optimizer4nn.zero_grad() 73 | optimzer4center.zero_grad() 74 | 75 | loss.backward() 76 | 77 | optimizer4nn.step() 78 | optimzer4center.step() 79 | 80 | ip1_loader.append(ip1) 81 | idx_loader.append((target)) 82 | 83 | feat = torch.cat(ip1_loader, 0) 84 | labels = torch.cat(idx_loader, 0) 85 | visualize(feat.data.cpu().numpy(),labels.data.cpu().numpy(),epoch) 86 | 87 | use_cuda = torch.cuda.is_available() and True 88 | device = torch.device("cuda" if use_cuda else "cpu") 89 | # Dataset 90 | trainset = datasets.MNIST('../MNIST', download=True,train=True, transform=transforms.Compose([ 91 | transforms.ToTensor(), 92 | transforms.Normalize((0.1307,), (0.3081,))])) 93 | train_loader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=4) 94 | 95 | # Model 96 | model = Net().to(device) 97 | 98 | # NLLLoss 99 | nllloss = nn.NLLLoss().to(device) #CrossEntropyLoss = log_softmax + NLLLoss 100 | # CenterLoss 101 | loss_weight = 1 102 | centerloss = CenterLoss(10, 2).to(device) 103 | 104 | # optimzer4nn 105 | optimizer4nn = optim.SGD(model.parameters(),lr=0.001,momentum=0.9, weight_decay=0.0005) 106 | sheduler = lr_scheduler.StepLR(optimizer4nn,20,gamma=0.8) 107 | 108 | # optimzer4center 109 | optimzer4center = optim.SGD(centerloss.parameters(), lr =0.5) 110 | 111 | for epoch in range(100): 112 | sheduler.step() 113 | # print optimizer4nn.param_groups[0]['lr'] 114 | train(epoch+1) 115 | 116 | 117 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### UPDATE(Oct. 2018) 2 | By dropping the bias of the last fc layer according to the [issue](https://github.com/jxgu1016/MNIST_center_loss_pytorch/issues/8), the centers tend to distribute around a circle as reported in the orignal paper. 3 | 4 | 5 | ### UPDATE(May. 2018) 6 | Migration to PyTorch 0.4 done! 7 | 8 | 9 | ### UPDATE(Apr. 2018) 10 | Thanks [@wenfahu](https://github.com/wenfahu) for accomplishing the optimization of *backward()*. 11 | 12 | 13 | ### UPDATE(Mar. 2018) 14 | Problems reported in the [NOTIFICATION](#jump) now has been SOLVED! Functionally, this repo is exactly the same as the official repo. New result is shown below and looks similar to the former one. 15 | If you want to try the former one, please return to [Commits on Feb 12, 2018](https://github.com/jxgu1016/MNIST_center_loss_pytorch/tree/dbeea5380de8a3c6b1b3b3f2c411b980e143dd87). 16 | 17 | Some codes can be and should be **optimized** when calculating Eq.4 in [*backword()*](https://github.com/jxgu1016/MNIST_center_loss_pytorch/blob/master/CenterLoss.py) to replace the for-loop and feel free to pull your request. 18 | 19 | ### NOTIFICATION(Feb. 2018) 20 | 21 | In the begining, it was just a practise project to get familiar with PyTorch. Surprisedly, I didn't expect that there would be so many researchers following my repo of center loss. In that case, I'd like to illustrate that this implementation is **not exactly the same** as the official one. 22 | 23 | If you read the equations in the paper carefully, the defination of center loss in the Eq. 2 can only lead you to the Eq. 3 but the update equation of centers in Eq. 4 can not be inferred arrcoding to the differentiation formulas. If not specified, the derivatives of one module are decided by the forward operation following the strategy of autograd in PyTorch. Considering the incompatibility of Eq. 3 and Eq. 4, only one of them can be implemented correctly and what I chose was the latter one. If you remvoe the *centers_count* in my code, this will lead you to the Eq. 3. 24 | 25 | This problem exists in other implementaions and the impact remains unknown but looks harmless. 26 | 27 | TO DO: To specify the derivatives just like the [original caffe repo](https://github.com/ydwen/caffe-face), instead of being calculated by autograd system. 28 | 29 | # MNIST_center_loss_pytorch 30 | 31 | A pytorch implementation of center loss on MNIST and it's a toy example of ECCV2016 paper [A Discriminative Feature Learning Approach for Deep Face Recognition](https://github.com/ydwen/caffe-face) 32 | 33 | In order to ease the classifiers, center loss was designed to make samples in each class flock together. 34 | 35 | Results are shown below: 36 | 37 |