├── datasets ├── __init__.py ├── cifar10.py └── cifar100.py ├── .gitattributes ├── models ├── __init__.py ├── extension.py ├── vgg_m.py ├── resnet_m.py └── shufflenet_m.py ├── record.py ├── acquisition.py ├── .gitignore ├── stat.py ├── options.py ├── README.md ├── layer.py ├── kernel.py ├── compression.py ├── architecture.py ├── graph.py ├── training.py └── LICENSE /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .cifar10 import * 2 | from .cifar100 import * -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .extension import * 2 | from .resnet_m import * 3 | from .vgg_m import * 4 | from .shufflenet_m import * -------------------------------------------------------------------------------- /record.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import options as opt 4 | 5 | class Record(object): 6 | 7 | def __init__(self): 8 | super(Record, self).__init__() 9 | self.n = 0 10 | self.x = [] 11 | self.y = torch.tensor([], dtype=torch.float, device=opt.device, 12 | requires_grad=False) 13 | self.reward_best = 0.0 14 | 15 | def add_sample(self, xn, yn): 16 | self.x.append(xn) 17 | self.y = torch.cat((self.y, torch.tensor([yn], dtype=torch.float, 18 | device=opt.device, 19 | requires_grad=False))) 20 | if yn > self.reward_best: 21 | self.reward_best = yn 22 | self.n += 1 23 | 24 | def save(self, save_path): 25 | path = os.path.dirname(save_path) 26 | if not os.path.exists(path): 27 | os.makedirs(path) 28 | torch.save(self, save_path) -------------------------------------------------------------------------------- /models/extension.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class Identity(nn.Module): 6 | 7 | def __init__(self): 8 | super(Identity, self).__init__() 9 | 10 | def forward(self, x): 11 | return x 12 | 13 | class Flatten(nn.Module): 14 | 15 | def __init__(self): 16 | super(Flatten, self).__init__() 17 | 18 | def forward(self, x): 19 | return x.view(x.size(0), -1) 20 | 21 | class Concat(nn.Module): 22 | 23 | def __init__(self, dim): 24 | super(Concat, self).__init__() 25 | self.dim = dim 26 | 27 | def forward(self, x): 28 | return torch.cat(x, dim=self.dim) 29 | 30 | class Shuffle(nn.Module): 31 | 32 | def __init__(self, groups): 33 | super(Shuffle, self).__init__() 34 | self.groups = groups 35 | 36 | def forward(self, x): 37 | N, C, H, W = x.size() 38 | g = self.groups 39 | x = x.view(N, g, C // g, H, W).permute(0, 2, 1, 3, 4) 40 | x = x.contiguous().view(N, C, H, W) 41 | return x -------------------------------------------------------------------------------- /acquisition.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from architecture import Architecture 3 | from kernel import Kernel 4 | import options as opt 5 | 6 | def get_rep_acq(teacher, kernel, action): 7 | rep = teacher.comp_rep(action) 8 | acq = kernel.acquisition(rep) 9 | return rep, acq 10 | 11 | def random_search(teacher, kernel, search_n=opt.ac_search_n): 12 | action_best, rep_best, acq_best = None, None, -1.0 13 | for i in range(search_n): 14 | action = teacher.comp_action_rand() 15 | rep, acq = get_rep_acq(teacher, kernel, action) 16 | if acq > acq_best: 17 | action_best, rep_best, acq_best = action, rep, acq 18 | return action_best, rep_best, acq_best 19 | 20 | def random_search_sfn(teacher, kernel, search_n=opt.ac_search_n): 21 | arch_best, rep_best, acq_best = None, None, -1.0 22 | for i in range(search_n): 23 | arch = teacher.comp_arch_rand_sfn() 24 | rep = arch.rep 25 | acq = kernel.acquisition(rep) 26 | if acq > acq_best: 27 | arch_best, rep_best, acq_best = arch, rep, acq 28 | return arch_best, rep_best, acq_best -------------------------------------------------------------------------------- /models/vgg_m.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .extension import * 5 | 6 | cfg = { 7 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 8 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 9 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 10 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 11 | } 12 | 13 | class VGGM(nn.Module): 14 | 15 | def __init__(self, vgg_name, num_classes=100): 16 | super(VGGM, self).__init__() 17 | self.features = self._make_layers(cfg[vgg_name]) 18 | self.flatten = Flatten() 19 | self.classifier = nn.Linear(512, num_classes) 20 | 21 | def forward(self, x): 22 | x = self.features(x) 23 | x = self.flatten(x) 24 | x = self.classifier(x) 25 | return x 26 | 27 | def _make_layers(self, cfg): 28 | layers = [] 29 | in_channels = 3 30 | for x in cfg: 31 | if x == 'M': 32 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 33 | else: 34 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1), 35 | nn.BatchNorm2d(x), 36 | nn.ReLU()] 37 | in_channels = x 38 | return nn.Sequential(*layers) 39 | 40 | def vgg16(**kwargs): 41 | return VGGM('VGG16', **kwargs) 42 | 43 | def vgg19(**kwargs): 44 | return VGGM('VGG19', **kwargs) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .nox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | db.sqlite3 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # IPython 77 | profile_default/ 78 | ipython_config.py 79 | 80 | # pyenv 81 | .python-version 82 | 83 | # celery beat schedule file 84 | celerybeat-schedule 85 | 86 | # SageMath parsed files 87 | *.sage.py 88 | 89 | # Environments 90 | .env 91 | .venv 92 | env/ 93 | venv/ 94 | ENV/ 95 | env.bak/ 96 | venv.bak/ 97 | 98 | # Spyder project settings 99 | .spyderproject 100 | .spyproject 101 | 102 | # Rope project settings 103 | .ropeproject 104 | 105 | # mkdocs documentation 106 | /site 107 | 108 | # mypy 109 | .mypy_cache/ 110 | .dmypy.json 111 | dmypy.json 112 | 113 | # Pyre type checker 114 | .pyre/ 115 | -------------------------------------------------------------------------------- /stat.py: -------------------------------------------------------------------------------- 1 | from compression import * 2 | 3 | if __name__ == '__main__': 4 | teacher_accs = { 5 | 'vgg19_cifar100_' : 73.71, 6 | 'resnet18_cifar100_' : 78.68, 7 | 'resnet34_cifar100_' : 78.71, 8 | 'shufflenet_cifar100_' : 71.14, 9 | 'vgg19_cifar10_' : 93.91, 10 | 'resnet18_cifar10_' : 95.24, 11 | 'resnet34_cifar10_' : 95.57, 12 | 'shufflenet_cifar10_' : 90.87, 13 | } 14 | path_list = [ 15 | './save/resnet34_cifar100_0', 16 | ] 17 | 18 | for path in path_list: 19 | if not(os.path.exists(path)): 20 | print('Path \'%s\' does not exist!' % (path)) 21 | continue 22 | print(path) 23 | teacher_acc = 0.0 24 | for key, val in teacher_accs.items(): 25 | if path.find(key) != -1: 26 | teacher_acc = val 27 | if teacher_acc == 0.0: 28 | print('Teacher acc not given!') 29 | continue 30 | else: 31 | print('Teacher acc:', teacher_acc) 32 | 33 | # architecture index, number of parameters, compression ratio, compression times, accuracy before & after fully training, f(x) before & after fully training 34 | print('Index\t#Params\tRatio\tTimes\tAcc before\tAcc after\tf(x) before\tf(x) after') 35 | best_index = 0 36 | best_reward2 = 0 37 | 38 | for i in range(opt.co_best_n): 39 | arch_path = '%s/arch_%d.pth' % (path, i) 40 | if not(os.path.exists(arch_path)): 41 | continue 42 | arch = torch.load(arch_path) 43 | param_n = arch.param_n() 44 | reward1 = arch.reward 45 | comp1 = arch.comp 46 | comp2 = 1.0 / (1.0 - comp1) 47 | acc1 = arch.acc 48 | 49 | arch_path = '%s/fully_%d.pth' % (path, i) 50 | if not(os.path.exists(arch_path)): 51 | continue 52 | arch = torch.load(arch_path) 53 | acc2 = arch.acc 54 | reward2 = arch.comp * (2 - arch.comp) * acc2 / teacher_acc 55 | 56 | print('%d\t%d\t%.4f\t%.2f\t%.2f\t\t%.2f\t\t%.4f\t\t%.4f' % (i, param_n, comp1, comp2, acc1, acc2, reward1, reward2)) 57 | if reward2 > best_reward2: 58 | best_index = i 59 | best_reward2 = reward2 60 | 61 | print('The best is arch %d with f(x) = %.4f' % (best_index, best_reward2)) -------------------------------------------------------------------------------- /datasets/cifar10.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torchvision.transforms as transforms 4 | from torch.utils.data import DataLoader 5 | from torch.utils.data.sampler import SubsetRandomSampler 6 | 7 | class CIFAR10(object): 8 | 9 | def __init__(self, batch_size=128, num_workers=4): 10 | train_transform = transforms.Compose([ 11 | transforms.RandomCrop(32, padding=4), 12 | transforms.RandomHorizontalFlip(), 13 | transforms.ToTensor(), 14 | transforms.Normalize((0.4914, 0.4822, 0.4465), 15 | (0.2023, 0.1994, 0.2010)), 16 | ]) 17 | 18 | test_transform = transforms.Compose([ 19 | transforms.ToTensor(), 20 | transforms.Normalize((0.4914, 0.4822, 0.4465), 21 | (0.2023, 0.1994, 0.2010)), 22 | ]) 23 | 24 | train_dataset = torchvision.datasets.CIFAR10(root='./datasets/data', 25 | train=True, download=True, transform=train_transform) 26 | test_dataset = torchvision.datasets.CIFAR10(root='./datasets/data', 27 | train=False, download=True, transform=test_transform) 28 | 29 | self.train_loader = DataLoader(train_dataset, batch_size=batch_size, 30 | shuffle=True, num_workers=num_workers, pin_memory=True) 31 | self.test_loader = DataLoader(test_dataset, batch_size=batch_size, 32 | shuffle=False, num_workers=num_workers, pin_memory=True) 33 | 34 | class CIFAR10Val(object): 35 | 36 | def __init__(self, batch_size=128, num_workers=4, val_size=5000): 37 | train_transform = transforms.Compose([ 38 | transforms.RandomCrop(32, padding=4), 39 | transforms.RandomHorizontalFlip(), 40 | transforms.ToTensor(), 41 | transforms.Normalize((0.4914, 0.4822, 0.4465), 42 | (0.2023, 0.1994, 0.2010)), 43 | ]) 44 | 45 | val_transform = transforms.Compose([ 46 | transforms.ToTensor(), 47 | transforms.Normalize((0.4914, 0.4822, 0.4465), 48 | (0.2023, 0.1994, 0.2010)), 49 | ]) 50 | 51 | train_dataset = torchvision.datasets.CIFAR10(root='./datasets/data', 52 | train=True, download=True, transform=train_transform) 53 | val_dataset = torchvision.datasets.CIFAR10(root='./datasets/data', 54 | train=True, download=True, transform=val_transform) 55 | 56 | total_size = len(train_dataset) 57 | indices = list(range(total_size)) 58 | train_size = total_size - val_size 59 | train_sampler = SubsetRandomSampler(indices[:train_size]) 60 | val_sampler = SubsetRandomSampler(indices[train_size:]) 61 | 62 | self.train_loader = DataLoader(train_dataset, batch_size=batch_size, 63 | sampler=train_sampler, num_workers=num_workers, pin_memory=True) 64 | self.val_loader = DataLoader(val_dataset, batch_size=batch_size, 65 | sampler=val_sampler, num_workers=num_workers, pin_memory=True) -------------------------------------------------------------------------------- /datasets/cifar100.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torchvision.transforms as transforms 4 | from torch.utils.data import DataLoader 5 | from torch.utils.data.sampler import SubsetRandomSampler 6 | 7 | class CIFAR100(object): 8 | 9 | def __init__(self, batch_size=128, num_workers=4): 10 | train_transform = transforms.Compose([ 11 | transforms.RandomCrop(32, padding=4), 12 | transforms.RandomHorizontalFlip(), 13 | transforms.ToTensor(), 14 | transforms.Normalize((0.4914, 0.4822, 0.4465), 15 | (0.2023, 0.1994, 0.2010)), 16 | ]) 17 | 18 | test_transform = transforms.Compose([ 19 | transforms.ToTensor(), 20 | transforms.Normalize((0.4914, 0.4822, 0.4465), 21 | (0.2023, 0.1994, 0.2010)), 22 | ]) 23 | 24 | train_dataset = torchvision.datasets.CIFAR100(root='./datasets/data', 25 | train=True, download=True, transform=train_transform) 26 | test_dataset = torchvision.datasets.CIFAR100(root='./datasets/data', 27 | train=False, download=True, transform=test_transform) 28 | 29 | self.train_loader = DataLoader(train_dataset, batch_size=batch_size, 30 | shuffle=True, num_workers=num_workers, pin_memory=True) 31 | self.test_loader = DataLoader(test_dataset, batch_size=batch_size, 32 | shuffle=False, num_workers=num_workers, pin_memory=True) 33 | 34 | class CIFAR100Val(object): 35 | 36 | def __init__(self, batch_size=128, num_workers=4, val_size=5000): 37 | train_transform = transforms.Compose([ 38 | transforms.RandomCrop(32, padding=4), 39 | transforms.RandomHorizontalFlip(), 40 | transforms.ToTensor(), 41 | transforms.Normalize((0.4914, 0.4822, 0.4465), 42 | (0.2023, 0.1994, 0.2010)), 43 | ]) 44 | 45 | val_transform = transforms.Compose([ 46 | transforms.ToTensor(), 47 | transforms.Normalize((0.4914, 0.4822, 0.4465), 48 | (0.2023, 0.1994, 0.2010)), 49 | ]) 50 | 51 | train_dataset = torchvision.datasets.CIFAR100(root='./datasets/data', 52 | train=True, download=True, transform=train_transform) 53 | val_dataset = torchvision.datasets.CIFAR100(root='./datasets/data', 54 | train=True, download=True, transform=val_transform) 55 | 56 | total_size = len(train_dataset) 57 | indices = list(range(total_size)) 58 | train_size = total_size - val_size 59 | train_sampler = SubsetRandomSampler(indices[:train_size]) 60 | val_sampler = SubsetRandomSampler(indices[train_size:]) 61 | 62 | self.train_loader = DataLoader(train_dataset, batch_size=batch_size, 63 | sampler=train_sampler, num_workers=num_workers, pin_memory=True) 64 | self.val_loader = DataLoader(val_dataset, batch_size=batch_size, 65 | sampler=val_sampler, num_workers=num_workers, pin_memory=True) -------------------------------------------------------------------------------- /options.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import datasets 3 | import models 4 | 5 | # global 6 | device = 'cuda' # used device, which can be 'cpu' or 'cuda' 7 | model = './models/pretrained/resnet34_cifar100.pth' # pretrained teacher model 8 | savedir = './save/resnet34_cifar100_0' # save directory 9 | writer = None # record writer for tensorboardX 10 | i = 0 # sample index in search 11 | 12 | # acquisition.py 13 | ac_search_n = 1000 # number of randomly sampled architectures when optimizing acquisition function (see 3.2) 14 | 15 | # architecture.py 16 | ar_max_layers = 128 # maximum number of layers of the original architecture 17 | ar_channel_mul = 2 # numbers of channels in layers should be divisible by this parameter 18 | # necessary for ShuffleNet which has group conv, not used in VGG/ResNet 19 | # hyper-params for random sampling in search space (see 3.2 & 6.5) 20 | ar_p1 = [0.3, 0.4, 0.5, 0.6, 0.7] # for layer removal 21 | ar_p2 = [0.0, 1.0] # for layer shrinkage 22 | ar_p3 = [0.003, 0.005, 0.01, 0.03, 0.05] # for adding skip connections 23 | 24 | # compression.py 25 | # hyper-params for multiple kernel strategy (see 3.3 & 6.3) 26 | co_step_n = 20 # number of search steps 27 | co_kernel_n = 8 # number of kernels, as well as evaluated architectures in each search step 28 | co_best_n = 4 # number of saved best architectures during search, all of which will be fully trained 29 | co_graph_gen = 'get_graph_resnet' # how to generate computation graph of original architecture 30 | # hyper-params for stopping criterion of kernel optimization 31 | co_alpha = 0.5 32 | co_beta = 0.001 33 | co_gamma = 0.5 34 | 35 | # kernel.py 36 | # hyper-params for kernels (see 3.1) 37 | ke_alpha = 0.01 38 | ke_beta = 0.05 39 | ke_gamma = 1 40 | ke_input_size = 16 + ar_max_layers * 2 41 | ke_hidden_size = 64 42 | ke_num_layers = 4 43 | ke_bidirectional = True 44 | ke_lr = 0.001 45 | ke_weight_decay = 5e-4 46 | 47 | # training.py 48 | # hyper-params for training during *search* 49 | tr_se_optimization = 'Adam' 50 | tr_se_epochs = 10 51 | tr_se_lr = 0.001 52 | tr_se_momentum = 0.9 53 | tr_se_weight_decay = 5e-4 54 | tr_se_lr_schedule = None 55 | tr_se_loss_criterion = 'KD' # 'KD': knowledge distillation using teacher outputs 56 | # 'CE': cross entropy using original labels 57 | # hyper-params for *fully* training after search 58 | tr_fu_optimization = 'SGD' 59 | tr_fu_epochs = 300 60 | tr_fu_lr = 0.01 61 | tr_fu_momentum = 0.9 62 | tr_fu_weight_decay = 5e-4 63 | tr_fu_lr_schedule = 'step' 64 | tr_fu_from_scratch = False # False: continue training based on simple training results during search 65 | # True: re-initialize weights and train from scratch -------------------------------------------------------------------------------- /models/resnet_m.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .extension import * 5 | 6 | def conv3x3(in_planes, out_planes, stride=1): 7 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 8 | padding=1, bias=False) 9 | 10 | class BasicBlockM(nn.Module): 11 | expansion = 1 12 | 13 | def __init__(self, inplanes, planes, stride=1, downsample=None): 14 | super(BasicBlockM, self).__init__() 15 | self.downsample = downsample 16 | self.conv1 = conv3x3(inplanes, planes, stride) 17 | self.bn1 = nn.BatchNorm2d(planes) 18 | self.relu1 = nn.ReLU() 19 | self.conv2 = conv3x3(planes, planes) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | self.relu2 = nn.ReLU() 22 | self.stride = stride 23 | 24 | def forward(self, x): 25 | residual = x 26 | 27 | out = self.conv1(x) 28 | out = self.bn1(out) 29 | out = self.relu1(out) 30 | 31 | out = self.conv2(out) 32 | out = self.bn2(out) 33 | 34 | if self.downsample is not None: 35 | residual = self.downsample(x) 36 | 37 | out += residual 38 | out = self.relu2(out) 39 | 40 | return out 41 | 42 | class ResNetM(nn.Module): 43 | 44 | def __init__(self, block, layers, num_classes=100): 45 | self.inplanes = 64 46 | super(ResNetM, self).__init__() 47 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, 48 | bias=False) 49 | self.bn1 = nn.BatchNorm2d(64) 50 | self.relu1 = nn.ReLU() 51 | self.layer1 = self._make_layer(block, 64, layers[0]) 52 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 53 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 54 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 55 | self.flatten = Flatten() 56 | self.avgpool = nn.AvgPool2d(4, stride=1) 57 | self.fc = nn.Linear(512 * block.expansion, num_classes) 58 | 59 | def _make_layer(self, block, planes, blocks, stride=1): 60 | downsample = None 61 | if stride != 1 or self.inplanes != planes * block.expansion: 62 | downsample = nn.Sequential( 63 | nn.Conv2d(self.inplanes, planes * block.expansion, 64 | kernel_size=1, stride=stride, bias=False), 65 | nn.BatchNorm2d(planes * block.expansion), 66 | ) 67 | 68 | layers = [] 69 | layers.append(block(self.inplanes, planes, stride, downsample)) 70 | self.inplanes = planes * block.expansion 71 | for i in range(1, blocks): 72 | layers.append(block(self.inplanes, planes)) 73 | 74 | return nn.Sequential(*layers) 75 | 76 | def forward(self, x): 77 | x = self.conv1(x) 78 | x = self.bn1(x) 79 | x = self.relu1(x) 80 | 81 | x = self.layer1(x) 82 | x = self.layer2(x) 83 | x = self.layer3(x) 84 | x = self.layer4(x) 85 | 86 | x = self.avgpool(x) 87 | x = self.flatten(x) 88 | x = self.fc(x) 89 | 90 | return x 91 | 92 | def resnet18(**kwargs): 93 | return ResNetM(BasicBlockM, [2, 2, 2, 2], **kwargs) 94 | 95 | def resnet34(**kwargs): 96 | return ResNetM(BasicBlockM, [3, 4, 6, 3], **kwargs) -------------------------------------------------------------------------------- /models/shufflenet_m.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .extension import * 5 | 6 | class BottleneckM(nn.Module): 7 | 8 | def __init__(self, in_planes, out_planes, stride, groups): 9 | super(BottleneckM, self).__init__() 10 | self.stride = stride 11 | mid_planes = out_planes // 4 12 | g = 1 if in_planes == 24 else groups 13 | 14 | self.conv1 = nn.Conv2d(in_planes, mid_planes, kernel_size=1, groups=g, 15 | bias=False) 16 | self.bn1 = nn.BatchNorm2d(mid_planes) 17 | self.relu1 = nn.ReLU() 18 | self.shuffle = Shuffle(groups=g) 19 | self.conv2 = nn.Conv2d(mid_planes, mid_planes, kernel_size=3, 20 | stride=stride, padding=1, groups=mid_planes, 21 | bias=False) 22 | self.bn2 = nn.BatchNorm2d(mid_planes) 23 | self.conv3 = nn.Conv2d(mid_planes, out_planes, kernel_size=1, 24 | groups=groups, bias=False) 25 | self.bn3 = nn.BatchNorm2d(out_planes) 26 | self.relu3 = nn.ReLU() 27 | 28 | if stride == 2: 29 | self.conv4 = nn.Conv2d(in_planes, in_planes, kernel_size=1, 30 | groups=2, bias=False) 31 | self.avgpool = nn.AvgPool2d(3, stride=2, padding=1) 32 | self.concat = Concat(dim=1) 33 | 34 | def forward(self, x): 35 | out = self.relu1(self.bn1(self.conv1(x))) 36 | out = self.shuffle(out) 37 | out = self.bn2(self.conv2(out)) 38 | out = self.bn3(self.conv3(out)) 39 | 40 | if self.stride == 2: 41 | res = self.avgpool(self.conv4(x)) 42 | out = self.relu3(self.concat([out, res])) 43 | else: 44 | res = x 45 | out = self.relu3(out + res) 46 | return out 47 | 48 | class ShuffleNetM(nn.Module): 49 | 50 | def __init__(self, cfg, num_classes=100): 51 | super(ShuffleNetM, self).__init__() 52 | out_planes = cfg['out_planes'] 53 | num_blocks = cfg['num_blocks'] 54 | groups = cfg['groups'] 55 | 56 | self.conv1 = nn.Conv2d(3, 24, kernel_size=1, bias=False) 57 | self.bn1 = nn.BatchNorm2d(24) 58 | self.relu1 = nn.ReLU() 59 | self.in_planes = 24 60 | self.layer1 = self._make_layer(out_planes[0], num_blocks[0], groups) 61 | self.layer2 = self._make_layer(out_planes[1], num_blocks[1], groups) 62 | self.layer3 = self._make_layer(out_planes[2], num_blocks[2], groups) 63 | self.avgpool = nn.AvgPool2d(4) 64 | self.flatten = Flatten() 65 | self.fc = nn.Linear(out_planes[2], num_classes) 66 | 67 | def _make_layer(self, out_planes, num_blocks, groups): 68 | layers = [] 69 | for i in range(num_blocks): 70 | stride = 2 if i == 0 else 1 71 | cat_planes = self.in_planes if i == 0 else 0 72 | layers.append(BottleneckM(self.in_planes, out_planes - cat_planes, 73 | stride=stride, groups=groups)) 74 | self.in_planes = out_planes 75 | return nn.Sequential(*layers) 76 | 77 | def forward(self, x): 78 | x = self.relu1(self.bn1(self.conv1(x))) 79 | 80 | x = self.layer1(x) 81 | x = self.layer2(x) 82 | x = self.layer3(x) 83 | 84 | x = self.fc(self.flatten(self.avgpool(x))) 85 | return x 86 | 87 | def shufflenet(**kwargs): 88 | cfg = { 89 | 'out_planes': [200, 400, 800], 90 | 'num_blocks': [4, 8, 4], 91 | 'groups': 2 92 | } 93 | return ShuffleNetM(cfg, **kwargs) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ESNAC: Embedding Space for Neural Architecture Compression 2 | 3 | This is the PyTorch implementation of our paper: 4 | 5 | Learnable Embedding Space for Efficient Neural Architecture Compression.
Shengcao Cao\*, Xiaofang Wang\*, and Kris M. Kitani. ICLR 2019. \[[OpenReview](https://openreview.net/forum?id=S1xLN3C9YX)\] \[[arXiv](https://arxiv.org/abs/1902.00383)\]. 6 | 7 | ## Requirements 8 | 9 | We recommend you to use this repository with [Anaconda Python 3.7](https://www.anaconda.com/distribution/) and the following libraries: 10 | 11 | - [PyTorch 1.0](https://pytorch.org/) 12 | - [tensorboardX](https://github.com/lanpa/tensorboardX) 13 | - [TensorFlow](https://www.tensorflow.org/) (Optional, only necessary if you would like to use TensorBoard to monitor the running of the job.) 14 | 15 | ## Usage 16 | 17 | - Before running `compression.py`, you need to prepare the pretrained teacher models and put them at the folder `./models/pretrained`. You can choose to train them on your own with `train_model_teacher()` in `training.py`, or download them at: 18 | 19 | - [Google Drive](https://drive.google.com/open?id=1RgeUljIs5WeRuHYjWnWAZf_qkNa3O-IR) 20 | - [百度网盘 (BaiduYun)](https://pan.baidu.com/s/1p0_2YycHoau-wN5xw9xTuA) (Code: 9aru) 21 | 22 | We would like to point out that these provided pretrained teacher models are not trained on the full training set of CIFAR-10 or CIFAR-100. For both CIFAR-10 and CIFAR-100, we sample 5K images from the full training set as the validation set. The provided pretrained teacher models are trained on the remaining training images and are only used during the search process. The teacher accuracy reported in our paper refers to the accuracy of teacher models trained on the full training set of CIFAR-10 or CIFAR-100. 23 | 24 | - Then run the main program: 25 | 26 | ``` 27 | python compression.py [-h] [--network NETWORK] [--dataset DATASET] 28 | [--suffix SUFFIX] [--device DEVICE] 29 | ``` 30 | 31 | For example, run 32 | 33 | ``` 34 | python compression.py --network resnet34 --dataset cifar100 --suffix 0 --device cuda 35 | ``` 36 | 37 | and you will see how the ResNet-34 architecture is compressed on the CIFAR-100 dataset using your GPU. The results will be saved at `./save/resnet34_cifar100_0` and the TensorBoard log will be saved at `./runs/resnet34_cifar100_0`. 38 | 39 | Other hyper-parameters can be adjusted in `options.py`. 40 | 41 | - The whole process includes two stages: searching for desired compressed architectures, and fully train them. `compression.py` will do them both. Optionally, you can use TensorBoard to monitor the process through the log files. 42 | 43 | - After the compression, you can use the script `stat.py` to get the statistics of the compression results. 44 | 45 | ## Random Seed and Reproducibility 46 | 47 | To ensure reproducibility, we provide the compression results on CIFAR-100 with random seed 127. This seed value is randomly picked. You can try other seed values or comment out the call of `seed_everything()` in `compression.py` to obtain different results. Here are the compression results on CIFAR-100 when fixing the seed value to 127: 48 | 49 | | Teacher | Accuracy | #Params | Ratio | Times | f(x) | 50 | | --- | :---: | :---: | :---: | :---: | :---: | 51 | | VGG-19 | 71.64% | 3.07M | 0.8470 | 6.54× | 0.9492 | 52 | | ResNet-18 | 71.91% | 1.26M | 0.8876 | 8.90× | 0.9024 | 53 | | ResNet-34 | 75.47% | 2.85M | 0.8664 | 7.48× | 0.9417 | 54 | | ShuffleNet | 68.17% | 0.18M | 0.8298 | 5.88× | 0.9305 | 55 | 56 | ## Citation 57 | 58 | If you find our work useful in your research, please consider citing our paper [Learnable Embedding Space for Efficient Neural Architecture Compression](https://openreview.net/forum?id=S1xLN3C9YX): 59 | 60 | ``` 61 | @inproceedings{ 62 | cao2018learnable, 63 | title={Learnable Embedding Space for Efficient Neural Architecture Compression}, 64 | author={Shengcao Cao and Xiaofang Wang and Kris M. Kitani}, 65 | booktitle={International Conference on Learning Representations}, 66 | year={2019}, 67 | url={https://openreview.net/forum?id=S1xLN3C9YX}, 68 | } 69 | ``` 70 | 71 | -------------------------------------------------------------------------------- /layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from models.extension import * 5 | 6 | class Layer(nn.Module): 7 | supported_base = (Identity, Flatten, Concat, Shuffle, 8 | nn.Conv2d, nn.MaxPool2d, nn.AvgPool2d, nn.ReLU, 9 | nn.BatchNorm2d, nn.Linear) 10 | 11 | def __init__(self, base, in_shape=None, out_shape=None): 12 | super(Layer, self).__init__() 13 | self.base = base 14 | self.base_type = base.__class__.__name__ 15 | if not isinstance(base, Layer.supported_base): 16 | raise NotImplementedError('Unknown base layer!') 17 | self.in_shape = torch.Size([-1] + list(in_shape[1:])) 18 | self.out_shape = torch.Size([-1] + list(out_shape[1:])) 19 | self.init_rep() 20 | 21 | def replace(self, base): 22 | if not isinstance(base, Layer.supported_base): 23 | raise NotImplementedError('Unknown base layer!') 24 | self.base = base 25 | self.base_type = base.__class__.__name__ 26 | self.init_rep() 27 | 28 | def shrink(self, Fi, Fo): 29 | in_shape = list(self.in_shape) 30 | in_shape[1] = Fi 31 | self.in_shape = torch.Size(in_shape) 32 | out_shape = list(self.out_shape) 33 | out_shape[1] = Fo 34 | self.out_shape = torch.Size(out_shape) 35 | 36 | b = self.base 37 | if isinstance(b, nn.Conv2d): 38 | groups = b.groups 39 | if (groups == b.in_channels and b.in_channels == b.out_channels and 40 | Fi == Fo): 41 | groups = Fi 42 | conv = nn.Conv2d(Fi, Fo, b.kernel_size, stride=b.stride, 43 | padding=b.padding, dilation=b.dilation, 44 | groups=groups, bias=(b.bias is not None)) 45 | conv.weight = nn.Parameter(b.weight[:Fo, :(Fi // groups)].clone()) 46 | if b.bias is not None: 47 | conv.bias = nn.Parameter(b.bias[:Fo].clone()) 48 | self.replace(conv) 49 | elif isinstance(b, nn.BatchNorm2d): 50 | bn = nn.BatchNorm2d(Fi, eps=b.eps, momentum=b.momentum, 51 | affine=b.affine, 52 | track_running_stats=b.track_running_stats) 53 | bn.weight = nn.Parameter(b.weight[:Fi].clone()) 54 | bn.bias = nn.Parameter(b.bias[:Fi].clone()) 55 | self.replace(bn) 56 | elif isinstance(b, nn.Linear): 57 | ln = nn.Linear(Fi, Fo, bias=(b.bias is not None)) 58 | ln.weight = nn.Parameter(b.weight[:Fo, :Fi].clone()) 59 | if b.bias is not None: 60 | ln.bias = nn.Parameter(b.bias[:Fo].clone()) 61 | self.replace(ln) 62 | else: 63 | self.init_rep() 64 | 65 | def forward(self, x): 66 | return self.base(x) 67 | 68 | def init_param(self): 69 | b = self.base 70 | if isinstance(b, nn.Conv2d): 71 | nn.init.kaiming_normal_(b.weight, mode='fan_out', 72 | nonlinearity='relu') 73 | if b.bias is not None: 74 | nn.init.constant_(b.bias, 0) 75 | elif isinstance(b, nn.BatchNorm2d): 76 | nn.init.constant_(b.weight, 1) 77 | nn.init.constant_(b.bias, 0) 78 | elif isinstance(b, nn.Linear): 79 | nn.init.normal_(b.weight, 0, 0.01) 80 | nn.init.constant_(b.bias, 0) 81 | 82 | def init_rep(self): 83 | b = self.base 84 | lt = Layer.supported_base.index(type(b)) 85 | lr = [0] * 10 86 | lr[lt] = 1 87 | k = getattr(b, 'kernel_size', 0) 88 | k = k[0] if type(k) is tuple else k 89 | s = getattr(b, 'stride', 0) 90 | s = s[0] if type(s) is tuple else s 91 | p = getattr(b, 'padding', 0) 92 | p = p[0] if type(p) is tuple else p 93 | g = getattr(b, 'groups', 0) 94 | i = 0 95 | o = 0 96 | if isinstance(b, (nn.Conv2d, nn.Linear)): 97 | i = list(self.in_shape)[1] 98 | o = list(self.out_shape)[1] 99 | self.rep = lr + [k, s, p, g, i, o] 100 | 101 | def param_n(self): 102 | return sum([len(w.view(-1)) for w in self.base.parameters()]) 103 | 104 | class LayerGroup(object): 105 | 106 | def __init__(self, F, in_layers, out_layers): 107 | self.F = F 108 | self.in_layers = set(in_layers) 109 | self.out_layers = set(out_layers) 110 | self.union = self.in_layers | self.out_layers 111 | self.inter = self.in_layers & self.out_layers 112 | self.in_only = self.union - self.out_layers 113 | self.out_only = self.union - self.in_layers -------------------------------------------------------------------------------- /kernel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | import torch.backends.cudnn as cudnn 6 | import options as opt 7 | from math import sqrt 8 | from scipy.stats import norm 9 | import os 10 | 11 | class Kernel(object): 12 | 13 | def __init__(self, x0, y0, alpha=opt.ke_alpha, beta=opt.ke_beta, 14 | input_size=opt.ke_input_size, hidden_size=opt.ke_hidden_size, 15 | num_layers=opt.ke_num_layers, bidirectional=opt.ke_bidirectional, 16 | lr=opt.ke_lr, weight_decay=opt.ke_weight_decay): 17 | 18 | super(Kernel, self).__init__() 19 | self.alpha = alpha 20 | self.beta = beta 21 | self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, 22 | num_layers=num_layers, bidirectional=bidirectional) 23 | self.lstm = self.lstm.to(opt.device) 24 | self.input_size = input_size 25 | self.hidden_size = hidden_size 26 | self.num_layers = num_layers 27 | self.bidirectional = bidirectional 28 | self.bi = 2 if bidirectional else 1 29 | 30 | self.x = [x0] 31 | self.y = torch.tensor([y0], dtype=torch.float, device=opt.device, 32 | requires_grad=False) 33 | self.x_best = x0 34 | self.y_best = y0 35 | self.i_best = 0 36 | 37 | self.n = 1 38 | self.E = self.embedding(x0).view(1, -1) 39 | self.K = self.kernel(self.E[0], self.E[0]).view(1, 1) 40 | self.K_inv = torch.inverse(self.K + self.beta * 41 | torch.eye(self.n, device=opt.device)) 42 | self.optimizer = optim.Adam(self.lstm.parameters(), lr=lr, 43 | weight_decay=weight_decay) 44 | 45 | def embedding(self, xi): 46 | inputs = xi.view(-1, 1, self.input_size) 47 | outputs, (hn, cn) = self.lstm(inputs) 48 | outputs = torch.mean(outputs.squeeze(1), dim=0) 49 | outputs = outputs / torch.norm(outputs) 50 | return outputs 51 | 52 | def kernel(self, ei, ej): 53 | d = ei - ej 54 | d = torch.sum(d * d) 55 | k = torch.exp(-d / (2 * self.alpha)) 56 | return k 57 | 58 | def kernel_batch(self, en): 59 | n = self.n 60 | k = torch.zeros(n, device=opt.device) 61 | for i in range(n): 62 | k[i] = self.kernel(self.E[i], en) 63 | return k 64 | 65 | def predict(self, xn): 66 | n = self.n 67 | en = self.embedding(xn) 68 | k = self.kernel_batch(en) 69 | kn = self.kernel(en, en) 70 | t = torch.mm(k.view(1, n), self.K_inv) 71 | mu = torch.mm(t, self.y.view(n, 1)) 72 | sigma = kn - torch.mm(t, k.view(n, 1)) 73 | sigma = torch.sqrt(sigma + self.beta) 74 | return mu, sigma 75 | 76 | def acquisition(self, xn): 77 | with torch.no_grad(): 78 | mu, sigma = self.predict(xn) 79 | mu = mu.item() 80 | sigma = sigma.item() 81 | y_best = self.y_best 82 | z = (mu - y_best) / sigma 83 | ei = (mu - y_best) * norm.cdf(z) + sigma * norm.pdf(z) 84 | return ei 85 | 86 | def kernel_batch_ex(self, t): 87 | n = self.n 88 | k = torch.zeros(n - 1, device=opt.device) 89 | for i in range(t): 90 | k[i] = self.kernel(self.E[i], self.E[t]) 91 | for i in range(t + 1, n): 92 | k[i - 1] = self.kernel(self.E[t], self.E[i]) 93 | return k 94 | 95 | def predict_ex(self, t): 96 | n = self.n 97 | k = self.kernel_batch_ex(t) 98 | kt = self.kernel(self.E[t], self.E[t]) 99 | indices = list(range(t)) + list(range(t + 1, n)) 100 | indices = torch.tensor(indices, dtype=torch.long, device=opt.device) 101 | K = self.K 102 | K = torch.index_select(K, 0, indices) 103 | K = torch.index_select(K, 1, indices) 104 | K_inv = torch.inverse(K + self.beta * 105 | torch.eye(n - 1, device=opt.device)) 106 | y = torch.index_select(self.y, 0, indices) 107 | 108 | t = torch.mm(k.view(1, n - 1), K_inv) 109 | mu = torch.mm(t, y.view(n - 1, 1)) 110 | sigma = kt - torch.mm(t, k.view(n - 1, 1)) 111 | sigma = torch.sqrt(sigma + self.beta) 112 | return mu, sigma 113 | 114 | def add_sample(self, xn, yn): 115 | self.x.append(xn) 116 | self.y = torch.cat((self.y, torch.tensor([yn], dtype=torch.float, 117 | device=opt.device, 118 | requires_grad=False))) 119 | n = self.n 120 | if yn > self.y_best: 121 | self.x_best = xn 122 | self.y_best = yn 123 | self.i_best = n 124 | en = self.embedding(xn) 125 | k = self.kernel_batch(en) 126 | kn = self.kernel(en, en) 127 | self.E = torch.cat((self.E, en.view(1, -1)), 0) 128 | self.K = torch.cat((torch.cat((self.K, k.view(n, 1)), 1), 129 | torch.cat((k.view(1, n), kn.view(1, 1)), 1)), 0) 130 | self.n += 1 131 | self.K_inv = torch.inverse(self.K + self.beta * 132 | torch.eye(self.n, device=opt.device)) 133 | 134 | def add_batch(self, x, y): 135 | self.x.extend(x) 136 | self.y = torch.cat((self.y, y)) 137 | m = len(x) 138 | for i in range(m): 139 | n = self.n 140 | if y[i].item() > self.y_best: 141 | self.x_best = x[i] 142 | self.y_best = y[i].item() 143 | self.i_best = n 144 | en = self.embedding(x[i]) 145 | k = self.kernel_batch(en) 146 | kn = self.kernel(en, en) 147 | self.E = torch.cat((self.E, en.view(1, -1)), 0) 148 | self.K = torch.cat((torch.cat((self.K, k.view(n, 1)), 1), 149 | torch.cat((k.view(1, n), kn.view(1, 1)), 1)), 0) 150 | self.n += 1 151 | self.K_inv = torch.inverse(self.K + self.beta * 152 | torch.eye(self.n, device=opt.device)) 153 | 154 | def update_EK(self): 155 | n = self.n 156 | E_ = torch.zeros((n, self.E.size(1)), device=opt.device) 157 | for i in range(n): 158 | E_[i] = self.embedding(self.x[i]) 159 | self.E = E_ 160 | K_ = torch.zeros((n, n), device=opt.device) 161 | for i in range(n): 162 | for j in range(i, n): 163 | k = self.kernel(self.E[i], self.E[j]) 164 | K_[i, j] = k 165 | K_[j, i] = k 166 | self.K = K_ 167 | self.K_inv = torch.inverse(self.K + self.beta * 168 | torch.eye(self.n, device=opt.device)) 169 | 170 | def loss(self): 171 | n = self.n 172 | l = torch.zeros(n, device=opt.device) 173 | for i in range(n): 174 | mu, sigma = self.predict_ex(i) 175 | d = self.y[i] - mu 176 | l[i] = -(0.918939 + torch.log(sigma) + d * d / (2 * sigma * sigma)) 177 | l = -torch.mean(l) 178 | return l 179 | 180 | def opt_step(self): 181 | if self.n < 2: 182 | return 0.0 183 | self.optimizer.zero_grad() 184 | l = self.loss() 185 | ll = -l.item() 186 | l.backward() 187 | self.optimizer.step() 188 | self.update_EK() 189 | return ll 190 | 191 | def save(self, save_path): 192 | path = os.path.dirname(save_path) 193 | if not os.path.exists(path): 194 | os.makedirs(path) 195 | torch.save(self, save_path) -------------------------------------------------------------------------------- /compression.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import datasets 3 | import models 4 | from architecture import Architecture 5 | from kernel import Kernel 6 | from record import Record 7 | import acquisition as ac 8 | import graph as gr 9 | import options as opt 10 | import training as tr 11 | import numpy as np 12 | import argparse 13 | from operator import attrgetter 14 | import os 15 | import random 16 | import time 17 | from tensorboardX import SummaryWriter 18 | 19 | def seed_everything(seed=127): 20 | random.seed(seed) 21 | torch.manual_seed(seed) 22 | torch.cuda.manual_seed_all(seed) 23 | np.random.seed(seed) 24 | os.environ['PYTHONHASHSEED'] = str(seed) 25 | torch.backends.cudnn.deterministic = True 26 | 27 | def new_kernels(teacher, record, kernel_n, alpha=opt.co_alpha, 28 | beta=opt.co_beta, gamma=opt.co_gamma): 29 | start_time = time.time() 30 | kernels = [] 31 | for i in range(kernel_n): 32 | kernel = Kernel(teacher.rep, 0.0) 33 | indices = [] 34 | for j in range(record.n): 35 | if random.random() < gamma: 36 | indices.append(j) 37 | if len(indices) > 0: 38 | x = [record.x[i] for i in indices] 39 | indices = torch.tensor(indices, dtype=torch.long, device=opt.device) 40 | y = torch.index_select(record.y, 0, indices) 41 | kernel.add_batch(x, y) 42 | ma = 0.0 43 | for j in range(100): 44 | ll = kernel.opt_step() 45 | opt.writer.add_scalar('step_%d/kernel_%d_loglikelihood' % (opt.i, i), 46 | ll, j) 47 | ma = (alpha * ll + (1 - alpha) * ma) if j > 0 else ll 48 | if j > 5 and abs(ma - ll) < beta: 49 | break 50 | kernels.append(kernel) 51 | opt.writer.add_scalar('compression/kernel_time', 52 | time.time() - start_time, opt.i) 53 | return kernels 54 | 55 | def next_samples(teacher, kernels, kernel_n): 56 | start_time = time.time() 57 | n = kernel_n 58 | reps_best, acqs_best, archs_best = [], [], [] 59 | 60 | if opt.co_graph_gen == 'get_graph_shufflenet': 61 | for i in range(n): 62 | arch, rep, acq = ac.random_search_sfn(teacher, kernels[i]) 63 | archs_best.append(arch) 64 | reps_best.append(rep) 65 | acqs_best.append(acq) 66 | opt.writer.add_scalar('compression/acq', acq, opt.i * n + i - n + 1) 67 | opt.writer.add_scalar('compression/sampling_time', 68 | time.time() - start_time, opt.i) 69 | return archs_best, reps_best 70 | 71 | else: 72 | for i in range(n): 73 | action, rep, acq = ac.random_search(teacher, kernels[i]) 74 | reps_best.append(rep) 75 | acqs_best.append(acq) 76 | archs_best.append(teacher.comp_arch(action)) 77 | opt.writer.add_scalar('compression/acq', acq, opt.i * n + i - n + 1) 78 | opt.writer.add_scalar('compression/sampling_time', 79 | time.time() - start_time, opt.i) 80 | return archs_best, reps_best 81 | 82 | def reward(teacher, teacher_acc, students, dataset): 83 | start_time = time.time() 84 | n = len(students) 85 | students_best, students_acc = tr.train_model_search(teacher, students, dataset) 86 | rs = [] 87 | for j in range(n): 88 | c = 1.0 - 1.0 * students_best[j].param_n() / teacher.param_n() 89 | a = 1.0 * students_acc[j] / teacher_acc 90 | r = c * (2 - c) * a 91 | opt.writer.add_scalar('compression/compression_score', c, 92 | opt.i * n - n + 1 + j) 93 | opt.writer.add_scalar('compression/accuracy_score', a, 94 | opt.i * n - n + 1 + j) 95 | opt.writer.add_scalar('compression/reward', r, 96 | opt.i * n - n + 1 + j) 97 | rs.append(r) 98 | students_best[j].comp = c 99 | students_best[j].acc = students_acc[j] 100 | students_best[j].reward = r 101 | opt.writer.add_scalar('compression/evaluating_time', 102 | time.time() - start_time, opt.i) 103 | return students_best, rs 104 | 105 | def compression(teacher, dataset, record, step_n=opt.co_step_n, 106 | kernel_n=opt.co_kernel_n, best_n=opt.co_best_n): 107 | 108 | teacher_acc = tr.test_model(teacher, dataset) 109 | archs_best = [] 110 | for i in range(1, step_n + 1): 111 | print ('Search step %d/%d' %(i, step_n)) 112 | start_time = time.time() 113 | opt.i = i 114 | kernels = new_kernels(teacher, record, kernel_n) 115 | students_best, xi = next_samples(teacher, kernels, kernel_n) 116 | students_best, yi = reward(teacher, teacher_acc, students_best, dataset) 117 | for j in range(kernel_n): 118 | record.add_sample(xi[j], yi[j]) 119 | if yi[j] == record.reward_best: 120 | opt.writer.add_scalar('compression/reward_best', yi[j], i) 121 | students_best = [student.to('cpu') for student in students_best] 122 | archs_best.extend(students_best) 123 | archs_best.sort(key=attrgetter('reward'), reverse=True) 124 | archs_best = archs_best[:best_n] 125 | for j, arch in enumerate(archs_best): 126 | arch.save('%s/arch_%d.pth' % (opt.savedir, j)) 127 | record.save(opt.savedir + '/record.pth') 128 | opt.writer.add_scalar('compression/step_time', 129 | time.time() - start_time, i) 130 | 131 | def fully_train(dataset, best_n=opt.co_best_n): 132 | dataset = getattr(datasets, dataset)() 133 | for i in range(best_n): 134 | print ('Fully train student architecture %d/%d' %(i+1, best_n)) 135 | model = torch.load('%s/arch_%d.pth' % (opt.savedir, i)) 136 | tr.train_model_student(model, dataset, 137 | '%s/fully_%d.pth' % (opt.savedir, i), i) 138 | 139 | if __name__ == '__main__': 140 | parser = argparse.ArgumentParser(description='Learnable Embedding Space for Efficient Neural Architecture Compression') 141 | 142 | parser.add_argument('--network', type=str, default='resnet34', 143 | help='resnet18/resnet34/vgg19/shufflenet') 144 | parser.add_argument('--dataset', type=str, default='cifar100', 145 | help='cifar10/cifar100') 146 | parser.add_argument('--suffix', type=str, default='0', help='0/1/2/3...') 147 | parser.add_argument('--device', type=str, default='cuda', help='cpu/cuda') 148 | 149 | args = parser.parse_args() 150 | 151 | seed_everything() 152 | 153 | assert args.network in ['resnet18', 'resnet34', 'vgg19', 'shufflenet'] 154 | assert args.dataset in ['cifar10', 'cifar100'] 155 | 156 | if args.network in ['resnet18', 'resnet34']: 157 | opt.co_graph_gen = 'get_graph_resnet' 158 | elif args.network == 'vgg19': 159 | opt.co_graph_gen = 'get_graph_vgg' 160 | elif args.network == 'shufflenet': 161 | opt.co_graph_gen = 'get_graph_shufflenet' 162 | 163 | if args.dataset == 'cifar10': 164 | opt.dataset = 'CIFAR10Val' 165 | elif args.dataset == 'cifar100': 166 | opt.dataset = 'CIFAR100Val' 167 | 168 | opt.device = args.device 169 | 170 | opt.model = './models/pretrained/%s_%s.pth' % (args.network, args.dataset) 171 | opt.savedir = './save/%s_%s_%s' % (args.network, args.dataset, args.suffix) 172 | opt.writer = SummaryWriter('./runs/%s_%s_%s' % (args.network, args.dataset, 173 | args.suffix)) 174 | assert not(os.path.exists(opt.savedir)), 'Overwriting existing files!' 175 | 176 | print ('Start compression. Please check the TensorBoard log in the folder ./runs/%s_%s_%s.'% 177 | (args.network, args.dataset, args.suffix)) 178 | 179 | model = torch.load(opt.model).to(opt.device) 180 | teacher = Architecture(*(getattr(gr, opt.co_graph_gen)(model))) 181 | dataset = getattr(datasets, opt.dataset)() 182 | record = Record() 183 | compression(teacher, dataset, record) 184 | fully_train(dataset=opt.dataset[:-3]) 185 | -------------------------------------------------------------------------------- /architecture.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import models 6 | import graph as gr 7 | import layer 8 | import options as opt 9 | import copy 10 | import os 11 | import random 12 | from math import ceil 13 | 14 | class Architecture(nn.Module): 15 | 16 | def __init__(self, n, V, E): 17 | super(Architecture, self).__init__() 18 | self.n = n 19 | self.V = V 20 | for i in range(n): 21 | self.add_module('layer_%d' % (i), V[i]) 22 | self.groups = gr.get_groups(V) 23 | self.E = E 24 | self.in_links, self.out_links = gr.get_links(E) 25 | self.init_rep() 26 | 27 | def get_layer(self, i): 28 | return getattr(self, 'layer_%d' % (i)) 29 | 30 | def forward(self, x): 31 | y = [None] * self.n 32 | y[0] = self.get_layer(0)(x) 33 | for j in range(1, self.n): 34 | x = [] 35 | for i in self.in_links[j]: 36 | x.append(y[i]) 37 | if j == self.out_links[i][-1]: 38 | y[i] = None 39 | if not x: 40 | y[j] = None 41 | else: 42 | layer = self.get_layer(j) 43 | if isinstance(layer.base, models.Concat): 44 | y[j] = layer(x) 45 | else: 46 | x = sum(x) 47 | y[j] = layer(x) 48 | return y[-1] 49 | 50 | def init_param(self): 51 | V = self.V 52 | for i in range(self.n): 53 | V[i].init_param() 54 | 55 | def param_n(self): 56 | V = self.V 57 | cnt = 0 58 | for i in range(self.n): 59 | cnt += V[i].param_n() 60 | return cnt 61 | 62 | def init_rep(self): 63 | n = self.n 64 | V = self.V 65 | base_mat = [(V[i].rep) for i in range(n)] 66 | in_mat = [([0] * opt.ar_max_layers) for i in range(n)] 67 | out_mat = [([0] * opt.ar_max_layers) for i in range(n)] 68 | for i in range(n): 69 | for j in self.in_links[i]: 70 | in_mat[i][i - j] = 1 71 | for j in self.out_links[i]: 72 | out_mat[i][j - i] = 1 73 | self.rep = [(base_mat[i] + in_mat[i] + out_mat[i]) for i in range(n)] 74 | self.rep = torch.tensor(self.rep, dtype=torch.float, device=opt.device) 75 | 76 | def comp_action_rand(self): 77 | n = self.n 78 | V = self.V 79 | p1 = random.choice(opt.ar_p1) 80 | action = [] 81 | for i in range(n): 82 | if random.random() < p1 and V[i].in_shape == V[i].out_shape: 83 | action.append(1.0) 84 | else: 85 | action.append(0.0) 86 | for i in range(len(self.groups)): 87 | action.append(random.uniform(*opt.ar_p2)) 88 | p3 = random.choice(opt.ar_p3) 89 | for i in range(n): 90 | for j in range(i + 1, n): 91 | if V[i].out_shape == V[j].in_shape: 92 | if random.random() < p3 and not action[j]: 93 | action.append(1.0) 94 | else: 95 | action.append(0.0) 96 | return np.array(action) 97 | 98 | def comp_rep(self, action): 99 | n = self.n 100 | V = self.V 101 | p = 0 102 | base_mat = [(V[i].rep.copy()) for i in range(n)] 103 | in_mat = [([0] * opt.ar_max_layers) for i in range(n)] 104 | out_mat = [([0] * opt.ar_max_layers) for i in range(n)] 105 | for i in range(n): 106 | if action[p]: 107 | base_mat[i][:16] = [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 108 | p += 1 109 | for g in self.groups: 110 | F = max(1, int((1.0 - action[p]) * g.F)) 111 | for j in g.in_layers: 112 | if (isinstance(V[j].base, (nn.Conv2d, nn.Linear)) and 113 | not action[j]): 114 | base_mat[j][14] = F 115 | for j in g.out_layers: 116 | if (isinstance(V[j].base, (nn.Conv2d, nn.Linear)) and 117 | not action[j]): 118 | base_mat[j][15] = F 119 | p += 1 120 | for i in range(n): 121 | for j in range(i + 1, n): 122 | if V[i].out_shape == V[j].in_shape: 123 | if self.E[i][j] or action[p]: 124 | in_mat[j][j - i] = 1 125 | out_mat[i][j - i] = 1 126 | p += 1 127 | rep = [(base_mat[i] + in_mat[i] + out_mat[i]) for i in range(n)] 128 | rep = torch.tensor(rep, dtype=torch.float, device=opt.device) 129 | return rep 130 | 131 | def comp_arch(self, action): 132 | arch = copy.deepcopy(self) 133 | arch.action = action 134 | n = arch.n 135 | V = arch.V 136 | p = 0 137 | for i in range(n): 138 | if action[p]: 139 | V[i].replace(layer.Identity()) 140 | p += 1 141 | in_shapes = [V[i].in_shape for i in range(n)] 142 | out_shapes = [V[i].out_shape for i in range(n)] 143 | for g in self.groups: 144 | F = max(1, int((1.0 - action[p]) * g.F)) 145 | for j in g.inter: 146 | V[j].shrink(F, F) 147 | for j in g.in_only: 148 | Fo = list(V[j].out_shape)[1] 149 | V[j].shrink(F, Fo) 150 | for j in g.out_only: 151 | Fi = list(V[j].in_shape)[1] 152 | V[j].shrink(Fi, F) 153 | p += 1 154 | for i in range(n): 155 | for j in range(i + 1, n): 156 | if out_shapes[i] == in_shapes[j]: 157 | if action[p]: 158 | arch.E[i][j] = True 159 | p += 1 160 | arch.in_links, arch.out_links = gr.get_links(arch.E) 161 | arch.init_rep() 162 | arch.to(opt.device) 163 | return arch 164 | 165 | def comp_arch_rand_sfn(self): 166 | 167 | def shrink_n(F, ratio): 168 | m = opt.ar_channel_mul 169 | return max(1, int(ceil((1.0 - ratio) * F / m))) * m 170 | 171 | arch = copy.deepcopy(self) 172 | n = arch.n 173 | V = arch.V 174 | 175 | p1 = random.choice(opt.ar_p1) 176 | for i in range(n): 177 | if (random.random() < p1 and V[i].in_shape == V[i].out_shape and 178 | i not in [11, 50, 125]): 179 | V[i].replace(models.Identity()) 180 | 181 | opt.ar_p2[1] = min(0.9, opt.ar_p2[1]) 182 | for g in self.groups: 183 | p2 = random.uniform(*opt.ar_p2) 184 | for j in g.inter: 185 | Fi = shrink_n(list(V[j].in_shape)[1], p2) 186 | Fo = shrink_n(list(V[j].out_shape)[1], p2) 187 | V[j].shrink(Fi, Fo) 188 | for j in g.in_only: 189 | Fi = shrink_n(list(V[j].in_shape)[1], p2) 190 | Fo = list(V[j].out_shape)[1] 191 | V[j].shrink(Fi, Fo) 192 | for j in g.out_only: 193 | Fi = list(V[j].in_shape)[1] 194 | Fo = shrink_n(list(V[j].out_shape)[1], p2) 195 | V[j].shrink(Fi, Fo) 196 | 197 | F1 = list(V[2].out_shape)[1] 198 | F4 = list(V[13].out_shape)[1] 199 | F3 = list(V[13].in_shape)[1] 200 | F2 = F4 - F3 201 | V[11].shrink(F1, F2) 202 | V[12].shrink(F2, F2) 203 | 204 | F1 = list(V[41].out_shape)[1] 205 | F4 = list(V[52].out_shape)[1] 206 | F3 = list(V[52].in_shape)[1] 207 | F2 = F4 - F3 208 | V[50].shrink(F1, F2) 209 | V[51].shrink(F2, F2) 210 | 211 | F1 = list(V[116].out_shape)[1] 212 | F4 = list(V[127].out_shape)[1] 213 | F3 = list(V[127].in_shape)[1] 214 | F2 = F4 - F3 215 | V[125].shrink(F1, F2) 216 | V[126].shrink(F2, F2) 217 | 218 | p3 = random.choice(opt.ar_p3) 219 | for i in range(n): 220 | for j in range(i + 1, n): 221 | if (random.random() < p3 and V[i].out_shape == V[j].in_shape and 222 | not isinstance(V[j].base, (models.Concat, models.Identity))): 223 | arch.E[i][j] = True 224 | 225 | arch.in_links, arch.out_links = gr.get_links(arch.E) 226 | arch.init_rep() 227 | arch.to(opt.device) 228 | return arch 229 | 230 | def save(self, save_path): 231 | path = os.path.dirname(save_path) 232 | if not os.path.exists(path): 233 | os.makedirs(path) 234 | torch.save(self, save_path) -------------------------------------------------------------------------------- /graph.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import models 3 | from layer import Layer, LayerGroup 4 | import options as opt 5 | 6 | def get_graph_vgg(vgg): 7 | D = dict() 8 | n = 0 9 | V = [] 10 | E = [[]] 11 | 12 | def record_hook(module, input, output): 13 | key = id(module) 14 | if key not in D: 15 | D[key] = len(V) 16 | V.append(Layer(module, input[0].shape, output.shape)) 17 | 18 | hooks = [] 19 | for module in vgg.modules(): 20 | if isinstance(module, Layer.supported_base): 21 | hooks.append(module.register_forward_hook(record_hook)) 22 | input = torch.rand(1, 3, 32, 32, device=opt.device) 23 | output = vgg(input) 24 | for hook in hooks: 25 | hook.remove() 26 | 27 | n = len(V) 28 | E = [([False] * n) for i in range(n)] 29 | for i in range(n - 1): 30 | E[i][i + 1] = True 31 | return n, V, E 32 | 33 | def get_graph_resnet(resnet): 34 | D = dict() 35 | n = 0 36 | V = [] 37 | E = [[]] 38 | 39 | def record_hook(module, input, output): 40 | key = id(module) 41 | if key not in D: 42 | D[key] = len(V) 43 | V.append(Layer(module, input[0].shape, output.shape)) 44 | 45 | def add_edge(src, dst): 46 | i = D[id(src)] 47 | j = D[id(dst)] 48 | E[i][j] = True 49 | 50 | def add_chain(ls): 51 | for i in range(len(ls) - 1): 52 | add_edge(ls[i], ls[i + 1]) 53 | 54 | hooks = [] 55 | for module in resnet.modules(): 56 | if isinstance(module, Layer.supported_base): 57 | hooks.append(module.register_forward_hook(record_hook)) 58 | input = torch.rand(1, 3, 32, 32, device=opt.device) 59 | output = resnet(input) 60 | for hook in hooks: 61 | hook.remove() 62 | 63 | n = len(V) 64 | E = [([False] * n) for i in range(n)] 65 | 66 | chain = [resnet.conv1, resnet.bn1, resnet.relu1] 67 | add_chain(chain) 68 | 69 | src = [resnet.relu1] 70 | for module in resnet.modules(): 71 | if isinstance(module, models.BasicBlockM): 72 | chain = [module.conv1, module.bn1, module.relu1, 73 | module.conv2, module.bn2, module.relu2] 74 | add_chain(chain) 75 | dst = [module.conv1] 76 | src_ = [module.relu2] 77 | 78 | if module.downsample is not None: 79 | chain = list(module.downsample.children()) 80 | add_chain(chain) 81 | dst.append(chain[0]) 82 | add_edge(chain[-1], module.relu2) 83 | else: 84 | dst.append(module.relu2) 85 | 86 | for s in src: 87 | for d in dst: 88 | add_edge(s, d) 89 | src = src_ 90 | dst = [] 91 | 92 | chain = [resnet.avgpool, resnet.flatten, resnet.fc] 93 | for s in src: 94 | add_edge(s, chain[0]) 95 | add_chain(chain) 96 | 97 | return n, V, E 98 | 99 | def get_graph_shufflenet(shufflenet): 100 | D = dict() 101 | n = 0 102 | V = [] 103 | E = [[]] 104 | 105 | def record_hook(module, input, output): 106 | key = id(module) 107 | if key not in D: 108 | D[key] = len(V) 109 | in_shape = input[0][0].shape if isinstance(input[0], list) else input[0].shape 110 | out_shape = output.shape 111 | V.append(Layer(module, in_shape, out_shape)) 112 | 113 | def add_edge(src, dst): 114 | i = D[id(src)] 115 | j = D[id(dst)] 116 | E[i][j] = True 117 | 118 | def add_chain(ls): 119 | for i in range(len(ls) - 1): 120 | add_edge(ls[i], ls[i + 1]) 121 | 122 | hooks = [] 123 | for module in shufflenet.modules(): 124 | if isinstance(module, Layer.supported_base): 125 | hooks.append(module.register_forward_hook(record_hook)) 126 | input = torch.rand(1, 3, 32, 32, device=opt.device) 127 | output = shufflenet(input) 128 | for hook in hooks: 129 | hook.remove() 130 | 131 | n = len(V) 132 | E = [([False] * n) for i in range(n)] 133 | 134 | chain = [shufflenet.conv1, shufflenet.bn1, shufflenet.relu1] 135 | add_chain(chain) 136 | 137 | src = [shufflenet.relu1] 138 | for module in shufflenet.modules(): 139 | if isinstance(module, models.BottleneckM): 140 | chain = [module.conv1, module.bn1, module.relu1, module.shuffle, 141 | module.conv2, module.bn2, 142 | module.conv3, module.bn3] 143 | add_chain(chain) 144 | dst = [module.conv1] 145 | src_ = [module.relu3] 146 | 147 | if module.stride == 2: 148 | dst.append(module.conv4) 149 | add_edge(module.conv4, module.avgpool) 150 | add_edge(module.avgpool, module.concat) 151 | add_edge(module.bn3, module.concat) 152 | add_edge(module.concat, module.relu3) 153 | else: 154 | add_edge(module.bn3, module.relu3) 155 | dst.append(module.relu3) 156 | 157 | for s in src: 158 | for d in dst: 159 | add_edge(s, d) 160 | src = src_ 161 | dst = [] 162 | 163 | chain = [shufflenet.avgpool, shufflenet.flatten, shufflenet.fc] 164 | for s in src: 165 | add_edge(s, chain[0]) 166 | add_chain(chain) 167 | 168 | return n, V, E 169 | 170 | def get_groups(V): 171 | if opt.co_graph_gen == 'get_graph_shufflenet': 172 | groups = [] 173 | in_layers = list(range(1, 4)) 174 | out_layers = list(range(0, 3)) 175 | groups.append(LayerGroup(-1, in_layers, out_layers)) 176 | in_layers = list(range(4, 11)) + list(range(13, 43)) 177 | out_layers = list(range(3, 11)) + list(range(13, 42)) 178 | groups.append(LayerGroup(-1, in_layers, out_layers)) 179 | in_layers = list(range(43, 50)) + list(range(52, 118)) 180 | out_layers = list(range(42, 50)) + list(range(52, 117)) 181 | groups.append(LayerGroup(-1, in_layers, out_layers)) 182 | in_layers = list(range(118, 125)) + list(range(127, 159)) 183 | out_layers = list(range(117, 125)) + list(range(127, 158)) 184 | groups.append(LayerGroup(-1, in_layers, out_layers)) 185 | return groups 186 | 187 | else: 188 | n = len(V) 189 | vis = [([False] * 2) for i in range(n)] 190 | vis[0][0] = True 191 | vis[-1][1] = True 192 | groups = [] 193 | for i in range(n): 194 | for j in range(2): 195 | if not vis[i][j]: 196 | F = V[i].out_shape[1] if j else V[i].in_shape[1] 197 | in_layers = [] 198 | out_layers = [] 199 | for k in range(n): 200 | if not vis[k][0] and V[k].in_shape[1] == F: 201 | in_layers.append(k) 202 | vis[k][0] = True 203 | if not vis[k][1] and V[k].out_shape[1] == F: 204 | out_layers.append(k) 205 | vis[k][1] = True 206 | groups.append(LayerGroup(F, in_layers, out_layers)) 207 | return groups 208 | 209 | def get_links(E): 210 | n = len(E) 211 | in_links = [[] for i in range(n)] 212 | out_links = [[] for i in range(n)] 213 | for i in range(n): 214 | for j in range(n): 215 | if E[i][j]: 216 | in_links[j].append(i) 217 | out_links[i].append(j) 218 | return in_links, out_links 219 | 220 | def get_plot(name, n, V, E, reduced=False): 221 | from graphviz import Digraph 222 | dot = Digraph(name=name) 223 | for i, v in enumerate(V): 224 | node_name = '%d %s %s->%s' % (i, v.base_type, 225 | str(list(v.in_shape)[1:]), str(list(v.out_shape)[1:])) 226 | colors = ['gray', 'gray', 'gray', 'gray', 'red', 'yellow', 'yellow', 'green', 'cyan', 'blue'] 227 | if v.base_type != 'Identity' or not reduced: 228 | color = colors[Layer.supported_base.index(type(v.base))] 229 | dot.node(str(i), node_name, shape='box', color=color) 230 | if reduced: 231 | for i in range(n): 232 | if V[i].base_type == 'Identity': 233 | in_links = [] 234 | out_links = [] 235 | for j in range(n): 236 | if E[j][i]: 237 | in_links.append(j) 238 | E[j][i] = False 239 | if E[i][j]: 240 | out_links.append(j) 241 | E[i][j] = False 242 | for u in in_links: 243 | for v in out_links: 244 | E[u][v] = True 245 | for i in range(n): 246 | for j in range(n): 247 | if E[i][j]: 248 | dot.edge(str(i), str(j)) 249 | dot.view() -------------------------------------------------------------------------------- /training.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | import torch.backends.cudnn as cudnn 6 | import options as opt 7 | import os 8 | import time 9 | 10 | def init_model(model): 11 | for module in model.modules(): 12 | if isinstance(module, nn.Conv2d): 13 | nn.init.kaiming_normal_(module.weight, mode='fan_out', 14 | nonlinearity='relu') 15 | if module.bias is not None: 16 | nn.init.constant_(module.bias, 0) 17 | elif isinstance(module, nn.BatchNorm2d): 18 | nn.init.constant_(module.weight, 1) 19 | nn.init.constant_(module.bias, 0) 20 | elif isinstance(module, nn.Linear): 21 | nn.init.normal_(module.weight, 0, 0.01) 22 | nn.init.constant_(module.bias, 0) 23 | return model 24 | 25 | def test_model(model, dataset): 26 | model.eval() 27 | correct = 0 28 | total = 0 29 | loader = None 30 | if hasattr(dataset, 'test_loader'): 31 | loader = dataset.test_loader 32 | elif hasattr(dataset, 'val_loader'): 33 | loader = dataset.val_loader 34 | else: 35 | raise NotImplementedError('Unknown dataset!') 36 | with torch.no_grad(): 37 | for batch_idx, (inputs, targets) in enumerate(loader): 38 | inputs = inputs.to(opt.device) 39 | targets = targets.to(opt.device) 40 | outputs = model(inputs) 41 | _, predicted = outputs.max(1) 42 | total += targets.size(0) 43 | correct += predicted.eq(targets).sum().item() 44 | acc = 100.0 * correct / total 45 | return acc 46 | 47 | def train_model_teacher(model_, dataset, save_path, epochs=400, lr=0.1, 48 | momentum=0.9, weight_decay=5e-4): 49 | acc_best = 0 50 | model_best = None 51 | model = torch.nn.DataParallel(model_.to(opt.device)) 52 | criterion = nn.CrossEntropyLoss() 53 | optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, 54 | weight_decay=weight_decay) 55 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1) 56 | 57 | for i in range(1, epochs + 1): 58 | model.train() 59 | scheduler.step() 60 | loss_total = 0 61 | batch_cnt = 0 62 | for batch_idx, (inputs, targets) in enumerate(dataset.train_loader): 63 | inputs = inputs.to(opt.device) 64 | targets = targets.to(opt.device) 65 | optimizer.zero_grad() 66 | outputs = model(inputs) 67 | loss = criterion(outputs, targets) 68 | loss.backward() 69 | optimizer.step() 70 | loss_total += loss.item() 71 | batch_cnt += 1 72 | opt.writer.add_scalar('training/loss', loss_total / batch_cnt, i) 73 | acc = test_model(model, dataset) 74 | opt.writer.add_scalar('training/acc', acc, i) 75 | if acc > acc_best: 76 | acc_best = acc 77 | model.module.acc = acc 78 | model_best = model.module 79 | torch.save(model_best, save_path) 80 | return model_best, acc_best 81 | 82 | def train_model_student(model_, dataset, save_path, idx, 83 | optimization=opt.tr_fu_optimization, 84 | epochs=opt.tr_fu_epochs, lr=opt.tr_fu_lr, 85 | momentum=opt.tr_fu_momentum, 86 | weight_decay=opt.tr_fu_weight_decay, 87 | lr_schedule=opt.tr_fu_lr_schedule, 88 | from_scratch=opt.tr_fu_from_scratch): 89 | acc_best = 0 90 | model_best = None 91 | model = torch.nn.DataParallel(model_.to(opt.device)) 92 | criterion = nn.CrossEntropyLoss() 93 | 94 | if optimization == 'SGD': 95 | optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, 96 | weight_decay=weight_decay) 97 | elif optimization == 'Adam': 98 | optimizer = optim.Adam(model.parameters(), lr=lr, 99 | weight_decay=weight_decay) 100 | if lr_schedule == 'step': 101 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, 102 | gamma=0.1) 103 | elif lr_schedule == 'linear': 104 | batch_cnt = len(dataset.train_loader) 105 | n_total_exp = epochs * batch_cnt 106 | lr_lambda = lambda n_exp_seen: 1 - n_exp_seen/n_total_exp 107 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) 108 | if from_scratch: 109 | init_model(model) 110 | 111 | for i in range(1, epochs + 1): 112 | model.train() 113 | if lr_schedule == 'step': 114 | scheduler.step() 115 | loss_total = 0 116 | batch_cnt = 0 117 | for batch_idx, (inputs, targets) in enumerate(dataset.train_loader): 118 | inputs = inputs.to(opt.device) 119 | targets = targets.to(opt.device) 120 | if lr_schedule == 'linear': 121 | scheduler.step() 122 | optimizer.zero_grad() 123 | outputs = model(inputs) 124 | loss = criterion(outputs, targets) 125 | loss.backward() 126 | optimizer.step() 127 | loss_total += loss.item() 128 | batch_cnt += 1 129 | opt.writer.add_scalar('training_%d/loss' % (idx), loss_total / batch_cnt, i) 130 | acc = test_model(model, dataset) 131 | opt.writer.add_scalar('training_%d/acc' % (idx), acc, i) 132 | if acc > acc_best: 133 | acc_best = acc 134 | model.module.acc = acc 135 | model_best = model.module 136 | torch.save(model_best, save_path) 137 | return model_best, acc_best 138 | 139 | def train_model_search(teacher_, students_, dataset, 140 | optimization=opt.tr_se_optimization, 141 | epochs=opt.tr_se_epochs, lr=opt.tr_se_lr, 142 | momentum=opt.tr_se_momentum, 143 | weight_decay=opt.tr_se_weight_decay, 144 | lr_schedule=opt.tr_se_lr_schedule, 145 | loss_criterion=opt.tr_se_loss_criterion): 146 | n = len(students_) 147 | accs_best = [0.0] * n 148 | students_best = [None] * n 149 | teacher = torch.nn.DataParallel(teacher_.to(opt.device)) 150 | students = [None] * n 151 | 152 | for j in range(n): 153 | students[j] = torch.nn.DataParallel(students_[j].to(opt.device)) 154 | if loss_criterion == 'KD': 155 | criterion = nn.MSELoss() 156 | elif loss_criterion == 'CE': 157 | criterion = nn.CrossEntropyLoss() 158 | if optimization == 'SGD': 159 | optimizers = [optim.SGD(students[j].parameters(), lr=lr, 160 | momentum=momentum, weight_decay=weight_decay) 161 | for j in range(n)] 162 | elif optimization == 'Adam': 163 | optimizers = [optim.Adam(students[j].parameters(), lr=lr, 164 | weight_decay=weight_decay) for j in range(n)] 165 | if lr_schedule == 'linear': 166 | batch_cnt = len(dataset.train_loader) 167 | n_total_exp = epochs * batch_cnt 168 | lr_lambda = lambda n_exp_seen: 1 - n_exp_seen/n_total_exp 169 | schedulers = [optim.lr_scheduler.LambdaLR(optimizers[j], 170 | lr_lambda=lr_lambda) 171 | for j in range(n)] 172 | 173 | for i in range(1, epochs + 1): 174 | teacher.eval() 175 | for j in range(n): 176 | students[j].train() 177 | loss_total = [0.0] * n 178 | batch_cnt = 0 179 | for batch_idx, (inputs, targets) in enumerate(dataset.train_loader): 180 | inputs = inputs.to(opt.device) 181 | 182 | if loss_criterion == 'KD': 183 | teacher_outputs = None 184 | with torch.no_grad(): 185 | teacher_outputs = teacher(inputs) 186 | elif loss_criterion == 'CE': 187 | targets = targets.to(opt.device) 188 | 189 | for j in range(n): 190 | if lr_schedule == 'linear': 191 | schedulers[j].step() 192 | optimizers[j].zero_grad() 193 | student_outputs = students[j](inputs) 194 | if loss_criterion == 'KD': 195 | loss = criterion(student_outputs, teacher_outputs) 196 | elif loss_criterion == 'CE': 197 | loss = criterion(student_outputs, targets) 198 | loss.backward() 199 | optimizers[j].step() 200 | loss_total[j] += loss.item() 201 | batch_cnt += 1 202 | for j in range(n): 203 | opt.writer.add_scalar('step_%d/sample_%d_loss' % (opt.i, j), 204 | loss_total[j] / batch_cnt, i) 205 | acc = test_model(students[j], dataset) 206 | opt.writer.add_scalar('step_%d/sample_%d_acc' % (opt.i, j), acc, i) 207 | if acc > accs_best[j]: 208 | accs_best[j] = acc 209 | students_best[j] = students[j].module 210 | return students_best, accs_best -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | ESNAC 2 | SOFTWARE LICENSE AGREEMENT 3 | ACADEMIC OR NON-PROFIT ORGANIZATION NONCOMMERCIAL RESEARCH USE ONLY 4 | 5 | BY USING OR DOWNLOADING THE SOFTWARE, YOU ARE AGREEING TO THE TERMS OF THIS LICENSE AGREEMENT. IF YOU DO NOT AGREE WITH THESE TERMS, YOU MAY NOT USE OR DOWNLOAD THE SOFTWARE. 6 | 7 | This is a license agreement ("Agreement") between your academic institution or non profit organization or self (called "Licensee" or "You" in this Agreement) and Carnegie Mellon University (called "Licensor" in this Agreement). All rights not specifically granted to you in this Agreement are reserved for Licensor. 8 | 9 | RESERVATION OF OWNERSHIP AND GRANT OF LICENSE: 10 | Licensor retains exclusive ownership of any copy of the Software (as defined below) licensed under this Agreement and hereby grants to Licensee a personal, non-exclusive, 11 | non-transferable license to use the Software for noncommercial research purposes, without the right to sublicense, pursuant to the terms and conditions of this Agreement. As used in this Agreement, the term "Software" means (i) the actual copy of all or any portion of code for program routines made accessible to Licensee by Licensor pursuant to this Agreement, inclusive of backups, updates, and/or merged copies permitted hereunder or subsequently supplied by Licensor, including all or any file structures, programming instructions, user interfaces and screen formats and sequences as well as any and all documentation and instructions related to it, and (ii) all or any derivatives and/or modifications created or made by You to any of the items specified in (i). 12 | 13 | CONFIDENTIALITY: Licensee acknowledges that the Software is proprietary to Licensor, and as such, Licensee agrees to receive all such materials in confidence and use the Software only in accordance with the terms of this Agreement. Licensee agrees to use reasonable effort to protect the Software from unauthorized use, reproduction, distribution, or publication. 14 | 15 | COPYRIGHT: The Software is owned by Licensor and is protected by United States copyright laws and applicable international treaties and/or conventions. 16 | 17 | PERMITTED USES: The Software may be used for your own noncommercial internal research purposes. You understand and agree that Licensor is not obligated to implement any suggestions and/or feedback you might provide regarding the Software, but to the extent Licensor does so, you are not entitled to any compensation related thereto. 18 | 19 | DERIVATIVES: You may create derivatives of or make modifications to the Software, however, You agree that all and any such derivatives and modifications will be owned by Licensor and become a part of the Software licensed to You under this Agreement. You may only use such derivatives and modifications for your own noncommercial internal research purposes, and you may not otherwise use, distribute or copy such derivatives and modifications in violation of this Agreement. 20 | 21 | BACKUPS: If Licensee is an organization, it may make that number of copies of the Software necessary for internal noncommercial use at a single site within its organization provided that all information appearing in or on the original labels, including the copyright and trademark notices are copied onto the labels of the copies. 22 | 23 | USES NOT PERMITTED: You may not distribute, copy or use the Software except as explicitly permitted herein. Licensee has not been granted any trademark license as part of this Agreement and may not use the name or mark "ESNAC", "Carnegie Mellon", or any renditions thereof without the prior written permission of Licensor. 24 | 25 | You may not sell, rent, lease, sublicense, lend, time-share or transfer, in whole or in part, or provide third parties access to prior or present versions (or any parts thereof) of the Software. 26 | 27 | ASSIGNMENT: You may not assign this Agreement or your rights hereunder without the prior written consent of Licensor. Any attempted assignment without such consent shall be null and void. 28 | 29 | TERM: The term of the license granted by this Agreement is from Licensee's acceptance of this Agreement by clicking "I Agree" below or by using the Software until terminated as provided below. 30 | 31 | The Agreement automatically terminates without notice if you fail to comply with any provision of this Agreement. Licensee may terminate this Agreement by ceasing using the Software. Upon any termination of this Agreement, Licensee will delete any and all copies of the Software. You agree that all provisions which operate to protect the proprietary rights of Licensor shall remain in force should breach occur and that the obligation of confidentiality described in this Agreement is binding in perpetuity and, as such, survives the term of the Agreement. 32 | 33 | FEE: Provided Licensee abides completely by the terms and conditions of this Agreement, there is no fee due to Licensor for Licensee's use of the Software in accordance with this Agreement. 34 | 35 | DISCLAIMER OF WARRANTIES: THE SOFTWARE IS PROVIDED "AS-IS" WITHOUT WARRANTY OF ANY KIND INCLUDING ANY WARRANTIES OF PERFORMANCE OR MERCHANTABILITY OR FITNESS FOR A PARTICULAR USE OR PURPOSE OR OF NON-INFRINGEMENT. LICENSEE BEARS ALL RISK RELATING TO QUALITY AND PERFORMANCE OF THE SOFTWARE AND RELATED MATERIALS. 36 | 37 | SUPPORT AND MAINTENANCE: No Software support or training by the Licensor is provided as part of this Agreement. 38 | 39 | EXCLUSIVE REMEDY AND LIMITATION OF LIABILITY: To the maximum extent permitted under applicable law, Licensor shall not be liable for direct, indirect, special, incidental, or consequential damages or lost profits related to Licensee's use of and/or inability to use the Software, even if Licensor is advised of the possibility of such damage. 40 | 41 | EXPORT REGULATION: Licensee agrees to comply with any and all applicable 42 | U.S. export control laws, regulations, and/or other laws related to embargoes and sanction programs administered by the Office of Foreign Assets Control. 43 | 44 | SEVERABILITY: If any provision(s) of this Agreement shall be held to be invalid, illegal, or unenforceable by a court or other tribunal of competent jurisdiction, the validity, legality and enforceability of the remaining provisions shall not in any way be affected or impaired thereby. 45 | 46 | NO IMPLIED WAIVERS: No failure or delay by Licensor in enforcing any right or remedy under this Agreement shall be construed as a waiver of any future or other exercise of such right or remedy by Licensor. 47 | 48 | GOVERNING LAW: This Agreement shall be construed and enforced in accordance with the laws of the Commonwealth of Pennsylvania without reference to conflict of laws principles. You consent to the personal jurisdiction of the courts of this County and waive their rights to venue outside of Allegheny County, Pennsylvania. 49 | 50 | ENTIRE AGREEMENT AND AMENDMENTS: This Agreement constitutes the sole and entire agreement between Licensee and Licensor as to the matter set forth herein and supersedes any previous agreements, understandings, and arrangements between the parties relating hereto. 51 | 52 | ************************************************************************ 53 | 54 | THIRD-PARTY SOFTWARE NOTICES AND INFORMATION 55 | 56 | This project incorporates material from the project(s) listed below (collectively, "Third Party Code"). This Third Party Code is licensed to you under their original license terms set forth below. We reserve all other rights not expressly granted, whether by implication, estoppel, or otherwise. 57 | 58 | 1. PyTorch, version 1.0, (https://github.com/pytorch/pytorch/) 59 | 60 | Copyright (c) 2016- Facebook, Inc (Adam Paszke) 61 | Copyright (c) 2014- Facebook, Inc (Soumith Chintala) 62 | Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) 63 | Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) 64 | Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) 65 | Copyright (c) 2011-2013 NYU (Clement Farabet) 66 | Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) 67 | Copyright (c) 2006 Idiap Research Institute (Samy Bengio) 68 | Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) 69 | 70 | From Caffe2: 71 | 72 | Copyright (c) 2016-present, Facebook Inc. All rights reserved. 73 | 74 | All contributions by Facebook: 75 | Copyright (c) 2016 Facebook Inc. 76 | 77 | All contributions by Google: 78 | Copyright (c) 2015 Google Inc. 79 | All rights reserved. 80 | 81 | All contributions by Yangqing Jia: 82 | Copyright (c) 2015 Yangqing Jia 83 | All rights reserved. 84 | 85 | All contributions from Caffe: 86 | Copyright(c) 2013, 2014, 2015, the respective contributors 87 | All rights reserved. 88 | 89 | All other contributions: 90 | Copyright(c) 2015, 2016 the respective contributors 91 | All rights reserved. 92 | 93 | Caffe2 uses a copyright model similar to Caffe: each contributor holds copyright over their contributions to Caffe2. The project versioning records all such contribution and copyright details. If a contributor wants to further mark their specific copyright on a particular contribution, they should indicate their copyright solely in the commit message of the change when it is committed. 94 | 95 | All rights reserved. 96 | 97 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 98 | 99 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 100 | 101 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 102 | 103 | 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America and IDIAP Research Institute nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 104 | 105 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 106 | 107 | 2. PyTorch Examples (https://github.com/pytorch/examples/) 108 | 109 | Copyright (c) 2017, 110 | All rights reserved 111 | 112 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 113 | 114 | * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 115 | 116 | * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 117 | 118 | * Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 119 | 120 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 121 | 122 | 3. PyTorch Vision (https://github.com/pytorch/vision/) 123 | 124 | Copyright (c) Soumith Chintala 2016, 125 | All rights reserved 126 | 127 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 128 | 129 | * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 130 | 131 | * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 132 | 133 | * Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 134 | 135 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 136 | 137 | 4. CIFAR10 with PyTorch (https://github.com/kuangliu/pytorch-cifar) 138 | 139 | Copyright (c) 2017 liukuang 140 | 141 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 142 | 143 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 144 | 145 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 146 | 147 | 148 | ************END OF THIRD-PARTY SOFTWARE NOTICES AND INFORMATION********** 149 | 150 | --------------------------------------------------------------------------------