├── README.md ├── spectral_norm.py ├── eval.py ├── resnet.py └── trainer.py /README.md: -------------------------------------------------------------------------------- 1 | Bounding Singular Values of Convolution Layers 2 | ===================================== 3 | 4 | Code for reproducing experiments in ["Bounding Singular Values of Convolution Layers"] 5 | 6 | ## Prerequisites 7 | 8 | - Python, NumPy, Pytorch, Argparse 9 | - A recent NVIDIA GPU 10 | 11 | ## Basic Usage 12 | 13 | To train the model with parameter value BETA, run python trainer.py --beta BETA. To access all the parameters use python trainer.py --help. 14 | -------------------------------------------------------------------------------- /spectral_norm.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import time 5 | import math 6 | 7 | import torch 8 | from torch.autograd import gradcheck 9 | import torch.nn as nn 10 | import torch.nn.parallel 11 | import torch.backends.cudnn as cudnn 12 | import torch.optim 13 | import torch.utils.data 14 | import torchvision.transforms as transforms 15 | import torchvision.datasets as datasets 16 | import resnet 17 | 18 | def power_iteration(W, u=None, v=None, num_iters=50, return_vectors=False): 19 | if u is None: 20 | u = torch.randn((1, W.shape[1]), device='cuda') 21 | u_norm = torch.norm(u, dim=1) 22 | u_n = u/u_norm 23 | else: 24 | u_n = u 25 | v_n = v 26 | for i in range(num_iters): 27 | v = torch.matmul(u_n, W.t()) 28 | v_norm = torch.norm(v, dim=1) 29 | v_n = v/v_norm 30 | 31 | u = torch.matmul(v_n, W) 32 | u_norm = torch.norm(u, dim=1) 33 | u_n = u/u_norm 34 | sigma = (v_n.mm(W)).mm(u_n.t()) 35 | if return_vectors: 36 | return sigma[0, 0], u_n, v_n 37 | return sigma[0, 0] 38 | 39 | class ConvFilterNorm(nn.Module): 40 | def __init__(self, conv_filter): 41 | super(ConvFilterNorm, self).__init__() 42 | out_ch, in_ch, h, w = conv_filter.shape 43 | conv_filter_permute = conv_filter.permute(dims=(0, 2, 1, 3)) 44 | conv_filter_matrix = conv_filter_permute.contiguous().view(out_ch*h, -1) 45 | 46 | self.sigma, u, v = power_iteration(conv_filter_matrix, num_iters=50, return_vectors=True) 47 | self.u = u.detach() 48 | self.v = v.detach() 49 | 50 | def forward(self, conv_filter): 51 | out_ch, in_ch, h, w = conv_filter.shape 52 | conv_filter_permute = conv_filter.permute(dims=(0, 2, 1, 3)) 53 | conv_filter_matrix = conv_filter_permute.contiguous().view(out_ch*h, -1) 54 | 55 | _, u, v = power_iteration(conv_filter_matrix, self.u, self.v, num_iters=10, return_vectors=True) 56 | self.u = u.detach() 57 | self.v = v.detach() 58 | return math.sqrt(h*w)*MatrixNormFunction.apply(conv_filter_matrix, self.u, self.v) 59 | 60 | class MatrixNorm(nn.Module): 61 | def __init__(self, matrix): 62 | super(MatrixNorm, self).__init__() 63 | self.sigma, u, v = power_iteration(matrix, num_iters=50, return_vectors=True) 64 | self.u = u.detach() 65 | self.v = v.detach() 66 | 67 | def forward(self, matrix): 68 | _, u, v = power_iteration(matrix, self.u, self.v, num_iters=10, return_vectors=True) 69 | self.u = u.detach() 70 | self.v = v.detach() 71 | return MatrixNormFunction.apply(matrix, self.u, self.v) 72 | 73 | class MatrixNormFunction(torch.autograd.Function): 74 | @staticmethod 75 | def forward(ctx, matrix, u, v): 76 | sigma = (v.mm(matrix)).mm(u.t()) 77 | ctx.save_for_backward(matrix, u, v) 78 | return sigma 79 | 80 | @staticmethod 81 | def backward(ctx, grad_output): 82 | filter_matrix, u, v = ctx.saved_tensors 83 | grad_weight = grad_output.clone() 84 | grad_singular = ((v.t()).mm(u)) 85 | return grad_weight*grad_singular, None, None 86 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import time 5 | import math 6 | import numpy as np 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.parallel 11 | import torch.backends.cudnn as cudnn 12 | import torch.optim 13 | import torch.utils.data 14 | import torchvision.transforms as transforms 15 | import torchvision.datasets as datasets 16 | import resnet 17 | 18 | from modelsummary import summary 19 | 20 | model_names = sorted(name for name in resnet.__dict__ 21 | if name.islower() and not name.startswith("__") 22 | and name.startswith("resnet") 23 | and callable(resnet.__dict__[name])) 24 | 25 | parser = argparse.ArgumentParser(description='Propert ResNets for CIFAR10 in pytorch') 26 | parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet32', 27 | choices=model_names, 28 | help='model architecture: ' + ' | '.join(model_names) + 29 | ' (default: resnet32)') 30 | parser.add_argument('-b', '--batch-size', default=100, type=int, 31 | metavar='N', help='mini-batch size (default: 100)') 32 | parser.add_argument('--load-dir', dest='load_dir', 33 | help='The directory used to load the trained models', 34 | default='save_resnet32', type=str) 35 | parser.add_argument('--checkpoint', action='store_true', 36 | help='use the last checkpoint model') 37 | 38 | def main(): 39 | global args, best_prec1 40 | args = parser.parse_args() 41 | 42 | model = torch.nn.DataParallel(resnet.__dict__[args.arch]()) 43 | model.cuda() 44 | 45 | checkpoint = torch.load(os.path.join(args.load_dir, 'model.th')) 46 | if args.checkpoint: 47 | checkpoint = torch.load(os.path.join(args.load_dir, 'checkpoint.th')) 48 | best_prec1 = checkpoint['best_prec1'] 49 | model.load_state_dict(checkpoint['state_dict']) 50 | 51 | cudnn.benchmark = True 52 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 53 | std=[0.229, 0.224, 0.225]) 54 | 55 | val_loader = torch.utils.data.DataLoader( 56 | datasets.CIFAR10(root='./data', train=False, transform=transforms.Compose([ 57 | transforms.ToTensor(), 58 | normalize, 59 | ])), 60 | batch_size=args.batch_size, shuffle=False, 61 | num_workers=4, pin_memory=True) 62 | 63 | # define loss function (criterion) and pptimizer 64 | criterion = nn.CrossEntropyLoss().cuda() 65 | 66 | validate(val_loader, model, criterion) 67 | return 68 | 69 | def validate(val_loader, model, criterion): 70 | """ 71 | Run evaluation 72 | """ 73 | batch_time = AverageMeter() 74 | losses = AverageMeter() 75 | top1 = AverageMeter() 76 | 77 | # switch to evaluate mode 78 | model.eval() 79 | 80 | end = time.time() 81 | with torch.no_grad(): 82 | for i, (input, target) in enumerate(val_loader): 83 | target = target.cuda() 84 | input_var = input.cuda() 85 | target_var = target.cuda() 86 | 87 | # compute output 88 | output = model(input_var) 89 | loss = criterion(output, target_var) 90 | 91 | output = output.float() 92 | loss = loss.float() 93 | 94 | # measure accuracy and record loss 95 | prec1 = accuracy(output.data, target)[0] 96 | losses.update(loss.item(), input.size(0)) 97 | top1.update(prec1.item(), input.size(0)) 98 | 99 | # measure elapsed time 100 | batch_time.update(time.time() - end) 101 | end = time.time() 102 | 103 | if i % 50 == 0: 104 | print('Test: [{0}/{1}]\t' 105 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 106 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 107 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 108 | i, len(val_loader), batch_time=batch_time, loss=losses, 109 | top1=top1)) 110 | 111 | print(' * Prec@1 {top1.avg:.3f}' 112 | .format(top1=top1)) 113 | 114 | return top1.avg 115 | 116 | class AverageMeter(object): 117 | """Computes and stores the average and current value""" 118 | def __init__(self): 119 | self.reset() 120 | 121 | def reset(self): 122 | self.val = 0 123 | self.avg = 0 124 | self.sum = 0 125 | self.count = 0 126 | 127 | def update(self, val, n=1): 128 | self.val = val 129 | self.sum += val * n 130 | self.count += n 131 | self.avg = self.sum / self.count 132 | 133 | 134 | def accuracy(output, target, topk=(1,)): 135 | """Computes the precision@k for the specified values of k""" 136 | maxk = max(topk) 137 | batch_size = target.size(0) 138 | 139 | _, pred = output.topk(maxk, 1, True, True) 140 | pred = pred.t() 141 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 142 | 143 | res = [] 144 | for k in topk: 145 | correct_k = correct[:k].view(-1).float().sum(0) 146 | res.append(correct_k.mul_(100.0 / batch_size)) 147 | return res 148 | 149 | 150 | if __name__ == '__main__': 151 | main() 152 | -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Properly implemented ResNet-s for CIFAR10 as described in paper [1]. 3 | 4 | The implementation and structure of this file is hugely influenced by [2] 5 | which is implemented for ImageNet and doesn't have option A for identity. 6 | Moreover, most of the implementations on the web is copy-paste from 7 | torchvision's resnet and has wrong number of params. 8 | 9 | Proper ResNet-s for CIFAR10 (for fair comparision and etc.) has following 10 | number of layers and parameters: 11 | 12 | name | layers | params 13 | ResNet20 | 20 | 0.27M 14 | ResNet32 | 32 | 0.46M 15 | ResNet44 | 44 | 0.66M 16 | ResNet56 | 56 | 0.85M 17 | ResNet110 | 110 | 1.7M 18 | ResNet1202| 1202 | 19.4m 19 | 20 | which this implementation indeed has. 21 | 22 | Reference: 23 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 24 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 25 | [2] https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 26 | 27 | If you use this implementation in you work, please don't forget to mention the 28 | author, Yerlan Idelbayev. 29 | ''' 30 | import math 31 | 32 | import torch 33 | import torch.nn as nn 34 | import torch.nn.functional as F 35 | import torch.nn.init as init 36 | 37 | from torch.autograd import Variable 38 | 39 | __all__ = ['ResNet', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202'] 40 | 41 | def _weights_init(m): 42 | classname = m.__class__.__name__ 43 | #print(classname) 44 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 45 | init.kaiming_normal_(m.weight) 46 | 47 | def power_iteration(W, u=None, v=None, num_iters=50, return_vectors=False): 48 | if u is None: 49 | u = torch.randn((1, W.shape[1]), device='cuda') 50 | u_norm = torch.norm(u, dim=1) 51 | u_n = u/u_norm 52 | for i in range(num_iters): 53 | v = torch.matmul(u_n, W.t()) 54 | v_norm = torch.norm(v, dim=1) 55 | v_n = v/v_norm 56 | 57 | u = torch.matmul(v_n, W) 58 | u_norm = torch.norm(u, dim=1) 59 | u_n = u/u_norm 60 | sigma = (v_n.mm(W)).mm(u_n.t()) 61 | if return_vectors: 62 | return sigma[0, 0], u_n, v_n 63 | return sigma 64 | 65 | class SpectralConv2d(nn.Module): 66 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True): 67 | super(SpectralConv2d, self).__init__() 68 | self.in_channels = in_channels 69 | self.out_channels = out_channels 70 | self.stride = (stride, stride) 71 | self.padding = (padding, padding) 72 | 73 | self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels, kernel_size, kernel_size)).to('cuda') 74 | if bias: 75 | self.bias = nn.Parameter(torch.Tensor(out_channels)).to('cuda') 76 | else: 77 | self.register_parameter('bias', None) 78 | 79 | init.kaiming_normal_(self.weight) 80 | if self.bias is not None: 81 | self.bias.data.uniform_(-0.1, 0.1) 82 | 83 | out_ch, in_ch, h, w = self.weight.shape 84 | filter_permute = self.weight.permute(dims=(0, 2, 1, 3)) 85 | filter_matrix = filter_permute.contiguous().view(out_ch*h, -1) 86 | 87 | self.sigma, self.u, self.v = power_iteration(filter_matrix, num_iters=50, return_vectors=True) 88 | self.sigma = math.sqrt(h*w)*self.sigma 89 | 90 | def forward(self, input): 91 | out_ch, in_ch, h, w = self.weight.shape 92 | filter_permute = self.weight.permute(dims=(0, 2, 1, 3)) 93 | filter_matrix = filter_permute.contiguous().view(out_ch*h, -1) 94 | self.sigma, self.u, self.v = power_iteration(filter_matrix, num_iters=1, return_vectors=True) 95 | self.sigma = math.sqrt(h*w)*self.sigma 96 | return F.conv2d(input, self.weight, self.bias, self.stride, self.padding) 97 | 98 | class LambdaLayer(nn.Module): 99 | def __init__(self, lambd): 100 | super(LambdaLayer, self).__init__() 101 | self.lambd = lambd 102 | 103 | def forward(self, x): 104 | return self.lambd(x) 105 | 106 | conv_module = nn.Conv2d 107 | # conv_module = SpectralConv2d 108 | 109 | class BasicBlock(nn.Module): 110 | expansion = 1 111 | 112 | def __init__(self, in_planes, planes, stride=1, option='A'): 113 | super(BasicBlock, self).__init__() 114 | self.conv1 = conv_module(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 115 | self.bn1 = nn.BatchNorm2d(planes) 116 | self.conv2 = conv_module(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 117 | self.bn2 = nn.BatchNorm2d(planes) 118 | 119 | self.shortcut = nn.Sequential() 120 | if stride != 1 or in_planes != planes: 121 | if option == 'A': 122 | """ 123 | For CIFAR10 ResNet paper uses option A. 124 | """ 125 | self.shortcut = LambdaLayer(lambda x: 126 | F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0)) 127 | elif option == 'B': 128 | self.shortcut = nn.Sequential( 129 | conv_module(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 130 | nn.BatchNorm2d(self.expansion * planes) 131 | ) 132 | 133 | def forward(self, x): 134 | out = F.relu(self.bn1(self.conv1(x))) 135 | out = self.bn2(self.conv2(out)) 136 | out += self.shortcut(x) 137 | out = F.relu(out) 138 | return out 139 | 140 | 141 | class ResNet(nn.Module): 142 | def __init__(self, block, num_blocks, num_classes=10): 143 | super(ResNet, self).__init__() 144 | self.in_planes = 16 145 | 146 | self.conv1 = conv_module(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 147 | self.bn1 = nn.BatchNorm2d(16) 148 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 149 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 150 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) 151 | self.linear = nn.Linear(64, num_classes) 152 | 153 | self.apply(_weights_init) 154 | 155 | def _make_layer(self, block, planes, num_blocks, stride): 156 | strides = [stride] + [1]*(num_blocks-1) 157 | layers = [] 158 | for stride in strides: 159 | layers.append(block(self.in_planes, planes, stride)) 160 | self.in_planes = planes * block.expansion 161 | 162 | return nn.Sequential(*layers) 163 | 164 | def forward(self, x): 165 | out = F.relu(self.bn1(self.conv1(x))) 166 | out = self.layer1(out) 167 | out = self.layer2(out) 168 | out = self.layer3(out) 169 | out = F.avg_pool2d(out, out.size()[3]) 170 | out = out.view(out.size(0), -1) 171 | out = self.linear(out) 172 | return out 173 | 174 | 175 | def resnet20(): 176 | return ResNet(BasicBlock, [3, 3, 3]) 177 | 178 | 179 | def resnet32(): 180 | return ResNet(BasicBlock, [5, 5, 5]) 181 | 182 | 183 | def resnet44(): 184 | return ResNet(BasicBlock, [7, 7, 7]) 185 | 186 | 187 | def resnet56(): 188 | return ResNet(BasicBlock, [9, 9, 9]) 189 | 190 | 191 | def resnet110(): 192 | return ResNet(BasicBlock, [18, 18, 18]) 193 | 194 | 195 | def resnet1202(): 196 | return ResNet(BasicBlock, [200, 200, 200]) 197 | 198 | 199 | def test(net): 200 | import numpy as np 201 | total_params = 0 202 | 203 | for x in filter(lambda p: p.requires_grad, net.parameters()): 204 | total_params += np.prod(x.data.numpy().shape) 205 | print("Total number of params", total_params) 206 | print("Total layers", len(list(filter(lambda p: p.requires_grad and len(p.data.size())>1, net.parameters())))) 207 | 208 | 209 | if __name__ == "__main__": 210 | for net_name in __all__: 211 | if net_name.startswith('resnet'): 212 | print(net_name) 213 | test(globals()[net_name]()) 214 | print() 215 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import time 5 | import math 6 | 7 | import torch 8 | from torch.autograd import gradcheck 9 | import torch.nn as nn 10 | import torch.nn.parallel 11 | import torch.backends.cudnn as cudnn 12 | import torch.optim 13 | import torch.utils.data 14 | import torchvision.transforms as transforms 15 | import torchvision.datasets as datasets 16 | import resnet 17 | 18 | from spectral_norm import ConvFilterNorm, MatrixNorm 19 | 20 | model_names = sorted(name for name in resnet.__dict__ 21 | if name.islower() and not name.startswith("__") 22 | and name.startswith("resnet") 23 | and callable(resnet.__dict__[name])) 24 | 25 | print(model_names) 26 | 27 | parser = argparse.ArgumentParser(description='Propert ResNets for CIFAR10 in pytorch') 28 | parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet32', 29 | choices=model_names, 30 | help='model architecture: ' + ' | '.join(model_names) + 31 | ' (default: resnet32)') 32 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 33 | help='number of data loading workers (default: 4)') 34 | parser.add_argument('--epochs', default=200, type=int, metavar='N', 35 | help='number of total epochs to run') 36 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 37 | help='manual epoch number (useful on restarts)') 38 | parser.add_argument('-b', '--batch-size', default=128, type=int, 39 | metavar='N', help='mini-batch size (default: 128)') 40 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 41 | metavar='LR', help='initial learning rate') 42 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 43 | help='momentum') 44 | parser.add_argument('--weight-decay', '--wd', default=0, type=float, 45 | metavar='W', help='weight decay (default: 0)') 46 | parser.add_argument('--print-freq', '-p', default=50, type=int, 47 | metavar='N', help='print frequency (default: 20)') 48 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 49 | help='path to latest checkpoint (default: none)') 50 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 51 | help='evaluate model on validation set') 52 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 53 | help='use pre-trained model') 54 | parser.add_argument('--save-dir', dest='save_dir', 55 | help='The directory used to save the trained models', 56 | default='save', type=str) 57 | parser.add_argument('--beta', dest='beta', 58 | help='argument used to regularize spectral norm', 59 | default=0.001, type=float) 60 | best_prec1 = 0 61 | 62 | def main(): 63 | global args, best_prec1 64 | args = parser.parse_args() 65 | print(args) 66 | 67 | if args.arch: 68 | args.save_dir = args.save_dir + '_' + args.arch + '_' + str(args.beta) 69 | 70 | # Check the save_dir exists or not 71 | if not os.path.exists(args.save_dir): 72 | os.makedirs(args.save_dir) 73 | 74 | model = torch.nn.DataParallel(resnet.__dict__[args.arch]()) 75 | model.cuda() 76 | 77 | # optionally resume from a checkpoint 78 | if args.resume: 79 | if os.path.isfile(args.resume): 80 | print("=> loading checkpoint '{}'".format(args.resume)) 81 | checkpoint = torch.load(args.resume) 82 | args.start_epoch = checkpoint['epoch'] 83 | best_prec1 = checkpoint['best_prec1'] 84 | model.load_state_dict(checkpoint['state_dict']) 85 | print("=> loaded checkpoint '{}' (epoch {})" 86 | .format(args.evaluate, checkpoint['epoch'])) 87 | else: 88 | print("=> no checkpoint found at '{}'".format(args.resume)) 89 | 90 | cudnn.benchmark = True 91 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 92 | std=[0.229, 0.224, 0.225]) 93 | 94 | train_loader = torch.utils.data.DataLoader( 95 | datasets.CIFAR10(root='./data', train=True, transform=transforms.Compose([ 96 | transforms.RandomHorizontalFlip(), 97 | transforms.RandomCrop(32, 4), 98 | transforms.ToTensor(), 99 | normalize, 100 | ]), download=True), 101 | batch_size=args.batch_size, shuffle=True, 102 | num_workers=args.workers, pin_memory=True) 103 | 104 | val_loader = torch.utils.data.DataLoader( 105 | datasets.CIFAR10(root='./data', train=False, transform=transforms.Compose([ 106 | transforms.ToTensor(), 107 | normalize, 108 | ])), 109 | batch_size=100, shuffle=False, 110 | num_workers=args.workers, pin_memory=True) 111 | 112 | # define loss function (criterion) and pptimizer 113 | criterion = nn.CrossEntropyLoss().cuda() 114 | 115 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 116 | momentum=args.momentum, 117 | weight_decay=args.weight_decay) 118 | 119 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, 120 | milestones=[100, 150], last_epoch=args.start_epoch - 1) 121 | 122 | if args.arch in ['resnet1202', 'resnet110']: 123 | # for resnet1202 original paper uses lr=0.01 for first 400 minibatches for warm-up 124 | # then switch back. In this setup it will correspond for first epoch. 125 | for param_group in optimizer.param_groups: 126 | param_group['lr'] = args.lr*0.1 127 | 128 | 129 | if args.evaluate: 130 | validate(val_loader, model, criterion) 131 | return 132 | 133 | spectral_dict = spectralnorm_init(model) 134 | for epoch in range(args.start_epoch, args.epochs): 135 | # train for one epoch 136 | print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr'])) 137 | train(train_loader, model, criterion, optimizer, epoch, spectral_dict, args.beta) 138 | lr_scheduler.step() 139 | 140 | # evaluate on validation set 141 | prec1 = validate(val_loader, model, criterion) 142 | 143 | # remember best prec@1 and save checkpoint 144 | is_best = prec1 > best_prec1 145 | best_prec1 = max(prec1, best_prec1) 146 | 147 | save_checkpoint({ 148 | 'epoch': epoch + 1, 149 | 'state_dict': model.state_dict(), 150 | 'best_prec1': best_prec1, 151 | }, filename=os.path.join(args.save_dir, 'checkpoint.th')) 152 | 153 | if is_best: 154 | save_checkpoint({ 155 | 'state_dict': model.state_dict(), 156 | 'best_prec1': best_prec1, 157 | }, filename=os.path.join(args.save_dir, 'model.th')) 158 | 159 | def spectralnorm_init(model): 160 | spectral_dict = {} 161 | for name, param in model.named_parameters(): 162 | if len(param.shape) > 2: 163 | spectral_dict[name] = ConvFilterNorm(param) 164 | elif len(param.shape) == 2: 165 | spectral_dict[name] = MatrixNorm(param) 166 | return spectral_dict 167 | 168 | def spectralnorm_sum(model, spectral_dict): 169 | sigma_list = [] 170 | for name, param in model.named_parameters(): 171 | if len(param.shape) >= 2: 172 | spectral_norm = spectral_dict[name](param) 173 | sigma_list.append(torch.log(spectral_norm)) 174 | return torch.stack(sigma_list, dim=0).sum() 175 | 176 | def train(train_loader, model, criterion, optimizer, epoch, spectral_dict, beta): 177 | """ 178 | Run one train epoch 179 | """ 180 | 181 | batch_time = AverageMeter() 182 | data_time = AverageMeter() 183 | losses = AverageMeter() 184 | top1 = AverageMeter() 185 | 186 | # switch to train mode 187 | model.train() 188 | 189 | end = time.time() 190 | for i, (input, target) in enumerate(train_loader): 191 | 192 | # measure data loading time 193 | data_time.update(time.time() - end) 194 | 195 | target = target.cuda() 196 | input_var = input.cuda() 197 | target_var = target 198 | 199 | spectral_loss = spectralnorm_sum(model, spectral_dict) 200 | 201 | # compute output 202 | output = model(input_var) 203 | loss = criterion(output, target_var) + beta*spectral_loss 204 | 205 | # compute gradient and do SGD step 206 | optimizer.zero_grad() 207 | loss.backward() 208 | optimizer.step() 209 | 210 | output = output.float() 211 | loss = loss.float() 212 | 213 | # measure accuracy and record loss 214 | prec1 = accuracy(output.data, target)[0] 215 | losses.update(loss.item(), input.size(0)) 216 | top1.update(prec1.item(), input.size(0)) 217 | 218 | # measure elapsed time 219 | batch_time.update(time.time() - end) 220 | end = time.time() 221 | 222 | if i % args.print_freq == 0: 223 | print('Epoch: [{0}][{1}/{2}]\t' 224 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 225 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 226 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 227 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 228 | epoch, i, len(train_loader), batch_time=batch_time, 229 | data_time=data_time, loss=losses, top1=top1)) 230 | 231 | 232 | def validate(val_loader, model, criterion): 233 | """ 234 | Run evaluation 235 | """ 236 | batch_time = AverageMeter() 237 | losses = AverageMeter() 238 | top1 = AverageMeter() 239 | 240 | # switch to evaluate mode 241 | model.eval() 242 | 243 | end = time.time() 244 | with torch.no_grad(): 245 | for i, (input, target) in enumerate(val_loader): 246 | target = target.cuda() 247 | input_var = input.cuda() 248 | target_var = target.cuda() 249 | 250 | # compute output 251 | output = model(input_var) 252 | loss = criterion(output, target_var) 253 | 254 | output = output.float() 255 | loss = loss.float() 256 | 257 | # measure accuracy and record loss 258 | prec1 = accuracy(output.data, target)[0] 259 | losses.update(loss.item(), input.size(0)) 260 | top1.update(prec1.item(), input.size(0)) 261 | 262 | # measure elapsed time 263 | batch_time.update(time.time() - end) 264 | end = time.time() 265 | 266 | if i % args.print_freq == 0: 267 | print('Test: [{0}/{1}]\t' 268 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 269 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 270 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 271 | i, len(val_loader), batch_time=batch_time, loss=losses, 272 | top1=top1)) 273 | 274 | print(' * Prec@1 {top1.avg:.3f}'.format(top1=top1)) 275 | return top1.avg 276 | 277 | def save_checkpoint(state, filename='checkpoint.pth.tar'): 278 | """ 279 | Save the training model 280 | """ 281 | torch.save(state, filename) 282 | 283 | class AverageMeter(object): 284 | """Computes and stores the average and current value""" 285 | def __init__(self): 286 | self.reset() 287 | 288 | def reset(self): 289 | self.val = 0 290 | self.avg = 0 291 | self.sum = 0 292 | self.count = 0 293 | 294 | def update(self, val, n=1): 295 | self.val = val 296 | self.sum += val * n 297 | self.count += n 298 | self.avg = self.sum / self.count 299 | 300 | 301 | def accuracy(output, target, topk=(1,)): 302 | """Computes the precision@k for the specified values of k""" 303 | maxk = max(topk) 304 | batch_size = target.size(0) 305 | 306 | _, pred = output.topk(maxk, 1, True, True) 307 | pred = pred.t() 308 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 309 | 310 | res = [] 311 | for k in topk: 312 | correct_k = correct[:k].view(-1).float().sum(0) 313 | res.append(correct_k.mul_(100.0 / batch_size)) 314 | return res 315 | 316 | 317 | if __name__ == '__main__': 318 | main() 319 | --------------------------------------------------------------------------------