├── sensAI-logo.png
├── utils
├── images
│ ├── cifar.png
│ └── imagenet.png
├── __init__.py
├── eval.py
├── misc.py
├── visualize.py
└── logger.py
├── models
├── imagenet
│ ├── __init__.py
│ └── resnext.py
└── cifar
│ ├── __init__.py
│ ├── mobilenetv2.py
│ ├── wrn.py
│ ├── vgg.py
│ ├── densenet.py
│ ├── resnet.py
│ ├── resnext.py
│ └── shufflenetv2.py
├── requirements.txt
├── scripts
├── activations_grouped_5_5_vgg19.sh
├── train_pruned_grouped.sh
├── activations_grouped_vgg19.sh
├── activations_grouped_vgg19_cifar100.sh
├── activations_grouped_resnet110_cifar100.sh
├── activations_grouped_resnet164_cifar100.sh
├── train_pruned_grouped.py
└── training_scheduler.py
├── load_model.py
├── datasets
├── utils.py
└── cifar.py
├── .gitignore
├── apoz_policy_imagenet.py
├── retrain_grouped_model.py
├── get_prune_candidates.py
├── compute_flops.py
├── README.md
├── logger.py
├── apoz_policy.py
├── group_selection.py
├── imagenet_evaluate_grouped.py
├── imagenet_dataset.py
├── prune_utils
├── layer_prune.py
└── prune.py
├── even_k_means.py
├── regularize_model.py
├── LICENSE.md
├── prune_and_get_model.py
├── imagenet_activations.py
└── evaluate.py
/sensAI-logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuanhuaWang/sensAI/HEAD/sensAI-logo.png
--------------------------------------------------------------------------------
/utils/images/cifar.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuanhuaWang/sensAI/HEAD/utils/images/cifar.png
--------------------------------------------------------------------------------
/models/imagenet/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from .resnext import *
4 |
--------------------------------------------------------------------------------
/utils/images/imagenet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuanhuaWang/sensAI/HEAD/utils/images/imagenet.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib
2 | numpy
3 | torch==1.5.0
4 | torchvision==0.6.0
5 | tqdm==4.46.1
6 | scikit-learn==0.21.3
7 |
--------------------------------------------------------------------------------
/models/cifar/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from .vgg import *
4 | from .resnet import *
5 | from .mobilenetv2 import *
6 | from .shufflenetv2 import *
7 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """Useful utils
2 | """
3 | from .misc import *
4 | from .logger import *
5 | from .visualize import *
6 | from .eval import *
7 |
8 | # progress bar
9 | import os, sys
10 | sys.path.append(os.path.join(os.path.dirname(__file__), "progress"))
--------------------------------------------------------------------------------
/scripts/activations_grouped_5_5_vgg19.sh:
--------------------------------------------------------------------------------
1 | MODEL=./checkpoint_bearclaw.pth.tar
2 | rm -r prune_candidate_logs
3 | mkdir prune_candidate_logs
4 |
5 | python3 get_prune_candidates.py -a vgg19_bn --resume $MODEL --evaluate --grouped 1 3 5 7 9
6 | python3 get_prune_candidates.py -a vgg19_bn --resume $MODEL --evaluate --grouped 2 4 6 8 0
7 |
--------------------------------------------------------------------------------
/scripts/train_pruned_grouped.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | EPOCHS=2
3 | FROM=./pruned_models
4 | SAVE=${FROM}_retrained
5 | mkdir ${SAVE}
6 | rm ${SAVE}/* -r
7 | mkdir $SAVE/resnet164/
8 | mkdir $SAVE/logs/
9 | i=0
10 | group_idx=0
11 | for file in ${FROM}/resnet164/*
12 | do
13 | CUDA_VISIBLE_DEVICES=$i python3 cifar_group.py -a resnet164 --epochs ${EPOCHS} --pruned --schedule 40 60 --gamma 0.1 --resume $file --checkpoint $SAVE/ --train-batch 256 --dataset cifar100 > ${SAVE}/logs/log${group_idx}.txt &
14 | group_idx=$((group_idx+1))
15 | i=$((i+1))
16 | i=$(( $i % 4 ))
17 | if [ $i -eq 0 ] ; then
18 | wait
19 | fi
20 | done
21 |
22 |
--------------------------------------------------------------------------------
/scripts/activations_grouped_vgg19.sh:
--------------------------------------------------------------------------------
1 | MODEL=./vgg19bn-cifar100.pth.tar
2 | rm -r prune_candidate_logs
3 | mkdir prune_candidate_logs
4 | python3 get_prune_candidates.py -a vgg19_bn -d cifar100 --resume $MODEL --grouped 1 2 13 32 46 51 62 77 91 93
5 | python3 get_prune_candidates.py -a vgg19_bn -d cifar100 --resume $MODEL --grouped 20 23 24 29 30 58 69 72 73 95
6 | python3 get_prune_candidates.py -a vgg19_bn -d cifar100 --resume $MODEL --grouped 33 47 49 52 56 59 66 67 76 96
7 | python3 get_prune_candidates.py -a vgg19_bn -d cifar100 --resume $MODEL --grouped 5 11 31 37 38 39 64 75 84 97
8 | python3 get_prune_candidates.py -a vgg19_bn -d cifar100 --resume $MODEL --grouped 16 21 28 41 48 81 86 87 94 99
9 |
--------------------------------------------------------------------------------
/utils/eval.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 | import pdb
3 |
4 | __all__ = ['accuracy', 'accuracy_binary']
5 |
6 | def accuracy(output, target, topk=(1,)):
7 | """Computes the precision@k for the specified values of k"""
8 | maxk = max(topk)
9 | batch_size = target.size(0)
10 |
11 | _, pred = output.topk(maxk, 1, True, True)
12 | pred = pred.t()
13 | correct = pred.eq(target.view(1, -1).expand_as(pred))
14 |
15 | res = []
16 | for k in topk:
17 | correct_k = correct[:k].view(-1).float().sum(0)
18 | res.append(correct_k.mul_(100.0 / batch_size))
19 | return res
20 |
21 | def accuracy_binary(output, target):
22 | pred = output >= 0.0
23 | pred = pred.flatten().long()
24 | acc = pred.eq(target).sum().float() / target.numel()
25 | return acc.data
26 |
--------------------------------------------------------------------------------
/scripts/activations_grouped_vgg19_cifar100.sh:
--------------------------------------------------------------------------------
1 | MODEL=./vgg19bn-cifar100.pth.tar
2 | rm -r prune_candidate_logs
3 | mkdir prune_candidate_logs
4 | python3 get_prune_candidates.py -a vgg19_bn -d cifar100 --resume $MODEL --grouped 1 2 13 32 46 51 62 77 91 93
5 | python3 get_prune_candidates.py -a vgg19_bn -d cifar100 --resume $MODEL --grouped 20 23 24 29 30 58 69 72 73 95
6 | python3 get_prune_candidates.py -a vgg19_bn -d cifar100 --resume $MODEL --grouped 33 47 49 52 56 59 66 67 76 96
7 | python3 get_prune_candidates.py -a vgg19_bn -d cifar100 --resume $MODEL --grouped 5 11 31 37 38 39 64 75 84 97
8 | python3 get_prune_candidates.py -a vgg19_bn -d cifar100 --resume $MODEL --grouped 16 21 28 41 48 81 86 87 94 99
9 | python3 get_prune_candidates.py -a vgg19_bn -d cifar100 --resume $MODEL --grouped 12 15 17 25 60 68 71 85 89 90
10 | python3 get_prune_candidates.py -a vgg19_bn -d cifar100 --resume $MODEL --grouped 3 6 19 34 35 36 43 65 80 88
11 | python3 get_prune_candidates.py -a vgg19_bn -d cifar100 --resume $MODEL --grouped 0 9 14 54 57 63 82 83 92 98
12 | python3 get_prune_candidates.py -a vgg19_bn -d cifar100 --resume $MODEL --grouped 8 10 22 26 40 50 53 61 70 79
13 | python3 get_prune_candidates.py -a vgg19_bn -d cifar100 --resume $MODEL --grouped 4 7 18 27 42 44 45 55 74 78
14 |
--------------------------------------------------------------------------------
/scripts/activations_grouped_resnet110_cifar100.sh:
--------------------------------------------------------------------------------
1 | MODEL=/home/ubuntu/baseModel/pytorch-classification/checkpoints/cifar100/resnet-110/model_best.pth.tar
2 | rm -r prune_candidate_logs
3 | mkdir prune_candidate_logs
4 | python3 get_prune_candidates.py -a resnet110 -d cifar100 --resume $MODEL --grouped 0 10 53 54 57 62 70 82 83 92
5 | python3 get_prune_candidates.py -a resnet110 -d cifar100 --resume $MODEL --grouped 23 30 32 49 61 67 71 73 91 95
6 | python3 get_prune_candidates.py -a resnet110 -d cifar100 --resume $MODEL --grouped 5 16 20 25 28 40 84 86 87 94
7 | python3 get_prune_candidates.py -a resnet110 -d cifar100 --resume $MODEL --grouped 15 19 34 38 42 43 66 75 88 97
8 | python3 get_prune_candidates.py -a resnet110 -d cifar100 --resume $MODEL --grouped 9 11 12 17 37 39 68 69 76 98
9 | python3 get_prune_candidates.py -a resnet110 -d cifar100 --resume $MODEL --grouped 18 26 27 29 44 45 78 79 93 99
10 | python3 get_prune_candidates.py -a resnet110 -d cifar100 --resume $MODEL --grouped 8 13 41 46 48 58 81 85 89 90
11 | python3 get_prune_candidates.py -a resnet110 -d cifar100 --resume $MODEL --grouped 14 22 33 47 51 52 56 59 60 96
12 | python3 get_prune_candidates.py -a resnet110 -d cifar100 --resume $MODEL --grouped 3 4 21 31 55 63 64 72 74 80
13 | python3 get_prune_candidates.py -a resnet110 -d cifar100 --resume $MODEL --grouped 1 2 6 7 24 35 36 50 65 77
14 |
--------------------------------------------------------------------------------
/scripts/activations_grouped_resnet164_cifar100.sh:
--------------------------------------------------------------------------------
1 | MODEL=/home/ubuntu/baseModel/pytorch-classification/checkpoints/cifar100/resnet-164/model_best.pth.tar
2 | rm -r prune_candidate_logs
3 | mkdir prune_candidate_logs
4 | python3 get_prune_candidates.py -a resnet164 -d cifar100 --resume $MODEL --grouped 0 10 53 54 57 61 62 70 83 92
5 | python3 get_prune_candidates.py -a resnet164 -d cifar100 --resume $MODEL --grouped 28 30 39 67 69 71 73 91 95 99
6 | python3 get_prune_candidates.py -a resnet164 -d cifar100 --resume $MODEL --grouped 5 9 16 20 22 25 84 86 87 94
7 | python3 get_prune_candidates.py -a resnet164 -d cifar100 --resume $MODEL --grouped 34 35 36 38 50 65 66 88 97 98
8 | python3 get_prune_candidates.py -a resnet164 -d cifar100 --resume $MODEL --grouped 6 7 14 15 19 24 40 51 75 79
9 | python3 get_prune_candidates.py -a resnet164 -d cifar100 --resume $MODEL --grouped 23 33 47 49 52 56 59 60 82 96
10 | python3 get_prune_candidates.py -a resnet164 -d cifar100 --resume $MODEL --grouped 18 26 27 29 42 44 74 77 78 93
11 | python3 get_prune_candidates.py -a resnet164 -d cifar100 --resume $MODEL --grouped 2 8 11 41 45 46 48 58 85 89
12 | python3 get_prune_candidates.py -a resnet164 -d cifar100 --resume $MODEL --grouped 3 4 21 31 43 55 63 64 72 80
13 | python3 get_prune_candidates.py -a resnet164 -d cifar100 --resume $MODEL --grouped 1 12 13 17 32 37 68 76 81 90
14 |
--------------------------------------------------------------------------------
/scripts/train_pruned_grouped.py:
--------------------------------------------------------------------------------
1 | from training_scheduler import train
2 | import os
3 | import shutil
4 | '''
5 | #!/bin/sh
6 | EPOCHS=80
7 | FROM=./pruned_models
8 | SAVE=${FROM}_retrained
9 | mkdir ${SAVE}
10 | rm ${SAVE}/* -r
11 | mkdir $SAVE/vgg19_bn/
12 | mkdir $SAVE/logs/
13 | i=0
14 | group_idx=0
15 | for file in ${FROM}/vgg19_bn/*
16 | do
17 | CUDA_VISIBLE_DEVICES=$i python3 cifar_group.py -a vgg19_bn --epochs ${EPOCHS} --pruned --schedule 40 60 --gamma 0.1 --resume $file --checkpoint $SAVE/ --train-batch 256 --dataset cifar100 > ${SAVE}/logs/log${group_idx}.txt &
18 | group_idx=$((group_idx+1))
19 | i=$((i+1))
20 | i=$(( $i % 4 ))
21 | if [ $i -eq 0 ] ; then
22 | wait
23 | fi
24 | done
25 | '''
26 |
27 | num_epochs = 80
28 | model_dir = "./pruned_models"
29 | save_dir = model_dir + "_retrained"
30 | if os.path.isdir(save_dir):
31 | shutil.rmtree(save_dir)
32 | os.mkdir(save_dir)
33 | os.mkdir(save_dir+"/vgg19_bn/")
34 | os.mkdir(save_dir+"/logs/")
35 |
36 | i = 0
37 | group_idx = 0
38 | commands = []
39 | for file in os.listdir(model_dir+"/vgg19_bn/"):
40 | command = "python3 cifar_group.py -a vgg19_bn --epochs " + str(num_epochs) + " --pruned --schedule 40 60 --gamma 0.1 --resume " + model_dir + "/vgg19_bn/" + file + " --checkpoint " + save_dir + "/ --train-batch 256 --dataset cifar100 > " + save_dir + "/logs/log" + str(group_idx) + ".txt"
41 | group_idx += 1
42 | i = (i + 1) % 4
43 | commands.append(command)
44 | print(commands)
45 | # train(executables=commands)
--------------------------------------------------------------------------------
/load_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 | import models.cifar as cifar_models
5 |
6 |
7 | def model_arches(dataset):
8 | if dataset == 'cifar':
9 | return sorted(name for name in cifar_models.__dict__
10 | if name.islower() and not name.startswith("__")
11 | and callable(cifar_models.__dict__[name]))
12 | else:
13 | raise NotImplementedError
14 |
15 |
16 |
17 | def load_pretrain_model(arch, dataset, resume_checkpoint, num_classes, use_cuda):
18 | print('==> Resuming from checkpoint..')
19 | assert os.path.isfile(resume_checkpoint), 'Error: no checkpoint found!'
20 | if use_cuda:
21 | checkpoint = torch.load(resume_checkpoint)
22 | else:
23 | checkpoint = torch.load(
24 | resume_checkpoint, map_location=torch.device('cpu'))
25 | if dataset.startswith('cifar'):
26 | model = cifar_models.__dict__[arch](num_classes=num_classes)
27 | else:
28 | raise NotImplementedError(f"Unsupported dataset: {dataset}.")
29 |
30 | if use_cuda:
31 | model.cuda()
32 | state_dict = {}
33 | # deal with old torch version
34 | if arch != 'mobilenetv2' and arch != 'shufflenetv2':
35 | for k, v in checkpoint['state_dict'].items():
36 | state_dict[k.replace('module.', '')] = v
37 | model.load_state_dict(state_dict)
38 | else:
39 | for k, v in checkpoint['net'].items():
40 | state_dict[k.replace('module.', '')] = v
41 | model.load_state_dict(state_dict)
42 | return model
43 |
--------------------------------------------------------------------------------
/datasets/utils.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 | import numpy as np
3 | import math
4 |
5 |
6 | class DataSetWrapper(object):
7 | def __init__(self, dataset, class_group: Tuple[int], negative_samples=False):
8 | # The original dataset has been shuffled. Skip shuffling this dataset
9 | # for consistency.
10 | self.dataset = dataset
11 | self.class_group = class_group
12 | self.negative_samples = negative_samples
13 | self.targets = np.asarray(self.dataset.targets)
14 | # This is the bool mask for all classes in the given group.
15 | positive_mask = np.zeros_like(self.targets, dtype=bool)
16 | for class_index in class_group:
17 | positive_mask |= (self.targets == class_index)
18 | positive_class_indices = np.where(positive_mask)[0]
19 | if negative_samples:
20 | # For N negative samples, P positive samples, we need to append
21 | # (k * N - P) positive samples.
22 | k = len(class_group)
23 | P = len(positive_class_indices)
24 | N = len(self.targets) - P
25 | assert N >= P, "there are already more positive classes"
26 | ext_P = k * N - P
27 | repeat_n = math.ceil(ext_P / P)
28 | extented_indices = np.repeat(
29 | positive_class_indices, repeat_n)[:ext_P]
30 | # fuse and shuffle
31 | all_indices = np.arange(len(self.targets))
32 | fullset = np.concatenate([all_indices, extented_indices])
33 | np.random.shuffle(fullset)
34 | self.mapping = fullset
35 | else:
36 | self.mapping = positive_class_indices
37 |
38 | def __getitem__(self, i):
39 | index = self.mapping[i]
40 | data, label = self.dataset[index]
41 | if label in self.class_group:
42 | label = list(self.class_group).index(label) + 1
43 | else:
44 | label = 0
45 | return data, label
46 |
47 | def __len__(self):
48 | return len(self.mapping)
49 |
50 | @property
51 | def num_classes(self):
52 | return len(self.class_group) + 1
53 |
--------------------------------------------------------------------------------
/datasets/cifar.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision import datasets
3 | from torchvision import transforms
4 | from typing import List, Tuple
5 |
6 | from datasets import utils
7 |
8 |
9 | # Transformations
10 | RC = transforms.RandomCrop(32, padding=4)
11 | RHF = transforms.RandomHorizontalFlip()
12 | RVF = transforms.RandomVerticalFlip()
13 | NRM = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
14 | TT = transforms.ToTensor()
15 | TPIL = transforms.ToPILImage()
16 |
17 | # Transforms object for trainset with augmentation
18 | transform_with_aug = transforms.Compose([RC, RHF, TT, NRM])
19 | # Transforms object for testset with NO augmentation
20 | transform_no_aug = transforms.Compose([TT, NRM])
21 |
22 |
23 | DATASET_ROOT = './data/'
24 |
25 |
26 | class CIFAR10TrainingSetWrapper(utils.DataSetWrapper):
27 | def __init__(self, class_group: Tuple[int], negative_samples=False):
28 | dataset = datasets.CIFAR10(root=DATASET_ROOT, train=True,
29 | download=True, transform=transform_with_aug)
30 | super().__init__(dataset, class_group, negative_samples)
31 |
32 |
33 | class CIFAR10TestingSetWrapper(utils.DataSetWrapper):
34 | def __init__(self, class_group: Tuple[int], negative_samples=False):
35 | dataset = datasets.CIFAR10(root=DATASET_ROOT, train=False,
36 | download=True, transform=transform_no_aug)
37 | super().__init__(dataset, class_group, negative_samples)
38 |
39 |
40 | class CIFAR100TrainingSetWrapper(utils.DataSetWrapper):
41 | def __init__(self, class_group: Tuple[int], negative_samples=False):
42 | dataset = datasets.CIFAR100(root=DATASET_ROOT, train=True,
43 | download=True, transform=transform_with_aug)
44 | super().__init__(dataset, class_group, negative_samples)
45 |
46 |
47 | class CIFAR100TestingSetWrapper(utils.DataSetWrapper):
48 | def __init__(self, class_group: Tuple[int], negative_samples=False):
49 | dataset = datasets.CIFAR100(root=DATASET_ROOT, train=False,
50 | download=True, transform=transform_no_aug)
51 | super().__init__(dataset, class_group, negative_samples)
52 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | limo/vgg-pruning/pruned_models_*
2 | pytorch-classification/data/
3 | limo/vgg-pruning/pruned_models/
4 | pytorch-classification/checkpoints/
5 | limo/vgg-pruning/pruned_models_20_epochs/
6 |
7 | # Byte-compiled / optimized / DLL files
8 | __pycache__/
9 | *.py[cod]
10 | *$py.class
11 |
12 | # C extensions
13 | *.so
14 |
15 | # Distribution / packaging
16 | .Python
17 | build/
18 | develop-eggs/
19 | dist/
20 | downloads/
21 | eggs/
22 | .eggs/
23 | lib/
24 | lib64/
25 | parts/
26 | sdist/
27 | var/
28 | wheels/
29 | pip-wheel-metadata/
30 | share/python-wheels/
31 | *.egg-info/
32 | .installed.cfg
33 | *.egg
34 | MANIFEST
35 |
36 | # PyInstaller
37 | # Usually these files are written by a python script from a template
38 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
39 | *.manifest
40 | *.spec
41 |
42 | # Installer logs
43 | pip-log.txt
44 | pip-delete-this-directory.txt
45 |
46 | # Unit test / coverage reports
47 | htmlcov/
48 | .tox/
49 | .nox/
50 | .coverage
51 | .coverage.*
52 | .cache
53 | nosetests.xml
54 | coverage.xml
55 | *.cover
56 | .hypothesis/
57 | .pytest_cache/
58 |
59 | # Translations
60 | *.mo
61 | *.pot
62 |
63 | # Django stuff:
64 | *.log
65 | local_settings.py
66 | db.sqlite3
67 |
68 | # Flask stuff:
69 | instance/
70 | .webassets-cache
71 |
72 | # Scrapy stuff:
73 | .scrapy
74 |
75 | # Sphinx documentation
76 | docs/_build/
77 |
78 | # PyBuilder
79 | target/
80 |
81 | # Jupyter Notebook
82 | .ipynb_checkpoints
83 |
84 | # IPython
85 | profile_default/
86 | ipython_config.py
87 |
88 | # pyenv
89 | .python-version
90 |
91 | # celery beat schedule file
92 | celerybeat-schedule
93 |
94 | # SageMath parsed files
95 | *.sage.py
96 |
97 | # Environments
98 | .env
99 | .venv
100 | env/
101 | venv/
102 | ENV/
103 | env.bak/
104 | venv.bak/
105 |
106 | # Spyder project settings
107 | .spyderproject
108 | .spyproject
109 |
110 | # Rope project settings
111 | .ropeproject
112 |
113 | # mkdocs documentation
114 | /site
115 |
116 | # mypy
117 | .mypy_cache/
118 | .dmypy.json
119 | dmypy.json
120 |
121 | # Pyre type checker
122 | .pyre/
123 |
124 | *.pth.tar
125 | pruned_models/
126 | pruned_models_retrained/
127 | prune_candidate_logs/
128 | data/
129 |
--------------------------------------------------------------------------------
/utils/misc.py:
--------------------------------------------------------------------------------
1 | '''Some helper functions for PyTorch, including:
2 | - get_mean_and_std: calculate the mean and std value of dataset.
3 | - msr_init: net parameter initialization.
4 | - progress_bar: progress bar mimic xlua.progress.
5 | '''
6 | import errno
7 | import os
8 | import sys
9 | import time
10 | import math
11 |
12 | import torch.nn as nn
13 | import torch.nn.init as init
14 | from torch.autograd import Variable
15 |
16 | __all__ = ['get_mean_and_std', 'init_params', 'mkdir_p', 'AverageMeter']
17 |
18 |
19 | def get_mean_and_std(dataset):
20 | '''Compute the mean and std value of dataset.'''
21 | dataloader = trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
22 |
23 | mean = torch.zeros(3)
24 | std = torch.zeros(3)
25 | print('==> Computing mean and std..')
26 | for inputs, targets in dataloader:
27 | for i in range(3):
28 | mean[i] += inputs[:,i,:,:].mean()
29 | std[i] += inputs[:,i,:,:].std()
30 | mean.div_(len(dataset))
31 | std.div_(len(dataset))
32 | return mean, std
33 |
34 | def init_params(net):
35 | '''Init layer parameters.'''
36 | for m in net.modules():
37 | if isinstance(m, nn.Conv2d):
38 | init.kaiming_normal(m.weight, mode='fan_out')
39 | if m.bias:
40 | init.constant(m.bias, 0)
41 | elif isinstance(m, nn.BatchNorm2d):
42 | init.constant(m.weight, 1)
43 | init.constant(m.bias, 0)
44 | elif isinstance(m, nn.Linear):
45 | init.normal(m.weight, std=1e-3)
46 | if m.bias:
47 | init.constant(m.bias, 0)
48 |
49 | def mkdir_p(path):
50 | '''make dir if not exist'''
51 | try:
52 | os.makedirs(path)
53 | except OSError as exc: # Python >2.5
54 | if exc.errno == errno.EEXIST and os.path.isdir(path):
55 | pass
56 | else:
57 | raise
58 |
59 | class AverageMeter(object):
60 | """Computes and stores the average and current value
61 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
62 | """
63 | def __init__(self):
64 | self.reset()
65 |
66 | def reset(self):
67 | self.val = 0
68 | self.avg = 0
69 | self.sum = 0
70 | self.count = 0
71 |
72 | def update(self, val, n=1):
73 | self.val = val
74 | self.sum += val * n
75 | self.count += n
76 | self.avg = self.sum / self.count
--------------------------------------------------------------------------------
/apoz_policy_imagenet.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import io
4 |
5 | """
6 | Calculate the Average Percentage of Zeros Score of the feature map activation layer output
7 | """
8 | def apoz_scoring(activation):
9 | if activation.dim() == 4:
10 | view_2d = activation.view(-1, activation.size(2) * activation.size(3)) # (batch*channels) x (h*w)
11 | featuremap_apoz = view_2d.abs().gt(0.005).sum(dim=1).float() / (activation.size(2) * activation.size(3)) # (batch*channels) x 1
12 | featuremap_apoz_mat = featuremap_apoz.view(activation.size(0), activation.size(1)) # batch x channels
13 | elif activation.dim() == 2 and activation.shape[1] == 1:
14 | featuremap_apoz_mat = activation.abs().gt(0.005).sum(dim=1).float() / activation.size(1)
15 | elif activation.dim() == 2: # FC Case: (batch x channels)
16 | featuremap_apoz_mat = activation.abs().gt(0.005).sum(dim=0).float()
17 | return 100 - featuremap_apoz_mat.mul(100)
18 | else:
19 | raise ValueError("activation_channels_apoz: Unsupported shape: ".format(activation.shape))
20 | return 100 - featuremap_apoz_mat.mean(dim=0).mul(100)
21 |
22 |
23 | def avg_scoring(activation):
24 | if activation.dim() == 4:
25 | view_2d = activation.view(-1, activation.size(2) * activation.size(3))
26 | featuremap_avg = view_2d.abs().sum(dim = 1).float() / (activation.size(2) * activation.size(3))
27 | featuremap_avg_mat = featuremap_avg.view(activation.size(0), activation.size(1))
28 | elif activation.dim() == 2 and activation.shape[1] == 1:
29 | featuremap_avg_mat = activation.abs().sum(dim = 1).float() / activation.size(1)
30 | elif activation.dim() == 2:
31 | featuremap_avg_mat = activation.abs().float()
32 | else:
33 | raise ValueError("activation_channels_avg: Unsupported shape: ".format(activation.shape))
34 | return featuremap_avg_mat.mean(dim = 0)
35 |
36 | def pruning_candidates(group_id, thresholds, file_name):
37 | layers_channels = []
38 | fmap_file = open(file_name, "rb")
39 | data_buffer = io.BytesIO(fmap_file.read())
40 | for _ in range(16):
41 | layers_channels.append(torch.load(data_buffer))
42 |
43 | candidates_by_layer = []
44 | print("Calculating pruning candidates for classe(s) {}".format(group_id))
45 | for index, layer in enumerate(layers_channels):
46 | apoz_score = apoz_scoring(layer)
47 | print(apoz_score.mean())
48 |
49 | curr_threshold = thresholds[index]
50 | while True:
51 | num_candidates = apoz_score.gt(curr_threshold).sum()
52 | print("Greater than {} %".format(curr_threshold), num_candidates)
53 | if num_candidates < apoz_score.size()[0]:
54 | candidates = [x[0] for x in apoz_score.gt(curr_threshold).nonzero().tolist()]
55 | break
56 | curr_threshold += 5
57 |
58 | print("Class Index: {}, Layer {}, Number of neurons with apoz > {}%: {}/{}".format(group_id, index, curr_threshold, len(candidates), apoz_score.size()[0]))
59 | candidates_by_layer.append(candidates)
60 | print("Zero channels out of total in layer {}: {}/{}".format(index, len(candidates) ,len(layer)))
61 | return candidates_by_layer
62 |
--------------------------------------------------------------------------------
/models/cifar/mobilenetv2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class Block(nn.Module):
7 | '''expand + depthwise + pointwise'''
8 | def __init__(self, in_planes, out_planes, expansion, stride):
9 | super(Block, self).__init__()
10 | self.stride = stride
11 |
12 | planes = expansion * in_planes
13 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, stride=1, padding=0, bias=False)
14 | self.bn1 = nn.BatchNorm2d(planes)
15 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, groups=planes, bias=False)
16 | self.bn2 = nn.BatchNorm2d(planes)
17 | self.conv3 = nn.Conv2d(planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
18 | self.bn3 = nn.BatchNorm2d(out_planes)
19 |
20 | self.shortcut = nn.Sequential()
21 | if stride == 1 and in_planes != out_planes:
22 | self.shortcut = nn.Sequential(
23 | nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False),
24 | nn.BatchNorm2d(out_planes),
25 | )
26 |
27 | def forward(self, x):
28 | out = F.relu(self.bn1(self.conv1(x)))
29 | out = F.relu(self.bn2(self.conv2(out)))
30 | out = self.bn3(self.conv3(out))
31 | out = out + self.shortcut(x) if self.stride==1 else out
32 | return out
33 |
34 |
35 | class MobileNetV2(nn.Module):
36 | # (expansion, out_planes, num_blocks, stride)
37 | cfg = [(1, 16, 1, 1),
38 | (6, 24, 2, 1), # NOTE: change stride 2 -> 1 for CIFAR10
39 | (6, 32, 3, 2),
40 | (6, 64, 4, 2),
41 | (6, 96, 3, 1),
42 | (6, 160, 3, 2),
43 | (6, 320, 1, 1)]
44 |
45 | def __init__(self, num_classes=10):
46 | super(MobileNetV2, self).__init__()
47 | # NOTE: change conv1 stride 2 -> 1 for CIFAR10
48 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False)
49 | self.bn1 = nn.BatchNorm2d(32)
50 | self.layers = self._make_layers(in_planes=32)
51 | self.conv2 = nn.Conv2d(320, 1280, kernel_size=1, stride=1, padding=0, bias=False)
52 | self.bn2 = nn.BatchNorm2d(1280)
53 | self.linear = nn.Linear(1280, num_classes)
54 |
55 | def _make_layers(self, in_planes):
56 | layers = []
57 | for expansion, out_planes, num_blocks, stride in self.cfg:
58 | strides = [stride] + [1]*(num_blocks-1)
59 | for stride in strides:
60 | layers.append(Block(in_planes, out_planes, expansion, stride))
61 | in_planes = out_planes
62 | return nn.Sequential(*layers)
63 |
64 | def forward(self, x, features_only=False):
65 | out = F.relu(self.bn1(self.conv1(x)))
66 | out = self.layers(out)
67 | out = F.relu(self.bn2(self.conv2(out)))
68 | # NOTE: change pooling kernel_size 7 -> 4 for CIFAR10
69 | out = F.avg_pool2d(out, 4)
70 | out = out.view(out.size(0), -1)
71 | if not features_only:
72 | out = self.linear(out)
73 | return out
74 |
75 | def mobilenetv2(**kwargs):
76 | model = MobileNetV2(num_classes=10)
77 | return model
--------------------------------------------------------------------------------
/retrain_grouped_model.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import glob
4 | import subprocess as sp
5 | import numpy as np
6 | parser = argparse.ArgumentParser(description='retrain pruned model')
7 | parser.add_argument('-d', '--dataset', required=True, type=str)
8 | parser.add_argument('--epochs', required=True, type=int)
9 | parser.add_argument('-a', '--arch', default='vgg19_bn',
10 | type=str, help='The architecture of the trained model')
11 | parser.add_argument('-r', '--resume', default='', type=str,
12 | help='The path to the checkpoints') ### pruned models are saved here
13 | parser.add_argument('--num_gpus', default=4, type=int)
14 | parser.add_argument('--train_batch', default=256, type=int)
15 | parser.add_argument('--data', default='/home/ubuntu/imagenet', required=False, type=str,
16 | help='location of the imagenet dataset that includes train/val')
17 |
18 | args = parser.parse_args()
19 |
20 |
21 | def main():
22 | save = args.resume[:-1] +'_retrained/'
23 | groups = np.load(open(args.resume + "grouping_config.npy", "rb"))
24 | resultExist = os.path.exists(save)
25 | if resultExist:
26 | rm_cmd = 'rm -rf ' + save
27 | sp.Popen(rm_cmd, shell=True)
28 | os.mkdir(save)
29 | np.save(open(os.path.join(save[:-1], "grouping_config.npy"), "wb"), groups)
30 | save += args.arch
31 | os.mkdir(save)
32 | files = [f for f in glob.glob(args.resume + args.arch+"/*.pth", recursive=False)]
33 | process_list = [None for _ in range(args.num_gpus)]
34 | if args.dataset in ['cifar10', 'cifar100']:
35 | for i, file in enumerate(files):
36 | if process_list[i % args.num_gpus]:
37 | process_list[i % args.num_gpus].wait()
38 | exec_cmd = 'python3 cifar_group.py' +\
39 | ' --arch %s' % args.arch +\
40 | ' --resume %s' % file +\
41 | ' --schedule 40 60' +\
42 | ' --gamma 0.1' +\
43 | ' --epochs %d' % args.epochs +\
44 | ' --checkpoint %s' % save +\
45 | ' --train-batch %d' % args.train_batch +\
46 | ' --dataset %s' % args.dataset +\
47 | ' --grouping_dir %s' % args.resume +\
48 | ' --pruned' +\
49 | ' --gpu_id %d' % (i % args.num_gpus)
50 | process_list[i % args.num_gpus] = sp.Popen(exec_cmd, shell=True)
51 | elif args.dataset in 'imagenet':
52 | for i, file in enumerate(files):
53 | if process_list[i % args.num_gpus]:
54 | process_list[i % args.num_gpus].wait()
55 | exec_cmd = 'python3 imagenet_official_retrain.py' +\
56 | ' --data %s' % args.data +\
57 | ' --arch %s' % args.arch +\
58 | ' --resume %s' % file +\
59 | ' --schedule 10 15' +\
60 | ' --config %s' % args.resume + '/grouping_config.npy' +\
61 | ' --gamma 0.1 ' +\
62 | ' --batch_size %d' % args.train_batch +\
63 | ' --epochs %d' % args.epochs +\
64 | ' --checkpoint %s' % save +\
65 | ' --gpu %s' % (i % args.num_gpus)
66 | process_list[i % args.num_gpus] = sp.Popen(exec_cmd, shell=True)
67 |
68 | if __name__ == '__main__':
69 | main()
70 |
--------------------------------------------------------------------------------
/get_prune_candidates.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import pickle
3 |
4 | import torch
5 | from torch import nn
6 | import torch.backends.cudnn as cudnn
7 |
8 | from apoz_policy import ActivationRecord
9 | from datasets import cifar
10 | import load_model
11 | from tqdm import tqdm
12 | import os
13 | from regularize_model import standard
14 |
15 |
16 | parser = argparse.ArgumentParser(
17 | description='PyTorch CIFAR10/100 Generate Class Specific Information')
18 | # Datasets
19 | parser.add_argument('-d', '--dataset', required=True, type=str)
20 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
21 | help='number of data loading workers (default: 4)')
22 | parser.add_argument('--resume', required=True, default='', type=str, metavar='PATH',
23 | help='path to latest checkpoint (default: none)')
24 | # Architecture
25 | parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet20',
26 | choices=load_model.model_arches('cifar'),
27 | help='model architecture: ' +
28 | ' | '.join(load_model.model_arches('cifar')) +
29 | ' (default: resnet18)')
30 | # Miscs
31 | parser.add_argument('--seed', type=int, default=42, help='manual seed')
32 | parser.add_argument('--grouped', required=True, type=int, nargs='+', default=[],
33 | help='Generate activations based on the these class indices')
34 | parser.add_argument('--group_number', required=True, type=int,
35 | help='Group number')
36 | parser.add_argument('--gpu_num', default='0', type=str,
37 | help='GPU number')
38 |
39 |
40 | args = parser.parse_args()
41 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_num
42 | use_cuda = torch.cuda.is_available()
43 |
44 | # Random seed
45 | torch.manual_seed(args.seed)
46 | if use_cuda:
47 | torch.cuda.manual_seed_all(args.seed)
48 |
49 | assert args.grouped
50 |
51 |
52 | def main():
53 | if args.dataset == 'cifar10':
54 | dataset = cifar.CIFAR10TrainingSetWrapper(args.grouped, False)
55 | num_classes = 10
56 | elif args.dataset == 'cifar100':
57 | dataset = cifar.CIFAR100TrainingSetWrapper(args.grouped, False)
58 | num_classes = 100
59 | else:
60 | raise NotImplementedError(
61 | f"There's no support for '{args.dataset}' dataset.")
62 |
63 | pruning_loader = torch.utils.data.DataLoader(
64 | dataset,
65 | batch_size=1000,
66 | num_workers=args.workers,
67 | pin_memory=False)
68 |
69 | model = load_model.load_pretrain_model(
70 | args.arch, 'cifar', args.resume, num_classes, use_cuda)
71 |
72 | if args.arch in ["mobilenetv2", "shufflenetv2"]:
73 | model = standard(model, args.arch, num_classes)
74 |
75 | if use_cuda:
76 | model.cuda()
77 | print('\nMake a test run to generate activations. \n Using training set.\n')
78 | with ActivationRecord(model, args.arch) as recorder:
79 | # collect pruning data
80 | #bar = tqdm(total=len(pruning_loader))
81 | for batch_idx, (inputs, _) in enumerate(pruning_loader):
82 | #bar.update(1)
83 | if use_cuda:
84 | inputs = inputs.cuda()
85 | recorder.record_batch(inputs)
86 | candidates_by_layer = recorder.generate_pruned_candidates()
87 |
88 | with open(f"prune_candidate_logs/group_{args.group_number}_apoz_layer_thresholds.npy", "wb") as f:
89 | pickle.dump(candidates_by_layer, f)
90 | print(candidates_by_layer)
91 |
92 |
93 | if __name__ == '__main__':
94 | main()
95 |
--------------------------------------------------------------------------------
/utils/visualize.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import torch
3 | import torch.nn as nn
4 | import torchvision
5 | import torchvision.transforms as transforms
6 | import numpy as np
7 | from .misc import *
8 |
9 | __all__ = ['make_image', 'show_batch', 'show_mask', 'show_mask_single']
10 |
11 | # functions to show an image
12 | def make_image(img, mean=(0,0,0), std=(1,1,1)):
13 | for i in range(0, 3):
14 | img[i] = img[i] * std[i] + mean[i] # unnormalize
15 | npimg = img.numpy()
16 | return np.transpose(npimg, (1, 2, 0))
17 |
18 | def gauss(x,a,b,c):
19 | return torch.exp(-torch.pow(torch.add(x,-b),2).div(2*c*c)).mul(a)
20 |
21 | def colorize(x):
22 | ''' Converts a one-channel grayscale image to a color heatmap image '''
23 | if x.dim() == 2:
24 | torch.unsqueeze(x, 0, out=x)
25 | if x.dim() == 3:
26 | cl = torch.zeros([3, x.size(1), x.size(2)])
27 | cl[0] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3)
28 | cl[1] = gauss(x,1,.5,.3)
29 | cl[2] = gauss(x,1,.2,.3)
30 | cl[cl.gt(1)] = 1
31 | elif x.dim() == 4:
32 | cl = torch.zeros([x.size(0), 3, x.size(2), x.size(3)])
33 | cl[:,0,:,:] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3)
34 | cl[:,1,:,:] = gauss(x,1,.5,.3)
35 | cl[:,2,:,:] = gauss(x,1,.2,.3)
36 | return cl
37 |
38 | def show_batch(images, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)):
39 | images = make_image(torchvision.utils.make_grid(images), Mean, Std)
40 | plt.imshow(images)
41 | plt.show()
42 |
43 |
44 | def show_mask_single(images, mask, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)):
45 | im_size = images.size(2)
46 |
47 | # save for adding mask
48 | im_data = images.clone()
49 | for i in range(0, 3):
50 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize
51 |
52 | images = make_image(torchvision.utils.make_grid(images), Mean, Std)
53 | plt.subplot(2, 1, 1)
54 | plt.imshow(images)
55 | plt.axis('off')
56 |
57 | # for b in range(mask.size(0)):
58 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min())
59 | mask_size = mask.size(2)
60 | # print('Max %f Min %f' % (mask.max(), mask.min()))
61 | mask = (upsampling(mask, scale_factor=im_size/mask_size))
62 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size))
63 | # for c in range(3):
64 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c]
65 |
66 | # print(mask.size())
67 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data)))
68 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std)
69 | plt.subplot(2, 1, 2)
70 | plt.imshow(mask)
71 | plt.axis('off')
72 |
73 | def show_mask(images, masklist, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)):
74 | im_size = images.size(2)
75 |
76 | # save for adding mask
77 | im_data = images.clone()
78 | for i in range(0, 3):
79 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize
80 |
81 | images = make_image(torchvision.utils.make_grid(images), Mean, Std)
82 | plt.subplot(1+len(masklist), 1, 1)
83 | plt.imshow(images)
84 | plt.axis('off')
85 |
86 | for i in range(len(masklist)):
87 | mask = masklist[i].data.cpu()
88 | # for b in range(mask.size(0)):
89 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min())
90 | mask_size = mask.size(2)
91 | # print('Max %f Min %f' % (mask.max(), mask.min()))
92 | mask = (upsampling(mask, scale_factor=im_size/mask_size))
93 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size))
94 | # for c in range(3):
95 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c]
96 |
97 | # print(mask.size())
98 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data)))
99 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std)
100 | plt.subplot(1+len(masklist), 1, i+2)
101 | plt.imshow(mask)
102 | plt.axis('off')
103 |
--------------------------------------------------------------------------------
/models/cifar/wrn.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | __all__ = ['wrn']
7 |
8 | class BasicBlock(nn.Module):
9 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
10 | super(BasicBlock, self).__init__()
11 | self.bn1 = nn.BatchNorm2d(in_planes)
12 | self.relu1 = nn.ReLU(inplace=True)
13 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
14 | padding=1, bias=False)
15 | self.bn2 = nn.BatchNorm2d(out_planes)
16 | self.relu2 = nn.ReLU(inplace=True)
17 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
18 | padding=1, bias=False)
19 | self.droprate = dropRate
20 | self.equalInOut = (in_planes == out_planes)
21 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
22 | padding=0, bias=False) or None
23 | def forward(self, x):
24 | if not self.equalInOut:
25 | x = self.relu1(self.bn1(x))
26 | else:
27 | out = self.relu1(self.bn1(x))
28 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
29 | if self.droprate > 0:
30 | out = F.dropout(out, p=self.droprate, training=self.training)
31 | out = self.conv2(out)
32 | return torch.add(x if self.equalInOut else self.convShortcut(x), out)
33 |
34 | class NetworkBlock(nn.Module):
35 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0):
36 | super(NetworkBlock, self).__init__()
37 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate)
38 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate):
39 | layers = []
40 | for i in range(nb_layers):
41 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate))
42 | return nn.Sequential(*layers)
43 | def forward(self, x):
44 | return self.layer(x)
45 |
46 | class WideResNet(nn.Module):
47 | def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0):
48 | super(WideResNet, self).__init__()
49 | nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor]
50 | assert (depth - 4) % 6 == 0, 'depth should be 6n+4'
51 | n = (depth - 4) // 6
52 | block = BasicBlock
53 | # 1st conv before any network block
54 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
55 | padding=1, bias=False)
56 | # 1st block
57 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)
58 | # 2nd block
59 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
60 | # 3rd block
61 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
62 | # global average pooling and classifier
63 | self.bn1 = nn.BatchNorm2d(nChannels[3])
64 | self.relu = nn.ReLU(inplace=True)
65 | self.fc = nn.Linear(nChannels[3], num_classes)
66 | self.nChannels = nChannels[3]
67 |
68 | for m in self.modules():
69 | if isinstance(m, nn.Conv2d):
70 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
71 | m.weight.data.normal_(0, math.sqrt(2. / n))
72 | elif isinstance(m, nn.BatchNorm2d):
73 | m.weight.data.fill_(1)
74 | m.bias.data.zero_()
75 | elif isinstance(m, nn.Linear):
76 | m.bias.data.zero_()
77 |
78 | def forward(self, x):
79 | out = self.conv1(x)
80 | out = self.block1(out)
81 | out = self.block2(out)
82 | out = self.block3(out)
83 | out = self.relu(self.bn1(out))
84 | out = F.avg_pool2d(out, 8)
85 | out = out.view(-1, self.nChannels)
86 | return self.fc(out)
87 |
88 | def wrn(**kwargs):
89 | """
90 | Constructs a Wide Residual Networks.
91 | """
92 | model = WideResNet(**kwargs)
93 | return model
94 |
--------------------------------------------------------------------------------
/scripts/training_scheduler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import threading
3 | import subprocess
4 | import multiprocessing as mp
5 | import os
6 |
7 | pruned_model_path="./pruned_models/vgg19_bn/"
8 | retrained_model_path="./retrained_model/vgg19_bn/"
9 | '''
10 | 1) initialize bounded producer/consumer queue of size max(num_devices (param), output from torch.cuda.device_count())
11 | '''
12 | def train(executables, allowable_devices=range(torch.cuda.device_count())):
13 | free_devices = mp.Queue(maxsize=len(allowable_devices))
14 | for i in allowable_devices:
15 | free_devices.put(i)
16 | for executable in executables:
17 | assigned_device = free_devices.get()
18 | print("script: '" + str(executable) + "' assigned to GPU: " + str(assigned_device))
19 | mp.Process(target=execute_on_device, args=(assigned_device, executable, free_devices)).start()
20 |
21 | def execute_on_device(GPU_ID, executable, free_devices):
22 | # train the model
23 | os.environ['CUDA_VISIBLE_DEVICES'] = str(GPU_ID)
24 | executable_tokens = executable.split(" ")
25 | stdout_file = None
26 | if ">" in executable_tokens:
27 | idx = executable_tokens.index(">")
28 | stdout_file = open(executable_tokens[idx+1], "w")
29 | executable_tokens = executable_tokens[:idx]
30 | print(stdout_file)
31 | subprocess.run(executable_tokens, stdout=stdout_file)
32 | # mark this GPU as free
33 | free_devices.put(GPU_ID)
34 | if stdout_file is not None:
35 | stdout_file.close()
36 |
37 | def get_stdout(executable_tokens):
38 | if '>' in executable_tokens:
39 | return executable
40 | else:
41 | return None
42 |
43 | if __name__ == '__main__':
44 | to_train = [
45 | "python3 cifar_binary.py --pruned -a vgg19_bn --lr 0.01 --epochs 40 --schedule 20 30 --gamma 0.1 --resume "+pruned_model_path+"vgg19_bn_0_pruned_model.pth --checkpoint "+retrained_model_path+"vgg19_bn_0_pruned_model --train-batch 64 --class-index 0",
46 | "python3 cifar_binary.py --pruned -a vgg19_bn --lr 0.01 --epochs 40 --schedule 20 30 --gamma 0.1 --resume "+pruned_model_path+"vgg19_bn_1_pruned_model.pth --checkpoint "+retrained_model_path+"vgg19_bn_1_pruned_model --train-batch 64 --class-index 1",
47 | "python3 cifar_binary.py --pruned -a vgg19_bn --lr 0.01 --epochs 40 --schedule 20 30 --gamma 0.1 --resume "+pruned_model_path+"vgg19_bn_2_pruned_model.pth --checkpoint "+retrained_model_path+"vgg19_bn_2_pruned_model --train-batch 64 --class-index 2",
48 | "python3 cifar_binary.py --pruned -a vgg19_bn --lr 0.01 --epochs 40 --schedule 20 30 --gamma 0.1 --resume "+pruned_model_path+"vgg19_bn_3_pruned_model.pth --checkpoint "+retrained_model_path+"vgg19_bn_3_pruned_model --train-batch 64 --class-index 3",
49 | "python3 cifar_binary.py --pruned -a vgg19_bn --lr 0.01 --epochs 40 --schedule 20 30 --gamma 0.1 --resume "+pruned_model_path+"vgg19_bn_4_pruned_model.pth --checkpoint "+retrained_model_path+"vgg19_bn_4_pruned_model --train-batch 64 --class-index 4",
50 | "python3 cifar_binary.py --pruned -a vgg19_bn --lr 0.01 --epochs 40 --schedule 20 30 --gamma 0.1 --resume "+pruned_model_path+"vgg19_bn_5_pruned_model.pth --checkpoint "+retrained_model_path+"vgg19_bn_5_pruned_model --train-batch 64 --class-index 5",
51 | "python3 cifar_binary.py --pruned -a vgg19_bn --lr 0.01 --epochs 40 --schedule 20 30 --gamma 0.1 --resume "+pruned_model_path+"vgg19_bn_6_pruned_model.pth --checkpoint "+retrained_model_path+"vgg19_bn_6_pruned_model --train-batch 64 --class-index 6",
52 | "python3 cifar_binary.py --pruned -a vgg19_bn --lr 0.01 --epochs 40 --schedule 20 30 --gamma 0.1 --resume "+pruned_model_path+"vgg19_bn_7_pruned_model.pth --checkpoint "+retrained_model_path+"vgg19_bn_7_pruned_model --train-batch 64 --class-index 7",
53 | "python3 cifar_binary.py --pruned -a vgg19_bn --lr 0.01 --epochs 40 --schedule 20 30 --gamma 0.1 --resume "+pruned_model_path+"vgg19_bn_8_pruned_model.pth --checkpoint "+retrained_model_path+"vgg19_bn_8_pruned_model --train-batch 64 --class-index 8",
54 | "python3 cifar_binary.py --pruned -a vgg19_bn --lr 0.01 --epochs 40 --schedule 20 30 --gamma 0.1 --resume "+pruned_model_path+"vgg19_bn_9_pruned_model.pth --checkpoint "+retrained_model_path+"vgg19_bn_9_pruned_model --train-batch 64 --class-index 9",
55 | ]
56 | train(to_train)
57 |
--------------------------------------------------------------------------------
/compute_flops.py:
--------------------------------------------------------------------------------
1 | # Code from https://github.com/simochen/model-tools.
2 | import numpy as np
3 |
4 | import torch
5 | import torchvision
6 | import torch.nn as nn
7 | from torch.autograd import Variable
8 |
9 |
10 | def print_model_param_nums(model=None, multiply_adds=True):
11 | if model == None:
12 | model = torchvision.models.alexnet()
13 | total = sum([param.nelement() for param in model.parameters()])
14 | print(' + Number of params: %.2fM' % (total / 1e6))
15 |
16 | def print_model_param_flops(model=None, input_res=224, multiply_adds=True):
17 |
18 | prods = {}
19 | def save_hook(name):
20 | def hook_per(self, input, output):
21 | prods[name] = np.prod(input[0].shape)
22 | return hook_per
23 |
24 | list_1=[]
25 | def simple_hook(self, input, output):
26 | list_1.append(np.prod(input[0].shape))
27 | list_2={}
28 | def simple_hook2(self, input, output):
29 | list_2['names'] = np.prod(input[0].shape)
30 |
31 | list_conv=[]
32 | def conv_hook(self, input, output):
33 | batch_size, input_channels, input_height, input_width = input[0].size()
34 | output_channels, output_height, output_width = output[0].size()
35 |
36 | kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups)
37 | bias_ops = 1 if self.bias is not None else 0
38 |
39 | params = output_channels * (kernel_ops + bias_ops)
40 | flops = (kernel_ops * (2 if multiply_adds else 1) + bias_ops) * output_channels * output_height * output_width * batch_size
41 |
42 | list_conv.append(flops)
43 |
44 | list_linear=[]
45 | def linear_hook(self, input, output):
46 | batch_size = input[0].size(0) if input[0].dim() == 2 else 1
47 |
48 | weight_ops = self.weight.nelement() * (2 if multiply_adds else 1)
49 | bias_ops = self.bias.nelement()
50 |
51 | flops = batch_size * (weight_ops + bias_ops)
52 | list_linear.append(flops)
53 |
54 | list_bn=[]
55 | def bn_hook(self, input, output):
56 | list_bn.append(input[0].nelement() * 2)
57 |
58 | list_relu=[]
59 | def relu_hook(self, input, output):
60 | list_relu.append(input[0].nelement())
61 |
62 | list_pooling=[]
63 | def pooling_hook(self, input, output):
64 | batch_size, input_channels, input_height, input_width = input[0].size()
65 | output_channels, output_height, output_width = output[0].size()
66 |
67 | kernel_ops = self.kernel_size * self.kernel_size
68 | bias_ops = 0
69 | params = 0
70 | flops = (kernel_ops + bias_ops) * output_channels * output_height * output_width * batch_size
71 |
72 | list_pooling.append(flops)
73 |
74 | list_upsample=[]
75 | # For bilinear upsample
76 | def upsample_hook(self, input, output):
77 | batch_size, input_channels, input_height, input_width = input[0].size()
78 | output_channels, output_height, output_width = output[0].size()
79 |
80 | flops = output_height * output_width * output_channels * batch_size * 12
81 | list_upsample.append(flops)
82 |
83 | def foo(net):
84 | childrens = list(net.children())
85 | if not childrens:
86 | if isinstance(net, torch.nn.Conv2d):
87 | net.register_forward_hook(conv_hook)
88 | if isinstance(net, torch.nn.Linear):
89 | net.register_forward_hook(linear_hook)
90 | if isinstance(net, torch.nn.BatchNorm2d):
91 | net.register_forward_hook(bn_hook)
92 | if isinstance(net, torch.nn.ReLU):
93 | net.register_forward_hook(relu_hook)
94 | if isinstance(net, torch.nn.MaxPool2d) or isinstance(net, torch.nn.AvgPool2d):
95 | net.register_forward_hook(pooling_hook)
96 | if isinstance(net, torch.nn.Upsample):
97 | net.register_forward_hook(upsample_hook)
98 | return
99 | for c in childrens:
100 | foo(c)
101 |
102 | if model == None:
103 | model = torchvision.models.alexnet()
104 | foo(model)
105 | input = torch.rand(3, 3, input_res, input_res)
106 | if input.is_cuda:
107 | model.cuda()
108 | else:
109 | model.cpu()
110 | with torch.no_grad():
111 | _ = model(input)
112 |
113 | total_flops = (sum(list_conv) + sum(list_linear) + sum(list_bn) + sum(list_relu) + sum(list_pooling) + sum(list_upsample))
114 |
115 | print(' + Number of FLOPs: %.5fG' % (total_flops / 1e9))
116 |
117 | return total_flops
118 |
--------------------------------------------------------------------------------
/models/cifar/vgg.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.utils.model_zoo as model_zoo
3 | import math
4 |
5 |
6 | __all__ = [
7 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
8 | 'vgg19_bn', 'vgg19',
9 | ]
10 |
11 |
12 | model_urls = {
13 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
14 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
15 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
16 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
17 | }
18 |
19 |
20 | class VGG(nn.Module):
21 |
22 | def __init__(self, features, num_classes=1000):
23 | super(VGG, self).__init__()
24 | self.features = features
25 | self.classifier = nn.Linear(512, num_classes)
26 | self._initialize_weights()
27 |
28 | def forward(self, x, features_only=False):
29 | x = self.features(x)
30 | x = x.view(x.size(0), -1)
31 | if not features_only:
32 | x = self.classifier(x)
33 | return x
34 |
35 | def _initialize_weights(self):
36 | for m in self.modules():
37 | if isinstance(m, nn.Conv2d):
38 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
39 | m.weight.data.normal_(0, math.sqrt(2. / n))
40 | if m.bias is not None:
41 | m.bias.data.zero_()
42 | elif isinstance(m, nn.BatchNorm2d):
43 | m.weight.data.fill_(1)
44 | m.bias.data.zero_()
45 | elif isinstance(m, nn.Linear):
46 | n = m.weight.size(1)
47 | m.weight.data.normal_(0, 0.01)
48 | m.bias.data.zero_()
49 |
50 |
51 | def make_layers(cfg, batch_norm=False):
52 | layers = []
53 | in_channels = 3
54 | for v in cfg:
55 | if v == 'M':
56 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
57 | else:
58 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
59 | if batch_norm:
60 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
61 | else:
62 | layers += [conv2d, nn.ReLU(inplace=True)]
63 | in_channels = v
64 | return nn.Sequential(*layers)
65 |
66 |
67 | cfg = {
68 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
69 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
70 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
71 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
72 | }
73 |
74 |
75 | def vgg11(**kwargs):
76 | """VGG 11-layer model (configuration "A")
77 |
78 | Args:
79 | pretrained (bool): If True, returns a model pre-trained on ImageNet
80 | """
81 | model = VGG(make_layers(cfg['A']), **kwargs)
82 | return model
83 |
84 |
85 | def vgg11_bn(**kwargs):
86 | """VGG 11-layer model (configuration "A") with batch normalization"""
87 | model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs)
88 | return model
89 |
90 |
91 | def vgg13(**kwargs):
92 | """VGG 13-layer model (configuration "B")
93 |
94 | Args:
95 | pretrained (bool): If True, returns a model pre-trained on ImageNet
96 | """
97 | model = VGG(make_layers(cfg['B']), **kwargs)
98 | return model
99 |
100 |
101 | def vgg13_bn(**kwargs):
102 | """VGG 13-layer model (configuration "B") with batch normalization"""
103 | model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs)
104 | return model
105 |
106 |
107 | def vgg16(**kwargs):
108 | """VGG 16-layer model (configuration "D")
109 |
110 | Args:
111 | pretrained (bool): If True, returns a model pre-trained on ImageNet
112 | """
113 | model = VGG(make_layers(cfg['D']), **kwargs)
114 | return model
115 |
116 |
117 | def vgg16_bn(**kwargs):
118 | """VGG 16-layer model (configuration "D") with batch normalization"""
119 | model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs)
120 | return model
121 |
122 |
123 | def vgg19(**kwargs):
124 | """VGG 19-layer model (configuration "E")
125 |
126 | Args:
127 | pretrained (bool): If True, returns a model pre-trained on ImageNet
128 | """
129 | model = VGG(make_layers(cfg['E']), **kwargs)
130 | return model
131 |
132 |
133 | def vgg19_bn(**kwargs):
134 | """VGG 19-layer model (configuration 'E') with batch normalization"""
135 | model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs)
136 | return model
137 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | # sensAI: ConvNets Decomposition via Class Parallelism for Fast Inference on Live Data
6 |
7 | ## Environment
8 |
9 | Linux, python 3.6+
10 |
11 | ## Setup
12 |
13 | ```bash
14 | pip install -r requirements.txt
15 | ```
16 |
17 | ## Instruction
18 |
19 | Supported CNN architectures and datasets:
20 |
21 | | Dataset | Architecture(`ARCH`) |
22 | | ------------- |:-------------:|
23 | | CIFAR-10 | vgg19_bn, resnet110, resnet164, mobilenetv2, shufflenetv2|
24 | | CIFAR-100 | vgg19_bn, resnet110, resnet164|
25 | | ImageNet-1K | vgg19_bn, resnet50|
26 |
27 |
28 | ### 1. Generate class groups
29 |
30 | For CIFAR-10/CIFAR-100:
31 | ```bash
32 | python3 group_selection.py \
33 | --arch $ARCH \
34 | --resume $pretrained_model \
35 | --dataset $DATASET \
36 | --ngroups $number_of_groups \
37 | --gpu_num $number_of_gpu
38 | ```
39 | For ImageNet-1K:
40 | ```bash
41 | python3 group_selection.py \
42 | --arch $ARCH \
43 | --dataset imagenet \
44 | --ngroups $number_of_groups \
45 | --gpu_num $number_of_gpu \
46 | --data /{path_to_imagenet_dataset}/
47 | ```
48 |
49 | Pruning candidate now stored in `./prune_candidate_logs/`
50 |
51 | ### 2. Prune models
52 |
53 | For CIFAR-10/CIFAR-100:
54 | ```bash
55 | python3 prune_and_get_model.py \
56 | -a $ARCH \
57 | --dataset $DATASET \
58 | --resume $pretrained_model \
59 | -c ./prune_candidate_logs/ \
60 | -s ./{TO_SAVE_PRUNED_MODEL_DIR}/
61 | ```
62 | For ImageNet-1K:
63 | ```bash
64 | python3 prune_and_get_model.py \
65 | -a $ARCH \
66 | --dataset imagenet \
67 | -c ./prune_candidate_logs/ \
68 | -s ./{TO_SAVE_PRUNED_MODEL_DIR}/ \
69 | --pretrained
70 | ```
71 |
72 | Pruned models are now saved in `./{TO_SAVE_PRUNED_MODEL_DIR}/$ARCH/`
73 |
74 | ### 3. Retrain pruned models
75 |
76 | For CIFAR-10/CIFAR-100:
77 | ```bash
78 | python3 retrain_grouped_model.py \
79 | -a $ARCH \
80 | --dataset $DATASET \
81 | --resume ./{TO_SAVE_PRUNED_MODEL_DIR}/ \
82 | --train_batch $batch_size \
83 | --epochs $number_of_epochs \
84 | --num_gpus $number_of_gpus
85 | ```
86 | For ImageNet-1K:
87 | ```bash
88 | python3 retrain_grouped_model.py \
89 | -a $ARCH \
90 | --dataset imagenet \
91 | --resume ./{TO_SAVE_PRUNED_MODEL_DIR}/ \
92 | --epochs $number_of_epochs \
93 | --num_gpus $number_of_gpus \
94 | --train_batch $batch_size \
95 | --data /{path_to_imagenet_dataset}/
96 | ```
97 |
98 | Retrained models now saved in `./{TO_SAVE_PRUNED_MODEL_DIR}_retrained/$ARCH/`
99 |
100 | ### 4. Evaluate
101 |
102 | For CIFAR-10/CIFAR-100:
103 | ```bash
104 | python3 evaluate.py \
105 | -a $ARCH \
106 | --dataset=$DATASET \
107 | --retrained_dir ./{TO_SAVE_PRUNED_MODEL_DIR}_retrained/ \
108 | --test-batch $batch_size
109 | ```
110 | For ImageNet-1K:
111 | ```bash
112 | python3 evaluate.py \
113 | -d imagenet \
114 | -a $ARCH \
115 | --retrained_dir ./{TO_SAVE_PRUNED_MODEL_DIR}_retrained/ \
116 | --data /{path_to_imagenet_dataset}/
117 | ```
118 |
119 | ## Contributors
120 |
121 | Thanks for all the main contributors to this repository:
122 |
123 | * [Brandon Hsieh](https://github.com/hsiehbrandon)
124 |
125 | * [Zhuang Liu](https://github.com/liuzhuang13)
126 |
127 | * [Kenan Jiang](https://github.com/Kenan-Jiang)
128 |
129 | * [Kehan Wang](https://github.com/Jason-Khan)
130 |
131 | * [Siyuan Zhuang](https://github.com/suquark)
132 |
133 | And many others [Zihao Fan](https://github.com/zihao-fan), [Hank O'Brien](https://github.com/hjobrien) , [Yaoqing Yang](https://github.com/nsfzyzz), [Adarsh Karnati](https://github.com/akarnati11), [Jichan Chung](https://github.com/jichan3751), [Yingxin Kang](https://github.com/Miiira), [
134 | Balaji Veeramani](https://github.com/bveeramani), [Sahil Rao](https://github.com/sahilrao21).
135 |
136 |
137 |
138 |
139 | ## Citation
140 |
141 | ```text
142 | @inproceedings{wang2021sensAI,
143 | author = {Guanhua Wang and Zhuang Liu and Brandon Hsieh and Siyuan Zhuang and Joseph Gonzalez and Trevor Darrell and Ion Stoica},
144 | title = {{sensAI: ConvNets Decomposition via Class Parallelism for Fast Inference on Live Data}},
145 | booktitle = {Proceedings of Fourth Conference on Machine Learning and Systems (MLSys'21)},
146 | year = {2021}
147 | }
148 | ```
149 |
150 |
--------------------------------------------------------------------------------
/logger.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | import matplotlib.pyplot as plt
3 | import os
4 | import sys
5 | import numpy as np
6 |
7 | __all__ = ['Logger', 'LoggerMonitor', 'savefig']
8 |
9 | def savefig(fname, dpi=None):
10 | dpi = 150 if dpi == None else dpi
11 | plt.savefig(fname, dpi=dpi)
12 |
13 | def plot_overlap(logger, names=None):
14 | names = logger.names if names == None else names
15 | numbers = logger.numbers
16 | for _, name in enumerate(names):
17 | x = np.arange(len(numbers[name]))
18 | plt.plot(x, np.asarray(numbers[name]))
19 | return [logger.title + '(' + name + ')' for name in names]
20 |
21 | class Logger(object):
22 | '''Save training process to log file with simple plot function.'''
23 | def __init__(self, fpath, title=None, resume=False):
24 | self.file = None
25 | self.resume = resume
26 | self.title = '' if title == None else title
27 | if fpath is not None:
28 | if resume:
29 | self.file = open(fpath, 'r')
30 | name = self.file.readline()
31 | self.names = name.rstrip().split('\t')
32 | self.numbers = {}
33 | for _, name in enumerate(self.names):
34 | self.numbers[name] = []
35 |
36 | for numbers in self.file:
37 | numbers = numbers.rstrip().split('\t')
38 | for i in range(0, len(numbers)):
39 | self.numbers[self.names[i]].append(numbers[i])
40 | self.file.close()
41 | self.file = open(fpath, 'a')
42 | else:
43 | self.file = open(fpath, 'w')
44 |
45 | def set_names(self, names):
46 | if self.resume:
47 | pass
48 | # initialize numbers as empty list
49 | self.numbers = {}
50 | self.names = names
51 | for _, name in enumerate(self.names):
52 | self.file.write(name)
53 | self.file.write('\t')
54 | self.numbers[name] = []
55 | self.file.write('\n')
56 | self.file.flush()
57 |
58 |
59 | def append(self, numbers):
60 | assert len(self.names) == len(numbers), 'Numbers do not match names'
61 | for index, num in enumerate(numbers):
62 | self.file.write("{0:.6f}".format(num))
63 | self.file.write('\t')
64 | self.numbers[self.names[index]].append(num)
65 | self.file.write('\n')
66 | self.file.flush()
67 |
68 | def plot(self, names=None):
69 | names = self.names if names == None else names
70 | numbers = self.numbers
71 | for _, name in enumerate(names):
72 | x = np.arange(len(numbers[name]))
73 | plt.plot(x, np.asarray(numbers[name]))
74 | plt.legend([self.title + '(' + name + ')' for name in names])
75 | plt.grid(True)
76 |
77 | def close(self):
78 | if self.file is not None:
79 | self.file.close()
80 |
81 | class LoggerMonitor(object):
82 | '''Load and visualize multiple logs.'''
83 | def __init__ (self, paths):
84 | '''paths is a distionary with {name:filepath} pair'''
85 | self.loggers = []
86 | for title, path in paths.items():
87 | logger = Logger(path, title=title, resume=True)
88 | self.loggers.append(logger)
89 |
90 | def plot(self, names=None):
91 | plt.figure()
92 | plt.subplot(121)
93 | legend_text = []
94 | for logger in self.loggers:
95 | legend_text += plot_overlap(logger, names)
96 | plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
97 | plt.grid(True)
98 |
99 | if __name__ == '__main__':
100 | # # Example
101 | # logger = Logger('test.txt')
102 | # logger.set_names(['Train loss', 'Valid loss','Test loss'])
103 |
104 | # length = 100
105 | # t = np.arange(length)
106 | # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
107 | # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
108 | # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
109 |
110 | # for i in range(0, length):
111 | # logger.append([train_loss[i], valid_loss[i], test_loss[i]])
112 | # logger.plot()
113 |
114 | # Example: logger monitor
115 | paths = {
116 | 'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt',
117 | 'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt',
118 | 'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt',
119 | }
120 |
121 | field = ['Valid Acc.']
122 |
123 | monitor = LoggerMonitor(paths)
124 | monitor.plot(names=field)
125 | savefig('test.eps')
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | import matplotlib.pyplot as plt
3 | import os
4 | import sys
5 | import numpy as np
6 |
7 | __all__ = ['Logger', 'LoggerMonitor', 'savefig']
8 |
9 | def savefig(fname, dpi=None):
10 | dpi = 150 if dpi == None else dpi
11 | plt.savefig(fname, dpi=dpi)
12 |
13 | def plot_overlap(logger, names=None):
14 | names = logger.names if names == None else names
15 | numbers = logger.numbers
16 | for _, name in enumerate(names):
17 | x = np.arange(len(numbers[name]))
18 | plt.plot(x, np.asarray(numbers[name]))
19 | return [logger.title + '(' + name + ')' for name in names]
20 |
21 | class Logger(object):
22 | '''Save training process to log file with simple plot function.'''
23 | def __init__(self, fpath, title=None, resume=False):
24 | self.file = None
25 | self.resume = resume
26 | self.title = '' if title == None else title
27 | if fpath is not None:
28 | if resume:
29 | self.file = open(fpath, 'r')
30 | name = self.file.readline()
31 | self.names = name.rstrip().split('\t')
32 | self.numbers = {}
33 | for _, name in enumerate(self.names):
34 | self.numbers[name] = []
35 |
36 | for numbers in self.file:
37 | numbers = numbers.rstrip().split('\t')
38 | for i in range(0, len(numbers)):
39 | self.numbers[self.names[i]].append(numbers[i])
40 | self.file.close()
41 | self.file = open(fpath, 'a')
42 | else:
43 | self.file = open(fpath, 'w')
44 |
45 | def set_names(self, names):
46 | if self.resume:
47 | pass
48 | # initialize numbers as empty list
49 | self.numbers = {}
50 | self.names = names
51 | for _, name in enumerate(self.names):
52 | self.file.write(name)
53 | self.file.write('\t')
54 | self.numbers[name] = []
55 | self.file.write('\n')
56 | self.file.flush()
57 |
58 |
59 | def append(self, numbers):
60 | assert len(self.names) == len(numbers), 'Numbers do not match names'
61 | for index, num in enumerate(numbers):
62 | self.file.write("{0:.6f}".format(num))
63 | self.file.write('\t')
64 | self.numbers[self.names[index]].append(num)
65 | self.file.write('\n')
66 | self.file.flush()
67 |
68 | def plot(self, names=None):
69 | names = self.names if names == None else names
70 | numbers = self.numbers
71 | for _, name in enumerate(names):
72 | x = np.arange(len(numbers[name]))
73 | plt.plot(x, np.asarray(numbers[name]))
74 | plt.legend([self.title + '(' + name + ')' for name in names])
75 | plt.grid(True)
76 |
77 | def close(self):
78 | if self.file is not None:
79 | self.file.close()
80 |
81 | class LoggerMonitor(object):
82 | '''Load and visualize multiple logs.'''
83 | def __init__ (self, paths):
84 | '''paths is a distionary with {name:filepath} pair'''
85 | self.loggers = []
86 | for title, path in paths.items():
87 | logger = Logger(path, title=title, resume=True)
88 | self.loggers.append(logger)
89 |
90 | def plot(self, names=None):
91 | plt.figure()
92 | plt.subplot(121)
93 | legend_text = []
94 | for logger in self.loggers:
95 | legend_text += plot_overlap(logger, names)
96 | plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
97 | plt.grid(True)
98 |
99 | if __name__ == '__main__':
100 | # # Example
101 | # logger = Logger('test.txt')
102 | # logger.set_names(['Train loss', 'Valid loss','Test loss'])
103 |
104 | # length = 100
105 | # t = np.arange(length)
106 | # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
107 | # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
108 | # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
109 |
110 | # for i in range(0, length):
111 | # logger.append([train_loss[i], valid_loss[i], test_loss[i]])
112 | # logger.plot()
113 |
114 | # Example: logger monitor
115 | paths = {
116 | 'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt',
117 | 'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt',
118 | 'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt',
119 | }
120 |
121 | field = ['Valid Acc.']
122 |
123 | monitor = LoggerMonitor(paths)
124 | monitor.plot(names=field)
125 | savefig('test.eps')
--------------------------------------------------------------------------------
/apoz_policy.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import contextlib
4 | import torch.nn.functional as F
5 |
6 | def apoz_scoring(activation):
7 | """
8 | Calculate the Average Percentage of Zeros Score of the feature map activation layer output
9 | """
10 | activation = (activation.abs() <= 0.005).float()
11 | if activation.dim() == 4:
12 | featuremap_apoz_mat = activation.mean(dim=(0, 2, 3))
13 | elif activation.dim() == 2:
14 | featuremap_apoz_mat = activation.mean(dim=(0, 1))
15 | else:
16 | raise ValueError(
17 | f"activation_channels_avg: Unsupported shape: {activation.shape}")
18 | return featuremap_apoz_mat.mul(100).cpu()
19 |
20 |
21 | def avg_scoring(activation):
22 | activation = activation.abs()
23 | if activation.dim() == 4:
24 | featuremap_avg_mat = activation.mean(dim=(0, 2, 3))
25 | elif activation.dim() == 2:
26 | featuremap_avg_mat = activation.mean(dim=(0, 1))
27 | else:
28 | raise ValueError(
29 | f"activation_channels_avg: Unsupported shape: {activation.shape}")
30 | return featuremap_avg_mat.cpu()
31 |
32 |
33 | class ActivationRecord:
34 | def __init__(self, model, arch):
35 | self.apoz_scores_by_layer = []
36 | self.avg_scores_by_layer = []
37 | self.num_batches = 0
38 | self.layer_idx = 0
39 | self._candidates_by_layer = None
40 | self._model = model
41 | # switch to evaluate mode
42 | self._model.eval()
43 | self._model.apply(lambda m: m.register_forward_hook(self._hook))
44 | self.arch = arch
45 |
46 | def parse_activation(self, feature_map):
47 | apoz_score = apoz_scoring(feature_map).numpy()
48 | avg_score = avg_scoring(feature_map).numpy()
49 |
50 | if self.num_batches == 0:
51 | self.apoz_scores_by_layer.append(apoz_score)
52 | self.avg_scores_by_layer.append(avg_score)
53 | else:
54 | self.apoz_scores_by_layer[self.layer_idx] += apoz_score
55 | self.avg_scores_by_layer[self.layer_idx] += avg_score
56 | self.layer_idx += 1
57 |
58 | def __enter__(self):
59 | return self
60 |
61 | def __exit__(self, exception_type, exception_value, traceback):
62 | for score in self.apoz_scores_by_layer:
63 | score /= self.num_batches
64 | for score in self.avg_scores_by_layer:
65 | score /= self.num_batches
66 |
67 | def record_batch(self, *args, **kwargs):
68 | # reset layer index
69 | self.layer_idx = 0
70 | with torch.no_grad():
71 | # output is not used
72 | _ = self._model(*args, **kwargs)
73 | self.num_batches += 1
74 |
75 | def _hook(self, module, input, output):
76 | """Apply a hook to RelU layer"""
77 | if self.arch == "shufflenetv2":
78 | if module.__class__.__name__ == 'BatchNorm2d':
79 | self.parse_activation(F.relu(output))
80 | else:
81 | if module.__class__.__name__ == 'ReLU':
82 | self.parse_activation(output)
83 |
84 | def generate_pruned_candidates(self):
85 | num_layers = len(self.apoz_scores_by_layer)
86 | thresholds = [73] * num_layers
87 | avg_thresholds = [0.01] * num_layers
88 |
89 | candidates_by_layer = []
90 | for layer_idx, (apoz_scores, avg_scores) in enumerate(zip(self.apoz_scores_by_layer, self.avg_scores_by_layer)):
91 | if self.arch == "mobilenetv2":
92 | apoz_scores = torch.Tensor(apoz_scores)
93 | avg_scores = torch.Tensor(avg_scores)
94 | avg_candidates = [idx for idx, score in enumerate(
95 | avg_scores) if score >= avg_thresholds[layer_idx]]
96 | candidates = [(idx,float(score)) for idx, score in enumerate(apoz_scores) if score >= thresholds[layer_idx]]
97 | candidates = sorted(candidates, key = lambda x: x[1])[:int(len(candidates)/2)]
98 | candidates = [x[0] for x in candidates]
99 | else:
100 | apoz_scores = torch.Tensor(apoz_scores)
101 | avg_scores = torch.Tensor(avg_scores)
102 | avg_candidates = [idx for idx, score in enumerate(
103 | avg_scores) if score >= avg_thresholds[layer_idx]]
104 | candidates = [x[0] for x in apoz_scores.gt(
105 | thresholds[layer_idx]).nonzero().tolist()]
106 | difference_candidates = list(
107 | set(candidates).difference(set(avg_candidates)))
108 | candidates_by_layer.append(difference_candidates)
109 | """
110 | DEBUG: Printing out remaining neuron IDs
111 | all_neuron = [idx for idx, score in enumerate(avg_scores)]
112 | remaining = list(set(all_neuron)-set(difference_candidates))
113 | print("\nThose remaining neuron index for layer ", layer_idx)
114 | print(remaining)
115 | """
116 | print(
117 | f"Total pruned candidates: {sum(len(l) for l in candidates_by_layer)}")
118 | return candidates_by_layer
119 |
--------------------------------------------------------------------------------
/models/cifar/densenet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import math
5 |
6 |
7 | __all__ = ['densenet']
8 |
9 |
10 | from torch.autograd import Variable
11 |
12 | class Bottleneck(nn.Module):
13 | def __init__(self, inplanes, expansion=4, growthRate=12, dropRate=0):
14 | super(Bottleneck, self).__init__()
15 | planes = expansion * growthRate
16 | self.bn1 = nn.BatchNorm2d(inplanes)
17 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
18 | self.bn2 = nn.BatchNorm2d(planes)
19 | self.conv2 = nn.Conv2d(planes, growthRate, kernel_size=3,
20 | padding=1, bias=False)
21 | self.relu = nn.ReLU(inplace=True)
22 | self.dropRate = dropRate
23 |
24 | def forward(self, x):
25 | out = self.bn1(x)
26 | out = self.relu(out)
27 | out = self.conv1(out)
28 | out = self.bn2(out)
29 | out = self.relu(out)
30 | out = self.conv2(out)
31 | if self.dropRate > 0:
32 | out = F.dropout(out, p=self.dropRate, training=self.training)
33 |
34 | out = torch.cat((x, out), 1)
35 |
36 | return out
37 |
38 |
39 | class BasicBlock(nn.Module):
40 | def __init__(self, inplanes, expansion=1, growthRate=12, dropRate=0):
41 | super(BasicBlock, self).__init__()
42 | planes = expansion * growthRate
43 | self.bn1 = nn.BatchNorm2d(inplanes)
44 | self.conv1 = nn.Conv2d(inplanes, growthRate, kernel_size=3,
45 | padding=1, bias=False)
46 | self.relu = nn.ReLU(inplace=True)
47 | self.dropRate = dropRate
48 |
49 | def forward(self, x):
50 | out = self.bn1(x)
51 | out = self.relu(out)
52 | out = self.conv1(out)
53 | if self.dropRate > 0:
54 | out = F.dropout(out, p=self.dropRate, training=self.training)
55 |
56 | out = torch.cat((x, out), 1)
57 |
58 | return out
59 |
60 |
61 | class Transition(nn.Module):
62 | def __init__(self, inplanes, outplanes):
63 | super(Transition, self).__init__()
64 | self.bn1 = nn.BatchNorm2d(inplanes)
65 | self.conv1 = nn.Conv2d(inplanes, outplanes, kernel_size=1,
66 | bias=False)
67 | self.relu = nn.ReLU(inplace=True)
68 |
69 | def forward(self, x):
70 | out = self.bn1(x)
71 | out = self.relu(out)
72 | out = self.conv1(out)
73 | out = F.avg_pool2d(out, 2)
74 | return out
75 |
76 |
77 | class DenseNet(nn.Module):
78 |
79 | def __init__(self, depth=22, block=Bottleneck,
80 | dropRate=0, num_classes=10, growthRate=12, compressionRate=2):
81 | super(DenseNet, self).__init__()
82 |
83 | assert (depth - 4) % 3 == 0, 'depth should be 3n+4'
84 | n = (depth - 4) / 3 if block == BasicBlock else (depth - 4) // 6
85 |
86 | self.growthRate = growthRate
87 | self.dropRate = dropRate
88 |
89 | # self.inplanes is a global variable used across multiple
90 | # helper functions
91 | self.inplanes = growthRate * 2
92 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, padding=1,
93 | bias=False)
94 | self.dense1 = self._make_denseblock(block, n)
95 | self.trans1 = self._make_transition(compressionRate)
96 | self.dense2 = self._make_denseblock(block, n)
97 | self.trans2 = self._make_transition(compressionRate)
98 | self.dense3 = self._make_denseblock(block, n)
99 | self.bn = nn.BatchNorm2d(self.inplanes)
100 | self.relu = nn.ReLU(inplace=True)
101 | self.avgpool = nn.AvgPool2d(8)
102 | self.fc = nn.Linear(self.inplanes, num_classes)
103 |
104 | # Weight initialization
105 | for m in self.modules():
106 | if isinstance(m, nn.Conv2d):
107 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
108 | m.weight.data.normal_(0, math.sqrt(2. / n))
109 | elif isinstance(m, nn.BatchNorm2d):
110 | m.weight.data.fill_(1)
111 | m.bias.data.zero_()
112 |
113 | def _make_denseblock(self, block, blocks):
114 | layers = []
115 | for i in range(blocks):
116 | # Currently we fix the expansion ratio as the default value
117 | layers.append(block(self.inplanes, growthRate=self.growthRate, dropRate=self.dropRate))
118 | self.inplanes += self.growthRate
119 |
120 | return nn.Sequential(*layers)
121 |
122 | def _make_transition(self, compressionRate):
123 | inplanes = self.inplanes
124 | outplanes = int(math.floor(self.inplanes // compressionRate))
125 | self.inplanes = outplanes
126 | return Transition(inplanes, outplanes)
127 |
128 |
129 | def forward(self, x):
130 | x = self.conv1(x)
131 |
132 | x = self.trans1(self.dense1(x))
133 | x = self.trans2(self.dense2(x))
134 | x = self.dense3(x)
135 | x = self.bn(x)
136 | x = self.relu(x)
137 |
138 | x = self.avgpool(x)
139 | x = x.view(x.size(0), -1)
140 | x = self.fc(x)
141 |
142 | return x
143 |
144 |
145 | def densenet(**kwargs):
146 | """
147 | Constructs a ResNet model.
148 | """
149 | return DenseNet(**kwargs)
--------------------------------------------------------------------------------
/models/cifar/resnet.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import torch.nn as nn
4 | import math
5 |
6 |
7 | def conv3x3(in_planes, out_planes, stride=1):
8 | "3x3 convolution with padding"
9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
10 | padding=1, bias=False)
11 |
12 |
13 | class BasicBlock(nn.Module):
14 | expansion = 1
15 |
16 | def __init__(self, inplanes, planes, stride=1, downsample=None):
17 | super(BasicBlock, self).__init__()
18 | self.conv1 = conv3x3(inplanes, planes, stride)
19 | self.bn1 = nn.BatchNorm2d(planes)
20 | self.relu = nn.ReLU(inplace=True)
21 | self.conv2 = conv3x3(planes, planes)
22 | self.bn2 = nn.BatchNorm2d(planes)
23 | self.downsample = downsample
24 | self.stride = stride
25 |
26 | def forward(self, x):
27 | residual = x
28 |
29 | out = self.conv1(x)
30 | out = self.bn1(out)
31 | out = self.relu(out)
32 |
33 | out = self.conv2(out)
34 | out = self.bn2(out)
35 |
36 | if self.downsample is not None:
37 | residual = self.downsample(x)
38 |
39 | out += residual
40 | out = self.relu(out)
41 |
42 | return out
43 |
44 |
45 | class Bottleneck(nn.Module):
46 | expansion = 4
47 |
48 | def __init__(self, inplanes, planes, stride=1, downsample=None):
49 | super(Bottleneck, self).__init__()
50 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
51 | self.bn1 = nn.BatchNorm2d(planes)
52 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
53 | padding=1, bias=False)
54 | self.bn2 = nn.BatchNorm2d(planes)
55 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
56 | self.bn3 = nn.BatchNorm2d(planes * 4)
57 | self.relu = nn.ReLU(inplace=True)
58 | self.downsample = downsample
59 | self.stride = stride
60 |
61 | def forward(self, x):
62 | residual = x
63 |
64 | out = self.conv1(x)
65 | out = self.bn1(out)
66 | out = self.relu(out)
67 |
68 | out = self.conv2(out)
69 | out = self.bn2(out)
70 | out = self.relu(out)
71 |
72 | out = self.conv3(out)
73 | out = self.bn3(out)
74 |
75 | if self.downsample is not None:
76 | residual = self.downsample(x)
77 |
78 | out += residual
79 | out = self.relu(out)
80 |
81 | return out
82 |
83 |
84 | class ResNet(nn.Module):
85 |
86 | def __init__(self, depth, num_classes=1000, block_name='BasicBlock'):
87 | super(ResNet, self).__init__()
88 | # Model type specifies number of layers for CIFAR-10 model
89 | if block_name.lower() == 'basicblock':
90 | assert (depth - 2) % 6 == 0, 'When use basicblock, depth should be 6n+2, e.g. 20, 32, 44, 56, 110, 1202'
91 | n = (depth - 2) // 6
92 | block = BasicBlock
93 | elif block_name.lower() == 'bottleneck':
94 | assert (depth - 2) % 9 == 0, 'When use bottleneck, depth should be 9n+2, e.g. 20, 29, 47, 56, 110, 1199'
95 | n = (depth - 2) // 9
96 | block = Bottleneck
97 | else:
98 | raise ValueError('block_name shoule be Basicblock or Bottleneck')
99 |
100 |
101 | self.inplanes = 16
102 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1,
103 | bias=False)
104 | self.bn1 = nn.BatchNorm2d(16)
105 | self.relu = nn.ReLU(inplace=True)
106 | self.layer1 = self._make_layer(block, 16, n)
107 | self.layer2 = self._make_layer(block, 32, n, stride=2)
108 | self.layer3 = self._make_layer(block, 64, n, stride=2)
109 | self.avgpool = nn.AvgPool2d(8)
110 | self.fc = nn.Linear(64 * block.expansion, num_classes)
111 |
112 | for m in self.modules():
113 | if isinstance(m, nn.Conv2d):
114 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
115 | m.weight.data.normal_(0, math.sqrt(2. / n))
116 | elif isinstance(m, nn.BatchNorm2d):
117 | m.weight.data.fill_(1)
118 | m.bias.data.zero_()
119 |
120 | def _make_layer(self, block, planes, blocks, stride=1):
121 | downsample = None
122 | if stride != 1 or self.inplanes != planes * block.expansion:
123 | downsample = nn.Sequential(
124 | nn.Conv2d(self.inplanes, planes * block.expansion,
125 | kernel_size=1, stride=stride, bias=False),
126 | nn.BatchNorm2d(planes * block.expansion),
127 | )
128 |
129 | layers = []
130 | layers.append(block(self.inplanes, planes, stride, downsample))
131 | self.inplanes = planes * block.expansion
132 | for i in range(1, blocks):
133 | layers.append(block(self.inplanes, planes))
134 |
135 | return nn.Sequential(*layers)
136 |
137 | def forward(self, x, features_only=False):
138 | x = self.conv1(x)
139 | x = self.bn1(x)
140 | x = self.relu(x) # 32x32
141 |
142 | x = self.layer1(x) # 32x32
143 | x = self.layer2(x) # 16x16
144 | x = self.layer3(x) # 8x8
145 |
146 | x = self.avgpool(x)
147 | x = x.view(x.size(0), -1)
148 | if features_only:
149 | return x
150 | x = self.fc(x)
151 |
152 | return x
153 |
154 |
155 | def resnet110(**kwargs):
156 | """
157 | Constructs a ResNet-110 model.
158 | """
159 | return ResNet(depth=110, block_name='bottleneck', **kwargs)
160 |
161 |
162 | def resnet164(**kwargs):
163 | """
164 | Constructs a ResNet-164 model.
165 | """
166 | return ResNet(depth=164, block_name='bottleneck', **kwargs)
167 |
168 |
169 | __all__ = ['resnet110', 'resnet164']
170 |
--------------------------------------------------------------------------------
/models/cifar/resnext.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | """
3 | Creates a ResNeXt Model as defined in:
4 | Xie, S., Girshick, R., Dollar, P., Tu, Z., & He, K. (2016).
5 | Aggregated residual transformations for deep neural networks.
6 | arXiv preprint arXiv:1611.05431.
7 | import from https://github.com/prlz77/ResNeXt.pytorch/blob/master/models/model.py
8 | """
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | from torch.nn import init
12 |
13 | __all__ = ['resnext']
14 |
15 | class ResNeXtBottleneck(nn.Module):
16 | """
17 | RexNeXt bottleneck type C (https://github.com/facebookresearch/ResNeXt/blob/master/models/resnext.lua)
18 | """
19 | def __init__(self, in_channels, out_channels, stride, cardinality, widen_factor):
20 | """ Constructor
21 | Args:
22 | in_channels: input channel dimensionality
23 | out_channels: output channel dimensionality
24 | stride: conv stride. Replaces pooling layer.
25 | cardinality: num of convolution groups.
26 | widen_factor: factor to reduce the input dimensionality before convolution.
27 | """
28 | super(ResNeXtBottleneck, self).__init__()
29 | D = cardinality * out_channels // widen_factor
30 | self.conv_reduce = nn.Conv2d(in_channels, D, kernel_size=1, stride=1, padding=0, bias=False)
31 | self.bn_reduce = nn.BatchNorm2d(D)
32 | self.conv_conv = nn.Conv2d(D, D, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False)
33 | self.bn = nn.BatchNorm2d(D)
34 | self.conv_expand = nn.Conv2d(D, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
35 | self.bn_expand = nn.BatchNorm2d(out_channels)
36 |
37 | self.shortcut = nn.Sequential()
38 | if in_channels != out_channels:
39 | self.shortcut.add_module('shortcut_conv', nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=False))
40 | self.shortcut.add_module('shortcut_bn', nn.BatchNorm2d(out_channels))
41 |
42 | def forward(self, x):
43 | bottleneck = self.conv_reduce.forward(x)
44 | bottleneck = F.relu(self.bn_reduce.forward(bottleneck), inplace=True)
45 | bottleneck = self.conv_conv.forward(bottleneck)
46 | bottleneck = F.relu(self.bn.forward(bottleneck), inplace=True)
47 | bottleneck = self.conv_expand.forward(bottleneck)
48 | bottleneck = self.bn_expand.forward(bottleneck)
49 | residual = self.shortcut.forward(x)
50 | return F.relu(residual + bottleneck, inplace=True)
51 |
52 |
53 | class CifarResNeXt(nn.Module):
54 | """
55 | ResNext optimized for the Cifar dataset, as specified in
56 | https://arxiv.org/pdf/1611.05431.pdf
57 | """
58 | def __init__(self, cardinality, depth, num_classes, widen_factor=4, dropRate=0):
59 | """ Constructor
60 | Args:
61 | cardinality: number of convolution groups.
62 | depth: number of layers.
63 | num_classes: number of classes
64 | widen_factor: factor to adjust the channel dimensionality
65 | """
66 | super(CifarResNeXt, self).__init__()
67 | self.cardinality = cardinality
68 | self.depth = depth
69 | self.block_depth = (self.depth - 2) // 9
70 | self.widen_factor = widen_factor
71 | self.num_classes = num_classes
72 | self.output_size = 64
73 | self.stages = [64, 64 * self.widen_factor, 128 * self.widen_factor, 256 * self.widen_factor]
74 |
75 | self.conv_1_3x3 = nn.Conv2d(3, 64, 3, 1, 1, bias=False)
76 | self.bn_1 = nn.BatchNorm2d(64)
77 | self.stage_1 = self.block('stage_1', self.stages[0], self.stages[1], 1)
78 | self.stage_2 = self.block('stage_2', self.stages[1], self.stages[2], 2)
79 | self.stage_3 = self.block('stage_3', self.stages[2], self.stages[3], 2)
80 | self.classifier = nn.Linear(1024, num_classes)
81 | init.kaiming_normal(self.classifier.weight)
82 |
83 | for key in self.state_dict():
84 | if key.split('.')[-1] == 'weight':
85 | if 'conv' in key:
86 | init.kaiming_normal(self.state_dict()[key], mode='fan_out')
87 | if 'bn' in key:
88 | self.state_dict()[key][...] = 1
89 | elif key.split('.')[-1] == 'bias':
90 | self.state_dict()[key][...] = 0
91 |
92 | def block(self, name, in_channels, out_channels, pool_stride=2):
93 | """ Stack n bottleneck modules where n is inferred from the depth of the network.
94 | Args:
95 | name: string name of the current block.
96 | in_channels: number of input channels
97 | out_channels: number of output channels
98 | pool_stride: factor to reduce the spatial dimensionality in the first bottleneck of the block.
99 | Returns: a Module consisting of n sequential bottlenecks.
100 | """
101 | block = nn.Sequential()
102 | for bottleneck in range(self.block_depth):
103 | name_ = '%s_bottleneck_%d' % (name, bottleneck)
104 | if bottleneck == 0:
105 | block.add_module(name_, ResNeXtBottleneck(in_channels, out_channels, pool_stride, self.cardinality,
106 | self.widen_factor))
107 | else:
108 | block.add_module(name_,
109 | ResNeXtBottleneck(out_channels, out_channels, 1, self.cardinality, self.widen_factor))
110 | return block
111 |
112 | def forward(self, x):
113 | x = self.conv_1_3x3.forward(x)
114 | x = F.relu(self.bn_1.forward(x), inplace=True)
115 | x = self.stage_1.forward(x)
116 | x = self.stage_2.forward(x)
117 | x = self.stage_3.forward(x)
118 | x = F.avg_pool2d(x, 8, 1)
119 | x = x.view(-1, 1024)
120 | return self.classifier(x)
121 |
122 | def resnext(**kwargs):
123 | """Constructs a ResNeXt.
124 | """
125 | model = CifarResNeXt(**kwargs)
126 | return model
--------------------------------------------------------------------------------
/models/cifar/shufflenetv2.py:
--------------------------------------------------------------------------------
1 | '''ShuffleNetV2 in PyTorch.
2 |
3 | See the paper "ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" for more details.
4 | '''
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 |
10 | class ShuffleBlock(nn.Module):
11 | def __init__(self, groups=2):
12 | super(ShuffleBlock, self).__init__()
13 | self.groups = groups
14 |
15 | def forward(self, x):
16 | '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]'''
17 | N, C, H, W = x.size()
18 | g = self.groups
19 | return x.view(N, g, C//g, H, W).permute(0, 2, 1, 3, 4).reshape(N, C, H, W)
20 |
21 |
22 | class SplitBlock(nn.Module):
23 | def __init__(self, ratio):
24 | super(SplitBlock, self).__init__()
25 | self.ratio = ratio
26 |
27 | def forward(self, x):
28 | c = int(x.size(1) * self.ratio)
29 | return x[:, :c, :, :], x[:, c:, :, :]
30 |
31 |
32 | class BasicBlock(nn.Module):
33 | def __init__(self, in_channels, split_ratio=0.5):
34 | super(BasicBlock, self).__init__()
35 | self.split = SplitBlock(split_ratio)
36 | in_channels = int(in_channels * split_ratio)
37 | self.conv1 = nn.Conv2d(in_channels, in_channels,
38 | kernel_size=1, bias=False)
39 | self.bn1 = nn.BatchNorm2d(in_channels)
40 | self.conv2 = nn.Conv2d(in_channels, in_channels,
41 | kernel_size=3, stride=1, padding=1, groups=in_channels, bias=False)
42 | self.bn2 = nn.BatchNorm2d(in_channels)
43 | self.conv3 = nn.Conv2d(in_channels, in_channels,
44 | kernel_size=1, bias=False)
45 | self.bn3 = nn.BatchNorm2d(in_channels)
46 | self.shuffle = ShuffleBlock()
47 |
48 | def forward(self, x):
49 | x1, x2 = self.split(x)
50 | out = F.relu(self.bn1(self.conv1(x2)))
51 | out = self.bn2(self.conv2(out))
52 | out = F.relu(self.bn3(self.conv3(out)))
53 | out = torch.cat([x1, out], 1)
54 | out = self.shuffle(out)
55 | return out
56 |
57 |
58 | class DownBlock(nn.Module):
59 | def __init__(self, in_channels, out_channels):
60 | super(DownBlock, self).__init__()
61 | mid_channels = out_channels // 2
62 | # left
63 | self.conv1 = nn.Conv2d(in_channels, in_channels,
64 | kernel_size=3, stride=2, padding=1, groups=in_channels, bias=False)
65 | self.bn1 = nn.BatchNorm2d(in_channels)
66 | self.conv2 = nn.Conv2d(in_channels, mid_channels,
67 | kernel_size=1, bias=False)
68 | self.bn2 = nn.BatchNorm2d(mid_channels)
69 | # right
70 | self.conv3 = nn.Conv2d(in_channels, mid_channels,
71 | kernel_size=1, bias=False)
72 | self.bn3 = nn.BatchNorm2d(mid_channels)
73 | self.conv4 = nn.Conv2d(mid_channels, mid_channels,
74 | kernel_size=3, stride=2, padding=1, groups=mid_channels, bias=False)
75 | self.bn4 = nn.BatchNorm2d(mid_channels)
76 | self.conv5 = nn.Conv2d(mid_channels, mid_channels,
77 | kernel_size=1, bias=False)
78 | self.bn5 = nn.BatchNorm2d(mid_channels)
79 |
80 | self.shuffle = ShuffleBlock()
81 |
82 | def forward(self, x):
83 | # left
84 | out1 = self.bn1(self.conv1(x))
85 | out1 = F.relu(self.bn2(self.conv2(out1)))
86 | # right
87 | out2 = F.relu(self.bn3(self.conv3(x)))
88 | out2 = self.bn4(self.conv4(out2))
89 | out2 = F.relu(self.bn5(self.conv5(out2)))
90 | # concat
91 | out = torch.cat([out1, out2], 1)
92 | out = self.shuffle(out)
93 | return out
94 |
95 |
96 | class ShuffleNetV2(nn.Module):
97 | def __init__(self, net_size):
98 | super(ShuffleNetV2, self).__init__()
99 | out_channels = configs[net_size]['out_channels']
100 | num_blocks = configs[net_size]['num_blocks']
101 |
102 | self.conv1 = nn.Conv2d(3, 24, kernel_size=3,
103 | stride=1, padding=1, bias=False)
104 | self.bn1 = nn.BatchNorm2d(24)
105 | self.in_channels = 24
106 | self.layer1 = self._make_layer(out_channels[0], num_blocks[0])
107 | self.layer2 = self._make_layer(out_channels[1], num_blocks[1])
108 | self.layer3 = self._make_layer(out_channels[2], num_blocks[2])
109 | self.conv2 = nn.Conv2d(out_channels[2], out_channels[3],
110 | kernel_size=1, stride=1, padding=0, bias=False)
111 | self.bn2 = nn.BatchNorm2d(out_channels[3])
112 | self.linear = nn.Linear(out_channels[3], 10)
113 |
114 | def _make_layer(self, out_channels, num_blocks):
115 | layers = [DownBlock(self.in_channels, out_channels)]
116 | for i in range(num_blocks):
117 | layers.append(BasicBlock(out_channels))
118 | self.in_channels = out_channels
119 | return nn.Sequential(*layers)
120 |
121 | def forward(self, x, features_only=False):
122 | out = F.relu(self.bn1(self.conv1(x)))
123 | # out = F.max_pool2d(out, 3, stride=2, padding=1)
124 | out = self.layer1(out)
125 | out = self.layer2(out)
126 | out = self.layer3(out)
127 | out = F.relu(self.bn2(self.conv2(out)))
128 | out = F.avg_pool2d(out, 4)
129 | out = out.view(out.size(0), -1)
130 | if not features_only:
131 | out = self.linear(out)
132 | return out
133 |
134 |
135 | configs = {
136 | 0.5: {
137 | 'out_channels': (48, 96, 192, 1024),
138 | 'num_blocks': (3, 7, 3)
139 | },
140 |
141 | 1: {
142 | 'out_channels': (116, 232, 464, 1024),
143 | 'num_blocks': (3, 7, 3)
144 | },
145 | 1.5: {
146 | 'out_channels': (176, 352, 704, 1024),
147 | 'num_blocks': (3, 7, 3)
148 | },
149 | 2: {
150 | 'out_channels': (224, 488, 976, 2048),
151 | 'num_blocks': (3, 7, 3)
152 | }
153 | }
154 |
155 |
156 | def shufflenetv2(**kwargs):
157 | model = ShuffleNetV2(1)
158 | return model
159 |
160 |
--------------------------------------------------------------------------------
/models/imagenet/resnext.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | """
3 | Creates a ResNeXt Model as defined in:
4 | Xie, S., Girshick, R., Dollar, P., Tu, Z., & He, K. (2016).
5 | Aggregated residual transformations for deep neural networks.
6 | arXiv preprint arXiv:1611.05431.
7 | import from https://github.com/facebookresearch/ResNeXt/blob/master/models/resnext.lua
8 | """
9 | import math
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from torch.nn import init
13 | import torch
14 |
15 | __all__ = ['resnext50', 'resnext101', 'resnext152']
16 |
17 | class Bottleneck(nn.Module):
18 | """
19 | RexNeXt bottleneck type C
20 | """
21 | expansion = 4
22 |
23 | def __init__(self, inplanes, planes, baseWidth, cardinality, stride=1, downsample=None):
24 | """ Constructor
25 | Args:
26 | inplanes: input channel dimensionality
27 | planes: output channel dimensionality
28 | baseWidth: base width.
29 | cardinality: num of convolution groups.
30 | stride: conv stride. Replaces pooling layer.
31 | """
32 | super(Bottleneck, self).__init__()
33 |
34 | D = int(math.floor(planes * (baseWidth / 64)))
35 | C = cardinality
36 |
37 | self.conv1 = nn.Conv2d(inplanes, D*C, kernel_size=1, stride=1, padding=0, bias=False)
38 | self.bn1 = nn.BatchNorm2d(D*C)
39 | self.conv2 = nn.Conv2d(D*C, D*C, kernel_size=3, stride=stride, padding=1, groups=C, bias=False)
40 | self.bn2 = nn.BatchNorm2d(D*C)
41 | self.conv3 = nn.Conv2d(D*C, planes * 4, kernel_size=1, stride=1, padding=0, bias=False)
42 | self.bn3 = nn.BatchNorm2d(planes * 4)
43 | self.relu = nn.ReLU(inplace=True)
44 |
45 | self.downsample = downsample
46 |
47 | def forward(self, x):
48 | residual = x
49 |
50 | out = self.conv1(x)
51 | out = self.bn1(out)
52 | out = self.relu(out)
53 |
54 | out = self.conv2(out)
55 | out = self.bn2(out)
56 | out = self.relu(out)
57 |
58 | out = self.conv3(out)
59 | out = self.bn3(out)
60 |
61 | if self.downsample is not None:
62 | residual = self.downsample(x)
63 |
64 | out += residual
65 | out = self.relu(out)
66 |
67 | return out
68 |
69 |
70 | class ResNeXt(nn.Module):
71 | """
72 | ResNext optimized for the ImageNet dataset, as specified in
73 | https://arxiv.org/pdf/1611.05431.pdf
74 | """
75 | def __init__(self, baseWidth, cardinality, layers, num_classes):
76 | """ Constructor
77 | Args:
78 | baseWidth: baseWidth for ResNeXt.
79 | cardinality: number of convolution groups.
80 | layers: config of layers, e.g., [3, 4, 6, 3]
81 | num_classes: number of classes
82 | """
83 | super(ResNeXt, self).__init__()
84 | block = Bottleneck
85 |
86 | self.cardinality = cardinality
87 | self.baseWidth = baseWidth
88 | self.num_classes = num_classes
89 | self.inplanes = 64
90 | self.output_size = 64
91 |
92 | self.conv1 = nn.Conv2d(3, 64, 7, 2, 3, bias=False)
93 | self.bn1 = nn.BatchNorm2d(64)
94 | self.relu = nn.ReLU(inplace=True)
95 | self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
96 | self.layer1 = self._make_layer(block, 64, layers[0])
97 | self.layer2 = self._make_layer(block, 128, layers[1], 2)
98 | self.layer3 = self._make_layer(block, 256, layers[2], 2)
99 | self.layer4 = self._make_layer(block, 512, layers[3], 2)
100 | self.avgpool = nn.AvgPool2d(7)
101 | self.fc = nn.Linear(512 * block.expansion, num_classes)
102 |
103 | for m in self.modules():
104 | if isinstance(m, nn.Conv2d):
105 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
106 | m.weight.data.normal_(0, math.sqrt(2. / n))
107 | elif isinstance(m, nn.BatchNorm2d):
108 | m.weight.data.fill_(1)
109 | m.bias.data.zero_()
110 |
111 | def _make_layer(self, block, planes, blocks, stride=1):
112 | """ Stack n bottleneck modules where n is inferred from the depth of the network.
113 | Args:
114 | block: block type used to construct ResNext
115 | planes: number of output channels (need to multiply by block.expansion)
116 | blocks: number of blocks to be built
117 | stride: factor to reduce the spatial dimensionality in the first bottleneck of the block.
118 | Returns: a Module consisting of n sequential bottlenecks.
119 | """
120 | downsample = None
121 | if stride != 1 or self.inplanes != planes * block.expansion:
122 | downsample = nn.Sequential(
123 | nn.Conv2d(self.inplanes, planes * block.expansion,
124 | kernel_size=1, stride=stride, bias=False),
125 | nn.BatchNorm2d(planes * block.expansion),
126 | )
127 |
128 | layers = []
129 | layers.append(block(self.inplanes, planes, self.baseWidth, self.cardinality, stride, downsample))
130 | self.inplanes = planes * block.expansion
131 | for i in range(1, blocks):
132 | layers.append(block(self.inplanes, planes, self.baseWidth, self.cardinality))
133 |
134 | return nn.Sequential(*layers)
135 |
136 | def forward(self, x):
137 | x = self.conv1(x)
138 | x = self.bn1(x)
139 | x = self.relu(x)
140 | x = self.maxpool1(x)
141 | x = self.layer1(x)
142 | x = self.layer2(x)
143 | x = self.layer3(x)
144 | x = self.layer4(x)
145 | x = self.avgpool(x)
146 | x = x.view(x.size(0), -1)
147 | x = self.fc(x)
148 |
149 | return x
150 |
151 |
152 | def resnext50(baseWidth, cardinality):
153 | """
154 | Construct ResNeXt-50.
155 | """
156 | model = ResNeXt(baseWidth, cardinality, [3, 4, 6, 3], 1000)
157 | return model
158 |
159 |
160 | def resnext101(baseWidth, cardinality):
161 | """
162 | Construct ResNeXt-101.
163 | """
164 | model = ResNeXt(baseWidth, cardinality, [3, 4, 23, 3], 1000)
165 | return model
166 |
167 |
168 | def resnext152(baseWidth, cardinality):
169 | """
170 | Construct ResNeXt-152.
171 | """
172 | model = ResNeXt(baseWidth, cardinality, [3, 8, 36, 3], 1000)
173 | return model
174 |
--------------------------------------------------------------------------------
/group_selection.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import pickle
3 |
4 | import torch
5 | from torch import nn
6 | import torch.backends.cudnn as cudnn
7 | import load_model
8 | from tqdm import tqdm
9 | import torchvision.transforms as transforms
10 | import torchvision.datasets as datasets
11 | import numpy as np
12 | import subprocess as sp
13 | import os
14 |
15 | from even_k_means import kmeans_lloyd
16 |
17 | parser = argparse.ArgumentParser(
18 | description='PyTorch CIFAR10/100/Imagenet Generate Group Info')
19 | # Datasets
20 | parser.add_argument('-d', '--dataset', required=True, type=str)
21 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
22 | help='number of data loading workers (default: 4)')
23 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
24 | help='path to latest checkpoint (default: none)')
25 | parser.add_argument('--data', default='/home/ubuntu/imagenet', required=False, type=str,
26 | help='location of the imagenet dataset that includes train/val')
27 | # Architecture
28 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet20',
29 | #choices=load_model.model_arches('cifar'),
30 | help='model architecture: ' +
31 | ' | '.join(load_model.model_arches('cifar')) +
32 | ' (default: resnet18)')
33 | parser.add_argument('-n', '--ngroups', required=True, type=int, metavar='N',
34 | help='number of groups')
35 | parser.add_argument('-g', '--gpu_num', default=1, type=int,
36 | help='number of gpus')
37 |
38 | # Miscs
39 | parser.add_argument('--seed', type=int, default=42, help='manual seed')
40 | args = parser.parse_args()
41 | use_cuda = torch.cuda.is_available() and True
42 |
43 | # Random seed
44 | torch.manual_seed(args.seed)
45 | if use_cuda:
46 | torch.cuda.manual_seed_all(args.seed)
47 |
48 | def main():
49 | print('==> Preparing dataset %s' % args.dataset)
50 | resultExist = os.path.exists("./prune_candidate_logs")
51 | if resultExist:
52 | rm_cmd = 'rm -rf ./prune_candidate_logs'
53 | sp.Popen(rm_cmd, shell=True)
54 | mkdir_cmd = 'mkdir ./prune_candidate_logs'
55 | sp.Popen(mkdir_cmd, shell=True)
56 | # cifar10/100 group selection
57 | if args.dataset in ['cifar10', 'cifar100']:
58 | if args.dataset == 'cifar10':
59 | dataset_loader = datasets.CIFAR10
60 | elif args.dataset == 'cifar100':
61 | dataset_loader = datasets.CIFAR100
62 |
63 | dataset = dataset_loader(
64 | root='./data',
65 | download=True,
66 | train=True,
67 | transform=transforms.Compose([
68 | transforms.ToTensor(),
69 | transforms.Normalize(
70 | (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
71 | ]))
72 | data_loader = torch.utils.data.DataLoader(
73 | dataset,
74 | batch_size=1000,
75 | num_workers=args.workers,
76 | pin_memory=False)
77 |
78 | model = load_model.load_pretrain_model(
79 | args.arch, 'cifar', args.resume, len(dataset.classes), use_cuda)
80 |
81 | all_features = []
82 | all_targets = []
83 |
84 | model.eval()
85 | print('\nMake a test run to generate groups. \n Using training set.\n')
86 | with tqdm(total=len(data_loader)) as bar:
87 | for batch_idx, (inputs, targets) in enumerate(data_loader):
88 | bar.update()
89 | if use_cuda:
90 | inputs = inputs.cuda()
91 | with torch.no_grad():
92 | features = model(inputs, features_only=True)
93 | all_features.append(features)
94 | all_targets.append(targets)
95 |
96 | all_features = torch.cat(all_features)
97 | all_targets = torch.cat(all_targets)
98 |
99 | groups = kmeans_grouping(all_features, all_targets,
100 | args.ngroups, same_group_size=True)
101 | print("groups: ", groups)
102 | print("\n====================== Grouping Result ========================\n")
103 | process_list = [None for _ in range(args.gpu_num)]
104 | for i, group in enumerate(groups):
105 | if process_list[i % args.gpu_num]:
106 | process_list[i % args.gpu_num].wait()
107 | print(f"Group #{i}: {' '.join(str(idx) for idx in group)}")
108 | exec_cmd = 'python3 get_prune_candidates.py' +\
109 | ' -a %s' % args.arch + ' -d %s' % args.dataset + ' --resume ./%s' % args.resume + \
110 | ' --grouped ' + str(group)[1:-1].replace(",", "") + ' --group_number %d' % i + ' --gpu_num %d' % (i % args.gpu_num)
111 | process_list[i % args.gpu_num] = sp.Popen(exec_cmd, shell=True)
112 |
113 | np.save(open("prune_candidate_logs/grouping_config.npy", "wb"), groups)
114 |
115 | # imagenet group selection
116 | elif args.dataset == 'imagenet':
117 | num_gpus = args.gpu_num
118 | num_groups = args.ngroups
119 | group_size = 1000 // num_groups
120 | groups = [[i for i in range((j) * group_size, (j+1) * group_size)] for j in range(num_groups) ]
121 | process_list = [None for _ in range(num_gpus)]
122 | for i, group in enumerate(groups):
123 | if process_list[i % num_gpus]:
124 | process_list[i % num_gpus].wait()
125 | exec_cmd = 'python3 imagenet_activations.py ' +\
126 | ' --data %s' % args.data +\
127 | ' --gpu %d' % (i % num_gpus) +\
128 | ' --arch %s' % args.arch + ' --evaluate --pretrained --group %s' % (' '.join(str(digit) for digit in group)) + \
129 | ' --name %s' % (str(i))
130 | process_list[i % num_gpus] = sp.Popen(exec_cmd, shell=True)
131 | # Save the grouping class index partition information
132 | np.save(open("prune_candidate_logs/grouping_config.npy", "wb"), groups)
133 | else:
134 | raise NotImplementedError(f"There's no support for '{args.dataset}' dataset.")
135 |
136 | def kmeans_grouping(features, targets, n_groups, same_group_size=True):
137 | class_indices = targets.unique().sort().values
138 | mean_vectors = []
139 | for t in class_indices:
140 | mean_vec = features[targets == t.item(), :].mean(dim=0)
141 | mean_vectors.append(mean_vec.cpu().numpy())
142 | X = np.asarray(mean_vectors)
143 | class_indices = class_indices.cpu().numpy()
144 | assert X.ndim == 2
145 | best_labels, best_inertia, best_centers, _ = kmeans_lloyd(
146 | X, None, n_groups, verbose=True,
147 | same_cluster_size=same_group_size,
148 | random_state=args.seed,
149 | tol=1e-6)
150 | groups = []
151 | for i in range(n_groups):
152 | groups.append(class_indices[best_labels == i].tolist())
153 | return groups
154 |
155 | if __name__ == '__main__':
156 | main()
157 |
--------------------------------------------------------------------------------
/imagenet_evaluate_grouped.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import shutil
4 | import time
5 | import sys
6 | import glob
7 | import re
8 | import numpy as np
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.parallel
13 | import torch.backends.cudnn as cudnn
14 | import torch.distributed as dist
15 | import torch.optim
16 | import torch.multiprocessing as mp
17 | import torch.utils.data
18 | import torch.utils.data.distributed
19 | import torchvision.transforms as transforms
20 | import imagenet_dataset as datasets
21 | import torchvision.models as models
22 |
23 | from compute_flops import print_model_param_flops
24 |
25 | def main_worker(gpu, ngpus_per_node, args):
26 | global best_acc1
27 | args.gpu = gpu
28 | args.evaluate = True
29 |
30 | cudnn.benchmark = True
31 | model_list = []
32 | num_flops = []
33 | avg_num_param = 0.0
34 | args.checkpoint = os.path.dirname(args.retrained_dir)
35 | criterion = nn.CrossEntropyLoss()
36 |
37 | # load groups
38 | file_names = [f for f in glob.glob(args.retrained_dir + "/" + args.arch + "/*.pth", recursive=False)]
39 | group_id_list = [filename_to_index(filename) for filename in file_names]
40 | group_config = np.load(open(args.retrained_dir + '/grouping_config.npy', "rb"))
41 |
42 | permutation_indices = [] # To allow for arbitrary grouping
43 | for group_id in group_id_list:
44 | permutation_indices.extend(group_config[int(group_id[0])])
45 | permutation_indices = torch.eye(1000)[permutation_indices].cuda(args.gpu)
46 |
47 | # load models
48 | for index, (group_id, file_name) in enumerate(zip(group_id_list, file_names)):
49 | model = torch.load(file_name)
50 | model = model.cuda(index % ngpus_per_node)
51 | avg_num_param += sum(p.numel() for p in model.parameters())/1000000.0
52 | print('Group {} model has total params: {:2f}M'.format(group_id ,sum(p.numel() for p in model.parameters())/1000000.0))
53 | model_list.append(model)
54 |
55 | # generate dataloader
56 | valdir = os.path.join(args.data, 'val')
57 | traindir = os.path.join(args.data, 'train')
58 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
59 | std=[0.229, 0.224, 0.225])
60 |
61 | val_loader = torch.utils.data.DataLoader(
62 | datasets.ImageFolder(valdir, transforms.Compose([
63 | transforms.Resize(256),
64 | transforms.CenterCrop(224),
65 | transforms.ToTensor(),
66 | normalize,
67 | ])),
68 | batch_size=args.batch_size, shuffle=False,
69 | num_workers=args.workers, pin_memory=True)
70 |
71 | if args.evaluate:
72 | validate(val_loader, model_list, criterion, args, permutation_indices, ngpus_per_node)
73 | return
74 |
75 | def validate(val_loader, model_list, criterion, args, p_indices, gpu_nums):
76 | batch_time = AverageMeter('Time', ':6.3f')
77 | losses = AverageMeter('Loss', ':.4e')
78 | top1 = AverageMeter('Acc@1', ':6.2f')
79 | top5 = AverageMeter('Acc@5', ':6.2f')
80 | progress = ProgressMeter(len(val_loader), batch_time, losses, top1, top5,
81 | prefix='Test: ')
82 |
83 | # switch to evaluate mode
84 | for model in model_list:
85 | model.eval()
86 |
87 | with torch.no_grad():
88 | end = time.time()
89 | for i, (input, target) in enumerate(val_loader):
90 | input_list = []
91 | for index in range(gpu_nums):
92 | input = input.cuda(index)
93 | input_list.append(input)
94 | target = target.cuda(0) ### send same input and target to each gpu
95 |
96 | # compute output
97 | output_list = torch.Tensor().cuda(0)
98 | for index, model in enumerate(model_list):
99 | temp = model(input_list[index%gpu_nums])
100 | output = nn.Softmax(dim=1)(temp)[:, 1:]
101 | output_list= torch.cat((output_list, output), 1)
102 | output = torch.mm(output_list, p_indices)
103 |
104 | loss = criterion(output, target)
105 |
106 | # measure accuracy and record loss
107 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
108 | losses.update(loss.item(), input.size(0))
109 | top1.update(acc1[0], input.size(0))
110 | top5.update(acc5[0], input.size(0))
111 |
112 | # measure elapsed time
113 | batch_time.update(time.time() - end)
114 | end = time.time()
115 |
116 | if i % args.print_freq == 0:
117 | progress.print(i)
118 | # TODO: this should also be done with the ProgressMeter
119 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
120 | .format(top1=top1, top5=top5))
121 |
122 | return top1.avg
123 |
124 |
125 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
126 | torch.save(state, filename)
127 | if is_best:
128 | shutil.copyfile(filename, 'model_best.pth.tar')
129 |
130 | def filename_to_index(filename):
131 | filename = [int(s) for s in filename.split('_') if s.isdigit()]
132 | return filename
133 |
134 |
135 | class AverageMeter(object):
136 | """Computes and stores the average and current value"""
137 | def __init__(self, name, fmt=':f'):
138 | self.name = name
139 | self.fmt = fmt
140 | self.reset()
141 |
142 | def reset(self):
143 | self.val = 0
144 | self.avg = 0
145 | self.sum = 0
146 | self.count = 0
147 |
148 | def update(self, val, n=1):
149 | self.val = val
150 | self.sum += val * n
151 | self.count += n
152 | self.avg = self.sum / self.count
153 |
154 | def __str__(self):
155 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
156 | return fmtstr.format(**self.__dict__)
157 |
158 |
159 | class ProgressMeter(object):
160 | def __init__(self, num_batches, *meters, prefix=""):
161 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
162 | self.meters = meters
163 | self.prefix = prefix
164 |
165 | def print(self, batch):
166 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
167 | entries += [str(meter) for meter in self.meters]
168 | print('\t'.join(entries))
169 |
170 | def _get_batch_fmtstr(self, num_batches):
171 | num_digits = len(str(num_batches // 1))
172 | fmt = '{:' + str(num_digits) + 'd}'
173 | return '[' + fmt + '/' + fmt.format(num_batches) + ']'
174 |
175 |
176 | def accuracy(output, target, topk=(1,)):
177 | """Computes the accuracy over the k top predictions for the specified values of k"""
178 | with torch.no_grad():
179 | maxk = max(topk)
180 | batch_size = target.size(0)
181 |
182 | _, pred = output.topk(maxk, 1, True, True)
183 | pred = pred.t()
184 | correct = pred.eq(target.view(1, -1).expand_as(pred))
185 |
186 | res = []
187 | for k in topk:
188 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
189 | res.append(correct_k.mul_(100.0 / batch_size))
190 | return res
--------------------------------------------------------------------------------
/imagenet_dataset.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | from PIL import Image
4 | import os
5 | import os.path
6 |
7 | import random
8 |
9 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm']
10 |
11 |
12 | def is_image_file(filename):
13 | """Checks if a file is an image.
14 |
15 | Args:
16 | filename (string): path to a file
17 |
18 | Returns:
19 | bool: True if the filename ends with a known image extension
20 | """
21 | filename_lower = filename.lower()
22 | return any(filename_lower.endswith(ext) for ext in IMG_EXTENSIONS)
23 |
24 | def find_classes(dir):
25 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
26 | classes.sort()
27 | class_to_idx = {classes[i]: i for i in range(len(classes))}
28 | return classes, class_to_idx
29 |
30 | def make_dataset(dir, class_to_idx, group = None, target_abs_index = None):
31 | images = []
32 | dir = os.path.expanduser(dir)
33 | for target in sorted(os.listdir(dir)):
34 | # pdb.set_trace()
35 | if target not in class_to_idx:
36 | continue
37 | if int(class_to_idx[target]) not in group:
38 | continue
39 |
40 | d = os.path.join(dir, target)
41 | if not os.path.isdir(d):
42 | continue
43 | for root, _, fnames in sorted(os.walk(d)):
44 | for fname in sorted(fnames):
45 | if is_image_file(fname):
46 | path = os.path.join(root, fname)
47 | if target_abs_index != None :
48 | item = (path, target_abs_index)
49 | else:
50 | item = (path, class_to_idx[target])
51 | images.append(item)
52 |
53 | return images # random.sample(images, 5000) # Used for debug
54 |
55 | def pil_loader(path):
56 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
57 | with open(path, 'rb') as f:
58 | with Image.open(f) as img:
59 | return img.convert('RGB')
60 |
61 | def accimage_loader(path):
62 | import accimage
63 | try:
64 | return accimage.Image(path)
65 | except IOError:
66 | # Potentially a decoding problem, fall back to PIL.Image
67 | return pil_loader(path)
68 |
69 | def default_loader(path):
70 | from torchvision import get_image_backend
71 | if get_image_backend() == 'accimage':
72 | return accimage_loader(path)
73 | else:
74 | return pil_loader(path)
75 |
76 | class ImageFolder(data.Dataset):
77 | """A generic data loader where the images are arranged in this way: ::
78 |
79 | root/dog/xxx.png
80 | root/dog/xxy.png
81 | root/dog/xxz.png
82 |
83 | root/cat/123.png
84 | root/cat/nsdf3.png
85 | root/cat/asd932_.png
86 |
87 | Args:
88 | root (string): Root directory path.
89 | transform (callable, optional): A function/transform that takes in an PIL image
90 | and returns a transformed version. E.g, ``transforms.RandomCrop``
91 | target_transform (callable, optional): A function/transform that takes in the
92 | target and transforms it.
93 | loader (callable, optional): A function to load an image given its path.
94 |
95 | Attributes:
96 | classes (list): List of the class names.
97 | class_to_idx (dict): Dict with items (class_name, class_index).
98 | imgs (list): List of (image path, class_index) tuples
99 | """
100 |
101 | def __init__(self, root, transform=None, target_transform=None,
102 | loader=default_loader, activations = False, group = None, retrain = False):
103 | classes, class_to_idx = find_classes(root)
104 |
105 | # Case: Evaluate but pull from training set
106 | if activations and group:
107 | imgs = make_dataset(root, class_to_idx, group)
108 | elif group is not None: # Case: Train / Evaluate: pos/neg according to group
109 | if retrain: # Subcase: Retraining (Training Set Creation)
110 | imgs = []
111 | for abs_index, class_index in enumerate(group):
112 | pos_imgs = make_dataset(root, \
113 | class_to_idx, \
114 | group=[class_index], \
115 | target_abs_index=abs_index + 1)
116 | multiplier = max(1, 0) # Multiple used to balance, if wanted
117 | imgs.extend(pos_imgs)
118 | negative_numbers = len(imgs)
119 | negative_indices = [i for i in range(1000) if i not in group]
120 | neg_imgs = make_dataset(root, \
121 | class_to_idx, \
122 | group=negative_indices, \
123 | target_abs_index=0)
124 | neg_imgs = random.sample(neg_imgs, negative_numbers)
125 | imgs.extend(neg_imgs)
126 | print("Num images in training set: {}".format(len(imgs)))
127 | # print("Added {} positive images with target index {}".format(len(pos_imgs)*multiplier, abs_index))
128 | else: # Subcase: Evaluation (Validation Set Creation)
129 | imgs = []
130 | for abs_index, class_index in enumerate(group):
131 | pos_imgs = make_dataset(root, \
132 | class_to_idx, \
133 | group=[class_index], \
134 | target_abs_index=abs_index + 1)
135 | imgs.extend(pos_imgs)
136 | negative_numbers = len(imgs)
137 | print("positive images in val loader: ", negative_numbers)
138 | negative_indices = [i for i in range(1000) if i not in group]
139 | neg_imgs = make_dataset(root, \
140 | class_to_idx, \
141 | group=negative_indices, \
142 | target_abs_index=0)
143 |
144 | neg_imgs = random.sample(neg_imgs, negative_numbers)
145 | imgs.extend(neg_imgs)
146 | print("Num images in validation set {}".format(len(imgs)))
147 | else: # Case: Default
148 | imgs = make_dataset(root, class_to_idx, group = [i for i in range(1000)])
149 | if len(imgs) == 0:
150 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
151 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
152 |
153 | self.root = root
154 | self.imgs = imgs
155 | self.classes = classes
156 | self.class_to_idx = class_to_idx
157 | self.transform = transform
158 | self.target_transform = target_transform
159 | self.loader = loader
160 |
161 | def __getitem__(self, index):
162 | """
163 | Args:
164 | index (int): Index
165 |
166 | Returns:
167 | tuple: (image, target) where target is class_index of the target class.
168 | """
169 | path, target = self.imgs[index]
170 | img = self.loader(path)
171 | if self.transform is not None:
172 | img = self.transform(img)
173 | if self.target_transform is not None:
174 | target = self.target_transform(target)
175 |
176 | return img, target
177 |
178 | def __len__(self):
179 | return len(self.imgs)
180 |
181 |
182 |
183 |
--------------------------------------------------------------------------------
/prune_utils/layer_prune.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from fractions import gcd
4 |
5 | def prune_output_linear_layer_(linear_layer, class_indices, use_bce=False):
6 | if use_bce:
7 | assert len(class_indices) == 1
8 | else:
9 | # use 0 as the placeholder of the negative class
10 | class_indices = [0] + list(class_indices)
11 | linear_layer.bias.data = linear_layer.bias.data[class_indices]
12 | linear_layer.weight.data = linear_layer.weight.data[class_indices, :]
13 | if not use_bce:
14 | # reinitialize the negative sample class
15 | linear_layer.weight.data[0].normal_(0, 0.01)
16 | linear_layer.out_features = len(class_indices)
17 |
18 |
19 | def prune_linear_in_features(fc, pruned_indices):
20 | new_fc = nn.Linear(fc.in_features - len(pruned_indices), fc.out_features)
21 | new_fc.bias.data = fc.bias.data.clone()
22 | new_fc.weight.data = prune_tensor(fc.weight.data, 1, pruned_indices)
23 | return new_fc
24 |
25 |
26 | def prune_linear_in_features_(fc, pruned_indices):
27 | fc.in_features -= len(pruned_indices)
28 | fc.weight.data = prune_tensor(fc.weight.data, 1, pruned_indices)
29 |
30 |
31 | def prune_tensor(tensor, dim, pruned_indices):
32 | if tensor.shape[dim] == 1:
33 | return tensor
34 | included_indices = [i for i in range(
35 | tensor.shape[dim]) if i not in pruned_indices]
36 | indexer = []
37 | for i in range(tensor.ndim):
38 | indexer.append(slice(None) if i != dim else included_indices)
39 | return tensor[indexer]
40 |
41 | def prune_batchnorm2d(bn, pruned_indices):
42 | new_bn = nn.BatchNorm2d(bn.num_features - len(pruned_indices))
43 | new_bn.weight.data = prune_tensor(bn.weight.data, 0, pruned_indices)
44 | new_bn.bias.data = prune_tensor(bn.bias.data, 0, pruned_indices)
45 | new_bn.running_mean.data = prune_tensor(
46 | bn.running_mean.data, 0, pruned_indices)
47 | new_bn.running_var.data = prune_tensor(
48 | bn.running_var.data, 0, pruned_indices)
49 | return new_bn
50 |
51 |
52 | def prune_batchnorm2d_(bn, pruned_indices):
53 | bn.num_features -= len(pruned_indices)
54 | bn.weight.data = prune_tensor(bn.weight.data, 0, pruned_indices)
55 | bn.bias.data = prune_tensor(bn.bias.data, 0, pruned_indices)
56 | bn.running_mean.data = prune_tensor(
57 | bn.running_mean.data, 0, pruned_indices)
58 | bn.running_var.data = prune_tensor(bn.running_var.data, 0, pruned_indices)
59 | return bn
60 |
61 |
62 | def prune_conv2d_out_channels(conv, pruned_indices):
63 | new_conv = nn.Conv2d(in_channels=conv.in_channels,
64 | out_channels=conv.out_channels - len(pruned_indices),
65 | kernel_size=conv.kernel_size,
66 | stride=conv.stride,
67 | padding=conv.padding,
68 | dilation=conv.dilation,
69 | groups=conv.groups,
70 | bias=conv.bias is not None)
71 |
72 | new_conv.weight.data = prune_tensor(conv.weight.data, 0, pruned_indices)
73 |
74 | if conv.bias is not None:
75 | new_conv.bias.data = prune_tensor(conv.bias.data, 0, pruned_indices)
76 | return new_conv
77 |
78 |
79 | def prune_conv2d_out_channels_(conv, pruned_indices):
80 | conv.out_channels -= len(pruned_indices)
81 | conv.weight.data = prune_tensor(conv.weight.data, 0, pruned_indices)
82 | if conv.bias is not None:
83 | conv.bias.data = prune_tensor(conv.bias.data, 0, pruned_indices)
84 | return conv
85 |
86 |
87 | def prune_conv2d_in_channels(conv, pruned_indices):
88 | new_conv = nn.Conv2d(in_channels=conv.in_channels - len(pruned_indices),
89 | out_channels=conv.out_channels,
90 | kernel_size=conv.kernel_size,
91 | stride=conv.stride,
92 | padding=conv.padding,
93 | dilation=conv.dilation,
94 | groups=conv.groups,
95 | bias=conv.bias is not None)
96 |
97 | new_conv.weight.data = prune_tensor(conv.weight.data, 1, pruned_indices)
98 |
99 | if conv.bias is not None:
100 | new_conv.bias.data = conv.bias.data.clone()
101 | return new_conv
102 |
103 |
104 | def prune_conv2d_in_channels_(conv, pruned_indices):
105 | conv.in_channels -= len(pruned_indices)
106 | conv.weight.data = prune_tensor(conv.weight.data, 1, pruned_indices)
107 | return conv
108 |
109 |
110 | def prune_contiguous_conv2d_(conv_p, conv_n, pruned_indices, bn=None):
111 | prune_conv2d_out_channels_(conv_p, pruned_indices)
112 | prune_conv2d_in_channels_(conv_n, pruned_indices)
113 | if bn is not None:
114 | prune_batchnorm2d_(bn, pruned_indices)
115 |
116 | def prune_contiguous_conv2d_last(conv_p, conv_n, pruned_indices, bn=None):
117 | prune_conv2d_out_channels_(conv_p, pruned_indices)
118 | if bn is not None:
119 | prune_batchnorm2d_(bn, pruned_indices)
120 |
121 | def prune_mobile_conv2d_in_channels(conv, pruned_indices):
122 | conv.in_channels -= len(pruned_indices)
123 | conv.groups = conv.in_channels
124 |
125 | conv.weight.data = prune_tensor(conv.weight.data, 1, pruned_indices)
126 | return conv
127 |
128 | def prune_mobile_conv2d_out_channels(conv, pruned_indices):
129 | if conv.groups != 1:
130 | pruned_indices = pruned_indices[:(conv.out_channels - conv.in_channels)]
131 | conv.out_channels -= len(pruned_indices)
132 | conv.groups = conv.out_channels
133 | conv.weight.data = prune_tensor(conv.weight.data, 0, pruned_indices)
134 | return conv
135 |
136 | def prune_contiguous_conv2d_mobile_a(conv_p, conv_n, pruned_indices, bn=None):
137 | prune_conv2d_out_channels_(conv_p, pruned_indices)
138 | prune_mobile_conv2d_in_channels(conv_n, pruned_indices)
139 | if bn is not None:
140 | prune_batchnorm2d_(bn, pruned_indices)
141 |
142 | def prune_contiguous_conv2d_mobile_b(conv_p, conv_n, pruned_indices, bn=None):
143 | prune_mobile_conv2d_out_channels(conv_p, pruned_indices)
144 | prune_conv2d_in_channels_(conv_n, pruned_indices[:(conv_n.in_channels-conv_p.out_channels)])
145 | if bn is not None:
146 | prune_batchnorm2d_(bn, pruned_indices[:(bn.num_features-conv_p.out_channels)])
147 |
148 | def prune_mobile_block(conv_1, conv_2, conv_3, pruned_indices_1, pruned_indices_2, bn_1, bn_2):
149 | small_len = min(len(pruned_indices_1), len(pruned_indices_2))
150 | if len(pruned_indices_2) < len(pruned_indices_1):
151 | pruned_indices_1 = pruned_indices_1[:small_len]
152 | prune_contiguous_conv2d_mobile_a(conv_1, conv_2, pruned_indices_1, bn=bn_1)
153 | prune_contiguous_conv2d_mobile_b(conv_2, conv_3, pruned_indices_2, bn=bn_2)
154 |
155 | def prune_downblock(block, layer_candidates):
156 | conv3 = block.conv3
157 | bn3 = block.bn3
158 | conv4 = block.conv4
159 | bn4 = block.bn4
160 | conv5 = block.conv5
161 | pruned_indices_3_4 = layer_candidates[2]
162 | pruned_indices_4_5 = layer_candidates[3]
163 | small_len = min(len(pruned_indices_3_4), len(pruned_indices_4_5))
164 | if len(pruned_indices_4_5) < len(pruned_indices_3_4):
165 | pruned_indices_3_4 = pruned_indices_4_5[:small_len]
166 | prune_contiguous_conv2d_mobile_a(conv3, conv4, pruned_indices_3_4, bn=bn3)
167 | prune_contiguous_conv2d_mobile_b(conv4, conv5, pruned_indices_4_5, bn=bn4)
168 |
169 | def prune_basicblock(block, layer_candidates):
170 | conv_1 = block.conv1
171 | bn_1 = block.bn1
172 | conv_2 = block.conv2
173 | bn_2 = block.bn2
174 | conv_3 = block.conv3
175 | pruned_indices_1 = layer_candidates[0]
176 | pruned_indices_2 = layer_candidates[1]
177 | small_len = min(len(pruned_indices_1), len(pruned_indices_2))
178 | if len(pruned_indices_2) < len(pruned_indices_1):
179 | pruned_indices_1 = pruned_indices_1[:small_len]
180 | prune_contiguous_conv2d_mobile_a(conv_1, conv_2, pruned_indices_1, bn=bn_1)
181 | prune_contiguous_conv2d_mobile_b(conv_2, conv_3, pruned_indices_2, bn=bn_2)
182 |
183 | def prune_shuffle_layer(layer, layer_candidates):
184 | for idx, block in enumerate(layer):
185 | if idx == 0:
186 | prune_downblock(block, layer_candidates[:5])
187 | else:
188 | candidates = layer_candidates[idx*3+2:idx*3+5]
189 | prune_basicblock(block, candidates)
--------------------------------------------------------------------------------
/even_k_means.py:
--------------------------------------------------------------------------------
1 | from sklearn.cluster.k_means_ import check_random_state, _check_sample_weight, _init_centroids
2 | from sklearn.metrics.pairwise import pairwise_distances_argmin_min, euclidean_distances
3 | from sklearn.utils.extmath import row_norms, squared_norm
4 | import numpy as np
5 |
6 |
7 | def _labels_inertia(X, sample_weight, x_squared_norms, centers, distances, same_cluster_size=False):
8 | """E step of the K-means EM algorithm.
9 | Compute the labels and the inertia of the given samples and centers.
10 | This will compute the distances in-place.
11 | Parameters
12 | ----------
13 | X : float64 array-like or CSR sparse matrix, shape (n_samples, n_features)
14 | The input samples to assign to the labels.
15 | sample_weight : array-like, shape (n_samples,)
16 | The weights for each observation in X.
17 | x_squared_norms : array, shape (n_samples,)
18 | Precomputed squared euclidean norm of each data point, to speed up
19 | computations.
20 | centers : float array, shape (k, n_features)
21 | The cluster centers.
22 | distances : float array, shape (n_samples,)
23 | Pre-allocated array to be filled in with each sample's distance
24 | to the closest center.
25 | Returns
26 | -------
27 | labels : int array of shape(n)
28 | The resulting assignment
29 | inertia : float
30 | Sum of squared distances of samples to their closest cluster center.
31 | """
32 | sample_weight = _check_sample_weight(X, sample_weight)
33 | n_samples = X.shape[0]
34 | n_clusters = centers.shape[0]
35 |
36 | # See http://jmonlong.github.io/Hippocamplus/2018/06/09/cluster-same-size/#same-size-k-means-variation
37 | if same_cluster_size:
38 | cluster_size = n_samples // n_clusters
39 | labels = np.zeros(n_samples, dtype=np.int32)
40 | mindist = np.zeros(n_samples, dtype=np.float32)
41 | # count how many samples have been labeled in a cluster
42 | counters = np.zeros(n_clusters, dtype=np.int32)
43 | # dist: (n_samples, n_clusters)
44 | dist = euclidean_distances(X, centers, squared=False)
45 | closeness = dist.min(axis=-1) - dist.max(axis=-1)
46 | ranking = np.argsort(closeness)
47 | for r in ranking:
48 | while True:
49 | label = dist[r].argmin()
50 | if counters[label] < cluster_size:
51 | labels[r] = label
52 | counters[label] += 1
53 | # squared distances are used for inertia in this function
54 | mindist[r] = dist[r, label] ** 2
55 | break
56 | else:
57 | dist[r, label] = np.inf
58 | else:
59 | # Breakup nearest neighbor distance computation into batches to prevent
60 | # memory blowup in the case of a large number of samples and clusters.
61 | # TODO: Once PR #7383 is merged use check_inputs=False in metric_kwargs.
62 | labels, mindist = pairwise_distances_argmin_min(
63 | X=X, Y=centers, metric='euclidean', metric_kwargs={'squared': True})
64 |
65 | # cython k-means code assumes int32 inputs
66 | labels = labels.astype(np.int32, copy=False)
67 | if n_samples == distances.shape[0]:
68 | # distances will be changed in-place
69 | distances[:] = mindist
70 | inertia = (mindist * sample_weight).sum()
71 | return labels, inertia
72 |
73 |
74 | def _centers_dense(X, sample_weight, labels, n_clusters, distances):
75 | """M step of the K-means EM algorithm
76 | Computation of cluster centers / means.
77 | Parameters
78 | ----------
79 | X : array-like, shape (n_samples, n_features)
80 | sample_weight : array-like, shape (n_samples,)
81 | The weights for each observation in X.
82 | labels : array of integers, shape (n_samples)
83 | Current label assignment
84 | n_clusters : int
85 | Number of desired clusters
86 | distances : array-like, shape (n_samples)
87 | Distance to closest cluster for each sample.
88 | Returns
89 | -------
90 | centers : array, shape (n_clusters, n_features)
91 | The resulting centers
92 | """
93 | # TODO: add support for CSR input
94 | n_samples = X.shape[0]
95 | n_features = X.shape[1]
96 |
97 | dtype = np.float32
98 | centers = np.zeros((n_clusters, n_features), dtype=dtype)
99 | weight_in_cluster = np.zeros((n_clusters,), dtype=dtype)
100 |
101 | for i in range(n_samples):
102 | c = labels[i]
103 | weight_in_cluster[c] += sample_weight[i]
104 | empty_clusters = np.where(weight_in_cluster == 0)[0]
105 | # maybe also relocate small clusters?
106 |
107 | if len(empty_clusters):
108 | # find points to reassign empty clusters to
109 | far_from_centers = distances.argsort()[::-1]
110 |
111 | for i, cluster_id in enumerate(empty_clusters):
112 | # XXX two relocated clusters could be close to each other
113 | far_index = far_from_centers[i]
114 | new_center = X[far_index] * sample_weight[far_index]
115 | centers[cluster_id] = new_center
116 | weight_in_cluster[cluster_id] = sample_weight[far_index]
117 |
118 | for i in range(n_samples):
119 | for j in range(n_features):
120 | centers[labels[i], j] += X[i, j] * sample_weight[i]
121 |
122 | centers /= weight_in_cluster[:, np.newaxis]
123 |
124 | return centers
125 |
126 |
127 | def kmeans_lloyd(X, sample_weight, n_clusters, max_iter=300,
128 | init='k-means++', verbose=False, x_squared_norms=None,
129 | random_state=None, tol=1e-4, same_cluster_size=False):
130 | """A single run of k-means, assumes preparation completed prior.
131 | Parameters
132 | ----------
133 | X : array-like of floats, shape (n_samples, n_features)
134 | The observations to cluster.
135 | n_clusters : int
136 | The number of clusters to form as well as the number of
137 | centroids to generate.
138 | sample_weight : array-like, shape (n_samples,)
139 | The weights for each observation in X.
140 | max_iter : int, optional, default 300
141 | Maximum number of iterations of the k-means algorithm to run.
142 | init : {'k-means++', 'random', or ndarray, or a callable}, optional
143 | Method for initialization, default to 'k-means++':
144 | 'k-means++' : selects initial cluster centers for k-mean
145 | clustering in a smart way to speed up convergence. See section
146 | Notes in k_init for more details.
147 | 'random': choose k observations (rows) at random from data for
148 | the initial centroids.
149 | If an ndarray is passed, it should be of shape (k, p) and gives
150 | the initial centers.
151 | If a callable is passed, it should take arguments X, k and
152 | and a random state and return an initialization.
153 | tol : float, optional
154 | The relative increment in the results before declaring convergence.
155 | verbose : boolean, optional
156 | Verbosity mode
157 | x_squared_norms : array
158 | Precomputed x_squared_norms.
159 | precompute_distances : boolean, default: True
160 | Precompute distances (faster but takes more memory).
161 | random_state : int, RandomState instance or None (default)
162 | Determines random number generation for centroid initialization. Use
163 | an int to make the randomness deterministic.
164 | See :term:`Glossary `.
165 | Returns
166 | -------
167 | centroid : float ndarray with shape (k, n_features)
168 | Centroids found at the last iteration of k-means.
169 | label : integer ndarray with shape (n_samples,)
170 | label[i] is the code or index of the centroid the
171 | i'th observation is closest to.
172 | inertia : float
173 | The final value of the inertia criterion (sum of squared distances to
174 | the closest centroid for all observations in the training set).
175 | n_iter : int
176 | Number of iterations run.
177 | """
178 | random_state = check_random_state(random_state)
179 | if same_cluster_size:
180 | assert len(X) % n_clusters == 0, "#samples is not divisible by #clusters"
181 |
182 | if verbose:
183 | print("\n==> Starting k-means clustering...\n")
184 |
185 | sample_weight = _check_sample_weight(X, sample_weight)
186 | x_squared_norms = row_norms(X, squared=True)
187 |
188 | best_labels, best_inertia, best_centers = None, None, None
189 | # init
190 | centers = _init_centroids(X, n_clusters, init, random_state=random_state,
191 | x_squared_norms=x_squared_norms)
192 | if verbose:
193 | print("Initialization complete")
194 |
195 | # Allocate memory to store the distances for each sample to its
196 | # closer center for reallocation in case of ties
197 | distances = np.zeros(shape=(X.shape[0],), dtype=X.dtype)
198 |
199 | # iterations
200 | for i in range(max_iter):
201 | centers_old = centers.copy()
202 | # labels assignment is also called the E-step of EM
203 | labels, inertia = \
204 | _labels_inertia(X, sample_weight, x_squared_norms,
205 | centers, distances=distances, same_cluster_size=same_cluster_size)
206 |
207 | # computation of the means is also called the M-step of EM
208 | centers = _centers_dense(
209 | X, sample_weight, labels, n_clusters, distances)
210 |
211 | if verbose:
212 | print("Iteration %2d, inertia %.3f" % (i, inertia))
213 |
214 | if best_inertia is None or inertia < best_inertia:
215 | best_labels = labels.copy()
216 | best_centers = centers.copy()
217 | best_inertia = inertia
218 |
219 | center_shift_total = squared_norm(centers_old - centers)
220 | if center_shift_total <= tol:
221 | if verbose:
222 | print("Converged at iteration %d: "
223 | "center shift %e within tolerance %e"
224 | % (i, center_shift_total, tol))
225 | break
226 |
227 | if center_shift_total > 0:
228 | # rerun E-step in case of non-convergence so that predicted labels
229 | # match cluster centers
230 | best_labels, best_inertia = \
231 | _labels_inertia(X, sample_weight, x_squared_norms,
232 | best_centers, distances=distances, same_cluster_size=same_cluster_size)
233 |
234 | return best_labels, best_inertia, best_centers, i + 1
235 |
--------------------------------------------------------------------------------
/regularize_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | def standard(model, arch, num_classes):
6 | if arch == "mobilenetv2":
7 | new_model = MobileNetV2(num_classes=num_classes)
8 | new_model.conv1 = model.conv1
9 | new_model.bn1 = model.bn1
10 | for new_layer, layer in zip(new_model.layers, model.layers):
11 | new_layer.conv1 = layer.conv1
12 | new_layer.bn1 = layer.bn1
13 | new_layer.conv2 = layer.conv2
14 | new_layer.bn2 = layer.bn2
15 | new_layer.conv3 = layer.conv3
16 | new_layer.bn3 = layer.bn3
17 | new_layer.shortcut = layer.shortcut
18 | new_model.conv2 = model.conv2
19 | new_model.bn2 = model.bn2
20 | new_model.linear = model.linear
21 | else:
22 | new_model = ShuffleNetV2(1)
23 | new_model.conv1 = model.conv1
24 | new_model.bn1 = model.bn1
25 | for new_layer, layer in [(new_model.layer1, model.layer1), (new_model.layer2, model.layer2), (new_model.layer3, model.layer3)]:
26 | new_layer[0].conv1 = layer[0].conv1
27 | new_layer[0].bn1 = layer[0].bn1
28 | new_layer[0].conv2 = layer[0].conv2
29 | new_layer[0].bn2 = layer[0].bn2
30 | new_layer[0].conv3 = layer[0].conv3
31 | new_layer[0].bn3 = layer[0].bn3
32 | new_layer[0].conv4 = layer[0].conv4
33 | new_layer[0].bn4 = layer[0].bn4
34 | new_layer[0].conv5 = layer[0].conv5
35 | new_layer[0].bn5 = layer[0].bn5
36 | new_layer[0].shuffle = layer[0].shuffle
37 | for i in range(1, len(new_layer)):
38 | new_layer[i].split = layer[i].split
39 | new_layer[i].conv1 = layer[i].conv1
40 | new_layer[i].bn1 = layer[i].bn1
41 | new_layer[i].conv2 = layer[i].conv2
42 | new_layer[i].bn2 = layer[i].bn2
43 | new_layer[i].conv3 = layer[i].conv3
44 | new_layer[i].bn3 = layer[i].bn3
45 | new_layer[i].shuffle = layer[i].shuffle
46 | new_model.conv2 = model.conv2
47 | new_model.bn2 = model.bn2
48 | new_model.linear = model.linear
49 | return new_model
50 |
51 |
52 |
53 |
54 | class Block(nn.Module):
55 | '''expand + depthwise + pointwise'''
56 | def __init__(self, in_planes, out_planes, expansion, stride):
57 | super(Block, self).__init__()
58 | self.stride = stride
59 |
60 | planes = expansion * in_planes
61 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, stride=1, padding=0, bias=False)
62 | self.bn1 = nn.BatchNorm2d(planes)
63 | self.relu1 = nn.ReLU(inplace=True)
64 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, groups=planes, bias=False)
65 | self.bn2 = nn.BatchNorm2d(planes)
66 | self.relu2 = nn.ReLU(inplace=True)
67 | self.conv3 = nn.Conv2d(planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
68 | self.bn3 = nn.BatchNorm2d(out_planes)
69 |
70 | self.shortcut = nn.Sequential()
71 | if stride == 1 and in_planes != out_planes:
72 | self.shortcut = nn.Sequential(
73 | nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False),
74 | nn.BatchNorm2d(out_planes),
75 | )
76 |
77 | def forward(self, x):
78 | out = self.relu1(self.bn1(self.conv1(x)))
79 | out = self.relu2(self.bn2(self.conv2(out)))
80 | out = self.bn3(self.conv3(out))
81 | out = out + self.shortcut(x) if self.stride==1 else out
82 | return out
83 |
84 |
85 | class MobileNetV2(nn.Module):
86 | # (expansion, out_planes, num_blocks, stride)
87 | cfg = [(1, 16, 1, 1),
88 | (6, 24, 2, 1), # NOTE: change stride 2 -> 1 for CIFAR10
89 | (6, 32, 3, 2),
90 | (6, 64, 4, 2),
91 | (6, 96, 3, 1),
92 | (6, 160, 3, 2),
93 | (6, 320, 1, 1)]
94 |
95 | def __init__(self, num_classes=10):
96 | super(MobileNetV2, self).__init__()
97 | # NOTE: change conv1 stride 2 -> 1 for CIFAR10
98 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False)
99 | self.bn1 = nn.BatchNorm2d(32)
100 | self.relu1 = nn.ReLU(inplace=True)
101 | self.layers = self._make_layers(in_planes=32)
102 | self.conv2 = nn.Conv2d(320, 1280, kernel_size=1, stride=1, padding=0, bias=False)
103 | self.bn2 = nn.BatchNorm2d(1280)
104 | self.relu2 = nn.ReLU(inplace=True)
105 | self.linear = nn.Linear(1280, num_classes)
106 |
107 | def _make_layers(self, in_planes):
108 | layers = []
109 | for expansion, out_planes, num_blocks, stride in self.cfg:
110 | strides = [stride] + [1]*(num_blocks-1)
111 | for stride in strides:
112 | layers.append(Block(in_planes, out_planes, expansion, stride))
113 | in_planes = out_planes
114 | return nn.Sequential(*layers)
115 |
116 | def forward(self, x, features_only=False):
117 | out = self.relu1(self.bn1(self.conv1(x)))
118 | out = self.layers(out)
119 | out = self.relu2(self.bn2(self.conv2(out)))
120 | # NOTE: change pooling kernel_size 7 -> 4 for CIFAR10
121 | out = F.avg_pool2d(out, 4)
122 | out = out.view(out.size(0), -1)
123 | if not features_only:
124 | out = self.linear(out)
125 | return out
126 |
127 |
128 | class ShuffleBlock(nn.Module):
129 | def __init__(self, groups=2):
130 | super(ShuffleBlock, self).__init__()
131 | self.groups = groups
132 |
133 | def forward(self, x):
134 | '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]'''
135 | N, C, H, W = x.size()
136 | g = self.groups
137 | return x.view(N, g, C//g, H, W).permute(0, 2, 1, 3, 4).reshape(N, C, H, W)
138 |
139 |
140 | class SplitBlock(nn.Module):
141 | def __init__(self, ratio):
142 | super(SplitBlock, self).__init__()
143 | self.ratio = ratio
144 |
145 | def forward(self, x):
146 | c = int(x.size(1) * self.ratio)
147 | return x[:, :c, :, :], x[:, c:, :, :]
148 |
149 |
150 | class BasicBlock(nn.Module):
151 | def __init__(self, in_channels, split_ratio=0.5):
152 | super(BasicBlock, self).__init__()
153 | self.split = SplitBlock(split_ratio)
154 | in_channels = int(in_channels * split_ratio)
155 | self.conv1 = nn.Conv2d(in_channels, in_channels,
156 | kernel_size=1, bias=False)
157 | self.bn1 = nn.BatchNorm2d(in_channels)
158 | self.relu1 = nn.ReLU(inplace=True)
159 | self.conv2 = nn.Conv2d(in_channels, in_channels,
160 | kernel_size=3, stride=1, padding=1, groups=in_channels, bias=False)
161 | self.bn2 = nn.BatchNorm2d(in_channels)
162 | self.conv3 = nn.Conv2d(in_channels, in_channels,
163 | kernel_size=1, bias=False)
164 | self.bn3 = nn.BatchNorm2d(in_channels)
165 | self.relu3 = nn.ReLU(inplace=True)
166 | self.shuffle = ShuffleBlock()
167 |
168 | def forward(self, x):
169 | x1, x2 = self.split(x)
170 | out = self.relu1(self.bn1(self.conv1(x2)))
171 | out = self.bn2(self.conv2(out))
172 | out = self.relu3(self.bn3(self.conv3(out)))
173 | out = torch.cat([x1, out], 1)
174 | out = self.shuffle(out)
175 | return out
176 |
177 |
178 | class DownBlock(nn.Module):
179 | def __init__(self, in_channels, out_channels):
180 | super(DownBlock, self).__init__()
181 | mid_channels = out_channels // 2
182 | # left
183 | self.conv1 = nn.Conv2d(in_channels, in_channels,
184 | kernel_size=3, stride=2, padding=1, groups=in_channels, bias=False)
185 | self.bn1 = nn.BatchNorm2d(in_channels)
186 | self.conv2 = nn.Conv2d(in_channels, mid_channels,
187 | kernel_size=1, bias=False)
188 | self.bn2 = nn.BatchNorm2d(mid_channels)
189 | self.relu2 = nn.ReLU(inplace=True)
190 | # right
191 | self.conv3 = nn.Conv2d(in_channels, mid_channels,
192 | kernel_size=1, bias=False)
193 | self.bn3 = nn.BatchNorm2d(mid_channels)
194 | self.relu3 = nn.ReLU(inplace=True)
195 | self.conv4 = nn.Conv2d(mid_channels, mid_channels,
196 | kernel_size=3, stride=2, padding=1, groups=mid_channels, bias=False)
197 | self.bn4 = nn.BatchNorm2d(mid_channels)
198 | self.conv5 = nn.Conv2d(mid_channels, mid_channels,
199 | kernel_size=1, bias=False)
200 | self.bn5 = nn.BatchNorm2d(mid_channels)
201 | self.relu5 = nn.ReLU(inplace=True)
202 |
203 | self.shuffle = ShuffleBlock()
204 |
205 | def forward(self, x):
206 | # left
207 | out1 = self.bn1(self.conv1(x))
208 | out1 = self.relu2(self.bn2(self.conv2(out1)))
209 | # right
210 | out2 = self.relu3(self.bn3(self.conv3(x)))
211 | out2 = self.bn4(self.conv4(out2))
212 | out2 = self.relu5(self.bn5(self.conv5(out2)))
213 | # concat
214 | out = torch.cat([out1, out2], 1)
215 | out = self.shuffle(out)
216 | return out
217 |
218 |
219 | class ShuffleNetV2(nn.Module):
220 | def __init__(self, net_size):
221 | super(ShuffleNetV2, self).__init__()
222 | out_channels = configs[net_size]['out_channels']
223 | num_blocks = configs[net_size]['num_blocks']
224 |
225 | self.conv1 = nn.Conv2d(3, 24, kernel_size=3,
226 | stride=1, padding=1, bias=False)
227 | self.bn1 = nn.BatchNorm2d(24)
228 | self.relu1 = nn.ReLU(inplace=True)
229 | self.in_channels = 24
230 | self.layer1 = self._make_layer(out_channels[0], num_blocks[0])
231 | self.layer2 = self._make_layer(out_channels[1], num_blocks[1])
232 | self.layer3 = self._make_layer(out_channels[2], num_blocks[2])
233 | self.conv2 = nn.Conv2d(out_channels[2], out_channels[3],
234 | kernel_size=1, stride=1, padding=0, bias=False)
235 | self.bn2 = nn.BatchNorm2d(out_channels[3])
236 | self.relu2 = nn.ReLU(inplace=True)
237 | self.linear = nn.Linear(out_channels[3], 10)
238 |
239 | def _make_layer(self, out_channels, num_blocks):
240 | layers = [DownBlock(self.in_channels, out_channels)]
241 | for i in range(num_blocks):
242 | layers.append(BasicBlock(out_channels))
243 | self.in_channels = out_channels
244 | return nn.Sequential(*layers)
245 |
246 | def forward(self, x, features_only=False):
247 | out = self.relu1(self.bn1(self.conv1(x)))
248 | # out = F.max_pool2d(out, 3, stride=2, padding=1)
249 | out = self.layer1(out)
250 | out = self.layer2(out)
251 | out = self.layer3(out)
252 | out = self.relu2(self.bn2(self.conv2(out)))
253 | out = F.avg_pool2d(out, 4)
254 | out = out.view(out.size(0), -1)
255 | if not features_only:
256 | out = self.linear(out)
257 | return out
258 |
259 |
260 | configs = {
261 | 0.5: {
262 | 'out_channels': (48, 96, 192, 1024),
263 | 'num_blocks': (3, 7, 3)
264 | },
265 |
266 | 1: {
267 | 'out_channels': (116, 232, 464, 1024),
268 | 'num_blocks': (3, 7, 3)
269 | },
270 | 1.5: {
271 | 'out_channels': (176, 352, 704, 1024),
272 | 'num_blocks': (3, 7, 3)
273 | },
274 | 2: {
275 | 'out_channels': (224, 488, 976, 2048),
276 | 'num_blocks': (3, 7, 3)
277 | }
278 | }
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/prune_utils/prune.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Variable
3 | from torchvision import models
4 | import sys
5 | import numpy as np
6 | from prune_utils.layer_prune import (
7 | prune_output_linear_layer_,
8 | prune_contiguous_conv2d_,
9 | prune_conv2d_out_channels_,
10 | prune_batchnorm2d_,
11 | prune_linear_in_features_,
12 | prune_contiguous_conv2d_last)
13 |
14 | def replace_layers(model, i, indexes, layers):
15 | if i in indexes:
16 | return layers[indexes.index(i)]
17 | return model[i]
18 |
19 | def prune_vgg16_conv_layer(model, layer_index, filter_index, use_batch_norm=False):
20 | _, conv = list(model.features._modules.items())[layer_index]
21 | next_conv = None
22 | offset = 1
23 |
24 | while layer_index + offset < len(model.features._modules.items()):
25 | res = list(model.features._modules.items())[layer_index+offset]
26 | if isinstance(res[1], torch.nn.modules.conv.Conv2d):
27 | next_name, next_conv = res
28 | break
29 | offset = offset + 1
30 |
31 | new_conv = \
32 | torch.nn.Conv2d(in_channels = conv.in_channels, \
33 | out_channels = conv.out_channels - 1,
34 | kernel_size = conv.kernel_size, \
35 | stride = conv.stride,
36 | padding = conv.padding,
37 | dilation = conv.dilation,
38 | groups = conv.groups,
39 | bias = True)#conv.bias)
40 |
41 | old_weights = conv.weight.data.cpu().numpy()
42 | new_weights = new_conv.weight.data.cpu().numpy()
43 | new_weights[:filter_index, :, :, :] = old_weights[:filter_index, :, :, :]
44 | new_weights[filter_index : , :, :, :] = old_weights[filter_index + 1 :, :, :, :]
45 | new_conv.weight.data = torch.from_numpy(new_weights).cuda()
46 |
47 | if conv.bias is not None:
48 | bias_numpy = conv.bias.data.cpu().numpy()
49 | bias = np.zeros(shape = (bias_numpy.shape[0] - 1), dtype = np.float32)
50 | bias[:filter_index] = bias_numpy[:filter_index]
51 | bias[filter_index : ] = bias_numpy[filter_index + 1 :]
52 | new_conv.bias.data = torch.from_numpy(bias).cuda()
53 |
54 | if use_batch_norm:
55 | _, bn = list(model.features._modules.items())[layer_index + 1]
56 | new_bn = torch.nn.BatchNorm2d(conv.out_channels - 1)
57 |
58 | old_weights = bn.weight.data.cpu().numpy()
59 | new_weights = new_bn.weight.data.cpu().numpy()
60 | new_weights[:filter_index] = old_weights[:filter_index]
61 | new_weights[filter_index:] = old_weights[filter_index+1:]
62 |
63 |
64 | old_bias = bn.bias.data.cpu().numpy()
65 | new_bias = new_bn.bias.data.cpu().numpy()
66 | new_bias[:filter_index] = old_bias[:filter_index]
67 | new_bias[filter_index:] = old_bias[filter_index+1:]
68 |
69 |
70 |
71 | old_running_mean = bn.running_mean.data.cpu().numpy()
72 | new_running_mean = new_bn.running_mean.data.cpu().numpy()
73 | new_running_mean[:filter_index] = old_running_mean[:filter_index]
74 | new_running_mean[filter_index:] = old_running_mean[filter_index+1:]
75 |
76 |
77 | old_running_var = bn.running_var.data.cpu().numpy()
78 | new_running_var = new_bn.running_var.data.cpu().numpy()
79 | new_running_var[:filter_index] = old_running_var[:filter_index]
80 | new_running_var[filter_index:] = old_running_var[filter_index+1:]
81 |
82 | new_bn.weight.data = torch.from_numpy(new_weights).cuda()
83 | new_bn.bias.data = torch.from_numpy(new_bias).cuda()
84 | new_bn.running_mean.data = torch.from_numpy(new_running_mean).cuda()
85 | new_bn.running_var.data = torch.from_numpy(new_running_var).cuda()
86 |
87 |
88 | if not next_conv is None:
89 | next_new_conv = \
90 | torch.nn.Conv2d(in_channels = next_conv.in_channels - 1,\
91 | out_channels = next_conv.out_channels, \
92 | kernel_size = next_conv.kernel_size, \
93 | stride = next_conv.stride,
94 | padding = next_conv.padding,
95 | dilation = next_conv.dilation,
96 | groups = next_conv.groups,
97 | bias = True)#next_conv.bias)
98 |
99 | old_weights = next_conv.weight.data.cpu().numpy()
100 | new_weights = next_new_conv.weight.data.cpu().numpy()
101 |
102 | new_weights[:, : filter_index, :, :] = old_weights[:, : filter_index, :, :]
103 | new_weights[:, filter_index : , :, :] = old_weights[:, filter_index + 1 :, :, :]
104 | next_new_conv.weight.data = torch.from_numpy(new_weights).cuda()
105 |
106 | if next_conv.bias is not None:
107 | next_new_conv.bias.data = torch.from_numpy(next_conv.bias.data.cpu().numpy().copy()).cuda()
108 |
109 | if not next_conv is None:
110 | features = torch.nn.Sequential(
111 | *(replace_layers(model.features, i, [layer_index, layer_index + 1, layer_index+offset], \
112 | [new_conv, new_bn, next_new_conv]) for i, _ in enumerate(model.features)))
113 | del model.features
114 | del conv
115 |
116 | model.features = features
117 | else:
118 | #Prunning the last conv layer. This affects the first linear layer of the classifier.
119 | model.features = torch.nn.Sequential(
120 | *(replace_layers(model.features, i, [layer_index, layer_index+1], \
121 | [new_conv, new_bn]) for i, _ in enumerate(model.features)))
122 | layer_index = 0
123 | old_linear_layer = None
124 | if len(model.classifier._modules):
125 | for _, module in model.classifier._modules.items():
126 | if isinstance(module, torch.nn.Linear):
127 | old_linear_layer = module
128 | break
129 | layer_index = layer_index + 1
130 | else:
131 | old_linear_layer = model.classifier
132 |
133 | if old_linear_layer is None:
134 | raise BaseException("No linear layer found in classifier")
135 | params_per_input_channel = old_linear_layer.in_features / conv.out_channels
136 |
137 | new_linear_layer = \
138 | torch.nn.Linear(int(old_linear_layer.in_features - params_per_input_channel),
139 | int(old_linear_layer.out_features))
140 |
141 | old_weights = old_linear_layer.weight.data.cpu().numpy()
142 | new_weights = new_linear_layer.weight.data.cpu().numpy()
143 |
144 | new_weights[:, : int(filter_index * params_per_input_channel)] = \
145 | old_weights[:, : int(filter_index * params_per_input_channel)]
146 | new_weights[:, int(filter_index * params_per_input_channel) :] = \
147 | old_weights[:, int((filter_index + 1) * params_per_input_channel) :]
148 |
149 | new_linear_layer.bias.data = torch.from_numpy(old_linear_layer.bias.data.cpu().numpy()).cuda()
150 |
151 | new_linear_layer.weight.data = torch.from_numpy(new_weights).cuda()
152 |
153 | if len(model.classifier._modules):
154 | classifier = torch.nn.Sequential(
155 | *(replace_layers(model.classifier, i, [layer_index], \
156 | [new_linear_layer]) for i, _ in enumerate(model.classifier)))
157 | else:
158 | classifier = torch.nn.Sequential(new_linear_layer)
159 |
160 | del model.classifier
161 | del next_conv
162 | del conv
163 | model.classifier = classifier
164 |
165 | return model
166 |
167 | def prune_last_fc_layers(model, class_indices, filter_indices = None, use_bce=False):
168 | layer_index = 0
169 | old_linear_layer = None
170 | counter = 0
171 | out_dim_prev = None
172 | filter_idx_mask = None
173 | linear_count = 0
174 |
175 | for idx, module in enumerate(model.classifier.modules()):
176 | if linear_count >= len(filter_indices):
177 | break
178 |
179 | if isinstance(module, torch.nn.Linear):
180 | old_linear_layer = module
181 | old_weights = old_linear_layer.weight.data
182 | # The new in dimension is the out dimensio of the last layer pruned,
183 | # if counter == 1, then the last layer is the the last conv layer,
184 | # otherwise, it is the previous linear layer
185 | in_dim = int(old_linear_layer.in_features) if counter == 1 else out_dim
186 | prev_filter_idx_mask = filter_idx_mask
187 | # The channel mask has the number of channels as the out dim - pruning candidates
188 | filter_idx_mask = [i for i in range(old_weights.shape[0]) if i not in filter_indices[linear_count]]
189 | out_dim = len(filter_idx_mask)
190 |
191 | new_linear_layer = \
192 | torch.nn.Linear(in_dim, out_dim)
193 |
194 | # The new bias has the shape of the out dimension
195 | new_linear_layer.bias.data = old_linear_layer.bias.data[filter_idx_mask]
196 | # The weight format is out_dim x in_dim, so we first selectively index the out dim, using the channel mask
197 | # Then selectively index the in dim, by the previous layer's filter mask
198 | # If the last layer was the last conv layer, prev_filter_idx_mask is None, in which case it indexes everything (no mask)
199 | new_linear_layer.weight.data = old_weights[filter_idx_mask, :][:, prev_filter_idx_mask].squeeze()
200 |
201 | # Set the new linear layer with the model
202 | model.classifier[idx - 1] = new_linear_layer
203 |
204 | linear_count += 1
205 | counter += 1
206 |
207 |
208 | counter = 0
209 | layer_index = 0
210 | if len(model.classifier._modules):
211 | for _, module in model.classifier._modules.items():
212 | if isinstance(module, torch.nn.Linear):
213 | old_linear_layer = module
214 | layer_index = counter
215 | counter += 1
216 | else:
217 | old_linear_layer = model.classifier
218 |
219 | if old_linear_layer is None:
220 | raise BaseException("No linear layer found in classifier")
221 |
222 | # If using bce, we don't need a negative out
223 | bce_offset = 0 if use_bce else 1
224 | # Create a new linear layer, in dimension is the out dimension of previous layer
225 | # out dimension is the number of classes with the pruned model
226 | new_linear_layer = \
227 | torch.nn.Linear(out_dim,
228 | len(class_indices) + bce_offset)
229 |
230 | old_weights = old_linear_layer.weight.data.cpu().numpy()
231 | new_weights = new_linear_layer.weight.data.cpu().numpy()
232 |
233 | new_weights[bce_offset:, :] = old_weights[class_indices][:,filter_idx_mask]
234 |
235 |
236 |
237 | new_linear_layer.bias.data[bce_offset:] = torch.from_numpy(np.asarray(old_linear_layer.bias.data.cpu().numpy()[class_indices])).cuda()
238 |
239 | new_linear_layer.weight.data = torch.from_numpy(new_weights).cuda()
240 |
241 | if len(model.classifier._modules):
242 | classifier = torch.nn.Sequential(
243 | *(replace_layers(model.classifier, i, [layer_index], \
244 | [new_linear_layer]) for i, _ in enumerate(model.classifier)))
245 | else:
246 | classifier = torch.nn.Sequential(new_linear_layer)
247 |
248 | del model.classifier
249 | model.classifier = classifier
250 |
251 | return model
252 |
253 | def prune_resnet50(model, candidates, group_indices):
254 | layers = list(model.children())
255 | # layer[0] : Conv2d
256 | # layer[1] : BatchNorm2e
257 | # layer[2] : ReLU
258 | layer_index = 1
259 | for stage in (layers[4], layers[5], layers[6], layers[7]):
260 | for index, block in enumerate(stage.children()):
261 | assert isinstance(block, models.resnet.Bottleneck), "only support bottleneck block"
262 | children_dict = dict(block.named_children())
263 | conv1 = children_dict['conv1']
264 | conv2 = children_dict['conv2']
265 | conv3 = children_dict['conv3']
266 | prune_contiguous_conv2d_(
267 | conv1, conv2, candidates[layer_index], bn=children_dict['bn1'])
268 | layer_index += 1
269 | prune_contiguous_conv2d_(
270 | conv2, conv3, candidates[layer_index], bn=children_dict['bn2'])
271 | layer_index += 2
272 | # because we are using the output of the ReLU, the output of
273 | # the downsample is merged before ReLU, so we do not need to
274 | # increase the layer index
275 | prune_output_linear_layer_(model.fc, group_indices, use_bce=False)
276 |
277 | if __name__ == '__main__':
278 | model = models.vgg16(pretrained=True)
279 | model.train()
280 |
281 | t0 = time.time()
282 | model = prune_conv_layer(model, 28, 10)
283 | print("The prunning took", time.time() - t0)
284 |
--------------------------------------------------------------------------------
/prune_and_get_model.py:
--------------------------------------------------------------------------------
1 | import re
2 | import glob
3 | import models.cifar as models
4 | import os
5 | import sys
6 | import argparse
7 | import pathlib
8 | import pickle
9 | import copy
10 | import numpy as np
11 | import re
12 | import torch
13 | from torch import nn
14 | import load_model
15 | import torch.multiprocessing as mp
16 |
17 | from regularize_model import standard
18 | from prune_utils.prune import prune_vgg16_conv_layer, prune_last_fc_layers, prune_resnet50
19 | from prune_utils.layer_prune import (
20 | prune_output_linear_layer_,
21 | prune_contiguous_conv2d_,
22 | prune_conv2d_out_channels_,
23 | prune_batchnorm2d_,
24 | prune_linear_in_features_,
25 | prune_mobile_block,
26 | prune_shuffle_layer)
27 | from models.cifar.resnet import Bottleneck
28 | import torchvision.models as imagenet_models
29 |
30 | parser = argparse.ArgumentParser(description='VGG with mask layer on cifar10')
31 | parser.add_argument('-d', '--dataset', required=True, type=str)
32 | parser.add_argument('-c', '--prune-candidates', default="./prune_candidate_logs/",
33 | type=str, help='Directory which stores the prune candidates for each model')
34 | parser.add_argument('-a', '--arch', default='vgg19_bn',
35 | type=str, help='The architecture of the trained model')
36 | parser.add_argument('-r', '--resume', default='', type=str,
37 | help='The path to the checkpoints')
38 | parser.add_argument('-s', '--save', default='./pruned_models',
39 | type=str, help='The path to store the pruned models')
40 | parser.add_argument('--bce', default=False, type=bool,
41 | help='Prune according to binary cross entropy loss, i.e. no additional negative output for classifer')
42 | parser.add_argument('--pretrained', dest='pretrained', action='store_true',
43 | help='use pre-trained model')
44 | args = parser.parse_args()
45 |
46 |
47 | def prune_vgg(model, pruned_candidates, group_indices):
48 | features = model.features
49 | conv_indices = [i for i, layer in enumerate(features) if isinstance(layer, nn.Conv2d)]
50 | conv_bn_indices = [i for i, layer in enumerate(features) if isinstance(layer, (nn.Conv2d, nn.BatchNorm2d))]
51 | assert len(conv_indices) == len(pruned_candidates)
52 | assert len(conv_indices) * 2 == len(conv_bn_indices)
53 |
54 | for i, conv_index in enumerate(conv_indices[:-1]):
55 | next_conv = None
56 | for j in range(conv_index + 1, len(features)):
57 | l = features[j]
58 | if isinstance(l, nn.Conv2d):
59 | next_conv = l
60 | break
61 | if next_conv is None:
62 | break
63 | bn = model.features[conv_index + 1]
64 | assert isinstance(bn, nn.BatchNorm2d)
65 | prune_contiguous_conv2d_(
66 | features[conv_index],
67 | next_conv,
68 | pruned_candidates[i],
69 | bn=bn)
70 |
71 | # Prunning the last conv layer. This affects the first linear layer of the classifier.
72 | last_conv = features[conv_indices[-1]]
73 | classifier = model.classifier
74 | assert classifier.in_features % last_conv.out_channels == 0
75 | params_per_input_channel = classifier.in_features // last_conv.out_channels
76 |
77 | pruned_indices = pruned_candidates[-1]
78 | prune_conv2d_out_channels_(last_conv, pruned_indices)
79 | prune_batchnorm2d_(features[conv_bn_indices[-1]], pruned_indices)
80 |
81 | linear_pruned_indices = []
82 | for i in pruned_indices:
83 | linear_pruned_indices += list(range(i * params_per_input_channel, (i + 1) * params_per_input_channel))
84 |
85 | prune_linear_in_features_(classifier, linear_pruned_indices)
86 | # prune the output of the classifier
87 | prune_output_linear_layer_(classifier, group_indices, use_bce=args.bce)
88 |
89 |
90 | def prune_resnet(model, candidates, group_indices):
91 | layers = list(model.children())
92 | # layer[0] : Conv2d
93 | # layer[1] : BatchNorm2e
94 | # layer[2] : ReLU
95 | layer_index = 1
96 | for stage in (layers[3], layers[4], layers[5]):
97 | for block in stage.children():
98 | assert isinstance(block, Bottleneck), "only support bottleneck block"
99 | children_dict = dict(block.named_children())
100 | conv1 = children_dict['conv1']
101 | conv2 = children_dict['conv2']
102 | conv3 = children_dict['conv3']
103 |
104 | prune_contiguous_conv2d_(
105 | conv1, conv2, candidates[layer_index], bn=children_dict['bn1'])
106 | layer_index += 1
107 | prune_contiguous_conv2d_(
108 | conv2, conv3, candidates[layer_index], bn=children_dict['bn2'])
109 | layer_index += 2
110 | # because we are using the output of the ReLU, the output of
111 | # the downsample is merged before ReLU, so we do not need to
112 | # increase the layer index
113 | prune_output_linear_layer_(model.fc, group_indices, use_bce=args.bce)
114 |
115 | def prune_mobilenetv2(model, candidates, group_indices):
116 | layers = list(model.layers)
117 | layer_index = 1
118 | for block in layers:
119 | conv1 = block.conv1
120 | bn1 = block.bn1
121 | conv2 = block.conv2
122 | bn2 = block.bn2
123 | conv3 = block.conv3
124 | prune_1 = candidates[layer_index]
125 | prune_2 = candidates[layer_index+1]
126 | prune_mobile_block(conv1, conv2, conv3, prune_1, prune_2, bn1, bn2)
127 | layer_index += 2
128 | prune_output_linear_layer_(model.linear, group_indices, use_bce=args.bce)
129 |
130 | def prune_shufflenetv2(model, candidates, group_indices):
131 | layer1, layer2, layer3 = model.layer1, model.layer2, model.layer3
132 | layer1_candidates = candidates[1:15]
133 | layer2_candidates = candidates[15:41]
134 | layer3_candidates = candidates[41:55]
135 | prune_shuffle_layer(layer1, layer1_candidates)
136 | prune_shuffle_layer(layer2, layer2_candidates)
137 | prune_shuffle_layer(layer3, layer3_candidates)
138 | prune_output_linear_layer_(model.linear, group_indices, use_bce=args.bce)
139 |
140 | def filename_to_index(filename):
141 | filename = filename[6+len(args.prune_candidates):]
142 | return int(filename[:filename.index('_')])
143 |
144 | def update_list(l):
145 | for i in range(len(l)):
146 | l[i] -= 1
147 |
148 | def prune_cifar_worker(proc_ind, i, new_model, candidates, group_indices, arch, model_save_directory):
149 | num_gpus = torch.cuda.device_count()
150 | new_model.cuda(i % num_gpus)
151 | group_indices = group_indices.tolist()
152 | if args.arch.startswith('vgg'):
153 | prune_vgg(new_model, candidates, group_indices)
154 | elif args.arch.startswith('resnet'):
155 | prune_resnet(new_model, candidates, group_indices)
156 | elif args.arch.startswith('mobile'):
157 | prune_mobilenetv2(new_model, candidates, group_indices)
158 | elif args.arch.startswith('shuffle'):
159 | prune_shufflenetv2(new_model, candidates, group_indices)
160 | else:
161 | raise NotImplementedError
162 |
163 | # save the pruned model
164 | pruned_model_name = f"{arch}_{i}_pruned_model.pth"
165 | torch.save(new_model, os.path.join(
166 | model_save_directory, pruned_model_name))
167 | print('Pruned model saved at', model_save_directory)
168 |
169 | def prune_imagenet_worker(proc_ind, model, candidates, group_indices, group_id, model_save_directory):
170 | num_gpus = torch.cuda.device_count()
171 | torch.cuda.set_device(group_id % num_gpus)
172 | model.cuda(group_id % num_gpus)
173 | if args.arch != "resnet50":
174 | conv_indices = [idx for idx, (n, p) in enumerate(model.features._modules.items()) if isinstance(p, nn.modules.conv.Conv2d)]
175 | offset = 0
176 | for layer_index, filter_list in zip(conv_indices, candidates):
177 | offset += 1
178 | filters_to_remove = list(filter_list)
179 | sorted(filters_to_remove)
180 |
181 | while len(filters_to_remove):
182 | filter_index = filters_to_remove.pop(0)
183 | model = prune_vgg16_conv_layer(model, layer_index, filter_index, use_batch_norm=True)
184 | update_list(filters_to_remove)
185 |
186 | # save the pruned model
187 | # The input dimension of the first fc layer is pruned from above
188 | model = prune_last_fc_layers(model, \
189 | group_indices, \
190 | filter_indices = candidates[offset:], \
191 | use_bce = args.bce)
192 | else:
193 | prune_resnet50(model, candidates, group_indices)
194 |
195 | pruned_model_name = args.arch + '_{}'.format(group_id) + '_pruned_model.pth'
196 | print('Grouped mode %s Total params: %.2fM' % (group_id ,sum(p.numel() for p in model.parameters())/1000000.0))
197 | torch.save(model, os.path.join(model_save_directory, pruned_model_name))
198 | print('Pruned model saved at', model_save_directory)
199 |
200 | def main():
201 | use_cuda = torch.cuda.is_available()
202 | # load groups
203 | file_names = [f for f in glob.glob(args.prune_candidates + "group_*.npy", recursive=False)]
204 | file_names.sort(key=filename_to_index)
205 | groups = np.load(open(args.prune_candidates + "grouping_config.npy", "rb"))
206 |
207 | # create pruned model save path
208 | model_save_directory = os.path.join(args.save, args.arch)
209 | pathlib.Path(model_save_directory).mkdir(parents=True, exist_ok=True)
210 | np.save(open(os.path.join(args.save, "grouping_config.npy"), "wb"), groups)
211 | if len(groups[0]) == 1:
212 | args.bce = True
213 | print(f'==> Preparing dataset {args.dataset}')
214 | if args.dataset in ['cifar10', 'cifar100']:
215 | if args.dataset == 'cifar10':
216 | num_classes = 10
217 | elif args.dataset == 'cifar100':
218 | num_classes = 100
219 |
220 | processes = []
221 | # for each class
222 | for i, (group_indices, file_name) in enumerate(zip(groups, file_names)):
223 | # load pruning candidates
224 | with open(file_name, 'rb') as f:
225 | candidates = pickle.load(f)
226 | # load checkpoints
227 | model = load_model.load_pretrain_model(
228 | args.arch, args.dataset, args.resume, num_classes, use_cuda)
229 | new_model = copy.deepcopy(model)
230 | if args.arch in ["mobilenetv2", "shufflenetv2"]:
231 | new_model = standard(new_model, args.arch, num_classes)
232 | p = mp.spawn(prune_cifar_worker, args=(i, new_model, candidates, group_indices, args.arch, model_save_directory), join=False)
233 | processes.append(p)
234 | for p in processes:
235 | p.join()
236 |
237 |
238 | elif args.dataset == 'imagenet':
239 | num_classes = len(groups)
240 | processes = []
241 | # for each class
242 | for group_id, file_name in enumerate(file_names):
243 | print('Pruning classes {} from candidates in {}'.format(group_id, file_name))
244 | group_indices = groups[group_id]
245 | # load pruning candidates
246 | print(file_name)
247 | candidates = np.load(open(file_name, 'rb'), allow_pickle=True).tolist()
248 |
249 | num_gpus = torch.cuda.device_count()
250 | # load checkpoints
251 | if args.pretrained:
252 | print("=> using pre-trained model '{}'".format(args.arch))
253 | model = imagenet_models.__dict__[args.arch](pretrained=True)
254 | # model = torch.nn.DataParallel(model).cuda() #TODO use DataParallel
255 | model = model.cuda(group_id % num_gpus)
256 | else:
257 | checkpoint = torch.load(args.resume)
258 | model = imagenet_models.__dict__[args.arch](num_classes=num_classes)
259 | # model = torch.nn.DataParallel(model).cuda() #TODO use DataParallel
260 | model = model.cuda(group_id % num_gpus)
261 | model.load_state_dict(checkpoint['state_dict'])
262 |
263 | # join existing num_gpus processes, to make sure only num_gpus processes are running at a time
264 | if group_id % num_gpus == 0:
265 | for p in processes:
266 | p.join()
267 | processes = []
268 |
269 | # model = model.module #TODO use DataParallel
270 | p = mp.spawn(prune_imagenet_worker, args=(model, candidates, group_indices, group_id, model_save_directory), join=False)
271 | processes.append(p)
272 |
273 | for p in processes:
274 | p.join()
275 | else:
276 | raise NotImplementedError
277 |
278 | if __name__ == '__main__':
279 | main()
280 |
--------------------------------------------------------------------------------
/imagenet_activations.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import random
4 | import shutil
5 | import time
6 | import warnings
7 | import sys
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.parallel
12 | import torch.backends.cudnn as cudnn
13 | import torch.distributed as dist
14 | import torch.optim
15 | import torch.multiprocessing as mp
16 | import torch.utils.data
17 | import torch.utils.data.distributed
18 | import torchvision.transforms as transforms
19 | import imagenet_dataset as datasets
20 | import torchvision.models as models
21 |
22 | import numpy as np
23 | from apoz_policy_imagenet import *
24 | import pdb
25 |
26 | model_names = sorted(name for name in models.__dict__
27 | if name.islower() and not name.startswith("__")
28 | and callable(models.__dict__[name]))
29 |
30 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
31 | parser.add_argument('--data', metavar='DIR',
32 | help='path to imagenet dataset')
33 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
34 | choices=model_names,
35 | help='model architecture: ' +
36 | ' | '.join(model_names) +
37 | ' (default: resnet18)')
38 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
39 | help='number of data loading workers (default: 4)')
40 | parser.add_argument('--epochs', default=90, type=int, metavar='N',
41 | help='number of total epochs to run')
42 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
43 | help='manual epoch number (useful on restarts)')
44 | parser.add_argument('-b', '--batch-size', default=64, type=int,
45 | metavar='N',
46 | help='mini-batch size (default: 256), this is the total '
47 | 'batch size of all GPUs on the current node when '
48 | 'using Data Parallel or Distributed Data Parallel')
49 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
50 | metavar='LR', help='initial learning rate', dest='lr')
51 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
52 | help='momentum')
53 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
54 | metavar='W', help='weight decay (default: 1e-4)',
55 | dest='weight_decay')
56 | parser.add_argument('-p', '--print-freq', default=1, type=int,
57 | metavar='N', help='print frequency (default: 10)')
58 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
59 | help='path to latest checkpoint (default: none)')
60 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
61 | help='evaluate model on validation set')
62 | parser.add_argument('--pretrained', dest='pretrained', action='store_true',
63 | help='use pre-trained model')
64 | parser.add_argument('--world-size', default=-1, type=int,
65 | help='number of nodes for distributed training')
66 | parser.add_argument('--rank', default=-1, type=int,
67 | help='node rank for distributed training')
68 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
69 | help='url used to set up distributed training')
70 | parser.add_argument('--dist-backend', default='nccl', type=str,
71 | help='distributed backend')
72 | parser.add_argument('--seed', default=None, type=int,
73 | help='seed for initializing training. ')
74 | parser.add_argument('--gpu', default=None, type=int,
75 | help='GPU id to use.')
76 | parser.add_argument('--multiprocessing-distributed', default=False, action='store_true',
77 | help='Use multi-processing distributed training to launch '
78 | 'N processes per node, which has N GPUs. This is the '
79 | 'fastest way to use PyTorch for either single node or '
80 | 'multi node data parallel training')
81 |
82 | parser.add_argument('--group', type=int, nargs='+', default=[],
83 | help='Generate activations based on the these class indices')
84 | parser.add_argument('--name', type=str, default='Name', help='Set the name id of the group')
85 |
86 |
87 | global num_layers
88 | num_layers = sys.maxsize
89 | global layer_idx
90 | layer_idx = 0
91 | num_batches = 0
92 | best_acc1 = 0
93 |
94 | def main():
95 | args = parser.parse_args()
96 | if args.seed is not None:
97 | random.seed(args.seed)
98 | torch.manual_seed(args.seed)
99 | cudnn.deterministic = True
100 | warnings.warn('You have chosen to seed training. '
101 | 'This will turn on the CUDNN deterministic setting, '
102 | 'which can slow down your training considerably! '
103 | 'You may see unexpected behavior when restarting '
104 | 'from checkpoints.')
105 |
106 | if args.gpu is not None:
107 | warnings.warn('You have chosen a specific GPU. This will completely '
108 | 'disable data parallelism.')
109 |
110 | main_worker(args.gpu, args)
111 |
112 |
113 | def main_worker(gpu, args):
114 | global best_acc1
115 | global num_layers
116 | global apoz_scores_by_layer
117 | global avg_scores_by_layer
118 | args.gpu = gpu
119 |
120 | if args.gpu is not None:
121 | print("Use GPU: {} for training".format(args.gpu))
122 |
123 | # create model
124 | if args.pretrained:
125 | print("=> using pre-trained model '{}'".format(args.arch))
126 | model = models.__dict__[args.arch](pretrained=True)
127 | else:
128 | print("=> creating model '{}'".format(args.arch))
129 | model = models.__dict__[args.arch]()
130 |
131 | if args.gpu is not None:
132 | print("checkpoint 1...")
133 | torch.cuda.set_device(args.gpu)
134 | model = model.cuda(args.gpu)
135 |
136 | # define loss function (criterion) and optimizer
137 | criterion = nn.CrossEntropyLoss()
138 |
139 | optimizer = torch.optim.SGD(model.parameters(), args.lr,
140 | momentum=args.momentum,
141 | weight_decay=args.weight_decay)
142 |
143 | # optionally resume from a checkpoint
144 | if args.resume:
145 | if os.path.isfile(args.resume):
146 | print("=> loading checkpoint '{}'".format(args.resume))
147 | checkpoint = torch.load(args.resume)
148 | args.start_epoch = checkpoint['epoch']
149 | best_acc1 = checkpoint['best_acc1']
150 | optimizer.load_state_dict(checkpoint['optimizer'])
151 | print("=> loaded checkpoint '{}' (epoch {})"
152 | .format(args.resume, checkpoint['epoch']))
153 | else:
154 | print("=> no checkpoint found at '{}'".format(args.resume))
155 |
156 | print("checkpoint 2...")
157 | apoz_scores_by_layer = []
158 | avg_scores_by_layer = []
159 | model = model.cuda(args.gpu)
160 |
161 | # Data loading code
162 | traindir = os.path.join(args.data, 'train')
163 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
164 | std=[0.229, 0.224, 0.225])
165 |
166 | train_dataset = datasets.ImageFolder(
167 | traindir,
168 | transforms.Compose([
169 | transforms.RandomResizedCrop(224),
170 | transforms.RandomHorizontalFlip(),
171 | transforms.ToTensor(),
172 | normalize,
173 | ]), activations=True, group=args.group)
174 |
175 | val_loader = torch.utils.data.DataLoader(
176 | train_dataset, batch_size=args.batch_size, shuffle=False,
177 | num_workers=args.workers, pin_memory=True, sampler=None)
178 |
179 | print("checkpoint 3...")
180 | if args.evaluate:
181 | validate(val_loader, model, criterion, args)
182 | generate_candidates(args.name)
183 | return
184 |
185 |
186 | def validate(val_loader, model, criterion, args):
187 | global layer_idx
188 | global num_batches
189 | global num_layers
190 | batch_time = AverageMeter('Time', ':6.3f')
191 | losses = AverageMeter('Loss', ':.4e')
192 | top1 = AverageMeter('Acc@1', ':6.2f')
193 | top5 = AverageMeter('Acc@5', ':6.2f')
194 | progress = ProgressMeter(len(val_loader), batch_time, losses, top1, top5,
195 | prefix='Test: ')
196 |
197 | # switch to evaluate mode
198 | model.apply(apply_hook)
199 | model.eval()
200 |
201 | with torch.no_grad():
202 | end = time.time()
203 | for i, (input, target) in enumerate(val_loader):
204 | num_batches += 1
205 | layer_idx = 0
206 | if args.gpu is not None:
207 | input = input.cuda(args.gpu)
208 | target = target.cuda(args.gpu)
209 | # compute output
210 | output = model(input)
211 | loss = criterion(output, target)
212 |
213 | # measure accuracy and record loss
214 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
215 | losses.update(loss.item(), input.size(0))
216 | top1.update(acc1[0], input.size(0))
217 | top5.update(acc5[0], input.size(0))
218 |
219 | # measure elapsed time
220 | batch_time.update(time.time() - end)
221 | end = time.time()
222 |
223 | if i % args.print_freq == 0:
224 | progress.print(i)
225 |
226 | num_layers = layer_idx
227 |
228 | # TODO: this should also be done with the ProgressMeter
229 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
230 | .format(top1=top1, top5=top5))
231 |
232 | return top1.avg
233 |
234 |
235 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
236 | torch.save(state, filename)
237 | if is_best:
238 | shutil.copyfile(filename, 'model_best.pth.tar')
239 |
240 | def parse_activation(relu_reference, feature_map):
241 | global layer_idx
242 | global num_layers
243 | apoz_score = apoz_scoring(feature_map)
244 | avg_score = avg_scoring(feature_map)
245 |
246 | if len(apoz_scores_by_layer) < num_layers:
247 | apoz_scores_by_layer.append(apoz_score)
248 | avg_scores_by_layer.append(avg_score)
249 | else:
250 | apoz_scores_by_layer[layer_idx] = torch.add(apoz_scores_by_layer[layer_idx], apoz_score)
251 | avg_scores_by_layer[layer_idx] = torch.add(avg_scores_by_layer[layer_idx], avg_score)
252 |
253 | layer_idx += 1
254 |
255 |
256 | """
257 | Apply a hook to RelU layer
258 | """
259 | def hook(self, input, output):
260 | if self.__class__.__name__ == 'ReLU':
261 | parse_activation(self, output.data)
262 |
263 | def apply_hook(m):
264 | m.register_forward_hook(hook)
265 |
266 | def generate_candidates(name):
267 | global num_batches
268 | global apoz_scores_by_layer
269 | global avg_scores_by_layer
270 | global num_layers
271 | group_id_string = name
272 | apoz_thresholds = [90] * num_layers
273 | avg_thresholds = [sys.maxsize] * num_layers #sys.maxsize to disable avg
274 | candidates_by_layer = []
275 |
276 | for layer_idx, (apoz_scores, avg_scores) in enumerate(zip(apoz_scores_by_layer, avg_scores_by_layer)):
277 | apoz_scores *= 1/ float(num_batches)
278 | apoz_scores = apoz_scores.cpu()
279 |
280 | avg_scores *= 1/ float(num_batches)
281 | avg_scores = avg_scores.cpu()
282 |
283 | avg_candidates = [idx for idx, score in enumerate(avg_scores) if score >= avg_thresholds[layer_idx]] if avg_scores.dim() != 0 else []
284 | candidates = [x[0] for x in apoz_scores.gt(apoz_thresholds[layer_idx]).nonzero().tolist()]
285 |
286 | difference_candidates = list(set(candidates).difference(set(avg_candidates)))
287 | candidates_by_layer.append(difference_candidates)
288 | print("Total candidates: {}".format(sum([len(l) for l in candidates_by_layer])))
289 | np.save(open("prune_candidate_logs/group_{}_apoz_layer_thresholds.npy".format( group_id_string), "wb"), candidates_by_layer)
290 | print(candidates_by_layer)
291 |
292 | class AverageMeter(object):
293 | """Computes and stores the average and current value"""
294 | def __init__(self, name, fmt=':f'):
295 | self.name = name
296 | self.fmt = fmt
297 | self.reset()
298 |
299 | def reset(self):
300 | self.val = 0
301 | self.avg = 0
302 | self.sum = 0
303 | self.count = 0
304 |
305 | def update(self, val, n=1):
306 | self.val = val
307 | self.sum += val * n
308 | self.count += n
309 | self.avg = self.sum / self.count
310 |
311 | def __str__(self):
312 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
313 | return fmtstr.format(**self.__dict__)
314 |
315 | class ProgressMeter(object):
316 | def __init__(self, num_batches, *meters, prefix=""):
317 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
318 | self.meters = meters
319 | self.prefix = prefix
320 |
321 | def print(self, batch):
322 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
323 | entries += [str(meter) for meter in self.meters]
324 | print('\t'.join(entries))
325 |
326 | def _get_batch_fmtstr(self, num_batches):
327 | num_digits = len(str(num_batches // 1))
328 | fmt = '{:' + str(num_digits) + 'd}'
329 | return '[' + fmt + '/' + fmt.format(num_batches) + ']'
330 |
331 | def adjust_learning_rate(optimizer, epoch, args):
332 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
333 | lr = args.lr * (0.1 ** (epoch // 30))
334 | for param_group in optimizer.param_groups:
335 | param_group['lr'] = lr
336 |
337 | def accuracy(output, target, topk=(1,)):
338 | """Computes the accuracy over the k top predictions for the specified values of k"""
339 | with torch.no_grad():
340 | maxk = max(topk)
341 | batch_size = target.size(0)
342 |
343 | _, pred = output.topk(maxk, 1, True, True)
344 | pred = pred.t()
345 | correct = pred.eq(target.view(1, -1).expand_as(pred))
346 |
347 | res = []
348 | for k in topk:
349 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
350 | res.append(correct_k.mul_(100.0 / batch_size))
351 | return res
352 |
353 | if __name__ == '__main__':
354 | main()
355 |
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import shutil
4 | import time
5 | import random
6 | import warnings
7 |
8 | from tqdm import tqdm
9 | import torch
10 | from torch import nn
11 | import torch.nn.parallel
12 | import torch.backends.cudnn as cudnn
13 | import torch.utils.data as data
14 | import torchvision.transforms as transforms
15 | import torchvision.datasets as datasets
16 | import numpy as np
17 | from utils import Logger, AverageMeter, accuracy, savefig
18 | from torch.utils.data import Dataset, DataLoader
19 | import glob
20 | import re
21 | import itertools
22 | from compute_flops import print_model_param_flops
23 | import torchvision.models as models
24 | from imagenet_evaluate_grouped import main_worker
25 | import torch.multiprocessing as mp
26 |
27 | model_names = sorted(name for name in models.__dict__
28 | if name.islower() and not name.startswith("__")
29 | and callable(models.__dict__[name]))
30 | model_names += ["resnet110", "resnet164", "mobilenetv2", "shufflenetv2"]
31 |
32 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10/100/ImageNet Testing')
33 | # Checkpoints
34 | parser.add_argument('--retrained_dir', type=str, metavar='PATH',
35 | help='path to the directory of pruned models (default: none)')
36 | # Datasets
37 | parser.add_argument('-d', '--dataset', required=True, type=str)
38 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
39 | help='number of data loading workers (default: 4)')
40 | parser.add_argument('--test-batch', default=128, type=int, metavar='N',
41 | help='test batchsize')
42 | parser.add_argument('--data', metavar='DIR', required=False,
43 | help='path to imagenet dataset')
44 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
45 | choices=model_names,
46 | help='model architecture: ' +
47 | ' | '.join(model_names) +
48 | ' (default: resnet18)')
49 | parser.add_argument('--epochs', default=90, type=int, metavar='N',
50 | help='number of total epochs to run')
51 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
52 | help='manual epoch number (useful on restarts)')
53 | parser.add_argument('-b', '--batch-size', default=64, type=int,
54 | metavar='N',
55 | help='mini-batch size (default: 256), this is the total '
56 | 'batch size of all GPUs on the current node when '
57 | 'using Data Parallel or Distributed Data Parallel')
58 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
59 | metavar='LR', help='initial learning rate', dest='lr')
60 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
61 | help='momentum')
62 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
63 | metavar='W', help='weight decay (default: 1e-4)',
64 | dest='weight_decay')
65 | parser.add_argument('-p', '--print-freq', default=10, type=int,
66 | metavar='N', help='print frequency (default: 10)')
67 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
68 | help='evaluate model on validation set')
69 | parser.add_argument('--pretrained', dest='pretrained', action='store_true',
70 | help='use pre-trained model')
71 | parser.add_argument('--world-size', default=-1, type=int,
72 | help='number of nodes for distributed training')
73 | parser.add_argument('--rank', default=-1, type=int,
74 | help='node rank for distributed training')
75 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
76 | help='url used to set up distributed training')
77 | parser.add_argument('--dist-backend', default='nccl', type=str,
78 | help='distributed backend')
79 | parser.add_argument('--gpu', default=None, type=int,
80 | help='GPU id to use.')
81 | parser.add_argument('--multiprocessing-distributed', action='store_true',
82 | help='Use multi-processing distributed training to launch '
83 | 'N processes per node, which has N GPUs. This is the '
84 | 'fastest way to use PyTorch for either single node or '
85 | 'multi node data parallel training')
86 | parser.add_argument('--bce', default=False, action='store_true',
87 | help='Use binary cross entropy loss')
88 | best_acc1 = 0
89 |
90 | # Miscs
91 | parser.add_argument('--seed', type=int, default=42, help='manual seed')
92 | args = parser.parse_args()
93 | state = {k: v for k, v in args._get_kwargs()}
94 | # Validate dataset
95 | assert args.dataset == 'cifar10' or args.dataset == 'cifar100' or args.dataset == 'imagenet', 'Dataset can only be cifar10, cifar100 or imagenet.'
96 |
97 | # Use CUDA
98 | use_cuda = torch.cuda.is_available()
99 |
100 | # Random seed
101 | torch.manual_seed(args.seed)
102 | if use_cuda:
103 | torch.cuda.manual_seed_all(args.seed)
104 |
105 | torch.set_printoptions(threshold=10000)
106 |
107 | def main():
108 | # imagenet evaluation
109 | if args.dataset == 'imagenet':
110 | imagenet_evaluate()
111 | return
112 |
113 | # cifar 10/100 evaluation
114 | print('==> Preparing dataset %s' % args.dataset)
115 | if args.dataset == 'cifar10':
116 | dataset_loader = datasets.CIFAR10
117 | elif args.dataset == 'cifar100':
118 | dataset_loader = datasets.CIFAR100
119 | else:
120 | raise NotImplementedError
121 |
122 | testloader = data.DataLoader(
123 | dataset_loader(
124 | root='./data',
125 | download=False,
126 | train=False,
127 | transform=transforms.Compose([
128 | transforms.ToTensor(),
129 | transforms.Normalize(
130 | (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
131 | ])),
132 | batch_size = args.test_batch,
133 | shuffle = True,
134 | num_workers = args.workers)
135 |
136 | cudnn.benchmark = True
137 | criterion = nn.CrossEntropyLoss()
138 | model = load_pruned_models(args.retrained_dir+'/'+args.arch+'/')
139 |
140 | if len(model.group_info) == 10 and args.dataset == 'cifar10':
141 | args.bce = True
142 |
143 | test_acc = test_list(testloader, model, criterion, use_cuda)
144 |
145 | def imagenet_evaluate():
146 | if args.seed is not None:
147 | random.seed(args.seed)
148 | torch.manual_seed(args.seed)
149 | cudnn.deterministic = True
150 | warnings.warn('You have chosen to seed training. '
151 | 'This will turn on the CUDNN deterministic setting, '
152 | 'which can slow down your training considerably! '
153 | 'You may see unexpected behavior when restarting '
154 | 'from checkpoints.')
155 |
156 | if args.gpu is not None:
157 | warnings.warn('You have chosen a specific GPU. This will completely '
158 | 'disable data parallelism.')
159 |
160 | if args.dist_url == "env://" and args.world_size == -1:
161 | args.world_size = int(os.environ["WORLD_SIZE"])
162 |
163 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed
164 |
165 | ngpus_per_node = torch.cuda.device_count()
166 | if args.multiprocessing_distributed:
167 | # Since we have ngpus_per_node processes per node, the total world_size
168 | # needs to be adjusted accordingly
169 | args.world_size = ngpus_per_node * args.world_size
170 | # Use torch.multiprocessing.spawn to launch distributed processes: the
171 | # main_worker process function
172 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
173 | else:
174 | # Simply call main_worker function
175 | main_worker(args.gpu, ngpus_per_node, args)
176 |
177 | def test_list(testloader, model, criterion, use_cuda):
178 | batch_time = AverageMeter()
179 | data_time = AverageMeter()
180 | losses = AverageMeter()
181 | top1 = AverageMeter()
182 | top5 = AverageMeter()
183 |
184 | if use_cuda:
185 | model.cuda()
186 | model.eval()
187 | end = time.time()
188 |
189 | if args.dataset == 'cifar10':
190 | confusion_matrix = np.zeros((10, 10))
191 | elif args.dataset == 'cifar100':
192 | confusion_matrix = np.zeros((100, 100))
193 | else:
194 | raise NotImplementedError
195 |
196 | bar = tqdm(total=len(testloader))
197 | # pdb.set_trace()
198 | for batch_idx, (inputs, targets) in enumerate(testloader):
199 | bar.update(1)
200 | # measure data loading time
201 | data_time.update(time.time() - end)
202 | if use_cuda:
203 | inputs, targets = inputs.cuda(), targets.cuda()
204 | with torch.no_grad():
205 | outputs = model(inputs)
206 | loss = criterion(outputs, targets)
207 | for output, target in zip(outputs, targets):
208 | gt = target.item()
209 | dt = np.argmax(output.cpu().numpy())
210 | confusion_matrix[gt, dt] += 1
211 | # measure accuracy and record loss
212 | prec1, prec5 = accuracy(outputs, targets, topk = (1, 5))
213 | losses.update(loss.item(), inputs.size(0))
214 | top1.update(prec1.item(), inputs.size(0))
215 | top5.update(prec5.item(), inputs.size(0))
216 |
217 | # measure elapsed time
218 | batch_time.update(time.time() - end)
219 | end = time.time()
220 |
221 | # plot progress
222 | bar.set_description('({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format(
223 | batch=batch_idx + 1,
224 | size=len(testloader),
225 | data=data_time.avg,
226 | bt=batch_time.avg,
227 | total='N/A' or bar.elapsed_td,
228 | eta='N/A' or bar.eta_td,
229 | loss=losses.avg,
230 | top1=top1.avg,
231 | top5=top5.avg,
232 | ))
233 | bar.close()
234 |
235 | np.set_printoptions(precision=3, linewidth=96)
236 |
237 | print("\n===== Full Confusion Matrix ==================================\n")
238 | if confusion_matrix.shape[0] < 20:
239 | print(confusion_matrix)
240 | else:
241 | print("Warning: The original confusion matrix is too big to fit into the screen. "
242 | "Skip printing the matrix.")
243 |
244 | if all([len(group) > 1 for group in model.group_info]):
245 | print("\n===== Inter-group Confusion Matrix ===========================\n")
246 | print(f"Group info: {[group for group in model.group_info]}")
247 | n_groups = len(model.group_info)
248 | group_confusion_matrix = np.zeros((n_groups, n_groups))
249 | for i in range(n_groups):
250 | for j in range(n_groups):
251 | cols = model.group_info[i]
252 | rows = model.group_info[j]
253 | group_confusion_matrix[i, j] += confusion_matrix[cols[0], rows[0]]
254 | group_confusion_matrix[i, j] += confusion_matrix[cols[0], rows[1]]
255 | group_confusion_matrix[i, j] += confusion_matrix[cols[1], rows[0]]
256 | group_confusion_matrix[i, j] += confusion_matrix[cols[1], rows[1]]
257 | group_confusion_matrix /= group_confusion_matrix.sum(axis=-1)[:, np.newaxis]
258 | print(group_confusion_matrix)
259 |
260 | print("\n===== In-group Confusion Matrix ==============================\n")
261 | for group in model.group_info:
262 | print(f"group {group}")
263 | inter_group_matrix = confusion_matrix[group, :][:, group]
264 | inter_group_matrix /= inter_group_matrix.sum(axis=-1)[:, np.newaxis]
265 | print(inter_group_matrix)
266 | return (losses.avg, top1.avg)
267 |
268 | class GroupedModel(nn.Module):
269 | def __init__(self, model_list, group_info):
270 | super().__init__()
271 | self.group_info = group_info
272 | # flatten list of list
273 | permutation_indices = list(itertools.chain.from_iterable(group_info))
274 | self.permutation_indices = torch.eye(len(permutation_indices))[permutation_indices]
275 | if use_cuda:
276 | self.permutation_indices = self.permutation_indices.cuda()
277 | self.model_list = nn.ModuleList(model_list)
278 |
279 | def forward(self, inputs):
280 | output_list = []
281 | if args.bce:
282 | for model_idx, model in enumerate(self.model_list):
283 | output = model(inputs)[:, 0]
284 | output_list.append(output)
285 | output_list = torch.softmax(torch.stack(output_list, dim=1).squeeze(), dim=1)
286 | else:
287 | for model_idx, model in enumerate(self.model_list):
288 | output = torch.softmax(model(inputs), dim=1)[:, 1:]
289 | output_list.append(output)
290 | output_list = torch.cat(output_list, 1)
291 | return torch.mm(output_list, self.permutation_indices)
292 |
293 | def print_statistics(self):
294 | num_params = []
295 | num_flops = []
296 |
297 | print("\n===== Metrics for grouped model ==========================\n")
298 |
299 | for group_id, model in zip(self.group_info, self.model_list):
300 | n_params = sum(p.numel() for p in model.parameters()) / 10**6
301 | num_params.append(n_params)
302 | print(f'Grouped model for Class {group_id} '
303 | f'Total params: {n_params:2f}M')
304 | num_flops.append(print_model_param_flops(model, 32))
305 |
306 | print(f"Average number of flops: {sum(num_flops) / len(num_flops) / 10**9 :3f} G")
307 | print(f"Average number of param: {sum(num_params) / len(num_params)} M")
308 |
309 |
310 | def load_pruned_models(model_dir):
311 | group_dir = model_dir[:-(len(args.arch)+1)]
312 | if not model_dir.endswith('/'):
313 | model_dir += '/'
314 | file_names = [f for f in glob.glob(model_dir + "*.pth", recursive=False)]
315 | model_list = [torch.load(file_name, map_location=lambda storage, loc: storage.cuda(0)) for file_name in file_names]
316 | groups = np.load(open(group_dir + "grouping_config.npy", "rb"))
317 | group_info = []
318 | for file in file_names:
319 | group_id = filename_to_index(file)
320 | print(f"Group number is: {group_id}")
321 | class_indices = groups[group_id]
322 | group_info.append(class_indices.tolist()[0])
323 | model = GroupedModel(model_list, group_info)
324 | model.print_statistics()
325 | return model
326 |
327 |
328 | def filename_to_index(filename):
329 | filename = [int(s) for s in filename.split('_') if s.isdigit()]
330 | return filename
331 |
332 | if __name__ == '__main__':
333 | main()
334 |
335 |
336 |
--------------------------------------------------------------------------------