├── .gitignore
├── LICENSE
├── README.md
├── center_loss.py
├── datasets.py
├── gifs
├── center_test.gif
├── center_train.gif
├── softmax_test.gif
└── softmax_train.gif
├── main.py
├── models.py
├── transforms.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Custom
2 | data/
3 | log/
4 |
5 | # OS files
6 | .DS_Store
7 | .AppleDouble
8 | .LSOverride
9 |
10 | # Byte-compiled / optimized / DLL files
11 | __pycache__/
12 | *.py[cod]
13 | *$py.class
14 |
15 | # C extensions
16 | *.so
17 |
18 | # Distribution / packaging
19 | .Python
20 | build/
21 | develop-eggs/
22 | dist/
23 | downloads/
24 | eggs/
25 | .eggs/
26 | lib/
27 | lib64/
28 | parts/
29 | sdist/
30 | var/
31 | wheels/
32 | *.egg-info/
33 | .installed.cfg
34 | *.egg
35 | MANIFEST
36 |
37 | # PyInstaller
38 | # Usually these files are written by a python script from a template
39 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
40 | *.manifest
41 | *.spec
42 |
43 | # Installer logs
44 | pip-log.txt
45 | pip-delete-this-directory.txt
46 |
47 | # Unit test / coverage reports
48 | htmlcov/
49 | .tox/
50 | .coverage
51 | .coverage.*
52 | .cache
53 | nosetests.xml
54 | coverage.xml
55 | *.cover
56 | .hypothesis/
57 | .pytest_cache/
58 |
59 | # Translations
60 | *.mo
61 | *.pot
62 |
63 | # Django stuff:
64 | *.log
65 | .static_storage/
66 | .media/
67 | local_settings.py
68 |
69 | # Flask stuff:
70 | instance/
71 | .webassets-cache
72 |
73 | # Scrapy stuff:
74 | .scrapy
75 |
76 | # Sphinx documentation
77 | docs/_build/
78 |
79 | # PyBuilder
80 | target/
81 |
82 | # Jupyter Notebook
83 | .ipynb_checkpoints
84 |
85 | # pyenv
86 | .python-version
87 |
88 | # celery beat schedule file
89 | celerybeat-schedule
90 |
91 | # SageMath parsed files
92 | *.sage.py
93 |
94 | # Environments
95 | .env
96 | .venv
97 | env/
98 | venv/
99 | ENV/
100 | env.bak/
101 | venv.bak/
102 |
103 | # Spyder project settings
104 | .spyderproject
105 | .spyproject
106 |
107 | # Rope project settings
108 | .ropeproject
109 |
110 | # mkdocs documentation
111 | /site
112 |
113 | # mypy
114 | .mypy_cache/
115 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Kaiyang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # pytorch-center-loss
2 | Pytorch implementation of center loss: [Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.](https://ydwen.github.io/papers/WenECCV16.pdf)
3 |
4 | This loss function is also used by [deep-person-reid](https://github.com/KaiyangZhou/deep-person-reid).
5 |
6 | ## Get started
7 | Clone this repo and run the code
8 | ```bash
9 | $ git clone https://github.com/KaiyangZhou/pytorch-center-loss
10 | $ cd pytorch-center-loss
11 | $ python main.py --eval-freq 1 --gpu 0 --save-dir log/ --plot
12 | ```
13 | You will see the following info in your terminal
14 | ```bash
15 | Currently using GPU: 0
16 | Creating dataset: mnist
17 | Creating model: cnn
18 | ==> Epoch 1/100
19 | Batch 50/469 Loss 2.332793 (2.557837) XentLoss 2.332744 (2.388296) CenterLoss 0.000048 (0.169540)
20 | Batch 100/469 Loss 2.354638 (2.463851) XentLoss 2.354637 (2.379078) CenterLoss 0.000001 (0.084773)
21 | Batch 150/469 Loss 2.361732 (2.434477) XentLoss 2.361732 (2.377962) CenterLoss 0.000000 (0.056515)
22 | Batch 200/469 Loss 2.336701 (2.417842) XentLoss 2.336700 (2.375455) CenterLoss 0.000001 (0.042386)
23 | Batch 250/469 Loss 2.404814 (2.407015) XentLoss 2.404813 (2.373106) CenterLoss 0.000001 (0.033909)
24 | Batch 300/469 Loss 2.338753 (2.398546) XentLoss 2.338752 (2.370288) CenterLoss 0.000001 (0.028258)
25 | Batch 350/469 Loss 2.367068 (2.390672) XentLoss 2.367059 (2.366450) CenterLoss 0.000009 (0.024221)
26 | Batch 400/469 Loss 2.344178 (2.384820) XentLoss 2.344142 (2.363620) CenterLoss 0.000036 (0.021199)
27 | Batch 450/469 Loss 2.329708 (2.379460) XentLoss 2.329661 (2.360611) CenterLoss 0.000047 (0.018848)
28 | ==> Test
29 | Accuracy (%): 10.32 Error rate (%): 89.68
30 | ... ...
31 | ==> Epoch 30/100
32 | Batch 50/469 Loss 0.141117 (0.155986) XentLoss 0.084169 (0.091617) CenterLoss 0.056949 (0.064369)
33 | Batch 100/469 Loss 0.138201 (0.151291) XentLoss 0.089146 (0.092839) CenterLoss 0.049055 (0.058452)
34 | Batch 150/469 Loss 0.151055 (0.151985) XentLoss 0.090816 (0.092405) CenterLoss 0.060239 (0.059580)
35 | Batch 200/469 Loss 0.150803 (0.153333) XentLoss 0.092857 (0.092156) CenterLoss 0.057946 (0.061176)
36 | Batch 250/469 Loss 0.162954 (0.154971) XentLoss 0.094889 (0.092099) CenterLoss 0.068065 (0.062872)
37 | Batch 300/469 Loss 0.162895 (0.156038) XentLoss 0.093100 (0.092034) CenterLoss 0.069795 (0.064004)
38 | Batch 350/469 Loss 0.146187 (0.156491) XentLoss 0.082508 (0.091787) CenterLoss 0.063679 (0.064704)
39 | Batch 400/469 Loss 0.171533 (0.157390) XentLoss 0.092526 (0.091674) CenterLoss 0.079007 (0.065716)
40 | Batch 450/469 Loss 0.209196 (0.158371) XentLoss 0.098388 (0.091560) CenterLoss 0.110808 (0.066811)
41 | ==> Test
42 | Accuracy (%): 98.51 Error rate (%): 1.49
43 | ... ...
44 | ```
45 |
46 | Please run `python main.py -h` for more details regarding input arguments.
47 |
48 | ## Results
49 | We visualize the feature learning process below.
50 |
51 | Softmax only. Left: training set. Right: test set.
52 |
53 |

54 |

55 |
56 |
57 | Softmax + center loss. Left: training set. Right: test set.
58 |
59 |

60 |

61 |
62 |
63 | ## How to use center loss in your own project
64 | 1. All you need is the `center_loss.py` file
65 | ```python
66 | from center_loss import CenterLoss
67 | ```
68 | 2. Initialize center loss in the main function
69 | ```python
70 | center_loss = CenterLoss(num_classes=10, feat_dim=2, use_gpu=True)
71 | ```
72 | 3. Construct an optimizer for center loss
73 | ```python
74 | optimizer_centloss = torch.optim.SGD(center_loss.parameters(), lr=0.5)
75 | ```
76 | Alternatively, you can merge optimizers of model and center loss, like
77 | ```
78 | params = list(model.parameters()) + list(center_loss.parameters())
79 | optimizer = torch.optim.SGD(params, lr=0.1) # here lr is the overall learning rate
80 | ```
81 |
82 | 4. Update class centers just like how you update a pytorch model
83 | ```python
84 | # features (torch tensor): a 2D torch float tensor with shape (batch_size, feat_dim)
85 | # labels (torch long tensor): 1D torch long tensor with shape (batch_size)
86 | # alpha (float): weight for center loss
87 | loss = center_loss(features, labels) * alpha + other_loss
88 | optimizer_centloss.zero_grad()
89 | loss.backward()
90 | # multiple (1./alpha) in order to remove the effect of alpha on updating centers
91 | for param in center_loss.parameters():
92 | param.grad.data *= (1./alpha)
93 | optimizer_centloss.step()
94 | ```
95 | If you adopt the second way (i.e. use one optimizer for both model and center loss), the update code would look like
96 | ```python
97 | loss = center_loss(features, labels) * alpha + other_loss
98 | optimizer.zero_grad()
99 | loss.backward()
100 | for param in center_loss.parameters():
101 | # lr_cent is learning rate for center loss, e.g. lr_cent = 0.5
102 | param.grad.data *= (lr_cent / (alpha * lr))
103 | optimizer.step()
104 | ```
--------------------------------------------------------------------------------
/center_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class CenterLoss(nn.Module):
5 | """Center loss.
6 |
7 | Reference:
8 | Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
9 |
10 | Args:
11 | num_classes (int): number of classes.
12 | feat_dim (int): feature dimension.
13 | """
14 | def __init__(self, num_classes=10, feat_dim=2, use_gpu=True):
15 | super(CenterLoss, self).__init__()
16 | self.num_classes = num_classes
17 | self.feat_dim = feat_dim
18 | self.use_gpu = use_gpu
19 |
20 | if self.use_gpu:
21 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())
22 | else:
23 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))
24 |
25 | def forward(self, x, labels):
26 | """
27 | Args:
28 | x: feature matrix with shape (batch_size, feat_dim).
29 | labels: ground truth labels with shape (batch_size).
30 | """
31 | batch_size = x.size(0)
32 | distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
33 | torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
34 | distmat.addmm_(1, -2, x, self.centers.t())
35 |
36 | classes = torch.arange(self.num_classes).long()
37 | if self.use_gpu: classes = classes.cuda()
38 | labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)
39 | mask = labels.eq(classes.expand(batch_size, self.num_classes))
40 |
41 | dist = distmat * mask.float()
42 | loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size
43 |
44 | return loss
45 |
--------------------------------------------------------------------------------
/datasets.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | from torch.utils.data import DataLoader
4 |
5 | import transforms
6 |
7 | class MNIST(object):
8 | def __init__(self, batch_size, use_gpu, num_workers):
9 | transform = transforms.Compose([
10 | transforms.ToTensor(),
11 | transforms.Normalize((0.1307,), (0.3081,))
12 | ])
13 |
14 | pin_memory = True if use_gpu else False
15 |
16 | trainset = torchvision.datasets.MNIST(root='./data/mnist', train=True, download=True, transform=transform)
17 |
18 | trainloader = torch.utils.data.DataLoader(
19 | trainset, batch_size=batch_size, shuffle=True,
20 | num_workers=num_workers, pin_memory=pin_memory,
21 | )
22 |
23 | testset = torchvision.datasets.MNIST(root='./data/mnist', train=False, download=True, transform=transform)
24 |
25 | testloader = torch.utils.data.DataLoader(
26 | testset, batch_size=batch_size, shuffle=False,
27 | num_workers=num_workers, pin_memory=pin_memory,
28 | )
29 |
30 | self.trainloader = trainloader
31 | self.testloader = testloader
32 | self.num_classes = 10
33 |
34 | __factory = {
35 | 'mnist': MNIST,
36 | }
37 |
38 | def create(name, batch_size, use_gpu, num_workers):
39 | if name not in __factory.keys():
40 | raise KeyError("Unknown dataset: {}".format(name))
41 | return __factory[name](batch_size, use_gpu, num_workers)
--------------------------------------------------------------------------------
/gifs/center_test.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KaiyangZhou/pytorch-center-loss/082ffa21c065426843f26129be51bb1cfd554806/gifs/center_test.gif
--------------------------------------------------------------------------------
/gifs/center_train.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KaiyangZhou/pytorch-center-loss/082ffa21c065426843f26129be51bb1cfd554806/gifs/center_train.gif
--------------------------------------------------------------------------------
/gifs/softmax_test.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KaiyangZhou/pytorch-center-loss/082ffa21c065426843f26129be51bb1cfd554806/gifs/softmax_test.gif
--------------------------------------------------------------------------------
/gifs/softmax_train.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KaiyangZhou/pytorch-center-loss/082ffa21c065426843f26129be51bb1cfd554806/gifs/softmax_train.gif
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import argparse
4 | import datetime
5 | import time
6 | import os.path as osp
7 | import matplotlib
8 | matplotlib.use('Agg')
9 | from matplotlib import pyplot as plt
10 | import numpy as np
11 |
12 | import torch
13 | import torch.nn as nn
14 | from torch.optim import lr_scheduler
15 | import torch.backends.cudnn as cudnn
16 |
17 | import datasets
18 | import models
19 | from utils import AverageMeter, Logger
20 | from center_loss import CenterLoss
21 |
22 | parser = argparse.ArgumentParser("Center Loss Example")
23 | # dataset
24 | parser.add_argument('-d', '--dataset', type=str, default='mnist', choices=['mnist'])
25 | parser.add_argument('-j', '--workers', default=4, type=int,
26 | help="number of data loading workers (default: 4)")
27 | # optimization
28 | parser.add_argument('--batch-size', type=int, default=128)
29 | parser.add_argument('--lr-model', type=float, default=0.001, help="learning rate for model")
30 | parser.add_argument('--lr-cent', type=float, default=0.5, help="learning rate for center loss")
31 | parser.add_argument('--weight-cent', type=float, default=1, help="weight for center loss")
32 | parser.add_argument('--max-epoch', type=int, default=100)
33 | parser.add_argument('--stepsize', type=int, default=20)
34 | parser.add_argument('--gamma', type=float, default=0.5, help="learning rate decay")
35 | # model
36 | parser.add_argument('--model', type=str, default='cnn')
37 | # misc
38 | parser.add_argument('--eval-freq', type=int, default=10)
39 | parser.add_argument('--print-freq', type=int, default=50)
40 | parser.add_argument('--gpu', type=str, default='0')
41 | parser.add_argument('--seed', type=int, default=1)
42 | parser.add_argument('--use-cpu', action='store_true')
43 | parser.add_argument('--save-dir', type=str, default='log')
44 | parser.add_argument('--plot', action='store_true', help="whether to plot features for every epoch")
45 |
46 | args = parser.parse_args()
47 |
48 | def main():
49 | torch.manual_seed(args.seed)
50 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
51 | use_gpu = torch.cuda.is_available()
52 | if args.use_cpu: use_gpu = False
53 |
54 | sys.stdout = Logger(osp.join(args.save_dir, 'log_' + args.dataset + '.txt'))
55 |
56 | if use_gpu:
57 | print("Currently using GPU: {}".format(args.gpu))
58 | cudnn.benchmark = True
59 | torch.cuda.manual_seed_all(args.seed)
60 | else:
61 | print("Currently using CPU")
62 |
63 | print("Creating dataset: {}".format(args.dataset))
64 | dataset = datasets.create(
65 | name=args.dataset, batch_size=args.batch_size, use_gpu=use_gpu,
66 | num_workers=args.workers,
67 | )
68 |
69 | trainloader, testloader = dataset.trainloader, dataset.testloader
70 |
71 | print("Creating model: {}".format(args.model))
72 | model = models.create(name=args.model, num_classes=dataset.num_classes)
73 |
74 | if use_gpu:
75 | model = nn.DataParallel(model).cuda()
76 |
77 | criterion_xent = nn.CrossEntropyLoss()
78 | criterion_cent = CenterLoss(num_classes=dataset.num_classes, feat_dim=2, use_gpu=use_gpu)
79 | optimizer_model = torch.optim.SGD(model.parameters(), lr=args.lr_model, weight_decay=5e-04, momentum=0.9)
80 | optimizer_centloss = torch.optim.SGD(criterion_cent.parameters(), lr=args.lr_cent)
81 |
82 | if args.stepsize > 0:
83 | scheduler = lr_scheduler.StepLR(optimizer_model, step_size=args.stepsize, gamma=args.gamma)
84 |
85 | start_time = time.time()
86 |
87 | for epoch in range(args.max_epoch):
88 | print("==> Epoch {}/{}".format(epoch+1, args.max_epoch))
89 | train(model, criterion_xent, criterion_cent,
90 | optimizer_model, optimizer_centloss,
91 | trainloader, use_gpu, dataset.num_classes, epoch)
92 |
93 | if args.stepsize > 0: scheduler.step()
94 |
95 | if args.eval_freq > 0 and (epoch+1) % args.eval_freq == 0 or (epoch+1) == args.max_epoch:
96 | print("==> Test")
97 | acc, err = test(model, testloader, use_gpu, dataset.num_classes, epoch)
98 | print("Accuracy (%): {}\t Error rate (%): {}".format(acc, err))
99 |
100 | elapsed = round(time.time() - start_time)
101 | elapsed = str(datetime.timedelta(seconds=elapsed))
102 | print("Finished. Total elapsed time (h:m:s): {}".format(elapsed))
103 |
104 | def train(model, criterion_xent, criterion_cent,
105 | optimizer_model, optimizer_centloss,
106 | trainloader, use_gpu, num_classes, epoch):
107 | model.train()
108 | xent_losses = AverageMeter()
109 | cent_losses = AverageMeter()
110 | losses = AverageMeter()
111 |
112 | if args.plot:
113 | all_features, all_labels = [], []
114 |
115 | for batch_idx, (data, labels) in enumerate(trainloader):
116 | if use_gpu:
117 | data, labels = data.cuda(), labels.cuda()
118 | features, outputs = model(data)
119 | loss_xent = criterion_xent(outputs, labels)
120 | loss_cent = criterion_cent(features, labels)
121 | loss_cent *= args.weight_cent
122 | loss = loss_xent + loss_cent
123 | optimizer_model.zero_grad()
124 | optimizer_centloss.zero_grad()
125 | loss.backward()
126 | optimizer_model.step()
127 | # by doing so, weight_cent would not impact on the learning of centers
128 | for param in criterion_cent.parameters():
129 | param.grad.data *= (1. / args.weight_cent)
130 | optimizer_centloss.step()
131 |
132 | losses.update(loss.item(), labels.size(0))
133 | xent_losses.update(loss_xent.item(), labels.size(0))
134 | cent_losses.update(loss_cent.item(), labels.size(0))
135 |
136 | if args.plot:
137 | if use_gpu:
138 | all_features.append(features.data.cpu().numpy())
139 | all_labels.append(labels.data.cpu().numpy())
140 | else:
141 | all_features.append(features.data.numpy())
142 | all_labels.append(labels.data.numpy())
143 |
144 | if (batch_idx+1) % args.print_freq == 0:
145 | print("Batch {}/{}\t Loss {:.6f} ({:.6f}) XentLoss {:.6f} ({:.6f}) CenterLoss {:.6f} ({:.6f})" \
146 | .format(batch_idx+1, len(trainloader), losses.val, losses.avg, xent_losses.val, xent_losses.avg, cent_losses.val, cent_losses.avg))
147 |
148 | if args.plot:
149 | all_features = np.concatenate(all_features, 0)
150 | all_labels = np.concatenate(all_labels, 0)
151 | plot_features(all_features, all_labels, num_classes, epoch, prefix='train')
152 |
153 | def test(model, testloader, use_gpu, num_classes, epoch):
154 | model.eval()
155 | correct, total = 0, 0
156 | if args.plot:
157 | all_features, all_labels = [], []
158 |
159 | with torch.no_grad():
160 | for data, labels in testloader:
161 | if use_gpu:
162 | data, labels = data.cuda(), labels.cuda()
163 | features, outputs = model(data)
164 | predictions = outputs.data.max(1)[1]
165 | total += labels.size(0)
166 | correct += (predictions == labels.data).sum()
167 |
168 | if args.plot:
169 | if use_gpu:
170 | all_features.append(features.data.cpu().numpy())
171 | all_labels.append(labels.data.cpu().numpy())
172 | else:
173 | all_features.append(features.data.numpy())
174 | all_labels.append(labels.data.numpy())
175 |
176 | if args.plot:
177 | all_features = np.concatenate(all_features, 0)
178 | all_labels = np.concatenate(all_labels, 0)
179 | plot_features(all_features, all_labels, num_classes, epoch, prefix='test')
180 |
181 | acc = correct * 100. / total
182 | err = 100. - acc
183 | return acc, err
184 |
185 | def plot_features(features, labels, num_classes, epoch, prefix):
186 | """Plot features on 2D plane.
187 |
188 | Args:
189 | features: (num_instances, num_features).
190 | labels: (num_instances).
191 | """
192 | colors = ['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9']
193 | for label_idx in range(num_classes):
194 | plt.scatter(
195 | features[labels==label_idx, 0],
196 | features[labels==label_idx, 1],
197 | c=colors[label_idx],
198 | s=1,
199 | )
200 | plt.legend(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'], loc='upper right')
201 | dirname = osp.join(args.save_dir, prefix)
202 | if not osp.exists(dirname):
203 | os.mkdir(dirname)
204 | save_name = osp.join(dirname, 'epoch_' + str(epoch+1) + '.png')
205 | plt.savefig(save_name, bbox_inches='tight')
206 | plt.close()
207 |
208 | if __name__ == '__main__':
209 | main()
210 |
211 |
212 |
213 |
214 |
215 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import functional as F
4 |
5 | import math
6 |
7 | class ConvNet(nn.Module):
8 | """LeNet++ as described in the Center Loss paper."""
9 | def __init__(self, num_classes):
10 | super(ConvNet, self).__init__()
11 | self.conv1_1 = nn.Conv2d(1, 32, 5, stride=1, padding=2)
12 | self.prelu1_1 = nn.PReLU()
13 | self.conv1_2 = nn.Conv2d(32, 32, 5, stride=1, padding=2)
14 | self.prelu1_2 = nn.PReLU()
15 |
16 | self.conv2_1 = nn.Conv2d(32, 64, 5, stride=1, padding=2)
17 | self.prelu2_1 = nn.PReLU()
18 | self.conv2_2 = nn.Conv2d(64, 64, 5, stride=1, padding=2)
19 | self.prelu2_2 = nn.PReLU()
20 |
21 | self.conv3_1 = nn.Conv2d(64, 128, 5, stride=1, padding=2)
22 | self.prelu3_1 = nn.PReLU()
23 | self.conv3_2 = nn.Conv2d(128, 128, 5, stride=1, padding=2)
24 | self.prelu3_2 = nn.PReLU()
25 |
26 | self.fc1 = nn.Linear(128*3*3, 2)
27 | self.prelu_fc1 = nn.PReLU()
28 | self.fc2 = nn.Linear(2, num_classes)
29 |
30 | def forward(self, x):
31 | x = self.prelu1_1(self.conv1_1(x))
32 | x = self.prelu1_2(self.conv1_2(x))
33 | x = F.max_pool2d(x, 2)
34 |
35 | x = self.prelu2_1(self.conv2_1(x))
36 | x = self.prelu2_2(self.conv2_2(x))
37 | x = F.max_pool2d(x, 2)
38 |
39 | x = self.prelu3_1(self.conv3_1(x))
40 | x = self.prelu3_2(self.conv3_2(x))
41 | x = F.max_pool2d(x, 2)
42 |
43 | x = x.view(-1, 128*3*3)
44 | x = self.prelu_fc1(self.fc1(x))
45 | y = self.fc2(x)
46 |
47 | return x, y
48 |
49 | __factory = {
50 | 'cnn': ConvNet,
51 | }
52 |
53 | def create(name, num_classes):
54 | if name not in __factory.keys():
55 | raise KeyError("Unknown model: {}".format(name))
56 | return __factory[name](num_classes)
57 |
58 | if __name__ == '__main__':
59 | pass
--------------------------------------------------------------------------------
/transforms.py:
--------------------------------------------------------------------------------
1 | from torchvision.transforms import *
2 | from PIL import Image
3 |
4 | class ToGray(object):
5 | """
6 | Convert image from RGB to gray level.
7 | """
8 | def __call__(self, img):
9 | return img.convert('L')
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import errno
4 | import shutil
5 | import os.path as osp
6 |
7 | import torch
8 |
9 | def mkdir_if_missing(directory):
10 | if not osp.exists(directory):
11 | try:
12 | os.makedirs(directory)
13 | except OSError as e:
14 | if e.errno != errno.EEXIST:
15 | raise
16 |
17 | class AverageMeter(object):
18 | """Computes and stores the average and current value.
19 |
20 | Code imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
21 | """
22 | def __init__(self):
23 | self.reset()
24 |
25 | def reset(self):
26 | self.val = 0
27 | self.avg = 0
28 | self.sum = 0
29 | self.count = 0
30 |
31 | def update(self, val, n=1):
32 | self.val = val
33 | self.sum += val * n
34 | self.count += n
35 | self.avg = self.sum / self.count
36 |
37 | def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'):
38 | mkdir_if_missing(osp.dirname(fpath))
39 | torch.save(state, fpath)
40 | if is_best:
41 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'best_model.pth.tar'))
42 |
43 | class Logger(object):
44 | """
45 | Write console output to external text file.
46 |
47 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py.
48 | """
49 | def __init__(self, fpath=None):
50 | self.console = sys.stdout
51 | self.file = None
52 | if fpath is not None:
53 | mkdir_if_missing(os.path.dirname(fpath))
54 | self.file = open(fpath, 'w')
55 |
56 | def __del__(self):
57 | self.close()
58 |
59 | def __enter__(self):
60 | pass
61 |
62 | def __exit__(self, *args):
63 | self.close()
64 |
65 | def write(self, msg):
66 | self.console.write(msg)
67 | if self.file is not None:
68 | self.file.write(msg)
69 |
70 | def flush(self):
71 | self.console.flush()
72 | if self.file is not None:
73 | self.file.flush()
74 | os.fsync(self.file.fileno())
75 |
76 | def close(self):
77 | self.console.close()
78 | if self.file is not None:
79 | self.file.close()
--------------------------------------------------------------------------------