├── .gitignore ├── LICENSE ├── README.md ├── center_loss.py ├── datasets.py ├── gifs ├── center_test.gif ├── center_train.gif ├── softmax_test.gif └── softmax_train.gif ├── main.py ├── models.py ├── transforms.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Custom 2 | data/ 3 | log/ 4 | 5 | # OS files 6 | .DS_Store 7 | .AppleDouble 8 | .LSOverride 9 | 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | .static_storage/ 66 | .media/ 67 | local_settings.py 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # celery beat schedule file 89 | celerybeat-schedule 90 | 91 | # SageMath parsed files 92 | *.sage.py 93 | 94 | # Environments 95 | .env 96 | .venv 97 | env/ 98 | venv/ 99 | ENV/ 100 | env.bak/ 101 | venv.bak/ 102 | 103 | # Spyder project settings 104 | .spyderproject 105 | .spyproject 106 | 107 | # Rope project settings 108 | .ropeproject 109 | 110 | # mkdocs documentation 111 | /site 112 | 113 | # mypy 114 | .mypy_cache/ 115 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Kaiyang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-center-loss 2 | Pytorch implementation of center loss: [Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.](https://ydwen.github.io/papers/WenECCV16.pdf) 3 | 4 | This loss function is also used by [deep-person-reid](https://github.com/KaiyangZhou/deep-person-reid). 5 | 6 | ## Get started 7 | Clone this repo and run the code 8 | ```bash 9 | $ git clone https://github.com/KaiyangZhou/pytorch-center-loss 10 | $ cd pytorch-center-loss 11 | $ python main.py --eval-freq 1 --gpu 0 --save-dir log/ --plot 12 | ``` 13 | You will see the following info in your terminal 14 | ```bash 15 | Currently using GPU: 0 16 | Creating dataset: mnist 17 | Creating model: cnn 18 | ==> Epoch 1/100 19 | Batch 50/469 Loss 2.332793 (2.557837) XentLoss 2.332744 (2.388296) CenterLoss 0.000048 (0.169540) 20 | Batch 100/469 Loss 2.354638 (2.463851) XentLoss 2.354637 (2.379078) CenterLoss 0.000001 (0.084773) 21 | Batch 150/469 Loss 2.361732 (2.434477) XentLoss 2.361732 (2.377962) CenterLoss 0.000000 (0.056515) 22 | Batch 200/469 Loss 2.336701 (2.417842) XentLoss 2.336700 (2.375455) CenterLoss 0.000001 (0.042386) 23 | Batch 250/469 Loss 2.404814 (2.407015) XentLoss 2.404813 (2.373106) CenterLoss 0.000001 (0.033909) 24 | Batch 300/469 Loss 2.338753 (2.398546) XentLoss 2.338752 (2.370288) CenterLoss 0.000001 (0.028258) 25 | Batch 350/469 Loss 2.367068 (2.390672) XentLoss 2.367059 (2.366450) CenterLoss 0.000009 (0.024221) 26 | Batch 400/469 Loss 2.344178 (2.384820) XentLoss 2.344142 (2.363620) CenterLoss 0.000036 (0.021199) 27 | Batch 450/469 Loss 2.329708 (2.379460) XentLoss 2.329661 (2.360611) CenterLoss 0.000047 (0.018848) 28 | ==> Test 29 | Accuracy (%): 10.32 Error rate (%): 89.68 30 | ... ... 31 | ==> Epoch 30/100 32 | Batch 50/469 Loss 0.141117 (0.155986) XentLoss 0.084169 (0.091617) CenterLoss 0.056949 (0.064369) 33 | Batch 100/469 Loss 0.138201 (0.151291) XentLoss 0.089146 (0.092839) CenterLoss 0.049055 (0.058452) 34 | Batch 150/469 Loss 0.151055 (0.151985) XentLoss 0.090816 (0.092405) CenterLoss 0.060239 (0.059580) 35 | Batch 200/469 Loss 0.150803 (0.153333) XentLoss 0.092857 (0.092156) CenterLoss 0.057946 (0.061176) 36 | Batch 250/469 Loss 0.162954 (0.154971) XentLoss 0.094889 (0.092099) CenterLoss 0.068065 (0.062872) 37 | Batch 300/469 Loss 0.162895 (0.156038) XentLoss 0.093100 (0.092034) CenterLoss 0.069795 (0.064004) 38 | Batch 350/469 Loss 0.146187 (0.156491) XentLoss 0.082508 (0.091787) CenterLoss 0.063679 (0.064704) 39 | Batch 400/469 Loss 0.171533 (0.157390) XentLoss 0.092526 (0.091674) CenterLoss 0.079007 (0.065716) 40 | Batch 450/469 Loss 0.209196 (0.158371) XentLoss 0.098388 (0.091560) CenterLoss 0.110808 (0.066811) 41 | ==> Test 42 | Accuracy (%): 98.51 Error rate (%): 1.49 43 | ... ... 44 | ``` 45 | 46 | Please run `python main.py -h` for more details regarding input arguments. 47 | 48 | ## Results 49 | We visualize the feature learning process below. 50 | 51 | Softmax only. Left: training set. Right: test set. 52 |
53 | train 54 | train 55 |
56 | 57 | Softmax + center loss. Left: training set. Right: test set. 58 |
59 | train 60 | train 61 |
62 | 63 | ## How to use center loss in your own project 64 | 1. All you need is the `center_loss.py` file 65 | ```python 66 | from center_loss import CenterLoss 67 | ``` 68 | 2. Initialize center loss in the main function 69 | ```python 70 | center_loss = CenterLoss(num_classes=10, feat_dim=2, use_gpu=True) 71 | ``` 72 | 3. Construct an optimizer for center loss 73 | ```python 74 | optimizer_centloss = torch.optim.SGD(center_loss.parameters(), lr=0.5) 75 | ``` 76 | Alternatively, you can merge optimizers of model and center loss, like 77 | ``` 78 | params = list(model.parameters()) + list(center_loss.parameters()) 79 | optimizer = torch.optim.SGD(params, lr=0.1) # here lr is the overall learning rate 80 | ``` 81 | 82 | 4. Update class centers just like how you update a pytorch model 83 | ```python 84 | # features (torch tensor): a 2D torch float tensor with shape (batch_size, feat_dim) 85 | # labels (torch long tensor): 1D torch long tensor with shape (batch_size) 86 | # alpha (float): weight for center loss 87 | loss = center_loss(features, labels) * alpha + other_loss 88 | optimizer_centloss.zero_grad() 89 | loss.backward() 90 | # multiple (1./alpha) in order to remove the effect of alpha on updating centers 91 | for param in center_loss.parameters(): 92 | param.grad.data *= (1./alpha) 93 | optimizer_centloss.step() 94 | ``` 95 | If you adopt the second way (i.e. use one optimizer for both model and center loss), the update code would look like 96 | ```python 97 | loss = center_loss(features, labels) * alpha + other_loss 98 | optimizer.zero_grad() 99 | loss.backward() 100 | for param in center_loss.parameters(): 101 | # lr_cent is learning rate for center loss, e.g. lr_cent = 0.5 102 | param.grad.data *= (lr_cent / (alpha * lr)) 103 | optimizer.step() 104 | ``` -------------------------------------------------------------------------------- /center_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class CenterLoss(nn.Module): 5 | """Center loss. 6 | 7 | Reference: 8 | Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016. 9 | 10 | Args: 11 | num_classes (int): number of classes. 12 | feat_dim (int): feature dimension. 13 | """ 14 | def __init__(self, num_classes=10, feat_dim=2, use_gpu=True): 15 | super(CenterLoss, self).__init__() 16 | self.num_classes = num_classes 17 | self.feat_dim = feat_dim 18 | self.use_gpu = use_gpu 19 | 20 | if self.use_gpu: 21 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda()) 22 | else: 23 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim)) 24 | 25 | def forward(self, x, labels): 26 | """ 27 | Args: 28 | x: feature matrix with shape (batch_size, feat_dim). 29 | labels: ground truth labels with shape (batch_size). 30 | """ 31 | batch_size = x.size(0) 32 | distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \ 33 | torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t() 34 | distmat.addmm_(1, -2, x, self.centers.t()) 35 | 36 | classes = torch.arange(self.num_classes).long() 37 | if self.use_gpu: classes = classes.cuda() 38 | labels = labels.unsqueeze(1).expand(batch_size, self.num_classes) 39 | mask = labels.eq(classes.expand(batch_size, self.num_classes)) 40 | 41 | dist = distmat * mask.float() 42 | loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size 43 | 44 | return loss 45 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from torch.utils.data import DataLoader 4 | 5 | import transforms 6 | 7 | class MNIST(object): 8 | def __init__(self, batch_size, use_gpu, num_workers): 9 | transform = transforms.Compose([ 10 | transforms.ToTensor(), 11 | transforms.Normalize((0.1307,), (0.3081,)) 12 | ]) 13 | 14 | pin_memory = True if use_gpu else False 15 | 16 | trainset = torchvision.datasets.MNIST(root='./data/mnist', train=True, download=True, transform=transform) 17 | 18 | trainloader = torch.utils.data.DataLoader( 19 | trainset, batch_size=batch_size, shuffle=True, 20 | num_workers=num_workers, pin_memory=pin_memory, 21 | ) 22 | 23 | testset = torchvision.datasets.MNIST(root='./data/mnist', train=False, download=True, transform=transform) 24 | 25 | testloader = torch.utils.data.DataLoader( 26 | testset, batch_size=batch_size, shuffle=False, 27 | num_workers=num_workers, pin_memory=pin_memory, 28 | ) 29 | 30 | self.trainloader = trainloader 31 | self.testloader = testloader 32 | self.num_classes = 10 33 | 34 | __factory = { 35 | 'mnist': MNIST, 36 | } 37 | 38 | def create(name, batch_size, use_gpu, num_workers): 39 | if name not in __factory.keys(): 40 | raise KeyError("Unknown dataset: {}".format(name)) 41 | return __factory[name](batch_size, use_gpu, num_workers) -------------------------------------------------------------------------------- /gifs/center_test.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaiyangZhou/pytorch-center-loss/082ffa21c065426843f26129be51bb1cfd554806/gifs/center_test.gif -------------------------------------------------------------------------------- /gifs/center_train.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaiyangZhou/pytorch-center-loss/082ffa21c065426843f26129be51bb1cfd554806/gifs/center_train.gif -------------------------------------------------------------------------------- /gifs/softmax_test.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaiyangZhou/pytorch-center-loss/082ffa21c065426843f26129be51bb1cfd554806/gifs/softmax_test.gif -------------------------------------------------------------------------------- /gifs/softmax_train.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaiyangZhou/pytorch-center-loss/082ffa21c065426843f26129be51bb1cfd554806/gifs/softmax_train.gif -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import datetime 5 | import time 6 | import os.path as osp 7 | import matplotlib 8 | matplotlib.use('Agg') 9 | from matplotlib import pyplot as plt 10 | import numpy as np 11 | 12 | import torch 13 | import torch.nn as nn 14 | from torch.optim import lr_scheduler 15 | import torch.backends.cudnn as cudnn 16 | 17 | import datasets 18 | import models 19 | from utils import AverageMeter, Logger 20 | from center_loss import CenterLoss 21 | 22 | parser = argparse.ArgumentParser("Center Loss Example") 23 | # dataset 24 | parser.add_argument('-d', '--dataset', type=str, default='mnist', choices=['mnist']) 25 | parser.add_argument('-j', '--workers', default=4, type=int, 26 | help="number of data loading workers (default: 4)") 27 | # optimization 28 | parser.add_argument('--batch-size', type=int, default=128) 29 | parser.add_argument('--lr-model', type=float, default=0.001, help="learning rate for model") 30 | parser.add_argument('--lr-cent', type=float, default=0.5, help="learning rate for center loss") 31 | parser.add_argument('--weight-cent', type=float, default=1, help="weight for center loss") 32 | parser.add_argument('--max-epoch', type=int, default=100) 33 | parser.add_argument('--stepsize', type=int, default=20) 34 | parser.add_argument('--gamma', type=float, default=0.5, help="learning rate decay") 35 | # model 36 | parser.add_argument('--model', type=str, default='cnn') 37 | # misc 38 | parser.add_argument('--eval-freq', type=int, default=10) 39 | parser.add_argument('--print-freq', type=int, default=50) 40 | parser.add_argument('--gpu', type=str, default='0') 41 | parser.add_argument('--seed', type=int, default=1) 42 | parser.add_argument('--use-cpu', action='store_true') 43 | parser.add_argument('--save-dir', type=str, default='log') 44 | parser.add_argument('--plot', action='store_true', help="whether to plot features for every epoch") 45 | 46 | args = parser.parse_args() 47 | 48 | def main(): 49 | torch.manual_seed(args.seed) 50 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 51 | use_gpu = torch.cuda.is_available() 52 | if args.use_cpu: use_gpu = False 53 | 54 | sys.stdout = Logger(osp.join(args.save_dir, 'log_' + args.dataset + '.txt')) 55 | 56 | if use_gpu: 57 | print("Currently using GPU: {}".format(args.gpu)) 58 | cudnn.benchmark = True 59 | torch.cuda.manual_seed_all(args.seed) 60 | else: 61 | print("Currently using CPU") 62 | 63 | print("Creating dataset: {}".format(args.dataset)) 64 | dataset = datasets.create( 65 | name=args.dataset, batch_size=args.batch_size, use_gpu=use_gpu, 66 | num_workers=args.workers, 67 | ) 68 | 69 | trainloader, testloader = dataset.trainloader, dataset.testloader 70 | 71 | print("Creating model: {}".format(args.model)) 72 | model = models.create(name=args.model, num_classes=dataset.num_classes) 73 | 74 | if use_gpu: 75 | model = nn.DataParallel(model).cuda() 76 | 77 | criterion_xent = nn.CrossEntropyLoss() 78 | criterion_cent = CenterLoss(num_classes=dataset.num_classes, feat_dim=2, use_gpu=use_gpu) 79 | optimizer_model = torch.optim.SGD(model.parameters(), lr=args.lr_model, weight_decay=5e-04, momentum=0.9) 80 | optimizer_centloss = torch.optim.SGD(criterion_cent.parameters(), lr=args.lr_cent) 81 | 82 | if args.stepsize > 0: 83 | scheduler = lr_scheduler.StepLR(optimizer_model, step_size=args.stepsize, gamma=args.gamma) 84 | 85 | start_time = time.time() 86 | 87 | for epoch in range(args.max_epoch): 88 | print("==> Epoch {}/{}".format(epoch+1, args.max_epoch)) 89 | train(model, criterion_xent, criterion_cent, 90 | optimizer_model, optimizer_centloss, 91 | trainloader, use_gpu, dataset.num_classes, epoch) 92 | 93 | if args.stepsize > 0: scheduler.step() 94 | 95 | if args.eval_freq > 0 and (epoch+1) % args.eval_freq == 0 or (epoch+1) == args.max_epoch: 96 | print("==> Test") 97 | acc, err = test(model, testloader, use_gpu, dataset.num_classes, epoch) 98 | print("Accuracy (%): {}\t Error rate (%): {}".format(acc, err)) 99 | 100 | elapsed = round(time.time() - start_time) 101 | elapsed = str(datetime.timedelta(seconds=elapsed)) 102 | print("Finished. Total elapsed time (h:m:s): {}".format(elapsed)) 103 | 104 | def train(model, criterion_xent, criterion_cent, 105 | optimizer_model, optimizer_centloss, 106 | trainloader, use_gpu, num_classes, epoch): 107 | model.train() 108 | xent_losses = AverageMeter() 109 | cent_losses = AverageMeter() 110 | losses = AverageMeter() 111 | 112 | if args.plot: 113 | all_features, all_labels = [], [] 114 | 115 | for batch_idx, (data, labels) in enumerate(trainloader): 116 | if use_gpu: 117 | data, labels = data.cuda(), labels.cuda() 118 | features, outputs = model(data) 119 | loss_xent = criterion_xent(outputs, labels) 120 | loss_cent = criterion_cent(features, labels) 121 | loss_cent *= args.weight_cent 122 | loss = loss_xent + loss_cent 123 | optimizer_model.zero_grad() 124 | optimizer_centloss.zero_grad() 125 | loss.backward() 126 | optimizer_model.step() 127 | # by doing so, weight_cent would not impact on the learning of centers 128 | for param in criterion_cent.parameters(): 129 | param.grad.data *= (1. / args.weight_cent) 130 | optimizer_centloss.step() 131 | 132 | losses.update(loss.item(), labels.size(0)) 133 | xent_losses.update(loss_xent.item(), labels.size(0)) 134 | cent_losses.update(loss_cent.item(), labels.size(0)) 135 | 136 | if args.plot: 137 | if use_gpu: 138 | all_features.append(features.data.cpu().numpy()) 139 | all_labels.append(labels.data.cpu().numpy()) 140 | else: 141 | all_features.append(features.data.numpy()) 142 | all_labels.append(labels.data.numpy()) 143 | 144 | if (batch_idx+1) % args.print_freq == 0: 145 | print("Batch {}/{}\t Loss {:.6f} ({:.6f}) XentLoss {:.6f} ({:.6f}) CenterLoss {:.6f} ({:.6f})" \ 146 | .format(batch_idx+1, len(trainloader), losses.val, losses.avg, xent_losses.val, xent_losses.avg, cent_losses.val, cent_losses.avg)) 147 | 148 | if args.plot: 149 | all_features = np.concatenate(all_features, 0) 150 | all_labels = np.concatenate(all_labels, 0) 151 | plot_features(all_features, all_labels, num_classes, epoch, prefix='train') 152 | 153 | def test(model, testloader, use_gpu, num_classes, epoch): 154 | model.eval() 155 | correct, total = 0, 0 156 | if args.plot: 157 | all_features, all_labels = [], [] 158 | 159 | with torch.no_grad(): 160 | for data, labels in testloader: 161 | if use_gpu: 162 | data, labels = data.cuda(), labels.cuda() 163 | features, outputs = model(data) 164 | predictions = outputs.data.max(1)[1] 165 | total += labels.size(0) 166 | correct += (predictions == labels.data).sum() 167 | 168 | if args.plot: 169 | if use_gpu: 170 | all_features.append(features.data.cpu().numpy()) 171 | all_labels.append(labels.data.cpu().numpy()) 172 | else: 173 | all_features.append(features.data.numpy()) 174 | all_labels.append(labels.data.numpy()) 175 | 176 | if args.plot: 177 | all_features = np.concatenate(all_features, 0) 178 | all_labels = np.concatenate(all_labels, 0) 179 | plot_features(all_features, all_labels, num_classes, epoch, prefix='test') 180 | 181 | acc = correct * 100. / total 182 | err = 100. - acc 183 | return acc, err 184 | 185 | def plot_features(features, labels, num_classes, epoch, prefix): 186 | """Plot features on 2D plane. 187 | 188 | Args: 189 | features: (num_instances, num_features). 190 | labels: (num_instances). 191 | """ 192 | colors = ['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9'] 193 | for label_idx in range(num_classes): 194 | plt.scatter( 195 | features[labels==label_idx, 0], 196 | features[labels==label_idx, 1], 197 | c=colors[label_idx], 198 | s=1, 199 | ) 200 | plt.legend(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'], loc='upper right') 201 | dirname = osp.join(args.save_dir, prefix) 202 | if not osp.exists(dirname): 203 | os.mkdir(dirname) 204 | save_name = osp.join(dirname, 'epoch_' + str(epoch+1) + '.png') 205 | plt.savefig(save_name, bbox_inches='tight') 206 | plt.close() 207 | 208 | if __name__ == '__main__': 209 | main() 210 | 211 | 212 | 213 | 214 | 215 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | import math 6 | 7 | class ConvNet(nn.Module): 8 | """LeNet++ as described in the Center Loss paper.""" 9 | def __init__(self, num_classes): 10 | super(ConvNet, self).__init__() 11 | self.conv1_1 = nn.Conv2d(1, 32, 5, stride=1, padding=2) 12 | self.prelu1_1 = nn.PReLU() 13 | self.conv1_2 = nn.Conv2d(32, 32, 5, stride=1, padding=2) 14 | self.prelu1_2 = nn.PReLU() 15 | 16 | self.conv2_1 = nn.Conv2d(32, 64, 5, stride=1, padding=2) 17 | self.prelu2_1 = nn.PReLU() 18 | self.conv2_2 = nn.Conv2d(64, 64, 5, stride=1, padding=2) 19 | self.prelu2_2 = nn.PReLU() 20 | 21 | self.conv3_1 = nn.Conv2d(64, 128, 5, stride=1, padding=2) 22 | self.prelu3_1 = nn.PReLU() 23 | self.conv3_2 = nn.Conv2d(128, 128, 5, stride=1, padding=2) 24 | self.prelu3_2 = nn.PReLU() 25 | 26 | self.fc1 = nn.Linear(128*3*3, 2) 27 | self.prelu_fc1 = nn.PReLU() 28 | self.fc2 = nn.Linear(2, num_classes) 29 | 30 | def forward(self, x): 31 | x = self.prelu1_1(self.conv1_1(x)) 32 | x = self.prelu1_2(self.conv1_2(x)) 33 | x = F.max_pool2d(x, 2) 34 | 35 | x = self.prelu2_1(self.conv2_1(x)) 36 | x = self.prelu2_2(self.conv2_2(x)) 37 | x = F.max_pool2d(x, 2) 38 | 39 | x = self.prelu3_1(self.conv3_1(x)) 40 | x = self.prelu3_2(self.conv3_2(x)) 41 | x = F.max_pool2d(x, 2) 42 | 43 | x = x.view(-1, 128*3*3) 44 | x = self.prelu_fc1(self.fc1(x)) 45 | y = self.fc2(x) 46 | 47 | return x, y 48 | 49 | __factory = { 50 | 'cnn': ConvNet, 51 | } 52 | 53 | def create(name, num_classes): 54 | if name not in __factory.keys(): 55 | raise KeyError("Unknown model: {}".format(name)) 56 | return __factory[name](num_classes) 57 | 58 | if __name__ == '__main__': 59 | pass -------------------------------------------------------------------------------- /transforms.py: -------------------------------------------------------------------------------- 1 | from torchvision.transforms import * 2 | from PIL import Image 3 | 4 | class ToGray(object): 5 | """ 6 | Convert image from RGB to gray level. 7 | """ 8 | def __call__(self, img): 9 | return img.convert('L') -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import errno 4 | import shutil 5 | import os.path as osp 6 | 7 | import torch 8 | 9 | def mkdir_if_missing(directory): 10 | if not osp.exists(directory): 11 | try: 12 | os.makedirs(directory) 13 | except OSError as e: 14 | if e.errno != errno.EEXIST: 15 | raise 16 | 17 | class AverageMeter(object): 18 | """Computes and stores the average and current value. 19 | 20 | Code imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 21 | """ 22 | def __init__(self): 23 | self.reset() 24 | 25 | def reset(self): 26 | self.val = 0 27 | self.avg = 0 28 | self.sum = 0 29 | self.count = 0 30 | 31 | def update(self, val, n=1): 32 | self.val = val 33 | self.sum += val * n 34 | self.count += n 35 | self.avg = self.sum / self.count 36 | 37 | def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'): 38 | mkdir_if_missing(osp.dirname(fpath)) 39 | torch.save(state, fpath) 40 | if is_best: 41 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'best_model.pth.tar')) 42 | 43 | class Logger(object): 44 | """ 45 | Write console output to external text file. 46 | 47 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py. 48 | """ 49 | def __init__(self, fpath=None): 50 | self.console = sys.stdout 51 | self.file = None 52 | if fpath is not None: 53 | mkdir_if_missing(os.path.dirname(fpath)) 54 | self.file = open(fpath, 'w') 55 | 56 | def __del__(self): 57 | self.close() 58 | 59 | def __enter__(self): 60 | pass 61 | 62 | def __exit__(self, *args): 63 | self.close() 64 | 65 | def write(self, msg): 66 | self.console.write(msg) 67 | if self.file is not None: 68 | self.file.write(msg) 69 | 70 | def flush(self): 71 | self.console.flush() 72 | if self.file is not None: 73 | self.file.flush() 74 | os.fsync(self.file.fileno()) 75 | 76 | def close(self): 77 | self.console.close() 78 | if self.file is not None: 79 | self.file.close() --------------------------------------------------------------------------------