├── checkpoints └── download.sh ├── scripts ├── search_mobilenet_0.5flops.sh ├── finetune_mobilenet_0.5flops.sh └── export_mobilenet_0.5flops.sh ├── env ├── rewards.py └── channel_pruning_env.py ├── LICENSE ├── models ├── mobilenet.py └── mobilenet_v2.py ├── README.md ├── lib ├── net_measure.py ├── data.py ├── agent.py ├── utils.py └── memory.py ├── eval_mobilenet.py ├── amc_fine_tune.py └── amc_search.py /checkpoints/download.sh: -------------------------------------------------------------------------------- 1 | # download original mobilenet 2 | wget -P checkpoints https://hanlab18.mit.edu/projects/amc/external/mobilenet_imagenet.pth.tar 3 | 4 | # download mobilenet 0.5FLOPs 5 | wget -P checkpoints https://hanlab18.mit.edu/projects/amc/external/mobilenet_imagenet_0.5flops_70.5.pth.tar 6 | -------------------------------------------------------------------------------- /scripts/search_mobilenet_0.5flops.sh: -------------------------------------------------------------------------------- 1 | python amc_search.py \ 2 | --job=train \ 3 | --model=mobilenet \ 4 | --dataset=imagenet \ 5 | --preserve_ratio=0.5 \ 6 | --lbound=0.2 \ 7 | --rbound=1 \ 8 | --reward=acc_reward \ 9 | --data_root=/dataset/imagenet \ 10 | --ckpt_path=./checkpoints/mobilenet_imagenet.pth.tar \ 11 | --seed=2018 12 | -------------------------------------------------------------------------------- /scripts/finetune_mobilenet_0.5flops.sh: -------------------------------------------------------------------------------- 1 | python -W ignore amc_fine_tune.py \ 2 | --model=mobilenet_0.5flops \ 3 | --dataset=imagenet \ 4 | --lr=0.05 \ 5 | --n_gpu=4 \ 6 | --batch_size=256 \ 7 | --n_worker=32 \ 8 | --lr_type=cos \ 9 | --n_epoch=150 \ 10 | --wd=4e-5 \ 11 | --seed=2018 \ 12 | --data_root=/dataset/imagenet \ 13 | --ckpt_path=./checkpoints/mobilenet_0.5flops_export.pth.tar 14 | -------------------------------------------------------------------------------- /env/rewards.py: -------------------------------------------------------------------------------- 1 | # Code for "AMC: AutoML for Model Compression and Acceleration on Mobile Devices" 2 | # Yihui He*, Ji Lin*, Zhijian Liu, Hanrui Wang, Li-Jia Li, Song Han 3 | # {jilin, songhan}@mit.edu 4 | 5 | import numpy as np 6 | 7 | 8 | # for pruning 9 | def acc_reward(net, acc, flops): 10 | return acc * 0.01 11 | 12 | 13 | def acc_flops_reward(net, acc, flops): 14 | error = (100 - acc) * 0.01 15 | return -error * np.log(flops) 16 | -------------------------------------------------------------------------------- /scripts/export_mobilenet_0.5flops.sh: -------------------------------------------------------------------------------- 1 | python amc_search.py \ 2 | --job=export \ 3 | --model=mobilenet \ 4 | --dataset=imagenet \ 5 | --data_root=/dataset/imagenet \ 6 | --ckpt_path=./checkpoints/mobilenet_imagenet.pth.tar \ 7 | --seed=2018 \ 8 | --n_calibration_batches=300 \ 9 | --n_worker=32 \ 10 | --channels=3,24,48,96,80,192,200,328,352,368,360,328,400,736,752 \ 11 | --export_path=./checkpoints/mobilenet_0.5flops_export.pth.tar 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 MIT_Han_Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /models/mobilenet.py: -------------------------------------------------------------------------------- 1 | # Code for "AMC: AutoML for Model Compression and Acceleration on Mobile Devices" 2 | # Yihui He*, Ji Lin*, Zhijian Liu, Hanrui Wang, Li-Jia Li, Song Han 3 | # {jilin, songhan}@mit.edu 4 | 5 | import torch.nn as nn 6 | import math 7 | 8 | 9 | def conv_bn(inp, oup, stride): 10 | return nn.Sequential( 11 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 12 | nn.BatchNorm2d(oup), 13 | nn.ReLU(inplace=True) 14 | ) 15 | 16 | 17 | def conv_dw(inp, oup, stride): 18 | return nn.Sequential( 19 | nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), 20 | nn.BatchNorm2d(inp), 21 | nn.ReLU(inplace=True), 22 | 23 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 24 | nn.BatchNorm2d(oup), 25 | nn.ReLU(inplace=True), 26 | ) 27 | 28 | 29 | class MobileNet(nn.Module): 30 | def __init__(self, n_class, profile='normal'): 31 | super(MobileNet, self).__init__() 32 | 33 | # original 34 | if profile == 'normal': 35 | in_planes = 32 36 | cfg = [64, (128, 2), 128, (256, 2), 256, (512, 2), 512, 512, 512, 512, 512, (1024, 2), 1024] 37 | # 0.5 AMC 38 | elif profile == '0.5flops': 39 | in_planes = 24 40 | cfg = [48, (96, 2), 80, (192, 2), 200, (328, 2), 352, 368, 360, 328, 400, (736, 2), 752] 41 | else: 42 | raise NotImplementedError 43 | 44 | self.conv1 = conv_bn(3, in_planes, stride=2) 45 | 46 | self.features = self._make_layers(in_planes, cfg, conv_dw) 47 | 48 | self.classifier = nn.Sequential( 49 | nn.Linear(cfg[-1], n_class), 50 | ) 51 | 52 | self._initialize_weights() 53 | 54 | def forward(self, x): 55 | x = self.conv1(x) 56 | x = self.features(x) 57 | x = x.mean(3).mean(2) # global average pooling 58 | 59 | x = self.classifier(x) 60 | return x 61 | 62 | def _make_layers(self, in_planes, cfg, layer): 63 | layers = [] 64 | for x in cfg: 65 | out_planes = x if isinstance(x, int) else x[0] 66 | stride = 1 if isinstance(x, int) else x[1] 67 | layers.append(layer(in_planes, out_planes, stride)) 68 | in_planes = out_planes 69 | return nn.Sequential(*layers) 70 | 71 | def _initialize_weights(self): 72 | for m in self.modules(): 73 | if isinstance(m, nn.Conv2d): 74 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 75 | m.weight.data.normal_(0, math.sqrt(2. / n)) 76 | if m.bias is not None: 77 | m.bias.data.zero_() 78 | elif isinstance(m, nn.BatchNorm2d): 79 | m.weight.data.fill_(1) 80 | m.bias.data.zero_() 81 | elif isinstance(m, nn.Linear): 82 | n = m.weight.size(1) 83 | m.weight.data.normal_(0, 0.01) 84 | m.bias.data.zero_() 85 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AutoML for Model Compression (AMC) 2 | 3 | This repo contains the PyTorch implementation for paper [**AMC: AutoML for Model Compression and Acceleration on Mobile Devices**](https://arxiv.org/abs/1802.03494). 4 | 5 | ![overview](https://hanlab.mit.edu/projects/amc/images/overview.png) 6 | 7 | 8 | 9 | ## Reference 10 | 11 | If you find the repo useful, please kindly cite our paper: 12 | 13 | ``` 14 | @inproceedings{he2018amc, 15 | title={AMC: AutoML for Model Compression and Acceleration on Mobile Devices}, 16 | author={He, Yihui and Lin, Ji and Liu, Zhijian and Wang, Hanrui and Li, Li-Jia and Han, Song}, 17 | booktitle={European Conference on Computer Vision (ECCV)}, 18 | year={2018} 19 | } 20 | ``` 21 | 22 | Other papers related to automated model design: 23 | 24 | - HAQ: Hardware-Aware Automated Quantization with Mixed Precision ([CVPR 2019](https://arxiv.org/abs/1811.08886)) 25 | 26 | - ProxylessNAS: Direct Neural Architecture Search on Target Task and Hardware ([ICLR 2019](https://arxiv.org/abs/1812.00332)) 27 | 28 | 29 | 30 | ## Training AMC 31 | 32 | Current code base supports the automated pruning of **MobileNet** on **ImageNet**. The pruning of MobileNet consists of 3 steps: **1. strategy search; 2. export the pruned weights; 3. fine-tune from pruned weights**. 33 | 34 | To conduct the full pruning procedure, follow the instructions below (results might vary a little from the paper due to different random seed): 35 | 36 | 1. **Strategy Search** 37 | 38 | To search the strategy on MobileNet ImageNet model, first get the pretrained MobileNet checkpoint on ImageNet by running: 39 | 40 | ``` 41 | bash ./checkpoints/download.sh 42 | ``` 43 | 44 | It will also download our 50% FLOPs compressed model. Then run the following script to search under 50% FLOPs constraint: 45 | 46 | ```bash 47 | bash ./scripts/search_mobilenet_0.5flops.sh 48 | ``` 49 | 50 | Results may differ due to different random seed. The strategy we found and reported in the paper is: 51 | 52 | ``` 53 | [3, 24, 48, 96, 80, 192, 200, 328, 352, 368, 360, 328, 400, 736, 752] 54 | ``` 55 | 56 | 2. **Export the Pruned Weights** 57 | 58 | After searching, we need to export the pruned weights by running: 59 | 60 | ``` 61 | bash ./scripts/export_mobilenet_0.5flops.sh 62 | ``` 63 | 64 | Also we need to modify MobileNet file to support the new pruned model (here it is already done in `models/mobilenet.py`) 65 | 66 | 3. **Fine-tune from Pruned Weights**a 67 | 68 | After exporting, we need to fine-tune from the pruned weights. For example, we can fine-tune using cosine learning rate for 150 epochs by running: 69 | 70 | ``` 71 | bash ./scripts/finetune_mobilenet_0.5flops.sh 72 | ``` 73 | 74 | 75 | 76 | ## AMC Compressed Model 77 | 78 | We also provide the models and weights compressed by our AMC method. We provide compressed MobileNet-V1 and MobileNet-V2 in both PyTorch and TensorFlow format [here](https://github.com/mit-han-lab/amc-compressed-models). 79 | 80 | Detailed statistics are as follows: 81 | 82 | | Models | Top1 Acc (%) | Top5 Acc (%) | 83 | | ------------------------ | ------------ | ------------ | 84 | | MobileNetV1-width*0.75 | 68.4 | 88.2 | 85 | | **MobileNetV1-50%FLOPs** | **70.494** | **89.306** | 86 | | **MobileNetV1-50%Time** | **70.200** | **89.430** | 87 | | MobileNetV2-width*0.75 | 69.8 | 89.6 | 88 | | **MobileNetV2-70%FLOPs** | **70.854** | **89.914** | 89 | 90 | 91 | 92 | ## Dependencies 93 | 94 | Current code base is tested under following environment: 95 | 96 | 1. Python 3.7.3 97 | 2. PyTorch 1.1.0 98 | 3. torchvision 0.2.1 99 | 4. NumPy 1.14.3 100 | 5. SciPy 1.1.0 101 | 6. scikit-learn 0.19.1 102 | 7. [tensorboardX](https://github.com/lanpa/tensorboardX) 103 | 8. ImageNet dataset 104 | 105 | 106 | 107 | ## Contact 108 | 109 | To contact the authors: 110 | 111 | Ji Lin, jilin@mit.edu 112 | 113 | Song Han, songhan@mit.edu 114 | -------------------------------------------------------------------------------- /lib/net_measure.py: -------------------------------------------------------------------------------- 1 | # Code for "AMC: AutoML for Model Compression and Acceleration on Mobile Devices" 2 | # Yihui He*, Ji Lin*, Zhijian Liu, Hanrui Wang, Li-Jia Li, Song Han 3 | # {jilin, songhan}@mit.edu 4 | 5 | import torch 6 | 7 | # [reference] https://github.com/ShichenLiu/CondenseNet/blob/master/utils.py 8 | 9 | 10 | def get_num_gen(gen): 11 | return sum(1 for _ in gen) 12 | 13 | 14 | def is_leaf(model): 15 | return get_num_gen(model.children()) == 0 16 | 17 | 18 | def get_layer_info(layer): 19 | layer_str = str(layer) 20 | type_name = layer_str[:layer_str.find('(')].strip() 21 | return type_name 22 | 23 | 24 | def get_layer_param(model): 25 | import operator 26 | import functools 27 | 28 | return sum([functools.reduce(operator.mul, i.size(), 1) for i in model.parameters()]) 29 | 30 | 31 | def measure_layer(layer, x): 32 | global count_ops, count_params 33 | delta_ops = 0 34 | delta_params = 0 35 | multi_add = 1 36 | type_name = get_layer_info(layer) 37 | 38 | # ops_conv 39 | if type_name in ['Conv2d']: 40 | out_h = int((x.size()[2] + 2 * layer.padding[0] - layer.kernel_size[0]) / 41 | layer.stride[0] + 1) 42 | out_w = int((x.size()[3] + 2 * layer.padding[1] - layer.kernel_size[1]) / 43 | layer.stride[1] + 1) 44 | delta_ops = layer.in_channels * layer.out_channels * layer.kernel_size[0] * \ 45 | layer.kernel_size[1] * out_h * out_w / layer.groups * multi_add 46 | delta_params = get_layer_param(layer) 47 | 48 | # ops_nonlinearity 49 | elif type_name in ['ReLU']: 50 | delta_ops = x.numel() / x.size(0) 51 | delta_params = get_layer_param(layer) 52 | 53 | # ops_pooling 54 | elif type_name in ['AvgPool2d']: 55 | in_w = x.size()[2] 56 | kernel_ops = layer.kernel_size * layer.kernel_size 57 | out_w = int((in_w + 2 * layer.padding - layer.kernel_size) / layer.stride + 1) 58 | out_h = int((in_w + 2 * layer.padding - layer.kernel_size) / layer.stride + 1) 59 | delta_ops = x.size()[1] * out_w * out_h * kernel_ops 60 | delta_params = get_layer_param(layer) 61 | 62 | elif type_name in ['AdaptiveAvgPool2d']: 63 | delta_ops = x.size()[1] * x.size()[2] * x.size()[3] 64 | delta_params = get_layer_param(layer) 65 | 66 | # ops_linear 67 | elif type_name in ['Linear']: 68 | weight_ops = layer.weight.numel() * multi_add 69 | bias_ops = layer.bias.numel() 70 | delta_ops = weight_ops + bias_ops 71 | delta_params = get_layer_param(layer) 72 | 73 | # ops_nothing 74 | elif type_name in ['BatchNorm2d', 'Dropout2d', 'DropChannel', 'Dropout']: 75 | delta_params = get_layer_param(layer) 76 | 77 | # unknown layer type 78 | else: 79 | delta_params = get_layer_param(layer) 80 | 81 | count_ops += delta_ops 82 | count_params += delta_params 83 | 84 | return 85 | 86 | 87 | def measure_model(model, H, W): 88 | global count_ops, count_params 89 | count_ops = 0 90 | count_params = 0 91 | data = torch.zeros(1, 3, H, W).cuda() 92 | 93 | def should_measure(x): 94 | return is_leaf(x) 95 | 96 | def modify_forward(model): 97 | for child in model.children(): 98 | if should_measure(child): 99 | def new_forward(m): 100 | def lambda_forward(x): 101 | measure_layer(m, x) 102 | return m.old_forward(x) 103 | return lambda_forward 104 | child.old_forward = child.forward 105 | child.forward = new_forward(child) 106 | else: 107 | modify_forward(child) 108 | 109 | def restore_forward(model): 110 | for child in model.children(): 111 | # leaf node 112 | if is_leaf(child) and hasattr(child, 'old_forward'): 113 | child.forward = child.old_forward 114 | child.old_forward = None 115 | else: 116 | restore_forward(child) 117 | 118 | modify_forward(model) 119 | model.forward(data) 120 | restore_forward(model) 121 | 122 | return count_ops, count_params -------------------------------------------------------------------------------- /eval_mobilenet.py: -------------------------------------------------------------------------------- 1 | # Code for "AMC: AutoML for Model Compression and Acceleration on Mobile Devices" 2 | # Yihui He*, Ji Lin*, Zhijian Liu, Hanrui Wang, Li-Jia Li, Song Han 3 | # {jilin, songhan}@mit.edu 4 | 5 | import os 6 | import time 7 | import torch 8 | import torch.nn as nn 9 | import torch.backends.cudnn as cudnn 10 | 11 | import argparse 12 | 13 | from torch.autograd import Variable 14 | 15 | from models.mobilenet import MobileNet 16 | from lib.utils import AverageMeter, progress_bar, accuracy 17 | 18 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') 19 | parser.add_argument('--model', default='mobilenet_0.5flops', type=str, help='name of the model to test') 20 | parser.add_argument('--imagenet_path', default=None, type=str, help='Directory of ImageNet') 21 | parser.add_argument('--n_gpu', default=1, type=int, help='name of the job') 22 | parser.add_argument('--batch_size', default=100, type=int, help='batch size') 23 | parser.add_argument('--n_worker', default=32, type=int, help='number of data loader worker') 24 | 25 | args = parser.parse_args() 26 | 27 | use_cuda = torch.cuda.is_available() 28 | 29 | 30 | def get_dataset(): 31 | # lazy import 32 | import torchvision.datasets as datasets 33 | import torchvision.transforms as transforms 34 | if not args.imagenet_path: 35 | raise Exception('Please provide valid ImageNet path!') 36 | print('=> Preparing data..') 37 | valdir = os.path.join(args.imagenet_path, 'val') 38 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 39 | std=[0.229, 0.224, 0.225]) 40 | 41 | input_size = 224 42 | val_loader = torch.utils.data.DataLoader( 43 | datasets.ImageFolder(valdir, transforms.Compose([ 44 | transforms.Resize(int(input_size / 0.875)), 45 | transforms.CenterCrop(input_size), 46 | transforms.ToTensor(), 47 | normalize, 48 | ])), 49 | batch_size=args.batch_size, shuffle=False, 50 | num_workers=args.n_worker, pin_memory=True) 51 | n_class = 1000 52 | return val_loader, n_class 53 | 54 | 55 | def get_model(n_class): 56 | print('=> Building model {}...'.format(args.model)) 57 | if args.model == 'mobilenet_0.5flops': 58 | net = MobileNet(n_class, profile='0.5flops') 59 | checkpoint_path = './checkpoints/mobilenet_imagenet_0.5flops_70.5.pth.tar' 60 | else: 61 | raise NotImplementedError 62 | 63 | print('=> Loading checkpoints..') 64 | checkpoint = torch.load(checkpoint_path) 65 | net.load_state_dict(checkpoint['state_dict']) # remove .module 66 | 67 | return net 68 | 69 | 70 | def evaluate(): 71 | # build dataset 72 | val_loader, n_class = get_dataset() 73 | # build model 74 | net = get_model(n_class) 75 | 76 | criterion = nn.CrossEntropyLoss() 77 | 78 | if use_cuda: 79 | net = net.cuda() 80 | net = torch.nn.DataParallel(net, list(range(args.n_gpu))) 81 | cudnn.benchmark = True 82 | 83 | # begin eval 84 | net.eval() 85 | 86 | batch_time = AverageMeter() 87 | losses = AverageMeter() 88 | top1 = AverageMeter() 89 | top5 = AverageMeter() 90 | end = time.time() 91 | 92 | with torch.no_grad(): 93 | for batch_idx, (inputs, targets) in enumerate(val_loader): 94 | if use_cuda: 95 | inputs, targets = inputs.cuda(), targets.cuda() 96 | inputs, targets = Variable(inputs), Variable(targets) 97 | outputs = net(inputs) 98 | loss = criterion(outputs, targets) 99 | 100 | # measure accuracy and record loss 101 | prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5)) 102 | losses.update(loss.item(), inputs.size(0)) 103 | top1.update(prec1.item(), inputs.size(0)) 104 | top5.update(prec5.item(), inputs.size(0)) 105 | # timing 106 | batch_time.update(time.time() - end) 107 | end = time.time() 108 | 109 | progress_bar(batch_idx, len(val_loader), 'Loss: {:.3f} | Acc1: {:.3f}% | Acc5: {:.3f}%' 110 | .format(losses.avg, top1.avg, top5.avg)) 111 | 112 | 113 | if __name__ == '__main__': 114 | evaluate() -------------------------------------------------------------------------------- /models/mobilenet_v2.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | 4 | 5 | def conv_bn(inp, oup, stride): 6 | return nn.Sequential( 7 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 8 | nn.BatchNorm2d(oup), 9 | nn.ReLU6(inplace=True) 10 | ) 11 | 12 | 13 | def conv_1x1_bn(inp, oup): 14 | return nn.Sequential( 15 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 16 | nn.BatchNorm2d(oup), 17 | nn.ReLU6(inplace=True) 18 | ) 19 | 20 | 21 | class InvertedResidual(nn.Module): 22 | def __init__(self, inp, oup, stride, expand_ratio): 23 | super(InvertedResidual, self).__init__() 24 | self.stride = stride 25 | assert stride in [1, 2] 26 | 27 | hidden_dim = round(inp * expand_ratio) 28 | self.use_res_connect = self.stride == 1 and inp == oup 29 | 30 | if expand_ratio == 1: 31 | self.conv = nn.Sequential( 32 | # dw 33 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 34 | nn.BatchNorm2d(hidden_dim), 35 | nn.ReLU6(inplace=True), 36 | # pw-linear 37 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 38 | nn.BatchNorm2d(oup), 39 | ) 40 | else: 41 | self.conv = nn.Sequential( 42 | # pw 43 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 44 | nn.BatchNorm2d(hidden_dim), 45 | nn.ReLU6(inplace=True), 46 | # dw 47 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 48 | nn.BatchNorm2d(hidden_dim), 49 | nn.ReLU6(inplace=True), 50 | # pw-linear 51 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 52 | nn.BatchNorm2d(oup), 53 | ) 54 | 55 | def forward(self, x): 56 | if self.use_res_connect: 57 | return x + self.conv(x) 58 | else: 59 | return self.conv(x) 60 | 61 | 62 | class MobileNetV2(nn.Module): 63 | def __init__(self, n_class=1000, input_size=224, width_mult=1.): 64 | super(MobileNetV2, self).__init__() 65 | block = InvertedResidual 66 | input_channel = 32 67 | last_channel = 1280 68 | interverted_residual_setting = [ 69 | # t, c, n, s 70 | [1, 16, 1, 1], 71 | [6, 24, 2, 2], 72 | [6, 32, 3, 2], 73 | [6, 64, 4, 2], 74 | [6, 96, 3, 1], 75 | [6, 160, 3, 2], 76 | [6, 320, 1, 1], 77 | ] 78 | 79 | # building first layer 80 | assert input_size % 32 == 0 81 | input_channel = int(input_channel * width_mult) 82 | self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel 83 | self.features = [conv_bn(3, input_channel, 2)] 84 | # building inverted residual blocks 85 | for t, c, n, s in interverted_residual_setting: 86 | output_channel = int(c * width_mult) 87 | for i in range(n): 88 | if i == 0: 89 | self.features.append(block(input_channel, output_channel, s, expand_ratio=t)) 90 | else: 91 | self.features.append(block(input_channel, output_channel, 1, expand_ratio=t)) 92 | input_channel = output_channel 93 | # building last several layers 94 | self.features.append(conv_1x1_bn(input_channel, self.last_channel)) 95 | # make it nn.Sequential 96 | self.features = nn.Sequential(*self.features) 97 | 98 | # building classifier 99 | self.classifier = nn.Sequential( 100 | nn.Dropout(0.2), 101 | nn.Linear(self.last_channel, n_class), 102 | ) 103 | 104 | self._initialize_weights() 105 | 106 | def forward(self, x): 107 | x = self.features(x) 108 | x = x.mean(3).mean(2) 109 | x = self.classifier(x) 110 | return x 111 | 112 | def _initialize_weights(self): 113 | for m in self.modules(): 114 | if isinstance(m, nn.Conv2d): 115 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 116 | m.weight.data.normal_(0, math.sqrt(2. / n)) 117 | if m.bias is not None: 118 | m.bias.data.zero_() 119 | elif isinstance(m, nn.BatchNorm2d): 120 | m.weight.data.fill_(1) 121 | m.bias.data.zero_() 122 | elif isinstance(m, nn.Linear): 123 | n = m.weight.size(1) 124 | m.weight.data.normal_(0, 0.01) 125 | m.bias.data.zero_() 126 | -------------------------------------------------------------------------------- /lib/data.py: -------------------------------------------------------------------------------- 1 | # Code for "AMC: AutoML for Model Compression and Acceleration on Mobile Devices" 2 | # Yihui He*, Ji Lin*, Zhijian Liu, Hanrui Wang, Li-Jia Li, Song Han 3 | # {jilin, songhan}@mit.edu 4 | 5 | import torch 6 | import torch.nn.parallel 7 | import torch.optim 8 | import torch.utils.data 9 | import torchvision 10 | import torchvision.transforms as transforms 11 | import torchvision.datasets as datasets 12 | from torch.utils.data.sampler import SubsetRandomSampler 13 | import numpy as np 14 | 15 | import os 16 | 17 | 18 | def get_dataset(dset_name, batch_size, n_worker, data_root='../../data'): 19 | cifar_tran_train = [ 20 | transforms.RandomCrop(32, padding=4), 21 | transforms.RandomHorizontalFlip(), 22 | transforms.ToTensor(), 23 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 24 | ] 25 | cifar_tran_test = [ 26 | transforms.ToTensor(), 27 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 28 | ] 29 | print('=> Preparing data..') 30 | if dset_name == 'cifar10': 31 | transform_train = transforms.Compose(cifar_tran_train) 32 | transform_test = transforms.Compose(cifar_tran_test) 33 | trainset = torchvision.datasets.CIFAR10(root=data_root, train=True, download=True, transform=transform_train) 34 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, 35 | num_workers=n_worker, pin_memory=True, sampler=None) 36 | testset = torchvision.datasets.CIFAR10(root=data_root, train=False, download=True, transform=transform_test) 37 | val_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, 38 | num_workers=n_worker, pin_memory=True) 39 | n_class = 10 40 | elif dset_name == 'imagenet': 41 | # get dir 42 | traindir = os.path.join(data_root, 'train') 43 | valdir = os.path.join(data_root, 'val') 44 | 45 | # preprocessing 46 | input_size = 224 47 | imagenet_tran_train = [ 48 | transforms.RandomResizedCrop(input_size, scale=(0.2, 1.0)), 49 | transforms.RandomHorizontalFlip(), 50 | transforms.ToTensor(), 51 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 52 | ] 53 | imagenet_tran_test = [ 54 | transforms.Resize(int(input_size / 0.875)), 55 | transforms.CenterCrop(input_size), 56 | transforms.ToTensor(), 57 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 58 | ] 59 | 60 | train_loader = torch.utils.data.DataLoader( 61 | datasets.ImageFolder(traindir, transforms.Compose(imagenet_tran_train)), 62 | batch_size=batch_size, shuffle=True, 63 | num_workers=n_worker, pin_memory=True, sampler=None) 64 | 65 | val_loader = torch.utils.data.DataLoader( 66 | datasets.ImageFolder(valdir, transforms.Compose(imagenet_tran_test)), 67 | batch_size=batch_size, shuffle=False, 68 | num_workers=n_worker, pin_memory=True) 69 | n_class = 1000 70 | 71 | else: 72 | raise NotImplementedError 73 | 74 | return train_loader, val_loader, n_class 75 | 76 | 77 | def get_split_dataset(dset_name, batch_size, n_worker, val_size, data_root='../data', 78 | use_real_val=False, shuffle=True): 79 | ''' 80 | split the train set into train / val for rl search 81 | ''' 82 | if shuffle: 83 | index_sampler = SubsetRandomSampler 84 | else: # every time we use the same order for the split subset 85 | class SubsetSequentialSampler(SubsetRandomSampler): 86 | def __iter__(self): 87 | return (self.indices[i] for i in torch.arange(len(self.indices)).int()) 88 | index_sampler = SubsetSequentialSampler 89 | 90 | print('=> Preparing data: {}...'.format(dset_name)) 91 | if dset_name == 'cifar10': 92 | transform_train = transforms.Compose([ 93 | transforms.RandomCrop(32, padding=4), 94 | transforms.RandomHorizontalFlip(), 95 | transforms.ToTensor(), 96 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 97 | ]) 98 | transform_test = transforms.Compose([ 99 | transforms.ToTensor(), 100 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 101 | ]) 102 | trainset = torchvision.datasets.CIFAR100(root=data_root, train=True, download=True, transform=transform_train) 103 | if use_real_val: # split the actual val set 104 | valset = torchvision.datasets.CIFAR10(root=data_root, train=False, download=True, transform=transform_test) 105 | n_val = len(valset) 106 | assert val_size < n_val 107 | indices = list(range(n_val)) 108 | np.random.shuffle(indices) 109 | _, val_idx = indices[val_size:], indices[:val_size] 110 | train_idx = list(range(len(trainset))) # all train set for train 111 | else: # split the train set 112 | valset = torchvision.datasets.CIFAR10(root=data_root, train=True, download=True, transform=transform_test) 113 | n_train = len(trainset) 114 | indices = list(range(n_train)) 115 | # now shuffle the indices 116 | np.random.shuffle(indices) 117 | assert val_size < n_train 118 | train_idx, val_idx = indices[val_size:], indices[:val_size] 119 | 120 | train_sampler = index_sampler(train_idx) 121 | val_sampler = index_sampler(val_idx) 122 | 123 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=False, sampler=train_sampler, 124 | num_workers=n_worker, pin_memory=True) 125 | val_loader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, sampler=val_sampler, 126 | num_workers=n_worker, pin_memory=True) 127 | n_class = 10 128 | 129 | elif dset_name == 'imagenet': 130 | train_dir = os.path.join(data_root, 'train') 131 | val_dir = os.path.join(data_root, 'val') 132 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 133 | std=[0.229, 0.224, 0.225]) 134 | input_size = 224 135 | train_transform = transforms.Compose([ 136 | transforms.RandomResizedCrop(input_size), 137 | transforms.RandomHorizontalFlip(), 138 | transforms.ToTensor(), 139 | normalize, 140 | ]) 141 | test_transform = transforms.Compose([ 142 | transforms.Resize(int(input_size/0.875)), 143 | transforms.CenterCrop(input_size), 144 | transforms.ToTensor(), 145 | normalize, 146 | ]) 147 | 148 | trainset = datasets.ImageFolder(train_dir, train_transform) 149 | if use_real_val: 150 | valset = datasets.ImageFolder(val_dir, test_transform) 151 | n_val = len(valset) 152 | assert val_size < n_val 153 | indices = list(range(n_val)) 154 | np.random.shuffle(indices) 155 | _, val_idx = indices[val_size:], indices[:val_size] 156 | train_idx = list(range(len(trainset))) # all trainset 157 | else: 158 | valset = datasets.ImageFolder(train_dir, test_transform) 159 | n_train = len(trainset) 160 | indices = list(range(n_train)) 161 | np.random.shuffle(indices) 162 | assert val_size < n_train 163 | train_idx, val_idx = indices[val_size:], indices[:val_size] 164 | 165 | train_sampler = index_sampler(train_idx) 166 | val_sampler = index_sampler(val_idx) 167 | 168 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, sampler=train_sampler, 169 | num_workers=n_worker, pin_memory=True) 170 | val_loader = torch.utils.data.DataLoader(valset, batch_size=batch_size, sampler=val_sampler, 171 | num_workers=n_worker, pin_memory=True) 172 | 173 | n_class = 1000 174 | else: 175 | raise NotImplementedError 176 | 177 | return train_loader, val_loader, n_class 178 | -------------------------------------------------------------------------------- /lib/agent.py: -------------------------------------------------------------------------------- 1 | # Code for "AMC: AutoML for Model Compression and Acceleration on Mobile Devices" 2 | # Yihui He*, Ji Lin*, Zhijian Liu, Hanrui Wang, Li-Jia Li, Song Han 3 | # {jilin, songhan}@mit.edu 4 | 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.optim import Adam 10 | 11 | from lib.memory import SequentialMemory 12 | from lib.utils import to_numpy, to_tensor 13 | 14 | criterion = nn.MSELoss() 15 | USE_CUDA = torch.cuda.is_available() 16 | 17 | 18 | class Actor(nn.Module): 19 | def __init__(self, nb_states, nb_actions, hidden1=400, hidden2=300): 20 | super(Actor, self).__init__() 21 | self.fc1 = nn.Linear(nb_states, hidden1) 22 | self.fc2 = nn.Linear(hidden1, hidden2) 23 | self.fc3 = nn.Linear(hidden2, nb_actions) 24 | self.relu = nn.ReLU() 25 | self.sigmoid = nn.Sigmoid() 26 | 27 | def forward(self, x): 28 | out = self.fc1(x) 29 | out = self.relu(out) 30 | out = self.fc2(out) 31 | out = self.relu(out) 32 | out = self.fc3(out) 33 | out = self.sigmoid(out) 34 | return out 35 | 36 | 37 | class Critic(nn.Module): 38 | def __init__(self, nb_states, nb_actions, hidden1=400, hidden2=300): 39 | super(Critic, self).__init__() 40 | self.fc11 = nn.Linear(nb_states, hidden1) 41 | self.fc12 = nn.Linear(nb_actions, hidden1) 42 | self.fc2 = nn.Linear(hidden1, hidden2) 43 | self.fc3 = nn.Linear(hidden2, 1) 44 | self.relu = nn.ReLU() 45 | 46 | def forward(self, xs): 47 | x, a = xs 48 | out = self.fc11(x) + self.fc12(a) 49 | out = self.relu(out) 50 | out = self.fc2(out) 51 | out = self.relu(out) 52 | out = self.fc3(out) 53 | return out 54 | 55 | 56 | class DDPG(object): 57 | def __init__(self, nb_states, nb_actions, args): 58 | 59 | self.nb_states = nb_states 60 | self.nb_actions = nb_actions 61 | 62 | # Create Actor and Critic Network 63 | net_cfg = { 64 | 'hidden1': args.hidden1, 65 | 'hidden2': args.hidden2, 66 | # 'init_w': args.init_w 67 | } 68 | self.actor = Actor(self.nb_states, self.nb_actions, **net_cfg) 69 | self.actor_target = Actor(self.nb_states, self.nb_actions, **net_cfg) 70 | self.actor_optim = Adam(self.actor.parameters(), lr=args.lr_a) 71 | 72 | self.critic = Critic(self.nb_states, self.nb_actions, **net_cfg) 73 | self.critic_target = Critic(self.nb_states, self.nb_actions, **net_cfg) 74 | self.critic_optim = Adam(self.critic.parameters(), lr=args.lr_c) 75 | 76 | self.hard_update(self.actor_target, self.actor) # Make sure target is with the same weight 77 | self.hard_update(self.critic_target, self.critic) 78 | 79 | # Create replay buffer 80 | self.memory = SequentialMemory(limit=args.rmsize, window_length=args.window_length) 81 | # self.random_process = OrnsteinUhlenbeckProcess(size=nb_actions, theta=args.ou_theta, mu=args.ou_mu, 82 | # sigma=args.ou_sigma) 83 | 84 | # Hyper-parameters 85 | self.batch_size = args.bsize 86 | self.tau = args.tau 87 | self.discount = args.discount 88 | self.depsilon = 1.0 / args.epsilon 89 | self.lbound = 0. # args.lbound 90 | self.rbound = 1. # args.rbound 91 | 92 | # noise 93 | self.init_delta = args.init_delta 94 | self.delta_decay = args.delta_decay 95 | self.warmup = args.warmup 96 | 97 | # 98 | self.epsilon = 1.0 99 | # self.s_t = None # Most recent state 100 | # self.a_t = None # Most recent action 101 | self.is_training = True 102 | 103 | # 104 | if USE_CUDA: self.cuda() 105 | 106 | # moving average baseline 107 | self.moving_average = None 108 | self.moving_alpha = 0.5 # based on batch, so small 109 | 110 | def update_policy(self): 111 | # Sample batch 112 | state_batch, action_batch, reward_batch, \ 113 | next_state_batch, terminal_batch = self.memory.sample_and_split(self.batch_size) 114 | 115 | # normalize the reward 116 | batch_mean_reward = np.mean(reward_batch) 117 | if self.moving_average is None: 118 | self.moving_average = batch_mean_reward 119 | else: 120 | self.moving_average += self.moving_alpha * (batch_mean_reward - self.moving_average) 121 | reward_batch -= self.moving_average 122 | # if reward_batch.std() > 0: 123 | # reward_batch /= reward_batch.std() 124 | 125 | # Prepare for the target q batch 126 | with torch.no_grad(): 127 | next_q_values = self.critic_target([ 128 | to_tensor(next_state_batch), 129 | self.actor_target(to_tensor(next_state_batch)), 130 | ]) 131 | 132 | target_q_batch = to_tensor(reward_batch) + \ 133 | self.discount * to_tensor(terminal_batch.astype(np.float)) * next_q_values 134 | 135 | # Critic update 136 | self.critic.zero_grad() 137 | 138 | q_batch = self.critic([to_tensor(state_batch), to_tensor(action_batch)]) 139 | 140 | value_loss = criterion(q_batch, target_q_batch) 141 | value_loss.backward() 142 | self.critic_optim.step() 143 | 144 | # Actor update 145 | self.actor.zero_grad() 146 | 147 | policy_loss = -self.critic([ 148 | to_tensor(state_batch), 149 | self.actor(to_tensor(state_batch)) 150 | ]) 151 | 152 | policy_loss = policy_loss.mean() 153 | policy_loss.backward() 154 | self.actor_optim.step() 155 | 156 | # Target update 157 | self.soft_update(self.actor_target, self.actor) 158 | self.soft_update(self.critic_target, self.critic) 159 | 160 | def eval(self): 161 | self.actor.eval() 162 | self.actor_target.eval() 163 | self.critic.eval() 164 | self.critic_target.eval() 165 | 166 | def cuda(self): 167 | self.actor.cuda() 168 | self.actor_target.cuda() 169 | self.critic.cuda() 170 | self.critic_target.cuda() 171 | 172 | def observe(self, r_t, s_t, s_t1, a_t, done): 173 | if self.is_training: 174 | self.memory.append(s_t, a_t, r_t, done) # save to memory 175 | # self.s_t = s_t1 176 | 177 | def random_action(self): 178 | action = np.random.uniform(self.lbound, self.rbound, self.nb_actions) 179 | # self.a_t = action 180 | return action 181 | 182 | def select_action(self, s_t, episode): 183 | # assert episode >= self.warmup, 'Episode: {} warmup: {}'.format(episode, self.warmup) 184 | action = to_numpy(self.actor(to_tensor(np.array(s_t).reshape(1, -1)))).squeeze(0) 185 | delta = self.init_delta * (self.delta_decay ** (episode - self.warmup)) 186 | # action += self.is_training * max(self.epsilon, 0) * self.random_process.sample() 187 | action = self.sample_from_truncated_normal_distribution(lower=self.lbound, upper=self.rbound, mu=action, sigma=delta) 188 | action = np.clip(action, self.lbound, self.rbound) 189 | 190 | # self.a_t = action 191 | return action 192 | 193 | def reset(self, obs): 194 | pass 195 | # self.s_t = obs 196 | # self.random_process.reset_states() 197 | 198 | def load_weights(self, output): 199 | if output is None: return 200 | 201 | self.actor.load_state_dict( 202 | torch.load('{}/actor.pkl'.format(output)) 203 | ) 204 | 205 | self.critic.load_state_dict( 206 | torch.load('{}/critic.pkl'.format(output)) 207 | ) 208 | 209 | def save_model(self, output): 210 | torch.save( 211 | self.actor.state_dict(), 212 | '{}/actor.pkl'.format(output) 213 | ) 214 | torch.save( 215 | self.critic.state_dict(), 216 | '{}/critic.pkl'.format(output) 217 | ) 218 | 219 | def soft_update(self, target, source): 220 | for target_param, param in zip(target.parameters(), source.parameters()): 221 | target_param.data.copy_( 222 | target_param.data * (1.0 - self.tau) + param.data * self.tau 223 | ) 224 | 225 | def hard_update(self, target, source): 226 | for target_param, param in zip(target.parameters(), source.parameters()): 227 | target_param.data.copy_(param.data) 228 | 229 | def sample_from_truncated_normal_distribution(self, lower, upper, mu, sigma, size=1): 230 | from scipy import stats 231 | return stats.truncnorm.rvs((lower-mu)/sigma, (upper-mu)/sigma, loc=mu, scale=sigma, size=size) 232 | 233 | 234 | -------------------------------------------------------------------------------- /lib/utils.py: -------------------------------------------------------------------------------- 1 | # Code for "AMC: AutoML for Model Compression and Acceleration on Mobile Devices" 2 | # Yihui He*, Ji Lin*, Zhijian Liu, Hanrui Wang, Li-Jia Li, Song Han 3 | # {jilin, songhan}@mit.edu 4 | 5 | import os 6 | import torch 7 | import time 8 | import sys 9 | 10 | 11 | class AverageMeter(object): 12 | """Computes and stores the average and current value""" 13 | def __init__(self): 14 | self.reset() 15 | 16 | def reset(self): 17 | self.val = 0 18 | self.avg = 0 19 | self.sum = 0 20 | self.count = 0 21 | 22 | def update(self, val, n=1): 23 | self.val = val 24 | self.sum += val * n 25 | self.count += n 26 | if self.count > 0: 27 | self.avg = self.sum / self.count 28 | 29 | def accumulate(self, val, n=1): 30 | self.sum += val 31 | self.count += n 32 | if self.count > 0: 33 | self.avg = self.sum / self.count 34 | 35 | 36 | class TextLogger(object): 37 | """Write log immediately to the disk""" 38 | def __init__(self, filepath): 39 | self.f = open(filepath, 'w') 40 | self.fid = self.f.fileno() 41 | self.filepath = filepath 42 | 43 | def close(self): 44 | self.f.close() 45 | 46 | def write(self, content): 47 | self.f.write(content) 48 | self.f.flush() 49 | os.fsync(self.fid) 50 | 51 | def write_buf(self, content): 52 | self.f.write(content) 53 | 54 | def print_and_write(self, content): 55 | print(content) 56 | self.write(content+'\n') 57 | 58 | 59 | def accuracy(output, target, topk=(1,)): 60 | """Computes the precision@k for the specified values of k""" 61 | batch_size = target.size(0) 62 | num = output.size(1) 63 | target_topk = [] 64 | appendices = [] 65 | for k in topk: 66 | if k <= num: 67 | target_topk.append(k) 68 | else: 69 | appendices.append([0.0]) 70 | topk = target_topk 71 | maxk = max(topk) 72 | _, pred = output.topk(maxk, 1, True, True) 73 | pred = pred.t() 74 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 75 | 76 | res = [] 77 | for k in topk: 78 | correct_k = correct[:k].view(-1).float().sum(0) 79 | res.append(correct_k.mul_(100.0 / batch_size)) 80 | return res + appendices 81 | 82 | 83 | def to_numpy(var): 84 | use_cuda = torch.cuda.is_available() 85 | return var.cpu().data.numpy() if use_cuda else var.data.numpy() 86 | 87 | 88 | def to_tensor(ndarray, requires_grad=False): # return a float tensor by default 89 | tensor = torch.from_numpy(ndarray).float() # by default does not require grad 90 | if requires_grad: 91 | tensor.requires_grad_() 92 | return tensor.cuda() if torch.cuda.is_available() else tensor 93 | 94 | 95 | def measure_layer_for_pruning(layer, x): 96 | def get_layer_type(layer): 97 | layer_str = str(layer) 98 | return layer_str[:layer_str.find('(')].strip() 99 | 100 | def get_layer_param(model): 101 | import operator 102 | import functools 103 | 104 | return sum([functools.reduce(operator.mul, i.size(), 1) for i in model.parameters()]) 105 | 106 | multi_add = 1 107 | type_name = get_layer_type(layer) 108 | 109 | # ops_conv 110 | if type_name in ['Conv2d']: 111 | out_h = int((x.size()[2] + 2 * layer.padding[0] - layer.kernel_size[0]) / 112 | layer.stride[0] + 1) 113 | out_w = int((x.size()[3] + 2 * layer.padding[1] - layer.kernel_size[1]) / 114 | layer.stride[1] + 1) 115 | layer.flops = layer.in_channels * layer.out_channels * layer.kernel_size[0] * \ 116 | layer.kernel_size[1] * out_h * out_w / layer.groups * multi_add 117 | layer.params = get_layer_param(layer) 118 | # ops_linear 119 | elif type_name in ['Linear']: 120 | weight_ops = layer.weight.numel() * multi_add 121 | bias_ops = layer.bias.numel() 122 | layer.flops = weight_ops + bias_ops 123 | layer.params = get_layer_param(layer) 124 | return 125 | 126 | 127 | def least_square_sklearn(X, Y): 128 | from sklearn.linear_model import LinearRegression 129 | reg = LinearRegression(fit_intercept=False) 130 | reg.fit(X, Y) 131 | return reg.coef_ 132 | 133 | 134 | def get_output_folder(parent_dir, env_name): 135 | """Return save folder. 136 | Assumes folders in the parent_dir have suffix -run{run 137 | number}. Finds the highest run number and sets the output folder 138 | to that number + 1. This is just convenient so that if you run the 139 | same script multiple times tensorboard can plot all of the results 140 | on the same plots with different names. 141 | Parameters 142 | ---------- 143 | parent_dir: str 144 | Path of the directory containing all experiment runs. 145 | Returns 146 | ------- 147 | parent_dir/run_dir 148 | Path to this run's save directory. 149 | """ 150 | os.makedirs(parent_dir, exist_ok=True) 151 | experiment_id = 0 152 | for folder_name in os.listdir(parent_dir): 153 | if not os.path.isdir(os.path.join(parent_dir, folder_name)): 154 | continue 155 | try: 156 | folder_name = int(folder_name.split('-run')[-1]) 157 | if folder_name > experiment_id: 158 | experiment_id = folder_name 159 | except: 160 | pass 161 | experiment_id += 1 162 | 163 | parent_dir = os.path.join(parent_dir, env_name) 164 | parent_dir = parent_dir + '-run{}'.format(experiment_id) 165 | os.makedirs(parent_dir, exist_ok=True) 166 | return parent_dir 167 | 168 | 169 | 170 | # Custom progress bar 171 | _, term_width = os.popen('stty size', 'r').read().split() 172 | term_width = int(term_width) 173 | TOTAL_BAR_LENGTH = 40. 174 | last_time = time.time() 175 | begin_time = last_time 176 | 177 | 178 | def progress_bar(current, total, msg=None): 179 | def format_time(seconds): 180 | days = int(seconds / 3600 / 24) 181 | seconds = seconds - days * 3600 * 24 182 | hours = int(seconds / 3600) 183 | seconds = seconds - hours * 3600 184 | minutes = int(seconds / 60) 185 | seconds = seconds - minutes * 60 186 | secondsf = int(seconds) 187 | seconds = seconds - secondsf 188 | millis = int(seconds * 1000) 189 | 190 | f = '' 191 | i = 1 192 | if days > 0: 193 | f += str(days) + 'D' 194 | i += 1 195 | if hours > 0 and i <= 2: 196 | f += str(hours) + 'h' 197 | i += 1 198 | if minutes > 0 and i <= 2: 199 | f += str(minutes) + 'm' 200 | i += 1 201 | if secondsf > 0 and i <= 2: 202 | f += str(secondsf) + 's' 203 | i += 1 204 | if millis > 0 and i <= 2: 205 | f += str(millis) + 'ms' 206 | i += 1 207 | if f == '': 208 | f = '0ms' 209 | return f 210 | 211 | global last_time, begin_time 212 | if current == 0: 213 | begin_time = time.time() # Reset for new bar. 214 | 215 | cur_len = int(TOTAL_BAR_LENGTH*current/total) 216 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 217 | 218 | sys.stdout.write(' [') 219 | for i in range(cur_len): 220 | sys.stdout.write('=') 221 | sys.stdout.write('>') 222 | for i in range(rest_len): 223 | sys.stdout.write('.') 224 | sys.stdout.write(']') 225 | 226 | cur_time = time.time() 227 | step_time = cur_time - last_time 228 | last_time = cur_time 229 | tot_time = cur_time - begin_time 230 | 231 | L = [] 232 | L.append(' Step: %s' % format_time(step_time)) 233 | L.append(' | Tot: %s' % format_time(tot_time)) 234 | if msg: 235 | L.append(' | ' + msg) 236 | 237 | msg = ''.join(L) 238 | sys.stdout.write(msg) 239 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 240 | sys.stdout.write(' ') 241 | 242 | # Go back to the center of the bar. 243 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): 244 | sys.stdout.write('\b') 245 | sys.stdout.write(' %d/%d ' % (current+1, total)) 246 | 247 | if current < total-1: 248 | sys.stdout.write('\r') 249 | else: 250 | sys.stdout.write('\n') 251 | sys.stdout.flush() 252 | 253 | # logging 254 | def prRed(prt): print("\033[91m {}\033[00m" .format(prt)) 255 | def prGreen(prt): print("\033[92m {}\033[00m" .format(prt)) 256 | def prYellow(prt): print("\033[93m {}\033[00m" .format(prt)) 257 | def prLightPurple(prt): print("\033[94m {}\033[00m" .format(prt)) 258 | def prPurple(prt): print("\033[95m {}\033[00m" .format(prt)) 259 | def prCyan(prt): print("\033[96m {}\033[00m" .format(prt)) 260 | def prLightGray(prt): print("\033[97m {}\033[00m" .format(prt)) 261 | def prBlack(prt): print("\033[98m {}\033[00m" .format(prt)) -------------------------------------------------------------------------------- /amc_fine_tune.py: -------------------------------------------------------------------------------- 1 | # Code for "AMC: AutoML for Model Compression and Acceleration on Mobile Devices" 2 | # Yihui He*, Ji Lin*, Zhijian Liu, Hanrui Wang, Li-Jia Li, Song Han 3 | # {jilin, songhan}@mit.edu 4 | 5 | import os 6 | import time 7 | import argparse 8 | import shutil 9 | import math 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.optim as optim 14 | import numpy as np 15 | 16 | from tensorboardX import SummaryWriter 17 | 18 | from lib.utils import accuracy, AverageMeter, progress_bar, get_output_folder 19 | from lib.data import get_dataset 20 | from lib.net_measure import measure_model 21 | 22 | 23 | def parse_args(): 24 | parser = argparse.ArgumentParser(description='AMC fine-tune script') 25 | parser.add_argument('--model', default='mobilenet', type=str, help='name of the model to train') 26 | parser.add_argument('--dataset', default='imagenet', type=str, help='name of the dataset to train') 27 | parser.add_argument('--lr', default=0.1, type=float, help='learning rate') 28 | parser.add_argument('--n_gpu', default=1, type=int, help='number of GPUs to use') 29 | parser.add_argument('--batch_size', default=128, type=int, help='batch size') 30 | parser.add_argument('--n_worker', default=4, type=int, help='number of data loader worker') 31 | parser.add_argument('--lr_type', default='exp', type=str, help='lr scheduler (exp/cos/step3/fixed)') 32 | parser.add_argument('--n_epoch', default=150, type=int, help='number of epochs to train') 33 | parser.add_argument('--wd', default=4e-5, type=float, help='weight decay') 34 | parser.add_argument('--seed', default=None, type=int, help='random seed to set') 35 | parser.add_argument('--data_root', default=None, type=str, help='dataset path') 36 | # resume 37 | parser.add_argument('--ckpt_path', default=None, type=str, help='checkpoint path to resume from') 38 | # run eval 39 | parser.add_argument('--eval', action='store_true', help='Simply run eval') 40 | 41 | return parser.parse_args() 42 | 43 | 44 | def get_model(): 45 | print('=> Building model..') 46 | if args.model == 'mobilenet': 47 | from models.mobilenet import MobileNet 48 | net = MobileNet(n_class=1000) 49 | elif args.model == 'mobilenet_0.5flops': 50 | from models.mobilenet import MobileNet 51 | net = MobileNet(n_class=1000, profile='0.5flops') 52 | else: 53 | raise NotImplementedError 54 | return net.cuda() if use_cuda else net 55 | 56 | 57 | def train(epoch, train_loader): 58 | print('\nEpoch: %d' % epoch) 59 | net.train() 60 | 61 | batch_time = AverageMeter() 62 | losses = AverageMeter() 63 | top1 = AverageMeter() 64 | top5 = AverageMeter() 65 | end = time.time() 66 | 67 | for batch_idx, (inputs, targets) in enumerate(train_loader): 68 | if use_cuda: 69 | inputs, targets = inputs.cuda(), targets.cuda() 70 | optimizer.zero_grad() 71 | outputs = net(inputs) 72 | loss = criterion(outputs, targets) 73 | 74 | loss.backward() 75 | optimizer.step() 76 | 77 | # measure accuracy and record loss 78 | prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5)) 79 | losses.update(loss.item(), inputs.size(0)) 80 | top1.update(prec1.item(), inputs.size(0)) 81 | top5.update(prec5.item(), inputs.size(0)) 82 | # timing 83 | batch_time.update(time.time() - end) 84 | end = time.time() 85 | 86 | progress_bar(batch_idx, len(train_loader), 'Loss: {:.3f} | Acc1: {:.3f}% | Acc5: {:.3f}%' 87 | .format(losses.avg, top1.avg, top5.avg)) 88 | writer.add_scalar('loss/train', losses.avg, epoch) 89 | writer.add_scalar('acc/train_top1', top1.avg, epoch) 90 | writer.add_scalar('acc/train_top5', top5.avg, epoch) 91 | 92 | 93 | def test(epoch, test_loader, save=True): 94 | global best_acc 95 | net.eval() 96 | 97 | batch_time = AverageMeter() 98 | losses = AverageMeter() 99 | top1 = AverageMeter() 100 | top5 = AverageMeter() 101 | end = time.time() 102 | 103 | with torch.no_grad(): 104 | for batch_idx, (inputs, targets) in enumerate(test_loader): 105 | if use_cuda: 106 | inputs, targets = inputs.cuda(), targets.cuda() 107 | outputs = net(inputs) 108 | loss = criterion(outputs, targets) 109 | 110 | # measure accuracy and record loss 111 | prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5)) 112 | losses.update(loss.item(), inputs.size(0)) 113 | top1.update(prec1.item(), inputs.size(0)) 114 | top5.update(prec5.item(), inputs.size(0)) 115 | # timing 116 | batch_time.update(time.time() - end) 117 | end = time.time() 118 | 119 | progress_bar(batch_idx, len(test_loader), 'Loss: {:.3f} | Acc1: {:.3f}% | Acc5: {:.3f}%' 120 | .format(losses.avg, top1.avg, top5.avg)) 121 | 122 | if save: 123 | writer.add_scalar('loss/test', losses.avg, epoch) 124 | writer.add_scalar('acc/test_top1', top1.avg, epoch) 125 | writer.add_scalar('acc/test_top5', top5.avg, epoch) 126 | 127 | is_best = False 128 | if top1.avg > best_acc: 129 | best_acc = top1.avg 130 | is_best = True 131 | 132 | print('Current best acc: {}'.format(best_acc)) 133 | save_checkpoint({ 134 | 'epoch': epoch, 135 | 'model': args.model, 136 | 'dataset': args.dataset, 137 | 'state_dict': net.module.state_dict() if isinstance(net, nn.DataParallel) else net.state_dict(), 138 | 'acc': top1.avg, 139 | 'optimizer': optimizer.state_dict(), 140 | }, is_best, checkpoint_dir=log_dir) 141 | 142 | 143 | def adjust_learning_rate(optimizer, epoch): 144 | if args.lr_type == 'cos': # cos without warm-up 145 | lr = 0.5 * args.lr * (1 + math.cos(math.pi * epoch / args.n_epoch)) 146 | elif args.lr_type == 'exp': 147 | step = 1 148 | decay = 0.96 149 | lr = args.lr * (decay ** (epoch // step)) 150 | elif args.lr_type == 'fixed': 151 | lr = args.lr 152 | else: 153 | raise NotImplementedError 154 | print('=> lr: {}'.format(lr)) 155 | for param_group in optimizer.param_groups: 156 | param_group['lr'] = lr 157 | return lr 158 | 159 | 160 | def save_checkpoint(state, is_best, checkpoint_dir='.'): 161 | filename = os.path.join(checkpoint_dir, 'ckpt.pth.tar') 162 | print('=> Saving checkpoint to {}'.format(filename)) 163 | torch.save(state, filename) 164 | if is_best: 165 | shutil.copyfile(filename, filename.replace('.pth.tar', '.best.pth.tar')) 166 | 167 | 168 | if __name__ == '__main__': 169 | args = parse_args() 170 | 171 | use_cuda = torch.cuda.is_available() 172 | if use_cuda: 173 | torch.backends.cudnn.benchmark = True 174 | 175 | best_acc = 0 # best test accuracy 176 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 177 | 178 | if args.seed is not None: 179 | np.random.seed(args.seed) 180 | torch.manual_seed(args.seed) 181 | torch.cuda.manual_seed(args.seed) 182 | 183 | print('=> Preparing data..') 184 | train_loader, val_loader, n_class = get_dataset(args.dataset, args.batch_size, args.n_worker, 185 | data_root=args.data_root) 186 | 187 | net = get_model() # for measure 188 | IMAGE_SIZE = 224 if args.dataset == 'imagenet' else 32 189 | n_flops, n_params = measure_model(net, IMAGE_SIZE, IMAGE_SIZE) 190 | print('=> Model Parameter: {:.3f} M, FLOPs: {:.3f}M'.format(n_params / 1e6, n_flops / 1e6)) 191 | 192 | del net 193 | net = get_model() # real training 194 | 195 | if args.ckpt_path is not None: # assigned checkpoint path to resume from 196 | print('=> Resuming from checkpoint..') 197 | checkpoint = torch.load(args.ckpt_path) 198 | sd = checkpoint['state_dict'] if 'state_dict' in checkpoint else checkpoint 199 | net.load_state_dict(sd) 200 | if use_cuda and args.n_gpu > 1: 201 | net = torch.nn.DataParallel(net, list(range(args.n_gpu))) 202 | 203 | criterion = nn.CrossEntropyLoss() 204 | print('Using SGD...') 205 | print('weight decay = {}'.format(args.wd)) 206 | optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.wd) 207 | 208 | if args.eval: # just run eval 209 | print('=> Start evaluation...') 210 | test(0, val_loader, save=False) 211 | else: # train 212 | print('=> Start training...') 213 | print('Training {} on {}...'.format(args.model, args.dataset)) 214 | log_dir = get_output_folder('./logs', '{}_{}_finetune'.format(args.model, args.dataset)) 215 | print('=> Saving logs to {}'.format(log_dir)) 216 | # tf writer 217 | writer = SummaryWriter(logdir=log_dir) 218 | 219 | for epoch in range(start_epoch, start_epoch + args.n_epoch): 220 | lr = adjust_learning_rate(optimizer, epoch) 221 | train(epoch, train_loader) 222 | test(epoch, val_loader) 223 | 224 | writer.close() 225 | print('=> Model Parameter: {:.3f} M, FLOPs: {:.3f}M, best top-1 acc: {}%'.format(n_params / 1e6, n_flops / 1e6, best_acc)) 226 | -------------------------------------------------------------------------------- /lib/memory.py: -------------------------------------------------------------------------------- 1 | # Code for "AMC: AutoML for Model Compression and Acceleration on Mobile Devices" 2 | # Yihui He*, Ji Lin*, Zhijian Liu, Hanrui Wang, Li-Jia Li, Song Han 3 | # {jilin, songhan}@mit.edu 4 | 5 | from __future__ import absolute_import 6 | from collections import deque, namedtuple 7 | import warnings 8 | import random 9 | 10 | import numpy as np 11 | 12 | # [reference] https://github.com/matthiasplappert/keras-rl/blob/master/rl/memory.py 13 | 14 | # This is to be understood as a transition: Given `state0`, performing `action` 15 | # yields `reward` and results in `state1`, which might be `terminal`. 16 | Experience = namedtuple('Experience', 'state0, action, reward, state1, terminal1') 17 | 18 | 19 | def sample_batch_indexes(low, high, size): 20 | if high - low >= size: 21 | # We have enough data. Draw without replacement, that is each index is unique in the 22 | # batch. We cannot use `np.random.choice` here because it is horribly inefficient as 23 | # the memory grows. See https://github.com/numpy/numpy/issues/2764 for a discussion. 24 | # `random.sample` does the same thing (drawing without replacement) and is way faster. 25 | r = range(low, high) 26 | batch_idxs = random.sample(r, size) 27 | else: 28 | # Not enough data. Help ourselves with sampling from the range, but the same index 29 | # can occur multiple times. This is not good and should be avoided by picking a 30 | # large enough warm-up phase. 31 | warnings.warn( 32 | 'Not enough entries to sample without replacement. ' 33 | 'Consider increasing your warm-up phase to avoid oversampling!') 34 | batch_idxs = np.random.random_integers(low, high - 1, size=size) 35 | assert len(batch_idxs) == size 36 | return batch_idxs 37 | 38 | 39 | class RingBuffer(object): 40 | def __init__(self, maxlen): 41 | self.maxlen = maxlen 42 | self.start = 0 43 | self.length = 0 44 | self.data = [None for _ in range(maxlen)] 45 | 46 | def __len__(self): 47 | return self.length 48 | 49 | def __getitem__(self, idx): 50 | if idx < 0 or idx >= self.length: 51 | raise KeyError() 52 | return self.data[(self.start + idx) % self.maxlen] 53 | 54 | def append(self, v): 55 | if self.length < self.maxlen: 56 | # We have space, simply increase the length. 57 | self.length += 1 58 | elif self.length == self.maxlen: 59 | # No space, "remove" the first item. 60 | self.start = (self.start + 1) % self.maxlen 61 | else: 62 | # This should never happen. 63 | raise RuntimeError() 64 | self.data[(self.start + self.length - 1) % self.maxlen] = v 65 | 66 | 67 | def zeroed_observation(observation): 68 | if hasattr(observation, 'shape'): 69 | return np.zeros(observation.shape) 70 | elif hasattr(observation, '__iter__'): 71 | out = [] 72 | for x in observation: 73 | out.append(zeroed_observation(x)) 74 | return out 75 | else: 76 | return 0. 77 | 78 | 79 | class Memory(object): 80 | def __init__(self, window_length, ignore_episode_boundaries=False): 81 | self.window_length = window_length 82 | self.ignore_episode_boundaries = ignore_episode_boundaries 83 | 84 | self.recent_observations = deque(maxlen=window_length) 85 | self.recent_terminals = deque(maxlen=window_length) 86 | 87 | def sample(self, batch_size, batch_idxs=None): 88 | raise NotImplementedError() 89 | 90 | def append(self, observation, action, reward, terminal, training=True): 91 | self.recent_observations.append(observation) 92 | self.recent_terminals.append(terminal) 93 | 94 | def get_recent_state(self, current_observation): 95 | # This code is slightly complicated by the fact that subsequent observations might be 96 | # from different episodes. We ensure that an experience never spans multiple episodes. 97 | # This is probably not that important in practice but it seems cleaner. 98 | state = [current_observation] 99 | idx = len(self.recent_observations) - 1 100 | for offset in range(0, self.window_length - 1): 101 | current_idx = idx - offset 102 | current_terminal = self.recent_terminals[current_idx - 1] if current_idx - 1 >= 0 else False 103 | if current_idx < 0 or (not self.ignore_episode_boundaries and current_terminal): 104 | # The previously handled observation was terminal, don't add the current one. 105 | # Otherwise we would leak into a different episode. 106 | break 107 | state.insert(0, self.recent_observations[current_idx]) 108 | while len(state) < self.window_length: 109 | state.insert(0, zeroed_observation(state[0])) 110 | return state 111 | 112 | def get_config(self): 113 | config = { 114 | 'window_length': self.window_length, 115 | 'ignore_episode_boundaries': self.ignore_episode_boundaries, 116 | } 117 | return config 118 | 119 | 120 | class SequentialMemory(Memory): 121 | def __init__(self, limit, **kwargs): 122 | super(SequentialMemory, self).__init__(**kwargs) 123 | 124 | self.limit = limit 125 | 126 | # Do not use deque to implement the memory. This data structure may seem convenient but 127 | # it is way too slow on random access. Instead, we use our own ring buffer implementation. 128 | self.actions = RingBuffer(limit) 129 | self.rewards = RingBuffer(limit) 130 | self.terminals = RingBuffer(limit) 131 | self.observations = RingBuffer(limit) 132 | 133 | def sample(self, batch_size, batch_idxs=None): 134 | if batch_idxs is None: 135 | # Draw random indexes such that we have at least a single entry before each 136 | # index. 137 | batch_idxs = sample_batch_indexes(0, self.nb_entries - 1, size=batch_size) 138 | batch_idxs = np.array(batch_idxs) + 1 139 | assert np.min(batch_idxs) >= 1 140 | assert np.max(batch_idxs) < self.nb_entries 141 | assert len(batch_idxs) == batch_size 142 | 143 | # Create experiences 144 | experiences = [] 145 | for idx in batch_idxs: 146 | terminal0 = self.terminals[idx - 2] if idx >= 2 else False 147 | while terminal0: 148 | # Skip this transition because the environment was reset here. Select a new, random 149 | # transition and use this instead. This may cause the batch to contain the same 150 | # transition twice. 151 | idx = sample_batch_indexes(1, self.nb_entries, size=1)[0] 152 | terminal0 = self.terminals[idx - 2] if idx >= 2 else False 153 | assert 1 <= idx < self.nb_entries 154 | 155 | # This code is slightly complicated by the fact that subsequent observations might be 156 | # from different episodes. We ensure that an experience never spans multiple episodes. 157 | # This is probably not that important in practice but it seems cleaner. 158 | state0 = [self.observations[idx - 1]] 159 | for offset in range(0, self.window_length - 1): 160 | current_idx = idx - 2 - offset 161 | current_terminal = self.terminals[current_idx - 1] if current_idx - 1 > 0 else False 162 | if current_idx < 0 or (not self.ignore_episode_boundaries and current_terminal): 163 | # The previously handled observation was terminal, don't add the current one. 164 | # Otherwise we would leak into a different episode. 165 | break 166 | state0.insert(0, self.observations[current_idx]) 167 | while len(state0) < self.window_length: 168 | state0.insert(0, zeroed_observation(state0[0])) 169 | action = self.actions[idx - 1] 170 | reward = self.rewards[idx - 1] 171 | terminal1 = self.terminals[idx - 1] 172 | 173 | # Okay, now we need to create the follow-up state. This is state0 shifted on timestep 174 | # to the right. Again, we need to be careful to not include an observation from the next 175 | # episode if the last state is terminal. 176 | state1 = [np.copy(x) for x in state0[1:]] 177 | state1.append(self.observations[idx]) 178 | 179 | assert len(state0) == self.window_length 180 | assert len(state1) == len(state0) 181 | experiences.append(Experience(state0=state0, action=action, reward=reward, 182 | state1=state1, terminal1=terminal1)) 183 | assert len(experiences) == batch_size 184 | return experiences 185 | 186 | def sample_and_split(self, batch_size, batch_idxs=None): 187 | experiences = self.sample(batch_size, batch_idxs) 188 | 189 | state0_batch = [] 190 | reward_batch = [] 191 | action_batch = [] 192 | terminal1_batch = [] 193 | state1_batch = [] 194 | for e in experiences: 195 | state0_batch.append(e.state0) 196 | state1_batch.append(e.state1) 197 | reward_batch.append(e.reward) 198 | action_batch.append(e.action) 199 | terminal1_batch.append(0. if e.terminal1 else 1.) 200 | 201 | # Prepare and validate parameters. 202 | state0_batch = np.array(state0_batch, 'double').reshape(batch_size, -1) 203 | state1_batch = np.array(state1_batch, 'double').reshape(batch_size, -1) 204 | terminal1_batch = np.array(terminal1_batch, 'double').reshape(batch_size, -1) 205 | reward_batch = np.array(reward_batch, 'double').reshape(batch_size, -1) 206 | action_batch = np.array(action_batch, 'double').reshape(batch_size, -1) 207 | 208 | return state0_batch, action_batch, reward_batch, state1_batch, terminal1_batch 209 | 210 | def append(self, observation, action, reward, terminal, training=True): 211 | super(SequentialMemory, self).append(observation, action, reward, terminal, training=training) 212 | 213 | # This needs to be understood as follows: in `observation`, take `action`, obtain `reward` 214 | # and weather the next state is `terminal` or not. 215 | if training: 216 | self.observations.append(observation) 217 | self.actions.append(action) 218 | self.rewards.append(reward) 219 | self.terminals.append(terminal) 220 | 221 | @property 222 | def nb_entries(self): 223 | return len(self.observations) 224 | 225 | def get_config(self): 226 | config = super(SequentialMemory, self).get_config() 227 | config['limit'] = self.limit 228 | return config 229 | -------------------------------------------------------------------------------- /amc_search.py: -------------------------------------------------------------------------------- 1 | # Code for "AMC: AutoML for Model Compression and Acceleration on Mobile Devices" 2 | # Yihui He*, Ji Lin*, Zhijian Liu, Hanrui Wang, Li-Jia Li, Song Han 3 | # {jilin, songhan}@mit.edu 4 | 5 | import os 6 | import numpy as np 7 | import argparse 8 | from copy import deepcopy 9 | import torch 10 | torch.backends.cudnn.deterministic = True 11 | 12 | from env.channel_pruning_env import ChannelPruningEnv 13 | from lib.agent import DDPG 14 | from lib.utils import get_output_folder 15 | 16 | from tensorboardX import SummaryWriter 17 | 18 | 19 | def parse_args(): 20 | parser = argparse.ArgumentParser(description='AMC search script') 21 | 22 | parser.add_argument('--job', default='train', type=str, help='support option: train/export') 23 | parser.add_argument('--suffix', default=None, type=str, help='suffix to help you remember what experiment you ran') 24 | # env 25 | parser.add_argument('--model', default='mobilenet', type=str, help='model to prune') 26 | parser.add_argument('--dataset', default='imagenet', type=str, help='dataset to use (cifar/imagenet)') 27 | parser.add_argument('--data_root', default=None, type=str, help='dataset path') 28 | parser.add_argument('--preserve_ratio', default=0.5, type=float, help='preserve ratio of the model') 29 | parser.add_argument('--lbound', default=0.2, type=float, help='minimum preserve ratio') 30 | parser.add_argument('--rbound', default=1., type=float, help='maximum preserve ratio') 31 | parser.add_argument('--reward', default='acc_reward', type=str, help='Setting the reward') 32 | parser.add_argument('--acc_metric', default='acc5', type=str, help='use acc1 or acc5') 33 | parser.add_argument('--use_real_val', dest='use_real_val', action='store_true') 34 | parser.add_argument('--ckpt_path', default=None, type=str, help='manual path of checkpoint') 35 | # parser.add_argument('--pruning_method', default='cp', type=str, 36 | # help='method to prune (fg/cp for fine-grained and channel pruning)') 37 | # only for channel pruning 38 | parser.add_argument('--n_calibration_batches', default=60, type=int, 39 | help='n_calibration_batches') 40 | parser.add_argument('--n_points_per_layer', default=10, type=int, 41 | help='method to prune (fg/cp for fine-grained and channel pruning)') 42 | parser.add_argument('--channel_round', default=8, type=int, help='Round channel to multiple of channel_round') 43 | # ddpg 44 | parser.add_argument('--hidden1', default=300, type=int, help='hidden num of first fully connect layer') 45 | parser.add_argument('--hidden2', default=300, type=int, help='hidden num of second fully connect layer') 46 | parser.add_argument('--lr_c', default=1e-3, type=float, help='learning rate for actor') 47 | parser.add_argument('--lr_a', default=1e-4, type=float, help='learning rate for actor') 48 | parser.add_argument('--warmup', default=100, type=int, 49 | help='time without training but only filling the replay memory') 50 | parser.add_argument('--discount', default=1., type=float, help='') 51 | parser.add_argument('--bsize', default=64, type=int, help='minibatch size') 52 | parser.add_argument('--rmsize', default=100, type=int, help='memory size for each layer') 53 | parser.add_argument('--window_length', default=1, type=int, help='') 54 | parser.add_argument('--tau', default=0.01, type=float, help='moving average for target network') 55 | # noise (truncated normal distribution) 56 | parser.add_argument('--init_delta', default=0.5, type=float, 57 | help='initial variance of truncated normal distribution') 58 | parser.add_argument('--delta_decay', default=0.95, type=float, 59 | help='delta decay during exploration') 60 | # training 61 | parser.add_argument('--max_episode_length', default=1e9, type=int, help='') 62 | parser.add_argument('--output', default='./logs', type=str, help='') 63 | parser.add_argument('--debug', dest='debug', action='store_true') 64 | parser.add_argument('--init_w', default=0.003, type=float, help='') 65 | parser.add_argument('--train_episode', default=800, type=int, help='train iters each timestep') 66 | parser.add_argument('--epsilon', default=50000, type=int, help='linear decay of exploration policy') 67 | parser.add_argument('--seed', default=None, type=int, help='random seed to set') 68 | parser.add_argument('--n_gpu', default=1, type=int, help='number of gpu to use') 69 | parser.add_argument('--n_worker', default=16, type=int, help='number of data loader worker') 70 | parser.add_argument('--data_bsize', default=50, type=int, help='number of data batch size') 71 | parser.add_argument('--resume', default='default', type=str, help='Resuming model path for testing') 72 | # export 73 | parser.add_argument('--ratios', default=None, type=str, help='ratios for pruning') 74 | parser.add_argument('--channels', default=None, type=str, help='channels after pruning') 75 | parser.add_argument('--export_path', default=None, type=str, help='path for exporting models') 76 | parser.add_argument('--use_new_input', dest='use_new_input', action='store_true', help='use new input feature') 77 | 78 | return parser.parse_args() 79 | 80 | 81 | def get_model_and_checkpoint(model, dataset, checkpoint_path, n_gpu=1): 82 | if model == 'mobilenet' and dataset == 'imagenet': 83 | from models.mobilenet import MobileNet 84 | net = MobileNet(n_class=1000) 85 | elif model == 'mobilenetv2' and dataset == 'imagenet': 86 | from models.mobilenet_v2 import MobileNetV2 87 | net = MobileNetV2(n_class=1000) 88 | else: 89 | raise NotImplementedError 90 | sd = torch.load(checkpoint_path) 91 | if 'state_dict' in sd: # a checkpoint but not a state_dict 92 | sd = sd['state_dict'] 93 | sd = {k.replace('module.', ''): v for k, v in sd.items()} 94 | net.load_state_dict(sd) 95 | net = net.cuda() 96 | if n_gpu > 1: 97 | net = torch.nn.DataParallel(net, range(n_gpu)) 98 | 99 | return net, deepcopy(net.state_dict()) 100 | 101 | 102 | def train(num_episode, agent, env, output): 103 | agent.is_training = True 104 | step = episode = episode_steps = 0 105 | episode_reward = 0. 106 | observation = None 107 | T = [] # trajectory 108 | while episode < num_episode: # counting based on episode 109 | # reset if it is the start of episode 110 | if observation is None: 111 | observation = deepcopy(env.reset()) 112 | agent.reset(observation) 113 | 114 | # agent pick action ... 115 | if episode <= args.warmup: 116 | action = agent.random_action() 117 | # action = sample_from_truncated_normal_distribution(lower=0., upper=1., mu=env.preserve_ratio, sigma=0.5) 118 | else: 119 | action = agent.select_action(observation, episode=episode) 120 | 121 | # env response with next_observation, reward, terminate_info 122 | observation2, reward, done, info = env.step(action) 123 | observation2 = deepcopy(observation2) 124 | 125 | T.append([reward, deepcopy(observation), deepcopy(observation2), action, done]) 126 | 127 | # fix-length, never reach here 128 | # if max_episode_length and episode_steps >= max_episode_length - 1: 129 | # done = True 130 | 131 | # [optional] save intermideate model 132 | if episode % int(num_episode / 3) == 0: 133 | agent.save_model(output) 134 | 135 | # update 136 | step += 1 137 | episode_steps += 1 138 | episode_reward += reward 139 | observation = deepcopy(observation2) 140 | 141 | if done: # end of episode 142 | print('#{}: episode_reward:{:.4f} acc: {:.4f}, ratio: {:.4f}'.format(episode, episode_reward, 143 | info['accuracy'], 144 | info['compress_ratio'])) 145 | text_writer.write( 146 | '#{}: episode_reward:{:.4f} acc: {:.4f}, ratio: {:.4f}\n'.format(episode, episode_reward, 147 | info['accuracy'], 148 | info['compress_ratio'])) 149 | final_reward = T[-1][0] 150 | # print('final_reward: {}'.format(final_reward)) 151 | # agent observe and update policy 152 | for r_t, s_t, s_t1, a_t, done in T: 153 | agent.observe(final_reward, s_t, s_t1, a_t, done) 154 | if episode > args.warmup: 155 | agent.update_policy() 156 | 157 | #agent.memory.append( 158 | # observation, 159 | # agent.select_action(observation, episode=episode), 160 | # 0., False 161 | #) 162 | 163 | # reset 164 | observation = None 165 | episode_steps = 0 166 | episode_reward = 0. 167 | episode += 1 168 | T = [] 169 | 170 | tfwriter.add_scalar('reward/last', final_reward, episode) 171 | tfwriter.add_scalar('reward/best', env.best_reward, episode) 172 | tfwriter.add_scalar('info/accuracy', info['accuracy'], episode) 173 | tfwriter.add_scalar('info/compress_ratio', info['compress_ratio'], episode) 174 | tfwriter.add_text('info/best_policy', str(env.best_strategy), episode) 175 | # record the preserve rate for each layer 176 | for i, preserve_rate in enumerate(env.strategy): 177 | tfwriter.add_scalar('preserve_rate/{}'.format(i), preserve_rate, episode) 178 | 179 | text_writer.write('best reward: {}\n'.format(env.best_reward)) 180 | text_writer.write('best policy: {}\n'.format(env.best_strategy)) 181 | text_writer.close() 182 | 183 | 184 | def export_model(env, args): 185 | assert args.ratios is not None or args.channels is not None, 'Please provide a valid ratio list or pruned channels' 186 | assert args.export_path is not None, 'Please provide a valid export path' 187 | env.set_export_path(args.export_path) 188 | 189 | print('=> Original model channels: {}'.format(env.org_channels)) 190 | if args.ratios: 191 | ratios = args.ratios.split(',') 192 | ratios = [float(r) for r in ratios] 193 | assert len(ratios) == len(env.org_channels) 194 | channels = [int(r * c) for r, c in zip(ratios, env.org_channels)] 195 | else: 196 | channels = args.channels.split(',') 197 | channels = [int(r) for r in channels] 198 | ratios = [c2 / c1 for c2, c1 in zip(channels, env.org_channels)] 199 | print('=> Pruning with ratios: {}'.format(ratios)) 200 | print('=> Channels after pruning: {}'.format(channels)) 201 | 202 | for r in ratios: 203 | env.step(r) 204 | 205 | return 206 | 207 | 208 | if __name__ == "__main__": 209 | args = parse_args() 210 | 211 | if args.seed is not None: 212 | np.random.seed(args.seed) 213 | torch.manual_seed(args.seed) 214 | torch.cuda.manual_seed(args.seed) 215 | 216 | model, checkpoint = get_model_and_checkpoint(args.model, args.dataset, checkpoint_path=args.ckpt_path, 217 | n_gpu=args.n_gpu) 218 | 219 | env = ChannelPruningEnv(model, checkpoint, args.dataset, 220 | preserve_ratio=1. if args.job == 'export' else args.preserve_ratio, 221 | n_data_worker=args.n_worker, batch_size=args.data_bsize, 222 | args=args, export_model=args.job == 'export', use_new_input=args.use_new_input) 223 | 224 | if args.job == 'train': 225 | # build folder and logs 226 | base_folder_name = '{}_{}_r{}_search'.format(args.model, args.dataset, args.preserve_ratio) 227 | if args.suffix is not None: 228 | base_folder_name = base_folder_name + '_' + args.suffix 229 | args.output = get_output_folder(args.output, base_folder_name) 230 | print('=> Saving logs to {}'.format(args.output)) 231 | tfwriter = SummaryWriter(logdir=args.output) 232 | text_writer = open(os.path.join(args.output, 'log.txt'), 'w') 233 | print('=> Output path: {}...'.format(args.output)) 234 | 235 | nb_states = env.layer_embedding.shape[1] 236 | nb_actions = 1 # just 1 action here 237 | 238 | args.rmsize = args.rmsize * len(env.prunable_idx) # for each layer 239 | print('** Actual replay buffer size: {}'.format(args.rmsize)) 240 | 241 | agent = DDPG(nb_states, nb_actions, args) 242 | train(args.train_episode, agent, env, args.output) 243 | elif args.job == 'export': 244 | export_model(env, args) 245 | else: 246 | raise RuntimeError('Undefined job {}'.format(args.job)) 247 | -------------------------------------------------------------------------------- /env/channel_pruning_env.py: -------------------------------------------------------------------------------- 1 | # Code for "AMC: AutoML for Model Compression and Acceleration on Mobile Devices" 2 | # Yihui He*, Ji Lin*, Zhijian Liu, Hanrui Wang, Li-Jia Li, Song Han 3 | # {jilin, songhan}@mit.edu 4 | 5 | import time 6 | import torch 7 | import torch.nn as nn 8 | from lib.utils import AverageMeter, accuracy, prGreen 9 | from lib.data import get_split_dataset 10 | from env.rewards import * 11 | import math 12 | 13 | import numpy as np 14 | import copy 15 | 16 | 17 | class ChannelPruningEnv: 18 | """ 19 | Env for channel pruning search 20 | """ 21 | def __init__(self, model, checkpoint, data, preserve_ratio, args, n_data_worker=4, 22 | batch_size=256, export_model=False, use_new_input=False): 23 | # default setting 24 | self.prunable_layer_types = [torch.nn.modules.conv.Conv2d, torch.nn.modules.linear.Linear] 25 | 26 | # save options 27 | self.model = model 28 | self.checkpoint = checkpoint 29 | self.n_data_worker = n_data_worker 30 | self.batch_size = batch_size 31 | self.data_type = data 32 | self.preserve_ratio = preserve_ratio 33 | 34 | # options from args 35 | self.args = args 36 | self.lbound = args.lbound 37 | self.rbound = args.rbound 38 | 39 | self.use_real_val = args.use_real_val 40 | 41 | self.n_calibration_batches = args.n_calibration_batches 42 | self.n_points_per_layer = args.n_points_per_layer 43 | self.channel_round = args.channel_round 44 | self.acc_metric = args.acc_metric 45 | self.data_root = args.data_root 46 | 47 | self.export_model = export_model 48 | self.use_new_input = use_new_input 49 | 50 | # sanity check 51 | assert self.preserve_ratio > self.lbound, 'Error! You can make achieve preserve_ratio smaller than lbound!' 52 | 53 | # prepare data 54 | self._init_data() 55 | 56 | # build indexs 57 | self._build_index() 58 | self.n_prunable_layer = len(self.prunable_idx) 59 | 60 | # extract information for preparing 61 | self._extract_layer_information() 62 | 63 | # build embedding (static part) 64 | self._build_state_embedding() 65 | 66 | # build reward 67 | self.reset() # restore weight 68 | self.org_acc = self._validate(self.val_loader, self.model) 69 | print('=> original acc: {:.3f}%'.format(self.org_acc)) 70 | self.org_model_size = sum(self.wsize_list) 71 | print('=> original weight size: {:.4f} M param'.format(self.org_model_size * 1. / 1e6)) 72 | self.org_flops = sum(self.flops_list) 73 | print('=> FLOPs:') 74 | print([self.layer_info_dict[idx]['flops']/1e6 for idx in sorted(self.layer_info_dict.keys())]) 75 | print('=> original FLOPs: {:.4f} M'.format(self.org_flops * 1. / 1e6)) 76 | 77 | self.expected_preserve_computation = self.preserve_ratio * self.org_flops 78 | 79 | self.reward = eval(args.reward) 80 | 81 | self.best_reward = -math.inf 82 | self.best_strategy = None 83 | self.best_d_prime_list = None 84 | 85 | self.org_w_size = sum(self.wsize_list) 86 | 87 | def step(self, action): 88 | # Pseudo prune and get the corresponding statistics. The real pruning happens till the end of all pseudo pruning 89 | if self.visited[self.cur_ind]: 90 | action = self.strategy_dict[self.prunable_idx[self.cur_ind]][0] 91 | preserve_idx = self.index_buffer[self.cur_ind] 92 | else: 93 | action = self._action_wall(action) # percentage to preserve 94 | preserve_idx = None 95 | 96 | # prune and update action 97 | action, d_prime, preserve_idx = self.prune_kernel(self.prunable_idx[self.cur_ind], action, preserve_idx) 98 | 99 | if not self.visited[self.cur_ind]: 100 | for group in self.shared_idx: 101 | if self.cur_ind in group: # set the shared ones 102 | for g_idx in group: 103 | self.strategy_dict[self.prunable_idx[g_idx]][0] = action 104 | self.strategy_dict[self.prunable_idx[g_idx - 1]][1] = action 105 | self.visited[g_idx] = True 106 | self.index_buffer[g_idx] = preserve_idx.copy() 107 | 108 | if self.export_model: # export checkpoint 109 | print('# Pruning {}: ratio: {}, d_prime: {}'.format(self.cur_ind, action, d_prime)) 110 | 111 | self.strategy.append(action) # save action to strategy 112 | self.d_prime_list.append(d_prime) 113 | 114 | self.strategy_dict[self.prunable_idx[self.cur_ind]][0] = action 115 | if self.cur_ind > 0: 116 | self.strategy_dict[self.prunable_idx[self.cur_ind - 1]][1] = action 117 | 118 | # all the actions are made 119 | if self._is_final_layer(): 120 | assert len(self.strategy) == len(self.prunable_idx) 121 | current_flops = self._cur_flops() 122 | acc_t1 = time.time() 123 | acc = self._validate(self.val_loader, self.model) 124 | acc_t2 = time.time() 125 | self.val_time = acc_t2 - acc_t1 126 | compress_ratio = current_flops * 1. / self.org_flops 127 | info_set = {'compress_ratio': compress_ratio, 'accuracy': acc, 'strategy': self.strategy.copy()} 128 | reward = self.reward(self, acc, current_flops) 129 | 130 | if reward > self.best_reward: 131 | self.best_reward = reward 132 | self.best_strategy = self.strategy.copy() 133 | self.best_d_prime_list = self.d_prime_list.copy() 134 | prGreen('New best reward: {:.4f}, acc: {:.4f}, compress: {:.4f}'.format(self.best_reward, acc, compress_ratio)) 135 | prGreen('New best policy: {}'.format(self.best_strategy)) 136 | prGreen('New best d primes: {}'.format(self.best_d_prime_list)) 137 | 138 | obs = self.layer_embedding[self.cur_ind, :].copy() # actually the same as the last state 139 | done = True 140 | if self.export_model: # export state dict 141 | torch.save(self.model.state_dict(), self.export_path) 142 | return None, None, None, None 143 | return obs, reward, done, info_set 144 | 145 | info_set = None 146 | reward = 0 147 | done = False 148 | self.visited[self.cur_ind] = True # set to visited 149 | self.cur_ind += 1 # the index of next layer 150 | # build next state (in-place modify) 151 | self.layer_embedding[self.cur_ind][-3] = self._cur_reduced() * 1. / self.org_flops # reduced 152 | self.layer_embedding[self.cur_ind][-2] = sum(self.flops_list[self.cur_ind + 1:]) * 1. / self.org_flops # rest 153 | self.layer_embedding[self.cur_ind][-1] = self.strategy[-1] # last action 154 | obs = self.layer_embedding[self.cur_ind, :].copy() 155 | 156 | return obs, reward, done, info_set 157 | 158 | def reset(self): 159 | # restore env by loading the checkpoint 160 | self.model.load_state_dict(self.checkpoint) 161 | self.cur_ind = 0 162 | self.strategy = [] # pruning strategy 163 | self.d_prime_list = [] 164 | self.strategy_dict = copy.deepcopy(self.min_strategy_dict) 165 | # reset layer embeddings 166 | self.layer_embedding[:, -1] = 1. 167 | self.layer_embedding[:, -2] = 0. 168 | self.layer_embedding[:, -3] = 0. 169 | obs = self.layer_embedding[0].copy() 170 | obs[-2] = sum(self.wsize_list[1:]) * 1. / sum(self.wsize_list) 171 | self.extract_time = 0 172 | self.fit_time = 0 173 | self.val_time = 0 174 | # for share index 175 | self.visited = [False] * len(self.prunable_idx) 176 | self.index_buffer = {} 177 | return obs 178 | 179 | def set_export_path(self, path): 180 | self.export_path = path 181 | 182 | def prune_kernel(self, op_idx, preserve_ratio, preserve_idx=None): 183 | '''Return the real ratio''' 184 | m_list = list(self.model.modules()) 185 | op = m_list[op_idx] 186 | assert (preserve_ratio <= 1.) 187 | 188 | if preserve_ratio == 1: # do not prune 189 | return 1., op.weight.size(1), None # TODO: should be a full index 190 | # n, c, h, w = op.weight.size() 191 | # mask = np.ones([c], dtype=bool) 192 | 193 | def format_rank(x): 194 | rank = int(np.around(x)) 195 | return max(rank, 1) 196 | 197 | n, c = op.weight.size(0), op.weight.size(1) 198 | d_prime = format_rank(c * preserve_ratio) 199 | d_prime = int(np.ceil(d_prime * 1. / self.channel_round) * self.channel_round) 200 | if d_prime > c: 201 | d_prime = int(np.floor(c * 1. / self.channel_round) * self.channel_round) 202 | 203 | extract_t1 = time.time() 204 | if self.use_new_input: # this is slow and may lead to overfitting 205 | self._regenerate_input_feature() 206 | X = self.layer_info_dict[op_idx]['input_feat'] # input after pruning of previous ops 207 | Y = self.layer_info_dict[op_idx]['output_feat'] # fixed output from original model 208 | weight = op.weight.data.cpu().numpy() 209 | # conv [C_out, C_in, ksize, ksize] 210 | # fc [C_out, C_in] 211 | op_type = 'Conv2D' 212 | if len(weight.shape) == 2: 213 | op_type = 'Linear' 214 | weight = weight[:, :, None, None] 215 | extract_t2 = time.time() 216 | self.extract_time += extract_t2 - extract_t1 217 | fit_t1 = time.time() 218 | 219 | if preserve_idx is None: # not provided, generate new 220 | importance = np.abs(weight).sum((0, 2, 3)) 221 | sorted_idx = np.argsort(-importance) # sum magnitude along C_in, sort descend 222 | preserve_idx = sorted_idx[:d_prime] # to preserve index 223 | assert len(preserve_idx) == d_prime 224 | mask = np.zeros(weight.shape[1], bool) 225 | mask[preserve_idx] = True 226 | 227 | # reconstruct, X, Y <= [N, C] 228 | masked_X = X[:, mask] 229 | if weight.shape[2] == 1: # 1x1 conv or fc 230 | from lib.utils import least_square_sklearn 231 | rec_weight = least_square_sklearn(X=masked_X, Y=Y) 232 | rec_weight = rec_weight.reshape(-1, 1, 1, d_prime) # (C_out, K_h, K_w, C_in') 233 | rec_weight = np.transpose(rec_weight, (0, 3, 1, 2)) # (C_out, C_in', K_h, K_w) 234 | else: 235 | raise NotImplementedError('Current code only supports 1x1 conv now!') 236 | if not self.export_model: # pad, pseudo compress 237 | rec_weight_pad = np.zeros_like(weight) 238 | rec_weight_pad[:, mask, :, :] = rec_weight 239 | rec_weight = rec_weight_pad 240 | 241 | if op_type == 'Linear': 242 | rec_weight = rec_weight.squeeze() 243 | assert len(rec_weight.shape) == 2 244 | fit_t2 = time.time() 245 | self.fit_time += fit_t2 - fit_t1 246 | # now assign 247 | op.weight.data = torch.from_numpy(rec_weight).cuda() 248 | action = np.sum(mask) * 1. / len(mask) # calculate the ratio 249 | if self.export_model: # prune previous buffer ops 250 | prev_idx = self.prunable_idx[self.prunable_idx.index(op_idx) - 1] 251 | for idx in range(prev_idx, op_idx): 252 | m = m_list[idx] 253 | if type(m) == nn.Conv2d: # depthwise 254 | m.weight.data = torch.from_numpy(m.weight.data.cpu().numpy()[mask, :, :, :]).cuda() 255 | if m.groups == m.in_channels: 256 | m.groups = int(np.sum(mask)) 257 | elif type(m) == nn.BatchNorm2d: 258 | m.weight.data = torch.from_numpy(m.weight.data.cpu().numpy()[mask]).cuda() 259 | m.bias.data = torch.from_numpy(m.bias.data.cpu().numpy()[mask]).cuda() 260 | m.running_mean.data = torch.from_numpy(m.running_mean.data.cpu().numpy()[mask]).cuda() 261 | m.running_var.data = torch.from_numpy(m.running_var.data.cpu().numpy()[mask]).cuda() 262 | return action, d_prime, preserve_idx 263 | 264 | def _is_final_layer(self): 265 | return self.cur_ind == len(self.prunable_idx) - 1 266 | 267 | def _action_wall(self, action): 268 | assert len(self.strategy) == self.cur_ind 269 | 270 | action = float(action) 271 | action = np.clip(action, 0, 1) 272 | 273 | other_comp = 0 274 | this_comp = 0 275 | for i, idx in enumerate(self.prunable_idx): 276 | flop = self.layer_info_dict[idx]['flops'] 277 | buffer_flop = self._get_buffer_flops(idx) 278 | 279 | if i == self.cur_ind - 1: # TODO: add other member in the set 280 | this_comp += flop * self.strategy_dict[idx][0] 281 | # add buffer (but not influenced by ratio) 282 | other_comp += buffer_flop * self.strategy_dict[idx][0] 283 | elif i == self.cur_ind: 284 | this_comp += flop * self.strategy_dict[idx][1] 285 | # also add buffer here (influenced by ratio) 286 | this_comp += buffer_flop 287 | else: 288 | other_comp += flop * self.strategy_dict[idx][0] * self.strategy_dict[idx][1] 289 | # add buffer 290 | other_comp += buffer_flop * self.strategy_dict[idx][0] # only consider input reduction 291 | 292 | self.expected_min_preserve = other_comp + this_comp * action 293 | max_preserve_ratio = (self.expected_preserve_computation - other_comp) * 1. / this_comp 294 | 295 | action = np.minimum(action, max_preserve_ratio) 296 | action = np.maximum(action, self.strategy_dict[self.prunable_idx[self.cur_ind]][0]) # impossible (should be) 297 | 298 | return action 299 | 300 | def _get_buffer_flops(self, idx): 301 | buffer_idx = self.buffer_dict[idx] 302 | buffer_flop = sum([self.layer_info_dict[_]['flops'] for _ in buffer_idx]) 303 | return buffer_flop 304 | 305 | def _cur_flops(self): 306 | flops = 0 307 | for i, idx in enumerate(self.prunable_idx): 308 | c, n = self.strategy_dict[idx] # input, output pruning ratio 309 | flops += self.layer_info_dict[idx]['flops'] * c * n 310 | # add buffer computation 311 | flops += self._get_buffer_flops(idx) * c # only related to input channel reduction 312 | return flops 313 | 314 | def _cur_reduced(self): 315 | # return the reduced weight 316 | reduced = self.org_flops - self._cur_flops() 317 | return reduced 318 | 319 | def _init_data(self): 320 | # split the train set into train + val 321 | # for CIFAR, split 5k for val 322 | # for ImageNet, split 3k for val 323 | val_size = 5000 if 'cifar' in self.data_type else 3000 324 | self.train_loader, self.val_loader, n_class = get_split_dataset(self.data_type, self.batch_size, 325 | self.n_data_worker, val_size, 326 | data_root=self.data_root, 327 | use_real_val=self.use_real_val, 328 | shuffle=False) # same sampling 329 | if self.use_real_val: # use the real val set for eval, which is actually wrong 330 | print('*** USE REAL VALIDATION SET!') 331 | 332 | def _build_index(self): 333 | self.prunable_idx = [] 334 | self.prunable_ops = [] 335 | self.layer_type_dict = {} 336 | self.strategy_dict = {} 337 | self.buffer_dict = {} 338 | this_buffer_list = [] 339 | self.org_channels = [] 340 | # build index and the min strategy dict 341 | for i, m in enumerate(self.model.modules()): 342 | if type(m) in self.prunable_layer_types: 343 | if type(m) == nn.Conv2d and m.groups == m.in_channels: # depth-wise conv, buffer 344 | this_buffer_list.append(i) 345 | else: # really prunable 346 | self.prunable_idx.append(i) 347 | self.prunable_ops.append(m) 348 | self.layer_type_dict[i] = type(m) 349 | self.buffer_dict[i] = this_buffer_list 350 | this_buffer_list = [] # empty 351 | self.org_channels.append(m.in_channels if type(m) == nn.Conv2d else m.in_features) 352 | 353 | self.strategy_dict[i] = [self.lbound, self.lbound] 354 | 355 | self.strategy_dict[self.prunable_idx[0]][0] = 1 # modify the input 356 | self.strategy_dict[self.prunable_idx[-1]][1] = 1 # modify the output 357 | 358 | self.shared_idx = [] 359 | if self.args.model == 'mobilenetv2': # TODO: to be tested! Share index for residual connection 360 | connected_idx = [4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32] # to be partitioned 361 | last_ch = -1 362 | share_group = None 363 | for c_idx in connected_idx: 364 | if self.prunable_ops[c_idx].in_channels != last_ch: # new group 365 | last_ch = self.prunable_ops[c_idx].in_channels 366 | if share_group is not None: 367 | self.shared_idx.append(share_group) 368 | share_group = [c_idx] 369 | else: # same group 370 | share_group.append(c_idx) 371 | print('=> Conv layers to share channels: {}'.format(self.shared_idx)) 372 | 373 | self.min_strategy_dict = copy.deepcopy(self.strategy_dict) 374 | 375 | self.buffer_idx = [] 376 | for k, v in self.buffer_dict.items(): 377 | self.buffer_idx += v 378 | 379 | print('=> Prunable layer idx: {}'.format(self.prunable_idx)) 380 | print('=> Buffer layer idx: {}'.format(self.buffer_idx)) 381 | print('=> Initial min strategy dict: {}'.format(self.min_strategy_dict)) 382 | 383 | # added for supporting residual connections during pruning 384 | self.visited = [False] * len(self.prunable_idx) 385 | self.index_buffer = {} 386 | 387 | def _extract_layer_information(self): 388 | m_list = list(self.model.modules()) 389 | 390 | self.data_saver = [] 391 | self.layer_info_dict = dict() 392 | self.wsize_list = [] 393 | self.flops_list = [] 394 | 395 | from lib.utils import measure_layer_for_pruning 396 | 397 | # extend the forward fn to record layer info 398 | def new_forward(m): 399 | def lambda_forward(x): 400 | m.input_feat = x.clone() 401 | measure_layer_for_pruning(m, x) 402 | y = m.old_forward(x) 403 | m.output_feat = y.clone() 404 | return y 405 | 406 | return lambda_forward 407 | 408 | for idx in self.prunable_idx + self.buffer_idx: # get all 409 | m = m_list[idx] 410 | m.old_forward = m.forward 411 | m.forward = new_forward(m) 412 | 413 | # now let the image flow 414 | print('=> Extracting information...') 415 | with torch.no_grad(): 416 | for i_b, (input, target) in enumerate(self.train_loader): # use image from train set 417 | if i_b == self.n_calibration_batches: 418 | break 419 | self.data_saver.append((input.clone(), target.clone())) 420 | input_var = torch.autograd.Variable(input).cuda() 421 | 422 | # inference and collect stats 423 | _ = self.model(input_var) 424 | 425 | if i_b == 0: # first batch 426 | for idx in self.prunable_idx + self.buffer_idx: 427 | self.layer_info_dict[idx] = dict() 428 | self.layer_info_dict[idx]['params'] = m_list[idx].params 429 | self.layer_info_dict[idx]['flops'] = m_list[idx].flops 430 | self.wsize_list.append(m_list[idx].params) 431 | self.flops_list.append(m_list[idx].flops) 432 | for idx in self.prunable_idx: 433 | f_in_np = m_list[idx].input_feat.data.cpu().numpy() 434 | f_out_np = m_list[idx].output_feat.data.cpu().numpy() 435 | if len(f_in_np.shape) == 4: # conv 436 | if self.prunable_idx.index(idx) == 0: # first conv 437 | f_in2save, f_out2save = None, None 438 | elif m_list[idx].weight.size(3) > 1: # normal conv 439 | f_in2save, f_out2save = f_in_np, f_out_np 440 | else: # 1x1 conv 441 | # assert f_out_np.shape[2] == f_in_np.shape[2] # now support k=3 442 | randx = np.random.randint(0, f_out_np.shape[2] - 0, self.n_points_per_layer) 443 | randy = np.random.randint(0, f_out_np.shape[3] - 0, self.n_points_per_layer) 444 | # input: [N, C, H, W] 445 | self.layer_info_dict[idx][(i_b, 'randx')] = randx.copy() 446 | self.layer_info_dict[idx][(i_b, 'randy')] = randy.copy() 447 | 448 | f_in2save = f_in_np[:, :, randx, randy].copy().transpose(0, 2, 1)\ 449 | .reshape(self.batch_size * self.n_points_per_layer, -1) 450 | 451 | f_out2save = f_out_np[:, :, randx, randy].copy().transpose(0, 2, 1) \ 452 | .reshape(self.batch_size * self.n_points_per_layer, -1) 453 | else: 454 | assert len(f_in_np.shape) == 2 455 | f_in2save = f_in_np.copy() 456 | f_out2save = f_out_np.copy() 457 | if 'input_feat' not in self.layer_info_dict[idx]: 458 | self.layer_info_dict[idx]['input_feat'] = f_in2save 459 | self.layer_info_dict[idx]['output_feat'] = f_out2save 460 | else: 461 | self.layer_info_dict[idx]['input_feat'] = np.vstack( 462 | (self.layer_info_dict[idx]['input_feat'], f_in2save)) 463 | self.layer_info_dict[idx]['output_feat'] = np.vstack( 464 | (self.layer_info_dict[idx]['output_feat'], f_out2save)) 465 | 466 | def _regenerate_input_feature(self): 467 | # only re-generate the input feature 468 | m_list = list(self.model.modules()) 469 | 470 | # delete old features 471 | for k, v in self.layer_info_dict.items(): 472 | if 'input_feat' in v: 473 | v.pop('input_feat') 474 | 475 | # now let the image flow 476 | print('=> Regenerate features...') 477 | 478 | with torch.no_grad(): 479 | for i_b, (input, target) in enumerate(self.data_saver): 480 | input_var = torch.autograd.Variable(input).cuda() 481 | 482 | # inference and collect stats 483 | _ = self.model(input_var) 484 | 485 | for idx in self.prunable_idx: 486 | f_in_np = m_list[idx].input_feat.data.cpu().numpy() 487 | if len(f_in_np.shape) == 4: # conv 488 | if self.prunable_idx.index(idx) == 0: # first conv 489 | f_in2save = None 490 | else: 491 | randx = self.layer_info_dict[idx][(i_b, 'randx')] 492 | randy = self.layer_info_dict[idx][(i_b, 'randy')] 493 | f_in2save = f_in_np[:, :, randx, randy].copy().transpose(0, 2, 1)\ 494 | .reshape(self.batch_size * self.n_points_per_layer, -1) 495 | else: # fc 496 | assert len(f_in_np.shape) == 2 497 | f_in2save = f_in_np.copy() 498 | if 'input_feat' not in self.layer_info_dict[idx]: 499 | self.layer_info_dict[idx]['input_feat'] = f_in2save 500 | else: 501 | self.layer_info_dict[idx]['input_feat'] = np.vstack( 502 | (self.layer_info_dict[idx]['input_feat'], f_in2save)) 503 | 504 | def _build_state_embedding(self): 505 | # build the static part of the state embedding 506 | layer_embedding = [] 507 | module_list = list(self.model.modules()) 508 | for i, ind in enumerate(self.prunable_idx): 509 | m = module_list[ind] 510 | this_state = [] 511 | if type(m) == nn.Conv2d: 512 | this_state.append(i) # index 513 | this_state.append(0) # layer type, 0 for conv 514 | this_state.append(m.in_channels) # in channels 515 | this_state.append(m.out_channels) # out channels 516 | this_state.append(m.stride[0]) # stride 517 | this_state.append(m.kernel_size[0]) # kernel size 518 | this_state.append(np.prod(m.weight.size())) # weight size 519 | elif type(m) == nn.Linear: 520 | this_state.append(i) # index 521 | this_state.append(1) # layer type, 1 for fc 522 | this_state.append(m.in_features) # in channels 523 | this_state.append(m.out_features) # out channels 524 | this_state.append(0) # stride 525 | this_state.append(1) # kernel size 526 | this_state.append(np.prod(m.weight.size())) # weight size 527 | 528 | # this 3 features need to be changed later 529 | this_state.append(0.) # reduced 530 | this_state.append(0.) # rest 531 | this_state.append(1.) # a_{t-1} 532 | layer_embedding.append(np.array(this_state)) 533 | 534 | # normalize the state 535 | layer_embedding = np.array(layer_embedding, 'float') 536 | print('=> shape of embedding (n_layer * n_dim): {}'.format(layer_embedding.shape)) 537 | assert len(layer_embedding.shape) == 2, layer_embedding.shape 538 | for i in range(layer_embedding.shape[1]): 539 | fmin = min(layer_embedding[:, i]) 540 | fmax = max(layer_embedding[:, i]) 541 | if fmax - fmin > 0: 542 | layer_embedding[:, i] = (layer_embedding[:, i] - fmin) / (fmax - fmin) 543 | 544 | self.layer_embedding = layer_embedding 545 | 546 | def _validate(self, val_loader, model, verbose=False): 547 | ''' 548 | Validate the performance on validation set 549 | :param val_loader: 550 | :param model: 551 | :param verbose: 552 | :return: 553 | ''' 554 | batch_time = AverageMeter() 555 | losses = AverageMeter() 556 | top1 = AverageMeter() 557 | top5 = AverageMeter() 558 | 559 | criterion = nn.CrossEntropyLoss().cuda() 560 | # switch to evaluate mode 561 | model.eval() 562 | end = time.time() 563 | 564 | t1 = time.time() 565 | with torch.no_grad(): 566 | for i, (input, target) in enumerate(val_loader): 567 | target = target.cuda(non_blocking=True) 568 | input_var = torch.autograd.Variable(input).cuda() 569 | target_var = torch.autograd.Variable(target).cuda() 570 | 571 | # compute output 572 | output = model(input_var) 573 | loss = criterion(output, target_var) 574 | 575 | # measure accuracy and record loss 576 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 577 | losses.update(loss.item(), input.size(0)) 578 | top1.update(prec1.item(), input.size(0)) 579 | top5.update(prec5.item(), input.size(0)) 580 | 581 | # measure elapsed time 582 | batch_time.update(time.time() - end) 583 | end = time.time() 584 | t2 = time.time() 585 | if verbose: 586 | print('* Test loss: %.3f top1: %.3f top5: %.3f time: %.3f' % 587 | (losses.avg, top1.avg, top5.avg, t2 - t1)) 588 | if self.acc_metric == 'acc1': 589 | return top1.avg 590 | elif self.acc_metric == 'acc5': 591 | return top5.avg 592 | else: 593 | raise NotImplementedError 594 | --------------------------------------------------------------------------------