├── .gitignore ├── LICENSE ├── README.md ├── main.py ├── model.py ├── optimizer.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .idea 3 | data -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2019, Minho Ryu 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-admm-prunning 2 | It is a pytorch implementation of DNN weight prunning with ADMM described in [**A Systematic DNN Weight Pruning Framework using Alternating Direction Method of Multipliers**](https://arxiv.org/abs/1804.03294). 3 | 4 | ## _Train and test_ 5 | - You can simply run code by 6 | ``` 7 | $ python main.py 8 | ``` 9 | 10 | - In the paper, authors use **l2-norm regularization** so you can easily add by 11 | ``` 12 | $ python main.py --l2 13 | ``` 14 | 15 | - Beyond this paper, if you don't want to use _predefined prunning ratio_, admm with **l1 norm regularization** can give a great solution and can be simply tested by 16 | ``` 17 | $ python main.py --l1 18 | ``` 19 | 20 | - There are two dataset you can test in this code: **[mnist, cifar10]**. Default setting is mnist, you can change dataset by 21 | ``` 22 | $ python main.py --dataset cifar10 23 | ``` 24 | 25 | ## _Models_ 26 | - In this code, there are two models: **[LeNet, AlexNet]**. I use LeNet for mnist, AlexNet for cifar10 by default. 27 | 28 | ## _Optimizer_ 29 | - To prevent prunned weights from updated by optimizer, I modified Adam (named PruneAdam). 30 | 31 | ## _References_ 32 | For this repository, I refer to _[KaiqiZhang's tensorflow implementation](https://github.com/KaiqiZhang/admm-pruning)_. 33 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import torch 4 | import torch.nn.functional as F 5 | from optimizer import PruneAdam 6 | from model import LeNet, AlexNet 7 | from utils import regularized_nll_loss, admm_loss, \ 8 | initialize_Z_and_U, update_X, update_Z, update_Z_l1, update_U, \ 9 | print_convergence, print_prune, apply_prune, apply_l1_prune 10 | from torchvision import datasets, transforms 11 | from tqdm import tqdm 12 | 13 | 14 | def train(args, model, device, train_loader, test_loader, optimizer): 15 | for epoch in range(args.num_pre_epochs): 16 | print('Pre epoch: {}'.format(epoch + 1)) 17 | model.train() 18 | for batch_idx, (data, target) in enumerate(tqdm(train_loader)): 19 | data, target = data.to(device), target.to(device) 20 | optimizer.zero_grad() 21 | output = model(data) 22 | loss = regularized_nll_loss(args, model, output, target) 23 | loss.backward() 24 | optimizer.step() 25 | test(args, model, device, test_loader) 26 | 27 | Z, U = initialize_Z_and_U(model) 28 | for epoch in range(args.num_epochs): 29 | model.train() 30 | print('Epoch: {}'.format(epoch + 1)) 31 | for batch_idx, (data, target) in enumerate(tqdm(train_loader)): 32 | data, target = data.to(device), target.to(device) 33 | optimizer.zero_grad() 34 | output = model(data) 35 | loss = admm_loss(args, device, model, Z, U, output, target) 36 | loss.backward() 37 | optimizer.step() 38 | X = update_X(model) 39 | Z = update_Z_l1(X, U, args) if args.l1 else update_Z(X, U, args) 40 | U = update_U(U, X, Z) 41 | print_convergence(model, X, Z) 42 | test(args, model, device, test_loader) 43 | 44 | 45 | def test(args, model, device, test_loader): 46 | model.eval() 47 | test_loss = 0 48 | correct = 0 49 | with torch.no_grad(): 50 | for data, target in test_loader: 51 | data, target = data.to(device), target.to(device) 52 | output = model(data) 53 | test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss 54 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability 55 | correct += pred.eq(target.view_as(pred)).sum().item() 56 | 57 | test_loss /= len(test_loader.dataset) 58 | 59 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 60 | test_loss, correct, len(test_loader.dataset), 61 | 100. * correct / len(test_loader.dataset))) 62 | 63 | 64 | def retrain(args, model, mask, device, train_loader, test_loader, optimizer): 65 | for epoch in range(args.num_re_epochs): 66 | print('Re epoch: {}'.format(epoch + 1)) 67 | model.train() 68 | for batch_idx, (data, target) in enumerate(tqdm(train_loader)): 69 | data, target = data.to(device), target.to(device) 70 | optimizer.zero_grad() 71 | output = model(data) 72 | loss = F.nll_loss(output, target) 73 | loss.backward() 74 | optimizer.prune_step(mask) 75 | 76 | test(args, model, device, test_loader) 77 | 78 | 79 | def main(): 80 | # Training settings 81 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 82 | parser.add_argument('--dataset', type=str, default="mnist", choices=["mnist", "cifar10"], 83 | metavar='D', help='training dataset (mnist or cifar10)') 84 | parser.add_argument('--batch-size', type=int, default=64, metavar='N', 85 | help='input batch size for training (default: 64)') 86 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', 87 | help='input batch size for testing (default: 1000)') 88 | parser.add_argument('--percent', type=list, default=[0.8, 0.92, 0.991, 0.93], 89 | metavar='P', help='pruning percentage (default: 0.8)') 90 | parser.add_argument('--alpha', type=float, default=5e-4, metavar='L', 91 | help='l2 norm weight (default: 5e-4)') 92 | parser.add_argument('--rho', type=float, default=1e-2, metavar='R', 93 | help='cardinality weight (default: 1e-2)') 94 | parser.add_argument('--l1', default=False, action='store_true', 95 | help='prune weights with l1 regularization instead of cardinality') 96 | parser.add_argument('--l2', default=False, action='store_true', 97 | help='apply l2 regularization') 98 | parser.add_argument('--num_pre_epochs', type=int, default=3, metavar='P', 99 | help='number of epochs to pretrain (default: 3)') 100 | parser.add_argument('--num_epochs', type=int, default=10, metavar='N', 101 | help='number of epochs to train (default: 10)') 102 | parser.add_argument('--num_re_epochs', type=int, default=3, metavar='R', 103 | help='number of epochs to retrain (default: 3)') 104 | parser.add_argument('--lr', type=float, default=1e-3, metavar='LR', 105 | help='learning rate (default: 1e-2)') 106 | parser.add_argument('--adam_epsilon', type=float, default=1e-8, metavar='E', 107 | help='adam epsilon (default: 1e-8)') 108 | parser.add_argument('--no-cuda', action='store_true', default=False, 109 | help='disables CUDA training') 110 | parser.add_argument('--seed', type=int, default=1, metavar='S', 111 | help='random seed (default: 1)') 112 | parser.add_argument('--save-model', action='store_true', default=False, 113 | help='For Saving the current Model') 114 | args = parser.parse_args() 115 | 116 | use_cuda = not args.no_cuda and torch.cuda.is_available() 117 | 118 | torch.manual_seed(args.seed) 119 | 120 | device = torch.device("cuda" if use_cuda else "cpu") 121 | 122 | kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} 123 | 124 | if args.dataset == "mnist": 125 | train_loader = torch.utils.data.DataLoader( 126 | datasets.MNIST('data', train=True, download=True, 127 | transform=transforms.Compose([ 128 | transforms.ToTensor(), 129 | transforms.Normalize((0.1307,), (0.3081,)) 130 | ])), 131 | batch_size=args.batch_size, shuffle=True, **kwargs) 132 | 133 | test_loader = torch.utils.data.DataLoader( 134 | datasets.MNIST('data', train=False, transform=transforms.Compose([ 135 | transforms.ToTensor(), 136 | transforms.Normalize((0.1307,), (0.3081,)) 137 | ])), 138 | batch_size=args.test_batch_size, shuffle=True, **kwargs) 139 | 140 | else: 141 | args.percent = [0.8, 0.92, 0.93, 0.94, 0.95, 0.99, 0.99, 0.93] 142 | args.num_pre_epochs = 5 143 | args.num_epochs = 20 144 | args.num_re_epochs = 5 145 | train_loader = torch.utils.data.DataLoader( 146 | datasets.CIFAR10('data', train=True, download=True, 147 | transform=transforms.Compose([ 148 | transforms.ToTensor(), 149 | transforms.Normalize((0.49139968, 0.48215827, 0.44653124), 150 | (0.24703233, 0.24348505, 0.26158768)) 151 | ])), shuffle=True, batch_size=args.batch_size, **kwargs) 152 | 153 | test_loader = torch.utils.data.DataLoader( 154 | datasets.CIFAR10('data', train=False, download=True, 155 | transform=transforms.Compose([ 156 | transforms.ToTensor(), 157 | transforms.Normalize((0.49139968, 0.48215827, 0.44653124), 158 | (0.24703233, 0.24348505, 0.26158768)) 159 | ])), shuffle=True, batch_size=args.test_batch_size, **kwargs) 160 | 161 | model = LeNet().to(device) if args.dataset == "mnist" else AlexNet().to(device) 162 | optimizer = PruneAdam(model.named_parameters(), lr=args.lr, eps=args.adam_epsilon) 163 | 164 | train(args, model, device, train_loader, test_loader, optimizer) 165 | mask = apply_l1_prune(model, device, args) if args.l1 else apply_prune(model, device, args) 166 | print_prune(model) 167 | test(args, model, device, test_loader) 168 | retrain(args, model, mask, device, train_loader, test_loader, optimizer) 169 | 170 | 171 | if __name__ == "__main__": 172 | main() -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class LeNet(nn.Module): 6 | def __init__(self): 7 | super(LeNet, self).__init__() 8 | self.conv1 = nn.Conv2d(1, 20, 5, 1) 9 | self.conv2 = nn.Conv2d(20, 50, 5, 1) 10 | self.fc1 = nn.Linear(4*4*50, 500) 11 | self.fc2 = nn.Linear(500, 10) 12 | 13 | def forward(self, x): 14 | x = F.relu(self.conv1(x)) 15 | x = F.max_pool2d(x, 2, 2) 16 | x = F.relu(self.conv2(x)) 17 | x = F.max_pool2d(x, 2, 2) 18 | x = x.view(-1, 4*4*50) 19 | x = F.relu(self.fc1(x)) 20 | x = self.fc2(x) 21 | return F.log_softmax(x, dim=1) 22 | 23 | 24 | class AlexNet(nn.Module): 25 | def __init__(self): 26 | super(AlexNet, self).__init__() 27 | 28 | self.features = nn.Sequential( 29 | nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1), 30 | nn.ReLU(inplace=True), 31 | nn.MaxPool2d(kernel_size=2), 32 | nn.Conv2d(64, 192, kernel_size=3, padding=1), 33 | nn.ReLU(inplace=True), 34 | nn.MaxPool2d(kernel_size=2), 35 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 36 | nn.ReLU(inplace=True), 37 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 38 | nn.ReLU(inplace=True), 39 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 40 | nn.ReLU(inplace=True), 41 | nn.MaxPool2d(kernel_size=2), 42 | ) 43 | 44 | self.classifier = nn.Sequential( 45 | nn.Dropout(), 46 | nn.Linear(256 * 2 * 2, 4096), 47 | nn.ReLU(inplace=True), 48 | nn.Dropout(), 49 | nn.Linear(4096, 4096), 50 | nn.ReLU(inplace=True), 51 | nn.Linear(4096, 10), 52 | ) 53 | 54 | def forward(self, x): 55 | x = self.features(x) 56 | x = x.view(x.shape[0], -1) 57 | x = self.classifier(x) 58 | return F.log_softmax(x, dim=1) 59 | -------------------------------------------------------------------------------- /optimizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is from official pytorch document (https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html) 3 | I modified optimizer to use name of the parameter for preventing prunned weights from updated by gradients 4 | """ 5 | 6 | import math 7 | from collections import defaultdict 8 | from torch._six import container_abcs 9 | import torch 10 | from copy import deepcopy 11 | from itertools import chain 12 | 13 | 14 | class _RequiredParameter(object): 15 | """Singleton class representing a required parameter for an Optimizer.""" 16 | def __repr__(self): 17 | return "" 18 | 19 | required = _RequiredParameter() 20 | 21 | 22 | class NameOptimizer(object): 23 | r"""Base class for all optimizers. 24 | 25 | .. warning:: 26 | Parameters need to be specified as collections that have a deterministic 27 | ordering that is consistent between runs. Examples of objects that don't 28 | satisfy those properties are sets and iterators over values of dictionaries. 29 | 30 | Arguments: 31 | params (iterable): an iterable of :class:`torch.Tensor` s or 32 | :class:`dict` s. Specifies what Tensors should be optimized. 33 | defaults: (dict): a dict containing default values of optimization 34 | options (used when a parameter group doesn't specify them). 35 | """ 36 | 37 | def __init__(self, named_params, defaults): 38 | self.defaults = defaults 39 | 40 | if isinstance(named_params, torch.Tensor): 41 | raise TypeError("params argument given to the optimizer should be " 42 | "an iterable of Tensors or dicts, but got " + 43 | torch.typename(named_params)) 44 | 45 | self.state = defaultdict(dict) 46 | self.param_groups = [] 47 | 48 | param_groups = list(named_params) 49 | if len(param_groups) == 0: 50 | raise ValueError("optimizer got an empty parameter list") 51 | if not isinstance(param_groups[0], dict): 52 | param_groups = [{'params': param_groups}] 53 | 54 | for param_group in param_groups: 55 | self.add_param_group(param_group) 56 | 57 | def __getstate__(self): 58 | return { 59 | 'defaults': self.defaults, 60 | 'state': self.state, 61 | 'param_groups': self.param_groups, 62 | } 63 | 64 | def __setstate__(self, state): 65 | self.__dict__.update(state) 66 | 67 | def __repr__(self): 68 | format_string = self.__class__.__name__ + ' (' 69 | for i, group in enumerate(self.param_groups): 70 | format_string += '\n' 71 | format_string += 'Parameter Group {0}\n'.format(i) 72 | for key in sorted(group.keys()): 73 | if key != 'params': 74 | format_string += ' {0}: {1}\n'.format(key, group[key]) 75 | format_string += ')' 76 | return format_string 77 | 78 | def state_dict(self): 79 | r"""Returns the state of the optimizer as a :class:`dict`. 80 | 81 | It contains two entries: 82 | 83 | * state - a dict holding current optimization state. Its content 84 | differs between optimizer classes. 85 | * param_groups - a dict containing all parameter groups 86 | """ 87 | # Save ids instead of Tensors 88 | def pack_group(group): 89 | packed = {k: v for k, v in group.items() if k != 'params'} 90 | packed['params'] = [id(p) for p in group['params']] 91 | return packed 92 | param_groups = [pack_group(g) for g in self.param_groups] 93 | # Remap state to use ids as keys 94 | packed_state = {(id(k) if isinstance(k, torch.Tensor) else k): v 95 | for k, v in self.state.items()} 96 | return { 97 | 'state': packed_state, 98 | 'param_groups': param_groups, 99 | } 100 | 101 | def load_state_dict(self, state_dict): 102 | r"""Loads the optimizer state. 103 | 104 | Arguments: 105 | state_dict (dict): optimizer state. Should be an object returned 106 | from a call to :meth:`state_dict`. 107 | """ 108 | # deepcopy, to be consistent with module API 109 | state_dict = deepcopy(state_dict) 110 | # Validate the state_dict 111 | groups = self.param_groups 112 | saved_groups = state_dict['param_groups'] 113 | 114 | if len(groups) != len(saved_groups): 115 | raise ValueError("loaded state dict has a different number of " 116 | "parameter groups") 117 | param_lens = (len(g['params']) for g in groups) 118 | saved_lens = (len(g['params']) for g in saved_groups) 119 | if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): 120 | raise ValueError("loaded state dict contains a parameter group " 121 | "that doesn't match the size of optimizer's group") 122 | 123 | # Update the state 124 | id_map = {old_id: p for old_id, p in 125 | zip(chain(*(g['params'] for g in saved_groups)), 126 | chain(*(g['params'] for g in groups)))} 127 | 128 | def cast(param, value): 129 | r"""Make a deep copy of value, casting all tensors to device of param.""" 130 | if isinstance(value, torch.Tensor): 131 | # Floating-point types are a bit special here. They are the only ones 132 | # that are assumed to always match the type of params. 133 | if param.is_floating_point(): 134 | value = value.to(param.dtype) 135 | value = value.to(param.device) 136 | return value 137 | elif isinstance(value, dict): 138 | return {k: cast(param, v) for k, v in value.items()} 139 | elif isinstance(value, container_abcs.Iterable): 140 | return type(value)(cast(param, v) for v in value) 141 | else: 142 | return value 143 | 144 | # Copy state assigned to params (and cast tensors to appropriate types). 145 | # State that is not assigned to params is copied as is (needed for 146 | # backward compatibility). 147 | state = defaultdict(dict) 148 | for k, v in state_dict['state'].items(): 149 | if k in id_map: 150 | param = id_map[k] 151 | state[param] = cast(param, v) 152 | else: 153 | state[k] = v 154 | 155 | # Update parameter groups, setting their 'params' value 156 | def update_group(group, new_group): 157 | new_group['params'] = group['params'] 158 | return new_group 159 | param_groups = [ 160 | update_group(g, ng) for g, ng in zip(groups, saved_groups)] 161 | self.__setstate__({'state': state, 'param_groups': param_groups}) 162 | 163 | def zero_grad(self): 164 | r"""Clears the gradients of all optimized :class:`torch.Tensor` s.""" 165 | for group in self.param_groups: 166 | for name, p in group['params']: 167 | if p.grad is not None: 168 | p.grad.detach_() 169 | p.grad.zero_() 170 | 171 | def step(self, closure): 172 | r"""Performs a single optimization step (parameter update). 173 | 174 | Arguments: 175 | closure (callable): A closure that reevaluates the model and 176 | returns the loss. Optional for most optimizers. 177 | """ 178 | raise NotImplementedError 179 | 180 | def add_param_group(self, param_group): 181 | r"""Add a param group to the :class:`Optimizer` s `param_groups`. 182 | 183 | This can be useful when fine tuning a pre-trained network as frozen layers can be made 184 | trainable and added to the :class:`Optimizer` as training progresses. 185 | 186 | Arguments: 187 | param_group (dict): Specifies what Tensors should be optimized along with group 188 | specific optimization options. 189 | """ 190 | assert isinstance(param_group, dict), "param group must be a dict" 191 | 192 | params = param_group['params'] 193 | if isinstance(params, torch.Tensor): 194 | param_group['params'] = [params] 195 | elif isinstance(params, set): 196 | raise TypeError('optimizer parameters need to be organized in ordered collections, but ' 197 | 'the ordering of tensors in sets will change between runs. Please use a list instead.') 198 | else: 199 | param_group['params'] = list(params) 200 | 201 | for name, param in param_group['params']: 202 | if not isinstance(param, torch.Tensor): 203 | raise TypeError("optimizer can only optimize Tensors, " 204 | "but one of the params is " + torch.typename(param)) 205 | if not param.is_leaf: 206 | raise ValueError("can't optimize a non-leaf Tensor") 207 | 208 | for name, default in self.defaults.items(): 209 | if default is required and name not in param_group: 210 | raise ValueError("parameter group didn't specify a value of required optimization parameter " + 211 | name) 212 | else: 213 | param_group.setdefault(name, default) 214 | 215 | param_set = set() 216 | for group in self.param_groups: 217 | param_set.update(set(group['params'])) 218 | 219 | if not param_set.isdisjoint(set(param_group['params'])): 220 | raise ValueError("some parameters appear in more than one parameter group") 221 | 222 | self.param_groups.append(param_group) 223 | 224 | 225 | class PruneAdam(NameOptimizer): 226 | r"""Implements Adam algorithm. 227 | 228 | It has been proposed in `Adam: A Method for Stochastic Optimization`_. 229 | 230 | Arguments: 231 | params (iterable): iterable of parameters to optimize or dicts defining 232 | parameter groups 233 | lr (float, optional): learning rate (default: 1e-3) 234 | betas (Tuple[float, float], optional): coefficients used for computing 235 | running averages of gradient and its square (default: (0.9, 0.999)) 236 | eps (float, optional): term added to the denominator to improve 237 | numerical stability (default: 1e-8) 238 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 239 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 240 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 241 | (default: False) 242 | 243 | .. _Adam\: A Method for Stochastic Optimization: 244 | https://arxiv.org/abs/1412.6980 245 | .. _On the Convergence of Adam and Beyond: 246 | https://openreview.net/forum?id=ryQu7f-RZ 247 | """ 248 | 249 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 250 | weight_decay=0, amsgrad=False): 251 | if not 0.0 <= lr: 252 | raise ValueError("Invalid learning rate: {}".format(lr)) 253 | if not 0.0 <= eps: 254 | raise ValueError("Invalid epsilon value: {}".format(eps)) 255 | if not 0.0 <= betas[0] < 1.0: 256 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 257 | if not 0.0 <= betas[1] < 1.0: 258 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 259 | defaults = dict(lr=lr, betas=betas, eps=eps, 260 | weight_decay=weight_decay, amsgrad=amsgrad) 261 | super(PruneAdam, self).__init__(params, defaults) 262 | 263 | def __setstate__(self, state): 264 | super(PruneAdam, self).__setstate__(state) 265 | for group in self.param_groups: 266 | group.setdefault('amsgrad', False) 267 | 268 | def step(self, closure=None): 269 | """Performs a single optimization step. 270 | 271 | Arguments: 272 | closure (callable, optional): A closure that reevaluates the model 273 | and returns the loss. 274 | """ 275 | loss = None 276 | if closure is not None: 277 | loss = closure() 278 | 279 | for group in self.param_groups: 280 | for name, p in group['params']: 281 | if p.grad is None: 282 | continue 283 | grad = p.grad.data 284 | if grad.is_sparse: 285 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 286 | amsgrad = group['amsgrad'] 287 | 288 | state = self.state[p] 289 | 290 | # State initialization 291 | if len(state) == 0: 292 | state['step'] = 0 293 | # Exponential moving average of gradient values 294 | state['exp_avg'] = torch.zeros_like(p.data) 295 | # Exponential moving average of squared gradient values 296 | state['exp_avg_sq'] = torch.zeros_like(p.data) 297 | if amsgrad: 298 | # Maintains max of all exp. moving avg. of sq. grad. values 299 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 300 | 301 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 302 | if amsgrad: 303 | max_exp_avg_sq = state['max_exp_avg_sq'] 304 | beta1, beta2 = group['betas'] 305 | 306 | state['step'] += 1 307 | 308 | if group['weight_decay'] != 0: 309 | grad.add_(group['weight_decay'], p.data) 310 | 311 | # Decay the first and second moment running average coefficient 312 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 313 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 314 | if amsgrad: 315 | # Maintains the maximum of all 2nd moment running avg. till now 316 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 317 | # Use the max. for normalizing running avg. of gradient 318 | denom = max_exp_avg_sq.sqrt().add_(group['eps']) 319 | else: 320 | denom = exp_avg_sq.sqrt().add_(group['eps']) 321 | 322 | bias_correction1 = 1 - beta1 ** state['step'] 323 | bias_correction2 = 1 - beta2 ** state['step'] 324 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 325 | 326 | p.data.addcdiv_(-step_size, exp_avg, denom) 327 | 328 | return loss 329 | 330 | def prune_step(self, mask, closure=None): 331 | """Performs a single optimization step. 332 | 333 | Arguments: 334 | closure (callable, optional): A closure that reevaluates the model 335 | and returns the loss. 336 | mask: prunning mask to prevent weight update. 337 | """ 338 | loss = None 339 | if closure is not None: 340 | loss = closure() 341 | for group in self.param_groups: 342 | for name, p in group['params']: 343 | if p.grad is None: 344 | continue 345 | grad = p.grad.data 346 | if grad.is_sparse: 347 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 348 | amsgrad = group['amsgrad'] 349 | 350 | state = self.state[p] 351 | 352 | # State initialization 353 | if len(state) == 0: 354 | state['step'] = 0 355 | # Exponential moving average of gradient values 356 | state['exp_avg'] = torch.zeros_like(p.data) 357 | # Exponential moving average of squared gradient values 358 | state['exp_avg_sq'] = torch.zeros_like(p.data) 359 | if amsgrad: 360 | # Maintains max of all exp. moving avg. of sq. grad. values 361 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 362 | 363 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 364 | if amsgrad: 365 | max_exp_avg_sq = state['max_exp_avg_sq'] 366 | beta1, beta2 = group['betas'] 367 | 368 | state['step'] += 1 369 | 370 | if group['weight_decay'] != 0: 371 | grad.add_(group['weight_decay'], p.data) 372 | 373 | # Decay the first and second moment running average coefficient 374 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 375 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 376 | 377 | if name.split('.')[-1] == "weight": 378 | exp_avg_sq.mul_(mask[name]) 379 | 380 | if amsgrad: 381 | # Maintains the maximum of all 2nd moment running avg. till now 382 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 383 | # Use the max. for normalizing running avg. of gradient 384 | denom = max_exp_avg_sq.sqrt().add_(group['eps']) 385 | else: 386 | denom = exp_avg_sq.sqrt().add_(group['eps']) 387 | 388 | bias_correction1 = 1 - beta1 ** state['step'] 389 | bias_correction2 = 1 - beta2 ** state['step'] 390 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 391 | 392 | if name.split('.')[-1] == "weight": 393 | exp_avg.mul_(mask[name]) 394 | p.data.addcdiv_(-step_size, exp_avg, denom) 395 | 396 | return loss 397 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | 5 | 6 | def regularized_nll_loss(args, model, output, target): 7 | index = 0 8 | loss = F.nll_loss(output, target) 9 | if args.l2: 10 | for name, param in model.named_parameters(): 11 | if name.split('.')[-1] == "weight": 12 | loss += args.alpha * param.norm() 13 | index += 1 14 | return loss 15 | 16 | 17 | def admm_loss(args, device, model, Z, U, output, target): 18 | idx = 0 19 | loss = F.nll_loss(output, target) 20 | for name, param in model.named_parameters(): 21 | if name.split('.')[-1] == "weight": 22 | u = U[idx].to(device) 23 | z = Z[idx].to(device) 24 | loss += args.rho / 2 * (param - z + u).norm() 25 | if args.l2: 26 | loss += args.alpha * param.norm() 27 | idx += 1 28 | return loss 29 | 30 | 31 | def initialize_Z_and_U(model): 32 | Z = () 33 | U = () 34 | for name, param in model.named_parameters(): 35 | if name.split('.')[-1] == "weight": 36 | Z += (param.detach().cpu().clone(),) 37 | U += (torch.zeros_like(param).cpu(),) 38 | return Z, U 39 | 40 | 41 | def update_X(model): 42 | X = () 43 | for name, param in model.named_parameters(): 44 | if name.split('.')[-1] == "weight": 45 | X += (param.detach().cpu().clone(),) 46 | return X 47 | 48 | 49 | def update_Z(X, U, args): 50 | new_Z = () 51 | idx = 0 52 | for x, u in zip(X, U): 53 | z = x + u 54 | pcen = np.percentile(abs(z), 100*args.percent[idx]) 55 | under_threshold = abs(z) < pcen 56 | z.data[under_threshold] = 0 57 | new_Z += (z,) 58 | idx += 1 59 | return new_Z 60 | 61 | 62 | def update_Z_l1(X, U, args): 63 | new_Z = () 64 | delta = args.alpha / args.rho 65 | for x, u in zip(X, U): 66 | z = x + u 67 | new_z = z.clone() 68 | if (z > delta).sum() != 0: 69 | new_z[z > delta] = z[z > delta] - delta 70 | if (z < -delta).sum() != 0: 71 | new_z[z < -delta] = z[z < -delta] + delta 72 | if (abs(z) <= delta).sum() != 0: 73 | new_z[abs(z) <= delta] = 0 74 | new_Z += (new_z,) 75 | return new_Z 76 | 77 | 78 | def update_U(U, X, Z): 79 | new_U = () 80 | for u, x, z in zip(U, X, Z): 81 | new_u = u + x - z 82 | new_U += (new_u,) 83 | return new_U 84 | 85 | 86 | def prune_weight(weight, device, percent): 87 | # to work with admm, we calculate percentile based on all elements instead of nonzero elements. 88 | weight_numpy = weight.detach().cpu().numpy() 89 | pcen = np.percentile(abs(weight_numpy), 100*percent) 90 | under_threshold = abs(weight_numpy) < pcen 91 | weight_numpy[under_threshold] = 0 92 | mask = torch.Tensor(abs(weight_numpy) >= pcen).to(device) 93 | return mask 94 | 95 | 96 | def prune_l1_weight(weight, device, delta): 97 | weight_numpy = weight.detach().cpu().numpy() 98 | under_threshold = abs(weight_numpy) < delta 99 | weight_numpy[under_threshold] = 0 100 | mask = torch.Tensor(abs(weight_numpy) >= delta).to(device) 101 | return mask 102 | 103 | 104 | def apply_prune(model, device, args): 105 | # returns dictionary of non_zero_values' indices 106 | print("Apply Pruning based on percentile") 107 | dict_mask = {} 108 | idx = 0 109 | for name, param in model.named_parameters(): 110 | if name.split('.')[-1] == "weight": 111 | mask = prune_weight(param, device, args.percent[idx]) 112 | param.data.mul_(mask) 113 | # param.data = torch.Tensor(weight_pruned).to(device) 114 | dict_mask[name] = mask 115 | idx += 1 116 | return dict_mask 117 | 118 | 119 | def apply_l1_prune(model, device, args): 120 | delta = args.alpha / args.rho 121 | print("Apply Pruning based on percentile") 122 | dict_mask = {} 123 | idx = 0 124 | for name, param in model.named_parameters(): 125 | if name.split('.')[-1] == "weight": 126 | mask = prune_l1_weight(param, device, delta) 127 | param.data.mul_(mask) 128 | dict_mask[name] = mask 129 | idx += 1 130 | return dict_mask 131 | 132 | 133 | def print_convergence(model, X, Z): 134 | idx = 0 135 | print("normalized norm of (weight - projection)") 136 | for name, _ in model.named_parameters(): 137 | if name.split('.')[-1] == "weight": 138 | x, z = X[idx], Z[idx] 139 | print("({}): {:.4f}".format(name, (x-z).norm().item() / x.norm().item())) 140 | idx += 1 141 | 142 | 143 | def print_prune(model): 144 | prune_param, total_param = 0, 0 145 | for name, param in model.named_parameters(): 146 | if name.split('.')[-1] == "weight": 147 | print("[at weight {}]".format(name)) 148 | print("percentage of pruned: {:.4f}%".format(100 * (abs(param) == 0).sum().item() / param.numel())) 149 | print("nonzero parameters after pruning: {} / {}\n".format((param != 0).sum().item(), param.numel())) 150 | total_param += param.numel() 151 | prune_param += (param != 0).sum().item() 152 | print("total nonzero parameters after pruning: {} / {} ({:.4f}%)". 153 | format(prune_param, total_param, 154 | 100 * (total_param - prune_param) / total_param)) 155 | --------------------------------------------------------------------------------