├── README.md ├── lib ├── __init__.py └── util.py ├── main.py ├── models ├── __init__.py ├── channel_selection.py └── vgg.py ├── requirement.txt └── script └── vgg19.sh /README.md: -------------------------------------------------------------------------------- 1 | # Prune from Scratch 2 | Unofficial implementation of the paper ["Pruning from Scratch"](https://arxiv.org/abs/1909.12579). 3 | 4 | In order to verify the validity of the thesis proposed in this paper, I implemented a simple version myself. 5 | 6 | 7 | ## Accuracy 8 | | Model |Prune ratio | Acc. | 9 | | ----------------- | ----------------- |-------------| 10 | | [VGG19](https://arxiv.org/abs/1409.1556) | 50% | 93.24% | 11 | 12 | ## Insight 13 | At first thought, I think "Pruning from scratch" doesn't make sense. Pruning the network architecture according to the initial random weights doesn't sound reasonable. However, the experiment results show that you did can prune a network from scratch. So I think the key point of "Pruning from scratch" is that the “winning-tickets” subnetwork (LTHLottery Ticket Hypothesis) of the over parameterized network already has better-than-random performance on the data, without any training. Specifically, when we prune a network N from the scratch based on the random weights, we are looking for the "winning tickets" of N actually. The "Supermask" of the paper "Deconstructing Lottery Tickets: Zeros, Signs, and the Supermask" do the similar things(https://arxiv.org/abs/1905.01067). 14 | -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .util import * -------------------------------------------------------------------------------- /lib/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import math 5 | 6 | def format_time(seconds): 7 | days = int(seconds / 3600/24) 8 | seconds = seconds - days*3600*24 9 | hours = int(seconds / 3600) 10 | seconds = seconds - hours*3600 11 | minutes = int(seconds / 60) 12 | seconds = seconds - minutes*60 13 | secondsf = int(seconds) 14 | seconds = seconds - secondsf 15 | millis = int(seconds*1000) 16 | 17 | f = '' 18 | i = 1 19 | if days > 0: 20 | f += str(days) + 'D' 21 | i += 1 22 | if hours > 0 and i <= 2: 23 | f += str(hours) + 'h' 24 | i += 1 25 | if minutes > 0 and i <= 2: 26 | f += str(minutes) + 'm' 27 | i += 1 28 | if secondsf > 0 and i <= 2: 29 | f += str(secondsf) + 's' 30 | i += 1 31 | if millis > 0 and i <= 2: 32 | f += str(millis) + 'ms' 33 | i += 1 34 | if f == '': 35 | f = '0ms' 36 | return f 37 | 38 | _, term_width = os.popen('stty size', 'r').read().split() 39 | term_width = int(term_width) 40 | 41 | TOTAL_BAR_LENGTH = 65. 42 | last_time = time.time() 43 | begin_time = last_time 44 | def progress_bar(current, total, msg=None): 45 | global last_time, begin_time 46 | if current == 0: 47 | begin_time = time.time() # Reset for new bar. 48 | 49 | cur_len = int(TOTAL_BAR_LENGTH*current/total) 50 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 51 | 52 | sys.stdout.write(' [') 53 | for i in range(cur_len): 54 | sys.stdout.write('=') 55 | sys.stdout.write('>') 56 | for i in range(rest_len): 57 | sys.stdout.write('.') 58 | sys.stdout.write(']') 59 | 60 | cur_time = time.time() 61 | step_time = cur_time - last_time 62 | last_time = cur_time 63 | tot_time = cur_time - begin_time 64 | 65 | L = [] 66 | L.append(' Step: %s' % format_time(step_time)) 67 | L.append(' | Tot: %s' % format_time(tot_time)) 68 | if msg: 69 | L.append(' | ' + msg) 70 | 71 | msg = ''.join(L) 72 | sys.stdout.write(msg) 73 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 74 | sys.stdout.write(' ') 75 | 76 | # Go back to the center of the bar. 77 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): 78 | sys.stdout.write('\b') 79 | sys.stdout.write(' %d/%d ' % (current+1, total)) 80 | 81 | if current < total-1: 82 | sys.stdout.write('\r') 83 | else: 84 | sys.stdout.write('\n') 85 | sys.stdout.flush() 86 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/bin/env python 2 | import os 3 | import sys 4 | import copy 5 | import torch 6 | import models 7 | import argparse 8 | import numpy as np 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.optim as optim 12 | import torchvision 13 | import torchvision.datasets as datasets 14 | import torchvision.transforms as transforms 15 | import lib 16 | from lib.util import progress_bar 17 | from torch.autograd import Variable 18 | import thop 19 | from thop import profile 20 | 21 | def parse_args(): 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--dataset', type=str, default='cifar10', help='Dataset name(cifar10, cifar100).') 24 | parser.add_argument('--model', type=str, default='vgg', help='Model type to use.') 25 | parser.add_argument('--outdir', type=str, default='./log', help='Output path.') 26 | parser.add_argument('--aepoch', type=int, default=10, help='The number of epochs for arch learning.') 27 | parser.add_argument('--wepoch', type=int, default=200, help='The number of epochs for weight learning.') 28 | parser.add_argument('--alr', type=float, default=0.1, help='Learning rate of the architecture learning.') 29 | parser.add_argument('--batchsize', type=int, default=256, help='Batchsize of dataloader.') 30 | parser.add_argument('--expansion', type=float, default=1.0, help='The expansion ratio for the model.') 31 | parser.add_argument('--ratio', type=float, default=0.5, help='The prune ratio used in sparsity regularzation.') 32 | parser.add_argument('--lr', type=float, default=0.01, help='Learning rate for weight training.') 33 | parser.add_argument('--lr_decay', action='store_true', default=False, help='If use the learning rate decay.') 34 | parser.add_argument('--balance', type=float, default=0.5, help='The balance constant of the sparsity regularization.') 35 | parser.add_argument('--weight_decay', type=float, default=1e-4, help='Weight decay (default 1e-4).') 36 | return parser.parse_args() 37 | 38 | 39 | def prepare_data(args): 40 | cifar_train_trans = transforms.Compose([ 41 | transforms.RandomCrop(32, padding=4), 42 | transforms.RandomHorizontalFlip(), 43 | transforms.ToTensor(), 44 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 45 | ]) 46 | cifar_val_trans = transforms.Compose([ 47 | transforms.ToTensor(), 48 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 49 | ]) 50 | if args.dataset == 'cifar10': 51 | train_data = datasets.CIFAR10('./data/cifar10', train=True, download=True, transform=cifar_train_trans) 52 | val_data = datasets.CIFAR10('./data/cifar10', train=False, download=False, transform=cifar_val_trans) 53 | elif args.dataset == 'cifar100': 54 | train_data = datasets.CIFAR100('./data/cifar100', train=True, download=True, transform=cifar_train_trans) 55 | val_data = datasets.CIFAR100('./data/cifar100', train=False, download=False, transform=cifar_val_trans) 56 | else: 57 | raise NotImplementedError 58 | train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batchsize, shuffle=True, num_workers=8) 59 | val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.batchsize, shuffle=False, num_workers=8) 60 | return train_loader, val_loader 61 | 62 | 63 | def regularzation_update(model, args): 64 | if not args.sum_channel: 65 | args.sum_channel = 0 66 | for layer in model.modules(): 67 | if isinstance(layer, nn.BatchNorm2d): 68 | args.sum_channel += layer.weight.size()[0] 69 | sumc = args.sum_channel 70 | for layer in model.modules(): 71 | if isinstance(layer, nn.BatchNorm2d): 72 | layer.weight.grad.data.add(args.balance * 2.0 * torch.sign(layer.weight.data)*(layer.weight.data/sumc-args.ratio)) 73 | 74 | 75 | def arch_train(model, args, train_loader, val_loader): 76 | '''First Train the architecture parameters without updating the other weights''' 77 | # Freeze the weights 78 | for para in model.parameters(): 79 | para.requires_grad = False 80 | # Enable the parameters of network architecture 81 | for layer in model.modules(): 82 | if isinstance(layer, nn.BatchNorm2d): 83 | for para in layer.parameters(): 84 | para.requires_grad = True 85 | model.train() 86 | criterion = nn.CrossEntropyLoss() 87 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.alr) 88 | print('Training the Architecture') 89 | 90 | for epochid in range(args.aepoch): 91 | print('==> Epoch: %d' % epochid) 92 | train_loss = 0.0 93 | total = 0 94 | correct = 0 95 | for batchid, (data, target) in enumerate(train_loader): 96 | if args.Use_Cuda: 97 | data, target = data.cuda(), target.cuda() 98 | optimizer.zero_grad() 99 | output = model(data) 100 | loss = criterion(output, target) 101 | loss.backward() 102 | regularzation_update(model, args) 103 | optimizer.step() 104 | 105 | train_loss += loss.item() 106 | _, predicted = output.max(1) 107 | total += target.size(0) 108 | correct += predicted.eq(target).sum().item() 109 | avg_loss = train_loss / (batchid+1) 110 | acc = correct / total 111 | progress_bar(batchid, len(train_loader), 'Loss: %.3f | Acc: %.3f'% (avg_loss, acc)) 112 | 113 | 114 | def binary_search(model, gates, args, data_loader): 115 | # Get single batch data to profile the flops of the model 116 | model = copy.deepcopy(model).cpu() 117 | data, target = next(iter(data_loader)) 118 | ori_macs, ori_params = profile(model, inputs=(data,)) 119 | #pos = min(int(len(gates) * args.ratio), len(gates)-1) 120 | sorted_gates, _ = torch.sort(gates) 121 | # TODO: use binary search to find the threshold for the pruning 122 | lpos, rpos = 0, len(sorted_gates) - 1 123 | input = torch.randn((args.batchsize, 3)) 124 | eps = 1 125 | macs, params = None, None 126 | while lpos < rpos: 127 | midpos = int((lpos + rpos) / 2) 128 | cur_thres = sorted_gates[midpos] 129 | cfg = [] 130 | for layer in model.modules(): 131 | if isinstance(layer, nn.BatchNorm2d): 132 | weight_copy = layer.weight.data.abs().clone() 133 | mask = weight_copy.gt(cur_thres) 134 | cfg.append(int(torch.sum(mask).item())) 135 | elif isinstance(layer, nn.MaxPool2d): 136 | cfg.append('M') 137 | pruned_model = models.__dict__[args.model](args.num_class, cfg=cfg) 138 | pruned_model(data) 139 | macs, params = profile(pruned_model, inputs=(data,)) 140 | if abs(macs - ori_macs * args.ratio) < eps: 141 | lpos = midpos 142 | break 143 | elif macs > ori_macs * args.ratio: 144 | lpos = midpos + 1 145 | else: 146 | #macs < ori_macs * args.ratio: 147 | rpos = midpos - 1 148 | print('==>Original Model:') 149 | print(' Flops: {}G Parameters: {}M'.format(ori_macs/(10**9), ori_params/(10**6))) 150 | print('==>Pruned Model:') 151 | print(' Flops: {}G Parameters: {}M'.format(macs/(10**9), params/(10**6))) 152 | return sorted_gates[lpos] 153 | 154 | 155 | def prune(model, args, data_loader): 156 | if not os.path.exists(args.outdir): 157 | os.makedirs(args.outdir) 158 | print('Pruning the network according to the architecture parameters.') 159 | gates = torch.zeros(args.sum_channel) 160 | index = 0 161 | pruned = 0 162 | cfg = [] 163 | cfg_mask = [] 164 | for lid, layer in enumerate(model.modules()): 165 | if isinstance(layer, nn.BatchNorm2d): 166 | nchannel = layer.weight.data.shape[0] 167 | gates[index:index+nchannel] = layer.weight.data.abs().clone() 168 | index += nchannel 169 | threshold = binary_search(model, gates, args, data_loader) 170 | for lid, layer in enumerate(model.modules()): 171 | if isinstance(layer, nn.BatchNorm2d): 172 | weight_copy = layer.weight.data.abs().clone() 173 | mask = weight_copy.gt(threshold) 174 | mask = mask.float().cuda() 175 | layer.weight.data.mul_(mask) 176 | layer.bias.data.mul_(mask) 177 | pruned += mask.shape[0] - sum(mask) 178 | cfg.append(int(torch.sum(mask).item())) 179 | cfg_mask.append(mask) 180 | elif isinstance(layer, nn.MaxPool2d): 181 | cfg.append('M') 182 | print('Original channel number: ',args.sum_channel) 183 | print(cfg) 184 | print('After pruned channel number: ', sum(filter(lambda x: x!='M', cfg))) 185 | new_model = models.__dict__[args.model](args.num_class, cfg=cfg) 186 | logfile = os.path.join(args.outdir, 'log.txt') 187 | with open(logfile, 'w') as logf: 188 | logf.write('Configuration of the pruned model\n') 189 | logf.write(str(cfg)) 190 | return new_model 191 | 192 | 193 | def validation(model, val_loader, criterion, Use_Cuda): 194 | model.eval() 195 | test_loss = 0.0 196 | correct = 0 197 | total = 0 198 | with torch.no_grad(): 199 | for batchid, (data, target) in enumerate(val_loader): 200 | if Use_Cuda: 201 | data, target = data.cuda(), target.cuda() 202 | output = model(data) 203 | loss = criterion(output, target) 204 | test_loss += loss 205 | _, predicted = output.max(1) 206 | total += target.size(0) 207 | correct += predicted.eq(target).sum().item() 208 | avg_acc = correct / total 209 | avg_loss = test_loss / (batchid + 1) 210 | progress_bar(batchid, len(val_loader), 'Loss: %.3f | Acc: %.3f' % (avg_loss, avg_acc)) 211 | return correct/total 212 | 213 | 214 | def weight_train(model, train_loader, val_loader, args): 215 | best_acc = 0.0 216 | criterion = nn.CrossEntropyLoss() 217 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay) 218 | #lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1) 219 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [int(args.wepoch*0.5), int(args.wepoch*0.75)], gamma=0.1) 220 | for i in range(args.wepoch): 221 | print('==>Epoch %d' % (i+1)) 222 | print('==>Training') 223 | model.train() 224 | train_loss = 0.0 225 | correct = 0 226 | total = 0 227 | for batchid, (data, target) in enumerate(train_loader): 228 | if args.Use_Cuda: 229 | data, target = data.cuda(), target.cuda() 230 | optimizer.zero_grad() 231 | output = model(data) 232 | loss = criterion(output, target) 233 | loss.backward() 234 | optimizer.step() 235 | 236 | train_loss += loss.item() 237 | _, predicted = output.max(1) 238 | total += output.size(0) 239 | correct += predicted.eq(target).sum().item() 240 | avg_loss = train_loss / (batchid + 1) 241 | avg_acc = correct / total 242 | progress_bar(batchid, len(train_loader), 'Loss: %.3f | Acc: %.3f' % (avg_loss, avg_acc)) 243 | # Validation 244 | print('==>Validating') 245 | val_acc = validation(model, val_loader, criterion, args.Use_Cuda) 246 | if val_acc > best_acc: 247 | best_acc = val_acc 248 | best_checkpoint = {'state_dict':model.state_dict(), 'Acc':best_acc} 249 | fname = os.path.join(args.outdir, 'best.pth.tar') 250 | torch.save(best_checkpoint, fname) 251 | print('==>Best validation accuracy', best_acc) 252 | # Save checkpoint 253 | if (i + 1) % 10 == 0: 254 | torch.save(model.state_dict(), os.path.join(args.outdir, 'checkpoint.pth.tar')) 255 | # Lr_scheduler 256 | if args.lr_decay: 257 | lr_scheduler.step() 258 | 259 | def main(): 260 | args = parse_args() 261 | train_loader, val_loader = prepare_data(args) 262 | args.num_class = 10 if args.dataset == 'cifar10' else 100 263 | model = models.__dict__[args.model](num_classes=args.num_class, expansion=args.expansion) 264 | args.Use_Cuda = torch.cuda.is_available() 265 | args.sum_channel = None 266 | if args.Use_Cuda: 267 | model.cuda() 268 | arch_train(model, args, train_loader, val_loader) 269 | new_model = prune(model, args, train_loader) 270 | if args.Use_Cuda: 271 | new_model.cuda() 272 | weight_train(new_model, train_loader, val_loader, args) 273 | 274 | if __name__ == '__main__': 275 | main() -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .vgg import * 4 | #from .preresnet import * 5 | #from .densenet import * 6 | from .channel_selection import * -------------------------------------------------------------------------------- /models/channel_selection.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class channel_selection(nn.Module): 7 | """ 8 | Select channels from the output of BatchNorm2d layer. It should be put directly after BatchNorm2d layer. 9 | The output shape of this layer is determined by the number of 1 in `self.indexes`. 10 | """ 11 | def __init__(self, num_channels): 12 | """ 13 | Initialize the `indexes` with all one vector with the length same as the number of channels. 14 | During pruning, the places in `indexes` which correpond to the channels to be pruned will be set to 0. 15 | """ 16 | super(channel_selection, self).__init__() 17 | self.indexes = nn.Parameter(torch.ones(num_channels)) 18 | 19 | def forward(self, input_tensor): 20 | """ 21 | Parameter 22 | --------- 23 | input_tensor: (N,C,H,W). It should be the output of BatchNorm2d layer. 24 | """ 25 | selected_index = np.squeeze(np.argwhere(self.indexes.data.cpu().numpy())) 26 | if selected_index.size == 1: 27 | selected_index = np.resize(selected_index, (1,)) 28 | output = input_tensor[:, selected_index, :, :] 29 | return output -------------------------------------------------------------------------------- /models/vgg.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | 6 | 7 | __all__ = ['vgg'] 8 | 9 | defaultcfg = { 10 | 11 : [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512], 11 | 13 : [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512], 12 | 16 : [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512], 13 | 19 : [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512], 14 | } 15 | 16 | class vgg(nn.Module): 17 | def __init__(self, num_classes=10, depth=19, init_weights=True, cfg=None, expansion=1.0): 18 | super(vgg, self).__init__() 19 | if cfg is None: 20 | cfg = defaultcfg[depth] 21 | 22 | self.feature = self.make_layers(cfg, True) 23 | self.classifier = nn.Linear(cfg[-1], num_classes) 24 | if init_weights: 25 | self._initialize_weights() 26 | 27 | def make_layers(self, cfg, batch_norm=False): 28 | layers = [] 29 | in_channels = 3 30 | for v in cfg: 31 | if v == 'M': 32 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 33 | else: 34 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1, bias=False) 35 | if batch_norm: 36 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 37 | else: 38 | layers += [conv2d, nn.ReLU(inplace=True)] 39 | in_channels = v 40 | return nn.Sequential(*layers) 41 | 42 | def forward(self, x): 43 | x = self.feature(x) 44 | x = nn.AvgPool2d(2)(x) 45 | x = x.view(x.size(0), -1) 46 | y = self.classifier(x) 47 | return y 48 | 49 | def _initialize_weights(self): 50 | for m in self.modules(): 51 | if isinstance(m, nn.Conv2d): 52 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 53 | m.weight.data.normal_(0, math.sqrt(2. / n)) 54 | if m.bias is not None: 55 | m.bias.data.zero_() 56 | elif isinstance(m, nn.BatchNorm2d): 57 | m.weight.data.fill_(0.5) 58 | m.bias.data.zero_() 59 | elif isinstance(m, nn.Linear): 60 | m.weight.data.normal_(0, 0.01) 61 | m.bias.data.zero_() 62 | 63 | if __name__ == '__main__': 64 | net = vgg() 65 | x = Variable(torch.FloatTensor(16, 3, 40, 40)) 66 | y = net(x) 67 | print(y.data.shape) -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | astor==0.8.1 2 | atari-py==0.2.6 3 | attrs==19.3.0 4 | backcall==0.1.0 5 | bleach==3.1.0 6 | certifi==2019.11.28 7 | chardet==3.0.4 8 | cloudpickle==1.2.2 9 | colorama==0.4.3 10 | contextlib2==0.5.5 11 | coverage==5.0.3 12 | cycler==0.10.0 13 | decorator==4.4.1 14 | defusedxml==0.6.0 15 | entrypoints==0.3 16 | future==0.18.2 17 | gym==0.16.0 18 | hyperopt==0.1.2 19 | idna==2.9 20 | imageio==2.6.1 21 | importlib-metadata==1.5.0 22 | ipdb==0.12.3 23 | ipykernel==5.1.4 24 | ipython==7.12.0 25 | ipython-genutils==0.2.0 26 | ipywidgets==7.5.1 27 | jedi==0.16.0 28 | Jinja2==2.11.1 29 | joblib==0.14.1 30 | json-tricks==3.14.0 31 | jsonschema==3.2.0 32 | jupyter==1.0.0 33 | jupyter-client==5.3.4 34 | jupyter-console==6.1.0 35 | jupyter-core==4.6.2 36 | kiwisolver==1.1.0 37 | MarkupSafe==1.1.1 38 | matplotlib==3.1.3 39 | mistune==0.8.4 40 | mkl-fft==1.0.15 41 | mkl-random==1.1.0 42 | mkl-service==2.3.0 43 | nbconvert==5.6.1 44 | nbformat==5.0.4 45 | networkx==2.4 46 | nni==1.4 47 | notebook==6.0.3 48 | numpy==1.18.1 49 | opencv-contrib-python==4.2.0.32 50 | opencv-python==4.2.0.32 51 | pandocfilters==1.4.2 52 | parso==0.6.1 53 | pexpect==4.8.0 54 | pickleshare==0.7.5 55 | Pillow==7.0.0 56 | prometheus-client==0.7.1 57 | prompt-toolkit==3.0.3 58 | protobuf==3.11.3 59 | psutil==5.7.0 60 | ptyprocess==0.6.0 61 | pyglet==1.5.0 62 | Pygments==2.5.2 63 | pygraphviz==1.5 64 | pymongo==3.10.1 65 | pyparsing==2.4.6 66 | pyrsistent==0.15.7 67 | python-dateutil==2.8.1 68 | PythonWebHDFS==0.2.3 69 | pyzmq==18.1.1 70 | qtconsole==4.6.0 71 | requests==2.23.0 72 | ruamel.yaml==0.16.10 73 | ruamel.yaml.clib==0.2.0 74 | schema==0.7.1 75 | scikit-learn==0.21.3 76 | scipy==1.1.0 77 | Send2Trash==1.5.0 78 | simplejson==3.17.0 79 | six==1.14.0 80 | tensorboardX==2.0 81 | terminado==0.8.3 82 | testpath==0.4.4 83 | thop==0.0.31.post2001170342 84 | torch==1.4.0 85 | torchvision==0.5.0 86 | tornado==6.0.3 87 | tqdm==4.42.1 88 | traitlets==4.3.3 89 | urllib3==1.25.8 90 | wcwidth==0.1.8 91 | webencodings==0.5.1 92 | widgetsnbextension==3.5.1 93 | zipp==3.0.0 94 | -------------------------------------------------------------------------------- /script/vgg19.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python main.py --aepoch 10 --alr 0.01 --wepoch 160 --lr 0.1 --lr_decay --batchsize 64 --outdir log/vgg19_5 --model vgg --ratio 0.5 --dataset cifar10 3 | --------------------------------------------------------------------------------