├── .gitignore
├── .idea
├── inspectionProfiles
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
├── pruning_tool.iml
├── vcs.xml
└── workspace.xml
├── README.md
├── __pycache__
├── dataloader.cpython-36.pyc
├── dataloader.cpython-37.pyc
├── pruning.cpython-36.pyc
├── pruning.cpython-37.pyc
├── quantization.cpython-36.pyc
├── quantization.cpython-37.pyc
├── sensitivity_analysis.cpython-36.pyc
└── sensitivity_analysis.cpython-37.pyc
├── commom_utils
├── __pycache__
│ ├── utils.cpython-36.pyc
│ └── utils.cpython-37.pyc
├── my_to_tensor.py
└── utils.py
├── data
└── train_data
│ ├── __pycache__
│ ├── ms1m_10k_loader.cpython-36.pyc
│ └── verifacation.cpython-36.pyc
│ ├── ms1m_10k_loader.py
│ └── verifacation.py
├── dataloader.py
├── distiller
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── config.cpython-36.pyc
│ ├── directives.cpython-36.pyc
│ ├── knowledge_distillation.cpython-36.pyc
│ ├── learning_rate.cpython-36.pyc
│ ├── model_summaries.cpython-36.pyc
│ ├── model_transforms.cpython-36.pyc
│ ├── policy.cpython-36.pyc
│ ├── scheduler.cpython-36.pyc
│ ├── sensitivity.cpython-36.pyc
│ ├── summary_graph.cpython-36.pyc
│ ├── thinning.cpython-36.pyc
│ ├── thresholding.cpython-36.pyc
│ └── utils.cpython-36.pyc
├── apputils
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── checkpoint.cpython-36.pyc
│ │ ├── data_loaders.cpython-36.pyc
│ │ ├── dataset_summaries.cpython-36.pyc
│ │ ├── execution_env.cpython-36.pyc
│ │ └── image_classifier.cpython-36.pyc
│ ├── checkpoint.py
│ ├── data_loaders.py
│ ├── dataset_summaries.py
│ ├── execution_env.py
│ └── image_classifier.py
├── config.py
├── data_loggers
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── collector.cpython-36.pyc
│ │ ├── logger.cpython-36.pyc
│ │ └── tbbackend.cpython-36.pyc
│ ├── collector.py
│ ├── logger.py
│ └── tbbackend.py
├── directives.py
├── knowledge_distillation.py
├── learning_rate.py
├── model_summaries.py
├── model_transforms.py
├── models
│ ├── __init__.py
│ ├── __pycache__
│ │ └── __init__.cpython-36.pyc
│ ├── cifar10
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-36.pyc
│ │ │ ├── plain_cifar.cpython-36.pyc
│ │ │ ├── preresnet_cifar.cpython-36.pyc
│ │ │ ├── resnet_cifar.cpython-36.pyc
│ │ │ ├── resnet_cifar_earlyexit.cpython-36.pyc
│ │ │ ├── simplenet_cifar.cpython-36.pyc
│ │ │ └── vgg_cifar.cpython-36.pyc
│ │ ├── plain_cifar.py
│ │ ├── preresnet_cifar.py
│ │ ├── resnet_cifar.py
│ │ ├── resnet_cifar_earlyexit.py
│ │ ├── simplenet_cifar.py
│ │ └── vgg_cifar.py
│ ├── imagenet
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-36.pyc
│ │ │ ├── alexnet_batchnorm.cpython-36.pyc
│ │ │ ├── mobilenet.cpython-36.pyc
│ │ │ ├── mobilenet_dropout.cpython-36.pyc
│ │ │ ├── preresnet_imagenet.cpython-36.pyc
│ │ │ ├── resnet.cpython-36.pyc
│ │ │ └── resnet_earlyexit.cpython-36.pyc
│ │ ├── alexnet_batchnorm.py
│ │ ├── mobilenet.py
│ │ ├── mobilenet_dropout.py
│ │ ├── preresnet_imagenet.py
│ │ ├── resnet.py
│ │ └── resnet_earlyexit.py
│ └── mnist
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ └── simplenet_mnist.cpython-36.pyc
│ │ └── simplenet_mnist.py
├── modules
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── aggregate.cpython-36.pyc
│ │ ├── eltwise.cpython-36.pyc
│ │ ├── grouping.cpython-36.pyc
│ │ ├── matmul.cpython-36.pyc
│ │ ├── rnn.cpython-36.pyc
│ │ └── tsvd.cpython-36.pyc
│ ├── aggregate.py
│ ├── eltwise.py
│ ├── grouping.py
│ ├── matmul.py
│ ├── rnn.py
│ └── tsvd.py
├── policy.py
├── pruning
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── automated_gradual_pruner.cpython-36.pyc
│ │ ├── baidu_rnn_pruner.cpython-36.pyc
│ │ ├── greedy_filter_pruning.cpython-36.pyc
│ │ ├── level_pruner.cpython-36.pyc
│ │ ├── magnitude_pruner.cpython-36.pyc
│ │ ├── pruner.cpython-36.pyc
│ │ ├── ranked_structures_pruner.cpython-36.pyc
│ │ ├── sensitivity_pruner.cpython-36.pyc
│ │ ├── splicing_pruner.cpython-36.pyc
│ │ └── structure_pruner.cpython-36.pyc
│ ├── automated_gradual_pruner.py
│ ├── baidu_rnn_pruner.py
│ ├── greedy_filter_pruning.py
│ ├── level_pruner.py
│ ├── magnitude_pruner.py
│ ├── pruner.py
│ ├── ranked_structures_pruner.py
│ ├── sensitivity_pruner.py
│ ├── splicing_pruner.py
│ └── structure_pruner.py
├── quantization
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── clipped_linear.cpython-36.pyc
│ │ ├── q_utils.cpython-36.pyc
│ │ ├── quantizer.cpython-36.pyc
│ │ ├── range_linear.cpython-36.pyc
│ │ └── sim_bn_fold.cpython-36.pyc
│ ├── clipped_linear.py
│ ├── q_utils.py
│ ├── quantizer.py
│ ├── range_linear.py
│ └── sim_bn_fold.py
├── regularization
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── drop_filter.cpython-36.pyc
│ │ ├── group_regularizer.cpython-36.pyc
│ │ ├── l1_regularizer.cpython-36.pyc
│ │ └── regularizer.cpython-36.pyc
│ ├── drop_filter.py
│ ├── group_regularizer.py
│ ├── l1_regularizer.py
│ └── regularizer.py
├── scheduler.py
├── sensitivity.py
├── summary_graph.py
├── thinning.py
├── thresholding.py
└── utils.py
├── example~
├── ljt_mobilefacenet_y2.sh
├── ljt_shufflefacenet_v2.sh
├── main.py
├── model_define
├── MobileFaceNet.py
├── MobileNetV3.py
├── __pycache__
│ ├── MobileFaceNet.cpython-36.pyc
│ ├── MobileFaceNet.cpython-37.pyc
│ ├── MobileNetV3.cpython-36.pyc
│ ├── MobileNetV3.cpython-37.pyc
│ ├── load_state_dict.cpython-36.pyc
│ ├── load_state_dict.cpython-37.pyc
│ ├── model.cpython-36.pyc
│ ├── model.cpython-37.pyc
│ ├── model_resnet.cpython-36.pyc
│ ├── model_resnet.cpython-37.pyc
│ └── resnet50_imagenet.cpython-36.pyc
├── load_state_dict.py
├── mobilefacenet_y2_ljt
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── mobilefacenet_big.cpython-36.pyc
│ │ └── network_elems.cpython-36.pyc
│ ├── common_utility.py
│ ├── mobilefacenet_big.py
│ └── network_elems.py
├── model.py
├── model_resnet.py
├── resnet50_imagenet.py
├── resnet_100_ljt
│ ├── __pycache__
│ │ └── resnet_100.cpython-36.pyc
│ └── resnet_100.py
├── resnet_50_ljt
│ ├── __pycache__
│ │ └── resnet_50.cpython-36.pyc
│ └── resnet_50.py
└── shufflefacenet_v2_ljt
│ ├── ShuffleFaceNetV2.py
│ ├── __pycache__
│ └── ShuffleFaceNetV2.cpython-36.pyc
│ └── blocks.py
├── mytest.py
├── prune.sh
├── pruning.py
├── pruning_analysis_tools
├── auto_make_yaml.py
└── plot_csv.py
├── quantization.py
├── quantization.sh
├── resnet_ljt.sh
├── sensitivity_analysis.py
├── src
├── __pycache__
│ ├── data_loader.cpython-36.pyc
│ ├── data_loader.cpython-37.pyc
│ ├── dataset.cpython-36.pyc
│ └── dataset.cpython-37.pyc
├── data_loader.py
├── dataset.py
└── loader
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── __init__.cpython-37.pyc
│ ├── autoaugment.cpython-36.pyc
│ ├── autoaugment.cpython-37.pyc
│ ├── functional.cpython-36.pyc
│ ├── functional_V2.cpython-36.pyc
│ ├── functional_V2.cpython-37.pyc
│ ├── read_image_list_io.cpython-36.pyc
│ ├── read_image_list_io.cpython-37.pyc
│ ├── transforms.cpython-36.pyc
│ ├── transforms_V2.cpython-36.pyc
│ ├── transforms_V2.cpython-37.pyc
│ ├── utility.cpython-36.pyc
│ └── utility.cpython-37.pyc
│ ├── autoaugment.py
│ ├── functional.py
│ ├── functional_V2.py
│ ├── read_image_list_io.py
│ ├── transforms.py
│ ├── transforms_V2.py
│ └── utility.py
├── test_class.py
├── test_module
├── __pycache__
│ ├── test_on_diverse_dataset.cpython-36.pyc
│ ├── test_on_diverse_dataset.cpython-37.pyc
│ ├── test_on_face_classification.cpython-36.pyc
│ ├── test_on_face_classification.cpython-37.pyc
│ ├── test_on_face_recognition.cpython-36.pyc
│ ├── test_on_face_recognition.cpython-37.pyc
│ └── test_with_insight_face.cpython-36.pyc
├── test_on_diverse_dataset.py
├── test_on_face_classification.py
├── test_on_face_recognition.py
└── test_with_insight_face.py
├── train_module
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ └── train_with_insight_face.cpython-36.pyc
└── train_with_insight_face.py
├── work_space
├── finetune
│ └── 2020-09-08-09-18
│ │ └── log
│ │ └── 2020-09-08-09-18
│ │ └── events.out.tfevents.1599527940.yeluyue
├── layers_out
│ └── out.txt
├── pruned_define_model
│ ├── __pycache__
│ │ ├── make_pruned_mobilefacenet_y2.cpython-36.pyc
│ │ ├── make_pruned_resnet50.cpython-36.pyc
│ │ └── make_pruned_resnet50_imagenet.cpython-36.pyc
│ ├── make_prune_resnet100.py
│ ├── make_prune_resnet34_lzc.py
│ ├── make_pruned_mobilefacenet_lzc.py
│ ├── make_pruned_mobilefacenet_y2.py
│ ├── make_pruned_mobilefacenet_y2_ljt.py
│ ├── make_pruned_mobilefacenet_zkx.py
│ ├── make_pruned_resnet34.py
│ ├── make_pruned_resnet50.py
│ ├── make_pruned_resnet50_imagenet.py
│ ├── make_pruned_resnet_100_ljt.py
│ ├── make_pruned_resnet_50_ljt.py
│ └── make_pruned_shufflefacenet_v2_ljt.py
└── sensitivity_data
│ ├── mobilefacenet_0.8228
│ └── sensitivity.csv
│ ├── mobilefacenet_lzc_0.4755
│ ├── L1Rank
│ │ └── sensitivity_mobilefacenet_lzc.csv
│ └── fpgm
│ │ └── sensitivity_mobilefacenet_lzc.csv
│ ├── mobilefacenet_y2_0.8651
│ ├── L1Rank
│ │ └── sensitivity_mobilefacenet_y2.csv
│ └── fpgm
│ │ └── sensitivity_mobilefacenet_y2_1.csv
│ ├── mobilefacenet_y2_company_0.9867
│ └── sensitivity_mobilefacenet_y2_2020-09-04-13-08.csv
│ ├── mobilefacenet_y2_ljt_BUSID
│ ├── FPGM
│ │ └── sensitivity_mobilefacenet_y2_ljt_2020-09-16-00-13.csv
│ └── L1Rank
│ │ └── sensitivity_mobilefacenet_y2_ljt_2020-09-15-15-52-BUSID.csv
│ ├── mobilefacenet_y2_ljt_TYLG
│ ├── FPGM
│ │ └── sensitivity_mobilefacenet_y2_ljt_2020-09-16-00-48.csv
│ └── L1Rank
│ │ └── sensitivity_mobilefacenet_y2_ljt_2020-09-15-16-35-TYLG.csv
│ ├── mobilefacenet_y2_ljt_XCH
│ └── FPGM
│ │ └── sensitivity_mobilefacenet_y2_ljt_2020-09-16-14-18.csv
│ ├── mobilefacenet_y2_zkx_0.7889
│ ├── L1Rank
│ │ └── sensitivity_mobilefacenet_y2.csv
│ └── fpgm
│ │ └── sensitivity_mobilefacenet_y2_2019-11-05-06-30.csv
│ ├── mobilefacenet_y2_zkx_0.8118
│ ├── L1Rank
│ │ └── sensitivity_mobilefacenet_y2_2019-11-04-09-33.csv
│ └── fpgm
│ │ └── sensitivity_mobilefacenet_y2_2019-11-05-05-59.csv
│ ├── mobilenetv3_0.6613
│ ├── L1Rank
│ │ └── sensitivity_mobilenetv3.csv
│ └── fpgm
│ │ └── sensitivity_mobilenetv3.csv
│ ├── resnet100_0.7710
│ ├── L1Rank
│ │ └── sensitivity_resnet100.csv
│ └── fpgm
│ │ └── sensitivity_resnet100_2019-11-07-20-49.csv
│ ├── resnet34_0.7217
│ └── sensitivity_resnet34.csv
│ ├── resnet34_lzc_0.6335
│ ├── L1Rank
│ │ └── sensitivity_resnet34_lzc.csv
│ └── fpgm
│ │ └── sensitivity_resnet34_lzc.csv
│ ├── resnet50_0.7517
│ ├── L1Rank
│ │ └── sensitivity_resnet50.csv
│ └── fpgm
│ │ ├── sensitivity_resnet50_2019-11-07-13-42.csv
│ │ └── sensitivity_resnet50_2019-11-13-15-04.csv
│ ├── resnet50_imagenet_0.9773
│ ├── L1Rank
│ │ ├── auto_yaml.yaml
│ │ └── sensitivity_resnet50_imagenet_2020-09-02-12-37.csv
│ └── fpgm
│ │ ├── auto_yaml.yaml
│ │ └── sensitivity_resnet50_imagenet_2020-08-27-17-22.csv
│ ├── resnet_100_ljt
│ └── FPGM
│ │ └── sensitivity_resnet_100_ljt_2020-09-23-08-25.csv
│ ├── resnet_50_ljt
│ └── fpgm
│ │ └── sensitivity_resnet_50_ljt_2020-09-22-23-00.csv
│ ├── sensitivity_resnet50_2020-09-07-17-13.csv
│ ├── shufflefacenet_v2_ljt_BUSID
│ └── sensitivity_shufflefacenet_v2_ljt_2020-09-16-19-16.csv
│ ├── shufflefacenet_v2_ljt_TYLG
│ └── FPGM
│ │ └── sensitivity_shufflefacenet_v2_ljt_2020-09-18-15-53.csv
│ └── shufflefacenet_v2_ljt_XCH
│ └── sensitivity_shufflefacenet_v2_ljt_2020-09-17-02-47.csv
└── yaml_file
└── auto_yaml.yaml
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pth
2 | *.png
3 | *.minivision-linx
4 | *.pt
5 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/pruning_tool.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # compression_tool
2 | A compression tool for minivision
3 | * 支持以下模型剪枝
4 | - [x] VGG
5 | - [x] ResNet
6 | - [x] MobileFaceNet
7 | - [x] ShuffleNet_v2
8 | * 支持以下剪枝算法
9 | - [x] L1Rank
10 | - [x] FPGM
11 | - [x] HRank
12 | - [ ] EagleEye(暂未放出)
13 |
14 | ## Install
15 | git clone https://github.com/BlossomingL/compression_tool
16 |
17 | 环境:pytorch 1.X
18 |
19 | ## 程序文件列表和文件功能
20 | * commen_utils: 一些常用的工具方法
21 | * data: 存放测试数据集和训练数据集,以及需要的pair list
22 | * model_define: 存放待剪枝的模型定义文件
23 | * pruning_analysis_tools: 两个剪枝分析工具,auto_make_yaml.py根据得到的csv文件自动生成yaml配置文件,plot_csv.py根据csv文件画出敏感度折线图。
24 | * test_module: 存放不同的测试模块(计算精度,ROC等)
25 | * train_module: 存放不同的训练模块
26 | * work_space: 存放各种模型文件,模型定义文件,以及敏感度分析数据(csv文件)
27 | * yaml_file: 存放自动生成的yaml文件
28 | * dataloader.py: 数据多线程批加载
29 | * main.py: 主程序入口
30 | * prune.sh: 剪枝脚本
31 | * pruning.py: 剪枝相关代码
32 | * quantization.py: 量化相关代码
33 | * quantization.sh: 量化运行脚本
34 | * sensitivity_analysis.py: 敏感度分析,生成csv文件
35 |
36 | ## 剪枝运行流程
37 | * 准备阶段:本工具运行需要安装distiller,安装文件位于Distiller,在Distiller目录下打开命令行,运行python setup.py install。准备模型定义文件放入model_define文件下,准备训练最高精度的pt文件放在work_space/model_train_best下,准备测试集放在data下以及测试集测试代码放在test_module下。
38 |
39 | * 敏感度分析: 运行pruning.py文件,例如对resnet100进行剪枝敏感度分析,sh脚本如下
40 |
41 | ```python
42 | python pruning.py --mode sa \ # sa(sensitivity analysis)表示进入敏感度分析模式
43 | --model resnet100 \ # 模型名称,只能从固定的几个选择
44 | --best_model_path work_space/model_train_best/2019-09-29-11-37_SVGArcFace-O1-b0.4s40t1.1_fc_0.4_112x112_2019-09-27-Adult-padSY-Bus_fResNet100v3cv-d512_model_iter-340000.pth \ # 训练好的模型文件
45 | --from_data_parallel \ # 上面的模型文件是否是多卡训练得到
46 | --test_root_path data/test_data/fc_0.4_112x112 \ # 测试集root路径
47 | --img_list_label_path # pair list路径 data/test_data/fc_0.4_112x112/pair_list/id_life_image_list_bmppair.txt \
48 | --fpgm \ # 采用fpgm算法剪枝
49 | --data_source company # 数据集来源
50 | ```
51 | 运行过后会生成一个csv文件,在work_space/sensitivity_data下
52 | * 生成yaml文件: 运行pruning_analysis_tools下的auto_make_yaml.py文件,其中config_yaml函数的参数需要自己配置,参数1:csv文件路径,参数2:期望剪枝后的精度,参数3:模型名称,比如上述resnet100,那么此参数为resnet100,参数4: 输入图像大小。运行后会在yaml_file文件夹下生成一个yaml文件。
53 |
54 | * 剪枝:运行pruning.py文件,例如对resnet100进行剪枝,sh脚本如下
55 | ```python
56 | python pruning.py --mode prune \
57 | --model resnet100 \
58 | --best_model_path work_space/model_train_best/2019-09-29-11-37_SVGArcFace-O1-b0.4s40t1.1_fc_0.4_112x112_2019-09-27-Adult-padSY-Bus_fResNet100v3cv-d512_model_iter-340000.pth \
59 | --from_data_parallel \
60 | --fpgm \
61 | --save_model_pt \ # 是否保存剪枝后的模型文件
62 | --test_root_path data/test_data/fc_0.4_112x112 \
63 | --img_list_label_path data/test_data/fc_0.4_112x112/pair_list/id_life_image_list_bmppair.txt \
64 | --data_source company
65 | ```
66 | 运行后会生成一个剪枝后的模型文件,保存在work_space/pruned_model,文件名与模型名称一致,并且还会打印出每层的out参数,此参数即剪枝后模型定义文件中的keep数组。
67 |
68 | ## 关于distiller
69 | 由于此工具硬剪枝部分(即真正将通道移除)的代码是采用distiller框架中的代码,因为模型的特殊,需要更改框架源码才能进行剪枝,下面对更改的部分说明:
70 | * distiller/apputils下的data_loaders.py文件classification_get_input_shape函数中dataset与yaml文件中dataset参数的值一样,例如:如果输入网络图像大小为80x80那么yaml文件中的dataset参数也就是80x80,如果需要添加其它类型的输入,那么就要更改此代码。
71 | * distiller/policy.py下添加了fpgm的参数选项
72 | * distiller/thinning.py下添加了两个个功能:
73 | * 能够剪PReLU层,具体函数为handle_prelu_layers,append_prelu_thinning_directive(注:如果PReLU层采用默认的参数1,那么需要将此代码注释掉,否则会出错)
74 | * 针对公司Block的第一层为BN层,源代码本身不支持对此层剪枝,更改后可支持。具体函数为handle_bn_layers_bn1。
75 | * distiller/pruning/ranked_structures_pruner.py下添加了fpgm算法。具体代码为rank_and_prune_filters函数中if fpgm开始到if结束。
76 | * distiller/summary_graph下更改源码一处BUG,在add_footprint_attr函数下加入try/catch模块
77 | * 增加CVPR 2020 HRank剪枝方法
78 |
79 | ## 参考
80 | [1] HRank:https://arxiv.org/abs/2002.10179
81 | [2] FPGM: https://arxiv.org/abs/1811.00250
82 | [3] Distiller: https://github.com/NervanaSystems/distiller
83 |
--------------------------------------------------------------------------------
/__pycache__/dataloader.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/__pycache__/dataloader.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/dataloader.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/__pycache__/dataloader.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/pruning.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/__pycache__/pruning.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/pruning.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/__pycache__/pruning.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/quantization.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/__pycache__/quantization.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/quantization.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/__pycache__/quantization.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/sensitivity_analysis.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/__pycache__/sensitivity_analysis.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/sensitivity_analysis.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/__pycache__/sensitivity_analysis.cpython-37.pyc
--------------------------------------------------------------------------------
/commom_utils/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/commom_utils/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/commom_utils/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/commom_utils/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/commom_utils/my_to_tensor.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # author: LinX
3 | # datetime: 2019/11/1 下午2:04
4 | import torch
5 | import numpy as np
6 |
7 |
8 | def to_tensor(pic):
9 |
10 | if isinstance(pic, np.ndarray):
11 | # handle numpy array
12 | if pic.ndim == 2:
13 | pic = pic[:, :, None]
14 |
15 | img = torch.from_numpy(pic.transpose((2, 0, 1)))
16 | # backward compatibility
17 | if isinstance(img, torch.ByteTensor):
18 | return img.float()
19 | else:
20 | return img
21 | # handle PIL Image
22 | if pic.mode == 'I':
23 | img = torch.from_numpy(np.array(pic, np.int32, copy=False))
24 | elif pic.mode == 'I;16':
25 | img = torch.from_numpy(np.array(pic, np.int16, copy=False))
26 | elif pic.mode == 'F':
27 | img = torch.from_numpy(np.array(pic, np.float32, copy=False))
28 | elif pic.mode == '1':
29 | img = 255 * torch.from_numpy(np.array(pic, np.uint8, copy=False))
30 | else:
31 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
32 | # PIL image mode: L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK
33 | if pic.mode == 'YCbCr':
34 | nchannel = 3
35 | elif pic.mode == 'I;16':
36 | nchannel = 1
37 | else:
38 | nchannel = len(pic.mode)
39 | img = img.view(pic.size[1], pic.size[0], nchannel)
40 | # put it from HWC to CHW format
41 | # yikes, this transpose takes 80% of the loading time/CPU
42 | img = img.transpose(0, 1).transpose(0, 2).contiguous()
43 | if isinstance(img, torch.ByteTensor):
44 | return img.float()
45 | else:
46 | return img
47 |
48 |
49 | class ToTensor:
50 | def __init__(self):
51 | pass
52 |
53 | def __call__(self, pic):
54 | """
55 | Args:
56 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
57 |
58 | Returns:
59 | Tensor: Converted image.
60 | """
61 | return to_tensor(pic)
62 |
--------------------------------------------------------------------------------
/commom_utils/utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # author: LinX
3 | # datetime: 2019/10/10 下午2:22
4 |
5 | import time
6 | from thop import profile
7 | from tqdm import tqdm
8 | from model_define.model import ResNet34
9 | from model_define.resnet50_imagenet import resnet_50
10 | from timeit import default_timer as timer
11 | import torch
12 | from datetime import datetime
13 | from torchvision import transforms as trans
14 | import matplotlib.pyplot as plt
15 | plt.switch_backend('agg')
16 | import io
17 |
18 |
19 | def l2_norm(input, axis=1):
20 | norm = torch.norm(input, 2, axis, True)
21 | output = torch.div(input, norm)
22 | return output
23 |
24 |
25 | def de_preprocess(tensor):
26 | return tensor*0.5 + 0.5
27 |
28 |
29 | def get_time():
30 | return (str(datetime.now())[:-10]).replace(' ', '-').replace(':', '-')
31 |
32 |
33 | def test_speed(model, shape=[1, 3, 122, 122], device='gpu', test_time=10000):
34 | model.eval()
35 | inputs = torch.rand(shape)
36 | if device == 'gpu':
37 | model = model.to('cuda')
38 | inputs = inputs.to('cuda')
39 | else:
40 | model = model.to('cpu')
41 | print('Testing forward time,this may take a few minutes')
42 | start_time = timer()
43 | with torch.no_grad():
44 | for i in tqdm(range(test_time)):
45 | model(inputs)
46 | count = timer() - start_time
47 | forward_time = (count / test_time) * 1000
48 | print('平均forward时间为{}ms'.format(forward_time))
49 | return forward_time
50 |
51 |
52 | def cal_flops(model, input_shape, device='gpu'):
53 | input_random = torch.rand(input_shape)
54 | if device == 'gpu':
55 | input_random = input_random.to('cuda')
56 | model = model.to('cuda')
57 | else:
58 | model = model.to('cpu')
59 | flops, params = profile(model, inputs=(input_random, ), verbose=False)
60 | return flops / (1024 * 1024 * 1024), params / (1024 * 1024)
61 |
62 |
63 | hflip = trans.Compose([
64 | de_preprocess,
65 | trans.ToPILImage(),
66 | trans.functional.hflip,
67 | trans.ToTensor(),
68 | trans.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
69 | ])
70 |
71 |
72 | def hflip_batch(imgs_tensor):
73 | hfliped_imgs = torch.empty_like(imgs_tensor)
74 | for i, img_ten in enumerate(imgs_tensor):
75 | hfliped_imgs[i] = hflip(img_ten)
76 | return hfliped_imgs
77 |
78 |
79 | def separate_bn_paras(modules):
80 | if not isinstance(modules, list):
81 | modules = [*modules.modules()]
82 | paras_only_bn = []
83 | paras_wo_bn = []
84 | for layer in modules:
85 | if 'model' in str(layer.__class__):
86 | continue
87 | if 'container' in str(layer.__class__):
88 | continue
89 | else:
90 | if 'batchnorm' in str(layer.__class__):
91 | paras_only_bn.extend([*layer.parameters()])
92 | else:
93 | paras_wo_bn.extend([*layer.parameters()])
94 | return paras_only_bn, paras_wo_bn
95 |
96 |
97 | def gen_plot(fpr, tpr):
98 | """Create a pyplot plot and save to buffer."""
99 | plt.figure()
100 | plt.xlabel("FPR", fontsize=14)
101 | plt.ylabel("TPR", fontsize=14)
102 | plt.title("ROC Curve", fontsize=14)
103 | plot = plt.plot(fpr, tpr, linewidth=2)
104 | buf = io.BytesIO()
105 | plt.savefig(buf, format='jpeg')
106 | buf.seek(0)
107 | plt.close()
108 | return buf
109 |
110 |
111 | def main():
112 | # model = ResNet34()
113 | # state_dict = torch.load('/home/user1/linx/program/LightFaceNet/work_space/models/model_train_best'
114 | # '/resnet34_model_2019-10-12-19-20_accuracy:0.7216981_step:84816_lin.pth')
115 | # model.load_state_dict(state_dict)
116 | # test_speed(model)
117 |
118 | # model = torch.load('/home/user1/linx/program/LightFaceNet/work_space/models/pruned_model/model_resnet34.pkl')
119 | # state_dict = torch.load('/home/user1/linx/program/LightFaceNet/work_space/models/pruned_model'
120 | # '/resnet34_best_pruned_0.6556604.pt')
121 | # model.load_state_dict(state_dict)
122 | # test_speed(model)
123 |
124 | model = resnet_50()
125 | state_dict = torch.load('/home/linx/program/InsightFace_Pytorch/work_space/2020-08-20-10-32/models/model_accuracy'
126 | ':0.9708333333333334_step:163760_best_acc_lfw.pth')
127 | model.load_state_dict(state_dict)
128 | print(cal_flops(model, (1, 3, 112, 112)))
129 |
130 |
131 | if __name__ == '__main__':
132 | main()
133 |
134 |
--------------------------------------------------------------------------------
/data/train_data/__pycache__/ms1m_10k_loader.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/data/train_data/__pycache__/ms1m_10k_loader.cpython-36.pyc
--------------------------------------------------------------------------------
/data/train_data/__pycache__/verifacation.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/data/train_data/__pycache__/verifacation.cpython-36.pyc
--------------------------------------------------------------------------------
/data/train_data/ms1m_10k_loader.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # author: linx
3 | # datetime 2020/9/1 上午11:35
4 | from pathlib import Path
5 | from torch.utils.data import Dataset, ConcatDataset, DataLoader
6 | from torchvision import transforms as trans
7 | from torchvision.datasets import ImageFolder
8 | from PIL import Image, ImageFile
9 |
10 | ImageFile.LOAD_TRUNCATED_IMAGES = True
11 | import numpy as np
12 | import cv2
13 | import bcolz
14 | import pickle
15 | import mxnet as mx
16 | from tqdm import tqdm
17 | import os
18 |
19 |
20 | def de_preprocess(tensor):
21 | return tensor * 0.5 + 0.5
22 |
23 |
24 | def get_train_dataset(imgs_folder):
25 | train_transform = trans.Compose([
26 | trans.RandomHorizontalFlip(),
27 | trans.ToTensor(),
28 | trans.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
29 | ])
30 | ds = ImageFolder(imgs_folder, train_transform)
31 | class_num = ds[-1][1] + 1
32 | return ds, class_num
33 |
34 |
35 | def get_train_loader(args):
36 | ms1m_ds, ms1m_class_num = get_train_dataset(os.path.join(args.train_data_path, 'imgs'))
37 | print('ms1m loader generated')
38 |
39 | ds = ms1m_ds
40 | class_num = ms1m_class_num
41 |
42 | loader = DataLoader(ds, batch_size=args.train_batch_size, shuffle=True, pin_memory=args.pin_memory,
43 | num_workers=args.num_workers)
44 | return loader, class_num
45 |
46 |
47 | def load_bin(path, rootdir, transform, image_size=[112, 112]):
48 | if not rootdir.exists():
49 | rootdir.mkdir()
50 | bins, issame_list = pickle.load(open(path, 'rb'), encoding='bytes')
51 | data = bcolz.fill([len(bins), 3, image_size[0], image_size[1]], dtype=np.float32, rootdir=rootdir, mode='w')
52 | for i in range(len(bins)):
53 | _bin = bins[i]
54 | img = mx.image.imdecode(_bin).asnumpy()
55 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
56 | img = Image.fromarray(img.astype(np.uint8))
57 | data[i, ...] = transform(img)
58 | i += 1
59 | if i % 1000 == 0:
60 | print('loading bin', i)
61 | print(data.shape)
62 | np.save(str(rootdir) + '_list', np.array(issame_list))
63 | return data, issame_list
64 |
65 |
66 | def get_val_pair(path, name):
67 | carray = bcolz.carray(rootdir=os.path.join(path, name), mode='r')
68 | issame = np.load(os.path.join(path, '{}_list.npy'.format(name)))
69 | return carray, issame
70 |
71 |
72 | def get_val_data(data_path):
73 | agedb_30, agedb_30_issame = get_val_pair(data_path, 'agedb_30') # 12000图片对,6000个相对应的标签
74 | cfp_fp, cfp_fp_issame = get_val_pair(data_path, 'cfp_fp')
75 | lfw, lfw_issame = get_val_pair(data_path, 'lfw')
76 | return agedb_30, cfp_fp, lfw, agedb_30_issame, cfp_fp_issame, lfw_issame
77 |
78 |
79 | def load_mx_rec(rec_path):
80 | save_path = rec_path / 'imgs'
81 | if not save_path.exists():
82 | save_path.mkdir()
83 | imgrec = mx.recordio.MXIndexedRecordIO(str(rec_path / 'train.idx'), str(rec_path / 'train.rec'), 'r')
84 | img_info = imgrec.read_idx(0)
85 | header, _ = mx.recordio.unpack(img_info)
86 | max_idx = int(header.label[0])
87 | for idx in tqdm(range(1, max_idx)):
88 | img_info = imgrec.read_idx(idx)
89 | header, img = mx.recordio.unpack_img(img_info)
90 | label = int(header.label)
91 | img = Image.fromarray(img)
92 | label_path = save_path / str(label)
93 | if not label_path.exists():
94 | label_path.mkdir()
95 | img.save(label_path / '{}.jpg'.format(idx), quality=95)
96 |
--------------------------------------------------------------------------------
/distiller/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2018 Intel Corporation
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | import torch
18 | from .utils import *
19 | from .thresholding import GroupThresholdMixin, threshold_mask, group_threshold_mask
20 | from .config import file_config, dict_config, config_component_from_file_by_class
21 | from .model_summaries import *
22 | from .scheduler import *
23 | from .sensitivity import *
24 | from .directives import *
25 | from .policy import *
26 | from .thinning import *
27 | from .knowledge_distillation import KnowledgeDistillationPolicy, DistillationLossWeights
28 | from .summary_graph import SummaryGraph, onnx_name_2_pytorch_name
29 |
30 | import logging
31 | logging.captureWarnings(True)
32 |
33 | del dict_config
34 | del thinning
35 |
36 | # Distiller version
37 | __version__ = "0.4.0-pre"
38 |
39 |
40 | def model_find_param_name(model, param_to_find):
41 | """Look up the name of a model parameter.
42 |
43 | Arguments:
44 | model: the model to search
45 | param_to_find: the parameter whose name we want to look up
46 |
47 | Returns:
48 | The parameter name (string) or None, if the parameter was not found.
49 | """
50 | for name, param in model.named_parameters():
51 | if param is param_to_find:
52 | return name
53 | return None
54 |
55 |
56 | def model_find_module_name(model, module_to_find):
57 | """Look up the name of a module in a model.
58 |
59 | Arguments:
60 | model: the model to search
61 | module_to_find: the module whose name we want to look up
62 |
63 | Returns:
64 | The module name (string) or None, if the module was not found.
65 | """
66 | for name, m in model.named_modules():
67 | if m == module_to_find:
68 | return name
69 | return None
70 |
71 |
72 | def model_find_param(model, param_to_find_name):
73 | """Look a model parameter by its name
74 |
75 | Arguments:
76 | model: the model to search
77 | param_to_find_name: the name of the parameter that we are searching for
78 |
79 | Returns:
80 | The parameter or None, if the paramter name was not found.
81 | """
82 | for name, param in model.named_parameters():
83 | if name == param_to_find_name:
84 | return param
85 | return None
86 |
87 |
88 | def model_find_module(model, module_to_find):
89 | """Given a module name, find the module in the provided model.
90 |
91 | Arguments:
92 | model: the model to search
93 | module_to_find: the module whose name we want to look up
94 |
95 | Returns:
96 | The module or None, if the module was not found.
97 | """
98 | for name, m in model.named_modules():
99 | if name == module_to_find:
100 | return m
101 | return None
102 |
103 |
104 | def check_pytorch_version():
105 | from pkg_resources import parse_version
106 | if parse_version(torch.__version__) < parse_version('1.1.0'):
107 | msg = "\n\nWRONG PYTORCH VERSION\n"\
108 | "The Distiller \'master\' branch now requires at least PyTorch version 1.1.0 due to "\
109 | "PyTorch API changes which are not backward-compatible. Version detected is {}.\n"\
110 | "To make sure PyTorch and all other dependencies are installed with their correct versions, " \
111 | "go to the Distiller repo root directory and run:\n\n"\
112 | "pip install -e .\n".format(torch.__version__)
113 | raise RuntimeError(msg)
114 |
115 |
116 | check_pytorch_version()
117 |
--------------------------------------------------------------------------------
/distiller/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/__pycache__/config.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/__pycache__/config.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/__pycache__/directives.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/__pycache__/directives.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/__pycache__/knowledge_distillation.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/__pycache__/knowledge_distillation.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/__pycache__/learning_rate.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/__pycache__/learning_rate.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/__pycache__/model_summaries.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/__pycache__/model_summaries.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/__pycache__/model_transforms.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/__pycache__/model_transforms.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/__pycache__/policy.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/__pycache__/policy.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/__pycache__/scheduler.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/__pycache__/scheduler.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/__pycache__/sensitivity.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/__pycache__/sensitivity.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/__pycache__/summary_graph.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/__pycache__/summary_graph.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/__pycache__/thinning.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/__pycache__/thinning.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/__pycache__/thresholding.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/__pycache__/thresholding.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/apputils/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2018 Intel Corporation
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | """This package contains Python code and classes that are meant to make your life easier,
18 | when working with distiller.
19 |
20 | """
21 | from .data_loaders import *
22 | from .checkpoint import *
23 | from .execution_env import *
24 | from .dataset_summaries import *
25 |
26 | del data_loaders
27 | del checkpoint
28 | del execution_env
29 | del dataset_summaries
30 |
--------------------------------------------------------------------------------
/distiller/apputils/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/apputils/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/apputils/__pycache__/checkpoint.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/apputils/__pycache__/checkpoint.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/apputils/__pycache__/data_loaders.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/apputils/__pycache__/data_loaders.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/apputils/__pycache__/dataset_summaries.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/apputils/__pycache__/dataset_summaries.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/apputils/__pycache__/execution_env.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/apputils/__pycache__/execution_env.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/apputils/__pycache__/image_classifier.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/apputils/__pycache__/image_classifier.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/apputils/dataset_summaries.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2019 Intel Corporation
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | import distiller
18 | import numpy as np
19 | import logging
20 | msglogger = logging.getLogger()
21 |
22 | def dataset_summary(data_loader):
23 | """Create a histogram of class membership distribution within a dataset.
24 |
25 | It is important to examine our training, validation, and test
26 | datasets, to make sure that they are balanced.
27 | """
28 | msglogger.info("Analyzing dataset:")
29 | print_frequency = 50
30 | for batch, (input, label_batch) in enumerate(data_loader):
31 | try:
32 | all_labels = np.append(all_labels, distiller.to_np(label_batch))
33 | except NameError:
34 | all_labels = distiller.to_np(label_batch)
35 | if (batch+1) % print_frequency == 0:
36 | # progress indicator
37 | print("batch: %d" % batch)
38 |
39 | hist = np.histogram(all_labels, bins=np.arange(1000+1))
40 | nclasses = len(hist[0])
41 | for data_class, size in enumerate(hist[0]):
42 | msglogger.info("\tClass {} = {}".format(data_class, size))
43 | msglogger.info("Dataset contains {} items".format(len(data_loader.sampler)))
44 | msglogger.info("Found {} classes".format(nclasses))
45 | msglogger.info("Average: {} samples per class".format(np.mean(hist[0])))
46 |
--------------------------------------------------------------------------------
/distiller/data_loggers/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2018 Intel Corporation
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | from .collector import *
18 | from .logger import PythonLogger, TensorBoardLogger, CsvLogger
19 |
20 | del logger
21 | del collector
22 |
--------------------------------------------------------------------------------
/distiller/data_loggers/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/data_loggers/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/data_loggers/__pycache__/collector.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/data_loggers/__pycache__/collector.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/data_loggers/__pycache__/logger.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/data_loggers/__pycache__/logger.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/data_loggers/__pycache__/tbbackend.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/data_loggers/__pycache__/tbbackend.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/data_loggers/tbbackend.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2018 Intel Corporation
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 | """ A TensorBoard backend.
17 |
18 | Writes logs to a file using a Google's TensorBoard protobuf format.
19 | See: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/summary.proto
20 | """
21 | import os
22 | # Disable FutureWarning from TensorFlow
23 | import warnings
24 | warnings.simplefilter(action='ignore', category=FutureWarning)
25 | import tensorflow as tf
26 | import numpy as np
27 |
28 |
29 | class TBBackend(object):
30 | def __init__(self, log_dir):
31 | self.writers = []
32 | self.log_dir = log_dir
33 | self.writers.append(tf.summary.FileWriter(log_dir))
34 |
35 | def scalar_summary(self, tag, scalar, step):
36 | """From TF documentation:
37 | tag: name for the data. Used by TensorBoard plugins to organize data.
38 | value: value associated with the tag (a float).
39 | """
40 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=scalar)])
41 | self.writers[0].add_summary(summary, step)
42 |
43 | def list_summary(self, tag, list, step, multi_graphs):
44 | """Log a relatively small list of scalars.
45 |
46 | We want to track the progress of multiple scalar parameters in a single graph.
47 | The list provides a single value for each of the parameters we are tracking.
48 |
49 | NOTE: There are two ways to log multiple values in TB and neither one is optimal.
50 | 1. Use a single writer: in this case all of the parameters use the same color, and
51 | distinguishing between them is difficult.
52 | 2. Use multiple writers: in this case each parameter has its own color which helps
53 | to visually separate the parameters. However, each writer logs to a different
54 | file and this creates a lot of files which slow down the TB load.
55 | """
56 | for i, scalar in enumerate(list):
57 | if multi_graphs and (i+1 > len(self.writers)):
58 | self.writers.append(tf.summary.FileWriter(os.path.join(self.log_dir, str(i))))
59 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=scalar)])
60 | self.writers[0 if not multi_graphs else i].add_summary(summary, step)
61 |
62 | def histogram_summary(self, tag, tensor, step):
63 | """
64 | From the TF documentation:
65 | tf.summary.histogram takes an arbitrarily sized and shaped Tensor, and
66 | compresses it into a histogram data structure consisting of many bins with
67 | widths and counts.
68 |
69 | TensorFlow uses non-uniformly distributed bins, which is better than using
70 | numpy's uniform bins for activations and parameters which converge around zero,
71 | but we don't add that logic here.
72 |
73 | https://www.tensorflow.org/programmers_guide/tensorboard_histograms
74 | """
75 | hist, edges = np.histogram(tensor, bins=200)
76 | tfhist = tf.HistogramProto(
77 | min=np.min(tensor),
78 | max=np.max(tensor),
79 | num=int(np.prod(tensor.shape)),
80 | sum=np.sum(tensor),
81 | sum_squares=np.sum(np.square(tensor)))
82 |
83 | # From the TF documentation:
84 | # Parallel arrays encoding the bucket boundaries and the bucket values.
85 | # bucket(i) is the count for the bucket i. The range for a bucket is:
86 | # i == 0: -DBL_MAX .. bucket_limit(0)
87 | # i != 0: bucket_limit(i-1) .. bucket_limit(i)
88 | tfhist.bucket_limit.extend(edges[1:])
89 | tfhist.bucket.extend(hist)
90 |
91 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=tfhist)])
92 | self.writers[0].add_summary(summary, step)
93 |
94 | def sync_to_file(self):
95 | for writer in self.writers:
96 | writer.flush()
97 |
--------------------------------------------------------------------------------
/distiller/directives.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2018 Intel Corporation
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | """Scheduling directives
18 |
19 | Scheduling directives are instructions (directives) that the scheduler can
20 | execute as part of scheduling pruning activities.
21 | """
22 | from __future__ import division
23 | import torch
24 | import numpy as np
25 | from collections import defaultdict
26 | import logging
27 | msglogger = logging.getLogger()
28 |
29 | from torchnet.meter import AverageValueMeter
30 | from distiller.utils import sparsity, density
31 |
32 |
33 | class FreezeTraining(object):
34 | def __init__(self, name):
35 | print("------FreezeTraining--------")
36 | self.name = name
37 |
38 | def freeze_training(model, which_params, freeze):
39 | """This function will freeze/defrost training for certain layers.
40 |
41 | Sometimes, when we prune and retrain a certain layer type,
42 | we'd like to freeze the training of the other layers.
43 | """
44 | for param in model.parameters():
45 | pname = model_find_param_name(model, param.data)
46 | if pname is None:
47 | continue
48 | for ptype in which_params:
49 | if ptype in pname:
50 | # see: http://pytorch.org/docs/master/notes/autograd.html?highlight=grad_fn
51 | param.requires_grad = not freeze
52 | if freeze:
53 | msglogger.info('Freezing: ' + pname)
54 | else:
55 | msglogger.info('Defrosting: ' + pname)
56 |
57 |
58 | def freeze_all(model, freeze):
59 | msglogger.info('{} all parameters'.format('Freezing' if freeze else 'Defrosting'))
60 | for param in model.parameters():
61 | param.requires_grad = not freeze
62 |
63 |
64 | def adjust_dropout(module, new_probabilty):
65 | """Replace the dropout probability of dropout layers
66 |
67 | As explained in the paper "Learning both Weights and Connections for
68 | Efficient Neural Networks":
69 | Dropout is widely used to prevent over-fitting, and this also applies to retraining.
70 | During retraining, however, the dropout ratio must be adjusted to account for the
71 | change in model capacity. In dropout, each parameter is probabilistically dropped
72 | during training, but will come back during inference. In pruning, parameters are
73 | dropped forever after pruning and have no chance to come back during both training
74 | and inference. As the parameters get sparse, the classifier will select the most
75 | informative predictors and thus have much less prediction variance, which reduces
76 | over-fitting. As pruning already reduced model capacity, the retraining dropout ratio
77 | should be smaller.
78 | """
79 | if type(module) in [torch.nn.Dropout,
80 | torch.nn.Dropout2d,
81 | torch.nn.Dropout3d,
82 | torch.nn.AlphaDropout]:
83 | msglogger.info("Adjusting dropout probability")# for {}".format(str(module)))
84 | module.p = new_probabilty
85 | else:
86 | for child in module.children():
87 | adjust_dropout(child, new_probabilty)
88 |
--------------------------------------------------------------------------------
/distiller/learning_rate.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2018 Intel Corporation
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | from bisect import bisect_right
18 | from torch.optim.lr_scheduler import _LRScheduler
19 |
20 |
21 | class PolynomialLR(_LRScheduler):
22 | """Set the learning rate for each parameter group using a polynomial defined as:
23 | lr = base_lr * (1 - T_cur/T_max) ^ (power), where T_cur is the current epoch and T_max is the maximum number of
24 | epochs.
25 |
26 | Args:
27 | optimizer (Optimizer): Wrapped optimizer.
28 | T_max (int): Maximum number of epochs
29 | power (int): Degree of polynomial
30 | last_epoch (int): The index of last epoch. Default: -1.
31 | """
32 | def __init__(self, optimizer, T_max, power, last_epoch=-1):
33 | self.T_max = T_max
34 | self.power = power
35 | super(PolynomialLR, self).__init__(optimizer, last_epoch)
36 |
37 | def get_lr(self):
38 | # base_lr * (1 - iter/max_iter) ^ (power)
39 | return [base_lr * (1 - self.last_epoch / self.T_max) ** self.power
40 | for base_lr in self.base_lrs]
41 |
42 |
43 | class MultiStepMultiGammaLR(_LRScheduler):
44 | """Similar to torch.otpim.MultiStepLR, but instead of a single gamma value, specify a gamma value per-milestone.
45 |
46 | Args:
47 | optimizer (Optimizer): Wrapped optimizer.
48 | milestones (list): List of epoch indices. Must be increasing.
49 | gammas (list): List of gamma values. Must have same length as milestones.
50 | last_epoch (int): The index of last epoch. Default: -1.
51 | """
52 | def __init__(self, optimizer, milestones, gammas, last_epoch=-1):
53 | if not list(milestones) == sorted(milestones):
54 | raise ValueError('Milestones should be a list of'
55 | ' increasing integers. Got {}', milestones)
56 | if len(milestones) != len(gammas):
57 | raise ValueError('Milestones and Gammas lists should be of same length.')
58 |
59 | self.milestones = milestones
60 | self.multiplicative_gammas = [1]
61 | for idx, gamma in enumerate(gammas):
62 | self.multiplicative_gammas.append(gamma * self.multiplicative_gammas[idx])
63 |
64 | super(MultiStepMultiGammaLR, self).__init__(optimizer, last_epoch)
65 |
66 | def get_lr(self):
67 | idx = bisect_right(self.milestones, self.last_epoch)
68 | return [base_lr * self.multiplicative_gammas[idx] for base_lr in self.base_lrs]
69 |
--------------------------------------------------------------------------------
/distiller/models/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/models/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/models/cifar10/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2018 Intel Corporation
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | """This package contains CIFAR image classification models for pytorch"""
18 |
19 | from .simplenet_cifar import *
20 | from .resnet_cifar import *
21 | from .preresnet_cifar import *
22 | from .vgg_cifar import *
23 | from .resnet_cifar_earlyexit import *
24 | from .plain_cifar import *
25 |
--------------------------------------------------------------------------------
/distiller/models/cifar10/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/models/cifar10/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/models/cifar10/__pycache__/plain_cifar.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/models/cifar10/__pycache__/plain_cifar.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/models/cifar10/__pycache__/preresnet_cifar.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/models/cifar10/__pycache__/preresnet_cifar.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/models/cifar10/__pycache__/resnet_cifar.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/models/cifar10/__pycache__/resnet_cifar.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/models/cifar10/__pycache__/resnet_cifar_earlyexit.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/models/cifar10/__pycache__/resnet_cifar_earlyexit.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/models/cifar10/__pycache__/simplenet_cifar.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/models/cifar10/__pycache__/simplenet_cifar.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/models/cifar10/__pycache__/vgg_cifar.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/models/cifar10/__pycache__/vgg_cifar.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/models/cifar10/plain_cifar.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2018 Intel Corporation
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | """Plain for CIFAR10
18 |
19 | Plain for CIFAR10, based on "Deep Residual Learning for Image Recognition".
20 |
21 | @inproceedings{DBLP:conf/cvpr/HeZRS16,
22 | author = {Kaiming He and
23 | Xiangyu Zhang and
24 | Shaoqing Ren and
25 | Jian Sun},
26 | title = {Deep Residual Learning for Image Recognition},
27 | booktitle = {{CVPR}},
28 | pages = {770--778},
29 | publisher = {{IEEE} Computer Society},
30 | year = {2016}
31 | }
32 |
33 | """
34 | import torch.nn as nn
35 | import math
36 |
37 |
38 | __all__ = ['plain20_cifar']
39 |
40 | NUM_CLASSES = 10
41 |
42 |
43 | def conv3x3(in_planes, out_planes, stride=1):
44 | """3x3 convolution with padding"""
45 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
46 | padding=1, bias=False)
47 |
48 |
49 | class BasicBlock(nn.Module):
50 | expansion = 1
51 |
52 | def __init__(self, inplanes, planes, stride=1):
53 | super().__init__()
54 | self.conv1 = conv3x3(inplanes, planes, stride)
55 | self.bn1 = nn.BatchNorm2d(planes)
56 | self.relu1 = nn.ReLU(inplace=False)
57 | self.conv2 = conv3x3(planes, planes)
58 | self.bn2 = nn.BatchNorm2d(planes)
59 | self.relu2 = nn.ReLU(inplace=False)
60 |
61 | def forward(self, x):
62 | out = self.conv1(x)
63 | out = self.bn1(out)
64 | out = self.relu1(out)
65 |
66 | out = self.conv2(out)
67 | out = self.bn2(out)
68 | out = self.relu2(out)
69 | return out
70 |
71 |
72 | class PlainCifar(nn.Module):
73 | def __init__(self, block, blks_per_layer, num_classes=NUM_CLASSES):
74 | self.inplanes = 16
75 | super().__init__()
76 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
77 | self.bn1 = nn.BatchNorm2d(self.inplanes)
78 | self.relu = nn.ReLU(inplace=True)
79 | self.layer1 = self._make_layer(block, 16, blks_per_layer[0], stride=1)
80 | self.layer2 = self._make_layer(block, 32, blks_per_layer[1], stride=2)
81 | self.layer3 = self._make_layer(block, 64, blks_per_layer[2], stride=2)
82 |
83 | self.avgpool = nn.AvgPool2d(8, stride=1)
84 | self.fc = nn.Linear(64 * block.expansion, num_classes)
85 |
86 | for m in self.modules():
87 | if isinstance(m, nn.Conv2d):
88 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
89 | m.weight.data.normal_(0, math.sqrt(2. / n))
90 | elif isinstance(m, nn.BatchNorm2d):
91 | m.weight.data.fill_(1)
92 | m.bias.data.zero_()
93 |
94 | def _make_layer(self, block, planes, num_blocks, stride):
95 | # Each layer is composed on 2*num_blocks blocks, and the first block usually
96 | # performs downsampling of the input, and doubling of the number of filters/feature-maps.
97 | blocks = []
98 | inplanes = self.inplanes
99 | # First block is special (downsamples and adds filters)
100 | blocks.append(block(inplanes, planes, stride))
101 |
102 | self.inplanes = planes * block.expansion
103 | for i in range(num_blocks - 1):
104 | blocks.append(block(self.inplanes, planes, stride=1))
105 | return nn.Sequential(*blocks)
106 |
107 | def forward(self, x):
108 | x = self.conv1(x)
109 | x = self.bn1(x)
110 | x = self.relu(x)
111 |
112 | x = self.layer1(x)
113 | x = self.layer2(x)
114 | x = self.layer3(x)
115 |
116 | x = self.avgpool(x)
117 | x = x.view(x.size(0), -1)
118 | x = self.fc(x)
119 | return x
120 |
121 |
122 | def plain20_cifar(**kwargs):
123 | model = PlainCifar(BasicBlock, [3, 3, 3], **kwargs)
124 | return model
125 |
--------------------------------------------------------------------------------
/distiller/models/cifar10/resnet_cifar.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2018 Intel Corporation
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | """Resnet for CIFAR10
18 |
19 | Resnet for CIFAR10, based on "Deep Residual Learning for Image Recognition".
20 | This is based on TorchVision's implementation of ResNet for ImageNet, with appropriate
21 | changes for the 10-class Cifar-10 dataset.
22 | This ResNet also has layer gates, to be able to dynamically remove layers.
23 |
24 | @inproceedings{DBLP:conf/cvpr/HeZRS16,
25 | author = {Kaiming He and
26 | Xiangyu Zhang and
27 | Shaoqing Ren and
28 | Jian Sun},
29 | title = {Deep Residual Learning for Image Recognition},
30 | booktitle = {{CVPR}},
31 | pages = {770--778},
32 | publisher = {{IEEE} Computer Society},
33 | year = {2016}
34 | }
35 |
36 | """
37 | import torch.nn as nn
38 | import math
39 | import torch.utils.model_zoo as model_zoo
40 |
41 |
42 | __all__ = ['resnet20_cifar', 'resnet32_cifar', 'resnet44_cifar', 'resnet56_cifar']
43 |
44 | NUM_CLASSES = 10
45 |
46 | def conv3x3(in_planes, out_planes, stride=1):
47 | """3x3 convolution with padding"""
48 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
49 | padding=1, bias=False)
50 |
51 | class BasicBlock(nn.Module):
52 | expansion = 1
53 |
54 | def __init__(self, block_gates, inplanes, planes, stride=1, downsample=None):
55 | super(BasicBlock, self).__init__()
56 | self.block_gates = block_gates
57 | self.conv1 = conv3x3(inplanes, planes, stride)
58 | self.bn1 = nn.BatchNorm2d(planes)
59 | self.relu1 = nn.ReLU(inplace=False) # To enable layer removal inplace must be False
60 | self.conv2 = conv3x3(planes, planes)
61 | self.bn2 = nn.BatchNorm2d(planes)
62 | self.relu2 = nn.ReLU(inplace=False)
63 | self.downsample = downsample
64 | self.stride = stride
65 |
66 | def forward(self, x):
67 | residual = out = x
68 |
69 | if self.block_gates[0]:
70 | out = self.conv1(x)
71 | out = self.bn1(out)
72 | out = self.relu1(out)
73 |
74 | if self.block_gates[1]:
75 | out = self.conv2(out)
76 | out = self.bn2(out)
77 |
78 | if self.downsample is not None:
79 | residual = self.downsample(x)
80 |
81 | out += residual
82 | out = self.relu2(out)
83 |
84 | return out
85 |
86 |
87 | class ResNetCifar(nn.Module):
88 |
89 | def __init__(self, block, layers, num_classes=NUM_CLASSES):
90 | self.nlayers = 0
91 | # Each layer manages its own gates
92 | self.layer_gates = []
93 | for layer in range(3):
94 | # For each of the 3 layers, create block gates: each block has two layers
95 | self.layer_gates.append([]) # [True, True] * layers[layer])
96 | for blk in range(layers[layer]):
97 | self.layer_gates[layer].append([True, True])
98 |
99 | self.inplanes = 16 # 64
100 | super(ResNetCifar, self).__init__()
101 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
102 | self.bn1 = nn.BatchNorm2d(self.inplanes)
103 | self.relu = nn.ReLU(inplace=True)
104 | self.layer1 = self._make_layer(self.layer_gates[0], block, 16, layers[0])
105 | self.layer2 = self._make_layer(self.layer_gates[1], block, 32, layers[1], stride=2)
106 | self.layer3 = self._make_layer(self.layer_gates[2], block, 64, layers[2], stride=2)
107 | self.avgpool = nn.AvgPool2d(8, stride=1)
108 | self.fc = nn.Linear(64 * block.expansion, num_classes)
109 |
110 | for m in self.modules():
111 | if isinstance(m, nn.Conv2d):
112 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
113 | m.weight.data.normal_(0, math.sqrt(2. / n))
114 | elif isinstance(m, nn.BatchNorm2d):
115 | m.weight.data.fill_(1)
116 | m.bias.data.zero_()
117 |
118 | def _make_layer(self, layer_gates, block, planes, blocks, stride=1):
119 | downsample = None
120 | if stride != 1 or self.inplanes != planes * block.expansion:
121 | downsample = nn.Sequential(
122 | nn.Conv2d(self.inplanes, planes * block.expansion,
123 | kernel_size=1, stride=stride, bias=False),
124 | nn.BatchNorm2d(planes * block.expansion),
125 | )
126 |
127 | layers = []
128 | layers.append(block(layer_gates[0], self.inplanes, planes, stride, downsample))
129 | self.inplanes = planes * block.expansion
130 | for i in range(1, blocks):
131 | layers.append(block(layer_gates[i], self.inplanes, planes))
132 |
133 | return nn.Sequential(*layers)
134 |
135 | def forward(self, x):
136 | x = self.conv1(x)
137 | x = self.bn1(x)
138 | x = self.relu(x)
139 |
140 | x = self.layer1(x)
141 | x = self.layer2(x)
142 | x = self.layer3(x)
143 |
144 | x = self.avgpool(x)
145 | x = x.view(x.size(0), -1)
146 | x = self.fc(x)
147 |
148 | return x
149 |
150 |
151 | def resnet20_cifar(**kwargs):
152 | model = ResNetCifar(BasicBlock, [3, 3, 3], **kwargs)
153 | return model
154 |
155 | def resnet32_cifar(**kwargs):
156 | model = ResNetCifar(BasicBlock, [5, 5, 5], **kwargs)
157 | return model
158 |
159 | def resnet44_cifar(**kwargs):
160 | model = ResNetCifar(BasicBlock, [7, 7, 7], **kwargs)
161 | return model
162 |
163 | def resnet56_cifar(**kwargs):
164 | model = ResNetCifar(BasicBlock, [9, 9, 9], **kwargs)
165 | return model
166 |
--------------------------------------------------------------------------------
/distiller/models/cifar10/resnet_cifar_earlyexit.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2018 Intel Corporation
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | """Resnet for CIFAR10
18 |
19 | Resnet for CIFAR10, based on "Deep Residual Learning for Image Recognition".
20 | This is based on TorchVision's implementation of ResNet for ImageNet, with appropriate
21 | changes for the 10-class Cifar-10 dataset.
22 | This ResNet also has layer gates, to be able to dynamically remove layers.
23 |
24 | @inproceedings{DBLP:conf/cvpr/HeZRS16,
25 | author = {Kaiming He and
26 | Xiangyu Zhang and
27 | Shaoqing Ren and
28 | Jian Sun},
29 | title = {Deep Residual Learning for Image Recognition},
30 | booktitle = {{CVPR}},
31 | pages = {770--778},
32 | publisher = {{IEEE} Computer Society},
33 | year = {2016}
34 | }
35 |
36 | """
37 | import torch.nn as nn
38 | import math
39 | import torch.utils.model_zoo as model_zoo
40 | import torchvision.models as models
41 | from .resnet_cifar import BasicBlock
42 | from .resnet_cifar import ResNetCifar
43 |
44 |
45 | __all__ = ['resnet20_cifar_earlyexit', 'resnet32_cifar_earlyexit', 'resnet44_cifar_earlyexit',
46 | 'resnet56_cifar_earlyexit', 'resnet110_cifar_earlyexit', 'resnet1202_cifar_earlyexit']
47 |
48 | NUM_CLASSES = 10
49 |
50 | def conv3x3(in_planes, out_planes, stride=1):
51 | """3x3 convolution with padding"""
52 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
53 | padding=1, bias=False)
54 |
55 |
56 | class ResNetCifarEarlyExit(ResNetCifar):
57 |
58 | def __init__(self, block, layers, num_classes=NUM_CLASSES):
59 | super(ResNetCifarEarlyExit, self).__init__(block, layers, num_classes)
60 |
61 | # Define early exit layers
62 | self.linear_exit0 = nn.Linear(1600, num_classes)
63 |
64 |
65 | def forward(self, x):
66 | x = self.conv1(x)
67 | x = self.bn1(x)
68 | x = self.relu(x)
69 |
70 | x = self.layer1(x)
71 |
72 | # Add early exit layers
73 | exit0 = nn.functional.avg_pool2d(x, 3)
74 | exit0 = exit0.view(exit0.size(0), -1)
75 | exit0 = self.linear_exit0(exit0)
76 |
77 | x = self.layer2(x)
78 | x = self.layer3(x)
79 |
80 | x = self.avgpool(x)
81 | x = x.view(x.size(0), -1)
82 | x = self.fc(x)
83 |
84 | # return a list of probabilities
85 | output = []
86 | output.append(exit0)
87 | output.append(x)
88 | return output
89 |
90 |
91 | def resnet20_cifar_earlyexit(**kwargs):
92 | model = ResNetCifarEarlyExit(BasicBlock, [3, 3, 3], **kwargs)
93 | return model
94 |
95 | def resnet32_cifar_earlyexit(**kwargs):
96 | model = ResNetCifarEarlyExit(BasicBlock, [5, 5, 5], **kwargs)
97 | return model
98 |
99 | def resnet44_cifar_earlyexit(**kwargs):
100 | model = ResNetCifarEarlyExit(BasicBlock, [7, 7, 7], **kwargs)
101 | return model
102 |
103 | def resnet56_cifar_earlyexit(**kwargs):
104 | model = ResNetCifarEarlyExit(BasicBlock, [9, 9, 9], **kwargs)
105 | return model
106 |
107 | def resnet110_cifar_earlyexit(**kwargs):
108 | model = ResNetCifarEarlyExit(BasicBlock, [18, 18, 18], **kwargs)
109 | return model
110 |
111 | def resnet1202_cifar_earlyexit(**kwargs):
112 | model = ResNetCifarEarlyExit(BasicBlock, [200, 200, 200], **kwargs)
113 | return model
--------------------------------------------------------------------------------
/distiller/models/cifar10/simplenet_cifar.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2018 Intel Corporation
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | import torch.nn as nn
18 | import torch.nn.functional as F
19 |
20 | __all__ = ['simplenet_cifar']
21 |
22 |
23 | class Simplenet(nn.Module):
24 | def __init__(self):
25 | super(Simplenet, self).__init__()
26 | self.conv1 = nn.Conv2d(3, 6, 5)
27 | self.relu_conv1 = nn.ReLU()
28 | self.pool1 = nn.MaxPool2d(2, 2)
29 | self.conv2 = nn.Conv2d(6, 16, 5)
30 | self.relu_conv2 = nn.ReLU()
31 | self.pool2 = nn.MaxPool2d(2, 2)
32 | self.fc1 = nn.Linear(16 * 5 * 5, 120)
33 | self.relu_fc1 = nn.ReLU()
34 | self.fc2 = nn.Linear(120, 84)
35 | self.relu_fc2 = nn.ReLU()
36 | self.fc3 = nn.Linear(84, 10)
37 |
38 | def forward(self, x):
39 | x = self.pool1(self.relu_conv1(self.conv1(x)))
40 | x = self.pool2(self.relu_conv2(self.conv2(x)))
41 | x = x.view(-1, 16 * 5 * 5)
42 | x = self.relu_fc1(self.fc1(x))
43 | x = self.relu_fc2(self.fc2(x))
44 | x = self.fc3(x)
45 | return x
46 |
47 |
48 | def simplenet_cifar():
49 | model = Simplenet()
50 | return model
51 |
--------------------------------------------------------------------------------
/distiller/models/cifar10/vgg_cifar.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2018 Intel Corporation
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | """VGG for CIFAR10
18 |
19 | VGG for CIFAR10, based on "Very Deep Convolutional Networks for Large-Scale
20 | Image Recognition".
21 | This is based on TorchVision's implementation of VGG for ImageNet, with
22 | appropriate changes for the 10-class Cifar-10 dataset.
23 | We replaced the three linear classifiers with a single one.
24 | """
25 |
26 | import torch.nn as nn
27 |
28 | __all__ = [
29 | 'VGGCifar', 'vgg11_cifar', 'vgg11_bn_cifar', 'vgg13_cifar', 'vgg13_bn_cifar', 'vgg16_cifar', 'vgg16_bn_cifar',
30 | 'vgg19_bn_cifar', 'vgg19_cifar',
31 | ]
32 |
33 |
34 | class VGGCifar(nn.Module):
35 | def __init__(self, features, num_classes=10, init_weights=True):
36 | super(VGGCifar, self).__init__()
37 | self.features = features
38 | self.classifier = nn.Linear(512, num_classes)
39 | if init_weights:
40 | self._initialize_weights()
41 |
42 | def forward(self, x):
43 | x = self.features(x)
44 | x = x.view(x.size(0), -1)
45 | x = self.classifier(x)
46 | return x
47 |
48 | def _initialize_weights(self):
49 | for m in self.modules():
50 | if isinstance(m, nn.Conv2d):
51 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
52 | if m.bias is not None:
53 | nn.init.constant_(m.bias, 0)
54 | elif isinstance(m, nn.BatchNorm2d):
55 | nn.init.constant_(m.weight, 1)
56 | nn.init.constant_(m.bias, 0)
57 | elif isinstance(m, nn.Linear):
58 | nn.init.normal_(m.weight, 0, 0.01)
59 | nn.init.constant_(m.bias, 0)
60 |
61 |
62 | def make_layers(cfg, batch_norm=False):
63 | layers = []
64 | in_channels = 3
65 | for v in cfg:
66 | if v == 'M':
67 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
68 | else:
69 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
70 | if batch_norm:
71 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
72 | else:
73 | layers += [conv2d, nn.ReLU(inplace=True)]
74 | in_channels = v
75 | return nn.Sequential(*layers)
76 |
77 |
78 | cfg = {
79 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
80 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
81 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
82 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
83 | }
84 |
85 |
86 | def vgg11_cifar(**kwargs):
87 | """VGG 11-layer model (configuration "A")"""
88 | model = VGGCifar(make_layers(cfg['A']), **kwargs)
89 | return model
90 |
91 |
92 | def vgg11_bn_cifar(**kwargs):
93 | """VGG 11-layer model (configuration "A") with batch normalization"""
94 | model = VGGCifar(make_layers(cfg['A'], batch_norm=True), **kwargs)
95 | return model
96 |
97 |
98 | def vgg13_cifar(**kwargs):
99 | """VGG 13-layer model (configuration "B")"""
100 | model = VGGCifar(make_layers(cfg['B']), **kwargs)
101 | return model
102 |
103 |
104 | def vgg13_bn_cifar(**kwargs):
105 | """VGG 13-layer model (configuration "B") with batch normalization"""
106 | model = VGGCifar(make_layers(cfg['B'], batch_norm=True), **kwargs)
107 | return model
108 |
109 |
110 | def vgg16_cifar(**kwargs):
111 | """VGG 16-layer model (configuration "D")
112 | """
113 | model = VGGCifar(make_layers(cfg['D']), **kwargs)
114 | return model
115 |
116 |
117 | def vgg16_bn_cifar(**kwargs):
118 | """VGG 16-layer model (configuration "D") with batch normalization"""
119 | model = VGGCifar(make_layers(cfg['D'], batch_norm=True), **kwargs)
120 | return model
121 |
122 |
123 | def vgg19_cifar(**kwargs):
124 | """VGG 19-layer model (configuration "E")
125 | """
126 | model = VGGCifar(make_layers(cfg['E']), **kwargs)
127 | return model
128 |
129 |
130 | def vgg19_bn_cifar(**kwargs):
131 | """VGG 19-layer model (configuration 'E') with batch normalization"""
132 | model = VGGCifar(make_layers(cfg['E'], batch_norm=True), **kwargs)
133 | return model
134 |
--------------------------------------------------------------------------------
/distiller/models/imagenet/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2018 Intel Corporation
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | """This package contains ImageNet image classification models not found in torchvision"""
18 |
19 | from .mobilenet import *
20 | from .mobilenet_dropout import *
21 | from .preresnet_imagenet import *
22 | from .alexnet_batchnorm import *
23 | from .resnet_earlyexit import *
24 | from .resnet import *
25 |
--------------------------------------------------------------------------------
/distiller/models/imagenet/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/models/imagenet/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/models/imagenet/__pycache__/alexnet_batchnorm.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/models/imagenet/__pycache__/alexnet_batchnorm.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/models/imagenet/__pycache__/mobilenet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/models/imagenet/__pycache__/mobilenet.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/models/imagenet/__pycache__/mobilenet_dropout.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/models/imagenet/__pycache__/mobilenet_dropout.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/models/imagenet/__pycache__/preresnet_imagenet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/models/imagenet/__pycache__/preresnet_imagenet.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/models/imagenet/__pycache__/resnet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/models/imagenet/__pycache__/resnet.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/models/imagenet/__pycache__/resnet_earlyexit.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/models/imagenet/__pycache__/resnet_earlyexit.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/models/imagenet/alexnet_batchnorm.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2018 Intel Corporation
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | """
18 | AlexNet model with batch-norm layers.
19 | Model configuration based on the AlexNet DoReFa example in TensorPack:
20 | https://github.com/tensorpack/tensorpack/blob/master/examples/DoReFa-Net/alexnet-dorefa.py
21 |
22 | Code based on the AlexNet PyTorch sample, with the required changes.
23 | """
24 |
25 | import math
26 | import torch.nn as nn
27 |
28 | __all__ = ['AlexNetBN', 'alexnet_bn']
29 |
30 |
31 | class AlexNetBN(nn.Module):
32 |
33 | def __init__(self, num_classes=1000):
34 | super(AlexNetBN, self).__init__()
35 | self.features = nn.Sequential(
36 | nn.Conv2d(3, 96, kernel_size=12, stride=4), # conv0 (224x224x3) -> (54x54x96)
37 | nn.ReLU(inplace=True),
38 | nn.Conv2d(96, 256, kernel_size=5, padding=2, groups=2, bias=False), # conv1 (54x54x96) -> (54x54x256)
39 | nn.BatchNorm2d(256, eps=1e-4, momentum=0.9), # bn1 (54x54x256)
40 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), # pool1 (54x54x256) -> (27x27x256)
41 | nn.ReLU(inplace=True),
42 |
43 | nn.Conv2d(256, 384, kernel_size=3, padding=1, bias=False), # conv2 (27x27x256) -> (27x27x384)
44 | nn.BatchNorm2d(384, eps=1e-4, momentum=0.9), # bn2 (27x27x384)
45 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1), # pool2 (27x27x384) -> (14x14x384)
46 | nn.ReLU(inplace=True),
47 |
48 | nn.Conv2d(384, 384, kernel_size=3, padding=1, groups=2, bias=False), # conv3 (14x14x384) -> (14x14x384)
49 | nn.BatchNorm2d(384, eps=1e-4, momentum=0.9), # bn3 (14x14x384)
50 | nn.ReLU(inplace=True),
51 |
52 | nn.Conv2d(384, 256, kernel_size=3, padding=1, groups=2, bias=False), # conv4 (14x14x384) -> (14x14x256)
53 | nn.BatchNorm2d(256, eps=1e-4, momentum=0.9), # bn4 (14x14x256)
54 | nn.MaxPool2d(kernel_size=3, stride=2), # pool4 (14x14x256) -> (6x6x256)
55 | nn.ReLU(inplace=True),
56 | )
57 | self.classifier = nn.Sequential(
58 | nn.Linear(256 * 6 * 6, 4096, bias=False), # fc0
59 | nn.BatchNorm1d(4096, eps=1e-4, momentum=0.9), # bnfc0
60 | nn.ReLU(inplace=True),
61 | nn.Linear(4096, 4096, bias=False), # fc1
62 | nn.BatchNorm1d(4096, eps=1e-4, momentum=0.9), # bnfc1
63 | nn.ReLU(inplace=True),
64 | nn.Linear(4096, num_classes), # fct
65 | )
66 |
67 | for m in self.modules():
68 | if isinstance(m, (nn.Conv2d, nn.Linear)):
69 | fan_in, k_size = (m.in_channels, m.kernel_size[0] * m.kernel_size[1]) if isinstance(m, nn.Conv2d) \
70 | else (m.in_features, 1)
71 | n = k_size * fan_in
72 | m.weight.data.normal_(0, math.sqrt(2. / n))
73 | if hasattr(m, 'bias') and m.bias is not None:
74 | m.bias.data.fill_(0)
75 | elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
76 | m.weight.data.fill_(1)
77 | m.bias.data.zero_()
78 |
79 | def forward(self, x):
80 | x = self.features(x)
81 | x = x.view(x.size(0), 256 * 6 * 6)
82 | x = self.classifier(x)
83 | return x
84 |
85 |
86 | def alexnet_bn(**kwargs):
87 | r"""AlexNet model with batch-norm layers.
88 | Model configuration based on the AlexNet DoReFa example in `TensorPack
89 | `
90 | """
91 | model = AlexNetBN(**kwargs)
92 | return model
93 |
--------------------------------------------------------------------------------
/distiller/models/imagenet/mobilenet.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2018 Intel Corporation
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | from math import floor
18 | import torch.nn as nn
19 |
20 | __all__ = ['mobilenet', 'mobilenet_025', 'mobilenet_050', 'mobilenet_075']
21 |
22 |
23 | class MobileNet(nn.Module):
24 | def __init__(self, channel_multiplier=1.0, min_channels=8):
25 | super(MobileNet, self).__init__()
26 |
27 | if channel_multiplier <= 0:
28 | raise ValueError('channel_multiplier must be >= 0')
29 |
30 | def conv_bn_relu(n_ifm, n_ofm, kernel_size, stride=1, padding=0, groups=1):
31 | return [
32 | nn.Conv2d(n_ifm, n_ofm, kernel_size, stride=stride, padding=padding, groups=groups, bias=False),
33 | nn.BatchNorm2d(n_ofm),
34 | nn.ReLU(inplace=True)
35 | ]
36 |
37 | def depthwise_conv(n_ifm, n_ofm, stride):
38 | return nn.Sequential(
39 | *conv_bn_relu(n_ifm, n_ifm, 3, stride=stride, padding=1, groups=n_ifm),
40 | *conv_bn_relu(n_ifm, n_ofm, 1, stride=1)
41 | )
42 |
43 | base_channels = [32, 64, 128, 256, 512, 1024]
44 | self.channels = [max(floor(n * channel_multiplier), min_channels) for n in base_channels]
45 |
46 | self.model = nn.Sequential(
47 | nn.Sequential(*conv_bn_relu(3, self.channels[0], 3, stride=2, padding=1)),
48 | depthwise_conv(self.channels[0], self.channels[1], 1),
49 | depthwise_conv(self.channels[1], self.channels[2], 2),
50 | depthwise_conv(self.channels[2], self.channels[2], 1),
51 | depthwise_conv(self.channels[2], self.channels[3], 2),
52 | depthwise_conv(self.channels[3], self.channels[3], 1),
53 | depthwise_conv(self.channels[3], self.channels[4], 2),
54 | depthwise_conv(self.channels[4], self.channels[4], 1),
55 | depthwise_conv(self.channels[4], self.channels[4], 1),
56 | depthwise_conv(self.channels[4], self.channels[4], 1),
57 | depthwise_conv(self.channels[4], self.channels[4], 1),
58 | depthwise_conv(self.channels[4], self.channels[4], 1),
59 | depthwise_conv(self.channels[4], self.channels[5], 2),
60 | depthwise_conv(self.channels[5], self.channels[5], 1),
61 | nn.AvgPool2d(7),
62 | )
63 | self.fc = nn.Linear(self.channels[5], 1000)
64 |
65 | def forward(self, x):
66 | x = self.model(x)
67 | x = x.view(-1, x.size(1))
68 | x = self.fc(x)
69 | return x
70 |
71 |
72 | def mobilenet_025():
73 | return MobileNet(channel_multiplier=0.25)
74 |
75 |
76 | def mobilenet_050():
77 | return MobileNet(channel_multiplier=0.5)
78 |
79 |
80 | def mobilenet_075():
81 | return MobileNet(channel_multiplier=0.75)
82 |
83 |
84 | def mobilenet():
85 | return MobileNet()
86 |
--------------------------------------------------------------------------------
/distiller/models/imagenet/mobilenet_dropout.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2019 Intel Corporation
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | """Source: https://github.com/mit-han-lab/amc-compressed-models/blob/master/models/mobilenet_v1.py
18 |
19 | The code has been modified to remove code related to AMC.
20 | """
21 |
22 | __all__ = ['mobilenet_v1_dropout']
23 |
24 |
25 | import torch
26 | import torch.nn as nn
27 | import math
28 |
29 |
30 | def conv_bn(inp, oup, stride):
31 | return nn.Sequential(
32 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
33 | nn.BatchNorm2d(oup),
34 | nn.ReLU(inplace=True)
35 | )
36 |
37 |
38 | def conv_dw(inp, oup, stride):
39 | return nn.Sequential(
40 | nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
41 | nn.BatchNorm2d(inp),
42 | nn.ReLU(inplace=True),
43 |
44 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
45 | nn.BatchNorm2d(oup),
46 | nn.ReLU(inplace=True),
47 | )
48 |
49 |
50 | class MobileNet(nn.Module):
51 | def __init__(self, n_class=1000, channel_multiplier=1.):
52 | super(MobileNet, self).__init__()
53 | in_planes = int(32 * channel_multiplier)
54 | a = int(64 * channel_multiplier)
55 | cfg = [a, (a*2, 2), a*2, (a*4, 2), a*4, (a*8, 2), a*8, a*8, a*8, a*8, a*8, (a*16, 2), a*16]
56 |
57 | self.conv1 = conv_bn(3, in_planes, stride=2)
58 | self.features = self._make_layers(in_planes, cfg, conv_dw)
59 | #self.dropout = nn.Dropout(0.2)
60 | self.classifier = nn.Sequential(
61 | nn.Dropout(0.2),
62 | nn.Linear(cfg[-1], n_class),
63 | )
64 |
65 | self._initialize_weights()
66 |
67 | def forward(self, x):
68 | x = self.conv1(x)
69 | x = self.features(x)
70 | x = x.mean(3).mean(2) # global average pooling
71 | #x = self.dropout(x)
72 | x = self.classifier(x)
73 | return x
74 |
75 | def _make_layers(self, in_planes, cfg, layer):
76 | layers = []
77 | for x in cfg:
78 | out_planes = x if isinstance(x, int) else x[0]
79 | stride = 1 if isinstance(x, int) else x[1]
80 | layers.append(layer(in_planes, out_planes, stride))
81 | in_planes = out_planes
82 | return nn.Sequential(*layers)
83 |
84 | def _initialize_weights(self):
85 | for m in self.modules():
86 | if isinstance(m, nn.Conv2d):
87 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
88 | m.weight.data.normal_(0, math.sqrt(2. / n))
89 | if m.bias is not None:
90 | m.bias.data.zero_()
91 | elif isinstance(m, nn.BatchNorm2d):
92 | m.weight.data.fill_(1)
93 | m.bias.data.zero_()
94 | elif isinstance(m, nn.Linear):
95 | n = m.weight.size(1)
96 | m.weight.data.normal_(0, 0.01)
97 | m.bias.data.zero_()
98 |
99 |
100 | def mobilenet_v1_dropout_025():
101 | return MobileNet(channel_multiplier=0.25)
102 |
103 |
104 | def mobilenet_v1_dropout_050():
105 | return MobileNet(channel_multiplier=0.5)
106 |
107 |
108 | def mobilenet_v1_dropout_075():
109 | return MobileNet(channel_multiplier=0.75)
110 |
111 |
112 | def mobilenet_v1_dropout():
113 | return MobileNet()
114 |
--------------------------------------------------------------------------------
/distiller/models/imagenet/resnet_earlyexit.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 | import torch.utils.model_zoo as model_zoo
4 | import torchvision.models as models
5 | from torchvision.models.resnet import Bottleneck
6 | from torchvision.models.resnet import BasicBlock
7 |
8 |
9 | __all__ = ['resnet50_earlyexit']
10 |
11 |
12 | def conv3x3(in_planes, out_planes, stride=1):
13 | """3x3 convolution with padding"""
14 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
15 | padding=1, bias=False)
16 |
17 |
18 | class ResNetEarlyExit(models.ResNet):
19 |
20 | def __init__(self, block, layers, num_classes=1000):
21 | super(ResNetEarlyExit, self).__init__(block, layers, num_classes)
22 |
23 | # Define early exit layers
24 | self.conv1_exit0 = nn.Conv2d(256, 50, kernel_size=7, stride=2, padding=3, bias=True)
25 | self.conv2_exit0 = nn.Conv2d(50, 12, kernel_size=7, stride=2, padding=3, bias=True)
26 | self.conv1_exit1 = nn.Conv2d(512, 12, kernel_size=7, stride=2, padding=3, bias=True)
27 | self.fc_exit0 = nn.Linear(147 * block.expansion, num_classes)
28 | self.fc_exit1 = nn.Linear(192 * block.expansion, num_classes)
29 |
30 | def forward(self, x):
31 | x = self.conv1(x)
32 | x = self.bn1(x)
33 | x = self.relu(x)
34 | x = self.maxpool(x)
35 |
36 | x = self.layer1(x)
37 |
38 | # Add early exit layers
39 | exit0 = self.avgpool(x)
40 | exit0 = self.conv1_exit0(exit0)
41 | exit0 = self.conv2_exit0(exit0)
42 | exit0 = self.avgpool(exit0)
43 | exit0 = exit0.view(exit0.size(0), -1)
44 | exit0 = self.fc_exit0(exit0)
45 |
46 | x = self.layer2(x)
47 |
48 | # Add early exit layers
49 | exit1 = self.conv1_exit1(x)
50 | exit1 = self.avgpool(exit1)
51 | exit1 = exit1.view(exit1.size(0), -1)
52 | exit1 = self.fc_exit1(exit1)
53 |
54 | x = self.layer3(x)
55 | x = self.layer4(x)
56 |
57 | x = self.avgpool(x)
58 | x = x.view(x.size(0), -1)
59 | x = self.fc(x)
60 |
61 | # return a list of probabilities
62 | output = []
63 | output.append(exit0)
64 | output.append(exit1)
65 | output.append(x)
66 | return output
67 |
68 |
69 | def resnet50_earlyexit(pretrained=False, **kwargs):
70 | """Constructs a ResNet-50 model.
71 | """
72 | model = ResNetEarlyExit(Bottleneck, [3, 4, 6, 3], **kwargs)
73 | return model
74 |
--------------------------------------------------------------------------------
/distiller/models/mnist/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2018 Intel Corporation
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | """This package contains MNIST image classification models for pytorch"""
18 |
19 | from .simplenet_mnist import *
--------------------------------------------------------------------------------
/distiller/models/mnist/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/models/mnist/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/models/mnist/__pycache__/simplenet_mnist.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/models/mnist/__pycache__/simplenet_mnist.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/models/mnist/simplenet_mnist.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2018 Intel Corporation
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | """An implementation of a trivial MNIST model.
18 |
19 | The original network definition is sourced here: https://github.com/pytorch/examples/blob/master/mnist/main.py
20 | """
21 |
22 | import torch.nn as nn
23 | import torch.nn.functional as F
24 |
25 |
26 | __all__ = ['simplenet_mnist', 'simplenet_v2_mnist']
27 |
28 |
29 | class Simplenet(nn.Module):
30 | def __init__(self):
31 | super().__init__()
32 | self.conv1 = nn.Conv2d(1, 20, 5, 1)
33 | self.relu1 = nn.ReLU(inplace=False)
34 | self.pool1 = nn.MaxPool2d(2, 2)
35 | self.conv2 = nn.Conv2d(20, 50, 5, 1)
36 | self.relu2 = nn.ReLU(inplace=False)
37 | self.pool2 = nn.MaxPool2d(2, 2)
38 | self.fc1 = nn.Linear(4*4*50, 500)
39 | self.relu3 = nn.ReLU(inplace=False)
40 | self.fc2 = nn.Linear(500, 10)
41 |
42 | def forward(self, x):
43 | x = self.pool1(self.relu1(self.conv1(x)))
44 | x = self.pool2(self.relu2(self.conv2(x)))
45 | x = x.view(x.size(0), -1)
46 | x = self.relu3(self.fc1(x))
47 | x = self.fc2(x)
48 | return x
49 |
50 |
51 | class Simplenet_v2(nn.Module):
52 | """
53 | This is Simplenet but with only one small Linear layer, instead of two Linear layers,
54 | one of which is large.
55 | 26K parameters.
56 | python compress_classifier.py ${MNIST_PATH} --arch=simplenet_mnist --vs=0 --lr=0.01
57 |
58 | ==> Best [Top1: 98.970 Top5: 99.970 Sparsity:0.00 Params: 26000 on epoch: 54]
59 | """
60 | def __init__(self):
61 | super().__init__()
62 | self.conv1 = nn.Conv2d(1, 20, 5, 1)
63 | self.relu1 = nn.ReLU(inplace=False)
64 | self.pool1 = nn.MaxPool2d(2, 2)
65 | self.conv2 = nn.Conv2d(20, 50, 5, 1)
66 | self.relu2 = nn.ReLU(inplace=False)
67 | self.pool2 = nn.MaxPool2d(2, 2)
68 | self.avgpool = nn.AvgPool2d(4, stride=1)
69 | self.fc = nn.Linear(50, 10)
70 |
71 | def forward(self, x):
72 | x = self.pool1(self.relu1(self.conv1(x)))
73 | x = self.pool2(self.relu2(self.conv2(x)))
74 | x = self.avgpool(x)
75 | x = x.view(x.size(0), -1)
76 | x = self.fc(x)
77 | return x
78 |
79 |
80 | def simplenet_mnist():
81 | model = Simplenet()
82 | return model
83 |
84 | def simplenet_v2_mnist():
85 | model = Simplenet_v2()
86 | return model
--------------------------------------------------------------------------------
/distiller/modules/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2018 Intel Corporation
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | from .eltwise import *
18 | from .grouping import *
19 | from .matmul import *
20 | from .rnn import *
21 | from .aggregate import Norm
22 |
23 | __all__ = ['EltwiseAdd', 'EltwiseMult', 'EltwiseDiv', 'Matmul', 'BatchMatmul',
24 | 'Concat', 'Chunk', 'Split', 'Stack',
25 | 'DistillerLSTMCell', 'DistillerLSTM', 'convert_model_to_distiller_lstm',
26 | 'Norm']
27 |
--------------------------------------------------------------------------------
/distiller/modules/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/modules/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/modules/__pycache__/aggregate.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/modules/__pycache__/aggregate.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/modules/__pycache__/eltwise.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/modules/__pycache__/eltwise.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/modules/__pycache__/grouping.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/modules/__pycache__/grouping.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/modules/__pycache__/matmul.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/modules/__pycache__/matmul.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/modules/__pycache__/rnn.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/modules/__pycache__/rnn.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/modules/__pycache__/tsvd.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/modules/__pycache__/tsvd.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/modules/aggregate.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class Norm(nn.Module):
6 | """
7 | A module wrapper for vector/matrix norm
8 | """
9 | def __init__(self, p='fro', dim=None, keepdim=False):
10 | super(Norm, self).__init__()
11 | self.p = p
12 | self.dim = dim
13 | self.keepdim = keepdim
14 |
15 | def forward(self, x: torch.Tensor):
16 | return torch.norm(x, p=self.p, dim=self.dim, keepdim=self.keepdim)
17 |
--------------------------------------------------------------------------------
/distiller/modules/eltwise.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2018 Intel Corporation
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 | import torch
17 | import torch.nn as nn
18 |
19 |
20 | class EltwiseAdd(nn.Module):
21 | def __init__(self, inplace=False):
22 | super(EltwiseAdd, self).__init__()
23 |
24 | self.inplace = inplace
25 |
26 | def forward(self, *input):
27 | res = input[0]
28 | if self.inplace:
29 | for t in input[1:]:
30 | res += t
31 | else:
32 | for t in input[1:]:
33 | res = res + t
34 | return res
35 |
36 |
37 | class EltwiseMult(nn.Module):
38 | def __init__(self, inplace=False):
39 | super(EltwiseMult, self).__init__()
40 | self.inplace = inplace
41 |
42 | def forward(self, *input):
43 | res = input[0]
44 | if self.inplace:
45 | for t in input[1:]:
46 | res *= t
47 | else:
48 | for t in input[1:]:
49 | res = res * t
50 | return res
51 |
52 |
53 | class EltwiseDiv(nn.Module):
54 | def __init__(self, inplace=False):
55 | super(EltwiseDiv, self).__init__()
56 | self.inplace = inplace
57 |
58 | def forward(self, x: torch.Tensor, y):
59 | if self.inplace:
60 | return x.div_(y)
61 | return x.div(y)
62 |
63 |
--------------------------------------------------------------------------------
/distiller/modules/grouping.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2018 Intel Corporation
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | import torch
18 | import torch.nn as nn
19 |
20 |
21 | class Concat(nn.Module):
22 | def __init__(self, dim=0):
23 | super(Concat, self).__init__()
24 | self.dim = dim
25 |
26 | def forward(self, *seq):
27 | return torch.cat(seq, dim=self.dim)
28 |
29 |
30 | class Chunk(nn.Module):
31 | def __init__(self, chunks, dim=0):
32 | super(Chunk, self).__init__()
33 | self.chunks = chunks
34 | self.dim = dim
35 |
36 | def forward(self, tensor):
37 | return tensor.chunk(self.chunks, dim=self.dim)
38 |
39 |
40 | class Split(nn.Module):
41 | def __init__(self, split_size_or_sections, dim=0):
42 | super(Split, self).__init__()
43 | self.split_size_or_sections = split_size_or_sections
44 | self.dim = dim
45 |
46 | def forward(self, tensor):
47 | return torch.split(tensor, self.split_size_or_sections, dim=self.dim)
48 |
49 |
50 | class Stack(nn.Module):
51 | def __init__(self, dim=0):
52 | super(Stack, self).__init__()
53 | self.dim = dim
54 |
55 | def forward(self, seq):
56 | return torch.stack(seq, dim=self.dim)
57 |
--------------------------------------------------------------------------------
/distiller/modules/matmul.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2019 Intel Corporation
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 | import torch
17 | import torch.nn as nn
18 |
19 |
20 | class Matmul(nn.Module):
21 | """
22 | A wrapper module for matmul operation between 2 tensors.
23 | """
24 | def __init__(self):
25 | super(Matmul, self).__init__()
26 |
27 | def forward(self, a: torch.Tensor, b: torch.Tensor):
28 | return a.matmul(b)
29 |
30 |
31 | class BatchMatmul(nn.Module):
32 | """
33 | A wrapper module for torch.bmm operation between 2 tensors.
34 | """
35 | def __init__(self):
36 | super(BatchMatmul, self).__init__()
37 |
38 | def forward(self, a: torch.Tensor, b:torch.Tensor):
39 | return torch.bmm(a, b)
--------------------------------------------------------------------------------
/distiller/modules/tsvd.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2019 Intel Corporation
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | """Truncated-SVD module.
18 |
19 | For an example of how truncated-SVD can be used, see this Jupyter notebook:
20 | https://github.com/NervanaSystems/distiller/blob/master/jupyter/truncated_svd.ipynb
21 |
22 | """
23 |
24 | def truncated_svd(W, l):
25 | """Compress the weight matrix W of an inner product (fully connected) layer using truncated SVD.
26 |
27 | For the original implementation (MIT license), see Faster-RCNN:
28 | https://github.com/rbgirshick/py-faster-rcnn/blob/master/tools/compress_net.py
29 | We replaced numpy operations with pytorch operations (so that we can leverage the GPU).
30 |
31 | Arguments:
32 | W: N x M weights matrix
33 | l: number of singular values to retain
34 | Returns:
35 | Ul, L: matrices such that W \approx Ul*L
36 | """
37 |
38 | U, s, V = torch.svd(W, some=True)
39 |
40 | Ul = U[:, :l]
41 | sl = s[:l]
42 | V = V.t()
43 | Vl = V[:l, :]
44 |
45 | SV = torch.mm(torch.diag(sl), Vl)
46 | return Ul, SV
47 |
48 |
49 | class TruncatedSVD(nn.Module):
50 | def __init__(self, replaced_gemm, gemm_weights, preserve_ratio):
51 | super().__init__()
52 | self.replaced_gemm = replaced_gemm
53 | print("W = {}".format(gemm_weights.shape))
54 | self.U, self.SV = truncated_svd(gemm_weights.data, int(preserve_ratio * gemm_weights.size(0)))
55 | print("U = {}".format(self.U.shape))
56 |
57 | self.fc_u = nn.Linear(self.U.size(1), self.U.size(0)).cuda()
58 | self.fc_u.weight.data = self.U
59 |
60 | print("SV = {}".format(self.SV.shape))
61 | self.fc_sv = nn.Linear(self.SV.size(1), self.SV.size(0)).cuda()
62 | self.fc_sv.weight.data = self.SV#.t()
63 |
64 | def forward(self, x):
65 | x = self.fc_sv.forward(x)
66 | x = self.fc_u.forward(x)
67 | return x
68 |
--------------------------------------------------------------------------------
/distiller/pruning/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2018 Intel Corporation
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | """
18 | :mod:`distiller.pruning` is a package implementing various pruning algorithms.
19 | """
20 |
21 | from .magnitude_pruner import MagnitudeParameterPruner
22 | from .automated_gradual_pruner import AutomatedGradualPruner, \
23 | L1RankedStructureParameterPruner_AGP, \
24 | L2RankedStructureParameterPruner_AGP, \
25 | ActivationAPoZRankedFilterPruner_AGP, \
26 | ActivationMeanRankedFilterPruner_AGP, \
27 | GradientRankedFilterPruner_AGP, \
28 | RandomRankedFilterPruner_AGP, \
29 | BernoulliFilterPruner_AGP
30 | from .level_pruner import SparsityLevelParameterPruner
31 | from .sensitivity_pruner import SensitivityPruner
32 | from .splicing_pruner import SplicingPruner
33 | from .structure_pruner import StructureParameterPruner
34 | from .ranked_structures_pruner import L1RankedStructureParameterPruner, \
35 | L2RankedStructureParameterPruner, \
36 | ActivationAPoZRankedFilterPruner, \
37 | ActivationMeanRankedFilterPruner, \
38 | GradientRankedFilterPruner, \
39 | RandomRankedFilterPruner, \
40 | RandomLevelStructureParameterPruner, \
41 | BernoulliFilterPruner, \
42 | FMReconstructionChannelPruner
43 | from .baidu_rnn_pruner import BaiduRNNPruner
44 | from .greedy_filter_pruning import greedy_pruner
45 |
46 | del magnitude_pruner
47 | del automated_gradual_pruner
48 | del level_pruner
49 | del sensitivity_pruner
50 | del structure_pruner
51 | del ranked_structures_pruner
52 |
--------------------------------------------------------------------------------
/distiller/pruning/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/pruning/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/pruning/__pycache__/automated_gradual_pruner.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/pruning/__pycache__/automated_gradual_pruner.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/pruning/__pycache__/baidu_rnn_pruner.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/pruning/__pycache__/baidu_rnn_pruner.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/pruning/__pycache__/greedy_filter_pruning.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/pruning/__pycache__/greedy_filter_pruning.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/pruning/__pycache__/level_pruner.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/pruning/__pycache__/level_pruner.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/pruning/__pycache__/magnitude_pruner.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/pruning/__pycache__/magnitude_pruner.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/pruning/__pycache__/pruner.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/pruning/__pycache__/pruner.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/pruning/__pycache__/ranked_structures_pruner.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/pruning/__pycache__/ranked_structures_pruner.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/pruning/__pycache__/sensitivity_pruner.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/pruning/__pycache__/sensitivity_pruner.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/pruning/__pycache__/splicing_pruner.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/pruning/__pycache__/splicing_pruner.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/pruning/__pycache__/structure_pruner.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/pruning/__pycache__/structure_pruner.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/pruning/baidu_rnn_pruner.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2018 Intel Corporation
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | from .pruner import _ParameterPruner
18 | from .level_pruner import SparsityLevelParameterPruner
19 | from distiller.utils import *
20 |
21 | import distiller
22 |
23 | class BaiduRNNPruner(_ParameterPruner):
24 | """An element-wise pruner for RNN networks.
25 |
26 | Narang, Sharan & Diamos, Gregory & Sengupta, Shubho & Elsen, Erich. (2017).
27 | Exploring Sparsity in Recurrent Neural Networks.
28 | (https://arxiv.org/abs/1704.05119)
29 |
30 | This implementation slightly differs from the algorithm original paper in that
31 | the algorithm changes the pruning rate at the training-step granularity, while
32 | Distiller controls the pruning rate at epoch granularity.
33 |
34 | Equation (1):
35 |
36 | 2 * q * freq
37 | start_slope = -------------------------------------------------------
38 | 2 * (ramp_itr - start_itr ) + 3 * (end_itr - ramp_itr )
39 |
40 |
41 | Pruning algorithm (1):
42 |
43 | if current itr < ramp itr then
44 | threshold = start_slope * (current_itr - start_itr + 1) / freq
45 | else
46 | threshold = (start_slope * (ramp_itr - start_itr + 1) +
47 | ramp_slope * (current_itr - ramp_itr + 1)) / freq
48 | end if
49 |
50 | mask = abs(param) < threshold
51 | """
52 |
53 | def __init__(self, name, q, ramp_epoch_offset, ramp_slope_mult, weights):
54 | # Initialize the pruner, using a configuration that originates from the
55 | # schedule YAML file.
56 | super(BaiduRNNPruner, self).__init__(name)
57 | self.params_names = weights
58 | assert self.params_names
59 |
60 | # This is the 'q' value that appears in equation (1) of the paper
61 | self.q = q
62 | # This is the number of epochs to wait after starting_epoch, before we
63 | # begin ramping up the pruning rate.
64 | # In other words, between epochs 'starting_epoch' and 'starting_epoch'+
65 | # self.ramp_epoch_offset the pruning slope is 'self.start_slope'. After
66 | # that, the slope is 'self.ramp_slope'
67 | self.ramp_epoch_offset = ramp_epoch_offset
68 | self.ramp_slope_mult = ramp_slope_mult
69 | self.ramp_slope = None
70 | self.start_slope = None
71 |
72 | def set_param_mask(self, param, param_name, zeros_mask_dict, meta):
73 | if param_name not in self.params_names:
74 | return
75 |
76 | starting_epoch = meta['starting_epoch']
77 | current_epoch = meta['current_epoch']
78 | ending_epoch = meta['ending_epoch']
79 | freq = meta['frequency']
80 |
81 | ramp_epoch = self.ramp_epoch_offset + starting_epoch
82 |
83 | # Calculate start slope
84 | if self.start_slope is None:
85 | # We want to calculate these values only once, and then cache them.
86 | self.start_slope = (2 * self.q * freq) / (2*(ramp_epoch - starting_epoch) + 3*(ending_epoch - ramp_epoch))
87 | self.ramp_slope = self.start_slope * self.ramp_slope_mult
88 |
89 | if current_epoch < ramp_epoch:
90 | eps = self.start_slope * (current_epoch - starting_epoch + 1) / freq
91 | else:
92 | eps = (self.start_slope * (ramp_epoch - starting_epoch + 1) +
93 | self.ramp_slope * (current_epoch - ramp_epoch + 1)) / freq
94 |
95 | # After computing the threshold, we can create the mask
96 | zeros_mask_dict[param_name].mask = distiller.threshold_mask(param.data, eps)
97 |
--------------------------------------------------------------------------------
/distiller/pruning/level_pruner.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2018 Intel Corporation
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | import torch
18 | from .pruner import _ParameterPruner
19 | import distiller
20 |
21 | class SparsityLevelParameterPruner(_ParameterPruner):
22 | """Prune to an exact pruning level specification.
23 |
24 | This pruner is very similar to MagnitudeParameterPruner, but instead of
25 | specifying an absolute threshold for pruning, you specify a target sparsity
26 | level (expressed as a fraction: 0.5 means 50% sparsity.)
27 |
28 | To find the correct threshold, we view the tensor as one large 1D vector, sort
29 | it using the absolute values of the elements, and then take topk elements.
30 | """
31 |
32 | def __init__(self, name, levels, **kwargs):
33 | super(SparsityLevelParameterPruner, self).__init__(name)
34 | self.levels = levels
35 | assert self.levels
36 |
37 | def set_param_mask(self, param, param_name, zeros_mask_dict, meta):
38 | # If there is a specific sparsity level specified for this module, then
39 | # use it. Otherwise try to use the default level ('*').
40 | desired_sparsity = self.levels.get(param_name, self.levels.get('*', 0))
41 | if desired_sparsity == 0:
42 | return
43 |
44 | self.prune_level(param, param_name, zeros_mask_dict, desired_sparsity)
45 |
46 | @staticmethod
47 | def prune_level(param, param_name, zeros_mask_dict, desired_sparsity):
48 | bottomk, _ = torch.topk(param.abs().view(-1), int(desired_sparsity * param.numel()), largest=False, sorted=True)
49 | threshold = bottomk.data[-1] # This is the largest element from the group of elements that we prune away
50 | zeros_mask_dict[param_name].mask = distiller.threshold_mask(param.data, threshold)
51 |
--------------------------------------------------------------------------------
/distiller/pruning/magnitude_pruner.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2018 Intel Corporation
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | from .pruner import _ParameterPruner
18 | import distiller
19 |
20 |
21 | class MagnitudeParameterPruner(_ParameterPruner):
22 | """This is the most basic magnitude-based pruner.
23 |
24 | This pruner supports configuring a scalar threshold for each layer.
25 | A default threshold is mandatory and is used for layers without explicit
26 | threshold setting.
27 |
28 | """
29 | def __init__(self, name, thresholds, **kwargs):
30 | """
31 | Usually, a Pruner is constructed by the compression schedule parser
32 | found in distiller/config.py.
33 | The constructor is passed a dictionary of thresholds, as explained below.
34 |
35 | Args:
36 | name (string): the name of the pruner (used only for debug)
37 | thresholds (dict): a disctionary of thresholds, with the key being the
38 | parameter name.
39 | A special key, '*', represents the default threshold value. If
40 | set_param_mask is invoked on a parameter tensor that does not have
41 | an explicit entry in the 'thresholds' dictionary, then this default
42 | value is used.
43 | Currently it is mandatory to include a '*' key in 'thresholds'.
44 | """
45 | super(MagnitudeParameterPruner, self).__init__(name)
46 | assert thresholds is not None
47 | # Make sure there is a default threshold to use
48 | assert '*' in thresholds
49 | self.thresholds = thresholds
50 |
51 | def set_param_mask(self, param, param_name, zeros_mask_dict, meta):
52 | threshold = self.thresholds.get(param_name, self.thresholds['*'])
53 | zeros_mask_dict[param_name].mask = distiller.threshold_mask(param.data, threshold)
54 |
--------------------------------------------------------------------------------
/distiller/pruning/pruner.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2018 Intel Corporation
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | import torch
18 | import distiller
19 |
20 | class _ParameterPruner(object):
21 | """Base class for all pruners.
22 |
23 | Arguments:
24 | name: pruner name is used mainly for debugging.
25 | """
26 | def __init__(self, name):
27 | self.name = name
28 |
29 | def set_param_mask(self, param, param_name, zeros_mask_dict, meta):
30 | raise NotImplementedError
31 |
32 | def threshold_model(model, threshold):
33 | """Threshold an entire model using the provided threshold
34 |
35 | This function prunes weights only (biases are left untouched).
36 | """
37 | for name, p in model.named_parameters():
38 | if 'weight' in name:
39 | mask = distiller.threshold_mask(p.data, threshold)
40 | p.data = p.data.mul_(mask)
41 |
--------------------------------------------------------------------------------
/distiller/pruning/sensitivity_pruner.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2018 Intel Corporation
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | from .pruner import _ParameterPruner
18 | import distiller
19 | import torch
20 |
21 | class SensitivityPruner(_ParameterPruner):
22 | """Use algorithm from "Learning both Weights and Connections for Efficient
23 | Neural Networks" - https://arxiv.org/pdf/1506.02626v3.pdf
24 |
25 | I.e.: "The pruning threshold is chosen as a quality parameter multiplied
26 | by the standard deviation of a layers weights."
27 | In this code, the "quality parameter" is referred to as "sensitivity" and
28 | is based on the values learned from performing sensitivity analysis.
29 |
30 | Note that this implementation deviates slightly from the algorithm Song Han
31 | describes in his PhD dissertation, in that the threshold value is set only
32 | once. In his PhD dissertation, Song Han describes a growing threshold, at
33 | each iteration. This requires n+1 hyper-parameters (n being the number of
34 | pruning iterations we use): the threshold and the threshold increase (delta)
35 | at each pruning iteration.
36 | The implementation that follows, takes advantage of the fact that as pruning
37 | progresses, more weights are pulled toward zero, and therefore the threshold
38 | "traps" more weights. Thus, we can use less hyper-parameters and achieve the
39 | same results.
40 | """
41 |
42 | def __init__(self, name, sensitivities, **kwargs):
43 | super(SensitivityPruner, self).__init__(name)
44 | self.sensitivities = sensitivities
45 |
46 | def set_param_mask(self, param, param_name, zeros_mask_dict, meta):
47 | if not hasattr(param, 'stddev'):
48 | param.stddev = torch.std(param).item()
49 |
50 | if param_name not in self.sensitivities:
51 | if '*' not in self.sensitivities:
52 | return
53 | else:
54 | sensitivity = self.sensitivities['*']
55 | else:
56 | sensitivity = self.sensitivities[param_name]
57 |
58 | threshold = param.stddev * sensitivity
59 |
60 | # After computing the threshold, we can create the mask
61 | zeros_mask_dict[param_name].mask = distiller.threshold_mask(param.data, threshold)
62 |
--------------------------------------------------------------------------------
/distiller/pruning/splicing_pruner.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2018 Intel Corporation
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 |
18 | from .pruner import _ParameterPruner
19 | import torch
20 | import logging
21 | msglogger = logging.getLogger()
22 |
23 |
24 | class SplicingPruner(_ParameterPruner):
25 | """A pruner that both prunes and splices connections.
26 |
27 | The idea of pruning and splicing working in tandem was first proposed in the following
28 | NIPS paper from Intel Labs China in 2016:
29 | Dynamic Network Surgery for Efficient DNNs, Yiwen Guo, Anbang Yao, Yurong Chen.
30 | NIPS 2016, https://arxiv.org/abs/1608.04493.
31 |
32 | A SplicingPruner works best with a Dynamic Network Surgery schedule.
33 | The original Caffe code from the authors of the paper is available here:
34 | https://github.com/yiwenguo/Dynamic-Network-Surgery/blob/master/src/caffe/layers/compress_conv_layer.cpp
35 | """
36 |
37 | def __init__(self, name, sensitivities, low_thresh_mult, hi_thresh_mult, sensitivity_multiplier=0):
38 | """Arguments:
39 | """
40 | super(SplicingPruner, self).__init__(name)
41 | self.sensitivities = sensitivities
42 | self.low_thresh_mult = low_thresh_mult
43 | self.hi_thresh_mult = hi_thresh_mult
44 | self.sensitivity_multiplier = sensitivity_multiplier
45 |
46 | def set_param_mask(self, param, param_name, zeros_mask_dict, meta):
47 | if param_name not in self.sensitivities:
48 | if '*' not in self.sensitivities:
49 | return
50 | else:
51 | sensitivity = self.sensitivities['*']
52 | else:
53 | sensitivity = self.sensitivities[param_name]
54 |
55 | if not hasattr(param, '_std'):
56 | # Compute the mean and standard-deviation once, and cache them.
57 | param._std = torch.std(param.abs()).item()
58 | param._mean = torch.mean(param.abs()).item()
59 |
60 | if self.sensitivity_multiplier > 0:
61 | # Linearly growing sensitivity - for now this is hard-coded
62 | starting_epoch = meta['starting_epoch']
63 | current_epoch = meta['current_epoch']
64 | sensitivity *= (current_epoch - starting_epoch) * self.sensitivity_multiplier + 1
65 |
66 | threshold_low = (param._mean + param._std * sensitivity) * self.low_thresh_mult
67 | threshold_hi = (param._mean + param._std * sensitivity) * self.hi_thresh_mult
68 |
69 | if zeros_mask_dict[param_name].mask is None:
70 | zeros_mask_dict[param_name].mask = torch.ones_like(param)
71 |
72 | # This code performs the code in equation (3) of the "Dynamic Network Surgery" paper:
73 | #
74 | # 0 if a > |W|
75 | # h(W) = mask if a <= |W| < b
76 | # 1 if b <= |W|
77 | #
78 | # h(W) is the so-called "network surgery function".
79 | # mask is the mask used in the previous iteration.
80 | # a and b are the low and high thresholds, respectively.
81 | # We followed the example implementation from Yiwen Guo in Caffe, and used the
82 | # weight tensor's starting mean and std.
83 | # This is very similar to the initialization performed by distiller.SensitivityPruner.
84 |
85 | mask = zeros_mask_dict[param_name].mask
86 | zeros, ones = torch.tensor([0]).type(mask.type()), torch.tensor([1]).type(mask.type())
87 | weights_abs = param.abs()
88 | new_mask = torch.where(threshold_low > weights_abs, zeros, mask)
89 | new_mask = torch.where(threshold_hi <= weights_abs, ones, new_mask)
90 | zeros_mask_dict[param_name].mask = new_mask
91 |
--------------------------------------------------------------------------------
/distiller/pruning/structure_pruner.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2018 Intel Corporation
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | import logging
18 | from .pruner import _ParameterPruner
19 | import distiller
20 | msglogger = logging.getLogger()
21 |
22 | class StructureParameterPruner(distiller.GroupThresholdMixin, _ParameterPruner):
23 | """Prune parameter structures.
24 |
25 | Pruning criterion: average L1-norm. If the average L1-norm (absolute value) of the eleements
26 | in the structure is below threshold, then the structure is pruned.
27 |
28 | We use the average, instead of plain L1-norm, because we don't want the threshold to depend on
29 | the structure size.
30 | """
31 | def __init__(self, name, model, reg_regims, threshold_criteria):
32 | super(StructureParameterPruner, self).__init__(name)
33 | self.name = name
34 | self.model = model
35 | self.reg_regims = reg_regims
36 | self.threshold_criteria = threshold_criteria
37 | assert threshold_criteria in ["Max", "Mean_Abs"]
38 |
39 | def set_param_mask(self, param, param_name, zeros_mask_dict, meta):
40 | if param_name not in self.reg_regims.keys():
41 | return
42 |
43 | group_type = self.reg_regims[param_name][1]
44 | threshold = self.reg_regims[param_name][0]
45 | zeros_mask_dict[param_name].mask = self.group_threshold_mask(param,
46 | group_type,
47 | threshold,
48 | self.threshold_criteria)
49 |
--------------------------------------------------------------------------------
/distiller/quantization/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2018 Intel Corporation
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | from .quantizer import Quantizer
18 | from .range_linear import RangeLinearQuantWrapper, RangeLinearQuantParamLayerWrapper, PostTrainLinearQuantizer, \
19 | LinearQuantMode, QuantAwareTrainRangeLinearQuantizer, add_post_train_quant_args, NCFQuantAwareTrainQuantizer, \
20 | RangeLinearQuantConcatWrapper, RangeLinearQuantEltwiseAddWrapper, RangeLinearQuantEltwiseMultWrapper, ClipMode
21 | from .clipped_linear import LinearQuantizeSTE, ClippedLinearQuantization, WRPNQuantizer, DorefaQuantizer, PACTQuantizer
22 |
23 | del quantizer
24 | del range_linear
25 | del clipped_linear
26 |
--------------------------------------------------------------------------------
/distiller/quantization/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/quantization/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/quantization/__pycache__/clipped_linear.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/quantization/__pycache__/clipped_linear.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/quantization/__pycache__/q_utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/quantization/__pycache__/q_utils.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/quantization/__pycache__/quantizer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/quantization/__pycache__/quantizer.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/quantization/__pycache__/range_linear.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/quantization/__pycache__/range_linear.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/quantization/__pycache__/sim_bn_fold.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/quantization/__pycache__/sim_bn_fold.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/regularization/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2018 Intel Corporation
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | from .l1_regularizer import L1Regularizer
18 | from .group_regularizer import GroupLassoRegularizer, GroupVarianceRegularizer
19 |
20 | del l1_regularizer
21 | del group_regularizer
22 |
--------------------------------------------------------------------------------
/distiller/regularization/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/regularization/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/regularization/__pycache__/drop_filter.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/regularization/__pycache__/drop_filter.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/regularization/__pycache__/group_regularizer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/regularization/__pycache__/group_regularizer.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/regularization/__pycache__/l1_regularizer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/regularization/__pycache__/l1_regularizer.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/regularization/__pycache__/regularizer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/regularization/__pycache__/regularizer.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller/regularization/drop_filter.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | from .regularizer import _Regularizer
6 |
7 |
8 | class Conv2dWithMask(nn.Conv2d):
9 | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
10 | padding=0, dilation=1, groups=1, bias=True):
11 |
12 | super(Conv2dWithMask, self).__init__(
13 | in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
14 | padding=padding, dilation=dilation, groups=groups, bias=bias)
15 |
16 | self.test_mask = None
17 | self.p_mask = 1.0
18 | self.frequency = 16
19 |
20 | def forward(self, input):
21 | if self.training:
22 | self.frequency -= 1
23 | if self.frequency == 0:
24 | sample = np.random.binomial(n=1, p=self.p_mask, size=self.out_channels)
25 | param = self.weight
26 | l1norm = param.detach().view(param.size(0), -1).norm(p=1, dim=1)
27 | mask = torch.tensor(sample)
28 | mask = mask.expand(param.size(1) * param.size(2) * param.size(3), param.size(0)).t().contiguous()
29 | mask = mask.view(self.weight.shape).to(param.device)
30 | mask = mask.type(param.type())
31 | masked_weights = self.weight * mask
32 | masked_l1norm = masked_weights.detach().view(param.size(0), -1).norm(p=1, dim=1)
33 | pruning_factor = (masked_l1norm.sum() / l1norm.sum()).item()
34 | pruning_factor = max(0.2, pruning_factor)
35 | weight = masked_weights / pruning_factor
36 | self.frequency = 16
37 | else:
38 | weight = self.weight
39 | else:
40 | weight = self.weight
41 | return F.conv2d(input, weight, self.bias, self.stride,
42 | self.padding, self.dilation, self.groups)
43 |
44 |
45 | # replaces all conv2d layers in target`s model with 'Conv2dWithMask'
46 | def replace_conv2d(container):
47 | for name, module in container.named_children():
48 | if (isinstance(module, nn.Conv2d)):
49 | print("replacing: ", name)
50 | new_module = Conv2dWithMask(in_channels=module.in_channels,
51 | out_channels=module.out_channels,
52 | kernel_size=module.kernel_size, padding=module.padding,
53 | stride=module.stride, bias=module.bias)
54 | setattr(container, name, new_module)
55 | replace_conv2d(module)
56 |
57 |
58 | class DropFilterRegularizer(_Regularizer):
59 | def __init__(self, name, model, reg_regims, threshold_criteria=None):
60 | super().__init__(name, model, reg_regims, threshold_criteria)
61 | replace_conv2d(model)
62 |
--------------------------------------------------------------------------------
/distiller/regularization/l1_regularizer.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2018 Intel Corporation
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | """L1-norm regularization"""
18 |
19 | import torch
20 | import math
21 | import numpy as np
22 | import distiller
23 | from .regularizer import _Regularizer, EPSILON
24 |
25 | class L1Regularizer(_Regularizer):
26 | def __init__(self, name, model, reg_regims, threshold_criteria=None):
27 | super(L1Regularizer, self).__init__(name, model, reg_regims, threshold_criteria)
28 |
29 | def loss(self, param, param_name, regularizer_loss, zeros_mask_dict):
30 | if param_name in self.reg_regims:
31 | strength = self.reg_regims[param_name]
32 | regularizer_loss += L1Regularizer.__add_l1(param, strength)
33 |
34 | return regularizer_loss
35 |
36 | def threshold(self, param, param_name, zeros_mask_dict):
37 | """Soft threshold for L1-norm regularizer"""
38 | if self.threshold_criteria is None or param_name not in self.reg_regims:
39 | return
40 |
41 | strength = self.reg_regims[param_name]
42 | zeros_mask_dict[param_name].mask = distiller.threshold_mask(param.data, threshold=strength)
43 | zeros_mask_dict[param_name].is_regularization_mask = True
44 |
45 | @staticmethod
46 | def __add_l1(var, strength):
47 | return var.abs().sum() * strength
48 |
49 | @staticmethod
50 | def __add_l1_all(loss, model, reg_regims):
51 | for param_name, param in model.named_parameters():
52 | if param_name in reg_regims.keys():
53 | strength = reg_regims[param_name]
54 | loss += L1Regularizer.__add_l1(param, strength)
55 |
--------------------------------------------------------------------------------
/distiller/regularization/regularizer.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2018 Intel Corporation
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | EPSILON = 1e-8
18 |
19 | class _Regularizer(object):
20 | def __init__(self, name, model, reg_regims, threshold_criteria):
21 | """Regularization base class.
22 |
23 | Args:
24 | reg_regims: regularization regiment. A dictionary of
25 | reg_regims[] = [ lambda, structure-type]
26 | """
27 | self.name = name
28 | self.model = model
29 | self.reg_regims = reg_regims
30 | self.threshold_criteria = threshold_criteria
31 |
32 | def loss(self, param, param_name, regularizer_loss, zeros_mask_dict):
33 | raise NotImplementedError
34 |
35 | def threshold(self, param, param_name, zeros_mask_dict):
36 | raise NotImplementedError
37 |
--------------------------------------------------------------------------------
/example~:
--------------------------------------------------------------------------------
1 | channels:
2 | - http://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/
3 | show_channel_urls: true
4 |
--------------------------------------------------------------------------------
/ljt_mobilefacenet_y2.sh:
--------------------------------------------------------------------------------
1 | # BUSID
2 | python main.py --mode sa \
3 | --model mobilefacenet_y2_ljt \
4 | --best_model_path /home/yeluyue/lz/model/2020-08-23-08-09_CombineMargin-ljt83-m0.9m0.4m0.15s64_le_re_0.4_144x122_2020-07-30-Full-CLEAN-0803-2-ID-INTRA-MIDDLE-30-INTER-90-HARD_MobileFaceNety2-d512-k-9-8_model_iter-125993_TYLG-0.7520_PadMaskYTBYGlassM280-0.9104_BusIDPhoto-0.7489-noamp.pth \
5 | --test_root_path /home/yeluyue/lz/dataset/200914_data_model_ljt/BusID/le_re_0.4_144x122 \
6 | --img_list_label_path /home/yeluyue/lz/dataset/200914_data_model_ljt/BusID/id_life_image_list_bmppair.txt \
7 | --data_source company \
8 | --fpgm
9 |
10 | python main.py --mode prune \
11 | --model mobilefacenet_y2_ljt \
12 | --save_model_pt \
13 | --data_source company \
14 | --fpgm \
15 | --best_model_path /home/yeluyue/lz/model/2020-08-23-08-09_CombineMargin-ljt83-m0.9m0.4m0.15s64_le_re_0.4_144x122_2020-07-30-Full-CLEAN-0803-2-ID-INTRA-MIDDLE-30-INTER-90-HARD_MobileFaceNety2-d512-k-9-8_model_iter-125993_TYLG-0.7520_PadMaskYTBYGlassM280-0.9104_BusIDPhoto-0.7489-noamp.pth \
16 | --test_root_path /home/yeluyue/lz/dataset/200914_data_model_ljt/BusID/le_re_0.4_144x122 \
17 | --img_list_label_path /home/yeluyue/lz/dataset/200914_data_model_ljt/BusID/id_life_image_list_bmppair.txt \
18 |
19 |
20 | # TYLG
21 | python main.py --mode sa \
22 | --model mobilefacenet_y2_ljt \
23 | --best_model_path /home/yeluyue/lz/model/2020-08-23-08-09_CombineMargin-ljt83-m0.9m0.4m0.15s64_le_re_0.4_144x122_2020-07-30-Full-CLEAN-0803-2-ID-INTRA-MIDDLE-30-INTER-90-HARD_MobileFaceNety2-d512-k-9-8_model_iter-125993_TYLG-0.7520_PadMaskYTBYGlassM280-0.9104_BusIDPhoto-0.7489-noamp.pth \
24 | --test_root_path /home/yeluyue/lz/dataset/200914_data_model_ljt/TYLG/le_re_0.4_144x122 \
25 | --img_list_label_path /home/yeluyue/lz/dataset/200914_data_model_ljt/TYLG/id_life_image_list_bmppair.txt \
26 | --data_source company \
27 | --fpgm
28 |
29 | python main.py --mode prune \
30 | --model mobilefacenet_y2_ljt \
31 | --save_model_pt \
32 | --data_source company \
33 | --fpgm \
34 | --best_model_path /home/yeluyue/lz/model/2020-08-23-08-09_CombineMargin-ljt83-m0.9m0.4m0.15s64_le_re_0.4_144x122_2020-07-30-Full-CLEAN-0803-2-ID-INTRA-MIDDLE-30-INTER-90-HARD_MobileFaceNety2-d512-k-9-8_model_iter-125993_TYLG-0.7520_PadMaskYTBYGlassM280-0.9104_BusIDPhoto-0.7489-noamp.pth \
35 | --test_root_path /home/yeluyue/lz/dataset/200914_data_model_ljt/TYLG/le_re_0.4_144x122 \
36 | --img_list_label_path /home/yeluyue/lz/dataset/200914_data_model_ljt/TYLG/id_life_image_list_bmppair.txt \
37 |
38 | # XCH
39 | python main.py --mode sa \
40 | --model mobilefacenet_y2_ljt \
41 | --best_model_path /home/yeluyue/lz/model/2020-08-23-08-09_CombineMargin-ljt83-m0.9m0.4m0.15s64_le_re_0.4_144x122_2020-07-30-Full-CLEAN-0803-2-ID-INTRA-MIDDLE-30-INTER-90-HARD_MobileFaceNety2-d512-k-9-8_model_iter-125993_TYLG-0.7520_PadMaskYTBYGlassM280-0.9104_BusIDPhoto-0.7489-noamp.pth \
42 | --test_root_path /home/yeluyue/lz/dataset/200914_data_model_ljt/XCH/le_re_0.4_144x122 \
43 | --img_list_label_path /home/yeluyue/lz/dataset/200914_data_model_ljt/XCH/id_life_image_list_bmppair.txt \
44 | --data_source company \
45 | --fpgm
--------------------------------------------------------------------------------
/ljt_shufflefacenet_v2.sh:
--------------------------------------------------------------------------------
1 | # BUSID
2 | python main.py --mode sa \
3 | --model shufflefacenet_v2_ljt \
4 | --best_model_path /home/yeluyue/lz/model/2020-09-15-10-53_CombineMargin-ljt914-m0.9m0.4m0.15s64_le_re_0.4_144x122_2020-07-30-Full-CLEAN-0803-2-MIDDLE-30_ShuffleFaceNetA-2.0-d512_model_iter-76608_TYLG-0.7319_XCHoldClean-0.8198_BusIDPhoto-0.7310-noamp.pth \
5 | --test_root_path /home/yeluyue/lz/dataset/200914_data_model_ljt/BusID/le_re_0.4_144x122 \
6 | --img_list_label_path /home/yeluyue/lz/dataset/200914_data_model_ljt/BusID/id_life_image_list_bmppair.txt \
7 | --data_source company \
8 | --fpgm
9 |
10 | # XCH
11 | python main.py --mode sa \
12 | --model shufflefacenet_v2_ljt \
13 | --best_model_path /home/yeluyue/lz/model/2020-09-15-10-53_CombineMargin-ljt914-m0.9m0.4m0.15s64_le_re_0.4_144x122_2020-07-30-Full-CLEAN-0803-2-MIDDLE-30_ShuffleFaceNetA-2.0-d512_model_iter-76608_TYLG-0.7319_XCHoldClean-0.8198_BusIDPhoto-0.7310-noamp.pth \
14 | --test_root_path /home/yeluyue/lz/dataset/200914_data_model_ljt/XCH/le_re_0.4_144x122 \
15 | --img_list_label_path /home/yeluyue/lz/dataset/200914_data_model_ljt/XCH/id_life_image_list_bmppair.txt \
16 | --data_source company \
17 | --fpgm
18 |
19 | # TYLG
20 | python main.py --mode sa \
21 | --model shufflefacenet_v2_ljt \
22 | --best_model_path /home/yeluyue/lz/model/2020-09-15-10-53_CombineMargin-ljt914-m0.9m0.4m0.15s64_le_re_0.4_144x122_2020-07-30-Full-CLEAN-0803-2-MIDDLE-30_ShuffleFaceNetA-2.0-d512_model_iter-76608_TYLG-0.7319_XCHoldClean-0.8198_BusIDPhoto-0.7310-noamp.pth \
23 | --test_root_path /home/yeluyue/lz/dataset/200914_data_model_ljt/TYLG/le_re_0.4_144x122 \
24 | --img_list_label_path /home/yeluyue/lz/dataset/200914_data_model_ljt/TYLG/id_life_image_list_bmppair.txt \
25 | --data_source company \
26 | --fpgm
27 |
28 | python main.py --mode prune \
29 | --model shufflefacenet_v2_ljt \
30 | --save_model_pt \
31 | --data_source company \
32 | --fpgm \
33 | --best_model_path /home/linx/model/ljt/2020-09-15-10-53_CombineMargin-ljt914-m0.9m0.4m0.15s64_le_re_0.4_144x122_2020-07-30-Full-CLEAN-0803-2-MIDDLE-30_ShuffleFaceNetA-2.0-d512_model_iter-76608_TYLG-0.7319_XCHoldClean-0.8198_BusIDPhoto-0.7310-noamp.pth \
34 | --test_root_path /home/linx/dataset/company_test_data/TYLG/le_re_0.4_144x122 \
35 | --img_list_label_path /home/linx/dataset/company_test_data/TYLG/id_life_image_list_bmppair.txt
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # author: LinX
3 | # datetime: 2019/11/12 下午1:38
4 | import argparse
5 | from sensitivity_analysis import sensitivity_analysis
6 | from pruning import prune
7 | from quantization import quantization
8 | from datetime import datetime
9 | import os
10 | from train_module.train_with_insight_face import face_learner
11 |
12 |
13 | def get_time():
14 | return (str(datetime.now())[:-10]).replace(' ', '-').replace(':', '-')
15 |
16 |
17 | def main():
18 | parser = argparse.ArgumentParser(description='prune face recognition model')
19 |
20 | # 剪枝
21 | parser.add_argument('--lr', default=0.01, type=float, help='retrain学习率, 一般为训练时的1/10')
22 | parser.add_argument('--weight_decay', default=4e-5, type=float, help='学习率衰减')
23 | parser.add_argument('--save_pruned_model_root', default='work_space/models/pruned_model/', help='剪枝模型定义和文件保存文件夹')
24 | parser.add_argument('--momentum', default=0.9, type=float)
25 |
26 | parser.add_argument('--epoch', default=30, type=int, help='剪枝后重训练多少个epoch')
27 | parser.add_argument('--head_path', default=None, help='训练头')
28 | parser.add_argument('--device', default='cuda:0')
29 |
30 | parser.add_argument('--print_freq', type=int, default=1, help='每隔多少次打印准确度信息')
31 | parser.add_argument('--save_model_pt', default=False, action='store_true', help='是否保存pt文件')
32 |
33 | parser.add_argument('--batch_size', type=int, default=256)
34 | parser.add_argument('--embedding_size', type=int, default=512)
35 | parser.add_argument('--pruned_save_model_path', default='work_space/pruned_model', help='剪枝后模型保存路径')
36 | parser.add_argument('--sensitivity_csv_path', default='work_space/sensitivity_data', help='剪枝敏感度分析后的csv文件保存路径')
37 |
38 | # 每次运行需要确定以下参数
39 | parser.add_argument('--mode', default=None, choices=['prune', 'quantization', 'test', 'sa', 'finetune'], help='prune表示仅仅剪枝,quantization表示量化'
40 | 'sa表示sensitivity analysis,'
41 | 'finetune表示剪枝并finetune')
42 |
43 | parser.add_argument('--best_model_path', default=None, help='已经训练好的最好的模型文件路径,准备用来剪枝')
44 | parser.add_argument('--test_root_path', default=None, help='测试集root路径')
45 | parser.add_argument('--img_list_label_path', default=None, help='测试集pair list路径')
46 | parser.add_argument('--model', default=None,
47 | choices=['mobilefacenet', 'resnet34', 'mobilefacenet_y2', 'resnet50', 'resnet100',
48 | 'mobilefacenet_lzc', 'mobilenetv3', 'resnet34_lzc', 'resnet50_imagenet',
49 | 'mobilefacenet_y2_ljt', 'shufflefacenet_v2_ljt', 'resnet_50_ljt', 'resnet_100_ljt'
50 | ], help='对哪个模型剪枝')
51 |
52 | parser.add_argument('--is_save', default=False, action='store_true', help='是否保存模型文件')
53 |
54 | parser.add_argument('--from_data_parallel', action='store_true', default=False, help='模型是否来自多卡训练')
55 |
56 | parser.add_argument('--data_source', choices=['lfw', 'company', 'company_zkx'], default='None',
57 | help='测试时使用哪个测试集, company->zy的resnet50和resnet100, company_zkx->zkx的mobilefacenet_y2')
58 |
59 | parser.add_argument('--fpgm', action='store_true', default=False, help='是否使用几何中位数剪枝')
60 | parser.add_argument('--hrank', action='store_true', default=False, help='是否使用HRank剪枝')
61 | parser.add_argument('--rank_path', default='./work_space/rank_conv/', help='HRank配置文件')
62 |
63 | parser.add_argument('--yaml_path', default='yaml_file/auto_yaml.yaml', help='剪枝配置文件')
64 |
65 | parser.add_argument('--cal_flops_and_forward', default=False, action='store_true', help='是否测试flops和前向时间')
66 |
67 | parser.add_argument('--test_batch_size', type=int, default=256)
68 |
69 | # 下面是量化时需要确定的参数
70 | parser.add_argument('--quantize-mode', type=str, choices=['symmetric', 'asymmetric-signed', 'unsigned'], default='symmetric',
71 | help='量化模式,将权重值映射到对称,有符号非对称和无符号非对称区间')
72 |
73 | parser.add_argument('--fp16', action='store_true', default=False, help='采用半精度量化,设置了此模式,上面的模式都会失效')
74 |
75 | parser.add_argument('--input_size', type=int, default=112, help='输出图片大小')
76 |
77 | parser.add_argument('--quantized_save_model_path', default='work_space/quantized_model', help='量化后模型保存路径')
78 |
79 | # finetune时所需参数
80 | parser.add_argument('--pruned_checkpoint', type=str, default=None, help='剪枝后的模型文件路径')
81 | parser.add_argument('--train_data_path', type=str, default=None, help='finetune所需训练集的路径')
82 | parser.add_argument('--milestones', type=str, default='12,15,18', help='规定在第几个epoch学习率下降')
83 | parser.add_argument('--train_batch_size', type=int, default=64, help='训练batch size')
84 | parser.add_argument('--pin_memory', type=bool, default=True)
85 | parser.add_argument('--num_workers', type=int, default=4)
86 | parser.add_argument('--work_path', type=str, default='work_space/finetune', help='训练过程产生的文件存放目录')
87 | parser.add_argument('--finetune_pruned_model', action='store_true', default=False, help='finetune 剪枝后的模型')
88 |
89 | args = parser.parse_args()
90 |
91 | if args.mode == 'prune':
92 | prune(args)
93 |
94 | elif args.mode == 'sa':
95 | sensitivity_analysis(args)
96 |
97 | elif args.mode == 'quantization':
98 | quantization(args)
99 |
100 | elif args.mode == 'finetune':
101 | args.work_path = os.path.join(args.work_path, get_time())
102 | os.mkdir(args.work_path)
103 |
104 | args.log_path = os.path.join(args.work_path, 'log')
105 | args.save_path = os.path.join(args.work_path, 'save')
106 | args.model_path = os.path.join(args.work_path, 'model')
107 |
108 | os.mkdir(args.log_path)
109 | os.mkdir(args.save_path)
110 | os.mkdir(args.model_path)
111 |
112 | args.log_path = os.path.join(args.log_path, get_time())
113 | args.milestones = list(map(int, args.milestones.split(',')))
114 |
115 | learner = face_learner(args)
116 | learner.train(args)
117 |
118 |
119 | if __name__ == '__main__':
120 | main()
121 |
--------------------------------------------------------------------------------
/model_define/__pycache__/MobileFaceNet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/model_define/__pycache__/MobileFaceNet.cpython-36.pyc
--------------------------------------------------------------------------------
/model_define/__pycache__/MobileFaceNet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/model_define/__pycache__/MobileFaceNet.cpython-37.pyc
--------------------------------------------------------------------------------
/model_define/__pycache__/MobileNetV3.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/model_define/__pycache__/MobileNetV3.cpython-36.pyc
--------------------------------------------------------------------------------
/model_define/__pycache__/MobileNetV3.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/model_define/__pycache__/MobileNetV3.cpython-37.pyc
--------------------------------------------------------------------------------
/model_define/__pycache__/load_state_dict.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/model_define/__pycache__/load_state_dict.cpython-36.pyc
--------------------------------------------------------------------------------
/model_define/__pycache__/load_state_dict.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/model_define/__pycache__/load_state_dict.cpython-37.pyc
--------------------------------------------------------------------------------
/model_define/__pycache__/model.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/model_define/__pycache__/model.cpython-36.pyc
--------------------------------------------------------------------------------
/model_define/__pycache__/model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/model_define/__pycache__/model.cpython-37.pyc
--------------------------------------------------------------------------------
/model_define/__pycache__/model_resnet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/model_define/__pycache__/model_resnet.cpython-36.pyc
--------------------------------------------------------------------------------
/model_define/__pycache__/model_resnet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/model_define/__pycache__/model_resnet.cpython-37.pyc
--------------------------------------------------------------------------------
/model_define/__pycache__/resnet50_imagenet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/model_define/__pycache__/resnet50_imagenet.cpython-36.pyc
--------------------------------------------------------------------------------
/model_define/load_state_dict.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # author: LinX
3 | # datetime: 2019/10/28 下午6:09
4 |
5 | from model_define.model_resnet import fresnet50_v3, fresnet100_v3, fresnet34_v3
6 | from collections import OrderedDict
7 | from model_define.MobileFaceNet import MobileFaceNet
8 | from model_define.model import MobileFaceNet_sor, ResNet34, MobileFaceNet_y2
9 | from model_define.MobileNetV3 import MobileNetV3_Large
10 | from model_define.resnet50_imagenet import resnet_50
11 | from model_define.mobilefacenet_y2_ljt.mobilefacenet_big import MobileFaceNet_y2_ljt
12 | from model_define.shufflefacenet_v2_ljt.ShuffleFaceNetV2 import ShuffleFaceNetV2
13 | from model_define.resnet_50_ljt.resnet_50 import fresnet50_v3_ljt
14 | from model_define.resnet_100_ljt.resnet_100 import fresnet100_v3_ljt
15 | import torch
16 |
17 |
18 | def load_state_dict(args):
19 |
20 | if args.model == 'mobilefacenet':
21 | model = MobileFaceNet_sor(args.embedding_size)
22 |
23 | elif args.model == 'resnet34':
24 | model = fresnet34_v3((112, 112))
25 | args.lr = 0.001
26 |
27 | elif args.model == 'mobilefacenet_y2':
28 | model = MobileFaceNet_y2(args.embedding_size)
29 |
30 | elif args.model == 'resnet50':
31 | model = fresnet50_v3()
32 |
33 | elif args.model == 'resnet50_imagenet':
34 | model = resnet_50()
35 |
36 | elif args.model == 'resnet100':
37 | model = fresnet100_v3()
38 |
39 | elif args.model == 'mobilefacenet_lzc':
40 | model = MobileFaceNet(128, (5, 5))
41 |
42 | elif args.model == 'mobilenetv3':
43 | model = MobileNetV3_Large(4)
44 |
45 | elif args.model == 'resnet34_lzc':
46 | model = fresnet34_v3((80, 80))
47 |
48 | elif args.model == 'mobilefacenet_y2_ljt':
49 | model = MobileFaceNet_y2_ljt()
50 |
51 | elif args.model == 'shufflefacenet_v2_ljt':
52 | model = ShuffleFaceNetV2(512, 2.0, (144, 122))
53 |
54 | elif args.model == 'resnet_50_ljt':
55 | model = fresnet50_v3_ljt()
56 |
57 | elif args.model == 'resnet_100_ljt':
58 | model = fresnet100_v3_ljt()
59 |
60 | else:
61 | print('不支持此模型剪枝!')
62 |
63 | print('load {}\'s checkpoint'.format(args.model))
64 |
65 | state_dict = torch.load(args.best_model_path, map_location=args.device)
66 | if args.from_data_parallel:
67 | new_state_dict = OrderedDict()
68 | for k, v in state_dict.items():
69 | new_state_dict[k[7:]] = v
70 | state_dict = new_state_dict
71 | model.load_state_dict(state_dict)
72 |
73 | return model
74 |
--------------------------------------------------------------------------------
/model_define/mobilefacenet_y2_ljt/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # author: linx
3 | # datetime 2020/9/14 下午6:24
4 |
--------------------------------------------------------------------------------
/model_define/mobilefacenet_y2_ljt/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/model_define/mobilefacenet_y2_ljt/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/model_define/mobilefacenet_y2_ljt/__pycache__/mobilefacenet_big.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/model_define/mobilefacenet_y2_ljt/__pycache__/mobilefacenet_big.cpython-36.pyc
--------------------------------------------------------------------------------
/model_define/mobilefacenet_y2_ljt/__pycache__/network_elems.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/model_define/mobilefacenet_y2_ljt/__pycache__/network_elems.cpython-36.pyc
--------------------------------------------------------------------------------
/model_define/mobilefacenet_y2_ljt/common_utility.py:
--------------------------------------------------------------------------------
1 | from torch.nn import Module
2 | import torch.nn.functional as F
3 | import math
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 |
8 | '''
9 | Net work's common utility
10 | '''
11 |
12 |
13 | def l2_norm(input, axis=1):
14 | norm = torch.norm(input, 2, axis, True)
15 | output = torch.div(input, norm)
16 | return output
17 |
18 |
19 | class L2Norm(Module):
20 | def forward(self, input):
21 | return F.normalize(input)
22 |
23 |
24 | class Flatten(Module):
25 | def forward(self, input):
26 | return input.view(input.size(0), -1)
27 | # for onnx model convert
28 | # batch_size = np.array(input.size(0))
29 | # batch_size.astype(dtype=np.int32)
30 | # return input.view(batch_size, 512)
31 |
32 |
33 | class GDC(nn.Module):
34 | def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
35 | super(GDC, self).__init__()
36 | self.conv = nn.Conv2d(in_c, out_channels=out_c, kernel_size=kernel,
37 | groups=groups, stride=stride, padding=padding, bias=False)
38 | self.bn = nn.BatchNorm2d(out_c)
39 |
40 | def forward(self, x):
41 | x = self.conv(x)
42 | x = self.bn(x)
43 | return x
44 |
45 |
46 | def Get_Conv_Size(height, width, kernel, stride, padding, rpt_num):
47 | conv_h = height
48 | conv_w = width
49 | for _ in range(rpt_num):
50 | conv_h = int((conv_h - kernel[0] + 2 * padding[0]) / stride[0] + 1)
51 | conv_w = int((conv_w - kernel[1] + 2 * padding[1]) / stride[1] + 1)
52 | return conv_h * conv_w
53 |
54 |
55 | def Get_Conv_kernel(height, width, kernel, stride, padding, rpt_num):
56 | conv_h = height
57 | conv_w = width
58 | for _ in range(rpt_num):
59 | conv_h = math.ceil((conv_h - kernel[0] + 2 * padding[0]) / stride[0] + 1)
60 | conv_w = math.ceil((conv_w - kernel[1] + 2 * padding[1]) / stride[1] + 1)
61 | print(conv_h, conv_w)
62 | return (int(conv_h), int(conv_w))
63 |
64 |
65 | def Get_Conv_kernel_floor(height, width, kernel, stride, padding, rpt_num):
66 | conv_h = height
67 | conv_w = width
68 | for _ in range(rpt_num):
69 | conv_h = math.floor((conv_h - kernel[0] + 2 * padding[0]) / stride[0] + 1)
70 | conv_w = math.floor((conv_w - kernel[1] + 2 * padding[1]) / stride[1] + 1)
71 | print(conv_h, conv_w)
72 | # print(conv_h, conv_w)
73 | return (int(conv_h), int(conv_w))
74 |
75 |
76 | def get_dense_ave_pooling_size(height, width, block_config):
77 | size1 = Get_Conv_kernel(height, width, (3, 3), (2, 2), (1, 1), 1)
78 | size2 = Get_Conv_kernel(size1[0], size1[1], (2, 2), (2, 2), (1, 1), 1)
79 | # print(size1)
80 | size3 = Get_Conv_kernel(size2[0], size2[1], (2, 2), (2, 2), (0, 0), len(block_config) - 1)
81 | return size3
82 |
83 |
84 | def get_shuffle_ave_pooling_size(height, width, using_pool=False):
85 | first_batch_num = 2
86 | if using_pool:
87 | first_batch_num = 3
88 |
89 | size1 = Get_Conv_kernel(height, width, (3, 3), (2, 2), (0, 0), first_batch_num)
90 | # print(size1)
91 | size2 = Get_Conv_kernel(size1[0], size1[1], (2, 2), (2, 2), (0, 0), 2)
92 | return size2
93 |
94 |
95 | def get_ghost_dw_size(height, width):
96 | size1 = Get_Conv_kernel_floor(height, width, (3, 3), (2, 2), (3 // 2, 3 // 2), 3)
97 | size1 = Get_Conv_kernel_floor(size1[0], size1[1], (5, 5), (2, 2), (5 // 2, 5 // 2), 2)
98 | return size1
99 |
100 |
101 | if __name__ == "__main__":
102 | # get_dense_ave_pooling_size(112,112, [1,2,3,4])
103 | # print("="*10)
104 | # get_shuffle_ave_pooling_size(112,112,True)
105 | # print("=" * 10)
106 | # get_shuffle_ave_pooling_size(112, 112, False)
107 | # print("=" * 10)
108 | Get_Conv_kernel_floor(112, 112, (3, 3), (2, 2), (1, 1), 4)
109 | print("=" * 10)
110 |
--------------------------------------------------------------------------------
/model_define/resnet_100_ljt/__pycache__/resnet_100.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/model_define/resnet_100_ljt/__pycache__/resnet_100.cpython-36.pyc
--------------------------------------------------------------------------------
/model_define/resnet_50_ljt/__pycache__/resnet_50.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/model_define/resnet_50_ljt/__pycache__/resnet_50.cpython-36.pyc
--------------------------------------------------------------------------------
/model_define/shufflefacenet_v2_ljt/__pycache__/ShuffleFaceNetV2.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/model_define/shufflefacenet_v2_ljt/__pycache__/ShuffleFaceNetV2.cpython-36.pyc
--------------------------------------------------------------------------------
/mytest.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # author: linx
3 | # datetime 2020/9/18 上午11:31
4 | from model_define.shufflefacenet_v2_ljt.ShuffleFaceNetV2 import ShuffleFaceNetV2
5 | import torch
6 |
7 | state_dict = torch.load('/home/linx/model/ljt/2020-09-15-10-53_CombineMargin-ljt914-m0.9m0.4m0.15s64_le_re_0.4_144x122_2020-07-30-Full-CLEAN-0803-2-MIDDLE-30_ShuffleFaceNetA-2.0-d512_model_iter-76608_TYLG-0.7319_XCHoldClean-0.8198_BusIDPhoto-0.7310-noamp.pth')
8 |
9 | for k, v in state_dict.items():
10 | print(k, v.shape)
11 | pass
12 |
13 | net = ShuffleFaceNetV2(512, '2.0', (144, 122))
14 | print(net)
--------------------------------------------------------------------------------
/pruning.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # author: LinX
3 | # datetime: 2019/10/10 下午2:21
4 |
5 | import matplotlib
6 | import torch
7 | import distiller
8 | import os
9 | from commom_utils.utils import cal_flops, test_speed, get_time
10 | from model_define.load_state_dict import load_state_dict
11 | from test_module.test_on_diverse_dataset import test
12 | import numpy as np
13 | matplotlib.use('agg')
14 |
15 |
16 | def prune(args):
17 | model = load_state_dict(args)
18 | model = model.cuda()
19 | epoch = 0
20 |
21 | # acc = test(args, model)
22 | # print('剪枝前acc为:{}'.format(acc))
23 |
24 | if args.fpgm:
25 | print('使用fpgm算法剪枝')
26 | if args.cal_flops_and_forward:
27 | flops, params = cal_flops(model, [1, 3, 144, 122])
28 | forward_time = test_speed(model, [1, 3, 144, 122])
29 | print('剪枝前前向时间为{}ms, flops={}, params={}'.format(forward_time, flops, params))
30 | conv_dict = {}
31 |
32 | if args.hrank:
33 |
34 | cnt = 1
35 | layer_name = 'conv1.conv.weight'
36 | conv_dict[layer_name] = np.load(args.rank_path + 'rank_conv' + str(cnt) + '.npy')
37 | cnt += 1
38 | for key, value in model.block_info.items():
39 | if value == 1:
40 | layer_name = '{}.conv.conv.weight'.format(key)
41 | conv_dict[layer_name] = np.load(args.rank_path+'rank_conv'+str(cnt)+'.npy')
42 | cnt += 1
43 | layer_name = '{}.conv_dw.conv.weight'.format(key)
44 | conv_dict[layer_name] = np.load(args.rank_path + 'rank_conv' + str(cnt) + '.npy')
45 | cnt += 1
46 | layer_name = '{}.project.conv.weight'.format(key)
47 | conv_dict[layer_name] = np.load(args.rank_path + 'rank_conv' + str(cnt) + '.npy')
48 | cnt += 1
49 | else:
50 | for j in range(value):
51 | layer_name = '{}.model.{}.conv.conv.weight'.format(key,j)
52 | conv_dict[layer_name] = np.load(args.rank_path + 'rank_conv' + str(cnt) + '.npy')
53 | cnt += 1
54 | layer_name = '{}.model.{}.conv_dw.conv.weight'.format(key,j)
55 | conv_dict[layer_name] = np.load(args.rank_path + 'rank_conv' + str(cnt) + '.npy')
56 | cnt += 1
57 | layer_name = '{}.model.{}.project.conv.weight'.format(key,j)
58 | conv_dict[layer_name] = np.load(args.rank_path + 'rank_conv' + str(cnt) + '.npy')
59 | cnt += 1
60 | layer_name = 'conv_6_sep.conv.weight'
61 | conv_dict[layer_name] = np.load(args.rank_path + 'rank_conv' + str(cnt) + '.npy')
62 | cnt += 1
63 | layer_name = 'conv_6_dw.conv.weight'
64 | conv_dict[layer_name] = np.load(args.rank_path + 'rank_conv' + str(cnt) + '.npy')
65 | cnt += 1
66 | print(len(conv_dict))
67 |
68 | model.train()
69 | compression_scheduler = distiller.config.file_config(model, None, args.yaml_path)
70 | compression_scheduler.on_epoch_begin(epoch, fpgm=args.fpgm, HRank=args.hrank, conv_index=conv_dict)
71 | compression_scheduler.on_minibatch_begin(epoch, minibatch_id=0, minibatches_per_epoch=0)
72 |
73 | if args.save_model_pt:
74 | torch.save(model.state_dict(), os.path.join(args.pruned_save_model_path, 'model_{}.pt'.format(args.model)))
75 | print('模型已保存!')
76 | config_model_arr = []
77 | if args.model != 'shufflefacenet_v2_ljt':
78 | for k, v in model.state_dict().items():
79 | if len(v.shape) == 4 and k.find('downsample') == -1 and k.find('fc') == -1:
80 | config_model_arr.append(v.shape[0])
81 | else:
82 | for k, v in model.state_dict().items():
83 | if k.find('branch_main.0.weight') != -1:
84 | config_model_arr.append(v.shape[0])
85 | print('网络每层out参数为:{}, 总共{}个参数'.format(config_model_arr, len(config_model_arr)))
86 |
87 | if not os.path.exists('work_space/layers_out/'):
88 | os.mkdir('work_space/layers_out/')
89 |
90 | with open('work_space/layers_out/out.txt', 'w') as f:
91 | f.write(str(config_model_arr))
92 | # print(model)
93 | acc = test(args, model)
94 | print('剪枝后acc为:{}'.format(acc))
95 |
96 | flops, params = cal_flops(model, [1, 3, 112, 112])
97 | forward_time = 0
98 | print('剪枝后前向时间为{}ms, flops={}, params={}'.format(forward_time, flops, params))
99 |
100 |
101 |
102 |
--------------------------------------------------------------------------------
/pruning_analysis_tools/plot_csv.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # author: LinX
3 | # datetime: 2019/10/9 上午10:03
4 |
5 | import pandas as pd
6 | import numpy as np
7 | import matplotlib.pyplot as plt
8 |
9 |
10 | def plot(sensitivity, save_root):
11 | """
12 | 画每个layer的折线图
13 | :param layer_name: layer名
14 | :param layer: 每个layer中的各个weight的敏感值
15 | :param save_root: 折线图存放路径
16 | :return:
17 | """
18 | i = 0
19 | for k, v in sensitivity.items():
20 |
21 | if i % 7 == 6:
22 | plt.ylabel('top1')
23 | plt.xlabel('sparsity')
24 | plt.title(str(i-6) + '-' + str(i+1) + ' Pruning Sensitivity')
25 | plt.legend(loc='lower center',
26 | ncol=2, mode="expand", borderaxespad=0.)
27 | plt.savefig(save_root + '/' + str(i-6) + '-' + str(i+1) + '.png', format='png')
28 | plt.close()
29 |
30 | sense = v
31 | name = k
32 | sparsities = np.arange(0, 0.95, 0.1)
33 | plt.plot(sparsities, sense, label=name)
34 | i += 1
35 |
36 |
37 | def plot_csv(csv_name, save_root):
38 | """
39 | plot sensitivity
40 | :param csv_name: csv文件名
41 | :param save_root: 图片保存路径
42 | :return:
43 | """
44 | data = pd.read_csv(csv_name)
45 | sensitivity = {}
46 |
47 | for x in data.values:
48 | sensitivity[x[0]] = []
49 |
50 | for x in data.values:
51 | sensitivity[x[0]].append(x[2])
52 |
53 | plot(sensitivity, save_root)
54 |
55 |
56 | if __name__ == '__main__':
57 | plot_csv('/home/yeluyue/lz/program/compression_tool/work_space/sensitivity_data/sensitivity_mobilefacenet_y2_2020-09-04-13-08.csv', '/home/yeluyue/lz/program/compression_tool/work_space/sensitivity_data')
58 |
--------------------------------------------------------------------------------
/quantization.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # author: LinX
3 | # datetime: 2019/11/12 下午1:38
4 | import torch
5 | from distiller.quantization.range_linear import PostTrainLinearQuantizer
6 | from model_define.load_state_dict import load_state_dict
7 | from test_module.test_on_diverse_dataset import test
8 | from commom_utils.utils import get_time
9 | import os
10 |
11 |
12 | def quantization(args):
13 | model = load_state_dict(args)
14 | model = model.to(args.device)
15 |
16 | acc = test(args, model)
17 | print('量化前精度为:{}'.format(acc))
18 |
19 | quantizer = PostTrainLinearQuantizer(model, fp16=args.fp16)
20 |
21 | quantizer.prepare_model(torch.rand([1, 3, args.input_size, args.input_size]))
22 |
23 | torch.save(model.state_dict(), os.path.join(args.quantized_save_model_path, args.model + get_time() + '.pt'))
24 | print('模型已保存')
25 |
26 | acc = test(args, model)
27 | print('量化后精度为:{}'.format(acc))
28 |
--------------------------------------------------------------------------------
/quantization.sh:
--------------------------------------------------------------------------------
1 | # 量化运行脚本
2 | python main.py --mode quantization \
3 | --model resnet100 \
4 | --best_model_path work_space/model_train_best/2019-09-29-11-37_SVGArcFace-O1-b0.4s40t1.1_fc_0.4_112x112_2019-09-27-Adult-padSY-Bus_fResNet100v3cv-d512_model_iter-340000.pth \
5 | --from_data_parallel \
6 | --data_source company \
7 | --test_root_path data/test_data/fc_0.4_112x112 \
8 | --img_list_label_path data/test_data/fc_0.4_112x112/pair_list/id_life_image_list_bmppair.txt \
9 | --test_batch_size 256 \
10 | --fp16
11 |
12 |
--------------------------------------------------------------------------------
/resnet_ljt.sh:
--------------------------------------------------------------------------------
1 | # ================================ResNet-50==============================================================================
2 |
3 | # TYLG
4 | python main.py --mode sa \
5 | --model resnet_50_ljt \
6 | --best_model_path /home/linx/model/ljt/2020-08-11-22-35_CombineMargin-ljt83-m0.9m0.4m0.15s64_le_re_0.4_144x122_2020-07-30-Full-CLEAN-0803-2-MIDDLE-30_fResNet50v3cv-d512_model_iter-76608_TYLG-0.8070_PadMaskYTBYGlassM280-0.9305_BusIDPhoto-0.6541.pth \
7 | --test_root_path /home/linx/dataset/company_test_data/TYLG/le_re_0.4_144x122 \
8 | --img_list_label_path /home/linx/dataset/company_test_data/TYLG/id_life_image_list_bmppair.txt \
9 | --data_source company \
10 | --fpgm \
11 | --from_data_parallel
12 |
13 | python main.py --mode prune \
14 | --model resnet_50_ljt \
15 | --save_model_pt \
16 | --data_source company \
17 | --fpgm \
18 | --best_model_path /home/linx/model/ljt/2020-06-26-12-13_CombineMargin-ljt-m0.9m0.4m0.15s64_le_re_0.4_112x112_2020-05-26-PNTMS-CLEAN-MIDDLE-70_fResNet50v3cv-d512_model_iter-113680_Idoa-0.8011_IdoaMask-0.8325_TYLG-0.8443.pth \
19 | --test_root_path /home/linx/dataset/company_test_data/TYLG/le_re_0.4_112x112 \
20 | --img_list_label_path /home/linx/dataset/company_test_data/TYLG/id_life_image_list_bmppair.txt \
21 | --from_data_parallel
22 |
23 | # ================================ResNet-100==============================================================================
24 | python main.py --mode sa \
25 | --model resnet_100_ljt \
26 | --best_model_path /home/linx/model/ljt/2020-06-27-12-59_CombineMargin-zk-O1D1Ls-m0.9m0.4m0.15s64_fc_0.4_144x122_2020-05-26-PNTMS-CLEAN-MIDDLE-70_fResNet100v3cv-d512_model_iter-96628_Idoa-0.8996_IdoaMask-0.9127_TYLG-0.9388.pth \
27 | --test_root_path /media/linx/B0C6A127C6A0EF32/200914_data_model_ljt/TYLG/fc_0.4_144x122 \
28 | --img_list_label_path /home/linx/dataset/company_test_data/TYLG/id_life_image_list_bmppair.txt \
29 | --data_source company \
30 | --fpgm \
31 | --from_data_parallel
32 |
33 |
34 | python main.py --mode prune \
35 | --model resnet_100_ljt \
36 | --save_model_pt \
37 | --data_source company \
38 | --fpgm \
39 | --best_model_path /home/linx/model/ljt/2020-06-27-12-59_CombineMargin-zk-O1D1Ls-m0.9m0.4m0.15s64_fc_0.4_144x122_2020-05-26-PNTMS-CLEAN-MIDDLE-70_fResNet100v3cv-d512_model_iter-96628_Idoa-0.8996_IdoaMask-0.9127_TYLG-0.9388.pth \
40 | --test_root_path /home/linx/dataset/company_test_data/TYLG/fc_0.4_144x122 \
41 | --img_list_label_path /home/linx/dataset/company_test_data/TYLG/id_life_image_list_bmppair.txt \
42 | --from_data_parallel
--------------------------------------------------------------------------------
/sensitivity_analysis.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # author: LinX
3 | # datetime: 2019/10/9 下午2:17
4 |
5 | import torch
6 | import os
7 | from collections import OrderedDict
8 | from copy import deepcopy
9 | import distiller
10 | from distiller.scheduler import CompressionScheduler
11 | import numpy as np
12 | from model_define.load_state_dict import load_state_dict
13 | import time
14 | from test_module.test_on_diverse_dataset import test
15 | from datetime import datetime
16 |
17 |
18 | def get_time():
19 | return (str(datetime.now())[:-10]).replace(' ', '-').replace(':', '-')
20 |
21 |
22 | def perform_sensitivity_analysis(model, net_params, sparsities, args):
23 |
24 | sensitivities = OrderedDict()
25 | print('测试原模型精度')
26 | accuracy = test(args, model)
27 |
28 | print('原模型精度为:{}'.format(accuracy))
29 |
30 | if args.fpgm:
31 | print('即将采用几何中位数剪枝产生折线图')
32 | conv_dict = {}
33 | if args.hrank:
34 | print('即将采用HRank剪枝')
35 | cnt = 1
36 | layer_name = 'conv1.conv.weight'
37 | conv_dict[layer_name] = np.load(args.rank_path + 'rank_conv' + str(cnt) + '.npy')
38 | cnt += 1
39 | for key, value in model.block_info.items():
40 | if value == 1:
41 | layer_name = '{}.conv.conv.weight'.format(key)
42 | conv_dict[layer_name] = np.load(args.rank_path + 'rank_conv' + str(cnt) + '.npy')
43 | cnt += 1
44 | layer_name = '{}.conv_dw.conv.weight'.format(key)
45 | conv_dict[layer_name] = np.load(args.rank_path + 'rank_conv' + str(cnt) + '.npy')
46 | cnt += 1
47 | layer_name = '{}.project.conv.weight'.format(key)
48 | conv_dict[layer_name] = np.load(args.rank_path + 'rank_conv' + str(cnt) + '.npy')
49 | cnt += 1
50 | else:
51 | for j in range(value):
52 | layer_name = '{}.model.{}.conv.conv.weight'.format(key, j)
53 | conv_dict[layer_name] = np.load(args.rank_path + 'rank_conv' + str(cnt) + '.npy')
54 | cnt += 1
55 | layer_name = '{}.model.{}.conv_dw.conv.weight'.format(key, j)
56 | conv_dict[layer_name] = np.load(args.rank_path + 'rank_conv' + str(cnt) + '.npy')
57 | cnt += 1
58 | layer_name = '{}.model.{}.project.conv.weight'.format(key, j)
59 | conv_dict[layer_name] = np.load(args.rank_path + 'rank_conv' + str(cnt) + '.npy')
60 | cnt += 1
61 | layer_name = 'conv_6_sep.conv.weight'
62 | conv_dict[layer_name] = np.load(args.rank_path + 'rank_conv' + str(cnt) + '.npy')
63 | cnt += 1
64 | layer_name = 'conv_6_dw.conv.weight'
65 | conv_dict[layer_name] = np.load(args.rank_path + 'rank_conv' + str(cnt) + '.npy')
66 | cnt += 1
67 | print(len(conv_dict))
68 |
69 | for param_name in net_params:
70 | if model.state_dict()[param_name].dim() not in [4]:
71 | continue
72 |
73 | model_cpy = deepcopy(model)
74 |
75 | sensitivity = OrderedDict()
76 |
77 | # 对每一层循环剪枝并测试精度(从0.05->0.95)
78 | for sparsity_level in sparsities:
79 |
80 | sparsity_level = float(sparsity_level)
81 |
82 | print(param_name, sparsity_level)
83 |
84 | pruner = distiller.pruning.L1RankedStructureParameterPruner("sensitivity",
85 | group_type="Filters",
86 | desired_sparsity=sparsity_level,
87 | weights=param_name)
88 |
89 | policy = distiller.PruningPolicy(pruner, pruner_args=None)
90 | scheduler = CompressionScheduler(model_cpy)
91 | scheduler.add_policy(policy, epochs=[0])
92 |
93 | scheduler.on_epoch_begin(0, fpgm=args.fpgm, HRank=args.hrank, conv_index=conv_dict)
94 |
95 | scheduler.mask_all_weights()
96 |
97 | accuracy = test(args, model_cpy)
98 |
99 | print('剪枝{}后的精度为:{}'.format(sparsity_level, accuracy))
100 |
101 | sensitivity[sparsity_level] = (accuracy, 0, 0)
102 | sensitivities[param_name] = sensitivity
103 |
104 | return sensitivities
105 |
106 |
107 | def sensitivity_analysis(args):
108 |
109 | model = load_state_dict(args)
110 | model.eval()
111 | model.cuda()
112 |
113 | sensitivities = np.arange(0.0, 0.95, 0.1)
114 | which_params = [param_name for param_name, _ in model.named_parameters()]
115 |
116 | start_time = time.time()
117 |
118 | sensitivity = perform_sensitivity_analysis(model, which_params, sensitivities, args)
119 |
120 | end_time = time.time()
121 | print('剪枝敏感度分析总共耗时{}h'.format((end_time - start_time) / 3600))
122 | # distiller.sensitivities_to_png(sensitivity, 'work_space/sensitivity_data/sensitivity_{}.png'.format(args.model))
123 | distiller.sensitivities_to_csv(sensitivity, os.path.join(args.sensitivity_csv_path, 'sensitivity_{}_{}.csv'.format(args.model, get_time())))
124 |
125 |
126 |
127 |
128 |
129 |
130 |
--------------------------------------------------------------------------------
/src/__pycache__/data_loader.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/src/__pycache__/data_loader.cpython-36.pyc
--------------------------------------------------------------------------------
/src/__pycache__/data_loader.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/src/__pycache__/data_loader.cpython-37.pyc
--------------------------------------------------------------------------------
/src/__pycache__/dataset.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/src/__pycache__/dataset.cpython-36.pyc
--------------------------------------------------------------------------------
/src/__pycache__/dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/src/__pycache__/dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/src/data_loader.py:
--------------------------------------------------------------------------------
1 | '''
2 | Data load from image list
3 | 1. with no Normalize
4 | 2. with opencv load image
5 | 3. image is not divide 255
6 | @19-10-21 mod by ljt: add data augmentation
7 | @19-11-29 mod by zkx 调整数据预处理方式到dict结构进行调用
8 | '''
9 |
10 | import os
11 | from src.loader import transforms_V2 as trans
12 | from src.loader.autoaugment import ImageNetPolicy
13 | from torch.utils.data import DataLoader
14 | from src.loader.utility import opencv_loader, read_image_list_test
15 | from src.loader.read_image_list_io import DatasetFromList, DatasetFromListTriplet
16 | import torch
17 |
18 | pwd_path = os.path.dirname(__file__)
19 | from src.dataset import TrainSet,TestSet
20 |
21 |
22 | class DefineTrans:
23 | def __init__(self, input_size):
24 | self.image_preprocess = {
25 | 'D1':trans.Compose([
26 | trans.ToPILImage(),
27 | trans.RandomResizedCrop(input_size, scale=(0.99, 1.01)),
28 | trans.ColorJitter(brightness=0.125, contrast=0.125, saturation=0.125),
29 | trans.RandomHorizontalFlip(),
30 | trans.ToTensor(),
31 | ]),
32 | 'D2':trans.Compose([
33 | trans.ToPILImage(),
34 | trans.RandomResizedCrop(input_size, scale=(0.9, 1.1)),
35 | trans.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
36 | trans.RandomHorizontalFlip(),
37 | trans.ToTensor(),
38 | ]),
39 | 'D3': trans.Compose([
40 | trans.ToPILImage(),
41 | trans.RandomResizedCrop(input_size, scale=(0.9, 1.1)),
42 | trans.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
43 | trans.RandomRotation(5),
44 | trans.RandomHorizontalFlip(),
45 | trans.ToTensor(),
46 | ]),
47 | 'D9':trans.Compose([
48 | trans.ToPILImage(),
49 | trans.RandomHorizontalFlip(),
50 | ImageNetPolicy(),
51 | trans.ToTensor(),
52 | ]),
53 | 'None':trans.Compose([
54 | trans.ToTensor(),
55 | ])
56 |
57 | # others
58 | # trans.RandomPerspective(),
59 | # trans.RandomErasing(),
60 | # trans.RandomRotation(5),
61 | # trans.RandomApply([trans.RandomResizedCrop((width, height), scale=(0.99, 1.01))]),
62 | # trans.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
63 | }
64 |
65 | # 用于训练集的数据加载
66 | def get_train_dataset(img_root_path, image_list_path, data_aug):
67 | height = int(img_root_path.split('_')[-1].split('x')[0])
68 | width = int(img_root_path.split('_')[-1].split('x')[1])
69 |
70 | # @2019-11-29 zkx get input image process method from dict
71 | input_process = DefineTrans((height, width))
72 | if data_aug is None:
73 | train_transform = input_process.image_preprocess['None']
74 | else:
75 | train_transform = input_process.image_preprocess[data_aug]
76 |
77 | ds = DatasetFromList(img_root_path, image_list_path, opencv_loader, train_transform, None)
78 | class_num = ds[-1][1] + 1
79 | return ds, class_num
80 |
81 |
82 | def get_train_list_loader(conf):
83 | train_set = conf.data_mode
84 | img_root_path = TrainSet[train_set]['root_path']
85 | img_label_list = TrainSet[train_set]['label_list']
86 |
87 | print("img_label_list")
88 | print(img_label_list)
89 |
90 | patch_info = conf.patch_info
91 | root_path = '{}/{}'.format(img_root_path, patch_info)
92 | celeb_ds, celeb_class_num = get_train_dataset(root_path, img_label_list, conf.data_aug)
93 | print("images:")
94 | print(len(celeb_ds))
95 | ds = celeb_ds
96 | class_num = celeb_class_num
97 | if conf.distributed:
98 | train_sampler = torch.utils.data.distributed.DistributedSampler(ds)
99 | else:
100 | train_sampler = None
101 | # @2020-02-25 ljt 加入drop_last,解决bn在最后一个step报错的问题
102 | loader = DataLoader(ds, batch_size=conf.batch_size, shuffle=(train_sampler is None), pin_memory=conf.pin_memory,
103 | num_workers=conf.num_workers, sampler=train_sampler,drop_last=True)
104 | return loader, class_num, train_sampler
105 |
106 |
107 | # 用于测试集的数据加载
108 | def get_test_dataset(img_root_path, image_list_path):
109 |
110 | # @2019-11-29 zkx get input image process method from dict
111 | input_process = DefineTrans((1, 1))
112 | test_transform = input_process.image_preprocess["None"]
113 | ds = DatasetFromList(img_root_path, image_list_path, opencv_loader, test_transform, None, read_image_list_test)
114 | return ds
115 |
116 | def get_batch_test_data(image_root_path, image_list_path, batch_size, num_workers):
117 | ds = get_test_dataset(image_root_path, image_list_path)
118 | loader = DataLoader(ds, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=num_workers)
119 | return loader
120 |
121 |
--------------------------------------------------------------------------------
/src/loader/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/src/loader/__init__.py
--------------------------------------------------------------------------------
/src/loader/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/src/loader/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/src/loader/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/src/loader/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/src/loader/__pycache__/autoaugment.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/src/loader/__pycache__/autoaugment.cpython-36.pyc
--------------------------------------------------------------------------------
/src/loader/__pycache__/autoaugment.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/src/loader/__pycache__/autoaugment.cpython-37.pyc
--------------------------------------------------------------------------------
/src/loader/__pycache__/functional.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/src/loader/__pycache__/functional.cpython-36.pyc
--------------------------------------------------------------------------------
/src/loader/__pycache__/functional_V2.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/src/loader/__pycache__/functional_V2.cpython-36.pyc
--------------------------------------------------------------------------------
/src/loader/__pycache__/functional_V2.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/src/loader/__pycache__/functional_V2.cpython-37.pyc
--------------------------------------------------------------------------------
/src/loader/__pycache__/read_image_list_io.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/src/loader/__pycache__/read_image_list_io.cpython-36.pyc
--------------------------------------------------------------------------------
/src/loader/__pycache__/read_image_list_io.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/src/loader/__pycache__/read_image_list_io.cpython-37.pyc
--------------------------------------------------------------------------------
/src/loader/__pycache__/transforms.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/src/loader/__pycache__/transforms.cpython-36.pyc
--------------------------------------------------------------------------------
/src/loader/__pycache__/transforms_V2.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/src/loader/__pycache__/transforms_V2.cpython-36.pyc
--------------------------------------------------------------------------------
/src/loader/__pycache__/transforms_V2.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/src/loader/__pycache__/transforms_V2.cpython-37.pyc
--------------------------------------------------------------------------------
/src/loader/__pycache__/utility.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/src/loader/__pycache__/utility.cpython-36.pyc
--------------------------------------------------------------------------------
/src/loader/__pycache__/utility.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/src/loader/__pycache__/utility.cpython-37.pyc
--------------------------------------------------------------------------------
/src/loader/read_image_list_io.py:
--------------------------------------------------------------------------------
1 | #coding=utf-8
2 | '''
3 | DatasetFromList的作用
4 | 1.从list中读取图片的路径和标签
5 | 2.重载__getitem__属性,使得类的对象能够用[]操作符调用,返回图片和标签
6 | 3.这种做法将获取图片路径和加载图片集成到了一个类中完成,巧妙的做到了解耦
7 | '''
8 | import numpy as np
9 | from torch.utils.data import Dataset
10 | from .utility import read_image_list,read_image_triplet_list
11 |
12 | class DatasetFromList(Dataset):
13 | """A generic data loader where the image list arrange in this way: ::
14 |
15 | class_x/xxx.ext 0
16 | class_x/xxy.ext 0
17 | class_x/xxz.ext 0
18 |
19 | class_y/123.ext 1
20 | class_y/nsdf3.ext 1
21 | class_y/asd932_.ext 1
22 |
23 | Args:
24 | root (string): Root directory path.
25 | image_list_path (string) : where to load image list
26 | loader (callable): A function to load a sample given its path.
27 | image_list_loader (callable) : A function to read image-label pair or image-image-prefix pair
28 | transform (callable, optional): A function/transform that takes in
29 | a sample and returns a transformed version.
30 | E.g, ``transforms.RandomCrop`` for images.
31 | target_transform (callable, optional): A function/transform that takes
32 | in the target and transforms it.
33 |
34 | Attributes:
35 | samples (list): List of (sample path, image_index) tuples
36 | """
37 |
38 | def __init__(self, root, image_list_path, loader ,
39 | transform=None, target_transform=None, image_list_loader = read_image_list):
40 | samples = image_list_loader(root, image_list_path)
41 | if len(samples) == 0:
42 | raise(RuntimeError("Found 0 files in image_list: " + image_list_path + "\n"))
43 |
44 | self.root = root
45 | self.loader = loader
46 | self.samples = samples
47 | self.transform = transform
48 | self.target_transform = target_transform
49 |
50 | def __getitem__(self, index):
51 | """
52 | Args:
53 | index (int): Index
54 |
55 | Variable:
56 | self.samples (list): [image path, image label]
57 |
58 | Returns:
59 | tuple: (sample, target) where target is class_index of the target class.
60 | """
61 | path, target = self.samples[index]
62 | import os
63 | if not os.path.exists(path):
64 | print("Not exists..", path)
65 |
66 | sample = self.loader(path)
67 | assert isinstance(sample, np.ndarray) , path
68 |
69 | if self.transform is not None:
70 | sample = self.transform(sample)
71 | if self.target_transform is not None:
72 | target = self.target_transform(target)
73 |
74 | return sample, target
75 |
76 | def __len__(self):
77 | return len(self.samples)
78 |
79 | def __repr__(self):
80 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
81 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
82 | fmt_str += ' Root Location: {}\n'.format(self.root)
83 | tmp = ' Transforms (if any): '
84 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
85 | tmp = ' Target Transforms (if any): '
86 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
87 | return fmt_str
88 |
89 |
90 | class DatasetFromListTriplet(DatasetFromList):
91 | def __init__(self, root, image_list_path, loader ,
92 | transform=None, target_transform=None, image_list_loader = read_image_triplet_list):
93 | super(DatasetFromListTriplet,self).__init__(root, image_list_path, loader, transform,
94 | target_transform, image_list_loader)
95 |
96 | def __getitem__(self, index):
97 | """
98 | Args:
99 | index (tuple): includes (key, image_index)
100 | Variable:
101 | self.samples (dictionary): key(int) image_label, value(list) same label's image path
102 |
103 | Returns:
104 | tuple: (sample, target) where target is class_index of the target class.
105 | """
106 | target = index[0]
107 | image_path = self.samples[index[0]][index[1]]
108 | sample = self.loader(image_path)
109 |
110 | if self.transform is not None:
111 | sample = self.transform(sample)
112 | if self.target_transform is not None:
113 | target = self.target_transform(target)
114 |
115 | return sample, target
116 |
117 | def get_class_num(self):
118 | return len(self.samples)
119 |
120 |
121 |
122 |
123 |
--------------------------------------------------------------------------------
/src/loader/utility.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import cv2
3 |
4 | def read_image_list(root_path, image_list_path):
5 | f = open(image_list_path, 'r')
6 | # f = open(image_list_path, encoding='utf-8', mode='r')
7 | data = f.read().splitlines()
8 | f.close()
9 |
10 | samples = []
11 | for line in data:
12 | sample_path = '{}/{}'.format(root_path, line.split(' ')[0])
13 | class_index = int(line.split(' ')[1])
14 | samples.append((sample_path, class_index))
15 | return samples
16 |
17 | def read_image_triplet_list(root_path, image_list_path):
18 | f = open(image_list_path, 'r')
19 | data = f.read().splitlines()
20 | f.close()
21 |
22 | label_num = int(data[-1].split(' ')[-1])+1
23 | samples = {}
24 | for i in range(label_num):
25 | samples[i] = []
26 |
27 | for line in data:
28 | sample_path = '{}/{}'.format(root_path, line.split(' ')[0])
29 | class_index = int(line.split(' ')[1])
30 | samples[class_index].append(sample_path)
31 | return samples
32 |
33 |
34 | def read_image_list_test(root_path, image_list_path):
35 | f = open(image_list_path, 'r')
36 | data = f.read().splitlines()
37 | f.close()
38 |
39 | samples = []
40 | for line in data:
41 | sample_path = '{}/{}'.format(root_path, line.split(' ')[0])
42 | image_prefix = line.split(' ')[0]
43 | samples.append((sample_path, image_prefix))
44 | return samples
45 |
46 |
47 | def pil_loader(path):
48 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
49 | with open(path, 'rb') as f:
50 | img = Image.open(f)
51 | return img.convert('RGB')
52 | #return img.convert('BGR')
53 |
54 | def opencv_loader(path):
55 | img = cv2.imread(path)
56 | return img
57 |
58 |
59 | # def accimage_loader(path):
60 | # import accimage
61 | # try:
62 | # return accimage.Image(path)
63 | # except IOError:
64 | # # Potentially a decoding problem, fall back to PIL.Image
65 | # return pil_loader(path)
66 |
67 |
68 | # def default_loader(path):
69 | # from torchvision import get_image_backend
70 | # if get_image_backend() == 'accimage':
71 | # return accimage_loader(path)
72 | # else:
73 | # return pil_loader(path)
--------------------------------------------------------------------------------
/test_class.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # author: LinX
3 | # datetime: 2019/10/29 下午8:19
4 | import torch
5 | import torch.nn as nn
6 | from tensorboardX import SummaryWriter
7 | from MobileNetV3 import MobileNetV3_Large
8 |
9 | writer = SummaryWriter('/home/user1/linx/program/LightFaceNet/work_space/log')
10 |
11 | model = MobileNetV3_Large(4)
12 |
13 | model = nn.DataParallel(model)
14 |
15 | model.load_state_dict(torch.load('/home/user1/linx/program/LightFaceNet/work_space/models/model_train_best/2019-10-12'
16 | '-16-04_LiveBody_le_0.2_80x80_fake-20190924-train-data_live-0926_MobileNetv3Large'
17 | '-c4_pytorch_iter_14000.pth'))
18 |
19 | for name, param in model.named_parameters():
20 | if len(param.shape) == 4:
21 | param = param.view(param.shape[0], -1)
22 | param = torch.norm(param, dim=1)
23 | print(param.shape)
24 | writer.add_histogram(name, param.clone().cpu().data.numpy())
25 |
26 | # from scipy.spatial import distance
27 | # import numpy as np
28 | #
29 | # a = np.array([[1, 2], [3, 4]])
30 | # b = np.array([[2, 3], [4, 5]])
31 | # print(distance.cdist(a, a, metric='euclidean'))
32 |
--------------------------------------------------------------------------------
/test_module/__pycache__/test_on_diverse_dataset.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/test_module/__pycache__/test_on_diverse_dataset.cpython-36.pyc
--------------------------------------------------------------------------------
/test_module/__pycache__/test_on_diverse_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/test_module/__pycache__/test_on_diverse_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/test_module/__pycache__/test_on_face_classification.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/test_module/__pycache__/test_on_face_classification.cpython-36.pyc
--------------------------------------------------------------------------------
/test_module/__pycache__/test_on_face_classification.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/test_module/__pycache__/test_on_face_classification.cpython-37.pyc
--------------------------------------------------------------------------------
/test_module/__pycache__/test_on_face_recognition.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/test_module/__pycache__/test_on_face_recognition.cpython-36.pyc
--------------------------------------------------------------------------------
/test_module/__pycache__/test_on_face_recognition.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/test_module/__pycache__/test_on_face_recognition.cpython-37.pyc
--------------------------------------------------------------------------------
/test_module/__pycache__/test_with_insight_face.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/test_module/__pycache__/test_with_insight_face.cpython-36.pyc
--------------------------------------------------------------------------------
/test_module/test_on_diverse_dataset.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # author: LinX
3 | # datetime: 2019/11/12 下午2:02
4 | from test_module.test_on_face_recognition import TestOnFaceRecognition
5 | from test_module.test_on_face_classification import TestOnFaceClassification
6 | from test_module.test_with_insight_face import TestWithInsightFace
7 |
8 |
9 | def test(args, model):
10 | if args.model == 'mobilenetv3' or args.model == 'mobilefacenet_lzc' or args.model == 'resnet34_lzc' and args.data_source != 'lfw':
11 | test = TestOnFaceClassification(model, args.test_root_path, args.img_list_label_path)
12 | acc = test.test(args.test_batch_size)
13 | return acc
14 | elif args.model == 'resnet50_imagenet' or args.data_source == 'lfw':
15 | test = TestWithInsightFace(model)
16 | agedb_30, cfp_fp, lfw, agedb_30_issame, cfp_fp_issame, lfw_issame = test.get_val_data(args.test_root_path)
17 | acc_lfw = test.test(lfw, lfw_issame, args.test_batch_size, args.device)
18 | # acc_cfp = test.test(cfp_fp, cfp_fp_issame, args.test_batch_size, args.device)
19 | # acc_agedb = test.test(agedb_30, agedb_30_issame, args.test_batch_size, args.device)
20 | return acc_lfw
21 | else:
22 | test = TestOnFaceRecognition(model, args.test_root_path, args.img_list_label_path, args.data_source)
23 | accuracy = test.test(args.test_batch_size)
24 | return accuracy
--------------------------------------------------------------------------------
/test_module/test_on_face_classification.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # author: LinX
3 | # datetime: 2019/11/1 下午6:14
4 | import cv2
5 | import torch
6 | import numpy as np
7 | from sklearn.metrics import roc_curve, auc
8 | import torch.nn.functional as F
9 | from dataloader import Face_Classification_Data
10 | from tqdm import tqdm
11 |
12 |
13 | def opencv_loader(path):
14 | img = cv2.imread(path, -1)
15 | return img
16 |
17 |
18 | class TestOnFaceClassification:
19 |
20 | def __init__(self, model, root_path, list_label, device=0, t=1e-3):
21 | self.model = model
22 | self.root_path = root_path
23 | self.list_label = list_label
24 | self.image_root_path = root_path
25 | self.list_label = list_label
26 | self.device = device
27 | self.t = t
28 |
29 | def test(self, batch_size):
30 | self.model.eval()
31 | self.model.to(self.device)
32 | data_loader = Face_Classification_Data(self.root_path, self.list_label, batch_size).test_loader
33 | file = open(self.list_label, 'r')
34 | liens = file.readlines()
35 | file.close()
36 |
37 | label_list = []
38 | score_list = []
39 | with torch.no_grad():
40 | for img, label in tqdm(data_loader):
41 | input = img.to(self.device)
42 | out = self.model(input)
43 | out = F.softmax(out, dim=1)
44 |
45 | for idx in range(len(out)):
46 |
47 | label_list.append(label.cpu().numpy()[idx])
48 | score_list.append(out[idx].cpu().numpy()[1])
49 |
50 | fpr, tpr, thresholds = roc_curve(np.array(label_list), np.array(score_list), pos_label=1)
51 | fpr = np.around(fpr, decimals=7)
52 | index = np.argmin(abs(fpr - self.t))
53 | index_all = np.where(fpr == fpr[index])
54 | max_acc = np.max(tpr[index_all])
55 |
56 | return max_acc
57 |
58 |
--------------------------------------------------------------------------------
/train_module/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # author: linx
3 | # datetime 2020/9/1 下午1:43
4 |
--------------------------------------------------------------------------------
/train_module/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/train_module/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/train_module/__pycache__/train_with_insight_face.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/train_module/__pycache__/train_with_insight_face.cpython-36.pyc
--------------------------------------------------------------------------------
/work_space/finetune/2020-09-08-09-18/log/2020-09-08-09-18/events.out.tfevents.1599527940.yeluyue:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/work_space/finetune/2020-09-08-09-18/log/2020-09-08-09-18/events.out.tfevents.1599527940.yeluyue
--------------------------------------------------------------------------------
/work_space/layers_out/out.txt:
--------------------------------------------------------------------------------
1 | [64, 64, 52, 64, 52, 58, 52, 128, 116, 116, 116, 128, 116, 128, 116, 231, 231, 231, 231, 231, 231, 256, 231, 256, 231, 231, 231, 205, 231, 256, 231, 256, 231, 256, 231, 256, 231, 256, 231, 256, 231, 205, 231, 461, 461, 512, 461, 512, 461]
--------------------------------------------------------------------------------
/work_space/pruned_define_model/__pycache__/make_pruned_mobilefacenet_y2.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/work_space/pruned_define_model/__pycache__/make_pruned_mobilefacenet_y2.cpython-36.pyc
--------------------------------------------------------------------------------
/work_space/pruned_define_model/__pycache__/make_pruned_resnet50.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/work_space/pruned_define_model/__pycache__/make_pruned_resnet50.cpython-36.pyc
--------------------------------------------------------------------------------
/work_space/pruned_define_model/__pycache__/make_pruned_resnet50_imagenet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/work_space/pruned_define_model/__pycache__/make_pruned_resnet50_imagenet.cpython-36.pyc
--------------------------------------------------------------------------------
/work_space/sensitivity_data/resnet50_imagenet_0.9773/L1Rank/auto_yaml.yaml:
--------------------------------------------------------------------------------
1 | extensions:
2 | net_thinner:
3 | arch: None
4 | class: FilterRemover
5 | dataset: 112x112
6 | thinning_func_str: remove_filters
7 | policies:
8 | - epochs:
9 | - 0
10 | pruner:
11 | instance_name: filter_pruner_10
12 | - epochs:
13 | - 0
14 | pruner:
15 | instance_name: filter_pruner_20
16 | - epochs:
17 | - 0
18 | pruner:
19 | instance_name: filter_pruner_30
20 | - epochs:
21 | - 0
22 | pruner:
23 | instance_name: filter_pruner_50
24 | - epochs:
25 | - 0
26 | pruner:
27 | instance_name: filter_pruner_60
28 | - epochs:
29 | - 0
30 | pruner:
31 | instance_name: filter_pruner_70
32 | - epochs:
33 | - 0
34 | pruner:
35 | instance_name: filter_pruner_90
36 | - epochs:
37 | - 0
38 | extension:
39 | instance_name: net_thinner
40 | pruners:
41 | filter_pruner_10:
42 | class: L1RankedStructureParameterPruner
43 | desired_sparsity: 0.1
44 | group_type: Filters
45 | weights:
46 | - layer4.0.conv1.weight
47 | filter_pruner_20:
48 | class: L1RankedStructureParameterPruner
49 | desired_sparsity: 0.2
50 | group_type: Filters
51 | weights:
52 | - conv1.weight
53 | - layer2.0.conv2.weight
54 | - layer2.1.conv2.weight
55 | - layer2.2.conv2.weight
56 | - layer2.3.conv2.weight
57 | - layer3.0.conv2.weight
58 | - layer3.1.conv2.weight
59 | - layer3.2.conv2.weight
60 | - layer3.3.conv2.weight
61 | - layer3.4.conv2.weight
62 | - layer3.5.conv2.weight
63 | - layer4.0.conv2.weight
64 | - layer4.1.conv2.weight
65 | - layer4.2.conv2.weight
66 | filter_pruner_30:
67 | class: L1RankedStructureParameterPruner
68 | desired_sparsity: 0.3
69 | group_type: Filters
70 | weights:
71 | - layer3.3.conv1.weight
72 | - layer3.5.conv1.weight
73 | filter_pruner_50:
74 | class: L1RankedStructureParameterPruner
75 | desired_sparsity: 0.5
76 | group_type: Filters
77 | weights:
78 | - layer3.0.conv1.weight
79 | - layer4.1.conv1.weight
80 | filter_pruner_60:
81 | class: L1RankedStructureParameterPruner
82 | desired_sparsity: 0.6
83 | group_type: Filters
84 | weights:
85 | - layer2.2.conv1.weight
86 | - layer3.1.conv1.weight
87 | - layer3.2.conv1.weight
88 | - layer3.4.conv1.weight
89 | - layer4.2.conv1.weight
90 | filter_pruner_70:
91 | class: L1RankedStructureParameterPruner
92 | desired_sparsity: 0.7
93 | group_type: Filters
94 | weights:
95 | - layer1.0.conv1.weight
96 | - layer1.1.conv1.weight
97 | - layer1.2.conv1.weight
98 | - layer2.0.conv1.weight
99 | - layer2.1.conv1.weight
100 | - layer2.3.conv1.weight
101 | filter_pruner_90:
102 | class: L1RankedStructureParameterPruner
103 | desired_sparsity: 0.9
104 | group_type: Filters
105 | weights:
106 | - layer1.0.conv2.weight
107 | - layer1.1.conv2.weight
108 | - layer1.2.conv2.weight
109 | version: 1
110 |
--------------------------------------------------------------------------------
/work_space/sensitivity_data/resnet50_imagenet_0.9773/fpgm/auto_yaml.yaml:
--------------------------------------------------------------------------------
1 | extensions:
2 | net_thinner:
3 | arch: None
4 | class: FilterRemover
5 | dataset: 112x112
6 | thinning_func_str: remove_filters
7 | policies:
8 | - epochs:
9 | - 0
10 | pruner:
11 | instance_name: filter_pruner_10
12 | - epochs:
13 | - 0
14 | pruner:
15 | instance_name: filter_pruner_20
16 | - epochs:
17 | - 0
18 | pruner:
19 | instance_name: filter_pruner_30
20 | - epochs:
21 | - 0
22 | pruner:
23 | instance_name: filter_pruner_40
24 | - epochs:
25 | - 0
26 | pruner:
27 | instance_name: filter_pruner_50
28 | - epochs:
29 | - 0
30 | pruner:
31 | instance_name: filter_pruner_60
32 | - epochs:
33 | - 0
34 | pruner:
35 | instance_name: filter_pruner_70
36 | - epochs:
37 | - 0
38 | pruner:
39 | instance_name: filter_pruner_90
40 | - epochs:
41 | - 0
42 | extension:
43 | instance_name: net_thinner
44 | pruners:
45 | filter_pruner_10:
46 | class: L1RankedStructureParameterPruner
47 | desired_sparsity: 0.1
48 | group_type: Filters
49 | weights:
50 | - layer4.0.conv1.weight
51 | filter_pruner_20:
52 | class: L1RankedStructureParameterPruner
53 | desired_sparsity: 0.2
54 | group_type: Filters
55 | weights:
56 | - conv1.weight
57 | - layer2.0.conv2.weight
58 | - layer2.1.conv2.weight
59 | - layer2.2.conv2.weight
60 | - layer2.3.conv2.weight
61 | - layer3.0.conv2.weight
62 | - layer3.1.conv2.weight
63 | - layer3.2.conv2.weight
64 | - layer3.3.conv2.weight
65 | - layer3.4.conv2.weight
66 | - layer3.5.conv2.weight
67 | - layer4.0.conv2.weight
68 | - layer4.1.conv2.weight
69 | - layer4.2.conv2.weight
70 | filter_pruner_30:
71 | class: L1RankedStructureParameterPruner
72 | desired_sparsity: 0.3
73 | group_type: Filters
74 | weights:
75 | - layer3.3.conv1.weight
76 | - layer3.5.conv1.weight
77 | filter_pruner_40:
78 | class: L1RankedStructureParameterPruner
79 | desired_sparsity: 0.4
80 | group_type: Filters
81 | weights:
82 | - layer3.0.conv1.weight
83 | filter_pruner_50:
84 | class: L1RankedStructureParameterPruner
85 | desired_sparsity: 0.5
86 | group_type: Filters
87 | weights:
88 | - layer3.2.conv1.weight
89 | - layer4.1.conv1.weight
90 | filter_pruner_60:
91 | class: L1RankedStructureParameterPruner
92 | desired_sparsity: 0.6
93 | group_type: Filters
94 | weights:
95 | - layer3.4.conv1.weight
96 | filter_pruner_70:
97 | class: L1RankedStructureParameterPruner
98 | desired_sparsity: 0.7
99 | group_type: Filters
100 | weights:
101 | - layer1.0.conv1.weight
102 | - layer1.1.conv1.weight
103 | - layer1.2.conv1.weight
104 | - layer2.0.conv1.weight
105 | - layer2.1.conv1.weight
106 | - layer2.2.conv1.weight
107 | - layer2.3.conv1.weight
108 | - layer3.1.conv1.weight
109 | - layer4.2.conv1.weight
110 | filter_pruner_90:
111 | class: L1RankedStructureParameterPruner
112 | desired_sparsity: 0.9
113 | group_type: Filters
114 | weights:
115 | - layer1.0.conv2.weight
116 | - layer1.1.conv2.weight
117 | - layer1.2.conv2.weight
118 | version: 1
119 |
--------------------------------------------------------------------------------
/yaml_file/auto_yaml.yaml:
--------------------------------------------------------------------------------
1 | extensions:
2 | net_thinner:
3 | arch: None
4 | class: FilterRemover
5 | dataset: 112x112
6 | thinning_func_str: remove_filters
7 | policies:
8 | - epochs:
9 | - 0
10 | pruner:
11 | instance_name: filter_pruner_10
12 | - epochs:
13 | - 0
14 | pruner:
15 | instance_name: filter_pruner_20
16 | - epochs:
17 | - 0
18 | extension:
19 | instance_name: net_thinner
20 | pruners:
21 | filter_pruner_10:
22 | class: L1RankedStructureParameterPruner
23 | desired_sparsity: 0.1
24 | group_type: Filters
25 | weights:
26 | # - conv1.weight
27 | - layer1.2.conv1.weight
28 | - layer2.0.conv2.weight
29 | - layer2.0.downsample.0.weight
30 | - layer2.1.conv1.weight
31 | - layer2.1.conv2.weight
32 | - layer2.2.conv2.weight
33 | - layer2.3.conv2.weight
34 | - layer3.0.conv1.weight
35 | - layer3.0.conv2.weight
36 | - layer3.0.downsample.0.weight
37 | - layer3.1.conv1.weight
38 | - layer3.1.conv2.weight
39 | - layer3.2.conv1.weight
40 | - layer3.2.conv2.weight
41 | - layer3.3.conv2.weight
42 | - layer3.4.conv2.weight
43 | - layer3.5.conv1.weight
44 | - layer3.5.conv2.weight
45 | - layer3.6.conv2.weight
46 | - layer3.7.conv2.weight
47 | - layer3.8.conv2.weight
48 | - layer3.9.conv2.weight
49 | - layer3.10.conv2.weight
50 | - layer3.11.conv2.weight
51 | - layer3.12.conv2.weight
52 | - layer3.13.conv2.weight
53 | - layer4.0.conv1.weight
54 | - layer4.0.conv2.weight
55 | - layer4.0.downsample.0.weight
56 | - layer4.1.conv2.weight
57 | - layer4.2.conv2.weight
58 | filter_pruner_20:
59 | class: L1RankedStructureParameterPruner
60 | desired_sparsity: 0.2
61 | group_type: Filters
62 | weights:
63 | - layer1.0.conv2.weight
64 | - layer1.0.downsample.0.weight
65 | - layer1.1.conv2.weight
66 | - layer1.2.conv2.weight
67 | - layer3.6.conv1.weight
68 | - layer3.13.conv1.weight
69 | version: 1
70 |
--------------------------------------------------------------------------------