├── README.md ├── load_datasets ├── __init__.py └── load_datasets.py ├── main.py └── models ├── __init__.py ├── layers.py └── networks.py /README.md: -------------------------------------------------------------------------------- 1 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/winning-the-lottery-with-continuous-1/network-pruning-on-imagenet-resnet-50-90)](https://paperswithcode.com/sota/network-pruning-on-imagenet-resnet-50-90?p=winning-the-lottery-with-continuous-1) 2 | 3 | # Continuous Sparsification 4 | 5 | Implementation of Continuous Sparsification (CS), a method based on l_0 regularization to find sparse neural networks, proposed in [[Winning the Lottery with Continuous Sparsification](https://arxiv.org/abs/1912.04427)]. 6 | 7 | ## Requirements 8 | ``` 9 | Python 2/3, PyTorch == 1.1.0 10 | ``` 11 | ## Training a ResNet on CIFAR with Continuous Sparsification 12 | 13 | The main.py script can be used to train a ResNet-18 on CIFAR-10 with Continuous Sparsification. By default it will perform 3 rounds of training, each round consisting of 85 epochs. With the default hyperparameter values for the mask initialization, mask penalty, and final temperature, the method will find a sub-network with 20-30% sparsity which achieves 91.5-92.0% test accuracy when trained after rewinding (the dense network achieves 90-91%). The training and rewinding protocols follow the ones in the Lottery Ticket Hypothesis papers by Frankle. 14 | 15 | In general, the sparsity of the final sub-network can be controlled by changing the value used to initialize the soft mask parameters. This can be done with, for example: 16 | 17 | ``` 18 | python main.py --mask-initial-value 0.1 19 | ``` 20 | 21 | The default value is 0.0 and increasing it will result in less sparse sub-networks. High sparsity sub-networks can be found by setting it to -0.1. 22 | 23 | ## Extending the code 24 | 25 | To train other network models with Continuous Sparsification, the first step is to choose which layers you want to sparsify and then implement PyTorch modules that perform soft masking on its original parameters. This repository contains code for 2D convolutions with soft masking: the SoftMaskedConv2d module in models/layers.py: 26 | 27 | ``` 28 | class SoftMaskedConv2d(nn.Module): 29 | def __init__(self, in_channels, out_channels, kernel_size, padding=1, stride=1, mask_initial_value=0.): 30 | super(SoftMaskedConv2d, self).__init__() 31 | self.mask_initial_value = mask_initial_value 32 | 33 | self.in_channels = in_channels 34 | self.out_channels = out_channels 35 | self.kernel_size = kernel_size 36 | self.padding = padding 37 | self.stride = stride 38 | 39 | self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels, kernel_size, kernel_size)) 40 | nn.init.xavier_normal_(self.weight) 41 | self.init_weight = nn.Parameter(torch.zeros_like(self.weight), requires_grad=False) 42 | self.init_mask() 43 | 44 | def init_mask(self): 45 | self.mask_weight = nn.Parameter(torch.Tensor(self.out_channels, self.in_channels, self.kernel_size, self.kernel_size)) 46 | nn.init.constant_(self.mask_weight, self.mask_initial_value) 47 | 48 | def compute_mask(self, temp, ticket): 49 | scaling = 1. / sigmoid(self.mask_initial_value) 50 | if ticket: mask = (self.mask_weight > 0).float() 51 | else: mask = F.sigmoid(temp * self.mask_weight) 52 | return scaling * mask 53 | 54 | def prune(self, temp): 55 | self.mask_weight.data = torch.clamp(temp * self.mask_weight.data, max=self.mask_initial_value) 56 | 57 | def forward(self, x, temp=1, ticket=False): 58 | self.mask = self.compute_mask(temp, ticket) 59 | masked_weight = self.weight * self.mask 60 | out = F.conv2d(x, masked_weight, stride=self.stride, padding=self.padding) 61 | return out 62 | 63 | def checkpoint(self): 64 | self.init_weight.data = self.weight.clone() 65 | 66 | def rewind_weights(self): 67 | self.weight.data = self.init_weight.clone() 68 | 69 | def extra_repr(self): 70 | return '{}, {}, kernel_size={}, stride={}, padding={}'.format( 71 | self.in_channels, self.out_channels, self.kernel_size, self.stride, self.padding) 72 | ``` 73 | 74 | Extending it to other layers is straightforward, since you only need to change the init, init_mask and the forward methods. In init_mask, you should create a mask parameter (of PyTorch Parameter type) for each parameter set that you want to sparsify -- each mask parameter must have the same dimensions as the corresponding parameter. 75 | 76 | ``` 77 | def init_mask(self): 78 | self.mask_weight = nn.Parameter(torch.Tensor(...)) 79 | nn.init.constant_(self.mask_weight, self.mask_initial_value) 80 | ``` 81 | 82 | In the forward method, you need to compute the masked parameter for each parameter to be sparsified (e.g. masked weights for a Linear layer), and then compute the output of the layer with the corresponding PyTorch functional call (e.g. F.Linear for Linear layers). For example: 83 | 84 | ``` 85 | def forward(self, x, temp=1, ticket=False): 86 | self.mask = self.compute_mask(temp, ticket) 87 | masked_weight = self.weight * self.mask 88 | out = F.linear(x, masked_weight) 89 | return out 90 | ``` 91 | 92 | Once all the required layers have been implemented, it remains to implement the network which CS will sparsify. In models/networks.py, you can find code for the ResNet-18 and use it as base to implement other networks. In general, your network can inherit from MaskedNet instead of nn.Module and most of the required functionalities will be immediately available. What remains is to use the layers you implemented (the ones with soft masked paramaters) in your network, and remember to pass temp and ticket as additional inputs: temp is the current temperature of CS (assumed to be the attribute model.temp in main.py), while ticket is a boolean variable that controls whether the parameters' masks should be soft (ticket=False) or hard (ticket=True). Having ticket=True means that the mask will be binary and the masked parameters will actually be sparse. Use ticket=False for training (i.e. sub-network search) and ticket=True once you are done and want to evaluate the sparse sub-network. 93 | 94 | ## Future plans 95 | 96 | We plan to make the effort of applying CS to other layers/networks considerably smaller. This will be hopefully achieved by offering a function that receives a standard PyTorch Module object and returns another Module but with the mask parameters properly created and the forward passes overloaded to use masked parameters instead. 97 | 98 | If there are specific functionalities that would help you in your research or in applying our method in general, feel free to suggest it and we will consider implementing it. 99 | 100 | ## Citation 101 | 102 | If you use our method for research purposes, please cite our work: 103 | 104 | ``` 105 | @article{ssm2019cs, 106 | author = {Savarese, Pedro and Silva, Hugo and Maire, Michael}, 107 | title = {Winning the Lottery with Continuous Sparsification}, 108 | journal = {arXiv:1912.04427}, 109 | year = "2019" 110 | } 111 | ``` 112 | 113 | -------------------------------------------------------------------------------- /load_datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .load_datasets import * -------------------------------------------------------------------------------- /load_datasets/load_datasets.py: -------------------------------------------------------------------------------- 1 | from torchvision import datasets, transforms 2 | from torch.utils.data.sampler import SubsetRandomSampler 3 | import numpy as np 4 | import torch 5 | 6 | def generate_loaders(val_set_size, batch_size, n_workers): 7 | mean, std = [x / 255 for x in [125.3, 123.0, 113.9]], [x / 255 for x in [63.0, 62.1, 66.7]] 8 | train_transform = transforms.Compose([transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), transforms.Normalize(mean, std)]) 9 | test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)]) 10 | 11 | d_fun = datasets.CIFAR10 12 | n_classes = 10 13 | 14 | addr = './data/cifar10' 15 | train_dataset = d_fun(addr, train=True, download=True, transform=train_transform) 16 | val_dataset = d_fun(addr, train=True, download=True, transform=test_transform) 17 | 18 | label_dict = {} 19 | for idx in range(len(train_dataset)): 20 | _, label = train_dataset[idx] 21 | if label not in label_dict: 22 | label_dict[label] = [idx] 23 | else: 24 | label_dict[label].append(idx) 25 | 26 | train_indices = [] 27 | val_indices = [] 28 | for label, idxs in label_dict.items(): 29 | np.random.shuffle(idxs) 30 | train_indices += idxs[(val_set_size//n_classes):] 31 | val_indices += idxs[:(val_set_size//n_classes)] 32 | 33 | test_dataset = d_fun(addr, train=False, download=True, transform=test_transform) 34 | assert val_set_size < len(train_dataset) 35 | 36 | train_sampler = SubsetRandomSampler(train_indices) 37 | valid_sampler = SubsetRandomSampler(val_indices) 38 | train_loader = torch.utils.data.DataLoader(train_dataset, shuffle=False, sampler=train_sampler, 39 | batch_size=batch_size, num_workers=n_workers, pin_memory=True) 40 | val_loader = torch.utils.data.DataLoader(val_dataset, shuffle=False, sampler=valid_sampler, 41 | batch_size=batch_size, num_workers=n_workers, pin_memory=True) 42 | test_loader = torch.utils.data.DataLoader(test_dataset, shuffle=False, 43 | batch_size=batch_size, num_workers=n_workers, pin_memory=True) 44 | 45 | return train_loader, val_loader, test_loader 46 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | import torch.backends.cudnn as cudnn 7 | import numpy as np 8 | import random 9 | from models import * 10 | from load_datasets import * 11 | 12 | parser = argparse.ArgumentParser(description='Training a ResNet on CIFAR-10 with Continuous Sparsification') 13 | parser.add_argument('--which-gpu', type=int, default=0, help='which GPU to use') 14 | parser.add_argument('--batch-size', type=int, default=128, metavar='N', help='input batch size for training/val/test (default: 128)') 15 | parser.add_argument('--epochs', type=int, default=85, help='number of epochs to train (default: 85)') 16 | parser.add_argument('--rounds', type=int, default=3, help='number of rounds to train (default: 3)') 17 | parser.add_argument('--lr', type=float, default=0.1, metavar='LR', help='learning rate (default: 0.1)') 18 | parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training') 19 | parser.add_argument('--seed', type=int, default=1234, metavar='S', help='random seed (default: 1234)') 20 | parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)') 21 | parser.add_argument('--val-set-size', type=int, default=5000, help='how much of the training set to use for validation (default: 5000)') 22 | parser.add_argument('--lr-schedule', type=int, nargs='+', default=[56,71], help='epochs at which the learning rate will be dropped') 23 | parser.add_argument('--lr-drops', type=float, nargs='+', default=[0.1, 0.1], help='how much to drop the lr at each epoch in the schedule') 24 | parser.add_argument('--decay', type=float, default=0.0001, help='weight decay (default: 0.0001)') 25 | parser.add_argument('--rewind-epoch', type=int, default=2, help='epoch to rewind weights to (default: 2)') 26 | parser.add_argument('--lmbda', type=float, default=1e-8, help='lambda for L1 mask regularization (default: 1e-8)') 27 | parser.add_argument('--final-temp', type=float, default=200, help='temperature at the end of each round (default: 200)') 28 | parser.add_argument('--mask-initial-value', type=float, default=0., help='initial value for mask parameters') 29 | args = parser.parse_args() 30 | 31 | args.cuda = not args.no_cuda and torch.cuda.is_available() 32 | torch.cuda.set_device(args.which_gpu) 33 | 34 | torch.manual_seed(args.seed) 35 | np.random.seed(args.seed) 36 | random.seed(args.seed) 37 | 38 | if args.cuda: 39 | torch.cuda.manual_seed_all(args.seed) 40 | cudnn.benchmark = True 41 | 42 | train_loader, val_loader, test_loader = generate_loaders(args.val_set_size, args.batch_size, args.workers) 43 | 44 | model = ResNet(args.mask_initial_value) 45 | 46 | if args.cuda: model.cuda() 47 | print(model) 48 | 49 | def adjust_learning_rate(optimizer, epoch): 50 | lr = args.lr 51 | assert len(args.lr_schedule) == len(args.lr_drops), "length of gammas and schedule should be equal" 52 | for (drop, step) in zip(args.lr_drops, args.lr_schedule): 53 | if (epoch >= step): lr = lr * drop 54 | else: break 55 | for param_group in optimizer.param_groups: param_group['lr'] = lr 56 | 57 | def compute_remaining_weights(masks): 58 | return 1 - sum(float((m == 0).sum()) for m in masks) / sum(m.numel() for m in masks) 59 | 60 | def train(outer_round): 61 | for epoch in range(args.epochs): 62 | print('\t--------- Epoch {} -----------'.format(epoch)) 63 | model.train() 64 | if epoch > 0: model.temp *= temp_increase 65 | if outer_round == 0 and epoch == args.rewind_epoch: model.checkpoint() 66 | for optimizer in optimizers: adjust_learning_rate(optimizer, epoch) 67 | 68 | for batch_idx, (data, target) in enumerate(train_loader): 69 | if args.cuda: data, target = data.cuda(), target.cuda(non_blocking=True) 70 | for optimizer in optimizers: optimizer.zero_grad() 71 | output = model(data) 72 | pred = output.max(1)[1] 73 | batch_correct = pred.eq(target.data.view_as(pred)).sum() 74 | masks = [m.mask for m in model.mask_modules] 75 | entries_sum = sum(m.sum() for m in masks) 76 | loss = F.cross_entropy(output, target) + args.lmbda * entries_sum 77 | loss.backward() 78 | for optimizer in optimizers: optimizer.step() 79 | 80 | val_acc = test(val_loader) 81 | test_acc = test(test_loader) 82 | remaining_weights = compute_remaining_weights(masks) 83 | print('\t\tTemp: {:.1f}\tRemaining weights: {:.4f}\tVal acc: {:.1f}\tTest acc: {}'.format(model.temp, remaining_weights, val_acc, test_acc)) 84 | 85 | def test(loader): 86 | model.eval() 87 | correct = 0. 88 | total = 0. 89 | with torch.no_grad(): 90 | for data, target in loader: 91 | if args.cuda: data, target = data.cuda(), target.cuda(non_blocking=True) 92 | output = model(data) 93 | pred = output.max(1)[1] 94 | correct += pred.eq(target.data.view_as(pred)).sum() 95 | total += data.size()[0] 96 | acc = 100. * correct.item() / total 97 | return acc 98 | 99 | iters_per_reset = args.epochs-1 100 | temp_increase = args.final_temp**(1./iters_per_reset) 101 | 102 | trainable_params = filter(lambda p: p.requires_grad, model.parameters()) 103 | num_params = sum([p.numel() for p in trainable_params]) 104 | print("Total number of parameters: {}".format(num_params)) 105 | 106 | weight_params = map(lambda a: a[1], filter(lambda p: p[1].requires_grad and 'mask' not in p[0], model.named_parameters())) 107 | mask_params = map(lambda a: a[1], filter(lambda p: p[1].requires_grad and 'mask' in p[0], model.named_parameters())) 108 | 109 | model.ticket = False 110 | weight_optim = optim.SGD(weight_params, lr=args.lr, momentum=0.9, nesterov=True, weight_decay=args.decay) 111 | mask_optim = optim.SGD(mask_params, lr=args.lr, momentum=0.9, nesterov=True) 112 | optimizers = [weight_optim, mask_optim] 113 | for outer_round in range(args.rounds): 114 | print('--------- Round {} -----------'.format(outer_round)) 115 | train(outer_round) 116 | model.temp = 1 117 | if outer_round != args.rounds-1: model.prune() 118 | 119 | print('--------- Training final ticket -----------') 120 | optimizers = [optim.SGD(weight_params, lr=args.lr, momentum=0.9, nesterov=True, weight_decay=args.decay)] 121 | model.ticket = True 122 | model.rewind_weights() 123 | train(outer_round) 124 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .networks import * 2 | -------------------------------------------------------------------------------- /models/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Function 5 | import numpy as np 6 | 7 | def sigmoid(x): 8 | return float(1./(1.+np.exp(-x))) 9 | 10 | class SoftMaskedConv2d(nn.Module): 11 | def __init__(self, in_channels, out_channels, kernel_size, padding=1, stride=1, mask_initial_value=0.): 12 | super(SoftMaskedConv2d, self).__init__() 13 | self.mask_initial_value = mask_initial_value 14 | 15 | self.in_channels = in_channels 16 | self.out_channels = out_channels 17 | self.kernel_size = kernel_size 18 | self.padding = padding 19 | self.stride = stride 20 | 21 | self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels, kernel_size, kernel_size)) 22 | nn.init.xavier_normal_(self.weight) 23 | self.init_weight = nn.Parameter(torch.zeros_like(self.weight), requires_grad=False) 24 | self.init_mask() 25 | 26 | def init_mask(self): 27 | self.mask_weight = nn.Parameter(torch.Tensor(self.out_channels, self.in_channels, self.kernel_size, self.kernel_size)) 28 | nn.init.constant_(self.mask_weight, self.mask_initial_value) 29 | 30 | def compute_mask(self, temp, ticket): 31 | scaling = 1. / sigmoid(self.mask_initial_value) 32 | if ticket: mask = (self.mask_weight > 0).float() 33 | else: mask = F.sigmoid(temp * self.mask_weight) 34 | return scaling * mask 35 | 36 | def prune(self, temp): 37 | self.mask_weight.data = torch.clamp(temp * self.mask_weight.data, max=self.mask_initial_value) 38 | 39 | def forward(self, x, temp=1, ticket=False): 40 | self.mask = self.compute_mask(temp, ticket) 41 | masked_weight = self.weight * self.mask 42 | out = F.conv2d(x, masked_weight, stride=self.stride, padding=self.padding) 43 | return out 44 | 45 | def checkpoint(self): 46 | self.init_weight.data = self.weight.clone() 47 | 48 | def rewind_weights(self): 49 | self.weight.data = self.init_weight.clone() 50 | 51 | def extra_repr(self): 52 | return '{}, {}, kernel_size={}, stride={}, padding={}'.format( 53 | self.in_channels, self.out_channels, self.kernel_size, self.stride, self.padding) 54 | -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .layers import * 5 | import functools 6 | from torch.nn import init 7 | import copy 8 | 9 | class MaskedNet(nn.Module): 10 | def __init__(self): 11 | super(MaskedNet, self).__init__() 12 | self.ticket = False 13 | 14 | def checkpoint(self): 15 | for m in self.mask_modules: m.checkpoint() 16 | for m in self.modules(): 17 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.Linear): 18 | m.checkpoint = copy.deepcopy(m.state_dict()) 19 | 20 | def rewind_weights(self): 21 | for m in self.mask_modules: m.rewind_weights() 22 | for m in self.modules(): 23 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.Linear): 24 | m.load_state_dict(m.checkpoint) 25 | 26 | def prune(self): 27 | for m in self.mask_modules: m.prune(self.temp) 28 | 29 | class ResBlock(nn.Module): 30 | def __init__(self, Conv, in_channels, out_channels, stride=1, downsample=None): 31 | super(ResBlock, self).__init__() 32 | self.conv_a = Conv(in_channels, out_channels, kernel_size=3, stride=stride, padding=1) 33 | self.bn_a = nn.BatchNorm2d(out_channels) 34 | self.conv_b = Conv(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 35 | self.bn_b = nn.BatchNorm2d(out_channels) 36 | self.downsample = downsample 37 | 38 | def forward(self, x, temp, ticket): 39 | residual = x 40 | out = self.conv_a(x, temp, ticket) 41 | out = self.bn_a(out) 42 | out = F.relu(out, inplace=True) 43 | out = self.conv_b(out, temp, ticket) 44 | out = self.bn_b(out) 45 | if self.downsample is not None: residual = self.downsample(x) 46 | return F.relu(residual + out, inplace=True) 47 | 48 | class ResStage(nn.Module): 49 | def __init__(self, Conv, in_channels, out_channels, stride=1): 50 | super(ResStage, self).__init__() 51 | downsample = None 52 | if stride != 1 or in_channels != out_channels: 53 | downsample = nn.Conv2d(in_channels, out_channels, 1, 2, 0, bias=False) 54 | 55 | self.block1 = ResBlock(Conv, in_channels, out_channels, stride, downsample) 56 | self.block2 = ResBlock(Conv, out_channels, out_channels) 57 | self.block3 = ResBlock(Conv, out_channels, out_channels) 58 | 59 | def forward(self, x, temp, ticket): 60 | out = self.block1(x, temp, ticket) 61 | out = self.block2(out, temp, ticket) 62 | out = self.block3(out, temp, ticket) 63 | return out 64 | 65 | class ResNet(MaskedNet): 66 | def __init__(self, mask_initial_value=0.): 67 | super(ResNet, self).__init__() 68 | 69 | Conv = functools.partial(SoftMaskedConv2d, mask_initial_value=mask_initial_value) 70 | 71 | self.conv0 = Conv(3, 16, 3, 1, 1) 72 | self.bn0 = nn.BatchNorm2d(16) 73 | 74 | self.stage1 = ResStage(Conv, 16, 16, 1) 75 | self.stage2 = ResStage(Conv, 16, 32, 2) 76 | self.stage3 = ResStage(Conv, 32, 64, 2) 77 | 78 | self.avgpool = nn.AvgPool2d(8) 79 | self.classifier = nn.Linear(64, 10) 80 | self.mask_modules = [m for m in self.modules() if type(m) == SoftMaskedConv2d] 81 | self.temp = 1 82 | 83 | for m in self.modules(): 84 | if isinstance(m, nn.Conv2d): 85 | init.xavier_normal_(m.weight) 86 | if isinstance(m, nn.BatchNorm2d): 87 | m.weight.data.fill_(1) 88 | m.bias.data.zero_() 89 | elif isinstance(m, nn.Linear): 90 | init.xavier_normal_(m.weight) 91 | m.bias.data.zero_() 92 | 93 | def forward(self, x): 94 | out = F.relu(self.bn0(self.conv0(x, self.temp, self.ticket)), inplace=True) 95 | out = self.stage1(out, self.temp, self.ticket) 96 | out = self.stage2(out, self.temp, self.ticket) 97 | out = self.stage3(out, self.temp, self.ticket) 98 | out = self.avgpool(out) 99 | out = out.view(x.size(0), -1) 100 | out = self.classifier(out) 101 | return out 102 | --------------------------------------------------------------------------------