├── images
├── FIg3.png
├── Fig1.png
└── Fig4.png
├── LICENSE
├── README.md
├── densenet.py
└── train.py
/images/FIg3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/andreasveit/densenet-pytorch/HEAD/images/FIg3.png
--------------------------------------------------------------------------------
/images/Fig1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/andreasveit/densenet-pytorch/HEAD/images/Fig1.png
--------------------------------------------------------------------------------
/images/Fig4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/andreasveit/densenet-pytorch/HEAD/images/Fig4.png
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2017, Andreas Veit
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | * Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | * Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | * Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # A PyTorch Implementation for Densely Connected Convolutional Networks (DenseNets)
2 |
3 | This repository contains a [PyTorch](http://pytorch.org/) implementation of the paper [Densely Connected Convolutional Networks](http://arxiv.org/abs/1608.06993). The code is based on the excellent [PyTorch example for training ResNet on Imagenet](https://github.com/pytorch/examples/tree/master/imagenet).
4 |
5 | The detault setting for this repo is a DenseNet-BC (with bottleneck layers and channel reduction), 100 layers, a growth rate of 12 and batch size 64.
6 |
7 | The [Official torch implementaion](https://github.com/liuzhuang13/DenseNet) contains further links to implementations in other frameworks.
8 |
9 | Example usage with optional arguments for different hyperparameters (e.g., DenseNet-40-12):
10 | ```sh
11 | $ python train.py --layers 40 --growth 12 --no-bottleneck --reduce 1.0 --name DenseNet-40-12
12 | ```
13 |
14 | ## DenseNets
15 | [DenseNets [1]](https://arxiv.org/abs/1608.06993) were introduced in late 2016 after to the discoveries by [[2]](https://arxiv.org/abs/1603.09382) and [[3]](https://arxiv.org/abs/1605.06431) that [residual networks [4]](https://arxiv.org/abs/1512.03385) exhibit extreme parameter redundancy. DenseNets address this shortcoming by reducing the size of the modules and by introducing more connections between layers. In fact, the output of each layer flows directly as input to all subsequent layers of the same feature dimension as illustrated in their Figure 1 (below). This increases the dependency between the layers and thus reduces redundancy.
16 |
17 |
18 |
19 | The improvements in accuracy per parameter are illustrated in their results on ImageNet (Figure 3).
20 |
21 |
22 |
23 | ## This implementation
24 | The training code in train.py trains a DenseNet on CIFAR 10 or 100. To train on ImageNet, densenet.py can be copied into the [PyTorch example for training ResNets on Imagenet](https://github.com/pytorch/examples/tree/master/imagenet), upon which this repo is based. Note that for ImageNet the model contains four dense blocks.
25 |
26 | This implementation is quite _memory efficient requiring between 10% and 20% less memory_ compared to the original torch implementation. We optain a final test error of 4.76 % with DenseNet-BC-100-12 (paper reports 4.51 %) and 5.35 % with DenseNet-40-12 (paper reports 5.24 %).
27 |
28 | This implementation allows for __all model variants__ in the DenseNet paper, i.e., with and without bottleneck, channel reduction, data augmentation and dropout.
29 |
30 | For simple configuration of the model, this repo uses `argparse` so that key hyperparameters can be easily changed.
31 |
32 | Further, this implementation supports [easy checkpointing](https://github.com/andreasveit/densenet-pytorch/blob/master/train.py#L136), keeping track of the best model and [resuming](https://github.com/andreasveit/densenet-pytorch/blob/master/train.py#L103) training from previous checkpoints.
33 |
34 | ### Tracking training progress with TensorBoard
35 | To track training progress, this implementation uses [TensorBoard](https://www.tensorflow.org/get_started/summaries_and_tensorboard) which offers great ways to track and compare multiple experiments. To track PyTorch experiments in TensorBoard we use [tensorboard_logger](https://github.com/TeamHG-Memex/tensorboard_logger) which can be installed with
36 | ```
37 | pip install tensorboard_logger
38 | ```
39 | Example training curves for DenseNet-BC-100-12 (dark blue) and DenseNet-40-12 (light blue) for training loss and validation accuracy is shown below.
40 |
41 | 
42 |
43 | ### Dependencies
44 | * [PyTorch](http://pytorch.org/)
45 |
46 | optional:
47 | * [tensorboard_logger](https://github.com/TeamHG-Memex/tensorboard_logger)
48 |
49 |
50 | ### Cite
51 | If you use DenseNets in your work, please cite the original paper as:
52 | ```
53 | @article{Huang2016Densely,
54 | author = {Huang, Gao and Liu, Zhuang and Weinberger, Kilian Q.},
55 | title = {Densely Connected Convolutional Networks},
56 | journal = {arXiv preprint arXiv:1608.06993},
57 | year = {2016}
58 | }
59 | ```
60 |
61 | If this implementation is useful to you and your project, please also consider to cite or acknowledge this code repository.
62 |
63 | ### References
64 | [1] Huang, G., Liu, Z., Weinberger, K. Q., & van der Maaten, L. (2016). Densely connected convolutional networks. arXiv preprint arXiv:1608.06993.
65 |
66 | [2] Huang, G., Sun, Y., Liu, Z., Sedra, D., & Weinberger, K. Q. (2016). Deep networks with stochastic depth. In European Conference on Computer Vision (ECCV '16)
67 |
68 | [3] Veit, A., Wilber, M. J., & Belongie, S. (2016). Residual networks behave like ensembles of relatively shallow networks. In Advances in Neural Information Processing Systems (NIPS '16)
69 |
70 | [4] He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep residual learning for image recognition. In Conference on Computer Vision and Pattern Recognition (CVPR '16)
71 |
--------------------------------------------------------------------------------
/densenet.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class BasicBlock(nn.Module):
8 | def __init__(self, in_planes, out_planes, dropRate=0.0):
9 | super(BasicBlock, self).__init__()
10 | self.bn1 = nn.BatchNorm2d(in_planes)
11 | self.relu = nn.ReLU(inplace=True)
12 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1,
13 | padding=1, bias=False)
14 | self.droprate = dropRate
15 | def forward(self, x):
16 | out = self.conv1(self.relu(self.bn1(x)))
17 | if self.droprate > 0:
18 | out = F.dropout(out, p=self.droprate, training=self.training)
19 | return torch.cat([x, out], 1)
20 |
21 | class BottleneckBlock(nn.Module):
22 | def __init__(self, in_planes, out_planes, dropRate=0.0):
23 | super(BottleneckBlock, self).__init__()
24 | inter_planes = out_planes * 4
25 | self.bn1 = nn.BatchNorm2d(in_planes)
26 | self.relu = nn.ReLU(inplace=True)
27 | self.conv1 = nn.Conv2d(in_planes, inter_planes, kernel_size=1, stride=1,
28 | padding=0, bias=False)
29 | self.bn2 = nn.BatchNorm2d(inter_planes)
30 | self.conv2 = nn.Conv2d(inter_planes, out_planes, kernel_size=3, stride=1,
31 | padding=1, bias=False)
32 | self.droprate = dropRate
33 | def forward(self, x):
34 | out = self.conv1(self.relu(self.bn1(x)))
35 | if self.droprate > 0:
36 | out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)
37 | out = self.conv2(self.relu(self.bn2(out)))
38 | if self.droprate > 0:
39 | out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)
40 | return torch.cat([x, out], 1)
41 |
42 | class TransitionBlock(nn.Module):
43 | def __init__(self, in_planes, out_planes, dropRate=0.0):
44 | super(TransitionBlock, self).__init__()
45 | self.bn1 = nn.BatchNorm2d(in_planes)
46 | self.relu = nn.ReLU(inplace=True)
47 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1,
48 | padding=0, bias=False)
49 | self.droprate = dropRate
50 | def forward(self, x):
51 | out = self.conv1(self.relu(self.bn1(x)))
52 | if self.droprate > 0:
53 | out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)
54 | return F.avg_pool2d(out, 2)
55 |
56 | class DenseBlock(nn.Module):
57 | def __init__(self, nb_layers, in_planes, growth_rate, block, dropRate=0.0):
58 | super(DenseBlock, self).__init__()
59 | self.layer = self._make_layer(block, in_planes, growth_rate, nb_layers, dropRate)
60 | def _make_layer(self, block, in_planes, growth_rate, nb_layers, dropRate):
61 | layers = []
62 | for i in range(nb_layers):
63 | layers.append(block(in_planes+i*growth_rate, growth_rate, dropRate))
64 | return nn.Sequential(*layers)
65 | def forward(self, x):
66 | return self.layer(x)
67 |
68 | class DenseNet3(nn.Module):
69 | def __init__(self, depth, num_classes, growth_rate=12,
70 | reduction=0.5, bottleneck=True, dropRate=0.0):
71 | super(DenseNet3, self).__init__()
72 | in_planes = 2 * growth_rate
73 | n = (depth - 4) / 3
74 | if bottleneck == True:
75 | n = n/2
76 | block = BottleneckBlock
77 | else:
78 | block = BasicBlock
79 | n = int(n)
80 | # 1st conv before any dense block
81 | self.conv1 = nn.Conv2d(3, in_planes, kernel_size=3, stride=1,
82 | padding=1, bias=False)
83 | # 1st block
84 | self.block1 = DenseBlock(n, in_planes, growth_rate, block, dropRate)
85 | in_planes = int(in_planes+n*growth_rate)
86 | self.trans1 = TransitionBlock(in_planes, int(math.floor(in_planes*reduction)), dropRate=dropRate)
87 | in_planes = int(math.floor(in_planes*reduction))
88 | # 2nd block
89 | self.block2 = DenseBlock(n, in_planes, growth_rate, block, dropRate)
90 | in_planes = int(in_planes+n*growth_rate)
91 | self.trans2 = TransitionBlock(in_planes, int(math.floor(in_planes*reduction)), dropRate=dropRate)
92 | in_planes = int(math.floor(in_planes*reduction))
93 | # 3rd block
94 | self.block3 = DenseBlock(n, in_planes, growth_rate, block, dropRate)
95 | in_planes = int(in_planes+n*growth_rate)
96 | # global average pooling and classifier
97 | self.bn1 = nn.BatchNorm2d(in_planes)
98 | self.relu = nn.ReLU(inplace=True)
99 | self.fc = nn.Linear(in_planes, num_classes)
100 | self.in_planes = in_planes
101 |
102 | for m in self.modules():
103 | if isinstance(m, nn.Conv2d):
104 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
105 | m.weight.data.normal_(0, math.sqrt(2. / n))
106 | elif isinstance(m, nn.BatchNorm2d):
107 | m.weight.data.fill_(1)
108 | m.bias.data.zero_()
109 | elif isinstance(m, nn.Linear):
110 | m.bias.data.zero_()
111 | def forward(self, x):
112 | out = self.conv1(x)
113 | out = self.trans1(self.block1(out))
114 | out = self.trans2(self.block2(out))
115 | out = self.block3(out)
116 | out = self.relu(self.bn1(out))
117 | out = F.avg_pool2d(out, 8)
118 | out = out.view(-1, self.in_planes)
119 | return self.fc(out)
120 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import shutil
4 | import time
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.parallel
9 | import torch.backends.cudnn as cudnn
10 | import torch.optim
11 | import torch.utils.data
12 | import torchvision.transforms as transforms
13 | import torchvision.datasets as datasets
14 |
15 | import densenet as dn
16 |
17 | # used for logging to TensorBoard
18 | from tensorboard_logger import configure, log_value
19 |
20 | parser = argparse.ArgumentParser(description='PyTorch DenseNet Training')
21 | parser.add_argument('--epochs', default=300, type=int,
22 | help='number of total epochs to run')
23 | parser.add_argument('--start-epoch', default=0, type=int,
24 | help='manual epoch number (useful on restarts)')
25 | parser.add_argument('-b', '--batch-size', default=64, type=int,
26 | help='mini-batch size (default: 64)')
27 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
28 | help='initial learning rate')
29 | parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
30 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
31 | help='weight decay (default: 1e-4)')
32 | parser.add_argument('--print-freq', '-p', default=10, type=int,
33 | help='print frequency (default: 10)')
34 | parser.add_argument('--layers', default=100, type=int,
35 | help='total number of layers (default: 100)')
36 | parser.add_argument('--growth', default=12, type=int,
37 | help='number of new channels per layer (default: 12)')
38 | parser.add_argument('--droprate', default=0, type=float,
39 | help='dropout probability (default: 0.0)')
40 | parser.add_argument('--no-augment', dest='augment', action='store_false',
41 | help='whether to use standard augmentation (default: True)')
42 | parser.add_argument('--reduce', default=0.5, type=float,
43 | help='compression rate in transition stage (default: 0.5)')
44 | parser.add_argument('--no-bottleneck', dest='bottleneck', action='store_false',
45 | help='To not use bottleneck block')
46 | parser.add_argument('--resume', default='', type=str,
47 | help='path to latest checkpoint (default: none)')
48 | parser.add_argument('--name', default='DenseNet_BC_100_12', type=str,
49 | help='name of experiment')
50 | parser.add_argument('--tensorboard',
51 | help='Log progress to TensorBoard', action='store_true')
52 | parser.set_defaults(bottleneck=True)
53 | parser.set_defaults(augment=True)
54 |
55 | best_prec1 = 0
56 |
57 | def main():
58 | global args, best_prec1
59 | args = parser.parse_args()
60 | if args.tensorboard: configure("runs/%s"%(args.name))
61 |
62 | # Data loading code
63 | normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]],
64 | std=[x/255.0 for x in [63.0, 62.1, 66.7]])
65 |
66 | if args.augment:
67 | transform_train = transforms.Compose([
68 | transforms.RandomCrop(32, padding=4),
69 | transforms.RandomHorizontalFlip(),
70 | transforms.ToTensor(),
71 | normalize,
72 | ])
73 | else:
74 | transform_train = transforms.Compose([
75 | transforms.ToTensor(),
76 | normalize,
77 | ])
78 | transform_test = transforms.Compose([
79 | transforms.ToTensor(),
80 | normalize
81 | ])
82 |
83 | kwargs = {'num_workers': 1, 'pin_memory': True}
84 | train_loader = torch.utils.data.DataLoader(
85 | datasets.CIFAR10('../data', train=True, download=True,
86 | transform=transform_train),
87 | batch_size=args.batch_size, shuffle=True, **kwargs)
88 | val_loader = torch.utils.data.DataLoader(
89 | datasets.CIFAR10('../data', train=False, transform=transform_test),
90 | batch_size=args.batch_size, shuffle=True, **kwargs)
91 |
92 | # create model
93 | model = dn.DenseNet3(args.layers, 10, args.growth, reduction=args.reduce,
94 | bottleneck=args.bottleneck, dropRate=args.droprate)
95 |
96 | # get the number of model parameters
97 | print('Number of model parameters: {}'.format(
98 | sum([p.data.nelement() for p in model.parameters()])))
99 |
100 | # for training on multiple GPUs.
101 | # Use CUDA_VISIBLE_DEVICES=0,1 to specify which GPUs to use
102 | # model = torch.nn.DataParallel(model).cuda()
103 | model = model.cuda()
104 |
105 | # optionally resume from a checkpoint
106 | if args.resume:
107 | if os.path.isfile(args.resume):
108 | print("=> loading checkpoint '{}'".format(args.resume))
109 | checkpoint = torch.load(args.resume)
110 | args.start_epoch = checkpoint['epoch']
111 | best_prec1 = checkpoint['best_prec1']
112 | model.load_state_dict(checkpoint['state_dict'])
113 | print("=> loaded checkpoint '{}' (epoch {})"
114 | .format(args.resume, checkpoint['epoch']))
115 | else:
116 | print("=> no checkpoint found at '{}'".format(args.resume))
117 |
118 | cudnn.benchmark = True
119 |
120 | # define loss function (criterion) and pptimizer
121 | criterion = nn.CrossEntropyLoss().cuda()
122 | optimizer = torch.optim.SGD(model.parameters(), args.lr,
123 | momentum=args.momentum,
124 | nesterov=True,
125 | weight_decay=args.weight_decay)
126 |
127 | for epoch in range(args.start_epoch, args.epochs):
128 | adjust_learning_rate(optimizer, epoch)
129 |
130 | # train for one epoch
131 | train(train_loader, model, criterion, optimizer, epoch)
132 |
133 | # evaluate on validation set
134 | prec1 = validate(val_loader, model, criterion, epoch)
135 |
136 | # remember best prec@1 and save checkpoint
137 | is_best = prec1 > best_prec1
138 | best_prec1 = max(prec1, best_prec1)
139 | save_checkpoint({
140 | 'epoch': epoch + 1,
141 | 'state_dict': model.state_dict(),
142 | 'best_prec1': best_prec1,
143 | }, is_best)
144 | print('Best accuracy: ', best_prec1)
145 |
146 | def train(train_loader, model, criterion, optimizer, epoch):
147 | """Train for one epoch on the training set"""
148 | batch_time = AverageMeter()
149 | losses = AverageMeter()
150 | top1 = AverageMeter()
151 |
152 | # switch to train mode
153 | model.train()
154 |
155 | end = time.time()
156 | for i, (input, target) in enumerate(train_loader):
157 | target = target.cuda(async=True)
158 | input = input.cuda()
159 | input_var = torch.autograd.Variable(input)
160 | target_var = torch.autograd.Variable(target)
161 |
162 | # compute output
163 | output = model(input_var)
164 | loss = criterion(output, target_var)
165 |
166 | # measure accuracy and record loss
167 | prec1 = accuracy(output.data, target, topk=(1,))[0]
168 | losses.update(loss.data[0], input.size(0))
169 | top1.update(prec1[0], input.size(0))
170 |
171 | # compute gradient and do SGD step
172 | optimizer.zero_grad()
173 | loss.backward()
174 | optimizer.step()
175 |
176 | # measure elapsed time
177 | batch_time.update(time.time() - end)
178 | end = time.time()
179 |
180 | if i % args.print_freq == 0:
181 | print('Epoch: [{0}][{1}/{2}]\t'
182 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
183 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
184 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
185 | epoch, i, len(train_loader), batch_time=batch_time,
186 | loss=losses, top1=top1))
187 | # log to TensorBoard
188 | if args.tensorboard:
189 | log_value('train_loss', losses.avg, epoch)
190 | log_value('train_acc', top1.avg, epoch)
191 |
192 | def validate(val_loader, model, criterion, epoch):
193 | """Perform validation on the validation set"""
194 | batch_time = AverageMeter()
195 | losses = AverageMeter()
196 | top1 = AverageMeter()
197 |
198 | # switch to evaluate mode
199 | model.eval()
200 |
201 | end = time.time()
202 | for i, (input, target) in enumerate(val_loader):
203 | target = target.cuda(async=True)
204 | input = input.cuda()
205 | input_var = torch.autograd.Variable(input, volatile=True)
206 | target_var = torch.autograd.Variable(target, volatile=True)
207 |
208 | # compute output
209 | output = model(input_var)
210 | loss = criterion(output, target_var)
211 |
212 | # measure accuracy and record loss
213 | prec1 = accuracy(output.data, target, topk=(1,))[0]
214 | losses.update(loss.data[0], input.size(0))
215 | top1.update(prec1[0], input.size(0))
216 |
217 | # measure elapsed time
218 | batch_time.update(time.time() - end)
219 | end = time.time()
220 |
221 | if i % args.print_freq == 0:
222 | print('Test: [{0}/{1}]\t'
223 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
224 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
225 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
226 | i, len(val_loader), batch_time=batch_time, loss=losses,
227 | top1=top1))
228 |
229 | print(' * Prec@1 {top1.avg:.3f}'.format(top1=top1))
230 | # log to TensorBoard
231 | if args.tensorboard:
232 | log_value('val_loss', losses.avg, epoch)
233 | log_value('val_acc', top1.avg, epoch)
234 | return top1.avg
235 |
236 |
237 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
238 | """Saves checkpoint to disk"""
239 | directory = "runs/%s/"%(args.name)
240 | if not os.path.exists(directory):
241 | os.makedirs(directory)
242 | filename = directory + filename
243 | torch.save(state, filename)
244 | if is_best:
245 | shutil.copyfile(filename, 'runs/%s/'%(args.name) + 'model_best.pth.tar')
246 |
247 | class AverageMeter(object):
248 | """Computes and stores the average and current value"""
249 | def __init__(self):
250 | self.reset()
251 |
252 | def reset(self):
253 | self.val = 0
254 | self.avg = 0
255 | self.sum = 0
256 | self.count = 0
257 |
258 | def update(self, val, n=1):
259 | self.val = val
260 | self.sum += val * n
261 | self.count += n
262 | self.avg = self.sum / self.count
263 |
264 |
265 | def adjust_learning_rate(optimizer, epoch):
266 | """Sets the learning rate to the initial LR decayed by 10 after 150 and 225 epochs"""
267 | lr = args.lr * (0.1 ** (epoch // 150)) * (0.1 ** (epoch // 225))
268 | # log to TensorBoard
269 | if args.tensorboard:
270 | log_value('learning_rate', lr, epoch)
271 | for param_group in optimizer.param_groups:
272 | param_group['lr'] = lr
273 |
274 | def accuracy(output, target, topk=(1,)):
275 | """Computes the precision@k for the specified values of k"""
276 | maxk = max(topk)
277 | batch_size = target.size(0)
278 |
279 | _, pred = output.topk(maxk, 1, True, True)
280 | pred = pred.t()
281 | correct = pred.eq(target.view(1, -1).expand_as(pred))
282 |
283 | res = []
284 | for k in topk:
285 | correct_k = correct[:k].view(-1).float().sum(0)
286 | res.append(correct_k.mul_(100.0 / batch_size))
287 | return res
288 |
289 | if __name__ == '__main__':
290 | main()
291 |
--------------------------------------------------------------------------------