├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── data_splits ├── c100_train_idx.npy ├── c100_valid_idx.npy ├── c10_train_idx.npy ├── c10_valid_idx.npy ├── svhn_train_idx.npy ├── svhn_valid_idx.npy ├── tin_train_idx.npy └── tin_valid_idx.npy ├── datasets ├── __init__.py ├── data.py └── tinyimagenet.py ├── finetuning.py ├── models ├── __init__.py ├── base_model.py ├── layers.py ├── mobilenet.py ├── models.py ├── network_slimming_resnet.py └── resnet.py ├── pretraining.py ├── pruning.py ├── requirements.txt ├── script.sh └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | checkpoints/* 2 | # Jupyter Notebook 3 | .ipynb_checkpoints 4 | 5 | # Environments 6 | .env 7 | .venv 8 | env/ 9 | venv/ 10 | ENV/ 11 | env.bak/ 12 | venv.bak/ 13 | 14 | __pypackages__/ 15 | __pycache__ 16 | .__pycache__ 17 | .vscode 18 | data/ 19 | pythonenv3.8 20 | *.ipynb 21 | logs/* 22 | logs -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/transmuteAI/ChipNet/e957d2a4e7503179fa0931739977564227f39719/.gitmodules -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Rishabh Tiwari 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ChipNet 2 | 3 | This is the official repository to the ICLR 2021 paper "[ChipNet: Budget-Aware Pruning with Heaviside Continuous Approximations]( 4 | https://openreview.net/pdf?id=xCxXwTzx4L1)" by Rishabh Tiwari, Udbhav Bamba, Arnav Chavan, Deepak Gupta. 5 | 6 | ## Getting Started 7 | 8 | You will need [Python 3.7](https://www.python.org/downloads) and the packages specified in _requirements.txt_. 9 | We recommend setting up a [virtual environment with pip](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/) 10 | and installing the packages there. 11 | 12 | Install packages with: 13 | 14 | ``` 15 | $ pip install -r requirements.txt 16 | ``` 17 | 18 | ## Configure and Run 19 | 20 | All configurations concerning data, model, training, etc. can be called using commandline arguments. 21 | 22 | The main script offers many options; here are the most important ones: 23 | 24 | ### Pretraining 25 | 26 | ``` 27 | usage: python pretraining.py [dataset] [model name] --epochs [number of epochs] --decay [weight decay] --batch_size [batch size] --lr [learning rate] --scheduler_type {1, 2} 28 | ``` 29 | 30 | ### Pruning 31 | ``` 32 | usage: python pruning.py [dataset] [model name] --Vc {float value between 0-1} --budget_type {channel_ratio, volume_ratio, parameter_ratio, flops_ratio} --epochs 20 33 | ``` 34 | 35 | ### Finetuning 36 | ``` 37 | usage: python finetuning.py [dataset] [model name] --Vc {float value between 0-1} --budget_type {channel_ratio, volume_ratio, parameter_ratio, flops_ratio} --name {model name}_{dataset}_{budget}_{budget_type} --epochs [number of epochs] --decay [weight decay] --batch_size [batch size] --lr [learning rate] --scheduler_type {1, 2} 38 | ``` 39 | 40 | ### Mask Transfer 41 | ``` 42 | usage: python finetuning.py [dataset] [model name] --Vc {float value between 0-1} --budget_type {channel_ratio, volume_ratio, parameter_ratio, flops_ratio} --host_name {model name}_{dataset}_{budget}_{budget_type} --epochs [number of epochs] --decay [weight decay] --batch_size [batch size] --lr [learning rate] --scheduler_type {1, 2} 43 | ``` 44 | * Parameter and FLOPs budget is supported only with models using ResNetCifar module for now. 45 | 46 | ## Citation 47 | Please cite our paper in your publications if it helps your research. Even if it does not,and you want to make us happy, do cite it :) 48 | 49 | @inproceedings{ 50 | tiwari2021chipnet, 51 | title={ChipNet: Budget-Aware Pruning with Heaviside Continuous Approximations}, 52 | author={Rishabh Tiwari and Udbhav Bamba and Arnav Chavan and Deepak Gupta}, 53 | booktitle={International Conference on Learning Representations}, 54 | year={2021}, 55 | url={https://openreview.net/forum?id=xCxXwTzx4L1} 56 | } 57 | 58 | 59 | 60 | ## License 61 | 62 | This project is licensed under the MIT License. 63 | -------------------------------------------------------------------------------- /data_splits/c100_train_idx.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/transmuteAI/ChipNet/e957d2a4e7503179fa0931739977564227f39719/data_splits/c100_train_idx.npy -------------------------------------------------------------------------------- /data_splits/c100_valid_idx.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/transmuteAI/ChipNet/e957d2a4e7503179fa0931739977564227f39719/data_splits/c100_valid_idx.npy -------------------------------------------------------------------------------- /data_splits/c10_train_idx.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/transmuteAI/ChipNet/e957d2a4e7503179fa0931739977564227f39719/data_splits/c10_train_idx.npy -------------------------------------------------------------------------------- /data_splits/c10_valid_idx.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/transmuteAI/ChipNet/e957d2a4e7503179fa0931739977564227f39719/data_splits/c10_valid_idx.npy -------------------------------------------------------------------------------- /data_splits/svhn_train_idx.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/transmuteAI/ChipNet/e957d2a4e7503179fa0931739977564227f39719/data_splits/svhn_train_idx.npy -------------------------------------------------------------------------------- /data_splits/svhn_valid_idx.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/transmuteAI/ChipNet/e957d2a4e7503179fa0931739977564227f39719/data_splits/svhn_valid_idx.npy -------------------------------------------------------------------------------- /data_splits/tin_train_idx.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/transmuteAI/ChipNet/e957d2a4e7503179fa0931739977564227f39719/data_splits/tin_train_idx.npy -------------------------------------------------------------------------------- /data_splits/tin_valid_idx.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/transmuteAI/ChipNet/e957d2a4e7503179fa0931739977564227f39719/data_splits/tin_valid_idx.npy -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .data import DataManager -------------------------------------------------------------------------------- /datasets/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.utils.data as data 5 | from torch.utils.data.sampler import SubsetRandomSampler 6 | from torchvision import transforms, datasets 7 | from torchvision.datasets import CIFAR10, CIFAR100, SVHN 8 | from .tinyimagenet import TinyImageNet 9 | from sklearn.model_selection import train_test_split 10 | import numpy as np 11 | 12 | class DataManager: 13 | def __init__(self, args): 14 | self.dataset_name = args.dataset 15 | self.batch_size = args.batch_size 16 | self.workers = args.workers 17 | self.valid_size = args.valid_size 18 | self.num_train = 0 19 | self.num_classes = {'c10': 10, 'c100': 100, 'tin': 200, 'svhn': 10}[self.dataset_name] 20 | self.insize = {'c10': 32, 'c100': 32, 'tin': 64, 'svhn': 32}[self.dataset_name] 21 | 22 | def prepare_data(self): 23 | print('... Preparing data ...') 24 | if self.dataset_name in ['c10', 'c100']: 25 | norm_mean = [0.49139968, 0.48215827, 0.44653124] 26 | norm_std = [0.24703233, 0.24348505, 0.26158768] 27 | norm_transform = transforms.Normalize(norm_mean, norm_std) 28 | train_transform = transforms.Compose([ 29 | transforms.RandomCrop(32, padding=4), 30 | transforms.RandomHorizontalFlip(), 31 | transforms.ToTensor(), 32 | norm_transform 33 | ]) 34 | val_transform = transforms.Compose([ 35 | transforms.ToTensor(), 36 | norm_transform 37 | ]) 38 | dataset_choice = {'c10': CIFAR10, 'c100': CIFAR100}[self.dataset_name] 39 | trainset = dataset_choice(root='./data', train=True, download=True, 40 | transform=train_transform) 41 | 42 | valset = dataset_choice(root='./data', train=True, download=True, 43 | transform=val_transform) 44 | 45 | testset = dataset_choice(root='./data', train=False, download=True, 46 | transform=val_transform) 47 | 48 | elif self.dataset_name == 'svhn': 49 | norm_mean =[0.4309, 0.4302, 0.4463] 50 | norm_std = [0.1253, 0.1282, 0.1147] 51 | norm_transform = transforms.Normalize(norm_mean, norm_std) 52 | train_transform = transforms.Compose([ 53 | transforms.RandomCrop(32, padding=4), 54 | transforms.RandomHorizontalFlip(), 55 | transforms.ToTensor(), 56 | norm_transform 57 | ]) 58 | val_transform = transforms.Compose([ 59 | transforms.ToTensor(), 60 | norm_transform 61 | ]) 62 | trainset = SVHN(root='./data', split='train', download=True, 63 | transform=train_transform) 64 | 65 | valset = SVHN(root='./data', split='train', download=True, 66 | transform=val_transform) 67 | 68 | 69 | testset = SVHN(root='./data', split='test', download=True, 70 | transform=val_transform) 71 | 72 | else: 73 | norm_mean = [0.485, 0.456, 0.406] 74 | norm_std = [0.229, 0.224, 0.225] 75 | norm_transform = transforms.Normalize(norm_mean, norm_std) 76 | train_transform = transforms.Compose([ 77 | transforms.RandomAffine(degrees=20.0, scale=(0.8, 1.2), shear=20.0), 78 | transforms.RandomHorizontalFlip(), 79 | transforms.ToTensor(), 80 | norm_transform, 81 | ]) 82 | val_transform = transforms.Compose([ 83 | transforms.ToTensor(), 84 | norm_transform 85 | ]) 86 | trainset = TinyImageNet('./data', train=True, transform=train_transform) 87 | valset = TinyImageNet('./data', train=True, transform=val_transform) 88 | testset = TinyImageNet('./data', train=False, transform=val_transform) 89 | 90 | self.num_train = len(trainset) 91 | train_idx, val_idx = self.get_split() 92 | train_sampler = SubsetRandomSampler(train_idx) 93 | val_sampler = SubsetRandomSampler(val_idx) 94 | train_loader = data.DataLoader(trainset, self.batch_size, num_workers=self.workers, 95 | sampler=train_sampler, pin_memory=True) 96 | val_loader = data.DataLoader(valset, self.batch_size, num_workers=self.workers, sampler=val_sampler, 97 | pin_memory=True) 98 | test_loader = data.DataLoader(testset, self.batch_size, num_workers=self.workers, shuffle=False, 99 | pin_memory=False) 100 | return train_loader, val_loader, test_loader 101 | 102 | def get_split(self): 103 | if(os.path.exists(f'data_splits/{self.dataset_name}_train_idx.npy') and os.path.exists(f'data_splits/{self.dataset_name}_valid_idx.npy')): 104 | print('using fixed split') 105 | train_idx, valid_idx = np.load(f'data_splits/{self.dataset_name}_train_idx.npy'), np.load(f'data_splits/{self.dataset_name}_valid_idx.npy') 106 | print(len(train_idx),len(valid_idx)) 107 | else: 108 | print('creating a split') 109 | indices = list(range(self.num_train)) 110 | train_idx, valid_idx = train_test_split(indices, test_size=self.valid_size) 111 | np.save(f'data_splits/{self.dataset_name}_train_idx.npy',train_idx) 112 | np.save(f'data_splits/{self.dataset_name}_valid_idx.npy',valid_idx) 113 | return train_idx, valid_idx 114 | 115 | 116 | -------------------------------------------------------------------------------- /datasets/tinyimagenet.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import glob 3 | import numpy as np 4 | import os 5 | from torchvision.datasets.folder import pil_loader 6 | from torchvision.datasets.utils import download_and_extract_archive 7 | 8 | class TinyImageNet(Dataset): 9 | def __init__(self, root, train, transform, download=True): 10 | 11 | self.url = "http://cs231n.stanford.edu/tiny-imagenet-200" 12 | self.root = root 13 | if download: 14 | if os.path.exists(f'{self.root}/tiny-imagenet-200/'): 15 | print('File already downloaded') 16 | else: 17 | download_and_extract_archive(self.url, root, filename="tiny-imagenet-200.zip") 18 | 19 | self.root = os.path.join(self.root, "tiny-imagenet-200") 20 | self.train = train 21 | self.transform = transform 22 | self.ids_string = np.sort(np.loadtxt(f"{self.root}/wnids.txt", "str")) 23 | self.ids = {class_string: i for i, class_string in enumerate(self.ids_string)} 24 | if train: 25 | self.paths = glob.glob(f"{self.root}/train/*/images/*") 26 | self.label = [self.ids[path.split("/")[-3]] for path in self.paths] 27 | else: 28 | self.val_annotations = np.loadtxt(f"{self.root}/val/val_annotations.txt", "str") 29 | self.paths = [f"{self.root}/val/images/{sample[0]}" for sample in self.val_annotations] 30 | self.label = [self.ids[sample[1]] for sample in self.val_annotations] 31 | 32 | def __len__(self): 33 | return len(self.paths) 34 | 35 | def __getitem__(self, idx): 36 | image = pil_loader(self.paths[idx]) 37 | 38 | if self.transform is not None: 39 | image = self.transform(image) 40 | 41 | return image, self.label[idx] 42 | -------------------------------------------------------------------------------- /finetuning.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | import numpy as np 6 | import pandas as pd 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from tqdm import tqdm as tqdm_notebook 10 | from datasets import DataManager 11 | from utils import * 12 | from models import get_model 13 | 14 | seed_everything(43) 15 | 16 | ap = argparse.ArgumentParser(description='finetuning') 17 | ap.add_argument('dataset', choices=['c10', 'c100', 'tin','svhn'], type=str, help='Dataset choice') 18 | ap.add_argument('model', type=str, help='Model choice') 19 | ap.add_argument('--budget_type', choices=['channel_ratio', 'volume_ratio','parameter_ratio','flops_ratio'], default = 'channel_ratio', type=str, help='Budget Type') 20 | ap.add_argument('--Vc', default=0.5, type=float, help='Budget Constraint') 21 | ap.add_argument('--batch_size', default=128, type=int, help='Batch Size') 22 | ap.add_argument('--epochs', default=300, type=int, help='Epochs') 23 | ap.add_argument('--name', type=str, help='name of model') 24 | ap.add_argument('--host_name',default = None, type=str, help='transfer the mask from this model') 25 | 26 | ap.add_argument('--valid_size', '-v', type=float, default=0.1, help='valid_size') 27 | ap.add_argument('--lr', default=0.05, type=float, help='Learning rate') 28 | ap.add_argument('--scheduler_type', '-st', type=int, choices=[1, 2], default=1, help='lr scheduler type') 29 | ap.add_argument('--decay', '-d', type=float, default=0.001, help='weight decay') 30 | ap.add_argument('--test_only', '-t', type=bool, default=False, help='test the best model') 31 | ap.add_argument('--workers', default=0, type=int, help='number of workers') 32 | ap.add_argument('--cuda_id', '-id', type=str, default='0', help='gpu number') 33 | args = ap.parse_args() 34 | 35 | valid_size=args.valid_size 36 | Vc = torch.FloatTensor([args.Vc]) 37 | if args.host_name == None: 38 | model_path = f"checkpoints/{args.name}_pruned.pth" 39 | else: 40 | # model_path = f"checkpoints/{args.name}_pretrained.pth" 41 | model_path = f"checkpoints/{args.host_name}_pruned.pth" 42 | 43 | ############################### preparing dataset ################################ 44 | 45 | data_object = DataManager(args) 46 | trainloader, valloader, testloader = data_object.prepare_data() 47 | dataloaders = { 48 | 'train': trainloader, 'val': valloader, "test": testloader 49 | } 50 | 51 | ############################### preparing model ################################### 52 | 53 | model = get_model(args.model, 'prune', data_object.num_classes, data_object.insize) 54 | if args.host_name is not None: 55 | host_state = torch.load(model_path)['state_dict'] 56 | model.load_state_dict(get_mask_dict(model.state_dict(), host_state), strict = False) 57 | else: 58 | state = torch.load(model_path)['state_dict'] 59 | model.load_state_dict(state, strict=False) 60 | CE = nn.CrossEntropyLoss() 61 | def criterion(model, y_pred, y_true): 62 | ce_loss = CE(y_pred, y_true) 63 | return ce_loss 64 | 65 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.decay) 66 | device = torch.device(f"cuda:{str(args.cuda_id)}") 67 | model.to(device) 68 | Vc.to(device) 69 | 70 | def train(model, loss_fn, optimizer): 71 | model.train() 72 | counter = 0 73 | tk1 = tqdm_notebook(dataloaders['train'], total=len(dataloaders['train'])) 74 | running_loss = 0. 75 | for x_var, y_var in tk1: 76 | counter +=1 77 | x_var = x_var.to(device=device) 78 | y_var = y_var.to(device=device) 79 | scores = model(x_var) 80 | loss = loss_fn(model, scores, y_var) 81 | running_loss+=loss.item() 82 | tk1.set_postfix(loss=running_loss/counter) 83 | optimizer.zero_grad() 84 | loss.backward() 85 | optimizer.step() 86 | return running_loss/counter 87 | 88 | def test(model, loss_fn, optimizer, phase): 89 | model.eval() 90 | counter = 0 91 | tk1 = tqdm_notebook(dataloaders[phase], total=len(dataloaders[phase])) 92 | running_loss = 0 93 | running_acc = 0 94 | total = 0 95 | with torch.no_grad(): 96 | for x_var, y_var in tk1: 97 | counter +=1 98 | x_var = x_var.to(device=device) 99 | y_var = y_var.to(device=device) 100 | scores = model(x_var) 101 | loss = loss_fn(model, scores, y_var) 102 | _, scores = torch.max(scores.data, 1) 103 | y_var = y_var.cpu().detach().numpy() 104 | scores = scores.cpu().detach().numpy() 105 | 106 | correct = (scores == y_var).sum().item() 107 | running_loss+=loss.item() 108 | running_acc+=correct 109 | total+=scores.shape[0] 110 | tk1.set_postfix(loss=running_loss/counter, acc=running_acc/total) 111 | return running_acc/total, running_loss/counter 112 | 113 | ############################## training starts here ############################# 114 | 115 | model.prepare_for_finetuning(device, Vc.item(), budget_type=args.budget_type) # sets beta and gamma and unfreezes network except zetas 116 | 117 | best_accuracy=0 118 | num_epochs = args.epochs 119 | train_losses = [] 120 | valid_losses = [] 121 | valid_accuracy = [] 122 | if args.test_only == False: 123 | for epoch in range(num_epochs): 124 | adjust_learning_rate(optimizer, epoch, args) 125 | print('Starting epoch %d / %d' % (epoch + 1, num_epochs)) 126 | train_loss = train(model, criterion, optimizer) 127 | accuracy, valid_loss = test(model, criterion, optimizer, "val") 128 | remaining = model.get_remaining(20.,args.budget_type).item() 129 | 130 | if accuracy>best_accuracy: 131 | print("**Saving model**") 132 | best_accuracy=accuracy 133 | torch.save({ 134 | "epoch": epoch + 1, 135 | "state_dict" : model.state_dict(), 136 | "acc" : best_accuracy, 137 | "rem" : remaining, 138 | }, f"checkpoints/{args.name}_{args.dataset}_finetuned.pth") 139 | 140 | train_losses.append(train_loss) 141 | valid_losses.append(valid_loss) 142 | valid_accuracy.append(accuracy) 143 | df_data=np.array([train_losses, valid_losses, valid_accuracy]).T 144 | df = pd.DataFrame(df_data,columns = ['train_losses','valid_losses','valid_accuracy']) 145 | df.to_csv(f"logs/{args.name}_{args.dataset}_finetuned.csv") 146 | 147 | state = torch.load(f"checkpoints/{args.name}_{args.dataset}_finetuned.pth") 148 | model.load_state_dict(state['state_dict'],strict=True) 149 | acc, v_loss = test(model, criterion, optimizer, "test") 150 | print(f"Test Accuracy: {acc} | Valid Accuracy: {state['acc']}") -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import get_model -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from collections import defaultdict 5 | from .layers import PrunableBatchNorm2d 6 | 7 | class BaseModel(nn.Module): 8 | def __init__(self): 9 | super(BaseModel, self).__init__() 10 | self.prunable_modules = [] 11 | self.prev_module = defaultdict() 12 | # self.next_module = defaultdict() 13 | pass 14 | 15 | def set_threshold(self, threshold): 16 | self.prune_threshold = threshold 17 | 18 | def init_weights(self): 19 | for m in self.modules(): 20 | if isinstance(m, nn.Conv2d): 21 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 22 | if m.bias is not None: 23 | nn.init.constant_(m.bias, 0) 24 | elif isinstance(m, nn.BatchNorm2d): 25 | nn.init.constant_(m.weight, 1) 26 | nn.init.constant_(m.bias, 0) 27 | elif isinstance(m, nn.Linear): 28 | nn.init.normal_(m.weight, 0, 0.01) 29 | nn.init.constant_(m.bias, 0) 30 | 31 | def calculate_prune_threshold(self, Vc, budget_type = 'channel_ratio'): 32 | zetas = self.give_zetas() 33 | if budget_type in ['volume_ratio']: 34 | zeta_weights = self.give_zeta_weights() 35 | zeta_weights = zeta_weights[np.argsort(zetas)] 36 | zetas = sorted(zetas) 37 | if budget_type == 'volume_ratio': 38 | curr_budget = 0 39 | indx = 0 40 | while(curr_budget<(1.-Vc)): 41 | indx+=1 42 | curr_budget+=zeta_weights[indx] 43 | prune_threshold = zetas[indx] 44 | else: 45 | prune_threshold = zetas[int((1.-Vc)*len(zetas))] 46 | return prune_threshold 47 | 48 | def smoothRound(self, x, steepness=20.): 49 | return 1./(1.+torch.exp(-1*steepness*(x-0.5))) 50 | 51 | def n_remaining(self, m, steepness=20.): 52 | return (m.pruned_zeta if m.is_pruned else self.smoothRound(m.get_zeta_t(), steepness)).sum() 53 | 54 | def is_all_pruned(self, m): 55 | return self.n_remaining(m) == 0 56 | 57 | def get_remaining(self, steepness=20., budget_type = 'channel_ratio'): 58 | """return the fraction of active zeta_t (i.e > 0.5)""" 59 | n_rem = 0 60 | n_total = 0 61 | for l_block in self.prunable_modules: 62 | if budget_type == 'volume_ratio': 63 | n_rem += (self.n_remaining(l_block, steepness)*l_block._conv_module.output_area) 64 | n_total += (l_block.num_gates*l_block._conv_module.output_area) 65 | elif budget_type == 'channel_ratio': 66 | n_rem += self.n_remaining(l_block, steepness) 67 | n_total += l_block.num_gates 68 | elif budget_type == 'parameter_ratio': 69 | k = l_block._conv_module.kernel_size[0] 70 | prev_total = 3 if self.prev_module[l_block] is None else self.prev_module[l_block].num_gates 71 | prev_remaining = 3 if self.prev_module[l_block] is None else self.n_remaining(self.prev_module[l_block], steepness) 72 | n_rem += self.n_remaining(l_block, steepness)*prev_remaining*k*k 73 | n_total += l_block.num_gates*prev_total*k*k 74 | elif budget_type == 'flops_ratio': 75 | k = l_block._conv_module.kernel_size[0] 76 | output_area = l_block._conv_module.output_area 77 | prev_total = 3 if self.prev_module[l_block] is None else self.prev_module[l_block].num_gates 78 | prev_remaining = 3 if self.prev_module[l_block] is None else self.n_remaining(self.prev_module[l_block], steepness) 79 | curr_remaining = self.n_remaining(l_block, steepness) 80 | n_rem += curr_remaining*prev_remaining*k*k*output_area + curr_remaining*output_area 81 | n_total += l_block.num_gates*prev_total*k*k*output_area + l_block.num_gates*output_area 82 | return n_rem/n_total 83 | 84 | def give_zetas(self): 85 | zetas = [] 86 | for l_block in self.prunable_modules: 87 | zetas.append(l_block.get_zeta_t().cpu().detach().numpy().tolist()) 88 | zetas = [z for k in zetas for z in k ] 89 | return zetas 90 | 91 | def give_zeta_weights(self): 92 | zeta_weights = [] 93 | for l_block in self.prunable_modules: 94 | zeta_weights.append([l_block._conv_module.output_area]*l_block.num_gates) 95 | zeta_weights = [z for k in zeta_weights for z in k ] 96 | return zeta_weights/np.sum(zeta_weights) 97 | 98 | def plot_zt(self): 99 | """plots the distribution of zeta_t and returns the same""" 100 | zetas = self.give_zetas() 101 | exactly_zeros = np.sum(np.array(zetas)==0.0) 102 | exactly_ones = np.sum(np.array(zetas)==1.0) 103 | plt.hist(zetas) 104 | plt.show() 105 | return exactly_zeros, exactly_ones 106 | 107 | def get_crispnessLoss(self, device): 108 | """loss reponsible for making zeta_t 1 or 0""" 109 | loss = torch.FloatTensor([]).to(device) 110 | for l_block in self.prunable_modules: 111 | loss = torch.cat([loss, torch.pow(l_block.get_zeta_t()-l_block.get_zeta_i(), 2)]) 112 | return torch.mean(loss).to(device) 113 | 114 | def prune(self, Vc, budget_type = 'channel_ratio', finetuning=False, threshold=None): 115 | """prunes the network to make zeta_t exactly 1 and 0""" 116 | 117 | if budget_type == 'parameter_ratio': 118 | zetas = sorted(self.give_zetas()) 119 | high = len(zetas)-1 120 | low = 0 121 | while low0).sum().item() 16 | module.output_area = out_tensor.size(2) * out_tensor.size(3) 17 | conv_module.register_forward_hook(fo_hook) 18 | self._conv_module = conv_module 19 | beta=1. 20 | gamma=2. 21 | for n, x in zip(('beta', 'gamma'), (torch.tensor([x], requires_grad=False) for x in (beta, gamma))): 22 | self.register_buffer(n, x) # self.beta will be created (same for gamma, zeta) 23 | 24 | def forward(self, input): 25 | out = super(PrunableBatchNorm2d, self).forward(input) 26 | z = self.pruned_zeta if self.is_pruned else self.get_zeta_t() 27 | out *= z[None, :, None, None] # broadcast the mask to all samples in the batch, and all locations 28 | return out 29 | 30 | def get_zeta_i(self): 31 | return self.__generalized_logistic(self.zeta) 32 | 33 | def get_zeta_t(self): 34 | zeta_i = self.get_zeta_i() 35 | return self.__continous_heavy_side(zeta_i) 36 | 37 | def set_beta_gamma(self, beta, gamma): 38 | self.beta.data.copy_(torch.Tensor([beta])) 39 | self.gamma.data.copy_(torch.Tensor([gamma])) 40 | 41 | def __generalized_logistic(self, x): 42 | return 1./(1.+torch.exp(-self.beta*x)) 43 | 44 | def __continous_heavy_side(self, x): 45 | return 1-torch.exp(-self.gamma*x)+x*torch.exp(-self.gamma) 46 | 47 | def prune(self, threshold): 48 | self.is_pruned = True 49 | self.pruned_zeta = (self.get_zeta_t()>threshold).float() 50 | # if self.is_imp and self.pruned_zeta.sum()==0: 51 | # self.pruned_zeta[torch.argmax(self.get_zeta_t()).item()] = 1. 52 | self.zeta.requires_grad = False 53 | 54 | def unprune(self): 55 | self.is_pruned = False 56 | self.zeta.requires_grad = True 57 | 58 | def get_params_count(self): 59 | total_conv_params = self._conv_module.in_channels*self.pruned_zeta.shape[0]*self._conv_module.kernel_size[0]*self._conv_module.kernel_size[1] 60 | bn_params = self.num_gates*2 61 | active_bn_params = self.pruned_zeta.sum().item()*2 62 | active_conv_params = self._conv_module.num_input_active_channels*self.pruned_zeta.sum().item()*self._conv_module.kernel_size[0]*self._conv_module.kernel_size[1] 63 | return active_conv_params+active_bn_params, total_conv_params+bn_params 64 | 65 | def get_volume(self): 66 | total_volume = self._conv_module.output_area*self.num_gates 67 | active_volume = self._conv_module.output_area*self.pruned_zeta.sum().item() 68 | return active_volume, total_volume 69 | 70 | def get_flops(self): 71 | k_area = self._conv_module.kernel_size[0]*self._conv_module.kernel_size[1] 72 | total_flops = self._conv_module.output_area*self.num_gates*self._conv_module.in_channels*k_area 73 | active_flops = self._conv_module.output_area*self.pruned_zeta.sum().item()*self._conv_module.num_input_active_channels*k_area 74 | return active_flops, total_flops 75 | 76 | @staticmethod 77 | def from_batchnorm(bn_module, conv_module): 78 | new_bn = PrunableBatchNorm2d(bn_module.num_features, conv_module) 79 | return new_bn, conv_module 80 | 81 | 82 | class ModuleInjection: 83 | pruning_method = 'full' 84 | prunable_modules = [] 85 | 86 | @staticmethod 87 | def make_prunable(conv_module, bn_module): 88 | """Make a (conv, bn) sequence prunable. 89 | :param conv_module: A Conv2d module 90 | :param bn_module: The BatchNorm2d module following the Conv2d above 91 | :param prune_before_bn: Whether the pruning gates will be applied before or after the Batch Norm 92 | :return: a pair (conv, bn) that can be trained to 93 | """ 94 | if ModuleInjection.pruning_method == 'full': 95 | return conv_module, bn_module 96 | new_bn, conv_module = PrunableBatchNorm2d.from_batchnorm(bn_module, conv_module=conv_module) 97 | ModuleInjection.prunable_modules.append(new_bn) 98 | return conv_module, new_bn -------------------------------------------------------------------------------- /models/mobilenet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from .layers import ModuleInjection, PrunableBatchNorm2d 5 | from .base_model import BaseModel 6 | import numpy as np 7 | 8 | '''MobileNet in PyTorch. 9 | See the paper "MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications" 10 | for more details. 11 | Code is taken from https://github.com/kuangliu/pytorch-cifar/blob/master/models/mobilenet.py 12 | ''' 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | 17 | 18 | class Block(nn.Module): 19 | '''expand + depthwise + pointwise''' 20 | def __init__(self, in_planes, out_planes, expansion, stride): 21 | super(Block, self).__init__() 22 | self.stride = stride 23 | 24 | planes = expansion * in_planes 25 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, stride=1, padding=0, bias=False) 26 | self.bn1 = nn.BatchNorm2d(planes) 27 | self.conv1, self.bn1 = ModuleInjection.make_prunable(self.conv1, self.bn1) 28 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, groups=planes, bias=False) 29 | self.bn2 = nn.BatchNorm2d(planes) 30 | self.conv3 = nn.Conv2d(planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 31 | self.bn3 = nn.BatchNorm2d(out_planes) 32 | self.conv3, self.bn3 = ModuleInjection.make_prunable(self.conv3, self.bn3) 33 | 34 | self.shortcut = nn.Sequential() 35 | if stride == 1 and in_planes != out_planes: 36 | conv_module = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 37 | bn_module = nn.BatchNorm2d(out_planes) 38 | conv_module, bn_module = ModuleInjection.make_prunable(conv_module, bn_module) 39 | if hasattr(bn_module, 'is_imp'): 40 | bn_module.is_imp = True 41 | self.shortcut = nn.Sequential( 42 | conv_module, 43 | bn_module 44 | ) 45 | 46 | def forward(self, x): 47 | out = F.relu(self.bn1(self.conv1(x))) 48 | out = F.relu(self.bn2(self.conv2(out))) 49 | out = self.bn3(self.conv3(out)) 50 | out = out + self.shortcut(x) if self.stride==1 else out 51 | return out 52 | 53 | 54 | class MobileNetv2(BaseModel): 55 | # (expansion, out_planes, num_blocks, stride) 56 | cfg = [(1, 16, 1, 1), 57 | (6, 24, 2, 1), # NOTE: change stride 2 -> 1 for CIFAR10 58 | (6, 32, 3, 2), 59 | (6, 64, 4, 2), 60 | (6, 96, 3, 1), 61 | (6, 160, 3, 2), 62 | (6, 320, 1, 1)] 63 | 64 | def __init__(self, num_classes=10): 65 | super(MobileNetv2, self).__init__() 66 | # NOTE: change conv1 stride 2 -> 1 for CIFAR10 67 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False) 68 | self.bn1 = nn.BatchNorm2d(32) 69 | self.conv1, self.bn1 = ModuleInjection.make_prunable(self.conv1, self.bn1) 70 | if hasattr(self.bn1, 'is_imp'): 71 | self.bn1.is_imp = True 72 | self.layers = self._make_layers(in_planes=32) 73 | self.conv2 = nn.Conv2d(320, 1280, kernel_size=1, stride=1, padding=0, bias=False) 74 | self.bn2 = nn.BatchNorm2d(1280) 75 | self.conv2, self.bn2 = ModuleInjection.make_prunable(self.conv2, self.bn2) 76 | if hasattr(self.bn2, 'is_imp'): 77 | self.bn2.is_imp = True 78 | self.linear = nn.Linear(1280, num_classes) 79 | 80 | def _make_layers(self, in_planes): 81 | layers = [] 82 | for expansion, out_planes, num_blocks, stride in self.cfg: 83 | strides = [stride] + [1]*(num_blocks-1) 84 | for stride in strides: 85 | layers.append(Block(in_planes, out_planes, expansion, stride)) 86 | in_planes = out_planes 87 | return nn.Sequential(*layers) 88 | 89 | def forward(self, x): 90 | out = F.relu(self.bn1(self.conv1(x))) 91 | out = self.layers(out) 92 | out = F.relu(self.bn2(self.conv2(out))) 93 | # NOTE: change pooling kernel_size 7 -> 4 for CIFAR10 94 | out = F.avg_pool2d(out, 4) 95 | out = out.view(out.size(0), -1) 96 | out = self.linear(out) 97 | return out 98 | 99 | def removable_orphans(self): 100 | num_removed = 0 101 | for b in self.layers: 102 | m1, m2 = b.bn1, b.bn3 103 | if self.is_all_pruned(m1) or self.is_all_pruned(m2): 104 | num_removed += self.n_remaining(m1) + self.n_remaining(m2) 105 | return num_removed 106 | 107 | def remove_orphans(self): 108 | num_removed = 0 109 | for b in self.layers: 110 | m1, m2 = b.bn1, b.bn3 111 | if self.is_all_pruned(m1) or self.is_all_pruned(m2): 112 | num_removed += self.n_remaining(m1) + self.n_remaining(m2) 113 | m1.pruned_zeta.data.copy_(torch.zeros_like(m1.pruned_zeta)) 114 | m2.pruned_zeta.data.copy_(torch.zeros_like(m2.pruned_zeta)) 115 | return num_removed 116 | 117 | 118 | 119 | def get_mobilenet(model, method, num_classes): 120 | """Returns the requested model, ready for training/pruning with the specified method. 121 | 122 | :param model: str 123 | :param method: full or prune 124 | :param num_classes: int, num classes in the dataset 125 | :return: A prunable MobileNet model 126 | """ 127 | ModuleInjection.pruning_method = method 128 | ModuleInjection.prunable_modules = [] 129 | if model == 'mobilenetv2': 130 | net = MobileNetv2(num_classes) 131 | net.prunable_modules = ModuleInjection.prunable_modules 132 | return net 133 | -------------------------------------------------------------------------------- /models/models.py: -------------------------------------------------------------------------------- 1 | from .resnet import get_resnet_model 2 | from .network_slimming_resnet import get_network_slimming_model 3 | from .mobilenet import get_mobilenet 4 | def get_model(model, method, num_classes, insize): 5 | """Returns the requested model, ready for training/pruning with the specified method. 6 | 7 | :param model: str, model_name 8 | :param method: full or prune 9 | :param num_classes: int, num classes in the dataset 10 | :return: A prunable model 11 | """ 12 | 13 | if model in ['wrn', 'r50', 'r101','r110', 'r152', 'r32', 'r18', 'r56', 'r20']: 14 | net = get_resnet_model(model, method, num_classes, insize) 15 | elif model in ['r164']: 16 | net = get_network_slimming_model(method, num_classes) 17 | elif model in ['mobilenetv2']: 18 | net = get_mobilenet(model, method, num_classes) 19 | return net -------------------------------------------------------------------------------- /models/network_slimming_resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from .layers import ModuleInjection, PrunableBatchNorm2d 5 | from .base_model import BaseModel 6 | 7 | """ 8 | Code taken and modified from https://github.com/Eric-mingjie/network-slimming/blob/master/models/preresnet.py 9 | """ 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, inplanes, planes, cfg, stride=1, downsample=None): 14 | super(Bottleneck, self).__init__() 15 | self.bn1 = nn.BatchNorm2d(inplanes) 16 | _, self.bn1 = ModuleInjection.make_prunable(None, self.bn1) 17 | self.conv1 = nn.Conv2d(cfg[0], cfg[1], kernel_size=1, bias=False) 18 | self.bn2 = nn.BatchNorm2d(cfg[1]) 19 | _, self.bn2 = ModuleInjection.make_prunable(None, self.bn2) 20 | self.conv2 = nn.Conv2d(cfg[1], cfg[2], kernel_size=3, stride=stride, 21 | padding=1, bias=False) 22 | self.bn3 = nn.BatchNorm2d(cfg[2]) 23 | _, self.bn3 = ModuleInjection.make_prunable(None, self.bn3) 24 | self.conv3 = nn.Conv2d(cfg[2], planes * 4, kernel_size=1, bias=False) 25 | self.relu = nn.ReLU(inplace=True) 26 | self.downsample = downsample 27 | self.stride = stride 28 | 29 | def forward(self, x): 30 | residual = x 31 | 32 | out = self.bn1(x) 33 | out = self.relu(out) 34 | out = self.conv1(out) 35 | 36 | out = self.bn2(out) 37 | out = self.relu(out) 38 | out = self.conv2(out) 39 | 40 | out = self.bn3(out) 41 | out = self.relu(out) 42 | out = self.conv3(out) 43 | 44 | if self.downsample is not None: 45 | residual = self.downsample(x) 46 | 47 | out += residual 48 | 49 | return out 50 | 51 | class resnet(BaseModel): 52 | def __init__(self, num_classes, depth=164, cfg=None): 53 | super(resnet, self).__init__() 54 | assert (depth - 2) % 9 == 0, 'depth should be 9n+2' 55 | 56 | n = (depth - 2) // 9 57 | block = Bottleneck 58 | 59 | if cfg is None: 60 | # Construct config variable. 61 | cfg = [[16, 16, 16], [64, 16, 16]*(n-1), [64, 32, 32], [128, 32, 32]*(n-1), [128, 64, 64], [256, 64, 64]*(n-1), [256]] 62 | cfg = [item for sub_list in cfg for item in sub_list] 63 | 64 | self.inplanes = 16 65 | 66 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, 67 | bias=False) 68 | self.layer1 = self._make_layer(block, 16, n, cfg = cfg[0:3*n]) 69 | self.layer2 = self._make_layer(block, 32, n, cfg = cfg[3*n:6*n], stride=2) 70 | self.layer3 = self._make_layer(block, 64, n, cfg = cfg[6*n:9*n], stride=2) 71 | self.bn = nn.BatchNorm2d(64 * block.expansion) 72 | _, self.bn = ModuleInjection.make_prunable(None, self.bn) 73 | if hasattr(self.bn, 'is_imp'): 74 | self.bn.is_imp = True 75 | self.relu = nn.ReLU(inplace=True) 76 | self.avgpool = nn.AvgPool2d(8) 77 | 78 | self.fc = nn.Linear(cfg[-1], num_classes) 79 | 80 | for m in self.modules(): 81 | if isinstance(m, nn.Conv2d): 82 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 83 | m.weight.data.normal_(0, math.sqrt(2. / n)) 84 | elif isinstance(m, nn.BatchNorm2d): 85 | m.weight.data.fill_(0.5) 86 | m.bias.data.zero_() 87 | 88 | def _make_layer(self, block, planes, blocks, cfg, stride=1): 89 | downsample = None 90 | if stride != 1 or self.inplanes != planes * block.expansion: 91 | downsample = nn.Sequential( 92 | nn.Conv2d(self.inplanes, planes * block.expansion, 93 | kernel_size=1, stride=stride, bias=False), 94 | ) 95 | 96 | layers = [] 97 | layers.append(block(self.inplanes, planes, cfg[0:3], stride, downsample)) 98 | self.inplanes = planes * block.expansion 99 | for i in range(1, blocks): 100 | layers.append(block(self.inplanes, planes, cfg[3*i: 3*(i+1)])) 101 | 102 | return nn.Sequential(*layers) 103 | 104 | def forward(self, x): 105 | x = self.conv1(x) 106 | 107 | x = self.layer1(x) # 32x32 108 | x = self.layer2(x) # 16x16 109 | x = self.layer3(x) # 8x8 110 | x = self.bn(x) 111 | x = self.relu(x) 112 | 113 | x = self.avgpool(x) 114 | x = x.view(x.size(0), -1) 115 | x = self.fc(x) 116 | 117 | return x 118 | 119 | 120 | def removable_orphans(self): 121 | num_removed = 0 122 | for l_blocks in [self.layer1, self.layer2, self.layer3]: 123 | for b in l_blocks: 124 | m1, m2, m3 = b.bn1, b.bn2, b.bn3 125 | if self.is_all_pruned(m1) or self.is_all_pruned(m2) or self.is_all_pruned(m3): 126 | num_removed += self.n_remaining(m1) + self.n_remaining(m2) + self.n_remaining(m3) 127 | return num_removed 128 | 129 | def remove_orphans(self): 130 | num_removed = 0 131 | for l_blocks in [self.layer1, self.layer2, self.layer3]: 132 | for b in l_blocks: 133 | m1, m2, m3 = b.bn1, b.bn2, b.bn3 134 | if self.is_all_pruned(m1) or self.is_all_pruned(m2) or self.is_all_pruned(m3): 135 | num_removed += self.n_remaining(m1) + self.n_remaining(m2) + self.n_remaining(m3) 136 | m1.pruned_zeta.data.copy_(torch.zeros_like(m1.pruned_zeta)) 137 | m2.pruned_zeta.data.copy_(torch.zeros_like(m2.pruned_zeta)) 138 | m3.pruned_zeta.data.copy_(torch.zeros_like(m3.pruned_zeta)) 139 | return num_removed 140 | 141 | def get_network_slimming_model(method, num_classes): 142 | ModuleInjection.pruning_method = method 143 | ModuleInjection.prunable_modules = [] 144 | net = resnet(num_classes) 145 | net.prunable_modules = ModuleInjection.prunable_modules 146 | return net -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from .layers import ModuleInjection, PrunableBatchNorm2d 5 | from .base_model import BaseModel 6 | import numpy as np 7 | 8 | 9 | def conv3x3(in_planes, out_planes, stride=1): 10 | """3x3 convolution with padding""" 11 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 12 | 13 | class BasicBlock(nn.Module): 14 | expansion = 1 15 | 16 | def __init__(self, inplanes, planes, stride=1, downsample=None): 17 | super(BasicBlock, self).__init__() 18 | self.conv1 = conv3x3(inplanes, planes, stride) 19 | self.bn1 = nn.BatchNorm2d(planes) 20 | self.activ = nn.ReLU(inplace=True) 21 | self.conv2 = conv3x3(planes, planes) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.downsample = downsample 24 | self.stride = stride 25 | 26 | self.conv1, self.bn1 = ModuleInjection.make_prunable(self.conv1, self.bn1) 27 | self.conv2, self.bn2 = ModuleInjection.make_prunable(self.conv2, self.bn2) 28 | 29 | def forward(self, x): 30 | residual = x 31 | 32 | out = self.conv1(x) 33 | out = self.bn1(out) 34 | out = self.activ(out) 35 | 36 | out = self.conv2(out) 37 | out = self.bn2(out) 38 | 39 | if self.downsample is not None: 40 | residual = self.downsample(x) 41 | 42 | out += residual 43 | out = self.activ(out) 44 | 45 | return out 46 | 47 | class Bottleneck(nn.Module): 48 | expansion = 4 49 | 50 | def __init__(self, inplanes, planes, stride=1, downsample=None): 51 | super(Bottleneck, self).__init__() 52 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 53 | self.bn1 = nn.BatchNorm2d(planes) 54 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 55 | self.bn2 = nn.BatchNorm2d(planes) 56 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 57 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 58 | self.activ = nn.ReLU(inplace=True) 59 | self.downsample = downsample 60 | self.stride = stride 61 | 62 | self.conv1, self.bn1 = ModuleInjection.make_prunable(self.conv1, self.bn1) 63 | self.conv2, self.bn2 = ModuleInjection.make_prunable(self.conv2, self.bn2) 64 | self.conv3, self.bn3 = ModuleInjection.make_prunable(self.conv3, self.bn3) 65 | 66 | def forward(self, x): 67 | residual = x 68 | 69 | out = self.conv1(x) 70 | out = self.bn1(out) 71 | out = self.activ(out) 72 | 73 | out = self.conv2(out) 74 | out = self.bn2(out) 75 | out = self.activ(out) 76 | 77 | out = self.conv3(out) 78 | out = self.bn3(out) 79 | 80 | if self.downsample is not None: 81 | residual = self.downsample(x) 82 | 83 | out += residual 84 | out = self.activ(out) 85 | 86 | return out 87 | 88 | class ResNetCifar(BaseModel): 89 | def __init__(self, block, layers, width=1, num_classes=1000, insize=32): 90 | super(ResNetCifar, self).__init__() 91 | self.inplanes = 16 92 | self.insize = insize 93 | self.layers_size = layers 94 | self.num_classes = num_classes 95 | self.width = width 96 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 97 | self.bn1 = nn.BatchNorm2d(16) 98 | self.conv1, self.bn1 = ModuleInjection.make_prunable(self.conv1, self.bn1) 99 | self.prev_module[self.bn1]=None 100 | self.activ = nn.ReLU(inplace=True) 101 | self.layer1 = self._make_layer(block, 16 * width, layers[0]) 102 | self.layer2 = self._make_layer(block, 32 * width, layers[1], stride=2) 103 | self.layer3 = self._make_layer(block, 64 * width, layers[2], stride=2) 104 | self.avgpool = nn.AdaptiveAvgPool2d(output_size=1) 105 | self.fc = nn.Linear(64 * width, num_classes) 106 | self.init_weights() 107 | 108 | assert block is BasicBlock 109 | prev = self.bn1 110 | for l_block in [self.layer1, self.layer2, self.layer3]: 111 | for b in l_block: 112 | self.prev_module[b.bn1] = prev 113 | self.prev_module[b.bn2] = b.bn1 114 | if b.downsample is not None: 115 | self.prev_module[b.downsample[1]] = prev 116 | prev = b.bn2 117 | 118 | 119 | def _make_layer(self, block, planes, blocks, stride=1): 120 | downsample = None 121 | if stride != 1 or self.inplanes != planes * block.expansion: 122 | conv_module = nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False) 123 | bn_module = nn.BatchNorm2d(planes * block.expansion) 124 | conv_module, bn_module = ModuleInjection.make_prunable(conv_module, bn_module) 125 | if hasattr(bn_module, 'is_imp'): 126 | bn_module.is_imp = True 127 | downsample = nn.Sequential(conv_module, bn_module) 128 | 129 | layers = [] 130 | layers.append(block(self.inplanes, planes, stride, downsample)) 131 | self.inplanes = planes * block.expansion 132 | for i in range(1, blocks): 133 | layers.append(block(self.inplanes, planes)) 134 | 135 | return nn.Sequential(*layers) 136 | 137 | def forward(self, x): 138 | x = self.conv1(x) 139 | x = self.bn1(x) 140 | x = self.activ(x) 141 | 142 | x = self.layer1(x) 143 | x = self.layer2(x) 144 | x = self.layer3(x) 145 | 146 | x = self.avgpool(x) 147 | x = x.view(x.size(0), -1) 148 | x = self.fc(x) 149 | 150 | return x 151 | 152 | def removable_orphans(self): 153 | num_removed = 0 154 | for l_blocks in [self.layer1, self.layer2, self.layer3]: 155 | for b in l_blocks: 156 | m1, m2 = b.bn1, b.bn2 157 | if self.is_all_pruned(m1) or self.is_all_pruned(m2): 158 | num_removed += self.n_remaining(m1) + self.n_remaining(m2) 159 | return num_removed 160 | 161 | def remove_orphans(self): 162 | num_removed = 0 163 | for l_blocks in [self.layer1, self.layer2, self.layer3]: 164 | for b in l_blocks: 165 | m1, m2 = b.bn1, b.bn2 166 | if self.is_all_pruned(m1) or self.is_all_pruned(m2): 167 | num_removed += self.n_remaining(m1) + self.n_remaining(m2) 168 | m1.pruned_zeta.data.copy_(torch.zeros_like(m1.pruned_zeta)) 169 | m2.pruned_zeta.data.copy_(torch.zeros_like(m2.pruned_zeta)) 170 | return num_removed 171 | 172 | def calc_params(self, a): 173 | ans = a[0]*a[1]*9 174 | current_loc = 2 175 | current_max = a[1] 176 | downsample_n = a[2] 177 | do_downsample = True if self.width>1 else False 178 | for l in self.layers_size: 179 | for i in range(l): 180 | if do_downsample: 181 | downsample_n = a[current_loc] 182 | ans+=current_max*a[current_loc] 183 | current_loc+=1 184 | 185 | ans+=current_max*a[current_loc]*9 186 | ans+=a[current_loc]*a[current_loc+1]*9 187 | if do_downsample: 188 | current_max = max(downsample_n, a[current_loc+1]) 189 | else: 190 | current_max = max(current_max, a[current_loc+1]) 191 | do_downsample = False 192 | current_loc+=2 193 | do_downsample = True 194 | return ans + a[-1]*self.num_classes + 2*np.sum(a) 195 | 196 | def calc_flops(self, a): 197 | ans=a[0]*a[1]*9*self.insize**2 + a[1]*self.insize**2 198 | current_loc = 2 199 | current_max = a[1] 200 | downsample_n = a[2] 201 | size = self.insize*2 202 | do_downsample = True if self.width>1 else False 203 | for l in self.layers_size: 204 | for i in range(l): 205 | if do_downsample: 206 | downsample_n = a[current_loc] 207 | size = size//2 208 | ans+=(current_max+1)*a[current_loc]*size**2 209 | current_loc+=1 210 | 211 | ans+=current_max*a[current_loc]*9*size**2 + a[current_loc]*size**2 212 | ans+=a[current_loc]*a[current_loc+1]*9*size**2 + a[current_loc+1]*size**2 213 | if do_downsample: 214 | current_max = max(downsample_n, a[current_loc+1]) 215 | else: 216 | current_max = max(current_max, a[current_loc+1]) 217 | do_downsample = False 218 | current_loc+=2 219 | do_downsample = True 220 | return 2*ans + 2*(current_max-1)*100 221 | 222 | def params(self): 223 | a = [3] 224 | b = [3] 225 | for i in self.prunable_modules: 226 | a.append(int(i.pruned_zeta.sum())) 227 | b.append(len(i.pruned_zeta)) 228 | return self.calc_params(a)/self.calc_params(b) 229 | 230 | 231 | def flops(self): 232 | a = [3] 233 | b = [3] 234 | for i in self.prunable_modules: 235 | a.append(int(i.pruned_zeta.sum())) 236 | b.append(len(i.pruned_zeta)) 237 | return self.calc_flops(a)/self.calc_flops(b) 238 | 239 | class ResNet(BaseModel): 240 | def __init__(self, block, layers, width=1, num_classes=1000, produce_vectors=False, init_weights=True, insize=32): 241 | super(ResNet, self).__init__() 242 | self.layers_size = layers 243 | self.num_classes = num_classes 244 | self.insize = insize 245 | self.produce_vectors = produce_vectors 246 | self.block_type = block.__class__.__name__ 247 | self.inplanes = 64 248 | if insize<128: 249 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 250 | else: 251 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 252 | self.bn1 = nn.BatchNorm2d(64) 253 | self.conv1, self.bn1 = ModuleInjection.make_prunable(self.conv1, self.bn1) 254 | self.prev_module[self.bn1]=None 255 | self.activ = nn.ReLU(inplace=True) 256 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 257 | self.layer1 = self._make_layer(block, 64 * width, layers[0]) 258 | self.layer2 = self._make_layer(block, 128 * width, layers[1], stride=2) 259 | self.layer3 = self._make_layer(block, 256 * width, layers[2], stride=2) 260 | self.layer4 = self._make_layer(block, 512 * width, layers[3], stride=2) 261 | self.avgpool = nn.AdaptiveAvgPool2d(output_size=1) # Global Avg Pool 262 | self.fc = nn.Linear(512 * block.expansion * width, num_classes) 263 | 264 | self.init_weights() 265 | prev = self.bn1 266 | for l in [self.layer1, self.layer2, self.layer3, self.layer4]: 267 | for b in l: 268 | self.prev_module[b.bn1] = prev 269 | self.prev_module[b.bn2] = b.bn1 270 | self.prev_module[b.bn3] = b.bn2 271 | if b.downsample is not None: 272 | self.prev_module[b.downsample[1]] = prev 273 | prev = b.bn3 274 | for b in l.children(): 275 | downs = next(b.downsample.children()) if b.downsample is not None else None 276 | 277 | def _make_layer(self, block, planes, blocks, stride=1): 278 | downsample = None 279 | if stride != 1 or self.inplanes != planes * block.expansion: 280 | conv_module = nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False) 281 | bn_module = nn.BatchNorm2d(planes * block.expansion) 282 | conv_module, bn_module = ModuleInjection.make_prunable(conv_module, bn_module) 283 | if hasattr(bn_module, 'is_imp'): 284 | bn_module.is_imp = True 285 | downsample = nn.Sequential(conv_module, bn_module) 286 | 287 | layers = [] 288 | layers.append(block(self.inplanes, planes, stride, downsample)) 289 | self.inplanes = planes * block.expansion 290 | for i in range(1, blocks): 291 | layers.append(block(self.inplanes, planes)) 292 | 293 | return nn.Sequential(*layers) 294 | 295 | def forward(self, x): 296 | x = self.conv1(x) 297 | x = self.bn1(x) 298 | x = self.activ(x) 299 | x = self.maxpool(x) 300 | 301 | x = self.layer1(x) 302 | x = self.layer2(x) 303 | x = self.layer3(x) 304 | x = self.layer4(x) 305 | 306 | x = self.avgpool(x) 307 | feature_vectors = x.view(x.size(0), -1) 308 | x = self.fc(feature_vectors) 309 | 310 | if self.produce_vectors: 311 | return x, feature_vectors 312 | else: 313 | return x 314 | 315 | def removable_orphans(self): 316 | num_removed = 0 317 | for l_blocks in [self.layer1, self.layer2, self.layer3, self.layer4]: 318 | for b in l_blocks: 319 | if self.block_type == 'Bottleneck': 320 | m1, m2, m3 = b.bn1, b.bn2, b.bn3 321 | if self.is_all_pruned(m1) or self.is_all_pruned(m2) or self.is_all_pruned(m3): 322 | num_removed += self.n_remaining(m1) + self.n_remaining(m2) + self.n_remaining(m3) 323 | else: 324 | m1, m2 = b.bn1, b.bn2 325 | if self.is_all_pruned(m1) or self.is_all_pruned(m2): 326 | num_removed += self.n_remaining(m1) + self.n_remaining(m2) 327 | return num_removed 328 | 329 | def remove_orphans(self): 330 | num_removed = 0 331 | for l_blocks in [self.layer1, self.layer2, self.layer3, self.layer4]: 332 | for b in l_blocks: 333 | if self.block_type == 'Bottleneck': 334 | m1, m2, m3 = b.bn1, b.bn2, b.bn3 335 | if self.is_all_pruned(m1) or self.is_all_pruned(m2) or self.is_all_pruned(m3): 336 | num_removed += self.n_remaining(m1) + self.n_remaining(m2) + self.n_remaining(m3) 337 | m1.pruned_zeta.data.copy_(torch.zeros_like(m1.pruned_zeta)) 338 | m2.pruned_zeta.data.copy_(torch.zeros_like(m2.pruned_zeta)) 339 | m3.pruned_zeta.data.copy_(torch.zeros_like(m3.pruned_zeta)) 340 | else: 341 | m1, m2 = b.bn1, b.bn2 342 | if self.is_all_pruned(m1) or self.is_all_pruned(m2): 343 | num_removed += self.n_remaining(m1) + self.n_remaining(m2) 344 | m1.pruned_zeta.data.copy_(torch.zeros_like(m1.pruned_zeta)) 345 | m2.pruned_zeta.data.copy_(torch.zeros_like(m2.pruned_zeta)) 346 | return num_removed 347 | 348 | def calc_params(self, a): 349 | ans = a[0]*a[1]*9 350 | current_loc = 2 351 | current_max = a[1] 352 | downsample_n = a[2] 353 | do_downsample = True 354 | for l in self.layers_size: 355 | for i in range(l): 356 | if do_downsample: 357 | downsample_n = a[current_loc] 358 | ans+=current_max*a[current_loc] 359 | current_loc+=1 360 | 361 | ans+=current_max*a[current_loc]*1 362 | ans+=a[current_loc]*a[current_loc+1]*9 363 | ans+=a[current_loc+1]*a[current_loc+2]*1 364 | if do_downsample: 365 | current_max = max(downsample_n, a[current_loc+2]) 366 | else: 367 | current_max = max(current_max, a[current_loc+2]) 368 | do_downsample = False 369 | current_loc+=3 370 | do_downsample = True 371 | return ans + a[-1]*self.num_classes + 2*np.sum(a) 372 | 373 | def params(self): 374 | a = [3] 375 | b = [3] 376 | for i in self.prunable_modules: 377 | a.append(int(i.pruned_zeta.sum())) 378 | b.append(len(i.pruned_zeta)) 379 | return self.calc_params(a)/self.calc_params(b) 380 | 381 | def calc_flops(self, a): 382 | ans=a[0]*a[1]*9*self.insize**2 + a[1]*self.insize**2 383 | current_loc = 2 384 | current_max = a[1] 385 | downsample_n = a[2] 386 | size = self.insize*2 387 | do_downsample = True 388 | for l in self.layers_size: 389 | for i in range(l): 390 | if do_downsample: 391 | downsample_n = a[current_loc] 392 | size = size//2 393 | ans+=(current_max+1)*a[current_loc]*size**2 394 | current_loc+=1 395 | 396 | ans+=current_max*a[current_loc]*1*size**2 + a[current_loc]*size**2 397 | ans+=a[current_loc]*a[current_loc+1]*9*size**2 + a[current_loc+1]*size**2 398 | ans+=a[current_loc+1]*a[current_loc+2]*1*size**2 + a[current_loc+2]*size**2 399 | if do_downsample: 400 | current_max = max(downsample_n, a[current_loc+2]) 401 | else: 402 | current_max = max(current_max, a[current_loc+2]) 403 | do_downsample = False 404 | current_loc+=3 405 | do_downsample = True 406 | return 2*ans + 2*(current_max-1)*100 407 | 408 | def flops(self): 409 | a = [3] 410 | b = [3] 411 | for i in self.prunable_modules: 412 | a.append(int(i.pruned_zeta.sum())) 413 | b.append(len(i.pruned_zeta)) 414 | return self.calc_flops(a)/self.calc_flops(b) 415 | 416 | 417 | def make_wide_resnet(num_classes, insize): 418 | model = ResNetCifar(BasicBlock, [4, 4, 4], width=12, num_classes=num_classes, insize=insize) 419 | return model 420 | 421 | def make_resnet20(num_classes, insize): 422 | model = ResNetCifar(BasicBlock, [3, 3, 3], width=1, num_classes=num_classes, insize=insize) 423 | return model 424 | 425 | def make_resnet32(num_classes, insize): 426 | model = ResNetCifar(BasicBlock, [5, 5, 5], width=1, num_classes=num_classes, insize=insize) 427 | return model 428 | 429 | def make_resnet50(num_classes, insize): 430 | model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, insize=insize) 431 | return model 432 | 433 | def make_resnet56(num_classes, insize): 434 | model = ResNetCifar(BasicBlock, [9, 9, 9], width=1, num_classes=num_classes, insize=insize) 435 | return model 436 | 437 | def make_resnet18(num_classes, insize): 438 | model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, insize=insize) 439 | return model 440 | 441 | def make_resnet101(num_classes, insize): 442 | model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, insize=insize) 443 | return model 444 | 445 | def make_resnet110(num_classes, insize): 446 | model = ResNetCifar(BasicBlock, [18, 18, 18], width=1, num_classes=num_classes, insize=insize) 447 | return model 448 | 449 | def make_resnet152(num_classes, insize): 450 | model = ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes, insize=insize) 451 | return model 452 | 453 | def get_resnet_model(model, method, num_classes, insize): 454 | """Returns the requested model, ready for training/pruning with the specified method. 455 | 456 | :param model: str, either wrn or r50 457 | :param method: full or prune 458 | :param num_classes: int, num classes in the dataset 459 | :return: A prunable ResNet model 460 | """ 461 | ModuleInjection.pruning_method = method 462 | ModuleInjection.prunable_modules = [] 463 | if model == 'wrn': 464 | net = make_wide_resnet(num_classes, insize) 465 | elif model == 'r18': 466 | net = make_resnet18(num_classes, insize) 467 | elif model == 'r20': 468 | net = make_resnet20(num_classes, insize) 469 | elif model == 'r32': 470 | net = make_resnet32(num_classes, insize) 471 | elif model == 'r50': 472 | net = make_resnet50(num_classes, insize) 473 | elif model == 'r56': 474 | net = make_resnet56(num_classes, insize) 475 | elif model == 'r101': 476 | net = make_resnet101(num_classes, insize) 477 | elif model == 'r110': 478 | net = make_resnet110(num_classes, insize) 479 | elif model == 'r152': 480 | net = make_resnet152(num_classes, insize) 481 | net.prunable_modules = ModuleInjection.prunable_modules 482 | return net 483 | -------------------------------------------------------------------------------- /pretraining.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | import numpy as np 6 | import pandas as pd 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from tqdm import tqdm as tqdm_notebook 10 | from datasets import DataManager 11 | from utils import * 12 | from models import get_model 13 | 14 | seed_everything(43) 15 | 16 | ap = argparse.ArgumentParser(description='pretraining') 17 | ap.add_argument('dataset', choices=['c10', 'c100', 'tin','svhn'], type=str, help='Dataset choice') 18 | ap.add_argument('model', type=str, help='Model choice') 19 | ap.add_argument('--test_only', '-t', type=bool, default=False, help='test the best model') 20 | ap.add_argument('--valid_size', '-v', type=float, default=0.1, help='valid_size') 21 | ap.add_argument('--batch_size', default=128, type=int, help='Batch Size') 22 | ap.add_argument('--lr', default=0.05, type=float, help='Learning rate') 23 | ap.add_argument('--scheduler_type', '-st', type=int, choices=[1, 2], default=1, help='lr scheduler type') 24 | ap.add_argument('--decay', '-d', type=float, default=0.001, help='weight decay') 25 | ap.add_argument('--epochs', default=200, type=int, help='Epochs') 26 | ap.add_argument('--workers', default=0, type=int, help='number of workers') 27 | ap.add_argument('--cuda_id', '-id', type=str, default='0', help='gpu number') 28 | args = ap.parse_args() 29 | 30 | ############################### preparing dataset ################################ 31 | 32 | data_object = DataManager(args) 33 | trainloader, valloader, testloader = data_object.prepare_data() 34 | dataloaders = { 35 | 'train': trainloader, 'val': valloader, "test": testloader 36 | } 37 | 38 | ############################### preparing model ################################### 39 | 40 | model = get_model(args.model, 'full', data_object.num_classes, data_object.insize) 41 | 42 | ############################## preparing for training ############################# 43 | 44 | if os.path.exists('logs') == False: 45 | os.mkdir("logs") 46 | 47 | if os.path.exists('checkpoints') == False: 48 | os.mkdir("checkpoints") 49 | 50 | criterion = nn.CrossEntropyLoss() 51 | 52 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.decay) 53 | 54 | device = torch.device(f"cuda:{str(args.cuda_id)}") 55 | 56 | model.to(device) 57 | 58 | def train(model, loss_fn, optimizer, scheduler=None): 59 | model.train() 60 | counter = 0 61 | tk1 = tqdm_notebook(dataloaders['train'], total=len(dataloaders['train'])) 62 | running_loss = 0 63 | for x_var, y_var in tk1: 64 | counter +=1 65 | x_var = x_var.to(device=device) 66 | y_var = y_var.to(device=device) 67 | scores = model(x_var) 68 | 69 | loss = loss_fn(scores, y_var) 70 | running_loss+=loss.item() 71 | tk1.set_postfix(loss=running_loss/counter) 72 | optimizer.zero_grad() 73 | loss.backward() 74 | optimizer.step() 75 | return running_loss/counter 76 | 77 | def test(model, loss_fn, optimizer, phase, scheduler=None): 78 | model.eval() 79 | counter = 0 80 | tk1 = tqdm_notebook(dataloaders[phase], total=len(dataloaders[phase])) 81 | running_loss = 0 82 | running_acc = 0 83 | total = 0 84 | with torch.no_grad(): 85 | for x_var, y_var in tk1: 86 | counter +=1 87 | x_var = x_var.to(device=device) 88 | y_var = y_var.to(device=device) 89 | scores = model(x_var) 90 | loss = loss_fn(scores, y_var) 91 | _, scores = torch.max(scores.data, 1) 92 | y_var = y_var.cpu().detach().numpy() 93 | scores = scores.cpu().detach().numpy() 94 | 95 | correct = (scores == y_var).sum().item() 96 | running_loss+=loss.item() 97 | running_acc+=correct 98 | total+=scores.shape[0] 99 | tk1.set_postfix(loss=running_loss/counter, acc=running_acc/total) 100 | return running_acc/total, running_loss/counter 101 | 102 | ###################################### training starts here ############################ 103 | 104 | best_acc = 0 105 | num_epochs = args.epochs 106 | train_losses = [] 107 | valid_losses = [] 108 | valid_accuracy = [] 109 | if args.test_only == False: 110 | for epoch in range(num_epochs): 111 | adjust_learning_rate(optimizer, epoch, args) 112 | print('Starting epoch %d / %d' % (epoch + 1, num_epochs)) 113 | t_loss = train(model, criterion, optimizer) 114 | acc, v_loss = test(model, criterion, optimizer, "val") 115 | 116 | if acc>best_acc: 117 | print("**Saving model**") 118 | best_acc=acc 119 | torch.save({ 120 | "epoch": epoch + 1, 121 | "state_dict" : model.state_dict(), 122 | "acc" : best_acc, 123 | }, f"checkpoints/{args.model}_{args.dataset}_pretrained.pth") 124 | 125 | train_losses.append(t_loss) 126 | valid_losses.append(v_loss) 127 | valid_accuracy.append(acc) 128 | df_data=np.array([train_losses, valid_losses, valid_accuracy]).T 129 | df = pd.DataFrame(df_data, columns = ['train_losses','valid_losses','valid_accuracy']) 130 | df.to_csv(f'logs/{args.model}_{args.dataset}_pretrained.csv') 131 | 132 | state = torch.load(f"checkpoints/{args.model}_{args.dataset}_pretrained.pth") 133 | model.load_state_dict(state['state_dict'],strict=True) 134 | acc, v_loss = test(model, criterion, optimizer, "test") 135 | print(f"Test Accuracy: {acc} | Valid Accuracy: {state['acc']}") -------------------------------------------------------------------------------- /pruning.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | import numpy as np 6 | import pandas as pd 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from tqdm import tqdm as tqdm_notebook 10 | 11 | from utils import * 12 | from models import get_model 13 | from datasets import DataManager 14 | 15 | seed_everything(43) 16 | 17 | ap = argparse.ArgumentParser(description='pruning with heaviside continuous approximations and logistic curves') 18 | ap.add_argument('dataset', choices=['c10', 'c100', 'tin','svhn'], type=str, help='Dataset choice') 19 | ap.add_argument('model', type=str, help='Model choice') 20 | ap.add_argument('--budget_type', choices=['channel_ratio', 'volume_ratio','parameter_ratio','flops_ratio'], default='channel_ratio', type=str, help='Budget Type') 21 | ap.add_argument('--Vc', default=0.25, type=float, help='Budget Constraint') 22 | ap.add_argument('--batch_size', default=32, type=int, help='Batch Size') 23 | ap.add_argument('--epochs', default=20, type=int, help='Epochs') 24 | ap.add_argument('--workers', default=0, type=int, help='Number of CPU workers') 25 | ap.add_argument('--valid_size', '-v', type=float, default=0.1, help='valid_size') 26 | ap.add_argument('--lr', default=0.001, type=float, help='Learning rate') 27 | ap.add_argument('--test_only','-t', default=False, type=bool, help='Testing') 28 | 29 | ap.add_argument('--decay', default=0.001, type=float, help='Weight decay') 30 | ap.add_argument('--w1', default=30., type=float, help='weightage to budget loss') 31 | ap.add_argument('--w2', default=10., type=float, help='weightage to crispness loss') 32 | ap.add_argument('--b_inc', default=5., type=float, help='beta increment') 33 | ap.add_argument('--g_inc', default=2., type=float, help='gamma increment') 34 | 35 | ap.add_argument('--cuda_id', '-id', type=str, default='0', help='gpu number') 36 | args = ap.parse_args() 37 | 38 | valid_size = args.valid_size 39 | BATCH_SIZE = args.batch_size 40 | Vc = torch.FloatTensor([args.Vc]) 41 | 42 | ############################### preparing dataset ################################ 43 | 44 | data_object = DataManager(args) 45 | trainloader, valloader, testloader = data_object.prepare_data() 46 | dataloaders = { 47 | 'train': trainloader, 'val': valloader, "test": testloader 48 | } 49 | 50 | ############################### preparing model ################################### 51 | 52 | model = get_model(args.model, 'prune', data_object.num_classes, data_object.insize) 53 | state = torch.load(f"checkpoints/{args.model}_{args.dataset}_pretrained.pth") 54 | model.load_state_dict(state['state_dict'], strict=False) 55 | 56 | ############################### preparing for pruning ################################### 57 | 58 | if os.path.exists('logs') == False: 59 | os.mkdir("logs") 60 | 61 | if os.path.exists('checkpoints') == False: 62 | os.mkdir("checkpoints") 63 | 64 | 65 | weightage1 = args.w1 #weightage given to budget loss 66 | weightage2 = args.w2 #weightage given to crispness loss 67 | steepness = 10. # steepness of gate_approximator 68 | 69 | CE = nn.CrossEntropyLoss() 70 | def criterion(model, y_pred, y_true): 71 | global steepness 72 | ce_loss = CE(y_pred, y_true) 73 | budget_loss = ((model.get_remaining(steepness, args.budget_type).to(device)-Vc.to(device))**2).to(device) 74 | crispness_loss = model.get_crispnessLoss(device) 75 | return budget_loss*weightage1 + crispness_loss*weightage2 + ce_loss 76 | 77 | param_optimizer = list(model.named_parameters()) 78 | no_decay = ["zeta"] 79 | optimizer_parameters = [ 80 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': args.decay,'lr':args.lr}, 81 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0,'lr':args.lr}, 82 | ] 83 | optimizer = optim.AdamW(optimizer_parameters) 84 | 85 | device = torch.device(f"cuda:{str(args.cuda_id)}") 86 | model.to(device) 87 | Vc.to(device) 88 | 89 | def train(model, loss_fn, optimizer, epoch): 90 | global steepness 91 | model.train() 92 | counter = 0 93 | tk1 = tqdm_notebook(dataloaders['train'], total=len(dataloaders['train'])) 94 | running_loss = 0 95 | for x_var, y_var in tk1: 96 | counter +=1 97 | x_var = x_var.to(device=device) 98 | y_var = y_var.to(device=device) 99 | scores = model(x_var) 100 | loss = loss_fn(model,scores, y_var) 101 | optimizer.zero_grad() 102 | loss.backward() 103 | running_loss+=loss.item() 104 | tk1.set_postfix(loss=running_loss/counter) 105 | optimizer.step() 106 | steepness=min(60,steepness+5./len(tk1)) 107 | return running_loss/counter 108 | 109 | def test(model, loss_fn, optimizer, phase, epoch): 110 | model.eval() 111 | counter = 0 112 | tk1 = tqdm_notebook(dataloaders[phase], total=len(dataloaders[phase])) 113 | running_loss = 0 114 | running_acc = 0 115 | total = 0 116 | with torch.no_grad(): 117 | for x_var, y_var in tk1: 118 | counter +=1 119 | x_var = x_var.to(device=device) 120 | y_var = y_var.to(device=device) 121 | scores = model(x_var) 122 | loss = loss_fn(model,scores, y_var) 123 | _, scores = torch.max(scores.data, 1) 124 | y_var = y_var.cpu().detach().numpy() 125 | scores = scores.cpu().detach().numpy() 126 | 127 | correct = (scores == y_var).sum().item() 128 | running_loss+=loss.item() 129 | running_acc+=correct 130 | total+=scores.shape[0] 131 | tk1.set_postfix(loss=(running_loss /counter), acc=(running_acc/total)) 132 | return running_acc/total 133 | 134 | best_acc = 0 135 | beta, gamma = 1., 2. 136 | model.set_beta_gamma(beta, gamma) 137 | 138 | remaining_before_pruning = [] 139 | remaining_after_pruning = [] 140 | valid_accuracy = [] 141 | pruning_accuracy = [] 142 | pruning_threshold = [] 143 | # exact_zeros = [] 144 | # exact_ones = [] 145 | problems = [] 146 | name = f'{args.model}_{args.dataset}_{str(np.round(Vc.item(),decimals=6))}_{args.budget_type}_pruned' 147 | if args.test_only == False: 148 | for epoch in range(args.epochs): 149 | print(f'Starting epoch {epoch + 1} / {args.epochs}') 150 | model.unprune() 151 | train(model, criterion, optimizer, epoch) 152 | print(f'[{epoch + 1} / {args.epochs}] Validation before pruning') 153 | acc = test(model, criterion, optimizer, "val", epoch) 154 | remaining = model.get_remaining(steepness, args.budget_type).item() 155 | remaining_before_pruning.append(remaining) 156 | valid_accuracy.append(acc) 157 | # exactly_zeros, exactly_ones = model.plot_zt() 158 | # exact_zeros.append(exactly_zeros) 159 | # exact_ones.append(exactly_ones) 160 | 161 | print(f'[{epoch + 1} / {args.epochs}] Validation after pruning') 162 | threshold, problem = model.prune(args.Vc, args.budget_type) 163 | acc = test(model, criterion, optimizer, "val", epoch) 164 | remaining = model.get_remaining(steepness, args.budget_type).item() 165 | pruning_accuracy.append(acc) 166 | pruning_threshold.append(threshold) 167 | remaining_after_pruning.append(remaining) 168 | problems.append(problem) 169 | 170 | # 171 | beta=min(6., beta+(0.1/args.b_inc)) 172 | gamma=min(256, gamma*(2**(1./args.g_inc))) 173 | model.set_beta_gamma(beta, gamma) 174 | print("Changed beta to", beta, "changed gamma to", gamma) 175 | 176 | if acc>best_acc: 177 | print("**Saving checkpoint**") 178 | best_acc=acc 179 | torch.save({ 180 | "epoch" : epoch+1, 181 | "beta" : beta, 182 | "gamma" : gamma, 183 | "prune_threshold":threshold, 184 | "state_dict" : model.state_dict(), 185 | "accuracy" : acc, 186 | }, f"checkpoints/{name}.pth") 187 | 188 | df_data=np.array([remaining_before_pruning, remaining_after_pruning, valid_accuracy, pruning_accuracy, pruning_threshold, problems]).T 189 | df = pd.DataFrame(df_data,columns = ['Remaining before pruning', 'Remaining after pruning', 'Valid accuracy', 'Pruning accuracy', 'Pruning threshold', 'problems']) 190 | df.to_csv(f"logs/{name}.csv") -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.18.1 2 | pandas==0.24.2 3 | torch==1.4.0 4 | torchvision==0.5.0 5 | matplotlib==3.0.3 6 | tqdm==4.48.2 7 | glob2==0.6 8 | -------------------------------------------------------------------------------- /script.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | echo 3 | mkdir checkpoints 4 | cd checkpoints 5 | wget https://transfer.sh/IhuwC/wrn_c100_pretrained.pth 6 | cd ../ 7 | 8 | # python pretraining.py $1 $2 --epochs 160 --batch_size 64 9 | python pruning.py $1 $2 --Vc 0.0625 --budget_type 'flops_ratio' 10 | python finetuning.py $1 $2 --name $2\_$1\_0.0625\_parameter\_ratio --epochs 300 --Vc 0.0625 --budget_type flops_ratio 11 | # python pruning.py $1 $2 --Vc 0.4 12 | # python pruning.py $1 $2 --Vc 0.2 13 | # python pruning.py $1 $2 --Vc 0.1 --w1 45. 14 | 15 | # python finetuning.py $1 $2 --name $2\_$1\_0.6\_channel\_ratio --epochs 160 --Vc 0.6 --batch_size 64 16 | # python finetuning.py $1 $2 --name $2\_$1\_0.4\_channel\_ratio --epochs 160 --Vc 0.4 --batch_size 64 17 | # python finetuning.py $1 $2 --name $2\_$1\_0.2\_channel\_ratio --epochs 160 --Vc 0.2 --batch_size 64 18 | # python finetuning.py $1 $2 --name $2\_$1\_0.1\_channel\_ratio --epochs 160 --Vc 0.1 --batch_size 64 -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import torch 4 | import torch.nn as nn 5 | import os 6 | import random 7 | import pandas as pd 8 | 9 | def seed_everything(seed): 10 | random.seed(seed) 11 | os.environ['PYTHONHASHSEED'] = str(seed) 12 | np.random.seed(seed) 13 | torch.manual_seed(seed) 14 | torch.cuda.manual_seed(seed) 15 | torch.backends.cudnn.deterministic = True 16 | 17 | def get_mask_dict(own_state, state_dict): 18 | for name, param in state_dict.items(): 19 | if name not in own_state: 20 | continue 21 | if 'zeta' not in name and 'beta' not in name and 'gamma' not in name: 22 | continue 23 | if isinstance(param, nn.Parameter): 24 | # backwards compatibility for serialized parameters 25 | param = param.data 26 | own_state[name].copy_(param) 27 | return own_state 28 | 29 | def adjust_learning_rate(optimizer, epoch, args): 30 | """Sets the learning rate to the initial LR decayed by 2 every 30 epochs""" 31 | if args.scheduler_type==1: 32 | lr = args.lr * (0.5 ** (epoch // 30)) 33 | for param_group in optimizer.param_groups: 34 | param_group['lr'] = lr 35 | else: 36 | if epoch in [args.epochs*0.5, args.epochs*0.75]: 37 | for param_group in optimizer.param_groups: 38 | param_group['lr'] *= 0.1 39 | 40 | def plot_learning_curves(logger_name): 41 | train_loss = [] 42 | val_loss = [] 43 | val_acc = [] 44 | df = pd.read_csv('logs/'+logger_name) 45 | 46 | train_loss = df.iloc[1:,1] 47 | val_loss = df.iloc[1:,2] 48 | val_acc = df.iloc[1:,3]*100 49 | 50 | plt.style.use('seaborn') 51 | plt.plot(np.arange(len(train_loss)), train_loss, label = 'Training error') 52 | plt.plot(np.arange(len(train_loss)), val_loss, label = 'Validation error') 53 | plt.ylabel('Loss', fontsize = 14) 54 | plt.xlabel('Epochs', fontsize = 14) 55 | plt.title('Loss Curve', fontsize = 18, y = 1.03) 56 | plt.legend() 57 | plt.ylim(0,4) 58 | plt.show() 59 | print() 60 | 61 | plt.style.use('seaborn') 62 | plt.plot(np.arange(len(train_loss)), val_acc, label = 'Validation Accuracy') 63 | plt.ylabel('Accuracy', fontsize = 14) 64 | plt.xlabel('Epochs', fontsize = 14) 65 | plt.title('Accuracy curve', fontsize = 18, y = 1.03) 66 | plt.legend() 67 | plt.ylim(0,100) 68 | plt.show() 69 | print() 70 | 71 | 72 | 73 | def visualize_model_architecture(model, budget, budget_type): 74 | pruned_model = [3,] 75 | full_model = [3,] 76 | device = torch.device('cpu') 77 | model.to(device) 78 | model(torch.rand(1,3,32,32)) 79 | model.prepare_for_finetuning(device=device,budget=budget,budget_type=budget_type) 80 | for l_block in model.prunable_modules: 81 | gates = l_block.pruned_zeta.cpu().detach().numpy().tolist() 82 | full_model.append(len(gates)) 83 | pruned_model.append(np.sum(gates)) 84 | fig = plt.figure() 85 | ax = fig.add_axes([0,0,1,1]) 86 | full_model = np.array(full_model) 87 | pruned_model = np.array(pruned_model) 88 | ax.bar(np.arange(len(full_model)), full_model, width = 0.5, color = 'b') 89 | ax.bar(np.arange(len(pruned_model)), pruned_model, width = 0.5, color = 'r') 90 | print(full_model) 91 | print(pruned_model) 92 | plt.show() 93 | if hasattr(model, 'calc_params') and budget_type!='parameter_ratio': 94 | total_params = model.calc_params(full_model) 95 | active_params = model.calc_params(pruned_model) 96 | else: 97 | active_params, total_params = model.get_params_count() 98 | if hasattr(model, 'calc_flops') and budget_type!='flops_ratio': 99 | total_flops = model.calc_flops(full_model) 100 | active_flops = model.calc_flops(pruned_model) 101 | else: 102 | active_flops, total_flops = model.get_flops() 103 | active_volume, total_volume = model.get_volume() 104 | active_channels, total_channels = model.get_channels() 105 | 106 | print(f'\nTotal parameter count: {total_params}') 107 | print(f'Remaining parameter count: {active_params}') 108 | print(f'Remaining Parameter Fraction: {active_params/total_params}') 109 | 110 | print(f'\nTotal volume count: {total_volume}') 111 | print(f'Remaining volume count: {active_volume}') 112 | print(f'Remaining volume Fraction: {active_volume/total_volume}') 113 | 114 | print(f'\nTotal flops count: {total_flops}') 115 | print(f'Remaining flops count: {active_flops}') 116 | print(f'Remaining flops Fraction: {active_flops/total_flops}') 117 | 118 | print(f'\nTotal channels count: {total_channels}') 119 | print(f'Remaining channels count: {active_channels}') 120 | print(f'Remaining channels Fraction: {active_channels/total_channels}') 121 | 122 | return [full_model, pruned_model] 123 | 124 | --------------------------------------------------------------------------------