├── .gitignore ├── .idea ├── inspectionProfiles │ └── profiles_settings.xml ├── misc.xml ├── modules.xml ├── pruning_tool.iml ├── vcs.xml └── workspace.xml ├── README.md ├── __pycache__ ├── dataloader.cpython-36.pyc ├── dataloader.cpython-37.pyc ├── pruning.cpython-36.pyc ├── pruning.cpython-37.pyc ├── quantization.cpython-36.pyc ├── quantization.cpython-37.pyc ├── sensitivity_analysis.cpython-36.pyc └── sensitivity_analysis.cpython-37.pyc ├── commom_utils ├── __pycache__ │ ├── utils.cpython-36.pyc │ └── utils.cpython-37.pyc ├── my_to_tensor.py └── utils.py ├── data └── train_data │ ├── __pycache__ │ ├── ms1m_10k_loader.cpython-36.pyc │ └── verifacation.cpython-36.pyc │ ├── ms1m_10k_loader.py │ └── verifacation.py ├── dataloader.py ├── distiller ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── config.cpython-36.pyc │ ├── directives.cpython-36.pyc │ ├── knowledge_distillation.cpython-36.pyc │ ├── learning_rate.cpython-36.pyc │ ├── model_summaries.cpython-36.pyc │ ├── model_transforms.cpython-36.pyc │ ├── policy.cpython-36.pyc │ ├── scheduler.cpython-36.pyc │ ├── sensitivity.cpython-36.pyc │ ├── summary_graph.cpython-36.pyc │ ├── thinning.cpython-36.pyc │ ├── thresholding.cpython-36.pyc │ └── utils.cpython-36.pyc ├── apputils │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── checkpoint.cpython-36.pyc │ │ ├── data_loaders.cpython-36.pyc │ │ ├── dataset_summaries.cpython-36.pyc │ │ ├── execution_env.cpython-36.pyc │ │ └── image_classifier.cpython-36.pyc │ ├── checkpoint.py │ ├── data_loaders.py │ ├── dataset_summaries.py │ ├── execution_env.py │ └── image_classifier.py ├── config.py ├── data_loggers │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── collector.cpython-36.pyc │ │ ├── logger.cpython-36.pyc │ │ └── tbbackend.cpython-36.pyc │ ├── collector.py │ ├── logger.py │ └── tbbackend.py ├── directives.py ├── knowledge_distillation.py ├── learning_rate.py ├── model_summaries.py ├── model_transforms.py ├── models │ ├── __init__.py │ ├── __pycache__ │ │ └── __init__.cpython-36.pyc │ ├── cifar10 │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── plain_cifar.cpython-36.pyc │ │ │ ├── preresnet_cifar.cpython-36.pyc │ │ │ ├── resnet_cifar.cpython-36.pyc │ │ │ ├── resnet_cifar_earlyexit.cpython-36.pyc │ │ │ ├── simplenet_cifar.cpython-36.pyc │ │ │ └── vgg_cifar.cpython-36.pyc │ │ ├── plain_cifar.py │ │ ├── preresnet_cifar.py │ │ ├── resnet_cifar.py │ │ ├── resnet_cifar_earlyexit.py │ │ ├── simplenet_cifar.py │ │ └── vgg_cifar.py │ ├── imagenet │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── alexnet_batchnorm.cpython-36.pyc │ │ │ ├── mobilenet.cpython-36.pyc │ │ │ ├── mobilenet_dropout.cpython-36.pyc │ │ │ ├── preresnet_imagenet.cpython-36.pyc │ │ │ ├── resnet.cpython-36.pyc │ │ │ └── resnet_earlyexit.cpython-36.pyc │ │ ├── alexnet_batchnorm.py │ │ ├── mobilenet.py │ │ ├── mobilenet_dropout.py │ │ ├── preresnet_imagenet.py │ │ ├── resnet.py │ │ └── resnet_earlyexit.py │ └── mnist │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ └── simplenet_mnist.cpython-36.pyc │ │ └── simplenet_mnist.py ├── modules │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── aggregate.cpython-36.pyc │ │ ├── eltwise.cpython-36.pyc │ │ ├── grouping.cpython-36.pyc │ │ ├── matmul.cpython-36.pyc │ │ ├── rnn.cpython-36.pyc │ │ └── tsvd.cpython-36.pyc │ ├── aggregate.py │ ├── eltwise.py │ ├── grouping.py │ ├── matmul.py │ ├── rnn.py │ └── tsvd.py ├── policy.py ├── pruning │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── automated_gradual_pruner.cpython-36.pyc │ │ ├── baidu_rnn_pruner.cpython-36.pyc │ │ ├── greedy_filter_pruning.cpython-36.pyc │ │ ├── level_pruner.cpython-36.pyc │ │ ├── magnitude_pruner.cpython-36.pyc │ │ ├── pruner.cpython-36.pyc │ │ ├── ranked_structures_pruner.cpython-36.pyc │ │ ├── sensitivity_pruner.cpython-36.pyc │ │ ├── splicing_pruner.cpython-36.pyc │ │ └── structure_pruner.cpython-36.pyc │ ├── automated_gradual_pruner.py │ ├── baidu_rnn_pruner.py │ ├── greedy_filter_pruning.py │ ├── level_pruner.py │ ├── magnitude_pruner.py │ ├── pruner.py │ ├── ranked_structures_pruner.py │ ├── sensitivity_pruner.py │ ├── splicing_pruner.py │ └── structure_pruner.py ├── quantization │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── clipped_linear.cpython-36.pyc │ │ ├── q_utils.cpython-36.pyc │ │ ├── quantizer.cpython-36.pyc │ │ ├── range_linear.cpython-36.pyc │ │ └── sim_bn_fold.cpython-36.pyc │ ├── clipped_linear.py │ ├── q_utils.py │ ├── quantizer.py │ ├── range_linear.py │ └── sim_bn_fold.py ├── regularization │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── drop_filter.cpython-36.pyc │ │ ├── group_regularizer.cpython-36.pyc │ │ ├── l1_regularizer.cpython-36.pyc │ │ └── regularizer.cpython-36.pyc │ ├── drop_filter.py │ ├── group_regularizer.py │ ├── l1_regularizer.py │ └── regularizer.py ├── scheduler.py ├── sensitivity.py ├── summary_graph.py ├── thinning.py ├── thresholding.py └── utils.py ├── example~ ├── ljt_mobilefacenet_y2.sh ├── ljt_shufflefacenet_v2.sh ├── main.py ├── model_define ├── MobileFaceNet.py ├── MobileNetV3.py ├── __pycache__ │ ├── MobileFaceNet.cpython-36.pyc │ ├── MobileFaceNet.cpython-37.pyc │ ├── MobileNetV3.cpython-36.pyc │ ├── MobileNetV3.cpython-37.pyc │ ├── load_state_dict.cpython-36.pyc │ ├── load_state_dict.cpython-37.pyc │ ├── model.cpython-36.pyc │ ├── model.cpython-37.pyc │ ├── model_resnet.cpython-36.pyc │ ├── model_resnet.cpython-37.pyc │ └── resnet50_imagenet.cpython-36.pyc ├── load_state_dict.py ├── mobilefacenet_y2_ljt │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── mobilefacenet_big.cpython-36.pyc │ │ └── network_elems.cpython-36.pyc │ ├── common_utility.py │ ├── mobilefacenet_big.py │ └── network_elems.py ├── model.py ├── model_resnet.py ├── resnet50_imagenet.py ├── resnet_100_ljt │ ├── __pycache__ │ │ └── resnet_100.cpython-36.pyc │ └── resnet_100.py ├── resnet_50_ljt │ ├── __pycache__ │ │ └── resnet_50.cpython-36.pyc │ └── resnet_50.py └── shufflefacenet_v2_ljt │ ├── ShuffleFaceNetV2.py │ ├── __pycache__ │ └── ShuffleFaceNetV2.cpython-36.pyc │ └── blocks.py ├── mytest.py ├── prune.sh ├── pruning.py ├── pruning_analysis_tools ├── auto_make_yaml.py └── plot_csv.py ├── quantization.py ├── quantization.sh ├── resnet_ljt.sh ├── sensitivity_analysis.py ├── src ├── __pycache__ │ ├── data_loader.cpython-36.pyc │ ├── data_loader.cpython-37.pyc │ ├── dataset.cpython-36.pyc │ └── dataset.cpython-37.pyc ├── data_loader.py ├── dataset.py └── loader │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── autoaugment.cpython-36.pyc │ ├── autoaugment.cpython-37.pyc │ ├── functional.cpython-36.pyc │ ├── functional_V2.cpython-36.pyc │ ├── functional_V2.cpython-37.pyc │ ├── read_image_list_io.cpython-36.pyc │ ├── read_image_list_io.cpython-37.pyc │ ├── transforms.cpython-36.pyc │ ├── transforms_V2.cpython-36.pyc │ ├── transforms_V2.cpython-37.pyc │ ├── utility.cpython-36.pyc │ └── utility.cpython-37.pyc │ ├── autoaugment.py │ ├── functional.py │ ├── functional_V2.py │ ├── read_image_list_io.py │ ├── transforms.py │ ├── transforms_V2.py │ └── utility.py ├── test_class.py ├── test_module ├── __pycache__ │ ├── test_on_diverse_dataset.cpython-36.pyc │ ├── test_on_diverse_dataset.cpython-37.pyc │ ├── test_on_face_classification.cpython-36.pyc │ ├── test_on_face_classification.cpython-37.pyc │ ├── test_on_face_recognition.cpython-36.pyc │ ├── test_on_face_recognition.cpython-37.pyc │ └── test_with_insight_face.cpython-36.pyc ├── test_on_diverse_dataset.py ├── test_on_face_classification.py ├── test_on_face_recognition.py └── test_with_insight_face.py ├── train_module ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ └── train_with_insight_face.cpython-36.pyc └── train_with_insight_face.py ├── work_space ├── finetune │ └── 2020-09-08-09-18 │ │ └── log │ │ └── 2020-09-08-09-18 │ │ └── events.out.tfevents.1599527940.yeluyue ├── layers_out │ └── out.txt ├── pruned_define_model │ ├── __pycache__ │ │ ├── make_pruned_mobilefacenet_y2.cpython-36.pyc │ │ ├── make_pruned_resnet50.cpython-36.pyc │ │ └── make_pruned_resnet50_imagenet.cpython-36.pyc │ ├── make_prune_resnet100.py │ ├── make_prune_resnet34_lzc.py │ ├── make_pruned_mobilefacenet_lzc.py │ ├── make_pruned_mobilefacenet_y2.py │ ├── make_pruned_mobilefacenet_y2_ljt.py │ ├── make_pruned_mobilefacenet_zkx.py │ ├── make_pruned_resnet34.py │ ├── make_pruned_resnet50.py │ ├── make_pruned_resnet50_imagenet.py │ ├── make_pruned_resnet_100_ljt.py │ ├── make_pruned_resnet_50_ljt.py │ └── make_pruned_shufflefacenet_v2_ljt.py └── sensitivity_data │ ├── mobilefacenet_0.8228 │ └── sensitivity.csv │ ├── mobilefacenet_lzc_0.4755 │ ├── L1Rank │ │ └── sensitivity_mobilefacenet_lzc.csv │ └── fpgm │ │ └── sensitivity_mobilefacenet_lzc.csv │ ├── mobilefacenet_y2_0.8651 │ ├── L1Rank │ │ └── sensitivity_mobilefacenet_y2.csv │ └── fpgm │ │ └── sensitivity_mobilefacenet_y2_1.csv │ ├── mobilefacenet_y2_company_0.9867 │ └── sensitivity_mobilefacenet_y2_2020-09-04-13-08.csv │ ├── mobilefacenet_y2_ljt_BUSID │ ├── FPGM │ │ └── sensitivity_mobilefacenet_y2_ljt_2020-09-16-00-13.csv │ └── L1Rank │ │ └── sensitivity_mobilefacenet_y2_ljt_2020-09-15-15-52-BUSID.csv │ ├── mobilefacenet_y2_ljt_TYLG │ ├── FPGM │ │ └── sensitivity_mobilefacenet_y2_ljt_2020-09-16-00-48.csv │ └── L1Rank │ │ └── sensitivity_mobilefacenet_y2_ljt_2020-09-15-16-35-TYLG.csv │ ├── mobilefacenet_y2_ljt_XCH │ └── FPGM │ │ └── sensitivity_mobilefacenet_y2_ljt_2020-09-16-14-18.csv │ ├── mobilefacenet_y2_zkx_0.7889 │ ├── L1Rank │ │ └── sensitivity_mobilefacenet_y2.csv │ └── fpgm │ │ └── sensitivity_mobilefacenet_y2_2019-11-05-06-30.csv │ ├── mobilefacenet_y2_zkx_0.8118 │ ├── L1Rank │ │ └── sensitivity_mobilefacenet_y2_2019-11-04-09-33.csv │ └── fpgm │ │ └── sensitivity_mobilefacenet_y2_2019-11-05-05-59.csv │ ├── mobilenetv3_0.6613 │ ├── L1Rank │ │ └── sensitivity_mobilenetv3.csv │ └── fpgm │ │ └── sensitivity_mobilenetv3.csv │ ├── resnet100_0.7710 │ ├── L1Rank │ │ └── sensitivity_resnet100.csv │ └── fpgm │ │ └── sensitivity_resnet100_2019-11-07-20-49.csv │ ├── resnet34_0.7217 │ └── sensitivity_resnet34.csv │ ├── resnet34_lzc_0.6335 │ ├── L1Rank │ │ └── sensitivity_resnet34_lzc.csv │ └── fpgm │ │ └── sensitivity_resnet34_lzc.csv │ ├── resnet50_0.7517 │ ├── L1Rank │ │ └── sensitivity_resnet50.csv │ └── fpgm │ │ ├── sensitivity_resnet50_2019-11-07-13-42.csv │ │ └── sensitivity_resnet50_2019-11-13-15-04.csv │ ├── resnet50_imagenet_0.9773 │ ├── L1Rank │ │ ├── auto_yaml.yaml │ │ └── sensitivity_resnet50_imagenet_2020-09-02-12-37.csv │ └── fpgm │ │ ├── auto_yaml.yaml │ │ └── sensitivity_resnet50_imagenet_2020-08-27-17-22.csv │ ├── resnet_100_ljt │ └── FPGM │ │ └── sensitivity_resnet_100_ljt_2020-09-23-08-25.csv │ ├── resnet_50_ljt │ └── fpgm │ │ └── sensitivity_resnet_50_ljt_2020-09-22-23-00.csv │ ├── sensitivity_resnet50_2020-09-07-17-13.csv │ ├── shufflefacenet_v2_ljt_BUSID │ └── sensitivity_shufflefacenet_v2_ljt_2020-09-16-19-16.csv │ ├── shufflefacenet_v2_ljt_TYLG │ └── FPGM │ │ └── sensitivity_shufflefacenet_v2_ljt_2020-09-18-15-53.csv │ └── shufflefacenet_v2_ljt_XCH │ └── sensitivity_shufflefacenet_v2_ljt_2020-09-17-02-47.csv └── yaml_file └── auto_yaml.yaml /.gitignore: -------------------------------------------------------------------------------- 1 | *.pth 2 | *.png 3 | *.minivision-linx 4 | *.pt 5 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/pruning_tool.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # compression_tool 2 | A compression tool for minivision 3 | * 支持以下模型剪枝 4 | - [x] VGG 5 | - [x] ResNet 6 | - [x] MobileFaceNet 7 | - [x] ShuffleNet_v2 8 | * 支持以下剪枝算法 9 | - [x] L1Rank 10 | - [x] FPGM 11 | - [x] HRank 12 | - [ ] EagleEye(暂未放出) 13 | 14 | ## Install 15 | git clone https://github.com/BlossomingL/compression_tool 16 | 17 | 环境:pytorch 1.X 18 | 19 | ## 程序文件列表和文件功能 20 | * commen_utils: 一些常用的工具方法 21 | * data: 存放测试数据集和训练数据集,以及需要的pair list 22 | * model_define: 存放待剪枝的模型定义文件 23 | * pruning_analysis_tools: 两个剪枝分析工具,auto_make_yaml.py根据得到的csv文件自动生成yaml配置文件,plot_csv.py根据csv文件画出敏感度折线图。 24 | * test_module: 存放不同的测试模块(计算精度,ROC等) 25 | * train_module: 存放不同的训练模块 26 | * work_space: 存放各种模型文件,模型定义文件,以及敏感度分析数据(csv文件) 27 | * yaml_file: 存放自动生成的yaml文件 28 | * dataloader.py: 数据多线程批加载 29 | * main.py: 主程序入口 30 | * prune.sh: 剪枝脚本 31 | * pruning.py: 剪枝相关代码 32 | * quantization.py: 量化相关代码 33 | * quantization.sh: 量化运行脚本 34 | * sensitivity_analysis.py: 敏感度分析,生成csv文件 35 | 36 | ## 剪枝运行流程 37 | * 准备阶段:本工具运行需要安装distiller,安装文件位于Distiller,在Distiller目录下打开命令行,运行python setup.py install。准备模型定义文件放入model_define文件下,准备训练最高精度的pt文件放在work_space/model_train_best下,准备测试集放在data下以及测试集测试代码放在test_module下。 38 | 39 | * 敏感度分析: 运行pruning.py文件,例如对resnet100进行剪枝敏感度分析,sh脚本如下 40 | 41 | ```python 42 | python pruning.py --mode sa \ # sa(sensitivity analysis)表示进入敏感度分析模式 43 | --model resnet100 \ # 模型名称,只能从固定的几个选择 44 | --best_model_path work_space/model_train_best/2019-09-29-11-37_SVGArcFace-O1-b0.4s40t1.1_fc_0.4_112x112_2019-09-27-Adult-padSY-Bus_fResNet100v3cv-d512_model_iter-340000.pth \ # 训练好的模型文件 45 | --from_data_parallel \ # 上面的模型文件是否是多卡训练得到 46 | --test_root_path data/test_data/fc_0.4_112x112 \ # 测试集root路径 47 | --img_list_label_path # pair list路径 data/test_data/fc_0.4_112x112/pair_list/id_life_image_list_bmppair.txt \ 48 | --fpgm \ # 采用fpgm算法剪枝 49 | --data_source company # 数据集来源 50 | ``` 51 | 运行过后会生成一个csv文件,在work_space/sensitivity_data下 52 | * 生成yaml文件: 运行pruning_analysis_tools下的auto_make_yaml.py文件,其中config_yaml函数的参数需要自己配置,参数1:csv文件路径,参数2:期望剪枝后的精度,参数3:模型名称,比如上述resnet100,那么此参数为resnet100,参数4: 输入图像大小。运行后会在yaml_file文件夹下生成一个yaml文件。 53 | 54 | * 剪枝:运行pruning.py文件,例如对resnet100进行剪枝,sh脚本如下 55 | ```python 56 | python pruning.py --mode prune \ 57 | --model resnet100 \ 58 | --best_model_path work_space/model_train_best/2019-09-29-11-37_SVGArcFace-O1-b0.4s40t1.1_fc_0.4_112x112_2019-09-27-Adult-padSY-Bus_fResNet100v3cv-d512_model_iter-340000.pth \ 59 | --from_data_parallel \ 60 | --fpgm \ 61 | --save_model_pt \ # 是否保存剪枝后的模型文件 62 | --test_root_path data/test_data/fc_0.4_112x112 \ 63 | --img_list_label_path data/test_data/fc_0.4_112x112/pair_list/id_life_image_list_bmppair.txt \ 64 | --data_source company 65 | ``` 66 | 运行后会生成一个剪枝后的模型文件,保存在work_space/pruned_model,文件名与模型名称一致,并且还会打印出每层的out参数,此参数即剪枝后模型定义文件中的keep数组。 67 | 68 | ## 关于distiller 69 | 由于此工具硬剪枝部分(即真正将通道移除)的代码是采用distiller框架中的代码,因为模型的特殊,需要更改框架源码才能进行剪枝,下面对更改的部分说明: 70 | * distiller/apputils下的data_loaders.py文件classification_get_input_shape函数中dataset与yaml文件中dataset参数的值一样,例如:如果输入网络图像大小为80x80那么yaml文件中的dataset参数也就是80x80,如果需要添加其它类型的输入,那么就要更改此代码。 71 | * distiller/policy.py下添加了fpgm的参数选项 72 | * distiller/thinning.py下添加了两个个功能: 73 | * 能够剪PReLU层,具体函数为handle_prelu_layers,append_prelu_thinning_directive(注:如果PReLU层采用默认的参数1,那么需要将此代码注释掉,否则会出错) 74 | * 针对公司Block的第一层为BN层,源代码本身不支持对此层剪枝,更改后可支持。具体函数为handle_bn_layers_bn1。 75 | * distiller/pruning/ranked_structures_pruner.py下添加了fpgm算法。具体代码为rank_and_prune_filters函数中if fpgm开始到if结束。 76 | * distiller/summary_graph下更改源码一处BUG,在add_footprint_attr函数下加入try/catch模块 77 | * 增加CVPR 2020 HRank剪枝方法 78 | 79 | ## 参考 80 | [1] HRank:https://arxiv.org/abs/2002.10179 81 | [2] FPGM: https://arxiv.org/abs/1811.00250 82 | [3] Distiller: https://github.com/NervanaSystems/distiller 83 | -------------------------------------------------------------------------------- /__pycache__/dataloader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/__pycache__/dataloader.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/dataloader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/__pycache__/dataloader.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/pruning.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/__pycache__/pruning.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/pruning.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/__pycache__/pruning.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/quantization.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/__pycache__/quantization.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/quantization.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/__pycache__/quantization.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/sensitivity_analysis.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/__pycache__/sensitivity_analysis.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/sensitivity_analysis.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/__pycache__/sensitivity_analysis.cpython-37.pyc -------------------------------------------------------------------------------- /commom_utils/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/commom_utils/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /commom_utils/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/commom_utils/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /commom_utils/my_to_tensor.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: LinX 3 | # datetime: 2019/11/1 下午2:04 4 | import torch 5 | import numpy as np 6 | 7 | 8 | def to_tensor(pic): 9 | 10 | if isinstance(pic, np.ndarray): 11 | # handle numpy array 12 | if pic.ndim == 2: 13 | pic = pic[:, :, None] 14 | 15 | img = torch.from_numpy(pic.transpose((2, 0, 1))) 16 | # backward compatibility 17 | if isinstance(img, torch.ByteTensor): 18 | return img.float() 19 | else: 20 | return img 21 | # handle PIL Image 22 | if pic.mode == 'I': 23 | img = torch.from_numpy(np.array(pic, np.int32, copy=False)) 24 | elif pic.mode == 'I;16': 25 | img = torch.from_numpy(np.array(pic, np.int16, copy=False)) 26 | elif pic.mode == 'F': 27 | img = torch.from_numpy(np.array(pic, np.float32, copy=False)) 28 | elif pic.mode == '1': 29 | img = 255 * torch.from_numpy(np.array(pic, np.uint8, copy=False)) 30 | else: 31 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 32 | # PIL image mode: L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK 33 | if pic.mode == 'YCbCr': 34 | nchannel = 3 35 | elif pic.mode == 'I;16': 36 | nchannel = 1 37 | else: 38 | nchannel = len(pic.mode) 39 | img = img.view(pic.size[1], pic.size[0], nchannel) 40 | # put it from HWC to CHW format 41 | # yikes, this transpose takes 80% of the loading time/CPU 42 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 43 | if isinstance(img, torch.ByteTensor): 44 | return img.float() 45 | else: 46 | return img 47 | 48 | 49 | class ToTensor: 50 | def __init__(self): 51 | pass 52 | 53 | def __call__(self, pic): 54 | """ 55 | Args: 56 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor. 57 | 58 | Returns: 59 | Tensor: Converted image. 60 | """ 61 | return to_tensor(pic) 62 | -------------------------------------------------------------------------------- /commom_utils/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: LinX 3 | # datetime: 2019/10/10 下午2:22 4 | 5 | import time 6 | from thop import profile 7 | from tqdm import tqdm 8 | from model_define.model import ResNet34 9 | from model_define.resnet50_imagenet import resnet_50 10 | from timeit import default_timer as timer 11 | import torch 12 | from datetime import datetime 13 | from torchvision import transforms as trans 14 | import matplotlib.pyplot as plt 15 | plt.switch_backend('agg') 16 | import io 17 | 18 | 19 | def l2_norm(input, axis=1): 20 | norm = torch.norm(input, 2, axis, True) 21 | output = torch.div(input, norm) 22 | return output 23 | 24 | 25 | def de_preprocess(tensor): 26 | return tensor*0.5 + 0.5 27 | 28 | 29 | def get_time(): 30 | return (str(datetime.now())[:-10]).replace(' ', '-').replace(':', '-') 31 | 32 | 33 | def test_speed(model, shape=[1, 3, 122, 122], device='gpu', test_time=10000): 34 | model.eval() 35 | inputs = torch.rand(shape) 36 | if device == 'gpu': 37 | model = model.to('cuda') 38 | inputs = inputs.to('cuda') 39 | else: 40 | model = model.to('cpu') 41 | print('Testing forward time,this may take a few minutes') 42 | start_time = timer() 43 | with torch.no_grad(): 44 | for i in tqdm(range(test_time)): 45 | model(inputs) 46 | count = timer() - start_time 47 | forward_time = (count / test_time) * 1000 48 | print('平均forward时间为{}ms'.format(forward_time)) 49 | return forward_time 50 | 51 | 52 | def cal_flops(model, input_shape, device='gpu'): 53 | input_random = torch.rand(input_shape) 54 | if device == 'gpu': 55 | input_random = input_random.to('cuda') 56 | model = model.to('cuda') 57 | else: 58 | model = model.to('cpu') 59 | flops, params = profile(model, inputs=(input_random, ), verbose=False) 60 | return flops / (1024 * 1024 * 1024), params / (1024 * 1024) 61 | 62 | 63 | hflip = trans.Compose([ 64 | de_preprocess, 65 | trans.ToPILImage(), 66 | trans.functional.hflip, 67 | trans.ToTensor(), 68 | trans.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 69 | ]) 70 | 71 | 72 | def hflip_batch(imgs_tensor): 73 | hfliped_imgs = torch.empty_like(imgs_tensor) 74 | for i, img_ten in enumerate(imgs_tensor): 75 | hfliped_imgs[i] = hflip(img_ten) 76 | return hfliped_imgs 77 | 78 | 79 | def separate_bn_paras(modules): 80 | if not isinstance(modules, list): 81 | modules = [*modules.modules()] 82 | paras_only_bn = [] 83 | paras_wo_bn = [] 84 | for layer in modules: 85 | if 'model' in str(layer.__class__): 86 | continue 87 | if 'container' in str(layer.__class__): 88 | continue 89 | else: 90 | if 'batchnorm' in str(layer.__class__): 91 | paras_only_bn.extend([*layer.parameters()]) 92 | else: 93 | paras_wo_bn.extend([*layer.parameters()]) 94 | return paras_only_bn, paras_wo_bn 95 | 96 | 97 | def gen_plot(fpr, tpr): 98 | """Create a pyplot plot and save to buffer.""" 99 | plt.figure() 100 | plt.xlabel("FPR", fontsize=14) 101 | plt.ylabel("TPR", fontsize=14) 102 | plt.title("ROC Curve", fontsize=14) 103 | plot = plt.plot(fpr, tpr, linewidth=2) 104 | buf = io.BytesIO() 105 | plt.savefig(buf, format='jpeg') 106 | buf.seek(0) 107 | plt.close() 108 | return buf 109 | 110 | 111 | def main(): 112 | # model = ResNet34() 113 | # state_dict = torch.load('/home/user1/linx/program/LightFaceNet/work_space/models/model_train_best' 114 | # '/resnet34_model_2019-10-12-19-20_accuracy:0.7216981_step:84816_lin.pth') 115 | # model.load_state_dict(state_dict) 116 | # test_speed(model) 117 | 118 | # model = torch.load('/home/user1/linx/program/LightFaceNet/work_space/models/pruned_model/model_resnet34.pkl') 119 | # state_dict = torch.load('/home/user1/linx/program/LightFaceNet/work_space/models/pruned_model' 120 | # '/resnet34_best_pruned_0.6556604.pt') 121 | # model.load_state_dict(state_dict) 122 | # test_speed(model) 123 | 124 | model = resnet_50() 125 | state_dict = torch.load('/home/linx/program/InsightFace_Pytorch/work_space/2020-08-20-10-32/models/model_accuracy' 126 | ':0.9708333333333334_step:163760_best_acc_lfw.pth') 127 | model.load_state_dict(state_dict) 128 | print(cal_flops(model, (1, 3, 112, 112))) 129 | 130 | 131 | if __name__ == '__main__': 132 | main() 133 | 134 | -------------------------------------------------------------------------------- /data/train_data/__pycache__/ms1m_10k_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/data/train_data/__pycache__/ms1m_10k_loader.cpython-36.pyc -------------------------------------------------------------------------------- /data/train_data/__pycache__/verifacation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/data/train_data/__pycache__/verifacation.cpython-36.pyc -------------------------------------------------------------------------------- /data/train_data/ms1m_10k_loader.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: linx 3 | # datetime 2020/9/1 上午11:35 4 | from pathlib import Path 5 | from torch.utils.data import Dataset, ConcatDataset, DataLoader 6 | from torchvision import transforms as trans 7 | from torchvision.datasets import ImageFolder 8 | from PIL import Image, ImageFile 9 | 10 | ImageFile.LOAD_TRUNCATED_IMAGES = True 11 | import numpy as np 12 | import cv2 13 | import bcolz 14 | import pickle 15 | import mxnet as mx 16 | from tqdm import tqdm 17 | import os 18 | 19 | 20 | def de_preprocess(tensor): 21 | return tensor * 0.5 + 0.5 22 | 23 | 24 | def get_train_dataset(imgs_folder): 25 | train_transform = trans.Compose([ 26 | trans.RandomHorizontalFlip(), 27 | trans.ToTensor(), 28 | trans.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 29 | ]) 30 | ds = ImageFolder(imgs_folder, train_transform) 31 | class_num = ds[-1][1] + 1 32 | return ds, class_num 33 | 34 | 35 | def get_train_loader(args): 36 | ms1m_ds, ms1m_class_num = get_train_dataset(os.path.join(args.train_data_path, 'imgs')) 37 | print('ms1m loader generated') 38 | 39 | ds = ms1m_ds 40 | class_num = ms1m_class_num 41 | 42 | loader = DataLoader(ds, batch_size=args.train_batch_size, shuffle=True, pin_memory=args.pin_memory, 43 | num_workers=args.num_workers) 44 | return loader, class_num 45 | 46 | 47 | def load_bin(path, rootdir, transform, image_size=[112, 112]): 48 | if not rootdir.exists(): 49 | rootdir.mkdir() 50 | bins, issame_list = pickle.load(open(path, 'rb'), encoding='bytes') 51 | data = bcolz.fill([len(bins), 3, image_size[0], image_size[1]], dtype=np.float32, rootdir=rootdir, mode='w') 52 | for i in range(len(bins)): 53 | _bin = bins[i] 54 | img = mx.image.imdecode(_bin).asnumpy() 55 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 56 | img = Image.fromarray(img.astype(np.uint8)) 57 | data[i, ...] = transform(img) 58 | i += 1 59 | if i % 1000 == 0: 60 | print('loading bin', i) 61 | print(data.shape) 62 | np.save(str(rootdir) + '_list', np.array(issame_list)) 63 | return data, issame_list 64 | 65 | 66 | def get_val_pair(path, name): 67 | carray = bcolz.carray(rootdir=os.path.join(path, name), mode='r') 68 | issame = np.load(os.path.join(path, '{}_list.npy'.format(name))) 69 | return carray, issame 70 | 71 | 72 | def get_val_data(data_path): 73 | agedb_30, agedb_30_issame = get_val_pair(data_path, 'agedb_30') # 12000图片对,6000个相对应的标签 74 | cfp_fp, cfp_fp_issame = get_val_pair(data_path, 'cfp_fp') 75 | lfw, lfw_issame = get_val_pair(data_path, 'lfw') 76 | return agedb_30, cfp_fp, lfw, agedb_30_issame, cfp_fp_issame, lfw_issame 77 | 78 | 79 | def load_mx_rec(rec_path): 80 | save_path = rec_path / 'imgs' 81 | if not save_path.exists(): 82 | save_path.mkdir() 83 | imgrec = mx.recordio.MXIndexedRecordIO(str(rec_path / 'train.idx'), str(rec_path / 'train.rec'), 'r') 84 | img_info = imgrec.read_idx(0) 85 | header, _ = mx.recordio.unpack(img_info) 86 | max_idx = int(header.label[0]) 87 | for idx in tqdm(range(1, max_idx)): 88 | img_info = imgrec.read_idx(idx) 89 | header, img = mx.recordio.unpack_img(img_info) 90 | label = int(header.label) 91 | img = Image.fromarray(img) 92 | label_path = save_path / str(label) 93 | if not label_path.exists(): 94 | label_path.mkdir() 95 | img.save(label_path / '{}.jpg'.format(idx), quality=95) 96 | -------------------------------------------------------------------------------- /distiller/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | import torch 18 | from .utils import * 19 | from .thresholding import GroupThresholdMixin, threshold_mask, group_threshold_mask 20 | from .config import file_config, dict_config, config_component_from_file_by_class 21 | from .model_summaries import * 22 | from .scheduler import * 23 | from .sensitivity import * 24 | from .directives import * 25 | from .policy import * 26 | from .thinning import * 27 | from .knowledge_distillation import KnowledgeDistillationPolicy, DistillationLossWeights 28 | from .summary_graph import SummaryGraph, onnx_name_2_pytorch_name 29 | 30 | import logging 31 | logging.captureWarnings(True) 32 | 33 | del dict_config 34 | del thinning 35 | 36 | # Distiller version 37 | __version__ = "0.4.0-pre" 38 | 39 | 40 | def model_find_param_name(model, param_to_find): 41 | """Look up the name of a model parameter. 42 | 43 | Arguments: 44 | model: the model to search 45 | param_to_find: the parameter whose name we want to look up 46 | 47 | Returns: 48 | The parameter name (string) or None, if the parameter was not found. 49 | """ 50 | for name, param in model.named_parameters(): 51 | if param is param_to_find: 52 | return name 53 | return None 54 | 55 | 56 | def model_find_module_name(model, module_to_find): 57 | """Look up the name of a module in a model. 58 | 59 | Arguments: 60 | model: the model to search 61 | module_to_find: the module whose name we want to look up 62 | 63 | Returns: 64 | The module name (string) or None, if the module was not found. 65 | """ 66 | for name, m in model.named_modules(): 67 | if m == module_to_find: 68 | return name 69 | return None 70 | 71 | 72 | def model_find_param(model, param_to_find_name): 73 | """Look a model parameter by its name 74 | 75 | Arguments: 76 | model: the model to search 77 | param_to_find_name: the name of the parameter that we are searching for 78 | 79 | Returns: 80 | The parameter or None, if the paramter name was not found. 81 | """ 82 | for name, param in model.named_parameters(): 83 | if name == param_to_find_name: 84 | return param 85 | return None 86 | 87 | 88 | def model_find_module(model, module_to_find): 89 | """Given a module name, find the module in the provided model. 90 | 91 | Arguments: 92 | model: the model to search 93 | module_to_find: the module whose name we want to look up 94 | 95 | Returns: 96 | The module or None, if the module was not found. 97 | """ 98 | for name, m in model.named_modules(): 99 | if name == module_to_find: 100 | return m 101 | return None 102 | 103 | 104 | def check_pytorch_version(): 105 | from pkg_resources import parse_version 106 | if parse_version(torch.__version__) < parse_version('1.1.0'): 107 | msg = "\n\nWRONG PYTORCH VERSION\n"\ 108 | "The Distiller \'master\' branch now requires at least PyTorch version 1.1.0 due to "\ 109 | "PyTorch API changes which are not backward-compatible. Version detected is {}.\n"\ 110 | "To make sure PyTorch and all other dependencies are installed with their correct versions, " \ 111 | "go to the Distiller repo root directory and run:\n\n"\ 112 | "pip install -e .\n".format(torch.__version__) 113 | raise RuntimeError(msg) 114 | 115 | 116 | check_pytorch_version() 117 | -------------------------------------------------------------------------------- /distiller/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/__pycache__/config.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/__pycache__/config.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/__pycache__/directives.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/__pycache__/directives.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/__pycache__/knowledge_distillation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/__pycache__/knowledge_distillation.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/__pycache__/learning_rate.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/__pycache__/learning_rate.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/__pycache__/model_summaries.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/__pycache__/model_summaries.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/__pycache__/model_transforms.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/__pycache__/model_transforms.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/__pycache__/policy.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/__pycache__/policy.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/__pycache__/scheduler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/__pycache__/scheduler.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/__pycache__/sensitivity.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/__pycache__/sensitivity.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/__pycache__/summary_graph.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/__pycache__/summary_graph.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/__pycache__/thinning.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/__pycache__/thinning.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/__pycache__/thresholding.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/__pycache__/thresholding.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/apputils/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | """This package contains Python code and classes that are meant to make your life easier, 18 | when working with distiller. 19 | 20 | """ 21 | from .data_loaders import * 22 | from .checkpoint import * 23 | from .execution_env import * 24 | from .dataset_summaries import * 25 | 26 | del data_loaders 27 | del checkpoint 28 | del execution_env 29 | del dataset_summaries 30 | -------------------------------------------------------------------------------- /distiller/apputils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/apputils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/apputils/__pycache__/checkpoint.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/apputils/__pycache__/checkpoint.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/apputils/__pycache__/data_loaders.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/apputils/__pycache__/data_loaders.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/apputils/__pycache__/dataset_summaries.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/apputils/__pycache__/dataset_summaries.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/apputils/__pycache__/execution_env.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/apputils/__pycache__/execution_env.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/apputils/__pycache__/image_classifier.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/apputils/__pycache__/image_classifier.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/apputils/dataset_summaries.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2019 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | import distiller 18 | import numpy as np 19 | import logging 20 | msglogger = logging.getLogger() 21 | 22 | def dataset_summary(data_loader): 23 | """Create a histogram of class membership distribution within a dataset. 24 | 25 | It is important to examine our training, validation, and test 26 | datasets, to make sure that they are balanced. 27 | """ 28 | msglogger.info("Analyzing dataset:") 29 | print_frequency = 50 30 | for batch, (input, label_batch) in enumerate(data_loader): 31 | try: 32 | all_labels = np.append(all_labels, distiller.to_np(label_batch)) 33 | except NameError: 34 | all_labels = distiller.to_np(label_batch) 35 | if (batch+1) % print_frequency == 0: 36 | # progress indicator 37 | print("batch: %d" % batch) 38 | 39 | hist = np.histogram(all_labels, bins=np.arange(1000+1)) 40 | nclasses = len(hist[0]) 41 | for data_class, size in enumerate(hist[0]): 42 | msglogger.info("\tClass {} = {}".format(data_class, size)) 43 | msglogger.info("Dataset contains {} items".format(len(data_loader.sampler))) 44 | msglogger.info("Found {} classes".format(nclasses)) 45 | msglogger.info("Average: {} samples per class".format(np.mean(hist[0]))) 46 | -------------------------------------------------------------------------------- /distiller/data_loggers/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from .collector import * 18 | from .logger import PythonLogger, TensorBoardLogger, CsvLogger 19 | 20 | del logger 21 | del collector 22 | -------------------------------------------------------------------------------- /distiller/data_loggers/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/data_loggers/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/data_loggers/__pycache__/collector.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/data_loggers/__pycache__/collector.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/data_loggers/__pycache__/logger.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/data_loggers/__pycache__/logger.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/data_loggers/__pycache__/tbbackend.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/data_loggers/__pycache__/tbbackend.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/data_loggers/tbbackend.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | """ A TensorBoard backend. 17 | 18 | Writes logs to a file using a Google's TensorBoard protobuf format. 19 | See: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/summary.proto 20 | """ 21 | import os 22 | # Disable FutureWarning from TensorFlow 23 | import warnings 24 | warnings.simplefilter(action='ignore', category=FutureWarning) 25 | import tensorflow as tf 26 | import numpy as np 27 | 28 | 29 | class TBBackend(object): 30 | def __init__(self, log_dir): 31 | self.writers = [] 32 | self.log_dir = log_dir 33 | self.writers.append(tf.summary.FileWriter(log_dir)) 34 | 35 | def scalar_summary(self, tag, scalar, step): 36 | """From TF documentation: 37 | tag: name for the data. Used by TensorBoard plugins to organize data. 38 | value: value associated with the tag (a float). 39 | """ 40 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=scalar)]) 41 | self.writers[0].add_summary(summary, step) 42 | 43 | def list_summary(self, tag, list, step, multi_graphs): 44 | """Log a relatively small list of scalars. 45 | 46 | We want to track the progress of multiple scalar parameters in a single graph. 47 | The list provides a single value for each of the parameters we are tracking. 48 | 49 | NOTE: There are two ways to log multiple values in TB and neither one is optimal. 50 | 1. Use a single writer: in this case all of the parameters use the same color, and 51 | distinguishing between them is difficult. 52 | 2. Use multiple writers: in this case each parameter has its own color which helps 53 | to visually separate the parameters. However, each writer logs to a different 54 | file and this creates a lot of files which slow down the TB load. 55 | """ 56 | for i, scalar in enumerate(list): 57 | if multi_graphs and (i+1 > len(self.writers)): 58 | self.writers.append(tf.summary.FileWriter(os.path.join(self.log_dir, str(i)))) 59 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=scalar)]) 60 | self.writers[0 if not multi_graphs else i].add_summary(summary, step) 61 | 62 | def histogram_summary(self, tag, tensor, step): 63 | """ 64 | From the TF documentation: 65 | tf.summary.histogram takes an arbitrarily sized and shaped Tensor, and 66 | compresses it into a histogram data structure consisting of many bins with 67 | widths and counts. 68 | 69 | TensorFlow uses non-uniformly distributed bins, which is better than using 70 | numpy's uniform bins for activations and parameters which converge around zero, 71 | but we don't add that logic here. 72 | 73 | https://www.tensorflow.org/programmers_guide/tensorboard_histograms 74 | """ 75 | hist, edges = np.histogram(tensor, bins=200) 76 | tfhist = tf.HistogramProto( 77 | min=np.min(tensor), 78 | max=np.max(tensor), 79 | num=int(np.prod(tensor.shape)), 80 | sum=np.sum(tensor), 81 | sum_squares=np.sum(np.square(tensor))) 82 | 83 | # From the TF documentation: 84 | # Parallel arrays encoding the bucket boundaries and the bucket values. 85 | # bucket(i) is the count for the bucket i. The range for a bucket is: 86 | # i == 0: -DBL_MAX .. bucket_limit(0) 87 | # i != 0: bucket_limit(i-1) .. bucket_limit(i) 88 | tfhist.bucket_limit.extend(edges[1:]) 89 | tfhist.bucket.extend(hist) 90 | 91 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=tfhist)]) 92 | self.writers[0].add_summary(summary, step) 93 | 94 | def sync_to_file(self): 95 | for writer in self.writers: 96 | writer.flush() 97 | -------------------------------------------------------------------------------- /distiller/directives.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | """Scheduling directives 18 | 19 | Scheduling directives are instructions (directives) that the scheduler can 20 | execute as part of scheduling pruning activities. 21 | """ 22 | from __future__ import division 23 | import torch 24 | import numpy as np 25 | from collections import defaultdict 26 | import logging 27 | msglogger = logging.getLogger() 28 | 29 | from torchnet.meter import AverageValueMeter 30 | from distiller.utils import sparsity, density 31 | 32 | 33 | class FreezeTraining(object): 34 | def __init__(self, name): 35 | print("------FreezeTraining--------") 36 | self.name = name 37 | 38 | def freeze_training(model, which_params, freeze): 39 | """This function will freeze/defrost training for certain layers. 40 | 41 | Sometimes, when we prune and retrain a certain layer type, 42 | we'd like to freeze the training of the other layers. 43 | """ 44 | for param in model.parameters(): 45 | pname = model_find_param_name(model, param.data) 46 | if pname is None: 47 | continue 48 | for ptype in which_params: 49 | if ptype in pname: 50 | # see: http://pytorch.org/docs/master/notes/autograd.html?highlight=grad_fn 51 | param.requires_grad = not freeze 52 | if freeze: 53 | msglogger.info('Freezing: ' + pname) 54 | else: 55 | msglogger.info('Defrosting: ' + pname) 56 | 57 | 58 | def freeze_all(model, freeze): 59 | msglogger.info('{} all parameters'.format('Freezing' if freeze else 'Defrosting')) 60 | for param in model.parameters(): 61 | param.requires_grad = not freeze 62 | 63 | 64 | def adjust_dropout(module, new_probabilty): 65 | """Replace the dropout probability of dropout layers 66 | 67 | As explained in the paper "Learning both Weights and Connections for 68 | Efficient Neural Networks": 69 | Dropout is widely used to prevent over-fitting, and this also applies to retraining. 70 | During retraining, however, the dropout ratio must be adjusted to account for the 71 | change in model capacity. In dropout, each parameter is probabilistically dropped 72 | during training, but will come back during inference. In pruning, parameters are 73 | dropped forever after pruning and have no chance to come back during both training 74 | and inference. As the parameters get sparse, the classifier will select the most 75 | informative predictors and thus have much less prediction variance, which reduces 76 | over-fitting. As pruning already reduced model capacity, the retraining dropout ratio 77 | should be smaller. 78 | """ 79 | if type(module) in [torch.nn.Dropout, 80 | torch.nn.Dropout2d, 81 | torch.nn.Dropout3d, 82 | torch.nn.AlphaDropout]: 83 | msglogger.info("Adjusting dropout probability")# for {}".format(str(module))) 84 | module.p = new_probabilty 85 | else: 86 | for child in module.children(): 87 | adjust_dropout(child, new_probabilty) 88 | -------------------------------------------------------------------------------- /distiller/learning_rate.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from bisect import bisect_right 18 | from torch.optim.lr_scheduler import _LRScheduler 19 | 20 | 21 | class PolynomialLR(_LRScheduler): 22 | """Set the learning rate for each parameter group using a polynomial defined as: 23 | lr = base_lr * (1 - T_cur/T_max) ^ (power), where T_cur is the current epoch and T_max is the maximum number of 24 | epochs. 25 | 26 | Args: 27 | optimizer (Optimizer): Wrapped optimizer. 28 | T_max (int): Maximum number of epochs 29 | power (int): Degree of polynomial 30 | last_epoch (int): The index of last epoch. Default: -1. 31 | """ 32 | def __init__(self, optimizer, T_max, power, last_epoch=-1): 33 | self.T_max = T_max 34 | self.power = power 35 | super(PolynomialLR, self).__init__(optimizer, last_epoch) 36 | 37 | def get_lr(self): 38 | # base_lr * (1 - iter/max_iter) ^ (power) 39 | return [base_lr * (1 - self.last_epoch / self.T_max) ** self.power 40 | for base_lr in self.base_lrs] 41 | 42 | 43 | class MultiStepMultiGammaLR(_LRScheduler): 44 | """Similar to torch.otpim.MultiStepLR, but instead of a single gamma value, specify a gamma value per-milestone. 45 | 46 | Args: 47 | optimizer (Optimizer): Wrapped optimizer. 48 | milestones (list): List of epoch indices. Must be increasing. 49 | gammas (list): List of gamma values. Must have same length as milestones. 50 | last_epoch (int): The index of last epoch. Default: -1. 51 | """ 52 | def __init__(self, optimizer, milestones, gammas, last_epoch=-1): 53 | if not list(milestones) == sorted(milestones): 54 | raise ValueError('Milestones should be a list of' 55 | ' increasing integers. Got {}', milestones) 56 | if len(milestones) != len(gammas): 57 | raise ValueError('Milestones and Gammas lists should be of same length.') 58 | 59 | self.milestones = milestones 60 | self.multiplicative_gammas = [1] 61 | for idx, gamma in enumerate(gammas): 62 | self.multiplicative_gammas.append(gamma * self.multiplicative_gammas[idx]) 63 | 64 | super(MultiStepMultiGammaLR, self).__init__(optimizer, last_epoch) 65 | 66 | def get_lr(self): 67 | idx = bisect_right(self.milestones, self.last_epoch) 68 | return [base_lr * self.multiplicative_gammas[idx] for base_lr in self.base_lrs] 69 | -------------------------------------------------------------------------------- /distiller/models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/models/cifar10/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | """This package contains CIFAR image classification models for pytorch""" 18 | 19 | from .simplenet_cifar import * 20 | from .resnet_cifar import * 21 | from .preresnet_cifar import * 22 | from .vgg_cifar import * 23 | from .resnet_cifar_earlyexit import * 24 | from .plain_cifar import * 25 | -------------------------------------------------------------------------------- /distiller/models/cifar10/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/models/cifar10/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/models/cifar10/__pycache__/plain_cifar.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/models/cifar10/__pycache__/plain_cifar.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/models/cifar10/__pycache__/preresnet_cifar.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/models/cifar10/__pycache__/preresnet_cifar.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/models/cifar10/__pycache__/resnet_cifar.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/models/cifar10/__pycache__/resnet_cifar.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/models/cifar10/__pycache__/resnet_cifar_earlyexit.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/models/cifar10/__pycache__/resnet_cifar_earlyexit.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/models/cifar10/__pycache__/simplenet_cifar.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/models/cifar10/__pycache__/simplenet_cifar.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/models/cifar10/__pycache__/vgg_cifar.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/models/cifar10/__pycache__/vgg_cifar.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/models/cifar10/plain_cifar.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | """Plain for CIFAR10 18 | 19 | Plain for CIFAR10, based on "Deep Residual Learning for Image Recognition". 20 | 21 | @inproceedings{DBLP:conf/cvpr/HeZRS16, 22 | author = {Kaiming He and 23 | Xiangyu Zhang and 24 | Shaoqing Ren and 25 | Jian Sun}, 26 | title = {Deep Residual Learning for Image Recognition}, 27 | booktitle = {{CVPR}}, 28 | pages = {770--778}, 29 | publisher = {{IEEE} Computer Society}, 30 | year = {2016} 31 | } 32 | 33 | """ 34 | import torch.nn as nn 35 | import math 36 | 37 | 38 | __all__ = ['plain20_cifar'] 39 | 40 | NUM_CLASSES = 10 41 | 42 | 43 | def conv3x3(in_planes, out_planes, stride=1): 44 | """3x3 convolution with padding""" 45 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 46 | padding=1, bias=False) 47 | 48 | 49 | class BasicBlock(nn.Module): 50 | expansion = 1 51 | 52 | def __init__(self, inplanes, planes, stride=1): 53 | super().__init__() 54 | self.conv1 = conv3x3(inplanes, planes, stride) 55 | self.bn1 = nn.BatchNorm2d(planes) 56 | self.relu1 = nn.ReLU(inplace=False) 57 | self.conv2 = conv3x3(planes, planes) 58 | self.bn2 = nn.BatchNorm2d(planes) 59 | self.relu2 = nn.ReLU(inplace=False) 60 | 61 | def forward(self, x): 62 | out = self.conv1(x) 63 | out = self.bn1(out) 64 | out = self.relu1(out) 65 | 66 | out = self.conv2(out) 67 | out = self.bn2(out) 68 | out = self.relu2(out) 69 | return out 70 | 71 | 72 | class PlainCifar(nn.Module): 73 | def __init__(self, block, blks_per_layer, num_classes=NUM_CLASSES): 74 | self.inplanes = 16 75 | super().__init__() 76 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) 77 | self.bn1 = nn.BatchNorm2d(self.inplanes) 78 | self.relu = nn.ReLU(inplace=True) 79 | self.layer1 = self._make_layer(block, 16, blks_per_layer[0], stride=1) 80 | self.layer2 = self._make_layer(block, 32, blks_per_layer[1], stride=2) 81 | self.layer3 = self._make_layer(block, 64, blks_per_layer[2], stride=2) 82 | 83 | self.avgpool = nn.AvgPool2d(8, stride=1) 84 | self.fc = nn.Linear(64 * block.expansion, num_classes) 85 | 86 | for m in self.modules(): 87 | if isinstance(m, nn.Conv2d): 88 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 89 | m.weight.data.normal_(0, math.sqrt(2. / n)) 90 | elif isinstance(m, nn.BatchNorm2d): 91 | m.weight.data.fill_(1) 92 | m.bias.data.zero_() 93 | 94 | def _make_layer(self, block, planes, num_blocks, stride): 95 | # Each layer is composed on 2*num_blocks blocks, and the first block usually 96 | # performs downsampling of the input, and doubling of the number of filters/feature-maps. 97 | blocks = [] 98 | inplanes = self.inplanes 99 | # First block is special (downsamples and adds filters) 100 | blocks.append(block(inplanes, planes, stride)) 101 | 102 | self.inplanes = planes * block.expansion 103 | for i in range(num_blocks - 1): 104 | blocks.append(block(self.inplanes, planes, stride=1)) 105 | return nn.Sequential(*blocks) 106 | 107 | def forward(self, x): 108 | x = self.conv1(x) 109 | x = self.bn1(x) 110 | x = self.relu(x) 111 | 112 | x = self.layer1(x) 113 | x = self.layer2(x) 114 | x = self.layer3(x) 115 | 116 | x = self.avgpool(x) 117 | x = x.view(x.size(0), -1) 118 | x = self.fc(x) 119 | return x 120 | 121 | 122 | def plain20_cifar(**kwargs): 123 | model = PlainCifar(BasicBlock, [3, 3, 3], **kwargs) 124 | return model 125 | -------------------------------------------------------------------------------- /distiller/models/cifar10/resnet_cifar.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | """Resnet for CIFAR10 18 | 19 | Resnet for CIFAR10, based on "Deep Residual Learning for Image Recognition". 20 | This is based on TorchVision's implementation of ResNet for ImageNet, with appropriate 21 | changes for the 10-class Cifar-10 dataset. 22 | This ResNet also has layer gates, to be able to dynamically remove layers. 23 | 24 | @inproceedings{DBLP:conf/cvpr/HeZRS16, 25 | author = {Kaiming He and 26 | Xiangyu Zhang and 27 | Shaoqing Ren and 28 | Jian Sun}, 29 | title = {Deep Residual Learning for Image Recognition}, 30 | booktitle = {{CVPR}}, 31 | pages = {770--778}, 32 | publisher = {{IEEE} Computer Society}, 33 | year = {2016} 34 | } 35 | 36 | """ 37 | import torch.nn as nn 38 | import math 39 | import torch.utils.model_zoo as model_zoo 40 | 41 | 42 | __all__ = ['resnet20_cifar', 'resnet32_cifar', 'resnet44_cifar', 'resnet56_cifar'] 43 | 44 | NUM_CLASSES = 10 45 | 46 | def conv3x3(in_planes, out_planes, stride=1): 47 | """3x3 convolution with padding""" 48 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 49 | padding=1, bias=False) 50 | 51 | class BasicBlock(nn.Module): 52 | expansion = 1 53 | 54 | def __init__(self, block_gates, inplanes, planes, stride=1, downsample=None): 55 | super(BasicBlock, self).__init__() 56 | self.block_gates = block_gates 57 | self.conv1 = conv3x3(inplanes, planes, stride) 58 | self.bn1 = nn.BatchNorm2d(planes) 59 | self.relu1 = nn.ReLU(inplace=False) # To enable layer removal inplace must be False 60 | self.conv2 = conv3x3(planes, planes) 61 | self.bn2 = nn.BatchNorm2d(planes) 62 | self.relu2 = nn.ReLU(inplace=False) 63 | self.downsample = downsample 64 | self.stride = stride 65 | 66 | def forward(self, x): 67 | residual = out = x 68 | 69 | if self.block_gates[0]: 70 | out = self.conv1(x) 71 | out = self.bn1(out) 72 | out = self.relu1(out) 73 | 74 | if self.block_gates[1]: 75 | out = self.conv2(out) 76 | out = self.bn2(out) 77 | 78 | if self.downsample is not None: 79 | residual = self.downsample(x) 80 | 81 | out += residual 82 | out = self.relu2(out) 83 | 84 | return out 85 | 86 | 87 | class ResNetCifar(nn.Module): 88 | 89 | def __init__(self, block, layers, num_classes=NUM_CLASSES): 90 | self.nlayers = 0 91 | # Each layer manages its own gates 92 | self.layer_gates = [] 93 | for layer in range(3): 94 | # For each of the 3 layers, create block gates: each block has two layers 95 | self.layer_gates.append([]) # [True, True] * layers[layer]) 96 | for blk in range(layers[layer]): 97 | self.layer_gates[layer].append([True, True]) 98 | 99 | self.inplanes = 16 # 64 100 | super(ResNetCifar, self).__init__() 101 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) 102 | self.bn1 = nn.BatchNorm2d(self.inplanes) 103 | self.relu = nn.ReLU(inplace=True) 104 | self.layer1 = self._make_layer(self.layer_gates[0], block, 16, layers[0]) 105 | self.layer2 = self._make_layer(self.layer_gates[1], block, 32, layers[1], stride=2) 106 | self.layer3 = self._make_layer(self.layer_gates[2], block, 64, layers[2], stride=2) 107 | self.avgpool = nn.AvgPool2d(8, stride=1) 108 | self.fc = nn.Linear(64 * block.expansion, num_classes) 109 | 110 | for m in self.modules(): 111 | if isinstance(m, nn.Conv2d): 112 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 113 | m.weight.data.normal_(0, math.sqrt(2. / n)) 114 | elif isinstance(m, nn.BatchNorm2d): 115 | m.weight.data.fill_(1) 116 | m.bias.data.zero_() 117 | 118 | def _make_layer(self, layer_gates, block, planes, blocks, stride=1): 119 | downsample = None 120 | if stride != 1 or self.inplanes != planes * block.expansion: 121 | downsample = nn.Sequential( 122 | nn.Conv2d(self.inplanes, planes * block.expansion, 123 | kernel_size=1, stride=stride, bias=False), 124 | nn.BatchNorm2d(planes * block.expansion), 125 | ) 126 | 127 | layers = [] 128 | layers.append(block(layer_gates[0], self.inplanes, planes, stride, downsample)) 129 | self.inplanes = planes * block.expansion 130 | for i in range(1, blocks): 131 | layers.append(block(layer_gates[i], self.inplanes, planes)) 132 | 133 | return nn.Sequential(*layers) 134 | 135 | def forward(self, x): 136 | x = self.conv1(x) 137 | x = self.bn1(x) 138 | x = self.relu(x) 139 | 140 | x = self.layer1(x) 141 | x = self.layer2(x) 142 | x = self.layer3(x) 143 | 144 | x = self.avgpool(x) 145 | x = x.view(x.size(0), -1) 146 | x = self.fc(x) 147 | 148 | return x 149 | 150 | 151 | def resnet20_cifar(**kwargs): 152 | model = ResNetCifar(BasicBlock, [3, 3, 3], **kwargs) 153 | return model 154 | 155 | def resnet32_cifar(**kwargs): 156 | model = ResNetCifar(BasicBlock, [5, 5, 5], **kwargs) 157 | return model 158 | 159 | def resnet44_cifar(**kwargs): 160 | model = ResNetCifar(BasicBlock, [7, 7, 7], **kwargs) 161 | return model 162 | 163 | def resnet56_cifar(**kwargs): 164 | model = ResNetCifar(BasicBlock, [9, 9, 9], **kwargs) 165 | return model 166 | -------------------------------------------------------------------------------- /distiller/models/cifar10/resnet_cifar_earlyexit.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | """Resnet for CIFAR10 18 | 19 | Resnet for CIFAR10, based on "Deep Residual Learning for Image Recognition". 20 | This is based on TorchVision's implementation of ResNet for ImageNet, with appropriate 21 | changes for the 10-class Cifar-10 dataset. 22 | This ResNet also has layer gates, to be able to dynamically remove layers. 23 | 24 | @inproceedings{DBLP:conf/cvpr/HeZRS16, 25 | author = {Kaiming He and 26 | Xiangyu Zhang and 27 | Shaoqing Ren and 28 | Jian Sun}, 29 | title = {Deep Residual Learning for Image Recognition}, 30 | booktitle = {{CVPR}}, 31 | pages = {770--778}, 32 | publisher = {{IEEE} Computer Society}, 33 | year = {2016} 34 | } 35 | 36 | """ 37 | import torch.nn as nn 38 | import math 39 | import torch.utils.model_zoo as model_zoo 40 | import torchvision.models as models 41 | from .resnet_cifar import BasicBlock 42 | from .resnet_cifar import ResNetCifar 43 | 44 | 45 | __all__ = ['resnet20_cifar_earlyexit', 'resnet32_cifar_earlyexit', 'resnet44_cifar_earlyexit', 46 | 'resnet56_cifar_earlyexit', 'resnet110_cifar_earlyexit', 'resnet1202_cifar_earlyexit'] 47 | 48 | NUM_CLASSES = 10 49 | 50 | def conv3x3(in_planes, out_planes, stride=1): 51 | """3x3 convolution with padding""" 52 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 53 | padding=1, bias=False) 54 | 55 | 56 | class ResNetCifarEarlyExit(ResNetCifar): 57 | 58 | def __init__(self, block, layers, num_classes=NUM_CLASSES): 59 | super(ResNetCifarEarlyExit, self).__init__(block, layers, num_classes) 60 | 61 | # Define early exit layers 62 | self.linear_exit0 = nn.Linear(1600, num_classes) 63 | 64 | 65 | def forward(self, x): 66 | x = self.conv1(x) 67 | x = self.bn1(x) 68 | x = self.relu(x) 69 | 70 | x = self.layer1(x) 71 | 72 | # Add early exit layers 73 | exit0 = nn.functional.avg_pool2d(x, 3) 74 | exit0 = exit0.view(exit0.size(0), -1) 75 | exit0 = self.linear_exit0(exit0) 76 | 77 | x = self.layer2(x) 78 | x = self.layer3(x) 79 | 80 | x = self.avgpool(x) 81 | x = x.view(x.size(0), -1) 82 | x = self.fc(x) 83 | 84 | # return a list of probabilities 85 | output = [] 86 | output.append(exit0) 87 | output.append(x) 88 | return output 89 | 90 | 91 | def resnet20_cifar_earlyexit(**kwargs): 92 | model = ResNetCifarEarlyExit(BasicBlock, [3, 3, 3], **kwargs) 93 | return model 94 | 95 | def resnet32_cifar_earlyexit(**kwargs): 96 | model = ResNetCifarEarlyExit(BasicBlock, [5, 5, 5], **kwargs) 97 | return model 98 | 99 | def resnet44_cifar_earlyexit(**kwargs): 100 | model = ResNetCifarEarlyExit(BasicBlock, [7, 7, 7], **kwargs) 101 | return model 102 | 103 | def resnet56_cifar_earlyexit(**kwargs): 104 | model = ResNetCifarEarlyExit(BasicBlock, [9, 9, 9], **kwargs) 105 | return model 106 | 107 | def resnet110_cifar_earlyexit(**kwargs): 108 | model = ResNetCifarEarlyExit(BasicBlock, [18, 18, 18], **kwargs) 109 | return model 110 | 111 | def resnet1202_cifar_earlyexit(**kwargs): 112 | model = ResNetCifarEarlyExit(BasicBlock, [200, 200, 200], **kwargs) 113 | return model -------------------------------------------------------------------------------- /distiller/models/cifar10/simplenet_cifar.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | 20 | __all__ = ['simplenet_cifar'] 21 | 22 | 23 | class Simplenet(nn.Module): 24 | def __init__(self): 25 | super(Simplenet, self).__init__() 26 | self.conv1 = nn.Conv2d(3, 6, 5) 27 | self.relu_conv1 = nn.ReLU() 28 | self.pool1 = nn.MaxPool2d(2, 2) 29 | self.conv2 = nn.Conv2d(6, 16, 5) 30 | self.relu_conv2 = nn.ReLU() 31 | self.pool2 = nn.MaxPool2d(2, 2) 32 | self.fc1 = nn.Linear(16 * 5 * 5, 120) 33 | self.relu_fc1 = nn.ReLU() 34 | self.fc2 = nn.Linear(120, 84) 35 | self.relu_fc2 = nn.ReLU() 36 | self.fc3 = nn.Linear(84, 10) 37 | 38 | def forward(self, x): 39 | x = self.pool1(self.relu_conv1(self.conv1(x))) 40 | x = self.pool2(self.relu_conv2(self.conv2(x))) 41 | x = x.view(-1, 16 * 5 * 5) 42 | x = self.relu_fc1(self.fc1(x)) 43 | x = self.relu_fc2(self.fc2(x)) 44 | x = self.fc3(x) 45 | return x 46 | 47 | 48 | def simplenet_cifar(): 49 | model = Simplenet() 50 | return model 51 | -------------------------------------------------------------------------------- /distiller/models/cifar10/vgg_cifar.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | """VGG for CIFAR10 18 | 19 | VGG for CIFAR10, based on "Very Deep Convolutional Networks for Large-Scale 20 | Image Recognition". 21 | This is based on TorchVision's implementation of VGG for ImageNet, with 22 | appropriate changes for the 10-class Cifar-10 dataset. 23 | We replaced the three linear classifiers with a single one. 24 | """ 25 | 26 | import torch.nn as nn 27 | 28 | __all__ = [ 29 | 'VGGCifar', 'vgg11_cifar', 'vgg11_bn_cifar', 'vgg13_cifar', 'vgg13_bn_cifar', 'vgg16_cifar', 'vgg16_bn_cifar', 30 | 'vgg19_bn_cifar', 'vgg19_cifar', 31 | ] 32 | 33 | 34 | class VGGCifar(nn.Module): 35 | def __init__(self, features, num_classes=10, init_weights=True): 36 | super(VGGCifar, self).__init__() 37 | self.features = features 38 | self.classifier = nn.Linear(512, num_classes) 39 | if init_weights: 40 | self._initialize_weights() 41 | 42 | def forward(self, x): 43 | x = self.features(x) 44 | x = x.view(x.size(0), -1) 45 | x = self.classifier(x) 46 | return x 47 | 48 | def _initialize_weights(self): 49 | for m in self.modules(): 50 | if isinstance(m, nn.Conv2d): 51 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 52 | if m.bias is not None: 53 | nn.init.constant_(m.bias, 0) 54 | elif isinstance(m, nn.BatchNorm2d): 55 | nn.init.constant_(m.weight, 1) 56 | nn.init.constant_(m.bias, 0) 57 | elif isinstance(m, nn.Linear): 58 | nn.init.normal_(m.weight, 0, 0.01) 59 | nn.init.constant_(m.bias, 0) 60 | 61 | 62 | def make_layers(cfg, batch_norm=False): 63 | layers = [] 64 | in_channels = 3 65 | for v in cfg: 66 | if v == 'M': 67 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 68 | else: 69 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 70 | if batch_norm: 71 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 72 | else: 73 | layers += [conv2d, nn.ReLU(inplace=True)] 74 | in_channels = v 75 | return nn.Sequential(*layers) 76 | 77 | 78 | cfg = { 79 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 80 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 81 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 82 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 83 | } 84 | 85 | 86 | def vgg11_cifar(**kwargs): 87 | """VGG 11-layer model (configuration "A")""" 88 | model = VGGCifar(make_layers(cfg['A']), **kwargs) 89 | return model 90 | 91 | 92 | def vgg11_bn_cifar(**kwargs): 93 | """VGG 11-layer model (configuration "A") with batch normalization""" 94 | model = VGGCifar(make_layers(cfg['A'], batch_norm=True), **kwargs) 95 | return model 96 | 97 | 98 | def vgg13_cifar(**kwargs): 99 | """VGG 13-layer model (configuration "B")""" 100 | model = VGGCifar(make_layers(cfg['B']), **kwargs) 101 | return model 102 | 103 | 104 | def vgg13_bn_cifar(**kwargs): 105 | """VGG 13-layer model (configuration "B") with batch normalization""" 106 | model = VGGCifar(make_layers(cfg['B'], batch_norm=True), **kwargs) 107 | return model 108 | 109 | 110 | def vgg16_cifar(**kwargs): 111 | """VGG 16-layer model (configuration "D") 112 | """ 113 | model = VGGCifar(make_layers(cfg['D']), **kwargs) 114 | return model 115 | 116 | 117 | def vgg16_bn_cifar(**kwargs): 118 | """VGG 16-layer model (configuration "D") with batch normalization""" 119 | model = VGGCifar(make_layers(cfg['D'], batch_norm=True), **kwargs) 120 | return model 121 | 122 | 123 | def vgg19_cifar(**kwargs): 124 | """VGG 19-layer model (configuration "E") 125 | """ 126 | model = VGGCifar(make_layers(cfg['E']), **kwargs) 127 | return model 128 | 129 | 130 | def vgg19_bn_cifar(**kwargs): 131 | """VGG 19-layer model (configuration 'E') with batch normalization""" 132 | model = VGGCifar(make_layers(cfg['E'], batch_norm=True), **kwargs) 133 | return model 134 | -------------------------------------------------------------------------------- /distiller/models/imagenet/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | """This package contains ImageNet image classification models not found in torchvision""" 18 | 19 | from .mobilenet import * 20 | from .mobilenet_dropout import * 21 | from .preresnet_imagenet import * 22 | from .alexnet_batchnorm import * 23 | from .resnet_earlyexit import * 24 | from .resnet import * 25 | -------------------------------------------------------------------------------- /distiller/models/imagenet/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/models/imagenet/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/models/imagenet/__pycache__/alexnet_batchnorm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/models/imagenet/__pycache__/alexnet_batchnorm.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/models/imagenet/__pycache__/mobilenet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/models/imagenet/__pycache__/mobilenet.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/models/imagenet/__pycache__/mobilenet_dropout.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/models/imagenet/__pycache__/mobilenet_dropout.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/models/imagenet/__pycache__/preresnet_imagenet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/models/imagenet/__pycache__/preresnet_imagenet.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/models/imagenet/__pycache__/resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/models/imagenet/__pycache__/resnet.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/models/imagenet/__pycache__/resnet_earlyexit.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/models/imagenet/__pycache__/resnet_earlyexit.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/models/imagenet/alexnet_batchnorm.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | """ 18 | AlexNet model with batch-norm layers. 19 | Model configuration based on the AlexNet DoReFa example in TensorPack: 20 | https://github.com/tensorpack/tensorpack/blob/master/examples/DoReFa-Net/alexnet-dorefa.py 21 | 22 | Code based on the AlexNet PyTorch sample, with the required changes. 23 | """ 24 | 25 | import math 26 | import torch.nn as nn 27 | 28 | __all__ = ['AlexNetBN', 'alexnet_bn'] 29 | 30 | 31 | class AlexNetBN(nn.Module): 32 | 33 | def __init__(self, num_classes=1000): 34 | super(AlexNetBN, self).__init__() 35 | self.features = nn.Sequential( 36 | nn.Conv2d(3, 96, kernel_size=12, stride=4), # conv0 (224x224x3) -> (54x54x96) 37 | nn.ReLU(inplace=True), 38 | nn.Conv2d(96, 256, kernel_size=5, padding=2, groups=2, bias=False), # conv1 (54x54x96) -> (54x54x256) 39 | nn.BatchNorm2d(256, eps=1e-4, momentum=0.9), # bn1 (54x54x256) 40 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), # pool1 (54x54x256) -> (27x27x256) 41 | nn.ReLU(inplace=True), 42 | 43 | nn.Conv2d(256, 384, kernel_size=3, padding=1, bias=False), # conv2 (27x27x256) -> (27x27x384) 44 | nn.BatchNorm2d(384, eps=1e-4, momentum=0.9), # bn2 (27x27x384) 45 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1), # pool2 (27x27x384) -> (14x14x384) 46 | nn.ReLU(inplace=True), 47 | 48 | nn.Conv2d(384, 384, kernel_size=3, padding=1, groups=2, bias=False), # conv3 (14x14x384) -> (14x14x384) 49 | nn.BatchNorm2d(384, eps=1e-4, momentum=0.9), # bn3 (14x14x384) 50 | nn.ReLU(inplace=True), 51 | 52 | nn.Conv2d(384, 256, kernel_size=3, padding=1, groups=2, bias=False), # conv4 (14x14x384) -> (14x14x256) 53 | nn.BatchNorm2d(256, eps=1e-4, momentum=0.9), # bn4 (14x14x256) 54 | nn.MaxPool2d(kernel_size=3, stride=2), # pool4 (14x14x256) -> (6x6x256) 55 | nn.ReLU(inplace=True), 56 | ) 57 | self.classifier = nn.Sequential( 58 | nn.Linear(256 * 6 * 6, 4096, bias=False), # fc0 59 | nn.BatchNorm1d(4096, eps=1e-4, momentum=0.9), # bnfc0 60 | nn.ReLU(inplace=True), 61 | nn.Linear(4096, 4096, bias=False), # fc1 62 | nn.BatchNorm1d(4096, eps=1e-4, momentum=0.9), # bnfc1 63 | nn.ReLU(inplace=True), 64 | nn.Linear(4096, num_classes), # fct 65 | ) 66 | 67 | for m in self.modules(): 68 | if isinstance(m, (nn.Conv2d, nn.Linear)): 69 | fan_in, k_size = (m.in_channels, m.kernel_size[0] * m.kernel_size[1]) if isinstance(m, nn.Conv2d) \ 70 | else (m.in_features, 1) 71 | n = k_size * fan_in 72 | m.weight.data.normal_(0, math.sqrt(2. / n)) 73 | if hasattr(m, 'bias') and m.bias is not None: 74 | m.bias.data.fill_(0) 75 | elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)): 76 | m.weight.data.fill_(1) 77 | m.bias.data.zero_() 78 | 79 | def forward(self, x): 80 | x = self.features(x) 81 | x = x.view(x.size(0), 256 * 6 * 6) 82 | x = self.classifier(x) 83 | return x 84 | 85 | 86 | def alexnet_bn(**kwargs): 87 | r"""AlexNet model with batch-norm layers. 88 | Model configuration based on the AlexNet DoReFa example in `TensorPack 89 | ` 90 | """ 91 | model = AlexNetBN(**kwargs) 92 | return model 93 | -------------------------------------------------------------------------------- /distiller/models/imagenet/mobilenet.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from math import floor 18 | import torch.nn as nn 19 | 20 | __all__ = ['mobilenet', 'mobilenet_025', 'mobilenet_050', 'mobilenet_075'] 21 | 22 | 23 | class MobileNet(nn.Module): 24 | def __init__(self, channel_multiplier=1.0, min_channels=8): 25 | super(MobileNet, self).__init__() 26 | 27 | if channel_multiplier <= 0: 28 | raise ValueError('channel_multiplier must be >= 0') 29 | 30 | def conv_bn_relu(n_ifm, n_ofm, kernel_size, stride=1, padding=0, groups=1): 31 | return [ 32 | nn.Conv2d(n_ifm, n_ofm, kernel_size, stride=stride, padding=padding, groups=groups, bias=False), 33 | nn.BatchNorm2d(n_ofm), 34 | nn.ReLU(inplace=True) 35 | ] 36 | 37 | def depthwise_conv(n_ifm, n_ofm, stride): 38 | return nn.Sequential( 39 | *conv_bn_relu(n_ifm, n_ifm, 3, stride=stride, padding=1, groups=n_ifm), 40 | *conv_bn_relu(n_ifm, n_ofm, 1, stride=1) 41 | ) 42 | 43 | base_channels = [32, 64, 128, 256, 512, 1024] 44 | self.channels = [max(floor(n * channel_multiplier), min_channels) for n in base_channels] 45 | 46 | self.model = nn.Sequential( 47 | nn.Sequential(*conv_bn_relu(3, self.channels[0], 3, stride=2, padding=1)), 48 | depthwise_conv(self.channels[0], self.channels[1], 1), 49 | depthwise_conv(self.channels[1], self.channels[2], 2), 50 | depthwise_conv(self.channels[2], self.channels[2], 1), 51 | depthwise_conv(self.channels[2], self.channels[3], 2), 52 | depthwise_conv(self.channels[3], self.channels[3], 1), 53 | depthwise_conv(self.channels[3], self.channels[4], 2), 54 | depthwise_conv(self.channels[4], self.channels[4], 1), 55 | depthwise_conv(self.channels[4], self.channels[4], 1), 56 | depthwise_conv(self.channels[4], self.channels[4], 1), 57 | depthwise_conv(self.channels[4], self.channels[4], 1), 58 | depthwise_conv(self.channels[4], self.channels[4], 1), 59 | depthwise_conv(self.channels[4], self.channels[5], 2), 60 | depthwise_conv(self.channels[5], self.channels[5], 1), 61 | nn.AvgPool2d(7), 62 | ) 63 | self.fc = nn.Linear(self.channels[5], 1000) 64 | 65 | def forward(self, x): 66 | x = self.model(x) 67 | x = x.view(-1, x.size(1)) 68 | x = self.fc(x) 69 | return x 70 | 71 | 72 | def mobilenet_025(): 73 | return MobileNet(channel_multiplier=0.25) 74 | 75 | 76 | def mobilenet_050(): 77 | return MobileNet(channel_multiplier=0.5) 78 | 79 | 80 | def mobilenet_075(): 81 | return MobileNet(channel_multiplier=0.75) 82 | 83 | 84 | def mobilenet(): 85 | return MobileNet() 86 | -------------------------------------------------------------------------------- /distiller/models/imagenet/mobilenet_dropout.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2019 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | """Source: https://github.com/mit-han-lab/amc-compressed-models/blob/master/models/mobilenet_v1.py 18 | 19 | The code has been modified to remove code related to AMC. 20 | """ 21 | 22 | __all__ = ['mobilenet_v1_dropout'] 23 | 24 | 25 | import torch 26 | import torch.nn as nn 27 | import math 28 | 29 | 30 | def conv_bn(inp, oup, stride): 31 | return nn.Sequential( 32 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 33 | nn.BatchNorm2d(oup), 34 | nn.ReLU(inplace=True) 35 | ) 36 | 37 | 38 | def conv_dw(inp, oup, stride): 39 | return nn.Sequential( 40 | nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), 41 | nn.BatchNorm2d(inp), 42 | nn.ReLU(inplace=True), 43 | 44 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 45 | nn.BatchNorm2d(oup), 46 | nn.ReLU(inplace=True), 47 | ) 48 | 49 | 50 | class MobileNet(nn.Module): 51 | def __init__(self, n_class=1000, channel_multiplier=1.): 52 | super(MobileNet, self).__init__() 53 | in_planes = int(32 * channel_multiplier) 54 | a = int(64 * channel_multiplier) 55 | cfg = [a, (a*2, 2), a*2, (a*4, 2), a*4, (a*8, 2), a*8, a*8, a*8, a*8, a*8, (a*16, 2), a*16] 56 | 57 | self.conv1 = conv_bn(3, in_planes, stride=2) 58 | self.features = self._make_layers(in_planes, cfg, conv_dw) 59 | #self.dropout = nn.Dropout(0.2) 60 | self.classifier = nn.Sequential( 61 | nn.Dropout(0.2), 62 | nn.Linear(cfg[-1], n_class), 63 | ) 64 | 65 | self._initialize_weights() 66 | 67 | def forward(self, x): 68 | x = self.conv1(x) 69 | x = self.features(x) 70 | x = x.mean(3).mean(2) # global average pooling 71 | #x = self.dropout(x) 72 | x = self.classifier(x) 73 | return x 74 | 75 | def _make_layers(self, in_planes, cfg, layer): 76 | layers = [] 77 | for x in cfg: 78 | out_planes = x if isinstance(x, int) else x[0] 79 | stride = 1 if isinstance(x, int) else x[1] 80 | layers.append(layer(in_planes, out_planes, stride)) 81 | in_planes = out_planes 82 | return nn.Sequential(*layers) 83 | 84 | def _initialize_weights(self): 85 | for m in self.modules(): 86 | if isinstance(m, nn.Conv2d): 87 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 88 | m.weight.data.normal_(0, math.sqrt(2. / n)) 89 | if m.bias is not None: 90 | m.bias.data.zero_() 91 | elif isinstance(m, nn.BatchNorm2d): 92 | m.weight.data.fill_(1) 93 | m.bias.data.zero_() 94 | elif isinstance(m, nn.Linear): 95 | n = m.weight.size(1) 96 | m.weight.data.normal_(0, 0.01) 97 | m.bias.data.zero_() 98 | 99 | 100 | def mobilenet_v1_dropout_025(): 101 | return MobileNet(channel_multiplier=0.25) 102 | 103 | 104 | def mobilenet_v1_dropout_050(): 105 | return MobileNet(channel_multiplier=0.5) 106 | 107 | 108 | def mobilenet_v1_dropout_075(): 109 | return MobileNet(channel_multiplier=0.75) 110 | 111 | 112 | def mobilenet_v1_dropout(): 113 | return MobileNet() 114 | -------------------------------------------------------------------------------- /distiller/models/imagenet/resnet_earlyexit.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | import torchvision.models as models 5 | from torchvision.models.resnet import Bottleneck 6 | from torchvision.models.resnet import BasicBlock 7 | 8 | 9 | __all__ = ['resnet50_earlyexit'] 10 | 11 | 12 | def conv3x3(in_planes, out_planes, stride=1): 13 | """3x3 convolution with padding""" 14 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 15 | padding=1, bias=False) 16 | 17 | 18 | class ResNetEarlyExit(models.ResNet): 19 | 20 | def __init__(self, block, layers, num_classes=1000): 21 | super(ResNetEarlyExit, self).__init__(block, layers, num_classes) 22 | 23 | # Define early exit layers 24 | self.conv1_exit0 = nn.Conv2d(256, 50, kernel_size=7, stride=2, padding=3, bias=True) 25 | self.conv2_exit0 = nn.Conv2d(50, 12, kernel_size=7, stride=2, padding=3, bias=True) 26 | self.conv1_exit1 = nn.Conv2d(512, 12, kernel_size=7, stride=2, padding=3, bias=True) 27 | self.fc_exit0 = nn.Linear(147 * block.expansion, num_classes) 28 | self.fc_exit1 = nn.Linear(192 * block.expansion, num_classes) 29 | 30 | def forward(self, x): 31 | x = self.conv1(x) 32 | x = self.bn1(x) 33 | x = self.relu(x) 34 | x = self.maxpool(x) 35 | 36 | x = self.layer1(x) 37 | 38 | # Add early exit layers 39 | exit0 = self.avgpool(x) 40 | exit0 = self.conv1_exit0(exit0) 41 | exit0 = self.conv2_exit0(exit0) 42 | exit0 = self.avgpool(exit0) 43 | exit0 = exit0.view(exit0.size(0), -1) 44 | exit0 = self.fc_exit0(exit0) 45 | 46 | x = self.layer2(x) 47 | 48 | # Add early exit layers 49 | exit1 = self.conv1_exit1(x) 50 | exit1 = self.avgpool(exit1) 51 | exit1 = exit1.view(exit1.size(0), -1) 52 | exit1 = self.fc_exit1(exit1) 53 | 54 | x = self.layer3(x) 55 | x = self.layer4(x) 56 | 57 | x = self.avgpool(x) 58 | x = x.view(x.size(0), -1) 59 | x = self.fc(x) 60 | 61 | # return a list of probabilities 62 | output = [] 63 | output.append(exit0) 64 | output.append(exit1) 65 | output.append(x) 66 | return output 67 | 68 | 69 | def resnet50_earlyexit(pretrained=False, **kwargs): 70 | """Constructs a ResNet-50 model. 71 | """ 72 | model = ResNetEarlyExit(Bottleneck, [3, 4, 6, 3], **kwargs) 73 | return model 74 | -------------------------------------------------------------------------------- /distiller/models/mnist/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | """This package contains MNIST image classification models for pytorch""" 18 | 19 | from .simplenet_mnist import * -------------------------------------------------------------------------------- /distiller/models/mnist/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/models/mnist/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/models/mnist/__pycache__/simplenet_mnist.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/models/mnist/__pycache__/simplenet_mnist.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/models/mnist/simplenet_mnist.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | #      http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | """An implementation of a trivial MNIST model. 18 |   19 | The original network definition is sourced here: https://github.com/pytorch/examples/blob/master/mnist/main.py 20 | """ 21 | 22 | import torch.nn as nn 23 | import torch.nn.functional as F 24 | 25 | 26 | __all__ = ['simplenet_mnist', 'simplenet_v2_mnist'] 27 | 28 | 29 | class Simplenet(nn.Module): 30 | def __init__(self): 31 | super().__init__() 32 | self.conv1 = nn.Conv2d(1, 20, 5, 1) 33 | self.relu1 = nn.ReLU(inplace=False) 34 | self.pool1 = nn.MaxPool2d(2, 2) 35 | self.conv2 = nn.Conv2d(20, 50, 5, 1) 36 | self.relu2 = nn.ReLU(inplace=False) 37 | self.pool2 = nn.MaxPool2d(2, 2) 38 | self.fc1 = nn.Linear(4*4*50, 500) 39 | self.relu3 = nn.ReLU(inplace=False) 40 | self.fc2 = nn.Linear(500, 10) 41 | 42 | def forward(self, x): 43 | x = self.pool1(self.relu1(self.conv1(x))) 44 | x = self.pool2(self.relu2(self.conv2(x))) 45 | x = x.view(x.size(0), -1) 46 | x = self.relu3(self.fc1(x)) 47 | x = self.fc2(x) 48 | return x 49 | 50 | 51 | class Simplenet_v2(nn.Module): 52 | """ 53 | This is Simplenet but with only one small Linear layer, instead of two Linear layers, 54 | one of which is large. 55 | 26K parameters. 56 | python compress_classifier.py ${MNIST_PATH} --arch=simplenet_mnist --vs=0 --lr=0.01 57 | 58 | ==> Best [Top1: 98.970 Top5: 99.970 Sparsity:0.00 Params: 26000 on epoch: 54] 59 | """ 60 | def __init__(self): 61 | super().__init__() 62 | self.conv1 = nn.Conv2d(1, 20, 5, 1) 63 | self.relu1 = nn.ReLU(inplace=False) 64 | self.pool1 = nn.MaxPool2d(2, 2) 65 | self.conv2 = nn.Conv2d(20, 50, 5, 1) 66 | self.relu2 = nn.ReLU(inplace=False) 67 | self.pool2 = nn.MaxPool2d(2, 2) 68 | self.avgpool = nn.AvgPool2d(4, stride=1) 69 | self.fc = nn.Linear(50, 10) 70 | 71 | def forward(self, x): 72 | x = self.pool1(self.relu1(self.conv1(x))) 73 | x = self.pool2(self.relu2(self.conv2(x))) 74 | x = self.avgpool(x) 75 | x = x.view(x.size(0), -1) 76 | x = self.fc(x) 77 | return x 78 | 79 | 80 | def simplenet_mnist(): 81 | model = Simplenet() 82 | return model 83 | 84 | def simplenet_v2_mnist(): 85 | model = Simplenet_v2() 86 | return model -------------------------------------------------------------------------------- /distiller/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from .eltwise import * 18 | from .grouping import * 19 | from .matmul import * 20 | from .rnn import * 21 | from .aggregate import Norm 22 | 23 | __all__ = ['EltwiseAdd', 'EltwiseMult', 'EltwiseDiv', 'Matmul', 'BatchMatmul', 24 | 'Concat', 'Chunk', 'Split', 'Stack', 25 | 'DistillerLSTMCell', 'DistillerLSTM', 'convert_model_to_distiller_lstm', 26 | 'Norm'] 27 | -------------------------------------------------------------------------------- /distiller/modules/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/modules/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/modules/__pycache__/aggregate.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/modules/__pycache__/aggregate.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/modules/__pycache__/eltwise.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/modules/__pycache__/eltwise.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/modules/__pycache__/grouping.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/modules/__pycache__/grouping.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/modules/__pycache__/matmul.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/modules/__pycache__/matmul.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/modules/__pycache__/rnn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/modules/__pycache__/rnn.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/modules/__pycache__/tsvd.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/modules/__pycache__/tsvd.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/modules/aggregate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Norm(nn.Module): 6 | """ 7 | A module wrapper for vector/matrix norm 8 | """ 9 | def __init__(self, p='fro', dim=None, keepdim=False): 10 | super(Norm, self).__init__() 11 | self.p = p 12 | self.dim = dim 13 | self.keepdim = keepdim 14 | 15 | def forward(self, x: torch.Tensor): 16 | return torch.norm(x, p=self.p, dim=self.dim, keepdim=self.keepdim) 17 | -------------------------------------------------------------------------------- /distiller/modules/eltwise.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | import torch 17 | import torch.nn as nn 18 | 19 | 20 | class EltwiseAdd(nn.Module): 21 | def __init__(self, inplace=False): 22 | super(EltwiseAdd, self).__init__() 23 | 24 | self.inplace = inplace 25 | 26 | def forward(self, *input): 27 | res = input[0] 28 | if self.inplace: 29 | for t in input[1:]: 30 | res += t 31 | else: 32 | for t in input[1:]: 33 | res = res + t 34 | return res 35 | 36 | 37 | class EltwiseMult(nn.Module): 38 | def __init__(self, inplace=False): 39 | super(EltwiseMult, self).__init__() 40 | self.inplace = inplace 41 | 42 | def forward(self, *input): 43 | res = input[0] 44 | if self.inplace: 45 | for t in input[1:]: 46 | res *= t 47 | else: 48 | for t in input[1:]: 49 | res = res * t 50 | return res 51 | 52 | 53 | class EltwiseDiv(nn.Module): 54 | def __init__(self, inplace=False): 55 | super(EltwiseDiv, self).__init__() 56 | self.inplace = inplace 57 | 58 | def forward(self, x: torch.Tensor, y): 59 | if self.inplace: 60 | return x.div_(y) 61 | return x.div(y) 62 | 63 | -------------------------------------------------------------------------------- /distiller/modules/grouping.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | import torch 18 | import torch.nn as nn 19 | 20 | 21 | class Concat(nn.Module): 22 | def __init__(self, dim=0): 23 | super(Concat, self).__init__() 24 | self.dim = dim 25 | 26 | def forward(self, *seq): 27 | return torch.cat(seq, dim=self.dim) 28 | 29 | 30 | class Chunk(nn.Module): 31 | def __init__(self, chunks, dim=0): 32 | super(Chunk, self).__init__() 33 | self.chunks = chunks 34 | self.dim = dim 35 | 36 | def forward(self, tensor): 37 | return tensor.chunk(self.chunks, dim=self.dim) 38 | 39 | 40 | class Split(nn.Module): 41 | def __init__(self, split_size_or_sections, dim=0): 42 | super(Split, self).__init__() 43 | self.split_size_or_sections = split_size_or_sections 44 | self.dim = dim 45 | 46 | def forward(self, tensor): 47 | return torch.split(tensor, self.split_size_or_sections, dim=self.dim) 48 | 49 | 50 | class Stack(nn.Module): 51 | def __init__(self, dim=0): 52 | super(Stack, self).__init__() 53 | self.dim = dim 54 | 55 | def forward(self, seq): 56 | return torch.stack(seq, dim=self.dim) 57 | -------------------------------------------------------------------------------- /distiller/modules/matmul.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2019 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | import torch 17 | import torch.nn as nn 18 | 19 | 20 | class Matmul(nn.Module): 21 | """ 22 | A wrapper module for matmul operation between 2 tensors. 23 | """ 24 | def __init__(self): 25 | super(Matmul, self).__init__() 26 | 27 | def forward(self, a: torch.Tensor, b: torch.Tensor): 28 | return a.matmul(b) 29 | 30 | 31 | class BatchMatmul(nn.Module): 32 | """ 33 | A wrapper module for torch.bmm operation between 2 tensors. 34 | """ 35 | def __init__(self): 36 | super(BatchMatmul, self).__init__() 37 | 38 | def forward(self, a: torch.Tensor, b:torch.Tensor): 39 | return torch.bmm(a, b) -------------------------------------------------------------------------------- /distiller/modules/tsvd.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2019 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | """Truncated-SVD module. 18 | 19 | For an example of how truncated-SVD can be used, see this Jupyter notebook: 20 | https://github.com/NervanaSystems/distiller/blob/master/jupyter/truncated_svd.ipynb 21 | 22 | """ 23 | 24 | def truncated_svd(W, l): 25 | """Compress the weight matrix W of an inner product (fully connected) layer using truncated SVD. 26 | 27 | For the original implementation (MIT license), see Faster-RCNN: 28 | https://github.com/rbgirshick/py-faster-rcnn/blob/master/tools/compress_net.py 29 | We replaced numpy operations with pytorch operations (so that we can leverage the GPU). 30 | 31 | Arguments: 32 | W: N x M weights matrix 33 | l: number of singular values to retain 34 | Returns: 35 | Ul, L: matrices such that W \approx Ul*L 36 | """ 37 | 38 | U, s, V = torch.svd(W, some=True) 39 | 40 | Ul = U[:, :l] 41 | sl = s[:l] 42 | V = V.t() 43 | Vl = V[:l, :] 44 | 45 | SV = torch.mm(torch.diag(sl), Vl) 46 | return Ul, SV 47 | 48 | 49 | class TruncatedSVD(nn.Module): 50 | def __init__(self, replaced_gemm, gemm_weights, preserve_ratio): 51 | super().__init__() 52 | self.replaced_gemm = replaced_gemm 53 | print("W = {}".format(gemm_weights.shape)) 54 | self.U, self.SV = truncated_svd(gemm_weights.data, int(preserve_ratio * gemm_weights.size(0))) 55 | print("U = {}".format(self.U.shape)) 56 | 57 | self.fc_u = nn.Linear(self.U.size(1), self.U.size(0)).cuda() 58 | self.fc_u.weight.data = self.U 59 | 60 | print("SV = {}".format(self.SV.shape)) 61 | self.fc_sv = nn.Linear(self.SV.size(1), self.SV.size(0)).cuda() 62 | self.fc_sv.weight.data = self.SV#.t() 63 | 64 | def forward(self, x): 65 | x = self.fc_sv.forward(x) 66 | x = self.fc_u.forward(x) 67 | return x 68 | -------------------------------------------------------------------------------- /distiller/pruning/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | """ 18 | :mod:`distiller.pruning` is a package implementing various pruning algorithms. 19 | """ 20 | 21 | from .magnitude_pruner import MagnitudeParameterPruner 22 | from .automated_gradual_pruner import AutomatedGradualPruner, \ 23 | L1RankedStructureParameterPruner_AGP, \ 24 | L2RankedStructureParameterPruner_AGP, \ 25 | ActivationAPoZRankedFilterPruner_AGP, \ 26 | ActivationMeanRankedFilterPruner_AGP, \ 27 | GradientRankedFilterPruner_AGP, \ 28 | RandomRankedFilterPruner_AGP, \ 29 | BernoulliFilterPruner_AGP 30 | from .level_pruner import SparsityLevelParameterPruner 31 | from .sensitivity_pruner import SensitivityPruner 32 | from .splicing_pruner import SplicingPruner 33 | from .structure_pruner import StructureParameterPruner 34 | from .ranked_structures_pruner import L1RankedStructureParameterPruner, \ 35 | L2RankedStructureParameterPruner, \ 36 | ActivationAPoZRankedFilterPruner, \ 37 | ActivationMeanRankedFilterPruner, \ 38 | GradientRankedFilterPruner, \ 39 | RandomRankedFilterPruner, \ 40 | RandomLevelStructureParameterPruner, \ 41 | BernoulliFilterPruner, \ 42 | FMReconstructionChannelPruner 43 | from .baidu_rnn_pruner import BaiduRNNPruner 44 | from .greedy_filter_pruning import greedy_pruner 45 | 46 | del magnitude_pruner 47 | del automated_gradual_pruner 48 | del level_pruner 49 | del sensitivity_pruner 50 | del structure_pruner 51 | del ranked_structures_pruner 52 | -------------------------------------------------------------------------------- /distiller/pruning/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/pruning/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/pruning/__pycache__/automated_gradual_pruner.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/pruning/__pycache__/automated_gradual_pruner.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/pruning/__pycache__/baidu_rnn_pruner.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/pruning/__pycache__/baidu_rnn_pruner.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/pruning/__pycache__/greedy_filter_pruning.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/pruning/__pycache__/greedy_filter_pruning.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/pruning/__pycache__/level_pruner.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/pruning/__pycache__/level_pruner.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/pruning/__pycache__/magnitude_pruner.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/pruning/__pycache__/magnitude_pruner.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/pruning/__pycache__/pruner.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/pruning/__pycache__/pruner.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/pruning/__pycache__/ranked_structures_pruner.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/pruning/__pycache__/ranked_structures_pruner.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/pruning/__pycache__/sensitivity_pruner.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/pruning/__pycache__/sensitivity_pruner.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/pruning/__pycache__/splicing_pruner.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/pruning/__pycache__/splicing_pruner.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/pruning/__pycache__/structure_pruner.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/pruning/__pycache__/structure_pruner.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/pruning/baidu_rnn_pruner.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from .pruner import _ParameterPruner 18 | from .level_pruner import SparsityLevelParameterPruner 19 | from distiller.utils import * 20 | 21 | import distiller 22 | 23 | class BaiduRNNPruner(_ParameterPruner): 24 | """An element-wise pruner for RNN networks. 25 | 26 | Narang, Sharan & Diamos, Gregory & Sengupta, Shubho & Elsen, Erich. (2017). 27 | Exploring Sparsity in Recurrent Neural Networks. 28 | (https://arxiv.org/abs/1704.05119) 29 | 30 | This implementation slightly differs from the algorithm original paper in that 31 | the algorithm changes the pruning rate at the training-step granularity, while 32 | Distiller controls the pruning rate at epoch granularity. 33 | 34 | Equation (1): 35 | 36 | 2 * q * freq 37 | start_slope = ------------------------------------------------------- 38 | 2 * (ramp_itr - start_itr ) + 3 * (end_itr - ramp_itr ) 39 | 40 | 41 | Pruning algorithm (1): 42 | 43 | if current itr < ramp itr then 44 | threshold = start_slope * (current_itr - start_itr + 1) / freq 45 | else 46 | threshold = (start_slope * (ramp_itr - start_itr + 1) + 47 | ramp_slope * (current_itr - ramp_itr + 1)) / freq 48 | end if 49 | 50 | mask = abs(param) < threshold 51 | """ 52 | 53 | def __init__(self, name, q, ramp_epoch_offset, ramp_slope_mult, weights): 54 | # Initialize the pruner, using a configuration that originates from the 55 | # schedule YAML file. 56 | super(BaiduRNNPruner, self).__init__(name) 57 | self.params_names = weights 58 | assert self.params_names 59 | 60 | # This is the 'q' value that appears in equation (1) of the paper 61 | self.q = q 62 | # This is the number of epochs to wait after starting_epoch, before we 63 | # begin ramping up the pruning rate. 64 | # In other words, between epochs 'starting_epoch' and 'starting_epoch'+ 65 | # self.ramp_epoch_offset the pruning slope is 'self.start_slope'. After 66 | # that, the slope is 'self.ramp_slope' 67 | self.ramp_epoch_offset = ramp_epoch_offset 68 | self.ramp_slope_mult = ramp_slope_mult 69 | self.ramp_slope = None 70 | self.start_slope = None 71 | 72 | def set_param_mask(self, param, param_name, zeros_mask_dict, meta): 73 | if param_name not in self.params_names: 74 | return 75 | 76 | starting_epoch = meta['starting_epoch'] 77 | current_epoch = meta['current_epoch'] 78 | ending_epoch = meta['ending_epoch'] 79 | freq = meta['frequency'] 80 | 81 | ramp_epoch = self.ramp_epoch_offset + starting_epoch 82 | 83 | # Calculate start slope 84 | if self.start_slope is None: 85 | # We want to calculate these values only once, and then cache them. 86 | self.start_slope = (2 * self.q * freq) / (2*(ramp_epoch - starting_epoch) + 3*(ending_epoch - ramp_epoch)) 87 | self.ramp_slope = self.start_slope * self.ramp_slope_mult 88 | 89 | if current_epoch < ramp_epoch: 90 | eps = self.start_slope * (current_epoch - starting_epoch + 1) / freq 91 | else: 92 | eps = (self.start_slope * (ramp_epoch - starting_epoch + 1) + 93 | self.ramp_slope * (current_epoch - ramp_epoch + 1)) / freq 94 | 95 | # After computing the threshold, we can create the mask 96 | zeros_mask_dict[param_name].mask = distiller.threshold_mask(param.data, eps) 97 | -------------------------------------------------------------------------------- /distiller/pruning/level_pruner.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | import torch 18 | from .pruner import _ParameterPruner 19 | import distiller 20 | 21 | class SparsityLevelParameterPruner(_ParameterPruner): 22 | """Prune to an exact pruning level specification. 23 | 24 | This pruner is very similar to MagnitudeParameterPruner, but instead of 25 | specifying an absolute threshold for pruning, you specify a target sparsity 26 | level (expressed as a fraction: 0.5 means 50% sparsity.) 27 | 28 | To find the correct threshold, we view the tensor as one large 1D vector, sort 29 | it using the absolute values of the elements, and then take topk elements. 30 | """ 31 | 32 | def __init__(self, name, levels, **kwargs): 33 | super(SparsityLevelParameterPruner, self).__init__(name) 34 | self.levels = levels 35 | assert self.levels 36 | 37 | def set_param_mask(self, param, param_name, zeros_mask_dict, meta): 38 | # If there is a specific sparsity level specified for this module, then 39 | # use it. Otherwise try to use the default level ('*'). 40 | desired_sparsity = self.levels.get(param_name, self.levels.get('*', 0)) 41 | if desired_sparsity == 0: 42 | return 43 | 44 | self.prune_level(param, param_name, zeros_mask_dict, desired_sparsity) 45 | 46 | @staticmethod 47 | def prune_level(param, param_name, zeros_mask_dict, desired_sparsity): 48 | bottomk, _ = torch.topk(param.abs().view(-1), int(desired_sparsity * param.numel()), largest=False, sorted=True) 49 | threshold = bottomk.data[-1] # This is the largest element from the group of elements that we prune away 50 | zeros_mask_dict[param_name].mask = distiller.threshold_mask(param.data, threshold) 51 | -------------------------------------------------------------------------------- /distiller/pruning/magnitude_pruner.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from .pruner import _ParameterPruner 18 | import distiller 19 | 20 | 21 | class MagnitudeParameterPruner(_ParameterPruner): 22 | """This is the most basic magnitude-based pruner. 23 | 24 | This pruner supports configuring a scalar threshold for each layer. 25 | A default threshold is mandatory and is used for layers without explicit 26 | threshold setting. 27 | 28 | """ 29 | def __init__(self, name, thresholds, **kwargs): 30 | """ 31 | Usually, a Pruner is constructed by the compression schedule parser 32 | found in distiller/config.py. 33 | The constructor is passed a dictionary of thresholds, as explained below. 34 | 35 | Args: 36 | name (string): the name of the pruner (used only for debug) 37 | thresholds (dict): a disctionary of thresholds, with the key being the 38 | parameter name. 39 | A special key, '*', represents the default threshold value. If 40 | set_param_mask is invoked on a parameter tensor that does not have 41 | an explicit entry in the 'thresholds' dictionary, then this default 42 | value is used. 43 | Currently it is mandatory to include a '*' key in 'thresholds'. 44 | """ 45 | super(MagnitudeParameterPruner, self).__init__(name) 46 | assert thresholds is not None 47 | # Make sure there is a default threshold to use 48 | assert '*' in thresholds 49 | self.thresholds = thresholds 50 | 51 | def set_param_mask(self, param, param_name, zeros_mask_dict, meta): 52 | threshold = self.thresholds.get(param_name, self.thresholds['*']) 53 | zeros_mask_dict[param_name].mask = distiller.threshold_mask(param.data, threshold) 54 | -------------------------------------------------------------------------------- /distiller/pruning/pruner.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | import torch 18 | import distiller 19 | 20 | class _ParameterPruner(object): 21 | """Base class for all pruners. 22 | 23 | Arguments: 24 | name: pruner name is used mainly for debugging. 25 | """ 26 | def __init__(self, name): 27 | self.name = name 28 | 29 | def set_param_mask(self, param, param_name, zeros_mask_dict, meta): 30 | raise NotImplementedError 31 | 32 | def threshold_model(model, threshold): 33 | """Threshold an entire model using the provided threshold 34 | 35 | This function prunes weights only (biases are left untouched). 36 | """ 37 | for name, p in model.named_parameters(): 38 | if 'weight' in name: 39 | mask = distiller.threshold_mask(p.data, threshold) 40 | p.data = p.data.mul_(mask) 41 | -------------------------------------------------------------------------------- /distiller/pruning/sensitivity_pruner.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from .pruner import _ParameterPruner 18 | import distiller 19 | import torch 20 | 21 | class SensitivityPruner(_ParameterPruner): 22 | """Use algorithm from "Learning both Weights and Connections for Efficient 23 | Neural Networks" - https://arxiv.org/pdf/1506.02626v3.pdf 24 | 25 | I.e.: "The pruning threshold is chosen as a quality parameter multiplied 26 | by the standard deviation of a layers weights." 27 | In this code, the "quality parameter" is referred to as "sensitivity" and 28 | is based on the values learned from performing sensitivity analysis. 29 | 30 | Note that this implementation deviates slightly from the algorithm Song Han 31 | describes in his PhD dissertation, in that the threshold value is set only 32 | once. In his PhD dissertation, Song Han describes a growing threshold, at 33 | each iteration. This requires n+1 hyper-parameters (n being the number of 34 | pruning iterations we use): the threshold and the threshold increase (delta) 35 | at each pruning iteration. 36 | The implementation that follows, takes advantage of the fact that as pruning 37 | progresses, more weights are pulled toward zero, and therefore the threshold 38 | "traps" more weights. Thus, we can use less hyper-parameters and achieve the 39 | same results. 40 | """ 41 | 42 | def __init__(self, name, sensitivities, **kwargs): 43 | super(SensitivityPruner, self).__init__(name) 44 | self.sensitivities = sensitivities 45 | 46 | def set_param_mask(self, param, param_name, zeros_mask_dict, meta): 47 | if not hasattr(param, 'stddev'): 48 | param.stddev = torch.std(param).item() 49 | 50 | if param_name not in self.sensitivities: 51 | if '*' not in self.sensitivities: 52 | return 53 | else: 54 | sensitivity = self.sensitivities['*'] 55 | else: 56 | sensitivity = self.sensitivities[param_name] 57 | 58 | threshold = param.stddev * sensitivity 59 | 60 | # After computing the threshold, we can create the mask 61 | zeros_mask_dict[param_name].mask = distiller.threshold_mask(param.data, threshold) 62 | -------------------------------------------------------------------------------- /distiller/pruning/splicing_pruner.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | 18 | from .pruner import _ParameterPruner 19 | import torch 20 | import logging 21 | msglogger = logging.getLogger() 22 | 23 | 24 | class SplicingPruner(_ParameterPruner): 25 | """A pruner that both prunes and splices connections. 26 | 27 | The idea of pruning and splicing working in tandem was first proposed in the following 28 | NIPS paper from Intel Labs China in 2016: 29 | Dynamic Network Surgery for Efficient DNNs, Yiwen Guo, Anbang Yao, Yurong Chen. 30 | NIPS 2016, https://arxiv.org/abs/1608.04493. 31 | 32 | A SplicingPruner works best with a Dynamic Network Surgery schedule. 33 | The original Caffe code from the authors of the paper is available here: 34 | https://github.com/yiwenguo/Dynamic-Network-Surgery/blob/master/src/caffe/layers/compress_conv_layer.cpp 35 | """ 36 | 37 | def __init__(self, name, sensitivities, low_thresh_mult, hi_thresh_mult, sensitivity_multiplier=0): 38 | """Arguments: 39 | """ 40 | super(SplicingPruner, self).__init__(name) 41 | self.sensitivities = sensitivities 42 | self.low_thresh_mult = low_thresh_mult 43 | self.hi_thresh_mult = hi_thresh_mult 44 | self.sensitivity_multiplier = sensitivity_multiplier 45 | 46 | def set_param_mask(self, param, param_name, zeros_mask_dict, meta): 47 | if param_name not in self.sensitivities: 48 | if '*' not in self.sensitivities: 49 | return 50 | else: 51 | sensitivity = self.sensitivities['*'] 52 | else: 53 | sensitivity = self.sensitivities[param_name] 54 | 55 | if not hasattr(param, '_std'): 56 | # Compute the mean and standard-deviation once, and cache them. 57 | param._std = torch.std(param.abs()).item() 58 | param._mean = torch.mean(param.abs()).item() 59 | 60 | if self.sensitivity_multiplier > 0: 61 | # Linearly growing sensitivity - for now this is hard-coded 62 | starting_epoch = meta['starting_epoch'] 63 | current_epoch = meta['current_epoch'] 64 | sensitivity *= (current_epoch - starting_epoch) * self.sensitivity_multiplier + 1 65 | 66 | threshold_low = (param._mean + param._std * sensitivity) * self.low_thresh_mult 67 | threshold_hi = (param._mean + param._std * sensitivity) * self.hi_thresh_mult 68 | 69 | if zeros_mask_dict[param_name].mask is None: 70 | zeros_mask_dict[param_name].mask = torch.ones_like(param) 71 | 72 | # This code performs the code in equation (3) of the "Dynamic Network Surgery" paper: 73 | # 74 | # 0 if a > |W| 75 | # h(W) = mask if a <= |W| < b 76 | # 1 if b <= |W| 77 | # 78 | # h(W) is the so-called "network surgery function". 79 | # mask is the mask used in the previous iteration. 80 | # a and b are the low and high thresholds, respectively. 81 | # We followed the example implementation from Yiwen Guo in Caffe, and used the 82 | # weight tensor's starting mean and std. 83 | # This is very similar to the initialization performed by distiller.SensitivityPruner. 84 | 85 | mask = zeros_mask_dict[param_name].mask 86 | zeros, ones = torch.tensor([0]).type(mask.type()), torch.tensor([1]).type(mask.type()) 87 | weights_abs = param.abs() 88 | new_mask = torch.where(threshold_low > weights_abs, zeros, mask) 89 | new_mask = torch.where(threshold_hi <= weights_abs, ones, new_mask) 90 | zeros_mask_dict[param_name].mask = new_mask 91 | -------------------------------------------------------------------------------- /distiller/pruning/structure_pruner.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | import logging 18 | from .pruner import _ParameterPruner 19 | import distiller 20 | msglogger = logging.getLogger() 21 | 22 | class StructureParameterPruner(distiller.GroupThresholdMixin, _ParameterPruner): 23 | """Prune parameter structures. 24 | 25 | Pruning criterion: average L1-norm. If the average L1-norm (absolute value) of the eleements 26 | in the structure is below threshold, then the structure is pruned. 27 | 28 | We use the average, instead of plain L1-norm, because we don't want the threshold to depend on 29 | the structure size. 30 | """ 31 | def __init__(self, name, model, reg_regims, threshold_criteria): 32 | super(StructureParameterPruner, self).__init__(name) 33 | self.name = name 34 | self.model = model 35 | self.reg_regims = reg_regims 36 | self.threshold_criteria = threshold_criteria 37 | assert threshold_criteria in ["Max", "Mean_Abs"] 38 | 39 | def set_param_mask(self, param, param_name, zeros_mask_dict, meta): 40 | if param_name not in self.reg_regims.keys(): 41 | return 42 | 43 | group_type = self.reg_regims[param_name][1] 44 | threshold = self.reg_regims[param_name][0] 45 | zeros_mask_dict[param_name].mask = self.group_threshold_mask(param, 46 | group_type, 47 | threshold, 48 | self.threshold_criteria) 49 | -------------------------------------------------------------------------------- /distiller/quantization/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from .quantizer import Quantizer 18 | from .range_linear import RangeLinearQuantWrapper, RangeLinearQuantParamLayerWrapper, PostTrainLinearQuantizer, \ 19 | LinearQuantMode, QuantAwareTrainRangeLinearQuantizer, add_post_train_quant_args, NCFQuantAwareTrainQuantizer, \ 20 | RangeLinearQuantConcatWrapper, RangeLinearQuantEltwiseAddWrapper, RangeLinearQuantEltwiseMultWrapper, ClipMode 21 | from .clipped_linear import LinearQuantizeSTE, ClippedLinearQuantization, WRPNQuantizer, DorefaQuantizer, PACTQuantizer 22 | 23 | del quantizer 24 | del range_linear 25 | del clipped_linear 26 | -------------------------------------------------------------------------------- /distiller/quantization/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/quantization/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/quantization/__pycache__/clipped_linear.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/quantization/__pycache__/clipped_linear.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/quantization/__pycache__/q_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/quantization/__pycache__/q_utils.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/quantization/__pycache__/quantizer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/quantization/__pycache__/quantizer.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/quantization/__pycache__/range_linear.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/quantization/__pycache__/range_linear.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/quantization/__pycache__/sim_bn_fold.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/quantization/__pycache__/sim_bn_fold.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/regularization/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from .l1_regularizer import L1Regularizer 18 | from .group_regularizer import GroupLassoRegularizer, GroupVarianceRegularizer 19 | 20 | del l1_regularizer 21 | del group_regularizer 22 | -------------------------------------------------------------------------------- /distiller/regularization/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/regularization/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/regularization/__pycache__/drop_filter.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/regularization/__pycache__/drop_filter.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/regularization/__pycache__/group_regularizer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/regularization/__pycache__/group_regularizer.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/regularization/__pycache__/l1_regularizer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/regularization/__pycache__/l1_regularizer.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/regularization/__pycache__/regularizer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/distiller/regularization/__pycache__/regularizer.cpython-36.pyc -------------------------------------------------------------------------------- /distiller/regularization/drop_filter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from .regularizer import _Regularizer 6 | 7 | 8 | class Conv2dWithMask(nn.Conv2d): 9 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 10 | padding=0, dilation=1, groups=1, bias=True): 11 | 12 | super(Conv2dWithMask, self).__init__( 13 | in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, 14 | padding=padding, dilation=dilation, groups=groups, bias=bias) 15 | 16 | self.test_mask = None 17 | self.p_mask = 1.0 18 | self.frequency = 16 19 | 20 | def forward(self, input): 21 | if self.training: 22 | self.frequency -= 1 23 | if self.frequency == 0: 24 | sample = np.random.binomial(n=1, p=self.p_mask, size=self.out_channels) 25 | param = self.weight 26 | l1norm = param.detach().view(param.size(0), -1).norm(p=1, dim=1) 27 | mask = torch.tensor(sample) 28 | mask = mask.expand(param.size(1) * param.size(2) * param.size(3), param.size(0)).t().contiguous() 29 | mask = mask.view(self.weight.shape).to(param.device) 30 | mask = mask.type(param.type()) 31 | masked_weights = self.weight * mask 32 | masked_l1norm = masked_weights.detach().view(param.size(0), -1).norm(p=1, dim=1) 33 | pruning_factor = (masked_l1norm.sum() / l1norm.sum()).item() 34 | pruning_factor = max(0.2, pruning_factor) 35 | weight = masked_weights / pruning_factor 36 | self.frequency = 16 37 | else: 38 | weight = self.weight 39 | else: 40 | weight = self.weight 41 | return F.conv2d(input, weight, self.bias, self.stride, 42 | self.padding, self.dilation, self.groups) 43 | 44 | 45 | # replaces all conv2d layers in target`s model with 'Conv2dWithMask' 46 | def replace_conv2d(container): 47 | for name, module in container.named_children(): 48 | if (isinstance(module, nn.Conv2d)): 49 | print("replacing: ", name) 50 | new_module = Conv2dWithMask(in_channels=module.in_channels, 51 | out_channels=module.out_channels, 52 | kernel_size=module.kernel_size, padding=module.padding, 53 | stride=module.stride, bias=module.bias) 54 | setattr(container, name, new_module) 55 | replace_conv2d(module) 56 | 57 | 58 | class DropFilterRegularizer(_Regularizer): 59 | def __init__(self, name, model, reg_regims, threshold_criteria=None): 60 | super().__init__(name, model, reg_regims, threshold_criteria) 61 | replace_conv2d(model) 62 | -------------------------------------------------------------------------------- /distiller/regularization/l1_regularizer.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | """L1-norm regularization""" 18 | 19 | import torch 20 | import math 21 | import numpy as np 22 | import distiller 23 | from .regularizer import _Regularizer, EPSILON 24 | 25 | class L1Regularizer(_Regularizer): 26 | def __init__(self, name, model, reg_regims, threshold_criteria=None): 27 | super(L1Regularizer, self).__init__(name, model, reg_regims, threshold_criteria) 28 | 29 | def loss(self, param, param_name, regularizer_loss, zeros_mask_dict): 30 | if param_name in self.reg_regims: 31 | strength = self.reg_regims[param_name] 32 | regularizer_loss += L1Regularizer.__add_l1(param, strength) 33 | 34 | return regularizer_loss 35 | 36 | def threshold(self, param, param_name, zeros_mask_dict): 37 | """Soft threshold for L1-norm regularizer""" 38 | if self.threshold_criteria is None or param_name not in self.reg_regims: 39 | return 40 | 41 | strength = self.reg_regims[param_name] 42 | zeros_mask_dict[param_name].mask = distiller.threshold_mask(param.data, threshold=strength) 43 | zeros_mask_dict[param_name].is_regularization_mask = True 44 | 45 | @staticmethod 46 | def __add_l1(var, strength): 47 | return var.abs().sum() * strength 48 | 49 | @staticmethod 50 | def __add_l1_all(loss, model, reg_regims): 51 | for param_name, param in model.named_parameters(): 52 | if param_name in reg_regims.keys(): 53 | strength = reg_regims[param_name] 54 | loss += L1Regularizer.__add_l1(param, strength) 55 | -------------------------------------------------------------------------------- /distiller/regularization/regularizer.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | EPSILON = 1e-8 18 | 19 | class _Regularizer(object): 20 | def __init__(self, name, model, reg_regims, threshold_criteria): 21 | """Regularization base class. 22 | 23 | Args: 24 | reg_regims: regularization regiment. A dictionary of 25 | reg_regims[] = [ lambda, structure-type] 26 | """ 27 | self.name = name 28 | self.model = model 29 | self.reg_regims = reg_regims 30 | self.threshold_criteria = threshold_criteria 31 | 32 | def loss(self, param, param_name, regularizer_loss, zeros_mask_dict): 33 | raise NotImplementedError 34 | 35 | def threshold(self, param, param_name, zeros_mask_dict): 36 | raise NotImplementedError 37 | -------------------------------------------------------------------------------- /example~: -------------------------------------------------------------------------------- 1 | channels: 2 | - http://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ 3 | show_channel_urls: true 4 | -------------------------------------------------------------------------------- /ljt_mobilefacenet_y2.sh: -------------------------------------------------------------------------------- 1 | # BUSID 2 | python main.py --mode sa \ 3 | --model mobilefacenet_y2_ljt \ 4 | --best_model_path /home/yeluyue/lz/model/2020-08-23-08-09_CombineMargin-ljt83-m0.9m0.4m0.15s64_le_re_0.4_144x122_2020-07-30-Full-CLEAN-0803-2-ID-INTRA-MIDDLE-30-INTER-90-HARD_MobileFaceNety2-d512-k-9-8_model_iter-125993_TYLG-0.7520_PadMaskYTBYGlassM280-0.9104_BusIDPhoto-0.7489-noamp.pth \ 5 | --test_root_path /home/yeluyue/lz/dataset/200914_data_model_ljt/BusID/le_re_0.4_144x122 \ 6 | --img_list_label_path /home/yeluyue/lz/dataset/200914_data_model_ljt/BusID/id_life_image_list_bmppair.txt \ 7 | --data_source company \ 8 | --fpgm 9 | 10 | python main.py --mode prune \ 11 | --model mobilefacenet_y2_ljt \ 12 | --save_model_pt \ 13 | --data_source company \ 14 | --fpgm \ 15 | --best_model_path /home/yeluyue/lz/model/2020-08-23-08-09_CombineMargin-ljt83-m0.9m0.4m0.15s64_le_re_0.4_144x122_2020-07-30-Full-CLEAN-0803-2-ID-INTRA-MIDDLE-30-INTER-90-HARD_MobileFaceNety2-d512-k-9-8_model_iter-125993_TYLG-0.7520_PadMaskYTBYGlassM280-0.9104_BusIDPhoto-0.7489-noamp.pth \ 16 | --test_root_path /home/yeluyue/lz/dataset/200914_data_model_ljt/BusID/le_re_0.4_144x122 \ 17 | --img_list_label_path /home/yeluyue/lz/dataset/200914_data_model_ljt/BusID/id_life_image_list_bmppair.txt \ 18 | 19 | 20 | # TYLG 21 | python main.py --mode sa \ 22 | --model mobilefacenet_y2_ljt \ 23 | --best_model_path /home/yeluyue/lz/model/2020-08-23-08-09_CombineMargin-ljt83-m0.9m0.4m0.15s64_le_re_0.4_144x122_2020-07-30-Full-CLEAN-0803-2-ID-INTRA-MIDDLE-30-INTER-90-HARD_MobileFaceNety2-d512-k-9-8_model_iter-125993_TYLG-0.7520_PadMaskYTBYGlassM280-0.9104_BusIDPhoto-0.7489-noamp.pth \ 24 | --test_root_path /home/yeluyue/lz/dataset/200914_data_model_ljt/TYLG/le_re_0.4_144x122 \ 25 | --img_list_label_path /home/yeluyue/lz/dataset/200914_data_model_ljt/TYLG/id_life_image_list_bmppair.txt \ 26 | --data_source company \ 27 | --fpgm 28 | 29 | python main.py --mode prune \ 30 | --model mobilefacenet_y2_ljt \ 31 | --save_model_pt \ 32 | --data_source company \ 33 | --fpgm \ 34 | --best_model_path /home/yeluyue/lz/model/2020-08-23-08-09_CombineMargin-ljt83-m0.9m0.4m0.15s64_le_re_0.4_144x122_2020-07-30-Full-CLEAN-0803-2-ID-INTRA-MIDDLE-30-INTER-90-HARD_MobileFaceNety2-d512-k-9-8_model_iter-125993_TYLG-0.7520_PadMaskYTBYGlassM280-0.9104_BusIDPhoto-0.7489-noamp.pth \ 35 | --test_root_path /home/yeluyue/lz/dataset/200914_data_model_ljt/TYLG/le_re_0.4_144x122 \ 36 | --img_list_label_path /home/yeluyue/lz/dataset/200914_data_model_ljt/TYLG/id_life_image_list_bmppair.txt \ 37 | 38 | # XCH 39 | python main.py --mode sa \ 40 | --model mobilefacenet_y2_ljt \ 41 | --best_model_path /home/yeluyue/lz/model/2020-08-23-08-09_CombineMargin-ljt83-m0.9m0.4m0.15s64_le_re_0.4_144x122_2020-07-30-Full-CLEAN-0803-2-ID-INTRA-MIDDLE-30-INTER-90-HARD_MobileFaceNety2-d512-k-9-8_model_iter-125993_TYLG-0.7520_PadMaskYTBYGlassM280-0.9104_BusIDPhoto-0.7489-noamp.pth \ 42 | --test_root_path /home/yeluyue/lz/dataset/200914_data_model_ljt/XCH/le_re_0.4_144x122 \ 43 | --img_list_label_path /home/yeluyue/lz/dataset/200914_data_model_ljt/XCH/id_life_image_list_bmppair.txt \ 44 | --data_source company \ 45 | --fpgm -------------------------------------------------------------------------------- /ljt_shufflefacenet_v2.sh: -------------------------------------------------------------------------------- 1 | # BUSID 2 | python main.py --mode sa \ 3 | --model shufflefacenet_v2_ljt \ 4 | --best_model_path /home/yeluyue/lz/model/2020-09-15-10-53_CombineMargin-ljt914-m0.9m0.4m0.15s64_le_re_0.4_144x122_2020-07-30-Full-CLEAN-0803-2-MIDDLE-30_ShuffleFaceNetA-2.0-d512_model_iter-76608_TYLG-0.7319_XCHoldClean-0.8198_BusIDPhoto-0.7310-noamp.pth \ 5 | --test_root_path /home/yeluyue/lz/dataset/200914_data_model_ljt/BusID/le_re_0.4_144x122 \ 6 | --img_list_label_path /home/yeluyue/lz/dataset/200914_data_model_ljt/BusID/id_life_image_list_bmppair.txt \ 7 | --data_source company \ 8 | --fpgm 9 | 10 | # XCH 11 | python main.py --mode sa \ 12 | --model shufflefacenet_v2_ljt \ 13 | --best_model_path /home/yeluyue/lz/model/2020-09-15-10-53_CombineMargin-ljt914-m0.9m0.4m0.15s64_le_re_0.4_144x122_2020-07-30-Full-CLEAN-0803-2-MIDDLE-30_ShuffleFaceNetA-2.0-d512_model_iter-76608_TYLG-0.7319_XCHoldClean-0.8198_BusIDPhoto-0.7310-noamp.pth \ 14 | --test_root_path /home/yeluyue/lz/dataset/200914_data_model_ljt/XCH/le_re_0.4_144x122 \ 15 | --img_list_label_path /home/yeluyue/lz/dataset/200914_data_model_ljt/XCH/id_life_image_list_bmppair.txt \ 16 | --data_source company \ 17 | --fpgm 18 | 19 | # TYLG 20 | python main.py --mode sa \ 21 | --model shufflefacenet_v2_ljt \ 22 | --best_model_path /home/yeluyue/lz/model/2020-09-15-10-53_CombineMargin-ljt914-m0.9m0.4m0.15s64_le_re_0.4_144x122_2020-07-30-Full-CLEAN-0803-2-MIDDLE-30_ShuffleFaceNetA-2.0-d512_model_iter-76608_TYLG-0.7319_XCHoldClean-0.8198_BusIDPhoto-0.7310-noamp.pth \ 23 | --test_root_path /home/yeluyue/lz/dataset/200914_data_model_ljt/TYLG/le_re_0.4_144x122 \ 24 | --img_list_label_path /home/yeluyue/lz/dataset/200914_data_model_ljt/TYLG/id_life_image_list_bmppair.txt \ 25 | --data_source company \ 26 | --fpgm 27 | 28 | python main.py --mode prune \ 29 | --model shufflefacenet_v2_ljt \ 30 | --save_model_pt \ 31 | --data_source company \ 32 | --fpgm \ 33 | --best_model_path /home/linx/model/ljt/2020-09-15-10-53_CombineMargin-ljt914-m0.9m0.4m0.15s64_le_re_0.4_144x122_2020-07-30-Full-CLEAN-0803-2-MIDDLE-30_ShuffleFaceNetA-2.0-d512_model_iter-76608_TYLG-0.7319_XCHoldClean-0.8198_BusIDPhoto-0.7310-noamp.pth \ 34 | --test_root_path /home/linx/dataset/company_test_data/TYLG/le_re_0.4_144x122 \ 35 | --img_list_label_path /home/linx/dataset/company_test_data/TYLG/id_life_image_list_bmppair.txt -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: LinX 3 | # datetime: 2019/11/12 下午1:38 4 | import argparse 5 | from sensitivity_analysis import sensitivity_analysis 6 | from pruning import prune 7 | from quantization import quantization 8 | from datetime import datetime 9 | import os 10 | from train_module.train_with_insight_face import face_learner 11 | 12 | 13 | def get_time(): 14 | return (str(datetime.now())[:-10]).replace(' ', '-').replace(':', '-') 15 | 16 | 17 | def main(): 18 | parser = argparse.ArgumentParser(description='prune face recognition model') 19 | 20 | # 剪枝 21 | parser.add_argument('--lr', default=0.01, type=float, help='retrain学习率, 一般为训练时的1/10') 22 | parser.add_argument('--weight_decay', default=4e-5, type=float, help='学习率衰减') 23 | parser.add_argument('--save_pruned_model_root', default='work_space/models/pruned_model/', help='剪枝模型定义和文件保存文件夹') 24 | parser.add_argument('--momentum', default=0.9, type=float) 25 | 26 | parser.add_argument('--epoch', default=30, type=int, help='剪枝后重训练多少个epoch') 27 | parser.add_argument('--head_path', default=None, help='训练头') 28 | parser.add_argument('--device', default='cuda:0') 29 | 30 | parser.add_argument('--print_freq', type=int, default=1, help='每隔多少次打印准确度信息') 31 | parser.add_argument('--save_model_pt', default=False, action='store_true', help='是否保存pt文件') 32 | 33 | parser.add_argument('--batch_size', type=int, default=256) 34 | parser.add_argument('--embedding_size', type=int, default=512) 35 | parser.add_argument('--pruned_save_model_path', default='work_space/pruned_model', help='剪枝后模型保存路径') 36 | parser.add_argument('--sensitivity_csv_path', default='work_space/sensitivity_data', help='剪枝敏感度分析后的csv文件保存路径') 37 | 38 | # 每次运行需要确定以下参数 39 | parser.add_argument('--mode', default=None, choices=['prune', 'quantization', 'test', 'sa', 'finetune'], help='prune表示仅仅剪枝,quantization表示量化' 40 | 'sa表示sensitivity analysis,' 41 | 'finetune表示剪枝并finetune') 42 | 43 | parser.add_argument('--best_model_path', default=None, help='已经训练好的最好的模型文件路径,准备用来剪枝') 44 | parser.add_argument('--test_root_path', default=None, help='测试集root路径') 45 | parser.add_argument('--img_list_label_path', default=None, help='测试集pair list路径') 46 | parser.add_argument('--model', default=None, 47 | choices=['mobilefacenet', 'resnet34', 'mobilefacenet_y2', 'resnet50', 'resnet100', 48 | 'mobilefacenet_lzc', 'mobilenetv3', 'resnet34_lzc', 'resnet50_imagenet', 49 | 'mobilefacenet_y2_ljt', 'shufflefacenet_v2_ljt', 'resnet_50_ljt', 'resnet_100_ljt' 50 | ], help='对哪个模型剪枝') 51 | 52 | parser.add_argument('--is_save', default=False, action='store_true', help='是否保存模型文件') 53 | 54 | parser.add_argument('--from_data_parallel', action='store_true', default=False, help='模型是否来自多卡训练') 55 | 56 | parser.add_argument('--data_source', choices=['lfw', 'company', 'company_zkx'], default='None', 57 | help='测试时使用哪个测试集, company->zy的resnet50和resnet100, company_zkx->zkx的mobilefacenet_y2') 58 | 59 | parser.add_argument('--fpgm', action='store_true', default=False, help='是否使用几何中位数剪枝') 60 | parser.add_argument('--hrank', action='store_true', default=False, help='是否使用HRank剪枝') 61 | parser.add_argument('--rank_path', default='./work_space/rank_conv/', help='HRank配置文件') 62 | 63 | parser.add_argument('--yaml_path', default='yaml_file/auto_yaml.yaml', help='剪枝配置文件') 64 | 65 | parser.add_argument('--cal_flops_and_forward', default=False, action='store_true', help='是否测试flops和前向时间') 66 | 67 | parser.add_argument('--test_batch_size', type=int, default=256) 68 | 69 | # 下面是量化时需要确定的参数 70 | parser.add_argument('--quantize-mode', type=str, choices=['symmetric', 'asymmetric-signed', 'unsigned'], default='symmetric', 71 | help='量化模式,将权重值映射到对称,有符号非对称和无符号非对称区间') 72 | 73 | parser.add_argument('--fp16', action='store_true', default=False, help='采用半精度量化,设置了此模式,上面的模式都会失效') 74 | 75 | parser.add_argument('--input_size', type=int, default=112, help='输出图片大小') 76 | 77 | parser.add_argument('--quantized_save_model_path', default='work_space/quantized_model', help='量化后模型保存路径') 78 | 79 | # finetune时所需参数 80 | parser.add_argument('--pruned_checkpoint', type=str, default=None, help='剪枝后的模型文件路径') 81 | parser.add_argument('--train_data_path', type=str, default=None, help='finetune所需训练集的路径') 82 | parser.add_argument('--milestones', type=str, default='12,15,18', help='规定在第几个epoch学习率下降') 83 | parser.add_argument('--train_batch_size', type=int, default=64, help='训练batch size') 84 | parser.add_argument('--pin_memory', type=bool, default=True) 85 | parser.add_argument('--num_workers', type=int, default=4) 86 | parser.add_argument('--work_path', type=str, default='work_space/finetune', help='训练过程产生的文件存放目录') 87 | parser.add_argument('--finetune_pruned_model', action='store_true', default=False, help='finetune 剪枝后的模型') 88 | 89 | args = parser.parse_args() 90 | 91 | if args.mode == 'prune': 92 | prune(args) 93 | 94 | elif args.mode == 'sa': 95 | sensitivity_analysis(args) 96 | 97 | elif args.mode == 'quantization': 98 | quantization(args) 99 | 100 | elif args.mode == 'finetune': 101 | args.work_path = os.path.join(args.work_path, get_time()) 102 | os.mkdir(args.work_path) 103 | 104 | args.log_path = os.path.join(args.work_path, 'log') 105 | args.save_path = os.path.join(args.work_path, 'save') 106 | args.model_path = os.path.join(args.work_path, 'model') 107 | 108 | os.mkdir(args.log_path) 109 | os.mkdir(args.save_path) 110 | os.mkdir(args.model_path) 111 | 112 | args.log_path = os.path.join(args.log_path, get_time()) 113 | args.milestones = list(map(int, args.milestones.split(','))) 114 | 115 | learner = face_learner(args) 116 | learner.train(args) 117 | 118 | 119 | if __name__ == '__main__': 120 | main() 121 | -------------------------------------------------------------------------------- /model_define/__pycache__/MobileFaceNet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/model_define/__pycache__/MobileFaceNet.cpython-36.pyc -------------------------------------------------------------------------------- /model_define/__pycache__/MobileFaceNet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/model_define/__pycache__/MobileFaceNet.cpython-37.pyc -------------------------------------------------------------------------------- /model_define/__pycache__/MobileNetV3.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/model_define/__pycache__/MobileNetV3.cpython-36.pyc -------------------------------------------------------------------------------- /model_define/__pycache__/MobileNetV3.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/model_define/__pycache__/MobileNetV3.cpython-37.pyc -------------------------------------------------------------------------------- /model_define/__pycache__/load_state_dict.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/model_define/__pycache__/load_state_dict.cpython-36.pyc -------------------------------------------------------------------------------- /model_define/__pycache__/load_state_dict.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/model_define/__pycache__/load_state_dict.cpython-37.pyc -------------------------------------------------------------------------------- /model_define/__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/model_define/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /model_define/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/model_define/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /model_define/__pycache__/model_resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/model_define/__pycache__/model_resnet.cpython-36.pyc -------------------------------------------------------------------------------- /model_define/__pycache__/model_resnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/model_define/__pycache__/model_resnet.cpython-37.pyc -------------------------------------------------------------------------------- /model_define/__pycache__/resnet50_imagenet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/model_define/__pycache__/resnet50_imagenet.cpython-36.pyc -------------------------------------------------------------------------------- /model_define/load_state_dict.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: LinX 3 | # datetime: 2019/10/28 下午6:09 4 | 5 | from model_define.model_resnet import fresnet50_v3, fresnet100_v3, fresnet34_v3 6 | from collections import OrderedDict 7 | from model_define.MobileFaceNet import MobileFaceNet 8 | from model_define.model import MobileFaceNet_sor, ResNet34, MobileFaceNet_y2 9 | from model_define.MobileNetV3 import MobileNetV3_Large 10 | from model_define.resnet50_imagenet import resnet_50 11 | from model_define.mobilefacenet_y2_ljt.mobilefacenet_big import MobileFaceNet_y2_ljt 12 | from model_define.shufflefacenet_v2_ljt.ShuffleFaceNetV2 import ShuffleFaceNetV2 13 | from model_define.resnet_50_ljt.resnet_50 import fresnet50_v3_ljt 14 | from model_define.resnet_100_ljt.resnet_100 import fresnet100_v3_ljt 15 | import torch 16 | 17 | 18 | def load_state_dict(args): 19 | 20 | if args.model == 'mobilefacenet': 21 | model = MobileFaceNet_sor(args.embedding_size) 22 | 23 | elif args.model == 'resnet34': 24 | model = fresnet34_v3((112, 112)) 25 | args.lr = 0.001 26 | 27 | elif args.model == 'mobilefacenet_y2': 28 | model = MobileFaceNet_y2(args.embedding_size) 29 | 30 | elif args.model == 'resnet50': 31 | model = fresnet50_v3() 32 | 33 | elif args.model == 'resnet50_imagenet': 34 | model = resnet_50() 35 | 36 | elif args.model == 'resnet100': 37 | model = fresnet100_v3() 38 | 39 | elif args.model == 'mobilefacenet_lzc': 40 | model = MobileFaceNet(128, (5, 5)) 41 | 42 | elif args.model == 'mobilenetv3': 43 | model = MobileNetV3_Large(4) 44 | 45 | elif args.model == 'resnet34_lzc': 46 | model = fresnet34_v3((80, 80)) 47 | 48 | elif args.model == 'mobilefacenet_y2_ljt': 49 | model = MobileFaceNet_y2_ljt() 50 | 51 | elif args.model == 'shufflefacenet_v2_ljt': 52 | model = ShuffleFaceNetV2(512, 2.0, (144, 122)) 53 | 54 | elif args.model == 'resnet_50_ljt': 55 | model = fresnet50_v3_ljt() 56 | 57 | elif args.model == 'resnet_100_ljt': 58 | model = fresnet100_v3_ljt() 59 | 60 | else: 61 | print('不支持此模型剪枝!') 62 | 63 | print('load {}\'s checkpoint'.format(args.model)) 64 | 65 | state_dict = torch.load(args.best_model_path, map_location=args.device) 66 | if args.from_data_parallel: 67 | new_state_dict = OrderedDict() 68 | for k, v in state_dict.items(): 69 | new_state_dict[k[7:]] = v 70 | state_dict = new_state_dict 71 | model.load_state_dict(state_dict) 72 | 73 | return model 74 | -------------------------------------------------------------------------------- /model_define/mobilefacenet_y2_ljt/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: linx 3 | # datetime 2020/9/14 下午6:24 4 | -------------------------------------------------------------------------------- /model_define/mobilefacenet_y2_ljt/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/model_define/mobilefacenet_y2_ljt/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /model_define/mobilefacenet_y2_ljt/__pycache__/mobilefacenet_big.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/model_define/mobilefacenet_y2_ljt/__pycache__/mobilefacenet_big.cpython-36.pyc -------------------------------------------------------------------------------- /model_define/mobilefacenet_y2_ljt/__pycache__/network_elems.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/model_define/mobilefacenet_y2_ljt/__pycache__/network_elems.cpython-36.pyc -------------------------------------------------------------------------------- /model_define/mobilefacenet_y2_ljt/common_utility.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Module 2 | import torch.nn.functional as F 3 | import math 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | 8 | ''' 9 | Net work's common utility 10 | ''' 11 | 12 | 13 | def l2_norm(input, axis=1): 14 | norm = torch.norm(input, 2, axis, True) 15 | output = torch.div(input, norm) 16 | return output 17 | 18 | 19 | class L2Norm(Module): 20 | def forward(self, input): 21 | return F.normalize(input) 22 | 23 | 24 | class Flatten(Module): 25 | def forward(self, input): 26 | return input.view(input.size(0), -1) 27 | # for onnx model convert 28 | # batch_size = np.array(input.size(0)) 29 | # batch_size.astype(dtype=np.int32) 30 | # return input.view(batch_size, 512) 31 | 32 | 33 | class GDC(nn.Module): 34 | def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1): 35 | super(GDC, self).__init__() 36 | self.conv = nn.Conv2d(in_c, out_channels=out_c, kernel_size=kernel, 37 | groups=groups, stride=stride, padding=padding, bias=False) 38 | self.bn = nn.BatchNorm2d(out_c) 39 | 40 | def forward(self, x): 41 | x = self.conv(x) 42 | x = self.bn(x) 43 | return x 44 | 45 | 46 | def Get_Conv_Size(height, width, kernel, stride, padding, rpt_num): 47 | conv_h = height 48 | conv_w = width 49 | for _ in range(rpt_num): 50 | conv_h = int((conv_h - kernel[0] + 2 * padding[0]) / stride[0] + 1) 51 | conv_w = int((conv_w - kernel[1] + 2 * padding[1]) / stride[1] + 1) 52 | return conv_h * conv_w 53 | 54 | 55 | def Get_Conv_kernel(height, width, kernel, stride, padding, rpt_num): 56 | conv_h = height 57 | conv_w = width 58 | for _ in range(rpt_num): 59 | conv_h = math.ceil((conv_h - kernel[0] + 2 * padding[0]) / stride[0] + 1) 60 | conv_w = math.ceil((conv_w - kernel[1] + 2 * padding[1]) / stride[1] + 1) 61 | print(conv_h, conv_w) 62 | return (int(conv_h), int(conv_w)) 63 | 64 | 65 | def Get_Conv_kernel_floor(height, width, kernel, stride, padding, rpt_num): 66 | conv_h = height 67 | conv_w = width 68 | for _ in range(rpt_num): 69 | conv_h = math.floor((conv_h - kernel[0] + 2 * padding[0]) / stride[0] + 1) 70 | conv_w = math.floor((conv_w - kernel[1] + 2 * padding[1]) / stride[1] + 1) 71 | print(conv_h, conv_w) 72 | # print(conv_h, conv_w) 73 | return (int(conv_h), int(conv_w)) 74 | 75 | 76 | def get_dense_ave_pooling_size(height, width, block_config): 77 | size1 = Get_Conv_kernel(height, width, (3, 3), (2, 2), (1, 1), 1) 78 | size2 = Get_Conv_kernel(size1[0], size1[1], (2, 2), (2, 2), (1, 1), 1) 79 | # print(size1) 80 | size3 = Get_Conv_kernel(size2[0], size2[1], (2, 2), (2, 2), (0, 0), len(block_config) - 1) 81 | return size3 82 | 83 | 84 | def get_shuffle_ave_pooling_size(height, width, using_pool=False): 85 | first_batch_num = 2 86 | if using_pool: 87 | first_batch_num = 3 88 | 89 | size1 = Get_Conv_kernel(height, width, (3, 3), (2, 2), (0, 0), first_batch_num) 90 | # print(size1) 91 | size2 = Get_Conv_kernel(size1[0], size1[1], (2, 2), (2, 2), (0, 0), 2) 92 | return size2 93 | 94 | 95 | def get_ghost_dw_size(height, width): 96 | size1 = Get_Conv_kernel_floor(height, width, (3, 3), (2, 2), (3 // 2, 3 // 2), 3) 97 | size1 = Get_Conv_kernel_floor(size1[0], size1[1], (5, 5), (2, 2), (5 // 2, 5 // 2), 2) 98 | return size1 99 | 100 | 101 | if __name__ == "__main__": 102 | # get_dense_ave_pooling_size(112,112, [1,2,3,4]) 103 | # print("="*10) 104 | # get_shuffle_ave_pooling_size(112,112,True) 105 | # print("=" * 10) 106 | # get_shuffle_ave_pooling_size(112, 112, False) 107 | # print("=" * 10) 108 | Get_Conv_kernel_floor(112, 112, (3, 3), (2, 2), (1, 1), 4) 109 | print("=" * 10) 110 | -------------------------------------------------------------------------------- /model_define/resnet_100_ljt/__pycache__/resnet_100.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/model_define/resnet_100_ljt/__pycache__/resnet_100.cpython-36.pyc -------------------------------------------------------------------------------- /model_define/resnet_50_ljt/__pycache__/resnet_50.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/model_define/resnet_50_ljt/__pycache__/resnet_50.cpython-36.pyc -------------------------------------------------------------------------------- /model_define/shufflefacenet_v2_ljt/__pycache__/ShuffleFaceNetV2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/model_define/shufflefacenet_v2_ljt/__pycache__/ShuffleFaceNetV2.cpython-36.pyc -------------------------------------------------------------------------------- /mytest.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: linx 3 | # datetime 2020/9/18 上午11:31 4 | from model_define.shufflefacenet_v2_ljt.ShuffleFaceNetV2 import ShuffleFaceNetV2 5 | import torch 6 | 7 | state_dict = torch.load('/home/linx/model/ljt/2020-09-15-10-53_CombineMargin-ljt914-m0.9m0.4m0.15s64_le_re_0.4_144x122_2020-07-30-Full-CLEAN-0803-2-MIDDLE-30_ShuffleFaceNetA-2.0-d512_model_iter-76608_TYLG-0.7319_XCHoldClean-0.8198_BusIDPhoto-0.7310-noamp.pth') 8 | 9 | for k, v in state_dict.items(): 10 | print(k, v.shape) 11 | pass 12 | 13 | net = ShuffleFaceNetV2(512, '2.0', (144, 122)) 14 | print(net) -------------------------------------------------------------------------------- /pruning.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: LinX 3 | # datetime: 2019/10/10 下午2:21 4 | 5 | import matplotlib 6 | import torch 7 | import distiller 8 | import os 9 | from commom_utils.utils import cal_flops, test_speed, get_time 10 | from model_define.load_state_dict import load_state_dict 11 | from test_module.test_on_diverse_dataset import test 12 | import numpy as np 13 | matplotlib.use('agg') 14 | 15 | 16 | def prune(args): 17 | model = load_state_dict(args) 18 | model = model.cuda() 19 | epoch = 0 20 | 21 | # acc = test(args, model) 22 | # print('剪枝前acc为:{}'.format(acc)) 23 | 24 | if args.fpgm: 25 | print('使用fpgm算法剪枝') 26 | if args.cal_flops_and_forward: 27 | flops, params = cal_flops(model, [1, 3, 144, 122]) 28 | forward_time = test_speed(model, [1, 3, 144, 122]) 29 | print('剪枝前前向时间为{}ms, flops={}, params={}'.format(forward_time, flops, params)) 30 | conv_dict = {} 31 | 32 | if args.hrank: 33 | 34 | cnt = 1 35 | layer_name = 'conv1.conv.weight' 36 | conv_dict[layer_name] = np.load(args.rank_path + 'rank_conv' + str(cnt) + '.npy') 37 | cnt += 1 38 | for key, value in model.block_info.items(): 39 | if value == 1: 40 | layer_name = '{}.conv.conv.weight'.format(key) 41 | conv_dict[layer_name] = np.load(args.rank_path+'rank_conv'+str(cnt)+'.npy') 42 | cnt += 1 43 | layer_name = '{}.conv_dw.conv.weight'.format(key) 44 | conv_dict[layer_name] = np.load(args.rank_path + 'rank_conv' + str(cnt) + '.npy') 45 | cnt += 1 46 | layer_name = '{}.project.conv.weight'.format(key) 47 | conv_dict[layer_name] = np.load(args.rank_path + 'rank_conv' + str(cnt) + '.npy') 48 | cnt += 1 49 | else: 50 | for j in range(value): 51 | layer_name = '{}.model.{}.conv.conv.weight'.format(key,j) 52 | conv_dict[layer_name] = np.load(args.rank_path + 'rank_conv' + str(cnt) + '.npy') 53 | cnt += 1 54 | layer_name = '{}.model.{}.conv_dw.conv.weight'.format(key,j) 55 | conv_dict[layer_name] = np.load(args.rank_path + 'rank_conv' + str(cnt) + '.npy') 56 | cnt += 1 57 | layer_name = '{}.model.{}.project.conv.weight'.format(key,j) 58 | conv_dict[layer_name] = np.load(args.rank_path + 'rank_conv' + str(cnt) + '.npy') 59 | cnt += 1 60 | layer_name = 'conv_6_sep.conv.weight' 61 | conv_dict[layer_name] = np.load(args.rank_path + 'rank_conv' + str(cnt) + '.npy') 62 | cnt += 1 63 | layer_name = 'conv_6_dw.conv.weight' 64 | conv_dict[layer_name] = np.load(args.rank_path + 'rank_conv' + str(cnt) + '.npy') 65 | cnt += 1 66 | print(len(conv_dict)) 67 | 68 | model.train() 69 | compression_scheduler = distiller.config.file_config(model, None, args.yaml_path) 70 | compression_scheduler.on_epoch_begin(epoch, fpgm=args.fpgm, HRank=args.hrank, conv_index=conv_dict) 71 | compression_scheduler.on_minibatch_begin(epoch, minibatch_id=0, minibatches_per_epoch=0) 72 | 73 | if args.save_model_pt: 74 | torch.save(model.state_dict(), os.path.join(args.pruned_save_model_path, 'model_{}.pt'.format(args.model))) 75 | print('模型已保存!') 76 | config_model_arr = [] 77 | if args.model != 'shufflefacenet_v2_ljt': 78 | for k, v in model.state_dict().items(): 79 | if len(v.shape) == 4 and k.find('downsample') == -1 and k.find('fc') == -1: 80 | config_model_arr.append(v.shape[0]) 81 | else: 82 | for k, v in model.state_dict().items(): 83 | if k.find('branch_main.0.weight') != -1: 84 | config_model_arr.append(v.shape[0]) 85 | print('网络每层out参数为:{}, 总共{}个参数'.format(config_model_arr, len(config_model_arr))) 86 | 87 | if not os.path.exists('work_space/layers_out/'): 88 | os.mkdir('work_space/layers_out/') 89 | 90 | with open('work_space/layers_out/out.txt', 'w') as f: 91 | f.write(str(config_model_arr)) 92 | # print(model) 93 | acc = test(args, model) 94 | print('剪枝后acc为:{}'.format(acc)) 95 | 96 | flops, params = cal_flops(model, [1, 3, 112, 112]) 97 | forward_time = 0 98 | print('剪枝后前向时间为{}ms, flops={}, params={}'.format(forward_time, flops, params)) 99 | 100 | 101 | 102 | -------------------------------------------------------------------------------- /pruning_analysis_tools/plot_csv.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: LinX 3 | # datetime: 2019/10/9 上午10:03 4 | 5 | import pandas as pd 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | 9 | 10 | def plot(sensitivity, save_root): 11 | """ 12 | 画每个layer的折线图 13 | :param layer_name: layer名 14 | :param layer: 每个layer中的各个weight的敏感值 15 | :param save_root: 折线图存放路径 16 | :return: 17 | """ 18 | i = 0 19 | for k, v in sensitivity.items(): 20 | 21 | if i % 7 == 6: 22 | plt.ylabel('top1') 23 | plt.xlabel('sparsity') 24 | plt.title(str(i-6) + '-' + str(i+1) + ' Pruning Sensitivity') 25 | plt.legend(loc='lower center', 26 | ncol=2, mode="expand", borderaxespad=0.) 27 | plt.savefig(save_root + '/' + str(i-6) + '-' + str(i+1) + '.png', format='png') 28 | plt.close() 29 | 30 | sense = v 31 | name = k 32 | sparsities = np.arange(0, 0.95, 0.1) 33 | plt.plot(sparsities, sense, label=name) 34 | i += 1 35 | 36 | 37 | def plot_csv(csv_name, save_root): 38 | """ 39 | plot sensitivity 40 | :param csv_name: csv文件名 41 | :param save_root: 图片保存路径 42 | :return: 43 | """ 44 | data = pd.read_csv(csv_name) 45 | sensitivity = {} 46 | 47 | for x in data.values: 48 | sensitivity[x[0]] = [] 49 | 50 | for x in data.values: 51 | sensitivity[x[0]].append(x[2]) 52 | 53 | plot(sensitivity, save_root) 54 | 55 | 56 | if __name__ == '__main__': 57 | plot_csv('/home/yeluyue/lz/program/compression_tool/work_space/sensitivity_data/sensitivity_mobilefacenet_y2_2020-09-04-13-08.csv', '/home/yeluyue/lz/program/compression_tool/work_space/sensitivity_data') 58 | -------------------------------------------------------------------------------- /quantization.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: LinX 3 | # datetime: 2019/11/12 下午1:38 4 | import torch 5 | from distiller.quantization.range_linear import PostTrainLinearQuantizer 6 | from model_define.load_state_dict import load_state_dict 7 | from test_module.test_on_diverse_dataset import test 8 | from commom_utils.utils import get_time 9 | import os 10 | 11 | 12 | def quantization(args): 13 | model = load_state_dict(args) 14 | model = model.to(args.device) 15 | 16 | acc = test(args, model) 17 | print('量化前精度为:{}'.format(acc)) 18 | 19 | quantizer = PostTrainLinearQuantizer(model, fp16=args.fp16) 20 | 21 | quantizer.prepare_model(torch.rand([1, 3, args.input_size, args.input_size])) 22 | 23 | torch.save(model.state_dict(), os.path.join(args.quantized_save_model_path, args.model + get_time() + '.pt')) 24 | print('模型已保存') 25 | 26 | acc = test(args, model) 27 | print('量化后精度为:{}'.format(acc)) 28 | -------------------------------------------------------------------------------- /quantization.sh: -------------------------------------------------------------------------------- 1 | # 量化运行脚本 2 | python main.py --mode quantization \ 3 | --model resnet100 \ 4 | --best_model_path work_space/model_train_best/2019-09-29-11-37_SVGArcFace-O1-b0.4s40t1.1_fc_0.4_112x112_2019-09-27-Adult-padSY-Bus_fResNet100v3cv-d512_model_iter-340000.pth \ 5 | --from_data_parallel \ 6 | --data_source company \ 7 | --test_root_path data/test_data/fc_0.4_112x112 \ 8 | --img_list_label_path data/test_data/fc_0.4_112x112/pair_list/id_life_image_list_bmppair.txt \ 9 | --test_batch_size 256 \ 10 | --fp16 11 | 12 | -------------------------------------------------------------------------------- /resnet_ljt.sh: -------------------------------------------------------------------------------- 1 | # ================================ResNet-50============================================================================== 2 | 3 | # TYLG 4 | python main.py --mode sa \ 5 | --model resnet_50_ljt \ 6 | --best_model_path /home/linx/model/ljt/2020-08-11-22-35_CombineMargin-ljt83-m0.9m0.4m0.15s64_le_re_0.4_144x122_2020-07-30-Full-CLEAN-0803-2-MIDDLE-30_fResNet50v3cv-d512_model_iter-76608_TYLG-0.8070_PadMaskYTBYGlassM280-0.9305_BusIDPhoto-0.6541.pth \ 7 | --test_root_path /home/linx/dataset/company_test_data/TYLG/le_re_0.4_144x122 \ 8 | --img_list_label_path /home/linx/dataset/company_test_data/TYLG/id_life_image_list_bmppair.txt \ 9 | --data_source company \ 10 | --fpgm \ 11 | --from_data_parallel 12 | 13 | python main.py --mode prune \ 14 | --model resnet_50_ljt \ 15 | --save_model_pt \ 16 | --data_source company \ 17 | --fpgm \ 18 | --best_model_path /home/linx/model/ljt/2020-06-26-12-13_CombineMargin-ljt-m0.9m0.4m0.15s64_le_re_0.4_112x112_2020-05-26-PNTMS-CLEAN-MIDDLE-70_fResNet50v3cv-d512_model_iter-113680_Idoa-0.8011_IdoaMask-0.8325_TYLG-0.8443.pth \ 19 | --test_root_path /home/linx/dataset/company_test_data/TYLG/le_re_0.4_112x112 \ 20 | --img_list_label_path /home/linx/dataset/company_test_data/TYLG/id_life_image_list_bmppair.txt \ 21 | --from_data_parallel 22 | 23 | # ================================ResNet-100============================================================================== 24 | python main.py --mode sa \ 25 | --model resnet_100_ljt \ 26 | --best_model_path /home/linx/model/ljt/2020-06-27-12-59_CombineMargin-zk-O1D1Ls-m0.9m0.4m0.15s64_fc_0.4_144x122_2020-05-26-PNTMS-CLEAN-MIDDLE-70_fResNet100v3cv-d512_model_iter-96628_Idoa-0.8996_IdoaMask-0.9127_TYLG-0.9388.pth \ 27 | --test_root_path /media/linx/B0C6A127C6A0EF32/200914_data_model_ljt/TYLG/fc_0.4_144x122 \ 28 | --img_list_label_path /home/linx/dataset/company_test_data/TYLG/id_life_image_list_bmppair.txt \ 29 | --data_source company \ 30 | --fpgm \ 31 | --from_data_parallel 32 | 33 | 34 | python main.py --mode prune \ 35 | --model resnet_100_ljt \ 36 | --save_model_pt \ 37 | --data_source company \ 38 | --fpgm \ 39 | --best_model_path /home/linx/model/ljt/2020-06-27-12-59_CombineMargin-zk-O1D1Ls-m0.9m0.4m0.15s64_fc_0.4_144x122_2020-05-26-PNTMS-CLEAN-MIDDLE-70_fResNet100v3cv-d512_model_iter-96628_Idoa-0.8996_IdoaMask-0.9127_TYLG-0.9388.pth \ 40 | --test_root_path /home/linx/dataset/company_test_data/TYLG/fc_0.4_144x122 \ 41 | --img_list_label_path /home/linx/dataset/company_test_data/TYLG/id_life_image_list_bmppair.txt \ 42 | --from_data_parallel -------------------------------------------------------------------------------- /sensitivity_analysis.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: LinX 3 | # datetime: 2019/10/9 下午2:17 4 | 5 | import torch 6 | import os 7 | from collections import OrderedDict 8 | from copy import deepcopy 9 | import distiller 10 | from distiller.scheduler import CompressionScheduler 11 | import numpy as np 12 | from model_define.load_state_dict import load_state_dict 13 | import time 14 | from test_module.test_on_diverse_dataset import test 15 | from datetime import datetime 16 | 17 | 18 | def get_time(): 19 | return (str(datetime.now())[:-10]).replace(' ', '-').replace(':', '-') 20 | 21 | 22 | def perform_sensitivity_analysis(model, net_params, sparsities, args): 23 | 24 | sensitivities = OrderedDict() 25 | print('测试原模型精度') 26 | accuracy = test(args, model) 27 | 28 | print('原模型精度为:{}'.format(accuracy)) 29 | 30 | if args.fpgm: 31 | print('即将采用几何中位数剪枝产生折线图') 32 | conv_dict = {} 33 | if args.hrank: 34 | print('即将采用HRank剪枝') 35 | cnt = 1 36 | layer_name = 'conv1.conv.weight' 37 | conv_dict[layer_name] = np.load(args.rank_path + 'rank_conv' + str(cnt) + '.npy') 38 | cnt += 1 39 | for key, value in model.block_info.items(): 40 | if value == 1: 41 | layer_name = '{}.conv.conv.weight'.format(key) 42 | conv_dict[layer_name] = np.load(args.rank_path + 'rank_conv' + str(cnt) + '.npy') 43 | cnt += 1 44 | layer_name = '{}.conv_dw.conv.weight'.format(key) 45 | conv_dict[layer_name] = np.load(args.rank_path + 'rank_conv' + str(cnt) + '.npy') 46 | cnt += 1 47 | layer_name = '{}.project.conv.weight'.format(key) 48 | conv_dict[layer_name] = np.load(args.rank_path + 'rank_conv' + str(cnt) + '.npy') 49 | cnt += 1 50 | else: 51 | for j in range(value): 52 | layer_name = '{}.model.{}.conv.conv.weight'.format(key, j) 53 | conv_dict[layer_name] = np.load(args.rank_path + 'rank_conv' + str(cnt) + '.npy') 54 | cnt += 1 55 | layer_name = '{}.model.{}.conv_dw.conv.weight'.format(key, j) 56 | conv_dict[layer_name] = np.load(args.rank_path + 'rank_conv' + str(cnt) + '.npy') 57 | cnt += 1 58 | layer_name = '{}.model.{}.project.conv.weight'.format(key, j) 59 | conv_dict[layer_name] = np.load(args.rank_path + 'rank_conv' + str(cnt) + '.npy') 60 | cnt += 1 61 | layer_name = 'conv_6_sep.conv.weight' 62 | conv_dict[layer_name] = np.load(args.rank_path + 'rank_conv' + str(cnt) + '.npy') 63 | cnt += 1 64 | layer_name = 'conv_6_dw.conv.weight' 65 | conv_dict[layer_name] = np.load(args.rank_path + 'rank_conv' + str(cnt) + '.npy') 66 | cnt += 1 67 | print(len(conv_dict)) 68 | 69 | for param_name in net_params: 70 | if model.state_dict()[param_name].dim() not in [4]: 71 | continue 72 | 73 | model_cpy = deepcopy(model) 74 | 75 | sensitivity = OrderedDict() 76 | 77 | # 对每一层循环剪枝并测试精度(从0.05->0.95) 78 | for sparsity_level in sparsities: 79 | 80 | sparsity_level = float(sparsity_level) 81 | 82 | print(param_name, sparsity_level) 83 | 84 | pruner = distiller.pruning.L1RankedStructureParameterPruner("sensitivity", 85 | group_type="Filters", 86 | desired_sparsity=sparsity_level, 87 | weights=param_name) 88 | 89 | policy = distiller.PruningPolicy(pruner, pruner_args=None) 90 | scheduler = CompressionScheduler(model_cpy) 91 | scheduler.add_policy(policy, epochs=[0]) 92 | 93 | scheduler.on_epoch_begin(0, fpgm=args.fpgm, HRank=args.hrank, conv_index=conv_dict) 94 | 95 | scheduler.mask_all_weights() 96 | 97 | accuracy = test(args, model_cpy) 98 | 99 | print('剪枝{}后的精度为:{}'.format(sparsity_level, accuracy)) 100 | 101 | sensitivity[sparsity_level] = (accuracy, 0, 0) 102 | sensitivities[param_name] = sensitivity 103 | 104 | return sensitivities 105 | 106 | 107 | def sensitivity_analysis(args): 108 | 109 | model = load_state_dict(args) 110 | model.eval() 111 | model.cuda() 112 | 113 | sensitivities = np.arange(0.0, 0.95, 0.1) 114 | which_params = [param_name for param_name, _ in model.named_parameters()] 115 | 116 | start_time = time.time() 117 | 118 | sensitivity = perform_sensitivity_analysis(model, which_params, sensitivities, args) 119 | 120 | end_time = time.time() 121 | print('剪枝敏感度分析总共耗时{}h'.format((end_time - start_time) / 3600)) 122 | # distiller.sensitivities_to_png(sensitivity, 'work_space/sensitivity_data/sensitivity_{}.png'.format(args.model)) 123 | distiller.sensitivities_to_csv(sensitivity, os.path.join(args.sensitivity_csv_path, 'sensitivity_{}_{}.csv'.format(args.model, get_time()))) 124 | 125 | 126 | 127 | 128 | 129 | 130 | -------------------------------------------------------------------------------- /src/__pycache__/data_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/src/__pycache__/data_loader.cpython-36.pyc -------------------------------------------------------------------------------- /src/__pycache__/data_loader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/src/__pycache__/data_loader.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/src/__pycache__/dataset.cpython-36.pyc -------------------------------------------------------------------------------- /src/__pycache__/dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/src/__pycache__/dataset.cpython-37.pyc -------------------------------------------------------------------------------- /src/data_loader.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Data load from image list 3 | 1. with no Normalize 4 | 2. with opencv load image 5 | 3. image is not divide 255 6 | @19-10-21 mod by ljt: add data augmentation 7 | @19-11-29 mod by zkx 调整数据预处理方式到dict结构进行调用 8 | ''' 9 | 10 | import os 11 | from src.loader import transforms_V2 as trans 12 | from src.loader.autoaugment import ImageNetPolicy 13 | from torch.utils.data import DataLoader 14 | from src.loader.utility import opencv_loader, read_image_list_test 15 | from src.loader.read_image_list_io import DatasetFromList, DatasetFromListTriplet 16 | import torch 17 | 18 | pwd_path = os.path.dirname(__file__) 19 | from src.dataset import TrainSet,TestSet 20 | 21 | 22 | class DefineTrans: 23 | def __init__(self, input_size): 24 | self.image_preprocess = { 25 | 'D1':trans.Compose([ 26 | trans.ToPILImage(), 27 | trans.RandomResizedCrop(input_size, scale=(0.99, 1.01)), 28 | trans.ColorJitter(brightness=0.125, contrast=0.125, saturation=0.125), 29 | trans.RandomHorizontalFlip(), 30 | trans.ToTensor(), 31 | ]), 32 | 'D2':trans.Compose([ 33 | trans.ToPILImage(), 34 | trans.RandomResizedCrop(input_size, scale=(0.9, 1.1)), 35 | trans.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1), 36 | trans.RandomHorizontalFlip(), 37 | trans.ToTensor(), 38 | ]), 39 | 'D3': trans.Compose([ 40 | trans.ToPILImage(), 41 | trans.RandomResizedCrop(input_size, scale=(0.9, 1.1)), 42 | trans.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1), 43 | trans.RandomRotation(5), 44 | trans.RandomHorizontalFlip(), 45 | trans.ToTensor(), 46 | ]), 47 | 'D9':trans.Compose([ 48 | trans.ToPILImage(), 49 | trans.RandomHorizontalFlip(), 50 | ImageNetPolicy(), 51 | trans.ToTensor(), 52 | ]), 53 | 'None':trans.Compose([ 54 | trans.ToTensor(), 55 | ]) 56 | 57 | # others 58 | # trans.RandomPerspective(), 59 | # trans.RandomErasing(), 60 | # trans.RandomRotation(5), 61 | # trans.RandomApply([trans.RandomResizedCrop((width, height), scale=(0.99, 1.01))]), 62 | # trans.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 63 | } 64 | 65 | # 用于训练集的数据加载 66 | def get_train_dataset(img_root_path, image_list_path, data_aug): 67 | height = int(img_root_path.split('_')[-1].split('x')[0]) 68 | width = int(img_root_path.split('_')[-1].split('x')[1]) 69 | 70 | # @2019-11-29 zkx get input image process method from dict 71 | input_process = DefineTrans((height, width)) 72 | if data_aug is None: 73 | train_transform = input_process.image_preprocess['None'] 74 | else: 75 | train_transform = input_process.image_preprocess[data_aug] 76 | 77 | ds = DatasetFromList(img_root_path, image_list_path, opencv_loader, train_transform, None) 78 | class_num = ds[-1][1] + 1 79 | return ds, class_num 80 | 81 | 82 | def get_train_list_loader(conf): 83 | train_set = conf.data_mode 84 | img_root_path = TrainSet[train_set]['root_path'] 85 | img_label_list = TrainSet[train_set]['label_list'] 86 | 87 | print("img_label_list") 88 | print(img_label_list) 89 | 90 | patch_info = conf.patch_info 91 | root_path = '{}/{}'.format(img_root_path, patch_info) 92 | celeb_ds, celeb_class_num = get_train_dataset(root_path, img_label_list, conf.data_aug) 93 | print("images:") 94 | print(len(celeb_ds)) 95 | ds = celeb_ds 96 | class_num = celeb_class_num 97 | if conf.distributed: 98 | train_sampler = torch.utils.data.distributed.DistributedSampler(ds) 99 | else: 100 | train_sampler = None 101 | # @2020-02-25 ljt 加入drop_last,解决bn在最后一个step报错的问题 102 | loader = DataLoader(ds, batch_size=conf.batch_size, shuffle=(train_sampler is None), pin_memory=conf.pin_memory, 103 | num_workers=conf.num_workers, sampler=train_sampler,drop_last=True) 104 | return loader, class_num, train_sampler 105 | 106 | 107 | # 用于测试集的数据加载 108 | def get_test_dataset(img_root_path, image_list_path): 109 | 110 | # @2019-11-29 zkx get input image process method from dict 111 | input_process = DefineTrans((1, 1)) 112 | test_transform = input_process.image_preprocess["None"] 113 | ds = DatasetFromList(img_root_path, image_list_path, opencv_loader, test_transform, None, read_image_list_test) 114 | return ds 115 | 116 | def get_batch_test_data(image_root_path, image_list_path, batch_size, num_workers): 117 | ds = get_test_dataset(image_root_path, image_list_path) 118 | loader = DataLoader(ds, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=num_workers) 119 | return loader 120 | 121 | -------------------------------------------------------------------------------- /src/loader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/src/loader/__init__.py -------------------------------------------------------------------------------- /src/loader/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/src/loader/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /src/loader/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/src/loader/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /src/loader/__pycache__/autoaugment.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/src/loader/__pycache__/autoaugment.cpython-36.pyc -------------------------------------------------------------------------------- /src/loader/__pycache__/autoaugment.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/src/loader/__pycache__/autoaugment.cpython-37.pyc -------------------------------------------------------------------------------- /src/loader/__pycache__/functional.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/src/loader/__pycache__/functional.cpython-36.pyc -------------------------------------------------------------------------------- /src/loader/__pycache__/functional_V2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/src/loader/__pycache__/functional_V2.cpython-36.pyc -------------------------------------------------------------------------------- /src/loader/__pycache__/functional_V2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/src/loader/__pycache__/functional_V2.cpython-37.pyc -------------------------------------------------------------------------------- /src/loader/__pycache__/read_image_list_io.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/src/loader/__pycache__/read_image_list_io.cpython-36.pyc -------------------------------------------------------------------------------- /src/loader/__pycache__/read_image_list_io.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/src/loader/__pycache__/read_image_list_io.cpython-37.pyc -------------------------------------------------------------------------------- /src/loader/__pycache__/transforms.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/src/loader/__pycache__/transforms.cpython-36.pyc -------------------------------------------------------------------------------- /src/loader/__pycache__/transforms_V2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/src/loader/__pycache__/transforms_V2.cpython-36.pyc -------------------------------------------------------------------------------- /src/loader/__pycache__/transforms_V2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/src/loader/__pycache__/transforms_V2.cpython-37.pyc -------------------------------------------------------------------------------- /src/loader/__pycache__/utility.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/src/loader/__pycache__/utility.cpython-36.pyc -------------------------------------------------------------------------------- /src/loader/__pycache__/utility.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/src/loader/__pycache__/utility.cpython-37.pyc -------------------------------------------------------------------------------- /src/loader/read_image_list_io.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | ''' 3 | DatasetFromList的作用 4 | 1.从list中读取图片的路径和标签 5 | 2.重载__getitem__属性,使得类的对象能够用[]操作符调用,返回图片和标签 6 | 3.这种做法将获取图片路径和加载图片集成到了一个类中完成,巧妙的做到了解耦 7 | ''' 8 | import numpy as np 9 | from torch.utils.data import Dataset 10 | from .utility import read_image_list,read_image_triplet_list 11 | 12 | class DatasetFromList(Dataset): 13 | """A generic data loader where the image list arrange in this way: :: 14 | 15 | class_x/xxx.ext 0 16 | class_x/xxy.ext 0 17 | class_x/xxz.ext 0 18 | 19 | class_y/123.ext 1 20 | class_y/nsdf3.ext 1 21 | class_y/asd932_.ext 1 22 | 23 | Args: 24 | root (string): Root directory path. 25 | image_list_path (string) : where to load image list 26 | loader (callable): A function to load a sample given its path. 27 | image_list_loader (callable) : A function to read image-label pair or image-image-prefix pair 28 | transform (callable, optional): A function/transform that takes in 29 | a sample and returns a transformed version. 30 | E.g, ``transforms.RandomCrop`` for images. 31 | target_transform (callable, optional): A function/transform that takes 32 | in the target and transforms it. 33 | 34 | Attributes: 35 | samples (list): List of (sample path, image_index) tuples 36 | """ 37 | 38 | def __init__(self, root, image_list_path, loader , 39 | transform=None, target_transform=None, image_list_loader = read_image_list): 40 | samples = image_list_loader(root, image_list_path) 41 | if len(samples) == 0: 42 | raise(RuntimeError("Found 0 files in image_list: " + image_list_path + "\n")) 43 | 44 | self.root = root 45 | self.loader = loader 46 | self.samples = samples 47 | self.transform = transform 48 | self.target_transform = target_transform 49 | 50 | def __getitem__(self, index): 51 | """ 52 | Args: 53 | index (int): Index 54 | 55 | Variable: 56 | self.samples (list): [image path, image label] 57 | 58 | Returns: 59 | tuple: (sample, target) where target is class_index of the target class. 60 | """ 61 | path, target = self.samples[index] 62 | import os 63 | if not os.path.exists(path): 64 | print("Not exists..", path) 65 | 66 | sample = self.loader(path) 67 | assert isinstance(sample, np.ndarray) , path 68 | 69 | if self.transform is not None: 70 | sample = self.transform(sample) 71 | if self.target_transform is not None: 72 | target = self.target_transform(target) 73 | 74 | return sample, target 75 | 76 | def __len__(self): 77 | return len(self.samples) 78 | 79 | def __repr__(self): 80 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 81 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 82 | fmt_str += ' Root Location: {}\n'.format(self.root) 83 | tmp = ' Transforms (if any): ' 84 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 85 | tmp = ' Target Transforms (if any): ' 86 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 87 | return fmt_str 88 | 89 | 90 | class DatasetFromListTriplet(DatasetFromList): 91 | def __init__(self, root, image_list_path, loader , 92 | transform=None, target_transform=None, image_list_loader = read_image_triplet_list): 93 | super(DatasetFromListTriplet,self).__init__(root, image_list_path, loader, transform, 94 | target_transform, image_list_loader) 95 | 96 | def __getitem__(self, index): 97 | """ 98 | Args: 99 | index (tuple): includes (key, image_index) 100 | Variable: 101 | self.samples (dictionary): key(int) image_label, value(list) same label's image path 102 | 103 | Returns: 104 | tuple: (sample, target) where target is class_index of the target class. 105 | """ 106 | target = index[0] 107 | image_path = self.samples[index[0]][index[1]] 108 | sample = self.loader(image_path) 109 | 110 | if self.transform is not None: 111 | sample = self.transform(sample) 112 | if self.target_transform is not None: 113 | target = self.target_transform(target) 114 | 115 | return sample, target 116 | 117 | def get_class_num(self): 118 | return len(self.samples) 119 | 120 | 121 | 122 | 123 | -------------------------------------------------------------------------------- /src/loader/utility.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import cv2 3 | 4 | def read_image_list(root_path, image_list_path): 5 | f = open(image_list_path, 'r') 6 | # f = open(image_list_path, encoding='utf-8', mode='r') 7 | data = f.read().splitlines() 8 | f.close() 9 | 10 | samples = [] 11 | for line in data: 12 | sample_path = '{}/{}'.format(root_path, line.split(' ')[0]) 13 | class_index = int(line.split(' ')[1]) 14 | samples.append((sample_path, class_index)) 15 | return samples 16 | 17 | def read_image_triplet_list(root_path, image_list_path): 18 | f = open(image_list_path, 'r') 19 | data = f.read().splitlines() 20 | f.close() 21 | 22 | label_num = int(data[-1].split(' ')[-1])+1 23 | samples = {} 24 | for i in range(label_num): 25 | samples[i] = [] 26 | 27 | for line in data: 28 | sample_path = '{}/{}'.format(root_path, line.split(' ')[0]) 29 | class_index = int(line.split(' ')[1]) 30 | samples[class_index].append(sample_path) 31 | return samples 32 | 33 | 34 | def read_image_list_test(root_path, image_list_path): 35 | f = open(image_list_path, 'r') 36 | data = f.read().splitlines() 37 | f.close() 38 | 39 | samples = [] 40 | for line in data: 41 | sample_path = '{}/{}'.format(root_path, line.split(' ')[0]) 42 | image_prefix = line.split(' ')[0] 43 | samples.append((sample_path, image_prefix)) 44 | return samples 45 | 46 | 47 | def pil_loader(path): 48 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 49 | with open(path, 'rb') as f: 50 | img = Image.open(f) 51 | return img.convert('RGB') 52 | #return img.convert('BGR') 53 | 54 | def opencv_loader(path): 55 | img = cv2.imread(path) 56 | return img 57 | 58 | 59 | # def accimage_loader(path): 60 | # import accimage 61 | # try: 62 | # return accimage.Image(path) 63 | # except IOError: 64 | # # Potentially a decoding problem, fall back to PIL.Image 65 | # return pil_loader(path) 66 | 67 | 68 | # def default_loader(path): 69 | # from torchvision import get_image_backend 70 | # if get_image_backend() == 'accimage': 71 | # return accimage_loader(path) 72 | # else: 73 | # return pil_loader(path) -------------------------------------------------------------------------------- /test_class.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: LinX 3 | # datetime: 2019/10/29 下午8:19 4 | import torch 5 | import torch.nn as nn 6 | from tensorboardX import SummaryWriter 7 | from MobileNetV3 import MobileNetV3_Large 8 | 9 | writer = SummaryWriter('/home/user1/linx/program/LightFaceNet/work_space/log') 10 | 11 | model = MobileNetV3_Large(4) 12 | 13 | model = nn.DataParallel(model) 14 | 15 | model.load_state_dict(torch.load('/home/user1/linx/program/LightFaceNet/work_space/models/model_train_best/2019-10-12' 16 | '-16-04_LiveBody_le_0.2_80x80_fake-20190924-train-data_live-0926_MobileNetv3Large' 17 | '-c4_pytorch_iter_14000.pth')) 18 | 19 | for name, param in model.named_parameters(): 20 | if len(param.shape) == 4: 21 | param = param.view(param.shape[0], -1) 22 | param = torch.norm(param, dim=1) 23 | print(param.shape) 24 | writer.add_histogram(name, param.clone().cpu().data.numpy()) 25 | 26 | # from scipy.spatial import distance 27 | # import numpy as np 28 | # 29 | # a = np.array([[1, 2], [3, 4]]) 30 | # b = np.array([[2, 3], [4, 5]]) 31 | # print(distance.cdist(a, a, metric='euclidean')) 32 | -------------------------------------------------------------------------------- /test_module/__pycache__/test_on_diverse_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/test_module/__pycache__/test_on_diverse_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /test_module/__pycache__/test_on_diverse_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/test_module/__pycache__/test_on_diverse_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /test_module/__pycache__/test_on_face_classification.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/test_module/__pycache__/test_on_face_classification.cpython-36.pyc -------------------------------------------------------------------------------- /test_module/__pycache__/test_on_face_classification.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/test_module/__pycache__/test_on_face_classification.cpython-37.pyc -------------------------------------------------------------------------------- /test_module/__pycache__/test_on_face_recognition.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/test_module/__pycache__/test_on_face_recognition.cpython-36.pyc -------------------------------------------------------------------------------- /test_module/__pycache__/test_on_face_recognition.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/test_module/__pycache__/test_on_face_recognition.cpython-37.pyc -------------------------------------------------------------------------------- /test_module/__pycache__/test_with_insight_face.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/test_module/__pycache__/test_with_insight_face.cpython-36.pyc -------------------------------------------------------------------------------- /test_module/test_on_diverse_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: LinX 3 | # datetime: 2019/11/12 下午2:02 4 | from test_module.test_on_face_recognition import TestOnFaceRecognition 5 | from test_module.test_on_face_classification import TestOnFaceClassification 6 | from test_module.test_with_insight_face import TestWithInsightFace 7 | 8 | 9 | def test(args, model): 10 | if args.model == 'mobilenetv3' or args.model == 'mobilefacenet_lzc' or args.model == 'resnet34_lzc' and args.data_source != 'lfw': 11 | test = TestOnFaceClassification(model, args.test_root_path, args.img_list_label_path) 12 | acc = test.test(args.test_batch_size) 13 | return acc 14 | elif args.model == 'resnet50_imagenet' or args.data_source == 'lfw': 15 | test = TestWithInsightFace(model) 16 | agedb_30, cfp_fp, lfw, agedb_30_issame, cfp_fp_issame, lfw_issame = test.get_val_data(args.test_root_path) 17 | acc_lfw = test.test(lfw, lfw_issame, args.test_batch_size, args.device) 18 | # acc_cfp = test.test(cfp_fp, cfp_fp_issame, args.test_batch_size, args.device) 19 | # acc_agedb = test.test(agedb_30, agedb_30_issame, args.test_batch_size, args.device) 20 | return acc_lfw 21 | else: 22 | test = TestOnFaceRecognition(model, args.test_root_path, args.img_list_label_path, args.data_source) 23 | accuracy = test.test(args.test_batch_size) 24 | return accuracy -------------------------------------------------------------------------------- /test_module/test_on_face_classification.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: LinX 3 | # datetime: 2019/11/1 下午6:14 4 | import cv2 5 | import torch 6 | import numpy as np 7 | from sklearn.metrics import roc_curve, auc 8 | import torch.nn.functional as F 9 | from dataloader import Face_Classification_Data 10 | from tqdm import tqdm 11 | 12 | 13 | def opencv_loader(path): 14 | img = cv2.imread(path, -1) 15 | return img 16 | 17 | 18 | class TestOnFaceClassification: 19 | 20 | def __init__(self, model, root_path, list_label, device=0, t=1e-3): 21 | self.model = model 22 | self.root_path = root_path 23 | self.list_label = list_label 24 | self.image_root_path = root_path 25 | self.list_label = list_label 26 | self.device = device 27 | self.t = t 28 | 29 | def test(self, batch_size): 30 | self.model.eval() 31 | self.model.to(self.device) 32 | data_loader = Face_Classification_Data(self.root_path, self.list_label, batch_size).test_loader 33 | file = open(self.list_label, 'r') 34 | liens = file.readlines() 35 | file.close() 36 | 37 | label_list = [] 38 | score_list = [] 39 | with torch.no_grad(): 40 | for img, label in tqdm(data_loader): 41 | input = img.to(self.device) 42 | out = self.model(input) 43 | out = F.softmax(out, dim=1) 44 | 45 | for idx in range(len(out)): 46 | 47 | label_list.append(label.cpu().numpy()[idx]) 48 | score_list.append(out[idx].cpu().numpy()[1]) 49 | 50 | fpr, tpr, thresholds = roc_curve(np.array(label_list), np.array(score_list), pos_label=1) 51 | fpr = np.around(fpr, decimals=7) 52 | index = np.argmin(abs(fpr - self.t)) 53 | index_all = np.where(fpr == fpr[index]) 54 | max_acc = np.max(tpr[index_all]) 55 | 56 | return max_acc 57 | 58 | -------------------------------------------------------------------------------- /train_module/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: linx 3 | # datetime 2020/9/1 下午1:43 4 | -------------------------------------------------------------------------------- /train_module/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/train_module/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /train_module/__pycache__/train_with_insight_face.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/train_module/__pycache__/train_with_insight_face.cpython-36.pyc -------------------------------------------------------------------------------- /work_space/finetune/2020-09-08-09-18/log/2020-09-08-09-18/events.out.tfevents.1599527940.yeluyue: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/work_space/finetune/2020-09-08-09-18/log/2020-09-08-09-18/events.out.tfevents.1599527940.yeluyue -------------------------------------------------------------------------------- /work_space/layers_out/out.txt: -------------------------------------------------------------------------------- 1 | [64, 64, 52, 64, 52, 58, 52, 128, 116, 116, 116, 128, 116, 128, 116, 231, 231, 231, 231, 231, 231, 256, 231, 256, 231, 231, 231, 205, 231, 256, 231, 256, 231, 256, 231, 256, 231, 256, 231, 256, 231, 205, 231, 461, 461, 512, 461, 512, 461] -------------------------------------------------------------------------------- /work_space/pruned_define_model/__pycache__/make_pruned_mobilefacenet_y2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/work_space/pruned_define_model/__pycache__/make_pruned_mobilefacenet_y2.cpython-36.pyc -------------------------------------------------------------------------------- /work_space/pruned_define_model/__pycache__/make_pruned_resnet50.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/work_space/pruned_define_model/__pycache__/make_pruned_resnet50.cpython-36.pyc -------------------------------------------------------------------------------- /work_space/pruned_define_model/__pycache__/make_pruned_resnet50_imagenet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlossomingL/compression_tool/412c47e53417de9c9de04b6d64ea3e07939a5b07/work_space/pruned_define_model/__pycache__/make_pruned_resnet50_imagenet.cpython-36.pyc -------------------------------------------------------------------------------- /work_space/sensitivity_data/resnet50_imagenet_0.9773/L1Rank/auto_yaml.yaml: -------------------------------------------------------------------------------- 1 | extensions: 2 | net_thinner: 3 | arch: None 4 | class: FilterRemover 5 | dataset: 112x112 6 | thinning_func_str: remove_filters 7 | policies: 8 | - epochs: 9 | - 0 10 | pruner: 11 | instance_name: filter_pruner_10 12 | - epochs: 13 | - 0 14 | pruner: 15 | instance_name: filter_pruner_20 16 | - epochs: 17 | - 0 18 | pruner: 19 | instance_name: filter_pruner_30 20 | - epochs: 21 | - 0 22 | pruner: 23 | instance_name: filter_pruner_50 24 | - epochs: 25 | - 0 26 | pruner: 27 | instance_name: filter_pruner_60 28 | - epochs: 29 | - 0 30 | pruner: 31 | instance_name: filter_pruner_70 32 | - epochs: 33 | - 0 34 | pruner: 35 | instance_name: filter_pruner_90 36 | - epochs: 37 | - 0 38 | extension: 39 | instance_name: net_thinner 40 | pruners: 41 | filter_pruner_10: 42 | class: L1RankedStructureParameterPruner 43 | desired_sparsity: 0.1 44 | group_type: Filters 45 | weights: 46 | - layer4.0.conv1.weight 47 | filter_pruner_20: 48 | class: L1RankedStructureParameterPruner 49 | desired_sparsity: 0.2 50 | group_type: Filters 51 | weights: 52 | - conv1.weight 53 | - layer2.0.conv2.weight 54 | - layer2.1.conv2.weight 55 | - layer2.2.conv2.weight 56 | - layer2.3.conv2.weight 57 | - layer3.0.conv2.weight 58 | - layer3.1.conv2.weight 59 | - layer3.2.conv2.weight 60 | - layer3.3.conv2.weight 61 | - layer3.4.conv2.weight 62 | - layer3.5.conv2.weight 63 | - layer4.0.conv2.weight 64 | - layer4.1.conv2.weight 65 | - layer4.2.conv2.weight 66 | filter_pruner_30: 67 | class: L1RankedStructureParameterPruner 68 | desired_sparsity: 0.3 69 | group_type: Filters 70 | weights: 71 | - layer3.3.conv1.weight 72 | - layer3.5.conv1.weight 73 | filter_pruner_50: 74 | class: L1RankedStructureParameterPruner 75 | desired_sparsity: 0.5 76 | group_type: Filters 77 | weights: 78 | - layer3.0.conv1.weight 79 | - layer4.1.conv1.weight 80 | filter_pruner_60: 81 | class: L1RankedStructureParameterPruner 82 | desired_sparsity: 0.6 83 | group_type: Filters 84 | weights: 85 | - layer2.2.conv1.weight 86 | - layer3.1.conv1.weight 87 | - layer3.2.conv1.weight 88 | - layer3.4.conv1.weight 89 | - layer4.2.conv1.weight 90 | filter_pruner_70: 91 | class: L1RankedStructureParameterPruner 92 | desired_sparsity: 0.7 93 | group_type: Filters 94 | weights: 95 | - layer1.0.conv1.weight 96 | - layer1.1.conv1.weight 97 | - layer1.2.conv1.weight 98 | - layer2.0.conv1.weight 99 | - layer2.1.conv1.weight 100 | - layer2.3.conv1.weight 101 | filter_pruner_90: 102 | class: L1RankedStructureParameterPruner 103 | desired_sparsity: 0.9 104 | group_type: Filters 105 | weights: 106 | - layer1.0.conv2.weight 107 | - layer1.1.conv2.weight 108 | - layer1.2.conv2.weight 109 | version: 1 110 | -------------------------------------------------------------------------------- /work_space/sensitivity_data/resnet50_imagenet_0.9773/fpgm/auto_yaml.yaml: -------------------------------------------------------------------------------- 1 | extensions: 2 | net_thinner: 3 | arch: None 4 | class: FilterRemover 5 | dataset: 112x112 6 | thinning_func_str: remove_filters 7 | policies: 8 | - epochs: 9 | - 0 10 | pruner: 11 | instance_name: filter_pruner_10 12 | - epochs: 13 | - 0 14 | pruner: 15 | instance_name: filter_pruner_20 16 | - epochs: 17 | - 0 18 | pruner: 19 | instance_name: filter_pruner_30 20 | - epochs: 21 | - 0 22 | pruner: 23 | instance_name: filter_pruner_40 24 | - epochs: 25 | - 0 26 | pruner: 27 | instance_name: filter_pruner_50 28 | - epochs: 29 | - 0 30 | pruner: 31 | instance_name: filter_pruner_60 32 | - epochs: 33 | - 0 34 | pruner: 35 | instance_name: filter_pruner_70 36 | - epochs: 37 | - 0 38 | pruner: 39 | instance_name: filter_pruner_90 40 | - epochs: 41 | - 0 42 | extension: 43 | instance_name: net_thinner 44 | pruners: 45 | filter_pruner_10: 46 | class: L1RankedStructureParameterPruner 47 | desired_sparsity: 0.1 48 | group_type: Filters 49 | weights: 50 | - layer4.0.conv1.weight 51 | filter_pruner_20: 52 | class: L1RankedStructureParameterPruner 53 | desired_sparsity: 0.2 54 | group_type: Filters 55 | weights: 56 | - conv1.weight 57 | - layer2.0.conv2.weight 58 | - layer2.1.conv2.weight 59 | - layer2.2.conv2.weight 60 | - layer2.3.conv2.weight 61 | - layer3.0.conv2.weight 62 | - layer3.1.conv2.weight 63 | - layer3.2.conv2.weight 64 | - layer3.3.conv2.weight 65 | - layer3.4.conv2.weight 66 | - layer3.5.conv2.weight 67 | - layer4.0.conv2.weight 68 | - layer4.1.conv2.weight 69 | - layer4.2.conv2.weight 70 | filter_pruner_30: 71 | class: L1RankedStructureParameterPruner 72 | desired_sparsity: 0.3 73 | group_type: Filters 74 | weights: 75 | - layer3.3.conv1.weight 76 | - layer3.5.conv1.weight 77 | filter_pruner_40: 78 | class: L1RankedStructureParameterPruner 79 | desired_sparsity: 0.4 80 | group_type: Filters 81 | weights: 82 | - layer3.0.conv1.weight 83 | filter_pruner_50: 84 | class: L1RankedStructureParameterPruner 85 | desired_sparsity: 0.5 86 | group_type: Filters 87 | weights: 88 | - layer3.2.conv1.weight 89 | - layer4.1.conv1.weight 90 | filter_pruner_60: 91 | class: L1RankedStructureParameterPruner 92 | desired_sparsity: 0.6 93 | group_type: Filters 94 | weights: 95 | - layer3.4.conv1.weight 96 | filter_pruner_70: 97 | class: L1RankedStructureParameterPruner 98 | desired_sparsity: 0.7 99 | group_type: Filters 100 | weights: 101 | - layer1.0.conv1.weight 102 | - layer1.1.conv1.weight 103 | - layer1.2.conv1.weight 104 | - layer2.0.conv1.weight 105 | - layer2.1.conv1.weight 106 | - layer2.2.conv1.weight 107 | - layer2.3.conv1.weight 108 | - layer3.1.conv1.weight 109 | - layer4.2.conv1.weight 110 | filter_pruner_90: 111 | class: L1RankedStructureParameterPruner 112 | desired_sparsity: 0.9 113 | group_type: Filters 114 | weights: 115 | - layer1.0.conv2.weight 116 | - layer1.1.conv2.weight 117 | - layer1.2.conv2.weight 118 | version: 1 119 | -------------------------------------------------------------------------------- /yaml_file/auto_yaml.yaml: -------------------------------------------------------------------------------- 1 | extensions: 2 | net_thinner: 3 | arch: None 4 | class: FilterRemover 5 | dataset: 112x112 6 | thinning_func_str: remove_filters 7 | policies: 8 | - epochs: 9 | - 0 10 | pruner: 11 | instance_name: filter_pruner_10 12 | - epochs: 13 | - 0 14 | pruner: 15 | instance_name: filter_pruner_20 16 | - epochs: 17 | - 0 18 | extension: 19 | instance_name: net_thinner 20 | pruners: 21 | filter_pruner_10: 22 | class: L1RankedStructureParameterPruner 23 | desired_sparsity: 0.1 24 | group_type: Filters 25 | weights: 26 | # - conv1.weight 27 | - layer1.2.conv1.weight 28 | - layer2.0.conv2.weight 29 | - layer2.0.downsample.0.weight 30 | - layer2.1.conv1.weight 31 | - layer2.1.conv2.weight 32 | - layer2.2.conv2.weight 33 | - layer2.3.conv2.weight 34 | - layer3.0.conv1.weight 35 | - layer3.0.conv2.weight 36 | - layer3.0.downsample.0.weight 37 | - layer3.1.conv1.weight 38 | - layer3.1.conv2.weight 39 | - layer3.2.conv1.weight 40 | - layer3.2.conv2.weight 41 | - layer3.3.conv2.weight 42 | - layer3.4.conv2.weight 43 | - layer3.5.conv1.weight 44 | - layer3.5.conv2.weight 45 | - layer3.6.conv2.weight 46 | - layer3.7.conv2.weight 47 | - layer3.8.conv2.weight 48 | - layer3.9.conv2.weight 49 | - layer3.10.conv2.weight 50 | - layer3.11.conv2.weight 51 | - layer3.12.conv2.weight 52 | - layer3.13.conv2.weight 53 | - layer4.0.conv1.weight 54 | - layer4.0.conv2.weight 55 | - layer4.0.downsample.0.weight 56 | - layer4.1.conv2.weight 57 | - layer4.2.conv2.weight 58 | filter_pruner_20: 59 | class: L1RankedStructureParameterPruner 60 | desired_sparsity: 0.2 61 | group_type: Filters 62 | weights: 63 | - layer1.0.conv2.weight 64 | - layer1.0.downsample.0.weight 65 | - layer1.1.conv2.weight 66 | - layer1.2.conv2.weight 67 | - layer3.6.conv1.weight 68 | - layer3.13.conv1.weight 69 | version: 1 70 | --------------------------------------------------------------------------------