├── LICENSE ├── README.md ├── train.py └── wideresnet.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 xternalz 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Wide Residual Networks (WideResNets) in PyTorch 2 | WideResNets for CIFAR10/100 implemented in PyTorch. This implementation requires less GPU memory than what is required by the official Torch implementation: https://github.com/szagoruyko/wide-residual-networks. 3 | 4 | Example: 5 | ``` 6 | python train.py --dataset cifar100 --layers 40 --widen-factor 4 7 | ``` 8 | 9 | # Acknowledgement 10 | - [densenet-pytorch](https://github.com/andreasveit/densenet-pytorch) 11 | - Wide Residual Networks (BMVC 2016) http://arxiv.org/abs/1605.07146 by Sergey Zagoruyko and Nikos Komodakis. 12 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import time 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.nn.parallel 10 | import torch.backends.cudnn as cudnn 11 | import torch.optim 12 | import torch.utils.data 13 | import torchvision.transforms as transforms 14 | import torchvision.datasets as datasets 15 | from torch.autograd import Variable 16 | 17 | from wideresnet import WideResNet 18 | 19 | # used for logging to TensorBoard 20 | from tensorboard_logger import configure, log_value 21 | 22 | parser = argparse.ArgumentParser(description='PyTorch WideResNet Training') 23 | parser.add_argument('--dataset', default='cifar10', type=str, 24 | help='dataset (cifar10 [default] or cifar100)') 25 | parser.add_argument('--epochs', default=200, type=int, 26 | help='number of total epochs to run') 27 | parser.add_argument('--start-epoch', default=0, type=int, 28 | help='manual epoch number (useful on restarts)') 29 | parser.add_argument('-b', '--batch-size', default=128, type=int, 30 | help='mini-batch size (default: 128)') 31 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 32 | help='initial learning rate') 33 | parser.add_argument('--momentum', default=0.9, type=float, help='momentum') 34 | parser.add_argument('--nesterov', default=True, type=bool, help='nesterov momentum') 35 | parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, 36 | help='weight decay (default: 5e-4)') 37 | parser.add_argument('--print-freq', '-p', default=10, type=int, 38 | help='print frequency (default: 10)') 39 | parser.add_argument('--layers', default=28, type=int, 40 | help='total number of layers (default: 28)') 41 | parser.add_argument('--widen-factor', default=10, type=int, 42 | help='widen factor (default: 10)') 43 | parser.add_argument('--droprate', default=0, type=float, 44 | help='dropout probability (default: 0.0)') 45 | parser.add_argument('--no-augment', dest='augment', action='store_false', 46 | help='whether to use standard augmentation (default: True)') 47 | parser.add_argument('--resume', default='', type=str, 48 | help='path to latest checkpoint (default: none)') 49 | parser.add_argument('--name', default='WideResNet-28-10', type=str, 50 | help='name of experiment') 51 | parser.add_argument('--tensorboard', 52 | help='Log progress to TensorBoard', action='store_true') 53 | parser.set_defaults(augment=True) 54 | 55 | best_prec1 = 0 56 | 57 | def main(): 58 | global args, best_prec1 59 | args = parser.parse_args() 60 | if args.tensorboard: configure("runs/%s"%(args.name)) 61 | 62 | # Data loading code 63 | normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]], 64 | std=[x/255.0 for x in [63.0, 62.1, 66.7]]) 65 | 66 | if args.augment: 67 | transform_train = transforms.Compose([ 68 | transforms.ToTensor(), 69 | transforms.Lambda(lambda x: F.pad(x.unsqueeze(0), 70 | (4,4,4,4),mode='reflect').squeeze()), 71 | transforms.ToPILImage(), 72 | transforms.RandomCrop(32), 73 | transforms.RandomHorizontalFlip(), 74 | transforms.ToTensor(), 75 | normalize, 76 | ]) 77 | else: 78 | transform_train = transforms.Compose([ 79 | transforms.ToTensor(), 80 | normalize, 81 | ]) 82 | transform_test = transforms.Compose([ 83 | transforms.ToTensor(), 84 | normalize 85 | ]) 86 | 87 | kwargs = {'num_workers': 1, 'pin_memory': True} 88 | assert(args.dataset == 'cifar10' or args.dataset == 'cifar100') 89 | train_loader = torch.utils.data.DataLoader( 90 | datasets.__dict__[args.dataset.upper()]('../data', train=True, download=True, 91 | transform=transform_train), 92 | batch_size=args.batch_size, shuffle=True, **kwargs) 93 | val_loader = torch.utils.data.DataLoader( 94 | datasets.__dict__[args.dataset.upper()]('../data', train=False, transform=transform_test), 95 | batch_size=args.batch_size, shuffle=True, **kwargs) 96 | 97 | # create model 98 | model = WideResNet(args.layers, args.dataset == 'cifar10' and 10 or 100, 99 | args.widen_factor, dropRate=args.droprate) 100 | 101 | # get the number of model parameters 102 | print('Number of model parameters: {}'.format( 103 | sum([p.data.nelement() for p in model.parameters()]))) 104 | 105 | # for training on multiple GPUs. 106 | # Use CUDA_VISIBLE_DEVICES=0,1 to specify which GPUs to use 107 | # model = torch.nn.DataParallel(model).cuda() 108 | model = model.cuda() 109 | 110 | # optionally resume from a checkpoint 111 | if args.resume: 112 | if os.path.isfile(args.resume): 113 | print("=> loading checkpoint '{}'".format(args.resume)) 114 | checkpoint = torch.load(args.resume) 115 | args.start_epoch = checkpoint['epoch'] 116 | best_prec1 = checkpoint['best_prec1'] 117 | model.load_state_dict(checkpoint['state_dict']) 118 | print("=> loaded checkpoint '{}' (epoch {})" 119 | .format(args.resume, checkpoint['epoch'])) 120 | else: 121 | print("=> no checkpoint found at '{}'".format(args.resume)) 122 | 123 | cudnn.benchmark = True 124 | 125 | # define loss function (criterion) and optimizer 126 | criterion = nn.CrossEntropyLoss().cuda() 127 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 128 | momentum=args.momentum, nesterov = args.nesterov, 129 | weight_decay=args.weight_decay) 130 | 131 | # cosine learning rate 132 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(train_loader)*args.epochs) 133 | 134 | for epoch in range(args.start_epoch, args.epochs): 135 | # train for one epoch 136 | train(train_loader, model, criterion, optimizer, scheduler, epoch) 137 | 138 | # evaluate on validation set 139 | prec1 = validate(val_loader, model, criterion, epoch) 140 | 141 | # remember best prec@1 and save checkpoint 142 | is_best = prec1 > best_prec1 143 | best_prec1 = max(prec1, best_prec1) 144 | save_checkpoint({ 145 | 'epoch': epoch + 1, 146 | 'state_dict': model.state_dict(), 147 | 'best_prec1': best_prec1, 148 | }, is_best) 149 | print('Best accuracy: ', best_prec1) 150 | 151 | def train(train_loader, model, criterion, optimizer, scheduler, epoch): 152 | """Train for one epoch on the training set""" 153 | batch_time = AverageMeter() 154 | losses = AverageMeter() 155 | top1 = AverageMeter() 156 | 157 | # switch to train mode 158 | model.train() 159 | 160 | end = time.time() 161 | for i, (input, target) in enumerate(train_loader): 162 | target = target.cuda(non_blocking=True) 163 | input = input.cuda(non_blocking=True) 164 | 165 | # compute output 166 | output = model(input) 167 | loss = criterion(output, target) 168 | 169 | # measure accuracy and record loss 170 | prec1 = accuracy(output.data, target, topk=(1,))[0] 171 | losses.update(loss.data.item(), input.size(0)) 172 | top1.update(prec1.item(), input.size(0)) 173 | 174 | # compute gradient and do SGD step 175 | optimizer.zero_grad() 176 | loss.backward() 177 | optimizer.step() 178 | scheduler.step() 179 | 180 | # measure elapsed time 181 | batch_time.update(time.time() - end) 182 | end = time.time() 183 | 184 | if i % args.print_freq == 0: 185 | print('Epoch: [{0}][{1}/{2}]\t' 186 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 187 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 188 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 189 | epoch, i, len(train_loader), batch_time=batch_time, 190 | loss=losses, top1=top1)) 191 | # log to TensorBoard 192 | if args.tensorboard: 193 | log_value('train_loss', losses.avg, epoch) 194 | log_value('train_acc', top1.avg, epoch) 195 | 196 | def validate(val_loader, model, criterion, epoch): 197 | """Perform validation on the validation set""" 198 | batch_time = AverageMeter() 199 | losses = AverageMeter() 200 | top1 = AverageMeter() 201 | 202 | # switch to evaluate mode 203 | model.eval() 204 | 205 | end = time.time() 206 | for i, (input, target) in enumerate(val_loader): 207 | target = target.cuda(non_blocking=True) 208 | input = input.cuda(non_blocking=True) 209 | 210 | # compute output 211 | with torch.no_grad(): 212 | output = model(input) 213 | loss = criterion(output, target) 214 | 215 | # measure accuracy and record loss 216 | prec1 = accuracy(output.data, target, topk=(1,))[0] 217 | losses.update(loss.data.item(), input.size(0)) 218 | top1.update(prec1.item(), input.size(0)) 219 | 220 | # measure elapsed time 221 | batch_time.update(time.time() - end) 222 | end = time.time() 223 | 224 | if i % args.print_freq == 0: 225 | print('Test: [{0}/{1}]\t' 226 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 227 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 228 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 229 | i, len(val_loader), batch_time=batch_time, loss=losses, 230 | top1=top1)) 231 | 232 | print(' * Prec@1 {top1.avg:.3f}'.format(top1=top1)) 233 | # log to TensorBoard 234 | if args.tensorboard: 235 | log_value('val_loss', losses.avg, epoch) 236 | log_value('val_acc', top1.avg, epoch) 237 | return top1.avg 238 | 239 | 240 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 241 | """Saves checkpoint to disk""" 242 | directory = "runs/%s/"%(args.name) 243 | if not os.path.exists(directory): 244 | os.makedirs(directory) 245 | filename = directory + filename 246 | torch.save(state, filename) 247 | if is_best: 248 | shutil.copyfile(filename, 'runs/%s/'%(args.name) + 'model_best.pth.tar') 249 | 250 | class AverageMeter(object): 251 | """Computes and stores the average and current value""" 252 | def __init__(self): 253 | self.reset() 254 | 255 | def reset(self): 256 | self.val = 0 257 | self.avg = 0 258 | self.sum = 0 259 | self.count = 0 260 | 261 | def update(self, val, n=1): 262 | self.val = val 263 | self.sum += val * n 264 | self.count += n 265 | self.avg = self.sum / self.count 266 | 267 | def accuracy(output, target, topk=(1,)): 268 | """Computes the precision@k for the specified values of k""" 269 | maxk = max(topk) 270 | batch_size = target.size(0) 271 | 272 | _, pred = output.topk(maxk, 1, True, True) 273 | pred = pred.t() 274 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 275 | 276 | res = [] 277 | for k in topk: 278 | correct_k = correct[:k].view(-1).float().sum(0) 279 | res.append(correct_k.mul_(100.0 / batch_size)) 280 | return res 281 | 282 | if __name__ == '__main__': 283 | main() 284 | -------------------------------------------------------------------------------- /wideresnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class BasicBlock(nn.Module): 8 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 9 | super(BasicBlock, self).__init__() 10 | self.bn1 = nn.BatchNorm2d(in_planes) 11 | self.relu1 = nn.ReLU(inplace=True) 12 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 13 | padding=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(out_planes) 15 | self.relu2 = nn.ReLU(inplace=True) 16 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 17 | padding=1, bias=False) 18 | self.droprate = dropRate 19 | self.equalInOut = (in_planes == out_planes) 20 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 21 | padding=0, bias=False) or None 22 | def forward(self, x): 23 | if not self.equalInOut: 24 | x = self.relu1(self.bn1(x)) 25 | else: 26 | out = self.relu1(self.bn1(x)) 27 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 28 | if self.droprate > 0: 29 | out = F.dropout(out, p=self.droprate, training=self.training) 30 | out = self.conv2(out) 31 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 32 | 33 | class NetworkBlock(nn.Module): 34 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 35 | super(NetworkBlock, self).__init__() 36 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 37 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 38 | layers = [] 39 | for i in range(int(nb_layers)): 40 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 41 | return nn.Sequential(*layers) 42 | def forward(self, x): 43 | return self.layer(x) 44 | 45 | class WideResNet(nn.Module): 46 | def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0): 47 | super(WideResNet, self).__init__() 48 | nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] 49 | assert((depth - 4) % 6 == 0) 50 | n = (depth - 4) / 6 51 | block = BasicBlock 52 | # 1st conv before any network block 53 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 54 | padding=1, bias=False) 55 | # 1st block 56 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 57 | # 2nd block 58 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 59 | # 3rd block 60 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 61 | # global average pooling and classifier 62 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 63 | self.relu = nn.ReLU(inplace=True) 64 | self.fc = nn.Linear(nChannels[3], num_classes) 65 | self.nChannels = nChannels[3] 66 | 67 | for m in self.modules(): 68 | if isinstance(m, nn.Conv2d): 69 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 70 | elif isinstance(m, nn.BatchNorm2d): 71 | m.weight.data.fill_(1) 72 | m.bias.data.zero_() 73 | elif isinstance(m, nn.Linear): 74 | m.bias.data.zero_() 75 | def forward(self, x): 76 | out = self.conv1(x) 77 | out = self.block1(out) 78 | out = self.block2(out) 79 | out = self.block3(out) 80 | out = self.relu(self.bn1(out)) 81 | out = F.avg_pool2d(out, 8) 82 | out = out.view(-1, self.nChannels) 83 | return self.fc(out) 84 | --------------------------------------------------------------------------------