├── utils ├── __init__.py ├── preprocess.py ├── common.py └── options.py ├── data ├── __init__.py └── cifar10.py ├── model ├── __init__.py ├── discriminator.py └── resnet.py ├── run.sh ├── fista.py ├── README.md ├── finetune.py └── main.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .cifar10 import Data as cifar10 2 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .discriminator import Discriminator 2 | from .resnet import resnet_56, resnet_56_sparse -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | # Training 2 | # CUDA_VISIBLE_DEVICES=0 python main.py --teacher_dir [pre-trained model dir] 3 | 4 | # Fine-tuning 5 | # CUDA_VISIBLE_DEVICES=0 python finetune.py --refine [pruend model dir] -------------------------------------------------------------------------------- /model/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | class Discriminator(nn.Module): 7 | def __init__(self, num_classes=10): 8 | super(Discriminator, self).__init__() 9 | self.filters = [num_classes, 128, 256, 128] 10 | block = [ 11 | nn.Linear(self.filters[i], self.filters[i+1]) \ 12 | for i in range(3) 13 | ] 14 | self.body = nn.Sequential(*block) 15 | 16 | self._initialize_weights() 17 | 18 | def forward(self, input): 19 | x = self.body(input) 20 | return x 21 | 22 | def _initialize_weights(self): 23 | for m in self.modules(): 24 | if isinstance(m, nn.Conv2d): 25 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 26 | if m.bias is not None: 27 | nn.init.constant_(m.bias, 0) 28 | elif isinstance(m, nn.BatchNorm2d): 29 | nn.init.constant_(m.weight, 1) 30 | nn.init.constant_(m.bias, 0) 31 | elif isinstance(m, nn.Linear): 32 | nn.init.normal_(m.weight, 0, 0.01) 33 | nn.init.constant_(m.bias, 0) -------------------------------------------------------------------------------- /data/cifar10.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets import CIFAR10 2 | from torch.utils.data import Dataset, DataLoader 3 | import torchvision.transforms as transforms 4 | 5 | class Data: 6 | def __init__(self, args): 7 | # pin_memory = False 8 | # if args.gpu is not None: 9 | pin_memory = True 10 | 11 | transform_train = transforms.Compose([ 12 | transforms.RandomCrop(32, padding=4), 13 | transforms.RandomHorizontalFlip(), 14 | transforms.ToTensor(), 15 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 16 | ]) 17 | 18 | transform_test = transforms.Compose([ 19 | transforms.ToTensor(), 20 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 21 | ]) 22 | 23 | trainset = CIFAR10(root=args.data_dir, train=True, download=True, transform=transform_train) 24 | 25 | self.loader_train = DataLoader( 26 | trainset, batch_size=args.train_batch_size, shuffle=True, 27 | num_workers=2, pin_memory=pin_memory 28 | ) 29 | 30 | testset = CIFAR10(root=args.data_dir, train=False, download=True, transform=transform_test) 31 | self.loader_test = DataLoader( 32 | testset, batch_size=args.eval_batch_size, shuffle=False, 33 | num_workers=2, pin_memory=pin_memory) -------------------------------------------------------------------------------- /fista.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim import Optimizer 3 | 4 | class FISTA(Optimizer): 5 | def __init__(self, params, lr=1e-2, gamma=0.1): 6 | defaults = dict(lr=lr, gamma=gamma) 7 | super(FISTA, self).__init__(params, defaults) 8 | 9 | def step(self, decay=1, closure=None): 10 | loss = None 11 | 12 | if closure is not None: 13 | loss = closure() 14 | 15 | for group in self.param_groups: 16 | for p in group['params']: 17 | if p.grad is None: 18 | continue 19 | 20 | grad = p.grad.data 21 | state = self.state[p] 22 | 23 | if 'alpha' not in state or decay: 24 | state['alpha'] = torch.ones_like(p.data) 25 | state['data'] = p.data 26 | y = p.data 27 | else: 28 | alpha = state['alpha'] 29 | data = state['data'] 30 | state['alpha'] = (1 + (1 + 4 * alpha**2).sqrt()) / 2 31 | y = p.data + ((alpha - 1) / state['alpha']) * (p.data - data) 32 | state['data'] = p.data 33 | 34 | mom = y - group['lr'] * grad 35 | p.data = self._prox(mom, group['lr'] * group['gamma']) 36 | 37 | # no-negative 38 | p.data = torch.max(p.data, torch.zeros_like(p.data)) 39 | 40 | return loss 41 | 42 | def _prox(self, x, gamma): 43 | y = torch.max(torch.abs(x) - gamma, torch.zeros_like(x)) 44 | 45 | return torch.sign(x) * y 46 | -------------------------------------------------------------------------------- /utils/preprocess.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("..") 3 | 4 | import re 5 | import numpy as np 6 | from collections import OrderedDict 7 | from model import resnet_56, resnet_56_sparse 8 | import torch 9 | 10 | def prune_resnet(args, state_dict): 11 | thre = args.thre 12 | num_layers = int(args.student_model.split('_')[1]) 13 | n = (num_layers - 2) // 6 14 | layers = np.arange(0, 3*n ,n) 15 | 16 | mask_block = [] 17 | for name, weight in state_dict.items(): 18 | if 'mask' in name: 19 | mask_block.append(weight.item()) 20 | 21 | pruned_num = sum(m <= thre for m in mask_block) 22 | pruned_blocks = [int(m) for m in np.argwhere(np.array(mask_block) <= thre)] 23 | 24 | old_block = 0 25 | layer = 'layer1' 26 | layer_num = int(layer[-1]) 27 | new_block = 0 28 | new_state_dict = OrderedDict() 29 | 30 | for key, value in state_dict.items(): 31 | if 'layer' in key: 32 | if key.split('.')[0] != layer: 33 | layer = key.split('.')[0] 34 | layer_num = int(layer[-1]) 35 | new_block = 0 36 | 37 | if key.split('.')[1] != old_block: 38 | old_block = key.split('.')[1] 39 | 40 | if mask_block[layers[layer_num-1] + int(old_block)] == 0: 41 | if layer_num != 1 and old_block == '0' and 'mask' in key: 42 | new_block = 1 43 | continue 44 | 45 | new_key = re.sub(r'\.\d+\.', '.{}.'.format(new_block), key, 1) 46 | if 'mask' in new_key: new_block += 1 47 | 48 | new_state_dict[new_key] = state_dict[key] 49 | 50 | else: 51 | new_state_dict[key] = state_dict[key] 52 | 53 | model = resnet_56_sparse(has_mask=mask_block).to(args.gpus[0]) 54 | 55 | print('\n---- After Prune ----\n') 56 | print(f"Pruned / Total: {pruned_num} / {len(mask_block)}") 57 | print("Pruned blocks", pruned_blocks) 58 | 59 | save_dir = f'{args.job_dir}/pruned.pt' 60 | print(f'Saving pruned model to {save_dir}...') 61 | 62 | save_state_dict = {} 63 | save_state_dict['state_dict_s'] = new_state_dict 64 | save_state_dict['mask'] = mask_block 65 | torch.save(save_state_dict, save_dir) 66 | 67 | if not args.random: 68 | model.load_state_dict(new_state_dict) 69 | 70 | return model 71 | 72 | -------------------------------------------------------------------------------- /utils/common.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import datetime 3 | import shutil 4 | from pathlib import Path 5 | import pdb 6 | import os 7 | 8 | import torch 9 | import functools 10 | 11 | class AverageMeter(object): 12 | """Computes and stores the average and current value""" 13 | def __init__(self): 14 | self.reset() 15 | 16 | def reset(self): 17 | self.val = 0.0 18 | self.avg = 0.0 19 | self.sum = 0.0 20 | self.count = 0 21 | 22 | def update(self, val, n=1): 23 | self.val = val 24 | self.sum += val * n 25 | self.count += n 26 | self.avg = self.sum / self.count 27 | 28 | class checkpoint(): 29 | def __init__(self, args): 30 | now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S') 31 | today = datetime.date.today() 32 | 33 | self.args = args 34 | self.job_dir = Path(args.job_dir) 35 | self.ckpt_dir = self.job_dir / 'checkpoint' 36 | self.run_dir = self.job_dir / 'run' 37 | 38 | if args.reset: 39 | os.system('rm -rf ' + args.job_dir) 40 | 41 | def _make_dir(path): 42 | if not os.path.exists(path): os.makedirs(path) 43 | 44 | _make_dir(self.job_dir) 45 | _make_dir(self.ckpt_dir) 46 | _make_dir(self.run_dir) 47 | 48 | config_dir = self.job_dir / 'config.txt' 49 | with open(config_dir, 'w') as f: 50 | f.write(now + '\n\n') 51 | for arg in vars(args): 52 | f.write('{}: {}\n'.format(arg, getattr(args, arg))) 53 | f.write('\n') 54 | 55 | def save_model(self, state, epoch, is_best): 56 | save_path = f'{self.ckpt_dir}/model_{epoch}.pt' 57 | print('=> Saving model to {}'.format(save_path)) 58 | torch.save(state, save_path) 59 | if is_best: 60 | shutil.copyfile(save_path, f'{self.ckpt_dir}/model_best.pt') 61 | 62 | def accuracy(output, target, topk=(1,)): 63 | """Computes the precision@k for the specified values of k""" 64 | with torch.no_grad(): 65 | maxk = max(topk) 66 | batch_size = target.size(0) 67 | 68 | _, pred = output.topk(maxk, 1, True, True) 69 | pred = pred.t() 70 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 71 | 72 | res = [] 73 | for k in topk: 74 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 75 | res.append(correct_k.mul_(100.0 / batch_size)) 76 | return res 77 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Towards Optimal Structured CNN Pruning via Generative Adversarial Learning(GAL) 2 | 3 | PyTorch implementation for GAL. 4 | 5 | 6 | 7 | ![GAL-framework](https://user-images.githubusercontent.com/47294246/54805147-021eb500-4cb1-11e9-85ac-861ecbada3e1.png) 8 | 9 | An illustration of GAL. Blue solid block, branch and channel elements are active, while red dotted elements are inactive and can be pruned since their corresponding scaling factors in the soft mask are 0. 10 | 11 | 12 | 13 | ## Abstract 14 | 15 | Structured pruning of filters or neurons has received increased focus for compressing convolutional neural networks. Most existing methods rely on multi-stage optimizations in a layer-wise manner for iteratively pruning and retraining which may not be optimal and may be computation intensive. Besides, these methods are designed for pruning a specific structure, such as filter or block structures without jointly pruning heterogeneous structures. In this paper, we propose an effective structured pruning approach that jointly prunes filters as well as other structures in an end-to-end manner. To accomplish this, we first introduce a soft mask to scale the output of these structures by defining a new objective function with sparsity regularization to align the output of baseline and network with this mask. We then effectively solve the optimization problem by generative adversarial learning (GAL), which learns a sparse soft mask in a label-free and an end-to-end manner. By forcing more scaling factors in the soft mask to zero, the fast iterative shrinkage-thresholding algorithm (FISTA) can be leveraged to fast and reliably remove the corresponding structures. Extensive experiments demonstrate the effectiveness of GAL on different datasets, including MNIST, CIFAR-10 and ImageNet ILSVRC 2012. For example, on ImageNet ILSVRC 2012, the pruned ResNet-50 achieves 10.88% Top-5 error and results in a factor of 3.7x speedup. This significantly outperforms state-of-the-art methods. 16 | 17 | 18 | 19 | ## Citation 20 | If you find GAL useful in your research, please consider citing: 21 | 22 | ``` 23 | @inproceedings{lin2019towards, 24 | title = {Towards Optimal Structured CNN Pruning via Generative Adversarial Learning}, 25 | author = {Lin, Shaohui and Ji, Rongrong and Yan, Chenqian and Zhang, Baochang and Cao, Liujuan and Ye, Qixiang and Huang, Feiyue and Doermann, David}, 26 | booktitle = {Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 27 | year = {2019} 28 | } 29 | ``` 30 | 31 | 32 | 33 | ## Running Code 34 | 35 | In this code, you can run our Resnet-56 model on CIFAR10 dataset. The code has been tested by Python 3.6, [Pytorch 0.4.1](https://pytorch.org/) and CUDA 9.0 on Ubuntu 16.04. 36 | 37 | 38 | 39 | ### Run examples 40 | 41 | The scripts of training and fine-tuning are provided in the `run.sh`, please kindly uncomment the appropriate line in `run.sh` to execute the training and fine-tuning. 42 | 43 | ```shell 44 | sh run.sh 45 | ``` 46 | 47 | 48 | 49 | **For training**, change the `teacher_dir` to the place where the pretrained model is located. 50 | 51 | ```shell 52 | # run.sh 53 | python main.py --teacher_dir [pre-trained model dir] 54 | ``` 55 | 56 | The pruned model will be named `pruned.pt` 57 | 58 | 59 | 60 | **For fine-tuning**, change the `refine` to the place where the pruned model is allowed to be fine-tuned. 61 | 62 | ```shell 63 | # run.sh 64 | python finetune.py --refine [pruned model dir] 65 | ``` 66 | 67 | You can set `--pruned` to reuse the `pruned.pt`. If you want to initiate weights randomly, just set `--random`. 68 | 69 | 70 | 71 | We also provide our [baseline](https://drive.google.com/open?id=1XHNxyFklGjvzNpTjzlkjpKc61-LLjt5T) model. Enjoy your training and testing! 72 | 73 | 74 | 75 | ## Tips 76 | 77 | If you find any problems, please feel free to contact to the authors (shaohuilin007@gmail.com or Im.cqyan@gmail.com). -------------------------------------------------------------------------------- /utils/options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pdb 4 | parser = argparse.ArgumentParser(description='Generatvie Adversarial Learning') 5 | 6 | ## Warm-up 7 | parser.add_argument( 8 | '--gpus', 9 | type=int, 10 | nargs='+', 11 | default=[0], 12 | help='Select gpu to use') 13 | parser.add_argument( 14 | '--dataset', 15 | type=str, 16 | default='cifar10', 17 | help='Dataset to train') 18 | parser.add_argument( 19 | '--data_dir', 20 | type=str, 21 | default='data/', 22 | help='The directory where the input data is stored.') 23 | parser.add_argument( 24 | '--job_dir', 25 | type=str, 26 | default='experiments/', 27 | help='The directory where the summaries will be stored.') 28 | parser.add_argument( 29 | '--teacher_dir', 30 | type=str, 31 | default='pretrained/', 32 | help='The directory where the teacher model saved.') 33 | parser.add_argument( 34 | '--reset', 35 | action='store_true', 36 | help='Reset the directory?') 37 | parser.add_argument( 38 | '--resume', 39 | type=str, 40 | default=None, 41 | help='Load the model from the specified checkpoint.') 42 | parser.add_argument( 43 | '--refine', 44 | type=str, 45 | default=None, 46 | help='Path to the model to be fine tuned.') 47 | 48 | ## Training 49 | parser.add_argument( 50 | '--student_model', 51 | type=str, 52 | default='resnet_56_sparse', 53 | help='The model of student.') 54 | parser.add_argument( 55 | '--teacher_model', 56 | type=str, 57 | default='resnet_56', 58 | help='The model of teacher.') 59 | parser.add_argument( 60 | '--num_epochs', 61 | type=int, 62 | default=100, 63 | help='The num of epochs to train.') 64 | parser.add_argument( 65 | '--train_batch_size', 66 | type=int, 67 | default=128, 68 | help='Batch size for training.') 69 | parser.add_argument( 70 | '--eval_batch_size', 71 | type=int, 72 | default=100, 73 | help='Batch size for validation.') 74 | parser.add_argument( 75 | '--momentum', 76 | type=float, 77 | default=0.9, 78 | help='Momentum for MomentumOptimizer.') 79 | parser.add_argument( 80 | '--lr', 81 | type=float, 82 | default=1e-2 83 | ) 84 | parser.add_argument( 85 | '--lr_decay_step', 86 | type=int, 87 | default=30 88 | ) 89 | parser.add_argument( 90 | '--mask_step', 91 | type=int, 92 | default=200, 93 | help='The frequency of mask to update' 94 | ) 95 | parser.add_argument( 96 | '--weight_decay', 97 | type=float, 98 | default=2e-4, 99 | help='The weight decay of loss.') 100 | parser.add_argument( 101 | '--miu', 102 | type=float, 103 | default=1, 104 | help='The miu of data loss.') 105 | parser.add_argument( 106 | '--lambda', 107 | dest='sparse_lambda', 108 | type=float, 109 | default=0.6, 110 | help='The sparse lambda for l1 loss') 111 | parser.add_argument( 112 | '--random', 113 | action='store_true', 114 | help='Random weight initialize for finetune') 115 | parser.add_argument( 116 | '--pruned', 117 | action='store_true', 118 | help='Load pruned model') 119 | parser.add_argument( 120 | '--thre', 121 | type=float, 122 | default=0.0, 123 | help='Thred of mask to be pruned') 124 | parser.add_argument( 125 | '--keep_grad', 126 | action='store_true', 127 | help='Keep gradients of mask for finetune') 128 | 129 | ## Status 130 | parser.add_argument( 131 | '--print_freq', 132 | type=int, 133 | default=200, 134 | help='The frequency to print loss.') 135 | parser.add_argument( 136 | '--test_only', 137 | action='store_true', 138 | help='Test only?') 139 | 140 | args = parser.parse_args() 141 | 142 | if args.resume is not None and not os.path.isfile(args.resume): 143 | raise ValueError('No checkpoint found at {} to resume'.format(args.resume)) 144 | 145 | if args.refine is not None and not os.path.isfile(args.refine): 146 | raise ValueError('No checkpoint found at {} to refine'.format(args.refine)) 147 | 148 | -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import time 4 | import utils.common as utils 5 | from importlib import import_module 6 | from tensorboardX import SummaryWriter 7 | from collections import OrderedDict 8 | import torch.nn.functional as F 9 | import numpy as np 10 | import collections 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.optim as optim 15 | from torch.optim.lr_scheduler import StepLR 16 | from utils.options import args 17 | from utils.preprocess import prune_resnet 18 | from model import resnet_56_sparse 19 | 20 | def main(): 21 | start_epoch = 0 22 | best_prec1, best_prec5 = 0.0, 0.0 23 | 24 | ckpt = utils.checkpoint(args) 25 | writer_train = SummaryWriter(args.job_dir + '/run/train') 26 | writer_test = SummaryWriter(args.job_dir + '/run/test') 27 | 28 | # Data loading 29 | print('=> Preparing data..') 30 | loader = import_module('data.' + args.dataset).Data(args) 31 | 32 | # Create model 33 | print('=> Building model...') 34 | criterion = nn.CrossEntropyLoss() 35 | 36 | # Fine tune from a checkpoint 37 | refine = args.refine 38 | assert refine is not None, 'refine is required' 39 | checkpoint = torch.load(refine, map_location=torch.device(f"cuda:{args.gpus[0]}")) 40 | 41 | if args.pruned: 42 | mask = checkpoint['mask'] 43 | pruned = sum([1 for m in mask if mask == 0]) 44 | print(f"Pruned / Total: {pruned} / {len(mask)}") 45 | model = resnet_56_sparse(has_mask = mask).to(args.gpus[0]) 46 | model.load_state_dict(checkpoint['state_dict_s']) 47 | else: 48 | model = prune_resnet(args, checkpoint['state_dict_s']) 49 | 50 | test_prec1, test_prec5 = test(args, loader.loader_test, model, criterion, writer_test) 51 | print(f"Simply test after prune {test_prec1:.3f}") 52 | 53 | if args.test_only: 54 | return 55 | 56 | if args.keep_grad: 57 | for name, weight in model.named_parameters(): 58 | if 'mask' in name: 59 | weight.requires_grad = False 60 | 61 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum,weight_decay=args.weight_decay) 62 | scheduler = StepLR(optimizer, step_size=args.lr_decay_step, gamma=0.1) 63 | 64 | resume = args.resume 65 | if resume: 66 | print('=> Loading checkpoint {}'.format(resume)) 67 | checkpoint = torch.load(resume, map_location=torch.device(f"cuda:{args.gpus[0]}")) 68 | start_epoch = checkpoint['epoch'] 69 | best_prec1 = checkpoint['best_prec1'] 70 | model.load_state_dict(checkpoint['state_dict']) 71 | optimizer.load_state_dict(checkpoint['optimizer']) 72 | scheduler.load_state_dict(checkpoint['scheduler']) 73 | print('=> Continue from epoch {}...'.format(start_epoch)) 74 | 75 | for epoch in range(start_epoch, args.num_epochs): 76 | scheduler.step(epoch) 77 | 78 | train(args, loader.loader_train, model, criterion, optimizer, writer_train, epoch) 79 | test_prec1, test_prec5 = test(args, loader.loader_test, model, criterion, writer_test, epoch) 80 | 81 | is_best = best_prec1 < test_prec1 82 | best_prec1 = max(test_prec1, best_prec1) 83 | best_prec5 = max(test_prec5, best_prec5) 84 | 85 | state = { 86 | 'state_dict_s': model.state_dict(), 87 | 'best_prec1': best_prec1, 88 | 'best_prec5': best_prec5, 89 | 'optimizer': optimizer.state_dict(), 90 | 'scheduler': scheduler.state_dict(), 91 | 'epoch': epoch + 1 92 | } 93 | 94 | ckpt.save_model(state, epoch + 1, is_best) 95 | 96 | print(f"=> Best @prec1: {best_prec1:.3f} @prec5: {best_prec5:.3f}") 97 | 98 | def train(args, loader_train, model, criterion, optimizer, writer_train, epoch): 99 | losses = utils.AverageMeter() 100 | top1 = utils.AverageMeter() 101 | top5 = utils.AverageMeter() 102 | 103 | model.train() 104 | num_iterations = len(loader_train) 105 | 106 | for i, (inputs, targets) in enumerate(loader_train, 1): 107 | 108 | inputs = inputs.to(args.gpus[0]) 109 | targets = targets.to(args.gpus[0]) 110 | 111 | logits = model(inputs) 112 | loss = criterion(logits, targets) 113 | 114 | prec1, prec5 = utils.accuracy(logits, targets, topk=(1, 5)) 115 | losses.update(loss.item(), inputs.size(0)) 116 | 117 | top1.update(prec1[0], inputs.size(0)) 118 | top5.update(prec5[0], inputs.size(0)) 119 | 120 | optimizer.zero_grad() 121 | loss.backward() 122 | optimizer.step() 123 | 124 | 125 | def test(args, loader_test, model, criterion, writer_test, epoch=0): 126 | losses = utils.AverageMeter() 127 | top1 = utils.AverageMeter() 128 | top5 = utils.AverageMeter() 129 | 130 | model.eval() 131 | num_iterations = len(loader_test) 132 | 133 | print("=> Evaluating...") 134 | with torch.no_grad(): 135 | for i, (inputs, targets) in enumerate(loader_test, 1): 136 | 137 | inputs = inputs.to(args.gpus[0]) 138 | targets = targets.to(args.gpus[0]) 139 | 140 | logits = model(inputs) 141 | loss = criterion(logits, targets) 142 | 143 | prec1, prec5 = utils.accuracy(logits, targets, topk=(1, 5)) 144 | losses.update(loss.item(), inputs.size(0)) 145 | top1.update(prec1[0], inputs.size(0)) 146 | top5.update(prec5[0], inputs.size(0)) 147 | 148 | print(f'* Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}') 149 | 150 | if not args.test_only: 151 | writer_test.add_scalar('test_top1', top1.avg, epoch) 152 | 153 | return top1.avg, top5.avg 154 | 155 | 156 | if __name__ == '__main__': 157 | main() 158 | -------------------------------------------------------------------------------- /model/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.parameter import Parameter 5 | from torch.distributions.normal import Normal 6 | 7 | norm_mean, norm_var = 0.0, 1.0 8 | 9 | def conv3x3(in_planes, out_planes, stride=1): 10 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 11 | padding=1, bias=False) 12 | 13 | class Mask(nn.Module): 14 | def __init__(self, init_value=[1], planes=None): 15 | super().__init__() 16 | self.planes = planes 17 | self.weight = Parameter(torch.Tensor(init_value)) 18 | 19 | def forward(self, input): 20 | weight = self.weight 21 | 22 | if self.planes is not None: 23 | weight = self.weight[None, :, None, None] 24 | 25 | return input * weight 26 | 27 | class LambdaLayer(nn.Module): 28 | def __init__(self, lambd): 29 | super(LambdaLayer, self).__init__() 30 | self.lambd = lambd 31 | 32 | def forward(self, x): 33 | return self.lambd(x) 34 | 35 | class ResBasicBlock(nn.Module): 36 | expansion = 1 37 | 38 | def __init__(self, inplanes, planes, stride=1): 39 | super(ResBasicBlock, self).__init__() 40 | self.inplanes = inplanes 41 | self.planes = planes 42 | self.conv1 = conv3x3(inplanes, planes, stride) 43 | self.bn1 = nn.BatchNorm2d(planes) 44 | self.relu = nn.ReLU(inplace=True) 45 | self.conv2 = conv3x3(planes, planes) 46 | self.bn2 = nn.BatchNorm2d(planes) 47 | self.stride = stride 48 | self.shortcut = nn.Sequential() 49 | if stride != 1 or inplanes != planes: 50 | self.shortcut = LambdaLayer(lambda x: F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0)) 51 | 52 | def forward(self, x): 53 | out = self.conv1(x) 54 | out = self.bn1(out) 55 | out = self.relu(out) 56 | 57 | out = self.conv2(out) 58 | out = self.bn2(out) 59 | 60 | out += self.shortcut(x) 61 | out = self.relu(out) 62 | 63 | return out 64 | 65 | class SparseResBasicBlock(nn.Module): 66 | expansion = 1 67 | 68 | def __init__(self, inplanes, planes, stride=1): 69 | super(SparseResBasicBlock, self).__init__() 70 | 71 | self.inplanes = inplanes 72 | self.planes = planes 73 | self.conv1 = conv3x3(inplanes, planes, stride) 74 | self.bn1 = nn.BatchNorm2d(planes) 75 | self.relu = nn.ReLU(inplace=True) 76 | self.conv2 = conv3x3(planes, planes) 77 | self.bn2 = nn.BatchNorm2d(planes) 78 | self.stride = stride 79 | self.shortcut = nn.Sequential() 80 | if stride != 1 or inplanes != planes: 81 | self.shortcut = LambdaLayer(lambda x: F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0)) 82 | 83 | m = Normal(torch.tensor([norm_mean]), torch.tensor([norm_var])).sample() 84 | self.mask = Mask(m) 85 | 86 | def forward(self, x): 87 | out = self.conv1(x) 88 | out = self.bn1(out) 89 | out = self.relu(out) 90 | 91 | out = self.conv2(out) 92 | out = self.bn2(out) 93 | 94 | out = self.mask(out) 95 | 96 | out += self.shortcut(x) 97 | out = self.relu(out) 98 | 99 | return out 100 | 101 | class ResNet(nn.Module): 102 | def __init__(self, block, num_layers, num_classes=10, has_mask=None): 103 | super(ResNet, self).__init__() 104 | assert (num_layers - 2) % 6 == 0, 'depth should be 6n+2' 105 | n = (num_layers - 2) // 6 106 | 107 | if has_mask is None : has_mask = [1]*3*n 108 | 109 | self.inplanes = 16 110 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) 111 | self.bn1 = nn.BatchNorm2d(self.inplanes) 112 | self.relu = nn.ReLU(inplace=True) 113 | 114 | self.layer1 = self._make_layer(block, 16, blocks=n, stride=1, has_mask=has_mask[0:n]) 115 | self.layer2 = self._make_layer(block, 32, blocks=n, stride=2, has_mask=has_mask[n:2*n]) 116 | self.layer3 = self._make_layer(block, 64, blocks=n, stride=2, has_mask=has_mask[2*n:3*n]) 117 | self.avgpool = nn.AdaptiveAvgPool2d(1) 118 | self.fc = nn.Linear(64 * block.expansion, num_classes) 119 | 120 | self.initialize() 121 | 122 | def initialize(self): 123 | for m in self.modules(): 124 | if isinstance(m, nn.Conv2d): 125 | nn.init.kaiming_normal_(m.weight) 126 | elif isinstance(m, nn.BatchNorm2d): 127 | nn.init.constant_(m.weight, 1) 128 | nn.init.constant_(m.bias, 0) 129 | 130 | def _make_layer(self, block, planes, blocks, stride, has_mask): 131 | layers = [] 132 | if has_mask[0] == 0 and (stride != 1 or self.inplanes != planes): 133 | layers.append(LambdaLayer(lambda x: F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))) 134 | if not has_mask[0] == 0: 135 | layers.append(block(self.inplanes, planes, stride)) 136 | 137 | self.inplanes = planes * block.expansion 138 | for i in range(1, blocks): 139 | if not has_mask[i] == 0: 140 | layers.append(block(self.inplanes, planes)) 141 | 142 | return nn.Sequential(*layers) 143 | 144 | def forward(self, x): 145 | x = self.conv1(x) 146 | x = self.bn1(x) 147 | x = self.relu(x) 148 | x = self.layer1(x) 149 | x = self.layer2(x) 150 | x = self.layer3(x) 151 | x = self.avgpool(x) 152 | x = x.view(x.size(0), -1) 153 | x = self.fc(x) 154 | 155 | return x 156 | 157 | def resnet_56(**kwargs): 158 | return ResNet(ResBasicBlock, 56, **kwargs) 159 | 160 | def resnet_56_sparse(**kwargs): 161 | return ResNet(SparseResBasicBlock, 56, **kwargs) 162 | 163 | def resnet_110(**kwargs): 164 | return ResNet(ResBasicBlock, 110, **kwargs) 165 | 166 | def resnet_110_sparse(**kwargs): 167 | return ResNet(SparseResBasicBlock, 110, **kwargs) 168 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import utils.common as utils 3 | from utils.options import args 4 | from utils.preprocess import prune_resnet 5 | from tensorboardX import SummaryWriter 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | import torch.nn.functional as F 11 | from torch.optim.lr_scheduler import StepLR 12 | 13 | from fista import FISTA 14 | from model import Discriminator, resnet_56, resnet_56_sparse 15 | from data import cifar10 16 | 17 | def main(): 18 | checkpoint = utils.checkpoint(args) 19 | writer_train = SummaryWriter(args.job_dir + '/run/train') 20 | writer_test = SummaryWriter(args.job_dir + '/run/test') 21 | 22 | start_epoch = 0 23 | best_prec1 = 0.0 24 | best_prec5 = 0.0 25 | 26 | # Data loading 27 | print('=> Preparing data..') 28 | loader = cifar10(args) 29 | 30 | # Create model 31 | print('=> Building model...') 32 | model_t = resnet_56().to(args.gpus[0]) 33 | 34 | # Load teacher model 35 | ckpt_t = torch.load(args.teacher_dir, map_location=torch.device(f"cuda:{args.gpus[0]}")) 36 | state_dict_t = ckpt_t['state_dict'] 37 | model_t.load_state_dict(state_dict_t) 38 | model_t = model_t.to(args.gpus[0]) 39 | 40 | for para in list(model_t.parameters())[:-2]: 41 | para.requires_grad = False 42 | 43 | model_s = resnet_56_sparse().to(args.gpus[0]) 44 | 45 | model_dict_s = model_s.state_dict() 46 | model_dict_s.update(state_dict_t) 47 | model_s.load_state_dict(model_dict_s) 48 | 49 | if len(args.gpus) != 1: 50 | model_s = nn.DataParallel(model_s, device_ids=args.gpus) 51 | 52 | model_d = Discriminator().to(args.gpus[0]) 53 | 54 | models = [model_t, model_s, model_d] 55 | 56 | optimizer_d = optim.SGD(model_d.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 57 | 58 | param_s = [param for name, param in model_s.named_parameters() if 'mask' not in name] 59 | param_m = [param for name, param in model_s.named_parameters() if 'mask' in name] 60 | 61 | optimizer_s = optim.SGD(param_s, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 62 | optimizer_m = FISTA(param_m, lr=args.lr, gamma=args.sparse_lambda) 63 | 64 | scheduler_d = StepLR(optimizer_d, step_size=args.lr_decay_step, gamma=0.1) 65 | scheduler_s = StepLR(optimizer_s, step_size=args.lr_decay_step, gamma=0.1) 66 | scheduler_m = StepLR(optimizer_m, step_size=args.lr_decay_step, gamma=0.1) 67 | 68 | resume = args.resume 69 | if resume: 70 | print('=> Resuming from ckpt {}'.format(resume)) 71 | ckpt = torch.load(resume, map_location=torch.device(f"cuda:{args.gpus[0]}")) 72 | best_prec1 = ckpt['best_prec1'] 73 | start_epoch = ckpt['epoch'] 74 | model_s.load_state_dict(ckpt['state_dict_s']) 75 | model_d.load_state_dict(ckpt['state_dict_d']) 76 | optimizer_d.load_state_dict(ckpt['optimizer_d']) 77 | optimizer_s.load_state_dict(ckpt['optimizer_s']) 78 | optimizer_m.load_state_dict(ckpt['optimizer_m']) 79 | scheduler_d.load_state_dict(ckpt['scheduler_d']) 80 | scheduler_s.load_state_dict(ckpt['scheduler_s']) 81 | scheduler_m.load_state_dict(ckpt['scheduler_m']) 82 | print('=> Continue from epoch {}...'.format(start_epoch)) 83 | 84 | optimizers = [optimizer_d, optimizer_s, optimizer_m] 85 | schedulers = [scheduler_d, scheduler_s, scheduler_m] 86 | 87 | if args.test_only: 88 | test_prec1, test_prec5 = test(args, loader.loader_test, model_s) 89 | print('=> Test Prec@1: {:.2f}'.format(test_prec1)) 90 | return 91 | 92 | for epoch in range(start_epoch, args.num_epochs): 93 | for s in schedulers: 94 | s.step(epoch) 95 | 96 | train(args, loader.loader_train, models, optimizers, epoch, writer_train) 97 | test_prec1, test_prec5 = test(args, loader.loader_test, model_s) 98 | 99 | is_best = best_prec1 < test_prec1 100 | best_prec1 = max(test_prec1, best_prec1) 101 | best_prec5 = max(test_prec5, best_prec5) 102 | 103 | model_state_dict = model_s.module.state_dict() if len(args.gpus) > 1 else model_s.state_dict() 104 | 105 | state = { 106 | 'state_dict_s': model_state_dict, 107 | 'state_dict_d': model_d.state_dict(), 108 | 'best_prec1': best_prec1, 109 | 'best_prec5': best_prec5, 110 | 'optimizer_d': optimizer_d.state_dict(), 111 | 'optimizer_s': optimizer_s.state_dict(), 112 | 'optimizer_m': optimizer_m.state_dict(), 113 | 'scheduler_d': scheduler_d.state_dict(), 114 | 'scheduler_s': scheduler_s.state_dict(), 115 | 'scheduler_m': scheduler_m.state_dict(), 116 | 'epoch': epoch + 1 117 | } 118 | checkpoint.save_model(state, epoch + 1, is_best) 119 | 120 | print(f"=> Best @prec1: {best_prec1:.3f} @prec5: {best_prec5:.3f}") 121 | 122 | best_model = torch.load(f'{args.job_dir}/checkpoint/model_best.pt', map_location=torch.device(f"cuda:{args.gpus[0]}")) 123 | 124 | model = prune_resnet(args, best_model['state_dict_s']) 125 | 126 | 127 | def train(args, loader_train, models, optimizers, epoch, writer_train): 128 | losses_d = utils.AverageMeter() 129 | losses_data = utils.AverageMeter() 130 | losses_g = utils.AverageMeter() 131 | losses_sparse = utils.AverageMeter() 132 | top1 = utils.AverageMeter() 133 | top5 = utils.AverageMeter() 134 | 135 | model_t = models[0] 136 | model_s = models[1] 137 | model_d = models[2] 138 | 139 | bce_logits = nn.BCEWithLogitsLoss() 140 | 141 | optimizer_d = optimizers[0] 142 | optimizer_s = optimizers[1] 143 | optimizer_m = optimizers[2] 144 | 145 | # switch to train mode 146 | model_d.train() 147 | model_s.train() 148 | 149 | num_iterations = len(loader_train) 150 | 151 | real_label = 1 152 | fake_label = 0 153 | 154 | for i, (inputs, targets) in enumerate(loader_train, 1): 155 | num_iters = num_iterations * epoch + i 156 | 157 | inputs = inputs.to(args.gpus[0]) 158 | targets = targets.to(args.gpus[0]) 159 | 160 | features_t = model_t(inputs) 161 | features_s = model_s(inputs) 162 | 163 | ############################ 164 | # (1) Update D network 165 | ########################### 166 | 167 | for p in model_d.parameters(): 168 | p.requires_grad = True 169 | 170 | optimizer_d.zero_grad() 171 | 172 | output_t = model_d(features_t.detach()) 173 | 174 | labels_real = torch.full_like(output_t, real_label, device=args.gpus[0]) 175 | error_real = bce_logits(output_t, labels_real) 176 | 177 | output_s = model_d(features_s.to(args.gpus[0]).detach()) 178 | 179 | labels_fake = torch.full_like(output_t, fake_label, device=args.gpus[0]) 180 | error_fake = bce_logits(output_s, labels_fake) 181 | 182 | error_d = error_real + error_fake 183 | 184 | labels = torch.full_like(output_s, real_label, device=args.gpus[0]) 185 | error_d += bce_logits(output_s, labels) 186 | 187 | error_d.backward() 188 | losses_d.update(error_d.item(), inputs.size(0)) 189 | writer_train.add_scalar( 190 | 'discriminator_loss', error_d.item(), num_iters) 191 | 192 | optimizer_d.step() 193 | 194 | if i % args.print_freq == 0: 195 | print( 196 | '=> D_Epoch[{0}]({1}/{2}):\t' 197 | 'Loss_d {loss_d.val:.4f} ({loss_d.avg:.4f})\t'.format( 198 | epoch, i, num_iterations, loss_d=losses_d)) 199 | 200 | ############################ 201 | # (2) Update student network 202 | ########################### 203 | 204 | for p in model_d.parameters(): 205 | p.requires_grad = False 206 | 207 | optimizer_s.zero_grad() 208 | optimizer_m.zero_grad() 209 | 210 | error_data = args.miu * F.mse_loss(features_t, features_s.to(args.gpus[0])) 211 | 212 | losses_data.update(error_data.item(), inputs.size(0)) 213 | writer_train.add_scalar( 214 | 'data_loss', error_data.item(), num_iters) 215 | error_data.backward(retain_graph=True) 216 | 217 | # fool discriminator 218 | output_s = model_d(features_s.to(args.gpus[0])) 219 | 220 | labels = torch.full_like(output_s, real_label, device=args.gpus[0]) 221 | error_g = bce_logits(output_s, labels) 222 | losses_g.update(error_g.item(), inputs.size(0)) 223 | writer_train.add_scalar( 224 | 'generator_loss', error_g.item(), num_iters) 225 | error_g.backward(retain_graph=True) 226 | 227 | # train mask 228 | mask = [] 229 | for name, param in model_s.named_parameters(): 230 | if 'mask' in name: 231 | mask.append(param.view(-1)) 232 | mask = torch.cat(mask) 233 | error_sparse = args.sparse_lambda * F.l1_loss(mask, torch.zeros(mask.size()).to(args.gpus[0]), reduction='sum') 234 | error_sparse.backward() 235 | 236 | losses_sparse.update(error_sparse.item(), inputs.size(0)) 237 | writer_train.add_scalar( 238 | 'sparse_loss', error_sparse.item(), num_iters) 239 | 240 | optimizer_s.step() 241 | 242 | decay = (epoch % args.lr_decay_step == 0 and i == 1) 243 | if i % args.mask_step == 0: 244 | optimizer_m.step(decay) 245 | 246 | prec1, prec5 = utils.accuracy(features_s.to(args.gpus[0]), targets.to(args.gpus[0]), topk=(1, 5)) 247 | top1.update(prec1[0], inputs.size(0)) 248 | top5.update(prec5[0], inputs.size(0)) 249 | 250 | if i % args.print_freq == 0: 251 | print( 252 | '=> G_Epoch[{0}]({1}/{2}):\t' 253 | 'Loss_sparse {loss_sparse.val:.4f} ({loss_sparse.avg:.4f})\t' 254 | 'Loss_data {loss_data.val:.4f} ({loss_data.avg:.4f})\t' 255 | 'Loss_d {loss_d.val:.4f} ({loss_d.avg:.4f})\t' 256 | 'Loss_g {loss_g.val:.4f} ({loss_g.avg:.4f})\t' 257 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 258 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 259 | epoch, i, num_iterations, loss_sparse=losses_sparse, loss_data=losses_data, loss_g=losses_g, loss_d=losses_d, top1=top1, top5=top5)) 260 | 261 | def test(args, loader_test, model_s): 262 | losses = utils.AverageMeter() 263 | top1 = utils.AverageMeter() 264 | top5 = utils.AverageMeter() 265 | 266 | cross_entropy = nn.CrossEntropyLoss() 267 | 268 | # switch to eval mode 269 | model_s.eval() 270 | 271 | with torch.no_grad(): 272 | for i, (inputs, targets) in enumerate(loader_test, 1): 273 | 274 | inputs = inputs.to(args.gpus[0]) 275 | targets = targets.to(args.gpus[0]) 276 | 277 | logits = model_s(inputs).to(args.gpus[0]) 278 | loss = cross_entropy(logits, targets) 279 | 280 | prec1, prec5 = utils.accuracy(logits, targets, topk=(1, 5)) 281 | losses.update(loss.item(), inputs.size(0)) 282 | top1.update(prec1[0], inputs.size(0)) 283 | top5.update(prec5[0], inputs.size(0)) 284 | 285 | 286 | print('* Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' 287 | .format(top1=top1, top5=top5)) 288 | 289 | mask = [] 290 | for name, weight in model_s.named_parameters(): 291 | if 'mask' in name: 292 | mask.append(weight.item()) 293 | 294 | print("* Pruned {} / {}".format(sum(m == 0 for m in mask), len(mask))) 295 | 296 | return top1.avg, top5.avg 297 | 298 | if __name__ == '__main__': 299 | main() 300 | 301 | 302 | --------------------------------------------------------------------------------