├── sensAI-logo.png ├── utils ├── images │ ├── cifar.png │ └── imagenet.png ├── __init__.py ├── eval.py ├── misc.py ├── visualize.py └── logger.py ├── models ├── imagenet │ ├── __init__.py │ └── resnext.py └── cifar │ ├── __init__.py │ ├── mobilenetv2.py │ ├── wrn.py │ ├── vgg.py │ ├── densenet.py │ ├── resnet.py │ ├── resnext.py │ └── shufflenetv2.py ├── requirements.txt ├── scripts ├── activations_grouped_5_5_vgg19.sh ├── train_pruned_grouped.sh ├── activations_grouped_vgg19.sh ├── activations_grouped_vgg19_cifar100.sh ├── activations_grouped_resnet110_cifar100.sh ├── activations_grouped_resnet164_cifar100.sh ├── train_pruned_grouped.py └── training_scheduler.py ├── load_model.py ├── datasets ├── utils.py └── cifar.py ├── .gitignore ├── apoz_policy_imagenet.py ├── retrain_grouped_model.py ├── get_prune_candidates.py ├── compute_flops.py ├── README.md ├── logger.py ├── apoz_policy.py ├── group_selection.py ├── imagenet_evaluate_grouped.py ├── imagenet_dataset.py ├── prune_utils ├── layer_prune.py └── prune.py ├── even_k_means.py ├── regularize_model.py ├── LICENSE.md ├── prune_and_get_model.py ├── imagenet_activations.py └── evaluate.py /sensAI-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuanhuaWang/sensAI/HEAD/sensAI-logo.png -------------------------------------------------------------------------------- /utils/images/cifar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuanhuaWang/sensAI/HEAD/utils/images/cifar.png -------------------------------------------------------------------------------- /models/imagenet/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .resnext import * 4 | -------------------------------------------------------------------------------- /utils/images/imagenet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuanhuaWang/sensAI/HEAD/utils/images/imagenet.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | numpy 3 | torch==1.5.0 4 | torchvision==0.6.0 5 | tqdm==4.46.1 6 | scikit-learn==0.21.3 7 | -------------------------------------------------------------------------------- /models/cifar/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .vgg import * 4 | from .resnet import * 5 | from .mobilenetv2 import * 6 | from .shufflenetv2 import * 7 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Useful utils 2 | """ 3 | from .misc import * 4 | from .logger import * 5 | from .visualize import * 6 | from .eval import * 7 | 8 | # progress bar 9 | import os, sys 10 | sys.path.append(os.path.join(os.path.dirname(__file__), "progress")) -------------------------------------------------------------------------------- /scripts/activations_grouped_5_5_vgg19.sh: -------------------------------------------------------------------------------- 1 | MODEL=./checkpoint_bearclaw.pth.tar 2 | rm -r prune_candidate_logs 3 | mkdir prune_candidate_logs 4 | 5 | python3 get_prune_candidates.py -a vgg19_bn --resume $MODEL --evaluate --grouped 1 3 5 7 9 6 | python3 get_prune_candidates.py -a vgg19_bn --resume $MODEL --evaluate --grouped 2 4 6 8 0 7 | -------------------------------------------------------------------------------- /scripts/train_pruned_grouped.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | EPOCHS=2 3 | FROM=./pruned_models 4 | SAVE=${FROM}_retrained 5 | mkdir ${SAVE} 6 | rm ${SAVE}/* -r 7 | mkdir $SAVE/resnet164/ 8 | mkdir $SAVE/logs/ 9 | i=0 10 | group_idx=0 11 | for file in ${FROM}/resnet164/* 12 | do 13 | CUDA_VISIBLE_DEVICES=$i python3 cifar_group.py -a resnet164 --epochs ${EPOCHS} --pruned --schedule 40 60 --gamma 0.1 --resume $file --checkpoint $SAVE/ --train-batch 256 --dataset cifar100 > ${SAVE}/logs/log${group_idx}.txt & 14 | group_idx=$((group_idx+1)) 15 | i=$((i+1)) 16 | i=$(( $i % 4 )) 17 | if [ $i -eq 0 ] ; then 18 | wait 19 | fi 20 | done 21 | 22 | -------------------------------------------------------------------------------- /scripts/activations_grouped_vgg19.sh: -------------------------------------------------------------------------------- 1 | MODEL=./vgg19bn-cifar100.pth.tar 2 | rm -r prune_candidate_logs 3 | mkdir prune_candidate_logs 4 | python3 get_prune_candidates.py -a vgg19_bn -d cifar100 --resume $MODEL --grouped 1 2 13 32 46 51 62 77 91 93 5 | python3 get_prune_candidates.py -a vgg19_bn -d cifar100 --resume $MODEL --grouped 20 23 24 29 30 58 69 72 73 95 6 | python3 get_prune_candidates.py -a vgg19_bn -d cifar100 --resume $MODEL --grouped 33 47 49 52 56 59 66 67 76 96 7 | python3 get_prune_candidates.py -a vgg19_bn -d cifar100 --resume $MODEL --grouped 5 11 31 37 38 39 64 75 84 97 8 | python3 get_prune_candidates.py -a vgg19_bn -d cifar100 --resume $MODEL --grouped 16 21 28 41 48 81 86 87 94 99 9 | -------------------------------------------------------------------------------- /utils/eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import pdb 3 | 4 | __all__ = ['accuracy', 'accuracy_binary'] 5 | 6 | def accuracy(output, target, topk=(1,)): 7 | """Computes the precision@k for the specified values of k""" 8 | maxk = max(topk) 9 | batch_size = target.size(0) 10 | 11 | _, pred = output.topk(maxk, 1, True, True) 12 | pred = pred.t() 13 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 14 | 15 | res = [] 16 | for k in topk: 17 | correct_k = correct[:k].view(-1).float().sum(0) 18 | res.append(correct_k.mul_(100.0 / batch_size)) 19 | return res 20 | 21 | def accuracy_binary(output, target): 22 | pred = output >= 0.0 23 | pred = pred.flatten().long() 24 | acc = pred.eq(target).sum().float() / target.numel() 25 | return acc.data 26 | -------------------------------------------------------------------------------- /scripts/activations_grouped_vgg19_cifar100.sh: -------------------------------------------------------------------------------- 1 | MODEL=./vgg19bn-cifar100.pth.tar 2 | rm -r prune_candidate_logs 3 | mkdir prune_candidate_logs 4 | python3 get_prune_candidates.py -a vgg19_bn -d cifar100 --resume $MODEL --grouped 1 2 13 32 46 51 62 77 91 93 5 | python3 get_prune_candidates.py -a vgg19_bn -d cifar100 --resume $MODEL --grouped 20 23 24 29 30 58 69 72 73 95 6 | python3 get_prune_candidates.py -a vgg19_bn -d cifar100 --resume $MODEL --grouped 33 47 49 52 56 59 66 67 76 96 7 | python3 get_prune_candidates.py -a vgg19_bn -d cifar100 --resume $MODEL --grouped 5 11 31 37 38 39 64 75 84 97 8 | python3 get_prune_candidates.py -a vgg19_bn -d cifar100 --resume $MODEL --grouped 16 21 28 41 48 81 86 87 94 99 9 | python3 get_prune_candidates.py -a vgg19_bn -d cifar100 --resume $MODEL --grouped 12 15 17 25 60 68 71 85 89 90 10 | python3 get_prune_candidates.py -a vgg19_bn -d cifar100 --resume $MODEL --grouped 3 6 19 34 35 36 43 65 80 88 11 | python3 get_prune_candidates.py -a vgg19_bn -d cifar100 --resume $MODEL --grouped 0 9 14 54 57 63 82 83 92 98 12 | python3 get_prune_candidates.py -a vgg19_bn -d cifar100 --resume $MODEL --grouped 8 10 22 26 40 50 53 61 70 79 13 | python3 get_prune_candidates.py -a vgg19_bn -d cifar100 --resume $MODEL --grouped 4 7 18 27 42 44 45 55 74 78 14 | -------------------------------------------------------------------------------- /scripts/activations_grouped_resnet110_cifar100.sh: -------------------------------------------------------------------------------- 1 | MODEL=/home/ubuntu/baseModel/pytorch-classification/checkpoints/cifar100/resnet-110/model_best.pth.tar 2 | rm -r prune_candidate_logs 3 | mkdir prune_candidate_logs 4 | python3 get_prune_candidates.py -a resnet110 -d cifar100 --resume $MODEL --grouped 0 10 53 54 57 62 70 82 83 92 5 | python3 get_prune_candidates.py -a resnet110 -d cifar100 --resume $MODEL --grouped 23 30 32 49 61 67 71 73 91 95 6 | python3 get_prune_candidates.py -a resnet110 -d cifar100 --resume $MODEL --grouped 5 16 20 25 28 40 84 86 87 94 7 | python3 get_prune_candidates.py -a resnet110 -d cifar100 --resume $MODEL --grouped 15 19 34 38 42 43 66 75 88 97 8 | python3 get_prune_candidates.py -a resnet110 -d cifar100 --resume $MODEL --grouped 9 11 12 17 37 39 68 69 76 98 9 | python3 get_prune_candidates.py -a resnet110 -d cifar100 --resume $MODEL --grouped 18 26 27 29 44 45 78 79 93 99 10 | python3 get_prune_candidates.py -a resnet110 -d cifar100 --resume $MODEL --grouped 8 13 41 46 48 58 81 85 89 90 11 | python3 get_prune_candidates.py -a resnet110 -d cifar100 --resume $MODEL --grouped 14 22 33 47 51 52 56 59 60 96 12 | python3 get_prune_candidates.py -a resnet110 -d cifar100 --resume $MODEL --grouped 3 4 21 31 55 63 64 72 74 80 13 | python3 get_prune_candidates.py -a resnet110 -d cifar100 --resume $MODEL --grouped 1 2 6 7 24 35 36 50 65 77 14 | -------------------------------------------------------------------------------- /scripts/activations_grouped_resnet164_cifar100.sh: -------------------------------------------------------------------------------- 1 | MODEL=/home/ubuntu/baseModel/pytorch-classification/checkpoints/cifar100/resnet-164/model_best.pth.tar 2 | rm -r prune_candidate_logs 3 | mkdir prune_candidate_logs 4 | python3 get_prune_candidates.py -a resnet164 -d cifar100 --resume $MODEL --grouped 0 10 53 54 57 61 62 70 83 92 5 | python3 get_prune_candidates.py -a resnet164 -d cifar100 --resume $MODEL --grouped 28 30 39 67 69 71 73 91 95 99 6 | python3 get_prune_candidates.py -a resnet164 -d cifar100 --resume $MODEL --grouped 5 9 16 20 22 25 84 86 87 94 7 | python3 get_prune_candidates.py -a resnet164 -d cifar100 --resume $MODEL --grouped 34 35 36 38 50 65 66 88 97 98 8 | python3 get_prune_candidates.py -a resnet164 -d cifar100 --resume $MODEL --grouped 6 7 14 15 19 24 40 51 75 79 9 | python3 get_prune_candidates.py -a resnet164 -d cifar100 --resume $MODEL --grouped 23 33 47 49 52 56 59 60 82 96 10 | python3 get_prune_candidates.py -a resnet164 -d cifar100 --resume $MODEL --grouped 18 26 27 29 42 44 74 77 78 93 11 | python3 get_prune_candidates.py -a resnet164 -d cifar100 --resume $MODEL --grouped 2 8 11 41 45 46 48 58 85 89 12 | python3 get_prune_candidates.py -a resnet164 -d cifar100 --resume $MODEL --grouped 3 4 21 31 43 55 63 64 72 80 13 | python3 get_prune_candidates.py -a resnet164 -d cifar100 --resume $MODEL --grouped 1 12 13 17 32 37 68 76 81 90 14 | -------------------------------------------------------------------------------- /scripts/train_pruned_grouped.py: -------------------------------------------------------------------------------- 1 | from training_scheduler import train 2 | import os 3 | import shutil 4 | ''' 5 | #!/bin/sh 6 | EPOCHS=80 7 | FROM=./pruned_models 8 | SAVE=${FROM}_retrained 9 | mkdir ${SAVE} 10 | rm ${SAVE}/* -r 11 | mkdir $SAVE/vgg19_bn/ 12 | mkdir $SAVE/logs/ 13 | i=0 14 | group_idx=0 15 | for file in ${FROM}/vgg19_bn/* 16 | do 17 | CUDA_VISIBLE_DEVICES=$i python3 cifar_group.py -a vgg19_bn --epochs ${EPOCHS} --pruned --schedule 40 60 --gamma 0.1 --resume $file --checkpoint $SAVE/ --train-batch 256 --dataset cifar100 > ${SAVE}/logs/log${group_idx}.txt & 18 | group_idx=$((group_idx+1)) 19 | i=$((i+1)) 20 | i=$(( $i % 4 )) 21 | if [ $i -eq 0 ] ; then 22 | wait 23 | fi 24 | done 25 | ''' 26 | 27 | num_epochs = 80 28 | model_dir = "./pruned_models" 29 | save_dir = model_dir + "_retrained" 30 | if os.path.isdir(save_dir): 31 | shutil.rmtree(save_dir) 32 | os.mkdir(save_dir) 33 | os.mkdir(save_dir+"/vgg19_bn/") 34 | os.mkdir(save_dir+"/logs/") 35 | 36 | i = 0 37 | group_idx = 0 38 | commands = [] 39 | for file in os.listdir(model_dir+"/vgg19_bn/"): 40 | command = "python3 cifar_group.py -a vgg19_bn --epochs " + str(num_epochs) + " --pruned --schedule 40 60 --gamma 0.1 --resume " + model_dir + "/vgg19_bn/" + file + " --checkpoint " + save_dir + "/ --train-batch 256 --dataset cifar100 > " + save_dir + "/logs/log" + str(group_idx) + ".txt" 41 | group_idx += 1 42 | i = (i + 1) % 4 43 | commands.append(command) 44 | print(commands) 45 | # train(executables=commands) -------------------------------------------------------------------------------- /load_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | import models.cifar as cifar_models 5 | 6 | 7 | def model_arches(dataset): 8 | if dataset == 'cifar': 9 | return sorted(name for name in cifar_models.__dict__ 10 | if name.islower() and not name.startswith("__") 11 | and callable(cifar_models.__dict__[name])) 12 | else: 13 | raise NotImplementedError 14 | 15 | 16 | 17 | def load_pretrain_model(arch, dataset, resume_checkpoint, num_classes, use_cuda): 18 | print('==> Resuming from checkpoint..') 19 | assert os.path.isfile(resume_checkpoint), 'Error: no checkpoint found!' 20 | if use_cuda: 21 | checkpoint = torch.load(resume_checkpoint) 22 | else: 23 | checkpoint = torch.load( 24 | resume_checkpoint, map_location=torch.device('cpu')) 25 | if dataset.startswith('cifar'): 26 | model = cifar_models.__dict__[arch](num_classes=num_classes) 27 | else: 28 | raise NotImplementedError(f"Unsupported dataset: {dataset}.") 29 | 30 | if use_cuda: 31 | model.cuda() 32 | state_dict = {} 33 | # deal with old torch version 34 | if arch != 'mobilenetv2' and arch != 'shufflenetv2': 35 | for k, v in checkpoint['state_dict'].items(): 36 | state_dict[k.replace('module.', '')] = v 37 | model.load_state_dict(state_dict) 38 | else: 39 | for k, v in checkpoint['net'].items(): 40 | state_dict[k.replace('module.', '')] = v 41 | model.load_state_dict(state_dict) 42 | return model 43 | -------------------------------------------------------------------------------- /datasets/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | import numpy as np 3 | import math 4 | 5 | 6 | class DataSetWrapper(object): 7 | def __init__(self, dataset, class_group: Tuple[int], negative_samples=False): 8 | # The original dataset has been shuffled. Skip shuffling this dataset 9 | # for consistency. 10 | self.dataset = dataset 11 | self.class_group = class_group 12 | self.negative_samples = negative_samples 13 | self.targets = np.asarray(self.dataset.targets) 14 | # This is the bool mask for all classes in the given group. 15 | positive_mask = np.zeros_like(self.targets, dtype=bool) 16 | for class_index in class_group: 17 | positive_mask |= (self.targets == class_index) 18 | positive_class_indices = np.where(positive_mask)[0] 19 | if negative_samples: 20 | # For N negative samples, P positive samples, we need to append 21 | # (k * N - P) positive samples. 22 | k = len(class_group) 23 | P = len(positive_class_indices) 24 | N = len(self.targets) - P 25 | assert N >= P, "there are already more positive classes" 26 | ext_P = k * N - P 27 | repeat_n = math.ceil(ext_P / P) 28 | extented_indices = np.repeat( 29 | positive_class_indices, repeat_n)[:ext_P] 30 | # fuse and shuffle 31 | all_indices = np.arange(len(self.targets)) 32 | fullset = np.concatenate([all_indices, extented_indices]) 33 | np.random.shuffle(fullset) 34 | self.mapping = fullset 35 | else: 36 | self.mapping = positive_class_indices 37 | 38 | def __getitem__(self, i): 39 | index = self.mapping[i] 40 | data, label = self.dataset[index] 41 | if label in self.class_group: 42 | label = list(self.class_group).index(label) + 1 43 | else: 44 | label = 0 45 | return data, label 46 | 47 | def __len__(self): 48 | return len(self.mapping) 49 | 50 | @property 51 | def num_classes(self): 52 | return len(self.class_group) + 1 53 | -------------------------------------------------------------------------------- /datasets/cifar.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import datasets 3 | from torchvision import transforms 4 | from typing import List, Tuple 5 | 6 | from datasets import utils 7 | 8 | 9 | # Transformations 10 | RC = transforms.RandomCrop(32, padding=4) 11 | RHF = transforms.RandomHorizontalFlip() 12 | RVF = transforms.RandomVerticalFlip() 13 | NRM = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 14 | TT = transforms.ToTensor() 15 | TPIL = transforms.ToPILImage() 16 | 17 | # Transforms object for trainset with augmentation 18 | transform_with_aug = transforms.Compose([RC, RHF, TT, NRM]) 19 | # Transforms object for testset with NO augmentation 20 | transform_no_aug = transforms.Compose([TT, NRM]) 21 | 22 | 23 | DATASET_ROOT = './data/' 24 | 25 | 26 | class CIFAR10TrainingSetWrapper(utils.DataSetWrapper): 27 | def __init__(self, class_group: Tuple[int], negative_samples=False): 28 | dataset = datasets.CIFAR10(root=DATASET_ROOT, train=True, 29 | download=True, transform=transform_with_aug) 30 | super().__init__(dataset, class_group, negative_samples) 31 | 32 | 33 | class CIFAR10TestingSetWrapper(utils.DataSetWrapper): 34 | def __init__(self, class_group: Tuple[int], negative_samples=False): 35 | dataset = datasets.CIFAR10(root=DATASET_ROOT, train=False, 36 | download=True, transform=transform_no_aug) 37 | super().__init__(dataset, class_group, negative_samples) 38 | 39 | 40 | class CIFAR100TrainingSetWrapper(utils.DataSetWrapper): 41 | def __init__(self, class_group: Tuple[int], negative_samples=False): 42 | dataset = datasets.CIFAR100(root=DATASET_ROOT, train=True, 43 | download=True, transform=transform_with_aug) 44 | super().__init__(dataset, class_group, negative_samples) 45 | 46 | 47 | class CIFAR100TestingSetWrapper(utils.DataSetWrapper): 48 | def __init__(self, class_group: Tuple[int], negative_samples=False): 49 | dataset = datasets.CIFAR100(root=DATASET_ROOT, train=False, 50 | download=True, transform=transform_no_aug) 51 | super().__init__(dataset, class_group, negative_samples) 52 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | limo/vgg-pruning/pruned_models_* 2 | pytorch-classification/data/ 3 | limo/vgg-pruning/pruned_models/ 4 | pytorch-classification/checkpoints/ 5 | limo/vgg-pruning/pruned_models_20_epochs/ 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | pip-wheel-metadata/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # celery beat schedule file 92 | celerybeat-schedule 93 | 94 | # SageMath parsed files 95 | *.sage.py 96 | 97 | # Environments 98 | .env 99 | .venv 100 | env/ 101 | venv/ 102 | ENV/ 103 | env.bak/ 104 | venv.bak/ 105 | 106 | # Spyder project settings 107 | .spyderproject 108 | .spyproject 109 | 110 | # Rope project settings 111 | .ropeproject 112 | 113 | # mkdocs documentation 114 | /site 115 | 116 | # mypy 117 | .mypy_cache/ 118 | .dmypy.json 119 | dmypy.json 120 | 121 | # Pyre type checker 122 | .pyre/ 123 | 124 | *.pth.tar 125 | pruned_models/ 126 | pruned_models_retrained/ 127 | prune_candidate_logs/ 128 | data/ 129 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | - msr_init: net parameter initialization. 4 | - progress_bar: progress bar mimic xlua.progress. 5 | ''' 6 | import errno 7 | import os 8 | import sys 9 | import time 10 | import math 11 | 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | from torch.autograd import Variable 15 | 16 | __all__ = ['get_mean_and_std', 'init_params', 'mkdir_p', 'AverageMeter'] 17 | 18 | 19 | def get_mean_and_std(dataset): 20 | '''Compute the mean and std value of dataset.''' 21 | dataloader = trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 22 | 23 | mean = torch.zeros(3) 24 | std = torch.zeros(3) 25 | print('==> Computing mean and std..') 26 | for inputs, targets in dataloader: 27 | for i in range(3): 28 | mean[i] += inputs[:,i,:,:].mean() 29 | std[i] += inputs[:,i,:,:].std() 30 | mean.div_(len(dataset)) 31 | std.div_(len(dataset)) 32 | return mean, std 33 | 34 | def init_params(net): 35 | '''Init layer parameters.''' 36 | for m in net.modules(): 37 | if isinstance(m, nn.Conv2d): 38 | init.kaiming_normal(m.weight, mode='fan_out') 39 | if m.bias: 40 | init.constant(m.bias, 0) 41 | elif isinstance(m, nn.BatchNorm2d): 42 | init.constant(m.weight, 1) 43 | init.constant(m.bias, 0) 44 | elif isinstance(m, nn.Linear): 45 | init.normal(m.weight, std=1e-3) 46 | if m.bias: 47 | init.constant(m.bias, 0) 48 | 49 | def mkdir_p(path): 50 | '''make dir if not exist''' 51 | try: 52 | os.makedirs(path) 53 | except OSError as exc: # Python >2.5 54 | if exc.errno == errno.EEXIST and os.path.isdir(path): 55 | pass 56 | else: 57 | raise 58 | 59 | class AverageMeter(object): 60 | """Computes and stores the average and current value 61 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 62 | """ 63 | def __init__(self): 64 | self.reset() 65 | 66 | def reset(self): 67 | self.val = 0 68 | self.avg = 0 69 | self.sum = 0 70 | self.count = 0 71 | 72 | def update(self, val, n=1): 73 | self.val = val 74 | self.sum += val * n 75 | self.count += n 76 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /apoz_policy_imagenet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import io 4 | 5 | """ 6 | Calculate the Average Percentage of Zeros Score of the feature map activation layer output 7 | """ 8 | def apoz_scoring(activation): 9 | if activation.dim() == 4: 10 | view_2d = activation.view(-1, activation.size(2) * activation.size(3)) # (batch*channels) x (h*w) 11 | featuremap_apoz = view_2d.abs().gt(0.005).sum(dim=1).float() / (activation.size(2) * activation.size(3)) # (batch*channels) x 1 12 | featuremap_apoz_mat = featuremap_apoz.view(activation.size(0), activation.size(1)) # batch x channels 13 | elif activation.dim() == 2 and activation.shape[1] == 1: 14 | featuremap_apoz_mat = activation.abs().gt(0.005).sum(dim=1).float() / activation.size(1) 15 | elif activation.dim() == 2: # FC Case: (batch x channels) 16 | featuremap_apoz_mat = activation.abs().gt(0.005).sum(dim=0).float() 17 | return 100 - featuremap_apoz_mat.mul(100) 18 | else: 19 | raise ValueError("activation_channels_apoz: Unsupported shape: ".format(activation.shape)) 20 | return 100 - featuremap_apoz_mat.mean(dim=0).mul(100) 21 | 22 | 23 | def avg_scoring(activation): 24 | if activation.dim() == 4: 25 | view_2d = activation.view(-1, activation.size(2) * activation.size(3)) 26 | featuremap_avg = view_2d.abs().sum(dim = 1).float() / (activation.size(2) * activation.size(3)) 27 | featuremap_avg_mat = featuremap_avg.view(activation.size(0), activation.size(1)) 28 | elif activation.dim() == 2 and activation.shape[1] == 1: 29 | featuremap_avg_mat = activation.abs().sum(dim = 1).float() / activation.size(1) 30 | elif activation.dim() == 2: 31 | featuremap_avg_mat = activation.abs().float() 32 | else: 33 | raise ValueError("activation_channels_avg: Unsupported shape: ".format(activation.shape)) 34 | return featuremap_avg_mat.mean(dim = 0) 35 | 36 | def pruning_candidates(group_id, thresholds, file_name): 37 | layers_channels = [] 38 | fmap_file = open(file_name, "rb") 39 | data_buffer = io.BytesIO(fmap_file.read()) 40 | for _ in range(16): 41 | layers_channels.append(torch.load(data_buffer)) 42 | 43 | candidates_by_layer = [] 44 | print("Calculating pruning candidates for classe(s) {}".format(group_id)) 45 | for index, layer in enumerate(layers_channels): 46 | apoz_score = apoz_scoring(layer) 47 | print(apoz_score.mean()) 48 | 49 | curr_threshold = thresholds[index] 50 | while True: 51 | num_candidates = apoz_score.gt(curr_threshold).sum() 52 | print("Greater than {} %".format(curr_threshold), num_candidates) 53 | if num_candidates < apoz_score.size()[0]: 54 | candidates = [x[0] for x in apoz_score.gt(curr_threshold).nonzero().tolist()] 55 | break 56 | curr_threshold += 5 57 | 58 | print("Class Index: {}, Layer {}, Number of neurons with apoz > {}%: {}/{}".format(group_id, index, curr_threshold, len(candidates), apoz_score.size()[0])) 59 | candidates_by_layer.append(candidates) 60 | print("Zero channels out of total in layer {}: {}/{}".format(index, len(candidates) ,len(layer))) 61 | return candidates_by_layer 62 | -------------------------------------------------------------------------------- /models/cifar/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Block(nn.Module): 7 | '''expand + depthwise + pointwise''' 8 | def __init__(self, in_planes, out_planes, expansion, stride): 9 | super(Block, self).__init__() 10 | self.stride = stride 11 | 12 | planes = expansion * in_planes 13 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, stride=1, padding=0, bias=False) 14 | self.bn1 = nn.BatchNorm2d(planes) 15 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, groups=planes, bias=False) 16 | self.bn2 = nn.BatchNorm2d(planes) 17 | self.conv3 = nn.Conv2d(planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 18 | self.bn3 = nn.BatchNorm2d(out_planes) 19 | 20 | self.shortcut = nn.Sequential() 21 | if stride == 1 and in_planes != out_planes: 22 | self.shortcut = nn.Sequential( 23 | nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False), 24 | nn.BatchNorm2d(out_planes), 25 | ) 26 | 27 | def forward(self, x): 28 | out = F.relu(self.bn1(self.conv1(x))) 29 | out = F.relu(self.bn2(self.conv2(out))) 30 | out = self.bn3(self.conv3(out)) 31 | out = out + self.shortcut(x) if self.stride==1 else out 32 | return out 33 | 34 | 35 | class MobileNetV2(nn.Module): 36 | # (expansion, out_planes, num_blocks, stride) 37 | cfg = [(1, 16, 1, 1), 38 | (6, 24, 2, 1), # NOTE: change stride 2 -> 1 for CIFAR10 39 | (6, 32, 3, 2), 40 | (6, 64, 4, 2), 41 | (6, 96, 3, 1), 42 | (6, 160, 3, 2), 43 | (6, 320, 1, 1)] 44 | 45 | def __init__(self, num_classes=10): 46 | super(MobileNetV2, self).__init__() 47 | # NOTE: change conv1 stride 2 -> 1 for CIFAR10 48 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False) 49 | self.bn1 = nn.BatchNorm2d(32) 50 | self.layers = self._make_layers(in_planes=32) 51 | self.conv2 = nn.Conv2d(320, 1280, kernel_size=1, stride=1, padding=0, bias=False) 52 | self.bn2 = nn.BatchNorm2d(1280) 53 | self.linear = nn.Linear(1280, num_classes) 54 | 55 | def _make_layers(self, in_planes): 56 | layers = [] 57 | for expansion, out_planes, num_blocks, stride in self.cfg: 58 | strides = [stride] + [1]*(num_blocks-1) 59 | for stride in strides: 60 | layers.append(Block(in_planes, out_planes, expansion, stride)) 61 | in_planes = out_planes 62 | return nn.Sequential(*layers) 63 | 64 | def forward(self, x, features_only=False): 65 | out = F.relu(self.bn1(self.conv1(x))) 66 | out = self.layers(out) 67 | out = F.relu(self.bn2(self.conv2(out))) 68 | # NOTE: change pooling kernel_size 7 -> 4 for CIFAR10 69 | out = F.avg_pool2d(out, 4) 70 | out = out.view(out.size(0), -1) 71 | if not features_only: 72 | out = self.linear(out) 73 | return out 74 | 75 | def mobilenetv2(**kwargs): 76 | model = MobileNetV2(num_classes=10) 77 | return model -------------------------------------------------------------------------------- /retrain_grouped_model.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import glob 4 | import subprocess as sp 5 | import numpy as np 6 | parser = argparse.ArgumentParser(description='retrain pruned model') 7 | parser.add_argument('-d', '--dataset', required=True, type=str) 8 | parser.add_argument('--epochs', required=True, type=int) 9 | parser.add_argument('-a', '--arch', default='vgg19_bn', 10 | type=str, help='The architecture of the trained model') 11 | parser.add_argument('-r', '--resume', default='', type=str, 12 | help='The path to the checkpoints') ### pruned models are saved here 13 | parser.add_argument('--num_gpus', default=4, type=int) 14 | parser.add_argument('--train_batch', default=256, type=int) 15 | parser.add_argument('--data', default='/home/ubuntu/imagenet', required=False, type=str, 16 | help='location of the imagenet dataset that includes train/val') 17 | 18 | args = parser.parse_args() 19 | 20 | 21 | def main(): 22 | save = args.resume[:-1] +'_retrained/' 23 | groups = np.load(open(args.resume + "grouping_config.npy", "rb")) 24 | resultExist = os.path.exists(save) 25 | if resultExist: 26 | rm_cmd = 'rm -rf ' + save 27 | sp.Popen(rm_cmd, shell=True) 28 | os.mkdir(save) 29 | np.save(open(os.path.join(save[:-1], "grouping_config.npy"), "wb"), groups) 30 | save += args.arch 31 | os.mkdir(save) 32 | files = [f for f in glob.glob(args.resume + args.arch+"/*.pth", recursive=False)] 33 | process_list = [None for _ in range(args.num_gpus)] 34 | if args.dataset in ['cifar10', 'cifar100']: 35 | for i, file in enumerate(files): 36 | if process_list[i % args.num_gpus]: 37 | process_list[i % args.num_gpus].wait() 38 | exec_cmd = 'python3 cifar_group.py' +\ 39 | ' --arch %s' % args.arch +\ 40 | ' --resume %s' % file +\ 41 | ' --schedule 40 60' +\ 42 | ' --gamma 0.1' +\ 43 | ' --epochs %d' % args.epochs +\ 44 | ' --checkpoint %s' % save +\ 45 | ' --train-batch %d' % args.train_batch +\ 46 | ' --dataset %s' % args.dataset +\ 47 | ' --grouping_dir %s' % args.resume +\ 48 | ' --pruned' +\ 49 | ' --gpu_id %d' % (i % args.num_gpus) 50 | process_list[i % args.num_gpus] = sp.Popen(exec_cmd, shell=True) 51 | elif args.dataset in 'imagenet': 52 | for i, file in enumerate(files): 53 | if process_list[i % args.num_gpus]: 54 | process_list[i % args.num_gpus].wait() 55 | exec_cmd = 'python3 imagenet_official_retrain.py' +\ 56 | ' --data %s' % args.data +\ 57 | ' --arch %s' % args.arch +\ 58 | ' --resume %s' % file +\ 59 | ' --schedule 10 15' +\ 60 | ' --config %s' % args.resume + '/grouping_config.npy' +\ 61 | ' --gamma 0.1 ' +\ 62 | ' --batch_size %d' % args.train_batch +\ 63 | ' --epochs %d' % args.epochs +\ 64 | ' --checkpoint %s' % save +\ 65 | ' --gpu %s' % (i % args.num_gpus) 66 | process_list[i % args.num_gpus] = sp.Popen(exec_cmd, shell=True) 67 | 68 | if __name__ == '__main__': 69 | main() 70 | -------------------------------------------------------------------------------- /get_prune_candidates.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | 4 | import torch 5 | from torch import nn 6 | import torch.backends.cudnn as cudnn 7 | 8 | from apoz_policy import ActivationRecord 9 | from datasets import cifar 10 | import load_model 11 | from tqdm import tqdm 12 | import os 13 | from regularize_model import standard 14 | 15 | 16 | parser = argparse.ArgumentParser( 17 | description='PyTorch CIFAR10/100 Generate Class Specific Information') 18 | # Datasets 19 | parser.add_argument('-d', '--dataset', required=True, type=str) 20 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 21 | help='number of data loading workers (default: 4)') 22 | parser.add_argument('--resume', required=True, default='', type=str, metavar='PATH', 23 | help='path to latest checkpoint (default: none)') 24 | # Architecture 25 | parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet20', 26 | choices=load_model.model_arches('cifar'), 27 | help='model architecture: ' + 28 | ' | '.join(load_model.model_arches('cifar')) + 29 | ' (default: resnet18)') 30 | # Miscs 31 | parser.add_argument('--seed', type=int, default=42, help='manual seed') 32 | parser.add_argument('--grouped', required=True, type=int, nargs='+', default=[], 33 | help='Generate activations based on the these class indices') 34 | parser.add_argument('--group_number', required=True, type=int, 35 | help='Group number') 36 | parser.add_argument('--gpu_num', default='0', type=str, 37 | help='GPU number') 38 | 39 | 40 | args = parser.parse_args() 41 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_num 42 | use_cuda = torch.cuda.is_available() 43 | 44 | # Random seed 45 | torch.manual_seed(args.seed) 46 | if use_cuda: 47 | torch.cuda.manual_seed_all(args.seed) 48 | 49 | assert args.grouped 50 | 51 | 52 | def main(): 53 | if args.dataset == 'cifar10': 54 | dataset = cifar.CIFAR10TrainingSetWrapper(args.grouped, False) 55 | num_classes = 10 56 | elif args.dataset == 'cifar100': 57 | dataset = cifar.CIFAR100TrainingSetWrapper(args.grouped, False) 58 | num_classes = 100 59 | else: 60 | raise NotImplementedError( 61 | f"There's no support for '{args.dataset}' dataset.") 62 | 63 | pruning_loader = torch.utils.data.DataLoader( 64 | dataset, 65 | batch_size=1000, 66 | num_workers=args.workers, 67 | pin_memory=False) 68 | 69 | model = load_model.load_pretrain_model( 70 | args.arch, 'cifar', args.resume, num_classes, use_cuda) 71 | 72 | if args.arch in ["mobilenetv2", "shufflenetv2"]: 73 | model = standard(model, args.arch, num_classes) 74 | 75 | if use_cuda: 76 | model.cuda() 77 | print('\nMake a test run to generate activations. \n Using training set.\n') 78 | with ActivationRecord(model, args.arch) as recorder: 79 | # collect pruning data 80 | #bar = tqdm(total=len(pruning_loader)) 81 | for batch_idx, (inputs, _) in enumerate(pruning_loader): 82 | #bar.update(1) 83 | if use_cuda: 84 | inputs = inputs.cuda() 85 | recorder.record_batch(inputs) 86 | candidates_by_layer = recorder.generate_pruned_candidates() 87 | 88 | with open(f"prune_candidate_logs/group_{args.group_number}_apoz_layer_thresholds.npy", "wb") as f: 89 | pickle.dump(candidates_by_layer, f) 90 | print(candidates_by_layer) 91 | 92 | 93 | if __name__ == '__main__': 94 | main() 95 | -------------------------------------------------------------------------------- /utils/visualize.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import torch 3 | import torch.nn as nn 4 | import torchvision 5 | import torchvision.transforms as transforms 6 | import numpy as np 7 | from .misc import * 8 | 9 | __all__ = ['make_image', 'show_batch', 'show_mask', 'show_mask_single'] 10 | 11 | # functions to show an image 12 | def make_image(img, mean=(0,0,0), std=(1,1,1)): 13 | for i in range(0, 3): 14 | img[i] = img[i] * std[i] + mean[i] # unnormalize 15 | npimg = img.numpy() 16 | return np.transpose(npimg, (1, 2, 0)) 17 | 18 | def gauss(x,a,b,c): 19 | return torch.exp(-torch.pow(torch.add(x,-b),2).div(2*c*c)).mul(a) 20 | 21 | def colorize(x): 22 | ''' Converts a one-channel grayscale image to a color heatmap image ''' 23 | if x.dim() == 2: 24 | torch.unsqueeze(x, 0, out=x) 25 | if x.dim() == 3: 26 | cl = torch.zeros([3, x.size(1), x.size(2)]) 27 | cl[0] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3) 28 | cl[1] = gauss(x,1,.5,.3) 29 | cl[2] = gauss(x,1,.2,.3) 30 | cl[cl.gt(1)] = 1 31 | elif x.dim() == 4: 32 | cl = torch.zeros([x.size(0), 3, x.size(2), x.size(3)]) 33 | cl[:,0,:,:] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3) 34 | cl[:,1,:,:] = gauss(x,1,.5,.3) 35 | cl[:,2,:,:] = gauss(x,1,.2,.3) 36 | return cl 37 | 38 | def show_batch(images, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 39 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 40 | plt.imshow(images) 41 | plt.show() 42 | 43 | 44 | def show_mask_single(images, mask, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 45 | im_size = images.size(2) 46 | 47 | # save for adding mask 48 | im_data = images.clone() 49 | for i in range(0, 3): 50 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize 51 | 52 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 53 | plt.subplot(2, 1, 1) 54 | plt.imshow(images) 55 | plt.axis('off') 56 | 57 | # for b in range(mask.size(0)): 58 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min()) 59 | mask_size = mask.size(2) 60 | # print('Max %f Min %f' % (mask.max(), mask.min())) 61 | mask = (upsampling(mask, scale_factor=im_size/mask_size)) 62 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size)) 63 | # for c in range(3): 64 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c] 65 | 66 | # print(mask.size()) 67 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data))) 68 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std) 69 | plt.subplot(2, 1, 2) 70 | plt.imshow(mask) 71 | plt.axis('off') 72 | 73 | def show_mask(images, masklist, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 74 | im_size = images.size(2) 75 | 76 | # save for adding mask 77 | im_data = images.clone() 78 | for i in range(0, 3): 79 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize 80 | 81 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 82 | plt.subplot(1+len(masklist), 1, 1) 83 | plt.imshow(images) 84 | plt.axis('off') 85 | 86 | for i in range(len(masklist)): 87 | mask = masklist[i].data.cpu() 88 | # for b in range(mask.size(0)): 89 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min()) 90 | mask_size = mask.size(2) 91 | # print('Max %f Min %f' % (mask.max(), mask.min())) 92 | mask = (upsampling(mask, scale_factor=im_size/mask_size)) 93 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size)) 94 | # for c in range(3): 95 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c] 96 | 97 | # print(mask.size()) 98 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data))) 99 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std) 100 | plt.subplot(1+len(masklist), 1, i+2) 101 | plt.imshow(mask) 102 | plt.axis('off') 103 | -------------------------------------------------------------------------------- /models/cifar/wrn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | __all__ = ['wrn'] 7 | 8 | class BasicBlock(nn.Module): 9 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 10 | super(BasicBlock, self).__init__() 11 | self.bn1 = nn.BatchNorm2d(in_planes) 12 | self.relu1 = nn.ReLU(inplace=True) 13 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 14 | padding=1, bias=False) 15 | self.bn2 = nn.BatchNorm2d(out_planes) 16 | self.relu2 = nn.ReLU(inplace=True) 17 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 18 | padding=1, bias=False) 19 | self.droprate = dropRate 20 | self.equalInOut = (in_planes == out_planes) 21 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 22 | padding=0, bias=False) or None 23 | def forward(self, x): 24 | if not self.equalInOut: 25 | x = self.relu1(self.bn1(x)) 26 | else: 27 | out = self.relu1(self.bn1(x)) 28 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 29 | if self.droprate > 0: 30 | out = F.dropout(out, p=self.droprate, training=self.training) 31 | out = self.conv2(out) 32 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 33 | 34 | class NetworkBlock(nn.Module): 35 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 36 | super(NetworkBlock, self).__init__() 37 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 38 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 39 | layers = [] 40 | for i in range(nb_layers): 41 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 42 | return nn.Sequential(*layers) 43 | def forward(self, x): 44 | return self.layer(x) 45 | 46 | class WideResNet(nn.Module): 47 | def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0): 48 | super(WideResNet, self).__init__() 49 | nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] 50 | assert (depth - 4) % 6 == 0, 'depth should be 6n+4' 51 | n = (depth - 4) // 6 52 | block = BasicBlock 53 | # 1st conv before any network block 54 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 55 | padding=1, bias=False) 56 | # 1st block 57 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 58 | # 2nd block 59 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 60 | # 3rd block 61 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 62 | # global average pooling and classifier 63 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 64 | self.relu = nn.ReLU(inplace=True) 65 | self.fc = nn.Linear(nChannels[3], num_classes) 66 | self.nChannels = nChannels[3] 67 | 68 | for m in self.modules(): 69 | if isinstance(m, nn.Conv2d): 70 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 71 | m.weight.data.normal_(0, math.sqrt(2. / n)) 72 | elif isinstance(m, nn.BatchNorm2d): 73 | m.weight.data.fill_(1) 74 | m.bias.data.zero_() 75 | elif isinstance(m, nn.Linear): 76 | m.bias.data.zero_() 77 | 78 | def forward(self, x): 79 | out = self.conv1(x) 80 | out = self.block1(out) 81 | out = self.block2(out) 82 | out = self.block3(out) 83 | out = self.relu(self.bn1(out)) 84 | out = F.avg_pool2d(out, 8) 85 | out = out.view(-1, self.nChannels) 86 | return self.fc(out) 87 | 88 | def wrn(**kwargs): 89 | """ 90 | Constructs a Wide Residual Networks. 91 | """ 92 | model = WideResNet(**kwargs) 93 | return model 94 | -------------------------------------------------------------------------------- /scripts/training_scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import threading 3 | import subprocess 4 | import multiprocessing as mp 5 | import os 6 | 7 | pruned_model_path="./pruned_models/vgg19_bn/" 8 | retrained_model_path="./retrained_model/vgg19_bn/" 9 | ''' 10 | 1) initialize bounded producer/consumer queue of size max(num_devices (param), output from torch.cuda.device_count()) 11 | ''' 12 | def train(executables, allowable_devices=range(torch.cuda.device_count())): 13 | free_devices = mp.Queue(maxsize=len(allowable_devices)) 14 | for i in allowable_devices: 15 | free_devices.put(i) 16 | for executable in executables: 17 | assigned_device = free_devices.get() 18 | print("script: '" + str(executable) + "' assigned to GPU: " + str(assigned_device)) 19 | mp.Process(target=execute_on_device, args=(assigned_device, executable, free_devices)).start() 20 | 21 | def execute_on_device(GPU_ID, executable, free_devices): 22 | # train the model 23 | os.environ['CUDA_VISIBLE_DEVICES'] = str(GPU_ID) 24 | executable_tokens = executable.split(" ") 25 | stdout_file = None 26 | if ">" in executable_tokens: 27 | idx = executable_tokens.index(">") 28 | stdout_file = open(executable_tokens[idx+1], "w") 29 | executable_tokens = executable_tokens[:idx] 30 | print(stdout_file) 31 | subprocess.run(executable_tokens, stdout=stdout_file) 32 | # mark this GPU as free 33 | free_devices.put(GPU_ID) 34 | if stdout_file is not None: 35 | stdout_file.close() 36 | 37 | def get_stdout(executable_tokens): 38 | if '>' in executable_tokens: 39 | return executable 40 | else: 41 | return None 42 | 43 | if __name__ == '__main__': 44 | to_train = [ 45 | "python3 cifar_binary.py --pruned -a vgg19_bn --lr 0.01 --epochs 40 --schedule 20 30 --gamma 0.1 --resume "+pruned_model_path+"vgg19_bn_0_pruned_model.pth --checkpoint "+retrained_model_path+"vgg19_bn_0_pruned_model --train-batch 64 --class-index 0", 46 | "python3 cifar_binary.py --pruned -a vgg19_bn --lr 0.01 --epochs 40 --schedule 20 30 --gamma 0.1 --resume "+pruned_model_path+"vgg19_bn_1_pruned_model.pth --checkpoint "+retrained_model_path+"vgg19_bn_1_pruned_model --train-batch 64 --class-index 1", 47 | "python3 cifar_binary.py --pruned -a vgg19_bn --lr 0.01 --epochs 40 --schedule 20 30 --gamma 0.1 --resume "+pruned_model_path+"vgg19_bn_2_pruned_model.pth --checkpoint "+retrained_model_path+"vgg19_bn_2_pruned_model --train-batch 64 --class-index 2", 48 | "python3 cifar_binary.py --pruned -a vgg19_bn --lr 0.01 --epochs 40 --schedule 20 30 --gamma 0.1 --resume "+pruned_model_path+"vgg19_bn_3_pruned_model.pth --checkpoint "+retrained_model_path+"vgg19_bn_3_pruned_model --train-batch 64 --class-index 3", 49 | "python3 cifar_binary.py --pruned -a vgg19_bn --lr 0.01 --epochs 40 --schedule 20 30 --gamma 0.1 --resume "+pruned_model_path+"vgg19_bn_4_pruned_model.pth --checkpoint "+retrained_model_path+"vgg19_bn_4_pruned_model --train-batch 64 --class-index 4", 50 | "python3 cifar_binary.py --pruned -a vgg19_bn --lr 0.01 --epochs 40 --schedule 20 30 --gamma 0.1 --resume "+pruned_model_path+"vgg19_bn_5_pruned_model.pth --checkpoint "+retrained_model_path+"vgg19_bn_5_pruned_model --train-batch 64 --class-index 5", 51 | "python3 cifar_binary.py --pruned -a vgg19_bn --lr 0.01 --epochs 40 --schedule 20 30 --gamma 0.1 --resume "+pruned_model_path+"vgg19_bn_6_pruned_model.pth --checkpoint "+retrained_model_path+"vgg19_bn_6_pruned_model --train-batch 64 --class-index 6", 52 | "python3 cifar_binary.py --pruned -a vgg19_bn --lr 0.01 --epochs 40 --schedule 20 30 --gamma 0.1 --resume "+pruned_model_path+"vgg19_bn_7_pruned_model.pth --checkpoint "+retrained_model_path+"vgg19_bn_7_pruned_model --train-batch 64 --class-index 7", 53 | "python3 cifar_binary.py --pruned -a vgg19_bn --lr 0.01 --epochs 40 --schedule 20 30 --gamma 0.1 --resume "+pruned_model_path+"vgg19_bn_8_pruned_model.pth --checkpoint "+retrained_model_path+"vgg19_bn_8_pruned_model --train-batch 64 --class-index 8", 54 | "python3 cifar_binary.py --pruned -a vgg19_bn --lr 0.01 --epochs 40 --schedule 20 30 --gamma 0.1 --resume "+pruned_model_path+"vgg19_bn_9_pruned_model.pth --checkpoint "+retrained_model_path+"vgg19_bn_9_pruned_model --train-batch 64 --class-index 9", 55 | ] 56 | train(to_train) 57 | -------------------------------------------------------------------------------- /compute_flops.py: -------------------------------------------------------------------------------- 1 | # Code from https://github.com/simochen/model-tools. 2 | import numpy as np 3 | 4 | import torch 5 | import torchvision 6 | import torch.nn as nn 7 | from torch.autograd import Variable 8 | 9 | 10 | def print_model_param_nums(model=None, multiply_adds=True): 11 | if model == None: 12 | model = torchvision.models.alexnet() 13 | total = sum([param.nelement() for param in model.parameters()]) 14 | print(' + Number of params: %.2fM' % (total / 1e6)) 15 | 16 | def print_model_param_flops(model=None, input_res=224, multiply_adds=True): 17 | 18 | prods = {} 19 | def save_hook(name): 20 | def hook_per(self, input, output): 21 | prods[name] = np.prod(input[0].shape) 22 | return hook_per 23 | 24 | list_1=[] 25 | def simple_hook(self, input, output): 26 | list_1.append(np.prod(input[0].shape)) 27 | list_2={} 28 | def simple_hook2(self, input, output): 29 | list_2['names'] = np.prod(input[0].shape) 30 | 31 | list_conv=[] 32 | def conv_hook(self, input, output): 33 | batch_size, input_channels, input_height, input_width = input[0].size() 34 | output_channels, output_height, output_width = output[0].size() 35 | 36 | kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) 37 | bias_ops = 1 if self.bias is not None else 0 38 | 39 | params = output_channels * (kernel_ops + bias_ops) 40 | flops = (kernel_ops * (2 if multiply_adds else 1) + bias_ops) * output_channels * output_height * output_width * batch_size 41 | 42 | list_conv.append(flops) 43 | 44 | list_linear=[] 45 | def linear_hook(self, input, output): 46 | batch_size = input[0].size(0) if input[0].dim() == 2 else 1 47 | 48 | weight_ops = self.weight.nelement() * (2 if multiply_adds else 1) 49 | bias_ops = self.bias.nelement() 50 | 51 | flops = batch_size * (weight_ops + bias_ops) 52 | list_linear.append(flops) 53 | 54 | list_bn=[] 55 | def bn_hook(self, input, output): 56 | list_bn.append(input[0].nelement() * 2) 57 | 58 | list_relu=[] 59 | def relu_hook(self, input, output): 60 | list_relu.append(input[0].nelement()) 61 | 62 | list_pooling=[] 63 | def pooling_hook(self, input, output): 64 | batch_size, input_channels, input_height, input_width = input[0].size() 65 | output_channels, output_height, output_width = output[0].size() 66 | 67 | kernel_ops = self.kernel_size * self.kernel_size 68 | bias_ops = 0 69 | params = 0 70 | flops = (kernel_ops + bias_ops) * output_channels * output_height * output_width * batch_size 71 | 72 | list_pooling.append(flops) 73 | 74 | list_upsample=[] 75 | # For bilinear upsample 76 | def upsample_hook(self, input, output): 77 | batch_size, input_channels, input_height, input_width = input[0].size() 78 | output_channels, output_height, output_width = output[0].size() 79 | 80 | flops = output_height * output_width * output_channels * batch_size * 12 81 | list_upsample.append(flops) 82 | 83 | def foo(net): 84 | childrens = list(net.children()) 85 | if not childrens: 86 | if isinstance(net, torch.nn.Conv2d): 87 | net.register_forward_hook(conv_hook) 88 | if isinstance(net, torch.nn.Linear): 89 | net.register_forward_hook(linear_hook) 90 | if isinstance(net, torch.nn.BatchNorm2d): 91 | net.register_forward_hook(bn_hook) 92 | if isinstance(net, torch.nn.ReLU): 93 | net.register_forward_hook(relu_hook) 94 | if isinstance(net, torch.nn.MaxPool2d) or isinstance(net, torch.nn.AvgPool2d): 95 | net.register_forward_hook(pooling_hook) 96 | if isinstance(net, torch.nn.Upsample): 97 | net.register_forward_hook(upsample_hook) 98 | return 99 | for c in childrens: 100 | foo(c) 101 | 102 | if model == None: 103 | model = torchvision.models.alexnet() 104 | foo(model) 105 | input = torch.rand(3, 3, input_res, input_res) 106 | if input.is_cuda: 107 | model.cuda() 108 | else: 109 | model.cpu() 110 | with torch.no_grad(): 111 | _ = model(input) 112 | 113 | total_flops = (sum(list_conv) + sum(list_linear) + sum(list_bn) + sum(list_relu) + sum(list_pooling) + sum(list_upsample)) 114 | 115 | print(' + Number of FLOPs: %.5fG' % (total_flops / 1e9)) 116 | 117 | return total_flops 118 | -------------------------------------------------------------------------------- /models/cifar/vgg.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | import math 4 | 5 | 6 | __all__ = [ 7 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 8 | 'vgg19_bn', 'vgg19', 9 | ] 10 | 11 | 12 | model_urls = { 13 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 14 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 15 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 16 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 17 | } 18 | 19 | 20 | class VGG(nn.Module): 21 | 22 | def __init__(self, features, num_classes=1000): 23 | super(VGG, self).__init__() 24 | self.features = features 25 | self.classifier = nn.Linear(512, num_classes) 26 | self._initialize_weights() 27 | 28 | def forward(self, x, features_only=False): 29 | x = self.features(x) 30 | x = x.view(x.size(0), -1) 31 | if not features_only: 32 | x = self.classifier(x) 33 | return x 34 | 35 | def _initialize_weights(self): 36 | for m in self.modules(): 37 | if isinstance(m, nn.Conv2d): 38 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 39 | m.weight.data.normal_(0, math.sqrt(2. / n)) 40 | if m.bias is not None: 41 | m.bias.data.zero_() 42 | elif isinstance(m, nn.BatchNorm2d): 43 | m.weight.data.fill_(1) 44 | m.bias.data.zero_() 45 | elif isinstance(m, nn.Linear): 46 | n = m.weight.size(1) 47 | m.weight.data.normal_(0, 0.01) 48 | m.bias.data.zero_() 49 | 50 | 51 | def make_layers(cfg, batch_norm=False): 52 | layers = [] 53 | in_channels = 3 54 | for v in cfg: 55 | if v == 'M': 56 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 57 | else: 58 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 59 | if batch_norm: 60 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 61 | else: 62 | layers += [conv2d, nn.ReLU(inplace=True)] 63 | in_channels = v 64 | return nn.Sequential(*layers) 65 | 66 | 67 | cfg = { 68 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 69 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 70 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 71 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 72 | } 73 | 74 | 75 | def vgg11(**kwargs): 76 | """VGG 11-layer model (configuration "A") 77 | 78 | Args: 79 | pretrained (bool): If True, returns a model pre-trained on ImageNet 80 | """ 81 | model = VGG(make_layers(cfg['A']), **kwargs) 82 | return model 83 | 84 | 85 | def vgg11_bn(**kwargs): 86 | """VGG 11-layer model (configuration "A") with batch normalization""" 87 | model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs) 88 | return model 89 | 90 | 91 | def vgg13(**kwargs): 92 | """VGG 13-layer model (configuration "B") 93 | 94 | Args: 95 | pretrained (bool): If True, returns a model pre-trained on ImageNet 96 | """ 97 | model = VGG(make_layers(cfg['B']), **kwargs) 98 | return model 99 | 100 | 101 | def vgg13_bn(**kwargs): 102 | """VGG 13-layer model (configuration "B") with batch normalization""" 103 | model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs) 104 | return model 105 | 106 | 107 | def vgg16(**kwargs): 108 | """VGG 16-layer model (configuration "D") 109 | 110 | Args: 111 | pretrained (bool): If True, returns a model pre-trained on ImageNet 112 | """ 113 | model = VGG(make_layers(cfg['D']), **kwargs) 114 | return model 115 | 116 | 117 | def vgg16_bn(**kwargs): 118 | """VGG 16-layer model (configuration "D") with batch normalization""" 119 | model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs) 120 | return model 121 | 122 | 123 | def vgg19(**kwargs): 124 | """VGG 19-layer model (configuration "E") 125 | 126 | Args: 127 | pretrained (bool): If True, returns a model pre-trained on ImageNet 128 | """ 129 | model = VGG(make_layers(cfg['E']), **kwargs) 130 | return model 131 | 132 | 133 | def vgg19_bn(**kwargs): 134 | """VGG 19-layer model (configuration 'E') with batch normalization""" 135 | model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs) 136 | return model 137 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

4 | 5 | # sensAI: ConvNets Decomposition via Class Parallelism for Fast Inference on Live Data 6 | 7 | ## Environment 8 | 9 | Linux, python 3.6+ 10 | 11 | ## Setup 12 | 13 | ```bash 14 | pip install -r requirements.txt 15 | ``` 16 | 17 | ## Instruction 18 | 19 | Supported CNN architectures and datasets: 20 | 21 | | Dataset | Architecture(`ARCH`) | 22 | | ------------- |:-------------:| 23 | | CIFAR-10 | vgg19_bn, resnet110, resnet164, mobilenetv2, shufflenetv2| 24 | | CIFAR-100 | vgg19_bn, resnet110, resnet164| 25 | | ImageNet-1K | vgg19_bn, resnet50| 26 | 27 | 28 | ### 1. Generate class groups 29 | 30 | For CIFAR-10/CIFAR-100: 31 | ```bash 32 | python3 group_selection.py \ 33 | --arch $ARCH \ 34 | --resume $pretrained_model \ 35 | --dataset $DATASET \ 36 | --ngroups $number_of_groups \ 37 | --gpu_num $number_of_gpu 38 | ``` 39 | For ImageNet-1K: 40 | ```bash 41 | python3 group_selection.py \ 42 | --arch $ARCH \ 43 | --dataset imagenet \ 44 | --ngroups $number_of_groups \ 45 | --gpu_num $number_of_gpu \ 46 | --data /{path_to_imagenet_dataset}/ 47 | ``` 48 | 49 | Pruning candidate now stored in `./prune_candidate_logs/` 50 | 51 | ### 2. Prune models 52 | 53 | For CIFAR-10/CIFAR-100: 54 | ```bash 55 | python3 prune_and_get_model.py \ 56 | -a $ARCH \ 57 | --dataset $DATASET \ 58 | --resume $pretrained_model \ 59 | -c ./prune_candidate_logs/ \ 60 | -s ./{TO_SAVE_PRUNED_MODEL_DIR}/ 61 | ``` 62 | For ImageNet-1K: 63 | ```bash 64 | python3 prune_and_get_model.py \ 65 | -a $ARCH \ 66 | --dataset imagenet \ 67 | -c ./prune_candidate_logs/ \ 68 | -s ./{TO_SAVE_PRUNED_MODEL_DIR}/ \ 69 | --pretrained 70 | ``` 71 | 72 | Pruned models are now saved in `./{TO_SAVE_PRUNED_MODEL_DIR}/$ARCH/` 73 | 74 | ### 3. Retrain pruned models 75 | 76 | For CIFAR-10/CIFAR-100: 77 | ```bash 78 | python3 retrain_grouped_model.py \ 79 | -a $ARCH \ 80 | --dataset $DATASET \ 81 | --resume ./{TO_SAVE_PRUNED_MODEL_DIR}/ \ 82 | --train_batch $batch_size \ 83 | --epochs $number_of_epochs \ 84 | --num_gpus $number_of_gpus 85 | ``` 86 | For ImageNet-1K: 87 | ```bash 88 | python3 retrain_grouped_model.py \ 89 | -a $ARCH \ 90 | --dataset imagenet \ 91 | --resume ./{TO_SAVE_PRUNED_MODEL_DIR}/ \ 92 | --epochs $number_of_epochs \ 93 | --num_gpus $number_of_gpus \ 94 | --train_batch $batch_size \ 95 | --data /{path_to_imagenet_dataset}/ 96 | ``` 97 | 98 | Retrained models now saved in `./{TO_SAVE_PRUNED_MODEL_DIR}_retrained/$ARCH/` 99 | 100 | ### 4. Evaluate 101 | 102 | For CIFAR-10/CIFAR-100: 103 | ```bash 104 | python3 evaluate.py \ 105 | -a $ARCH \ 106 | --dataset=$DATASET \ 107 | --retrained_dir ./{TO_SAVE_PRUNED_MODEL_DIR}_retrained/ \ 108 | --test-batch $batch_size 109 | ``` 110 | For ImageNet-1K: 111 | ```bash 112 | python3 evaluate.py \ 113 | -d imagenet \ 114 | -a $ARCH \ 115 | --retrained_dir ./{TO_SAVE_PRUNED_MODEL_DIR}_retrained/ \ 116 | --data /{path_to_imagenet_dataset}/ 117 | ``` 118 | 119 | ## Contributors 120 | 121 | Thanks for all the main contributors to this repository: 122 | 123 | * [Brandon Hsieh](https://github.com/hsiehbrandon) 124 | 125 | * [Zhuang Liu](https://github.com/liuzhuang13) 126 | 127 | * [Kenan Jiang](https://github.com/Kenan-Jiang) 128 | 129 | * [Kehan Wang](https://github.com/Jason-Khan) 130 | 131 | * [Siyuan Zhuang](https://github.com/suquark) 132 | 133 | And many others [Zihao Fan](https://github.com/zihao-fan), [Hank O'Brien](https://github.com/hjobrien) , [Yaoqing Yang](https://github.com/nsfzyzz), [Adarsh Karnati](https://github.com/akarnati11), [Jichan Chung](https://github.com/jichan3751), [Yingxin Kang](https://github.com/Miiira), [ 134 | Balaji Veeramani](https://github.com/bveeramani), [Sahil Rao](https://github.com/sahilrao21). 135 | 136 | 137 | 138 | 139 | ## Citation 140 | 141 | ```text 142 | @inproceedings{wang2021sensAI, 143 | author = {Guanhua Wang and Zhuang Liu and Brandon Hsieh and Siyuan Zhuang and Joseph Gonzalez and Trevor Darrell and Ion Stoica}, 144 | title = {{sensAI: ConvNets Decomposition via Class Parallelism for Fast Inference on Live Data}}, 145 | booktitle = {Proceedings of Fourth Conference on Machine Learning and Systems (MLSys'21)}, 146 | year = {2021} 147 | } 148 | ``` 149 | 150 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import matplotlib.pyplot as plt 3 | import os 4 | import sys 5 | import numpy as np 6 | 7 | __all__ = ['Logger', 'LoggerMonitor', 'savefig'] 8 | 9 | def savefig(fname, dpi=None): 10 | dpi = 150 if dpi == None else dpi 11 | plt.savefig(fname, dpi=dpi) 12 | 13 | def plot_overlap(logger, names=None): 14 | names = logger.names if names == None else names 15 | numbers = logger.numbers 16 | for _, name in enumerate(names): 17 | x = np.arange(len(numbers[name])) 18 | plt.plot(x, np.asarray(numbers[name])) 19 | return [logger.title + '(' + name + ')' for name in names] 20 | 21 | class Logger(object): 22 | '''Save training process to log file with simple plot function.''' 23 | def __init__(self, fpath, title=None, resume=False): 24 | self.file = None 25 | self.resume = resume 26 | self.title = '' if title == None else title 27 | if fpath is not None: 28 | if resume: 29 | self.file = open(fpath, 'r') 30 | name = self.file.readline() 31 | self.names = name.rstrip().split('\t') 32 | self.numbers = {} 33 | for _, name in enumerate(self.names): 34 | self.numbers[name] = [] 35 | 36 | for numbers in self.file: 37 | numbers = numbers.rstrip().split('\t') 38 | for i in range(0, len(numbers)): 39 | self.numbers[self.names[i]].append(numbers[i]) 40 | self.file.close() 41 | self.file = open(fpath, 'a') 42 | else: 43 | self.file = open(fpath, 'w') 44 | 45 | def set_names(self, names): 46 | if self.resume: 47 | pass 48 | # initialize numbers as empty list 49 | self.numbers = {} 50 | self.names = names 51 | for _, name in enumerate(self.names): 52 | self.file.write(name) 53 | self.file.write('\t') 54 | self.numbers[name] = [] 55 | self.file.write('\n') 56 | self.file.flush() 57 | 58 | 59 | def append(self, numbers): 60 | assert len(self.names) == len(numbers), 'Numbers do not match names' 61 | for index, num in enumerate(numbers): 62 | self.file.write("{0:.6f}".format(num)) 63 | self.file.write('\t') 64 | self.numbers[self.names[index]].append(num) 65 | self.file.write('\n') 66 | self.file.flush() 67 | 68 | def plot(self, names=None): 69 | names = self.names if names == None else names 70 | numbers = self.numbers 71 | for _, name in enumerate(names): 72 | x = np.arange(len(numbers[name])) 73 | plt.plot(x, np.asarray(numbers[name])) 74 | plt.legend([self.title + '(' + name + ')' for name in names]) 75 | plt.grid(True) 76 | 77 | def close(self): 78 | if self.file is not None: 79 | self.file.close() 80 | 81 | class LoggerMonitor(object): 82 | '''Load and visualize multiple logs.''' 83 | def __init__ (self, paths): 84 | '''paths is a distionary with {name:filepath} pair''' 85 | self.loggers = [] 86 | for title, path in paths.items(): 87 | logger = Logger(path, title=title, resume=True) 88 | self.loggers.append(logger) 89 | 90 | def plot(self, names=None): 91 | plt.figure() 92 | plt.subplot(121) 93 | legend_text = [] 94 | for logger in self.loggers: 95 | legend_text += plot_overlap(logger, names) 96 | plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 97 | plt.grid(True) 98 | 99 | if __name__ == '__main__': 100 | # # Example 101 | # logger = Logger('test.txt') 102 | # logger.set_names(['Train loss', 'Valid loss','Test loss']) 103 | 104 | # length = 100 105 | # t = np.arange(length) 106 | # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 107 | # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 108 | # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 109 | 110 | # for i in range(0, length): 111 | # logger.append([train_loss[i], valid_loss[i], test_loss[i]]) 112 | # logger.plot() 113 | 114 | # Example: logger monitor 115 | paths = { 116 | 'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt', 117 | 'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt', 118 | 'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt', 119 | } 120 | 121 | field = ['Valid Acc.'] 122 | 123 | monitor = LoggerMonitor(paths) 124 | monitor.plot(names=field) 125 | savefig('test.eps') -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import matplotlib.pyplot as plt 3 | import os 4 | import sys 5 | import numpy as np 6 | 7 | __all__ = ['Logger', 'LoggerMonitor', 'savefig'] 8 | 9 | def savefig(fname, dpi=None): 10 | dpi = 150 if dpi == None else dpi 11 | plt.savefig(fname, dpi=dpi) 12 | 13 | def plot_overlap(logger, names=None): 14 | names = logger.names if names == None else names 15 | numbers = logger.numbers 16 | for _, name in enumerate(names): 17 | x = np.arange(len(numbers[name])) 18 | plt.plot(x, np.asarray(numbers[name])) 19 | return [logger.title + '(' + name + ')' for name in names] 20 | 21 | class Logger(object): 22 | '''Save training process to log file with simple plot function.''' 23 | def __init__(self, fpath, title=None, resume=False): 24 | self.file = None 25 | self.resume = resume 26 | self.title = '' if title == None else title 27 | if fpath is not None: 28 | if resume: 29 | self.file = open(fpath, 'r') 30 | name = self.file.readline() 31 | self.names = name.rstrip().split('\t') 32 | self.numbers = {} 33 | for _, name in enumerate(self.names): 34 | self.numbers[name] = [] 35 | 36 | for numbers in self.file: 37 | numbers = numbers.rstrip().split('\t') 38 | for i in range(0, len(numbers)): 39 | self.numbers[self.names[i]].append(numbers[i]) 40 | self.file.close() 41 | self.file = open(fpath, 'a') 42 | else: 43 | self.file = open(fpath, 'w') 44 | 45 | def set_names(self, names): 46 | if self.resume: 47 | pass 48 | # initialize numbers as empty list 49 | self.numbers = {} 50 | self.names = names 51 | for _, name in enumerate(self.names): 52 | self.file.write(name) 53 | self.file.write('\t') 54 | self.numbers[name] = [] 55 | self.file.write('\n') 56 | self.file.flush() 57 | 58 | 59 | def append(self, numbers): 60 | assert len(self.names) == len(numbers), 'Numbers do not match names' 61 | for index, num in enumerate(numbers): 62 | self.file.write("{0:.6f}".format(num)) 63 | self.file.write('\t') 64 | self.numbers[self.names[index]].append(num) 65 | self.file.write('\n') 66 | self.file.flush() 67 | 68 | def plot(self, names=None): 69 | names = self.names if names == None else names 70 | numbers = self.numbers 71 | for _, name in enumerate(names): 72 | x = np.arange(len(numbers[name])) 73 | plt.plot(x, np.asarray(numbers[name])) 74 | plt.legend([self.title + '(' + name + ')' for name in names]) 75 | plt.grid(True) 76 | 77 | def close(self): 78 | if self.file is not None: 79 | self.file.close() 80 | 81 | class LoggerMonitor(object): 82 | '''Load and visualize multiple logs.''' 83 | def __init__ (self, paths): 84 | '''paths is a distionary with {name:filepath} pair''' 85 | self.loggers = [] 86 | for title, path in paths.items(): 87 | logger = Logger(path, title=title, resume=True) 88 | self.loggers.append(logger) 89 | 90 | def plot(self, names=None): 91 | plt.figure() 92 | plt.subplot(121) 93 | legend_text = [] 94 | for logger in self.loggers: 95 | legend_text += plot_overlap(logger, names) 96 | plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 97 | plt.grid(True) 98 | 99 | if __name__ == '__main__': 100 | # # Example 101 | # logger = Logger('test.txt') 102 | # logger.set_names(['Train loss', 'Valid loss','Test loss']) 103 | 104 | # length = 100 105 | # t = np.arange(length) 106 | # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 107 | # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 108 | # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 109 | 110 | # for i in range(0, length): 111 | # logger.append([train_loss[i], valid_loss[i], test_loss[i]]) 112 | # logger.plot() 113 | 114 | # Example: logger monitor 115 | paths = { 116 | 'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt', 117 | 'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt', 118 | 'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt', 119 | } 120 | 121 | field = ['Valid Acc.'] 122 | 123 | monitor = LoggerMonitor(paths) 124 | monitor.plot(names=field) 125 | savefig('test.eps') -------------------------------------------------------------------------------- /apoz_policy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import contextlib 4 | import torch.nn.functional as F 5 | 6 | def apoz_scoring(activation): 7 | """ 8 | Calculate the Average Percentage of Zeros Score of the feature map activation layer output 9 | """ 10 | activation = (activation.abs() <= 0.005).float() 11 | if activation.dim() == 4: 12 | featuremap_apoz_mat = activation.mean(dim=(0, 2, 3)) 13 | elif activation.dim() == 2: 14 | featuremap_apoz_mat = activation.mean(dim=(0, 1)) 15 | else: 16 | raise ValueError( 17 | f"activation_channels_avg: Unsupported shape: {activation.shape}") 18 | return featuremap_apoz_mat.mul(100).cpu() 19 | 20 | 21 | def avg_scoring(activation): 22 | activation = activation.abs() 23 | if activation.dim() == 4: 24 | featuremap_avg_mat = activation.mean(dim=(0, 2, 3)) 25 | elif activation.dim() == 2: 26 | featuremap_avg_mat = activation.mean(dim=(0, 1)) 27 | else: 28 | raise ValueError( 29 | f"activation_channels_avg: Unsupported shape: {activation.shape}") 30 | return featuremap_avg_mat.cpu() 31 | 32 | 33 | class ActivationRecord: 34 | def __init__(self, model, arch): 35 | self.apoz_scores_by_layer = [] 36 | self.avg_scores_by_layer = [] 37 | self.num_batches = 0 38 | self.layer_idx = 0 39 | self._candidates_by_layer = None 40 | self._model = model 41 | # switch to evaluate mode 42 | self._model.eval() 43 | self._model.apply(lambda m: m.register_forward_hook(self._hook)) 44 | self.arch = arch 45 | 46 | def parse_activation(self, feature_map): 47 | apoz_score = apoz_scoring(feature_map).numpy() 48 | avg_score = avg_scoring(feature_map).numpy() 49 | 50 | if self.num_batches == 0: 51 | self.apoz_scores_by_layer.append(apoz_score) 52 | self.avg_scores_by_layer.append(avg_score) 53 | else: 54 | self.apoz_scores_by_layer[self.layer_idx] += apoz_score 55 | self.avg_scores_by_layer[self.layer_idx] += avg_score 56 | self.layer_idx += 1 57 | 58 | def __enter__(self): 59 | return self 60 | 61 | def __exit__(self, exception_type, exception_value, traceback): 62 | for score in self.apoz_scores_by_layer: 63 | score /= self.num_batches 64 | for score in self.avg_scores_by_layer: 65 | score /= self.num_batches 66 | 67 | def record_batch(self, *args, **kwargs): 68 | # reset layer index 69 | self.layer_idx = 0 70 | with torch.no_grad(): 71 | # output is not used 72 | _ = self._model(*args, **kwargs) 73 | self.num_batches += 1 74 | 75 | def _hook(self, module, input, output): 76 | """Apply a hook to RelU layer""" 77 | if self.arch == "shufflenetv2": 78 | if module.__class__.__name__ == 'BatchNorm2d': 79 | self.parse_activation(F.relu(output)) 80 | else: 81 | if module.__class__.__name__ == 'ReLU': 82 | self.parse_activation(output) 83 | 84 | def generate_pruned_candidates(self): 85 | num_layers = len(self.apoz_scores_by_layer) 86 | thresholds = [73] * num_layers 87 | avg_thresholds = [0.01] * num_layers 88 | 89 | candidates_by_layer = [] 90 | for layer_idx, (apoz_scores, avg_scores) in enumerate(zip(self.apoz_scores_by_layer, self.avg_scores_by_layer)): 91 | if self.arch == "mobilenetv2": 92 | apoz_scores = torch.Tensor(apoz_scores) 93 | avg_scores = torch.Tensor(avg_scores) 94 | avg_candidates = [idx for idx, score in enumerate( 95 | avg_scores) if score >= avg_thresholds[layer_idx]] 96 | candidates = [(idx,float(score)) for idx, score in enumerate(apoz_scores) if score >= thresholds[layer_idx]] 97 | candidates = sorted(candidates, key = lambda x: x[1])[:int(len(candidates)/2)] 98 | candidates = [x[0] for x in candidates] 99 | else: 100 | apoz_scores = torch.Tensor(apoz_scores) 101 | avg_scores = torch.Tensor(avg_scores) 102 | avg_candidates = [idx for idx, score in enumerate( 103 | avg_scores) if score >= avg_thresholds[layer_idx]] 104 | candidates = [x[0] for x in apoz_scores.gt( 105 | thresholds[layer_idx]).nonzero().tolist()] 106 | difference_candidates = list( 107 | set(candidates).difference(set(avg_candidates))) 108 | candidates_by_layer.append(difference_candidates) 109 | """ 110 | DEBUG: Printing out remaining neuron IDs 111 | all_neuron = [idx for idx, score in enumerate(avg_scores)] 112 | remaining = list(set(all_neuron)-set(difference_candidates)) 113 | print("\nThose remaining neuron index for layer ", layer_idx) 114 | print(remaining) 115 | """ 116 | print( 117 | f"Total pruned candidates: {sum(len(l) for l in candidates_by_layer)}") 118 | return candidates_by_layer 119 | -------------------------------------------------------------------------------- /models/cifar/densenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | 7 | __all__ = ['densenet'] 8 | 9 | 10 | from torch.autograd import Variable 11 | 12 | class Bottleneck(nn.Module): 13 | def __init__(self, inplanes, expansion=4, growthRate=12, dropRate=0): 14 | super(Bottleneck, self).__init__() 15 | planes = expansion * growthRate 16 | self.bn1 = nn.BatchNorm2d(inplanes) 17 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 18 | self.bn2 = nn.BatchNorm2d(planes) 19 | self.conv2 = nn.Conv2d(planes, growthRate, kernel_size=3, 20 | padding=1, bias=False) 21 | self.relu = nn.ReLU(inplace=True) 22 | self.dropRate = dropRate 23 | 24 | def forward(self, x): 25 | out = self.bn1(x) 26 | out = self.relu(out) 27 | out = self.conv1(out) 28 | out = self.bn2(out) 29 | out = self.relu(out) 30 | out = self.conv2(out) 31 | if self.dropRate > 0: 32 | out = F.dropout(out, p=self.dropRate, training=self.training) 33 | 34 | out = torch.cat((x, out), 1) 35 | 36 | return out 37 | 38 | 39 | class BasicBlock(nn.Module): 40 | def __init__(self, inplanes, expansion=1, growthRate=12, dropRate=0): 41 | super(BasicBlock, self).__init__() 42 | planes = expansion * growthRate 43 | self.bn1 = nn.BatchNorm2d(inplanes) 44 | self.conv1 = nn.Conv2d(inplanes, growthRate, kernel_size=3, 45 | padding=1, bias=False) 46 | self.relu = nn.ReLU(inplace=True) 47 | self.dropRate = dropRate 48 | 49 | def forward(self, x): 50 | out = self.bn1(x) 51 | out = self.relu(out) 52 | out = self.conv1(out) 53 | if self.dropRate > 0: 54 | out = F.dropout(out, p=self.dropRate, training=self.training) 55 | 56 | out = torch.cat((x, out), 1) 57 | 58 | return out 59 | 60 | 61 | class Transition(nn.Module): 62 | def __init__(self, inplanes, outplanes): 63 | super(Transition, self).__init__() 64 | self.bn1 = nn.BatchNorm2d(inplanes) 65 | self.conv1 = nn.Conv2d(inplanes, outplanes, kernel_size=1, 66 | bias=False) 67 | self.relu = nn.ReLU(inplace=True) 68 | 69 | def forward(self, x): 70 | out = self.bn1(x) 71 | out = self.relu(out) 72 | out = self.conv1(out) 73 | out = F.avg_pool2d(out, 2) 74 | return out 75 | 76 | 77 | class DenseNet(nn.Module): 78 | 79 | def __init__(self, depth=22, block=Bottleneck, 80 | dropRate=0, num_classes=10, growthRate=12, compressionRate=2): 81 | super(DenseNet, self).__init__() 82 | 83 | assert (depth - 4) % 3 == 0, 'depth should be 3n+4' 84 | n = (depth - 4) / 3 if block == BasicBlock else (depth - 4) // 6 85 | 86 | self.growthRate = growthRate 87 | self.dropRate = dropRate 88 | 89 | # self.inplanes is a global variable used across multiple 90 | # helper functions 91 | self.inplanes = growthRate * 2 92 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, padding=1, 93 | bias=False) 94 | self.dense1 = self._make_denseblock(block, n) 95 | self.trans1 = self._make_transition(compressionRate) 96 | self.dense2 = self._make_denseblock(block, n) 97 | self.trans2 = self._make_transition(compressionRate) 98 | self.dense3 = self._make_denseblock(block, n) 99 | self.bn = nn.BatchNorm2d(self.inplanes) 100 | self.relu = nn.ReLU(inplace=True) 101 | self.avgpool = nn.AvgPool2d(8) 102 | self.fc = nn.Linear(self.inplanes, num_classes) 103 | 104 | # Weight initialization 105 | for m in self.modules(): 106 | if isinstance(m, nn.Conv2d): 107 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 108 | m.weight.data.normal_(0, math.sqrt(2. / n)) 109 | elif isinstance(m, nn.BatchNorm2d): 110 | m.weight.data.fill_(1) 111 | m.bias.data.zero_() 112 | 113 | def _make_denseblock(self, block, blocks): 114 | layers = [] 115 | for i in range(blocks): 116 | # Currently we fix the expansion ratio as the default value 117 | layers.append(block(self.inplanes, growthRate=self.growthRate, dropRate=self.dropRate)) 118 | self.inplanes += self.growthRate 119 | 120 | return nn.Sequential(*layers) 121 | 122 | def _make_transition(self, compressionRate): 123 | inplanes = self.inplanes 124 | outplanes = int(math.floor(self.inplanes // compressionRate)) 125 | self.inplanes = outplanes 126 | return Transition(inplanes, outplanes) 127 | 128 | 129 | def forward(self, x): 130 | x = self.conv1(x) 131 | 132 | x = self.trans1(self.dense1(x)) 133 | x = self.trans2(self.dense2(x)) 134 | x = self.dense3(x) 135 | x = self.bn(x) 136 | x = self.relu(x) 137 | 138 | x = self.avgpool(x) 139 | x = x.view(x.size(0), -1) 140 | x = self.fc(x) 141 | 142 | return x 143 | 144 | 145 | def densenet(**kwargs): 146 | """ 147 | Constructs a ResNet model. 148 | """ 149 | return DenseNet(**kwargs) -------------------------------------------------------------------------------- /models/cifar/resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch.nn as nn 4 | import math 5 | 6 | 7 | def conv3x3(in_planes, out_planes, stride=1): 8 | "3x3 convolution with padding" 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 10 | padding=1, bias=False) 11 | 12 | 13 | class BasicBlock(nn.Module): 14 | expansion = 1 15 | 16 | def __init__(self, inplanes, planes, stride=1, downsample=None): 17 | super(BasicBlock, self).__init__() 18 | self.conv1 = conv3x3(inplanes, planes, stride) 19 | self.bn1 = nn.BatchNorm2d(planes) 20 | self.relu = nn.ReLU(inplace=True) 21 | self.conv2 = conv3x3(planes, planes) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.downsample = downsample 24 | self.stride = stride 25 | 26 | def forward(self, x): 27 | residual = x 28 | 29 | out = self.conv1(x) 30 | out = self.bn1(out) 31 | out = self.relu(out) 32 | 33 | out = self.conv2(out) 34 | out = self.bn2(out) 35 | 36 | if self.downsample is not None: 37 | residual = self.downsample(x) 38 | 39 | out += residual 40 | out = self.relu(out) 41 | 42 | return out 43 | 44 | 45 | class Bottleneck(nn.Module): 46 | expansion = 4 47 | 48 | def __init__(self, inplanes, planes, stride=1, downsample=None): 49 | super(Bottleneck, self).__init__() 50 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 51 | self.bn1 = nn.BatchNorm2d(planes) 52 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 53 | padding=1, bias=False) 54 | self.bn2 = nn.BatchNorm2d(planes) 55 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 56 | self.bn3 = nn.BatchNorm2d(planes * 4) 57 | self.relu = nn.ReLU(inplace=True) 58 | self.downsample = downsample 59 | self.stride = stride 60 | 61 | def forward(self, x): 62 | residual = x 63 | 64 | out = self.conv1(x) 65 | out = self.bn1(out) 66 | out = self.relu(out) 67 | 68 | out = self.conv2(out) 69 | out = self.bn2(out) 70 | out = self.relu(out) 71 | 72 | out = self.conv3(out) 73 | out = self.bn3(out) 74 | 75 | if self.downsample is not None: 76 | residual = self.downsample(x) 77 | 78 | out += residual 79 | out = self.relu(out) 80 | 81 | return out 82 | 83 | 84 | class ResNet(nn.Module): 85 | 86 | def __init__(self, depth, num_classes=1000, block_name='BasicBlock'): 87 | super(ResNet, self).__init__() 88 | # Model type specifies number of layers for CIFAR-10 model 89 | if block_name.lower() == 'basicblock': 90 | assert (depth - 2) % 6 == 0, 'When use basicblock, depth should be 6n+2, e.g. 20, 32, 44, 56, 110, 1202' 91 | n = (depth - 2) // 6 92 | block = BasicBlock 93 | elif block_name.lower() == 'bottleneck': 94 | assert (depth - 2) % 9 == 0, 'When use bottleneck, depth should be 9n+2, e.g. 20, 29, 47, 56, 110, 1199' 95 | n = (depth - 2) // 9 96 | block = Bottleneck 97 | else: 98 | raise ValueError('block_name shoule be Basicblock or Bottleneck') 99 | 100 | 101 | self.inplanes = 16 102 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, 103 | bias=False) 104 | self.bn1 = nn.BatchNorm2d(16) 105 | self.relu = nn.ReLU(inplace=True) 106 | self.layer1 = self._make_layer(block, 16, n) 107 | self.layer2 = self._make_layer(block, 32, n, stride=2) 108 | self.layer3 = self._make_layer(block, 64, n, stride=2) 109 | self.avgpool = nn.AvgPool2d(8) 110 | self.fc = nn.Linear(64 * block.expansion, num_classes) 111 | 112 | for m in self.modules(): 113 | if isinstance(m, nn.Conv2d): 114 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 115 | m.weight.data.normal_(0, math.sqrt(2. / n)) 116 | elif isinstance(m, nn.BatchNorm2d): 117 | m.weight.data.fill_(1) 118 | m.bias.data.zero_() 119 | 120 | def _make_layer(self, block, planes, blocks, stride=1): 121 | downsample = None 122 | if stride != 1 or self.inplanes != planes * block.expansion: 123 | downsample = nn.Sequential( 124 | nn.Conv2d(self.inplanes, planes * block.expansion, 125 | kernel_size=1, stride=stride, bias=False), 126 | nn.BatchNorm2d(planes * block.expansion), 127 | ) 128 | 129 | layers = [] 130 | layers.append(block(self.inplanes, planes, stride, downsample)) 131 | self.inplanes = planes * block.expansion 132 | for i in range(1, blocks): 133 | layers.append(block(self.inplanes, planes)) 134 | 135 | return nn.Sequential(*layers) 136 | 137 | def forward(self, x, features_only=False): 138 | x = self.conv1(x) 139 | x = self.bn1(x) 140 | x = self.relu(x) # 32x32 141 | 142 | x = self.layer1(x) # 32x32 143 | x = self.layer2(x) # 16x16 144 | x = self.layer3(x) # 8x8 145 | 146 | x = self.avgpool(x) 147 | x = x.view(x.size(0), -1) 148 | if features_only: 149 | return x 150 | x = self.fc(x) 151 | 152 | return x 153 | 154 | 155 | def resnet110(**kwargs): 156 | """ 157 | Constructs a ResNet-110 model. 158 | """ 159 | return ResNet(depth=110, block_name='bottleneck', **kwargs) 160 | 161 | 162 | def resnet164(**kwargs): 163 | """ 164 | Constructs a ResNet-164 model. 165 | """ 166 | return ResNet(depth=164, block_name='bottleneck', **kwargs) 167 | 168 | 169 | __all__ = ['resnet110', 'resnet164'] 170 | -------------------------------------------------------------------------------- /models/cifar/resnext.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | """ 3 | Creates a ResNeXt Model as defined in: 4 | Xie, S., Girshick, R., Dollar, P., Tu, Z., & He, K. (2016). 5 | Aggregated residual transformations for deep neural networks. 6 | arXiv preprint arXiv:1611.05431. 7 | import from https://github.com/prlz77/ResNeXt.pytorch/blob/master/models/model.py 8 | """ 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.nn import init 12 | 13 | __all__ = ['resnext'] 14 | 15 | class ResNeXtBottleneck(nn.Module): 16 | """ 17 | RexNeXt bottleneck type C (https://github.com/facebookresearch/ResNeXt/blob/master/models/resnext.lua) 18 | """ 19 | def __init__(self, in_channels, out_channels, stride, cardinality, widen_factor): 20 | """ Constructor 21 | Args: 22 | in_channels: input channel dimensionality 23 | out_channels: output channel dimensionality 24 | stride: conv stride. Replaces pooling layer. 25 | cardinality: num of convolution groups. 26 | widen_factor: factor to reduce the input dimensionality before convolution. 27 | """ 28 | super(ResNeXtBottleneck, self).__init__() 29 | D = cardinality * out_channels // widen_factor 30 | self.conv_reduce = nn.Conv2d(in_channels, D, kernel_size=1, stride=1, padding=0, bias=False) 31 | self.bn_reduce = nn.BatchNorm2d(D) 32 | self.conv_conv = nn.Conv2d(D, D, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False) 33 | self.bn = nn.BatchNorm2d(D) 34 | self.conv_expand = nn.Conv2d(D, out_channels, kernel_size=1, stride=1, padding=0, bias=False) 35 | self.bn_expand = nn.BatchNorm2d(out_channels) 36 | 37 | self.shortcut = nn.Sequential() 38 | if in_channels != out_channels: 39 | self.shortcut.add_module('shortcut_conv', nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=False)) 40 | self.shortcut.add_module('shortcut_bn', nn.BatchNorm2d(out_channels)) 41 | 42 | def forward(self, x): 43 | bottleneck = self.conv_reduce.forward(x) 44 | bottleneck = F.relu(self.bn_reduce.forward(bottleneck), inplace=True) 45 | bottleneck = self.conv_conv.forward(bottleneck) 46 | bottleneck = F.relu(self.bn.forward(bottleneck), inplace=True) 47 | bottleneck = self.conv_expand.forward(bottleneck) 48 | bottleneck = self.bn_expand.forward(bottleneck) 49 | residual = self.shortcut.forward(x) 50 | return F.relu(residual + bottleneck, inplace=True) 51 | 52 | 53 | class CifarResNeXt(nn.Module): 54 | """ 55 | ResNext optimized for the Cifar dataset, as specified in 56 | https://arxiv.org/pdf/1611.05431.pdf 57 | """ 58 | def __init__(self, cardinality, depth, num_classes, widen_factor=4, dropRate=0): 59 | """ Constructor 60 | Args: 61 | cardinality: number of convolution groups. 62 | depth: number of layers. 63 | num_classes: number of classes 64 | widen_factor: factor to adjust the channel dimensionality 65 | """ 66 | super(CifarResNeXt, self).__init__() 67 | self.cardinality = cardinality 68 | self.depth = depth 69 | self.block_depth = (self.depth - 2) // 9 70 | self.widen_factor = widen_factor 71 | self.num_classes = num_classes 72 | self.output_size = 64 73 | self.stages = [64, 64 * self.widen_factor, 128 * self.widen_factor, 256 * self.widen_factor] 74 | 75 | self.conv_1_3x3 = nn.Conv2d(3, 64, 3, 1, 1, bias=False) 76 | self.bn_1 = nn.BatchNorm2d(64) 77 | self.stage_1 = self.block('stage_1', self.stages[0], self.stages[1], 1) 78 | self.stage_2 = self.block('stage_2', self.stages[1], self.stages[2], 2) 79 | self.stage_3 = self.block('stage_3', self.stages[2], self.stages[3], 2) 80 | self.classifier = nn.Linear(1024, num_classes) 81 | init.kaiming_normal(self.classifier.weight) 82 | 83 | for key in self.state_dict(): 84 | if key.split('.')[-1] == 'weight': 85 | if 'conv' in key: 86 | init.kaiming_normal(self.state_dict()[key], mode='fan_out') 87 | if 'bn' in key: 88 | self.state_dict()[key][...] = 1 89 | elif key.split('.')[-1] == 'bias': 90 | self.state_dict()[key][...] = 0 91 | 92 | def block(self, name, in_channels, out_channels, pool_stride=2): 93 | """ Stack n bottleneck modules where n is inferred from the depth of the network. 94 | Args: 95 | name: string name of the current block. 96 | in_channels: number of input channels 97 | out_channels: number of output channels 98 | pool_stride: factor to reduce the spatial dimensionality in the first bottleneck of the block. 99 | Returns: a Module consisting of n sequential bottlenecks. 100 | """ 101 | block = nn.Sequential() 102 | for bottleneck in range(self.block_depth): 103 | name_ = '%s_bottleneck_%d' % (name, bottleneck) 104 | if bottleneck == 0: 105 | block.add_module(name_, ResNeXtBottleneck(in_channels, out_channels, pool_stride, self.cardinality, 106 | self.widen_factor)) 107 | else: 108 | block.add_module(name_, 109 | ResNeXtBottleneck(out_channels, out_channels, 1, self.cardinality, self.widen_factor)) 110 | return block 111 | 112 | def forward(self, x): 113 | x = self.conv_1_3x3.forward(x) 114 | x = F.relu(self.bn_1.forward(x), inplace=True) 115 | x = self.stage_1.forward(x) 116 | x = self.stage_2.forward(x) 117 | x = self.stage_3.forward(x) 118 | x = F.avg_pool2d(x, 8, 1) 119 | x = x.view(-1, 1024) 120 | return self.classifier(x) 121 | 122 | def resnext(**kwargs): 123 | """Constructs a ResNeXt. 124 | """ 125 | model = CifarResNeXt(**kwargs) 126 | return model -------------------------------------------------------------------------------- /models/cifar/shufflenetv2.py: -------------------------------------------------------------------------------- 1 | '''ShuffleNetV2 in PyTorch. 2 | 3 | See the paper "ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" for more details. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class ShuffleBlock(nn.Module): 11 | def __init__(self, groups=2): 12 | super(ShuffleBlock, self).__init__() 13 | self.groups = groups 14 | 15 | def forward(self, x): 16 | '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]''' 17 | N, C, H, W = x.size() 18 | g = self.groups 19 | return x.view(N, g, C//g, H, W).permute(0, 2, 1, 3, 4).reshape(N, C, H, W) 20 | 21 | 22 | class SplitBlock(nn.Module): 23 | def __init__(self, ratio): 24 | super(SplitBlock, self).__init__() 25 | self.ratio = ratio 26 | 27 | def forward(self, x): 28 | c = int(x.size(1) * self.ratio) 29 | return x[:, :c, :, :], x[:, c:, :, :] 30 | 31 | 32 | class BasicBlock(nn.Module): 33 | def __init__(self, in_channels, split_ratio=0.5): 34 | super(BasicBlock, self).__init__() 35 | self.split = SplitBlock(split_ratio) 36 | in_channels = int(in_channels * split_ratio) 37 | self.conv1 = nn.Conv2d(in_channels, in_channels, 38 | kernel_size=1, bias=False) 39 | self.bn1 = nn.BatchNorm2d(in_channels) 40 | self.conv2 = nn.Conv2d(in_channels, in_channels, 41 | kernel_size=3, stride=1, padding=1, groups=in_channels, bias=False) 42 | self.bn2 = nn.BatchNorm2d(in_channels) 43 | self.conv3 = nn.Conv2d(in_channels, in_channels, 44 | kernel_size=1, bias=False) 45 | self.bn3 = nn.BatchNorm2d(in_channels) 46 | self.shuffle = ShuffleBlock() 47 | 48 | def forward(self, x): 49 | x1, x2 = self.split(x) 50 | out = F.relu(self.bn1(self.conv1(x2))) 51 | out = self.bn2(self.conv2(out)) 52 | out = F.relu(self.bn3(self.conv3(out))) 53 | out = torch.cat([x1, out], 1) 54 | out = self.shuffle(out) 55 | return out 56 | 57 | 58 | class DownBlock(nn.Module): 59 | def __init__(self, in_channels, out_channels): 60 | super(DownBlock, self).__init__() 61 | mid_channels = out_channels // 2 62 | # left 63 | self.conv1 = nn.Conv2d(in_channels, in_channels, 64 | kernel_size=3, stride=2, padding=1, groups=in_channels, bias=False) 65 | self.bn1 = nn.BatchNorm2d(in_channels) 66 | self.conv2 = nn.Conv2d(in_channels, mid_channels, 67 | kernel_size=1, bias=False) 68 | self.bn2 = nn.BatchNorm2d(mid_channels) 69 | # right 70 | self.conv3 = nn.Conv2d(in_channels, mid_channels, 71 | kernel_size=1, bias=False) 72 | self.bn3 = nn.BatchNorm2d(mid_channels) 73 | self.conv4 = nn.Conv2d(mid_channels, mid_channels, 74 | kernel_size=3, stride=2, padding=1, groups=mid_channels, bias=False) 75 | self.bn4 = nn.BatchNorm2d(mid_channels) 76 | self.conv5 = nn.Conv2d(mid_channels, mid_channels, 77 | kernel_size=1, bias=False) 78 | self.bn5 = nn.BatchNorm2d(mid_channels) 79 | 80 | self.shuffle = ShuffleBlock() 81 | 82 | def forward(self, x): 83 | # left 84 | out1 = self.bn1(self.conv1(x)) 85 | out1 = F.relu(self.bn2(self.conv2(out1))) 86 | # right 87 | out2 = F.relu(self.bn3(self.conv3(x))) 88 | out2 = self.bn4(self.conv4(out2)) 89 | out2 = F.relu(self.bn5(self.conv5(out2))) 90 | # concat 91 | out = torch.cat([out1, out2], 1) 92 | out = self.shuffle(out) 93 | return out 94 | 95 | 96 | class ShuffleNetV2(nn.Module): 97 | def __init__(self, net_size): 98 | super(ShuffleNetV2, self).__init__() 99 | out_channels = configs[net_size]['out_channels'] 100 | num_blocks = configs[net_size]['num_blocks'] 101 | 102 | self.conv1 = nn.Conv2d(3, 24, kernel_size=3, 103 | stride=1, padding=1, bias=False) 104 | self.bn1 = nn.BatchNorm2d(24) 105 | self.in_channels = 24 106 | self.layer1 = self._make_layer(out_channels[0], num_blocks[0]) 107 | self.layer2 = self._make_layer(out_channels[1], num_blocks[1]) 108 | self.layer3 = self._make_layer(out_channels[2], num_blocks[2]) 109 | self.conv2 = nn.Conv2d(out_channels[2], out_channels[3], 110 | kernel_size=1, stride=1, padding=0, bias=False) 111 | self.bn2 = nn.BatchNorm2d(out_channels[3]) 112 | self.linear = nn.Linear(out_channels[3], 10) 113 | 114 | def _make_layer(self, out_channels, num_blocks): 115 | layers = [DownBlock(self.in_channels, out_channels)] 116 | for i in range(num_blocks): 117 | layers.append(BasicBlock(out_channels)) 118 | self.in_channels = out_channels 119 | return nn.Sequential(*layers) 120 | 121 | def forward(self, x, features_only=False): 122 | out = F.relu(self.bn1(self.conv1(x))) 123 | # out = F.max_pool2d(out, 3, stride=2, padding=1) 124 | out = self.layer1(out) 125 | out = self.layer2(out) 126 | out = self.layer3(out) 127 | out = F.relu(self.bn2(self.conv2(out))) 128 | out = F.avg_pool2d(out, 4) 129 | out = out.view(out.size(0), -1) 130 | if not features_only: 131 | out = self.linear(out) 132 | return out 133 | 134 | 135 | configs = { 136 | 0.5: { 137 | 'out_channels': (48, 96, 192, 1024), 138 | 'num_blocks': (3, 7, 3) 139 | }, 140 | 141 | 1: { 142 | 'out_channels': (116, 232, 464, 1024), 143 | 'num_blocks': (3, 7, 3) 144 | }, 145 | 1.5: { 146 | 'out_channels': (176, 352, 704, 1024), 147 | 'num_blocks': (3, 7, 3) 148 | }, 149 | 2: { 150 | 'out_channels': (224, 488, 976, 2048), 151 | 'num_blocks': (3, 7, 3) 152 | } 153 | } 154 | 155 | 156 | def shufflenetv2(**kwargs): 157 | model = ShuffleNetV2(1) 158 | return model 159 | 160 | -------------------------------------------------------------------------------- /models/imagenet/resnext.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | """ 3 | Creates a ResNeXt Model as defined in: 4 | Xie, S., Girshick, R., Dollar, P., Tu, Z., & He, K. (2016). 5 | Aggregated residual transformations for deep neural networks. 6 | arXiv preprint arXiv:1611.05431. 7 | import from https://github.com/facebookresearch/ResNeXt/blob/master/models/resnext.lua 8 | """ 9 | import math 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torch.nn import init 13 | import torch 14 | 15 | __all__ = ['resnext50', 'resnext101', 'resnext152'] 16 | 17 | class Bottleneck(nn.Module): 18 | """ 19 | RexNeXt bottleneck type C 20 | """ 21 | expansion = 4 22 | 23 | def __init__(self, inplanes, planes, baseWidth, cardinality, stride=1, downsample=None): 24 | """ Constructor 25 | Args: 26 | inplanes: input channel dimensionality 27 | planes: output channel dimensionality 28 | baseWidth: base width. 29 | cardinality: num of convolution groups. 30 | stride: conv stride. Replaces pooling layer. 31 | """ 32 | super(Bottleneck, self).__init__() 33 | 34 | D = int(math.floor(planes * (baseWidth / 64))) 35 | C = cardinality 36 | 37 | self.conv1 = nn.Conv2d(inplanes, D*C, kernel_size=1, stride=1, padding=0, bias=False) 38 | self.bn1 = nn.BatchNorm2d(D*C) 39 | self.conv2 = nn.Conv2d(D*C, D*C, kernel_size=3, stride=stride, padding=1, groups=C, bias=False) 40 | self.bn2 = nn.BatchNorm2d(D*C) 41 | self.conv3 = nn.Conv2d(D*C, planes * 4, kernel_size=1, stride=1, padding=0, bias=False) 42 | self.bn3 = nn.BatchNorm2d(planes * 4) 43 | self.relu = nn.ReLU(inplace=True) 44 | 45 | self.downsample = downsample 46 | 47 | def forward(self, x): 48 | residual = x 49 | 50 | out = self.conv1(x) 51 | out = self.bn1(out) 52 | out = self.relu(out) 53 | 54 | out = self.conv2(out) 55 | out = self.bn2(out) 56 | out = self.relu(out) 57 | 58 | out = self.conv3(out) 59 | out = self.bn3(out) 60 | 61 | if self.downsample is not None: 62 | residual = self.downsample(x) 63 | 64 | out += residual 65 | out = self.relu(out) 66 | 67 | return out 68 | 69 | 70 | class ResNeXt(nn.Module): 71 | """ 72 | ResNext optimized for the ImageNet dataset, as specified in 73 | https://arxiv.org/pdf/1611.05431.pdf 74 | """ 75 | def __init__(self, baseWidth, cardinality, layers, num_classes): 76 | """ Constructor 77 | Args: 78 | baseWidth: baseWidth for ResNeXt. 79 | cardinality: number of convolution groups. 80 | layers: config of layers, e.g., [3, 4, 6, 3] 81 | num_classes: number of classes 82 | """ 83 | super(ResNeXt, self).__init__() 84 | block = Bottleneck 85 | 86 | self.cardinality = cardinality 87 | self.baseWidth = baseWidth 88 | self.num_classes = num_classes 89 | self.inplanes = 64 90 | self.output_size = 64 91 | 92 | self.conv1 = nn.Conv2d(3, 64, 7, 2, 3, bias=False) 93 | self.bn1 = nn.BatchNorm2d(64) 94 | self.relu = nn.ReLU(inplace=True) 95 | self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 96 | self.layer1 = self._make_layer(block, 64, layers[0]) 97 | self.layer2 = self._make_layer(block, 128, layers[1], 2) 98 | self.layer3 = self._make_layer(block, 256, layers[2], 2) 99 | self.layer4 = self._make_layer(block, 512, layers[3], 2) 100 | self.avgpool = nn.AvgPool2d(7) 101 | self.fc = nn.Linear(512 * block.expansion, num_classes) 102 | 103 | for m in self.modules(): 104 | if isinstance(m, nn.Conv2d): 105 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 106 | m.weight.data.normal_(0, math.sqrt(2. / n)) 107 | elif isinstance(m, nn.BatchNorm2d): 108 | m.weight.data.fill_(1) 109 | m.bias.data.zero_() 110 | 111 | def _make_layer(self, block, planes, blocks, stride=1): 112 | """ Stack n bottleneck modules where n is inferred from the depth of the network. 113 | Args: 114 | block: block type used to construct ResNext 115 | planes: number of output channels (need to multiply by block.expansion) 116 | blocks: number of blocks to be built 117 | stride: factor to reduce the spatial dimensionality in the first bottleneck of the block. 118 | Returns: a Module consisting of n sequential bottlenecks. 119 | """ 120 | downsample = None 121 | if stride != 1 or self.inplanes != planes * block.expansion: 122 | downsample = nn.Sequential( 123 | nn.Conv2d(self.inplanes, planes * block.expansion, 124 | kernel_size=1, stride=stride, bias=False), 125 | nn.BatchNorm2d(planes * block.expansion), 126 | ) 127 | 128 | layers = [] 129 | layers.append(block(self.inplanes, planes, self.baseWidth, self.cardinality, stride, downsample)) 130 | self.inplanes = planes * block.expansion 131 | for i in range(1, blocks): 132 | layers.append(block(self.inplanes, planes, self.baseWidth, self.cardinality)) 133 | 134 | return nn.Sequential(*layers) 135 | 136 | def forward(self, x): 137 | x = self.conv1(x) 138 | x = self.bn1(x) 139 | x = self.relu(x) 140 | x = self.maxpool1(x) 141 | x = self.layer1(x) 142 | x = self.layer2(x) 143 | x = self.layer3(x) 144 | x = self.layer4(x) 145 | x = self.avgpool(x) 146 | x = x.view(x.size(0), -1) 147 | x = self.fc(x) 148 | 149 | return x 150 | 151 | 152 | def resnext50(baseWidth, cardinality): 153 | """ 154 | Construct ResNeXt-50. 155 | """ 156 | model = ResNeXt(baseWidth, cardinality, [3, 4, 6, 3], 1000) 157 | return model 158 | 159 | 160 | def resnext101(baseWidth, cardinality): 161 | """ 162 | Construct ResNeXt-101. 163 | """ 164 | model = ResNeXt(baseWidth, cardinality, [3, 4, 23, 3], 1000) 165 | return model 166 | 167 | 168 | def resnext152(baseWidth, cardinality): 169 | """ 170 | Construct ResNeXt-152. 171 | """ 172 | model = ResNeXt(baseWidth, cardinality, [3, 8, 36, 3], 1000) 173 | return model 174 | -------------------------------------------------------------------------------- /group_selection.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | 4 | import torch 5 | from torch import nn 6 | import torch.backends.cudnn as cudnn 7 | import load_model 8 | from tqdm import tqdm 9 | import torchvision.transforms as transforms 10 | import torchvision.datasets as datasets 11 | import numpy as np 12 | import subprocess as sp 13 | import os 14 | 15 | from even_k_means import kmeans_lloyd 16 | 17 | parser = argparse.ArgumentParser( 18 | description='PyTorch CIFAR10/100/Imagenet Generate Group Info') 19 | # Datasets 20 | parser.add_argument('-d', '--dataset', required=True, type=str) 21 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 22 | help='number of data loading workers (default: 4)') 23 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 24 | help='path to latest checkpoint (default: none)') 25 | parser.add_argument('--data', default='/home/ubuntu/imagenet', required=False, type=str, 26 | help='location of the imagenet dataset that includes train/val') 27 | # Architecture 28 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet20', 29 | #choices=load_model.model_arches('cifar'), 30 | help='model architecture: ' + 31 | ' | '.join(load_model.model_arches('cifar')) + 32 | ' (default: resnet18)') 33 | parser.add_argument('-n', '--ngroups', required=True, type=int, metavar='N', 34 | help='number of groups') 35 | parser.add_argument('-g', '--gpu_num', default=1, type=int, 36 | help='number of gpus') 37 | 38 | # Miscs 39 | parser.add_argument('--seed', type=int, default=42, help='manual seed') 40 | args = parser.parse_args() 41 | use_cuda = torch.cuda.is_available() and True 42 | 43 | # Random seed 44 | torch.manual_seed(args.seed) 45 | if use_cuda: 46 | torch.cuda.manual_seed_all(args.seed) 47 | 48 | def main(): 49 | print('==> Preparing dataset %s' % args.dataset) 50 | resultExist = os.path.exists("./prune_candidate_logs") 51 | if resultExist: 52 | rm_cmd = 'rm -rf ./prune_candidate_logs' 53 | sp.Popen(rm_cmd, shell=True) 54 | mkdir_cmd = 'mkdir ./prune_candidate_logs' 55 | sp.Popen(mkdir_cmd, shell=True) 56 | # cifar10/100 group selection 57 | if args.dataset in ['cifar10', 'cifar100']: 58 | if args.dataset == 'cifar10': 59 | dataset_loader = datasets.CIFAR10 60 | elif args.dataset == 'cifar100': 61 | dataset_loader = datasets.CIFAR100 62 | 63 | dataset = dataset_loader( 64 | root='./data', 65 | download=True, 66 | train=True, 67 | transform=transforms.Compose([ 68 | transforms.ToTensor(), 69 | transforms.Normalize( 70 | (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 71 | ])) 72 | data_loader = torch.utils.data.DataLoader( 73 | dataset, 74 | batch_size=1000, 75 | num_workers=args.workers, 76 | pin_memory=False) 77 | 78 | model = load_model.load_pretrain_model( 79 | args.arch, 'cifar', args.resume, len(dataset.classes), use_cuda) 80 | 81 | all_features = [] 82 | all_targets = [] 83 | 84 | model.eval() 85 | print('\nMake a test run to generate groups. \n Using training set.\n') 86 | with tqdm(total=len(data_loader)) as bar: 87 | for batch_idx, (inputs, targets) in enumerate(data_loader): 88 | bar.update() 89 | if use_cuda: 90 | inputs = inputs.cuda() 91 | with torch.no_grad(): 92 | features = model(inputs, features_only=True) 93 | all_features.append(features) 94 | all_targets.append(targets) 95 | 96 | all_features = torch.cat(all_features) 97 | all_targets = torch.cat(all_targets) 98 | 99 | groups = kmeans_grouping(all_features, all_targets, 100 | args.ngroups, same_group_size=True) 101 | print("groups: ", groups) 102 | print("\n====================== Grouping Result ========================\n") 103 | process_list = [None for _ in range(args.gpu_num)] 104 | for i, group in enumerate(groups): 105 | if process_list[i % args.gpu_num]: 106 | process_list[i % args.gpu_num].wait() 107 | print(f"Group #{i}: {' '.join(str(idx) for idx in group)}") 108 | exec_cmd = 'python3 get_prune_candidates.py' +\ 109 | ' -a %s' % args.arch + ' -d %s' % args.dataset + ' --resume ./%s' % args.resume + \ 110 | ' --grouped ' + str(group)[1:-1].replace(",", "") + ' --group_number %d' % i + ' --gpu_num %d' % (i % args.gpu_num) 111 | process_list[i % args.gpu_num] = sp.Popen(exec_cmd, shell=True) 112 | 113 | np.save(open("prune_candidate_logs/grouping_config.npy", "wb"), groups) 114 | 115 | # imagenet group selection 116 | elif args.dataset == 'imagenet': 117 | num_gpus = args.gpu_num 118 | num_groups = args.ngroups 119 | group_size = 1000 // num_groups 120 | groups = [[i for i in range((j) * group_size, (j+1) * group_size)] for j in range(num_groups) ] 121 | process_list = [None for _ in range(num_gpus)] 122 | for i, group in enumerate(groups): 123 | if process_list[i % num_gpus]: 124 | process_list[i % num_gpus].wait() 125 | exec_cmd = 'python3 imagenet_activations.py ' +\ 126 | ' --data %s' % args.data +\ 127 | ' --gpu %d' % (i % num_gpus) +\ 128 | ' --arch %s' % args.arch + ' --evaluate --pretrained --group %s' % (' '.join(str(digit) for digit in group)) + \ 129 | ' --name %s' % (str(i)) 130 | process_list[i % num_gpus] = sp.Popen(exec_cmd, shell=True) 131 | # Save the grouping class index partition information 132 | np.save(open("prune_candidate_logs/grouping_config.npy", "wb"), groups) 133 | else: 134 | raise NotImplementedError(f"There's no support for '{args.dataset}' dataset.") 135 | 136 | def kmeans_grouping(features, targets, n_groups, same_group_size=True): 137 | class_indices = targets.unique().sort().values 138 | mean_vectors = [] 139 | for t in class_indices: 140 | mean_vec = features[targets == t.item(), :].mean(dim=0) 141 | mean_vectors.append(mean_vec.cpu().numpy()) 142 | X = np.asarray(mean_vectors) 143 | class_indices = class_indices.cpu().numpy() 144 | assert X.ndim == 2 145 | best_labels, best_inertia, best_centers, _ = kmeans_lloyd( 146 | X, None, n_groups, verbose=True, 147 | same_cluster_size=same_group_size, 148 | random_state=args.seed, 149 | tol=1e-6) 150 | groups = [] 151 | for i in range(n_groups): 152 | groups.append(class_indices[best_labels == i].tolist()) 153 | return groups 154 | 155 | if __name__ == '__main__': 156 | main() 157 | -------------------------------------------------------------------------------- /imagenet_evaluate_grouped.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import time 5 | import sys 6 | import glob 7 | import re 8 | import numpy as np 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.parallel 13 | import torch.backends.cudnn as cudnn 14 | import torch.distributed as dist 15 | import torch.optim 16 | import torch.multiprocessing as mp 17 | import torch.utils.data 18 | import torch.utils.data.distributed 19 | import torchvision.transforms as transforms 20 | import imagenet_dataset as datasets 21 | import torchvision.models as models 22 | 23 | from compute_flops import print_model_param_flops 24 | 25 | def main_worker(gpu, ngpus_per_node, args): 26 | global best_acc1 27 | args.gpu = gpu 28 | args.evaluate = True 29 | 30 | cudnn.benchmark = True 31 | model_list = [] 32 | num_flops = [] 33 | avg_num_param = 0.0 34 | args.checkpoint = os.path.dirname(args.retrained_dir) 35 | criterion = nn.CrossEntropyLoss() 36 | 37 | # load groups 38 | file_names = [f for f in glob.glob(args.retrained_dir + "/" + args.arch + "/*.pth", recursive=False)] 39 | group_id_list = [filename_to_index(filename) for filename in file_names] 40 | group_config = np.load(open(args.retrained_dir + '/grouping_config.npy', "rb")) 41 | 42 | permutation_indices = [] # To allow for arbitrary grouping 43 | for group_id in group_id_list: 44 | permutation_indices.extend(group_config[int(group_id[0])]) 45 | permutation_indices = torch.eye(1000)[permutation_indices].cuda(args.gpu) 46 | 47 | # load models 48 | for index, (group_id, file_name) in enumerate(zip(group_id_list, file_names)): 49 | model = torch.load(file_name) 50 | model = model.cuda(index % ngpus_per_node) 51 | avg_num_param += sum(p.numel() for p in model.parameters())/1000000.0 52 | print('Group {} model has total params: {:2f}M'.format(group_id ,sum(p.numel() for p in model.parameters())/1000000.0)) 53 | model_list.append(model) 54 | 55 | # generate dataloader 56 | valdir = os.path.join(args.data, 'val') 57 | traindir = os.path.join(args.data, 'train') 58 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 59 | std=[0.229, 0.224, 0.225]) 60 | 61 | val_loader = torch.utils.data.DataLoader( 62 | datasets.ImageFolder(valdir, transforms.Compose([ 63 | transforms.Resize(256), 64 | transforms.CenterCrop(224), 65 | transforms.ToTensor(), 66 | normalize, 67 | ])), 68 | batch_size=args.batch_size, shuffle=False, 69 | num_workers=args.workers, pin_memory=True) 70 | 71 | if args.evaluate: 72 | validate(val_loader, model_list, criterion, args, permutation_indices, ngpus_per_node) 73 | return 74 | 75 | def validate(val_loader, model_list, criterion, args, p_indices, gpu_nums): 76 | batch_time = AverageMeter('Time', ':6.3f') 77 | losses = AverageMeter('Loss', ':.4e') 78 | top1 = AverageMeter('Acc@1', ':6.2f') 79 | top5 = AverageMeter('Acc@5', ':6.2f') 80 | progress = ProgressMeter(len(val_loader), batch_time, losses, top1, top5, 81 | prefix='Test: ') 82 | 83 | # switch to evaluate mode 84 | for model in model_list: 85 | model.eval() 86 | 87 | with torch.no_grad(): 88 | end = time.time() 89 | for i, (input, target) in enumerate(val_loader): 90 | input_list = [] 91 | for index in range(gpu_nums): 92 | input = input.cuda(index) 93 | input_list.append(input) 94 | target = target.cuda(0) ### send same input and target to each gpu 95 | 96 | # compute output 97 | output_list = torch.Tensor().cuda(0) 98 | for index, model in enumerate(model_list): 99 | temp = model(input_list[index%gpu_nums]) 100 | output = nn.Softmax(dim=1)(temp)[:, 1:] 101 | output_list= torch.cat((output_list, output), 1) 102 | output = torch.mm(output_list, p_indices) 103 | 104 | loss = criterion(output, target) 105 | 106 | # measure accuracy and record loss 107 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 108 | losses.update(loss.item(), input.size(0)) 109 | top1.update(acc1[0], input.size(0)) 110 | top5.update(acc5[0], input.size(0)) 111 | 112 | # measure elapsed time 113 | batch_time.update(time.time() - end) 114 | end = time.time() 115 | 116 | if i % args.print_freq == 0: 117 | progress.print(i) 118 | # TODO: this should also be done with the ProgressMeter 119 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 120 | .format(top1=top1, top5=top5)) 121 | 122 | return top1.avg 123 | 124 | 125 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 126 | torch.save(state, filename) 127 | if is_best: 128 | shutil.copyfile(filename, 'model_best.pth.tar') 129 | 130 | def filename_to_index(filename): 131 | filename = [int(s) for s in filename.split('_') if s.isdigit()] 132 | return filename 133 | 134 | 135 | class AverageMeter(object): 136 | """Computes and stores the average and current value""" 137 | def __init__(self, name, fmt=':f'): 138 | self.name = name 139 | self.fmt = fmt 140 | self.reset() 141 | 142 | def reset(self): 143 | self.val = 0 144 | self.avg = 0 145 | self.sum = 0 146 | self.count = 0 147 | 148 | def update(self, val, n=1): 149 | self.val = val 150 | self.sum += val * n 151 | self.count += n 152 | self.avg = self.sum / self.count 153 | 154 | def __str__(self): 155 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 156 | return fmtstr.format(**self.__dict__) 157 | 158 | 159 | class ProgressMeter(object): 160 | def __init__(self, num_batches, *meters, prefix=""): 161 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 162 | self.meters = meters 163 | self.prefix = prefix 164 | 165 | def print(self, batch): 166 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 167 | entries += [str(meter) for meter in self.meters] 168 | print('\t'.join(entries)) 169 | 170 | def _get_batch_fmtstr(self, num_batches): 171 | num_digits = len(str(num_batches // 1)) 172 | fmt = '{:' + str(num_digits) + 'd}' 173 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 174 | 175 | 176 | def accuracy(output, target, topk=(1,)): 177 | """Computes the accuracy over the k top predictions for the specified values of k""" 178 | with torch.no_grad(): 179 | maxk = max(topk) 180 | batch_size = target.size(0) 181 | 182 | _, pred = output.topk(maxk, 1, True, True) 183 | pred = pred.t() 184 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 185 | 186 | res = [] 187 | for k in topk: 188 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 189 | res.append(correct_k.mul_(100.0 / batch_size)) 190 | return res -------------------------------------------------------------------------------- /imagenet_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | 7 | import random 8 | 9 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm'] 10 | 11 | 12 | def is_image_file(filename): 13 | """Checks if a file is an image. 14 | 15 | Args: 16 | filename (string): path to a file 17 | 18 | Returns: 19 | bool: True if the filename ends with a known image extension 20 | """ 21 | filename_lower = filename.lower() 22 | return any(filename_lower.endswith(ext) for ext in IMG_EXTENSIONS) 23 | 24 | def find_classes(dir): 25 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 26 | classes.sort() 27 | class_to_idx = {classes[i]: i for i in range(len(classes))} 28 | return classes, class_to_idx 29 | 30 | def make_dataset(dir, class_to_idx, group = None, target_abs_index = None): 31 | images = [] 32 | dir = os.path.expanduser(dir) 33 | for target in sorted(os.listdir(dir)): 34 | # pdb.set_trace() 35 | if target not in class_to_idx: 36 | continue 37 | if int(class_to_idx[target]) not in group: 38 | continue 39 | 40 | d = os.path.join(dir, target) 41 | if not os.path.isdir(d): 42 | continue 43 | for root, _, fnames in sorted(os.walk(d)): 44 | for fname in sorted(fnames): 45 | if is_image_file(fname): 46 | path = os.path.join(root, fname) 47 | if target_abs_index != None : 48 | item = (path, target_abs_index) 49 | else: 50 | item = (path, class_to_idx[target]) 51 | images.append(item) 52 | 53 | return images # random.sample(images, 5000) # Used for debug 54 | 55 | def pil_loader(path): 56 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 57 | with open(path, 'rb') as f: 58 | with Image.open(f) as img: 59 | return img.convert('RGB') 60 | 61 | def accimage_loader(path): 62 | import accimage 63 | try: 64 | return accimage.Image(path) 65 | except IOError: 66 | # Potentially a decoding problem, fall back to PIL.Image 67 | return pil_loader(path) 68 | 69 | def default_loader(path): 70 | from torchvision import get_image_backend 71 | if get_image_backend() == 'accimage': 72 | return accimage_loader(path) 73 | else: 74 | return pil_loader(path) 75 | 76 | class ImageFolder(data.Dataset): 77 | """A generic data loader where the images are arranged in this way: :: 78 | 79 | root/dog/xxx.png 80 | root/dog/xxy.png 81 | root/dog/xxz.png 82 | 83 | root/cat/123.png 84 | root/cat/nsdf3.png 85 | root/cat/asd932_.png 86 | 87 | Args: 88 | root (string): Root directory path. 89 | transform (callable, optional): A function/transform that takes in an PIL image 90 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 91 | target_transform (callable, optional): A function/transform that takes in the 92 | target and transforms it. 93 | loader (callable, optional): A function to load an image given its path. 94 | 95 | Attributes: 96 | classes (list): List of the class names. 97 | class_to_idx (dict): Dict with items (class_name, class_index). 98 | imgs (list): List of (image path, class_index) tuples 99 | """ 100 | 101 | def __init__(self, root, transform=None, target_transform=None, 102 | loader=default_loader, activations = False, group = None, retrain = False): 103 | classes, class_to_idx = find_classes(root) 104 | 105 | # Case: Evaluate but pull from training set 106 | if activations and group: 107 | imgs = make_dataset(root, class_to_idx, group) 108 | elif group is not None: # Case: Train / Evaluate: pos/neg according to group 109 | if retrain: # Subcase: Retraining (Training Set Creation) 110 | imgs = [] 111 | for abs_index, class_index in enumerate(group): 112 | pos_imgs = make_dataset(root, \ 113 | class_to_idx, \ 114 | group=[class_index], \ 115 | target_abs_index=abs_index + 1) 116 | multiplier = max(1, 0) # Multiple used to balance, if wanted 117 | imgs.extend(pos_imgs) 118 | negative_numbers = len(imgs) 119 | negative_indices = [i for i in range(1000) if i not in group] 120 | neg_imgs = make_dataset(root, \ 121 | class_to_idx, \ 122 | group=negative_indices, \ 123 | target_abs_index=0) 124 | neg_imgs = random.sample(neg_imgs, negative_numbers) 125 | imgs.extend(neg_imgs) 126 | print("Num images in training set: {}".format(len(imgs))) 127 | # print("Added {} positive images with target index {}".format(len(pos_imgs)*multiplier, abs_index)) 128 | else: # Subcase: Evaluation (Validation Set Creation) 129 | imgs = [] 130 | for abs_index, class_index in enumerate(group): 131 | pos_imgs = make_dataset(root, \ 132 | class_to_idx, \ 133 | group=[class_index], \ 134 | target_abs_index=abs_index + 1) 135 | imgs.extend(pos_imgs) 136 | negative_numbers = len(imgs) 137 | print("positive images in val loader: ", negative_numbers) 138 | negative_indices = [i for i in range(1000) if i not in group] 139 | neg_imgs = make_dataset(root, \ 140 | class_to_idx, \ 141 | group=negative_indices, \ 142 | target_abs_index=0) 143 | 144 | neg_imgs = random.sample(neg_imgs, negative_numbers) 145 | imgs.extend(neg_imgs) 146 | print("Num images in validation set {}".format(len(imgs))) 147 | else: # Case: Default 148 | imgs = make_dataset(root, class_to_idx, group = [i for i in range(1000)]) 149 | if len(imgs) == 0: 150 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 151 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 152 | 153 | self.root = root 154 | self.imgs = imgs 155 | self.classes = classes 156 | self.class_to_idx = class_to_idx 157 | self.transform = transform 158 | self.target_transform = target_transform 159 | self.loader = loader 160 | 161 | def __getitem__(self, index): 162 | """ 163 | Args: 164 | index (int): Index 165 | 166 | Returns: 167 | tuple: (image, target) where target is class_index of the target class. 168 | """ 169 | path, target = self.imgs[index] 170 | img = self.loader(path) 171 | if self.transform is not None: 172 | img = self.transform(img) 173 | if self.target_transform is not None: 174 | target = self.target_transform(target) 175 | 176 | return img, target 177 | 178 | def __len__(self): 179 | return len(self.imgs) 180 | 181 | 182 | 183 | -------------------------------------------------------------------------------- /prune_utils/layer_prune.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from fractions import gcd 4 | 5 | def prune_output_linear_layer_(linear_layer, class_indices, use_bce=False): 6 | if use_bce: 7 | assert len(class_indices) == 1 8 | else: 9 | # use 0 as the placeholder of the negative class 10 | class_indices = [0] + list(class_indices) 11 | linear_layer.bias.data = linear_layer.bias.data[class_indices] 12 | linear_layer.weight.data = linear_layer.weight.data[class_indices, :] 13 | if not use_bce: 14 | # reinitialize the negative sample class 15 | linear_layer.weight.data[0].normal_(0, 0.01) 16 | linear_layer.out_features = len(class_indices) 17 | 18 | 19 | def prune_linear_in_features(fc, pruned_indices): 20 | new_fc = nn.Linear(fc.in_features - len(pruned_indices), fc.out_features) 21 | new_fc.bias.data = fc.bias.data.clone() 22 | new_fc.weight.data = prune_tensor(fc.weight.data, 1, pruned_indices) 23 | return new_fc 24 | 25 | 26 | def prune_linear_in_features_(fc, pruned_indices): 27 | fc.in_features -= len(pruned_indices) 28 | fc.weight.data = prune_tensor(fc.weight.data, 1, pruned_indices) 29 | 30 | 31 | def prune_tensor(tensor, dim, pruned_indices): 32 | if tensor.shape[dim] == 1: 33 | return tensor 34 | included_indices = [i for i in range( 35 | tensor.shape[dim]) if i not in pruned_indices] 36 | indexer = [] 37 | for i in range(tensor.ndim): 38 | indexer.append(slice(None) if i != dim else included_indices) 39 | return tensor[indexer] 40 | 41 | def prune_batchnorm2d(bn, pruned_indices): 42 | new_bn = nn.BatchNorm2d(bn.num_features - len(pruned_indices)) 43 | new_bn.weight.data = prune_tensor(bn.weight.data, 0, pruned_indices) 44 | new_bn.bias.data = prune_tensor(bn.bias.data, 0, pruned_indices) 45 | new_bn.running_mean.data = prune_tensor( 46 | bn.running_mean.data, 0, pruned_indices) 47 | new_bn.running_var.data = prune_tensor( 48 | bn.running_var.data, 0, pruned_indices) 49 | return new_bn 50 | 51 | 52 | def prune_batchnorm2d_(bn, pruned_indices): 53 | bn.num_features -= len(pruned_indices) 54 | bn.weight.data = prune_tensor(bn.weight.data, 0, pruned_indices) 55 | bn.bias.data = prune_tensor(bn.bias.data, 0, pruned_indices) 56 | bn.running_mean.data = prune_tensor( 57 | bn.running_mean.data, 0, pruned_indices) 58 | bn.running_var.data = prune_tensor(bn.running_var.data, 0, pruned_indices) 59 | return bn 60 | 61 | 62 | def prune_conv2d_out_channels(conv, pruned_indices): 63 | new_conv = nn.Conv2d(in_channels=conv.in_channels, 64 | out_channels=conv.out_channels - len(pruned_indices), 65 | kernel_size=conv.kernel_size, 66 | stride=conv.stride, 67 | padding=conv.padding, 68 | dilation=conv.dilation, 69 | groups=conv.groups, 70 | bias=conv.bias is not None) 71 | 72 | new_conv.weight.data = prune_tensor(conv.weight.data, 0, pruned_indices) 73 | 74 | if conv.bias is not None: 75 | new_conv.bias.data = prune_tensor(conv.bias.data, 0, pruned_indices) 76 | return new_conv 77 | 78 | 79 | def prune_conv2d_out_channels_(conv, pruned_indices): 80 | conv.out_channels -= len(pruned_indices) 81 | conv.weight.data = prune_tensor(conv.weight.data, 0, pruned_indices) 82 | if conv.bias is not None: 83 | conv.bias.data = prune_tensor(conv.bias.data, 0, pruned_indices) 84 | return conv 85 | 86 | 87 | def prune_conv2d_in_channels(conv, pruned_indices): 88 | new_conv = nn.Conv2d(in_channels=conv.in_channels - len(pruned_indices), 89 | out_channels=conv.out_channels, 90 | kernel_size=conv.kernel_size, 91 | stride=conv.stride, 92 | padding=conv.padding, 93 | dilation=conv.dilation, 94 | groups=conv.groups, 95 | bias=conv.bias is not None) 96 | 97 | new_conv.weight.data = prune_tensor(conv.weight.data, 1, pruned_indices) 98 | 99 | if conv.bias is not None: 100 | new_conv.bias.data = conv.bias.data.clone() 101 | return new_conv 102 | 103 | 104 | def prune_conv2d_in_channels_(conv, pruned_indices): 105 | conv.in_channels -= len(pruned_indices) 106 | conv.weight.data = prune_tensor(conv.weight.data, 1, pruned_indices) 107 | return conv 108 | 109 | 110 | def prune_contiguous_conv2d_(conv_p, conv_n, pruned_indices, bn=None): 111 | prune_conv2d_out_channels_(conv_p, pruned_indices) 112 | prune_conv2d_in_channels_(conv_n, pruned_indices) 113 | if bn is not None: 114 | prune_batchnorm2d_(bn, pruned_indices) 115 | 116 | def prune_contiguous_conv2d_last(conv_p, conv_n, pruned_indices, bn=None): 117 | prune_conv2d_out_channels_(conv_p, pruned_indices) 118 | if bn is not None: 119 | prune_batchnorm2d_(bn, pruned_indices) 120 | 121 | def prune_mobile_conv2d_in_channels(conv, pruned_indices): 122 | conv.in_channels -= len(pruned_indices) 123 | conv.groups = conv.in_channels 124 | 125 | conv.weight.data = prune_tensor(conv.weight.data, 1, pruned_indices) 126 | return conv 127 | 128 | def prune_mobile_conv2d_out_channels(conv, pruned_indices): 129 | if conv.groups != 1: 130 | pruned_indices = pruned_indices[:(conv.out_channels - conv.in_channels)] 131 | conv.out_channels -= len(pruned_indices) 132 | conv.groups = conv.out_channels 133 | conv.weight.data = prune_tensor(conv.weight.data, 0, pruned_indices) 134 | return conv 135 | 136 | def prune_contiguous_conv2d_mobile_a(conv_p, conv_n, pruned_indices, bn=None): 137 | prune_conv2d_out_channels_(conv_p, pruned_indices) 138 | prune_mobile_conv2d_in_channels(conv_n, pruned_indices) 139 | if bn is not None: 140 | prune_batchnorm2d_(bn, pruned_indices) 141 | 142 | def prune_contiguous_conv2d_mobile_b(conv_p, conv_n, pruned_indices, bn=None): 143 | prune_mobile_conv2d_out_channels(conv_p, pruned_indices) 144 | prune_conv2d_in_channels_(conv_n, pruned_indices[:(conv_n.in_channels-conv_p.out_channels)]) 145 | if bn is not None: 146 | prune_batchnorm2d_(bn, pruned_indices[:(bn.num_features-conv_p.out_channels)]) 147 | 148 | def prune_mobile_block(conv_1, conv_2, conv_3, pruned_indices_1, pruned_indices_2, bn_1, bn_2): 149 | small_len = min(len(pruned_indices_1), len(pruned_indices_2)) 150 | if len(pruned_indices_2) < len(pruned_indices_1): 151 | pruned_indices_1 = pruned_indices_1[:small_len] 152 | prune_contiguous_conv2d_mobile_a(conv_1, conv_2, pruned_indices_1, bn=bn_1) 153 | prune_contiguous_conv2d_mobile_b(conv_2, conv_3, pruned_indices_2, bn=bn_2) 154 | 155 | def prune_downblock(block, layer_candidates): 156 | conv3 = block.conv3 157 | bn3 = block.bn3 158 | conv4 = block.conv4 159 | bn4 = block.bn4 160 | conv5 = block.conv5 161 | pruned_indices_3_4 = layer_candidates[2] 162 | pruned_indices_4_5 = layer_candidates[3] 163 | small_len = min(len(pruned_indices_3_4), len(pruned_indices_4_5)) 164 | if len(pruned_indices_4_5) < len(pruned_indices_3_4): 165 | pruned_indices_3_4 = pruned_indices_4_5[:small_len] 166 | prune_contiguous_conv2d_mobile_a(conv3, conv4, pruned_indices_3_4, bn=bn3) 167 | prune_contiguous_conv2d_mobile_b(conv4, conv5, pruned_indices_4_5, bn=bn4) 168 | 169 | def prune_basicblock(block, layer_candidates): 170 | conv_1 = block.conv1 171 | bn_1 = block.bn1 172 | conv_2 = block.conv2 173 | bn_2 = block.bn2 174 | conv_3 = block.conv3 175 | pruned_indices_1 = layer_candidates[0] 176 | pruned_indices_2 = layer_candidates[1] 177 | small_len = min(len(pruned_indices_1), len(pruned_indices_2)) 178 | if len(pruned_indices_2) < len(pruned_indices_1): 179 | pruned_indices_1 = pruned_indices_1[:small_len] 180 | prune_contiguous_conv2d_mobile_a(conv_1, conv_2, pruned_indices_1, bn=bn_1) 181 | prune_contiguous_conv2d_mobile_b(conv_2, conv_3, pruned_indices_2, bn=bn_2) 182 | 183 | def prune_shuffle_layer(layer, layer_candidates): 184 | for idx, block in enumerate(layer): 185 | if idx == 0: 186 | prune_downblock(block, layer_candidates[:5]) 187 | else: 188 | candidates = layer_candidates[idx*3+2:idx*3+5] 189 | prune_basicblock(block, candidates) -------------------------------------------------------------------------------- /even_k_means.py: -------------------------------------------------------------------------------- 1 | from sklearn.cluster.k_means_ import check_random_state, _check_sample_weight, _init_centroids 2 | from sklearn.metrics.pairwise import pairwise_distances_argmin_min, euclidean_distances 3 | from sklearn.utils.extmath import row_norms, squared_norm 4 | import numpy as np 5 | 6 | 7 | def _labels_inertia(X, sample_weight, x_squared_norms, centers, distances, same_cluster_size=False): 8 | """E step of the K-means EM algorithm. 9 | Compute the labels and the inertia of the given samples and centers. 10 | This will compute the distances in-place. 11 | Parameters 12 | ---------- 13 | X : float64 array-like or CSR sparse matrix, shape (n_samples, n_features) 14 | The input samples to assign to the labels. 15 | sample_weight : array-like, shape (n_samples,) 16 | The weights for each observation in X. 17 | x_squared_norms : array, shape (n_samples,) 18 | Precomputed squared euclidean norm of each data point, to speed up 19 | computations. 20 | centers : float array, shape (k, n_features) 21 | The cluster centers. 22 | distances : float array, shape (n_samples,) 23 | Pre-allocated array to be filled in with each sample's distance 24 | to the closest center. 25 | Returns 26 | ------- 27 | labels : int array of shape(n) 28 | The resulting assignment 29 | inertia : float 30 | Sum of squared distances of samples to their closest cluster center. 31 | """ 32 | sample_weight = _check_sample_weight(X, sample_weight) 33 | n_samples = X.shape[0] 34 | n_clusters = centers.shape[0] 35 | 36 | # See http://jmonlong.github.io/Hippocamplus/2018/06/09/cluster-same-size/#same-size-k-means-variation 37 | if same_cluster_size: 38 | cluster_size = n_samples // n_clusters 39 | labels = np.zeros(n_samples, dtype=np.int32) 40 | mindist = np.zeros(n_samples, dtype=np.float32) 41 | # count how many samples have been labeled in a cluster 42 | counters = np.zeros(n_clusters, dtype=np.int32) 43 | # dist: (n_samples, n_clusters) 44 | dist = euclidean_distances(X, centers, squared=False) 45 | closeness = dist.min(axis=-1) - dist.max(axis=-1) 46 | ranking = np.argsort(closeness) 47 | for r in ranking: 48 | while True: 49 | label = dist[r].argmin() 50 | if counters[label] < cluster_size: 51 | labels[r] = label 52 | counters[label] += 1 53 | # squared distances are used for inertia in this function 54 | mindist[r] = dist[r, label] ** 2 55 | break 56 | else: 57 | dist[r, label] = np.inf 58 | else: 59 | # Breakup nearest neighbor distance computation into batches to prevent 60 | # memory blowup in the case of a large number of samples and clusters. 61 | # TODO: Once PR #7383 is merged use check_inputs=False in metric_kwargs. 62 | labels, mindist = pairwise_distances_argmin_min( 63 | X=X, Y=centers, metric='euclidean', metric_kwargs={'squared': True}) 64 | 65 | # cython k-means code assumes int32 inputs 66 | labels = labels.astype(np.int32, copy=False) 67 | if n_samples == distances.shape[0]: 68 | # distances will be changed in-place 69 | distances[:] = mindist 70 | inertia = (mindist * sample_weight).sum() 71 | return labels, inertia 72 | 73 | 74 | def _centers_dense(X, sample_weight, labels, n_clusters, distances): 75 | """M step of the K-means EM algorithm 76 | Computation of cluster centers / means. 77 | Parameters 78 | ---------- 79 | X : array-like, shape (n_samples, n_features) 80 | sample_weight : array-like, shape (n_samples,) 81 | The weights for each observation in X. 82 | labels : array of integers, shape (n_samples) 83 | Current label assignment 84 | n_clusters : int 85 | Number of desired clusters 86 | distances : array-like, shape (n_samples) 87 | Distance to closest cluster for each sample. 88 | Returns 89 | ------- 90 | centers : array, shape (n_clusters, n_features) 91 | The resulting centers 92 | """ 93 | # TODO: add support for CSR input 94 | n_samples = X.shape[0] 95 | n_features = X.shape[1] 96 | 97 | dtype = np.float32 98 | centers = np.zeros((n_clusters, n_features), dtype=dtype) 99 | weight_in_cluster = np.zeros((n_clusters,), dtype=dtype) 100 | 101 | for i in range(n_samples): 102 | c = labels[i] 103 | weight_in_cluster[c] += sample_weight[i] 104 | empty_clusters = np.where(weight_in_cluster == 0)[0] 105 | # maybe also relocate small clusters? 106 | 107 | if len(empty_clusters): 108 | # find points to reassign empty clusters to 109 | far_from_centers = distances.argsort()[::-1] 110 | 111 | for i, cluster_id in enumerate(empty_clusters): 112 | # XXX two relocated clusters could be close to each other 113 | far_index = far_from_centers[i] 114 | new_center = X[far_index] * sample_weight[far_index] 115 | centers[cluster_id] = new_center 116 | weight_in_cluster[cluster_id] = sample_weight[far_index] 117 | 118 | for i in range(n_samples): 119 | for j in range(n_features): 120 | centers[labels[i], j] += X[i, j] * sample_weight[i] 121 | 122 | centers /= weight_in_cluster[:, np.newaxis] 123 | 124 | return centers 125 | 126 | 127 | def kmeans_lloyd(X, sample_weight, n_clusters, max_iter=300, 128 | init='k-means++', verbose=False, x_squared_norms=None, 129 | random_state=None, tol=1e-4, same_cluster_size=False): 130 | """A single run of k-means, assumes preparation completed prior. 131 | Parameters 132 | ---------- 133 | X : array-like of floats, shape (n_samples, n_features) 134 | The observations to cluster. 135 | n_clusters : int 136 | The number of clusters to form as well as the number of 137 | centroids to generate. 138 | sample_weight : array-like, shape (n_samples,) 139 | The weights for each observation in X. 140 | max_iter : int, optional, default 300 141 | Maximum number of iterations of the k-means algorithm to run. 142 | init : {'k-means++', 'random', or ndarray, or a callable}, optional 143 | Method for initialization, default to 'k-means++': 144 | 'k-means++' : selects initial cluster centers for k-mean 145 | clustering in a smart way to speed up convergence. See section 146 | Notes in k_init for more details. 147 | 'random': choose k observations (rows) at random from data for 148 | the initial centroids. 149 | If an ndarray is passed, it should be of shape (k, p) and gives 150 | the initial centers. 151 | If a callable is passed, it should take arguments X, k and 152 | and a random state and return an initialization. 153 | tol : float, optional 154 | The relative increment in the results before declaring convergence. 155 | verbose : boolean, optional 156 | Verbosity mode 157 | x_squared_norms : array 158 | Precomputed x_squared_norms. 159 | precompute_distances : boolean, default: True 160 | Precompute distances (faster but takes more memory). 161 | random_state : int, RandomState instance or None (default) 162 | Determines random number generation for centroid initialization. Use 163 | an int to make the randomness deterministic. 164 | See :term:`Glossary `. 165 | Returns 166 | ------- 167 | centroid : float ndarray with shape (k, n_features) 168 | Centroids found at the last iteration of k-means. 169 | label : integer ndarray with shape (n_samples,) 170 | label[i] is the code or index of the centroid the 171 | i'th observation is closest to. 172 | inertia : float 173 | The final value of the inertia criterion (sum of squared distances to 174 | the closest centroid for all observations in the training set). 175 | n_iter : int 176 | Number of iterations run. 177 | """ 178 | random_state = check_random_state(random_state) 179 | if same_cluster_size: 180 | assert len(X) % n_clusters == 0, "#samples is not divisible by #clusters" 181 | 182 | if verbose: 183 | print("\n==> Starting k-means clustering...\n") 184 | 185 | sample_weight = _check_sample_weight(X, sample_weight) 186 | x_squared_norms = row_norms(X, squared=True) 187 | 188 | best_labels, best_inertia, best_centers = None, None, None 189 | # init 190 | centers = _init_centroids(X, n_clusters, init, random_state=random_state, 191 | x_squared_norms=x_squared_norms) 192 | if verbose: 193 | print("Initialization complete") 194 | 195 | # Allocate memory to store the distances for each sample to its 196 | # closer center for reallocation in case of ties 197 | distances = np.zeros(shape=(X.shape[0],), dtype=X.dtype) 198 | 199 | # iterations 200 | for i in range(max_iter): 201 | centers_old = centers.copy() 202 | # labels assignment is also called the E-step of EM 203 | labels, inertia = \ 204 | _labels_inertia(X, sample_weight, x_squared_norms, 205 | centers, distances=distances, same_cluster_size=same_cluster_size) 206 | 207 | # computation of the means is also called the M-step of EM 208 | centers = _centers_dense( 209 | X, sample_weight, labels, n_clusters, distances) 210 | 211 | if verbose: 212 | print("Iteration %2d, inertia %.3f" % (i, inertia)) 213 | 214 | if best_inertia is None or inertia < best_inertia: 215 | best_labels = labels.copy() 216 | best_centers = centers.copy() 217 | best_inertia = inertia 218 | 219 | center_shift_total = squared_norm(centers_old - centers) 220 | if center_shift_total <= tol: 221 | if verbose: 222 | print("Converged at iteration %d: " 223 | "center shift %e within tolerance %e" 224 | % (i, center_shift_total, tol)) 225 | break 226 | 227 | if center_shift_total > 0: 228 | # rerun E-step in case of non-convergence so that predicted labels 229 | # match cluster centers 230 | best_labels, best_inertia = \ 231 | _labels_inertia(X, sample_weight, x_squared_norms, 232 | best_centers, distances=distances, same_cluster_size=same_cluster_size) 233 | 234 | return best_labels, best_inertia, best_centers, i + 1 235 | -------------------------------------------------------------------------------- /regularize_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | def standard(model, arch, num_classes): 6 | if arch == "mobilenetv2": 7 | new_model = MobileNetV2(num_classes=num_classes) 8 | new_model.conv1 = model.conv1 9 | new_model.bn1 = model.bn1 10 | for new_layer, layer in zip(new_model.layers, model.layers): 11 | new_layer.conv1 = layer.conv1 12 | new_layer.bn1 = layer.bn1 13 | new_layer.conv2 = layer.conv2 14 | new_layer.bn2 = layer.bn2 15 | new_layer.conv3 = layer.conv3 16 | new_layer.bn3 = layer.bn3 17 | new_layer.shortcut = layer.shortcut 18 | new_model.conv2 = model.conv2 19 | new_model.bn2 = model.bn2 20 | new_model.linear = model.linear 21 | else: 22 | new_model = ShuffleNetV2(1) 23 | new_model.conv1 = model.conv1 24 | new_model.bn1 = model.bn1 25 | for new_layer, layer in [(new_model.layer1, model.layer1), (new_model.layer2, model.layer2), (new_model.layer3, model.layer3)]: 26 | new_layer[0].conv1 = layer[0].conv1 27 | new_layer[0].bn1 = layer[0].bn1 28 | new_layer[0].conv2 = layer[0].conv2 29 | new_layer[0].bn2 = layer[0].bn2 30 | new_layer[0].conv3 = layer[0].conv3 31 | new_layer[0].bn3 = layer[0].bn3 32 | new_layer[0].conv4 = layer[0].conv4 33 | new_layer[0].bn4 = layer[0].bn4 34 | new_layer[0].conv5 = layer[0].conv5 35 | new_layer[0].bn5 = layer[0].bn5 36 | new_layer[0].shuffle = layer[0].shuffle 37 | for i in range(1, len(new_layer)): 38 | new_layer[i].split = layer[i].split 39 | new_layer[i].conv1 = layer[i].conv1 40 | new_layer[i].bn1 = layer[i].bn1 41 | new_layer[i].conv2 = layer[i].conv2 42 | new_layer[i].bn2 = layer[i].bn2 43 | new_layer[i].conv3 = layer[i].conv3 44 | new_layer[i].bn3 = layer[i].bn3 45 | new_layer[i].shuffle = layer[i].shuffle 46 | new_model.conv2 = model.conv2 47 | new_model.bn2 = model.bn2 48 | new_model.linear = model.linear 49 | return new_model 50 | 51 | 52 | 53 | 54 | class Block(nn.Module): 55 | '''expand + depthwise + pointwise''' 56 | def __init__(self, in_planes, out_planes, expansion, stride): 57 | super(Block, self).__init__() 58 | self.stride = stride 59 | 60 | planes = expansion * in_planes 61 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, stride=1, padding=0, bias=False) 62 | self.bn1 = nn.BatchNorm2d(planes) 63 | self.relu1 = nn.ReLU(inplace=True) 64 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, groups=planes, bias=False) 65 | self.bn2 = nn.BatchNorm2d(planes) 66 | self.relu2 = nn.ReLU(inplace=True) 67 | self.conv3 = nn.Conv2d(planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 68 | self.bn3 = nn.BatchNorm2d(out_planes) 69 | 70 | self.shortcut = nn.Sequential() 71 | if stride == 1 and in_planes != out_planes: 72 | self.shortcut = nn.Sequential( 73 | nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False), 74 | nn.BatchNorm2d(out_planes), 75 | ) 76 | 77 | def forward(self, x): 78 | out = self.relu1(self.bn1(self.conv1(x))) 79 | out = self.relu2(self.bn2(self.conv2(out))) 80 | out = self.bn3(self.conv3(out)) 81 | out = out + self.shortcut(x) if self.stride==1 else out 82 | return out 83 | 84 | 85 | class MobileNetV2(nn.Module): 86 | # (expansion, out_planes, num_blocks, stride) 87 | cfg = [(1, 16, 1, 1), 88 | (6, 24, 2, 1), # NOTE: change stride 2 -> 1 for CIFAR10 89 | (6, 32, 3, 2), 90 | (6, 64, 4, 2), 91 | (6, 96, 3, 1), 92 | (6, 160, 3, 2), 93 | (6, 320, 1, 1)] 94 | 95 | def __init__(self, num_classes=10): 96 | super(MobileNetV2, self).__init__() 97 | # NOTE: change conv1 stride 2 -> 1 for CIFAR10 98 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False) 99 | self.bn1 = nn.BatchNorm2d(32) 100 | self.relu1 = nn.ReLU(inplace=True) 101 | self.layers = self._make_layers(in_planes=32) 102 | self.conv2 = nn.Conv2d(320, 1280, kernel_size=1, stride=1, padding=0, bias=False) 103 | self.bn2 = nn.BatchNorm2d(1280) 104 | self.relu2 = nn.ReLU(inplace=True) 105 | self.linear = nn.Linear(1280, num_classes) 106 | 107 | def _make_layers(self, in_planes): 108 | layers = [] 109 | for expansion, out_planes, num_blocks, stride in self.cfg: 110 | strides = [stride] + [1]*(num_blocks-1) 111 | for stride in strides: 112 | layers.append(Block(in_planes, out_planes, expansion, stride)) 113 | in_planes = out_planes 114 | return nn.Sequential(*layers) 115 | 116 | def forward(self, x, features_only=False): 117 | out = self.relu1(self.bn1(self.conv1(x))) 118 | out = self.layers(out) 119 | out = self.relu2(self.bn2(self.conv2(out))) 120 | # NOTE: change pooling kernel_size 7 -> 4 for CIFAR10 121 | out = F.avg_pool2d(out, 4) 122 | out = out.view(out.size(0), -1) 123 | if not features_only: 124 | out = self.linear(out) 125 | return out 126 | 127 | 128 | class ShuffleBlock(nn.Module): 129 | def __init__(self, groups=2): 130 | super(ShuffleBlock, self).__init__() 131 | self.groups = groups 132 | 133 | def forward(self, x): 134 | '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]''' 135 | N, C, H, W = x.size() 136 | g = self.groups 137 | return x.view(N, g, C//g, H, W).permute(0, 2, 1, 3, 4).reshape(N, C, H, W) 138 | 139 | 140 | class SplitBlock(nn.Module): 141 | def __init__(self, ratio): 142 | super(SplitBlock, self).__init__() 143 | self.ratio = ratio 144 | 145 | def forward(self, x): 146 | c = int(x.size(1) * self.ratio) 147 | return x[:, :c, :, :], x[:, c:, :, :] 148 | 149 | 150 | class BasicBlock(nn.Module): 151 | def __init__(self, in_channels, split_ratio=0.5): 152 | super(BasicBlock, self).__init__() 153 | self.split = SplitBlock(split_ratio) 154 | in_channels = int(in_channels * split_ratio) 155 | self.conv1 = nn.Conv2d(in_channels, in_channels, 156 | kernel_size=1, bias=False) 157 | self.bn1 = nn.BatchNorm2d(in_channels) 158 | self.relu1 = nn.ReLU(inplace=True) 159 | self.conv2 = nn.Conv2d(in_channels, in_channels, 160 | kernel_size=3, stride=1, padding=1, groups=in_channels, bias=False) 161 | self.bn2 = nn.BatchNorm2d(in_channels) 162 | self.conv3 = nn.Conv2d(in_channels, in_channels, 163 | kernel_size=1, bias=False) 164 | self.bn3 = nn.BatchNorm2d(in_channels) 165 | self.relu3 = nn.ReLU(inplace=True) 166 | self.shuffle = ShuffleBlock() 167 | 168 | def forward(self, x): 169 | x1, x2 = self.split(x) 170 | out = self.relu1(self.bn1(self.conv1(x2))) 171 | out = self.bn2(self.conv2(out)) 172 | out = self.relu3(self.bn3(self.conv3(out))) 173 | out = torch.cat([x1, out], 1) 174 | out = self.shuffle(out) 175 | return out 176 | 177 | 178 | class DownBlock(nn.Module): 179 | def __init__(self, in_channels, out_channels): 180 | super(DownBlock, self).__init__() 181 | mid_channels = out_channels // 2 182 | # left 183 | self.conv1 = nn.Conv2d(in_channels, in_channels, 184 | kernel_size=3, stride=2, padding=1, groups=in_channels, bias=False) 185 | self.bn1 = nn.BatchNorm2d(in_channels) 186 | self.conv2 = nn.Conv2d(in_channels, mid_channels, 187 | kernel_size=1, bias=False) 188 | self.bn2 = nn.BatchNorm2d(mid_channels) 189 | self.relu2 = nn.ReLU(inplace=True) 190 | # right 191 | self.conv3 = nn.Conv2d(in_channels, mid_channels, 192 | kernel_size=1, bias=False) 193 | self.bn3 = nn.BatchNorm2d(mid_channels) 194 | self.relu3 = nn.ReLU(inplace=True) 195 | self.conv4 = nn.Conv2d(mid_channels, mid_channels, 196 | kernel_size=3, stride=2, padding=1, groups=mid_channels, bias=False) 197 | self.bn4 = nn.BatchNorm2d(mid_channels) 198 | self.conv5 = nn.Conv2d(mid_channels, mid_channels, 199 | kernel_size=1, bias=False) 200 | self.bn5 = nn.BatchNorm2d(mid_channels) 201 | self.relu5 = nn.ReLU(inplace=True) 202 | 203 | self.shuffle = ShuffleBlock() 204 | 205 | def forward(self, x): 206 | # left 207 | out1 = self.bn1(self.conv1(x)) 208 | out1 = self.relu2(self.bn2(self.conv2(out1))) 209 | # right 210 | out2 = self.relu3(self.bn3(self.conv3(x))) 211 | out2 = self.bn4(self.conv4(out2)) 212 | out2 = self.relu5(self.bn5(self.conv5(out2))) 213 | # concat 214 | out = torch.cat([out1, out2], 1) 215 | out = self.shuffle(out) 216 | return out 217 | 218 | 219 | class ShuffleNetV2(nn.Module): 220 | def __init__(self, net_size): 221 | super(ShuffleNetV2, self).__init__() 222 | out_channels = configs[net_size]['out_channels'] 223 | num_blocks = configs[net_size]['num_blocks'] 224 | 225 | self.conv1 = nn.Conv2d(3, 24, kernel_size=3, 226 | stride=1, padding=1, bias=False) 227 | self.bn1 = nn.BatchNorm2d(24) 228 | self.relu1 = nn.ReLU(inplace=True) 229 | self.in_channels = 24 230 | self.layer1 = self._make_layer(out_channels[0], num_blocks[0]) 231 | self.layer2 = self._make_layer(out_channels[1], num_blocks[1]) 232 | self.layer3 = self._make_layer(out_channels[2], num_blocks[2]) 233 | self.conv2 = nn.Conv2d(out_channels[2], out_channels[3], 234 | kernel_size=1, stride=1, padding=0, bias=False) 235 | self.bn2 = nn.BatchNorm2d(out_channels[3]) 236 | self.relu2 = nn.ReLU(inplace=True) 237 | self.linear = nn.Linear(out_channels[3], 10) 238 | 239 | def _make_layer(self, out_channels, num_blocks): 240 | layers = [DownBlock(self.in_channels, out_channels)] 241 | for i in range(num_blocks): 242 | layers.append(BasicBlock(out_channels)) 243 | self.in_channels = out_channels 244 | return nn.Sequential(*layers) 245 | 246 | def forward(self, x, features_only=False): 247 | out = self.relu1(self.bn1(self.conv1(x))) 248 | # out = F.max_pool2d(out, 3, stride=2, padding=1) 249 | out = self.layer1(out) 250 | out = self.layer2(out) 251 | out = self.layer3(out) 252 | out = self.relu2(self.bn2(self.conv2(out))) 253 | out = F.avg_pool2d(out, 4) 254 | out = out.view(out.size(0), -1) 255 | if not features_only: 256 | out = self.linear(out) 257 | return out 258 | 259 | 260 | configs = { 261 | 0.5: { 262 | 'out_channels': (48, 96, 192, 1024), 263 | 'num_blocks': (3, 7, 3) 264 | }, 265 | 266 | 1: { 267 | 'out_channels': (116, 232, 464, 1024), 268 | 'num_blocks': (3, 7, 3) 269 | }, 270 | 1.5: { 271 | 'out_channels': (176, 352, 704, 1024), 272 | 'num_blocks': (3, 7, 3) 273 | }, 274 | 2: { 275 | 'out_channels': (224, 488, 976, 2048), 276 | 'num_blocks': (3, 7, 3) 277 | } 278 | } -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /prune_utils/prune.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | from torchvision import models 4 | import sys 5 | import numpy as np 6 | from prune_utils.layer_prune import ( 7 | prune_output_linear_layer_, 8 | prune_contiguous_conv2d_, 9 | prune_conv2d_out_channels_, 10 | prune_batchnorm2d_, 11 | prune_linear_in_features_, 12 | prune_contiguous_conv2d_last) 13 | 14 | def replace_layers(model, i, indexes, layers): 15 | if i in indexes: 16 | return layers[indexes.index(i)] 17 | return model[i] 18 | 19 | def prune_vgg16_conv_layer(model, layer_index, filter_index, use_batch_norm=False): 20 | _, conv = list(model.features._modules.items())[layer_index] 21 | next_conv = None 22 | offset = 1 23 | 24 | while layer_index + offset < len(model.features._modules.items()): 25 | res = list(model.features._modules.items())[layer_index+offset] 26 | if isinstance(res[1], torch.nn.modules.conv.Conv2d): 27 | next_name, next_conv = res 28 | break 29 | offset = offset + 1 30 | 31 | new_conv = \ 32 | torch.nn.Conv2d(in_channels = conv.in_channels, \ 33 | out_channels = conv.out_channels - 1, 34 | kernel_size = conv.kernel_size, \ 35 | stride = conv.stride, 36 | padding = conv.padding, 37 | dilation = conv.dilation, 38 | groups = conv.groups, 39 | bias = True)#conv.bias) 40 | 41 | old_weights = conv.weight.data.cpu().numpy() 42 | new_weights = new_conv.weight.data.cpu().numpy() 43 | new_weights[:filter_index, :, :, :] = old_weights[:filter_index, :, :, :] 44 | new_weights[filter_index : , :, :, :] = old_weights[filter_index + 1 :, :, :, :] 45 | new_conv.weight.data = torch.from_numpy(new_weights).cuda() 46 | 47 | if conv.bias is not None: 48 | bias_numpy = conv.bias.data.cpu().numpy() 49 | bias = np.zeros(shape = (bias_numpy.shape[0] - 1), dtype = np.float32) 50 | bias[:filter_index] = bias_numpy[:filter_index] 51 | bias[filter_index : ] = bias_numpy[filter_index + 1 :] 52 | new_conv.bias.data = torch.from_numpy(bias).cuda() 53 | 54 | if use_batch_norm: 55 | _, bn = list(model.features._modules.items())[layer_index + 1] 56 | new_bn = torch.nn.BatchNorm2d(conv.out_channels - 1) 57 | 58 | old_weights = bn.weight.data.cpu().numpy() 59 | new_weights = new_bn.weight.data.cpu().numpy() 60 | new_weights[:filter_index] = old_weights[:filter_index] 61 | new_weights[filter_index:] = old_weights[filter_index+1:] 62 | 63 | 64 | old_bias = bn.bias.data.cpu().numpy() 65 | new_bias = new_bn.bias.data.cpu().numpy() 66 | new_bias[:filter_index] = old_bias[:filter_index] 67 | new_bias[filter_index:] = old_bias[filter_index+1:] 68 | 69 | 70 | 71 | old_running_mean = bn.running_mean.data.cpu().numpy() 72 | new_running_mean = new_bn.running_mean.data.cpu().numpy() 73 | new_running_mean[:filter_index] = old_running_mean[:filter_index] 74 | new_running_mean[filter_index:] = old_running_mean[filter_index+1:] 75 | 76 | 77 | old_running_var = bn.running_var.data.cpu().numpy() 78 | new_running_var = new_bn.running_var.data.cpu().numpy() 79 | new_running_var[:filter_index] = old_running_var[:filter_index] 80 | new_running_var[filter_index:] = old_running_var[filter_index+1:] 81 | 82 | new_bn.weight.data = torch.from_numpy(new_weights).cuda() 83 | new_bn.bias.data = torch.from_numpy(new_bias).cuda() 84 | new_bn.running_mean.data = torch.from_numpy(new_running_mean).cuda() 85 | new_bn.running_var.data = torch.from_numpy(new_running_var).cuda() 86 | 87 | 88 | if not next_conv is None: 89 | next_new_conv = \ 90 | torch.nn.Conv2d(in_channels = next_conv.in_channels - 1,\ 91 | out_channels = next_conv.out_channels, \ 92 | kernel_size = next_conv.kernel_size, \ 93 | stride = next_conv.stride, 94 | padding = next_conv.padding, 95 | dilation = next_conv.dilation, 96 | groups = next_conv.groups, 97 | bias = True)#next_conv.bias) 98 | 99 | old_weights = next_conv.weight.data.cpu().numpy() 100 | new_weights = next_new_conv.weight.data.cpu().numpy() 101 | 102 | new_weights[:, : filter_index, :, :] = old_weights[:, : filter_index, :, :] 103 | new_weights[:, filter_index : , :, :] = old_weights[:, filter_index + 1 :, :, :] 104 | next_new_conv.weight.data = torch.from_numpy(new_weights).cuda() 105 | 106 | if next_conv.bias is not None: 107 | next_new_conv.bias.data = torch.from_numpy(next_conv.bias.data.cpu().numpy().copy()).cuda() 108 | 109 | if not next_conv is None: 110 | features = torch.nn.Sequential( 111 | *(replace_layers(model.features, i, [layer_index, layer_index + 1, layer_index+offset], \ 112 | [new_conv, new_bn, next_new_conv]) for i, _ in enumerate(model.features))) 113 | del model.features 114 | del conv 115 | 116 | model.features = features 117 | else: 118 | #Prunning the last conv layer. This affects the first linear layer of the classifier. 119 | model.features = torch.nn.Sequential( 120 | *(replace_layers(model.features, i, [layer_index, layer_index+1], \ 121 | [new_conv, new_bn]) for i, _ in enumerate(model.features))) 122 | layer_index = 0 123 | old_linear_layer = None 124 | if len(model.classifier._modules): 125 | for _, module in model.classifier._modules.items(): 126 | if isinstance(module, torch.nn.Linear): 127 | old_linear_layer = module 128 | break 129 | layer_index = layer_index + 1 130 | else: 131 | old_linear_layer = model.classifier 132 | 133 | if old_linear_layer is None: 134 | raise BaseException("No linear layer found in classifier") 135 | params_per_input_channel = old_linear_layer.in_features / conv.out_channels 136 | 137 | new_linear_layer = \ 138 | torch.nn.Linear(int(old_linear_layer.in_features - params_per_input_channel), 139 | int(old_linear_layer.out_features)) 140 | 141 | old_weights = old_linear_layer.weight.data.cpu().numpy() 142 | new_weights = new_linear_layer.weight.data.cpu().numpy() 143 | 144 | new_weights[:, : int(filter_index * params_per_input_channel)] = \ 145 | old_weights[:, : int(filter_index * params_per_input_channel)] 146 | new_weights[:, int(filter_index * params_per_input_channel) :] = \ 147 | old_weights[:, int((filter_index + 1) * params_per_input_channel) :] 148 | 149 | new_linear_layer.bias.data = torch.from_numpy(old_linear_layer.bias.data.cpu().numpy()).cuda() 150 | 151 | new_linear_layer.weight.data = torch.from_numpy(new_weights).cuda() 152 | 153 | if len(model.classifier._modules): 154 | classifier = torch.nn.Sequential( 155 | *(replace_layers(model.classifier, i, [layer_index], \ 156 | [new_linear_layer]) for i, _ in enumerate(model.classifier))) 157 | else: 158 | classifier = torch.nn.Sequential(new_linear_layer) 159 | 160 | del model.classifier 161 | del next_conv 162 | del conv 163 | model.classifier = classifier 164 | 165 | return model 166 | 167 | def prune_last_fc_layers(model, class_indices, filter_indices = None, use_bce=False): 168 | layer_index = 0 169 | old_linear_layer = None 170 | counter = 0 171 | out_dim_prev = None 172 | filter_idx_mask = None 173 | linear_count = 0 174 | 175 | for idx, module in enumerate(model.classifier.modules()): 176 | if linear_count >= len(filter_indices): 177 | break 178 | 179 | if isinstance(module, torch.nn.Linear): 180 | old_linear_layer = module 181 | old_weights = old_linear_layer.weight.data 182 | # The new in dimension is the out dimensio of the last layer pruned, 183 | # if counter == 1, then the last layer is the the last conv layer, 184 | # otherwise, it is the previous linear layer 185 | in_dim = int(old_linear_layer.in_features) if counter == 1 else out_dim 186 | prev_filter_idx_mask = filter_idx_mask 187 | # The channel mask has the number of channels as the out dim - pruning candidates 188 | filter_idx_mask = [i for i in range(old_weights.shape[0]) if i not in filter_indices[linear_count]] 189 | out_dim = len(filter_idx_mask) 190 | 191 | new_linear_layer = \ 192 | torch.nn.Linear(in_dim, out_dim) 193 | 194 | # The new bias has the shape of the out dimension 195 | new_linear_layer.bias.data = old_linear_layer.bias.data[filter_idx_mask] 196 | # The weight format is out_dim x in_dim, so we first selectively index the out dim, using the channel mask 197 | # Then selectively index the in dim, by the previous layer's filter mask 198 | # If the last layer was the last conv layer, prev_filter_idx_mask is None, in which case it indexes everything (no mask) 199 | new_linear_layer.weight.data = old_weights[filter_idx_mask, :][:, prev_filter_idx_mask].squeeze() 200 | 201 | # Set the new linear layer with the model 202 | model.classifier[idx - 1] = new_linear_layer 203 | 204 | linear_count += 1 205 | counter += 1 206 | 207 | 208 | counter = 0 209 | layer_index = 0 210 | if len(model.classifier._modules): 211 | for _, module in model.classifier._modules.items(): 212 | if isinstance(module, torch.nn.Linear): 213 | old_linear_layer = module 214 | layer_index = counter 215 | counter += 1 216 | else: 217 | old_linear_layer = model.classifier 218 | 219 | if old_linear_layer is None: 220 | raise BaseException("No linear layer found in classifier") 221 | 222 | # If using bce, we don't need a negative out 223 | bce_offset = 0 if use_bce else 1 224 | # Create a new linear layer, in dimension is the out dimension of previous layer 225 | # out dimension is the number of classes with the pruned model 226 | new_linear_layer = \ 227 | torch.nn.Linear(out_dim, 228 | len(class_indices) + bce_offset) 229 | 230 | old_weights = old_linear_layer.weight.data.cpu().numpy() 231 | new_weights = new_linear_layer.weight.data.cpu().numpy() 232 | 233 | new_weights[bce_offset:, :] = old_weights[class_indices][:,filter_idx_mask] 234 | 235 | 236 | 237 | new_linear_layer.bias.data[bce_offset:] = torch.from_numpy(np.asarray(old_linear_layer.bias.data.cpu().numpy()[class_indices])).cuda() 238 | 239 | new_linear_layer.weight.data = torch.from_numpy(new_weights).cuda() 240 | 241 | if len(model.classifier._modules): 242 | classifier = torch.nn.Sequential( 243 | *(replace_layers(model.classifier, i, [layer_index], \ 244 | [new_linear_layer]) for i, _ in enumerate(model.classifier))) 245 | else: 246 | classifier = torch.nn.Sequential(new_linear_layer) 247 | 248 | del model.classifier 249 | model.classifier = classifier 250 | 251 | return model 252 | 253 | def prune_resnet50(model, candidates, group_indices): 254 | layers = list(model.children()) 255 | # layer[0] : Conv2d 256 | # layer[1] : BatchNorm2e 257 | # layer[2] : ReLU 258 | layer_index = 1 259 | for stage in (layers[4], layers[5], layers[6], layers[7]): 260 | for index, block in enumerate(stage.children()): 261 | assert isinstance(block, models.resnet.Bottleneck), "only support bottleneck block" 262 | children_dict = dict(block.named_children()) 263 | conv1 = children_dict['conv1'] 264 | conv2 = children_dict['conv2'] 265 | conv3 = children_dict['conv3'] 266 | prune_contiguous_conv2d_( 267 | conv1, conv2, candidates[layer_index], bn=children_dict['bn1']) 268 | layer_index += 1 269 | prune_contiguous_conv2d_( 270 | conv2, conv3, candidates[layer_index], bn=children_dict['bn2']) 271 | layer_index += 2 272 | # because we are using the output of the ReLU, the output of 273 | # the downsample is merged before ReLU, so we do not need to 274 | # increase the layer index 275 | prune_output_linear_layer_(model.fc, group_indices, use_bce=False) 276 | 277 | if __name__ == '__main__': 278 | model = models.vgg16(pretrained=True) 279 | model.train() 280 | 281 | t0 = time.time() 282 | model = prune_conv_layer(model, 28, 10) 283 | print("The prunning took", time.time() - t0) 284 | -------------------------------------------------------------------------------- /prune_and_get_model.py: -------------------------------------------------------------------------------- 1 | import re 2 | import glob 3 | import models.cifar as models 4 | import os 5 | import sys 6 | import argparse 7 | import pathlib 8 | import pickle 9 | import copy 10 | import numpy as np 11 | import re 12 | import torch 13 | from torch import nn 14 | import load_model 15 | import torch.multiprocessing as mp 16 | 17 | from regularize_model import standard 18 | from prune_utils.prune import prune_vgg16_conv_layer, prune_last_fc_layers, prune_resnet50 19 | from prune_utils.layer_prune import ( 20 | prune_output_linear_layer_, 21 | prune_contiguous_conv2d_, 22 | prune_conv2d_out_channels_, 23 | prune_batchnorm2d_, 24 | prune_linear_in_features_, 25 | prune_mobile_block, 26 | prune_shuffle_layer) 27 | from models.cifar.resnet import Bottleneck 28 | import torchvision.models as imagenet_models 29 | 30 | parser = argparse.ArgumentParser(description='VGG with mask layer on cifar10') 31 | parser.add_argument('-d', '--dataset', required=True, type=str) 32 | parser.add_argument('-c', '--prune-candidates', default="./prune_candidate_logs/", 33 | type=str, help='Directory which stores the prune candidates for each model') 34 | parser.add_argument('-a', '--arch', default='vgg19_bn', 35 | type=str, help='The architecture of the trained model') 36 | parser.add_argument('-r', '--resume', default='', type=str, 37 | help='The path to the checkpoints') 38 | parser.add_argument('-s', '--save', default='./pruned_models', 39 | type=str, help='The path to store the pruned models') 40 | parser.add_argument('--bce', default=False, type=bool, 41 | help='Prune according to binary cross entropy loss, i.e. no additional negative output for classifer') 42 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 43 | help='use pre-trained model') 44 | args = parser.parse_args() 45 | 46 | 47 | def prune_vgg(model, pruned_candidates, group_indices): 48 | features = model.features 49 | conv_indices = [i for i, layer in enumerate(features) if isinstance(layer, nn.Conv2d)] 50 | conv_bn_indices = [i for i, layer in enumerate(features) if isinstance(layer, (nn.Conv2d, nn.BatchNorm2d))] 51 | assert len(conv_indices) == len(pruned_candidates) 52 | assert len(conv_indices) * 2 == len(conv_bn_indices) 53 | 54 | for i, conv_index in enumerate(conv_indices[:-1]): 55 | next_conv = None 56 | for j in range(conv_index + 1, len(features)): 57 | l = features[j] 58 | if isinstance(l, nn.Conv2d): 59 | next_conv = l 60 | break 61 | if next_conv is None: 62 | break 63 | bn = model.features[conv_index + 1] 64 | assert isinstance(bn, nn.BatchNorm2d) 65 | prune_contiguous_conv2d_( 66 | features[conv_index], 67 | next_conv, 68 | pruned_candidates[i], 69 | bn=bn) 70 | 71 | # Prunning the last conv layer. This affects the first linear layer of the classifier. 72 | last_conv = features[conv_indices[-1]] 73 | classifier = model.classifier 74 | assert classifier.in_features % last_conv.out_channels == 0 75 | params_per_input_channel = classifier.in_features // last_conv.out_channels 76 | 77 | pruned_indices = pruned_candidates[-1] 78 | prune_conv2d_out_channels_(last_conv, pruned_indices) 79 | prune_batchnorm2d_(features[conv_bn_indices[-1]], pruned_indices) 80 | 81 | linear_pruned_indices = [] 82 | for i in pruned_indices: 83 | linear_pruned_indices += list(range(i * params_per_input_channel, (i + 1) * params_per_input_channel)) 84 | 85 | prune_linear_in_features_(classifier, linear_pruned_indices) 86 | # prune the output of the classifier 87 | prune_output_linear_layer_(classifier, group_indices, use_bce=args.bce) 88 | 89 | 90 | def prune_resnet(model, candidates, group_indices): 91 | layers = list(model.children()) 92 | # layer[0] : Conv2d 93 | # layer[1] : BatchNorm2e 94 | # layer[2] : ReLU 95 | layer_index = 1 96 | for stage in (layers[3], layers[4], layers[5]): 97 | for block in stage.children(): 98 | assert isinstance(block, Bottleneck), "only support bottleneck block" 99 | children_dict = dict(block.named_children()) 100 | conv1 = children_dict['conv1'] 101 | conv2 = children_dict['conv2'] 102 | conv3 = children_dict['conv3'] 103 | 104 | prune_contiguous_conv2d_( 105 | conv1, conv2, candidates[layer_index], bn=children_dict['bn1']) 106 | layer_index += 1 107 | prune_contiguous_conv2d_( 108 | conv2, conv3, candidates[layer_index], bn=children_dict['bn2']) 109 | layer_index += 2 110 | # because we are using the output of the ReLU, the output of 111 | # the downsample is merged before ReLU, so we do not need to 112 | # increase the layer index 113 | prune_output_linear_layer_(model.fc, group_indices, use_bce=args.bce) 114 | 115 | def prune_mobilenetv2(model, candidates, group_indices): 116 | layers = list(model.layers) 117 | layer_index = 1 118 | for block in layers: 119 | conv1 = block.conv1 120 | bn1 = block.bn1 121 | conv2 = block.conv2 122 | bn2 = block.bn2 123 | conv3 = block.conv3 124 | prune_1 = candidates[layer_index] 125 | prune_2 = candidates[layer_index+1] 126 | prune_mobile_block(conv1, conv2, conv3, prune_1, prune_2, bn1, bn2) 127 | layer_index += 2 128 | prune_output_linear_layer_(model.linear, group_indices, use_bce=args.bce) 129 | 130 | def prune_shufflenetv2(model, candidates, group_indices): 131 | layer1, layer2, layer3 = model.layer1, model.layer2, model.layer3 132 | layer1_candidates = candidates[1:15] 133 | layer2_candidates = candidates[15:41] 134 | layer3_candidates = candidates[41:55] 135 | prune_shuffle_layer(layer1, layer1_candidates) 136 | prune_shuffle_layer(layer2, layer2_candidates) 137 | prune_shuffle_layer(layer3, layer3_candidates) 138 | prune_output_linear_layer_(model.linear, group_indices, use_bce=args.bce) 139 | 140 | def filename_to_index(filename): 141 | filename = filename[6+len(args.prune_candidates):] 142 | return int(filename[:filename.index('_')]) 143 | 144 | def update_list(l): 145 | for i in range(len(l)): 146 | l[i] -= 1 147 | 148 | def prune_cifar_worker(proc_ind, i, new_model, candidates, group_indices, arch, model_save_directory): 149 | num_gpus = torch.cuda.device_count() 150 | new_model.cuda(i % num_gpus) 151 | group_indices = group_indices.tolist() 152 | if args.arch.startswith('vgg'): 153 | prune_vgg(new_model, candidates, group_indices) 154 | elif args.arch.startswith('resnet'): 155 | prune_resnet(new_model, candidates, group_indices) 156 | elif args.arch.startswith('mobile'): 157 | prune_mobilenetv2(new_model, candidates, group_indices) 158 | elif args.arch.startswith('shuffle'): 159 | prune_shufflenetv2(new_model, candidates, group_indices) 160 | else: 161 | raise NotImplementedError 162 | 163 | # save the pruned model 164 | pruned_model_name = f"{arch}_{i}_pruned_model.pth" 165 | torch.save(new_model, os.path.join( 166 | model_save_directory, pruned_model_name)) 167 | print('Pruned model saved at', model_save_directory) 168 | 169 | def prune_imagenet_worker(proc_ind, model, candidates, group_indices, group_id, model_save_directory): 170 | num_gpus = torch.cuda.device_count() 171 | torch.cuda.set_device(group_id % num_gpus) 172 | model.cuda(group_id % num_gpus) 173 | if args.arch != "resnet50": 174 | conv_indices = [idx for idx, (n, p) in enumerate(model.features._modules.items()) if isinstance(p, nn.modules.conv.Conv2d)] 175 | offset = 0 176 | for layer_index, filter_list in zip(conv_indices, candidates): 177 | offset += 1 178 | filters_to_remove = list(filter_list) 179 | sorted(filters_to_remove) 180 | 181 | while len(filters_to_remove): 182 | filter_index = filters_to_remove.pop(0) 183 | model = prune_vgg16_conv_layer(model, layer_index, filter_index, use_batch_norm=True) 184 | update_list(filters_to_remove) 185 | 186 | # save the pruned model 187 | # The input dimension of the first fc layer is pruned from above 188 | model = prune_last_fc_layers(model, \ 189 | group_indices, \ 190 | filter_indices = candidates[offset:], \ 191 | use_bce = args.bce) 192 | else: 193 | prune_resnet50(model, candidates, group_indices) 194 | 195 | pruned_model_name = args.arch + '_{}'.format(group_id) + '_pruned_model.pth' 196 | print('Grouped mode %s Total params: %.2fM' % (group_id ,sum(p.numel() for p in model.parameters())/1000000.0)) 197 | torch.save(model, os.path.join(model_save_directory, pruned_model_name)) 198 | print('Pruned model saved at', model_save_directory) 199 | 200 | def main(): 201 | use_cuda = torch.cuda.is_available() 202 | # load groups 203 | file_names = [f for f in glob.glob(args.prune_candidates + "group_*.npy", recursive=False)] 204 | file_names.sort(key=filename_to_index) 205 | groups = np.load(open(args.prune_candidates + "grouping_config.npy", "rb")) 206 | 207 | # create pruned model save path 208 | model_save_directory = os.path.join(args.save, args.arch) 209 | pathlib.Path(model_save_directory).mkdir(parents=True, exist_ok=True) 210 | np.save(open(os.path.join(args.save, "grouping_config.npy"), "wb"), groups) 211 | if len(groups[0]) == 1: 212 | args.bce = True 213 | print(f'==> Preparing dataset {args.dataset}') 214 | if args.dataset in ['cifar10', 'cifar100']: 215 | if args.dataset == 'cifar10': 216 | num_classes = 10 217 | elif args.dataset == 'cifar100': 218 | num_classes = 100 219 | 220 | processes = [] 221 | # for each class 222 | for i, (group_indices, file_name) in enumerate(zip(groups, file_names)): 223 | # load pruning candidates 224 | with open(file_name, 'rb') as f: 225 | candidates = pickle.load(f) 226 | # load checkpoints 227 | model = load_model.load_pretrain_model( 228 | args.arch, args.dataset, args.resume, num_classes, use_cuda) 229 | new_model = copy.deepcopy(model) 230 | if args.arch in ["mobilenetv2", "shufflenetv2"]: 231 | new_model = standard(new_model, args.arch, num_classes) 232 | p = mp.spawn(prune_cifar_worker, args=(i, new_model, candidates, group_indices, args.arch, model_save_directory), join=False) 233 | processes.append(p) 234 | for p in processes: 235 | p.join() 236 | 237 | 238 | elif args.dataset == 'imagenet': 239 | num_classes = len(groups) 240 | processes = [] 241 | # for each class 242 | for group_id, file_name in enumerate(file_names): 243 | print('Pruning classes {} from candidates in {}'.format(group_id, file_name)) 244 | group_indices = groups[group_id] 245 | # load pruning candidates 246 | print(file_name) 247 | candidates = np.load(open(file_name, 'rb'), allow_pickle=True).tolist() 248 | 249 | num_gpus = torch.cuda.device_count() 250 | # load checkpoints 251 | if args.pretrained: 252 | print("=> using pre-trained model '{}'".format(args.arch)) 253 | model = imagenet_models.__dict__[args.arch](pretrained=True) 254 | # model = torch.nn.DataParallel(model).cuda() #TODO use DataParallel 255 | model = model.cuda(group_id % num_gpus) 256 | else: 257 | checkpoint = torch.load(args.resume) 258 | model = imagenet_models.__dict__[args.arch](num_classes=num_classes) 259 | # model = torch.nn.DataParallel(model).cuda() #TODO use DataParallel 260 | model = model.cuda(group_id % num_gpus) 261 | model.load_state_dict(checkpoint['state_dict']) 262 | 263 | # join existing num_gpus processes, to make sure only num_gpus processes are running at a time 264 | if group_id % num_gpus == 0: 265 | for p in processes: 266 | p.join() 267 | processes = [] 268 | 269 | # model = model.module #TODO use DataParallel 270 | p = mp.spawn(prune_imagenet_worker, args=(model, candidates, group_indices, group_id, model_save_directory), join=False) 271 | processes.append(p) 272 | 273 | for p in processes: 274 | p.join() 275 | else: 276 | raise NotImplementedError 277 | 278 | if __name__ == '__main__': 279 | main() 280 | -------------------------------------------------------------------------------- /imagenet_activations.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import shutil 5 | import time 6 | import warnings 7 | import sys 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.parallel 12 | import torch.backends.cudnn as cudnn 13 | import torch.distributed as dist 14 | import torch.optim 15 | import torch.multiprocessing as mp 16 | import torch.utils.data 17 | import torch.utils.data.distributed 18 | import torchvision.transforms as transforms 19 | import imagenet_dataset as datasets 20 | import torchvision.models as models 21 | 22 | import numpy as np 23 | from apoz_policy_imagenet import * 24 | import pdb 25 | 26 | model_names = sorted(name for name in models.__dict__ 27 | if name.islower() and not name.startswith("__") 28 | and callable(models.__dict__[name])) 29 | 30 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 31 | parser.add_argument('--data', metavar='DIR', 32 | help='path to imagenet dataset') 33 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', 34 | choices=model_names, 35 | help='model architecture: ' + 36 | ' | '.join(model_names) + 37 | ' (default: resnet18)') 38 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 39 | help='number of data loading workers (default: 4)') 40 | parser.add_argument('--epochs', default=90, type=int, metavar='N', 41 | help='number of total epochs to run') 42 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 43 | help='manual epoch number (useful on restarts)') 44 | parser.add_argument('-b', '--batch-size', default=64, type=int, 45 | metavar='N', 46 | help='mini-batch size (default: 256), this is the total ' 47 | 'batch size of all GPUs on the current node when ' 48 | 'using Data Parallel or Distributed Data Parallel') 49 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 50 | metavar='LR', help='initial learning rate', dest='lr') 51 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 52 | help='momentum') 53 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 54 | metavar='W', help='weight decay (default: 1e-4)', 55 | dest='weight_decay') 56 | parser.add_argument('-p', '--print-freq', default=1, type=int, 57 | metavar='N', help='print frequency (default: 10)') 58 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 59 | help='path to latest checkpoint (default: none)') 60 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 61 | help='evaluate model on validation set') 62 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 63 | help='use pre-trained model') 64 | parser.add_argument('--world-size', default=-1, type=int, 65 | help='number of nodes for distributed training') 66 | parser.add_argument('--rank', default=-1, type=int, 67 | help='node rank for distributed training') 68 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 69 | help='url used to set up distributed training') 70 | parser.add_argument('--dist-backend', default='nccl', type=str, 71 | help='distributed backend') 72 | parser.add_argument('--seed', default=None, type=int, 73 | help='seed for initializing training. ') 74 | parser.add_argument('--gpu', default=None, type=int, 75 | help='GPU id to use.') 76 | parser.add_argument('--multiprocessing-distributed', default=False, action='store_true', 77 | help='Use multi-processing distributed training to launch ' 78 | 'N processes per node, which has N GPUs. This is the ' 79 | 'fastest way to use PyTorch for either single node or ' 80 | 'multi node data parallel training') 81 | 82 | parser.add_argument('--group', type=int, nargs='+', default=[], 83 | help='Generate activations based on the these class indices') 84 | parser.add_argument('--name', type=str, default='Name', help='Set the name id of the group') 85 | 86 | 87 | global num_layers 88 | num_layers = sys.maxsize 89 | global layer_idx 90 | layer_idx = 0 91 | num_batches = 0 92 | best_acc1 = 0 93 | 94 | def main(): 95 | args = parser.parse_args() 96 | if args.seed is not None: 97 | random.seed(args.seed) 98 | torch.manual_seed(args.seed) 99 | cudnn.deterministic = True 100 | warnings.warn('You have chosen to seed training. ' 101 | 'This will turn on the CUDNN deterministic setting, ' 102 | 'which can slow down your training considerably! ' 103 | 'You may see unexpected behavior when restarting ' 104 | 'from checkpoints.') 105 | 106 | if args.gpu is not None: 107 | warnings.warn('You have chosen a specific GPU. This will completely ' 108 | 'disable data parallelism.') 109 | 110 | main_worker(args.gpu, args) 111 | 112 | 113 | def main_worker(gpu, args): 114 | global best_acc1 115 | global num_layers 116 | global apoz_scores_by_layer 117 | global avg_scores_by_layer 118 | args.gpu = gpu 119 | 120 | if args.gpu is not None: 121 | print("Use GPU: {} for training".format(args.gpu)) 122 | 123 | # create model 124 | if args.pretrained: 125 | print("=> using pre-trained model '{}'".format(args.arch)) 126 | model = models.__dict__[args.arch](pretrained=True) 127 | else: 128 | print("=> creating model '{}'".format(args.arch)) 129 | model = models.__dict__[args.arch]() 130 | 131 | if args.gpu is not None: 132 | print("checkpoint 1...") 133 | torch.cuda.set_device(args.gpu) 134 | model = model.cuda(args.gpu) 135 | 136 | # define loss function (criterion) and optimizer 137 | criterion = nn.CrossEntropyLoss() 138 | 139 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 140 | momentum=args.momentum, 141 | weight_decay=args.weight_decay) 142 | 143 | # optionally resume from a checkpoint 144 | if args.resume: 145 | if os.path.isfile(args.resume): 146 | print("=> loading checkpoint '{}'".format(args.resume)) 147 | checkpoint = torch.load(args.resume) 148 | args.start_epoch = checkpoint['epoch'] 149 | best_acc1 = checkpoint['best_acc1'] 150 | optimizer.load_state_dict(checkpoint['optimizer']) 151 | print("=> loaded checkpoint '{}' (epoch {})" 152 | .format(args.resume, checkpoint['epoch'])) 153 | else: 154 | print("=> no checkpoint found at '{}'".format(args.resume)) 155 | 156 | print("checkpoint 2...") 157 | apoz_scores_by_layer = [] 158 | avg_scores_by_layer = [] 159 | model = model.cuda(args.gpu) 160 | 161 | # Data loading code 162 | traindir = os.path.join(args.data, 'train') 163 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 164 | std=[0.229, 0.224, 0.225]) 165 | 166 | train_dataset = datasets.ImageFolder( 167 | traindir, 168 | transforms.Compose([ 169 | transforms.RandomResizedCrop(224), 170 | transforms.RandomHorizontalFlip(), 171 | transforms.ToTensor(), 172 | normalize, 173 | ]), activations=True, group=args.group) 174 | 175 | val_loader = torch.utils.data.DataLoader( 176 | train_dataset, batch_size=args.batch_size, shuffle=False, 177 | num_workers=args.workers, pin_memory=True, sampler=None) 178 | 179 | print("checkpoint 3...") 180 | if args.evaluate: 181 | validate(val_loader, model, criterion, args) 182 | generate_candidates(args.name) 183 | return 184 | 185 | 186 | def validate(val_loader, model, criterion, args): 187 | global layer_idx 188 | global num_batches 189 | global num_layers 190 | batch_time = AverageMeter('Time', ':6.3f') 191 | losses = AverageMeter('Loss', ':.4e') 192 | top1 = AverageMeter('Acc@1', ':6.2f') 193 | top5 = AverageMeter('Acc@5', ':6.2f') 194 | progress = ProgressMeter(len(val_loader), batch_time, losses, top1, top5, 195 | prefix='Test: ') 196 | 197 | # switch to evaluate mode 198 | model.apply(apply_hook) 199 | model.eval() 200 | 201 | with torch.no_grad(): 202 | end = time.time() 203 | for i, (input, target) in enumerate(val_loader): 204 | num_batches += 1 205 | layer_idx = 0 206 | if args.gpu is not None: 207 | input = input.cuda(args.gpu) 208 | target = target.cuda(args.gpu) 209 | # compute output 210 | output = model(input) 211 | loss = criterion(output, target) 212 | 213 | # measure accuracy and record loss 214 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 215 | losses.update(loss.item(), input.size(0)) 216 | top1.update(acc1[0], input.size(0)) 217 | top5.update(acc5[0], input.size(0)) 218 | 219 | # measure elapsed time 220 | batch_time.update(time.time() - end) 221 | end = time.time() 222 | 223 | if i % args.print_freq == 0: 224 | progress.print(i) 225 | 226 | num_layers = layer_idx 227 | 228 | # TODO: this should also be done with the ProgressMeter 229 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 230 | .format(top1=top1, top5=top5)) 231 | 232 | return top1.avg 233 | 234 | 235 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 236 | torch.save(state, filename) 237 | if is_best: 238 | shutil.copyfile(filename, 'model_best.pth.tar') 239 | 240 | def parse_activation(relu_reference, feature_map): 241 | global layer_idx 242 | global num_layers 243 | apoz_score = apoz_scoring(feature_map) 244 | avg_score = avg_scoring(feature_map) 245 | 246 | if len(apoz_scores_by_layer) < num_layers: 247 | apoz_scores_by_layer.append(apoz_score) 248 | avg_scores_by_layer.append(avg_score) 249 | else: 250 | apoz_scores_by_layer[layer_idx] = torch.add(apoz_scores_by_layer[layer_idx], apoz_score) 251 | avg_scores_by_layer[layer_idx] = torch.add(avg_scores_by_layer[layer_idx], avg_score) 252 | 253 | layer_idx += 1 254 | 255 | 256 | """ 257 | Apply a hook to RelU layer 258 | """ 259 | def hook(self, input, output): 260 | if self.__class__.__name__ == 'ReLU': 261 | parse_activation(self, output.data) 262 | 263 | def apply_hook(m): 264 | m.register_forward_hook(hook) 265 | 266 | def generate_candidates(name): 267 | global num_batches 268 | global apoz_scores_by_layer 269 | global avg_scores_by_layer 270 | global num_layers 271 | group_id_string = name 272 | apoz_thresholds = [90] * num_layers 273 | avg_thresholds = [sys.maxsize] * num_layers #sys.maxsize to disable avg 274 | candidates_by_layer = [] 275 | 276 | for layer_idx, (apoz_scores, avg_scores) in enumerate(zip(apoz_scores_by_layer, avg_scores_by_layer)): 277 | apoz_scores *= 1/ float(num_batches) 278 | apoz_scores = apoz_scores.cpu() 279 | 280 | avg_scores *= 1/ float(num_batches) 281 | avg_scores = avg_scores.cpu() 282 | 283 | avg_candidates = [idx for idx, score in enumerate(avg_scores) if score >= avg_thresholds[layer_idx]] if avg_scores.dim() != 0 else [] 284 | candidates = [x[0] for x in apoz_scores.gt(apoz_thresholds[layer_idx]).nonzero().tolist()] 285 | 286 | difference_candidates = list(set(candidates).difference(set(avg_candidates))) 287 | candidates_by_layer.append(difference_candidates) 288 | print("Total candidates: {}".format(sum([len(l) for l in candidates_by_layer]))) 289 | np.save(open("prune_candidate_logs/group_{}_apoz_layer_thresholds.npy".format( group_id_string), "wb"), candidates_by_layer) 290 | print(candidates_by_layer) 291 | 292 | class AverageMeter(object): 293 | """Computes and stores the average and current value""" 294 | def __init__(self, name, fmt=':f'): 295 | self.name = name 296 | self.fmt = fmt 297 | self.reset() 298 | 299 | def reset(self): 300 | self.val = 0 301 | self.avg = 0 302 | self.sum = 0 303 | self.count = 0 304 | 305 | def update(self, val, n=1): 306 | self.val = val 307 | self.sum += val * n 308 | self.count += n 309 | self.avg = self.sum / self.count 310 | 311 | def __str__(self): 312 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 313 | return fmtstr.format(**self.__dict__) 314 | 315 | class ProgressMeter(object): 316 | def __init__(self, num_batches, *meters, prefix=""): 317 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 318 | self.meters = meters 319 | self.prefix = prefix 320 | 321 | def print(self, batch): 322 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 323 | entries += [str(meter) for meter in self.meters] 324 | print('\t'.join(entries)) 325 | 326 | def _get_batch_fmtstr(self, num_batches): 327 | num_digits = len(str(num_batches // 1)) 328 | fmt = '{:' + str(num_digits) + 'd}' 329 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 330 | 331 | def adjust_learning_rate(optimizer, epoch, args): 332 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 333 | lr = args.lr * (0.1 ** (epoch // 30)) 334 | for param_group in optimizer.param_groups: 335 | param_group['lr'] = lr 336 | 337 | def accuracy(output, target, topk=(1,)): 338 | """Computes the accuracy over the k top predictions for the specified values of k""" 339 | with torch.no_grad(): 340 | maxk = max(topk) 341 | batch_size = target.size(0) 342 | 343 | _, pred = output.topk(maxk, 1, True, True) 344 | pred = pred.t() 345 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 346 | 347 | res = [] 348 | for k in topk: 349 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 350 | res.append(correct_k.mul_(100.0 / batch_size)) 351 | return res 352 | 353 | if __name__ == '__main__': 354 | main() 355 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import time 5 | import random 6 | import warnings 7 | 8 | from tqdm import tqdm 9 | import torch 10 | from torch import nn 11 | import torch.nn.parallel 12 | import torch.backends.cudnn as cudnn 13 | import torch.utils.data as data 14 | import torchvision.transforms as transforms 15 | import torchvision.datasets as datasets 16 | import numpy as np 17 | from utils import Logger, AverageMeter, accuracy, savefig 18 | from torch.utils.data import Dataset, DataLoader 19 | import glob 20 | import re 21 | import itertools 22 | from compute_flops import print_model_param_flops 23 | import torchvision.models as models 24 | from imagenet_evaluate_grouped import main_worker 25 | import torch.multiprocessing as mp 26 | 27 | model_names = sorted(name for name in models.__dict__ 28 | if name.islower() and not name.startswith("__") 29 | and callable(models.__dict__[name])) 30 | model_names += ["resnet110", "resnet164", "mobilenetv2", "shufflenetv2"] 31 | 32 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10/100/ImageNet Testing') 33 | # Checkpoints 34 | parser.add_argument('--retrained_dir', type=str, metavar='PATH', 35 | help='path to the directory of pruned models (default: none)') 36 | # Datasets 37 | parser.add_argument('-d', '--dataset', required=True, type=str) 38 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 39 | help='number of data loading workers (default: 4)') 40 | parser.add_argument('--test-batch', default=128, type=int, metavar='N', 41 | help='test batchsize') 42 | parser.add_argument('--data', metavar='DIR', required=False, 43 | help='path to imagenet dataset') 44 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', 45 | choices=model_names, 46 | help='model architecture: ' + 47 | ' | '.join(model_names) + 48 | ' (default: resnet18)') 49 | parser.add_argument('--epochs', default=90, type=int, metavar='N', 50 | help='number of total epochs to run') 51 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 52 | help='manual epoch number (useful on restarts)') 53 | parser.add_argument('-b', '--batch-size', default=64, type=int, 54 | metavar='N', 55 | help='mini-batch size (default: 256), this is the total ' 56 | 'batch size of all GPUs on the current node when ' 57 | 'using Data Parallel or Distributed Data Parallel') 58 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 59 | metavar='LR', help='initial learning rate', dest='lr') 60 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 61 | help='momentum') 62 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 63 | metavar='W', help='weight decay (default: 1e-4)', 64 | dest='weight_decay') 65 | parser.add_argument('-p', '--print-freq', default=10, type=int, 66 | metavar='N', help='print frequency (default: 10)') 67 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 68 | help='evaluate model on validation set') 69 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 70 | help='use pre-trained model') 71 | parser.add_argument('--world-size', default=-1, type=int, 72 | help='number of nodes for distributed training') 73 | parser.add_argument('--rank', default=-1, type=int, 74 | help='node rank for distributed training') 75 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 76 | help='url used to set up distributed training') 77 | parser.add_argument('--dist-backend', default='nccl', type=str, 78 | help='distributed backend') 79 | parser.add_argument('--gpu', default=None, type=int, 80 | help='GPU id to use.') 81 | parser.add_argument('--multiprocessing-distributed', action='store_true', 82 | help='Use multi-processing distributed training to launch ' 83 | 'N processes per node, which has N GPUs. This is the ' 84 | 'fastest way to use PyTorch for either single node or ' 85 | 'multi node data parallel training') 86 | parser.add_argument('--bce', default=False, action='store_true', 87 | help='Use binary cross entropy loss') 88 | best_acc1 = 0 89 | 90 | # Miscs 91 | parser.add_argument('--seed', type=int, default=42, help='manual seed') 92 | args = parser.parse_args() 93 | state = {k: v for k, v in args._get_kwargs()} 94 | # Validate dataset 95 | assert args.dataset == 'cifar10' or args.dataset == 'cifar100' or args.dataset == 'imagenet', 'Dataset can only be cifar10, cifar100 or imagenet.' 96 | 97 | # Use CUDA 98 | use_cuda = torch.cuda.is_available() 99 | 100 | # Random seed 101 | torch.manual_seed(args.seed) 102 | if use_cuda: 103 | torch.cuda.manual_seed_all(args.seed) 104 | 105 | torch.set_printoptions(threshold=10000) 106 | 107 | def main(): 108 | # imagenet evaluation 109 | if args.dataset == 'imagenet': 110 | imagenet_evaluate() 111 | return 112 | 113 | # cifar 10/100 evaluation 114 | print('==> Preparing dataset %s' % args.dataset) 115 | if args.dataset == 'cifar10': 116 | dataset_loader = datasets.CIFAR10 117 | elif args.dataset == 'cifar100': 118 | dataset_loader = datasets.CIFAR100 119 | else: 120 | raise NotImplementedError 121 | 122 | testloader = data.DataLoader( 123 | dataset_loader( 124 | root='./data', 125 | download=False, 126 | train=False, 127 | transform=transforms.Compose([ 128 | transforms.ToTensor(), 129 | transforms.Normalize( 130 | (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 131 | ])), 132 | batch_size = args.test_batch, 133 | shuffle = True, 134 | num_workers = args.workers) 135 | 136 | cudnn.benchmark = True 137 | criterion = nn.CrossEntropyLoss() 138 | model = load_pruned_models(args.retrained_dir+'/'+args.arch+'/') 139 | 140 | if len(model.group_info) == 10 and args.dataset == 'cifar10': 141 | args.bce = True 142 | 143 | test_acc = test_list(testloader, model, criterion, use_cuda) 144 | 145 | def imagenet_evaluate(): 146 | if args.seed is not None: 147 | random.seed(args.seed) 148 | torch.manual_seed(args.seed) 149 | cudnn.deterministic = True 150 | warnings.warn('You have chosen to seed training. ' 151 | 'This will turn on the CUDNN deterministic setting, ' 152 | 'which can slow down your training considerably! ' 153 | 'You may see unexpected behavior when restarting ' 154 | 'from checkpoints.') 155 | 156 | if args.gpu is not None: 157 | warnings.warn('You have chosen a specific GPU. This will completely ' 158 | 'disable data parallelism.') 159 | 160 | if args.dist_url == "env://" and args.world_size == -1: 161 | args.world_size = int(os.environ["WORLD_SIZE"]) 162 | 163 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 164 | 165 | ngpus_per_node = torch.cuda.device_count() 166 | if args.multiprocessing_distributed: 167 | # Since we have ngpus_per_node processes per node, the total world_size 168 | # needs to be adjusted accordingly 169 | args.world_size = ngpus_per_node * args.world_size 170 | # Use torch.multiprocessing.spawn to launch distributed processes: the 171 | # main_worker process function 172 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 173 | else: 174 | # Simply call main_worker function 175 | main_worker(args.gpu, ngpus_per_node, args) 176 | 177 | def test_list(testloader, model, criterion, use_cuda): 178 | batch_time = AverageMeter() 179 | data_time = AverageMeter() 180 | losses = AverageMeter() 181 | top1 = AverageMeter() 182 | top5 = AverageMeter() 183 | 184 | if use_cuda: 185 | model.cuda() 186 | model.eval() 187 | end = time.time() 188 | 189 | if args.dataset == 'cifar10': 190 | confusion_matrix = np.zeros((10, 10)) 191 | elif args.dataset == 'cifar100': 192 | confusion_matrix = np.zeros((100, 100)) 193 | else: 194 | raise NotImplementedError 195 | 196 | bar = tqdm(total=len(testloader)) 197 | # pdb.set_trace() 198 | for batch_idx, (inputs, targets) in enumerate(testloader): 199 | bar.update(1) 200 | # measure data loading time 201 | data_time.update(time.time() - end) 202 | if use_cuda: 203 | inputs, targets = inputs.cuda(), targets.cuda() 204 | with torch.no_grad(): 205 | outputs = model(inputs) 206 | loss = criterion(outputs, targets) 207 | for output, target in zip(outputs, targets): 208 | gt = target.item() 209 | dt = np.argmax(output.cpu().numpy()) 210 | confusion_matrix[gt, dt] += 1 211 | # measure accuracy and record loss 212 | prec1, prec5 = accuracy(outputs, targets, topk = (1, 5)) 213 | losses.update(loss.item(), inputs.size(0)) 214 | top1.update(prec1.item(), inputs.size(0)) 215 | top5.update(prec5.item(), inputs.size(0)) 216 | 217 | # measure elapsed time 218 | batch_time.update(time.time() - end) 219 | end = time.time() 220 | 221 | # plot progress 222 | bar.set_description('({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format( 223 | batch=batch_idx + 1, 224 | size=len(testloader), 225 | data=data_time.avg, 226 | bt=batch_time.avg, 227 | total='N/A' or bar.elapsed_td, 228 | eta='N/A' or bar.eta_td, 229 | loss=losses.avg, 230 | top1=top1.avg, 231 | top5=top5.avg, 232 | )) 233 | bar.close() 234 | 235 | np.set_printoptions(precision=3, linewidth=96) 236 | 237 | print("\n===== Full Confusion Matrix ==================================\n") 238 | if confusion_matrix.shape[0] < 20: 239 | print(confusion_matrix) 240 | else: 241 | print("Warning: The original confusion matrix is too big to fit into the screen. " 242 | "Skip printing the matrix.") 243 | 244 | if all([len(group) > 1 for group in model.group_info]): 245 | print("\n===== Inter-group Confusion Matrix ===========================\n") 246 | print(f"Group info: {[group for group in model.group_info]}") 247 | n_groups = len(model.group_info) 248 | group_confusion_matrix = np.zeros((n_groups, n_groups)) 249 | for i in range(n_groups): 250 | for j in range(n_groups): 251 | cols = model.group_info[i] 252 | rows = model.group_info[j] 253 | group_confusion_matrix[i, j] += confusion_matrix[cols[0], rows[0]] 254 | group_confusion_matrix[i, j] += confusion_matrix[cols[0], rows[1]] 255 | group_confusion_matrix[i, j] += confusion_matrix[cols[1], rows[0]] 256 | group_confusion_matrix[i, j] += confusion_matrix[cols[1], rows[1]] 257 | group_confusion_matrix /= group_confusion_matrix.sum(axis=-1)[:, np.newaxis] 258 | print(group_confusion_matrix) 259 | 260 | print("\n===== In-group Confusion Matrix ==============================\n") 261 | for group in model.group_info: 262 | print(f"group {group}") 263 | inter_group_matrix = confusion_matrix[group, :][:, group] 264 | inter_group_matrix /= inter_group_matrix.sum(axis=-1)[:, np.newaxis] 265 | print(inter_group_matrix) 266 | return (losses.avg, top1.avg) 267 | 268 | class GroupedModel(nn.Module): 269 | def __init__(self, model_list, group_info): 270 | super().__init__() 271 | self.group_info = group_info 272 | # flatten list of list 273 | permutation_indices = list(itertools.chain.from_iterable(group_info)) 274 | self.permutation_indices = torch.eye(len(permutation_indices))[permutation_indices] 275 | if use_cuda: 276 | self.permutation_indices = self.permutation_indices.cuda() 277 | self.model_list = nn.ModuleList(model_list) 278 | 279 | def forward(self, inputs): 280 | output_list = [] 281 | if args.bce: 282 | for model_idx, model in enumerate(self.model_list): 283 | output = model(inputs)[:, 0] 284 | output_list.append(output) 285 | output_list = torch.softmax(torch.stack(output_list, dim=1).squeeze(), dim=1) 286 | else: 287 | for model_idx, model in enumerate(self.model_list): 288 | output = torch.softmax(model(inputs), dim=1)[:, 1:] 289 | output_list.append(output) 290 | output_list = torch.cat(output_list, 1) 291 | return torch.mm(output_list, self.permutation_indices) 292 | 293 | def print_statistics(self): 294 | num_params = [] 295 | num_flops = [] 296 | 297 | print("\n===== Metrics for grouped model ==========================\n") 298 | 299 | for group_id, model in zip(self.group_info, self.model_list): 300 | n_params = sum(p.numel() for p in model.parameters()) / 10**6 301 | num_params.append(n_params) 302 | print(f'Grouped model for Class {group_id} ' 303 | f'Total params: {n_params:2f}M') 304 | num_flops.append(print_model_param_flops(model, 32)) 305 | 306 | print(f"Average number of flops: {sum(num_flops) / len(num_flops) / 10**9 :3f} G") 307 | print(f"Average number of param: {sum(num_params) / len(num_params)} M") 308 | 309 | 310 | def load_pruned_models(model_dir): 311 | group_dir = model_dir[:-(len(args.arch)+1)] 312 | if not model_dir.endswith('/'): 313 | model_dir += '/' 314 | file_names = [f for f in glob.glob(model_dir + "*.pth", recursive=False)] 315 | model_list = [torch.load(file_name, map_location=lambda storage, loc: storage.cuda(0)) for file_name in file_names] 316 | groups = np.load(open(group_dir + "grouping_config.npy", "rb")) 317 | group_info = [] 318 | for file in file_names: 319 | group_id = filename_to_index(file) 320 | print(f"Group number is: {group_id}") 321 | class_indices = groups[group_id] 322 | group_info.append(class_indices.tolist()[0]) 323 | model = GroupedModel(model_list, group_info) 324 | model.print_statistics() 325 | return model 326 | 327 | 328 | def filename_to_index(filename): 329 | filename = [int(s) for s in filename.split('_') if s.isdigit()] 330 | return filename 331 | 332 | if __name__ == '__main__': 333 | main() 334 | 335 | 336 | --------------------------------------------------------------------------------