├── models ├── __init__.py └── vanilla.py ├── run.sh.example ├── imgs ├── lr0.01.png ├── lr0.003.png └── paper-result.png ├── average_meter.py ├── utils.py ├── README.md └── main.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .vanilla import Vanilla 2 | -------------------------------------------------------------------------------- /run.sh.example: -------------------------------------------------------------------------------- 1 | python main.py --repr --S1 20 --S2 10 2 | -------------------------------------------------------------------------------- /imgs/lr0.01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siahuat0727/RePr/HEAD/imgs/lr0.01.png -------------------------------------------------------------------------------- /imgs/lr0.003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siahuat0727/RePr/HEAD/imgs/lr0.003.png -------------------------------------------------------------------------------- /imgs/paper-result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siahuat0727/RePr/HEAD/imgs/paper-result.png -------------------------------------------------------------------------------- /average_meter.py: -------------------------------------------------------------------------------- 1 | class AverageMeter(object): 2 | """Computes and stores the average and current value""" 3 | def __init__(self): 4 | self.reset() 5 | 6 | def reset(self): 7 | self.val = 0 8 | self.avg = 0 9 | self.sum = 0 10 | self.count = 0 11 | 12 | def update(self, val, n=1): 13 | self.val = val 14 | self.sum += val * n 15 | self.count += n 16 | self.avg = self.sum / self.count 17 | -------------------------------------------------------------------------------- /models/vanilla.py: -------------------------------------------------------------------------------- 1 | '''ConvNet in PyTorch.''' 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class Vanilla(nn.Module): 6 | def __init__(self, num_classes=10): 7 | super(Vanilla, self).__init__() 8 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1) 9 | self.conv2 = nn.Conv2d(32, 32, kernel_size=3, padding=1) 10 | self.conv3 = nn.Conv2d(32, 32, kernel_size=3, padding=1) 11 | self.fc1 = nn.Linear(32*32*32, num_classes) 12 | 13 | def forward(self, x): 14 | out = F.relu(self.conv1(x)) 15 | out = F.relu(self.conv2(out)) 16 | out = F.relu(self.conv3(out)) 17 | out = out.view(out.size(0), -1) 18 | out = self.fc1(out) 19 | return out 20 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.linalg import qr 3 | import torch 4 | 5 | def test_filter_sparsity(conv_weights): 6 | for name, W in conv_weights: 7 | zero = sum(w.nonzero().size(0) == 0 for w in W) 8 | print("filter sparsity of layer {} is {}".format(name, zero/W.size(0))) 9 | 10 | def qr_null(A, tol=None): 11 | Q, R, _ = qr(A.T, mode='full', pivoting=True) 12 | tol = np.finfo(R.dtype).eps if tol is None else tol 13 | rnk = min(A.shape) - np.abs(np.diag(R))[::-1].searchsorted(tol) 14 | return Q[:, rnk:].conj() 15 | 16 | 17 | def accuracy(output, target, topk=(1, )): 18 | """Computes the accuracy over the k top predictions for the specified values of k""" 19 | with torch.no_grad(): 20 | maxk = max(topk) 21 | batch_size = target.size(0) 22 | 23 | _, pred = output.topk(maxk, 1, True, True) 24 | pred = pred.t() 25 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 26 | 27 | res = [] 28 | for k in topk: 29 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 30 | res.append(correct_k.mul_(100.0 / batch_size)) 31 | return res 32 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RePr 2 | My implementation of RePr training scheme in PyTorch. https://arxiv.org/pdf/1811.07275.pdf 3 | 4 | ## Usage 5 | ``` 6 | $ python main.py --help 7 | usage: main.py [-h] [--lr LR] [--repr] [--S1 S1] [--S2 S2] [--epochs EPOCHS] 8 | [--workers WORKERS] [--print_freq PRINT_FREQ] [--gpu GPU] 9 | [--save_model SAVE_MODEL] [--prune_ratio PRUNE_RATIO] 10 | [--comment COMMENT] [--zero_init] 11 | 12 | PyTorch CIFAR10 Training 13 | 14 | optional arguments: 15 | -h, --help show this help message and exit 16 | --lr LR learning rate (default: 0.01) 17 | --repr whether to use RePr training scheme (default: False) 18 | --S1 S1 S1 epochs for RePr (default: 20) 19 | --S2 S2 S2 epochs for RePr (default: 10) 20 | --epochs EPOCHS total epochs for training (default: 100) 21 | --workers WORKERS number of worker to load data (default: 16) 22 | --print_freq PRINT_FREQ 23 | print frequency (default: 50) 24 | --gpu GPU gpu id (default: 0) 25 | --save_model SAVE_MODEL 26 | path to save model (default: best.pt) 27 | --prune_ratio PRUNE_RATIO 28 | prune ratio (default: 0.3) 29 | --comment COMMENT tag for tensorboardX event name (default: ) 30 | --zero_init whether to initialize with zero (default: False) 31 | ``` 32 | 33 | ## Execute example 34 | Standard training scheme 35 | ``` 36 | $ python main.py 37 | ``` 38 | 39 | RePr training scheme 40 | ``` 41 | $ python main.py --repr --S1 20 --S2 10 --epoch 110 42 | ``` 43 | 44 | ## Results 45 | 46 | ### Original paper 47 | 48 | Std | RePr 49 | ---- | ---- 50 | 72.1 | 76.4 51 | 52 | ![](https://github.com/siahuat0727/RePr/blob/master/imgs/paper-result.png) 53 | 54 | ### My implementation 55 | 56 | With data augmentation (`torchvision.transforms.RandomCrop`) 57 | 58 | learning rate = 0.01 59 | 60 | Std | RePr 61 | ---- | ---- 62 | 77.84| 74.48 63 | 64 | ![](https://github.com/siahuat0727/RePr/blob/master/imgs/lr0.01.png) 65 | 66 | 67 | Without data augmentation 68 | 69 | learning rate = 0.003 70 | 71 | Std | RePr 72 | ---- | ---- 73 | 64.86| 69.05 74 | 75 | ![](https://github.com/siahuat0727/RePr/blob/master/imgs/lr0.003.png) 76 | 77 | For more information, please visit [my blog](https://siahuat0727.github.io/2019/03/17/repr/). (Last updated on 2019-04-28) 78 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | '''Train CIFAR10 with PyTorch.''' 2 | from __future__ import print_function 3 | import math 4 | import argparse 5 | import time 6 | import datetime 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.backends.cudnn as cudnn 11 | import torchvision 12 | import torchvision.transforms as transforms 13 | import numpy as np 14 | import matplotlib.pyplot as plt 15 | from matplotlib.colors import ListedColormap 16 | from models import Vanilla 17 | from average_meter import AverageMeter 18 | from utils import qr_null, test_filter_sparsity, accuracy 19 | from tensorboardX import SummaryWriter 20 | 21 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training', 22 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 23 | parser.add_argument('--lr', type=float, default=0.01, help="learning rate") 24 | parser.add_argument('--repr', action='store_true', help="whether to use RePr training scheme") 25 | parser.add_argument('--S1', type=int, default=20, help="S1 epochs for RePr") 26 | parser.add_argument('--S2', type=int, default=10, help="S2 epochs for RePr") 27 | parser.add_argument('--epochs', type=int, default=100, help="total epochs for training") 28 | parser.add_argument('--workers', type=int, default=16, help="number of worker to load data") 29 | parser.add_argument('--print_freq', type=int, default=50, help="print frequency") 30 | parser.add_argument('--gpu', type=int, default=0, help="gpu id") 31 | parser.add_argument('--save_model', type=str, default='best.pt', help="path to save model") 32 | parser.add_argument('--prune_ratio', type=float, default=0.3, help="prune ratio") 33 | parser.add_argument('--comment', type=str, default='', help="tag for tensorboardX event name") 34 | parser.add_argument('--zero_init', action='store_true', help="whether to initialize with zero") 35 | 36 | def train(train_loader, criterion, optimizer, epoch, model, writer, mask, args, conv_weights): 37 | batch_time = AverageMeter() 38 | data_time = AverageMeter() 39 | losses = AverageMeter() 40 | top1 = AverageMeter() 41 | 42 | # switch to train mode 43 | model.train() 44 | 45 | end = time.time() 46 | for i, (data, target) in enumerate(train_loader): 47 | # measure data loading time 48 | data_time.update(time.time() - end) 49 | 50 | if args.gpu is not None: # TODO None? 51 | data = data.cuda(args.gpu, non_blocking=True) 52 | target = target.cuda(args.gpu, non_blocking=True) 53 | 54 | output = model(data) 55 | 56 | loss = criterion(output, target) 57 | 58 | acc1, _ = accuracy(output, target, topk=(1, 5)) 59 | 60 | losses.update(loss.item(), data.size(0)) 61 | top1.update(acc1[0], data.size(0)) 62 | 63 | optimizer.zero_grad() 64 | 65 | loss.backward() 66 | 67 | S1, S2 = args.S1, args.S2 68 | if args.repr and any(s1 <= epoch < s1+S2 for s1 in range(S1, args.epochs, S1+S2)): 69 | if i == 0: 70 | print('freeze for this epoch') 71 | with torch.no_grad(): 72 | for name, W in conv_weights: 73 | W.grad[mask[name]] = 0 74 | 75 | optimizer.step() 76 | 77 | # measure elapsed time 78 | batch_time.update(time.time() - end) 79 | 80 | if i % args.print_freq == 0: 81 | print('Epoch: [{0}][{1}/{2}]\t' 82 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 83 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 84 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 85 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t' 86 | 'LR {lr:.3f}\t' 87 | .format( 88 | epoch, i, len(train_loader), batch_time=batch_time, 89 | data_time=data_time, loss=losses, top1=top1, 90 | lr=optimizer.param_groups[0]['lr'])) 91 | 92 | end = time.time() 93 | writer.add_scalar('Train/Acc', top1.avg, epoch) 94 | writer.add_scalar('Train/Loss', losses.avg, epoch) 95 | 96 | def validate(val_loader, criterion, model, writer, args, epoch, best_acc): 97 | batch_time = AverageMeter() 98 | losses = AverageMeter() 99 | top1 = AverageMeter() 100 | 101 | # switch to evaluate mode 102 | model.eval() 103 | 104 | with torch.no_grad(): 105 | end = time.time() 106 | for i, (data, target) in enumerate(val_loader): 107 | if args.gpu is not None: # TODO None? 108 | data = data.cuda(args.gpu, non_blocking=True) 109 | target = target.cuda(args.gpu, non_blocking=True) 110 | 111 | # compute output 112 | output = model(data) 113 | loss = criterion(output, target) 114 | 115 | # measure accuracy and record loss 116 | acc1, _ = accuracy(output, target, topk=(1, 5)) 117 | losses.update(loss.item(), data.size(0)) 118 | top1.update(acc1[0], data.size(0)) 119 | 120 | # measure elapsed time 121 | batch_time.update(time.time() - end) 122 | 123 | if i % args.print_freq == 0: 124 | print('Test: [{0}/{1}]\t' 125 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 126 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 127 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t' 128 | .format( 129 | i, len(val_loader), batch_time=batch_time, loss=losses, 130 | top1=top1)) 131 | end = time.time() 132 | 133 | print(' * Acc@1 {top1.avg:.3f} '.format(top1=top1)) 134 | writer.add_scalar('Test/Acc', top1.avg, epoch) 135 | writer.add_scalar('Test/Loss', losses.avg, epoch) 136 | 137 | if top1.avg.item() > best_acc: 138 | print('new best_acc is {top1.avg:.3f}'.format(top1=top1)) 139 | print('saving model {}'.format(args.save_model)) 140 | torch.save(model.state_dict(), args.save_model) 141 | return top1.avg.item() 142 | 143 | def pruning(conv_weights, prune_ratio): 144 | print('Pruning...') 145 | # calculate inter-filter orthogonality 146 | inter_filter_ortho = {} 147 | for name, W in conv_weights: 148 | size = W.size() 149 | W2d = W.view(size[0], -1) 150 | W2d = F.normalize(W2d, p=2, dim=1) 151 | W_WT = torch.mm(W2d, W2d.transpose(0, 1)) 152 | I = torch.eye(W_WT.size()[0], dtype=torch.float32).cuda() 153 | P = torch.abs(W_WT - I) 154 | P = P.sum(dim=1) / size[0] 155 | inter_filter_ortho[name] = P.cpu().detach().numpy() 156 | # the ranking is computed overall the filters in the network 157 | ranks = np.concatenate([v.flatten() for v in inter_filter_ortho.values()]) 158 | threshold = np.percentile(ranks, 100*(1-prune_ratio)) 159 | 160 | prune = {} 161 | mask = {} 162 | drop_filters = {} 163 | for name, W in conv_weights: 164 | prune[name] = inter_filter_ortho[name] > threshold # e.g. [True, False, True, True, False] 165 | # get indice of bad filters 166 | mask[name] = np.where(prune[name])[0] # e.g. [0, 2, 3] 167 | drop_filters[name] = None 168 | if mask[name].size > 0: 169 | with torch.no_grad(): 170 | drop_filters[name] = W.data[mask[name]].view(mask[name].size, -1).cpu().numpy() 171 | W.data[mask[name]] = 0 172 | 173 | test_filter_sparsity(conv_weights) 174 | return prune, mask, drop_filters 175 | 176 | def reinitialize(mask, drop_filters, conv_weights, fc_weights, zero_init): 177 | print('Reinitializing...') 178 | with torch.no_grad(): 179 | prev_layer_name = None 180 | prev_num_filters = None 181 | for name, W in conv_weights + fc_weights: 182 | if W.dim() == 4 and drop_filters[name] is not None: # conv weights 183 | # find null space 184 | size = W.size() 185 | stdv = 1. / math.sqrt(size[1]*size[2]*size[3]) # https://github.com/pytorch/pytorch/blob/08891b0a4e08e2c642deac2042a02238a4d34c67/torch/nn/modules/conv.py#L40-L47 186 | W2d = W.view(size[0], -1).cpu().numpy() 187 | null_space = qr_null(np.vstack((drop_filters[name], W2d))) 188 | null_space = torch.from_numpy(null_space).cuda() 189 | 190 | if null_space.size == 0: 191 | W.data[mask[name]].uniform_(-stdv, stdv) 192 | else: 193 | null_space = null_space.transpose(0, 1).view(-1, size[1], size[2], size[3]) 194 | null_count = 0 195 | for mask_idx in mask[name]: 196 | if null_count < null_space.size(0): 197 | W.data[mask_idx] = null_space.data[null_count].clamp_(-stdv, stdv) 198 | null_count += 1 199 | else: 200 | W.data[mask_idx].uniform_(-stdv, stdv) 201 | 202 | # mask channels of prev-layer-pruned-filters' outputs 203 | if prev_layer_name is not None: 204 | if W.dim() == 4: # conv 205 | if zero_init: 206 | W.data[:, mask[prev_layer_name]] = 0 207 | else: 208 | W.data[:, mask[prev_layer_name]].uniform_(-stdv, stdv) 209 | elif W.dim() == 2: # fc 210 | if zero_init: 211 | W.view(W.size(0), prev_num_filters, -1).data[:, mask[prev_layer_name]] = 0 212 | else: 213 | stdv = 1. / math.sqrt(W.size(1)) 214 | W.view(W.size(0), prev_num_filters, -1).data[:, mask[prev_layer_name]].uniform_(-stdv, stdv) 215 | prev_layer_name, prev_num_filters = name, W.size(0) 216 | test_filter_sparsity(conv_weights) 217 | 218 | def main(): 219 | if not torch.cuda.is_available(): 220 | raise Exception("Only support GPU training") 221 | cudnn.benchmark = True 222 | 223 | args = parser.parse_args() 224 | 225 | # Data 226 | print('==> Preparing data..') 227 | 228 | transform_train = transforms.Compose([ 229 | transforms.RandomCrop(32, padding=4), 230 | transforms.RandomHorizontalFlip(), 231 | transforms.ToTensor(), 232 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 233 | ]) 234 | 235 | transform_test = transforms.Compose([ 236 | transforms.ToTensor(), 237 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 238 | ]) 239 | 240 | trainset = torchvision.datasets.CIFAR10( 241 | root='./data', train=True, download=True, transform=transform_train) 242 | trainloader = torch.utils.data.DataLoader( 243 | trainset, batch_size=128, shuffle=True, num_workers=args.workers) 244 | 245 | testset = torchvision.datasets.CIFAR10( 246 | root='./data', train=False, download=True, transform=transform_test) 247 | testloader = torch.utils.data.DataLoader( 248 | testset, batch_size=100, shuffle=False, num_workers=args.workers) 249 | 250 | # Model 251 | print('==> Building model..') 252 | 253 | model = Vanilla() 254 | print(model) 255 | 256 | if args.gpu is not None: 257 | torch.cuda.set_device(args.gpu) 258 | model.cuda() 259 | else: 260 | model.cuda() 261 | model = torch.nn.DataParallel(model) 262 | 263 | conv_weights = [] 264 | fc_weights = [] 265 | for name, W in model.named_parameters(): 266 | if W.dim() == 4: 267 | conv_weights.append((name, W)) 268 | elif W.dim() == 2: 269 | fc_weights.append((name, W)) 270 | 271 | criterion = nn.CrossEntropyLoss().cuda() 272 | optimizer = torch.optim.SGD(model.parameters(), args.lr) 273 | comment = "-{}-{}-{}".format("repr" if args.repr else "norepr", args.epochs, args.comment) 274 | writer = SummaryWriter(comment=comment) 275 | 276 | mask = None 277 | drop_filters = None 278 | best_acc = 0 # best test accuracy 279 | prune_map = [] 280 | for epoch in range(args.epochs): 281 | if args.repr: 282 | # check if the end of S1 stage 283 | if any(epoch == s for s in range(args.S1, args.epochs, args.S1+args.S2)): 284 | prune, mask, drop_filters = pruning(conv_weights, args.prune_ratio) 285 | prune_map.append(np.concatenate(list(prune.values()))) 286 | # check if the end of S2 stage 287 | if any(epoch == s for s in range(args.S1+args.S2, args.epochs, args.S1+args.S2)): 288 | reinitialize(mask, drop_filters, conv_weights, fc_weights, args.zero_init) 289 | train(trainloader, criterion, optimizer, epoch, model, writer, mask, args, conv_weights) 290 | acc = validate(testloader, criterion, model, writer, args, epoch, best_acc) 291 | best_acc = max(best_acc, acc) 292 | test_filter_sparsity(conv_weights) 293 | 294 | writer.close() 295 | print('overall best_acc is {}'.format(best_acc)) 296 | 297 | # Shows which filters turn off as training progresses 298 | if args.repr: 299 | prune_map = np.array(prune_map).transpose() 300 | print(prune_map) 301 | plt.matshow(prune_map.astype(np.int), cmap=ListedColormap(['k', 'w'])) 302 | plt.xticks(np.arange(prune_map.shape[1])) 303 | plt.yticks(np.arange(prune_map.shape[0])) 304 | plt.title('Filters on/off map\nwhite: off (pruned)\nblack: on') 305 | plt.xlabel('Pruning stage') 306 | plt.ylabel('Filter index from shallower layer to deeper layer') 307 | plt.savefig('{}-{}.png'.format( 308 | datetime.datetime.fromtimestamp(time.time()).strftime('%Y-%m-%d-%H:%M:%S'), 309 | comment)) 310 | 311 | 312 | if __name__ == '__main__': 313 | main() 314 | --------------------------------------------------------------------------------