├── .gitignore
├── README.md
├── cal_flops_params.py
├── data
├── __init__.py
├── cifar10.py
└── imagenet.py
├── evaluate.py
├── get_flops.py
├── img
├── framework.jpeg
└── framework.png
├── main.py
├── mask.py
├── models
├── __init__.py
├── densenet_cifar.py
├── googlenet_cifar.py
├── resnet_cifar.py
├── resnet_imagenet.py
└── vgg.py
├── rank_conv
├── densenet_40
│ ├── rank_conv1.npy
│ ├── rank_conv10.npy
│ ├── rank_conv11.npy
│ ├── rank_conv12.npy
│ ├── rank_conv13.npy
│ ├── rank_conv14.npy
│ ├── rank_conv15.npy
│ ├── rank_conv16.npy
│ ├── rank_conv17.npy
│ ├── rank_conv18.npy
│ ├── rank_conv19.npy
│ ├── rank_conv2.npy
│ ├── rank_conv20.npy
│ ├── rank_conv21.npy
│ ├── rank_conv22.npy
│ ├── rank_conv23.npy
│ ├── rank_conv24.npy
│ ├── rank_conv25.npy
│ ├── rank_conv26.npy
│ ├── rank_conv27.npy
│ ├── rank_conv28.npy
│ ├── rank_conv29.npy
│ ├── rank_conv3.npy
│ ├── rank_conv30.npy
│ ├── rank_conv31.npy
│ ├── rank_conv32.npy
│ ├── rank_conv33.npy
│ ├── rank_conv34.npy
│ ├── rank_conv35.npy
│ ├── rank_conv36.npy
│ ├── rank_conv37.npy
│ ├── rank_conv38.npy
│ ├── rank_conv39.npy
│ ├── rank_conv4.npy
│ ├── rank_conv5.npy
│ ├── rank_conv6.npy
│ ├── rank_conv7.npy
│ ├── rank_conv8.npy
│ └── rank_conv9.npy
├── googlenet
│ ├── rank_conv1.npy
│ ├── rank_conv10_n1x1.npy
│ ├── rank_conv10_n3x3.npy
│ ├── rank_conv10_n5x5.npy
│ ├── rank_conv10_pool_planes.npy
│ ├── rank_conv2_n1x1.npy
│ ├── rank_conv2_n3x3.npy
│ ├── rank_conv2_n5x5.npy
│ ├── rank_conv2_pool_planes.npy
│ ├── rank_conv3_n1x1.npy
│ ├── rank_conv3_n3x3.npy
│ ├── rank_conv3_n5x5.npy
│ ├── rank_conv3_pool_planes.npy
│ ├── rank_conv4_n1x1.npy
│ ├── rank_conv4_n3x3.npy
│ ├── rank_conv4_n5x5.npy
│ ├── rank_conv4_pool_planes.npy
│ ├── rank_conv5_n1x1.npy
│ ├── rank_conv5_n3x3.npy
│ ├── rank_conv5_n5x5.npy
│ ├── rank_conv5_pool_planes.npy
│ ├── rank_conv6_n1x1.npy
│ ├── rank_conv6_n3x3.npy
│ ├── rank_conv6_n5x5.npy
│ ├── rank_conv6_pool_planes.npy
│ ├── rank_conv7_n1x1.npy
│ ├── rank_conv7_n3x3.npy
│ ├── rank_conv7_n5x5.npy
│ ├── rank_conv7_pool_planes.npy
│ ├── rank_conv8_n1x1.npy
│ ├── rank_conv8_n3x3.npy
│ ├── rank_conv8_n5x5.npy
│ ├── rank_conv8_pool_planes.npy
│ ├── rank_conv9_n1x1.npy
│ ├── rank_conv9_n3x3.npy
│ ├── rank_conv9_n5x5.npy
│ └── rank_conv9_pool_planes.npy
├── resnet_110
│ ├── rank_conv1.npy
│ ├── rank_conv10.npy
│ ├── rank_conv100.npy
│ ├── rank_conv101.npy
│ ├── rank_conv102.npy
│ ├── rank_conv103.npy
│ ├── rank_conv104.npy
│ ├── rank_conv105.npy
│ ├── rank_conv106.npy
│ ├── rank_conv107.npy
│ ├── rank_conv108.npy
│ ├── rank_conv109.npy
│ ├── rank_conv11.npy
│ ├── rank_conv12.npy
│ ├── rank_conv13.npy
│ ├── rank_conv14.npy
│ ├── rank_conv15.npy
│ ├── rank_conv16.npy
│ ├── rank_conv17.npy
│ ├── rank_conv18.npy
│ ├── rank_conv19.npy
│ ├── rank_conv2.npy
│ ├── rank_conv20.npy
│ ├── rank_conv21.npy
│ ├── rank_conv22.npy
│ ├── rank_conv23.npy
│ ├── rank_conv24.npy
│ ├── rank_conv25.npy
│ ├── rank_conv26.npy
│ ├── rank_conv27.npy
│ ├── rank_conv28.npy
│ ├── rank_conv29.npy
│ ├── rank_conv3.npy
│ ├── rank_conv30.npy
│ ├── rank_conv31.npy
│ ├── rank_conv32.npy
│ ├── rank_conv33.npy
│ ├── rank_conv34.npy
│ ├── rank_conv35.npy
│ ├── rank_conv36.npy
│ ├── rank_conv37.npy
│ ├── rank_conv38.npy
│ ├── rank_conv39.npy
│ ├── rank_conv4.npy
│ ├── rank_conv40.npy
│ ├── rank_conv41.npy
│ ├── rank_conv42.npy
│ ├── rank_conv43.npy
│ ├── rank_conv44.npy
│ ├── rank_conv45.npy
│ ├── rank_conv46.npy
│ ├── rank_conv47.npy
│ ├── rank_conv48.npy
│ ├── rank_conv49.npy
│ ├── rank_conv5.npy
│ ├── rank_conv50.npy
│ ├── rank_conv51.npy
│ ├── rank_conv52.npy
│ ├── rank_conv53.npy
│ ├── rank_conv54.npy
│ ├── rank_conv55.npy
│ ├── rank_conv56.npy
│ ├── rank_conv57.npy
│ ├── rank_conv58.npy
│ ├── rank_conv59.npy
│ ├── rank_conv6.npy
│ ├── rank_conv60.npy
│ ├── rank_conv61.npy
│ ├── rank_conv62.npy
│ ├── rank_conv63.npy
│ ├── rank_conv64.npy
│ ├── rank_conv65.npy
│ ├── rank_conv66.npy
│ ├── rank_conv67.npy
│ ├── rank_conv68.npy
│ ├── rank_conv69.npy
│ ├── rank_conv7.npy
│ ├── rank_conv70.npy
│ ├── rank_conv71.npy
│ ├── rank_conv72.npy
│ ├── rank_conv73.npy
│ ├── rank_conv74.npy
│ ├── rank_conv75.npy
│ ├── rank_conv76.npy
│ ├── rank_conv77.npy
│ ├── rank_conv78.npy
│ ├── rank_conv79.npy
│ ├── rank_conv8.npy
│ ├── rank_conv80.npy
│ ├── rank_conv81.npy
│ ├── rank_conv82.npy
│ ├── rank_conv83.npy
│ ├── rank_conv84.npy
│ ├── rank_conv85.npy
│ ├── rank_conv86.npy
│ ├── rank_conv87.npy
│ ├── rank_conv88.npy
│ ├── rank_conv89.npy
│ ├── rank_conv9.npy
│ ├── rank_conv90.npy
│ ├── rank_conv91.npy
│ ├── rank_conv92.npy
│ ├── rank_conv93.npy
│ ├── rank_conv94.npy
│ ├── rank_conv95.npy
│ ├── rank_conv96.npy
│ ├── rank_conv97.npy
│ ├── rank_conv98.npy
│ └── rank_conv99.npy
├── resnet_50
│ ├── rank_conv1.npy
│ ├── rank_conv10.npy
│ ├── rank_conv11.npy
│ ├── rank_conv12.npy
│ ├── rank_conv13.npy
│ ├── rank_conv14.npy
│ ├── rank_conv15.npy
│ ├── rank_conv16.npy
│ ├── rank_conv17.npy
│ ├── rank_conv18.npy
│ ├── rank_conv19.npy
│ ├── rank_conv2.npy
│ ├── rank_conv20.npy
│ ├── rank_conv21.npy
│ ├── rank_conv22.npy
│ ├── rank_conv23.npy
│ ├── rank_conv24.npy
│ ├── rank_conv25.npy
│ ├── rank_conv26.npy
│ ├── rank_conv27.npy
│ ├── rank_conv28.npy
│ ├── rank_conv29.npy
│ ├── rank_conv3.npy
│ ├── rank_conv30.npy
│ ├── rank_conv31.npy
│ ├── rank_conv32.npy
│ ├── rank_conv33.npy
│ ├── rank_conv34.npy
│ ├── rank_conv35.npy
│ ├── rank_conv36.npy
│ ├── rank_conv37.npy
│ ├── rank_conv38.npy
│ ├── rank_conv39.npy
│ ├── rank_conv4.npy
│ ├── rank_conv40.npy
│ ├── rank_conv41.npy
│ ├── rank_conv42.npy
│ ├── rank_conv43.npy
│ ├── rank_conv44.npy
│ ├── rank_conv45.npy
│ ├── rank_conv46.npy
│ ├── rank_conv47.npy
│ ├── rank_conv48.npy
│ ├── rank_conv49.npy
│ ├── rank_conv5.npy
│ ├── rank_conv50.npy
│ ├── rank_conv51.npy
│ ├── rank_conv52.npy
│ ├── rank_conv53.npy
│ ├── rank_conv6.npy
│ ├── rank_conv7.npy
│ ├── rank_conv8.npy
│ └── rank_conv9.npy
├── resnet_56
│ ├── rank_conv1.npy
│ ├── rank_conv10.npy
│ ├── rank_conv11.npy
│ ├── rank_conv12.npy
│ ├── rank_conv13.npy
│ ├── rank_conv14.npy
│ ├── rank_conv15.npy
│ ├── rank_conv16.npy
│ ├── rank_conv17.npy
│ ├── rank_conv18.npy
│ ├── rank_conv19.npy
│ ├── rank_conv2.npy
│ ├── rank_conv20.npy
│ ├── rank_conv21.npy
│ ├── rank_conv22.npy
│ ├── rank_conv23.npy
│ ├── rank_conv24.npy
│ ├── rank_conv25.npy
│ ├── rank_conv26.npy
│ ├── rank_conv27.npy
│ ├── rank_conv28.npy
│ ├── rank_conv29.npy
│ ├── rank_conv3.npy
│ ├── rank_conv30.npy
│ ├── rank_conv31.npy
│ ├── rank_conv32.npy
│ ├── rank_conv33.npy
│ ├── rank_conv34.npy
│ ├── rank_conv35.npy
│ ├── rank_conv36.npy
│ ├── rank_conv37.npy
│ ├── rank_conv38.npy
│ ├── rank_conv39.npy
│ ├── rank_conv4.npy
│ ├── rank_conv40.npy
│ ├── rank_conv41.npy
│ ├── rank_conv42.npy
│ ├── rank_conv43.npy
│ ├── rank_conv44.npy
│ ├── rank_conv45.npy
│ ├── rank_conv46.npy
│ ├── rank_conv47.npy
│ ├── rank_conv48.npy
│ ├── rank_conv49.npy
│ ├── rank_conv5.npy
│ ├── rank_conv50.npy
│ ├── rank_conv51.npy
│ ├── rank_conv52.npy
│ ├── rank_conv53.npy
│ ├── rank_conv54.npy
│ ├── rank_conv55.npy
│ ├── rank_conv6.npy
│ ├── rank_conv7.npy
│ ├── rank_conv8.npy
│ └── rank_conv9.npy
└── vgg_16_bn
│ ├── rank_conv1.npy
│ ├── rank_conv10.npy
│ ├── rank_conv11.npy
│ ├── rank_conv12.npy
│ ├── rank_conv2.npy
│ ├── rank_conv3.npy
│ ├── rank_conv4.npy
│ ├── rank_conv5.npy
│ ├── rank_conv6.npy
│ ├── rank_conv7.npy
│ ├── rank_conv8.npy
│ └── rank_conv9.npy
├── rank_generation.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | */.DS_Store
3 | .idea
4 | __pycache__
5 | command_cifar.md
6 | models/__pycache__
7 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # HRank: Filter Pruning using High-Rank Feature Map ([Link](https://128.84.21.199/abs/2002.10179)).
2 |
3 | Pytorch implementation of HRank (CVPR 2020, Oral).
4 |
5 | This repository was what we used during the preparation of CVPR 2020. You can enjoy it!
6 |
7 | A better version of HRank (in both fine-tuning efficiency and accuracy performance) are released at [HRankPlus](https://github.com/lmbxmu/HRankPlus). It is highly suggested!
8 |
9 |
10 |

11 |
12 | Framework of HRank. In the left column, we first use images to run through the convolutional layers to get the feature maps. In the middle column, we then estimate the rank of each feature map, which is used as the criteria for pruning. The right column shows the pruning (the red filters), and fine-tuning where the green filters are updated and the blue filters are frozen.
13 |
14 |
15 | ## Tips
16 |
17 | Any problem, please contact the authors via emails: lmbxmu@stu.xmu.edu.cn or ethan.zhangyc@gmail.com. Do not post issues with github as much as possible, just in case that I could not receive the emails from github thus ignore the posted issues.
18 |
19 |
20 | ## Citation
21 | If you find HRank useful in your research, please consider citing:
22 |
23 | ```
24 | @inproceedings{lin2020hrank,
25 | title={HRank: Filter Pruning using High-Rank Feature Map},
26 | author={Lin, Mingbao and Ji, Rongrong and Wang, Yan and Zhang, Yichen and Zhang, Baochang and Tian, Yonghong and Shao, Ling},
27 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
28 | pages={1529--1538},
29 | year={2020}
30 | }
31 | ```
32 |
33 | ## Running Code
34 |
35 | In this code, you can run our models on CIFAR-10 and ImageNet dataset. The code has been tested by Python 3.6, Pytorch 1.0 and CUDA 9.0 on Ubuntu 16.04.
36 |
37 |
38 | ### Rank Generation
39 |
40 | ```shell
41 | python rank_generation.py \
42 | --resume [pre-trained model dir] \
43 | --arch [model arch name] \
44 | --limit [batch numbers] \
45 | --gpu [gpu_id]
46 |
47 | ```
48 |
49 |
50 |
51 | ### Model Training
52 |
53 | For the ease of reproducibility. we provide some of the experimental results and the corresponding pruned rate of every layer as belows:
54 | #### Attention! The actual pruning rates are much higher than these presented in the paper since we do not count the next-layer channel removal (For example, if 50 filters are removed in the first layer, then the corresponding 50 channels in the second-layer filters should be removed as well).
55 |
56 | ##### 1. VGG-16
57 |
58 | | Params | Flops | Accuracy |
59 | |--------------|---------------|----------|
60 | | 2.64M(82.1%) | 108.61M(65.3%)| 92.34% |
61 |
62 | ```shell
63 | python main.py \
64 | --job_dir ./result/vgg_16_bn/[folder name] \
65 | --resume [pre-trained model dir] \
66 | --arch vgg_16_bn \
67 | --compress_rate [0.95]+[0.5]*6+[0.9]*4+[0.8]*2 \
68 | --gpu [gpu_id]
69 | ```
70 | ##### 2. ResNet56
71 |
72 | | Params | Flops | Accuracy |
73 | |--------------|--------------|----------|
74 | | 0.49M(42.4%) | 62.72M(50.0%)| 93.17% |
75 |
76 | ```shell
77 | python main.py \
78 | --job_dir ./result/resnet_56/[folder name] \
79 | --resume [pre-trained model dir] \
80 | --arch resnet_56 \
81 | --compress_rate [0.1]+[0.60]*35+[0.0]*2+[0.6]*6+[0.4]*3+[0.1]+[0.4]+[0.1]+[0.4]+[0.1]+[0.4]+[0.1]+[0.4] \
82 | --gpu [gpu_id]
83 | ```
84 | ##### 3. ResNet110
85 | Note that, in the paper, we mistakenly regarded the FLOPs as 148.70M(41.2%). We apologize for it and We will update the arXiv version as soon as possible.
86 |
87 | | Params | Flops | Accuracy |
88 | |--------------|--------------|----------|
89 | | 1.04M(38.7%) |156.90M(37.9%)| 94.23% |
90 |
91 | ```shell
92 | python main.py \
93 | --job_dir ./result/resnet_110/[folder name] \
94 | --resume [pre-trained model dir] \
95 | --arch resnet_110 \
96 | --compress_rate [0.1]+[0.40]*36+[0.40]*36+[0.4]*36 \
97 | --gpu [gpu_id]
98 | ```
99 | ##### 4. DenseNet40
100 |
101 | | Params | Flops | Accuracy |
102 | |--------------|--------------|----------|
103 | | 0.66M(36.5%) |167.41M(40.8%)| 94.24% |
104 |
105 | ```shell
106 | python main.py \
107 | --job_dir ./result/densenet_40/[folder name] \
108 | --resume [pre-trained model dir] \
109 | --arch densenet_40 \
110 | --compress_rate [0.0]+[0.1]*6+[0.7]*6+[0.0]+[0.1]*6+[0.7]*6+[0.0]+[0.1]*6+[0.7]*5+[0.0] \
111 | --gpu [gpu_id]
112 | ```
113 | ##### 5. GoogLeNet
114 |
115 | | Params | Flops | Accuracy |
116 | |--------------|--------------|----------|
117 | | 1.86M(69.8%) | 0.45B(70.4%)| 94.07% |
118 |
119 | ```shell
120 | python main.py \
121 | --job_dir ./result/googlenet/[folder name] \
122 | --resume [pre-trained model dir] \
123 | --arch googlenet \
124 | --compress_rate [0.10]+[0.8]*5+[0.85]+[0.8]*3 \
125 | --gpu [gpu_id]
126 | ```
127 | ##### 6. ResNet50
128 |
129 | | Params | Flops| Acc Top1 |Acc Top5 |
130 | |---------|------|----------|----------|
131 | | 13.77M |1.55B | 71.98%| 91.01% |
132 |
133 | ```shell
134 | python main.py \
135 | --dataset imagenet \
136 | --data_dir [ImageNet dataset dir] \
137 | --job_dir./result/resnet_50/[folder name] \
138 | --resume [pre-trained model dir] \
139 | --arch resnet_50 \
140 | --compress_rate [0.2]+[0.8]*10+[0.8]*13+[0.55]*19+[0.45]*10 \
141 | --gpu [gpu_id]
142 | ```
143 |
144 | After training, checkpoints and loggers can be found in the `job_dir`. The pruned model will be named `[arch]_cov[i]` for stage i, and therefore the final pruned model is the one with largest `i`.
145 |
146 | ### Get FLOPS & Params
147 | ```shell
148 | python cal_flops_params.py \
149 | --arch resnet_56_convwise \
150 | --compress_rate [0.1]+[0.60]*35+[0.0]*2+[0.6]*6+[0.4]*3+[0.1]+[0.4]+[0.1]+[0.4]+[0.1]+[0.4]+[0.1]+[0.4]
151 | ```
152 |
153 | ### Evaluate Final Performance
154 | ```shell
155 | python evaluate.py \
156 | --dataset [dataset name] \
157 | --data_dir [dataset dir] \
158 | --test_model_dir [job dir of test model] \
159 | --arch [arch name] \
160 | --gpu [gpu id]
161 | ```
162 |
163 |
164 | ## Other optional arguments
165 | ```
166 | optional arguments:
167 | --data_dir dataset directory
168 | default='./data'
169 | --dataset dataset name
170 | default: cifar10
171 | Optional: cifar10', imagenet
172 | --lr initial learning rate
173 | default: 0.01
174 | --lr_decay_step learning rate decay step
175 | default: 5,10
176 | --resume load the model from the specified checkpoint
177 | --resume_mask mask loading directory
178 | --gpu Select gpu to use
179 | default: 0
180 | --job_dir The directory where the summaries will be stored.
181 | --epochs The num of epochs to train.
182 | default: 30
183 | --train_batch_size Batch size for training.
184 | default: 128
185 | --eval_batch_size Batch size for validation.
186 | default: 100
187 | --start_cov The num of conv to start prune
188 | default: 0
189 | --compress_rate compress rate of each conv
190 | --arch The architecture to prune
191 | default: vgg_16_bn
192 | Optional: resnet_50, vgg_16_bn, resnet_56, resnet_110, densenet_40, googlenet
193 | ```
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 | ## Pre-trained Models
203 |
204 | Additionally, we provide the pre-trained models used in our experiments.
205 |
206 |
207 | ### CIFAR-10:
208 | [Vgg-16](https://drive.google.com/open?id=1i3ifLh70y1nb8d4mazNzyC4I27jQcHrE)
209 | | [ResNet56](https://drive.google.com/open?id=1f1iSGvYFjSKIvzTko4fXFCbS-8dw556T)
210 | | [ResNet110](https://drive.google.com/open?id=1uENM3S5D_IKvXB26b1BFwMzUpkOoA26m)
211 | | [DenseNet-40](https://drive.google.com/open?id=12rInJ0YpGwZd_k76jctQwrfzPubsfrZH)
212 | | [GoogLeNet](https://drive.google.com/open?id=1rYMazSyMbWwkCGCLvofNKwl58W6mmg5c)
213 |
214 | ### ImageNet:
215 | [ResNet50](https://drive.google.com/open?id=1OYpVB84BMU0y-KU7PdEPhbHwODmFvPbB)
216 |
--------------------------------------------------------------------------------
/cal_flops_params.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import argparse
4 | import get_flops
5 | from models import *
6 |
7 | parser = argparse.ArgumentParser(description='Calculating flops and params')
8 |
9 | parser.add_argument(
10 | '--input_image_size',
11 | type=int,
12 | default=32,
13 | help='The input_image_size')
14 | parser.add_argument(
15 | '--arch',
16 | type=str,
17 | default='vgg_16_bn',
18 | choices=('vgg_16_bn','resnet_56','resnet_110','densenet_40','googlenet','resnet_50'),
19 | help='The architecture to prune')
20 | parser.add_argument(
21 | '--compress_rate',
22 | type=str,
23 | default=None,
24 | help='compress rate of each conv')
25 | args = parser.parse_args()
26 |
27 | device = torch.device("cpu")
28 |
29 |
30 | if args.compress_rate:
31 | import re
32 | cprate_str=args.compress_rate
33 | cprate_str_list=cprate_str.split('+')
34 | pat_cprate=re.compile(r'\d+\.\d*')
35 | pat_num = re.compile(r'\*\d+')
36 | cprate=[]
37 | for x in cprate_str_list:
38 | num=1
39 | find_num=re.findall(pat_num,x)
40 | if find_num:
41 | assert len(find_num) == 1
42 | num=int(find_num[0].replace('*',''))
43 | find_cprate = re.findall(pat_cprate, x)
44 | assert len(find_cprate)==1
45 | print(float(find_cprate[0]),num)
46 | cprate+=[float(find_cprate[0])]*num
47 | compress_rate=cprate
48 | print(compress_rate)
49 |
50 | if args.arch=='vgg_16_bn':
51 | compress_rate[12]=0.
52 |
53 | print('==> Building model..')
54 | net = eval(args.arch)(compress_rate=compress_rate)
55 | print(net.compress_rate)
56 | net.eval()
57 |
58 | if args.arch=='googlenet' or args.arch=='resnet_50':
59 | flops, params = get_flops.measure_model(net, device, 3, args.input_image_size, args.input_image_size, True)
60 | else:
61 | flops, params= get_flops.measure_model(net,device,3,args.input_image_size,args.input_image_size)
62 |
63 | print('Params: %.2f'%(params))
64 | print('Flops: %.2f'%(flops))
65 |
66 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/data/cifar10.py:
--------------------------------------------------------------------------------
1 | from torchvision.datasets import CIFAR10
2 | from torch.utils.data import Dataset, DataLoader
3 | import torchvision.transforms as transforms
4 |
5 | class Data:
6 | def __init__(self, args):
7 | # pin_memory = False
8 | # if args.gpu is not None:
9 | pin_memory = True
10 |
11 | transform_train = transforms.Compose([
12 | transforms.RandomCrop(32, padding=4),
13 | transforms.RandomHorizontalFlip(),
14 | transforms.ToTensor(),
15 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
16 | ])
17 | transform_test = transforms.Compose([
18 | transforms.ToTensor(),
19 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
20 | ])
21 |
22 | trainset = CIFAR10(root=args.data_dir, train=True, download=True, transform=transform_train)
23 | self.loader_train = DataLoader(
24 | trainset, batch_size=args.train_batch_size, shuffle=True,
25 | num_workers=2, pin_memory=pin_memory
26 | )
27 |
28 | testset = CIFAR10(root=args.data_dir, train=False, download=True, transform=transform_test)
29 | self.loader_test = DataLoader(
30 | testset, batch_size=args.eval_batch_size, shuffle=False,
31 | num_workers=2, pin_memory=pin_memory)
32 |
--------------------------------------------------------------------------------
/data/imagenet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torchvision.transforms as transforms
3 | import torchvision.datasets as datasets
4 | from torch.utils.data import DataLoader
5 |
6 |
7 | class Data:
8 | def __init__(self, args, is_evaluate=False):
9 | pin_memory = False
10 | if args.gpu is not None:
11 | pin_memory = True
12 |
13 | scale_size = 224
14 |
15 | traindir = os.path.join(args.data_dir, 'ILSVRC2012_img_train')
16 | valdir = os.path.join(args.data_dir, 'val')
17 | normalize = transforms.Normalize(
18 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
19 |
20 | if not is_evaluate:
21 | trainset = datasets.ImageFolder(
22 | traindir,
23 | transforms.Compose([
24 | transforms.RandomResizedCrop(224),
25 | transforms.RandomHorizontalFlip(),
26 | transforms.Resize(scale_size),
27 | transforms.ToTensor(),
28 | normalize,
29 | ]))
30 |
31 | self.loader_train = DataLoader(
32 | trainset,
33 | batch_size=args.train_batch_size,
34 | shuffle=True,
35 | num_workers=2,
36 | pin_memory=pin_memory)
37 |
38 | testset = datasets.ImageFolder(
39 | valdir,
40 | transforms.Compose([
41 | transforms.Resize(256),
42 | transforms.CenterCrop(224),
43 | transforms.Resize(scale_size),
44 | transforms.ToTensor(),
45 | normalize,
46 | ]))
47 |
48 | self.loader_test = DataLoader(
49 | testset,
50 | batch_size=args.eval_batch_size,
51 | shuffle=False,
52 | num_workers=2,
53 | pin_memory=True)
54 |
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import torch
4 | import torch.nn as nn
5 | import torch.optim as optim
6 | import torch.backends.cudnn as cudnn
7 |
8 | import torchvision
9 | import torchvision.transforms as transforms
10 |
11 | import argparse
12 |
13 | from data import imagenet
14 | from models import *
15 | from mask import *
16 | import utils
17 |
18 |
19 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Evaluate')
20 | parser.add_argument(
21 | '--data_dir',
22 | type=str,
23 | default='./data',
24 | help='dataset path')
25 | parser.add_argument(
26 | '--dataset',
27 | type=str,
28 | default='cifar10',
29 | choices=('cifar10','imagenet'),
30 | help='dataset')
31 | parser.add_argument(
32 | '--arch',
33 | type=str,
34 | default='vgg_16_bn',
35 | choices=('resnet_50','vgg_16_bn','resnet_56','resnet_110','densenet_40','googlenet'),
36 | help='The architecture to prune')
37 | parser.add_argument(
38 | '--test_model_dir',
39 | type=str,
40 | default='./result/tmp/',
41 | help='The directory where the summaries will be stored.')
42 | parser.add_argument(
43 | '--eval_batch_size',
44 | type=int,
45 | default=100,
46 | help='Batch size for validation.')
47 | parser.add_argument(
48 | '--gpu',
49 | type=str,
50 | default='0',
51 | help='Select gpu to use')
52 |
53 | args = parser.parse_args()
54 |
55 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
56 | cudnn.benchmark = True
57 | cudnn.enabled = True
58 |
59 | # Data
60 | print('==> Preparing data..')
61 | if args.dataset=='cifar10':
62 | transform_test = transforms.Compose([
63 | transforms.ToTensor(),
64 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
65 | ])
66 | testset = torchvision.datasets.CIFAR10(root=args.data_dir, train=False, download=True, transform=transform_test)
67 | testloader = torch.utils.data.DataLoader(testset, batch_size=args.eval_batch_size, shuffle=False, num_workers=2)
68 | print_freq = 3000 // args.eval_batch_size
69 | elif args.dataset=='imagenet':
70 | data_tmp = imagenet.Data(args, is_evaluate=True)
71 | testloader = data_tmp.loader_test
72 | print_freq = 10000 // args.eval_batch_size
73 | else:
74 | raise NotImplementedError
75 |
76 | # Model
77 | print('==> Building model..')
78 | net = eval(args.arch)(compress_rate=[0.]*200)
79 | net = net.cuda()
80 | print(net)
81 |
82 | if len(args.gpu)>1 and torch.cuda.is_available():
83 | device_id = []
84 | for i in range((len(args.gpu) + 1) // 2):
85 | device_id.append(i)
86 | net = torch.nn.DataParallel(net, device_ids=device_id)
87 |
88 | def test():
89 | top1 = utils.AverageMeter()
90 | top5 = utils.AverageMeter()
91 |
92 | net.eval()
93 | num_iterations = len(testloader)
94 | with torch.no_grad():
95 | for batch_idx, (inputs, targets) in enumerate(testloader):
96 | inputs, targets = inputs.cuda(), targets.cuda()
97 | outputs = net(inputs)
98 |
99 | prec1, prec5 = utils.accuracy(outputs, targets, topk=(1, 5))
100 | top1.update(prec1[0], inputs.size(0))
101 | top5.update(prec5[0], inputs.size(0))
102 |
103 | if batch_idx%print_freq==0:
104 | print(
105 | '({0}/{1}): '
106 | 'Prec@1(1,5) {top1.avg:.2f}, {top5.avg:.2f}'.format(
107 | batch_idx, num_iterations, top1=top1, top5=top5))
108 |
109 | print("Final Prec@1(1,5) {top1.avg:.2f}, {top5.avg:.2f}".format(top1=top1, top5=top5))
110 |
111 |
112 | if len(args.gpu)>1:
113 | convcfg = net.module.covcfg
114 | else:
115 | convcfg = net.covcfg
116 |
117 | cov_id=len(convcfg)
118 | new_state_dict = OrderedDict()
119 | pruned_checkpoint = torch.load(args.test_model_dir + "/pruned_checkpoint/" + args.arch + "_cov" + str(cov_id) + '.pt',
120 | map_location='cuda:0')
121 | tmp_ckpt = pruned_checkpoint['state_dict']
122 | if len(args.gpu) == 1:
123 | for k, v in tmp_ckpt.items():
124 | new_state_dict[k.replace('module.', '')] = v
125 | else:
126 | for k, v in tmp_ckpt.items():
127 | new_state_dict['module.' + k.replace('module.', '')] = v
128 | net.load_state_dict(new_state_dict)
129 |
130 | test()
131 |
132 |
--------------------------------------------------------------------------------
/get_flops.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import unicode_literals
3 | from __future__ import print_function
4 | from __future__ import division
5 |
6 | import torch
7 | from torch.autograd import Variable
8 | from functools import reduce
9 | import operator
10 |
11 | count_ops = 0
12 | count_params = 0
13 |
14 |
15 | def get_num_gen(gen):
16 | return sum(1 for x in gen)
17 |
18 | def is_pruned(layer):
19 | try:
20 | layer.mask
21 | return True
22 | except AttributeError:
23 | return False
24 |
25 | def is_leaf(model):
26 | return get_num_gen(model.children()) == 0
27 |
28 | def get_layer_info(layer):
29 | layer_str = str(layer)
30 | # print(layer_str)
31 | type_name = layer_str[:layer_str.find('(')].strip()
32 | return type_name
33 |
34 | def get_layer_param(model, is_conv=True):
35 | if is_conv:
36 | total=0.
37 | for idx, param in enumerate(model.parameters()):
38 | assert idx<2
39 | f = param.size()[0]
40 | pruned_num = int(model.cp_rate * f)
41 | if len(param.size())>1:
42 | c=param.size()[1]
43 | if hasattr(model,'last_prune_num'):
44 | last_prune_num=model.last_prune_num
45 | total += (f - pruned_num) * (c-last_prune_num) * param.numel() / f / c
46 | else:
47 | total += (f - pruned_num) * param.numel() / f
48 | else:
49 | total += (f - pruned_num) * param.numel() / f
50 | return total
51 | else:
52 | return sum([reduce(operator.mul, i.size(), 1) for i in model.parameters()])
53 |
54 | ### The input batch size should be 1 to call this function
55 | def measure_layer(layer, x, print_name):
56 | global count_ops, count_params
57 | delta_ops = 0
58 | delta_params = 0
59 | multi_add = 1
60 | type_name = get_layer_info(layer)
61 |
62 | ### ops_conv
63 | if type_name in ['Conv2d']:
64 | out_h = int((x.size()[2] + 2 * layer.padding[0] - layer.kernel_size[0]) /
65 | layer.stride[0] + 1)
66 | out_w = int((x.size()[3] + 2 * layer.padding[1] - layer.kernel_size[1]) /
67 | layer.stride[1] + 1)
68 | pruned_num = int(layer.cp_rate * layer.out_channels)
69 |
70 | if hasattr(layer,'tmp_name') and 'trans' in layer.tmp_name:
71 | delta_ops = (layer.in_channels-layer.last_prune_num) * (layer.out_channels - pruned_num) * layer.kernel_size[0] * \
72 | layer.kernel_size[1] * out_h * out_w / layer.groups * multi_add
73 | else:
74 | delta_ops = layer.in_channels * (layer.out_channels-pruned_num) * layer.kernel_size[0] * \
75 | layer.kernel_size[1] * out_h * out_w / layer.groups * multi_add
76 |
77 | delta_ops_ori = layer.in_channels * layer.out_channels * layer.kernel_size[0] * \
78 | layer.kernel_size[1] * out_h * out_w / layer.groups * multi_add
79 |
80 | delta_params = get_layer_param(layer)
81 |
82 | if print_name:
83 | print(layer.tmp_name, layer.cp_rate, '| input:',x.size(),'| weight:',[layer.out_channels, layer.in_channels, layer.kernel_size[0], layer.kernel_size[1]],
84 | '| params:', delta_params, '| flops:', delta_ops_ori)
85 | else:
86 | print(layer.cp_rate, [layer.out_channels,layer.in_channels,layer.kernel_size[0],layer.kernel_size[1]],
87 | 'params:',delta_params, ' flops:',delta_ops_ori)
88 |
89 | ### ops_linear
90 | elif type_name in ['Linear']:
91 | weight_ops = layer.weight.numel() * multi_add
92 | bias_ops = layer.bias.numel()
93 | delta_ops = x.size()[0] * (weight_ops + bias_ops)
94 | delta_params = get_layer_param(layer, is_conv=False)
95 |
96 | print('linear:',layer, delta_ops, delta_params)
97 |
98 | elif type_name in ['DenseBasicBlock', 'ResBasicBlock']:
99 | measure_layer(layer.conv1, x)
100 |
101 | elif type_name in ['Inception']:
102 | measure_layer(layer.conv1, x)
103 |
104 | elif type_name in ['DenseBottleneck', 'SparseDenseBottleneck']:
105 | measure_layer(layer.conv1, x)
106 |
107 | elif type_name in ['Transition', 'SparseTransition']:
108 | measure_layer(layer.conv1, x)
109 |
110 | elif type_name in ['ReLU', 'BatchNorm1d','BatchNorm2d', 'Dropout2d', 'DropChannel', 'Dropout', 'AdaptiveAvgPool2d', 'AvgPool2d', 'MaxPool2d', 'Mask', 'channel_selection', 'LambdaLayer', 'Sequential']:
111 | return
112 | ### unknown layer type
113 | else:
114 | raise TypeError('unknown layer type: %s' % type_name)
115 |
116 | count_ops += delta_ops
117 | count_params += delta_params
118 | return
119 |
120 | def measure_model(model, device, C, H, W, print_name=False):
121 | global count_ops, count_params
122 | count_ops = 0
123 | count_params = 0
124 | data = Variable(torch.zeros(1, C, H, W)).to(device)
125 | model = model.to(device)
126 | model.eval()
127 |
128 | def should_measure(x):
129 | return is_leaf(x)
130 |
131 | def modify_forward(model, print_name):
132 | for child in model.children():
133 | if should_measure(child):
134 | def new_forward(m):
135 | def lambda_forward(x):
136 | measure_layer(m, x, print_name)
137 | return m.old_forward(x)
138 | return lambda_forward
139 | child.old_forward = child.forward
140 | child.forward = new_forward(child)
141 | else:
142 | modify_forward(child, print_name)
143 |
144 | def restore_forward(model):
145 | for child in model.children():
146 | # leaf node
147 | if is_leaf(child) and hasattr(child, 'old_forward'):
148 | child.forward = child.old_forward
149 | child.old_forward = None
150 | else:
151 | restore_forward(child)
152 |
153 | modify_forward(model, print_name)
154 | model.forward(data)
155 | restore_forward(model)
156 |
157 | return count_ops, count_params
158 |
--------------------------------------------------------------------------------
/img/framework.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/img/framework.jpeg
--------------------------------------------------------------------------------
/img/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/img/framework.png
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.optim as optim
4 | import torch.backends.cudnn as cudnn
5 |
6 | import torchvision
7 | import torchvision.transforms as transforms
8 |
9 | import os
10 | import argparse
11 |
12 | from data import imagenet
13 | from models import *
14 | from utils import progress_bar
15 | from mask import *
16 | import utils
17 |
18 |
19 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
20 | parser.add_argument(
21 | '--data_dir',
22 | type=str,
23 | default='./data',
24 | help='dataset path')
25 | parser.add_argument(
26 | '--dataset',
27 | type=str,
28 | default='cifar10',
29 | choices=('cifar10','imagenet'),
30 | help='dataset')
31 | parser.add_argument(
32 | '--lr',
33 | default=0.01,
34 | type=float,
35 | help='initial learning rate')
36 | parser.add_argument(
37 | '--lr_decay_step',
38 | default='5,10',
39 | type=str,
40 | help='learning rate decay step')
41 | parser.add_argument(
42 | '--resume',
43 | type=str,
44 | default=None,
45 | help='load the model from the specified checkpoint')
46 | parser.add_argument(
47 | '--resume_mask',
48 | type=str,
49 | default=None,
50 | help='mask loading')
51 | parser.add_argument(
52 | '--gpu',
53 | type=str,
54 | default='0',
55 | help='Select gpu to use')
56 | parser.add_argument(
57 | '--job_dir',
58 | type=str,
59 | default='./result/tmp/',
60 | help='The directory where the summaries will be stored.')
61 | parser.add_argument(
62 | '--epochs',
63 | type=int,
64 | default=15,
65 | help='The num of epochs to train.')
66 | parser.add_argument(
67 | '--train_batch_size',
68 | type=int,
69 | default=128,
70 | help='Batch size for training.')
71 | parser.add_argument(
72 | '--eval_batch_size',
73 | type=int,
74 | default=100,
75 | help='Batch size for validation.')
76 | parser.add_argument(
77 | '--start_cov',
78 | type=int,
79 | default=0,
80 | help='The num of conv to start prune')
81 | parser.add_argument(
82 | '--compress_rate',
83 | type=str,
84 | default=None,
85 | help='compress rate of each conv')
86 | parser.add_argument(
87 | '--arch',
88 | type=str,
89 | default='vgg_16_bn',
90 | choices=('resnet_50','vgg_16_bn','resnet_56','resnet_110','densenet_40','googlenet'),
91 | help='The architecture to prune')
92 |
93 | args = parser.parse_args()
94 |
95 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
96 |
97 | if len(args.gpu)==1:
98 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
99 | else:
100 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
101 |
102 | best_acc = 0 # best test accuracy
103 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch
104 | lr_decay_step = list(map(int, args.lr_decay_step.split(',')))
105 |
106 | ckpt = utils.checkpoint(args)
107 | print_logger = utils.get_logger(os.path.join(args.job_dir, "logger.log"))
108 | utils.print_params(vars(args), print_logger.info)
109 |
110 | # Data
111 | print_logger.info('==> Preparing data..')
112 |
113 | if args.dataset=='cifar10':
114 | transform_train = transforms.Compose([
115 | transforms.RandomCrop(32, padding=4),
116 | transforms.RandomHorizontalFlip(),
117 | transforms.ToTensor(),
118 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
119 | ])
120 |
121 | transform_test = transforms.Compose([
122 | transforms.ToTensor(),
123 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
124 | ])
125 |
126 | trainset = torchvision.datasets.CIFAR10(root=args.data_dir, train=True, download=True, transform=transform_train)
127 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.train_batch_size, shuffle=True, num_workers=2)
128 |
129 | testset = torchvision.datasets.CIFAR10(root=args.data_dir, train=False, download=True, transform=transform_test)
130 | testloader = torch.utils.data.DataLoader(testset, batch_size=args.eval_batch_size, shuffle=False, num_workers=2)
131 | elif args.dataset=='imagenet':
132 | data_tmp = imagenet.Data(args)
133 | trainloader = data_tmp.loader_train
134 | testloader = data_tmp.loader_test
135 | else:
136 | assert 1==0
137 |
138 | if args.compress_rate:
139 | import re
140 | cprate_str=args.compress_rate
141 | cprate_str_list=cprate_str.split('+')
142 | pat_cprate = re.compile(r'\d+\.\d*')
143 | pat_num = re.compile(r'\*\d+')
144 | cprate=[]
145 | for x in cprate_str_list:
146 | num=1
147 | find_num=re.findall(pat_num,x)
148 | if find_num:
149 | assert len(find_num) == 1
150 | num=int(find_num[0].replace('*',''))
151 | find_cprate = re.findall(pat_cprate, x)
152 | assert len(find_cprate)==1
153 | cprate+=[float(find_cprate[0])]*num
154 |
155 | compress_rate=cprate
156 |
157 | # Model
158 | device_ids=list(map(int, args.gpu.split(',')))
159 | print_logger.info('==> Building model..')
160 | net = eval(args.arch)(compress_rate=compress_rate)
161 | net = net.to(device)
162 |
163 | if len(args.gpu)>1 and torch.cuda.is_available():
164 | device_id = []
165 | for i in range((len(args.gpu) + 1) // 2):
166 | device_id.append(i)
167 | net = torch.nn.DataParallel(net, device_ids=device_id)
168 |
169 | cudnn.benchmark = True
170 | print(net)
171 |
172 | if len(args.gpu)>1:
173 | m = eval('mask_'+args.arch)(model=net, compress_rate=net.module.compress_rate, job_dir=args.job_dir, device=device)
174 | else:
175 | m = eval('mask_' + args.arch)(model=net, compress_rate=net.compress_rate, job_dir=args.job_dir, device=device)
176 |
177 | criterion = nn.CrossEntropyLoss()
178 |
179 | # Training
180 | def train(epoch, cov_id, optimizer, scheduler, pruning=True):
181 | print_logger.info('\nEpoch: %d' % epoch)
182 | net.train()
183 |
184 | train_loss = 0
185 | correct = 0
186 | total = 0
187 | for batch_idx, (inputs, targets) in enumerate(trainloader):
188 | with torch.cuda.device(device):
189 | inputs = inputs.to(device)
190 | targets = targets.to(device)
191 | optimizer.zero_grad()
192 | outputs = net(inputs)
193 | loss = criterion(outputs, targets)
194 | loss.backward()
195 |
196 | optimizer.step()
197 |
198 | if pruning:
199 | m.grad_mask(cov_id)
200 |
201 | train_loss += loss.item()
202 | _, predicted = outputs.max(1)
203 | total += targets.size(0)
204 | correct += predicted.eq(targets).sum().item()
205 |
206 | progress_bar(batch_idx,len(trainloader),
207 | 'Cov: %d | Loss: %.3f | Acc: %.3f%% (%d/%d)'
208 | % (cov_id, train_loss / (batch_idx + 1), 100. * correct / total, correct, total))
209 |
210 | def test(epoch, cov_id, optimizer, scheduler):
211 | top1 = utils.AverageMeter()
212 | top5 = utils.AverageMeter()
213 |
214 | global best_acc
215 | net.eval()
216 | num_iterations = len(testloader)
217 | with torch.no_grad():
218 | for batch_idx, (inputs, targets) in enumerate(testloader):
219 | inputs, targets = inputs.to(device), targets.to(device)
220 | outputs = net(inputs)
221 | loss = criterion(outputs, targets)
222 |
223 | prec1, prec5 = utils.accuracy(outputs, targets, topk=(1, 5))
224 | top1.update(prec1[0], inputs.size(0))
225 | top5.update(prec5[0], inputs.size(0))
226 |
227 | print_logger.info(
228 | 'Epoch[{0}]({1}/{2}): '
229 | 'Prec@1(1,5) {top1.avg:.2f}, {top5.avg:.2f}'.format(
230 | epoch, batch_idx, num_iterations, top1=top1, top5=top5))
231 |
232 | if top1.avg > best_acc:
233 | print_logger.info('Saving to '+args.arch+'_cov'+str(cov_id)+'.pt')
234 | state = {
235 | 'state_dict': net.state_dict(),
236 | 'best_prec1': top1.avg,
237 | 'epoch': epoch,
238 | 'scheduler':scheduler.state_dict(),
239 | 'optimizer': optimizer.state_dict()
240 | }
241 | if not os.path.isdir(args.job_dir+'/pruned_checkpoint'):
242 | os.mkdir(args.job_dir+'/pruned_checkpoint')
243 | best_acc = top1.avg
244 | torch.save(state, args.job_dir+'/pruned_checkpoint/'+args.arch+'_cov'+str(cov_id)+'.pt')
245 |
246 | print_logger.info("=>Best accuracy {:.3f}".format(best_acc))
247 |
248 |
249 | if len(args.gpu)>1:
250 | convcfg = net.module.covcfg
251 | else:
252 | convcfg = net.covcfg
253 |
254 | param_per_cov_dic={
255 | 'vgg_16_bn': 4,
256 | 'densenet_40': 3,
257 | 'googlenet': 28,
258 | 'resnet_50':3,
259 | 'resnet_56':3,
260 | 'resnet_110':3
261 | }
262 |
263 | if len(args.gpu)>1:
264 | print_logger.info('compress rate: ' + str(net.module.compress_rate))
265 | else:
266 | print_logger.info('compress rate: ' + str(net.compress_rate))
267 |
268 | for cov_id in range(args.start_cov, len(convcfg)):
269 | # Load pruned_checkpoint
270 | print_logger.info("cov-id: %d ====> Resuming from pruned_checkpoint..." % (cov_id))
271 |
272 | m.layer_mask(cov_id + 1, resume=args.resume_mask, param_per_cov=param_per_cov_dic[args.arch], arch=args.arch)
273 |
274 | optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
275 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=lr_decay_step, gamma=0.1)
276 |
277 | if cov_id == 0:
278 |
279 | pruned_checkpoint = torch.load(args.resume, map_location='cuda:0')
280 | from collections import OrderedDict
281 | new_state_dict = OrderedDict()
282 | if args.arch == 'resnet_50':
283 | tmp_ckpt = pruned_checkpoint
284 | else:
285 | tmp_ckpt = pruned_checkpoint['state_dict']
286 |
287 | if len(args.gpu) > 1:
288 | for k, v in tmp_ckpt.items():
289 | new_state_dict['module.' + k.replace('module.', '')] = v
290 | else:
291 | for k, v in tmp_ckpt.items():
292 | new_state_dict[k.replace('module.', '')] = v
293 |
294 | net.load_state_dict(new_state_dict)#'''
295 | else:
296 | if args.arch=='resnet_50':
297 | skip_list=[1,5,8,11,15,18,21,24,28,31,34,37,40,43,47,50,53]
298 | if cov_id+1 not in skip_list:
299 | continue
300 | else:
301 | pruned_checkpoint = torch.load(
302 | args.job_dir + "/pruned_checkpoint/" + args.arch + "_cov" + str(skip_list[skip_list.index(cov_id+1)-1]) + '.pt')
303 | net.load_state_dict(pruned_checkpoint['state_dict'])
304 | else:
305 | if len(args.gpu) == 1:
306 | pruned_checkpoint = torch.load(args.job_dir + "/pruned_checkpoint/" + args.arch + "_cov" + str(cov_id) + '.pt', map_location='cuda:' + args.gpu)
307 | else:
308 | pruned_checkpoint = torch.load(args.job_dir + "/pruned_checkpoint/" + args.arch + "_cov" + str(cov_id) + '.pt')
309 | net.load_state_dict(pruned_checkpoint['state_dict'])
310 |
311 | best_acc=0.
312 | for epoch in range(0, args.epochs):
313 | train(epoch, cov_id + 1, optimizer, scheduler)
314 | scheduler.step()
315 | test(epoch, cov_id + 1, optimizer, scheduler)
316 |
--------------------------------------------------------------------------------
/mask.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import numpy as np
4 |
5 | import pickle
6 |
7 |
8 | class mask_vgg_16_bn:
9 | def __init__(self, model=None, compress_rate=[0.50], job_dir='',device=None):
10 | self.model = model
11 | self.compress_rate = compress_rate
12 | self.mask = {}
13 | self.job_dir=job_dir
14 | self.device = device
15 |
16 | def layer_mask(self, cov_id, resume=None, param_per_cov=4, arch="vgg_16_bn"):
17 | params = self.model.parameters()
18 | prefix = "rank_conv/"+arch+"/rank_conv"
19 | subfix = ".npy"
20 |
21 | if resume:
22 | with open(resume, 'rb') as f:
23 | self.mask = pickle.load(f)
24 | else:
25 | resume=self.job_dir+'/mask'
26 |
27 | self.param_per_cov=param_per_cov
28 |
29 | for index, item in enumerate(params):
30 |
31 | if index == cov_id * param_per_cov:
32 | break
33 | if index == (cov_id - 1) * param_per_cov:
34 | f, c, w, h = item.size()
35 | rank = np.load(prefix + str(cov_id) + subfix)
36 | pruned_num = int(self.compress_rate[cov_id - 1] * f)
37 | ind = np.argsort(rank)[pruned_num:] # preserved filter id
38 |
39 | zeros = torch.zeros(f, 1, 1, 1).to(self.device)
40 | for i in range(len(ind)):
41 | zeros[ind[i], 0, 0, 0] = 1.
42 | self.mask[index] = zeros # covolutional weight
43 | item.data = item.data * self.mask[index]
44 |
45 | if index > (cov_id - 1) * param_per_cov and index <= (cov_id - 1) * param_per_cov + param_per_cov-1:
46 | self.mask[index] = torch.squeeze(zeros)
47 | item.data = item.data * self.mask[index]
48 |
49 | with open(resume, "wb") as f:
50 | pickle.dump(self.mask, f)
51 |
52 | def grad_mask(self, cov_id):
53 | params = self.model.parameters()
54 | for index, item in enumerate(params):
55 | if index == cov_id * self.param_per_cov:
56 | break
57 | item.data = item.data * self.mask[index]#prune certain weight
58 |
59 |
60 | class mask_resnet_56:
61 | def __init__(self, model=None, compress_rate=[0.50], job_dir='',device=None):
62 | self.model = model
63 | self.compress_rate = compress_rate
64 | self.mask = {}
65 | self.job_dir=job_dir
66 | self.device = device
67 |
68 | def layer_mask(self, cov_id, resume=None, param_per_cov=3, arch="resnet_56"):
69 | params = self.model.parameters()
70 | prefix = "rank_conv/"+arch+"/rank_conv"
71 | subfix = ".npy"
72 |
73 | if resume:
74 | with open(resume, 'rb') as f:
75 | self.mask = pickle.load(f)
76 | else:
77 | resume=self.job_dir+'/mask'
78 |
79 | self.param_per_cov=param_per_cov
80 |
81 | for index, item in enumerate(params):
82 |
83 | if index == cov_id*param_per_cov:
84 | break
85 |
86 | if index == (cov_id - 1) * param_per_cov:
87 | f, c, w, h = item.size()
88 | rank = np.load(prefix + str(cov_id) + subfix)
89 | pruned_num = int(self.compress_rate[cov_id - 1] * f)
90 | ind = np.argsort(rank)[pruned_num:] # preserved filter id
91 |
92 | zeros = torch.zeros(f, 1, 1, 1).to(self.device)
93 | for i in range(len(ind)):
94 | zeros[ind[i], 0, 0, 0] = 1.
95 | self.mask[index] = zeros # covolutional weight
96 | item.data = item.data * self.mask[index]
97 |
98 | elif index > (cov_id-1)*param_per_cov and index < cov_id*param_per_cov:
99 | self.mask[index] = torch.squeeze(zeros)
100 | item.data = item.data * self.mask[index].to(self.device)
101 |
102 | with open(resume, "wb") as f:
103 | pickle.dump(self.mask, f)
104 |
105 | def grad_mask(self, cov_id):
106 | params = self.model.parameters()
107 | for index, item in enumerate(params):
108 | if index == cov_id*self.param_per_cov:
109 | break
110 | item.data = item.data * self.mask[index].to(self.device)#prune certain weight
111 |
112 |
113 | class mask_densenet_40:
114 | def __init__(self, model=None, compress_rate=[0.50], job_dir='',device=None):
115 | self.model = model
116 | self.compress_rate = compress_rate
117 | self.job_dir=job_dir
118 | self.device=device
119 | self.mask = {}
120 |
121 | def layer_mask(self, cov_id, resume=None, param_per_cov=3, arch="densenet_40"):
122 | params = self.model.parameters()
123 | prefix = "rank_conv/"+arch+"/rank_conv"
124 | subfix = ".npy"
125 |
126 | if resume:
127 | with open(resume, 'rb') as f:
128 | self.mask = pickle.load(f)
129 | else:
130 | resume=self.job_dir+'/mask'
131 |
132 | self.param_per_cov=param_per_cov
133 |
134 | for index, item in enumerate(params):
135 |
136 | if index == cov_id * param_per_cov:
137 | break
138 | if index == (cov_id - 1) * param_per_cov:
139 | f, c, w, h = item.size()
140 | rank = np.load(prefix + str(cov_id) + subfix)
141 | pruned_num = int(self.compress_rate[cov_id - 1] * f)
142 | ind = np.argsort(rank)[pruned_num:] # preserved filter id
143 |
144 | zeros = torch.zeros(f, 1, 1, 1).to(self.device)
145 | for i in range(len(ind)):
146 | zeros[ind[i], 0, 0, 0] = 1.
147 | self.mask[index] = zeros # covolutional weight
148 | item.data = item.data * self.mask[index]
149 |
150 | # prune BN's parameter
151 | if index > (cov_id - 1) * param_per_cov and index <= (cov_id - 1) * param_per_cov + param_per_cov-1:
152 | # if this BN not belong to 1st conv or transition conv --> add pre-BN mask to this mask
153 | if cov_id>=2 and cov_id!=14 and cov_id!=27:
154 | self.mask[index] = torch.cat([self.mask[index-param_per_cov], torch.squeeze(zeros)], 0).to(self.device)
155 | else:
156 | self.mask[index] = torch.squeeze(zeros).to(self.device)
157 | item.data = item.data * self.mask[index]
158 |
159 | with open(resume, "wb") as f:
160 | pickle.dump(self.mask, f)
161 |
162 | def grad_mask(self, cov_id):
163 | params = self.model.parameters()
164 | for index, item in enumerate(params):
165 | if index == cov_id * self.param_per_cov:
166 | break
167 | item.data = item.data * self.mask[index].to(self.device)
168 |
169 |
170 | class mask_googlenet:
171 | def __init__(self, model=None, compress_rate=[0.50], job_dir='',device=None):
172 | self.model = model
173 | self.compress_rate = compress_rate
174 | self.mask = {}
175 | self.job_dir=job_dir
176 | self.device = device
177 |
178 | def layer_mask(self, cov_id, resume=None, param_per_cov=28, arch="googlenet"):
179 | params = self.model.parameters()
180 | prefix = "rank_conv/"+arch+"/rank_conv"
181 | subfix = ".npy"
182 |
183 | if resume:
184 | with open(resume, 'rb') as f:
185 | self.mask = pickle.load(f)
186 | else:
187 | resume=self.job_dir+'/mask'
188 |
189 | self.param_per_cov=param_per_cov
190 |
191 | for index, item in enumerate(params):
192 |
193 | if index == (cov_id-1) * param_per_cov + 4:
194 | break
195 | if (cov_id==1 and index==0)\
196 | or index == (cov_id - 1) * param_per_cov - 24 \
197 | or index == (cov_id - 1) * param_per_cov - 16 \
198 | or index == (cov_id - 1) * param_per_cov - 8 \
199 | or index == (cov_id - 1) * param_per_cov - 4 \
200 | or index == (cov_id - 1) * param_per_cov:
201 |
202 | if index == (cov_id - 1) * param_per_cov - 24:
203 | rank = np.load(prefix + str(cov_id)+'_'+'n1x1' + subfix)
204 | elif index == (cov_id - 1) * param_per_cov - 16:
205 | rank = np.load(prefix + str(cov_id)+'_'+'n3x3' + subfix)
206 | elif index == (cov_id - 1) * param_per_cov - 8 \
207 | or index == (cov_id - 1) * param_per_cov - 4:
208 | rank = np.load(prefix + str(cov_id)+'_'+'n5x5' + subfix)
209 | elif cov_id==1 and index==0:
210 | rank = np.load(prefix + str(cov_id) + subfix)
211 | else:
212 | rank = np.load(prefix + str(cov_id) + '_' + 'pool_planes' + subfix)
213 |
214 | f, c, w, h = item.size()
215 | pruned_num = int(self.compress_rate[cov_id - 1] * f)
216 | ind = np.argsort(rank)[pruned_num:] # preserved filter id
217 |
218 | zeros = torch.zeros(f, 1, 1, 1).to(self.device)
219 | for i in range(len(ind)):
220 | zeros[ind[i], 0, 0, 0] = 1.
221 | self.mask[index] = zeros # covolutional weight
222 | item.data = item.data * self.mask[index]
223 |
224 | elif cov_id==1 and index > 0 and index <= 3:
225 | self.mask[index] = torch.squeeze(zeros)
226 | item.data = item.data * self.mask[index]
227 |
228 | elif (index>=(cov_id - 1) * param_per_cov - 20 and index< (cov_id - 1) * param_per_cov - 16) \
229 | or (index>=(cov_id - 1) * param_per_cov - 12 and index< (cov_id - 1) * param_per_cov - 8):
230 | continue
231 |
232 | elif index > (cov_id-1)*param_per_cov-24 and index < (cov_id-1)*param_per_cov+4:
233 | self.mask[index] = torch.squeeze(zeros)
234 | item.data = item.data * self.mask[index]
235 |
236 | with open(resume, "wb") as f:
237 | pickle.dump(self.mask, f)
238 |
239 | def grad_mask(self, cov_id):
240 | params = self.model.parameters()
241 | for index, item in enumerate(params):
242 | if index == (cov_id-1) * self.param_per_cov + 4:
243 | break
244 | if index not in self.mask:
245 | continue
246 | item.data = item.data * self.mask[index].to(self.device)#prune certain weight
247 |
248 |
249 | class mask_resnet_110:
250 | def __init__(self, model=None, compress_rate=[0.50], job_dir='',device=None):
251 | self.model = model
252 | self.compress_rate = compress_rate
253 | self.mask = {}
254 | self.job_dir=job_dir
255 | self.device = device
256 |
257 | def layer_mask(self, cov_id, resume=None, param_per_cov=3, arch="resnet_110_convwise"):
258 | params = self.model.parameters()
259 | prefix = "rank_conv/"+arch+"/rank_conv"
260 | subfix = ".npy"
261 |
262 | if resume:
263 | with open(resume, 'rb') as f:
264 | self.mask = pickle.load(f)
265 | else:
266 | resume=self.job_dir+'/mask'
267 |
268 | self.param_per_cov=param_per_cov
269 |
270 | for index, item in enumerate(params):
271 |
272 | if index == cov_id*param_per_cov:
273 | break
274 |
275 | if index == (cov_id - 1) * param_per_cov:
276 | f, c, w, h = item.size()
277 | rank = np.load(prefix + str(cov_id) + subfix)
278 | pruned_num = int(self.compress_rate[cov_id - 1] * f)
279 | ind = np.argsort(rank)[pruned_num:] # preserved filter id
280 |
281 | zeros = torch.zeros(f, 1, 1, 1).to(self.device)
282 | for i in range(len(ind)):
283 | zeros[ind[i], 0, 0, 0] = 1.
284 |
285 | self.mask[index] = zeros # covolutional weight
286 | item.data = item.data * self.mask[index]
287 |
288 | elif index > (cov_id-1)*param_per_cov and index < cov_id*param_per_cov:
289 | self.mask[index] = torch.squeeze(zeros)
290 | item.data = item.data * self.mask[index]
291 |
292 | with open(resume, "wb") as f:
293 | pickle.dump(self.mask, f)
294 |
295 | def grad_mask(self, cov_id):
296 | params = self.model.parameters()
297 | for index, item in enumerate(params):
298 | if index == cov_id*self.param_per_cov:
299 | break
300 | item.data = item.data * self.mask[index].to(self.device)#prune certain weight
301 |
302 |
303 | class mask_resnet_50:
304 | def __init__(self, model=None, compress_rate=[0.50], job_dir='',device=None):
305 | self.model = model
306 | self.compress_rate = compress_rate
307 | self.mask = {}
308 | self.job_dir=job_dir
309 | self.device = device
310 |
311 | def layer_mask(self, cov_id, resume=None, param_per_cov=3, arch="resnet_50_convwise"):
312 | params = self.model.parameters()
313 | prefix = "rank_conv/"+arch+"/rank_conv"
314 | subfix = ".npy"
315 |
316 | if resume:
317 | with open(resume, 'rb') as f:
318 | self.mask = pickle.load(f)
319 | else:
320 | resume=self.job_dir+'/mask'
321 |
322 | self.param_per_cov=param_per_cov
323 |
324 | for index, item in enumerate(params):
325 |
326 | if index == cov_id * param_per_cov:
327 | break
328 |
329 | if index == (cov_id-1) * param_per_cov:
330 | f, c, w, h = item.size()
331 | rank = np.load(prefix + str(cov_id) + subfix)
332 | pruned_num = int(self.compress_rate[cov_id - 1] * f)
333 | ind = np.argsort(rank)[pruned_num:] # preserved filter id
334 | zeros = torch.zeros(f, 1, 1, 1).to(self.device)#.cuda(self.device[0])#.to(self.device)
335 | for i in range(len(ind)):
336 | zeros[ind[i], 0, 0, 0] = 1.
337 | self.mask[index] = zeros # covolutional weight
338 | item.data = item.data * self.mask[index]
339 |
340 | elif index > (cov_id-1) * param_per_cov and index < cov_id * param_per_cov:
341 | self.mask[index] = torch.squeeze(zeros)
342 | item.data = item.data * self.mask[index]
343 |
344 | with open(resume, "wb") as f:
345 | pickle.dump(self.mask, f)
346 |
347 | def grad_mask(self, cov_id):
348 | params = self.model.parameters()
349 | for index, item in enumerate(params):
350 | if index == cov_id * self.param_per_cov:
351 | break
352 | item.data = item.data * self.mask[index]#prune certain weight
353 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .vgg import *
2 | from .densenet_cifar import *
3 | from .googlenet_cifar import *
4 | from .resnet_imagenet import *
5 | from .resnet_cifar import *
6 |
--------------------------------------------------------------------------------
/models/densenet_cifar.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import math
6 | import numpy as np
7 |
8 |
9 | norm_mean, norm_var = 0.0, 1.0
10 |
11 | cov_cfg=[(3*i+1) for i in range(12*3+2+1)]
12 |
13 |
14 | class DenseBasicBlock(nn.Module):
15 | def __init__(self, inplanes, filters, index, expansion=1, growthRate=12, dropRate=0, compress_rate=0., tmp_name=None):
16 | super(DenseBasicBlock, self).__init__()
17 |
18 | self.bn1 = nn.BatchNorm2d(inplanes)
19 | self.relu = nn.ReLU(inplace=True)
20 | self.conv1 = nn.Conv2d(filters, growthRate, kernel_size=3,
21 | padding=1, bias=False)
22 | self.conv1.cp_rate = compress_rate
23 | self.conv1.tmp_name = tmp_name
24 |
25 | self.dropRate = dropRate
26 |
27 | def forward(self, x):
28 | out = self.bn1(x)
29 | out = self.relu(out)
30 | out = self.conv1(out)
31 | if self.dropRate > 0:
32 | out = F.dropout(out, p=self.dropRate, training=self.training)
33 |
34 | out = torch.cat((x, out), 1)
35 |
36 | return out
37 |
38 |
39 | class Transition(nn.Module):
40 | def __init__(self, inplanes, outplanes, filters, index, compress_rate, tmp_name, last_prune_num):
41 | super(Transition, self).__init__()
42 | self.bn1 = nn.BatchNorm2d(inplanes)
43 | self.relu = nn.ReLU(inplace=True)
44 | self.conv1 = nn.Conv2d(filters, outplanes, kernel_size=1,
45 | bias=False)
46 | self.conv1.cp_rate = compress_rate
47 | self.conv1.tmp_name = tmp_name
48 | self.conv1.last_prune_num=last_prune_num
49 |
50 | def forward(self, x):
51 | out = self.bn1(x)
52 | out = self.relu(out)
53 | out = self.conv1(out)
54 | out = F.avg_pool2d(out, 2)
55 | return out
56 |
57 |
58 | class DenseNet(nn.Module):
59 |
60 | def __init__(self, depth=40, block=DenseBasicBlock,
61 | dropRate=0, num_classes=10, growthRate=12, compressionRate=1, filters=None, indexes=None,compress_rate=None):
62 | super(DenseNet, self).__init__()
63 |
64 | assert (depth - 4) % 3 == 0, 'depth should be 3n+4'
65 | n = (depth - 4) // 3 if 'DenseBasicBlock' in str(block) else (depth - 4) // 6
66 |
67 | transition = Transition
68 | if filters == None:
69 | filters = []
70 | start = growthRate*2
71 | for i in range(3):
72 | filters.append([start + growthRate*i for i in range(n+1)])
73 | start = (start + growthRate*n) // compressionRate
74 | filters = [item for sub_list in filters for item in sub_list]
75 |
76 | indexes = []
77 | for f in filters:
78 | indexes.append(np.arange(f))
79 |
80 | self.covcfg=cov_cfg
81 | self.compress_rate=compress_rate
82 |
83 | self.growthRate = growthRate
84 | self.dropRate = dropRate
85 |
86 | self.inplanes = growthRate * 2
87 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, padding=1,
88 | bias=False)
89 | self.conv1.cp_rate=compress_rate[0]
90 | self.conv1.tmp_name = 'conv1'
91 | self.last_prune_num=self.inplanes*compress_rate[0]
92 |
93 | self.dense1 = self._make_denseblock(block, n, filters[0:n], indexes[0:n], compress_rate[1:n+1],'dense1', self.last_prune_num)
94 | self.trans1 = self._make_transition(transition, compressionRate, filters[n], indexes[n], compress_rate[n+1],'trans1', self.last_prune_num)
95 | self.dense2 = self._make_denseblock(block, n, filters[n+1:2*n+1], indexes[n+1:2*n+1], compress_rate[n+2:2*n+2],'dense2', self.last_prune_num)
96 | self.trans2 = self._make_transition(transition, compressionRate, filters[2*n+1], indexes[2*n+1], compress_rate[2*n+2],'trans2', self.last_prune_num)
97 | self.dense3 = self._make_denseblock(block, n, filters[2*n+2:3*n+2], indexes[2*n+2:3*n+2], compress_rate[2*n+3:3*n+3],'dense3', self.last_prune_num)
98 | self.bn = nn.BatchNorm2d(self.inplanes)
99 | self.relu = nn.ReLU(inplace=True)
100 | self.avgpool = nn.AvgPool2d(8)
101 |
102 | self.fc = nn.Linear(self.inplanes, num_classes)
103 |
104 | # Weight initialization
105 | for m in self.modules():
106 | if isinstance(m, nn.Conv2d):
107 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
108 | m.weight.data.normal_(0, math.sqrt(2. / n))
109 | elif isinstance(m, nn.BatchNorm2d):
110 | m.weight.data.fill_(1)
111 | m.bias.data.zero_()
112 |
113 | def _make_denseblock(self, block, blocks, filters, indexes, compress_rate, tmp_name, last_prune_num):
114 | layers = []
115 | assert blocks == len(filters), 'Length of the filters parameter is not right.'
116 | assert blocks == len(indexes), 'Length of the indexes parameter is not right.'
117 | for i in range(blocks):
118 | self.last_prune_num+=int(compress_rate[i]*self.growthRate)
119 | layers.append(block(self.inplanes, filters=filters[i], index=indexes[i],
120 | growthRate=self.growthRate, dropRate=self.dropRate, compress_rate=compress_rate[i], tmp_name=tmp_name+'_'+str(i)))
121 | self.inplanes += self.growthRate
122 |
123 | return nn.Sequential(*layers)
124 |
125 | def _make_transition(self, transition, compressionRate, filters, index, compress_rate, tmp_name, last_prune_num):
126 | inplanes = self.inplanes
127 | outplanes = int(math.floor(self.inplanes // compressionRate))
128 | self.inplanes = outplanes
129 | self.last_prune_num=int(compress_rate*filters)
130 | return transition(inplanes, outplanes, filters, index, compress_rate, tmp_name, last_prune_num)
131 |
132 | def forward(self, x):
133 | x = self.conv1(x)
134 |
135 | x = self.dense1(x)
136 | x = self.trans1(x)
137 | x = self.dense2(x)
138 | x = self.trans2(x)
139 | x = self.dense3(x)
140 | x = self.bn(x)
141 | x = self.relu(x)
142 |
143 | x = self.avgpool(x)
144 | x = x.view(x.size(0), -1)
145 |
146 | x = self.fc(x)
147 |
148 | return x
149 |
150 |
151 | def densenet_40(compress_rate=None):
152 | return DenseNet(depth=40, block=DenseBasicBlock, compressionRate=1, compress_rate=compress_rate)
153 |
--------------------------------------------------------------------------------
/models/googlenet_cifar.py:
--------------------------------------------------------------------------------
1 | '''GoogLeNet with PyTorch.'''
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | norm_mean, norm_var = 0.0, 1.0
8 |
9 | cov_cfg=[(22*i+2) for i in range(1+2+5+2)]
10 |
11 |
12 | class Inception(nn.Module):
13 | def __init__(self, in_planes, n1x1, n3x3red, n3x3, n5x5red, n5x5, pool_planes, cp_rate, tmp_name):
14 | super(Inception, self).__init__()
15 | self.cp_rate=cp_rate
16 | self.tmp_name=tmp_name
17 |
18 | self.n1x1 = n1x1
19 | self.n3x3 = n3x3
20 | self.n5x5 = n5x5
21 | self.pool_planes = pool_planes
22 |
23 | # 1x1 conv branch
24 | if self.n1x1:
25 | conv1x1 = nn.Conv2d(in_planes, n1x1, kernel_size=1)
26 | conv1x1.cp_rate = self.cp_rate
27 | conv1x1.tmp_name = self.tmp_name
28 |
29 | self.branch1x1 = nn.Sequential(
30 | conv1x1,
31 | nn.BatchNorm2d(n1x1),
32 | nn.ReLU(True),
33 | )
34 |
35 | # 1x1 conv -> 3x3 conv branch
36 | if self.n3x3:
37 | conv3x3_1=nn.Conv2d(in_planes, n3x3red, kernel_size=1)
38 | conv3x3_2=nn.Conv2d(n3x3red, n3x3, kernel_size=3, padding=1)
39 | conv3x3_1.cp_rate = 0.
40 | conv3x3_1.tmp_name = self.tmp_name
41 | conv3x3_2.cp_rate = self.cp_rate
42 | conv3x3_2.tmp_name = self.tmp_name
43 |
44 | self.branch3x3 = nn.Sequential(
45 | conv3x3_1,
46 | nn.BatchNorm2d(n3x3red),
47 | nn.ReLU(True),
48 | conv3x3_2,
49 | nn.BatchNorm2d(n3x3),
50 | nn.ReLU(True),
51 | )
52 |
53 | # 1x1 conv -> 5x5 conv branch
54 | if self.n5x5 > 0:
55 | conv5x5_1 = nn.Conv2d(in_planes, n5x5red, kernel_size=1)
56 | conv5x5_2 = nn.Conv2d(n5x5red, n5x5, kernel_size=3, padding=1)
57 | conv5x5_3 = nn.Conv2d(n5x5, n5x5, kernel_size=3, padding=1)
58 | conv5x5_1.cp_rate = 0.
59 | conv5x5_1.tmp_name = self.tmp_name
60 | conv5x5_2.cp_rate = self.cp_rate
61 | conv5x5_2.tmp_name = self.tmp_name
62 | conv5x5_3.cp_rate = self.cp_rate
63 | conv5x5_3.tmp_name = self.tmp_name
64 |
65 | self.branch5x5 = nn.Sequential(
66 | conv5x5_1,
67 | nn.BatchNorm2d(n5x5red),
68 | nn.ReLU(True),
69 | conv5x5_2,
70 | nn.BatchNorm2d(n5x5),
71 | nn.ReLU(True),
72 | conv5x5_3,
73 | nn.BatchNorm2d(n5x5),
74 | nn.ReLU(True),
75 | )
76 |
77 | # 3x3 pool -> 1x1 conv branch
78 | if self.pool_planes > 0:
79 | conv_pool = nn.Conv2d(in_planes, pool_planes, kernel_size=1)
80 | conv_pool.cp_rate = self.cp_rate
81 | conv_pool.tmp_name = self.tmp_name
82 |
83 | self.branch_pool = nn.Sequential(
84 | nn.MaxPool2d(3, stride=1, padding=1),
85 | conv_pool,
86 | nn.BatchNorm2d(pool_planes),
87 | nn.ReLU(True),
88 | )
89 |
90 | def forward(self, x):
91 | out = []
92 | y1 = self.branch1x1(x)
93 | out.append(y1)
94 |
95 | y2 = self.branch3x3(x)
96 | out.append(y2)
97 |
98 | y3 = self.branch5x5(x)
99 | out.append(y3)
100 |
101 | y4 = self.branch_pool(x)
102 | out.append(y4)
103 | return torch.cat(out, 1)
104 |
105 |
106 | class GoogLeNet(nn.Module):
107 | def __init__(self, block=Inception, filters=None, compress_rate=None):
108 | super(GoogLeNet, self).__init__()
109 |
110 | self.covcfg=cov_cfg
111 | self.compress_rate=compress_rate
112 |
113 | conv_pre = nn.Conv2d(3, 192, kernel_size=3, padding=1)
114 | conv_pre.cp_rate=compress_rate[0]
115 | conv_pre.tmp_name='pre_layer'
116 | self.pre_layers = nn.Sequential(
117 | conv_pre,
118 | nn.BatchNorm2d(192),
119 | nn.ReLU(True),
120 | )
121 | if filters is None:
122 | filters = [
123 | [64, 128, 32, 32],
124 | [128, 192, 96, 64],
125 | [192, 208, 48, 64],
126 | [160, 224, 64, 64],
127 | [128, 256, 64, 64],
128 | [112, 288, 64, 64],
129 | [256, 320, 128, 128],
130 | [256, 320, 128, 128],
131 | [384, 384, 128, 128]
132 | ]
133 |
134 | self.filters=filters
135 |
136 | self.inception_a3 = block(192, filters[0][0], 96, filters[0][1], 16, filters[0][2], filters[0][3], self.compress_rate[1], 'a3')
137 | self.inception_b3 = block(sum(filters[0]), filters[1][0], 128, filters[1][1], 32, filters[1][2], filters[1][3], self.compress_rate[2], 'a4')
138 |
139 | self.maxpool1 = nn.MaxPool2d(3, stride=2, padding=1)
140 | self.maxpool2 = nn.MaxPool2d(3, stride=2, padding=1)
141 |
142 | self.inception_a4 = block(sum(filters[1]), filters[2][0], 96, filters[2][1], 16, filters[2][2], filters[2][3], self.compress_rate[3], 'a4')
143 | self.inception_b4 = block(sum(filters[2]), filters[3][0], 112, filters[3][1], 24, filters[3][2], filters[3][3], self.compress_rate[4], 'b4')
144 | self.inception_c4 = block(sum(filters[3]), filters[4][0], 128, filters[4][1], 24, filters[4][2], filters[4][3], self.compress_rate[6], 'c4')
145 | self.inception_d4 = block(sum(filters[4]), filters[5][0], 144, filters[5][1], 32, filters[5][2], filters[5][3], self.compress_rate[6], 'd4')
146 | self.inception_e4 = block(sum(filters[5]), filters[6][0], 160, filters[6][1], 32, filters[6][2], filters[6][3], self.compress_rate[7], 'e4')
147 |
148 | self.inception_a5 = block(sum(filters[6]), filters[7][0], 160, filters[7][1], 32, filters[7][2], filters[7][3], self.compress_rate[8], 'a5')
149 | self.inception_b5 = block(sum(filters[7]), filters[8][0], 192, filters[8][1], 48, filters[8][2], filters[8][3], self.compress_rate[9], 'b5')
150 |
151 | self.avgpool = nn.AvgPool2d(8, stride=1)
152 | self.linear = nn.Linear(sum(filters[-1]), 10)
153 |
154 | def forward(self, x):
155 |
156 | out = self.pre_layers(x)
157 | # 192 x 32 x 32
158 | out = self.inception_a3(out)
159 |
160 | # 256 x 32 x 32
161 | out = self.inception_b3(out)
162 | # 480 x 32 x 32
163 | out = self.maxpool1(out)
164 |
165 | # 480 x 16 x 16
166 | out = self.inception_a4(out)
167 |
168 | # 512 x 16 x 16
169 | out = self.inception_b4(out)
170 |
171 | # 512 x 16 x 16
172 | out = self.inception_c4(out)
173 |
174 | # 512 x 16 x 16
175 | out = self.inception_d4(out)
176 |
177 | # 528 x 16 x 16
178 | out = self.inception_e4(out)
179 | # 823 x 16 x 16
180 | out = self.maxpool2(out)
181 |
182 | # 823 x 8 x 8
183 | out = self.inception_a5(out)
184 |
185 | # 823 x 8 x 8
186 | out = self.inception_b5(out)
187 |
188 | # 1024 x 8 x 8
189 | out = self.avgpool(out)
190 | out = out.view(out.size(0), -1)
191 | out = self.linear(out)
192 |
193 | return out
194 |
195 |
196 | def googlenet(compress_rate=None):
197 | return GoogLeNet(block=Inception, compress_rate=compress_rate)
198 |
--------------------------------------------------------------------------------
/models/resnet_cifar.py:
--------------------------------------------------------------------------------
1 |
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | norm_mean, norm_var = 0.0, 1.0
6 |
7 |
8 | def conv3x3(in_planes, out_planes, stride=1):
9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
10 | padding=1, bias=False)
11 |
12 |
13 | class LambdaLayer(nn.Module):
14 | def __init__(self, lambd):
15 | super(LambdaLayer, self).__init__()
16 | self.lambd = lambd
17 |
18 | def forward(self, x):
19 | return self.lambd(x)
20 |
21 |
22 | class ResBasicBlock(nn.Module):
23 | expansion = 1
24 |
25 | def __init__(self, inplanes, planes, stride=1, compress_rate=[0.]):
26 | super(ResBasicBlock, self).__init__()
27 | self.inplanes = inplanes
28 | self.planes = planes
29 | self.conv1 = conv3x3(inplanes, planes, stride)
30 | self.conv1.cp_rate = compress_rate[0]
31 | self.bn1 = nn.BatchNorm2d(planes)
32 | self.relu1 = nn.ReLU(inplace=True)
33 |
34 | self.conv2 = conv3x3(planes, planes)
35 | self.conv2.cp_rate = compress_rate[1]
36 | self.bn2 = nn.BatchNorm2d(planes)
37 | self.relu2 = nn.ReLU(inplace=True)
38 | self.stride = stride
39 | self.shortcut = nn.Sequential()
40 | if stride != 1 or inplanes != planes:
41 | self.shortcut = LambdaLayer(
42 | lambda x: F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes // 4, planes // 4), "constant", 0))
43 |
44 | def forward(self, x):
45 | out = self.conv1(x)
46 | out = self.bn1(out)
47 | out = self.relu1(out)
48 |
49 | out = self.conv2(out)
50 | out = self.bn2(out)
51 |
52 | out += self.shortcut(x)
53 | out = self.relu2(out)
54 |
55 | return out
56 |
57 |
58 | class ResNet(nn.Module):
59 | def __init__(self, block, num_layers, covcfg, compress_rate, num_classes=10):
60 | super(ResNet, self).__init__()
61 | assert (num_layers - 2) % 6 == 0, 'depth should be 6n+2'
62 | n = (num_layers - 2) // 6
63 | self.covcfg = covcfg
64 | self.compress_rate = compress_rate
65 | self.num_layers = num_layers
66 |
67 | self.inplanes = 16
68 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
69 | self.conv1.cp_rate = compress_rate[0]
70 |
71 | self.bn1 = nn.BatchNorm2d(self.inplanes)
72 | self.relu = nn.ReLU(inplace=True)
73 |
74 | self.layer1 = self._make_layer(block, 16, blocks=n, stride=1,
75 | compress_rate=compress_rate[1:2 * n + 1])
76 | self.layer2 = self._make_layer(block, 32, blocks=n, stride=2,
77 | compress_rate=compress_rate[2 * n + 1:4 * n + 1])
78 | self.layer3 = self._make_layer(block, 64, blocks=n, stride=2,
79 | compress_rate=compress_rate[4 * n + 1:6 * n + 1])
80 | self.avgpool = nn.AdaptiveAvgPool2d(1)
81 |
82 | if num_layers == 110:
83 | self.linear = nn.Linear(64 * block.expansion, num_classes)
84 | else:
85 | self.fc = nn.Linear(64 * block.expansion, num_classes)
86 |
87 | self.initialize()
88 |
89 | def initialize(self):
90 | for m in self.modules():
91 | if isinstance(m, nn.Conv2d):
92 | nn.init.kaiming_normal_(m.weight)
93 | elif isinstance(m, nn.BatchNorm2d):
94 | nn.init.constant_(m.weight, 1)
95 | nn.init.constant_(m.bias, 0)
96 |
97 | def _make_layer(self, block, planes, blocks, stride, compress_rate):
98 | layers = []
99 |
100 | layers.append(block(self.inplanes, planes, stride, compress_rate=compress_rate[0:2]))
101 |
102 | self.inplanes = planes * block.expansion
103 | for i in range(1, blocks):
104 | layers.append(block(self.inplanes, planes, compress_rate=compress_rate[2 * i:2 * i + 2]))
105 |
106 | return nn.Sequential(*layers)
107 |
108 | def forward(self, x):
109 |
110 | x = self.conv1(x)
111 | x = self.bn1(x)
112 | x = self.relu(x)
113 | x = self.layer1(x)
114 | x = self.layer2(x)
115 | x = self.layer3(x)
116 | x = self.avgpool(x)
117 | x = x.view(x.size(0), -1)
118 |
119 | if self.num_layers == 110:
120 | x = self.linear(x)
121 | else:
122 | x = self.fc(x)
123 |
124 | return x
125 |
126 |
127 | def resnet_56(compress_rate=None):
128 | cov_cfg = [(3 * i + 2) for i in range(9 * 3 * 2 + 1)]
129 | return ResNet(ResBasicBlock, 56, cov_cfg, compress_rate=compress_rate)
130 |
131 |
132 | def resnet_110(compress_rate=None):
133 | cov_cfg = [(3 * i + 2) for i in range(18 * 3 * 2 + 1)]
134 | return ResNet(ResBasicBlock, 110, cov_cfg, compress_rate=compress_rate)
135 |
136 |
--------------------------------------------------------------------------------
/models/resnet_imagenet.py:
--------------------------------------------------------------------------------
1 |
2 | import torch.nn as nn
3 |
4 |
5 | norm_mean, norm_var = 1.0, 0.1
6 |
7 |
8 | def conv3x3(in_planes, out_planes, stride=1):
9 | """3x3 convolution with padding"""
10 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
11 | padding=1, bias=False)
12 |
13 |
14 | class ResBottleneck(nn.Module):
15 | expansion = 4
16 |
17 | def __init__(self, inplanes, planes, stride=1, downsample=None, cp_rate=[0.], tmp_name=None):
18 | super(ResBottleneck, self).__init__()
19 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
20 | self.conv1.cp_rate = cp_rate[0]
21 | self.conv1.tmp_name = tmp_name
22 | self.bn1 = nn.BatchNorm2d(planes)
23 | self.relu1 = nn.ReLU(inplace=True)
24 |
25 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
26 | padding=1, bias=False)
27 | self.bn2 = nn.BatchNorm2d(planes)
28 | self.conv2.cp_rate = cp_rate[1]
29 | self.conv2.tmp_name = tmp_name
30 | self.relu2 = nn.ReLU(inplace=True)
31 |
32 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
33 | self.conv3.cp_rate = cp_rate[2]
34 | self.conv3.tmp_name = tmp_name
35 |
36 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
37 | self.relu3 = nn.ReLU(inplace=True)
38 |
39 | self.downsample = downsample
40 | self.stride = stride
41 |
42 | def forward(self, x):
43 | residual = x
44 |
45 | out = self.conv1(x)
46 | out = self.bn1(out)
47 | out = self.relu1(out)
48 |
49 | out = self.conv2(out)
50 | out = self.bn2(out)
51 | out = self.relu2(out)
52 |
53 | out = self.conv3(out)
54 | out = self.bn3(out)
55 |
56 | if self.downsample is not None:
57 | residual = self.downsample(x)
58 |
59 | out += residual
60 | out = self.relu3(out)
61 |
62 | return out
63 |
64 |
65 | class Downsample(nn.Module):
66 | def __init__(self, downsample):
67 | super(Downsample, self).__init__()
68 | self.downsample = downsample
69 |
70 | def forward(self, x):
71 | out = self.downsample(x)
72 | return out
73 |
74 |
75 | class ResNet(nn.Module):
76 | def __init__(self, block, num_blocks, num_classes=1000, covcfg=None, compress_rate=None):
77 | self.inplanes = 64
78 | super(ResNet, self).__init__()
79 |
80 | self.covcfg = covcfg
81 | self.compress_rate = compress_rate
82 | self.num_blocks = num_blocks
83 |
84 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
85 | self.conv1.cp_rate = compress_rate[0]
86 | self.conv1.tmp_name = 'conv1'
87 | self.bn1 = nn.BatchNorm2d(64)
88 | self.relu = nn.ReLU(inplace=True)
89 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
90 |
91 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1,
92 | cp_rate=compress_rate[1:3*num_blocks[0]+2],
93 | tmp_name='layer1')
94 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2,
95 | cp_rate=compress_rate[3*num_blocks[0]+2:3*num_blocks[0]+3*num_blocks[1]+3],
96 | tmp_name='layer2')
97 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2,
98 | cp_rate=compress_rate[3*num_blocks[0]+3*num_blocks[1]+3:3*num_blocks[0]+3*num_blocks[1]+3*num_blocks[2]+4],
99 | tmp_name='layer3')
100 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2,
101 | cp_rate=compress_rate[3*num_blocks[0]+3*num_blocks[1]+3*num_blocks[2]+4:],
102 | tmp_name='layer4')
103 |
104 | self.avgpool = nn.AvgPool2d(7, stride=1)
105 | self.fc = nn.Linear(512 * block.expansion, num_classes)
106 |
107 | for m in self.modules():
108 | if isinstance(m, nn.Conv2d):
109 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
110 | elif isinstance(m, nn.BatchNorm2d):
111 | nn.init.constant_(m.weight, 1)
112 | nn.init.constant_(m.bias, 0)
113 |
114 | def _make_layer(self, block, planes, blocks, stride, cp_rate, tmp_name):
115 | downsample = None
116 | if stride != 1 or self.inplanes != planes * block.expansion:
117 | conv_short = nn.Conv2d(self.inplanes, planes * block.expansion,
118 | kernel_size=1, stride=stride, bias=False)
119 | conv_short.cp_rate = cp_rate[0]
120 | conv_short.tmp_name = tmp_name + '_shortcut'
121 | downsample = nn.Sequential(
122 | conv_short,
123 | nn.BatchNorm2d(planes * block.expansion),
124 | )
125 |
126 | layers = []
127 | layers.append(block(self.inplanes, planes, stride, downsample, cp_rate=cp_rate[1:4],
128 | tmp_name=tmp_name + '_block' + str(1)))
129 |
130 | self.inplanes = planes * block.expansion
131 | for i in range(1, blocks):
132 | layers.append(block(self.inplanes, planes, cp_rate=cp_rate[3 * i + 1:3 * i + 4],
133 | tmp_name=tmp_name + '_block' + str(i + 1)))
134 |
135 | return nn.Sequential(*layers)
136 |
137 | def forward(self, x):
138 |
139 | x = self.conv1(x)
140 | x = self.bn1(x)
141 | x = self.relu(x)
142 | x = self.maxpool(x)
143 |
144 | x = self.layer1(x)
145 |
146 | # 256 x 56 x 56
147 | x = self.layer2(x)
148 |
149 | # 512 x 28 x 28
150 | x = self.layer3(x)
151 |
152 | # 1024 x 14 x 14
153 | x = self.layer4(x)
154 |
155 | # 2048 x 7 x 7
156 | x = self.avgpool(x)
157 | x = x.view(x.size(0), -1)
158 | x = self.fc(x)
159 |
160 | return x
161 |
162 |
163 | def resnet_50(compress_rate=None):
164 | cov_cfg = [(3*i + 3) for i in range(3*3 + 1 + 4*3 + 1 + 6*3 + 1 + 3*3 + 1 + 1)]
165 | model = ResNet(ResBottleneck, [3, 4, 6, 3], covcfg=cov_cfg, compress_rate=compress_rate)
166 | return model
167 |
--------------------------------------------------------------------------------
/models/vgg.py:
--------------------------------------------------------------------------------
1 |
2 | import math
3 | import torch.nn as nn
4 | from collections import OrderedDict
5 |
6 |
7 | norm_mean, norm_var = 0.0, 1.0
8 |
9 | defaultcfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 512]
10 | relucfg = [2, 6, 9, 13, 16, 19, 23, 26, 29, 33, 36, 39]
11 | convcfg = [0, 3, 7, 10, 14, 17, 20, 24, 27, 30, 34, 37]
12 |
13 |
14 | class VGG(nn.Module):
15 | def __init__(self, num_classes=10, init_weights=True, cfg=None, compress_rate=None):
16 | super(VGG, self).__init__()
17 | self.features = nn.Sequential()
18 |
19 | if cfg is None:
20 | cfg = defaultcfg
21 |
22 | self.relucfg = relucfg
23 | self.covcfg = convcfg
24 | self.compress_rate = compress_rate
25 | self.features = self.make_layers(cfg[:-1], True, compress_rate)
26 | self.classifier = nn.Sequential(OrderedDict([
27 | ('linear1', nn.Linear(cfg[-2], cfg[-1])),
28 | ('norm1', nn.BatchNorm1d(cfg[-1])),
29 | ('relu1', nn.ReLU(inplace=True)),
30 | ('linear2', nn.Linear(cfg[-1], num_classes)),
31 | ]))
32 |
33 | if init_weights:
34 | self._initialize_weights()
35 |
36 | def make_layers(self, cfg, batch_norm=True, compress_rate=None):
37 | layers = nn.Sequential()
38 | in_channels = 3
39 | cnt = 0
40 | for i, v in enumerate(cfg):
41 | if v == 'M':
42 | layers.add_module('pool%d' % i, nn.MaxPool2d(kernel_size=2, stride=2))
43 | else:
44 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
45 | conv2d.cp_rate = compress_rate[cnt]
46 | cnt += 1
47 |
48 | layers.add_module('conv%d' % i, conv2d)
49 | layers.add_module('norm%d' % i, nn.BatchNorm2d(v))
50 | layers.add_module('relu%d' % i, nn.ReLU(inplace=True))
51 | in_channels = v
52 |
53 | return layers
54 |
55 | def forward(self, x):
56 | x = self.features(x)
57 |
58 | x = nn.AvgPool2d(2)(x)
59 | x = x.view(x.size(0), -1)
60 | x = self.classifier(x)
61 | return x
62 |
63 | def _initialize_weights(self):
64 | for m in self.modules():
65 | if isinstance(m, nn.Conv2d):
66 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
67 | m.weight.data.normal_(0, math.sqrt(2. / n))
68 | if m.bias is not None:
69 | m.bias.data.zero_()
70 | elif isinstance(m, nn.BatchNorm2d):
71 | m.weight.data.fill_(0.5)
72 | m.bias.data.zero_()
73 | elif isinstance(m, nn.Linear):
74 | m.weight.data.normal_(0, 0.01)
75 | m.bias.data.zero_()
76 |
77 |
78 | def vgg_16_bn(compress_rate=None):
79 | return VGG(compress_rate=compress_rate)
80 |
81 |
--------------------------------------------------------------------------------
/rank_conv/densenet_40/rank_conv1.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv1.npy
--------------------------------------------------------------------------------
/rank_conv/densenet_40/rank_conv10.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv10.npy
--------------------------------------------------------------------------------
/rank_conv/densenet_40/rank_conv11.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv11.npy
--------------------------------------------------------------------------------
/rank_conv/densenet_40/rank_conv12.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv12.npy
--------------------------------------------------------------------------------
/rank_conv/densenet_40/rank_conv13.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv13.npy
--------------------------------------------------------------------------------
/rank_conv/densenet_40/rank_conv14.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv14.npy
--------------------------------------------------------------------------------
/rank_conv/densenet_40/rank_conv15.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv15.npy
--------------------------------------------------------------------------------
/rank_conv/densenet_40/rank_conv16.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv16.npy
--------------------------------------------------------------------------------
/rank_conv/densenet_40/rank_conv17.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv17.npy
--------------------------------------------------------------------------------
/rank_conv/densenet_40/rank_conv18.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv18.npy
--------------------------------------------------------------------------------
/rank_conv/densenet_40/rank_conv19.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv19.npy
--------------------------------------------------------------------------------
/rank_conv/densenet_40/rank_conv2.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv2.npy
--------------------------------------------------------------------------------
/rank_conv/densenet_40/rank_conv20.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv20.npy
--------------------------------------------------------------------------------
/rank_conv/densenet_40/rank_conv21.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv21.npy
--------------------------------------------------------------------------------
/rank_conv/densenet_40/rank_conv22.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv22.npy
--------------------------------------------------------------------------------
/rank_conv/densenet_40/rank_conv23.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv23.npy
--------------------------------------------------------------------------------
/rank_conv/densenet_40/rank_conv24.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv24.npy
--------------------------------------------------------------------------------
/rank_conv/densenet_40/rank_conv25.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv25.npy
--------------------------------------------------------------------------------
/rank_conv/densenet_40/rank_conv26.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv26.npy
--------------------------------------------------------------------------------
/rank_conv/densenet_40/rank_conv27.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv27.npy
--------------------------------------------------------------------------------
/rank_conv/densenet_40/rank_conv28.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv28.npy
--------------------------------------------------------------------------------
/rank_conv/densenet_40/rank_conv29.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv29.npy
--------------------------------------------------------------------------------
/rank_conv/densenet_40/rank_conv3.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv3.npy
--------------------------------------------------------------------------------
/rank_conv/densenet_40/rank_conv30.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv30.npy
--------------------------------------------------------------------------------
/rank_conv/densenet_40/rank_conv31.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv31.npy
--------------------------------------------------------------------------------
/rank_conv/densenet_40/rank_conv32.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv32.npy
--------------------------------------------------------------------------------
/rank_conv/densenet_40/rank_conv33.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv33.npy
--------------------------------------------------------------------------------
/rank_conv/densenet_40/rank_conv34.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv34.npy
--------------------------------------------------------------------------------
/rank_conv/densenet_40/rank_conv35.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv35.npy
--------------------------------------------------------------------------------
/rank_conv/densenet_40/rank_conv36.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv36.npy
--------------------------------------------------------------------------------
/rank_conv/densenet_40/rank_conv37.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv37.npy
--------------------------------------------------------------------------------
/rank_conv/densenet_40/rank_conv38.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv38.npy
--------------------------------------------------------------------------------
/rank_conv/densenet_40/rank_conv39.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv39.npy
--------------------------------------------------------------------------------
/rank_conv/densenet_40/rank_conv4.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv4.npy
--------------------------------------------------------------------------------
/rank_conv/densenet_40/rank_conv5.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv5.npy
--------------------------------------------------------------------------------
/rank_conv/densenet_40/rank_conv6.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv6.npy
--------------------------------------------------------------------------------
/rank_conv/densenet_40/rank_conv7.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv7.npy
--------------------------------------------------------------------------------
/rank_conv/densenet_40/rank_conv8.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv8.npy
--------------------------------------------------------------------------------
/rank_conv/densenet_40/rank_conv9.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/densenet_40/rank_conv9.npy
--------------------------------------------------------------------------------
/rank_conv/googlenet/rank_conv1.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv1.npy
--------------------------------------------------------------------------------
/rank_conv/googlenet/rank_conv10_n1x1.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv10_n1x1.npy
--------------------------------------------------------------------------------
/rank_conv/googlenet/rank_conv10_n3x3.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv10_n3x3.npy
--------------------------------------------------------------------------------
/rank_conv/googlenet/rank_conv10_n5x5.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv10_n5x5.npy
--------------------------------------------------------------------------------
/rank_conv/googlenet/rank_conv10_pool_planes.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv10_pool_planes.npy
--------------------------------------------------------------------------------
/rank_conv/googlenet/rank_conv2_n1x1.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv2_n1x1.npy
--------------------------------------------------------------------------------
/rank_conv/googlenet/rank_conv2_n3x3.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv2_n3x3.npy
--------------------------------------------------------------------------------
/rank_conv/googlenet/rank_conv2_n5x5.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv2_n5x5.npy
--------------------------------------------------------------------------------
/rank_conv/googlenet/rank_conv2_pool_planes.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv2_pool_planes.npy
--------------------------------------------------------------------------------
/rank_conv/googlenet/rank_conv3_n1x1.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv3_n1x1.npy
--------------------------------------------------------------------------------
/rank_conv/googlenet/rank_conv3_n3x3.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv3_n3x3.npy
--------------------------------------------------------------------------------
/rank_conv/googlenet/rank_conv3_n5x5.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv3_n5x5.npy
--------------------------------------------------------------------------------
/rank_conv/googlenet/rank_conv3_pool_planes.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv3_pool_planes.npy
--------------------------------------------------------------------------------
/rank_conv/googlenet/rank_conv4_n1x1.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv4_n1x1.npy
--------------------------------------------------------------------------------
/rank_conv/googlenet/rank_conv4_n3x3.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv4_n3x3.npy
--------------------------------------------------------------------------------
/rank_conv/googlenet/rank_conv4_n5x5.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv4_n5x5.npy
--------------------------------------------------------------------------------
/rank_conv/googlenet/rank_conv4_pool_planes.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv4_pool_planes.npy
--------------------------------------------------------------------------------
/rank_conv/googlenet/rank_conv5_n1x1.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv5_n1x1.npy
--------------------------------------------------------------------------------
/rank_conv/googlenet/rank_conv5_n3x3.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv5_n3x3.npy
--------------------------------------------------------------------------------
/rank_conv/googlenet/rank_conv5_n5x5.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv5_n5x5.npy
--------------------------------------------------------------------------------
/rank_conv/googlenet/rank_conv5_pool_planes.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv5_pool_planes.npy
--------------------------------------------------------------------------------
/rank_conv/googlenet/rank_conv6_n1x1.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv6_n1x1.npy
--------------------------------------------------------------------------------
/rank_conv/googlenet/rank_conv6_n3x3.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv6_n3x3.npy
--------------------------------------------------------------------------------
/rank_conv/googlenet/rank_conv6_n5x5.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv6_n5x5.npy
--------------------------------------------------------------------------------
/rank_conv/googlenet/rank_conv6_pool_planes.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv6_pool_planes.npy
--------------------------------------------------------------------------------
/rank_conv/googlenet/rank_conv7_n1x1.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv7_n1x1.npy
--------------------------------------------------------------------------------
/rank_conv/googlenet/rank_conv7_n3x3.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv7_n3x3.npy
--------------------------------------------------------------------------------
/rank_conv/googlenet/rank_conv7_n5x5.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv7_n5x5.npy
--------------------------------------------------------------------------------
/rank_conv/googlenet/rank_conv7_pool_planes.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv7_pool_planes.npy
--------------------------------------------------------------------------------
/rank_conv/googlenet/rank_conv8_n1x1.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv8_n1x1.npy
--------------------------------------------------------------------------------
/rank_conv/googlenet/rank_conv8_n3x3.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv8_n3x3.npy
--------------------------------------------------------------------------------
/rank_conv/googlenet/rank_conv8_n5x5.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv8_n5x5.npy
--------------------------------------------------------------------------------
/rank_conv/googlenet/rank_conv8_pool_planes.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv8_pool_planes.npy
--------------------------------------------------------------------------------
/rank_conv/googlenet/rank_conv9_n1x1.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv9_n1x1.npy
--------------------------------------------------------------------------------
/rank_conv/googlenet/rank_conv9_n3x3.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv9_n3x3.npy
--------------------------------------------------------------------------------
/rank_conv/googlenet/rank_conv9_n5x5.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv9_n5x5.npy
--------------------------------------------------------------------------------
/rank_conv/googlenet/rank_conv9_pool_planes.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/googlenet/rank_conv9_pool_planes.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv1.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv1.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv10.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv10.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv100.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv100.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv101.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv101.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv102.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv102.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv103.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv103.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv104.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv104.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv105.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv105.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv106.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv106.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv107.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv107.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv108.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv108.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv109.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv109.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv11.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv11.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv12.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv12.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv13.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv13.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv14.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv14.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv15.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv15.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv16.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv16.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv17.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv17.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv18.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv18.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv19.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv19.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv2.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv2.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv20.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv20.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv21.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv21.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv22.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv22.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv23.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv23.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv24.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv24.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv25.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv25.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv26.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv26.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv27.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv27.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv28.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv28.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv29.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv29.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv3.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv3.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv30.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv30.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv31.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv31.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv32.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv32.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv33.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv33.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv34.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv34.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv35.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv35.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv36.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv36.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv37.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv37.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv38.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv38.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv39.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv39.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv4.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv4.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv40.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv40.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv41.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv41.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv42.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv42.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv43.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv43.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv44.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv44.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv45.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv45.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv46.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv46.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv47.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv47.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv48.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv48.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv49.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv49.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv5.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv5.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv50.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv50.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv51.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv51.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv52.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv52.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv53.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv53.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv54.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv54.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv55.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv55.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv56.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv56.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv57.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv57.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv58.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv58.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv59.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv59.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv6.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv6.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv60.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv60.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv61.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv61.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv62.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv62.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv63.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv63.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv64.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv64.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv65.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv65.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv66.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv66.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv67.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv67.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv68.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv68.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv69.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv69.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv7.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv7.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv70.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv70.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv71.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv71.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv72.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv72.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv73.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv73.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv74.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv74.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv75.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv75.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv76.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv76.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv77.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv77.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv78.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv78.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv79.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv79.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv8.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv8.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv80.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv80.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv81.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv81.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv82.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv82.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv83.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv83.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv84.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv84.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv85.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv85.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv86.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv86.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv87.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv87.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv88.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv88.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv89.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv89.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv9.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv9.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv90.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv90.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv91.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv91.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv92.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv92.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv93.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv93.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv94.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv94.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv95.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv95.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv96.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv96.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv97.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv97.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv98.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv98.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_110/rank_conv99.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_110/rank_conv99.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv1.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv1.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv10.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv10.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv11.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv11.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv12.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv12.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv13.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv13.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv14.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv14.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv15.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv15.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv16.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv16.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv17.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv17.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv18.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv18.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv19.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv19.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv2.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv2.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv20.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv20.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv21.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv21.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv22.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv22.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv23.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv23.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv24.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv24.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv25.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv25.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv26.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv26.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv27.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv27.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv28.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv28.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv29.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv29.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv3.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv3.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv30.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv30.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv31.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv31.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv32.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv32.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv33.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv33.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv34.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv34.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv35.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv35.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv36.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv36.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv37.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv37.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv38.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv38.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv39.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv39.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv4.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv4.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv40.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv40.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv41.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv41.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv42.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv42.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv43.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv43.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv44.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv44.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv45.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv45.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv46.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv46.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv47.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv47.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv48.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv48.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv49.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv49.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv5.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv5.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv50.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv50.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv51.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv51.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv52.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv52.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv53.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv53.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv6.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv6.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv7.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv7.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv8.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv8.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_50/rank_conv9.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_50/rank_conv9.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv1.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv1.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv10.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv10.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv11.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv11.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv12.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv12.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv13.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv13.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv14.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv14.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv15.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv15.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv16.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv16.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv17.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv17.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv18.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv18.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv19.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv19.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv2.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv2.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv20.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv20.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv21.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv21.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv22.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv22.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv23.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv23.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv24.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv24.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv25.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv25.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv26.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv26.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv27.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv27.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv28.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv28.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv29.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv29.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv3.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv3.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv30.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv30.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv31.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv31.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv32.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv32.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv33.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv33.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv34.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv34.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv35.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv35.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv36.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv36.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv37.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv37.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv38.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv38.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv39.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv39.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv4.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv4.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv40.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv40.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv41.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv41.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv42.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv42.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv43.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv43.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv44.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv44.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv45.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv45.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv46.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv46.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv47.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv47.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv48.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv48.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv49.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv49.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv5.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv5.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv50.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv50.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv51.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv51.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv52.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv52.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv53.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv53.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv54.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv54.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv55.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv55.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv6.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv6.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv7.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv7.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv8.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv8.npy
--------------------------------------------------------------------------------
/rank_conv/resnet_56/rank_conv9.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/resnet_56/rank_conv9.npy
--------------------------------------------------------------------------------
/rank_conv/vgg_16_bn/rank_conv1.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/vgg_16_bn/rank_conv1.npy
--------------------------------------------------------------------------------
/rank_conv/vgg_16_bn/rank_conv10.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/vgg_16_bn/rank_conv10.npy
--------------------------------------------------------------------------------
/rank_conv/vgg_16_bn/rank_conv11.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/vgg_16_bn/rank_conv11.npy
--------------------------------------------------------------------------------
/rank_conv/vgg_16_bn/rank_conv12.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/vgg_16_bn/rank_conv12.npy
--------------------------------------------------------------------------------
/rank_conv/vgg_16_bn/rank_conv2.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/vgg_16_bn/rank_conv2.npy
--------------------------------------------------------------------------------
/rank_conv/vgg_16_bn/rank_conv3.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/vgg_16_bn/rank_conv3.npy
--------------------------------------------------------------------------------
/rank_conv/vgg_16_bn/rank_conv4.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/vgg_16_bn/rank_conv4.npy
--------------------------------------------------------------------------------
/rank_conv/vgg_16_bn/rank_conv5.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/vgg_16_bn/rank_conv5.npy
--------------------------------------------------------------------------------
/rank_conv/vgg_16_bn/rank_conv6.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/vgg_16_bn/rank_conv6.npy
--------------------------------------------------------------------------------
/rank_conv/vgg_16_bn/rank_conv7.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/vgg_16_bn/rank_conv7.npy
--------------------------------------------------------------------------------
/rank_conv/vgg_16_bn/rank_conv8.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/vgg_16_bn/rank_conv8.npy
--------------------------------------------------------------------------------
/rank_conv/vgg_16_bn/rank_conv9.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmbxmu/HRank/33050a16c11b5e0f105b268be8c2a42087d11c9d/rank_conv/vgg_16_bn/rank_conv9.npy
--------------------------------------------------------------------------------
/rank_generation.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.backends.cudnn as cudnn
4 |
5 | import torchvision
6 | import torchvision.transforms as transforms
7 |
8 | import os
9 | import argparse
10 |
11 | import data.imagenet as imagenet
12 | from models import *
13 | from utils import progress_bar
14 | import numpy as np
15 |
16 | parser = argparse.ArgumentParser(description='Rank extraction')
17 |
18 | parser.add_argument(
19 | '--data_dir',
20 | type=str,
21 | default='./data',
22 | help='dataset path')
23 | parser.add_argument(
24 | '--dataset',
25 | type=str,
26 | default='cifar10',
27 | choices=('cifar10','imagenet'),
28 | help='dataset')
29 | parser.add_argument(
30 | '--job_dir',
31 | type=str,
32 | default='result/tmp',
33 | help='The directory where the summaries will be stored.')
34 | parser.add_argument(
35 | '--arch',
36 | type=str,
37 | default='vgg_16_bn',
38 | choices=('resnet_50','vgg_16_bn','resnet_56','resnet_110','densenet_40','googlenet'),
39 | help='The architecture to prune')
40 | parser.add_argument(
41 | '--resume',
42 | type=str,
43 | default=None,
44 | help='load the model from the specified checkpoint')
45 | parser.add_argument(
46 | '--limit',
47 | type=int,
48 | default=5,
49 | help='The num of batch to get rank.')
50 | parser.add_argument(
51 | '--train_batch_size',
52 | type=int,
53 | default=128,
54 | help='Batch size for training.')
55 | parser.add_argument(
56 | '--eval_batch_size',
57 | type=int,
58 | default=100,
59 | help='Batch size for validation.')
60 | parser.add_argument(
61 | '--start_idx',
62 | type=int,
63 | default=0,
64 | help='The index of conv to start extract rank.')
65 | parser.add_argument(
66 | '--gpu',
67 | type=str,
68 | default='0',
69 | help='Select gpu to use')
70 | parser.add_argument(
71 | '--adjust_ckpt',
72 | action='store_true',
73 | help='adjust ckpt from pruned checkpoint')
74 | parser.add_argument(
75 | '--compress_rate',
76 | type=str,
77 | default=None,
78 | help='compress rate of each conv')
79 |
80 |
81 | args = parser.parse_args()
82 |
83 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
84 | os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
85 | cudnn.benchmark = True
86 |
87 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
88 |
89 | # Data
90 | print('==> Preparing data..')
91 | if args.dataset=='cifar10':
92 | transform_train = transforms.Compose([
93 | transforms.RandomCrop(32, padding=4),
94 | transforms.RandomHorizontalFlip(),
95 | transforms.ToTensor(),
96 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
97 | ])
98 |
99 | transform_test = transforms.Compose([
100 | transforms.ToTensor(),
101 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
102 | ])
103 | trainset = torchvision.datasets.CIFAR10(root=args.data_dir, train=True, download=True, transform=transform_train)
104 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.train_batch_size, shuffle=True, num_workers=2)
105 |
106 | testset = torchvision.datasets.CIFAR10(root=args.data_dir, train=False, download=True, transform=transform_test)
107 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
108 | elif args.dataset=='imagenet':
109 | data_tmp = imagenet.Data(args)
110 | trainloader = data_tmp.loader_train
111 | testloader = data_tmp.loader_test
112 |
113 | if args.compress_rate:
114 | import re
115 | cprate_str=args.compress_rate
116 | cprate_str_list=cprate_str.split('+')
117 | pat_cprate = re.compile(r'\d+\.\d*')
118 | pat_num = re.compile(r'\*\d+')
119 | cprate=[]
120 | for x in cprate_str_list:
121 | num=1
122 | find_num=re.findall(pat_num,x)
123 | if find_num:
124 | assert len(find_num) == 1
125 | num=int(find_num[0].replace('*',''))
126 | find_cprate = re.findall(pat_cprate, x)
127 | assert len(find_cprate)==1
128 | cprate+=[float(find_cprate[0])]*num
129 |
130 | compress_rate=cprate
131 | else:
132 | default_cprate={
133 | 'vgg_16_bn': [0.7]*7+[0.1]*6,
134 | 'densenet_40': [0.0]+[0.1]*6+[0.7]*6+[0.0]+[0.1]*6+[0.7]*6+[0.0]+[0.1]*6+[0.7]*5+[0.0],
135 | 'googlenet': [0.10]+[0.7]+[0.5]+[0.8]*4+[0.5]+[0.6]*2,
136 | 'resnet_50':[0.2]+[0.8]*10+[0.8]*13+[0.55]*19+[0.45]*10,
137 | 'resnet_56':[0.1]+[0.60]*35+[0.0]*2+[0.6]*6+[0.4]*3+[0.1]+[0.4]+[0.1]+[0.4]+[0.1]+[0.4]+[0.1]+[0.4],
138 | 'resnet_110':[0.1]+[0.40]*36+[0.40]*36+[0.4]*36
139 | }
140 | compress_rate=default_cprate[args.arch]
141 |
142 | # Model
143 | print('==> Building model..')
144 | print(compress_rate)
145 | net = eval(args.arch)(compress_rate=compress_rate)
146 | net = net.to(device)
147 |
148 | if len(args.gpu)>1 and torch.cuda.is_available():
149 | device_id = []
150 | for i in range((len(args.gpu) + 1) // 2):
151 | device_id.append(i)
152 | net = torch.nn.DataParallel(net, device_ids=device_id)
153 |
154 |
155 | if args.resume:
156 | # Load checkpoint.
157 | print('==> Resuming from checkpoint..')
158 | checkpoint = torch.load(args.resume, map_location='cuda:'+args.gpu)
159 | from collections import OrderedDict
160 | new_state_dict = OrderedDict()
161 | if args.adjust_ckpt:
162 | for k, v in checkpoint.items():
163 | new_state_dict[k.replace('module.', '')] = v
164 | else:
165 | for k, v in checkpoint['state_dict'].items():
166 | new_state_dict[k.replace('module.', '')] = v
167 | net.load_state_dict(new_state_dict)
168 |
169 |
170 | criterion = nn.CrossEntropyLoss()
171 | feature_result = torch.tensor(0.)
172 | total = torch.tensor(0.)
173 | def get_feature_hook(self, input, output):
174 | global feature_result
175 | global entropy
176 | global total
177 | a = output.shape[0]
178 | b = output.shape[1]
179 | c = torch.tensor([torch.matrix_rank(output[i,j,:,:]).item() for i in range(a) for j in range(b)])
180 |
181 | c = c.view(a, -1).float()
182 | c = c.sum(0)
183 | feature_result = feature_result * total + c
184 | total = total + a
185 | feature_result = feature_result / total
186 |
187 | def get_feature_hook_densenet(self, input, output):
188 | global feature_result
189 | global total
190 | a = output.shape[0]
191 | b = output.shape[1]
192 | c = torch.tensor([torch.matrix_rank(output[i,j,:,:]).item() for i in range(a) for j in range(b-12,b)])
193 |
194 | c = c.view(a, -1).float()
195 | c = c.sum(0)
196 | feature_result = feature_result * total + c
197 | total = total + a
198 | feature_result = feature_result / total
199 |
200 | def get_feature_hook_googlenet(self, input, output):
201 | global feature_result
202 | global total
203 | a = output.shape[0]
204 | b = output.shape[1]
205 | c = torch.tensor([torch.matrix_rank(output[i,j,:,:]).item() for i in range(a) for j in range(b-12,b)])
206 |
207 | c = c.view(a, -1).float()
208 | c = c.sum(0)
209 | feature_result = feature_result * total + c
210 | total = total + a
211 | feature_result = feature_result / total
212 |
213 |
214 | def test():
215 | global best_acc
216 | net.eval()
217 | test_loss = 0
218 | correct = 0
219 | total = 0
220 | limit = args.limit
221 |
222 | with torch.no_grad():
223 | for batch_idx, (inputs, targets) in enumerate(trainloader):
224 | if batch_idx >= limit: # use the first 6 batches to estimate the rank.
225 | break
226 | inputs, targets = inputs.to(device), targets.to(device)
227 | outputs = net(inputs)
228 | loss = criterion(outputs, targets)
229 |
230 | test_loss += loss.item()
231 | _, predicted = outputs.max(1)
232 | total += targets.size(0)
233 | correct += predicted.eq(targets).sum().item()
234 |
235 | progress_bar(batch_idx, limit, 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
236 | % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))#'''
237 |
238 |
239 | if args.arch=='vgg_16_bn':
240 |
241 | if len(args.gpu) > 1:
242 | relucfg = net.module.relucfg
243 | else:
244 | relucfg = net.relucfg
245 |
246 | for i, cov_id in enumerate(relucfg):
247 | cov_layer = net.features[cov_id]
248 | handler = cov_layer.register_forward_hook(get_feature_hook)
249 | test()
250 | handler.remove()
251 |
252 | if not os.path.isdir('rank_conv/'+args.arch+'_limit%d'%(args.limit)):
253 | os.mkdir('rank_conv/'+args.arch+'_limit%d'%(args.limit))
254 | np.save('rank_conv/'+args.arch+'_limit%d'%(args.limit)+'/rank_conv' + str(i + 1) + '.npy', feature_result.numpy())
255 |
256 | feature_result = torch.tensor(0.)
257 | total = torch.tensor(0.)
258 |
259 | elif args.arch=='resnet_56':
260 |
261 | cov_layer = eval('net.relu')
262 | handler = cov_layer.register_forward_hook(get_feature_hook)
263 | test()
264 | handler.remove()
265 |
266 | if not os.path.isdir('rank_conv/' + args.arch+'_limit%d'%(args.limit)):
267 | os.mkdir('rank_conv/' + args.arch+'_limit%d'%(args.limit))
268 | np.save('rank_conv/' + args.arch+'_limit%d'%(args.limit)+ '/rank_conv%d' % (1) + '.npy', feature_result.numpy())
269 | feature_result = torch.tensor(0.)
270 | total = torch.tensor(0.)
271 |
272 | # ResNet56 per block
273 | cnt=1
274 | for i in range(3):
275 | block = eval('net.layer%d' % (i + 1))
276 | for j in range(9):
277 | cov_layer = block[j].relu1
278 | handler = cov_layer.register_forward_hook(get_feature_hook)
279 | test()
280 | handler.remove()
281 | np.save('rank_conv/' + args.arch +'_limit%d'%(args.limit)+ '/rank_conv%d'%(cnt + 1)+'.npy', feature_result.numpy())
282 | cnt+=1
283 | feature_result = torch.tensor(0.)
284 | total = torch.tensor(0.)
285 |
286 | cov_layer = block[j].relu2
287 | handler = cov_layer.register_forward_hook(get_feature_hook)
288 | test()
289 | handler.remove()
290 | np.save('rank_conv/' + args.arch +'_limit%d'%(args.limit)+ '/rank_conv%d'%(cnt + 1)+'.npy', feature_result.numpy())
291 | cnt += 1
292 | feature_result = torch.tensor(0.)
293 | total = torch.tensor(0.)
294 |
295 | elif args.arch=='densenet_40':
296 |
297 | if not os.path.isdir('rank_conv/' + args.arch+'_limit%d'%(args.limit)):
298 | os.mkdir('rank_conv/' + args.arch+'_limit%d'%(args.limit))
299 |
300 | feature_result = torch.tensor(0.)
301 | total = torch.tensor(0.)
302 |
303 | # Densenet per block & transition
304 | for i in range(3):
305 | dense = eval('net.dense%d' % (i + 1))
306 | for j in range(12):
307 | cov_layer = dense[j].relu
308 | if j==0:
309 | handler = cov_layer.register_forward_hook(get_feature_hook)
310 | else:
311 | handler = cov_layer.register_forward_hook(get_feature_hook_densenet)
312 | test()
313 | handler.remove()
314 |
315 | np.save('rank_conv/' + args.arch +'_limit%d'%(args.limit) + '/rank_conv%d'%(13*i+j+1)+'.npy', feature_result.numpy())
316 | feature_result = torch.tensor(0.)
317 | total = torch.tensor(0.)
318 |
319 | if i<2:
320 | trans=eval('net.trans%d' % (i + 1))
321 | cov_layer = trans.relu
322 |
323 | handler = cov_layer.register_forward_hook(get_feature_hook_densenet)
324 | test()
325 | handler.remove()
326 |
327 | np.save('rank_conv/' + args.arch +'_limit%d'%(args.limit) + '/rank_conv%d' % (13 * (i+1)) + '.npy', feature_result.numpy())
328 | feature_result = torch.tensor(0.)
329 | total = torch.tensor(0.)#'''
330 |
331 | cov_layer = net.relu
332 | handler = cov_layer.register_forward_hook(get_feature_hook_densenet)
333 | test()
334 | handler.remove()
335 | np.save('rank_conv/' + args.arch +'_limit%d'%(args.limit) + '/rank_conv%d' % (39) + '.npy', feature_result.numpy())
336 | feature_result = torch.tensor(0.)
337 | total = torch.tensor(0.)
338 |
339 | elif args.arch=='googlenet':
340 |
341 | if not os.path.isdir('rank_conv/' + args.arch+'_limit%d'%(args.limit)):
342 | os.mkdir('rank_conv/' + args.arch+'_limit%d'%(args.limit))
343 | feature_result = torch.tensor(0.)
344 | total = torch.tensor(0.)
345 |
346 | cov_list=['pre_layers',
347 | 'inception_a3',
348 | 'maxpool1',
349 | 'inception_a4',
350 | 'inception_b4',
351 | 'inception_c4',
352 | 'inception_d4',
353 | 'maxpool2',
354 | 'inception_a5',
355 | 'inception_b5',
356 | ]
357 |
358 | # branch type
359 | tp_list=['n1x1','n3x3','n5x5','pool_planes']
360 | for idx, cov in enumerate(cov_list):
361 |
362 | if idx0:
371 | for idx1,tp in enumerate(tp_list):
372 | if idx1==3:
373 | np.save('rank_conv/' + args.arch+'_limit%d'%(args.limit) + '/rank_conv%d_'%(idx+1)+tp+'.npy',
374 | feature_result[sum(net.filters[idx-1][:-1]) : sum(net.filters[idx-1][:])].numpy())
375 | #elif idx1==0:
376 | # np.save('rank_conv1/' + args.arch + '/rank_conv%d_'%(idx+1)+tp+'.npy',
377 | # feature_result[0 : sum(net.filters[idx-1][:1])].numpy())
378 | else:
379 | np.save('rank_conv/' + args.arch+'_limit%d'%(args.limit) + '/rank_conv%d_' % (idx + 1) + tp + '.npy',
380 | feature_result[sum(net.filters[idx-1][:idx1]) : sum(net.filters[idx-1][:idx1+1])].numpy())
381 | else:
382 | np.save('rank_conv/' + args.arch+'_limit%d'%(args.limit) + '/rank_conv%d' % (idx + 1) + '.npy',feature_result.numpy())
383 | feature_result = torch.tensor(0.)
384 | total = torch.tensor(0.)
385 |
386 | elif args.arch=='resnet_110':
387 |
388 | cov_layer = eval('net.relu')
389 | handler = cov_layer.register_forward_hook(get_feature_hook)
390 | test()
391 | handler.remove()
392 |
393 | if not os.path.isdir('rank_conv/' + args.arch+'_limit%d'%(args.limit)):
394 | os.mkdir('rank_conv/' + args.arch+'_limit%d'%(args.limit))
395 | np.save('rank_conv/' + args.arch+'_limit%d'%(args.limit) + '/rank_conv%d' % (1) + '.npy', feature_result.numpy())
396 | feature_result = torch.tensor(0.)
397 | total = torch.tensor(0.)
398 |
399 | cnt = 1
400 | # ResNet110 per block
401 | for i in range(3):
402 | block = eval('net.layer%d' % (i + 1))
403 | for j in range(18):
404 | cov_layer = block[j].relu1
405 | handler = cov_layer.register_forward_hook(get_feature_hook)
406 | test()
407 | handler.remove()
408 | np.save('rank_conv/' + args.arch + '_limit%d' % (args.limit) + '/rank_conv%d' % (
409 | cnt + 1) + '.npy', feature_result.numpy())
410 | cnt += 1
411 | feature_result = torch.tensor(0.)
412 | total = torch.tensor(0.)
413 |
414 | cov_layer = block[j].relu2
415 | handler = cov_layer.register_forward_hook(get_feature_hook)
416 | test()
417 | handler.remove()
418 | np.save('rank_conv/' + args.arch + '_limit%d' % (args.limit) + '/rank_conv%d' % (
419 | cnt + 1) + '.npy', feature_result.numpy())
420 | cnt += 1
421 | feature_result = torch.tensor(0.)
422 | total = torch.tensor(0.)
423 |
424 | elif args.arch=='resnet_50':
425 |
426 | cov_layer = eval('net.maxpool')
427 | handler = cov_layer.register_forward_hook(get_feature_hook)
428 | test()
429 | handler.remove()
430 |
431 | if not os.path.isdir('rank_conv/' + args.arch+'_limit%d'%(args.limit)):
432 | os.mkdir('rank_conv/' + args.arch+'_limit%d'%(args.limit))
433 | np.save('rank_conv/' + args.arch+'_limit%d'%(args.limit) + '/rank_conv%d' % (1) + '.npy', feature_result.numpy())
434 | feature_result = torch.tensor(0.)
435 | total = torch.tensor(0.)
436 |
437 | # ResNet50 per bottleneck
438 | cnt=1
439 | for i in range(4):
440 | block = eval('net.layer%d' % (i + 1))
441 | for j in range(net.num_blocks[i]):
442 | cov_layer = block[j].relu1
443 | handler = cov_layer.register_forward_hook(get_feature_hook)
444 | test()
445 | handler.remove()
446 | np.save('rank_conv/' + args.arch+'_limit%d'%(args.limit) + '/rank_conv%d'%(cnt+1)+'.npy', feature_result.numpy())
447 | cnt+=1
448 | feature_result = torch.tensor(0.)
449 | total = torch.tensor(0.)
450 |
451 | cov_layer = block[j].relu2
452 | handler = cov_layer.register_forward_hook(get_feature_hook)
453 | test()
454 | handler.remove()
455 | np.save('rank_conv/' + args.arch + '_limit%d' % (args.limit) + '/rank_conv%d' % (cnt + 1) + '.npy',
456 | feature_result.numpy())
457 | cnt += 1
458 | feature_result = torch.tensor(0.)
459 | total = torch.tensor(0.)
460 |
461 | cov_layer = block[j].relu3
462 | handler = cov_layer.register_forward_hook(get_feature_hook)
463 | test()
464 | handler.remove()
465 | if j==0:
466 | np.save('rank_conv/' + args.arch + '_limit%d' % (args.limit) + '/rank_conv%d' % (cnt + 1) + '.npy',
467 | feature_result.numpy())#shortcut conv
468 | cnt += 1
469 | np.save('rank_conv/' + args.arch + '_limit%d' % (args.limit) + '/rank_conv%d' % (cnt + 1) + '.npy',
470 | feature_result.numpy())#conv3
471 | cnt += 1
472 | feature_result = torch.tensor(0.)
473 | total = torch.tensor(0.)
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | from __future__ import absolute_import
4 | import os
5 | import sys
6 | import time
7 | import logging
8 | import datetime
9 | import torch
10 | from pathlib import Path
11 |
12 |
13 | def get_logger(file_path):
14 | """ Make python logger """
15 | # [!] Since tensorboardX use default logger (e.g. logging.info()), we should use custom logger
16 | logger = logging.getLogger('kd')
17 | log_format = '%(asctime)s | %(message)s'
18 | formatter = logging.Formatter(log_format, datefmt='%m/%d %I:%M:%S %p')
19 | file_handler = logging.FileHandler(file_path)
20 | file_handler.setFormatter(formatter)
21 | stream_handler = logging.StreamHandler()
22 | stream_handler.setFormatter(formatter)
23 |
24 | logger.addHandler(file_handler)
25 | logger.addHandler(stream_handler)
26 | logger.setLevel(logging.INFO)
27 |
28 | return logger
29 |
30 |
31 | class AverageMeter(object):
32 | """Computes and stores the average and current value"""
33 | def __init__(self):
34 | self.reset()
35 |
36 | def reset(self):
37 | self.val = 0.0
38 | self.avg = 0.0
39 | self.sum = 0.0
40 | self.count = 0
41 |
42 | def update(self, val, n=1):
43 | self.val = val
44 | self.sum += val * n
45 | self.count += n
46 | self.avg = self.sum / self.count
47 |
48 |
49 | def accuracy(output, target, topk=(1,)):
50 | """Computes the precision@k for the specified values of k"""
51 | with torch.no_grad():
52 | maxk = max(topk)
53 | batch_size = target.size(0)
54 |
55 | _, pred = output.topk(maxk, 1, True, True)
56 | pred = pred.t()
57 | correct = pred.eq(target.view(1, -1).expand_as(pred))
58 |
59 | res = []
60 | for k in topk:
61 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
62 | res.append(correct_k.mul_(100.0 / batch_size))
63 | return res
64 |
65 |
66 | class checkpoint():
67 | def __init__(self, args):
68 | now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
69 | today = datetime.date.today()
70 |
71 | self.args = args
72 | self.job_dir = Path(args.job_dir)
73 | self.run_dir = self.job_dir / 'run'
74 | print(args.job_dir)
75 |
76 | def _make_dir(path):
77 | if not os.path.exists(path): os.makedirs(path)
78 |
79 | _make_dir(self.job_dir)
80 | _make_dir(self.run_dir)
81 |
82 | config_dir = self.job_dir / 'config.txt'
83 | with open(config_dir, 'w') as f:
84 | f.write(now + '\n\n')
85 | for arg in vars(args):
86 | f.write('{}: {}\n'.format(arg, getattr(args, arg)))
87 | f.write('\n')
88 |
89 |
90 | def print_params(config, prtf=print):
91 | prtf("")
92 | prtf("Parameters:")
93 | for attr, value in sorted(config.items()):
94 | prtf("{}={}".format(attr.upper(), value))
95 | prtf("")
96 |
97 |
98 | _, term_width = os.popen('stty size', 'r').read().split()
99 | term_width = int(term_width)
100 |
101 | TOTAL_BAR_LENGTH = 65.
102 | last_time = time.time()
103 | begin_time = last_time
104 | def progress_bar(current, total, msg=None):
105 | global last_time, begin_time
106 | if current == 0:
107 | begin_time = time.time() # Reset for new bar.
108 |
109 | cur_len = int(TOTAL_BAR_LENGTH*current/total)
110 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
111 |
112 | sys.stdout.write(' [')
113 | for i in range(cur_len):
114 | sys.stdout.write('=')
115 | sys.stdout.write('>')
116 | for i in range(rest_len):
117 | sys.stdout.write('.')
118 | sys.stdout.write(']')
119 |
120 | cur_time = time.time()
121 | step_time = cur_time - last_time
122 | last_time = cur_time
123 | tot_time = cur_time - begin_time
124 |
125 | L = []
126 | L.append(' Step: %s' % format_time(step_time))
127 | L.append(' | Tot: %s' % format_time(tot_time))
128 | if msg:
129 | L.append(' | ' + msg)
130 |
131 | msg = ''.join(L)
132 | sys.stdout.write(msg)
133 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
134 | sys.stdout.write(' ')
135 |
136 | # Go back to the center of the bar.
137 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
138 | sys.stdout.write('\b')
139 | sys.stdout.write(' %d/%d ' % (current+1, total))
140 |
141 | if current < total-1:
142 | sys.stdout.write('\r')
143 | else:
144 | sys.stdout.write('\n')
145 | sys.stdout.flush()
146 |
147 |
148 | def format_time(seconds):
149 | days = int(seconds / 3600/24)
150 | seconds = seconds - days*3600*24
151 | hours = int(seconds / 3600)
152 | seconds = seconds - hours*3600
153 | minutes = int(seconds / 60)
154 | seconds = seconds - minutes*60
155 | secondsf = int(seconds)
156 | seconds = seconds - secondsf
157 | millis = int(seconds*1000)
158 |
159 | f = ''
160 | i = 1
161 | if days > 0:
162 | f += str(days) + 'D'
163 | i += 1
164 | if hours > 0 and i <= 2:
165 | f += str(hours) + 'h'
166 | i += 1
167 | if minutes > 0 and i <= 2:
168 | f += str(minutes) + 'm'
169 | i += 1
170 | if secondsf > 0 and i <= 2:
171 | f += str(secondsf) + 's'
172 | i += 1
173 | if millis > 0 and i <= 2:
174 | f += str(millis) + 'ms'
175 | i += 1
176 | if f == '':
177 | f = '0ms'
178 | return f
179 |
--------------------------------------------------------------------------------