├── .gitignore ├── README.md ├── cal_flops_params.py ├── data ├── __init__.py ├── cifar10.py └── imagenet.py ├── evaluate.py ├── get_flops.py ├── img ├── framework.jpeg └── framework.png ├── main.py ├── mask.py ├── models ├── __init__.py ├── densenet_cifar.py ├── googlenet_cifar.py ├── resnet_cifar.py ├── resnet_imagenet.py └── vgg.py ├── rank_conv ├── densenet_40 │ ├── rank_conv1.npy │ ├── rank_conv10.npy │ ├── rank_conv11.npy │ ├── rank_conv12.npy │ ├── rank_conv13.npy │ ├── rank_conv14.npy │ ├── rank_conv15.npy │ ├── rank_conv16.npy │ ├── rank_conv17.npy │ ├── rank_conv18.npy │ ├── rank_conv19.npy │ ├── rank_conv2.npy │ ├── rank_conv20.npy │ ├── rank_conv21.npy │ ├── rank_conv22.npy │ ├── rank_conv23.npy │ ├── rank_conv24.npy │ ├── rank_conv25.npy │ ├── rank_conv26.npy │ ├── rank_conv27.npy │ ├── rank_conv28.npy │ ├── rank_conv29.npy │ ├── rank_conv3.npy │ ├── rank_conv30.npy │ ├── rank_conv31.npy │ ├── rank_conv32.npy │ ├── rank_conv33.npy │ ├── rank_conv34.npy │ ├── rank_conv35.npy │ ├── rank_conv36.npy │ ├── rank_conv37.npy │ ├── rank_conv38.npy │ ├── rank_conv39.npy │ ├── rank_conv4.npy │ ├── rank_conv5.npy │ ├── rank_conv6.npy │ ├── rank_conv7.npy │ ├── rank_conv8.npy │ └── rank_conv9.npy ├── googlenet │ ├── rank_conv1.npy │ ├── rank_conv10_n1x1.npy │ ├── rank_conv10_n3x3.npy │ ├── rank_conv10_n5x5.npy │ ├── rank_conv10_pool_planes.npy │ ├── rank_conv2_n1x1.npy │ ├── rank_conv2_n3x3.npy │ ├── rank_conv2_n5x5.npy │ ├── rank_conv2_pool_planes.npy │ ├── rank_conv3_n1x1.npy │ ├── rank_conv3_n3x3.npy │ ├── rank_conv3_n5x5.npy │ ├── rank_conv3_pool_planes.npy │ ├── rank_conv4_n1x1.npy │ ├── rank_conv4_n3x3.npy │ ├── rank_conv4_n5x5.npy │ ├── rank_conv4_pool_planes.npy │ ├── rank_conv5_n1x1.npy │ ├── rank_conv5_n3x3.npy │ ├── rank_conv5_n5x5.npy │ ├── rank_conv5_pool_planes.npy │ ├── rank_conv6_n1x1.npy │ ├── rank_conv6_n3x3.npy │ ├── rank_conv6_n5x5.npy │ ├── rank_conv6_pool_planes.npy │ ├── rank_conv7_n1x1.npy │ ├── rank_conv7_n3x3.npy │ ├── rank_conv7_n5x5.npy │ ├── rank_conv7_pool_planes.npy │ ├── rank_conv8_n1x1.npy │ ├── rank_conv8_n3x3.npy │ ├── rank_conv8_n5x5.npy │ ├── rank_conv8_pool_planes.npy │ ├── rank_conv9_n1x1.npy │ ├── rank_conv9_n3x3.npy │ ├── rank_conv9_n5x5.npy │ └── rank_conv9_pool_planes.npy ├── resnet_110 │ ├── rank_conv1.npy │ ├── rank_conv10.npy │ ├── rank_conv100.npy │ ├── rank_conv101.npy │ ├── rank_conv102.npy │ ├── rank_conv103.npy │ ├── rank_conv104.npy │ ├── rank_conv105.npy │ ├── rank_conv106.npy │ ├── rank_conv107.npy │ ├── rank_conv108.npy │ ├── rank_conv109.npy │ ├── rank_conv11.npy │ ├── rank_conv12.npy │ ├── rank_conv13.npy │ ├── rank_conv14.npy │ ├── rank_conv15.npy │ ├── rank_conv16.npy │ ├── rank_conv17.npy │ ├── rank_conv18.npy │ ├── rank_conv19.npy │ ├── rank_conv2.npy │ ├── rank_conv20.npy │ ├── rank_conv21.npy │ ├── rank_conv22.npy │ ├── rank_conv23.npy │ ├── rank_conv24.npy │ ├── rank_conv25.npy │ ├── rank_conv26.npy │ ├── rank_conv27.npy │ ├── rank_conv28.npy │ ├── rank_conv29.npy │ ├── rank_conv3.npy │ ├── rank_conv30.npy │ ├── rank_conv31.npy │ ├── rank_conv32.npy │ ├── rank_conv33.npy │ ├── rank_conv34.npy │ ├── rank_conv35.npy │ ├── rank_conv36.npy │ ├── rank_conv37.npy │ ├── rank_conv38.npy │ ├── rank_conv39.npy │ ├── rank_conv4.npy │ ├── rank_conv40.npy │ ├── rank_conv41.npy │ ├── rank_conv42.npy │ ├── rank_conv43.npy │ ├── rank_conv44.npy │ ├── rank_conv45.npy │ ├── rank_conv46.npy │ ├── rank_conv47.npy │ ├── rank_conv48.npy │ ├── rank_conv49.npy │ ├── rank_conv5.npy │ ├── rank_conv50.npy │ ├── rank_conv51.npy │ ├── rank_conv52.npy │ ├── rank_conv53.npy │ ├── rank_conv54.npy │ ├── rank_conv55.npy │ ├── rank_conv56.npy │ ├── rank_conv57.npy │ ├── rank_conv58.npy │ ├── rank_conv59.npy │ ├── rank_conv6.npy │ ├── rank_conv60.npy │ ├── rank_conv61.npy │ ├── rank_conv62.npy │ ├── rank_conv63.npy │ ├── rank_conv64.npy │ ├── rank_conv65.npy │ ├── rank_conv66.npy │ ├── rank_conv67.npy │ ├── rank_conv68.npy │ ├── rank_conv69.npy │ ├── rank_conv7.npy │ ├── rank_conv70.npy │ ├── rank_conv71.npy │ ├── rank_conv72.npy │ ├── rank_conv73.npy │ ├── rank_conv74.npy │ ├── rank_conv75.npy │ ├── rank_conv76.npy │ ├── rank_conv77.npy │ ├── rank_conv78.npy │ ├── rank_conv79.npy │ ├── rank_conv8.npy │ ├── rank_conv80.npy │ ├── rank_conv81.npy │ ├── rank_conv82.npy │ ├── rank_conv83.npy │ ├── rank_conv84.npy │ ├── rank_conv85.npy │ ├── rank_conv86.npy │ ├── rank_conv87.npy │ ├── rank_conv88.npy │ ├── rank_conv89.npy │ ├── rank_conv9.npy │ ├── rank_conv90.npy │ ├── rank_conv91.npy │ ├── rank_conv92.npy │ ├── rank_conv93.npy │ ├── rank_conv94.npy │ ├── rank_conv95.npy │ ├── rank_conv96.npy │ ├── rank_conv97.npy │ ├── rank_conv98.npy │ └── rank_conv99.npy ├── resnet_50 │ ├── rank_conv1.npy │ ├── rank_conv10.npy │ ├── rank_conv11.npy │ ├── rank_conv12.npy │ ├── rank_conv13.npy │ ├── rank_conv14.npy │ ├── rank_conv15.npy │ ├── rank_conv16.npy │ ├── rank_conv17.npy │ ├── rank_conv18.npy │ ├── rank_conv19.npy │ ├── rank_conv2.npy │ ├── rank_conv20.npy │ ├── rank_conv21.npy │ ├── rank_conv22.npy │ ├── rank_conv23.npy │ ├── rank_conv24.npy │ ├── rank_conv25.npy │ ├── rank_conv26.npy │ ├── rank_conv27.npy │ ├── rank_conv28.npy │ ├── rank_conv29.npy │ ├── rank_conv3.npy │ ├── rank_conv30.npy │ ├── rank_conv31.npy │ ├── rank_conv32.npy │ ├── rank_conv33.npy │ ├── rank_conv34.npy │ ├── rank_conv35.npy │ ├── rank_conv36.npy │ ├── rank_conv37.npy │ ├── rank_conv38.npy │ ├── rank_conv39.npy │ ├── rank_conv4.npy │ ├── rank_conv40.npy │ ├── rank_conv41.npy │ ├── rank_conv42.npy │ ├── rank_conv43.npy │ ├── rank_conv44.npy │ ├── rank_conv45.npy │ ├── rank_conv46.npy │ ├── rank_conv47.npy │ ├── rank_conv48.npy │ ├── rank_conv49.npy │ ├── rank_conv5.npy │ ├── rank_conv50.npy │ ├── rank_conv51.npy │ ├── rank_conv52.npy │ ├── rank_conv53.npy │ ├── rank_conv6.npy │ ├── rank_conv7.npy │ ├── rank_conv8.npy │ └── rank_conv9.npy ├── resnet_56 │ ├── rank_conv1.npy │ ├── rank_conv10.npy │ ├── rank_conv11.npy │ ├── rank_conv12.npy │ ├── rank_conv13.npy │ ├── rank_conv14.npy │ ├── rank_conv15.npy │ ├── rank_conv16.npy │ ├── rank_conv17.npy │ ├── rank_conv18.npy │ ├── rank_conv19.npy │ ├── rank_conv2.npy │ ├── rank_conv20.npy │ ├── rank_conv21.npy │ ├── rank_conv22.npy │ ├── rank_conv23.npy │ ├── rank_conv24.npy │ ├── rank_conv25.npy │ ├── rank_conv26.npy │ ├── rank_conv27.npy │ ├── rank_conv28.npy │ ├── rank_conv29.npy │ ├── rank_conv3.npy │ ├── rank_conv30.npy │ ├── rank_conv31.npy │ ├── rank_conv32.npy │ ├── rank_conv33.npy │ ├── rank_conv34.npy │ ├── rank_conv35.npy │ ├── rank_conv36.npy │ ├── rank_conv37.npy │ ├── rank_conv38.npy │ ├── rank_conv39.npy │ ├── rank_conv4.npy │ ├── rank_conv40.npy │ ├── rank_conv41.npy │ ├── rank_conv42.npy │ ├── rank_conv43.npy │ ├── rank_conv44.npy │ ├── rank_conv45.npy │ ├── rank_conv46.npy │ ├── rank_conv47.npy │ ├── rank_conv48.npy │ ├── rank_conv49.npy │ ├── rank_conv5.npy │ ├── rank_conv50.npy │ ├── rank_conv51.npy │ ├── rank_conv52.npy │ ├── rank_conv53.npy │ ├── rank_conv54.npy │ ├── rank_conv55.npy │ ├── rank_conv6.npy │ ├── rank_conv7.npy │ ├── rank_conv8.npy │ └── rank_conv9.npy └── vgg_16_bn │ ├── rank_conv1.npy │ ├── rank_conv10.npy │ ├── rank_conv11.npy │ ├── rank_conv12.npy │ ├── rank_conv2.npy │ ├── rank_conv3.npy │ ├── rank_conv4.npy │ ├── rank_conv5.npy │ ├── rank_conv6.npy │ ├── rank_conv7.npy │ ├── rank_conv8.npy │ └── rank_conv9.npy ├── rank_generation.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | */.DS_Store 3 | .idea 4 | __pycache__ 5 | command_cifar.md 6 | models/__pycache__ 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HRank: Filter Pruning using High-Rank Feature Map ([Link](https://128.84.21.199/abs/2002.10179))![]( https://visitor-badge.glitch.me/badge?page_id=lmbxmu.hrank). 2 | 3 | Pytorch implementation of HRank (CVPR 2020, Oral). 4 | 5 | This repository was what we used during the preparation of CVPR 2020. You can enjoy it! 6 | 7 | A better version of HRank (in both fine-tuning efficiency and accuracy performance) are released at [HRankPlus](https://github.com/lmbxmu/HRankPlus). It is highly suggested! 8 | 9 | 10 |
11 | 12 | Framework of HRank. In the left column, we first use images to run through the convolutional layers to get the feature maps. In the middle column, we then estimate the rank of each feature map, which is used as the criteria for pruning. The right column shows the pruning (the red filters), and fine-tuning where the green filters are updated and the blue filters are frozen. 13 | 14 | 15 | ## Tips 16 | 17 | Any problem, please contact the authors via emails: lmbxmu@stu.xmu.edu.cn or ethan.zhangyc@gmail.com. Do not post issues with github as much as possible, just in case that I could not receive the emails from github thus ignore the posted issues. 18 | 19 | 20 | ## Citation 21 | If you find HRank useful in your research, please consider citing: 22 | 23 | ``` 24 | @inproceedings{lin2020hrank, 25 | title={HRank: Filter Pruning using High-Rank Feature Map}, 26 | author={Lin, Mingbao and Ji, Rongrong and Wang, Yan and Zhang, Yichen and Zhang, Baochang and Tian, Yonghong and Shao, Ling}, 27 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 28 | pages={1529--1538}, 29 | year={2020} 30 | } 31 | ``` 32 | 33 | ## Running Code 34 | 35 | In this code, you can run our models on CIFAR-10 and ImageNet dataset. The code has been tested by Python 3.6, Pytorch 1.0 and CUDA 9.0 on Ubuntu 16.04. 36 | 37 | 38 | ### Rank Generation 39 | 40 | ```shell 41 | python rank_generation.py \ 42 | --resume [pre-trained model dir] \ 43 | --arch [model arch name] \ 44 | --limit [batch numbers] \ 45 | --gpu [gpu_id] 46 | 47 | ``` 48 | 49 | 50 | 51 | ### Model Training 52 | 53 | For the ease of reproducibility. we provide some of the experimental results and the corresponding pruned rate of every layer as belows: 54 | #### Attention! The actual pruning rates are much higher than these presented in the paper since we do not count the next-layer channel removal (For example, if 50 filters are removed in the first layer, then the corresponding 50 channels in the second-layer filters should be removed as well). 55 | 56 | ##### 1. VGG-16 57 | 58 | | Params | Flops | Accuracy | 59 | |--------------|---------------|----------| 60 | | 2.64M(82.1%) | 108.61M(65.3%)| 92.34% | 61 | 62 | ```shell 63 | python main.py \ 64 | --job_dir ./result/vgg_16_bn/[folder name] \ 65 | --resume [pre-trained model dir] \ 66 | --arch vgg_16_bn \ 67 | --compress_rate [0.95]+[0.5]*6+[0.9]*4+[0.8]*2 \ 68 | --gpu [gpu_id] 69 | ``` 70 | ##### 2. ResNet56 71 | 72 | | Params | Flops | Accuracy | 73 | |--------------|--------------|----------| 74 | | 0.49M(42.4%) | 62.72M(50.0%)| 93.17% | 75 | 76 | ```shell 77 | python main.py \ 78 | --job_dir ./result/resnet_56/[folder name] \ 79 | --resume [pre-trained model dir] \ 80 | --arch resnet_56 \ 81 | --compress_rate [0.1]+[0.60]*35+[0.0]*2+[0.6]*6+[0.4]*3+[0.1]+[0.4]+[0.1]+[0.4]+[0.1]+[0.4]+[0.1]+[0.4] \ 82 | --gpu [gpu_id] 83 | ``` 84 | ##### 3. ResNet110 85 | Note that, in the paper, we mistakenly regarded the FLOPs as 148.70M(41.2%). We apologize for it and We will update the arXiv version as soon as possible. 86 | 87 | | Params | Flops | Accuracy | 88 | |--------------|--------------|----------| 89 | | 1.04M(38.7%) |156.90M(37.9%)| 94.23% | 90 | 91 | ```shell 92 | python main.py \ 93 | --job_dir ./result/resnet_110/[folder name] \ 94 | --resume [pre-trained model dir] \ 95 | --arch resnet_110 \ 96 | --compress_rate [0.1]+[0.40]*36+[0.40]*36+[0.4]*36 \ 97 | --gpu [gpu_id] 98 | ``` 99 | ##### 4. DenseNet40 100 | 101 | | Params | Flops | Accuracy | 102 | |--------------|--------------|----------| 103 | | 0.66M(36.5%) |167.41M(40.8%)| 94.24% | 104 | 105 | ```shell 106 | python main.py \ 107 | --job_dir ./result/densenet_40/[folder name] \ 108 | --resume [pre-trained model dir] \ 109 | --arch densenet_40 \ 110 | --compress_rate [0.0]+[0.1]*6+[0.7]*6+[0.0]+[0.1]*6+[0.7]*6+[0.0]+[0.1]*6+[0.7]*5+[0.0] \ 111 | --gpu [gpu_id] 112 | ``` 113 | ##### 5. GoogLeNet 114 | 115 | | Params | Flops | Accuracy | 116 | |--------------|--------------|----------| 117 | | 1.86M(69.8%) | 0.45B(70.4%)| 94.07% | 118 | 119 | ```shell 120 | python main.py \ 121 | --job_dir ./result/googlenet/[folder name] \ 122 | --resume [pre-trained model dir] \ 123 | --arch googlenet \ 124 | --compress_rate [0.10]+[0.8]*5+[0.85]+[0.8]*3 \ 125 | --gpu [gpu_id] 126 | ``` 127 | ##### 6. ResNet50 128 | 129 | | Params | Flops| Acc Top1 |Acc Top5 | 130 | |---------|------|----------|----------| 131 | | 13.77M |1.55B | 71.98%| 91.01% | 132 | 133 | ```shell 134 | python main.py \ 135 | --dataset imagenet \ 136 | --data_dir [ImageNet dataset dir] \ 137 | --job_dir./result/resnet_50/[folder name] \ 138 | --resume [pre-trained model dir] \ 139 | --arch resnet_50 \ 140 | --compress_rate [0.2]+[0.8]*10+[0.8]*13+[0.55]*19+[0.45]*10 \ 141 | --gpu [gpu_id] 142 | ``` 143 | 144 | After training, checkpoints and loggers can be found in the `job_dir`. The pruned model will be named `[arch]_cov[i]` for stage i, and therefore the final pruned model is the one with largest `i`. 145 | 146 | ### Get FLOPS & Params 147 | ```shell 148 | python cal_flops_params.py \ 149 | --arch resnet_56_convwise \ 150 | --compress_rate [0.1]+[0.60]*35+[0.0]*2+[0.6]*6+[0.4]*3+[0.1]+[0.4]+[0.1]+[0.4]+[0.1]+[0.4]+[0.1]+[0.4] 151 | ``` 152 | 153 | ### Evaluate Final Performance 154 | ```shell 155 | python evaluate.py \ 156 | --dataset [dataset name] \ 157 | --data_dir [dataset dir] \ 158 | --test_model_dir [job dir of test model] \ 159 | --arch [arch name] \ 160 | --gpu [gpu id] 161 | ``` 162 | 163 | 164 | ## Other optional arguments 165 | ``` 166 | optional arguments: 167 | --data_dir dataset directory 168 | default='./data' 169 | --dataset dataset name 170 | default: cifar10 171 | Optional: cifar10', imagenet 172 | --lr initial learning rate 173 | default: 0.01 174 | --lr_decay_step learning rate decay step 175 | default: 5,10 176 | --resume load the model from the specified checkpoint 177 | --resume_mask mask loading directory 178 | --gpu Select gpu to use 179 | default: 0 180 | --job_dir The directory where the summaries will be stored. 181 | --epochs The num of epochs to train. 182 | default: 30 183 | --train_batch_size Batch size for training. 184 | default: 128 185 | --eval_batch_size Batch size for validation. 186 | default: 100 187 | --start_cov The num of conv to start prune 188 | default: 0 189 | --compress_rate compress rate of each conv 190 | --arch The architecture to prune 191 | default: vgg_16_bn 192 | Optional: resnet_50, vgg_16_bn, resnet_56, resnet_110, densenet_40, googlenet 193 | ``` 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | ## Pre-trained Models 203 | 204 | Additionally, we provide the pre-trained models used in our experiments. 205 | 206 | 207 | ### CIFAR-10: 208 | [Vgg-16](https://drive.google.com/open?id=1i3ifLh70y1nb8d4mazNzyC4I27jQcHrE) 209 | | [ResNet56](https://drive.google.com/open?id=1f1iSGvYFjSKIvzTko4fXFCbS-8dw556T) 210 | | [ResNet110](https://drive.google.com/open?id=1uENM3S5D_IKvXB26b1BFwMzUpkOoA26m) 211 | | [DenseNet-40](https://drive.google.com/open?id=12rInJ0YpGwZd_k76jctQwrfzPubsfrZH) 212 | | [GoogLeNet](https://drive.google.com/open?id=1rYMazSyMbWwkCGCLvofNKwl58W6mmg5c) 213 | 214 | ### ImageNet: 215 | [ResNet50](https://drive.google.com/open?id=1OYpVB84BMU0y-KU7PdEPhbHwODmFvPbB) 216 | -------------------------------------------------------------------------------- /cal_flops_params.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import argparse 4 | import get_flops 5 | from models import * 6 | 7 | parser = argparse.ArgumentParser(description='Calculating flops and params') 8 | 9 | parser.add_argument( 10 | '--input_image_size', 11 | type=int, 12 | default=32, 13 | help='The input_image_size') 14 | parser.add_argument( 15 | '--arch', 16 | type=str, 17 | default='vgg_16_bn', 18 | choices=('vgg_16_bn','resnet_56','resnet_110','densenet_40','googlenet','resnet_50'), 19 | help='The architecture to prune') 20 | parser.add_argument( 21 | '--compress_rate', 22 | type=str, 23 | default=None, 24 | help='compress rate of each conv') 25 | args = parser.parse_args() 26 | 27 | device = torch.device("cpu") 28 | 29 | 30 | if args.compress_rate: 31 | import re 32 | cprate_str=args.compress_rate 33 | cprate_str_list=cprate_str.split('+') 34 | pat_cprate=re.compile(r'\d+\.\d*') 35 | pat_num = re.compile(r'\*\d+') 36 | cprate=[] 37 | for x in cprate_str_list: 38 | num=1 39 | find_num=re.findall(pat_num,x) 40 | if find_num: 41 | assert len(find_num) == 1 42 | num=int(find_num[0].replace('*','')) 43 | find_cprate = re.findall(pat_cprate, x) 44 | assert len(find_cprate)==1 45 | print(float(find_cprate[0]),num) 46 | cprate+=[float(find_cprate[0])]*num 47 | compress_rate=cprate 48 | print(compress_rate) 49 | 50 | if args.arch=='vgg_16_bn': 51 | compress_rate[12]=0. 52 | 53 | print('==> Building model..') 54 | net = eval(args.arch)(compress_rate=compress_rate) 55 | print(net.compress_rate) 56 | net.eval() 57 | 58 | if args.arch=='googlenet' or args.arch=='resnet_50': 59 | flops, params = get_flops.measure_model(net, device, 3, args.input_image_size, args.input_image_size, True) 60 | else: 61 | flops, params= get_flops.measure_model(net,device,3,args.input_image_size,args.input_image_size) 62 | 63 | print('Params: %.2f'%(params)) 64 | print('Flops: %.2f'%(flops)) 65 | 66 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /data/cifar10.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets import CIFAR10 2 | from torch.utils.data import Dataset, DataLoader 3 | import torchvision.transforms as transforms 4 | 5 | class Data: 6 | def __init__(self, args): 7 | # pin_memory = False 8 | # if args.gpu is not None: 9 | pin_memory = True 10 | 11 | transform_train = transforms.Compose([ 12 | transforms.RandomCrop(32, padding=4), 13 | transforms.RandomHorizontalFlip(), 14 | transforms.ToTensor(), 15 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 16 | ]) 17 | transform_test = transforms.Compose([ 18 | transforms.ToTensor(), 19 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 20 | ]) 21 | 22 | trainset = CIFAR10(root=args.data_dir, train=True, download=True, transform=transform_train) 23 | self.loader_train = DataLoader( 24 | trainset, batch_size=args.train_batch_size, shuffle=True, 25 | num_workers=2, pin_memory=pin_memory 26 | ) 27 | 28 | testset = CIFAR10(root=args.data_dir, train=False, download=True, transform=transform_test) 29 | self.loader_test = DataLoader( 30 | testset, batch_size=args.eval_batch_size, shuffle=False, 31 | num_workers=2, pin_memory=pin_memory) 32 | -------------------------------------------------------------------------------- /data/imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torchvision.transforms as transforms 3 | import torchvision.datasets as datasets 4 | from torch.utils.data import DataLoader 5 | 6 | 7 | class Data: 8 | def __init__(self, args, is_evaluate=False): 9 | pin_memory = False 10 | if args.gpu is not None: 11 | pin_memory = True 12 | 13 | scale_size = 224 14 | 15 | traindir = os.path.join(args.data_dir, 'ILSVRC2012_img_train') 16 | valdir = os.path.join(args.data_dir, 'val') 17 | normalize = transforms.Normalize( 18 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 19 | 20 | if not is_evaluate: 21 | trainset = datasets.ImageFolder( 22 | traindir, 23 | transforms.Compose([ 24 | transforms.RandomResizedCrop(224), 25 | transforms.RandomHorizontalFlip(), 26 | transforms.Resize(scale_size), 27 | transforms.ToTensor(), 28 | normalize, 29 | ])) 30 | 31 | self.loader_train = DataLoader( 32 | trainset, 33 | batch_size=args.train_batch_size, 34 | shuffle=True, 35 | num_workers=2, 36 | pin_memory=pin_memory) 37 | 38 | testset = datasets.ImageFolder( 39 | valdir, 40 | transforms.Compose([ 41 | transforms.Resize(256), 42 | transforms.CenterCrop(224), 43 | transforms.Resize(scale_size), 44 | transforms.ToTensor(), 45 | normalize, 46 | ])) 47 | 48 | self.loader_test = DataLoader( 49 | testset, 50 | batch_size=args.eval_batch_size, 51 | shuffle=False, 52 | num_workers=2, 53 | pin_memory=True) 54 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import torch.backends.cudnn as cudnn 7 | 8 | import torchvision 9 | import torchvision.transforms as transforms 10 | 11 | import argparse 12 | 13 | from data import imagenet 14 | from models import * 15 | from mask import * 16 | import utils 17 | 18 | 19 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Evaluate') 20 | parser.add_argument( 21 | '--data_dir', 22 | type=str, 23 | default='./data', 24 | help='dataset path') 25 | parser.add_argument( 26 | '--dataset', 27 | type=str, 28 | default='cifar10', 29 | choices=('cifar10','imagenet'), 30 | help='dataset') 31 | parser.add_argument( 32 | '--arch', 33 | type=str, 34 | default='vgg_16_bn', 35 | choices=('resnet_50','vgg_16_bn','resnet_56','resnet_110','densenet_40','googlenet'), 36 | help='The architecture to prune') 37 | parser.add_argument( 38 | '--test_model_dir', 39 | type=str, 40 | default='./result/tmp/', 41 | help='The directory where the summaries will be stored.') 42 | parser.add_argument( 43 | '--eval_batch_size', 44 | type=int, 45 | default=100, 46 | help='Batch size for validation.') 47 | parser.add_argument( 48 | '--gpu', 49 | type=str, 50 | default='0', 51 | help='Select gpu to use') 52 | 53 | args = parser.parse_args() 54 | 55 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 56 | cudnn.benchmark = True 57 | cudnn.enabled = True 58 | 59 | # Data 60 | print('==> Preparing data..') 61 | if args.dataset=='cifar10': 62 | transform_test = transforms.Compose([ 63 | transforms.ToTensor(), 64 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 65 | ]) 66 | testset = torchvision.datasets.CIFAR10(root=args.data_dir, train=False, download=True, transform=transform_test) 67 | testloader = torch.utils.data.DataLoader(testset, batch_size=args.eval_batch_size, shuffle=False, num_workers=2) 68 | print_freq = 3000 // args.eval_batch_size 69 | elif args.dataset=='imagenet': 70 | data_tmp = imagenet.Data(args, is_evaluate=True) 71 | testloader = data_tmp.loader_test 72 | print_freq = 10000 // args.eval_batch_size 73 | else: 74 | raise NotImplementedError 75 | 76 | # Model 77 | print('==> Building model..') 78 | net = eval(args.arch)(compress_rate=[0.]*200) 79 | net = net.cuda() 80 | print(net) 81 | 82 | if len(args.gpu)>1 and torch.cuda.is_available(): 83 | device_id = [] 84 | for i in range((len(args.gpu) + 1) // 2): 85 | device_id.append(i) 86 | net = torch.nn.DataParallel(net, device_ids=device_id) 87 | 88 | def test(): 89 | top1 = utils.AverageMeter() 90 | top5 = utils.AverageMeter() 91 | 92 | net.eval() 93 | num_iterations = len(testloader) 94 | with torch.no_grad(): 95 | for batch_idx, (inputs, targets) in enumerate(testloader): 96 | inputs, targets = inputs.cuda(), targets.cuda() 97 | outputs = net(inputs) 98 | 99 | prec1, prec5 = utils.accuracy(outputs, targets, topk=(1, 5)) 100 | top1.update(prec1[0], inputs.size(0)) 101 | top5.update(prec5[0], inputs.size(0)) 102 | 103 | if batch_idx%print_freq==0: 104 | print( 105 | '({0}/{1}): ' 106 | 'Prec@1(1,5) {top1.avg:.2f}, {top5.avg:.2f}'.format( 107 | batch_idx, num_iterations, top1=top1, top5=top5)) 108 | 109 | print("Final Prec@1(1,5) {top1.avg:.2f}, {top5.avg:.2f}".format(top1=top1, top5=top5)) 110 | 111 | 112 | if len(args.gpu)>1: 113 | convcfg = net.module.covcfg 114 | else: 115 | convcfg = net.covcfg 116 | 117 | cov_id=len(convcfg) 118 | new_state_dict = OrderedDict() 119 | pruned_checkpoint = torch.load(args.test_model_dir + "/pruned_checkpoint/" + args.arch + "_cov" + str(cov_id) + '.pt', 120 | map_location='cuda:0') 121 | tmp_ckpt = pruned_checkpoint['state_dict'] 122 | if len(args.gpu) == 1: 123 | for k, v in tmp_ckpt.items(): 124 | new_state_dict[k.replace('module.', '')] = v 125 | else: 126 | for k, v in tmp_ckpt.items(): 127 | new_state_dict['module.' + k.replace('module.', '')] = v 128 | net.load_state_dict(new_state_dict) 129 | 130 | test() 131 | 132 | -------------------------------------------------------------------------------- /get_flops.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import unicode_literals 3 | from __future__ import print_function 4 | from __future__ import division 5 | 6 | import torch 7 | from torch.autograd import Variable 8 | from functools import reduce 9 | import operator 10 | 11 | count_ops = 0 12 | count_params = 0 13 | 14 | 15 | def get_num_gen(gen): 16 | return sum(1 for x in gen) 17 | 18 | def is_pruned(layer): 19 | try: 20 | layer.mask 21 | return True 22 | except AttributeError: 23 | return False 24 | 25 | def is_leaf(model): 26 | return get_num_gen(model.children()) == 0 27 | 28 | def get_layer_info(layer): 29 | layer_str = str(layer) 30 | # print(layer_str) 31 | type_name = layer_str[:layer_str.find('(')].strip() 32 | return type_name 33 | 34 | def get_layer_param(model, is_conv=True): 35 | if is_conv: 36 | total=0. 37 | for idx, param in enumerate(model.parameters()): 38 | assert idx<2 39 | f = param.size()[0] 40 | pruned_num = int(model.cp_rate * f) 41 | if len(param.size())>1: 42 | c=param.size()[1] 43 | if hasattr(model,'last_prune_num'): 44 | last_prune_num=model.last_prune_num 45 | total += (f - pruned_num) * (c-last_prune_num) * param.numel() / f / c 46 | else: 47 | total += (f - pruned_num) * param.numel() / f 48 | else: 49 | total += (f - pruned_num) * param.numel() / f 50 | return total 51 | else: 52 | return sum([reduce(operator.mul, i.size(), 1) for i in model.parameters()]) 53 | 54 | ### The input batch size should be 1 to call this function 55 | def measure_layer(layer, x, print_name): 56 | global count_ops, count_params 57 | delta_ops = 0 58 | delta_params = 0 59 | multi_add = 1 60 | type_name = get_layer_info(layer) 61 | 62 | ### ops_conv 63 | if type_name in ['Conv2d']: 64 | out_h = int((x.size()[2] + 2 * layer.padding[0] - layer.kernel_size[0]) / 65 | layer.stride[0] + 1) 66 | out_w = int((x.size()[3] + 2 * layer.padding[1] - layer.kernel_size[1]) / 67 | layer.stride[1] + 1) 68 | pruned_num = int(layer.cp_rate * layer.out_channels) 69 | 70 | if hasattr(layer,'tmp_name') and 'trans' in layer.tmp_name: 71 | delta_ops = (layer.in_channels-layer.last_prune_num) * (layer.out_channels - pruned_num) * layer.kernel_size[0] * \ 72 | layer.kernel_size[1] * out_h * out_w / layer.groups * multi_add 73 | else: 74 | delta_ops = layer.in_channels * (layer.out_channels-pruned_num) * layer.kernel_size[0] * \ 75 | layer.kernel_size[1] * out_h * out_w / layer.groups * multi_add 76 | 77 | delta_ops_ori = layer.in_channels * layer.out_channels * layer.kernel_size[0] * \ 78 | layer.kernel_size[1] * out_h * out_w / layer.groups * multi_add 79 | 80 | delta_params = get_layer_param(layer) 81 | 82 | if print_name: 83 | print(layer.tmp_name, layer.cp_rate, '| input:',x.size(),'| weight:',[layer.out_channels, layer.in_channels, layer.kernel_size[0], layer.kernel_size[1]], 84 | '| params:', delta_params, '| flops:', delta_ops_ori) 85 | else: 86 | print(layer.cp_rate, [layer.out_channels,layer.in_channels,layer.kernel_size[0],layer.kernel_size[1]], 87 | 'params:',delta_params, ' flops:',delta_ops_ori) 88 | 89 | ### ops_linear 90 | elif type_name in ['Linear']: 91 | weight_ops = layer.weight.numel() * multi_add 92 | bias_ops = layer.bias.numel() 93 | delta_ops = x.size()[0] * (weight_ops + bias_ops) 94 | delta_params = get_layer_param(layer, is_conv=False) 95 | 96 | print('linear:',layer, delta_ops, delta_params) 97 | 98 | elif type_name in ['DenseBasicBlock', 'ResBasicBlock']: 99 | measure_layer(layer.conv1, x) 100 | 101 | elif type_name in ['Inception']: 102 | measure_layer(layer.conv1, x) 103 | 104 | elif type_name in ['DenseBottleneck', 'SparseDenseBottleneck']: 105 | measure_layer(layer.conv1, x) 106 | 107 | elif type_name in ['Transition', 'SparseTransition']: 108 | measure_layer(layer.conv1, x) 109 | 110 | elif type_name in ['ReLU', 'BatchNorm1d','BatchNorm2d', 'Dropout2d', 'DropChannel', 'Dropout', 'AdaptiveAvgPool2d', 'AvgPool2d', 'MaxPool2d', 'Mask', 'channel_selection', 'LambdaLayer', 'Sequential']: 111 | return 112 | ### unknown layer type 113 | else: 114 | raise TypeError('unknown layer type: %s' % type_name) 115 | 116 | count_ops += delta_ops 117 | count_params += delta_params 118 | return 119 | 120 | def measure_model(model, device, C, H, W, print_name=False): 121 | global count_ops, count_params 122 | count_ops = 0 123 | count_params = 0 124 | data = Variable(torch.zeros(1, C, H, W)).to(device) 125 | model = model.to(device) 126 | model.eval() 127 | 128 | def should_measure(x): 129 | return is_leaf(x) 130 | 131 | def modify_forward(model, print_name): 132 | for child in model.children(): 133 | if should_measure(child): 134 | def new_forward(m): 135 | def lambda_forward(x): 136 | measure_layer(m, x, print_name) 137 | return m.old_forward(x) 138 | return lambda_forward 139 | child.old_forward = child.forward 140 | child.forward = new_forward(child) 141 | else: 142 | modify_forward(child, print_name) 143 | 144 | def restore_forward(model): 145 | for child in model.children(): 146 | # leaf node 147 | if is_leaf(child) and hasattr(child, 'old_forward'): 148 | child.forward = child.old_forward 149 | child.old_forward = None 150 | else: 151 | restore_forward(child) 152 | 153 | modify_forward(model, print_name) 154 | model.forward(data) 155 | restore_forward(model) 156 | 157 | return count_ops, count_params 158 | -------------------------------------------------------------------------------- /img/framework.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/img/framework.jpeg -------------------------------------------------------------------------------- /img/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/img/framework.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.optim as optim 4 | import torch.backends.cudnn as cudnn 5 | 6 | import torchvision 7 | import torchvision.transforms as transforms 8 | 9 | import os 10 | import argparse 11 | 12 | from data import imagenet 13 | from models import * 14 | from utils import progress_bar 15 | from mask import * 16 | import utils 17 | 18 | 19 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') 20 | parser.add_argument( 21 | '--data_dir', 22 | type=str, 23 | default='./data', 24 | help='dataset path') 25 | parser.add_argument( 26 | '--dataset', 27 | type=str, 28 | default='cifar10', 29 | choices=('cifar10','imagenet'), 30 | help='dataset') 31 | parser.add_argument( 32 | '--lr', 33 | default=0.01, 34 | type=float, 35 | help='initial learning rate') 36 | parser.add_argument( 37 | '--lr_decay_step', 38 | default='5,10', 39 | type=str, 40 | help='learning rate decay step') 41 | parser.add_argument( 42 | '--resume', 43 | type=str, 44 | default=None, 45 | help='load the model from the specified checkpoint') 46 | parser.add_argument( 47 | '--resume_mask', 48 | type=str, 49 | default=None, 50 | help='mask loading') 51 | parser.add_argument( 52 | '--gpu', 53 | type=str, 54 | default='0', 55 | help='Select gpu to use') 56 | parser.add_argument( 57 | '--job_dir', 58 | type=str, 59 | default='./result/tmp/', 60 | help='The directory where the summaries will be stored.') 61 | parser.add_argument( 62 | '--epochs', 63 | type=int, 64 | default=15, 65 | help='The num of epochs to train.') 66 | parser.add_argument( 67 | '--train_batch_size', 68 | type=int, 69 | default=128, 70 | help='Batch size for training.') 71 | parser.add_argument( 72 | '--eval_batch_size', 73 | type=int, 74 | default=100, 75 | help='Batch size for validation.') 76 | parser.add_argument( 77 | '--start_cov', 78 | type=int, 79 | default=0, 80 | help='The num of conv to start prune') 81 | parser.add_argument( 82 | '--compress_rate', 83 | type=str, 84 | default=None, 85 | help='compress rate of each conv') 86 | parser.add_argument( 87 | '--arch', 88 | type=str, 89 | default='vgg_16_bn', 90 | choices=('resnet_50','vgg_16_bn','resnet_56','resnet_110','densenet_40','googlenet'), 91 | help='The architecture to prune') 92 | 93 | args = parser.parse_args() 94 | 95 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 96 | 97 | if len(args.gpu)==1: 98 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 99 | else: 100 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 101 | 102 | best_acc = 0 # best test accuracy 103 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 104 | lr_decay_step = list(map(int, args.lr_decay_step.split(','))) 105 | 106 | ckpt = utils.checkpoint(args) 107 | print_logger = utils.get_logger(os.path.join(args.job_dir, "logger.log")) 108 | utils.print_params(vars(args), print_logger.info) 109 | 110 | # Data 111 | print_logger.info('==> Preparing data..') 112 | 113 | if args.dataset=='cifar10': 114 | transform_train = transforms.Compose([ 115 | transforms.RandomCrop(32, padding=4), 116 | transforms.RandomHorizontalFlip(), 117 | transforms.ToTensor(), 118 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 119 | ]) 120 | 121 | transform_test = transforms.Compose([ 122 | transforms.ToTensor(), 123 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 124 | ]) 125 | 126 | trainset = torchvision.datasets.CIFAR10(root=args.data_dir, train=True, download=True, transform=transform_train) 127 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.train_batch_size, shuffle=True, num_workers=2) 128 | 129 | testset = torchvision.datasets.CIFAR10(root=args.data_dir, train=False, download=True, transform=transform_test) 130 | testloader = torch.utils.data.DataLoader(testset, batch_size=args.eval_batch_size, shuffle=False, num_workers=2) 131 | elif args.dataset=='imagenet': 132 | data_tmp = imagenet.Data(args) 133 | trainloader = data_tmp.loader_train 134 | testloader = data_tmp.loader_test 135 | else: 136 | assert 1==0 137 | 138 | if args.compress_rate: 139 | import re 140 | cprate_str=args.compress_rate 141 | cprate_str_list=cprate_str.split('+') 142 | pat_cprate = re.compile(r'\d+\.\d*') 143 | pat_num = re.compile(r'\*\d+') 144 | cprate=[] 145 | for x in cprate_str_list: 146 | num=1 147 | find_num=re.findall(pat_num,x) 148 | if find_num: 149 | assert len(find_num) == 1 150 | num=int(find_num[0].replace('*','')) 151 | find_cprate = re.findall(pat_cprate, x) 152 | assert len(find_cprate)==1 153 | cprate+=[float(find_cprate[0])]*num 154 | 155 | compress_rate=cprate 156 | 157 | # Model 158 | device_ids=list(map(int, args.gpu.split(','))) 159 | print_logger.info('==> Building model..') 160 | net = eval(args.arch)(compress_rate=compress_rate) 161 | net = net.to(device) 162 | 163 | if len(args.gpu)>1 and torch.cuda.is_available(): 164 | device_id = [] 165 | for i in range((len(args.gpu) + 1) // 2): 166 | device_id.append(i) 167 | net = torch.nn.DataParallel(net, device_ids=device_id) 168 | 169 | cudnn.benchmark = True 170 | print(net) 171 | 172 | if len(args.gpu)>1: 173 | m = eval('mask_'+args.arch)(model=net, compress_rate=net.module.compress_rate, job_dir=args.job_dir, device=device) 174 | else: 175 | m = eval('mask_' + args.arch)(model=net, compress_rate=net.compress_rate, job_dir=args.job_dir, device=device) 176 | 177 | criterion = nn.CrossEntropyLoss() 178 | 179 | # Training 180 | def train(epoch, cov_id, optimizer, scheduler, pruning=True): 181 | print_logger.info('\nEpoch: %d' % epoch) 182 | net.train() 183 | 184 | train_loss = 0 185 | correct = 0 186 | total = 0 187 | for batch_idx, (inputs, targets) in enumerate(trainloader): 188 | with torch.cuda.device(device): 189 | inputs = inputs.to(device) 190 | targets = targets.to(device) 191 | optimizer.zero_grad() 192 | outputs = net(inputs) 193 | loss = criterion(outputs, targets) 194 | loss.backward() 195 | 196 | optimizer.step() 197 | 198 | if pruning: 199 | m.grad_mask(cov_id) 200 | 201 | train_loss += loss.item() 202 | _, predicted = outputs.max(1) 203 | total += targets.size(0) 204 | correct += predicted.eq(targets).sum().item() 205 | 206 | progress_bar(batch_idx,len(trainloader), 207 | 'Cov: %d | Loss: %.3f | Acc: %.3f%% (%d/%d)' 208 | % (cov_id, train_loss / (batch_idx + 1), 100. * correct / total, correct, total)) 209 | 210 | def test(epoch, cov_id, optimizer, scheduler): 211 | top1 = utils.AverageMeter() 212 | top5 = utils.AverageMeter() 213 | 214 | global best_acc 215 | net.eval() 216 | num_iterations = len(testloader) 217 | with torch.no_grad(): 218 | for batch_idx, (inputs, targets) in enumerate(testloader): 219 | inputs, targets = inputs.to(device), targets.to(device) 220 | outputs = net(inputs) 221 | loss = criterion(outputs, targets) 222 | 223 | prec1, prec5 = utils.accuracy(outputs, targets, topk=(1, 5)) 224 | top1.update(prec1[0], inputs.size(0)) 225 | top5.update(prec5[0], inputs.size(0)) 226 | 227 | print_logger.info( 228 | 'Epoch[{0}]({1}/{2}): ' 229 | 'Prec@1(1,5) {top1.avg:.2f}, {top5.avg:.2f}'.format( 230 | epoch, batch_idx, num_iterations, top1=top1, top5=top5)) 231 | 232 | if top1.avg > best_acc: 233 | print_logger.info('Saving to '+args.arch+'_cov'+str(cov_id)+'.pt') 234 | state = { 235 | 'state_dict': net.state_dict(), 236 | 'best_prec1': top1.avg, 237 | 'epoch': epoch, 238 | 'scheduler':scheduler.state_dict(), 239 | 'optimizer': optimizer.state_dict() 240 | } 241 | if not os.path.isdir(args.job_dir+'/pruned_checkpoint'): 242 | os.mkdir(args.job_dir+'/pruned_checkpoint') 243 | best_acc = top1.avg 244 | torch.save(state, args.job_dir+'/pruned_checkpoint/'+args.arch+'_cov'+str(cov_id)+'.pt') 245 | 246 | print_logger.info("=>Best accuracy {:.3f}".format(best_acc)) 247 | 248 | 249 | if len(args.gpu)>1: 250 | convcfg = net.module.covcfg 251 | else: 252 | convcfg = net.covcfg 253 | 254 | param_per_cov_dic={ 255 | 'vgg_16_bn': 4, 256 | 'densenet_40': 3, 257 | 'googlenet': 28, 258 | 'resnet_50':3, 259 | 'resnet_56':3, 260 | 'resnet_110':3 261 | } 262 | 263 | if len(args.gpu)>1: 264 | print_logger.info('compress rate: ' + str(net.module.compress_rate)) 265 | else: 266 | print_logger.info('compress rate: ' + str(net.compress_rate)) 267 | 268 | for cov_id in range(args.start_cov, len(convcfg)): 269 | # Load pruned_checkpoint 270 | print_logger.info("cov-id: %d ====> Resuming from pruned_checkpoint..." % (cov_id)) 271 | 272 | m.layer_mask(cov_id + 1, resume=args.resume_mask, param_per_cov=param_per_cov_dic[args.arch], arch=args.arch) 273 | 274 | optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) 275 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=lr_decay_step, gamma=0.1) 276 | 277 | if cov_id == 0: 278 | 279 | pruned_checkpoint = torch.load(args.resume, map_location='cuda:0') 280 | from collections import OrderedDict 281 | new_state_dict = OrderedDict() 282 | if args.arch == 'resnet_50': 283 | tmp_ckpt = pruned_checkpoint 284 | else: 285 | tmp_ckpt = pruned_checkpoint['state_dict'] 286 | 287 | if len(args.gpu) > 1: 288 | for k, v in tmp_ckpt.items(): 289 | new_state_dict['module.' + k.replace('module.', '')] = v 290 | else: 291 | for k, v in tmp_ckpt.items(): 292 | new_state_dict[k.replace('module.', '')] = v 293 | 294 | net.load_state_dict(new_state_dict)#''' 295 | else: 296 | if args.arch=='resnet_50': 297 | skip_list=[1,5,8,11,15,18,21,24,28,31,34,37,40,43,47,50,53] 298 | if cov_id+1 not in skip_list: 299 | continue 300 | else: 301 | pruned_checkpoint = torch.load( 302 | args.job_dir + "/pruned_checkpoint/" + args.arch + "_cov" + str(skip_list[skip_list.index(cov_id+1)-1]) + '.pt') 303 | net.load_state_dict(pruned_checkpoint['state_dict']) 304 | else: 305 | if len(args.gpu) == 1: 306 | pruned_checkpoint = torch.load(args.job_dir + "/pruned_checkpoint/" + args.arch + "_cov" + str(cov_id) + '.pt', map_location='cuda:' + args.gpu) 307 | else: 308 | pruned_checkpoint = torch.load(args.job_dir + "/pruned_checkpoint/" + args.arch + "_cov" + str(cov_id) + '.pt') 309 | net.load_state_dict(pruned_checkpoint['state_dict']) 310 | 311 | best_acc=0. 312 | for epoch in range(0, args.epochs): 313 | train(epoch, cov_id + 1, optimizer, scheduler) 314 | scheduler.step() 315 | test(epoch, cov_id + 1, optimizer, scheduler) 316 | -------------------------------------------------------------------------------- /mask.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import numpy as np 4 | 5 | import pickle 6 | 7 | 8 | class mask_vgg_16_bn: 9 | def __init__(self, model=None, compress_rate=[0.50], job_dir='',device=None): 10 | self.model = model 11 | self.compress_rate = compress_rate 12 | self.mask = {} 13 | self.job_dir=job_dir 14 | self.device = device 15 | 16 | def layer_mask(self, cov_id, resume=None, param_per_cov=4, arch="vgg_16_bn"): 17 | params = self.model.parameters() 18 | prefix = "rank_conv/"+arch+"/rank_conv" 19 | subfix = ".npy" 20 | 21 | if resume: 22 | with open(resume, 'rb') as f: 23 | self.mask = pickle.load(f) 24 | else: 25 | resume=self.job_dir+'/mask' 26 | 27 | self.param_per_cov=param_per_cov 28 | 29 | for index, item in enumerate(params): 30 | 31 | if index == cov_id * param_per_cov: 32 | break 33 | if index == (cov_id - 1) * param_per_cov: 34 | f, c, w, h = item.size() 35 | rank = np.load(prefix + str(cov_id) + subfix) 36 | pruned_num = int(self.compress_rate[cov_id - 1] * f) 37 | ind = np.argsort(rank)[pruned_num:] # preserved filter id 38 | 39 | zeros = torch.zeros(f, 1, 1, 1).to(self.device) 40 | for i in range(len(ind)): 41 | zeros[ind[i], 0, 0, 0] = 1. 42 | self.mask[index] = zeros # covolutional weight 43 | item.data = item.data * self.mask[index] 44 | 45 | if index > (cov_id - 1) * param_per_cov and index <= (cov_id - 1) * param_per_cov + param_per_cov-1: 46 | self.mask[index] = torch.squeeze(zeros) 47 | item.data = item.data * self.mask[index] 48 | 49 | with open(resume, "wb") as f: 50 | pickle.dump(self.mask, f) 51 | 52 | def grad_mask(self, cov_id): 53 | params = self.model.parameters() 54 | for index, item in enumerate(params): 55 | if index == cov_id * self.param_per_cov: 56 | break 57 | item.data = item.data * self.mask[index]#prune certain weight 58 | 59 | 60 | class mask_resnet_56: 61 | def __init__(self, model=None, compress_rate=[0.50], job_dir='',device=None): 62 | self.model = model 63 | self.compress_rate = compress_rate 64 | self.mask = {} 65 | self.job_dir=job_dir 66 | self.device = device 67 | 68 | def layer_mask(self, cov_id, resume=None, param_per_cov=3, arch="resnet_56"): 69 | params = self.model.parameters() 70 | prefix = "rank_conv/"+arch+"/rank_conv" 71 | subfix = ".npy" 72 | 73 | if resume: 74 | with open(resume, 'rb') as f: 75 | self.mask = pickle.load(f) 76 | else: 77 | resume=self.job_dir+'/mask' 78 | 79 | self.param_per_cov=param_per_cov 80 | 81 | for index, item in enumerate(params): 82 | 83 | if index == cov_id*param_per_cov: 84 | break 85 | 86 | if index == (cov_id - 1) * param_per_cov: 87 | f, c, w, h = item.size() 88 | rank = np.load(prefix + str(cov_id) + subfix) 89 | pruned_num = int(self.compress_rate[cov_id - 1] * f) 90 | ind = np.argsort(rank)[pruned_num:] # preserved filter id 91 | 92 | zeros = torch.zeros(f, 1, 1, 1).to(self.device) 93 | for i in range(len(ind)): 94 | zeros[ind[i], 0, 0, 0] = 1. 95 | self.mask[index] = zeros # covolutional weight 96 | item.data = item.data * self.mask[index] 97 | 98 | elif index > (cov_id-1)*param_per_cov and index < cov_id*param_per_cov: 99 | self.mask[index] = torch.squeeze(zeros) 100 | item.data = item.data * self.mask[index].to(self.device) 101 | 102 | with open(resume, "wb") as f: 103 | pickle.dump(self.mask, f) 104 | 105 | def grad_mask(self, cov_id): 106 | params = self.model.parameters() 107 | for index, item in enumerate(params): 108 | if index == cov_id*self.param_per_cov: 109 | break 110 | item.data = item.data * self.mask[index].to(self.device)#prune certain weight 111 | 112 | 113 | class mask_densenet_40: 114 | def __init__(self, model=None, compress_rate=[0.50], job_dir='',device=None): 115 | self.model = model 116 | self.compress_rate = compress_rate 117 | self.job_dir=job_dir 118 | self.device=device 119 | self.mask = {} 120 | 121 | def layer_mask(self, cov_id, resume=None, param_per_cov=3, arch="densenet_40"): 122 | params = self.model.parameters() 123 | prefix = "rank_conv/"+arch+"/rank_conv" 124 | subfix = ".npy" 125 | 126 | if resume: 127 | with open(resume, 'rb') as f: 128 | self.mask = pickle.load(f) 129 | else: 130 | resume=self.job_dir+'/mask' 131 | 132 | self.param_per_cov=param_per_cov 133 | 134 | for index, item in enumerate(params): 135 | 136 | if index == cov_id * param_per_cov: 137 | break 138 | if index == (cov_id - 1) * param_per_cov: 139 | f, c, w, h = item.size() 140 | rank = np.load(prefix + str(cov_id) + subfix) 141 | pruned_num = int(self.compress_rate[cov_id - 1] * f) 142 | ind = np.argsort(rank)[pruned_num:] # preserved filter id 143 | 144 | zeros = torch.zeros(f, 1, 1, 1).to(self.device) 145 | for i in range(len(ind)): 146 | zeros[ind[i], 0, 0, 0] = 1. 147 | self.mask[index] = zeros # covolutional weight 148 | item.data = item.data * self.mask[index] 149 | 150 | # prune BN's parameter 151 | if index > (cov_id - 1) * param_per_cov and index <= (cov_id - 1) * param_per_cov + param_per_cov-1: 152 | # if this BN not belong to 1st conv or transition conv --> add pre-BN mask to this mask 153 | if cov_id>=2 and cov_id!=14 and cov_id!=27: 154 | self.mask[index] = torch.cat([self.mask[index-param_per_cov], torch.squeeze(zeros)], 0).to(self.device) 155 | else: 156 | self.mask[index] = torch.squeeze(zeros).to(self.device) 157 | item.data = item.data * self.mask[index] 158 | 159 | with open(resume, "wb") as f: 160 | pickle.dump(self.mask, f) 161 | 162 | def grad_mask(self, cov_id): 163 | params = self.model.parameters() 164 | for index, item in enumerate(params): 165 | if index == cov_id * self.param_per_cov: 166 | break 167 | item.data = item.data * self.mask[index].to(self.device) 168 | 169 | 170 | class mask_googlenet: 171 | def __init__(self, model=None, compress_rate=[0.50], job_dir='',device=None): 172 | self.model = model 173 | self.compress_rate = compress_rate 174 | self.mask = {} 175 | self.job_dir=job_dir 176 | self.device = device 177 | 178 | def layer_mask(self, cov_id, resume=None, param_per_cov=28, arch="googlenet"): 179 | params = self.model.parameters() 180 | prefix = "rank_conv/"+arch+"/rank_conv" 181 | subfix = ".npy" 182 | 183 | if resume: 184 | with open(resume, 'rb') as f: 185 | self.mask = pickle.load(f) 186 | else: 187 | resume=self.job_dir+'/mask' 188 | 189 | self.param_per_cov=param_per_cov 190 | 191 | for index, item in enumerate(params): 192 | 193 | if index == (cov_id-1) * param_per_cov + 4: 194 | break 195 | if (cov_id==1 and index==0)\ 196 | or index == (cov_id - 1) * param_per_cov - 24 \ 197 | or index == (cov_id - 1) * param_per_cov - 16 \ 198 | or index == (cov_id - 1) * param_per_cov - 8 \ 199 | or index == (cov_id - 1) * param_per_cov - 4 \ 200 | or index == (cov_id - 1) * param_per_cov: 201 | 202 | if index == (cov_id - 1) * param_per_cov - 24: 203 | rank = np.load(prefix + str(cov_id)+'_'+'n1x1' + subfix) 204 | elif index == (cov_id - 1) * param_per_cov - 16: 205 | rank = np.load(prefix + str(cov_id)+'_'+'n3x3' + subfix) 206 | elif index == (cov_id - 1) * param_per_cov - 8 \ 207 | or index == (cov_id - 1) * param_per_cov - 4: 208 | rank = np.load(prefix + str(cov_id)+'_'+'n5x5' + subfix) 209 | elif cov_id==1 and index==0: 210 | rank = np.load(prefix + str(cov_id) + subfix) 211 | else: 212 | rank = np.load(prefix + str(cov_id) + '_' + 'pool_planes' + subfix) 213 | 214 | f, c, w, h = item.size() 215 | pruned_num = int(self.compress_rate[cov_id - 1] * f) 216 | ind = np.argsort(rank)[pruned_num:] # preserved filter id 217 | 218 | zeros = torch.zeros(f, 1, 1, 1).to(self.device) 219 | for i in range(len(ind)): 220 | zeros[ind[i], 0, 0, 0] = 1. 221 | self.mask[index] = zeros # covolutional weight 222 | item.data = item.data * self.mask[index] 223 | 224 | elif cov_id==1 and index > 0 and index <= 3: 225 | self.mask[index] = torch.squeeze(zeros) 226 | item.data = item.data * self.mask[index] 227 | 228 | elif (index>=(cov_id - 1) * param_per_cov - 20 and index< (cov_id - 1) * param_per_cov - 16) \ 229 | or (index>=(cov_id - 1) * param_per_cov - 12 and index< (cov_id - 1) * param_per_cov - 8): 230 | continue 231 | 232 | elif index > (cov_id-1)*param_per_cov-24 and index < (cov_id-1)*param_per_cov+4: 233 | self.mask[index] = torch.squeeze(zeros) 234 | item.data = item.data * self.mask[index] 235 | 236 | with open(resume, "wb") as f: 237 | pickle.dump(self.mask, f) 238 | 239 | def grad_mask(self, cov_id): 240 | params = self.model.parameters() 241 | for index, item in enumerate(params): 242 | if index == (cov_id-1) * self.param_per_cov + 4: 243 | break 244 | if index not in self.mask: 245 | continue 246 | item.data = item.data * self.mask[index].to(self.device)#prune certain weight 247 | 248 | 249 | class mask_resnet_110: 250 | def __init__(self, model=None, compress_rate=[0.50], job_dir='',device=None): 251 | self.model = model 252 | self.compress_rate = compress_rate 253 | self.mask = {} 254 | self.job_dir=job_dir 255 | self.device = device 256 | 257 | def layer_mask(self, cov_id, resume=None, param_per_cov=3, arch="resnet_110_convwise"): 258 | params = self.model.parameters() 259 | prefix = "rank_conv/"+arch+"/rank_conv" 260 | subfix = ".npy" 261 | 262 | if resume: 263 | with open(resume, 'rb') as f: 264 | self.mask = pickle.load(f) 265 | else: 266 | resume=self.job_dir+'/mask' 267 | 268 | self.param_per_cov=param_per_cov 269 | 270 | for index, item in enumerate(params): 271 | 272 | if index == cov_id*param_per_cov: 273 | break 274 | 275 | if index == (cov_id - 1) * param_per_cov: 276 | f, c, w, h = item.size() 277 | rank = np.load(prefix + str(cov_id) + subfix) 278 | pruned_num = int(self.compress_rate[cov_id - 1] * f) 279 | ind = np.argsort(rank)[pruned_num:] # preserved filter id 280 | 281 | zeros = torch.zeros(f, 1, 1, 1).to(self.device) 282 | for i in range(len(ind)): 283 | zeros[ind[i], 0, 0, 0] = 1. 284 | 285 | self.mask[index] = zeros # covolutional weight 286 | item.data = item.data * self.mask[index] 287 | 288 | elif index > (cov_id-1)*param_per_cov and index < cov_id*param_per_cov: 289 | self.mask[index] = torch.squeeze(zeros) 290 | item.data = item.data * self.mask[index] 291 | 292 | with open(resume, "wb") as f: 293 | pickle.dump(self.mask, f) 294 | 295 | def grad_mask(self, cov_id): 296 | params = self.model.parameters() 297 | for index, item in enumerate(params): 298 | if index == cov_id*self.param_per_cov: 299 | break 300 | item.data = item.data * self.mask[index].to(self.device)#prune certain weight 301 | 302 | 303 | class mask_resnet_50: 304 | def __init__(self, model=None, compress_rate=[0.50], job_dir='',device=None): 305 | self.model = model 306 | self.compress_rate = compress_rate 307 | self.mask = {} 308 | self.job_dir=job_dir 309 | self.device = device 310 | 311 | def layer_mask(self, cov_id, resume=None, param_per_cov=3, arch="resnet_50_convwise"): 312 | params = self.model.parameters() 313 | prefix = "rank_conv/"+arch+"/rank_conv" 314 | subfix = ".npy" 315 | 316 | if resume: 317 | with open(resume, 'rb') as f: 318 | self.mask = pickle.load(f) 319 | else: 320 | resume=self.job_dir+'/mask' 321 | 322 | self.param_per_cov=param_per_cov 323 | 324 | for index, item in enumerate(params): 325 | 326 | if index == cov_id * param_per_cov: 327 | break 328 | 329 | if index == (cov_id-1) * param_per_cov: 330 | f, c, w, h = item.size() 331 | rank = np.load(prefix + str(cov_id) + subfix) 332 | pruned_num = int(self.compress_rate[cov_id - 1] * f) 333 | ind = np.argsort(rank)[pruned_num:] # preserved filter id 334 | zeros = torch.zeros(f, 1, 1, 1).to(self.device)#.cuda(self.device[0])#.to(self.device) 335 | for i in range(len(ind)): 336 | zeros[ind[i], 0, 0, 0] = 1. 337 | self.mask[index] = zeros # covolutional weight 338 | item.data = item.data * self.mask[index] 339 | 340 | elif index > (cov_id-1) * param_per_cov and index < cov_id * param_per_cov: 341 | self.mask[index] = torch.squeeze(zeros) 342 | item.data = item.data * self.mask[index] 343 | 344 | with open(resume, "wb") as f: 345 | pickle.dump(self.mask, f) 346 | 347 | def grad_mask(self, cov_id): 348 | params = self.model.parameters() 349 | for index, item in enumerate(params): 350 | if index == cov_id * self.param_per_cov: 351 | break 352 | item.data = item.data * self.mask[index]#prune certain weight 353 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .vgg import * 2 | from .densenet_cifar import * 3 | from .googlenet_cifar import * 4 | from .resnet_imagenet import * 5 | from .resnet_cifar import * 6 | -------------------------------------------------------------------------------- /models/densenet_cifar.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import math 6 | import numpy as np 7 | 8 | 9 | norm_mean, norm_var = 0.0, 1.0 10 | 11 | cov_cfg=[(3*i+1) for i in range(12*3+2+1)] 12 | 13 | 14 | class DenseBasicBlock(nn.Module): 15 | def __init__(self, inplanes, filters, index, expansion=1, growthRate=12, dropRate=0, compress_rate=0., tmp_name=None): 16 | super(DenseBasicBlock, self).__init__() 17 | 18 | self.bn1 = nn.BatchNorm2d(inplanes) 19 | self.relu = nn.ReLU(inplace=True) 20 | self.conv1 = nn.Conv2d(filters, growthRate, kernel_size=3, 21 | padding=1, bias=False) 22 | self.conv1.cp_rate = compress_rate 23 | self.conv1.tmp_name = tmp_name 24 | 25 | self.dropRate = dropRate 26 | 27 | def forward(self, x): 28 | out = self.bn1(x) 29 | out = self.relu(out) 30 | out = self.conv1(out) 31 | if self.dropRate > 0: 32 | out = F.dropout(out, p=self.dropRate, training=self.training) 33 | 34 | out = torch.cat((x, out), 1) 35 | 36 | return out 37 | 38 | 39 | class Transition(nn.Module): 40 | def __init__(self, inplanes, outplanes, filters, index, compress_rate, tmp_name, last_prune_num): 41 | super(Transition, self).__init__() 42 | self.bn1 = nn.BatchNorm2d(inplanes) 43 | self.relu = nn.ReLU(inplace=True) 44 | self.conv1 = nn.Conv2d(filters, outplanes, kernel_size=1, 45 | bias=False) 46 | self.conv1.cp_rate = compress_rate 47 | self.conv1.tmp_name = tmp_name 48 | self.conv1.last_prune_num=last_prune_num 49 | 50 | def forward(self, x): 51 | out = self.bn1(x) 52 | out = self.relu(out) 53 | out = self.conv1(out) 54 | out = F.avg_pool2d(out, 2) 55 | return out 56 | 57 | 58 | class DenseNet(nn.Module): 59 | 60 | def __init__(self, depth=40, block=DenseBasicBlock, 61 | dropRate=0, num_classes=10, growthRate=12, compressionRate=1, filters=None, indexes=None,compress_rate=None): 62 | super(DenseNet, self).__init__() 63 | 64 | assert (depth - 4) % 3 == 0, 'depth should be 3n+4' 65 | n = (depth - 4) // 3 if 'DenseBasicBlock' in str(block) else (depth - 4) // 6 66 | 67 | transition = Transition 68 | if filters == None: 69 | filters = [] 70 | start = growthRate*2 71 | for i in range(3): 72 | filters.append([start + growthRate*i for i in range(n+1)]) 73 | start = (start + growthRate*n) // compressionRate 74 | filters = [item for sub_list in filters for item in sub_list] 75 | 76 | indexes = [] 77 | for f in filters: 78 | indexes.append(np.arange(f)) 79 | 80 | self.covcfg=cov_cfg 81 | self.compress_rate=compress_rate 82 | 83 | self.growthRate = growthRate 84 | self.dropRate = dropRate 85 | 86 | self.inplanes = growthRate * 2 87 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, padding=1, 88 | bias=False) 89 | self.conv1.cp_rate=compress_rate[0] 90 | self.conv1.tmp_name = 'conv1' 91 | self.last_prune_num=self.inplanes*compress_rate[0] 92 | 93 | self.dense1 = self._make_denseblock(block, n, filters[0:n], indexes[0:n], compress_rate[1:n+1],'dense1', self.last_prune_num) 94 | self.trans1 = self._make_transition(transition, compressionRate, filters[n], indexes[n], compress_rate[n+1],'trans1', self.last_prune_num) 95 | self.dense2 = self._make_denseblock(block, n, filters[n+1:2*n+1], indexes[n+1:2*n+1], compress_rate[n+2:2*n+2],'dense2', self.last_prune_num) 96 | self.trans2 = self._make_transition(transition, compressionRate, filters[2*n+1], indexes[2*n+1], compress_rate[2*n+2],'trans2', self.last_prune_num) 97 | self.dense3 = self._make_denseblock(block, n, filters[2*n+2:3*n+2], indexes[2*n+2:3*n+2], compress_rate[2*n+3:3*n+3],'dense3', self.last_prune_num) 98 | self.bn = nn.BatchNorm2d(self.inplanes) 99 | self.relu = nn.ReLU(inplace=True) 100 | self.avgpool = nn.AvgPool2d(8) 101 | 102 | self.fc = nn.Linear(self.inplanes, num_classes) 103 | 104 | # Weight initialization 105 | for m in self.modules(): 106 | if isinstance(m, nn.Conv2d): 107 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 108 | m.weight.data.normal_(0, math.sqrt(2. / n)) 109 | elif isinstance(m, nn.BatchNorm2d): 110 | m.weight.data.fill_(1) 111 | m.bias.data.zero_() 112 | 113 | def _make_denseblock(self, block, blocks, filters, indexes, compress_rate, tmp_name, last_prune_num): 114 | layers = [] 115 | assert blocks == len(filters), 'Length of the filters parameter is not right.' 116 | assert blocks == len(indexes), 'Length of the indexes parameter is not right.' 117 | for i in range(blocks): 118 | self.last_prune_num+=int(compress_rate[i]*self.growthRate) 119 | layers.append(block(self.inplanes, filters=filters[i], index=indexes[i], 120 | growthRate=self.growthRate, dropRate=self.dropRate, compress_rate=compress_rate[i], tmp_name=tmp_name+'_'+str(i))) 121 | self.inplanes += self.growthRate 122 | 123 | return nn.Sequential(*layers) 124 | 125 | def _make_transition(self, transition, compressionRate, filters, index, compress_rate, tmp_name, last_prune_num): 126 | inplanes = self.inplanes 127 | outplanes = int(math.floor(self.inplanes // compressionRate)) 128 | self.inplanes = outplanes 129 | self.last_prune_num=int(compress_rate*filters) 130 | return transition(inplanes, outplanes, filters, index, compress_rate, tmp_name, last_prune_num) 131 | 132 | def forward(self, x): 133 | x = self.conv1(x) 134 | 135 | x = self.dense1(x) 136 | x = self.trans1(x) 137 | x = self.dense2(x) 138 | x = self.trans2(x) 139 | x = self.dense3(x) 140 | x = self.bn(x) 141 | x = self.relu(x) 142 | 143 | x = self.avgpool(x) 144 | x = x.view(x.size(0), -1) 145 | 146 | x = self.fc(x) 147 | 148 | return x 149 | 150 | 151 | def densenet_40(compress_rate=None): 152 | return DenseNet(depth=40, block=DenseBasicBlock, compressionRate=1, compress_rate=compress_rate) 153 | -------------------------------------------------------------------------------- /models/googlenet_cifar.py: -------------------------------------------------------------------------------- 1 | '''GoogLeNet with PyTorch.''' 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | norm_mean, norm_var = 0.0, 1.0 8 | 9 | cov_cfg=[(22*i+2) for i in range(1+2+5+2)] 10 | 11 | 12 | class Inception(nn.Module): 13 | def __init__(self, in_planes, n1x1, n3x3red, n3x3, n5x5red, n5x5, pool_planes, cp_rate, tmp_name): 14 | super(Inception, self).__init__() 15 | self.cp_rate=cp_rate 16 | self.tmp_name=tmp_name 17 | 18 | self.n1x1 = n1x1 19 | self.n3x3 = n3x3 20 | self.n5x5 = n5x5 21 | self.pool_planes = pool_planes 22 | 23 | # 1x1 conv branch 24 | if self.n1x1: 25 | conv1x1 = nn.Conv2d(in_planes, n1x1, kernel_size=1) 26 | conv1x1.cp_rate = self.cp_rate 27 | conv1x1.tmp_name = self.tmp_name 28 | 29 | self.branch1x1 = nn.Sequential( 30 | conv1x1, 31 | nn.BatchNorm2d(n1x1), 32 | nn.ReLU(True), 33 | ) 34 | 35 | # 1x1 conv -> 3x3 conv branch 36 | if self.n3x3: 37 | conv3x3_1=nn.Conv2d(in_planes, n3x3red, kernel_size=1) 38 | conv3x3_2=nn.Conv2d(n3x3red, n3x3, kernel_size=3, padding=1) 39 | conv3x3_1.cp_rate = 0. 40 | conv3x3_1.tmp_name = self.tmp_name 41 | conv3x3_2.cp_rate = self.cp_rate 42 | conv3x3_2.tmp_name = self.tmp_name 43 | 44 | self.branch3x3 = nn.Sequential( 45 | conv3x3_1, 46 | nn.BatchNorm2d(n3x3red), 47 | nn.ReLU(True), 48 | conv3x3_2, 49 | nn.BatchNorm2d(n3x3), 50 | nn.ReLU(True), 51 | ) 52 | 53 | # 1x1 conv -> 5x5 conv branch 54 | if self.n5x5 > 0: 55 | conv5x5_1 = nn.Conv2d(in_planes, n5x5red, kernel_size=1) 56 | conv5x5_2 = nn.Conv2d(n5x5red, n5x5, kernel_size=3, padding=1) 57 | conv5x5_3 = nn.Conv2d(n5x5, n5x5, kernel_size=3, padding=1) 58 | conv5x5_1.cp_rate = 0. 59 | conv5x5_1.tmp_name = self.tmp_name 60 | conv5x5_2.cp_rate = self.cp_rate 61 | conv5x5_2.tmp_name = self.tmp_name 62 | conv5x5_3.cp_rate = self.cp_rate 63 | conv5x5_3.tmp_name = self.tmp_name 64 | 65 | self.branch5x5 = nn.Sequential( 66 | conv5x5_1, 67 | nn.BatchNorm2d(n5x5red), 68 | nn.ReLU(True), 69 | conv5x5_2, 70 | nn.BatchNorm2d(n5x5), 71 | nn.ReLU(True), 72 | conv5x5_3, 73 | nn.BatchNorm2d(n5x5), 74 | nn.ReLU(True), 75 | ) 76 | 77 | # 3x3 pool -> 1x1 conv branch 78 | if self.pool_planes > 0: 79 | conv_pool = nn.Conv2d(in_planes, pool_planes, kernel_size=1) 80 | conv_pool.cp_rate = self.cp_rate 81 | conv_pool.tmp_name = self.tmp_name 82 | 83 | self.branch_pool = nn.Sequential( 84 | nn.MaxPool2d(3, stride=1, padding=1), 85 | conv_pool, 86 | nn.BatchNorm2d(pool_planes), 87 | nn.ReLU(True), 88 | ) 89 | 90 | def forward(self, x): 91 | out = [] 92 | y1 = self.branch1x1(x) 93 | out.append(y1) 94 | 95 | y2 = self.branch3x3(x) 96 | out.append(y2) 97 | 98 | y3 = self.branch5x5(x) 99 | out.append(y3) 100 | 101 | y4 = self.branch_pool(x) 102 | out.append(y4) 103 | return torch.cat(out, 1) 104 | 105 | 106 | class GoogLeNet(nn.Module): 107 | def __init__(self, block=Inception, filters=None, compress_rate=None): 108 | super(GoogLeNet, self).__init__() 109 | 110 | self.covcfg=cov_cfg 111 | self.compress_rate=compress_rate 112 | 113 | conv_pre = nn.Conv2d(3, 192, kernel_size=3, padding=1) 114 | conv_pre.cp_rate=compress_rate[0] 115 | conv_pre.tmp_name='pre_layer' 116 | self.pre_layers = nn.Sequential( 117 | conv_pre, 118 | nn.BatchNorm2d(192), 119 | nn.ReLU(True), 120 | ) 121 | if filters is None: 122 | filters = [ 123 | [64, 128, 32, 32], 124 | [128, 192, 96, 64], 125 | [192, 208, 48, 64], 126 | [160, 224, 64, 64], 127 | [128, 256, 64, 64], 128 | [112, 288, 64, 64], 129 | [256, 320, 128, 128], 130 | [256, 320, 128, 128], 131 | [384, 384, 128, 128] 132 | ] 133 | 134 | self.filters=filters 135 | 136 | self.inception_a3 = block(192, filters[0][0], 96, filters[0][1], 16, filters[0][2], filters[0][3], self.compress_rate[1], 'a3') 137 | self.inception_b3 = block(sum(filters[0]), filters[1][0], 128, filters[1][1], 32, filters[1][2], filters[1][3], self.compress_rate[2], 'a4') 138 | 139 | self.maxpool1 = nn.MaxPool2d(3, stride=2, padding=1) 140 | self.maxpool2 = nn.MaxPool2d(3, stride=2, padding=1) 141 | 142 | self.inception_a4 = block(sum(filters[1]), filters[2][0], 96, filters[2][1], 16, filters[2][2], filters[2][3], self.compress_rate[3], 'a4') 143 | self.inception_b4 = block(sum(filters[2]), filters[3][0], 112, filters[3][1], 24, filters[3][2], filters[3][3], self.compress_rate[4], 'b4') 144 | self.inception_c4 = block(sum(filters[3]), filters[4][0], 128, filters[4][1], 24, filters[4][2], filters[4][3], self.compress_rate[6], 'c4') 145 | self.inception_d4 = block(sum(filters[4]), filters[5][0], 144, filters[5][1], 32, filters[5][2], filters[5][3], self.compress_rate[6], 'd4') 146 | self.inception_e4 = block(sum(filters[5]), filters[6][0], 160, filters[6][1], 32, filters[6][2], filters[6][3], self.compress_rate[7], 'e4') 147 | 148 | self.inception_a5 = block(sum(filters[6]), filters[7][0], 160, filters[7][1], 32, filters[7][2], filters[7][3], self.compress_rate[8], 'a5') 149 | self.inception_b5 = block(sum(filters[7]), filters[8][0], 192, filters[8][1], 48, filters[8][2], filters[8][3], self.compress_rate[9], 'b5') 150 | 151 | self.avgpool = nn.AvgPool2d(8, stride=1) 152 | self.linear = nn.Linear(sum(filters[-1]), 10) 153 | 154 | def forward(self, x): 155 | 156 | out = self.pre_layers(x) 157 | # 192 x 32 x 32 158 | out = self.inception_a3(out) 159 | 160 | # 256 x 32 x 32 161 | out = self.inception_b3(out) 162 | # 480 x 32 x 32 163 | out = self.maxpool1(out) 164 | 165 | # 480 x 16 x 16 166 | out = self.inception_a4(out) 167 | 168 | # 512 x 16 x 16 169 | out = self.inception_b4(out) 170 | 171 | # 512 x 16 x 16 172 | out = self.inception_c4(out) 173 | 174 | # 512 x 16 x 16 175 | out = self.inception_d4(out) 176 | 177 | # 528 x 16 x 16 178 | out = self.inception_e4(out) 179 | # 823 x 16 x 16 180 | out = self.maxpool2(out) 181 | 182 | # 823 x 8 x 8 183 | out = self.inception_a5(out) 184 | 185 | # 823 x 8 x 8 186 | out = self.inception_b5(out) 187 | 188 | # 1024 x 8 x 8 189 | out = self.avgpool(out) 190 | out = out.view(out.size(0), -1) 191 | out = self.linear(out) 192 | 193 | return out 194 | 195 | 196 | def googlenet(compress_rate=None): 197 | return GoogLeNet(block=Inception, compress_rate=compress_rate) 198 | -------------------------------------------------------------------------------- /models/resnet_cifar.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | norm_mean, norm_var = 0.0, 1.0 6 | 7 | 8 | def conv3x3(in_planes, out_planes, stride=1): 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 10 | padding=1, bias=False) 11 | 12 | 13 | class LambdaLayer(nn.Module): 14 | def __init__(self, lambd): 15 | super(LambdaLayer, self).__init__() 16 | self.lambd = lambd 17 | 18 | def forward(self, x): 19 | return self.lambd(x) 20 | 21 | 22 | class ResBasicBlock(nn.Module): 23 | expansion = 1 24 | 25 | def __init__(self, inplanes, planes, stride=1, compress_rate=[0.]): 26 | super(ResBasicBlock, self).__init__() 27 | self.inplanes = inplanes 28 | self.planes = planes 29 | self.conv1 = conv3x3(inplanes, planes, stride) 30 | self.conv1.cp_rate = compress_rate[0] 31 | self.bn1 = nn.BatchNorm2d(planes) 32 | self.relu1 = nn.ReLU(inplace=True) 33 | 34 | self.conv2 = conv3x3(planes, planes) 35 | self.conv2.cp_rate = compress_rate[1] 36 | self.bn2 = nn.BatchNorm2d(planes) 37 | self.relu2 = nn.ReLU(inplace=True) 38 | self.stride = stride 39 | self.shortcut = nn.Sequential() 40 | if stride != 1 or inplanes != planes: 41 | self.shortcut = LambdaLayer( 42 | lambda x: F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes // 4, planes // 4), "constant", 0)) 43 | 44 | def forward(self, x): 45 | out = self.conv1(x) 46 | out = self.bn1(out) 47 | out = self.relu1(out) 48 | 49 | out = self.conv2(out) 50 | out = self.bn2(out) 51 | 52 | out += self.shortcut(x) 53 | out = self.relu2(out) 54 | 55 | return out 56 | 57 | 58 | class ResNet(nn.Module): 59 | def __init__(self, block, num_layers, covcfg, compress_rate, num_classes=10): 60 | super(ResNet, self).__init__() 61 | assert (num_layers - 2) % 6 == 0, 'depth should be 6n+2' 62 | n = (num_layers - 2) // 6 63 | self.covcfg = covcfg 64 | self.compress_rate = compress_rate 65 | self.num_layers = num_layers 66 | 67 | self.inplanes = 16 68 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) 69 | self.conv1.cp_rate = compress_rate[0] 70 | 71 | self.bn1 = nn.BatchNorm2d(self.inplanes) 72 | self.relu = nn.ReLU(inplace=True) 73 | 74 | self.layer1 = self._make_layer(block, 16, blocks=n, stride=1, 75 | compress_rate=compress_rate[1:2 * n + 1]) 76 | self.layer2 = self._make_layer(block, 32, blocks=n, stride=2, 77 | compress_rate=compress_rate[2 * n + 1:4 * n + 1]) 78 | self.layer3 = self._make_layer(block, 64, blocks=n, stride=2, 79 | compress_rate=compress_rate[4 * n + 1:6 * n + 1]) 80 | self.avgpool = nn.AdaptiveAvgPool2d(1) 81 | 82 | if num_layers == 110: 83 | self.linear = nn.Linear(64 * block.expansion, num_classes) 84 | else: 85 | self.fc = nn.Linear(64 * block.expansion, num_classes) 86 | 87 | self.initialize() 88 | 89 | def initialize(self): 90 | for m in self.modules(): 91 | if isinstance(m, nn.Conv2d): 92 | nn.init.kaiming_normal_(m.weight) 93 | elif isinstance(m, nn.BatchNorm2d): 94 | nn.init.constant_(m.weight, 1) 95 | nn.init.constant_(m.bias, 0) 96 | 97 | def _make_layer(self, block, planes, blocks, stride, compress_rate): 98 | layers = [] 99 | 100 | layers.append(block(self.inplanes, planes, stride, compress_rate=compress_rate[0:2])) 101 | 102 | self.inplanes = planes * block.expansion 103 | for i in range(1, blocks): 104 | layers.append(block(self.inplanes, planes, compress_rate=compress_rate[2 * i:2 * i + 2])) 105 | 106 | return nn.Sequential(*layers) 107 | 108 | def forward(self, x): 109 | 110 | x = self.conv1(x) 111 | x = self.bn1(x) 112 | x = self.relu(x) 113 | x = self.layer1(x) 114 | x = self.layer2(x) 115 | x = self.layer3(x) 116 | x = self.avgpool(x) 117 | x = x.view(x.size(0), -1) 118 | 119 | if self.num_layers == 110: 120 | x = self.linear(x) 121 | else: 122 | x = self.fc(x) 123 | 124 | return x 125 | 126 | 127 | def resnet_56(compress_rate=None): 128 | cov_cfg = [(3 * i + 2) for i in range(9 * 3 * 2 + 1)] 129 | return ResNet(ResBasicBlock, 56, cov_cfg, compress_rate=compress_rate) 130 | 131 | 132 | def resnet_110(compress_rate=None): 133 | cov_cfg = [(3 * i + 2) for i in range(18 * 3 * 2 + 1)] 134 | return ResNet(ResBasicBlock, 110, cov_cfg, compress_rate=compress_rate) 135 | 136 | -------------------------------------------------------------------------------- /models/resnet_imagenet.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.nn as nn 3 | 4 | 5 | norm_mean, norm_var = 1.0, 0.1 6 | 7 | 8 | def conv3x3(in_planes, out_planes, stride=1): 9 | """3x3 convolution with padding""" 10 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 11 | padding=1, bias=False) 12 | 13 | 14 | class ResBottleneck(nn.Module): 15 | expansion = 4 16 | 17 | def __init__(self, inplanes, planes, stride=1, downsample=None, cp_rate=[0.], tmp_name=None): 18 | super(ResBottleneck, self).__init__() 19 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 20 | self.conv1.cp_rate = cp_rate[0] 21 | self.conv1.tmp_name = tmp_name 22 | self.bn1 = nn.BatchNorm2d(planes) 23 | self.relu1 = nn.ReLU(inplace=True) 24 | 25 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 26 | padding=1, bias=False) 27 | self.bn2 = nn.BatchNorm2d(planes) 28 | self.conv2.cp_rate = cp_rate[1] 29 | self.conv2.tmp_name = tmp_name 30 | self.relu2 = nn.ReLU(inplace=True) 31 | 32 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 33 | self.conv3.cp_rate = cp_rate[2] 34 | self.conv3.tmp_name = tmp_name 35 | 36 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 37 | self.relu3 = nn.ReLU(inplace=True) 38 | 39 | self.downsample = downsample 40 | self.stride = stride 41 | 42 | def forward(self, x): 43 | residual = x 44 | 45 | out = self.conv1(x) 46 | out = self.bn1(out) 47 | out = self.relu1(out) 48 | 49 | out = self.conv2(out) 50 | out = self.bn2(out) 51 | out = self.relu2(out) 52 | 53 | out = self.conv3(out) 54 | out = self.bn3(out) 55 | 56 | if self.downsample is not None: 57 | residual = self.downsample(x) 58 | 59 | out += residual 60 | out = self.relu3(out) 61 | 62 | return out 63 | 64 | 65 | class Downsample(nn.Module): 66 | def __init__(self, downsample): 67 | super(Downsample, self).__init__() 68 | self.downsample = downsample 69 | 70 | def forward(self, x): 71 | out = self.downsample(x) 72 | return out 73 | 74 | 75 | class ResNet(nn.Module): 76 | def __init__(self, block, num_blocks, num_classes=1000, covcfg=None, compress_rate=None): 77 | self.inplanes = 64 78 | super(ResNet, self).__init__() 79 | 80 | self.covcfg = covcfg 81 | self.compress_rate = compress_rate 82 | self.num_blocks = num_blocks 83 | 84 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 85 | self.conv1.cp_rate = compress_rate[0] 86 | self.conv1.tmp_name = 'conv1' 87 | self.bn1 = nn.BatchNorm2d(64) 88 | self.relu = nn.ReLU(inplace=True) 89 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 90 | 91 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1, 92 | cp_rate=compress_rate[1:3*num_blocks[0]+2], 93 | tmp_name='layer1') 94 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2, 95 | cp_rate=compress_rate[3*num_blocks[0]+2:3*num_blocks[0]+3*num_blocks[1]+3], 96 | tmp_name='layer2') 97 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2, 98 | cp_rate=compress_rate[3*num_blocks[0]+3*num_blocks[1]+3:3*num_blocks[0]+3*num_blocks[1]+3*num_blocks[2]+4], 99 | tmp_name='layer3') 100 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2, 101 | cp_rate=compress_rate[3*num_blocks[0]+3*num_blocks[1]+3*num_blocks[2]+4:], 102 | tmp_name='layer4') 103 | 104 | self.avgpool = nn.AvgPool2d(7, stride=1) 105 | self.fc = nn.Linear(512 * block.expansion, num_classes) 106 | 107 | for m in self.modules(): 108 | if isinstance(m, nn.Conv2d): 109 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 110 | elif isinstance(m, nn.BatchNorm2d): 111 | nn.init.constant_(m.weight, 1) 112 | nn.init.constant_(m.bias, 0) 113 | 114 | def _make_layer(self, block, planes, blocks, stride, cp_rate, tmp_name): 115 | downsample = None 116 | if stride != 1 or self.inplanes != planes * block.expansion: 117 | conv_short = nn.Conv2d(self.inplanes, planes * block.expansion, 118 | kernel_size=1, stride=stride, bias=False) 119 | conv_short.cp_rate = cp_rate[0] 120 | conv_short.tmp_name = tmp_name + '_shortcut' 121 | downsample = nn.Sequential( 122 | conv_short, 123 | nn.BatchNorm2d(planes * block.expansion), 124 | ) 125 | 126 | layers = [] 127 | layers.append(block(self.inplanes, planes, stride, downsample, cp_rate=cp_rate[1:4], 128 | tmp_name=tmp_name + '_block' + str(1))) 129 | 130 | self.inplanes = planes * block.expansion 131 | for i in range(1, blocks): 132 | layers.append(block(self.inplanes, planes, cp_rate=cp_rate[3 * i + 1:3 * i + 4], 133 | tmp_name=tmp_name + '_block' + str(i + 1))) 134 | 135 | return nn.Sequential(*layers) 136 | 137 | def forward(self, x): 138 | 139 | x = self.conv1(x) 140 | x = self.bn1(x) 141 | x = self.relu(x) 142 | x = self.maxpool(x) 143 | 144 | x = self.layer1(x) 145 | 146 | # 256 x 56 x 56 147 | x = self.layer2(x) 148 | 149 | # 512 x 28 x 28 150 | x = self.layer3(x) 151 | 152 | # 1024 x 14 x 14 153 | x = self.layer4(x) 154 | 155 | # 2048 x 7 x 7 156 | x = self.avgpool(x) 157 | x = x.view(x.size(0), -1) 158 | x = self.fc(x) 159 | 160 | return x 161 | 162 | 163 | def resnet_50(compress_rate=None): 164 | cov_cfg = [(3*i + 3) for i in range(3*3 + 1 + 4*3 + 1 + 6*3 + 1 + 3*3 + 1 + 1)] 165 | model = ResNet(ResBottleneck, [3, 4, 6, 3], covcfg=cov_cfg, compress_rate=compress_rate) 166 | return model 167 | -------------------------------------------------------------------------------- /models/vgg.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | import torch.nn as nn 4 | from collections import OrderedDict 5 | 6 | 7 | norm_mean, norm_var = 0.0, 1.0 8 | 9 | defaultcfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 512] 10 | relucfg = [2, 6, 9, 13, 16, 19, 23, 26, 29, 33, 36, 39] 11 | convcfg = [0, 3, 7, 10, 14, 17, 20, 24, 27, 30, 34, 37] 12 | 13 | 14 | class VGG(nn.Module): 15 | def __init__(self, num_classes=10, init_weights=True, cfg=None, compress_rate=None): 16 | super(VGG, self).__init__() 17 | self.features = nn.Sequential() 18 | 19 | if cfg is None: 20 | cfg = defaultcfg 21 | 22 | self.relucfg = relucfg 23 | self.covcfg = convcfg 24 | self.compress_rate = compress_rate 25 | self.features = self.make_layers(cfg[:-1], True, compress_rate) 26 | self.classifier = nn.Sequential(OrderedDict([ 27 | ('linear1', nn.Linear(cfg[-2], cfg[-1])), 28 | ('norm1', nn.BatchNorm1d(cfg[-1])), 29 | ('relu1', nn.ReLU(inplace=True)), 30 | ('linear2', nn.Linear(cfg[-1], num_classes)), 31 | ])) 32 | 33 | if init_weights: 34 | self._initialize_weights() 35 | 36 | def make_layers(self, cfg, batch_norm=True, compress_rate=None): 37 | layers = nn.Sequential() 38 | in_channels = 3 39 | cnt = 0 40 | for i, v in enumerate(cfg): 41 | if v == 'M': 42 | layers.add_module('pool%d' % i, nn.MaxPool2d(kernel_size=2, stride=2)) 43 | else: 44 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 45 | conv2d.cp_rate = compress_rate[cnt] 46 | cnt += 1 47 | 48 | layers.add_module('conv%d' % i, conv2d) 49 | layers.add_module('norm%d' % i, nn.BatchNorm2d(v)) 50 | layers.add_module('relu%d' % i, nn.ReLU(inplace=True)) 51 | in_channels = v 52 | 53 | return layers 54 | 55 | def forward(self, x): 56 | x = self.features(x) 57 | 58 | x = nn.AvgPool2d(2)(x) 59 | x = x.view(x.size(0), -1) 60 | x = self.classifier(x) 61 | return x 62 | 63 | def _initialize_weights(self): 64 | for m in self.modules(): 65 | if isinstance(m, nn.Conv2d): 66 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 67 | m.weight.data.normal_(0, math.sqrt(2. / n)) 68 | if m.bias is not None: 69 | m.bias.data.zero_() 70 | elif isinstance(m, nn.BatchNorm2d): 71 | m.weight.data.fill_(0.5) 72 | m.bias.data.zero_() 73 | elif isinstance(m, nn.Linear): 74 | m.weight.data.normal_(0, 0.01) 75 | m.bias.data.zero_() 76 | 77 | 78 | def vgg_16_bn(compress_rate=None): 79 | return VGG(compress_rate=compress_rate) 80 | 81 | -------------------------------------------------------------------------------- /rank_conv/densenet_40/rank_conv1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv1.npy -------------------------------------------------------------------------------- /rank_conv/densenet_40/rank_conv10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv10.npy -------------------------------------------------------------------------------- /rank_conv/densenet_40/rank_conv11.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv11.npy -------------------------------------------------------------------------------- /rank_conv/densenet_40/rank_conv12.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv12.npy -------------------------------------------------------------------------------- /rank_conv/densenet_40/rank_conv13.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv13.npy -------------------------------------------------------------------------------- /rank_conv/densenet_40/rank_conv14.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv14.npy -------------------------------------------------------------------------------- /rank_conv/densenet_40/rank_conv15.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv15.npy -------------------------------------------------------------------------------- /rank_conv/densenet_40/rank_conv16.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv16.npy -------------------------------------------------------------------------------- /rank_conv/densenet_40/rank_conv17.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv17.npy -------------------------------------------------------------------------------- /rank_conv/densenet_40/rank_conv18.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv18.npy -------------------------------------------------------------------------------- /rank_conv/densenet_40/rank_conv19.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv19.npy -------------------------------------------------------------------------------- /rank_conv/densenet_40/rank_conv2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv2.npy -------------------------------------------------------------------------------- /rank_conv/densenet_40/rank_conv20.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv20.npy -------------------------------------------------------------------------------- /rank_conv/densenet_40/rank_conv21.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv21.npy -------------------------------------------------------------------------------- /rank_conv/densenet_40/rank_conv22.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv22.npy -------------------------------------------------------------------------------- /rank_conv/densenet_40/rank_conv23.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv23.npy -------------------------------------------------------------------------------- /rank_conv/densenet_40/rank_conv24.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv24.npy -------------------------------------------------------------------------------- /rank_conv/densenet_40/rank_conv25.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv25.npy -------------------------------------------------------------------------------- /rank_conv/densenet_40/rank_conv26.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv26.npy -------------------------------------------------------------------------------- /rank_conv/densenet_40/rank_conv27.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv27.npy -------------------------------------------------------------------------------- /rank_conv/densenet_40/rank_conv28.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv28.npy -------------------------------------------------------------------------------- /rank_conv/densenet_40/rank_conv29.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv29.npy -------------------------------------------------------------------------------- /rank_conv/densenet_40/rank_conv3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv3.npy -------------------------------------------------------------------------------- /rank_conv/densenet_40/rank_conv30.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv30.npy -------------------------------------------------------------------------------- /rank_conv/densenet_40/rank_conv31.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv31.npy -------------------------------------------------------------------------------- /rank_conv/densenet_40/rank_conv32.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv32.npy -------------------------------------------------------------------------------- /rank_conv/densenet_40/rank_conv33.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv33.npy -------------------------------------------------------------------------------- /rank_conv/densenet_40/rank_conv34.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv34.npy -------------------------------------------------------------------------------- /rank_conv/densenet_40/rank_conv35.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv35.npy -------------------------------------------------------------------------------- /rank_conv/densenet_40/rank_conv36.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv36.npy -------------------------------------------------------------------------------- /rank_conv/densenet_40/rank_conv37.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv37.npy -------------------------------------------------------------------------------- /rank_conv/densenet_40/rank_conv38.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv38.npy -------------------------------------------------------------------------------- /rank_conv/densenet_40/rank_conv39.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv39.npy -------------------------------------------------------------------------------- /rank_conv/densenet_40/rank_conv4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv4.npy -------------------------------------------------------------------------------- /rank_conv/densenet_40/rank_conv5.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv5.npy -------------------------------------------------------------------------------- /rank_conv/densenet_40/rank_conv6.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv6.npy -------------------------------------------------------------------------------- /rank_conv/densenet_40/rank_conv7.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv7.npy -------------------------------------------------------------------------------- /rank_conv/densenet_40/rank_conv8.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv8.npy -------------------------------------------------------------------------------- /rank_conv/densenet_40/rank_conv9.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv9.npy -------------------------------------------------------------------------------- /rank_conv/googlenet/rank_conv1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv1.npy -------------------------------------------------------------------------------- /rank_conv/googlenet/rank_conv10_n1x1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv10_n1x1.npy -------------------------------------------------------------------------------- /rank_conv/googlenet/rank_conv10_n3x3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv10_n3x3.npy -------------------------------------------------------------------------------- /rank_conv/googlenet/rank_conv10_n5x5.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv10_n5x5.npy -------------------------------------------------------------------------------- /rank_conv/googlenet/rank_conv10_pool_planes.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv10_pool_planes.npy -------------------------------------------------------------------------------- /rank_conv/googlenet/rank_conv2_n1x1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv2_n1x1.npy -------------------------------------------------------------------------------- /rank_conv/googlenet/rank_conv2_n3x3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv2_n3x3.npy -------------------------------------------------------------------------------- /rank_conv/googlenet/rank_conv2_n5x5.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv2_n5x5.npy -------------------------------------------------------------------------------- /rank_conv/googlenet/rank_conv2_pool_planes.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv2_pool_planes.npy -------------------------------------------------------------------------------- /rank_conv/googlenet/rank_conv3_n1x1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv3_n1x1.npy -------------------------------------------------------------------------------- /rank_conv/googlenet/rank_conv3_n3x3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv3_n3x3.npy -------------------------------------------------------------------------------- /rank_conv/googlenet/rank_conv3_n5x5.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv3_n5x5.npy -------------------------------------------------------------------------------- /rank_conv/googlenet/rank_conv3_pool_planes.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv3_pool_planes.npy -------------------------------------------------------------------------------- /rank_conv/googlenet/rank_conv4_n1x1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv4_n1x1.npy -------------------------------------------------------------------------------- /rank_conv/googlenet/rank_conv4_n3x3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv4_n3x3.npy -------------------------------------------------------------------------------- /rank_conv/googlenet/rank_conv4_n5x5.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv4_n5x5.npy -------------------------------------------------------------------------------- /rank_conv/googlenet/rank_conv4_pool_planes.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv4_pool_planes.npy -------------------------------------------------------------------------------- /rank_conv/googlenet/rank_conv5_n1x1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv5_n1x1.npy -------------------------------------------------------------------------------- /rank_conv/googlenet/rank_conv5_n3x3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv5_n3x3.npy -------------------------------------------------------------------------------- /rank_conv/googlenet/rank_conv5_n5x5.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv5_n5x5.npy -------------------------------------------------------------------------------- /rank_conv/googlenet/rank_conv5_pool_planes.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv5_pool_planes.npy -------------------------------------------------------------------------------- /rank_conv/googlenet/rank_conv6_n1x1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv6_n1x1.npy -------------------------------------------------------------------------------- /rank_conv/googlenet/rank_conv6_n3x3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv6_n3x3.npy -------------------------------------------------------------------------------- /rank_conv/googlenet/rank_conv6_n5x5.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv6_n5x5.npy -------------------------------------------------------------------------------- /rank_conv/googlenet/rank_conv6_pool_planes.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv6_pool_planes.npy -------------------------------------------------------------------------------- /rank_conv/googlenet/rank_conv7_n1x1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv7_n1x1.npy -------------------------------------------------------------------------------- /rank_conv/googlenet/rank_conv7_n3x3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv7_n3x3.npy -------------------------------------------------------------------------------- /rank_conv/googlenet/rank_conv7_n5x5.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv7_n5x5.npy -------------------------------------------------------------------------------- /rank_conv/googlenet/rank_conv7_pool_planes.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv7_pool_planes.npy -------------------------------------------------------------------------------- /rank_conv/googlenet/rank_conv8_n1x1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv8_n1x1.npy -------------------------------------------------------------------------------- /rank_conv/googlenet/rank_conv8_n3x3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv8_n3x3.npy -------------------------------------------------------------------------------- /rank_conv/googlenet/rank_conv8_n5x5.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv8_n5x5.npy -------------------------------------------------------------------------------- /rank_conv/googlenet/rank_conv8_pool_planes.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv8_pool_planes.npy -------------------------------------------------------------------------------- /rank_conv/googlenet/rank_conv9_n1x1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv9_n1x1.npy -------------------------------------------------------------------------------- /rank_conv/googlenet/rank_conv9_n3x3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv9_n3x3.npy -------------------------------------------------------------------------------- /rank_conv/googlenet/rank_conv9_n5x5.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv9_n5x5.npy -------------------------------------------------------------------------------- /rank_conv/googlenet/rank_conv9_pool_planes.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv9_pool_planes.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv1.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv10.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv100.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv100.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv101.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv101.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv102.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv102.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv103.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv103.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv104.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv104.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv105.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv105.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv106.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv106.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv107.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv107.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv108.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv108.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv109.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv109.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv11.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv11.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv12.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv12.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv13.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv13.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv14.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv14.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv15.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv15.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv16.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv16.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv17.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv17.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv18.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv18.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv19.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv19.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv2.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv20.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv20.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv21.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv21.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv22.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv22.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv23.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv23.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv24.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv24.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv25.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv25.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv26.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv26.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv27.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv27.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv28.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv28.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv29.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv29.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv3.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv30.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv30.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv31.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv31.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv32.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv32.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv33.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv33.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv34.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv34.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv35.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv35.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv36.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv36.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv37.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv37.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv38.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv38.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv39.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv39.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv4.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv40.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv40.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv41.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv41.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv42.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv42.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv43.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv43.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv44.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv44.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv45.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv45.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv46.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv46.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv47.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv47.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv48.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv48.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv49.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv49.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv5.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv5.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv50.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv50.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv51.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv51.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv52.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv52.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv53.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv53.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv54.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv54.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv55.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv55.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv56.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv56.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv57.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv57.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv58.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv58.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv59.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv59.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv6.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv6.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv60.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv60.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv61.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv61.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv62.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv62.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv63.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv63.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv64.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv64.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv65.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv65.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv66.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv66.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv67.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv67.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv68.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv68.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv69.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv69.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv7.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv7.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv70.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv70.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv71.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv71.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv72.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv72.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv73.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv73.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv74.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv74.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv75.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv75.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv76.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv76.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv77.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv77.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv78.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv78.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv79.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv79.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv8.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv8.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv80.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv80.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv81.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv81.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv82.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv82.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv83.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv83.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv84.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv84.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv85.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv85.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv86.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv86.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv87.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv87.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv88.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv88.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv89.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv89.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv9.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv9.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv90.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv90.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv91.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv91.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv92.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv92.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv93.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv93.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv94.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv94.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv95.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv95.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv96.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv96.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv97.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv97.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv98.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv98.npy -------------------------------------------------------------------------------- /rank_conv/resnet_110/rank_conv99.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv99.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv1.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv10.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv11.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv11.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv12.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv12.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv13.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv13.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv14.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv14.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv15.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv15.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv16.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv16.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv17.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv17.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv18.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv18.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv19.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv19.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv2.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv20.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv20.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv21.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv21.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv22.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv22.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv23.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv23.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv24.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv24.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv25.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv25.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv26.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv26.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv27.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv27.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv28.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv28.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv29.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv29.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv3.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv30.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv30.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv31.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv31.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv32.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv32.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv33.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv33.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv34.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv34.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv35.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv35.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv36.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv36.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv37.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv37.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv38.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv38.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv39.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv39.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv4.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv40.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv40.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv41.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv41.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv42.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv42.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv43.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv43.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv44.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv44.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv45.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv45.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv46.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv46.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv47.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv47.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv48.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv48.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv49.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv49.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv5.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv5.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv50.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv50.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv51.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv51.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv52.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv52.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv53.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv53.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv6.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv6.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv7.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv7.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv8.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv8.npy -------------------------------------------------------------------------------- /rank_conv/resnet_50/rank_conv9.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv9.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv1.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv10.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv11.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv11.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv12.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv12.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv13.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv13.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv14.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv14.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv15.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv15.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv16.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv16.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv17.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv17.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv18.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv18.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv19.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv19.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv2.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv20.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv20.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv21.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv21.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv22.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv22.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv23.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv23.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv24.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv24.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv25.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv25.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv26.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv26.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv27.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv27.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv28.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv28.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv29.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv29.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv3.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv30.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv30.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv31.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv31.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv32.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv32.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv33.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv33.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv34.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv34.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv35.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv35.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv36.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv36.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv37.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv37.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv38.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv38.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv39.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv39.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv4.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv40.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv40.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv41.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv41.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv42.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv42.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv43.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv43.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv44.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv44.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv45.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv45.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv46.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv46.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv47.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv47.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv48.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv48.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv49.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv49.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv5.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv5.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv50.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv50.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv51.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv51.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv52.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv52.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv53.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv53.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv54.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv54.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv55.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv55.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv6.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv6.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv7.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv7.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv8.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv8.npy -------------------------------------------------------------------------------- /rank_conv/resnet_56/rank_conv9.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv9.npy -------------------------------------------------------------------------------- /rank_conv/vgg_16_bn/rank_conv1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/vgg_16_bn/rank_conv1.npy -------------------------------------------------------------------------------- /rank_conv/vgg_16_bn/rank_conv10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/vgg_16_bn/rank_conv10.npy -------------------------------------------------------------------------------- /rank_conv/vgg_16_bn/rank_conv11.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/vgg_16_bn/rank_conv11.npy -------------------------------------------------------------------------------- /rank_conv/vgg_16_bn/rank_conv12.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/vgg_16_bn/rank_conv12.npy -------------------------------------------------------------------------------- /rank_conv/vgg_16_bn/rank_conv2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/vgg_16_bn/rank_conv2.npy -------------------------------------------------------------------------------- /rank_conv/vgg_16_bn/rank_conv3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/vgg_16_bn/rank_conv3.npy -------------------------------------------------------------------------------- /rank_conv/vgg_16_bn/rank_conv4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/vgg_16_bn/rank_conv4.npy -------------------------------------------------------------------------------- /rank_conv/vgg_16_bn/rank_conv5.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/vgg_16_bn/rank_conv5.npy -------------------------------------------------------------------------------- /rank_conv/vgg_16_bn/rank_conv6.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/vgg_16_bn/rank_conv6.npy -------------------------------------------------------------------------------- /rank_conv/vgg_16_bn/rank_conv7.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/vgg_16_bn/rank_conv7.npy -------------------------------------------------------------------------------- /rank_conv/vgg_16_bn/rank_conv8.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/vgg_16_bn/rank_conv8.npy -------------------------------------------------------------------------------- /rank_conv/vgg_16_bn/rank_conv9.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/vgg_16_bn/rank_conv9.npy -------------------------------------------------------------------------------- /rank_generation.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.backends.cudnn as cudnn 4 | 5 | import torchvision 6 | import torchvision.transforms as transforms 7 | 8 | import os 9 | import argparse 10 | 11 | import data.imagenet as imagenet 12 | from models import * 13 | from utils import progress_bar 14 | import numpy as np 15 | 16 | parser = argparse.ArgumentParser(description='Rank extraction') 17 | 18 | parser.add_argument( 19 | '--data_dir', 20 | type=str, 21 | default='./data', 22 | help='dataset path') 23 | parser.add_argument( 24 | '--dataset', 25 | type=str, 26 | default='cifar10', 27 | choices=('cifar10','imagenet'), 28 | help='dataset') 29 | parser.add_argument( 30 | '--job_dir', 31 | type=str, 32 | default='result/tmp', 33 | help='The directory where the summaries will be stored.') 34 | parser.add_argument( 35 | '--arch', 36 | type=str, 37 | default='vgg_16_bn', 38 | choices=('resnet_50','vgg_16_bn','resnet_56','resnet_110','densenet_40','googlenet'), 39 | help='The architecture to prune') 40 | parser.add_argument( 41 | '--resume', 42 | type=str, 43 | default=None, 44 | help='load the model from the specified checkpoint') 45 | parser.add_argument( 46 | '--limit', 47 | type=int, 48 | default=5, 49 | help='The num of batch to get rank.') 50 | parser.add_argument( 51 | '--train_batch_size', 52 | type=int, 53 | default=128, 54 | help='Batch size for training.') 55 | parser.add_argument( 56 | '--eval_batch_size', 57 | type=int, 58 | default=100, 59 | help='Batch size for validation.') 60 | parser.add_argument( 61 | '--start_idx', 62 | type=int, 63 | default=0, 64 | help='The index of conv to start extract rank.') 65 | parser.add_argument( 66 | '--gpu', 67 | type=str, 68 | default='0', 69 | help='Select gpu to use') 70 | parser.add_argument( 71 | '--adjust_ckpt', 72 | action='store_true', 73 | help='adjust ckpt from pruned checkpoint') 74 | parser.add_argument( 75 | '--compress_rate', 76 | type=str, 77 | default=None, 78 | help='compress rate of each conv') 79 | 80 | 81 | args = parser.parse_args() 82 | 83 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 84 | os.environ['CUDA_VISIBLE_DEVICES'] = '0,1' 85 | cudnn.benchmark = True 86 | 87 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 88 | 89 | # Data 90 | print('==> Preparing data..') 91 | if args.dataset=='cifar10': 92 | transform_train = transforms.Compose([ 93 | transforms.RandomCrop(32, padding=4), 94 | transforms.RandomHorizontalFlip(), 95 | transforms.ToTensor(), 96 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 97 | ]) 98 | 99 | transform_test = transforms.Compose([ 100 | transforms.ToTensor(), 101 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 102 | ]) 103 | trainset = torchvision.datasets.CIFAR10(root=args.data_dir, train=True, download=True, transform=transform_train) 104 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.train_batch_size, shuffle=True, num_workers=2) 105 | 106 | testset = torchvision.datasets.CIFAR10(root=args.data_dir, train=False, download=True, transform=transform_test) 107 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2) 108 | elif args.dataset=='imagenet': 109 | data_tmp = imagenet.Data(args) 110 | trainloader = data_tmp.loader_train 111 | testloader = data_tmp.loader_test 112 | 113 | if args.compress_rate: 114 | import re 115 | cprate_str=args.compress_rate 116 | cprate_str_list=cprate_str.split('+') 117 | pat_cprate = re.compile(r'\d+\.\d*') 118 | pat_num = re.compile(r'\*\d+') 119 | cprate=[] 120 | for x in cprate_str_list: 121 | num=1 122 | find_num=re.findall(pat_num,x) 123 | if find_num: 124 | assert len(find_num) == 1 125 | num=int(find_num[0].replace('*','')) 126 | find_cprate = re.findall(pat_cprate, x) 127 | assert len(find_cprate)==1 128 | cprate+=[float(find_cprate[0])]*num 129 | 130 | compress_rate=cprate 131 | else: 132 | default_cprate={ 133 | 'vgg_16_bn': [0.7]*7+[0.1]*6, 134 | 'densenet_40': [0.0]+[0.1]*6+[0.7]*6+[0.0]+[0.1]*6+[0.7]*6+[0.0]+[0.1]*6+[0.7]*5+[0.0], 135 | 'googlenet': [0.10]+[0.7]+[0.5]+[0.8]*4+[0.5]+[0.6]*2, 136 | 'resnet_50':[0.2]+[0.8]*10+[0.8]*13+[0.55]*19+[0.45]*10, 137 | 'resnet_56':[0.1]+[0.60]*35+[0.0]*2+[0.6]*6+[0.4]*3+[0.1]+[0.4]+[0.1]+[0.4]+[0.1]+[0.4]+[0.1]+[0.4], 138 | 'resnet_110':[0.1]+[0.40]*36+[0.40]*36+[0.4]*36 139 | } 140 | compress_rate=default_cprate[args.arch] 141 | 142 | # Model 143 | print('==> Building model..') 144 | print(compress_rate) 145 | net = eval(args.arch)(compress_rate=compress_rate) 146 | net = net.to(device) 147 | 148 | if len(args.gpu)>1 and torch.cuda.is_available(): 149 | device_id = [] 150 | for i in range((len(args.gpu) + 1) // 2): 151 | device_id.append(i) 152 | net = torch.nn.DataParallel(net, device_ids=device_id) 153 | 154 | 155 | if args.resume: 156 | # Load checkpoint. 157 | print('==> Resuming from checkpoint..') 158 | checkpoint = torch.load(args.resume, map_location='cuda:'+args.gpu) 159 | from collections import OrderedDict 160 | new_state_dict = OrderedDict() 161 | if args.adjust_ckpt: 162 | for k, v in checkpoint.items(): 163 | new_state_dict[k.replace('module.', '')] = v 164 | else: 165 | for k, v in checkpoint['state_dict'].items(): 166 | new_state_dict[k.replace('module.', '')] = v 167 | net.load_state_dict(new_state_dict) 168 | 169 | 170 | criterion = nn.CrossEntropyLoss() 171 | feature_result = torch.tensor(0.) 172 | total = torch.tensor(0.) 173 | def get_feature_hook(self, input, output): 174 | global feature_result 175 | global entropy 176 | global total 177 | a = output.shape[0] 178 | b = output.shape[1] 179 | c = torch.tensor([torch.matrix_rank(output[i,j,:,:]).item() for i in range(a) for j in range(b)]) 180 | 181 | c = c.view(a, -1).float() 182 | c = c.sum(0) 183 | feature_result = feature_result * total + c 184 | total = total + a 185 | feature_result = feature_result / total 186 | 187 | def get_feature_hook_densenet(self, input, output): 188 | global feature_result 189 | global total 190 | a = output.shape[0] 191 | b = output.shape[1] 192 | c = torch.tensor([torch.matrix_rank(output[i,j,:,:]).item() for i in range(a) for j in range(b-12,b)]) 193 | 194 | c = c.view(a, -1).float() 195 | c = c.sum(0) 196 | feature_result = feature_result * total + c 197 | total = total + a 198 | feature_result = feature_result / total 199 | 200 | def get_feature_hook_googlenet(self, input, output): 201 | global feature_result 202 | global total 203 | a = output.shape[0] 204 | b = output.shape[1] 205 | c = torch.tensor([torch.matrix_rank(output[i,j,:,:]).item() for i in range(a) for j in range(b-12,b)]) 206 | 207 | c = c.view(a, -1).float() 208 | c = c.sum(0) 209 | feature_result = feature_result * total + c 210 | total = total + a 211 | feature_result = feature_result / total 212 | 213 | 214 | def test(): 215 | global best_acc 216 | net.eval() 217 | test_loss = 0 218 | correct = 0 219 | total = 0 220 | limit = args.limit 221 | 222 | with torch.no_grad(): 223 | for batch_idx, (inputs, targets) in enumerate(trainloader): 224 | if batch_idx >= limit: # use the first 6 batches to estimate the rank. 225 | break 226 | inputs, targets = inputs.to(device), targets.to(device) 227 | outputs = net(inputs) 228 | loss = criterion(outputs, targets) 229 | 230 | test_loss += loss.item() 231 | _, predicted = outputs.max(1) 232 | total += targets.size(0) 233 | correct += predicted.eq(targets).sum().item() 234 | 235 | progress_bar(batch_idx, limit, 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 236 | % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))#''' 237 | 238 | 239 | if args.arch=='vgg_16_bn': 240 | 241 | if len(args.gpu) > 1: 242 | relucfg = net.module.relucfg 243 | else: 244 | relucfg = net.relucfg 245 | 246 | for i, cov_id in enumerate(relucfg): 247 | cov_layer = net.features[cov_id] 248 | handler = cov_layer.register_forward_hook(get_feature_hook) 249 | test() 250 | handler.remove() 251 | 252 | if not os.path.isdir('rank_conv/'+args.arch+'_limit%d'%(args.limit)): 253 | os.mkdir('rank_conv/'+args.arch+'_limit%d'%(args.limit)) 254 | np.save('rank_conv/'+args.arch+'_limit%d'%(args.limit)+'/rank_conv' + str(i + 1) + '.npy', feature_result.numpy()) 255 | 256 | feature_result = torch.tensor(0.) 257 | total = torch.tensor(0.) 258 | 259 | elif args.arch=='resnet_56': 260 | 261 | cov_layer = eval('net.relu') 262 | handler = cov_layer.register_forward_hook(get_feature_hook) 263 | test() 264 | handler.remove() 265 | 266 | if not os.path.isdir('rank_conv/' + args.arch+'_limit%d'%(args.limit)): 267 | os.mkdir('rank_conv/' + args.arch+'_limit%d'%(args.limit)) 268 | np.save('rank_conv/' + args.arch+'_limit%d'%(args.limit)+ '/rank_conv%d' % (1) + '.npy', feature_result.numpy()) 269 | feature_result = torch.tensor(0.) 270 | total = torch.tensor(0.) 271 | 272 | # ResNet56 per block 273 | cnt=1 274 | for i in range(3): 275 | block = eval('net.layer%d' % (i + 1)) 276 | for j in range(9): 277 | cov_layer = block[j].relu1 278 | handler = cov_layer.register_forward_hook(get_feature_hook) 279 | test() 280 | handler.remove() 281 | np.save('rank_conv/' + args.arch +'_limit%d'%(args.limit)+ '/rank_conv%d'%(cnt + 1)+'.npy', feature_result.numpy()) 282 | cnt+=1 283 | feature_result = torch.tensor(0.) 284 | total = torch.tensor(0.) 285 | 286 | cov_layer = block[j].relu2 287 | handler = cov_layer.register_forward_hook(get_feature_hook) 288 | test() 289 | handler.remove() 290 | np.save('rank_conv/' + args.arch +'_limit%d'%(args.limit)+ '/rank_conv%d'%(cnt + 1)+'.npy', feature_result.numpy()) 291 | cnt += 1 292 | feature_result = torch.tensor(0.) 293 | total = torch.tensor(0.) 294 | 295 | elif args.arch=='densenet_40': 296 | 297 | if not os.path.isdir('rank_conv/' + args.arch+'_limit%d'%(args.limit)): 298 | os.mkdir('rank_conv/' + args.arch+'_limit%d'%(args.limit)) 299 | 300 | feature_result = torch.tensor(0.) 301 | total = torch.tensor(0.) 302 | 303 | # Densenet per block & transition 304 | for i in range(3): 305 | dense = eval('net.dense%d' % (i + 1)) 306 | for j in range(12): 307 | cov_layer = dense[j].relu 308 | if j==0: 309 | handler = cov_layer.register_forward_hook(get_feature_hook) 310 | else: 311 | handler = cov_layer.register_forward_hook(get_feature_hook_densenet) 312 | test() 313 | handler.remove() 314 | 315 | np.save('rank_conv/' + args.arch +'_limit%d'%(args.limit) + '/rank_conv%d'%(13*i+j+1)+'.npy', feature_result.numpy()) 316 | feature_result = torch.tensor(0.) 317 | total = torch.tensor(0.) 318 | 319 | if i<2: 320 | trans=eval('net.trans%d' % (i + 1)) 321 | cov_layer = trans.relu 322 | 323 | handler = cov_layer.register_forward_hook(get_feature_hook_densenet) 324 | test() 325 | handler.remove() 326 | 327 | np.save('rank_conv/' + args.arch +'_limit%d'%(args.limit) + '/rank_conv%d' % (13 * (i+1)) + '.npy', feature_result.numpy()) 328 | feature_result = torch.tensor(0.) 329 | total = torch.tensor(0.)#''' 330 | 331 | cov_layer = net.relu 332 | handler = cov_layer.register_forward_hook(get_feature_hook_densenet) 333 | test() 334 | handler.remove() 335 | np.save('rank_conv/' + args.arch +'_limit%d'%(args.limit) + '/rank_conv%d' % (39) + '.npy', feature_result.numpy()) 336 | feature_result = torch.tensor(0.) 337 | total = torch.tensor(0.) 338 | 339 | elif args.arch=='googlenet': 340 | 341 | if not os.path.isdir('rank_conv/' + args.arch+'_limit%d'%(args.limit)): 342 | os.mkdir('rank_conv/' + args.arch+'_limit%d'%(args.limit)) 343 | feature_result = torch.tensor(0.) 344 | total = torch.tensor(0.) 345 | 346 | cov_list=['pre_layers', 347 | 'inception_a3', 348 | 'maxpool1', 349 | 'inception_a4', 350 | 'inception_b4', 351 | 'inception_c4', 352 | 'inception_d4', 353 | 'maxpool2', 354 | 'inception_a5', 355 | 'inception_b5', 356 | ] 357 | 358 | # branch type 359 | tp_list=['n1x1','n3x3','n5x5','pool_planes'] 360 | for idx, cov in enumerate(cov_list): 361 | 362 | if idx0: 371 | for idx1,tp in enumerate(tp_list): 372 | if idx1==3: 373 | np.save('rank_conv/' + args.arch+'_limit%d'%(args.limit) + '/rank_conv%d_'%(idx+1)+tp+'.npy', 374 | feature_result[sum(net.filters[idx-1][:-1]) : sum(net.filters[idx-1][:])].numpy()) 375 | #elif idx1==0: 376 | # np.save('rank_conv1/' + args.arch + '/rank_conv%d_'%(idx+1)+tp+'.npy', 377 | # feature_result[0 : sum(net.filters[idx-1][:1])].numpy()) 378 | else: 379 | np.save('rank_conv/' + args.arch+'_limit%d'%(args.limit) + '/rank_conv%d_' % (idx + 1) + tp + '.npy', 380 | feature_result[sum(net.filters[idx-1][:idx1]) : sum(net.filters[idx-1][:idx1+1])].numpy()) 381 | else: 382 | np.save('rank_conv/' + args.arch+'_limit%d'%(args.limit) + '/rank_conv%d' % (idx + 1) + '.npy',feature_result.numpy()) 383 | feature_result = torch.tensor(0.) 384 | total = torch.tensor(0.) 385 | 386 | elif args.arch=='resnet_110': 387 | 388 | cov_layer = eval('net.relu') 389 | handler = cov_layer.register_forward_hook(get_feature_hook) 390 | test() 391 | handler.remove() 392 | 393 | if not os.path.isdir('rank_conv/' + args.arch+'_limit%d'%(args.limit)): 394 | os.mkdir('rank_conv/' + args.arch+'_limit%d'%(args.limit)) 395 | np.save('rank_conv/' + args.arch+'_limit%d'%(args.limit) + '/rank_conv%d' % (1) + '.npy', feature_result.numpy()) 396 | feature_result = torch.tensor(0.) 397 | total = torch.tensor(0.) 398 | 399 | cnt = 1 400 | # ResNet110 per block 401 | for i in range(3): 402 | block = eval('net.layer%d' % (i + 1)) 403 | for j in range(18): 404 | cov_layer = block[j].relu1 405 | handler = cov_layer.register_forward_hook(get_feature_hook) 406 | test() 407 | handler.remove() 408 | np.save('rank_conv/' + args.arch + '_limit%d' % (args.limit) + '/rank_conv%d' % ( 409 | cnt + 1) + '.npy', feature_result.numpy()) 410 | cnt += 1 411 | feature_result = torch.tensor(0.) 412 | total = torch.tensor(0.) 413 | 414 | cov_layer = block[j].relu2 415 | handler = cov_layer.register_forward_hook(get_feature_hook) 416 | test() 417 | handler.remove() 418 | np.save('rank_conv/' + args.arch + '_limit%d' % (args.limit) + '/rank_conv%d' % ( 419 | cnt + 1) + '.npy', feature_result.numpy()) 420 | cnt += 1 421 | feature_result = torch.tensor(0.) 422 | total = torch.tensor(0.) 423 | 424 | elif args.arch=='resnet_50': 425 | 426 | cov_layer = eval('net.maxpool') 427 | handler = cov_layer.register_forward_hook(get_feature_hook) 428 | test() 429 | handler.remove() 430 | 431 | if not os.path.isdir('rank_conv/' + args.arch+'_limit%d'%(args.limit)): 432 | os.mkdir('rank_conv/' + args.arch+'_limit%d'%(args.limit)) 433 | np.save('rank_conv/' + args.arch+'_limit%d'%(args.limit) + '/rank_conv%d' % (1) + '.npy', feature_result.numpy()) 434 | feature_result = torch.tensor(0.) 435 | total = torch.tensor(0.) 436 | 437 | # ResNet50 per bottleneck 438 | cnt=1 439 | for i in range(4): 440 | block = eval('net.layer%d' % (i + 1)) 441 | for j in range(net.num_blocks[i]): 442 | cov_layer = block[j].relu1 443 | handler = cov_layer.register_forward_hook(get_feature_hook) 444 | test() 445 | handler.remove() 446 | np.save('rank_conv/' + args.arch+'_limit%d'%(args.limit) + '/rank_conv%d'%(cnt+1)+'.npy', feature_result.numpy()) 447 | cnt+=1 448 | feature_result = torch.tensor(0.) 449 | total = torch.tensor(0.) 450 | 451 | cov_layer = block[j].relu2 452 | handler = cov_layer.register_forward_hook(get_feature_hook) 453 | test() 454 | handler.remove() 455 | np.save('rank_conv/' + args.arch + '_limit%d' % (args.limit) + '/rank_conv%d' % (cnt + 1) + '.npy', 456 | feature_result.numpy()) 457 | cnt += 1 458 | feature_result = torch.tensor(0.) 459 | total = torch.tensor(0.) 460 | 461 | cov_layer = block[j].relu3 462 | handler = cov_layer.register_forward_hook(get_feature_hook) 463 | test() 464 | handler.remove() 465 | if j==0: 466 | np.save('rank_conv/' + args.arch + '_limit%d' % (args.limit) + '/rank_conv%d' % (cnt + 1) + '.npy', 467 | feature_result.numpy())#shortcut conv 468 | cnt += 1 469 | np.save('rank_conv/' + args.arch + '_limit%d' % (args.limit) + '/rank_conv%d' % (cnt + 1) + '.npy', 470 | feature_result.numpy())#conv3 471 | cnt += 1 472 | feature_result = torch.tensor(0.) 473 | total = torch.tensor(0.) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from __future__ import absolute_import 4 | import os 5 | import sys 6 | import time 7 | import logging 8 | import datetime 9 | import torch 10 | from pathlib import Path 11 | 12 | 13 | def get_logger(file_path): 14 | """ Make python logger """ 15 | # [!] Since tensorboardX use default logger (e.g. logging.info()), we should use custom logger 16 | logger = logging.getLogger('kd') 17 | log_format = '%(asctime)s | %(message)s' 18 | formatter = logging.Formatter(log_format, datefmt='%m/%d %I:%M:%S %p') 19 | file_handler = logging.FileHandler(file_path) 20 | file_handler.setFormatter(formatter) 21 | stream_handler = logging.StreamHandler() 22 | stream_handler.setFormatter(formatter) 23 | 24 | logger.addHandler(file_handler) 25 | logger.addHandler(stream_handler) 26 | logger.setLevel(logging.INFO) 27 | 28 | return logger 29 | 30 | 31 | class AverageMeter(object): 32 | """Computes and stores the average and current value""" 33 | def __init__(self): 34 | self.reset() 35 | 36 | def reset(self): 37 | self.val = 0.0 38 | self.avg = 0.0 39 | self.sum = 0.0 40 | self.count = 0 41 | 42 | def update(self, val, n=1): 43 | self.val = val 44 | self.sum += val * n 45 | self.count += n 46 | self.avg = self.sum / self.count 47 | 48 | 49 | def accuracy(output, target, topk=(1,)): 50 | """Computes the precision@k for the specified values of k""" 51 | with torch.no_grad(): 52 | maxk = max(topk) 53 | batch_size = target.size(0) 54 | 55 | _, pred = output.topk(maxk, 1, True, True) 56 | pred = pred.t() 57 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 58 | 59 | res = [] 60 | for k in topk: 61 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 62 | res.append(correct_k.mul_(100.0 / batch_size)) 63 | return res 64 | 65 | 66 | class checkpoint(): 67 | def __init__(self, args): 68 | now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S') 69 | today = datetime.date.today() 70 | 71 | self.args = args 72 | self.job_dir = Path(args.job_dir) 73 | self.run_dir = self.job_dir / 'run' 74 | print(args.job_dir) 75 | 76 | def _make_dir(path): 77 | if not os.path.exists(path): os.makedirs(path) 78 | 79 | _make_dir(self.job_dir) 80 | _make_dir(self.run_dir) 81 | 82 | config_dir = self.job_dir / 'config.txt' 83 | with open(config_dir, 'w') as f: 84 | f.write(now + '\n\n') 85 | for arg in vars(args): 86 | f.write('{}: {}\n'.format(arg, getattr(args, arg))) 87 | f.write('\n') 88 | 89 | 90 | def print_params(config, prtf=print): 91 | prtf("") 92 | prtf("Parameters:") 93 | for attr, value in sorted(config.items()): 94 | prtf("{}={}".format(attr.upper(), value)) 95 | prtf("") 96 | 97 | 98 | _, term_width = os.popen('stty size', 'r').read().split() 99 | term_width = int(term_width) 100 | 101 | TOTAL_BAR_LENGTH = 65. 102 | last_time = time.time() 103 | begin_time = last_time 104 | def progress_bar(current, total, msg=None): 105 | global last_time, begin_time 106 | if current == 0: 107 | begin_time = time.time() # Reset for new bar. 108 | 109 | cur_len = int(TOTAL_BAR_LENGTH*current/total) 110 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 111 | 112 | sys.stdout.write(' [') 113 | for i in range(cur_len): 114 | sys.stdout.write('=') 115 | sys.stdout.write('>') 116 | for i in range(rest_len): 117 | sys.stdout.write('.') 118 | sys.stdout.write(']') 119 | 120 | cur_time = time.time() 121 | step_time = cur_time - last_time 122 | last_time = cur_time 123 | tot_time = cur_time - begin_time 124 | 125 | L = [] 126 | L.append(' Step: %s' % format_time(step_time)) 127 | L.append(' | Tot: %s' % format_time(tot_time)) 128 | if msg: 129 | L.append(' | ' + msg) 130 | 131 | msg = ''.join(L) 132 | sys.stdout.write(msg) 133 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 134 | sys.stdout.write(' ') 135 | 136 | # Go back to the center of the bar. 137 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): 138 | sys.stdout.write('\b') 139 | sys.stdout.write(' %d/%d ' % (current+1, total)) 140 | 141 | if current < total-1: 142 | sys.stdout.write('\r') 143 | else: 144 | sys.stdout.write('\n') 145 | sys.stdout.flush() 146 | 147 | 148 | def format_time(seconds): 149 | days = int(seconds / 3600/24) 150 | seconds = seconds - days*3600*24 151 | hours = int(seconds / 3600) 152 | seconds = seconds - hours*3600 153 | minutes = int(seconds / 60) 154 | seconds = seconds - minutes*60 155 | secondsf = int(seconds) 156 | seconds = seconds - secondsf 157 | millis = int(seconds*1000) 158 | 159 | f = '' 160 | i = 1 161 | if days > 0: 162 | f += str(days) + 'D' 163 | i += 1 164 | if hours > 0 and i <= 2: 165 | f += str(hours) + 'h' 166 | i += 1 167 | if minutes > 0 and i <= 2: 168 | f += str(minutes) + 'm' 169 | i += 1 170 | if secondsf > 0 and i <= 2: 171 | f += str(secondsf) + 's' 172 | i += 1 173 | if millis > 0 and i <= 2: 174 | f += str(millis) + 'ms' 175 | i += 1 176 | if f == '': 177 | f = '0ms' 178 | return f 179 | --------------------------------------------------------------------------------