├── .gitignore ├── .vscode ├── launch.json └── settings.json ├── README.md ├── applications └── efficientnet │ ├── .gitignore │ ├── .vscode │ ├── launch.json │ └── settings.json │ ├── README.md │ ├── dms_efficientnet.py │ ├── requirements.txt │ ├── scripts │ └── DMS-450 │ │ ├── prune.sh │ │ └── retrain.sh │ ├── test_eff.py │ ├── timm_dataset.py │ ├── timm_pruning.py │ └── timm_retrain.py ├── dms ├── __init__.py ├── dtopk_src.py └── modules │ ├── __init__.py │ ├── algorithm.py │ ├── dtp.py │ ├── mutable.py │ ├── mutator.py │ ├── op.py │ ├── optimizer.py │ ├── scheduler.py │ ├── unit.py │ └── utils.py ├── images └── compare.png ├── mmrazor ├── .circleci │ ├── config.yml │ ├── docker │ │ └── Dockerfile │ └── test.yml ├── .dev_scripts │ ├── benchmark_summary_analyse.py │ ├── benchmark_test.py │ ├── benchmark_train.py │ └── meta_files_test.py ├── .github │ ├── CONTRIBUTING.md │ ├── ISSUE_TEMPLATE │ │ ├── bug_report.md │ │ ├── config.yml │ │ ├── feature_request.md │ │ └── general-questions.md │ ├── pull_request_template.md │ └── workflows │ │ ├── build.yml │ │ ├── deploy.yml │ │ └── lint.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yml ├── LICENSE ├── MANIFEST.in ├── README.md ├── README_zh-CN.md ├── configs │ ├── _base_ │ │ ├── datasets │ │ │ └── mmcls │ │ │ │ ├── cifar100_bs16_auto_aug.py │ │ │ │ └── pipelines │ │ │ │ └── auto_aug_cifar.py │ │ ├── nas_backbones │ │ │ ├── attentive_mobilenetv3_supernet.py │ │ │ ├── darts_supernet.py │ │ │ ├── dsnas_shufflenet_supernet.py │ │ │ ├── ofa_mobilenetv3_supernet.py │ │ │ ├── spos_mobilenet_supernet.py │ │ │ └── spos_shufflenet_supernet.py │ │ ├── settings │ │ │ ├── cifar10_darts_subnet.py │ │ │ ├── cifar10_darts_supernet.py │ │ │ ├── imagenet_bs1024_dsnas.py │ │ │ ├── imagenet_bs1024_spos.py │ │ │ ├── imagenet_bs2048_AdamW.py │ │ │ ├── imagenet_bs2048_autoslim.py │ │ │ ├── imagenet_bs2048_autoslim_pil.py │ │ │ ├── imagenet_bs2048_bignas.py │ │ │ ├── imagenet_bs2048_dmcp.py │ │ │ └── imagenet_bs2048_ofa.py │ │ └── vanilla_models │ │ │ └── wrn16_2_cifar10.py │ ├── distill │ │ ├── mmcls │ │ │ ├── abloss │ │ │ │ ├── README.md │ │ │ │ ├── abloss_logits_resnet50_resnet18_8xb32_in1k.py │ │ │ │ ├── abloss_pretrain_backbone_resnet50_resnet18_8xb32_in1k.py │ │ │ │ └── metafile.yml │ │ │ ├── byot │ │ │ │ ├── README.md │ │ │ │ ├── byot_resnet18_8xb16_cifar100.py │ │ │ │ └── metafile.yml │ │ │ ├── crd │ │ │ │ ├── README.md │ │ │ │ ├── crd_neck_r50_r18_8xb16_cifar10.py │ │ │ │ └── datasets │ │ │ │ │ └── crd_cifar10_bs16.py │ │ │ ├── dafl │ │ │ │ ├── README.md │ │ │ │ ├── dafl_logits_resnet34_resnet18_8xb256_cifar10.py │ │ │ │ └── metafile.yml │ │ │ ├── deit │ │ │ │ ├── README.md │ │ │ │ ├── deit-base_regnety160_pt-16xb64_in1k.py │ │ │ │ └── metafile.yml │ │ │ ├── dfad │ │ │ │ ├── README.md │ │ │ │ ├── dfad_logits_resnet34_resnet18_8xb32_cifar10.py │ │ │ │ └── metafile.yml │ │ │ ├── dkd │ │ │ │ ├── README.md │ │ │ │ ├── dkd_resnet34_resnet18_8xb32_in1k.py │ │ │ │ └── metafile.yml │ │ │ ├── factor_transfer │ │ │ │ ├── README.md │ │ │ │ ├── factor-transfer_backbone_resnet50_resnet18_8xb16_cifar10_pretrain.py │ │ │ │ ├── factor-transfer_backbone_resnet50_resnet18_8xb16_cifar10_train.py │ │ │ │ └── metafile.yml │ │ │ ├── fitnets │ │ │ │ ├── README.md │ │ │ │ ├── fitnets_backbone_logits_resnet50_resnet18_8xb32_in1k.py │ │ │ │ └── metafile.yml │ │ │ ├── kd │ │ │ │ ├── README.md │ │ │ │ ├── kd_logits_resnet34_resnet18_8xb32_in1k.py │ │ │ │ ├── kd_logits_resnet50_mobilenet-v2_8xb32_in1k.py │ │ │ │ ├── kd_logits_resnet50_shufflenet-v2-1x_16xb64_in1k.py │ │ │ │ └── metafile.yml │ │ │ ├── ofd │ │ │ │ ├── README.md │ │ │ │ ├── metafile.yml │ │ │ │ └── ofd_backbone_resnet50_resnet18_8xb16_cifar10.py │ │ │ ├── rkd │ │ │ │ ├── README.md │ │ │ │ ├── metafile.yml │ │ │ │ └── rkd_neck_resnet34_resnet18_8xb32_in1k.py │ │ │ ├── wsld │ │ │ │ ├── README.md │ │ │ │ ├── metafile.yml │ │ │ │ └── wsld_logits_resnet34_resnet18_8xb32_in1k.py │ │ │ └── zskt │ │ │ │ ├── README.md │ │ │ │ ├── metafile.yml │ │ │ │ └── zskt_backbone_logits_resnet34_resnet18_8xb16_cifar10.py │ │ ├── mmdet │ │ │ ├── cwd │ │ │ │ ├── README.md │ │ │ │ ├── cwd_cls_head_gfl_r101_fpn_gfl_r50_fpn_1x_coco.py │ │ │ │ ├── cwd_fpn_frcnn_r101_frcnn_r50_1x_coco.py │ │ │ │ ├── cwd_fpn_retina_r101_retina_r50_1x_coco.py │ │ │ │ ├── cwd_fpn_retina_r101_retina_r50_1x_coco_visualization.py │ │ │ │ └── metafile.yml │ │ │ ├── fbkd │ │ │ │ ├── README.md │ │ │ │ ├── fbkd_fpn_faster-rcnn_r101_faster-rcnn_r50_1x_coco.py │ │ │ │ └── metafile.yml │ │ │ ├── mgd │ │ │ │ ├── README.md │ │ │ │ └── mgd_fpn_retina_x101_retina_r50_2x_coco.py │ │ │ └── pkd │ │ │ │ ├── README.md │ │ │ │ ├── metafile.yml │ │ │ │ ├── pkd_fpn_faster-rcnn_r101_faster-rcnn_r50_2x_coco.py │ │ │ │ ├── pkd_fpn_fcos_x101_retina_r50_1x_coco.py │ │ │ │ ├── pkd_fpn_mask-rcnn_swin_retina_r50_2x_coco.py │ │ │ │ ├── pkd_fpn_reppoints_x101-dcn_reppoints_r50_2x_coco.py │ │ │ │ └── pkd_fpn_retina_x101_retina_r50_2x_coco.py │ │ ├── mmdet3d │ │ │ └── pkd │ │ │ │ ├── README.md │ │ │ │ ├── metafile.yml │ │ │ │ └── pkd_fpn_fcos3d_r101_fcos3d_r50_8xb2-1x_nus-mono3d.py │ │ └── mmseg │ │ │ └── cwd │ │ │ ├── README.md │ │ │ ├── cwd_logits_pspnet_r101-d8_pspnet_r18-d8_4xb2-80k_cityscapes-512x1024.py │ │ │ └── metafile.yml │ ├── nas │ │ ├── mmcls │ │ │ ├── autoformer │ │ │ │ ├── AUTOFORMER_SUBNET_B.yaml │ │ │ │ ├── README.md │ │ │ │ ├── autoformer_search_8xb128_in1k.py │ │ │ │ ├── autoformer_subnet_8xb256_in1k.py │ │ │ │ └── autoformer_supernet_32xb256_in1k.py │ │ │ ├── autoslim │ │ │ │ ├── README.md │ │ │ │ ├── autoslim_mbv2_1.5x_search_8xb256_in1k.py │ │ │ │ ├── autoslim_mbv2_1.5x_slimmable_subnet_8xb256_in1k.py │ │ │ │ ├── autoslim_mbv2_1.5x_subnet_8xb256_in1k_flops-220M.py │ │ │ │ ├── autoslim_mbv2_1.5x_subnet_8xb256_in1k_flops-320M.py │ │ │ │ ├── autoslim_mbv2_1.5x_subnet_8xb256_in1k_flops-530M.py │ │ │ │ ├── autoslim_mbv2_1.5x_supernet_8xb256_in1k.py │ │ │ │ └── metafile.yml │ │ │ ├── bignas │ │ │ │ ├── ATTENTIVE_SUBNET_A0.yaml │ │ │ │ ├── ATTENTIVE_SUBNET_A1.yaml │ │ │ │ ├── ATTENTIVE_SUBNET_A2.yaml │ │ │ │ ├── ATTENTIVE_SUBNET_A3.yaml │ │ │ │ ├── ATTENTIVE_SUBNET_A4.yaml │ │ │ │ ├── ATTENTIVE_SUBNET_A5.yaml │ │ │ │ ├── ATTENTIVE_SUBNET_A6.yaml │ │ │ │ ├── README.md │ │ │ │ ├── attentive_mobilenet_search_8xb128_in1k.py │ │ │ │ ├── attentive_mobilenet_subnet_8xb256_in1k.py │ │ │ │ └── attentive_mobilenet_supernet_32xb64_in1k.py │ │ │ ├── darts │ │ │ │ ├── DARTS_SUBNET_CIFAR_MMRAZOR_97.32.yaml │ │ │ │ ├── DARTS_SUBNET_CIFAR_PAPER_ALIAS.yaml │ │ │ │ ├── README.md │ │ │ │ ├── darts_subnet_1xb96_cifar10_2.0.py │ │ │ │ ├── darts_subnet_1xb96_cifar10_2.0_mmrazor.py │ │ │ │ ├── darts_supernet_unroll_1xb96_cifar10.py │ │ │ │ └── metafile.yml │ │ │ ├── dsnas │ │ │ │ ├── DSNAS_SUBNET_IMAGENET_PAPER_ALIAS.yaml │ │ │ │ ├── README.md │ │ │ │ ├── dsnas_subnet_8xb128_in1k.py │ │ │ │ └── dsnas_supernet_8xb128_in1k.py │ │ │ ├── onceforall │ │ │ │ ├── OFA_SUBNET_NOTE8_LAT22.yaml │ │ │ │ ├── OFA_SUBNET_NOTE8_LAT31.yaml │ │ │ │ ├── README.md │ │ │ │ ├── ofa_mobilenet_search_8xb128_in1k.py │ │ │ │ ├── ofa_mobilenet_subnet_8xb256_in1k.py │ │ │ │ └── ofa_mobilenet_supernet_32xb64_in1k.py │ │ │ └── spos │ │ │ │ ├── README.md │ │ │ │ ├── SPOS_SUBNET.yaml │ │ │ │ ├── faster-rcnn_nas_backbone_fpn_1x_coco.py │ │ │ │ ├── metafile.yml │ │ │ │ ├── spos_mobilenet_search_8xb128_in1k.py │ │ │ │ ├── spos_mobilenet_subnet_8xb128_in1k.py │ │ │ │ ├── spos_mobilenet_supernet_8xb128_in1k.py │ │ │ │ ├── spos_shufflenet_search_8xb128_in1k.py │ │ │ │ ├── spos_shufflenet_search_predictor_8xb128_in1k.py │ │ │ │ ├── spos_shufflenet_subnet_8xb128_in1k.py │ │ │ │ └── spos_shufflenet_supernet_8xb128_in1k.py │ │ └── mmdet │ │ │ └── detnas │ │ │ ├── DETNAS_SUBNET.yaml │ │ │ ├── README.md │ │ │ ├── detnas_frcnn_shufflenet_search_coco_1x.py │ │ │ ├── detnas_frcnn_shufflenet_subnet_coco_1x.py │ │ │ ├── detnas_frcnn_shufflenet_supernet_coco_1x.py │ │ │ ├── detnas_retina_shufflenet_supernet_coco_1x.py │ │ │ ├── detnas_shufflenet_subnet_8xb128_in1k.py │ │ │ ├── detnas_shufflenet_supernet_8xb128_in1k.py │ │ │ └── metafile.yml │ ├── pruning │ │ ├── base │ │ │ └── group_fisher │ │ │ │ ├── README.md │ │ │ │ ├── group_fisher_deploy_template.py │ │ │ │ ├── group_fisher_finetune_template.py │ │ │ │ └── group_fisher_prune_template.py │ │ ├── mmcls │ │ │ ├── dcff │ │ │ │ ├── README.md │ │ │ │ ├── dcff_compact_resnet_8xb32_in1k.py │ │ │ │ ├── dcff_resnet_8xb32_in1k.py │ │ │ │ └── fix_subnet.json │ │ │ ├── dmcp │ │ │ │ ├── DMCP_MBV2_100M.json │ │ │ │ ├── DMCP_R50_2G.json │ │ │ │ ├── README.md │ │ │ │ ├── dmcp_mbv2_subnet_32xb64.py │ │ │ │ ├── dmcp_mbv2_supernet_32xb64.py │ │ │ │ ├── dmcp_resnet50_subnet_32xb64.py │ │ │ │ ├── dmcp_resnet50_supernet_32xb64.py │ │ │ │ └── metafile.yml │ │ │ ├── group_fisher │ │ │ │ ├── README.md │ │ │ │ ├── mobilenet │ │ │ │ │ ├── group_fisher_act_deploy_mobilenet-v2_8xb32_in1k.py │ │ │ │ │ ├── group_fisher_act_finetune_mobilenet-v2_8xb32_in1k.py │ │ │ │ │ ├── group_fisher_act_prune_mobilenet-v2_8xb32_in1k.py │ │ │ │ │ ├── group_fisher_flops_deploy_mobilenet-v2_8xb32_in1k.py │ │ │ │ │ ├── group_fisher_flops_finetune_mobilenet-v2_8xb32_in1k.py │ │ │ │ │ ├── group_fisher_flops_prune_mobilenet-v2_8xb32_in1k.py │ │ │ │ │ ├── metafile.yml │ │ │ │ │ └── script.sh │ │ │ │ └── resnet50 │ │ │ │ │ ├── group_fisher_act_deploy_resnet50_8xb32_in1k.py │ │ │ │ │ ├── group_fisher_act_finetune_resnet50_8xb32_in1k.py │ │ │ │ │ ├── group_fisher_act_finetune_resnet50_8xb32_in1k_dist.py │ │ │ │ │ ├── group_fisher_act_prune_resnet50_8xb32_in1k.py │ │ │ │ │ ├── group_fisher_flops_deploy_resnet50_8xb32_in1k.py │ │ │ │ │ ├── group_fisher_flops_finetune_resnet50_8xb32_in1k.py │ │ │ │ │ ├── group_fisher_flops_prune_resnet50_8xb32_in1k.py │ │ │ │ │ ├── metafile.yml │ │ │ │ │ └── script.sh │ │ │ └── l1-norm │ │ │ │ ├── README.md │ │ │ │ ├── l1-norm_resnet34_8xb32_in1k_a.py │ │ │ │ ├── l1-norm_resnet34_8xb32_in1k_a_deploy.py │ │ │ │ ├── l1-norm_resnet34_8xb32_in1k_b.py │ │ │ │ ├── l1-norm_resnet34_8xb32_in1k_b_deploy.py │ │ │ │ ├── l1-norm_resnet34_8xb32_in1k_c.py │ │ │ │ ├── l1-norm_resnet34_8xb32_in1k_c_deploy.py │ │ │ │ ├── metafile.yml │ │ │ │ └── script.sh │ │ ├── mmdet │ │ │ ├── dcff │ │ │ │ ├── README.md │ │ │ │ ├── dcff_compact_faster_rcnn_resnet50_8xb4_coco.py │ │ │ │ ├── dcff_faster_rcnn_resnet50_8xb4_coco.py │ │ │ │ ├── dcff_faster_rcnn_resnet50_fpn.py │ │ │ │ └── fix_subnet.json │ │ │ └── group_fisher │ │ │ │ ├── README.md │ │ │ │ └── retinanet │ │ │ │ ├── group_fisher_act_deploy_retinanet_r50_fpn_1x_coco.py │ │ │ │ ├── group_fisher_act_finetune_retinanet_r50_fpn_1x_coco.py │ │ │ │ ├── group_fisher_act_prune_retinanet_r50_fpn_1x_coco.py │ │ │ │ ├── group_fisher_flops_deploy_retinanet_r50_fpn_1x_coco.py │ │ │ │ ├── group_fisher_flops_finetune_retinanet_r50_fpn_1x_coco.py │ │ │ │ ├── group_fisher_flops_prune_retinanet_r50_fpn_1x_coco.py │ │ │ │ ├── metafile.yml │ │ │ │ └── script.sh │ │ ├── mmpose │ │ │ ├── dcff │ │ │ │ ├── README.md │ │ │ │ ├── dcff_compact_topdown_heatmap_resnet50_coco.py │ │ │ │ ├── dcff_topdown_heatmap_resnet50_coco.py │ │ │ │ └── fix_subnet.json │ │ │ └── group_fisher │ │ │ │ ├── group_fisher_deploy_rtmpose-s_8xb256-420e_aic-coco-256x192.py │ │ │ │ ├── group_fisher_deploy_rtmpose-s_8xb256-420e_coco-256x192.py │ │ │ │ ├── group_fisher_finetune_rtmpose-s_8xb256-420e_aic-coco-256x192.py │ │ │ │ ├── group_fisher_finetune_rtmpose-s_8xb256-420e_coco-256x192.py │ │ │ │ ├── group_fisher_prune_rtmpose-s_8xb256-420e_aic-coco-256x192.py │ │ │ │ ├── group_fisher_prune_rtmpose-s_8xb256-420e_coco-256x192.py │ │ │ │ └── script.sh │ │ └── mmseg │ │ │ └── dcff │ │ │ ├── README.md │ │ │ ├── dcff_compact_pointrend_resnet50_8xb2_cityscapes.py │ │ │ ├── dcff_pointrend_resnet50_8xb2_cityscapes.py │ │ │ ├── fix_subnet.json │ │ │ └── pointrend_resnet50.py │ ├── quantization │ │ ├── deploy_cfgs │ │ │ ├── mmcls │ │ │ │ ├── classification_openvino_dynamic-224x224.py │ │ │ │ └── classification_tensorrt-int8-explicit_dynamic-224x224.py │ │ │ └── mmdet │ │ │ │ ├── detection_openvino_dynamic-800x1344.py │ │ │ │ └── detection_tensorrt-int8-explicit_dynamic-320x320-1344x1344.py │ │ ├── ptq │ │ │ └── base │ │ │ │ ├── README.md │ │ │ │ ├── metafile.yml │ │ │ │ ├── ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py │ │ │ │ ├── ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py │ │ │ │ ├── ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py │ │ │ │ ├── ptq_openvino_retina_r50_1x_coco_calib32xb32.py │ │ │ │ ├── ptq_openvino_yolox_s_8xb8-300e_coco_calib32xb32.py │ │ │ │ ├── ptq_tensorrt_mbv2_8xb32_in1k_calib32xb32.py │ │ │ │ ├── ptq_tensorrt_resnet18_8xb32_in1k_calib32xb32.py │ │ │ │ ├── ptq_tensorrt_resnet50_8xb32_in1k_calib32xb32.py │ │ │ │ ├── ptq_tensorrt_retina_r50_1x_coco_calib32xb32.py │ │ │ │ └── ptq_tensorrt_yolox_s_8xb8-300e_coco_calib32xb32.py │ │ └── qat │ │ │ ├── base │ │ │ ├── README.md │ │ │ ├── metafile.yml │ │ │ └── qat_openvino_resnet18_10e_8xb32_in1k.py │ │ │ └── lsq │ │ │ ├── README.md │ │ │ ├── lsq_openvino_resnet18_8xb32_100e_in1k.py │ │ │ ├── lsq_openvino_resnet18_8xb32_10e_in1k.py │ │ │ └── metafile.yml │ └── vanilla │ │ └── mmcls │ │ └── wide-resnet │ │ ├── README.md │ │ ├── wrn16-w2_b16x8_cifar10.py │ │ ├── wrn22-w4_b16x8_cifar10.py │ │ ├── wrn28-w4_b16x8_cifar10.py │ │ └── wrn40-w2_b16x8_cifar10.py ├── demo │ └── pruning │ │ └── config_pruning.ipynb ├── docker │ ├── Dockerfile │ └── serve │ │ ├── Dockerfile │ │ ├── config.properties │ │ └── entrypoint.sh ├── docs │ ├── en │ │ ├── Makefile │ │ ├── _static │ │ │ ├── css │ │ │ │ └── readthedocs.css │ │ │ └── image │ │ │ │ └── mmrazor-logo.png │ │ ├── advanced_guides │ │ │ ├── algorithm.md │ │ │ ├── apply_existing_algorithms_to_new_tasks.md │ │ │ ├── customize_architectures.md │ │ │ ├── customize_kd_algorithms.md │ │ │ ├── customize_mixed_algorithms.md │ │ │ ├── customize_nas_algorithms.md │ │ │ ├── customize_pruning_algorithms.md │ │ │ ├── customize_quantization_algorithms.md │ │ │ ├── delivery.md │ │ │ ├── index.rst │ │ │ ├── mutable.md │ │ │ ├── mutator.md │ │ │ ├── recorder.md │ │ │ └── tutorials │ │ │ │ ├── how_to_prune_your_model.md │ │ │ │ └── how_to_use_config_tool_of_pruning.md │ │ ├── api.rst │ │ ├── conf.py │ │ ├── get_started │ │ │ ├── installation.md │ │ │ ├── model_zoo.md │ │ │ └── overview.md │ │ ├── imgs │ │ │ └── pruning │ │ │ │ ├── draw-config.png │ │ │ │ ├── framework-ChanelMutator.png │ │ │ │ ├── framework-algorithm.png │ │ │ │ ├── framework-framework.png │ │ │ │ ├── framework-graph.png │ │ │ │ ├── framework-op.png │ │ │ │ ├── pruning_framework.png │ │ │ │ └── unit.png │ │ ├── index.rst │ │ ├── make.bat │ │ ├── notes │ │ │ ├── changelog.md │ │ │ ├── contribution_guide.md │ │ │ └── faq.md │ │ ├── switch_language.md │ │ └── user_guides │ │ │ ├── 1_learn_about_config.md │ │ │ ├── 2_train_different_types_algorithms.md │ │ │ ├── 3_train_with_different_devices.md │ │ │ ├── 4_test_a_model.md │ │ │ ├── index.rst │ │ │ ├── pruning_user_guide.md │ │ │ └── quantization_user_guide.md │ └── zh_cn │ │ ├── Makefile │ │ ├── api.rst │ │ ├── conf.py │ │ ├── index.rst │ │ ├── make.bat │ │ ├── switch_language.md │ │ └── user_guides │ │ └── visualization.md ├── mmrazor │ ├── __init__.py │ ├── datasets │ │ ├── __init__.py │ │ ├── crd_dataset_wrapper.py │ │ └── transforms │ │ │ ├── __init__.py │ │ │ ├── auto_augment.py │ │ │ ├── auto_augmentv2.py │ │ │ └── formatting.py │ ├── engine │ │ ├── __init__.py │ │ ├── hooks │ │ │ ├── __init__.py │ │ │ ├── dmcp_subnet_hook.py │ │ │ ├── dump_subnet_hook.py │ │ │ ├── estimate_resources_hook.py │ │ │ ├── group_fisher_hooks.py │ │ │ ├── stop_distillation_hook.py │ │ │ └── visualization_hook.py │ │ ├── optimizers │ │ │ ├── __init__.py │ │ │ └── optimizer_constructor.py │ │ └── runner │ │ │ ├── __init__.py │ │ │ ├── autoslim_greedy_search_loop.py │ │ │ ├── darts_loop.py │ │ │ ├── distill_val_loop.py │ │ │ ├── evolution_search_loop.py │ │ │ ├── iteprune_val_loop.py │ │ │ ├── quantization_loops.py │ │ │ ├── slimmable_val_loop.py │ │ │ ├── subnet_sampler_loop.py │ │ │ ├── subnet_val_loop.py │ │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── calibrate_bn_mixin.py │ │ │ ├── check.py │ │ │ └── genetic.py │ ├── implementations │ │ ├── __init__.py │ │ ├── pruning │ │ │ ├── __init__.py │ │ │ ├── group_fisher │ │ │ │ ├── __init__.py │ │ │ │ ├── algorithm.py │ │ │ │ ├── counters.py │ │ │ │ ├── hook.py │ │ │ │ ├── mutator.py │ │ │ │ ├── ops.py │ │ │ │ ├── prune_deploy_sub_model.py │ │ │ │ ├── prune_sub_model.py │ │ │ │ └── unit.py │ │ │ └── sparse_gpt │ │ │ │ ├── __init__.py │ │ │ │ ├── compressor.py │ │ │ │ ├── ops.py │ │ │ │ ├── sparse24_utils.py │ │ │ │ └── utils.py │ │ └── quantization │ │ │ └── gptq │ │ │ ├── __init__.py │ │ │ ├── compressor.py │ │ │ ├── custom_autotune.py │ │ │ ├── gptq.py │ │ │ ├── ops.py │ │ │ ├── quantizer.py │ │ │ └── utils.py │ ├── models │ │ ├── __init__.py │ │ ├── algorithms │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── distill │ │ │ │ ├── __init__.py │ │ │ │ └── configurable │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── datafree_distillation.py │ │ │ │ │ ├── fpn_teacher_distill.py │ │ │ │ │ ├── overhaul_feature_distillation.py │ │ │ │ │ ├── self_distill.py │ │ │ │ │ └── single_teacher_distill.py │ │ │ ├── nas │ │ │ │ ├── __init__.py │ │ │ │ ├── autoformer.py │ │ │ │ ├── autoslim.py │ │ │ │ ├── bignas.py │ │ │ │ ├── darts.py │ │ │ │ ├── dsnas.py │ │ │ │ └── spos.py │ │ │ ├── pruning │ │ │ │ ├── __init__.py │ │ │ │ ├── dcff.py │ │ │ │ ├── dmcp.py │ │ │ │ ├── group_fisher_algoritho.py │ │ │ │ ├── ite_prune_algorithm.py │ │ │ │ └── slimmable_network.py │ │ │ └── quantization │ │ │ │ ├── __init__.py │ │ │ │ └── mm_architecture.py │ │ ├── architectures │ │ │ ├── __init__.py │ │ │ ├── backbones │ │ │ │ ├── __init__.py │ │ │ │ ├── darts_backbone.py │ │ │ │ ├── searchable_autoformer.py │ │ │ │ ├── searchable_mobilenet_v2.py │ │ │ │ ├── searchable_mobilenet_v3.py │ │ │ │ ├── searchable_shufflenet_v2.py │ │ │ │ └── wideresnet.py │ │ │ ├── classifiers │ │ │ │ ├── __init__.py │ │ │ │ └── image.py │ │ │ ├── connectors │ │ │ │ ├── __init__.py │ │ │ │ ├── base_connector.py │ │ │ │ ├── byot_connector.py │ │ │ │ ├── convmodule_connector.py │ │ │ │ ├── crd_connector.py │ │ │ │ ├── factor_transfer_connectors.py │ │ │ │ ├── fbkd_connector.py │ │ │ │ ├── mgd_connector.py │ │ │ │ ├── norm_connector.py │ │ │ │ ├── ofd_connector.py │ │ │ │ └── torch_connector.py │ │ │ ├── dynamic_ops │ │ │ │ ├── __init__.py │ │ │ │ ├── bricks │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── dynamic_container.py │ │ │ │ │ ├── dynamic_conv.py │ │ │ │ │ ├── dynamic_embed.py │ │ │ │ │ ├── dynamic_function.py │ │ │ │ │ ├── dynamic_linear.py │ │ │ │ │ ├── dynamic_multi_head_attention.py │ │ │ │ │ ├── dynamic_norm.py │ │ │ │ │ ├── dynamic_relative_position.py │ │ │ │ │ └── group_fisher_ops.py │ │ │ │ ├── head │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── dynamic_linear_head.py │ │ │ │ └── mixins │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── dynamic_conv_mixins.py │ │ │ │ │ ├── dynamic_layernorm_mixins.py │ │ │ │ │ └── dynamic_mixins.py │ │ │ ├── generators │ │ │ │ ├── __init__.py │ │ │ │ ├── base_generator.py │ │ │ │ ├── dafl_generator.py │ │ │ │ └── zskt_generator.py │ │ │ ├── heads │ │ │ │ ├── __init__.py │ │ │ │ ├── darts_subnet_head.py │ │ │ │ └── deit_head.py │ │ │ ├── necks │ │ │ │ ├── __init__.py │ │ │ │ └── squeezemean_with_dropout.py │ │ │ ├── ops │ │ │ │ ├── __init__.py │ │ │ │ ├── base.py │ │ │ │ ├── common.py │ │ │ │ ├── darts_series.py │ │ │ │ ├── efficientnet_series.py │ │ │ │ ├── function.py │ │ │ │ ├── gather_tensors.py │ │ │ │ ├── mobilenet_series.py │ │ │ │ ├── shufflenet_series.py │ │ │ │ └── transformer_series.py │ │ │ └── utils │ │ │ │ ├── __init__.py │ │ │ │ ├── mutable_register.py │ │ │ │ └── set_dropout.py │ │ ├── distillers │ │ │ ├── __init__.py │ │ │ ├── base_distiller.py │ │ │ ├── byot_distiller.py │ │ │ ├── configurable_distiller.py │ │ │ └── ofd_distiller.py │ │ ├── fake_quants │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── lsq.py │ │ │ └── torch_fake_quants.py │ │ ├── losses │ │ │ ├── __init__.py │ │ │ ├── ab_loss.py │ │ │ ├── at_loss.py │ │ │ ├── crd_loss.py │ │ │ ├── cross_entropy_loss.py │ │ │ ├── cwd.py │ │ │ ├── dafl_loss.py │ │ │ ├── decoupled_kd.py │ │ │ ├── dist_loss.py │ │ │ ├── factor_transfer_loss.py │ │ │ ├── fbkd_loss.py │ │ │ ├── kd_soft_ce_loss.py │ │ │ ├── kl_divergence.py │ │ │ ├── l1_loss.py │ │ │ ├── l2_loss.py │ │ │ ├── mgd_loss.py │ │ │ ├── ofd_loss.py │ │ │ ├── pkd_loss.py │ │ │ ├── relational_kd.py │ │ │ └── weighted_soft_label_distillation.py │ │ ├── mutables │ │ │ ├── __init__.py │ │ │ ├── base_mutable.py │ │ │ ├── derived_mutable.py │ │ │ ├── mutable_channel │ │ │ │ ├── MutableChannel.md │ │ │ │ ├── __init__.py │ │ │ │ ├── base_mutable_channel.py │ │ │ │ ├── mutable_channel_container.py │ │ │ │ ├── oneshot_mutable_channel.py │ │ │ │ ├── sequential_mutable_channel.py │ │ │ │ ├── simple_mutable_channel.py │ │ │ │ └── units │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── channel_unit.py │ │ │ │ │ ├── dcff_channel_unit.py │ │ │ │ │ ├── dmcp_channel_unit.py │ │ │ │ │ ├── group_fisher_unit.py │ │ │ │ │ ├── l1_mutable_channel_unit.py │ │ │ │ │ ├── mutable_channel_unit.ipynb │ │ │ │ │ ├── mutable_channel_unit.py │ │ │ │ │ ├── one_shot_mutable_channel_unit.py │ │ │ │ │ ├── sequential_mutable_channel_unit.py │ │ │ │ │ ├── slimmable_channel_unit.py │ │ │ │ │ └── utils.py │ │ │ ├── mutable_module │ │ │ │ ├── __init__.py │ │ │ │ ├── diff_mutable_module.py │ │ │ │ ├── mutable_module.py │ │ │ │ └── one_shot_mutable_module.py │ │ │ └── mutable_value │ │ │ │ ├── __init__.py │ │ │ │ └── mutable_value.py │ │ ├── mutators │ │ │ ├── __init__.py │ │ │ ├── base_mutator.py │ │ │ ├── channel_mutator │ │ │ │ ├── __init__.py │ │ │ │ ├── channel_mutator.ipynb │ │ │ │ ├── channel_mutator.py │ │ │ │ ├── dcff_channel_mutator.py │ │ │ │ ├── dmcp_channel_mutator.py │ │ │ │ ├── group_fisher_mutator.py │ │ │ │ ├── one_shot_channel_mutator.py │ │ │ │ └── slimmable_channel_mutator.py │ │ │ ├── group_mixin.py │ │ │ └── nas_mutator.py │ │ ├── observers │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── lsq.py │ │ │ └── torch_observers.py │ │ ├── quantizers │ │ │ ├── __init__.py │ │ │ ├── academic_quantizer.py │ │ │ ├── base.py │ │ │ ├── exporters │ │ │ │ ├── __init__.py │ │ │ │ ├── base_quantize_exporter.py │ │ │ │ ├── openvino_quantize_exporter.py │ │ │ │ ├── optim_utils.py │ │ │ │ └── tensorrt_quantize_exporter.py │ │ │ ├── native_quantizer.py │ │ │ ├── openvino_quantizer.py │ │ │ └── tensorrt_quantizer.py │ │ ├── task_modules │ │ │ ├── __init__.py │ │ │ ├── delivery │ │ │ │ ├── __init__.py │ │ │ │ ├── delivery_manager.py │ │ │ │ ├── distill_delivery.py │ │ │ │ ├── function_outputs_delivery.py │ │ │ │ └── method_outputs_delivery.py │ │ │ ├── demo_inputs │ │ │ │ ├── __init__.py │ │ │ │ ├── default_demo_inputs.py │ │ │ │ ├── demo_inputs.py │ │ │ │ ├── mmpose_demo_input.py │ │ │ │ └── mmseg_demo_input.py │ │ │ ├── estimators │ │ │ │ ├── __init__.py │ │ │ │ ├── base_estimator.py │ │ │ │ ├── counters │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── flops_params_counter.py │ │ │ │ │ ├── latency_counter.py │ │ │ │ │ └── op_counters │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── activation_layer_counter.py │ │ │ │ │ │ ├── base_counter.py │ │ │ │ │ │ ├── conv_layer_counter.py │ │ │ │ │ │ ├── deconv_layer_counter.py │ │ │ │ │ │ ├── group_fisher_counters.py │ │ │ │ │ │ ├── linear_layer_counter.py │ │ │ │ │ │ ├── norm_layer_counter.py │ │ │ │ │ │ ├── pooling_layer_counter.py │ │ │ │ │ │ └── upsample_layer_counter.py │ │ │ │ └── resource_estimator.py │ │ │ ├── predictor │ │ │ │ ├── __init__.py │ │ │ │ ├── handler │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── base_handler.py │ │ │ │ │ ├── carts_handler.py │ │ │ │ │ ├── gp_handler.py │ │ │ │ │ ├── mlp_handler.py │ │ │ │ │ └── rbf_handler.py │ │ │ │ └── metric_predictor.py │ │ │ ├── recorder │ │ │ │ ├── __init__.py │ │ │ │ ├── base_recorder.py │ │ │ │ ├── function_inputs_recorder.py │ │ │ │ ├── function_outputs_recorder.py │ │ │ │ ├── method_inputs_recorder.py │ │ │ │ ├── method_outputs_recorder.py │ │ │ │ ├── module_inputs_recorder.py │ │ │ │ ├── module_outputs_recorder.py │ │ │ │ ├── param_recorder.py │ │ │ │ └── recorder_manager.py │ │ │ └── tracer │ │ │ │ ├── __init__.py │ │ │ │ ├── backward_tracer.py │ │ │ │ ├── channel_analyzer.py │ │ │ │ ├── fx │ │ │ │ ├── __init__.py │ │ │ │ ├── custom_tracer.py │ │ │ │ └── graph_utils.py │ │ │ │ ├── fx_tracer.py │ │ │ │ ├── loss_calculator │ │ │ │ ├── __init__.py │ │ │ │ ├── cascade_encoder_decoder_loss_calculator.py │ │ │ │ ├── image_classifier_loss_calculator.py │ │ │ │ ├── single_stage_detector_loss_calculator.py │ │ │ │ ├── sum_loss_calculator.py │ │ │ │ ├── top_down_pose_estimator_loss_calculator.py │ │ │ │ └── two_stage_detector_loss_calculator.py │ │ │ │ ├── parsers.py │ │ │ │ └── path.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── expandable_utils │ │ │ ├── __init__.py │ │ │ ├── ops.py │ │ │ ├── tools.py │ │ │ └── unit.py │ │ │ ├── make_divisible.py │ │ │ ├── misc.py │ │ │ ├── optim_wrapper.py │ │ │ ├── parse_values.py │ │ │ ├── quantization_util.py │ │ │ └── utils.py │ ├── registry │ │ ├── __init__.py │ │ └── registry.py │ ├── structures │ │ ├── __init__.py │ │ ├── graph │ │ │ ├── __init__.py │ │ │ ├── base_graph.py │ │ │ ├── channel_flow.py │ │ │ ├── channel_graph.py │ │ │ ├── channel_nodes.py │ │ │ ├── module_graph.py │ │ │ └── pseudo_fx_graph.py │ │ ├── quantization │ │ │ ├── __init__.py │ │ │ ├── backend_config │ │ │ │ ├── __init__.py │ │ │ │ ├── academic.py │ │ │ │ ├── common_operator_config_utils.py │ │ │ │ ├── mapping.py │ │ │ │ ├── native.py │ │ │ │ ├── openvino.py │ │ │ │ └── tensorrt.py │ │ │ └── qconfig.py │ │ └── subnet │ │ │ ├── __init__.py │ │ │ ├── candidate.py │ │ │ └── fix_subnet.py │ ├── testing │ │ ├── __init__.py │ │ ├── _fast_stop_training_hook.py │ │ └── _fx_models.py │ ├── utils │ │ ├── __init__.py │ │ ├── index_dict.py │ │ ├── log_tools.py │ │ ├── misc.py │ │ ├── placeholder.py │ │ ├── runtime_info.py │ │ ├── setup_env.py │ │ └── typing.py │ ├── version.py │ └── visualization │ │ ├── __init__.py │ │ └── local_visualizer.py ├── model-index.yml ├── projects │ └── mmrazor_large │ │ ├── README.md │ │ ├── algorithms │ │ ├── GPTQ.md │ │ └── SparseGPT.md │ │ └── examples │ │ ├── ResNet │ │ ├── README.md │ │ ├── resnet18_gptq.py │ │ └── resnet18_sparse_gpt.py │ │ └── language_models │ │ ├── LLaMA │ │ ├── README.md │ │ ├── datautils.py │ │ ├── llama_gptq.py │ │ ├── llama_sparse_gpt.py │ │ ├── llama_sparse_gpt_fsdp.py │ │ └── utils.py │ │ └── OPT │ │ ├── README.md │ │ ├── datautils.py │ │ ├── opt_gptq.py │ │ ├── opt_sparse_gpt.py │ │ ├── opt_sparse_gpt_fsdp.py │ │ └── utils.py ├── requirements.txt ├── requirements │ ├── docs.txt │ ├── mminstall.txt │ ├── optional.txt │ ├── readthedocs.txt │ ├── runtime.txt │ └── tests.txt ├── resources │ ├── design_and_implement.png │ ├── mmrazor-logo.png │ ├── qq_group_qrcode.jpg │ ├── xiaozhushou_weixin_qrcode.jpeg │ └── zhihu_qrcode.jpg ├── setup.cfg ├── setup.py ├── tests │ ├── __init__.py │ ├── data │ │ ├── MBV2_220M.yaml │ │ ├── MBV2_320M.yaml │ │ ├── MBV2_530M.yaml │ │ ├── MBV2_slimmable_channel_config.json │ │ ├── MBV2_slimmable_config.json │ │ ├── __init__.py │ │ ├── color.jpeg │ │ ├── concat_subnet1.yaml │ │ ├── concat_subnet2.yaml │ │ ├── dataset │ │ │ ├── a │ │ │ │ └── 1.JPG │ │ │ ├── ann.json │ │ │ ├── ann.txt │ │ │ ├── b │ │ │ │ ├── 2.jpeg │ │ │ │ └── subb │ │ │ │ │ └── 3.jpg │ │ │ ├── classes.txt │ │ │ └── multi_label_ann.json │ │ ├── model_library.py │ │ ├── models.py │ │ ├── subnet1.yaml │ │ ├── subnet2.yaml │ │ ├── test_models │ │ │ ├── test_algorithm │ │ │ │ └── MBV2_220M.yaml │ │ │ ├── test_mutator │ │ │ │ └── subnet1.json │ │ │ ├── test_subnet │ │ │ │ └── mockmodel_subnet.yaml │ │ │ └── test_task_modules │ │ │ │ └── mmcls_cfg.py │ │ ├── test_registry │ │ │ ├── registry_architecture_config.py │ │ │ ├── registry_subnet_config.py │ │ │ └── subnet.json │ │ └── tracer_passed_models.py │ ├── test_core │ │ ├── __init__.py │ │ ├── test_delivers │ │ │ ├── test_deliver_manager.py │ │ │ ├── test_function_outputs_deliver.py │ │ │ ├── test_method_outputs_deliver.py │ │ │ └── toy_module.py │ │ ├── test_graph │ │ │ ├── __init__.py │ │ │ ├── test_channel_flow.py │ │ │ ├── test_channel_graph.py │ │ │ ├── test_graph.py │ │ │ └── test_prune_tracer_model.py │ │ ├── test_recorders │ │ │ ├── test_base_recorder.py │ │ │ ├── test_func_inputs_recorder.py │ │ │ ├── test_func_outputs_recorder.py │ │ │ ├── test_method_inputs_recorder.py │ │ │ ├── test_method_outputs_recorder.py │ │ │ ├── test_module_recorders.py │ │ │ ├── test_param_recorder.py │ │ │ ├── test_recorder_manager.py │ │ │ └── toy_mod.py │ │ └── test_tracer │ │ │ ├── __init__.py │ │ │ ├── test_backward_tracer.py │ │ │ ├── test_fx_tracer.py │ │ │ ├── test_loss_calculator.py │ │ │ └── test_prune_tracer.py │ ├── test_data.py │ ├── test_datasets │ │ ├── test_datasets.py │ │ └── test_transforms │ │ │ └── test_formatting.py │ ├── test_doc.py │ ├── test_engine │ │ └── test_hooks │ │ │ ├── test_stop_distillation_hook.py │ │ │ └── test_visualization_hook.py │ ├── test_impl │ │ ├── __init__.py │ │ ├── test_pruning │ │ │ ├── __init__.py │ │ │ ├── test_group_fisher │ │ │ │ ├── __init__.py │ │ │ │ ├── test_algorithm.py │ │ │ │ ├── test_prune_deploy_sub_model.py │ │ │ │ ├── test_prune_sub_model.py │ │ │ │ └── test_unit.py │ │ │ └── test_sparse_gpt │ │ │ │ └── test_op.py │ │ └── test_quantization │ │ │ └── test_gptq │ │ │ └── test_op_gptq.py │ ├── test_models │ │ ├── __init__.py │ │ ├── test_algorithms │ │ │ ├── __init__.py │ │ │ ├── test_autoformer.py │ │ │ ├── test_autoslim.py │ │ │ ├── test_base_algorithm.py │ │ │ ├── test_bignas.py │ │ │ ├── test_darts.py │ │ │ ├── test_datafree_distill.py │ │ │ ├── test_dcff_network.py │ │ │ ├── test_dmcp.py │ │ │ ├── test_dsnas.py │ │ │ ├── test_general_quant.py │ │ │ ├── test_mm_architecture.py │ │ │ ├── test_ofd_algo.py │ │ │ ├── test_prune_algorithm.py │ │ │ ├── test_self_distill.py │ │ │ ├── test_single_teacher_distill.py │ │ │ ├── test_slimmable_network.py │ │ │ ├── test_spos.py │ │ │ └── toy_models.py │ │ ├── test_architectures │ │ │ ├── test_backbones │ │ │ │ ├── test_autoformerbackbone.py │ │ │ │ ├── test_dartsbackbone.py │ │ │ │ ├── test_searchable_mobilenet_v2.py │ │ │ │ ├── test_searchable_mobilenet_v3.py │ │ │ │ ├── test_searchable_shufflenet_v2.py │ │ │ │ └── utils.py │ │ │ ├── test_connectors │ │ │ │ └── test_connectors.py │ │ │ ├── test_dynamic_op │ │ │ │ ├── __init__.py │ │ │ │ ├── test_bricks │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── test_dynamic_attention.py │ │ │ │ │ ├── test_dynamic_container.py │ │ │ │ │ ├── test_dynamic_conv.py │ │ │ │ │ ├── test_dynamic_embed.py │ │ │ │ │ ├── test_dynamic_layernorm.py │ │ │ │ │ ├── test_dynamic_linear.py │ │ │ │ │ ├── test_dynamic_norm.py │ │ │ │ │ ├── test_dynamic_relative_position.py │ │ │ │ │ └── test_dynamic_resizer.py │ │ │ │ └── utils.py │ │ │ └── test_generators │ │ │ │ └── test_generators.py │ │ ├── test_classifier │ │ │ └── test_imageclassifier.py │ │ ├── test_distillers │ │ │ ├── test_byot_distill.py │ │ │ └── test_configurable_distill.py │ │ ├── test_fake_quants │ │ │ ├── test_lsq_fake_quants.py │ │ │ └── test_torch_fake_quants.py │ │ ├── test_losses │ │ │ ├── test_distillation_losses.py │ │ │ └── test_general_losses.py │ │ ├── test_mutables │ │ │ ├── __init__.py │ │ │ ├── test_derived_mutable.py │ │ │ ├── test_diffchoiceroute.py │ │ │ ├── test_diffop.py │ │ │ ├── test_gumbelchoiceroute.py │ │ │ ├── test_mutable_channel │ │ │ │ ├── __init__.py │ │ │ │ ├── test_mutable_channels.py │ │ │ │ ├── test_sequential_mutable_channel.py │ │ │ │ └── test_units │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── test_dcff_channel_unit.py │ │ │ │ │ ├── test_l1_mutable_channel_unit.py │ │ │ │ │ ├── test_mutable_channel_units.py │ │ │ │ │ ├── test_one_shot_mutable_channel_unit.py │ │ │ │ │ └── test_sequential_mutable_channel_unit.py │ │ │ ├── test_mutable_value.py │ │ │ ├── test_onehotop.py │ │ │ ├── test_oneshotop.py │ │ │ └── test_sequential_mutable_channel.py │ │ ├── test_mutators │ │ │ ├── __init__.py │ │ │ ├── test_channel_mutator.py │ │ │ ├── test_dcff_mutator.py │ │ │ ├── test_dmcp_mutator.py │ │ │ └── test_nas_mutator.py │ │ ├── test_observers │ │ │ ├── test_lsq_observer.py │ │ │ └── test_torch_observers.py │ │ ├── test_quantizers │ │ │ ├── test_academic_quantizer.py │ │ │ ├── test_exporter.py │ │ │ ├── test_native_quantizer.py │ │ │ ├── test_openvino_quantizer.py │ │ │ └── test_tensorrt_quantizer.py │ │ ├── test_subnet │ │ │ ├── test_candidate.py │ │ │ └── test_fix_subnet.py │ │ ├── test_task_modules │ │ │ ├── __init__.py │ │ │ ├── test_custom_tracer.py │ │ │ ├── test_demo_inputs │ │ │ │ ├── __init__.py │ │ │ │ └── test_demo_inputs.py │ │ │ ├── test_estimators │ │ │ │ └── test_flops_params.py │ │ │ ├── test_graph_utils.py │ │ │ └── test_predictors │ │ │ │ └── test_metric_predictor.py │ │ └── test_utils │ │ │ ├── __init__.py │ │ │ └── test_expandable_utils │ │ │ ├── __init__.py │ │ │ └── test_expand.py │ ├── test_registry │ │ └── test_registry.py │ ├── test_runners │ │ ├── test_autoslim_greedy_search_loop.py │ │ ├── test_darts_loop.py │ │ ├── test_distill_val_loop.py │ │ ├── test_evolution_search_loop.py │ │ ├── test_quantization_loop.py │ │ ├── test_subnet_sampler_loop.py │ │ └── test_utils │ │ │ ├── test_calibrate_bn_mixin.py │ │ │ ├── test_check.py │ │ │ └── test_genetic.py │ ├── test_structures │ │ ├── test_backendconfig.py │ │ └── test_qconfig.py │ ├── test_tools │ │ ├── __init__.py │ │ └── test_tools.py │ ├── test_utils │ │ ├── test_index_dict.py │ │ └── test_placeholder.py │ ├── test_visualizer │ │ └── test_visualizer.py │ └── utils │ │ ├── __init__.py │ │ ├── set_dist_env.py │ │ └── set_torch_thread.py └── tools │ ├── dist_test.sh │ ├── dist_train.sh │ ├── misc │ └── print_config.py │ ├── model_converters │ ├── convert_attentivenas_nas_ckpt.py │ ├── convert_bignas_gml_ckpt.py │ ├── convert_kd_ckpt.py │ ├── convert_kd_ckpt_to_student.py │ ├── convert_ofa_ckpt.py │ ├── convert_quant_ckpt.py │ ├── convert_supernet2subnet.py │ └── publish_model.py │ ├── pruning │ ├── get_channel_units.py │ ├── get_flops.py │ ├── get_l1_prune_config.py │ └── get_static_model_from_algorithm.py │ ├── ptq.py │ ├── slurm_test.sh │ ├── slurm_train.sh │ ├── test.py │ ├── train.py │ └── visualizations │ ├── demo.jpg │ ├── feature_diff_visualization.py │ ├── feature_visualization.py │ ├── vis_configs │ ├── backbone_feature_diff_visualization.py │ ├── backbone_feature_visualization.py │ ├── fpn_feature_diff_visualization.py │ └── fpn_feature_visualization.py │ └── vis_scheduler.py ├── pyproject.toml └── test_dms.py /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Python: Debug Tests", 9 | "type": "debugpy", 10 | "request": "launch", 11 | "program": "${file}", 12 | "purpose": [ 13 | "debug-test" 14 | ], 15 | "console": "integratedTerminal", 16 | "justMyCode": true 17 | } 18 | ] 19 | } -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.testing.unittestArgs": [ 3 | "-v", 4 | "-s", 5 | ".", 6 | "-p", 7 | "test*.py" 8 | ], 9 | "python.testing.pytestEnabled": false, 10 | "python.testing.unittestEnabled": true 11 | } -------------------------------------------------------------------------------- /applications/efficientnet/.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | output 3 | checkpoints -------------------------------------------------------------------------------- /applications/efficientnet/.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Python Debugger: Current File", 9 | "type": "debugpy", 10 | "request": "launch", 11 | "program": "${file}", 12 | "console": "integratedTerminal", 13 | "justMyCode": false 14 | } 15 | ] 16 | } -------------------------------------------------------------------------------- /applications/efficientnet/.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.testing.unittestArgs": [ 3 | "-v", 4 | "-s", 5 | ".", 6 | "-p", 7 | "test*.py" 8 | ], 9 | "python.testing.pytestEnabled": false, 10 | "python.testing.unittestEnabled": true 11 | } -------------------------------------------------------------------------------- /applications/efficientnet/README.md: -------------------------------------------------------------------------------- 1 | # DMS-EfficientNet 2 | 3 | ## Getting Started 4 | 5 | ``` 6 | pip install -r requirements.txt 7 | ``` 8 | 9 | ## Pruning && Retrain 10 | 11 | ``` 12 | export Variant=DMS-450 # for example 13 | scripts/{Variant}/prune.sh 14 | scripts/{Variant}/retrain.sh 15 | ``` -------------------------------------------------------------------------------- /applications/efficientnet/requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.13.1+cu117 2 | torchvision==0.14.1+cu117 3 | torchaudio==0.13.1 4 | timm==0.9.2 5 | mmcv==2.0.1 6 | joblib 7 | mmcls== 1.0.0rc6 -------------------------------------------------------------------------------- /applications/efficientnet/scripts/DMS-450/prune.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node=8 --master_port 29112 timm_pruning.py ./data/imagenet_torch --model efficientnet_b4 -b 96 --sched step --epochs 4 \ 2 | --decay-epochs 2.4 --decay-rate .97 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 \ 3 | --drop 0.5 --drop-path 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .048 \ 4 | --experiment dms_450_pruned --pin-mem --input-size 3 224 224 \ 5 | --target 0.296 --mutator_lr 2e-5 --loss_weight 1 --skip_full_target 6 | 7 | -------------------------------------------------------------------------------- /applications/efficientnet/scripts/DMS-450/retrain.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node=8 --master_port 29112 \ 2 | timm_retrain.py ./data/imagenet_torch --model efficientnet_b4 -b 96 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .97 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.2 --drop-path 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .048 \ 3 | --experiment dms_450_retrain_distil --pin-mem --resume output/train/dms_450_retrain_distil/last.pth.tar \ 4 | --pruned output/train/dms_450_prune/last.pth.tar --input-size 3 224 224 --target 0.2924 --teacher timm/efficientnet_b4 --teacher_input_image_size 320 5 | 6 | -------------------------------------------------------------------------------- /dms/__init__.py: -------------------------------------------------------------------------------- 1 | from .dtopk_src import differentiable_topk, MASK_THRESHOLD 2 | -------------------------------------------------------------------------------- /dms/dtopk_src.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from mmengine.dist import all_reduce 4 | 5 | MASK_THRESHOLD = 0.5 6 | 7 | 8 | @torch.jit.script 9 | def _dtopk(c: torch.Tensor, a: torch.Tensor, lambda_: float = 1.0): 10 | y = (c - a) * c.numel() * lambda_ 11 | y = y.sigmoid() 12 | return y 13 | 14 | 15 | @torch.jit.script 16 | def differentiable_topk( 17 | c: torch.Tensor, 18 | a: torch.Tensor, 19 | lambda_: float = 1.0, 20 | normalize: bool = True, 21 | ): 22 | """ 23 | Differentiable top-k operator: Elements with large importance are kept. 24 | 25 | Args: 26 | c: importance score of elements 27 | a: pruning ratio of elements 28 | lambda_: hyper-parameter to control the polarity of the generated mask, default to 1. 29 | normalize: whether to normalize the importance score, default is True 30 | Returns: 31 | soft masks of elements 32 | """ 33 | 34 | if c.numel() == 1: 35 | return (-(0.5 - a) * lambda_).sigmoid() 36 | else: 37 | if normalize: 38 | c_compare = c.unsqueeze(-1) - c.unsqueeze(-2) # [N,N] 39 | c_ = (c_compare >= 0).float().mean(dim=-1) # normalize to [0,1] 40 | else: 41 | c_ = c 42 | 43 | imp = _dtopk(c_, a, lambda_=lambda_) # [0,1] 44 | return imp 45 | -------------------------------------------------------------------------------- /dms/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LKJacky/Differentiable-Model-Scaling/27188926d8e0815ba1b6159b47741ea289ba2ca2/dms/modules/__init__.py -------------------------------------------------------------------------------- /images/compare.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LKJacky/Differentiable-Model-Scaling/27188926d8e0815ba1b6159b47741ea289ba2ca2/images/compare.png -------------------------------------------------------------------------------- /mmrazor/.circleci/docker/Dockerfile: -------------------------------------------------------------------------------- 1 | ARG PYTORCH="1.8.1" 2 | ARG CUDA="10.2" 3 | ARG CUDNN="7" 4 | 5 | FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel 6 | 7 | # To fix GPG key error when running apt-get update 8 | RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub 9 | RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub 10 | 11 | RUN apt-get update && apt-get install -y ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6 libgl1-mesa-glx 12 | -------------------------------------------------------------------------------- /mmrazor/.github/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | We appreciate all contributions to improve MMRazor. Please refer to [CONTRIBUTING.md](https://github.com/open-mmlab/mmcv/blob/master/CONTRIBUTING.md) in MMCV for more details about the contributing guideline. 2 | -------------------------------------------------------------------------------- /mmrazor/.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '[Bug]' 5 | labels: bug 6 | assignees: '' 7 | --- 8 | 9 | ### Describe the bug 10 | 11 | A clear and concise description of what the bug is. 12 | 13 | \[here\] 14 | 15 | ### To Reproduce 16 | 17 | The command you executed. 18 | 19 | ```shell 20 | [here] 21 | ``` 22 | 23 | ### Post related information 24 | 25 | 1. The output of `pip list | grep "mmcv\|mmrazor\|^torch"` 26 | \[here\] 27 | 2. Your config file if you modified it or created a new one. 28 | 29 | ```python 30 | [here] 31 | ``` 32 | 33 | 3. Your train log file if you meet the problem during training. 34 | \[here\] 35 | 4. Other code you modified in the `mmrazor` folder. 36 | \[here\] 37 | 38 | ### Additional context 39 | 40 | Add any other context about the problem here. 41 | 42 | \[here\] 43 | -------------------------------------------------------------------------------- /mmrazor/.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | 3 | contact_links: 4 | - name: MMRazor Documentation 5 | url: https://mmrazor.readthedocs.io/en/latest/ 6 | about: Check if your question is answered in docs 7 | -------------------------------------------------------------------------------- /mmrazor/.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '[Feature]' 5 | labels: enhancement 6 | assignees: '' 7 | --- 8 | 9 | ### Describe the feature 10 | 11 | \[here\] 12 | 13 | ### Motivation 14 | 15 | A clear and concise description of the motivation of the feature. 16 | Ex1. It is inconvenient when \[....\]. 17 | Ex2. There is a recent paper \[....\], which is very helpful for \[....\]. 18 | 19 | \[here\] 20 | 21 | ### Related resources 22 | 23 | If there is an official code release or third-party implementation, please also provide the information here, which would be very helpful. 24 | 25 | \[here\] 26 | 27 | ### Additional context 28 | 29 | Add any other context or screenshots about the feature request here. 30 | If you would like to implement the feature and create a PR, please leave a comment here and that would be much appreciated. 31 | 32 | \[here\] 33 | -------------------------------------------------------------------------------- /mmrazor/.github/ISSUE_TEMPLATE/general-questions.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: General questions 3 | about: 'Ask general questions to get help ' 4 | title: '' 5 | labels: help wanted 6 | assignees: '' 7 | --- 8 | 9 | ### Checklist 10 | 11 | - I have searched related issues but cannot get the expected help. 12 | - I have read related documents and don't know what to do. 13 | 14 | ### Describe the question you meet 15 | 16 | \[here\] 17 | 18 | ### Post related information 19 | 20 | 1. The output of `pip list | grep "mmcv\|mmrazor\|^torch"` 21 | \[here\] 22 | 2. Your config file if you modified it or created a new one. 23 | 24 | ```python 25 | [here] 26 | ``` 27 | 28 | 3. Your train log file if you meet the problem during training. 29 | \[here\] 30 | 4. Other code you modified in the `mmrazor` folder. 31 | \[here\] 32 | -------------------------------------------------------------------------------- /mmrazor/.github/workflows/deploy.yml: -------------------------------------------------------------------------------- 1 | name: deploy 2 | 3 | on: push 4 | 5 | concurrency: 6 | group: ${{ github.workflow }}-${{ github.ref }} 7 | cancel-in-progress: true 8 | 9 | jobs: 10 | build-n-publish: 11 | runs-on: ubuntu-latest 12 | if: startsWith(github.event.ref, 'refs/tags') 13 | steps: 14 | - uses: actions/checkout@v2 15 | - name: Set up Python 3.7 16 | uses: actions/setup-python@v2 17 | with: 18 | python-version: 3.7 19 | - name: Build MMRAZOR 20 | run: | 21 | pip install wheel 22 | python setup.py sdist bdist_wheel 23 | - name: Publish distribution to PyPI 24 | run: | 25 | pip install twine 26 | twine upload dist/* -u __token__ -p ${{ secrets.pypi_password }} 27 | -------------------------------------------------------------------------------- /mmrazor/.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: lint 2 | 3 | on: [push, pull_request] 4 | 5 | concurrency: 6 | group: ${{ github.workflow }}-${{ github.ref }} 7 | cancel-in-progress: true 8 | 9 | jobs: 10 | lint: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v2 14 | - name: Set up Python 3.7 15 | uses: actions/setup-python@v2 16 | with: 17 | python-version: 3.7 18 | - name: Install pre-commit hook 19 | run: | 20 | pip install pre-commit 21 | pre-commit install 22 | - name: Linting 23 | run: pre-commit run --all-files 24 | - name: Check docstring coverage 25 | run: | 26 | pip install interrogate 27 | interrogate -v --ignore-init-method --ignore-module --ignore-nested-functions --ignore-regex "__repr__" --fail-under 80 mmrazor 28 | -------------------------------------------------------------------------------- /mmrazor/.readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | formats: all 4 | 5 | python: 6 | version: 3.7 7 | install: 8 | - requirements: requirements/docs.txt 9 | - requirements: requirements/readthedocs.txt 10 | -------------------------------------------------------------------------------- /mmrazor/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include requirements/*.txt 2 | include mmrazor/VERSION 3 | include mmrazor/.mim/model-index.yml 4 | include mmrazor/.mim/demo/*/* 5 | recursive-include mmrazor/.mim/configs *.py *.yml 6 | recursive-include mmrazor/.mim/tools *.sh *.py 7 | -------------------------------------------------------------------------------- /mmrazor/configs/_base_/nas_backbones/darts_supernet.py: -------------------------------------------------------------------------------- 1 | mutable_cfg = dict( 2 | type='mmrazor.DiffMutableOP', 3 | candidates=dict( 4 | zero=dict(type='mmrazor.DartsZero'), 5 | skip_connect=dict(type='mmrazor.DartsSkipConnect', use_drop_path=True), 6 | max_pool_3x3=dict( 7 | type='mmrazor.DartsPoolBN', pool_type='max', use_drop_path=True), 8 | avg_pool_3x3=dict( 9 | type='mmrazor.DartsPoolBN', pool_type='avg', use_drop_path=True), 10 | sep_conv_3x3=dict( 11 | type='mmrazor.DartsSepConv', kernel_size=3, use_drop_path=True), 12 | sep_conv_5x5=dict( 13 | type='mmrazor.DartsSepConv', kernel_size=5, use_drop_path=True), 14 | dil_conv_3x3=dict( 15 | type='mmrazor.DartsDilConv', kernel_size=3, use_drop_path=True), 16 | dil_conv_5x5=dict( 17 | type='mmrazor.DartsDilConv', kernel_size=5, use_drop_path=True))) 18 | 19 | route_cfg = dict(type='mmrazor.DiffChoiceRoute', with_arch_param=True) 20 | 21 | nas_backbone = dict( 22 | type='mmrazor.DartsBackbone', 23 | in_channels=3, 24 | base_channels=16, 25 | num_layers=8, 26 | num_nodes=4, 27 | stem_multiplier=3, 28 | out_indices=(7, ), 29 | mutable_cfg=mutable_cfg, 30 | route_cfg=route_cfg, 31 | norm_cfg=dict(type='BN', affine=False)) 32 | -------------------------------------------------------------------------------- /mmrazor/configs/_base_/nas_backbones/dsnas_shufflenet_supernet.py: -------------------------------------------------------------------------------- 1 | norm_cfg = dict(type='BN', eps=0.01) 2 | 3 | _STAGE_MUTABLE = dict( 4 | type='mmrazor.OneHotMutableOP', 5 | fix_threshold=0.3, 6 | candidates=dict( 7 | shuffle_3x3=dict( 8 | type='ShuffleBlock', kernel_size=3, norm_cfg=norm_cfg), 9 | shuffle_5x5=dict( 10 | type='ShuffleBlock', kernel_size=5, norm_cfg=norm_cfg), 11 | shuffle_7x7=dict( 12 | type='ShuffleBlock', kernel_size=7, norm_cfg=norm_cfg), 13 | shuffle_xception=dict(type='ShuffleXception', norm_cfg=norm_cfg))) 14 | 15 | arch_setting = [ 16 | # Parameters to build layers. 3 parameters are needed to construct a 17 | # layer, from left to right: channel, num_blocks, mutable_cfg. 18 | [64, 4, _STAGE_MUTABLE], 19 | [160, 4, _STAGE_MUTABLE], 20 | [320, 8, _STAGE_MUTABLE], 21 | [640, 4, _STAGE_MUTABLE] 22 | ] 23 | 24 | nas_backbone = dict( 25 | type='mmrazor.SearchableShuffleNetV2', 26 | widen_factor=1.0, 27 | arch_setting=arch_setting, 28 | norm_cfg=norm_cfg) 29 | -------------------------------------------------------------------------------- /mmrazor/configs/_base_/nas_backbones/spos_shufflenet_supernet.py: -------------------------------------------------------------------------------- 1 | _STAGE_MUTABLE = dict( 2 | _scope_='mmrazor', 3 | type='OneShotMutableOP', 4 | candidates=dict( 5 | shuffle_3x3=dict(type='ShuffleBlock', kernel_size=3), 6 | shuffle_5x5=dict(type='ShuffleBlock', kernel_size=5), 7 | shuffle_7x7=dict(type='ShuffleBlock', kernel_size=7), 8 | shuffle_xception=dict(type='ShuffleXception'))) 9 | 10 | arch_setting = [ 11 | # Parameters to build layers. 3 parameters are needed to construct a 12 | # layer, from left to right: channel, num_blocks, mutable_cfg. 13 | [64, 4, _STAGE_MUTABLE], 14 | [160, 4, _STAGE_MUTABLE], 15 | [320, 8, _STAGE_MUTABLE], 16 | [640, 4, _STAGE_MUTABLE] 17 | ] 18 | 19 | nas_backbone = dict( 20 | _scope_='mmrazor', 21 | type='SearchableShuffleNetV2', 22 | widen_factor=1.0, 23 | arch_setting=arch_setting) 24 | -------------------------------------------------------------------------------- /mmrazor/configs/_base_/settings/imagenet_bs2048_autoslim.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | './imagenet_bs1024_spos.py', 3 | ] 4 | 5 | _RandomResizedCrop_cfg = _base_.train_dataloader.dataset.pipeline[1] 6 | assert _RandomResizedCrop_cfg.type == 'RandomResizedCrop' 7 | _RandomResizedCrop_cfg.crop_ratio_range = (0.25, 1.0) 8 | 9 | optim_wrapper = dict(optimizer=dict(weight_decay=1e-4, nesterov=True)) 10 | 11 | train_dataloader = dict(batch_size=256) 12 | 13 | val_dataloader = dict(batch_size=256) 14 | 15 | test_dataloader = dict(batch_size=256) 16 | -------------------------------------------------------------------------------- /mmrazor/configs/_base_/settings/imagenet_bs2048_autoslim_pil.py: -------------------------------------------------------------------------------- 1 | _base_ = 'imagenet_bs2048_autoslim.py' 2 | 3 | _RandomResizedCrop_cfg = _base_.train_dataloader.dataset.pipeline[1] 4 | assert _RandomResizedCrop_cfg.type == 'RandomResizedCrop' 5 | _RandomResizedCrop_cfg.backend = 'pillow' 6 | 7 | _ResizeEdge_cfg_val = _base_.val_dataloader.dataset.pipeline[1] 8 | assert _ResizeEdge_cfg_val.type == 'ResizeEdge' 9 | _ResizeEdge_cfg_val.backend = 'pillow' 10 | 11 | _ResizeEdge_cfg_test = _base_.test_dataloader.dataset.pipeline[1] 12 | assert _ResizeEdge_cfg_test.type == 'ResizeEdge' 13 | _ResizeEdge_cfg_test.backend = 'pillow' 14 | -------------------------------------------------------------------------------- /mmrazor/configs/_base_/vanilla_models/wrn16_2_cifar10.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | _scope_='mmcls', 3 | type='ImageClassifier', 4 | backbone=dict( 5 | _scope_='mmrazor', 6 | type='WideResNet', 7 | depth=16, 8 | num_stages=3, 9 | widen_factor=2, 10 | ), 11 | neck=dict(type='GlobalAveragePooling'), 12 | head=dict( 13 | type='LinearClsHead', 14 | num_classes=10, 15 | in_channels=128, 16 | loss=dict(type='CrossEntropyLoss', loss_weight=1.0), 17 | topk=(1, 5), 18 | )) 19 | 20 | find_unused_parameters = True 21 | -------------------------------------------------------------------------------- /mmrazor/configs/distill/mmcls/byot/metafile.yml: -------------------------------------------------------------------------------- 1 | Collections: 2 | - Name: BYOT 3 | Metadata: 4 | Training Data: 5 | - CIFAR100 6 | Paper: 7 | URL: https://arxiv.org/pdf/2107.06916.pdf 8 | Title: Training Compact CNNs for Image Classification using Dynamic-coded Filter Fusion 9 | README: configs/distill/mmcls/byot/README.md 10 | Converted From: 11 | Code: 12 | URL: https://github.com/luanyunteng/pytorch-be-your-own-teacher 13 | Models: 14 | - Name: byot_resnet18_8xb16_cifar100 15 | In Collection: BYOT 16 | Metadata: 17 | inference time (ms/im): 18 | - value: 0.62 19 | hardware: V100 20 | backend: PyTorch 21 | batch size: 16 22 | mode: FP32 23 | resolution: (32, 32) 24 | Results: 25 | - Task: Classification 26 | Dataset: CIFAR100 27 | Metrics: 28 | Top 1 Accuracy: 80.66 29 | Top 5 Accuracy: 95.76 30 | Weights: https://download.openmmlab.com/mmrazor/v1/byot/byot_resnet18_8xb16_cifar100_20220817_191217-0251084e.pth 31 | Config: configs/distill/mmcls/byot/byot_resnet18_8xb16_cifar100.py 32 | -------------------------------------------------------------------------------- /mmrazor/configs/distill/mmcls/factor_transfer/factor-transfer_backbone_resnet50_resnet18_8xb16_cifar10_train.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | './factor-transfer_backbone_resnet50_resnet18_8xb16_cifar10_pretrain.py' 3 | ] 4 | 5 | train_cfg = dict(by_epoch=True, max_epochs=200, val_interval=1) 6 | 7 | model = dict( 8 | calculate_student_loss=True, 9 | student_trainable=True, 10 | distiller=dict( 11 | distill_losses=dict(loss_s4=dict(type='FTLoss', loss_weight=1.0)), 12 | connectors=dict(loss_s4_tfeat=dict(phase='train')), 13 | loss_forward_mappings=dict( 14 | _delete_=True, 15 | loss_s4=dict( 16 | s_feature=dict( 17 | from_student=True, 18 | recorder='bb_s4', 19 | connector='loss_s4_sfeat'), 20 | t_feature=dict( 21 | from_student=False, 22 | recorder='bb_s4', 23 | connector='loss_s4_tfeat'), 24 | ))), 25 | init_cfg=dict( 26 | type='Pretrained', 27 | checkpoint= # noqa: E251 28 | 'https://download.openmmlab.com/mmrazor/v1/factor_transfer/factor-transfer_backbone_resnet50_resnet18_8xb16_cifar10_pretrain_20220831_173259-ebdb09e2.pth' # noqa: E501 29 | )) 30 | -------------------------------------------------------------------------------- /mmrazor/configs/distill/mmdet/cwd/cwd_cls_head_gfl_r101_fpn_gfl_r50_fpn_1x_coco.py: -------------------------------------------------------------------------------- 1 | _base_ = ['./cwd_fpn_retina_r101_retina_r50_1x_coco.py'] 2 | 3 | teacher_ckpt = 'https://download.openmmlab.com/mmdetection/v2.0/gfl/gfl_r101_fpn_mstrain_2x_coco/gfl_r101_fpn_mstrain_2x_coco_20200629_200126-dd12f847.pth' # noqa: E501 4 | model = dict( 5 | architecture=dict( 6 | cfg_path='mmdet::gfl/gfl_r50_fpn_1x_coco.py', pretrained=False), 7 | teacher=dict( 8 | cfg_path='mmdet::gfl/gfl_r101_fpn_ms-2x_coco.py', pretrained=True), 9 | teacher_ckpt=teacher_ckpt) 10 | -------------------------------------------------------------------------------- /mmrazor/configs/distill/mmdet/cwd/cwd_fpn_retina_r101_retina_r50_1x_coco.py: -------------------------------------------------------------------------------- 1 | _base_ = ['./cwd_fpn_frcnn_r101_frcnn_r50_1x_coco.py'] 2 | 3 | model = dict( 4 | architecture=dict( 5 | cfg_path='mmdet::retinanet/retinanet_r50_fpn_1x_coco.py', 6 | pretrained=False), 7 | teacher=dict( 8 | cfg_path='mmdet::retinanet/retinanet_r101_fpn_2x_coco.py', 9 | pretrained=True)) 10 | 11 | # optimizer 12 | optim_wrapper = dict( 13 | optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)) 14 | -------------------------------------------------------------------------------- /mmrazor/configs/distill/mmdet/cwd/cwd_fpn_retina_r101_retina_r50_1x_coco_visualization.py: -------------------------------------------------------------------------------- 1 | _base_ = ['./cwd_fpn_retina_r101_retina_r50_1x_coco.py'] 2 | 3 | default_hooks = dict( 4 | checkpoint=dict(type='CheckpointHook', interval=-1), 5 | visualization=dict( 6 | _scope_='mmrazor', 7 | type='RazorVisualizationHook', 8 | enabled=True, 9 | recorders=dict( 10 | # todo: Maybe it is hard for users to understand why to add a 11 | # prefix `architecture.` 12 | neck=dict( 13 | _scope_='mmrazor', 14 | type='ModuleOutputs', 15 | source='architecture.neck')), 16 | mappings=dict( 17 | p3=dict(recorder='neck', data_idx=0), 18 | p4=dict(recorder='neck', data_idx=1), 19 | p5=dict(recorder='neck', data_idx=2), 20 | p6=dict(recorder='neck', data_idx=3)), 21 | out_dir='retina_vis')) 22 | -------------------------------------------------------------------------------- /mmrazor/configs/distill/mmdet/cwd/metafile.yml: -------------------------------------------------------------------------------- 1 | 2 | Models: 3 | - Name: cwd_cls_head_gfl_r101_fpn_gfl_r50_fpn_1x_coco 4 | In Collection: CWD 5 | Metadata: 6 | Location: cls head 7 | Student: 8 | Metrics: 9 | box AP: 40.2 10 | Config: mmdet::gfl/gfl_r50_fpn_1x_coco.py 11 | Weights: https://download.openmmlab.com/mmdetection/v2.0/gfl/gfl_r50_fpn_1x_coco/gfl_r50_fpn_1x_coco_20200629_121244-25944287.pth 12 | Teacher: 13 | Metrics: 14 | box AP: 44.7 15 | Config: mmdet::gfl/gfl_r50_fpn_mstrain_2x_coco.py 16 | Weights: https://download.openmmlab.com/mmdetection/v2.0/gfl/gfl_r101_fpn_mstrain_2x_coco/gfl_r101_fpn_mstrain_2x_coco_20200629_200126-dd12f847.pth 17 | Results: 18 | - Task: Object Detection 19 | Dataset: COCO 20 | Metrics: 21 | box AP: 41.9 22 | Config: configs/distill/mmdet/cwd/cwd_cls_head_gfl_r101_fpn_gfl_r50_fpn_1x_coco.py 23 | Weights: https://download.openmmlab.com/mmrazor/v1/cwd/cwd_cls_head_gfl_r101_fpn_gfl_r50_fpn_1x_coco/cwd_cls_head_gfl_r101_fpn_gfl_r50_fpn_1x_coco_20211222-c134bb21.pth 24 | -------------------------------------------------------------------------------- /mmrazor/configs/distill/mmdet/pkd/pkd_fpn_fcos_x101_retina_r50_1x_coco.py: -------------------------------------------------------------------------------- 1 | _base_ = ['./pkd_fpn_retina_x101_retina_r50_2x_coco.py'] 2 | 3 | teacher_ckpt = 'https://download.openmmlab.com/mmdetection/v2.0/fcos/fcos_x101_64x4d_fpn_gn-head_mstrain_640-800_2x_coco/fcos_x101_64x4d_fpn_gn-head_mstrain_640-800_2x_coco-ede514a8.pth' # noqa: E501 4 | 5 | model = dict( 6 | architecture=dict( 7 | cfg_path='mmdet::retinanet/retinanet_r50_fpn_1x_coco.py'), 8 | teacher=dict( 9 | cfg_path= # noqa: E251 10 | 'mmdet::fcos/fcos_x101-64x4d_fpn_gn-head_ms-640-800-2x_coco.py'), 11 | teacher_ckpt=teacher_ckpt) 12 | 13 | # training schedule for 1x 14 | train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=12, val_interval=1) 15 | 16 | # learning rate 17 | param_scheduler = [ 18 | dict( 19 | type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, end=500), 20 | dict( 21 | type='MultiStepLR', 22 | begin=0, 23 | end=12, 24 | by_epoch=True, 25 | milestones=[8, 11], 26 | gamma=0.1) 27 | ] 28 | -------------------------------------------------------------------------------- /mmrazor/configs/distill/mmdet/pkd/pkd_fpn_reppoints_x101-dcn_reppoints_r50_2x_coco.py: -------------------------------------------------------------------------------- 1 | _base_ = ['./pkd_fpn_retina_x101_retina_r50_2x_coco.py'] 2 | 3 | teacher_ckpt = 'https://download.openmmlab.com/mmdetection/v2.0/reppoints/reppoints_moment_x101_fpn_dconv_c3-c5_gn-neck%2Bhead_2x_coco/reppoints_moment_x101_fpn_dconv_c3-c5_gn-neck%2Bhead_2x_coco_20200329-f87da1ea.pth' # noqa: E501 4 | 5 | model = dict( 6 | architecture=dict( 7 | cfg_path= # noqa: E251 8 | 'mmdet::reppoints/reppoints-moment_r50_fpn-gn_head-gn_2x_coco.py'), 9 | teacher=dict( 10 | cfg_path= # noqa: E251 11 | 'mmdet::reppoints/reppoints-moment_x101-dconv-c3-c5_fpn-gn_head-gn_2x_coco.py' # noqa: E501 12 | ), 13 | teacher_ckpt=teacher_ckpt) 14 | -------------------------------------------------------------------------------- /mmrazor/configs/distill/mmdet/pkd/pkd_fpn_retina_x101_retina_r50_2x_coco.py: -------------------------------------------------------------------------------- 1 | _base_ = ['./pkd_fpn_faster-rcnn_r101_faster-rcnn_r50_2x_coco.py'] 2 | 3 | teacher_ckpt = 'https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_x101_64x4d_fpn_1x_coco/retinanet_x101_64x4d_fpn_1x_coco_20200130-366f5af1.pth' # noqa: E501 4 | 5 | model = dict( 6 | architecture=dict( 7 | cfg_path='mmdet::retinanet/retinanet_r50_fpn_2x_coco.py'), 8 | teacher=dict( 9 | cfg_path='mmdet::retinanet/retinanet_x101-64x4d_fpn_1x_coco.py'), 10 | teacher_ckpt=teacher_ckpt, 11 | distiller=dict( 12 | distill_losses=dict( 13 | loss_pkd_fpn0=dict(loss_weight=10), 14 | loss_pkd_fpn1=dict(loss_weight=10), 15 | loss_pkd_fpn2=dict(loss_weight=10), 16 | loss_pkd_fpn3=dict(loss_weight=10)))) 17 | 18 | # optimizer 19 | optim_wrapper = dict(optimizer=dict(lr=0.01)) 20 | -------------------------------------------------------------------------------- /mmrazor/configs/distill/mmdet3d/pkd/metafile.yml: -------------------------------------------------------------------------------- 1 | Models: 2 | - Name: pkd_fpn_fcos3d_r101_fcos3d_r50_8xb2-1x_nus-mono3d 3 | In Collection: PKD 4 | Metadata: 5 | Location: FPN 6 | Student: 7 | Metrics: 8 | box AP: 26.8 9 | Config: 10 | Weights: 11 | Teacher: 12 | Metrics: 13 | box AP: 32.1 14 | Config: mmdet3d::fcos3d/fcos3d_r101-caffe-dcn_fpn_head-gn_8xb2-1x_nus-mono3d_finetune.py 15 | Weights: https://download.openmmlab.com/mmdetection3d/v0.1.0_models/fcos3d/fcos3d_r101_caffe_fpn_gn-head_dcn_2x8_1x_nus-mono3d_finetune/fcos3d_r101_caffe_fpn_gn-head_dcn_2x8_1x_nus-mono3d_finetune_20210717_095645-8d806dc2.pth 16 | Results: 17 | - Task: Object Detection 18 | Dataset: COCO 19 | Metrics: 20 | box AP: 29.3 21 | Config: configs/distill/mmdet3d/pkd/pkd_fpn_fcos3d_r101_fcos3d_r50_8xb2-1x_nus-mono3d.py 22 | Weights: https://download.openmmlab.com/mmrazor/v1/pkd/pkd_fcos3d_w10/pkd_fpn_fcos3d_r101_fcos3d_r50_8xb2-1x_nus-mono3d_20220928_234557-0b51b62e.json?versionId=CAEQThiBgIDrvdC0oBgiIDNmNGNkNDZhM2RmNjQ1MmI4ZDM0OGNmYmFkYjk5ZjFi 23 | -------------------------------------------------------------------------------- /mmrazor/configs/nas/mmcls/autoformer/autoformer_search_8xb128_in1k.py: -------------------------------------------------------------------------------- 1 | _base_ = ['./autoformer_supernet_32xb256_in1k.py'] 2 | 3 | custom_hooks = None 4 | 5 | train_cfg = dict( 6 | _delete_=True, 7 | type='mmrazor.EvolutionSearchLoop', 8 | dataloader=_base_.val_dataloader, 9 | evaluator=_base_.val_evaluator, 10 | max_epochs=20, 11 | num_candidates=20, 12 | top_k=10, 13 | num_mutation=5, 14 | num_crossover=5, 15 | mutate_prob=0.2, 16 | constraints_range=dict(params=(0, 55)), 17 | score_key='accuracy/top1') 18 | -------------------------------------------------------------------------------- /mmrazor/configs/nas/mmcls/autoformer/autoformer_subnet_8xb256_in1k.py: -------------------------------------------------------------------------------- 1 | _base_ = 'autoformer_supernet_32xb256_in1k.py' 2 | 3 | model = dict( 4 | _scope_='mmrazor', 5 | type='sub_model', 6 | cfg=_base_.supernet, 7 | # NOTE: You can replace the yaml with the mutable_cfg searched by yourself 8 | fix_subnet='configs/nas/mmcls/autoformer/AUTOFORMER_SUBNET_B.yaml', 9 | # You can also load the checkpoint of supernet instead of the specific 10 | # subnet by modifying the `checkpoint`(path) in the following `init_cfg` 11 | # with `init_weight_from_supernet = True`. 12 | init_weight_from_supernet=False, 13 | init_cfg=dict( 14 | type='Pretrained', 15 | checkpoint= # noqa: E251 16 | 'https://download.openmmlab.com/mmrazor/v1/autoformer/autoformer_supernet_32xb256_in1k_20220919_110144-c658ce8f.pth', # noqa: E501 17 | prefix='architecture.')) 18 | -------------------------------------------------------------------------------- /mmrazor/configs/nas/mmcls/autoslim/autoslim_mbv2_1.5x_search_8xb256_in1k.py: -------------------------------------------------------------------------------- 1 | _base_ = ['./autoslim_mbv2_1.5x_supernet_8xb256_in1k.py'] 2 | 3 | model = dict(bn_training_mode=True) 4 | 5 | train_cfg = None 6 | optim_wrapper = None 7 | param_scheduler = None 8 | train_dataloader = None 9 | 10 | val_cfg = None 11 | val_dataloader = None 12 | val_evaluator = None 13 | 14 | test_cfg = dict( 15 | _delete_=True, 16 | type='mmrazor.AutoSlimGreedySearchLoop', 17 | dataloader=_base_.test_dataloader, 18 | evaluator=_base_.test_evaluator, 19 | target_flops=(500., 300., 200.)) 20 | -------------------------------------------------------------------------------- /mmrazor/configs/nas/mmcls/autoslim/autoslim_mbv2_1.5x_subnet_8xb256_in1k_flops-220M.py: -------------------------------------------------------------------------------- 1 | _base_ = 'autoslim_mbv2_1.5x_slimmable_subnet_8xb256_in1k.py' 2 | 3 | model = dict(deploy_index=0) 4 | -------------------------------------------------------------------------------- /mmrazor/configs/nas/mmcls/autoslim/autoslim_mbv2_1.5x_subnet_8xb256_in1k_flops-320M.py: -------------------------------------------------------------------------------- 1 | _base_ = 'autoslim_mbv2_1.5x_slimmable_subnet_8xb256_in1k.py' 2 | 3 | model = dict(deploy_index=1) 4 | -------------------------------------------------------------------------------- /mmrazor/configs/nas/mmcls/autoslim/autoslim_mbv2_1.5x_subnet_8xb256_in1k_flops-530M.py: -------------------------------------------------------------------------------- 1 | _base_ = 'autoslim_mbv2_1.5x_slimmable_subnet_8xb256_in1k.py' 2 | 3 | model = dict(deploy_index=2) 4 | -------------------------------------------------------------------------------- /mmrazor/configs/nas/mmcls/bignas/attentive_mobilenet_search_8xb128_in1k.py: -------------------------------------------------------------------------------- 1 | _base_ = ['./attentive_mobilenet_supernet_32xb64_in1k.py'] 2 | 3 | train_cfg = dict( 4 | _delete_=True, 5 | type='mmrazor.EvolutionSearchLoop', 6 | dataloader=_base_.val_dataloader, 7 | evaluator=_base_.val_evaluator, 8 | max_epochs=20, 9 | num_candidates=50, 10 | top_k=10, 11 | num_mutation=25, 12 | num_crossover=25, 13 | mutate_prob=0.1, 14 | calibrate_sample_num=4096, 15 | constraints_range=dict(flops=(0., 700.)), 16 | score_key='accuracy/top1') 17 | -------------------------------------------------------------------------------- /mmrazor/configs/nas/mmcls/bignas/attentive_mobilenet_subnet_8xb256_in1k.py: -------------------------------------------------------------------------------- 1 | _base_ = 'attentive_mobilenet_supernet_32xb64_in1k.py' 2 | 3 | model = dict( 4 | _scope_='mmrazor', 5 | type='sub_model', 6 | cfg=_base_.supernet, 7 | # NOTE: You can replace the yaml with the mutable_cfg searched by yourself 8 | fix_subnet='configs/nas/mmcls/bignas/ATTENTIVE_SUBNET_A0.yaml', 9 | # You can load the checkpoint of supernet instead of the specific 10 | # subnet by modifying the `checkpoint`(path) in the following `init_cfg` 11 | # with `init_weight_from_supernet = True`. 12 | init_weight_from_supernet=True, 13 | init_cfg=dict( 14 | type='Pretrained', 15 | checkpoint= # noqa: E251 16 | 'https://download.openmmlab.com/mmrazor/v1/bignas/attentive_mobilenet_supernet_32xb64_in1k_flops-2G_acc-81.72_20221229_200440-954772a3.pth', # noqa: E501 17 | prefix='architecture.')) 18 | 19 | model_wrapper_cfg = None 20 | find_unused_parameters = True 21 | 22 | test_cfg = dict(evaluate_fixed_subnet=True) 23 | 24 | default_hooks = dict(checkpoint=None) 25 | -------------------------------------------------------------------------------- /mmrazor/configs/nas/mmcls/darts/darts_supernet_unroll_1xb96_cifar10.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | 'mmrazor::_base_/settings/cifar10_darts_supernet.py', 3 | 'mmrazor::_base_/nas_backbones/darts_supernet.py', 4 | 'mmcls::_base_/default_runtime.py', 5 | ] 6 | 7 | custom_hooks = [ 8 | dict(type='mmrazor.DumpSubnetHook', interval=10, by_epoch=True) 9 | ] 10 | 11 | # model 12 | model = dict( 13 | type='mmrazor.Darts', 14 | architecture=dict( 15 | type='ImageClassifier', 16 | backbone=_base_.nas_backbone, 17 | neck=dict(type='GlobalAveragePooling'), 18 | head=dict( 19 | type='LinearClsHead', 20 | num_classes=10, 21 | in_channels=256, 22 | loss=dict(type='CrossEntropyLoss', loss_weight=1.0), 23 | topk=(1, 5), 24 | cal_acc=True)), 25 | mutator=dict(type='mmrazor.NasMutator'), 26 | unroll=True) 27 | 28 | model_wrapper_cfg = dict( 29 | type='mmrazor.DartsDDP', 30 | broadcast_buffers=False, 31 | find_unused_parameters=False) 32 | 33 | find_unused_parameter = False 34 | -------------------------------------------------------------------------------- /mmrazor/configs/nas/mmcls/darts/metafile.yml: -------------------------------------------------------------------------------- 1 | Collections: 2 | - Name: Darts 3 | Metadata: 4 | Training Data: 5 | - CIFAR-10 6 | Paper: 7 | URL: https://arxiv.org/abs/1806.09055 8 | Title: DARTS:Differentiable Architecture Search 9 | README: configs/nas/mmcls/darts/README.md 10 | Code: 11 | URL: https://github.com/open-mmlab/mmrazor/blob/v0.1.0/mmrazor/models/algorithms/darts.py 12 | Version: v0.1.0 13 | Converted From: 14 | Code: https://github.com/quark0/darts 15 | Models: 16 | - Name: darts_subnet_1xb96_cifar10_2.0 17 | In Collection: Darts 18 | Metadata: 19 | Params(M): 3.42 20 | Mutable: https://download.openmmlab.com/mmrazor/v1/darts/darts_subnetnet_1xb96_cifar10_acc-97.32_20211222-e5727921_mutable_cfg.yaml 21 | Results: 22 | - Task: Image Classification 23 | Dataset: CIFAR-10 24 | Metrics: 25 | Top 1 Accuracy: 97.32 26 | Top 5 Accuracy: 99.94 27 | Config: configs/nas/mmcls/darts/darts_subnet_1xb96_cifar10_2.0.py 28 | Weights: https://download.openmmlab.com/mmrazor/v1/darts/darts_subnetnet_1xb96_cifar10_acc-97.32_20211222-e5727921_latest.pth 29 | -------------------------------------------------------------------------------- /mmrazor/configs/nas/mmcls/dsnas/DSNAS_SUBNET_IMAGENET_PAPER_ALIAS.yaml: -------------------------------------------------------------------------------- 1 | backbone.layers.0.0: 2 | chosen: shuffle_3x3 3 | backbone.layers.0.1: 4 | chosen: shuffle_7x7 5 | backbone.layers.0.2: 6 | chosen: shuffle_3x3 7 | backbone.layers.0.3: 8 | chosen: shuffle_5x5 9 | backbone.layers.1.0: 10 | chosen: shuffle_3x3 11 | backbone.layers.1.1: 12 | chosen: shuffle_3x3 13 | backbone.layers.1.2: 14 | chosen: shuffle_3x3 15 | backbone.layers.1.3: 16 | chosen: shuffle_7x7 17 | backbone.layers.2.0: 18 | chosen: shuffle_xception 19 | backbone.layers.2.1: 20 | chosen: shuffle_3x3 21 | backbone.layers.2.2: 22 | chosen: shuffle_3x3 23 | backbone.layers.2.3: 24 | chosen: shuffle_5x5 25 | backbone.layers.2.4: 26 | chosen: shuffle_3x3 27 | backbone.layers.2.5: 28 | chosen: shuffle_5x5 29 | backbone.layers.2.6: 30 | chosen: shuffle_7x7 31 | backbone.layers.2.7: 32 | chosen: shuffle_7x7 33 | backbone.layers.3.0: 34 | chosen: shuffle_xception 35 | backbone.layers.3.1: 36 | chosen: shuffle_3x3 37 | backbone.layers.3.2: 38 | chosen: shuffle_7x7 39 | backbone.layers.3.3: 40 | chosen: shuffle_3x3 41 | -------------------------------------------------------------------------------- /mmrazor/configs/nas/mmcls/dsnas/dsnas_subnet_8xb128_in1k.py: -------------------------------------------------------------------------------- 1 | _base_ = ['./dsnas_supernet_8xb128_in1k.py'] 2 | 3 | model = dict( 4 | _scope_='mmrazor', 5 | type='sub_model', 6 | cfg=_base_.supernet, 7 | # NOTE: You can replace the yaml with the mutable_cfg searched by yourself 8 | fix_subnet= # noqa: E251 9 | 'configs/nas/mmcls/dsnas/DSNAS_SUBNET_IMAGENET_PAPER_ALIAS.yaml' 10 | ) # noqa: E501 11 | 12 | find_unused_parameters = False 13 | -------------------------------------------------------------------------------- /mmrazor/configs/nas/mmcls/dsnas/dsnas_supernet_8xb128_in1k.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | 'mmrazor::_base_/settings/imagenet_bs1024_dsnas.py', 3 | 'mmrazor::_base_/nas_backbones/dsnas_shufflenet_supernet.py', 4 | 'mmcls::_base_/default_runtime.py', 5 | ] 6 | 7 | custom_hooks = [ 8 | dict(type='mmrazor.DumpSubnetHook', interval=10, by_epoch=True) 9 | ] 10 | 11 | supernet = dict( 12 | _scope_='mmcls', 13 | type='ImageClassifier', 14 | data_preprocessor=_base_.data_preprocessor, 15 | backbone=_base_.nas_backbone, 16 | neck=dict(type='GlobalAveragePooling'), 17 | head=dict( 18 | type='LinearClsHead', 19 | num_classes=1000, 20 | in_channels=1024, 21 | loss=dict( 22 | type='LabelSmoothLoss', 23 | num_classes=1000, 24 | label_smooth_val=0.1, 25 | mode='original', 26 | loss_weight=1.0), 27 | topk=(1, 5))) 28 | 29 | # model 30 | model = dict( 31 | type='mmrazor.DSNAS', 32 | architecture=supernet, 33 | mutator=dict(type='mmrazor.NasMutator'), 34 | pretrain_epochs=15, 35 | finetune_epochs=_base_.search_epochs, 36 | ) 37 | 38 | model_wrapper_cfg = dict( 39 | type='mmrazor.DSNASDDP', 40 | broadcast_buffers=False, 41 | find_unused_parameters=True) 42 | 43 | randomness = dict(seed=48, diff_rank_seed=True) 44 | -------------------------------------------------------------------------------- /mmrazor/configs/nas/mmcls/onceforall/ofa_mobilenet_search_8xb128_in1k.py: -------------------------------------------------------------------------------- 1 | _base_ = ['./ofa_mobilenet_supernet_32xb64_in1k.py'] 2 | 3 | train_cfg = dict( 4 | _delete_=True, 5 | type='mmrazor.EvolutionSearchLoop', 6 | dataloader=_base_.val_dataloader, 7 | evaluator=_base_.val_evaluator, 8 | max_epochs=1, 9 | num_candidates=2, 10 | top_k=1, 11 | num_mutation=1, 12 | num_crossover=1, 13 | mutate_prob=0.1, 14 | calibrate_sample_num=4096, 15 | constraints_range=dict(flops=(0., 700.)), 16 | score_key='accuracy/top1') 17 | -------------------------------------------------------------------------------- /mmrazor/configs/nas/mmcls/onceforall/ofa_mobilenet_subnet_8xb256_in1k.py: -------------------------------------------------------------------------------- 1 | _base_ = 'ofa_mobilenet_supernet_32xb64_in1k.py' 2 | 3 | model = dict( 4 | _scope_='mmrazor', 5 | type='sub_model', 6 | cfg=_base_.supernet, 7 | # NOTE: You can replace the yaml with the mutable_cfg searched by yourself 8 | fix_subnet='configs/nas/mmcls/onceforall/OFA_SUBNET_NOTE8_LAT31.yaml', 9 | # You can also load the checkpoint of supernet instead of the specific 10 | # subnet by modifying the `checkpoint`(path) in the following `init_cfg` 11 | # with `init_weight_from_supernet = True`. 12 | init_weight_from_supernet=False, 13 | init_cfg=dict( 14 | type='Pretrained', 15 | checkpoint= # noqa: E251 16 | 'https://download.openmmlab.com/mmrazor/v1/ofa/ofa_mobilenet_subnet_8xb256_in1k_note8_lat%4031ms_top1%4072.8_finetune%4025.py_20221214_0939-981a8b2a.pth', # noqa: E501 17 | prefix='architecture.')) 18 | 19 | model_wrapper_cfg = None 20 | find_unused_parameters = True 21 | 22 | test_cfg = dict(evaluate_fixed_subnet=True) 23 | -------------------------------------------------------------------------------- /mmrazor/configs/nas/mmcls/spos/SPOS_SUBNET.yaml: -------------------------------------------------------------------------------- 1 | backbone.layers.0.0: 2 | chosen: shuffle_7x7 3 | backbone.layers.0.1: 4 | chosen: shuffle_3x3 5 | backbone.layers.0.2: 6 | chosen: shuffle_7x7 7 | backbone.layers.0.3: 8 | chosen: shuffle_3x3 9 | backbone.layers.1.0: 10 | chosen: shuffle_xception 11 | backbone.layers.1.1: 12 | chosen: shuffle_5x5 13 | backbone.layers.1.2: 14 | chosen: shuffle_5x5 15 | backbone.layers.1.3: 16 | chosen: shuffle_3x3 17 | backbone.layers.2.0: 18 | chosen: shuffle_3x3 19 | backbone.layers.2.1: 20 | chosen: shuffle_5x5 21 | backbone.layers.2.2: 22 | chosen: shuffle_3x3 23 | backbone.layers.2.3: 24 | chosen: shuffle_5x5 25 | backbone.layers.2.4: 26 | chosen: shuffle_3x3 27 | backbone.layers.2.5: 28 | chosen: shuffle_xception 29 | backbone.layers.2.6: 30 | chosen: shuffle_5x5 31 | backbone.layers.2.7: 32 | chosen: shuffle_7x7 33 | backbone.layers.3.0: 34 | chosen: shuffle_7x7 35 | backbone.layers.3.1: 36 | chosen: shuffle_3x3 37 | backbone.layers.3.2: 38 | chosen: shuffle_5x5 39 | backbone.layers.3.3: 40 | chosen: shuffle_xception 41 | -------------------------------------------------------------------------------- /mmrazor/configs/nas/mmcls/spos/faster-rcnn_nas_backbone_fpn_1x_coco.py: -------------------------------------------------------------------------------- 1 | # Suppose you are in mmdet and want to use the searched subnet 2 | # as backbone for faster-rcnn, then you can just use this config. 3 | 4 | _base_ = [ 5 | '../_base_/models/faster-rcnn_r50_fpn.py', 6 | '../_base_/datasets/coco_detection.py', 7 | '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py', 8 | 'mmrazor::_base_/nas_backbones/spos_shufflenet_supernet.py' 9 | ] 10 | 11 | _base_.nas_backbone.out_indices = (0, 1, 2, 3) 12 | _base_.nas_backbone.with_last_layer = False 13 | nas_backbone = dict( 14 | # use mmrazor's build_func 15 | type='mmrazor.sub_model', 16 | cfg=_base_.nas_backbone, 17 | fix_subnet='/path/to/your/mmrazor/configs/nas/mmcls/spos/SPOS_SUBNET.yaml', 18 | extra_prefix='backbone.') 19 | 20 | _base_.model.backbone = nas_backbone 21 | _base_.model.neck.in_channels = [64, 160, 320, 640] 22 | -------------------------------------------------------------------------------- /mmrazor/configs/nas/mmcls/spos/metafile.yml: -------------------------------------------------------------------------------- 1 | Collections: 2 | - Name: SPOS 3 | Metadata: 4 | Training Data: 5 | - ImageNet-1k 6 | Paper: 7 | URL: https://arxiv.org/abs/1904.00420 8 | Title: Single Path One-Shot Neural Architecture Search with Uniform Sampling 9 | README: configs/nas/mmcls/spos/README.md 10 | Code: 11 | URL: https://github.com/open-mmlab/mmrazor/blob/v0.1.0/mmrazor/models/algorithms/spos.py 12 | Version: v0.1.0 13 | Converted From: 14 | Code: https://github.com/megvii-model/SinglePathOneShot 15 | Models: 16 | - Name: spos_shufflenet_subnet_8xb128_in1k 17 | In Collection: SPOS 18 | Metadata: 19 | FLOPs: 330 MB 20 | Subnet: https://download.openmmlab.com/mmrazor/v1/spos/spos_shufflenetv2_subnet_8xb128_in1k_flops_0.33M_acc_73.87_20211222-1f0a0b4d_subnet_cfg_v3.yaml 21 | Results: 22 | - Task: Image Classification 23 | Dataset: ImageNet-1k 24 | Metrics: 25 | Top 1 Accuracy: 73.87 26 | Top 5 Accuracy: 91.60 27 | Config: configs/nas/mmcls/spos/spos_shufflenet_subnet_8xb128_in1k.py 28 | Weights: https://download.openmmlab.com/mmrazor/v1/spos/spos_shufflenetv2_subnet_8xb128_in1k_flops_0.33M_acc_73.87_20211222-1f0a0b4d_v3.pth 29 | -------------------------------------------------------------------------------- /mmrazor/configs/nas/mmcls/spos/spos_mobilenet_search_8xb128_in1k.py: -------------------------------------------------------------------------------- 1 | _base_ = ['./spos_mobilenet_supernet_8xb128_in1k.py'] 2 | 3 | model = dict(norm_training=True) 4 | 5 | train_cfg = dict( 6 | _delete_=True, 7 | type='mmrazor.EvolutionSearchLoop', 8 | dataloader=_base_.val_dataloader, 9 | evaluator=_base_.val_evaluator, 10 | max_epochs=20, 11 | num_candidates=50, 12 | top_k=10, 13 | num_mutation=25, 14 | num_crossover=25, 15 | mutate_prob=0.1, 16 | constraints_range=dict(flops=(0., 465.)), 17 | score_key='accuracy/top1') 18 | -------------------------------------------------------------------------------- /mmrazor/configs/nas/mmcls/spos/spos_mobilenet_subnet_8xb128_in1k.py: -------------------------------------------------------------------------------- 1 | _base_ = ['./spos_mobilenet_supernet_8xb128_in1k.py'] 2 | 3 | model = dict( 4 | _scope_='mmrazor', 5 | type='sub_model', 6 | cfg=_base_.supernet, 7 | # NOTE: You can replace the yaml with the mutable_cfg searched by yourself 8 | fix_subnet='configs/nas/spos/AngleNAS_SHUFFLENETV2_IN1k_2.0.yaml', 9 | init_cfg=dict( 10 | type='Pretrained', 11 | checkpoint= # noqa: E251 12 | 'https://download.openmmlab.com/mmrazor/v1/spos/spos_mobilenetv2_subnet_8xb128_in1k_flops_0.33M_acc_73.87_20211222-1f0a0b4d_v3.pth', # noqa: E501 13 | prefix='architecture.')) 14 | 15 | model_wrapper_cfg = None 16 | 17 | find_unused_parameters = False 18 | -------------------------------------------------------------------------------- /mmrazor/configs/nas/mmcls/spos/spos_mobilenet_supernet_8xb128_in1k.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | 'mmrazor::_base_/settings/imagenet_bs1024_spos.py', 3 | 'mmrazor::_base_/nas_backbones/spos_mobilenet_supernet.py', 4 | 'mmcls::_base_/default_runtime.py', 5 | ] 6 | 7 | # model 8 | supernet = dict( 9 | _scope_='mmcls', 10 | type='ImageClassifier', 11 | # data_preprocessor=_base_.preprocess_cfg, 12 | backbone=_base_.nas_backbone, 13 | neck=dict(type='GlobalAveragePooling'), 14 | head=dict( 15 | type='LinearClsHead', 16 | num_classes=1000, 17 | in_channels=1728, 18 | loss=dict( 19 | type='LabelSmoothLoss', 20 | num_classes=1000, 21 | label_smooth_val=0.1, 22 | mode='original', 23 | loss_weight=1.0), 24 | topk=(1, 5))) 25 | 26 | model = dict( 27 | type='mmrazor.SPOS', 28 | architecture=supernet, 29 | mutator=dict(type='mmrazor.NasMutator')) 30 | 31 | find_unused_parameters = True 32 | -------------------------------------------------------------------------------- /mmrazor/configs/nas/mmcls/spos/spos_shufflenet_search_8xb128_in1k.py: -------------------------------------------------------------------------------- 1 | _base_ = ['./spos_shufflenet_supernet_8xb128_in1k.py'] 2 | 3 | model = dict(norm_training=True) 4 | 5 | train_cfg = dict( 6 | _delete_=True, 7 | type='mmrazor.EvolutionSearchLoop', 8 | dataloader=_base_.val_dataloader, 9 | evaluator=_base_.val_evaluator, 10 | max_epochs=20, 11 | num_candidates=50, 12 | top_k=10, 13 | num_mutation=25, 14 | num_crossover=25, 15 | mutate_prob=0.1, 16 | constraints_range=dict(flops=(0, 330)), 17 | score_key='accuracy/top1') 18 | -------------------------------------------------------------------------------- /mmrazor/configs/nas/mmcls/spos/spos_shufflenet_search_predictor_8xb128_in1k.py: -------------------------------------------------------------------------------- 1 | _base_ = ['./spos_shufflenet_supernet_8xb128_in1k.py'] 2 | 3 | model = dict(norm_training=True) 4 | 5 | train_cfg = dict( 6 | _delete_=True, 7 | type='mmrazor.EvolutionSearchLoop', 8 | dataloader=_base_.val_dataloader, 9 | evaluator=_base_.val_evaluator, 10 | max_epochs=20, 11 | num_candidates=50, 12 | top_k=10, 13 | num_mutation=25, 14 | num_crossover=25, 15 | mutate_prob=0.1, 16 | constraints_range=dict(flops=(0., 360.)), 17 | predictor_cfg=dict( 18 | type='mmrazor.MetricPredictor', 19 | train_samples=20, 20 | handler_cfg=dict(type='mmrazor.GaussProcessHandler')), 21 | ) 22 | -------------------------------------------------------------------------------- /mmrazor/configs/nas/mmcls/spos/spos_shufflenet_subnet_8xb128_in1k.py: -------------------------------------------------------------------------------- 1 | _base_ = ['./spos_shufflenet_supernet_8xb128_in1k.py'] 2 | 3 | _base_.model = dict( 4 | _scope_='mmrazor', 5 | type='sub_model', 6 | cfg=_base_.supernet, 7 | # NOTE: You can replace the yaml with the mutable_cfg searched by yourself 8 | fix_subnet='configs/nas/mmcls/spos/SPOS_SUBNET.yaml', 9 | init_cfg=dict( 10 | type='Pretrained', 11 | checkpoint= # noqa: E251 12 | 'https://download.openmmlab.com/mmrazor/v1/spos/spos_shufflenetv2_subnet_8xb128_in1k_flops_0.33M_acc_73.87_20211222-1f0a0b4d_v3.pth', # noqa: E501 13 | prefix='architecture.')) 14 | 15 | model_wrapper_cfg = None 16 | 17 | find_unused_parameters = False 18 | -------------------------------------------------------------------------------- /mmrazor/configs/nas/mmcls/spos/spos_shufflenet_supernet_8xb128_in1k.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | 'mmrazor::_base_/settings/imagenet_bs1024_spos.py', 3 | 'mmrazor::_base_/nas_backbones/spos_shufflenet_supernet.py', 4 | 'mmcls::_base_/default_runtime.py', 5 | ] 6 | 7 | # model 8 | supernet = dict( 9 | _scope_='mmcls', 10 | type='ImageClassifier', 11 | data_preprocessor=_base_.preprocess_cfg, 12 | backbone=_base_.nas_backbone, 13 | neck=dict(type='GlobalAveragePooling'), 14 | head=dict( 15 | type='LinearClsHead', 16 | num_classes=1000, 17 | in_channels=1024, 18 | loss=dict( 19 | type='LabelSmoothLoss', 20 | num_classes=1000, 21 | label_smooth_val=0.1, 22 | mode='original', 23 | loss_weight=1.0), 24 | topk=(1, 5))) 25 | 26 | model = dict( 27 | type='mmrazor.SPOS', 28 | architecture=supernet, 29 | mutator=dict(type='mmrazor.NasMutator')) 30 | 31 | find_unused_parameters = True 32 | -------------------------------------------------------------------------------- /mmrazor/configs/nas/mmdet/detnas/DETNAS_SUBNET.yaml: -------------------------------------------------------------------------------- 1 | backbone.layers.0.0: 2 | chosen: shuffle_5x5 3 | backbone.layers.0.1: 4 | chosen: shuffle_3x3 5 | backbone.layers.0.2: 6 | chosen: shuffle_3x3 7 | backbone.layers.0.3: 8 | chosen: shuffle_3x3 9 | backbone.layers.1.0: 10 | chosen: shuffle_xception 11 | backbone.layers.1.1: 12 | chosen: shuffle_3x3 13 | backbone.layers.1.2: 14 | chosen: shuffle_xception 15 | backbone.layers.1.3: 16 | chosen: shuffle_7x7 17 | backbone.layers.2.0: 18 | chosen: shuffle_7x7 19 | backbone.layers.2.1: 20 | chosen: shuffle_7x7 21 | backbone.layers.2.2: 22 | chosen: shuffle_xception 23 | backbone.layers.2.3: 24 | chosen: shuffle_xception 25 | backbone.layers.2.4: 26 | chosen: shuffle_3x3 27 | backbone.layers.2.5: 28 | chosen: shuffle_7x7 29 | backbone.layers.2.6: 30 | chosen: shuffle_5x5 31 | backbone.layers.2.7: 32 | chosen: shuffle_xception 33 | backbone.layers.3.0: 34 | chosen: shuffle_7x7 35 | backbone.layers.3.1: 36 | chosen: shuffle_7x7 37 | backbone.layers.3.2: 38 | chosen: shuffle_7x7 39 | backbone.layers.3.3: 40 | chosen: shuffle_5x5 41 | -------------------------------------------------------------------------------- /mmrazor/configs/nas/mmdet/detnas/detnas_frcnn_shufflenet_search_coco_1x.py: -------------------------------------------------------------------------------- 1 | _base_ = ['./detnas_frcnn_shufflenet_supernet_coco_1x.py'] 2 | 3 | model = dict(norm_training=True) 4 | 5 | train_cfg = dict( 6 | _delete_=True, 7 | type='mmrazor.EvolutionSearchLoop', 8 | dataloader=_base_.val_dataloader, 9 | evaluator=_base_.val_evaluator, 10 | max_epochs=20, 11 | num_candidates=50, 12 | top_k=10, 13 | num_mutation=20, 14 | num_crossover=20, 15 | mutate_prob=0.1, 16 | constraints_range=dict(flops=(0, 330)), 17 | score_key='coco/bbox_mAP') 18 | -------------------------------------------------------------------------------- /mmrazor/configs/nas/mmdet/detnas/detnas_frcnn_shufflenet_subnet_coco_1x.py: -------------------------------------------------------------------------------- 1 | _base_ = ['./detnas_frcnn_shufflenet_supernet_coco_1x.py'] 2 | 3 | model = dict( 4 | _scope_='mmrazor', 5 | type='sub_model', 6 | cfg=_base_.supernet, 7 | # NOTE: You can replace the yaml with the mutable_cfg searched by yourself 8 | fix_subnet='configs/nas/mmdet/detnas/DETNAS_SUBNET.yaml', 9 | init_cfg=dict( 10 | type='Pretrained', 11 | checkpoint= # noqa: E251 12 | 'https://download.openmmlab.com/mmrazor/v1/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco_bbox_backbone_flops-0.34M_mAP-37.5_20220715-61d2e900_v1.pth', # noqa: E501 13 | prefix='architecture.')) 14 | 15 | find_unused_parameters = False 16 | -------------------------------------------------------------------------------- /mmrazor/configs/nas/mmdet/detnas/detnas_frcnn_shufflenet_supernet_coco_1x.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | 'mmdet::_base_/models/faster-rcnn_r50_fpn.py', 3 | 'mmdet::_base_/datasets/coco_detection.py', 4 | 'mmdet::_base_/schedules/schedule_1x.py', 5 | 'mmdet::_base_/default_runtime.py', 6 | 'mmrazor::_base_/nas_backbones/spos_shufflenet_supernet.py' 7 | ] 8 | 9 | norm_cfg = dict(type='SyncBN', requires_grad=True) 10 | 11 | supernet = _base_.model 12 | 13 | supernet.backbone = _base_.nas_backbone 14 | supernet.backbone.norm_cfg = norm_cfg 15 | supernet.backbone.out_indices = (0, 1, 2, 3) 16 | supernet.backbone.with_last_layer = False 17 | 18 | supernet.neck.norm_cfg = norm_cfg 19 | supernet.neck.in_channels = [64, 160, 320, 640] 20 | 21 | supernet.roi_head.bbox_head.norm_cfg = norm_cfg 22 | supernet.roi_head.bbox_head.type = 'Shared4Conv1FCBBoxHead' 23 | 24 | model = dict( 25 | _delete_=True, 26 | type='mmrazor.SPOS', 27 | architecture=supernet, 28 | mutator=dict(type='mmrazor.NasMutator')) 29 | 30 | find_unused_parameters = True 31 | -------------------------------------------------------------------------------- /mmrazor/configs/nas/mmdet/detnas/detnas_retina_shufflenet_supernet_coco_1x.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | 'mmdet::_base_/models/retinanet_r50_fpn.py', 3 | 'mmdet::_base_/datasets/coco_detection.py', 4 | 'mmdet::_base_/schedules/schedule_1x.py', 5 | 'mmdet::_base_/default_runtime.py', 6 | 'mmrazor::_base_/nas_backbones/spos_shufflenet_supernet.py' 7 | ] 8 | 9 | norm_cfg = dict(type='SyncBN', requires_grad=True) 10 | 11 | supernet = _base_.model 12 | 13 | supernet.backbone = _base_.nas_backbone 14 | supernet.backbone.norm_cfg = norm_cfg 15 | supernet.backbone.out_indices = (0, 1, 2, 3) 16 | supernet.backbone.with_last_layer = False 17 | 18 | supernet.neck.norm_cfg = norm_cfg 19 | supernet.neck.in_channels = [64, 160, 320, 640] 20 | 21 | model = dict( 22 | _delete_=True, 23 | type='mmrazor.SPOS', 24 | architecture=supernet, 25 | mutator=dict(type='mmrazor.NasMutator')) 26 | 27 | find_unused_parameters = True 28 | -------------------------------------------------------------------------------- /mmrazor/configs/nas/mmdet/detnas/detnas_shufflenet_subnet_8xb128_in1k.py: -------------------------------------------------------------------------------- 1 | _base_ = './detnas_shufflenet_supernet_8xb128_in1k.py' 2 | 3 | model = dict( 4 | _scope_='mmrazor', 5 | type='sub_model', 6 | cfg=_base_.supernet, 7 | # NOTE: You can replace the yaml with the mutable_cfg searched by yourself 8 | fix_subnet= # noqa: E251 9 | 'https://download.openmmlab.com/mmrazor/v1/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco_bbox_backbone_flops-0.34M_mAP-37.5_20220715-61d2e900_subnet_cfg_v1.yaml' # noqa: E501 10 | ) 11 | 12 | find_unused_parameters = False 13 | -------------------------------------------------------------------------------- /mmrazor/configs/nas/mmdet/detnas/detnas_shufflenet_supernet_8xb128_in1k.py: -------------------------------------------------------------------------------- 1 | _base_ = 'mmrazor::nas/mmcls/spos/shufflenet/spos_shufflenet_supernet_8xb128_in1k.py' # noqa: E501 2 | -------------------------------------------------------------------------------- /mmrazor/configs/nas/mmdet/detnas/metafile.yml: -------------------------------------------------------------------------------- 1 | Collections: 2 | - Name: DetNAS 3 | Metadata: 4 | Training Data: 5 | - ImageNet-1k 6 | - COCO 7 | Paper: 8 | URL: https://arxiv.org/abs/1903.10979 9 | Title: DetNAS:Backbone Search for Object Detection 10 | README: configs/nas/mmdet/detnas/README.md 11 | Code: 12 | URL: https://github.com/open-mmlab/mmrazor/blob/v0.1.0/mmrazor/models/algorithms/detnas.py 13 | Version: v0.1.0 14 | Converted From: 15 | Code: https://github.com/megvii-model/DetNAS 16 | Models: 17 | - Name: detnas_frcnn_shufflenet_subnet_coco_1x 18 | In Collection: DetNAS 19 | Metadata: 20 | FLOPs(Backbone): 340 MB 21 | Params(Backbone): 3.35 MB 22 | Supernet: FRCNN-ShuffleNetV2 23 | Mutable: https://download.openmmlab.com/mmrazor/v1/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco_bbox_backbone_flops-0.34M_mAP-37.5_20220715-61d2e900_subnet_cfg_v1.yaml 24 | Results: 25 | - Task: Object Detection 26 | Dataset: COCO 27 | Metrics: 28 | box AP: 37.5 29 | Config: configs/nas/mmdet/detnas/detnas_frcnn_shufflenet_subnet_coco_1x.py 30 | Weights: https://download.openmmlab.com/mmrazor/v1/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco_bbox_backbone_flops-0.34M_mAP-37.5_20220715-61d2e900_v1.pth 31 | -------------------------------------------------------------------------------- /mmrazor/configs/pruning/base/group_fisher/group_fisher_deploy_template.py: -------------------------------------------------------------------------------- 1 | ############################################################################# 2 | """You have to fill these args. 3 | 4 | _base_(str): The path to your pretrain config file. 5 | fix_subnet (Union[dict,str]): The dict store the pruning structure or the 6 | json file including it. 7 | divisor (int): The divisor the make the channel number divisible. 8 | """ 9 | 10 | _base_ = '' 11 | fix_subnet = {} 12 | divisor = 8 13 | ############################################################################## 14 | 15 | architecture = _base_.model 16 | 17 | model = dict( 18 | _delete_=True, 19 | _scope_='mmrazor', 20 | type='GroupFisherDeploySubModel', 21 | architecture=architecture, 22 | fix_subnet=fix_subnet, 23 | divisor=divisor, 24 | ) 25 | -------------------------------------------------------------------------------- /mmrazor/configs/pruning/base/group_fisher/group_fisher_finetune_template.py: -------------------------------------------------------------------------------- 1 | ############################################################################# 2 | """# You have to fill these args. 3 | 4 | _base_(str): The path to your pruning config file. 5 | pruned_path (str): The path to the checkpoint of the pruned model. 6 | finetune_lr (float): The lr rate to finetune. Usually, we directly use the lr 7 | rate of the pretrain. 8 | """ 9 | 10 | _base_ = '' 11 | pruned_path = '' 12 | finetune_lr = 0.1 13 | ############################################################################## 14 | 15 | algorithm = _base_.model 16 | algorithm.init_cfg = dict(type='Pretrained', checkpoint=pruned_path) 17 | 18 | model = dict( 19 | _delete_=True, 20 | _scope_='mmrazor', 21 | type='GroupFisherSubModel', 22 | algorithm=algorithm, 23 | ) 24 | 25 | # restore lr 26 | optim_wrapper = dict(optimizer=dict(lr=finetune_lr)) 27 | 28 | # remove pruning related hooks 29 | custom_hooks = _base_.custom_hooks[:-2] 30 | 31 | # delete ddp 32 | model_wrapper_cfg = None 33 | -------------------------------------------------------------------------------- /mmrazor/configs/pruning/mmcls/dcff/dcff_compact_resnet_8xb32_in1k.py: -------------------------------------------------------------------------------- 1 | _base_ = ['dcff_resnet_8xb32_in1k.py'] 2 | 3 | # model settings 4 | _base_.model = dict( 5 | _scope_='mmrazor', 6 | type='sub_model', 7 | cfg=dict( 8 | cfg_path='mmcls::resnet/resnet50_8xb32_in1k.py', pretrained=False), 9 | fix_subnet='configs/pruning/mmcls/dcff/fix_subnet.json', 10 | mode='mutator', 11 | init_cfg=dict( 12 | type='Pretrained', 13 | checkpoint='configs/pruning/mmcls/dcff/fix_subnet_weight.pth')) 14 | -------------------------------------------------------------------------------- /mmrazor/configs/pruning/mmcls/dmcp/metafile.yml: -------------------------------------------------------------------------------- 1 | # Models: 2 | # - Name: dmcp_resnet50_subnet_32xb64 3 | # In Collection: DMCP 4 | # Config: configs/pruning/mmcls/dmcp/dmcp_resnet50_subnet_32xb64.py 5 | # Weights: https://download.openmmlab.com/mmrazor/v1/pruning/dmcp/resnet50/2G/DMCP_R50_2G.pth 6 | # Results: 7 | # - Task: Image Classification 8 | # Dataset: ImageNet-1k 9 | # Metrics: 10 | # Top 1 Accuracy: 76.11 11 | # - Name: dmcp_mbv2_subnet_32xb64 12 | # In Collection: DMCP 13 | # Config: configs/pruning/mmcls/dmcp/dmcp_mbv2_subnet_32xb64.py 14 | # Weights: https://download.openmmlab.com/mmrazor/v1/pruning/dmcp/mobilenetv2/100M/DMCP_MBV2_100M.pth 15 | # Results: 16 | # - Task: Image Classification 17 | # Dataset: ImageNet-1k 18 | # Metrics: 19 | # Top 1 Accuracy: 67.22 20 | -------------------------------------------------------------------------------- /mmrazor/configs/pruning/mmcls/group_fisher/mobilenet/group_fisher_act_finetune_mobilenet-v2_8xb32_in1k.py: -------------------------------------------------------------------------------- 1 | ############################################################################# 2 | """# You have to fill these args. 3 | 4 | _base_(str): The path to your pruning config file. 5 | pruned_path (str): The path to the checkpoint of the pruned model. 6 | finetune_lr (float): The lr rate to finetune. Usually, we directly use the lr 7 | rate of the pretrain. 8 | """ 9 | 10 | _base_ = './group_fisher_act_prune_mobilenet-v2_8xb32_in1k.py' 11 | pruned_path = 'https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/mobilenet/act/group_fisher_act_prune_mobilenet-v2_8xb32_in1k.pth' # noqa 12 | finetune_lr = 0.045 13 | ############################################################################## 14 | algorithm = _base_.model 15 | algorithm.init_cfg = dict(type='Pretrained', checkpoint=pruned_path) 16 | 17 | model = dict( 18 | _delete_=True, 19 | _scope_='mmrazor', 20 | type='GroupFisherSubModel', 21 | algorithm=algorithm, 22 | ) 23 | 24 | # restore lr 25 | optim_wrapper = dict(optimizer=dict(lr=finetune_lr)) 26 | 27 | # remove pruning related hooks 28 | custom_hooks = _base_.custom_hooks[:-2] 29 | 30 | # delete ddp 31 | model_wrapper_cfg = None 32 | -------------------------------------------------------------------------------- /mmrazor/configs/pruning/mmcls/group_fisher/mobilenet/group_fisher_flops_finetune_mobilenet-v2_8xb32_in1k.py: -------------------------------------------------------------------------------- 1 | ############################################################################# 2 | """# You have to fill these args. 3 | 4 | _base_(str): The path to your pruning config file. 5 | pruned_path (str): The path to the checkpoint of the pruned model. 6 | finetune_lr (float): The lr rate to finetune. Usually, we directly use the lr 7 | rate of the pretrain. 8 | """ 9 | 10 | _base_ = './group_fisher_flops_prune_mobilenet-v2_8xb32_in1k.py' 11 | pruned_path = 'https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/mobilenet/flop/group_fisher_flops_prune_mobilenet-v2_8xb32_in1k.pth' # noqa 12 | finetune_lr = 0.045 13 | ############################################################################## 14 | 15 | algorithm = _base_.model 16 | algorithm.init_cfg = dict(type='Pretrained', checkpoint=pruned_path) 17 | 18 | model = dict( 19 | _delete_=True, 20 | _scope_='mmrazor', 21 | type='GroupFisherSubModel', 22 | algorithm=algorithm, 23 | ) 24 | 25 | # restore lr 26 | optim_wrapper = dict(optimizer=dict(lr=finetune_lr)) 27 | 28 | # remove pruning related hooks 29 | custom_hooks = _base_.custom_hooks[:-2] 30 | 31 | # delete ddp 32 | model_wrapper_cfg = None 33 | -------------------------------------------------------------------------------- /mmrazor/configs/pruning/mmcls/group_fisher/mobilenet/group_fisher_flops_prune_mobilenet-v2_8xb32_in1k.py: -------------------------------------------------------------------------------- 1 | _base_ = './group_fisher_act_prune_mobilenet-v2_8xb32_in1k.py' 2 | model = dict( 3 | mutator=dict( 4 | channel_unit_cfg=dict( 5 | default_args=dict(normalization_type='flops', ), ), ), ) 6 | -------------------------------------------------------------------------------- /mmrazor/configs/pruning/mmcls/group_fisher/mobilenet/metafile.yml: -------------------------------------------------------------------------------- 1 | Models: 2 | - Name: group_fisher_act_finetune_mobilenet-v2_8xb32_in1k 3 | In Collection: GroupFisher 4 | Config: configs/pruning/mmcls/group_fisher/mobilenet/group_fisher_act_finetune_mobilenet-v2_8xb32_in1k.py 5 | Weights: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/mobilenet/act/group_fisher_act_finetune_mobilenet-v2_8xb32_in1k.pth 6 | Results: 7 | - Task: Image Classification 8 | Dataset: ImageNet-1k 9 | Metrics: 10 | Top 1 Accuracy: 70.82 11 | - Name: group_fisher_flops_finetune_mobilenet-v2_8xb32_in1k 12 | In Collection: GroupFisher 13 | Config: configs/pruning/mmcls/group_fisher/mobilenet/group_fisher_flops_finetune_mobilenet-v2_8xb32_in1k.py 14 | Weights: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/mobilenet/flop/group_fisher_flops_finetune_mobilenet-v2_8xb32_in1k.pth 15 | Results: 16 | - Task: Image Classification 17 | Dataset: ImageNet-1k 18 | Metrics: 19 | Top 1 Accuracy: 70.87 20 | -------------------------------------------------------------------------------- /mmrazor/configs/pruning/mmcls/group_fisher/resnet50/group_fisher_act_finetune_resnet50_8xb32_in1k.py: -------------------------------------------------------------------------------- 1 | ############################################################################# 2 | """# You have to fill these args. 3 | 4 | _base_(str): The path to your pruning config file. 5 | pruned_path (str): The path to the checkpoint of the pruned model. 6 | finetune_lr (float): The lr rate to finetune. Usually, we directly use the lr 7 | rate of the pretrain. 8 | """ 9 | 10 | _base_ = './group_fisher_act_prune_resnet50_8xb32_in1k.py' 11 | pruned_path = 'https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/resnet50/act/group_fisher_act_prune_resnet50_8xb32_in1k.pth' # noqa 12 | finetune_lr = 0.1 13 | ############################################################################## 14 | algorithm = _base_.model 15 | algorithm.init_cfg = dict(type='Pretrained', checkpoint=pruned_path) 16 | 17 | model = dict( 18 | _delete_=True, 19 | _scope_='mmrazor', 20 | type='GroupFisherSubModel', 21 | algorithm=algorithm, 22 | ) 23 | 24 | # restore lr 25 | optim_wrapper = dict(optimizer=dict(lr=finetune_lr)) 26 | 27 | # remove pruning related hooks 28 | custom_hooks = _base_.custom_hooks[:-2] 29 | 30 | # delete ddp 31 | model_wrapper_cfg = None 32 | -------------------------------------------------------------------------------- /mmrazor/configs/pruning/mmcls/group_fisher/resnet50/group_fisher_flops_finetune_resnet50_8xb32_in1k.py: -------------------------------------------------------------------------------- 1 | ############################################################################# 2 | """# You have to fill these args. 3 | 4 | _base_(str): The path to your pruning config file. 5 | pruned_path (str): The path to the checkpoint of the pruned model. 6 | finetune_lr (float): The lr rate to finetune. Usually, we directly use the lr 7 | rate of the pretrain. 8 | """ 9 | 10 | _base_ = './group_fisher_flops_prune_resnet50_8xb32_in1k.py' 11 | pruned_path = 'https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/resnet50/flops/group_fisher_flops_prune_resnet50_8xb32_in1k.pth' # noqa 12 | finetune_lr = 0.1 13 | ############################################################################## 14 | algorithm = _base_.model 15 | algorithm.init_cfg = dict(type='Pretrained', checkpoint=pruned_path) 16 | 17 | model = dict( 18 | _delete_=True, 19 | _scope_='mmrazor', 20 | type='GroupFisherSubModel', 21 | algorithm=algorithm, 22 | ) 23 | 24 | # restore lr 25 | optim_wrapper = dict(optimizer=dict(lr=finetune_lr)) 26 | 27 | # remove pruning related hooks 28 | custom_hooks = _base_.custom_hooks[:-2] 29 | 30 | # delete ddp 31 | model_wrapper_cfg = None 32 | -------------------------------------------------------------------------------- /mmrazor/configs/pruning/mmcls/group_fisher/resnet50/group_fisher_flops_prune_resnet50_8xb32_in1k.py: -------------------------------------------------------------------------------- 1 | _base_ = './group_fisher_act_prune_resnet50_8xb32_in1k.py' 2 | model = dict( 3 | mutator=dict( 4 | channel_unit_cfg=dict( 5 | default_args=dict(normalization_type='flops', ), ), ), ) 6 | -------------------------------------------------------------------------------- /mmrazor/configs/pruning/mmcls/group_fisher/resnet50/metafile.yml: -------------------------------------------------------------------------------- 1 | Models: 2 | - Name: group_fisher_act_finetune_resnet50_8xb32_in1k 3 | In Collection: GroupFisher 4 | Config: configs/pruning/mmcls/group_fisher/resnet50/group_fisher_act_finetune_resnet50_8xb32_in1k.py 5 | Weights: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/resnet50/act/group_fisher_act_finetune_resnet50_8xb32_in1k.pth 6 | Results: 7 | - Task: Image Classification 8 | Dataset: ImageNet-1k 9 | Metrics: 10 | Top 1 Accuracy: 75.22 11 | - Name: group_fisher_flops_finetune_resnet50_8xb32_in1k 12 | In Collection: GroupFisher 13 | Config: configs/pruning/mmcls/group_fisher/resnet50/group_fisher_flops_finetune_resnet50_8xb32_in1k.py 14 | Weights: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/resnet50/flops/group_fisher_flops_finetune_resnet50_8xb32_in1k.pth 15 | Results: 16 | - Task: Image Classification 17 | Dataset: ImageNet-1k 18 | Metrics: 19 | Top 1 Accuracy: 75.61 20 | -------------------------------------------------------------------------------- /mmrazor/configs/pruning/mmcls/l1-norm/metafile.yml: -------------------------------------------------------------------------------- 1 | Models: 2 | - Name: l1-norm_resnet34_8xb32_in1k_a 3 | In Collection: L1-norm 4 | Config: configs/pruning/mmcls/l1-norm/l1-norm_resnet34_8xb32_in1k_a.py 5 | Weights: https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pruning/l1-norm/l1-norm_resnet34_8xb32_in1k_a.pth 6 | Results: 7 | - Task: Image Classification 8 | Dataset: ImageNet-1k 9 | Metrics: 10 | Top 1 Accuracy: 73.61 11 | - Name: l1-norm_resnet34_8xb32_in1k_b 12 | In Collection: L1-norm 13 | Config: configs/pruning/mmcls/l1-norm/l1-norm_resnet34_8xb32_in1k_b.py 14 | Weights: https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pruning/l1-norm/l1-norm_resnet34_8xb32_in1k_b.pth 15 | Results: 16 | - Task: Image Classification 17 | Dataset: ImageNet-1k 18 | Metrics: 19 | Top 1 Accuracy: 73.20 20 | - Name: l1-norm_resnet34_8xb32_in1k_c 21 | In Collection: L1-norm 22 | Config: configs/pruning/mmcls/l1-norm/l1-norm_resnet34_8xb32_in1k_c.py 23 | Weights: https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pruning/l1-norm/l1-norm_resnet34_8xb32_in1k_c.pth 24 | Results: 25 | - Task: Image Classification 26 | Dataset: ImageNet-1k 27 | Metrics: 28 | Top 1 Accuracy: 73.89 29 | -------------------------------------------------------------------------------- /mmrazor/configs/pruning/mmcls/l1-norm/script.sh: -------------------------------------------------------------------------------- 1 | 2 | # export pruned checkpoint example 3 | 4 | python ./tools/pruning/get_static_model_from_algorithm.py configs/pruning/mmcls/l1-norm/l1-norm_resnet34_8xb32_in1k_a.py https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pruning/l1-norm/l1-norm_resnet34_8xb32_in1k_a.pth -o ./work_dirs/norm_resnet34_8xb32_in1k_a 5 | 6 | # deploy example 7 | 8 | razor_config=configs/pruning/mmcls/l1-norm/l1-norm_resnet34_8xb32_in1k_a_deploy.py 9 | deploy_config=mmdeploy/configs/mmcls/classification_onnxruntime_dynamic.py 10 | static_model_checkpoint_path=path/to/pruend/checkpoint 11 | 12 | python mmdeploy/tools/deploy.py $deploy_config \ 13 | $razor_config \ 14 | $static_model_checkpoint_path \ 15 | mmdeploy/tests/data/tiger.jpeg \ 16 | --work-dir ./work_dirs/mmdeploy 17 | 18 | python mmdeploy/tools/profiler.py $deploy_config \ 19 | $razor_config \ 20 | mmdeploy/demo/resources \ 21 | --model ./work_dirs/mmdeploy/end2end.onnx \ 22 | --shape 224x224 \ 23 | --device cpu \ 24 | --num-iter 1000 \ 25 | --warmup 100 26 | -------------------------------------------------------------------------------- /mmrazor/configs/pruning/mmdet/dcff/dcff_compact_faster_rcnn_resnet50_8xb4_coco.py: -------------------------------------------------------------------------------- 1 | _base_ = ['dcff_faster_rcnn_resnet50_8xb4_coco.py'] 2 | 3 | # model settings 4 | _base_.model = dict( 5 | _scope_='mmrazor', 6 | type='sub_model', 7 | cfg=_base_.architecture, 8 | fix_subnet='configs/pruning/mmdet/dcff/fix_subnet.json', 9 | mode='mutator', 10 | init_cfg=dict( 11 | type='Pretrained', 12 | checkpoint='configs/pruning/mmdet/dcff/fix_subnet_weight.pth')) 13 | -------------------------------------------------------------------------------- /mmrazor/configs/pruning/mmdet/group_fisher/retinanet/group_fisher_act_finetune_retinanet_r50_fpn_1x_coco.py: -------------------------------------------------------------------------------- 1 | ############################################################################# 2 | """# You have to fill these args. 3 | 4 | _base_(str): The path to your pruning config file. 5 | pruned_path (str): The path to the checkpoint of the pruned model. 6 | finetune_lr (float): The lr rate to finetune. Usually, we directly use the lr 7 | rate of the pretrain. 8 | """ 9 | 10 | _base_ = './group_fisher_act_prune_retinanet_r50_fpn_1x_coco.py' 11 | pruned_path = 'https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/retinanet/act/group_fisher_act_prune_retinanet_r50_fpn_1x_coco.pth' # noqa 12 | finetune_lr = 0.005 13 | ############################################################################## 14 | algorithm = _base_.model 15 | algorithm.init_cfg = dict(type='Pretrained', checkpoint=pruned_path) 16 | 17 | model = dict( 18 | _delete_=True, 19 | _scope_='mmrazor', 20 | type='GroupFisherSubModel', 21 | algorithm=algorithm, 22 | ) 23 | 24 | # restore lr 25 | optim_wrapper = dict(optimizer=dict(lr=finetune_lr)) 26 | 27 | # remove pruning related hooks 28 | custom_hooks = _base_.custom_hooks[:-2] 29 | 30 | # delete ddp 31 | model_wrapper_cfg = None 32 | -------------------------------------------------------------------------------- /mmrazor/configs/pruning/mmdet/group_fisher/retinanet/group_fisher_flops_finetune_retinanet_r50_fpn_1x_coco.py: -------------------------------------------------------------------------------- 1 | ############################################################################# 2 | """# You have to fill these args. 3 | 4 | _base_(str): The path to your pruning config file. 5 | pruned_path (str): The path to the checkpoint of the pruned model. 6 | finetune_lr (float): The lr rate to finetune. Usually, we directly use the lr 7 | rate of the pretrain. 8 | """ 9 | 10 | _base_ = './group_fisher_flops_prune_retinanet_r50_fpn_1x_coco.py' 11 | pruned_path = 'https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/retinanet/flops/group_fisher_flops_prune_retinanet_r50_fpn_1x_coco.pth' # noqa 12 | finetune_lr = 0.005 13 | ############################################################################## 14 | algorithm = _base_.model 15 | algorithm.init_cfg = dict(type='Pretrained', checkpoint=pruned_path) 16 | 17 | model = dict( 18 | _delete_=True, 19 | _scope_='mmrazor', 20 | type='GroupFisherSubModel', 21 | algorithm=algorithm, 22 | ) 23 | 24 | # restore lr 25 | optim_wrapper = dict(optimizer=dict(lr=finetune_lr)) 26 | 27 | # remove pruning related hooks 28 | custom_hooks = _base_.custom_hooks[:-2] 29 | 30 | # delete ddp 31 | model_wrapper_cfg = None 32 | -------------------------------------------------------------------------------- /mmrazor/configs/pruning/mmdet/group_fisher/retinanet/group_fisher_flops_prune_retinanet_r50_fpn_1x_coco.py: -------------------------------------------------------------------------------- 1 | _base_ = './group_fisher_act_prune_retinanet_r50_fpn_1x_coco.py' 2 | model = dict( 3 | mutator=dict( 4 | channel_unit_cfg=dict( 5 | default_args=dict(normalization_type='flops', ), ), ), ) 6 | -------------------------------------------------------------------------------- /mmrazor/configs/pruning/mmdet/group_fisher/retinanet/metafile.yml: -------------------------------------------------------------------------------- 1 | Models: 2 | - Name: group_fisher_act_finetune_retinanet_r50_fpn_1x_coco 3 | In Collection: GroupFisher 4 | Config: configs/pruning/mmdet/group_fisher/retinanet/group_fisher_act_finetune_retinanet_r50_fpn_1x_coco.py 5 | Weights: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/retinanet/act/group_fisher_act_finetune_retinanet_r50_fpn_1x_coco.pth 6 | Results: 7 | - Task: Object Detection 8 | Dataset: COCO 9 | Metrics: 10 | box AP: 36.5 11 | - Name: group_fisher_flops_finetune_retinanet_r50_fpn_1x_coco 12 | In Collection: GroupFisher 13 | Config: configs/pruning/mmdet/group_fisher/retinanet/group_fisher_flops_finetune_retinanet_r50_fpn_1x_coco.py 14 | Weights: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/retinanet/flops/group_fisher_flops_finetune_retinanet_r50_fpn_1x_coco.pth 15 | Results: 16 | - Task: Object Detection 17 | Dataset: COCO 18 | Metrics: 19 | box AP: 36.6 20 | -------------------------------------------------------------------------------- /mmrazor/configs/pruning/mmpose/dcff/dcff_compact_topdown_heatmap_resnet50_coco.py: -------------------------------------------------------------------------------- 1 | _base_ = ['dcff_topdown_heatmap_resnet50_coco.py'] 2 | 3 | # model settings 4 | _base_.model = dict( 5 | _scope_='mmrazor', 6 | type='sub_model', 7 | cfg=_base_.architecture, 8 | fix_subnet='configs/pruning/mmpose/dcff/fix_subnet.json', 9 | mode='mutator', 10 | init_cfg=dict( 11 | type='Pretrained', 12 | checkpoint='configs/pruning/mmpose/dcff/fix_subnet_weight.pth')) 13 | -------------------------------------------------------------------------------- /mmrazor/configs/pruning/mmpose/group_fisher/group_fisher_finetune_rtmpose-s_8xb256-420e_aic-coco-256x192.py: -------------------------------------------------------------------------------- 1 | ############################################################################# 2 | """# You have to fill these args. 3 | 4 | _base_(str): The path to your pruning config file. 5 | pruned_path (str): The path to the checkpoint of the pruned model. 6 | finetune_lr (float): The lr rate to finetune. Usually, we directly use the lr 7 | rate of the pretrain. 8 | """ 9 | 10 | _base_ = './group_fisher_prune_rtmpose-s_8xb256-420e_aic-coco-256x192.py' # noqa 11 | pruned_path = 'https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/rtmpose-s/group_fisher_prune_rtmpose-s_8xb256-420e_aic-coco-256x192.pth' # noqa 12 | finetune_lr = 4e-3 13 | ############################################################################## 14 | 15 | algorithm = _base_.model 16 | algorithm.init_cfg = dict(type='Pretrained', checkpoint=pruned_path) 17 | 18 | model = dict( 19 | _delete_=True, 20 | _scope_='mmrazor', 21 | type='GroupFisherSubModel', 22 | algorithm=algorithm, 23 | ) 24 | 25 | # restore lr 26 | optim_wrapper = dict(optimizer=dict(lr=finetune_lr)) 27 | 28 | # remove pruning related hooks 29 | custom_hooks = _base_.custom_hooks[:-2] 30 | 31 | # delete ddp 32 | model_wrapper_cfg = None 33 | -------------------------------------------------------------------------------- /mmrazor/configs/pruning/mmseg/dcff/dcff_compact_pointrend_resnet50_8xb2_cityscapes.py: -------------------------------------------------------------------------------- 1 | _base_ = ['dcff_pointrend_resnet50_8xb2_cityscapes.py'] 2 | 3 | # model settings 4 | _base_.model = dict( 5 | _scope_='mmrazor', 6 | type='sub_model', 7 | cfg=_base_.architecture, 8 | fix_subnet='configs/pruning/mmseg/dcff/fix_subnet.json', 9 | mode='mutator', 10 | init_cfg=dict( 11 | type='Pretrained', 12 | checkpoint='configs/pruning/mmseg/dcff/fix_subnet_weight.pth')) 13 | -------------------------------------------------------------------------------- /mmrazor/configs/quantization/deploy_cfgs/mmcls/classification_openvino_dynamic-224x224.py: -------------------------------------------------------------------------------- 1 | deploy_cfg = dict( 2 | onnx_config=dict( 3 | type='onnx', 4 | export_params=True, 5 | keep_initializers_as_inputs=False, 6 | opset_version=11, 7 | save_file='end2end.onnx', 8 | input_names=['input'], 9 | output_names=['output'], 10 | input_shape=None, 11 | optimize=True, 12 | dynamic_axes={ 13 | 'input': { 14 | 0: 'batch', 15 | 2: 'height', 16 | 3: 'width' 17 | }, 18 | 'output': { 19 | 0: 'batch' 20 | } 21 | }), 22 | backend_config=dict( 23 | type='openvino', 24 | model_inputs=[dict(opt_shapes=dict(input=[1, 3, 224, 224]))]), 25 | codebase_config=dict(type='mmcls', task='Classification'), 26 | function_record_to_pop=[ 27 | 'mmcls.models.classifiers.ImageClassifier.forward', 28 | 'mmcls.models.classifiers.BaseClassifier.forward' 29 | ], 30 | ) 31 | -------------------------------------------------------------------------------- /mmrazor/configs/quantization/qat/base/metafile.yml: -------------------------------------------------------------------------------- 1 | Collections: 2 | - Name: QAT 3 | README: configs/quantization/qat/base/README.md 4 | Models: 5 | - Name: qat_openvino_resnet18_10e_8xb32_in1k.py 6 | In Collection: QAT 7 | Metadata: 8 | Backend: openvino 9 | Float Model: 10 | Config: mmcls::resnet/resnet18_8xb32_in1k.py 11 | Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth 12 | Metrics: 13 | Top 1 Accuracy: 69.90 14 | Results: 15 | - Task: Image Classification 16 | Dataset: ImageNet-1k 17 | Metrics: 18 | Top 1 Accuracy: 69.98 19 | Config: configs/quantization/qat/base/qat_openvino_resnet18_10e_8xb32_in1k.py 20 | Weights: https://download.openmmlab.com/mmrazor/v1/quantization/qat/openvino/qat_openvino_resnet18_8xb32_10e_in1k_20230413_172732-5b9ff01d.pth 21 | -------------------------------------------------------------------------------- /mmrazor/configs/vanilla/mmcls/wide-resnet/wrn16-w2_b16x8_cifar10.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | 'mmcls::_base_/datasets/cifar10_bs16.py', 3 | '../../../_base_/vanilla_models/wrn16_2_cifar10.py', 4 | 'mmcls::_base_/schedules/cifar10_bs128.py', 5 | 'mmcls::_base_/default_runtime.py', 6 | ] 7 | test_evaluator = dict(topk=(1, 5)) 8 | -------------------------------------------------------------------------------- /mmrazor/configs/vanilla/mmcls/wide-resnet/wrn22-w4_b16x8_cifar10.py: -------------------------------------------------------------------------------- 1 | _base_ = ['wrn16-w2_b16x8_cifar10.py'] 2 | model = dict( 3 | backbone=dict(depth=22, widen_factor=4), head=dict(in_channels=256, )) 4 | -------------------------------------------------------------------------------- /mmrazor/configs/vanilla/mmcls/wide-resnet/wrn28-w4_b16x8_cifar10.py: -------------------------------------------------------------------------------- 1 | _base_ = ['wrn16-w2_b16x8_cifar10.py'] 2 | model = dict( 3 | backbone=dict(depth=28, widen_factor=4), head=dict(in_channels=256, )) 4 | -------------------------------------------------------------------------------- /mmrazor/configs/vanilla/mmcls/wide-resnet/wrn40-w2_b16x8_cifar10.py: -------------------------------------------------------------------------------- 1 | _base_ = ['wrn16-w2_b16x8_cifar10.py'] 2 | model = dict( 3 | backbone=dict(depth=40, widen_factor=2), head=dict(in_channels=128, )) 4 | -------------------------------------------------------------------------------- /mmrazor/docker/Dockerfile: -------------------------------------------------------------------------------- 1 | ARG PYTORCH="1.6.0" 2 | ARG CUDA="10.1" 3 | ARG CUDNN="7" 4 | 5 | FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel 6 | 7 | ENV TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0+PTX" 8 | ENV TORCH_NVCC_FLAGS="-Xfatbin -compress-all" 9 | ENV CMAKE_PREFIX_PATH="(dirname(which conda))/../" 10 | 11 | RUN apt-get update && apt-get install -y ffmpeg libsm6 libxext6 git ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6 \ 12 | && apt-get clean \ 13 | && rm -rf /var/lib/apt/lists/* 14 | 15 | # Install MMCV 16 | RUN pip install mmcv-full==1.3.8 -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.6.0/index.html 17 | 18 | # Install MMClassification 19 | RUN conda clean --all 20 | RUN git clone https://github.com/open-mmlab/mmclassification.git 21 | WORKDIR ./mmclassification 22 | RUN pip install --no-cache-dir -e . 23 | -------------------------------------------------------------------------------- /mmrazor/docker/serve/config.properties: -------------------------------------------------------------------------------- 1 | inference_address=http://0.0.0.0:8080 2 | management_address=http://0.0.0.0:8081 3 | metrics_address=http://0.0.0.0:8082 4 | model_store=/home/model-server/model-store 5 | load_models=all 6 | -------------------------------------------------------------------------------- /mmrazor/docker/serve/entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | if [[ "$1" = "serve" ]]; then 5 | shift 1 6 | torchserve --start --ts-config /home/model-server/config.properties 7 | else 8 | eval "$@" 9 | fi 10 | 11 | # prevent docker exit 12 | tail -f /dev/null 13 | -------------------------------------------------------------------------------- /mmrazor/docs/en/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /mmrazor/docs/en/_static/css/readthedocs.css: -------------------------------------------------------------------------------- 1 | .header-logo { 2 | background-image: url("../image/mmrazor-logo.png"); 3 | background-size: 125px 40px; 4 | height: 40px; 5 | width: 125px; 6 | } 7 | -------------------------------------------------------------------------------- /mmrazor/docs/en/_static/image/mmrazor-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LKJacky/Differentiable-Model-Scaling/27188926d8e0815ba1b6159b47741ea289ba2ca2/mmrazor/docs/en/_static/image/mmrazor-logo.png -------------------------------------------------------------------------------- /mmrazor/docs/en/advanced_guides/index.rst: -------------------------------------------------------------------------------- 1 | Key Concepts 2 | *************** 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | algorithm.md 8 | mutator.md 9 | mutable.md 10 | recorder.md 11 | delivery.md 12 | 13 | Development tutorials 14 | ************************ 15 | 16 | .. toctree:: 17 | :maxdepth: 1 18 | 19 | customize_architectures.md 20 | customize_nas_algorithms.md 21 | customize_pruning_algorithms.md 22 | customize_kd_algorithms.md 23 | customize_quantization_algorithms.md 24 | customize_mixed_algorithms.md 25 | apply_existing_algorithms_to_new_tasks.md 26 | -------------------------------------------------------------------------------- /mmrazor/docs/en/imgs/pruning/draw-config.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LKJacky/Differentiable-Model-Scaling/27188926d8e0815ba1b6159b47741ea289ba2ca2/mmrazor/docs/en/imgs/pruning/draw-config.png -------------------------------------------------------------------------------- /mmrazor/docs/en/imgs/pruning/framework-ChanelMutator.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LKJacky/Differentiable-Model-Scaling/27188926d8e0815ba1b6159b47741ea289ba2ca2/mmrazor/docs/en/imgs/pruning/framework-ChanelMutator.png -------------------------------------------------------------------------------- /mmrazor/docs/en/imgs/pruning/framework-algorithm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LKJacky/Differentiable-Model-Scaling/27188926d8e0815ba1b6159b47741ea289ba2ca2/mmrazor/docs/en/imgs/pruning/framework-algorithm.png -------------------------------------------------------------------------------- /mmrazor/docs/en/imgs/pruning/framework-framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LKJacky/Differentiable-Model-Scaling/27188926d8e0815ba1b6159b47741ea289ba2ca2/mmrazor/docs/en/imgs/pruning/framework-framework.png -------------------------------------------------------------------------------- /mmrazor/docs/en/imgs/pruning/framework-graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LKJacky/Differentiable-Model-Scaling/27188926d8e0815ba1b6159b47741ea289ba2ca2/mmrazor/docs/en/imgs/pruning/framework-graph.png -------------------------------------------------------------------------------- /mmrazor/docs/en/imgs/pruning/framework-op.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LKJacky/Differentiable-Model-Scaling/27188926d8e0815ba1b6159b47741ea289ba2ca2/mmrazor/docs/en/imgs/pruning/framework-op.png -------------------------------------------------------------------------------- /mmrazor/docs/en/imgs/pruning/pruning_framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LKJacky/Differentiable-Model-Scaling/27188926d8e0815ba1b6159b47741ea289ba2ca2/mmrazor/docs/en/imgs/pruning/pruning_framework.png -------------------------------------------------------------------------------- /mmrazor/docs/en/imgs/pruning/unit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LKJacky/Differentiable-Model-Scaling/27188926d8e0815ba1b6159b47741ea289ba2ca2/mmrazor/docs/en/imgs/pruning/unit.png -------------------------------------------------------------------------------- /mmrazor/docs/en/index.rst: -------------------------------------------------------------------------------- 1 | .. mmrazor documentation master file, created by 2 | sphinx-quickstart on Mon Aug 29 15:21:38 2022. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to mmrazor's documentation! 7 | =================================== 8 | 9 | .. toctree:: 10 | :maxdepth: 1 11 | :caption: Get Started: 12 | 13 | get_started/overview.md 14 | get_started/installation.md 15 | get_started/model_zoo.md 16 | 17 | .. toctree:: 18 | :maxdepth: 2 19 | :caption: User Guides: 20 | 21 | user_guides/index.rst 22 | 23 | .. toctree:: 24 | :maxdepth: 2 25 | :caption: Advanced Guides 26 | 27 | advanced_guides/index.rst 28 | 29 | .. toctree:: 30 | :maxdepth: 1 31 | :caption: Notes 32 | 33 | notes/changelog.md 34 | notes/contribution_guide.md 35 | notes/faq.md 36 | 37 | .. toctree:: 38 | :maxdepth: 1 39 | :caption: APIs Reference 40 | 41 | api.rst 42 | 43 | .. toctree:: 44 | :maxdepth: 1 45 | :caption: Switch Language 46 | 47 | switch_language.md 48 | 49 | 50 | Indices and tables 51 | ================== 52 | 53 | * :ref:`genindex` 54 | * :ref:`modindex` 55 | * :ref:`search` 56 | -------------------------------------------------------------------------------- /mmrazor/docs/en/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /mmrazor/docs/en/notes/faq.md: -------------------------------------------------------------------------------- 1 | # Frequently Asked Questions 2 | -------------------------------------------------------------------------------- /mmrazor/docs/en/switch_language.md: -------------------------------------------------------------------------------- 1 | ## English 2 | 3 | ## 简体中文 4 | -------------------------------------------------------------------------------- /mmrazor/docs/en/user_guides/1_learn_about_config.md: -------------------------------------------------------------------------------- 1 | # Learn about Configs 2 | 3 | ## Directory structure of configs in mmrazor 4 | 5 | ![image](https://user-images.githubusercontent.com/88702197/187635756-55f80a44-161b-4af9-b226-9b7aef68a139.png) 6 | 7 | `mmxxx`: means some task repositories of OpenMMLab, such mmcls, mmdet, mmseg and so on. 8 | 9 | `_base_`: includes configures of datasets, experiment settings and model architectures. 10 | 11 | `distill`/`nas`/`pruning`: model compression algorithms. 12 | 13 | `vanilla`: task models owned by mmrazor. 14 | 15 | ## More about config 16 | 17 | Please refer to [config](https://github.com/open-mmlab/mmengine/blob/main/docs/zh_cn/tutorials/config.md) in mmengine. 18 | -------------------------------------------------------------------------------- /mmrazor/docs/en/user_guides/index.rst: -------------------------------------------------------------------------------- 1 | Train & Test 2 | ************** 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | 8 | 1_learn_about_config.md 9 | 2_train_different_types_algorithms.md 10 | 3_train_with_different_devices.md 11 | 4_test_a_model.md 12 | 13 | Quantization 14 | ************ 15 | 16 | .. toctree:: 17 | :maxdepth: 1 18 | 19 | quantization_user_guide.md 20 | 21 | Useful Tools 22 | ************ 23 | 24 | please refer to upstream applied repositories' docs 25 | -------------------------------------------------------------------------------- /mmrazor/docs/zh_cn/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /mmrazor/docs/zh_cn/index.rst: -------------------------------------------------------------------------------- 1 | 欢迎来到 MMRazor 的中文文档! 2 | ========================= 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | :caption: 开始你的第一步 7 | 8 | .. toctree:: 9 | :maxdepth: 2 10 | :caption: 快速启动 11 | 12 | .. toctree:: 13 | :maxdepth: 2 14 | :caption: 教程 15 | 16 | .. toctree:: 17 | :caption: 语言切换 18 | 19 | switch_language.md 20 | 21 | .. toctree:: 22 | :caption: 接口文档(英文) 23 | 24 | api.rst 25 | 26 | Indices and tables 27 | ================== 28 | 29 | * :ref:`genindex` 30 | * :ref:`search` 31 | -------------------------------------------------------------------------------- /mmrazor/docs/zh_cn/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /mmrazor/docs/zh_cn/switch_language.md: -------------------------------------------------------------------------------- 1 | ## English 2 | 3 | ## 简体中文 4 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import mmcv 3 | import mmengine 4 | from mmengine.utils import digit_version 5 | 6 | from .version import __version__ 7 | 8 | mmcv_minimum_version = '2.0.0rc1' 9 | mmcv_maximum_version = '2.1.0' 10 | mmcv_version = digit_version(mmcv.__version__) 11 | 12 | mmengine_minimum_version = '0.1.0' 13 | mmengine_maximum_version = '1.0.0' 14 | mmengine_version = digit_version(mmengine.__version__) 15 | 16 | assert (mmcv_version >= digit_version(mmcv_minimum_version) 17 | and mmcv_version <= digit_version(mmcv_maximum_version)), \ 18 | f'MMCV=={mmcv.__version__} is used but incompatible. ' \ 19 | f'Please install mmcv>={mmcv_minimum_version}, <={mmcv_maximum_version}.' 20 | 21 | assert (mmengine_version >= digit_version(mmengine_minimum_version) 22 | and mmengine_version < digit_version(mmengine_maximum_version)), \ 23 | f'MMEngine=={mmengine.__version__} is used but incompatible. ' \ 24 | f'Please install mmengine>={mmengine_minimum_version}, ' \ 25 | f'<{mmengine_maximum_version}.' 26 | 27 | __all__ = ['__version__', 'digit_version'] 28 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .crd_dataset_wrapper import CRDDataset 3 | from .transforms import AutoAugment, AutoAugmentV2, PackCRDClsInputs 4 | 5 | __all__ = ['AutoAugment', 'AutoAugmentV2', 'PackCRDClsInputs', 'CRDDataset'] 6 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/datasets/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .auto_augment import AutoAugment 3 | from .auto_augmentv2 import AutoAugmentV2 4 | from .formatting import PackCRDClsInputs 5 | 6 | __all__ = ['AutoAugment', 'AutoAugmentV2', 'PackCRDClsInputs'] 7 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/engine/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .hooks import (DMCPSubnetHook, DumpSubnetHook, EstimateResourcesHook, 3 | StopDistillHook) 4 | from .optimizers import SeparateOptimWrapperConstructor 5 | from .runner import (AutoSlimGreedySearchLoop, DartsEpochBasedTrainLoop, 6 | DartsIterBasedTrainLoop, EvolutionSearchLoop, 7 | GreedySamplerTrainLoop, LSQEpochBasedLoop, PTQLoop, 8 | QATEpochBasedLoop, QATValLoop, SelfDistillValLoop, 9 | SingleTeacherDistillValLoop, SlimmableValLoop, 10 | SubnetValLoop) 11 | 12 | __all__ = [ 13 | 'DMCPSubnetHook', 'StopDistillHook', 'SeparateOptimWrapperConstructor', 14 | 'DumpSubnetHook', 'SingleTeacherDistillValLoop', 15 | 'DartsEpochBasedTrainLoop', 'DartsIterBasedTrainLoop', 'SlimmableValLoop', 16 | 'EvolutionSearchLoop', 'GreedySamplerTrainLoop', 'EstimateResourcesHook', 17 | 'SelfDistillValLoop', 'AutoSlimGreedySearchLoop', 'SubnetValLoop', 18 | 'PTQLoop', 'QATEpochBasedLoop', 'LSQEpochBasedLoop', 'QATValLoop' 19 | ] 20 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/engine/hooks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .dmcp_subnet_hook import DMCPSubnetHook 3 | from .dump_subnet_hook import DumpSubnetHook 4 | from .estimate_resources_hook import EstimateResourcesHook 5 | from .stop_distillation_hook import StopDistillHook 6 | from .visualization_hook import RazorVisualizationHook 7 | 8 | __all__ = [ 9 | 'DumpSubnetHook', 'EstimateResourcesHook', 'RazorVisualizationHook', 10 | 'DMCPSubnetHook', 'StopDistillHook' 11 | ] 12 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/engine/hooks/group_fisher_hooks.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | """This file includes the modules in the impl folder. 3 | 4 | As it only records impl modules, it is not initialized automatically. 5 | """ 6 | from mmrazor.implementations.pruning.group_fisher import \ 7 | PruningStructureHook # noqa 8 | from mmrazor.implementations.pruning.group_fisher import \ 9 | ResourceInfoHook # noqa 10 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/engine/hooks/stop_distillation_hook.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmengine.hooks import Hook 3 | from mmengine.model import is_model_wrapper 4 | 5 | from mmrazor.registry import HOOKS 6 | 7 | 8 | @HOOKS.register_module() 9 | class StopDistillHook(Hook): 10 | """Stop distilling at a certain time. 11 | 12 | Args: 13 | stop_epoch (int): Stop distillation at this epoch. 14 | """ 15 | 16 | priority = 'LOW' 17 | 18 | def __init__(self, stop_epoch: int) -> None: 19 | self.stop_epoch = stop_epoch 20 | 21 | def before_train_epoch(self, runner) -> None: 22 | """Stop distillation.""" 23 | if runner.epoch >= self.stop_epoch: 24 | model = runner.model 25 | # TODO: refactor after mmengine using model wrapper 26 | if is_model_wrapper(model): 27 | model = model.module 28 | assert hasattr(model, 'distillation_stopped') 29 | 30 | runner.logger.info('Distillation has been stopped!') 31 | model.distillation_stopped = True 32 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/engine/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .optimizer_constructor import SeparateOptimWrapperConstructor 3 | 4 | __all__ = ['SeparateOptimWrapperConstructor'] 5 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/engine/runner/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .autoslim_greedy_search_loop import AutoSlimGreedySearchLoop 3 | from .darts_loop import DartsEpochBasedTrainLoop, DartsIterBasedTrainLoop 4 | from .distill_val_loop import SelfDistillValLoop, SingleTeacherDistillValLoop 5 | from .evolution_search_loop import EvolutionSearchLoop 6 | from .iteprune_val_loop import ItePruneValLoop 7 | from .quantization_loops import (LSQEpochBasedLoop, PTQLoop, QATEpochBasedLoop, 8 | QATValLoop) 9 | from .slimmable_val_loop import SlimmableValLoop 10 | from .subnet_sampler_loop import GreedySamplerTrainLoop 11 | from .subnet_val_loop import SubnetValLoop 12 | 13 | __all__ = [ 14 | 'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop', 15 | 'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop', 16 | 'GreedySamplerTrainLoop', 'SubnetValLoop', 'SelfDistillValLoop', 17 | 'ItePruneValLoop', 'AutoSlimGreedySearchLoop', 'QATEpochBasedLoop', 18 | 'PTQLoop', 'LSQEpochBasedLoop', 'QATValLoop' 19 | ] 20 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/engine/runner/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .calibrate_bn_mixin import CalibrateBNMixin 3 | from .check import check_subnet_resources 4 | from .genetic import crossover 5 | 6 | __all__ = ['crossover', 'check_subnet_resources', 'CalibrateBNMixin'] 7 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/engine/runner/utils/genetic.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import copy 3 | 4 | import numpy as np 5 | 6 | from mmrazor.utils import SingleMutatorRandomSubnet 7 | 8 | 9 | def crossover(random_subnet1: SingleMutatorRandomSubnet, 10 | random_subnet2: SingleMutatorRandomSubnet, 11 | prob: float = 0.5) -> SingleMutatorRandomSubnet: 12 | """Crossover in genetic algorithm. 13 | 14 | Args: 15 | random_subnet1 (SINGLE_MUTATOR_RANDOM_SUBNET): One of the subnets to 16 | crossover. 17 | random_subnet2 (SINGLE_MUTATOR_RANDOM_SUBNET): One of the subnets to 18 | crossover. 19 | prob (float): The probablity of getting choice from `random_subnet2`. 20 | Defaults to 0.5. 21 | 22 | Returns: 23 | SINGLE_MUTATOR_RANDOM_SUBNET: The result of crossover. 24 | """ 25 | assert prob >= 0. and prob <= 1., \ 26 | 'The probability of crossover has to be between 0 and 1' 27 | crossover_subnet = copy.deepcopy(random_subnet1) 28 | for group_id, choice in random_subnet2.items(): 29 | if np.random.random_sample() < prob: 30 | crossover_subnet[group_id] = choice 31 | return crossover_subnet 32 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/implementations/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | """impl folder is an experimental file structure to store algorithm 3 | implementations. 4 | 5 | Previous file structure splits the files of an algorithm into different folders 6 | according to the types of these files. It may make it hard to understand an 7 | algorithm. So we add the impl folder, where all files of an algorithm are 8 | stored in one folder. As this structure is experimental, it may change rapidly. 9 | """ 10 | 11 | from . import pruning # noqa 12 | 13 | __all__ = ['pruning'] 14 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/implementations/pruning/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from . import group_fisher, sparse_gpt 3 | 4 | __all__ = ['group_fisher', 'sparse_gpt'] 5 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/implementations/pruning/group_fisher/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .algorithm import GroupFisherAlgorithm 3 | from .counters import GroupFisherConv2dCounter, GroupFisherLinearCounter 4 | from .hook import PruningStructureHook, ResourceInfoHook 5 | from .mutator import GroupFisherChannelMutator 6 | from .ops import GroupFisherConv2d, GroupFisherLinear, GroupFisherMixin 7 | from .prune_deploy_sub_model import GroupFisherDeploySubModel 8 | from .prune_sub_model import GroupFisherSubModel 9 | from .unit import GroupFisherChannelUnit 10 | 11 | __all__ = [ 12 | 'GroupFisherDeploySubModel', 13 | 'GroupFisherSubModel', 14 | 'GroupFisherAlgorithm', 15 | 'GroupFisherConv2dCounter', 16 | 'GroupFisherLinearCounter', 17 | 'PruningStructureHook', 18 | 'ResourceInfoHook', 19 | 'GroupFisherChannelMutator', 20 | 'GroupFisherChannelUnit', 21 | 'GroupFisherConv2d', 22 | 'GroupFisherLinear', 23 | 'GroupFisherMixin', 24 | ] 25 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/implementations/pruning/group_fisher/counters.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmrazor.models.task_modules.estimators.counters import ( 3 | DynamicConv2dCounter, DynamicLinearCounter) 4 | from mmrazor.registry import TASK_UTILS 5 | 6 | 7 | @TASK_UTILS.register_module() 8 | class GroupFisherConv2dCounter(DynamicConv2dCounter): 9 | """Counter of GroupFisherConv2d.""" 10 | pass 11 | 12 | 13 | @TASK_UTILS.register_module() 14 | class GroupFisherLinearCounter(DynamicLinearCounter): 15 | """Counter of GroupFisherLinear.""" 16 | pass 17 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/implementations/pruning/sparse_gpt/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .compressor import SparseGptCompressor 3 | from .ops import SparseGptLinear, SparseGptMixIn 4 | from .utils import replace_with_dynamic_ops 5 | 6 | __all__ = [ 7 | 'SparseGptLinear', 'SparseGptMixIn', 'replace_with_dynamic_ops', 8 | 'SparseGptCompressor' 9 | ] 10 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/implementations/pruning/sparse_gpt/sparse24_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | 4 | 5 | @torch.no_grad() 6 | def is_weight_sparse_24(weight: torch.Tensor, dim=-1): 7 | """"Check if the weight is sparse 24.""" 8 | weight = weight.transpose(-1, dim).reshape(-1, 4) # N 4 9 | is_zero = (weight == 0).sum(-1) # N 10 | return (is_zero >= 2).all() 11 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/implementations/quantization/gptq/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .compressor import GPTQCompressor 3 | from .gptq import GPTQMixIn 4 | from .ops import GPTQConv2d, GPTQLinear, TritonGPTQLinear 5 | from .quantizer import Quantizer 6 | 7 | __all__ = [ 8 | 'GPTQCompressor', 9 | 'GPTQMixIn', 10 | 'GPTQConv2d', 11 | 'GPTQLinear', 12 | 'TritonGPTQLinear', 13 | 'Quantizer', 14 | ] 15 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .algorithms import * # noqa: F401,F403 3 | from .architectures import * # noqa: F401,F403 4 | from .distillers import * # noqa: F401,F403 5 | from .fake_quants import * # noqa: F401,F403 6 | from .losses import * # noqa: F401,F403 7 | from .mutables import * # noqa: F401,F403 8 | from .mutators import * # noqa: F401,F403 9 | from .observers import * # noqa: F401,F403 10 | from .quantizers import * # noqa: F401,F403 11 | from .task_modules import * # noqa: F401,F403 12 | from .utils import * # noqa: F401,F403 13 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/algorithms/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .base import BaseAlgorithm 3 | from .distill import (DAFLDataFreeDistillation, DataFreeDistillation, 4 | FpnTeacherDistill, OverhaulFeatureDistillation, 5 | SelfDistill, SingleTeacherDistill) 6 | from .nas import (DSNAS, DSNASDDP, SPOS, Autoformer, AutoSlim, AutoSlimDDP, 7 | BigNAS, BigNASDDP, Darts, DartsDDP) 8 | from .pruning import DCFF, DMCP, DMCPDDP, SlimmableNetwork, SlimmableNetworkDDP 9 | from .pruning.ite_prune_algorithm import ItePruneAlgorithm 10 | from .quantization import MMArchitectureQuant, MMArchitectureQuantDDP 11 | 12 | __all__ = [ 13 | 'SingleTeacherDistill', 'BaseAlgorithm', 'FpnTeacherDistill', 'SPOS', 14 | 'SlimmableNetwork', 'SlimmableNetworkDDP', 'AutoSlim', 'AutoSlimDDP', 15 | 'Darts', 'DartsDDP', 'DCFF', 'SelfDistill', 'DataFreeDistillation', 16 | 'DAFLDataFreeDistillation', 'OverhaulFeatureDistillation', 17 | 'ItePruneAlgorithm', 'DSNAS', 'DSNASDDP', 'Autoformer', 'BigNAS', 18 | 'BigNASDDP', 'DMCP', 'DMCPDDP', 'MMArchitectureQuant', 19 | 'MMArchitectureQuantDDP' 20 | ] 21 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/algorithms/distill/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .configurable import (DAFLDataFreeDistillation, DataFreeDistillation, 3 | FpnTeacherDistill, OverhaulFeatureDistillation, 4 | SelfDistill, SingleTeacherDistill) 5 | 6 | __all__ = [ 7 | 'SingleTeacherDistill', 'FpnTeacherDistill', 'SelfDistill', 8 | 'DataFreeDistillation', 'DAFLDataFreeDistillation', 9 | 'OverhaulFeatureDistillation' 10 | ] 11 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/algorithms/distill/configurable/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .datafree_distillation import (DAFLDataFreeDistillation, 3 | DataFreeDistillation) 4 | from .fpn_teacher_distill import FpnTeacherDistill 5 | from .overhaul_feature_distillation import OverhaulFeatureDistillation 6 | from .self_distill import SelfDistill 7 | from .single_teacher_distill import SingleTeacherDistill 8 | 9 | __all__ = [ 10 | 'SelfDistill', 'SingleTeacherDistill', 'FpnTeacherDistill', 11 | 'DataFreeDistillation', 'DAFLDataFreeDistillation', 12 | 'OverhaulFeatureDistillation' 13 | ] 14 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/algorithms/nas/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .autoformer import Autoformer 3 | from .autoslim import AutoSlim, AutoSlimDDP 4 | from .bignas import BigNAS, BigNASDDP 5 | from .darts import Darts, DartsDDP 6 | from .dsnas import DSNAS, DSNASDDP 7 | from .spos import SPOS 8 | 9 | __all__ = [ 10 | 'SPOS', 'AutoSlim', 'AutoSlimDDP', 'BigNAS', 'BigNASDDP', 'Darts', 11 | 'DartsDDP', 'DSNAS', 'DSNASDDP', 'DSNAS', 'DSNASDDP', 'Autoformer' 12 | ] 13 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/algorithms/pruning/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .dcff import DCFF 3 | from .dmcp import DMCP, DMCPDDP 4 | from .slimmable_network import SlimmableNetwork, SlimmableNetworkDDP 5 | 6 | __all__ = [ 7 | 'SlimmableNetwork', 'SlimmableNetworkDDP', 'DCFF', 'DMCP', 'DMCPDDP' 8 | ] 9 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/algorithms/pruning/group_fisher_algoritho.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | """This file includes the modules in the impl folder. 3 | 4 | As it only records impl modules, it is not initialized automatically. 5 | """ 6 | from mmrazor.implementations.pruning.group_fisher import \ 7 | GroupFisherAlgorithm # noqa 8 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/algorithms/quantization/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .mm_architecture import MMArchitectureQuant, MMArchitectureQuantDDP 3 | 4 | __all__ = ['MMArchitectureQuant', 'MMArchitectureQuantDDP'] 5 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/architectures/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .backbones import * # noqa: F401,F403 3 | from .classifiers import * # noqa: F401,F403 4 | from .connectors import * # noqa: F401,F403 5 | from .dynamic_ops import * # noqa: F401,F403 6 | from .generators import * # noqa: F401,F403 7 | from .heads import * # noqa: F401,F403 8 | from .necks import * # noqa: F401,F403 9 | from .ops import * # noqa: F401,F403 10 | from .utils import * # noqa: F401,F403 11 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/architectures/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .darts_backbone import DartsBackbone 3 | from .searchable_autoformer import AutoformerBackbone 4 | from .searchable_mobilenet_v2 import SearchableMobileNetV2 5 | from .searchable_mobilenet_v3 import AttentiveMobileNetV3 6 | from .searchable_shufflenet_v2 import SearchableShuffleNetV2 7 | from .wideresnet import WideResNet 8 | 9 | __all__ = [ 10 | 'DartsBackbone', 'AutoformerBackbone', 'SearchableMobileNetV2', 11 | 'AttentiveMobileNetV3', 'SearchableShuffleNetV2', 'WideResNet' 12 | ] 13 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/architectures/classifiers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .image import SearchableImageClassifier 3 | 4 | __all__ = ['SearchableImageClassifier'] 5 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/architectures/connectors/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .byot_connector import BYOTConnector 3 | from .convmodule_connector import ConvModuleConnector 4 | from .crd_connector import CRDConnector 5 | from .factor_transfer_connectors import Paraphraser, Translator 6 | from .fbkd_connector import FBKDStudentConnector, FBKDTeacherConnector 7 | from .mgd_connector import MGDConnector 8 | from .norm_connector import NormConnector 9 | from .ofd_connector import OFDTeacherConnector 10 | from .torch_connector import TorchFunctionalConnector, TorchNNConnector 11 | 12 | __all__ = [ 13 | 'ConvModuleConnector', 'Translator', 'Paraphraser', 'BYOTConnector', 14 | 'FBKDTeacherConnector', 'FBKDStudentConnector', 'TorchFunctionalConnector', 15 | 'CRDConnector', 'TorchNNConnector', 'OFDTeacherConnector', 'MGDConnector', 16 | 'NormConnector' 17 | ] 18 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/architectures/connectors/norm_connector.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from typing import Dict, Optional 3 | 4 | import torch 5 | from mmcv.cnn import build_norm_layer 6 | 7 | from mmrazor.registry import MODELS 8 | from .base_connector import BaseConnector 9 | 10 | 11 | @MODELS.register_module() 12 | class NormConnector(BaseConnector): 13 | 14 | def __init__(self, in_channels, norm_cfg, init_cfg: Optional[Dict] = None): 15 | super(NormConnector, self).__init__(init_cfg) 16 | _, self.norm = build_norm_layer(norm_cfg, in_channels) 17 | 18 | def forward_train(self, feature: torch.Tensor) -> torch.Tensor: 19 | return self.norm(feature) 20 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/architectures/connectors/ofd_connector.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from typing import Dict, Optional 3 | 4 | import torch 5 | 6 | from mmrazor.registry import MODELS 7 | from .base_connector import BaseConnector 8 | 9 | 10 | @MODELS.register_module() 11 | class OFDTeacherConnector(BaseConnector): 12 | """Connector designed for ``OverhaulFeatureDistillation`` 13 | 14 | Args: 15 | init_cfg (Optional[Dict], optional): Initialization config dict. 16 | Defaults to None. 17 | """ 18 | 19 | def __init__(self, init_cfg: Optional[Dict] = None) -> None: 20 | super().__init__(init_cfg) 21 | self.margin: torch.Tensor = None 22 | 23 | def init_margin(self, margin: torch.Tensor) -> None: 24 | """Initializing margin, will be called by 25 | ``OverhaulFeatureDistillation``. 26 | 27 | Args: 28 | margin (torch.Tensor): margin 29 | """ 30 | self.margin = margin 31 | 32 | def forward_train(self, feature: torch.Tensor) -> torch.Tensor: 33 | """forward func for training.""" 34 | assert self.margin is not None, ( 35 | 'margin must be initialized before training.') 36 | self.margin = self.margin.to(feature.device) 37 | feature = torch.max(feature.detach(), self.margin) 38 | return feature 39 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/architectures/dynamic_ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .bricks import * # noqa: F401,F403 3 | from .head import * # noqa: F401,F403 4 | from .mixins import * # noqa: F401,F403 5 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/architectures/dynamic_ops/bricks/group_fisher_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | """This file includes the modules in the impl folder. 3 | 4 | As it only records impl modules, it is not initialized automatically. 5 | """ 6 | from mmrazor.implementations.pruning.group_fisher import \ 7 | GroupFisherConv2d # noqa 8 | from mmrazor.implementations.pruning.group_fisher import \ 9 | GroupFisherLinear # noqa 10 | from mmrazor.implementations.pruning.group_fisher import \ 11 | GroupFisherMixin # noqa 12 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/architectures/dynamic_ops/head/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .dynamic_linear_head import DynamicLinearClsHead # noqa: F401 3 | 4 | __all__ = ['DynamicLinearClsHead'] 5 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/architectures/dynamic_ops/mixins/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .dynamic_conv_mixins import DynamicConvMixin 3 | from .dynamic_layernorm_mixins import DynamicLayerNormMixin 4 | from .dynamic_mixins import (DynamicBatchNormMixin, DynamicChannelMixin, 5 | DynamicLinearMixin, DynamicMixin) 6 | 7 | __all__ = [ 8 | 'DynamicChannelMixin', 9 | 'DynamicBatchNormMixin', 10 | 'DynamicLinearMixin', 11 | 'DynamicMixin', 12 | 'DynamicConvMixin', 13 | 'DynamicLayerNormMixin', 14 | ] 15 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/architectures/generators/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .dafl_generator import DAFLGenerator 3 | from .zskt_generator import ZSKTGenerator 4 | 5 | __all__ = ['DAFLGenerator', 'ZSKTGenerator'] 6 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/architectures/heads/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .darts_subnet_head import DartsSubnetClsHead 3 | from .deit_head import DeiTClsHead 4 | 5 | __all__ = ['DartsSubnetClsHead', 'DeiTClsHead'] 6 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/architectures/necks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .squeezemean_with_dropout import SqueezeMeanPoolingWithDropout 3 | 4 | __all__ = ['SqueezeMeanPoolingWithDropout'] 5 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/architectures/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .common import Identity 3 | from .darts_series import (DartsDilConv, DartsPoolBN, DartsSepConv, 4 | DartsSkipConnect, DartsZero) 5 | from .efficientnet_series import ConvBnAct, DepthwiseSeparableConv 6 | from .function import InputResizer 7 | from .gather_tensors import GatherTensors 8 | from .mobilenet_series import MBBlock 9 | from .shufflenet_series import ShuffleBlock, ShuffleXception 10 | from .transformer_series import MultiheadAttention, RelativePosition2D 11 | 12 | __all__ = [ 13 | 'ShuffleBlock', 'ShuffleXception', 'DartsPoolBN', 'DartsDilConv', 14 | 'DartsSepConv', 'DartsSkipConnect', 'DartsZero', 'MBBlock', 'Identity', 15 | 'ConvBnAct', 'DepthwiseSeparableConv', 'GatherTensors', 'InputResizer', 16 | 'RelativePosition2D', 'MultiheadAttention' 17 | ] 18 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/architectures/ops/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmengine.model import BaseModule 3 | 4 | 5 | class BaseOP(BaseModule): 6 | """Base class for searchable operations. 7 | 8 | Args: 9 | in_channels (int): The input channels of the operation. 10 | out_channels (int): The output channels of the operation. 11 | stride (int): Stride of the operation. Defaults to 1. 12 | """ 13 | 14 | def __init__(self, in_channels, out_channels, stride=1, **kwargs): 15 | super(BaseOP, self).__init__(**kwargs) 16 | 17 | self.in_channels = in_channels 18 | self.out_channels = out_channels 19 | self.stride = stride 20 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/architectures/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .mutable_register import mutate_conv_module, mutate_mobilenet_layer 3 | from .set_dropout import set_dropout 4 | 5 | __all__ = ['mutate_conv_module', 'mutate_mobilenet_layer', 'set_dropout'] 6 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/distillers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .base_distiller import BaseDistiller 3 | from .byot_distiller import BYOTDistiller 4 | from .configurable_distiller import ConfigurableDistiller 5 | from .ofd_distiller import OFDDistiller 6 | 7 | __all__ = [ 8 | 'ConfigurableDistiller', 'BaseDistiller', 'BYOTDistiller', 'OFDDistiller' 9 | ] 10 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/distillers/base_distiller.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from abc import ABC, abstractmethod 3 | from typing import Dict, Optional 4 | 5 | from mmengine.model import BaseModule 6 | 7 | from ..algorithms.base import LossResults 8 | 9 | 10 | class BaseDistiller(BaseModule, ABC): 11 | """Base class for distiller. 12 | 13 | Args: 14 | init_cfg (dict, optional): Config for distiller. Default to None. 15 | """ 16 | 17 | def __init__(self, init_cfg: Optional[Dict] = None) -> None: 18 | super().__init__(init_cfg) 19 | 20 | @abstractmethod 21 | def compute_distill_losses(self) -> LossResults: 22 | """Compute distill losses automatically.""" 23 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/fake_quants/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .base import BaseFakeQuantize 3 | from .lsq import (LearnableFakeQuantize, enable_param_learning, 4 | enable_static_estimate, enable_val) 5 | from .torch_fake_quants import register_torch_fake_quants 6 | 7 | __all__ = [ 8 | 'BaseFakeQuantize', 'register_torch_fake_quants', 'LearnableFakeQuantize', 9 | 'enable_val', 'enable_param_learning', 'enable_static_estimate' 10 | ] 11 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/fake_quants/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | try: 3 | from torch.ao.quantization import FakeQuantize 4 | except ImportError: 5 | from mmrazor.utils import get_placeholder 6 | FakeQuantize = get_placeholder('torch>=1.13') 7 | 8 | BaseFakeQuantize = FakeQuantize 9 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .ab_loss import ABLoss 3 | from .at_loss import ATLoss 4 | from .crd_loss import CRDLoss 5 | from .cross_entropy_loss import CrossEntropyLoss 6 | from .cwd import ChannelWiseDivergence 7 | from .dafl_loss import ActivationLoss, InformationEntropyLoss, OnehotLikeLoss 8 | from .decoupled_kd import DKDLoss 9 | from .dist_loss import DISTLoss 10 | from .factor_transfer_loss import FTLoss 11 | from .fbkd_loss import FBKDLoss 12 | from .kd_soft_ce_loss import KDSoftCELoss 13 | from .kl_divergence import KLDivergence 14 | from .l1_loss import L1Loss 15 | from .l2_loss import L2Loss 16 | from .mgd_loss import MGDLoss 17 | from .ofd_loss import OFDLoss 18 | from .pkd_loss import PKDLoss 19 | from .relational_kd import AngleWiseRKD, DistanceWiseRKD 20 | from .weighted_soft_label_distillation import WSLD 21 | 22 | __all__ = [ 23 | 'ChannelWiseDivergence', 'KLDivergence', 'AngleWiseRKD', 'DistanceWiseRKD', 24 | 'WSLD', 'L2Loss', 'ABLoss', 'DKDLoss', 'KDSoftCELoss', 'ActivationLoss', 25 | 'OnehotLikeLoss', 'InformationEntropyLoss', 'FTLoss', 'ATLoss', 'OFDLoss', 26 | 'L1Loss', 'FBKDLoss', 'CRDLoss', 'CrossEntropyLoss', 'PKDLoss', 'MGDLoss', 27 | 'DISTLoss' 28 | ] 29 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/losses/cross_entropy_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from mmrazor.registry import MODELS 6 | 7 | 8 | @MODELS.register_module() 9 | class CrossEntropyLoss(nn.Module): 10 | """Cross entropy loss. 11 | 12 | Args: 13 | loss_weight (float): Weight of the loss. Defaults to 1.0. 14 | """ 15 | 16 | def __init__(self, loss_weight=1.0): 17 | super(CrossEntropyLoss, self).__init__() 18 | self.loss_weight = loss_weight 19 | 20 | def forward(self, preds_S, preds_T): 21 | preds_T = preds_T.detach() 22 | loss = F.cross_entropy(preds_S, preds_T.argmax(dim=1)) 23 | return loss * self.loss_weight 24 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/mutables/mutable_channel/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .base_mutable_channel import BaseMutableChannel 3 | from .mutable_channel_container import MutableChannelContainer 4 | from .oneshot_mutable_channel import OneShotMutableChannel 5 | from .sequential_mutable_channel import SquentialMutableChannel 6 | from .simple_mutable_channel import SimpleMutableChannel 7 | from .units import (ChannelUnitType, DCFFChannelUnit, DMCPChannelUnit, 8 | L1MutableChannelUnit, MutableChannelUnit, 9 | OneShotMutableChannelUnit, SequentialMutableChannelUnit, 10 | SlimmableChannelUnit) 11 | 12 | __all__ = [ 13 | 'SimpleMutableChannel', 'L1MutableChannelUnit', 14 | 'SequentialMutableChannelUnit', 'MutableChannelUnit', 15 | 'OneShotMutableChannelUnit', 'SlimmableChannelUnit', 'BaseMutableChannel', 16 | 'MutableChannelContainer', 'SquentialMutableChannel', 'ChannelUnitType', 17 | 'DCFFChannelUnit', 'OneShotMutableChannel', 'DMCPChannelUnit' 18 | ] 19 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/mutables/mutable_channel/units/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .dcff_channel_unit import DCFFChannelUnit 3 | from .dmcp_channel_unit import DMCPChannelUnit 4 | from .l1_mutable_channel_unit import L1MutableChannelUnit 5 | from .mutable_channel_unit import ChannelUnitType, MutableChannelUnit 6 | from .one_shot_mutable_channel_unit import OneShotMutableChannelUnit 7 | from .sequential_mutable_channel_unit import SequentialMutableChannelUnit 8 | from .slimmable_channel_unit import SlimmableChannelUnit 9 | 10 | __all__ = [ 11 | 'L1MutableChannelUnit', 'MutableChannelUnit', 12 | 'SequentialMutableChannelUnit', 'OneShotMutableChannelUnit', 13 | 'SlimmableChannelUnit', 'ChannelUnitType', 'DCFFChannelUnit', 14 | 'DMCPChannelUnit' 15 | ] 16 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/mutables/mutable_channel/units/group_fisher_unit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | """This file includes the modules in the impl folder. 3 | 4 | As it only records impl modules, it is not initialized automatically. 5 | """ 6 | from mmrazor.implementations.pruning.group_fisher import \ 7 | GroupFisherChannelUnit # noqa 8 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/mutables/mutable_module/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .diff_mutable_module import (DiffChoiceRoute, DiffMutableModule, 3 | DiffMutableOP, OneHotMutableOP) 4 | from .mutable_module import MutableModule 5 | from .one_shot_mutable_module import OneShotMutableModule, OneShotMutableOP 6 | 7 | __all__ = [ 8 | 'DiffMutableModule', 'DiffMutableOP', 'DiffChoiceRoute', 9 | 'OneShotMutableOP', 'OneShotMutableModule', 'MutableModule', 10 | 'OneHotMutableOP' 11 | ] 12 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/mutables/mutable_value/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .mutable_value import MutableValue, OneShotMutableValue 3 | 4 | __all__ = ['MutableValue', 'OneShotMutableValue'] 5 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/mutators/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .channel_mutator import (ChannelMutator, DCFFChannelMutator, 3 | DMCPChannelMutator, OneShotChannelMutator, 4 | SlimmableChannelMutator) 5 | from .nas_mutator import NasMutator 6 | 7 | __all__ = [ 8 | 'ChannelMutator', 'DCFFChannelMutator', 'DMCPChannelMutator', 9 | 'SlimmableChannelMutator', 'NasMutator', 'OneShotChannelMutator' 10 | ] 11 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/mutators/channel_mutator/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .channel_mutator import ChannelMutator 3 | from .dcff_channel_mutator import DCFFChannelMutator 4 | from .dmcp_channel_mutator import DMCPChannelMutator 5 | from .one_shot_channel_mutator import OneShotChannelMutator 6 | from .slimmable_channel_mutator import SlimmableChannelMutator 7 | 8 | __all__ = [ 9 | 'SlimmableChannelMutator', 'ChannelMutator', 'OneShotChannelMutator', 10 | 'DCFFChannelMutator', 'DMCPChannelMutator' 11 | ] 12 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/mutators/channel_mutator/group_fisher_mutator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | """This file includes the modules in the impl folder. 3 | 4 | As it only records impl modules, it is not initialized automatically. 5 | """ 6 | from mmrazor.implementations.pruning.group_fisher import \ 7 | GroupFisherChannelMutator # noqa 8 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/observers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .base import BaseObserver 3 | from .lsq import LSQObserver, LSQPerChannelObserver 4 | from .torch_observers import register_torch_observers 5 | 6 | __all__ = [ 7 | 'BaseObserver', 'register_torch_observers', 'LSQObserver', 8 | 'LSQPerChannelObserver' 9 | ] 10 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/observers/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | try: 3 | from torch.ao.quantization.observer import UniformQuantizationObserverBase 4 | except ImportError: 5 | from mmrazor.utils import get_placeholder 6 | UniformQuantizationObserverBase = get_placeholder('torch>=1.13') 7 | 8 | BaseObserver = UniformQuantizationObserverBase 9 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/quantizers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .academic_quantizer import AcademicQuantizer 3 | from .base import BaseQuantizer 4 | from .native_quantizer import TorchNativeQuantizer 5 | from .openvino_quantizer import OpenVINOQuantizer 6 | from .tensorrt_quantizer import TensorRTQuantizer 7 | 8 | __all__ = [ 9 | 'BaseQuantizer', 'AcademicQuantizer', 'TorchNativeQuantizer', 10 | 'TensorRTQuantizer', 'OpenVINOQuantizer' 11 | ] 12 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/quantizers/exporters/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .openvino_quantize_exporter import OpenVinoQuantizeExportor 3 | from .tensorrt_quantize_exporter import TensorRTExplicitExporter 4 | 5 | __all__ = ['OpenVinoQuantizeExportor', 'TensorRTExplicitExporter'] 6 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/task_modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .delivery import * # noqa: F401,F403 3 | from .demo_inputs import * # noqa: F401,F403 4 | from .estimators import ResourceEstimator 5 | from .predictor import * # noqa: F401,F403 6 | from .recorder import * # noqa: F401,F403 7 | from .tracer import * # noqa: F401,F403 8 | 9 | __all__ = ['ResourceEstimator'] 10 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/task_modules/delivery/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .delivery_manager import DistillDeliveryManager 3 | from .function_outputs_delivery import FunctionOutputsDelivery 4 | from .method_outputs_delivery import MethodOutputsDelivery 5 | 6 | __all__ = [ 7 | 'FunctionOutputsDelivery', 'MethodOutputsDelivery', 8 | 'DistillDeliveryManager' 9 | ] 10 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/task_modules/demo_inputs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .default_demo_inputs import DefaultDemoInput, defaul_demo_inputs 3 | from .demo_inputs import (BaseDemoInput, DefaultMMClsDemoInput, 4 | DefaultMMDemoInput, DefaultMMDetDemoInput, 5 | DefaultMMSegDemoInput) 6 | 7 | __all__ = [ 8 | 'defaul_demo_inputs', 9 | 'DefaultMMClsDemoInput', 10 | 'DefaultMMDetDemoInput', 11 | 'DefaultMMDemoInput', 12 | 'DefaultMMSegDemoInput', 13 | 'BaseDemoInput', 14 | 'DefaultDemoInput', 15 | ] 16 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/task_modules/estimators/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .counters import * # noqa: F401,F403 3 | from .resource_estimator import ResourceEstimator 4 | 5 | __all__ = ['ResourceEstimator'] 6 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/task_modules/estimators/counters/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .flops_params_counter import get_model_flops_params 3 | from .latency_counter import get_model_latency 4 | from .op_counters import * # noqa: F401,F403 5 | 6 | __all__ = ['get_model_flops_params', 'get_model_latency'] 7 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/task_modules/estimators/counters/op_counters/base_counter.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from abc import ABCMeta, abstractclassmethod 3 | 4 | 5 | class BaseCounter(object, metaclass=ABCMeta): 6 | """Base class of all op module counters in `TASK_UTILS`. 7 | 8 | In ResourceEstimator, `XXModuleCounter` is responsible for `XXModule`, 9 | which refers to estimator/flops_params_counter.py::get_counter_type(). 10 | Users can customize a `ModuleACounter` and overwrite the `add_count_hook` 11 | method with a self-defined module `ModuleA`. 12 | """ 13 | 14 | def __init__(self) -> None: 15 | pass 16 | 17 | @staticmethod 18 | @abstractclassmethod 19 | def add_count_hook(module, input, output): 20 | """The main method of a `BaseCounter` which defines the way to 21 | calculate resources(flops/params) of the current module. 22 | 23 | Args: 24 | module (nn.Module): the module to be tested. 25 | input (_type_): input_tensor. Plz refer to `torch forward_hook` 26 | output (_type_): output_tensor. Plz refer to `torch forward_hook` 27 | """ 28 | pass 29 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/task_modules/estimators/counters/op_counters/group_fisher_counters.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | """This file includes the modules in the impl folder. 3 | 4 | As it only records impl modules, it is not initialized automatically. 5 | """ 6 | from mmrazor.implementations.pruning.group_fisher import ( # noqa 7 | GroupFisherConv2dCounter, GroupFisherLinearCounter) 8 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/task_modules/estimators/counters/op_counters/linear_layer_counter.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import numpy as np 3 | 4 | from mmrazor.registry import TASK_UTILS 5 | from ..flops_params_counter import get_model_parameters_number 6 | from .base_counter import BaseCounter 7 | 8 | 9 | @TASK_UTILS.register_module() 10 | class LinearCounter(BaseCounter): 11 | """FLOPs/params counter for Linear operation series.""" 12 | 13 | @staticmethod 14 | def add_count_hook(module, input, output): 15 | """Calculate FLOPs and params based on the size of input & output.""" 16 | input = input[0] 17 | output_last_dim = output.shape[ 18 | -1] # pytorch checks dimensions, so here we don't care much 19 | module.__flops__ += int(np.prod(input.shape) * output_last_dim) 20 | module.__params__ += get_model_parameters_number(module) 21 | 22 | 23 | @TASK_UTILS.register_module() 24 | class DynamicLinearCounter(LinearCounter): 25 | pass 26 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/task_modules/estimators/counters/op_counters/upsample_layer_counter.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmrazor.registry import TASK_UTILS 3 | from ..flops_params_counter import get_model_parameters_number 4 | from .base_counter import BaseCounter 5 | 6 | 7 | @TASK_UTILS.register_module() 8 | class UpsampleCounter(BaseCounter): 9 | """FLOPs/params counter for Upsample function.""" 10 | 11 | @staticmethod 12 | def add_count_hook(module, input, output): 13 | """Calculate FLOPs and params based on the size of input & output.""" 14 | output_size = output[0] 15 | batch_size = output_size.shape[0] 16 | output_elements_count = batch_size 17 | for val in output_size.shape[1:]: 18 | output_elements_count *= val 19 | module.__flops__ += int(output_elements_count) 20 | module.__params__ += get_model_parameters_number(module) 21 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/task_modules/predictor/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .metric_predictor import MetricPredictor 3 | 4 | __all__ = ['MetricPredictor'] 5 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/task_modules/predictor/handler/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .carts_handler import CartsHandler 3 | from .gp_handler import GaussProcessHandler 4 | from .mlp_handler import MLPHandler 5 | from .rbf_handler import RBFHandler 6 | 7 | __all__ = ['CartsHandler', 'GaussProcessHandler', 'MLPHandler', 'RBFHandler'] 8 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/task_modules/predictor/handler/base_handler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from joblib import dump, load 3 | 4 | 5 | class BaseHandler: 6 | """Base class for a handler. 7 | 8 | Note: 9 | The handler works through a specific machine leanring algorithm, 10 | and is designed for predicting the evaluation metric of a model. 11 | """ 12 | 13 | def __init__(self) -> None: 14 | pass 15 | 16 | def fit(self, train_data, train_label): 17 | """Training the model of handler.""" 18 | pass 19 | 20 | def predict(self, test_data): 21 | """Predicting the metric using the model of handler.""" 22 | pass 23 | 24 | def load(self, path): 25 | """Load pretrained weights for the handler.""" 26 | self.model = load(path) 27 | 28 | def save(self, path): 29 | """Save the handler and return saved path for diff suffix.""" 30 | path += f'_{self.__class__.__name__}.joblib'.lower() 31 | dump(self.model, path) 32 | return path 33 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/task_modules/recorder/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .function_inputs_recorder import FunctionInputsRecorder 3 | from .function_outputs_recorder import FunctionOutputsRecorder 4 | from .method_inputs_recorder import MethodInputsRecorder 5 | from .method_outputs_recorder import MethodOutputsRecorder 6 | from .module_inputs_recorder import ModuleInputsRecorder 7 | from .module_outputs_recorder import ModuleOutputsRecorder 8 | from .param_recorder import ParameterRecorder 9 | from .recorder_manager import RecorderManager 10 | 11 | __all__ = [ 12 | 'FunctionOutputsRecorder', 'MethodOutputsRecorder', 13 | 'ModuleOutputsRecorder', 'ParameterRecorder', 'RecorderManager', 14 | 'ModuleInputsRecorder', 'MethodInputsRecorder', 'FunctionInputsRecorder' 15 | ] 16 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/task_modules/recorder/module_inputs_recorder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from typing import Any, Tuple 3 | 4 | from torch import nn 5 | 6 | from mmrazor.registry import TASK_UTILS 7 | from .module_outputs_recorder import ModuleOutputsRecorder 8 | 9 | 10 | @TASK_UTILS.register_module() 11 | class ModuleInputsRecorder(ModuleOutputsRecorder): 12 | """Recorder for intermediate results which are Pytorch moudle's inputs.""" 13 | 14 | def __init__(self, *args, **kwargs): 15 | super().__init__(*args, **kwargs) 16 | 17 | def forward_hook(self, module: nn.Module, inputs: Tuple, 18 | outputs: Any) -> None: 19 | """Save the module's forward input. 20 | 21 | Args: 22 | module (:obj:`torch.nn.Module`): The module to register hook. 23 | inputs (tuple): The input of the module. 24 | outputs : The output of the module. 25 | """ 26 | if self.recording: 27 | self.data_buffer.append(inputs) 28 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/task_modules/tracer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .backward_tracer import BackwardTracer 3 | from .channel_analyzer import ChannelAnalyzer 4 | # from .razor_tracer import RazorFxTracer 5 | from .fx import (CustomTracer, UntracedMethodRegistry, build_graphmodule, 6 | custom_symbolic_trace) 7 | from .loss_calculator import * # noqa: F401,F403 8 | from .parsers import * # noqa: F401,F403 9 | from .path import (Path, PathConcatNode, PathConvNode, PathDepthWiseConvNode, 10 | PathLinearNode, PathList, PathNode, PathNormNode) 11 | 12 | __all__ = [ 13 | 'BackwardTracer', 'PathConvNode', 'PathLinearNode', 'PathNormNode', 14 | 'PathConcatNode', 'Path', 'PathList', 'PathNode', 'PathDepthWiseConvNode', 15 | 'ChannelAnalyzer', 'CustomTracer', 'UntracedMethodRegistry', 16 | 'custom_symbolic_trace', 'build_graphmodule' 17 | ] 18 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/task_modules/tracer/fx/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .custom_tracer import (CustomTracer, UntracedMethodRegistry, 3 | build_graphmodule, custom_symbolic_trace) 4 | from .graph_utils import (del_fakequant_after_function, 5 | del_fakequant_after_method, 6 | del_fakequant_after_module, del_fakequant_after_op, 7 | del_fakequant_before_function, 8 | del_fakequant_before_method, 9 | del_fakequant_before_module, del_fakequant_before_op) 10 | 11 | __all__ = [ 12 | 'CustomTracer', 'UntracedMethodRegistry', 'custom_symbolic_trace', 13 | 'build_graphmodule', 'del_fakequant_before_module', 14 | 'del_fakequant_after_module', 'del_fakequant_after_function', 15 | 'del_fakequant_before_function', 'del_fakequant_after_op', 16 | 'del_fakequant_before_op', 'del_fakequant_before_method', 17 | 'del_fakequant_after_method' 18 | ] 19 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/task_modules/tracer/loss_calculator/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .cascade_encoder_decoder_loss_calculator import \ 3 | CascadeEncoderDecoderPseudoLoss 4 | from .image_classifier_loss_calculator import ImageClassifierPseudoLoss 5 | from .single_stage_detector_loss_calculator import \ 6 | SingleStageDetectorPseudoLoss 7 | from .sum_loss_calculator import SumPseudoLoss 8 | from .top_down_pose_estimator_loss_calculator import \ 9 | TopdownPoseEstimatorPseudoLoss 10 | from .two_stage_detector_loss_calculator import TwoStageDetectorPseudoLoss 11 | 12 | __all__ = [ 13 | 'ImageClassifierPseudoLoss', 'SingleStageDetectorPseudoLoss', 14 | 'TwoStageDetectorPseudoLoss', 'TopdownPoseEstimatorPseudoLoss', 15 | 'CascadeEncoderDecoderPseudoLoss', 'SumPseudoLoss' 16 | ] 17 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/task_modules/tracer/loss_calculator/cascade_encoder_decoder_loss_calculator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | 4 | from mmrazor.registry import TASK_UTILS 5 | 6 | try: 7 | from mmseg.models import CascadeEncoderDecoder 8 | except ImportError: 9 | from mmrazor.utils import get_placeholder 10 | CascadeEncoderDecoder = get_placeholder('mmseg') 11 | 12 | 13 | @TASK_UTILS.register_module() 14 | class CascadeEncoderDecoderPseudoLoss: 15 | """Calculate the pseudo loss to trace the topology of a 16 | `CascadeEncoderDecoder` in MMSegmentation with `BackwardTracer`.""" 17 | 18 | def __call__(self, model: CascadeEncoderDecoder) -> torch.Tensor: 19 | pseudo_img = torch.rand(1, 3, 224, 224) 20 | pseudo_output = model.backbone(pseudo_img) 21 | pseudo_output = model.neck(pseudo_output) 22 | # unmodified decode_heads 23 | out = torch.tensor(0.) 24 | for levels in pseudo_output: 25 | out += sum([level.sum() for level in levels]) 26 | return out 27 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/task_modules/tracer/loss_calculator/image_classifier_loss_calculator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | 4 | from mmrazor.registry import TASK_UTILS 5 | 6 | try: 7 | from mmcls.models import ImageClassifier 8 | except ImportError: 9 | from mmrazor.utils import get_placeholder 10 | ImageClassifier = get_placeholder('mmcls') 11 | 12 | 13 | @TASK_UTILS.register_module() 14 | class ImageClassifierPseudoLoss: 15 | """Calculate the pseudo loss to trace the topology of a `ImageClassifier` 16 | in MMClassification with `BackwardTracer`. 17 | 18 | Args: 19 | input_shape (Tuple): The shape of the pseudo input. Defaults to 20 | (2, 3, 224, 224). 21 | """ 22 | 23 | def __init__(self, input_shape=(2, 3, 224, 224)): 24 | self.input_shape = input_shape 25 | 26 | def __call__(self, model: ImageClassifier) -> torch.Tensor: 27 | pseudo_img = torch.rand(self.input_shape) 28 | pseudo_output = model(pseudo_img) 29 | return pseudo_output.sum() 30 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/task_modules/tracer/loss_calculator/single_stage_detector_loss_calculator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | 4 | from mmrazor.registry import TASK_UTILS 5 | 6 | try: 7 | from mmdet.models import SingleStageDetector 8 | except ImportError: 9 | from mmrazor.utils import get_placeholder 10 | SingleStageDetector = get_placeholder('mmdet') 11 | 12 | 13 | @TASK_UTILS.register_module() 14 | class SingleStageDetectorPseudoLoss: 15 | """Calculate the pseudo loss to trace the topology of a 16 | `SingleStageDetector` in MMDetection with `BackwardTracer`. 17 | 18 | Args: 19 | input_shape (Tuple): The shape of the pseudo input. Defaults to 20 | (2, 3, 224, 224). 21 | """ 22 | 23 | def __init__(self, input_shape=(2, 3, 224, 224)): 24 | self.input_shape = input_shape 25 | 26 | def __call__(self, model: SingleStageDetector) -> torch.Tensor: 27 | pseudo_img = torch.rand(self.input_shape) 28 | pseudo_output = model(pseudo_img) 29 | out = torch.tensor(0.) 30 | for levels in pseudo_output: 31 | out += sum([level.sum() for level in levels]) 32 | 33 | return out 34 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/task_modules/tracer/loss_calculator/top_down_pose_estimator_loss_calculator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | 4 | from mmrazor.registry import TASK_UTILS 5 | 6 | try: 7 | from mmpose.models import TopdownPoseEstimator 8 | except ImportError: 9 | from mmrazor.utils import get_placeholder 10 | TopdownPoseEstimator = get_placeholder('mmpose') 11 | 12 | 13 | @TASK_UTILS.register_module() 14 | class TopdownPoseEstimatorPseudoLoss: 15 | """Calculate the pseudo loss to trace the topology of a 16 | `TopdownPoseEstimator` in MMPose with `BackwardTracer`.""" 17 | 18 | def __call__(self, model: TopdownPoseEstimator) -> torch.Tensor: 19 | pseudo_img = torch.rand(1, 3, 224, 224) 20 | pseudo_output = model.backbone(pseudo_img) 21 | # immutable decode_heads 22 | out = torch.tensor(0.) 23 | for levels in pseudo_output: 24 | out += sum([level.sum() for level in levels]) 25 | return out 26 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/task_modules/tracer/loss_calculator/two_stage_detector_loss_calculator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | 4 | from mmrazor.registry import TASK_UTILS 5 | 6 | try: 7 | from mmdet.models import TwoStageDetector 8 | except ImportError: 9 | from mmrazor.utils import get_placeholder 10 | TwoStageDetector = get_placeholder('mmdet') 11 | 12 | 13 | # todo: adapt to mmdet 2.0 14 | @TASK_UTILS.register_module() 15 | class TwoStageDetectorPseudoLoss: 16 | """Calculate the pseudo loss to trace the topology of a `TwoStageDetector` 17 | in MMDet with `BackwardTracer`.""" 18 | 19 | def __call__(self, model: TwoStageDetector) -> torch.Tensor: 20 | pseudo_img = torch.rand(1, 3, 224, 224) 21 | pseudo_output = model.backbone(pseudo_img) 22 | pseudo_output = model.neck(pseudo_output) 23 | out = torch.tensor(0.) 24 | for levels in pseudo_output: 25 | out += sum([level.sum() for level in levels]) 26 | 27 | return out 28 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .make_divisible import make_divisible 3 | from .misc import add_prefix 4 | from .optim_wrapper import reinitialize_optim_wrapper_count_status 5 | from .parse_values import parse_values 6 | from .quantization_util import pop_rewriter_function_record, str2class 7 | from .utils import get_module_device, set_requires_grad 8 | 9 | __all__ = [ 10 | 'make_divisible', 'add_prefix', 'reinitialize_optim_wrapper_count_status', 11 | 'str2class', 'get_module_device', 'set_requires_grad', 'parse_values', 12 | 'pop_rewriter_function_record' 13 | ] 14 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/utils/expandable_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | """This module is used to expand the channels of a supernet. 3 | 4 | We only expose some tool functions, rather than all DynamicOps and 5 | MutableChannelUnits, as They uses a few hacky operations. 6 | """ 7 | from .tools import (expand_expandable_dynamic_model, expand_static_model,ExpandableUnit, 8 | make_channel_divisible, to_expandable_model) 9 | 10 | __all__ = [ 11 | 'make_channel_divisible', 12 | 'to_expandable_model', 13 | 'expand_expandable_dynamic_model', 14 | 'expand_static_model', 15 | 'ExpandableUnit' 16 | ] 17 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from typing import Dict 3 | 4 | 5 | def add_prefix(inputs: Dict, prefix: str) -> Dict: 6 | """Add prefix for dict. 7 | 8 | Args: 9 | inputs (dict): The input dict with str keys. 10 | prefix (str): The prefix to add. 11 | 12 | Returns: 13 | dict: The dict with keys updated with ``prefix``. 14 | """ 15 | 16 | outputs = dict() 17 | for name, value in inputs.items(): 18 | outputs[f'{prefix}.{name}'] = value 19 | 20 | return outputs 21 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/utils/parse_values.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from typing import List 3 | 4 | 5 | def parse_values(candidate_lists: List[list]): 6 | """Parse a list with format `(min_range, max_range, step)`. 7 | 8 | NOTE: this method is required when customizing search space in configs. 9 | """ 10 | 11 | def _range_to_list(input_range: List[int]) -> List[int]: 12 | assert len(input_range) == 3, ( 13 | 'The format should be `(min_range, max_range, step)` with dim=3, ' 14 | f'but got dim={len(input_range)}.') 15 | start, end, step = input_range 16 | return list(range(start, end + 1, step)) 17 | 18 | return [_range_to_list(i) for i in candidate_lists] 19 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/models/utils/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from typing import List, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | def get_module_device(module: nn.Module) -> torch.device: 9 | """Get the device of a module. 10 | 11 | Args: 12 | module (nn.Module): A module contains the parameters. 13 | """ 14 | try: 15 | next(module.parameters()) 16 | except StopIteration as e: 17 | raise ValueError('The input module should contain parameters.') from e 18 | 19 | if next(module.parameters()).is_cuda: 20 | return next(module.parameters()).get_device() 21 | 22 | return torch.device('cpu') 23 | 24 | 25 | def set_requires_grad(nets: Union[nn.Module, List[nn.Module]], 26 | requires_grad: bool = False) -> None: 27 | """Set requires_grad for all the networks. 28 | 29 | Args: 30 | nets (nn.Module | list[nn.Module]): A list of networks or a single 31 | network. 32 | requires_grad (bool): Whether the networks require gradients or not 33 | """ 34 | if not isinstance(nets, list): 35 | nets = [nets] 36 | for net in nets: 37 | if net is not None: 38 | for param in net.parameters(): 39 | param.requires_grad = requires_grad 40 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/registry/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS, METRICS, 3 | MODEL_WRAPPERS, MODELS, OPTIM_WRAPPER_CONSTRUCTORS, 4 | OPTIM_WRAPPERS, OPTIMIZERS, PARAM_SCHEDULERS, 5 | RUNNER_CONSTRUCTORS, RUNNERS, TASK_UTILS, TRANSFORMS, 6 | VISBACKENDS, VISUALIZERS, WEIGHT_INITIALIZERS, 7 | sub_model) 8 | 9 | __all__ = [ 10 | 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS', 'DATASETS', 'DATA_SAMPLERS', 11 | 'TRANSFORMS', 'MODELS', 'WEIGHT_INITIALIZERS', 'OPTIMIZERS', 12 | 'OPTIM_WRAPPERS', 'OPTIM_WRAPPER_CONSTRUCTORS', 'TASK_UTILS', 13 | 'PARAM_SCHEDULERS', 'METRICS', 'MODEL_WRAPPERS', 'LOOPS', 'VISBACKENDS', 14 | 'VISUALIZERS', 'sub_model' 15 | ] 16 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/structures/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .quantization import * # noqa: F401,F403 3 | from .subnet import * # noqa: F401,F403 4 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/structures/graph/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .base_graph import BaseGraph, BaseNode 3 | from .module_graph import ModuleGraph, ModuleNode 4 | 5 | __all__ = ['BaseGraph', 'BaseNode', 'ModuleNode', 'ModuleGraph'] 6 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/structures/quantization/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .backend_config import * # noqa: F401,F403 3 | from .qconfig import * # noqa: F401,F403 4 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/structures/quantization/backend_config/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .academic import (get_academic_backend_config, 3 | get_academic_backend_config_dict) 4 | from .mapping import BackendConfigs 5 | from .native import get_native_backend_config, get_native_backend_config_dict 6 | from .openvino import (get_openvino_backend_config, 7 | get_openvino_backend_config_dict) 8 | from .tensorrt import (get_tensorrt_backend_config, 9 | get_tensorrt_backend_config_dict) 10 | 11 | __all__ = [ 12 | 'BackendConfigs', 13 | 'get_native_backend_config', 14 | 'get_native_backend_config_dict', 15 | 'get_academic_backend_config', 16 | 'get_academic_backend_config_dict', 17 | 'get_openvino_backend_config', 18 | 'get_openvino_backend_config_dict', 19 | 'get_tensorrt_backend_config', 20 | 'get_tensorrt_backend_config_dict', 21 | ] 22 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/structures/quantization/backend_config/mapping.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | 4 | from mmrazor import digit_version 5 | from .academic import get_academic_backend_config 6 | from .native import get_native_backend_config 7 | from .openvino import get_openvino_backend_config 8 | from .tensorrt import get_tensorrt_backend_config 9 | 10 | if digit_version( 11 | torch.__version__) >= digit_version('1.13.0') and digit_version( 12 | torch.__version__) <= digit_version('1.13.1'): 13 | BackendConfigs = { 14 | 'academic': get_academic_backend_config(), 15 | 'native': get_native_backend_config(), 16 | 'tensorrt': get_tensorrt_backend_config(), 17 | 'openvino': get_openvino_backend_config() 18 | } 19 | else: 20 | BackendConfigs = { 21 | 'academic': None, 22 | 'native': None, 23 | 'tensorrt': None, 24 | 'openvino': None 25 | } 26 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/structures/subnet/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .candidate import Candidates 3 | from .fix_subnet import convert_fix_subnet, export_fix_subnet, load_fix_subnet 4 | 5 | __all__ = [ 6 | 'load_fix_subnet', 'export_fix_subnet', 'convert_fix_subnet', 'Candidates' 7 | ] 8 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/testing/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from ._fast_stop_training_hook import FastStopTrainingHook # noqa: F401,F403 3 | from ._fx_models import * # noqa: F401, F403 4 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/testing/_fast_stop_training_hook.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmengine.hooks import Hook 3 | 4 | from mmrazor.registry import HOOKS 5 | 6 | 7 | @HOOKS.register_module() 8 | class FastStopTrainingHook(Hook): 9 | """Set runner's epoch information to the model.""" 10 | 11 | def __init__(self, by_epoch, save_ckpt=False, stop_iter_or_epoch=5): 12 | self.by_epoch = by_epoch 13 | self.save_ckpt = save_ckpt 14 | self.stop_iter_or_epoch = stop_iter_or_epoch 15 | 16 | def after_train_iter(self, runner, batch_idx: int, data_batch: None, 17 | outputs: None) -> None: 18 | if self.save_ckpt and self.by_epoch: 19 | # If it is epoch-based and want to save weights, 20 | # we must run at least 1 epoch. 21 | return 22 | if runner.iter >= self.stop_iter_or_epoch: 23 | raise RuntimeError('quick exit') 24 | 25 | def after_train_epoch(self, runner) -> None: 26 | if runner.epoch >= self.stop_iter_or_epoch - 1: 27 | raise RuntimeError('quick exit') 28 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .index_dict import IndexDict 3 | from .log_tools import get_level, print_log 4 | from .misc import find_latest_checkpoint 5 | from .placeholder import get_package_placeholder, get_placeholder 6 | from .runtime_info import RuntimeInfo 7 | from .setup_env import register_all_modules, setup_multi_processes 8 | from .typing import (FixMutable, MultiMutatorsRandomSubnet, 9 | SingleMutatorRandomSubnet, SupportRandomSubnet, 10 | ValidFixMutable) 11 | 12 | __all__ = [ 13 | 'find_latest_checkpoint', 'setup_multi_processes', 'register_all_modules', 14 | 'FixMutable', 'ValidFixMutable', 'SingleMutatorRandomSubnet', 15 | 'MultiMutatorsRandomSubnet', 'SupportRandomSubnet', 'get_placeholder', 16 | 'IndexDict', 'get_level', 'print_log', 'RuntimeInfo', 17 | 'get_package_placeholder' 18 | ] 19 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/utils/log_tools.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import logging 3 | 4 | import torch.distributed as dist 5 | from mmengine import MMLogger 6 | from mmengine import print_log as engine_print_log 7 | 8 | 9 | def get_level(level='info'): 10 | if isinstance(level, str): 11 | level = level.upper() 12 | assert level in logging._nameToLevel 13 | level = logging._nameToLevel[level] 14 | elif isinstance(level, int): 15 | pass 16 | else: 17 | raise NotImplementedError() 18 | return level 19 | 20 | 21 | def print_log(msg, logger='current', level='info', only_rank0=True): 22 | 23 | if only_rank0 and dist.is_initialized(): 24 | if dist.get_rank() == 0: 25 | engine_print_log(msg, logger, get_level(level)) 26 | else: 27 | pass 28 | else: 29 | engine_print_log(msg, logger, get_level(level)) 30 | 31 | 32 | def set_log_level(level='debug'): 33 | level = get_level(level) 34 | 35 | logger = MMLogger.get_current_instance() 36 | logger.handlers[0].setLevel(level) 37 | logger.setLevel(level) 38 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import glob 3 | import os.path as osp 4 | import warnings 5 | 6 | 7 | def find_latest_checkpoint(path, suffix='pth'): 8 | """Find the latest checkpoint from the working directory. 9 | 10 | Args: 11 | path(str): The path to find checkpoints. 12 | suffix(str): File extension. Defaults to pth. 13 | 14 | Returns: 15 | latest_path(str | None): File path of the latest checkpoint. 16 | 17 | References: 18 | .. [1] https://github.com/microsoft/SoftTeacher 19 | /blob/main/ssod/utils/patch.py 20 | """ 21 | if not osp.exists(path): 22 | warnings.warn('The path of checkpoints does not exist.') 23 | return None 24 | if osp.exists(osp.join(path, f'latest.{suffix}')): 25 | return osp.join(path, f'latest.{suffix}') 26 | 27 | checkpoints = glob.glob(osp.join(path, f'*.{suffix}')) 28 | if len(checkpoints) == 0: 29 | warnings.warn('There are no checkpoints in the path.') 30 | return None 31 | latest = -1 32 | latest_path = None 33 | for checkpoint in checkpoints: 34 | count = int(osp.basename(checkpoint).split('_')[-1].split('.')[0]) 35 | if count > latest: 36 | latest = count 37 | latest_path = checkpoint 38 | return latest_path 39 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/version.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved 2 | 3 | __version__ = '1.0.0' 4 | 5 | 6 | def parse_version_info(version_str): 7 | """Parse a version string into a tuple. 8 | 9 | Args: 10 | version_str (str): The version string. 11 | Returns: 12 | tuple[int | str]: The version info, e.g., "1.3.0" is parsed into 13 | (1, 3, 0), and "2.0.0rc1" is parsed into (2, 0, 0, 'rc1'). 14 | """ 15 | version_info = [] 16 | for x in version_str.split('.'): 17 | if x.isdigit(): 18 | version_info.append(int(x)) 19 | elif x.find('rc') != -1: 20 | patch_version = x.split('rc') 21 | version_info.append(int(patch_version[0])) 22 | version_info.append(f'rc{patch_version[1]}') 23 | return tuple(version_info) 24 | 25 | 26 | version_info = parse_version_info(__version__) 27 | 28 | __all__ = ['__version__', 'version_info', 'parse_version_info'] 29 | -------------------------------------------------------------------------------- /mmrazor/mmrazor/visualization/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .local_visualizer import modify 3 | 4 | __all__ = ['modify'] 5 | -------------------------------------------------------------------------------- /mmrazor/projects/mmrazor_large/examples/ResNet/README.md: -------------------------------------------------------------------------------- 1 | # Examples for ResNet 2 | 3 | ## SparseGPT 4 | 5 | For more details about SparseGPT, please refer to [SparseGPT](../../algorithms/SparseGPT.md) 6 | 7 | ### Usage 8 | 9 | ```shell 10 | python projects/mmrazor_large/examples/ResNet/resnet18_sparse_gpt.py --data {imagenet_path} --batchsize 128 --num_samples 512 11 | ``` 12 | 13 | **Note**: this imagenet folder follows torch format. 14 | 15 | ## GPTQ 16 | 17 | For more details about GPTQ, please refer to [GPTQ](../../algorithms/GPTQ.md) 18 | 19 | ### Usage 20 | 21 | ```shell 22 | python projects/mmrazor_large/examples/ResNet/resnet18_gptq.py --data {imagenet_path} --batchsize 128 --num_samples 512 23 | ``` 24 | 25 | **Note**: this imagenet folder follows torch format. 26 | -------------------------------------------------------------------------------- /mmrazor/requirements.txt: -------------------------------------------------------------------------------- 1 | -r requirements/optional.txt 2 | -r requirements/runtime.txt 3 | -r requirements/tests.txt 4 | -------------------------------------------------------------------------------- /mmrazor/requirements/docs.txt: -------------------------------------------------------------------------------- 1 | docutils==0.16.0 2 | m2r 3 | myst-parser 4 | git+https://github.com/open-mmlab/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme 5 | sphinx==4.0.2 6 | sphinx-copybutton 7 | sphinx_markdown_tables 8 | -------------------------------------------------------------------------------- /mmrazor/requirements/mminstall.txt: -------------------------------------------------------------------------------- 1 | mmcv>=2.0.0rc1 2 | mmengine>=0.1.0,<1.0.0 3 | -------------------------------------------------------------------------------- /mmrazor/requirements/optional.txt: -------------------------------------------------------------------------------- 1 | pydacefit 2 | pySOT==0.2.3 3 | scipy 4 | timm 5 | -------------------------------------------------------------------------------- /mmrazor/requirements/readthedocs.txt: -------------------------------------------------------------------------------- 1 | mmcv>=1.3.8 2 | ordered_set 3 | torch 4 | torchvision 5 | -------------------------------------------------------------------------------- /mmrazor/requirements/runtime.txt: -------------------------------------------------------------------------------- 1 | ordered_set 2 | typing_extensions;python_version<"3.8" 3 | -------------------------------------------------------------------------------- /mmrazor/requirements/tests.txt: -------------------------------------------------------------------------------- 1 | coverage 2 | flake8 3 | interrogate 4 | isort==4.3.21 5 | nbconvert 6 | nbformat 7 | numpy < 1.24.0 # A temporary solution for tests with mmdet. 8 | onnx 9 | pytest 10 | triton==2.0.0 11 | xdoctest >= 0.10.0 12 | yapf 13 | -------------------------------------------------------------------------------- /mmrazor/resources/design_and_implement.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LKJacky/Differentiable-Model-Scaling/27188926d8e0815ba1b6159b47741ea289ba2ca2/mmrazor/resources/design_and_implement.png -------------------------------------------------------------------------------- /mmrazor/resources/mmrazor-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LKJacky/Differentiable-Model-Scaling/27188926d8e0815ba1b6159b47741ea289ba2ca2/mmrazor/resources/mmrazor-logo.png -------------------------------------------------------------------------------- /mmrazor/resources/qq_group_qrcode.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LKJacky/Differentiable-Model-Scaling/27188926d8e0815ba1b6159b47741ea289ba2ca2/mmrazor/resources/qq_group_qrcode.jpg -------------------------------------------------------------------------------- /mmrazor/resources/xiaozhushou_weixin_qrcode.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LKJacky/Differentiable-Model-Scaling/27188926d8e0815ba1b6159b47741ea289ba2ca2/mmrazor/resources/xiaozhushou_weixin_qrcode.jpeg -------------------------------------------------------------------------------- /mmrazor/resources/zhihu_qrcode.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LKJacky/Differentiable-Model-Scaling/27188926d8e0815ba1b6159b47741ea289ba2ca2/mmrazor/resources/zhihu_qrcode.jpg -------------------------------------------------------------------------------- /mmrazor/setup.cfg: -------------------------------------------------------------------------------- 1 | [bdist_wheel] 2 | universal=1 3 | 4 | [aliases] 5 | test=pytest 6 | 7 | [yapf] 8 | based_on_style = pep8 9 | blank_line_before_nested_class_or_def = true 10 | split_before_expression_after_opening_paren = true 11 | 12 | [isort] 13 | line_length = 79 14 | multi_line_output = 0 15 | extra_standard_library = pkg_resources,setuptools 16 | known_first_party = mmrazor 17 | known_third_party=cv2,mmcls,mmcv,mmdet,mmseg,numpy,ordered_set,packaging,pytest,pytorch_sphinx_theme,torch,yaml 18 | no_lines_before = STDLIB,LOCALFOLDER 19 | default_section = THIRDPARTY 20 | 21 | [codespell] 22 | skip = *.ipynb 23 | quiet-level = 3 24 | ignore-words-list = patten,confectionary,nd,ty,formating 25 | -------------------------------------------------------------------------------- /mmrazor/tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | -------------------------------------------------------------------------------- /mmrazor/tests/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | -------------------------------------------------------------------------------- /mmrazor/tests/data/color.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LKJacky/Differentiable-Model-Scaling/27188926d8e0815ba1b6159b47741ea289ba2ca2/mmrazor/tests/data/color.jpeg -------------------------------------------------------------------------------- /mmrazor/tests/data/concat_subnet1.yaml: -------------------------------------------------------------------------------- 1 | op1.mutable_in_channels: 2 | current_choice: 3 3 | origin_channels: 3 4 | op1.mutable_out_channels: 5 | current_choice: 4 6 | origin_channels: 8 7 | bn1.mutable_num_features: 8 | current_choice: 4 9 | origin_channels: 8 10 | op2.mutable_in_channels: 11 | current_choice: 3 12 | origin_channels: 3 13 | op2.mutable_out_channels: 14 | current_choice: 4 15 | origin_channels: 8 16 | bn2.mutable_num_features: 17 | current_choice: 4 18 | origin_channels: 8 19 | op3.mutable_in_channels: 20 | current_choice: 8 21 | origin_channels: 16 22 | op3.mutable_out_channels: 23 | current_choice: 8 24 | origin_channels: 8 -------------------------------------------------------------------------------- /mmrazor/tests/data/concat_subnet2.yaml: -------------------------------------------------------------------------------- 1 | op1.mutable_in_channels: 2 | current_choice: 3 3 | origin_channels: 3 4 | op1.mutable_out_channels: 5 | current_choice: 8 6 | origin_channels: 8 7 | bn1.mutable_num_features: 8 | current_choice: 8 9 | origin_channels: 8 10 | op2.mutable_in_channels: 11 | current_choice: 3 12 | origin_channels: 3 13 | op2.mutable_out_channels: 14 | current_choice: 8 15 | origin_channels: 8 16 | bn2.mutable_num_features: 17 | current_choice: 8 18 | origin_channels: 8 19 | op3.mutable_in_channels: 20 | current_choice: 16 21 | origin_channels: 16 22 | op3.mutable_out_channels: 23 | current_choice: 8 24 | origin_channels: 8 -------------------------------------------------------------------------------- /mmrazor/tests/data/dataset/a/1.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LKJacky/Differentiable-Model-Scaling/27188926d8e0815ba1b6159b47741ea289ba2ca2/mmrazor/tests/data/dataset/a/1.JPG -------------------------------------------------------------------------------- /mmrazor/tests/data/dataset/ann.json: -------------------------------------------------------------------------------- 1 | { 2 | "metainfo": { 3 | "categories": [ 4 | { 5 | "category_name": "first", 6 | "id": 0 7 | }, 8 | { 9 | "category_name": "second", 10 | "id": 1 11 | } 12 | ] 13 | }, 14 | "data_list": [ 15 | { 16 | "img_path": "a/1.JPG", 17 | "gt_label": 0 18 | }, 19 | { 20 | "img_path": "b/2.jpeg", 21 | "gt_label": 1 22 | }, 23 | { 24 | "img_path": "b/subb/2.jpeg", 25 | "gt_label": 1 26 | } 27 | ] 28 | } 29 | -------------------------------------------------------------------------------- /mmrazor/tests/data/dataset/ann.txt: -------------------------------------------------------------------------------- 1 | a/1.JPG 0 2 | b/2.jpeg 1 3 | b/subb/3.jpg 1 4 | -------------------------------------------------------------------------------- /mmrazor/tests/data/dataset/b/2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LKJacky/Differentiable-Model-Scaling/27188926d8e0815ba1b6159b47741ea289ba2ca2/mmrazor/tests/data/dataset/b/2.jpeg -------------------------------------------------------------------------------- /mmrazor/tests/data/dataset/b/subb/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LKJacky/Differentiable-Model-Scaling/27188926d8e0815ba1b6159b47741ea289ba2ca2/mmrazor/tests/data/dataset/b/subb/3.jpg -------------------------------------------------------------------------------- /mmrazor/tests/data/dataset/classes.txt: -------------------------------------------------------------------------------- 1 | bus 2 | car 3 | -------------------------------------------------------------------------------- /mmrazor/tests/data/dataset/multi_label_ann.json: -------------------------------------------------------------------------------- 1 | { 2 | "metainfo": { 3 | "categories": [ 4 | { 5 | "category_name": "first", 6 | "id": 0 7 | }, 8 | { 9 | "category_name": "second", 10 | "id": 1 11 | } 12 | ] 13 | }, 14 | "data_list": [ 15 | { 16 | "img_path": "a/1.JPG", 17 | "gt_label": [0] 18 | }, 19 | { 20 | "img_path": "b/2.jpeg", 21 | "gt_label": [1] 22 | }, 23 | { 24 | "img_path": "b/subb/2.jpeg", 25 | "gt_label": [0, 1] 26 | } 27 | ] 28 | } 29 | -------------------------------------------------------------------------------- /mmrazor/tests/data/subnet1.yaml: -------------------------------------------------------------------------------- 1 | op1.mutable_in_channels: 2 | current_choice: 3 3 | origin_channels: 3 4 | op1.mutable_out_channels: 5 | current_choice: 4 6 | origin_channels: 8 7 | bn1.mutable_num_features: 8 | current_choice: 4 9 | origin_channels: 8 10 | op2.mutable_in_channels: 11 | current_choice: 4 12 | origin_channels: 8 13 | op2.mutable_out_channels: 14 | current_choice: 4 15 | origin_channels: 8 16 | bn2.mutable_num_features: 17 | current_choice: 4 18 | origin_channels: 8 19 | op3.mutable_in_channels: 20 | current_choice: 4 21 | origin_channels: 8 22 | op3.mutable_out_channels: 23 | current_choice: 8 24 | origin_channels: 8 -------------------------------------------------------------------------------- /mmrazor/tests/data/subnet2.yaml: -------------------------------------------------------------------------------- 1 | op1.mutable_in_channels: 2 | current_choice: 3 3 | origin_channels: 3 4 | op1.mutable_out_channels: 5 | current_choice: 8 6 | origin_channels: 8 7 | bn1.mutable_num_features: 8 | current_choice: 8 9 | origin_channels: 8 10 | op2.mutable_in_channels: 11 | current_choice: 8 12 | origin_channels: 8 13 | op2.mutable_out_channels: 14 | current_choice: 8 15 | origin_channels: 8 16 | bn2.mutable_num_features: 17 | current_choice: 8 18 | origin_channels: 8 19 | op3.mutable_in_channels: 20 | current_choice: 8 21 | origin_channels: 8 22 | op3.mutable_out_channels: 23 | current_choice: 8 24 | origin_channels: 8 -------------------------------------------------------------------------------- /mmrazor/tests/data/test_models/test_mutator/subnet1.json: -------------------------------------------------------------------------------- 1 | { 2 | "op1_(0, 8)_8": { 3 | "init_args":{ 4 | "num_channels":8, 5 | "divisor":1, 6 | "min_value":1, 7 | "min_ratio":0.9, 8 | "candidate_choices":[ 9 | 6 10 | ], 11 | "choice_mode":"number" 12 | }, 13 | "choice":6 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /mmrazor/tests/data/test_models/test_subnet/mockmodel_subnet.yaml: -------------------------------------------------------------------------------- 1 | mutable1: 2 | chosen: conv1 3 | mutable2: 4 | chosen: conv2 5 | mutable3.0.kernel_size: 6 | chosen: 3 7 | -------------------------------------------------------------------------------- /mmrazor/tests/data/test_models/test_task_modules/mmcls_cfg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | _base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py'] -------------------------------------------------------------------------------- /mmrazor/tests/data/test_registry/registry_architecture_config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from platform import architecture 3 | 4 | 5 | supernet = dict( 6 | type='MockModel', 7 | ) 8 | 9 | model = dict( 10 | type='MockAlgorithm', 11 | architecture=supernet, 12 | _return_architecture_ = True, 13 | ) 14 | 15 | -------------------------------------------------------------------------------- /mmrazor/tests/data/test_registry/registry_subnet_config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | supernet = dict( 3 | type='mmrazor.sub_model', 4 | cfg=dict( 5 | type='MockModel', 6 | ), 7 | fix_subnet = { 8 | 'backbone.mutable1': {'chosen':'conv1'}, 9 | 'backbone.mutable2': {'chosen':'conv2'}, 10 | }, 11 | extra_prefix='backbone.' 12 | ) 13 | 14 | model = dict( 15 | type='MockAlgorithm', 16 | architecture=supernet 17 | ) 18 | -------------------------------------------------------------------------------- /mmrazor/tests/test_core/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | -------------------------------------------------------------------------------- /mmrazor/tests/test_core/test_delivers/toy_module.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import random 3 | 4 | TOY_VAR = 'aaa' 5 | 6 | 7 | def toy_func(): 8 | return random.randint(0, 1000) 9 | 10 | 11 | class ToyClass: 12 | 13 | def __init__(self): 14 | self._count = 0 15 | 16 | def random_int(self): 17 | return random.randint(0, 1000) 18 | 19 | @property 20 | def count(self): 21 | return self._count 22 | 23 | def __call__(self): 24 | self._count += 1 25 | return self._count 26 | -------------------------------------------------------------------------------- /mmrazor/tests/test_core/test_graph/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | -------------------------------------------------------------------------------- /mmrazor/tests/test_core/test_graph/test_graph.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import sys 3 | from unittest import TestCase 4 | 5 | import torch 6 | 7 | sys.setrecursionlimit(int(1e8)) 8 | 9 | DEVICE = torch.device('cpu') 10 | 11 | 12 | class TestGraph(TestCase): 13 | pass 14 | # def test_init_from_fx_tracer(self) -> None: 15 | # TestData = BackwardPassedModelManager.include_models() 16 | # with SetTorchThread(1): 17 | # with mp.Pool() as p: 18 | # result = p.map(_test_init_from_fx_tracer, TestData) 19 | # for res, model in zip(result, TestData): 20 | # with self.subTest(model=model): 21 | # self.assertTrue(res[0], res[1]) 22 | 23 | # def test_init_from_backward_tracer(self) -> None: 24 | # TestData = FxPassedModelManager.include_models() 25 | # with SetTorchThread(1) as _: 26 | # with mp.Pool() as p: 27 | # result = p.map(_test_init_from_backward_tracer, TestData) 28 | # for res, model in zip(result, TestData): 29 | # # test_init_from_backward_tracer(model) 30 | # with self.subTest(model=model): 31 | # self.assertTrue(res[0], res[1]) 32 | -------------------------------------------------------------------------------- /mmrazor/tests/test_core/test_recorders/test_method_inputs_recorder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from unittest import TestCase 3 | 4 | from mmrazor.models.task_modules import MethodInputsRecorder 5 | 6 | 7 | class TestFuncOutputsRecorder(TestCase): 8 | 9 | def test_context_manager(self): 10 | from toy_mod import ToyClass 11 | 12 | toy = ToyClass() 13 | 14 | recorder = MethodInputsRecorder('toy_mod.ToyClass.func') 15 | recorder.initialize() 16 | 17 | with recorder: 18 | _ = toy.func(x=1, y=2) 19 | _ = toy.func(1, y=2) 20 | _ = toy.func(y=2, x=1) 21 | 22 | self.assertTrue( 23 | recorder.get_record_data(record_idx=0, data_idx=0) == 1) 24 | self.assertTrue( 25 | recorder.get_record_data(record_idx=0, data_idx=1) == 2) 26 | 27 | self.assertTrue( 28 | recorder.get_record_data(record_idx=1, data_idx=0) == 1) 29 | self.assertTrue( 30 | recorder.get_record_data(record_idx=1, data_idx=1) == 2) 31 | 32 | self.assertTrue( 33 | recorder.get_record_data(record_idx=2, data_idx=0) == 1) 34 | self.assertTrue( 35 | recorder.get_record_data(record_idx=2, data_idx=1) == 2) 36 | -------------------------------------------------------------------------------- /mmrazor/tests/test_core/test_recorders/test_param_recorder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from unittest import TestCase 3 | 4 | import torch 5 | from torch import nn 6 | 7 | from mmrazor.models.task_modules import ParameterRecorder 8 | 9 | 10 | class ToyModel(nn.Module): 11 | 12 | def __init__(self): 13 | super().__init__() 14 | self.toy_conv = nn.Conv2d(1, 1, 1) 15 | self.no_record_conv = nn.Conv2d(1, 1, 1) 16 | 17 | def forward(self, x): 18 | return self.toy_conv(x) 19 | 20 | 21 | class TestParameterRecorder(TestCase): 22 | 23 | def test_prepare_from_model(self): 24 | 25 | model = ToyModel() 26 | recorder = ParameterRecorder('AAA') 27 | with self.assertRaisesRegex(AssertionError, '"AAA" is not in the'): 28 | recorder.initialize(model) 29 | 30 | recorder = ParameterRecorder('toy_conv.bias') 31 | with self.assertRaisesRegex(AssertionError, 'model can not be None'): 32 | recorder.prepare_from_model() 33 | 34 | recorder.initialize(model) 35 | bias_weight = recorder.get_record_data() 36 | 37 | self.assertEquals(bias_weight, model.toy_conv.bias) 38 | 39 | with recorder: 40 | _ = model(torch.randn(1, 1, 1, 1)) 41 | 42 | self.assertEquals(bias_weight, model.toy_conv.bias) 43 | -------------------------------------------------------------------------------- /mmrazor/tests/test_core/test_recorders/toy_mod.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import random 3 | 4 | TOY_VAR = 'aaa' 5 | 6 | 7 | def toy_func(a): 8 | return a 9 | 10 | 11 | def toy_func2(a, b): 12 | return a, b 13 | 14 | 15 | def toy_list_func(a): 16 | return [a, a, a] 17 | 18 | 19 | def execute_toy_func(a): 20 | toy_func(a) 21 | 22 | 23 | def execute_toy_func2(a, b): 24 | toy_func2(a, b) 25 | 26 | 27 | def execute_toy_list_func(a): 28 | toy_list_func(a) 29 | 30 | 31 | class ToyClass: 32 | 33 | TOY_CLS = 'TOY_CLASS' 34 | 35 | def __init__(self): 36 | self._count = 0 37 | 38 | def toy(self): 39 | self._count += 1 40 | return self._count 41 | 42 | def func(self, x, y=0): 43 | return x + y 44 | 45 | def __call__(self): 46 | self._count += 1 47 | return self._count 48 | 49 | 50 | class Toy(): 51 | 52 | def toy_func(self): 53 | return random.randint(0, 1000) 54 | 55 | def toy_list_func(self): 56 | return [random.randint(0, 1000) for _ in range(3)] 57 | -------------------------------------------------------------------------------- /mmrazor/tests/test_core/test_tracer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | -------------------------------------------------------------------------------- /mmrazor/tests/test_core/test_tracer/test_prune_tracer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from unittest import TestCase 3 | 4 | import torch 5 | 6 | from mmrazor import digit_version 7 | from mmrazor.models.task_modules.tracer import ChannelAnalyzer 8 | from ...data.models import SingleLineModel 9 | 10 | 11 | class TestChannelAnalyzer(TestCase): 12 | 13 | def test_backward_tracer(self): 14 | model = SingleLineModel() 15 | tracer = ChannelAnalyzer(tracer_type='BackwardTracer') 16 | unit_configs = tracer.analyze(model) 17 | print(unit_configs) 18 | 19 | def test_fx_tracer(self): 20 | if digit_version(torch.__version__) < digit_version('1.12.0'): 21 | self.skipTest('torch<1.12.0') 22 | model = SingleLineModel() 23 | tracer = ChannelAnalyzer(tracer_type='FxTracer') 24 | unit_configs = tracer.analyze(model) 25 | print(unit_configs) 26 | -------------------------------------------------------------------------------- /mmrazor/tests/test_doc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os 3 | from unittest import TestCase 4 | 5 | import nbformat 6 | from nbconvert.preprocessors import ExecutePreprocessor 7 | 8 | TEST_DOC = os.getenv('TEST_DOC') == 'true' 9 | notebook_paths = [ 10 | './mmrazor/models/mutators/channel_mutator/channel_mutator.ipynb', 11 | './mmrazor/models/mutables/mutable_channel/units/mutable_channel_unit.ipynb', # noqa 12 | './demo/config_pruning.ipynb' 13 | ] 14 | 15 | 16 | class TestDocs(TestCase): 17 | 18 | def setUp(self) -> None: 19 | if not TEST_DOC: 20 | self.skipTest('disabled') 21 | 22 | def test_notebooks(self): 23 | for path in notebook_paths: 24 | with self.subTest(path=path): 25 | with open(path) as file: 26 | nb_in = nbformat.read(file, nbformat.NO_CONVERT) 27 | ep = ExecutePreprocessor( 28 | timeout=600, kernel_name='python3') 29 | try: 30 | _ = ep.preprocess(nb_in) 31 | except Exception: 32 | self.fail() 33 | -------------------------------------------------------------------------------- /mmrazor/tests/test_engine/test_hooks/test_stop_distillation_hook.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from unittest import TestCase 3 | from unittest.mock import Mock 4 | 5 | from mmrazor.engine import StopDistillHook 6 | 7 | 8 | class TestStopDistillHook(TestCase): 9 | 10 | def setUp(self): 11 | self.hook = StopDistillHook(stop_epoch=5) 12 | runner = Mock() 13 | runner.model = Mock() 14 | runner.model.distillation_stopped = False 15 | 16 | runner.epoch = 0 17 | self.runner = runner 18 | 19 | def test_before_train_epoch(self): 20 | max_epochs = 10 21 | target = [False] * 5 + [True] * 5 22 | for epoch in range(max_epochs): 23 | self.hook.before_train_epoch(self.runner) 24 | self.assertEquals(self.runner.model.distillation_stopped, 25 | target[epoch]) 26 | self.runner.epoch += 1 27 | -------------------------------------------------------------------------------- /mmrazor/tests/test_impl/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | -------------------------------------------------------------------------------- /mmrazor/tests/test_impl/test_pruning/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | -------------------------------------------------------------------------------- /mmrazor/tests/test_impl/test_pruning/test_group_fisher/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | -------------------------------------------------------------------------------- /mmrazor/tests/test_models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | -------------------------------------------------------------------------------- /mmrazor/tests/test_models/test_algorithms/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | -------------------------------------------------------------------------------- /mmrazor/tests/test_models/test_algorithms/test_general_quant.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from unittest import TestCase 3 | 4 | import torch.nn as nn 5 | 6 | 7 | class ToyModel(nn.Module): 8 | 9 | def __init__(self) -> None: 10 | super().__init__() 11 | # TODO 12 | 13 | 14 | class TestGeneralQuant(TestCase): 15 | """TODO. 16 | 17 | Args: 18 | TestCase (_type_): _description_ 19 | """ 20 | 21 | def test_init(self): 22 | pass 23 | 24 | def test_prepare(self): 25 | pass 26 | 27 | def test_convert(self): 28 | pass 29 | 30 | def test_states(self): 31 | pass 32 | 33 | def test_forward(self): 34 | pass 35 | -------------------------------------------------------------------------------- /mmrazor/tests/test_models/test_architectures/test_backbones/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from typing import Dict, List 3 | 4 | from torch import Tensor 5 | from torch.nn import Conv2d, Module 6 | 7 | from mmrazor.registry import MODELS 8 | 9 | 10 | @MODELS.register_module() 11 | class MockMutable(Module): 12 | 13 | def __init__(self, choices: List[str], module_kwargs: Dict) -> None: 14 | super().__init__() 15 | 16 | self.choices = choices 17 | self.module_kwargs = module_kwargs 18 | self.conv = Conv2d(**module_kwargs, kernel_size=3, padding=3 // 2) 19 | 20 | def forward(self, x: Tensor) -> Tensor: 21 | return self.conv(x) 22 | -------------------------------------------------------------------------------- /mmrazor/tests/test_models/test_architectures/test_dynamic_op/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | -------------------------------------------------------------------------------- /mmrazor/tests/test_models/test_architectures/test_dynamic_op/test_bricks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | -------------------------------------------------------------------------------- /mmrazor/tests/test_models/test_architectures/test_dynamic_op/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from typing import Dict, Optional 3 | 4 | from mmrazor.models.architectures.dynamic_ops import DynamicMixin 5 | from mmrazor.utils.typing import DumpChosen 6 | 7 | 8 | def fix_dynamic_op(op: DynamicMixin, 9 | fix_mutables: Optional[Dict] = None) -> None: 10 | for name, mutable in op.mutable_attrs.items(): 11 | 12 | if fix_mutables is not None: 13 | chosen = fix_mutables[f'mutable_attrs.{name}'] 14 | else: 15 | chosen = mutable.dump_chosen() 16 | 17 | if not isinstance(chosen, DumpChosen): 18 | chosen = DumpChosen(**chosen) 19 | 20 | mutable.fix_chosen(chosen.chosen) 21 | -------------------------------------------------------------------------------- /mmrazor/tests/test_models/test_fake_quants/test_torch_fake_quants.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import pytest 3 | import torch 4 | 5 | from mmrazor import digit_version 6 | from mmrazor.models.fake_quants import register_torch_fake_quants 7 | from mmrazor.registry import MODELS 8 | 9 | 10 | @pytest.mark.skipif( 11 | digit_version(torch.__version__) < digit_version('1.13.0'), 12 | reason='version of torch < 1.13.0') 13 | def test_register_torch_fake_quants(): 14 | 15 | TORCH_fake_quants = register_torch_fake_quants() 16 | assert isinstance(TORCH_fake_quants, list) 17 | for fake_quant in TORCH_fake_quants: 18 | assert MODELS.get(fake_quant) 19 | -------------------------------------------------------------------------------- /mmrazor/tests/test_models/test_mutables/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | -------------------------------------------------------------------------------- /mmrazor/tests/test_models/test_mutables/test_mutable_channel/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | -------------------------------------------------------------------------------- /mmrazor/tests/test_models/test_mutables/test_mutable_channel/test_units/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | -------------------------------------------------------------------------------- /mmrazor/tests/test_models/test_mutables/test_mutable_channel/test_units/test_l1_mutable_channel_unit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from unittest import TestCase 3 | 4 | import torch.nn as nn 5 | 6 | from mmrazor.models.mutables import L1MutableChannelUnit 7 | from mmrazor.models.mutators import ChannelMutator 8 | from .....data.models import SingleLineModel 9 | 10 | 11 | class TestL1MutableChannelUnit(TestCase): 12 | 13 | def test_init(self): 14 | model = SingleLineModel() 15 | mutator = ChannelMutator( 16 | channel_unit_cfg={ 17 | 'type': 'L1MutableChannelUnit', 18 | 'default_args': { 19 | 'choice_mode': 'ratio' 20 | } 21 | }) 22 | mutator.prepare_from_supernet(model) 23 | 24 | def test_convnd(self): 25 | unit = L1MutableChannelUnit(8) 26 | conv = nn.Conv3d(3, 8, 3) 27 | norm = unit._get_l1_norm(conv, 0, 8) 28 | self.assertSequenceEqual(norm.shape, [8]) 29 | -------------------------------------------------------------------------------- /mmrazor/tests/test_models/test_mutables/test_sequential_mutable_channel.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from unittest import TestCase 3 | 4 | from mmrazor.models.mutables import SquentialMutableChannel 5 | 6 | 7 | class TestSquentialMutableChannel(TestCase): 8 | 9 | def test_mul_float(self): 10 | channel = SquentialMutableChannel(10) 11 | new_channel = channel * 0.5 12 | self.assertEqual(new_channel.current_choice, 5) 13 | channel.current_choice = 5 14 | self.assertEqual(new_channel.current_choice, 2) 15 | -------------------------------------------------------------------------------- /mmrazor/tests/test_models/test_mutators/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | -------------------------------------------------------------------------------- /mmrazor/tests/test_models/test_mutators/test_dmcp_mutator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | from mmcls.models import * # noqa: F401,F403 4 | from torch import Tensor, nn 5 | from torch.nn import Module 6 | 7 | from mmrazor.models.mutators import DMCPChannelMutator 8 | 9 | 10 | class ResBlock(Module): 11 | 12 | def __init__(self) -> None: 13 | super().__init__() 14 | 15 | self.op1 = nn.Conv2d(3, 8, 1) 16 | self.bn1 = nn.BatchNorm2d(8) 17 | self.op2 = nn.Conv2d(8, 8, 1) 18 | self.bn2 = nn.BatchNorm2d(8) 19 | self.op3 = nn.Conv2d(8, 8, 1) 20 | 21 | def forward(self, x: Tensor) -> Tensor: 22 | x1 = self.bn1(self.op1(x)) 23 | x2 = self.bn2(self.op2(x1)) 24 | x3 = self.op3(x2 + x1) 25 | return x3 26 | 27 | 28 | def test_DMCP_channel_mutator() -> None: 29 | imgs = torch.randn(16, 3, 224, 224) 30 | 31 | # ResBlock 32 | mutator = DMCPChannelMutator(channel_unit_cfg=dict(type='DMCPChannelUnit')) 33 | 34 | model = ResBlock() 35 | mutator.prepare_from_supernet(model) 36 | for mode in ['max', 'min', 'random', 'expected', 'direct']: 37 | mutator.sample_subnet(mode, arch_train=True) 38 | out3 = model(imgs) 39 | 40 | assert out3.shape == (16, 8, 224, 224) 41 | -------------------------------------------------------------------------------- /mmrazor/tests/test_models/test_observers/test_torch_observers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import pytest 3 | import torch 4 | 5 | from mmrazor import digit_version 6 | from mmrazor.models.observers import register_torch_observers 7 | from mmrazor.registry import MODELS 8 | 9 | 10 | @pytest.mark.skipif( 11 | digit_version(torch.__version__) < digit_version('1.13.0'), 12 | reason='version of torch < 1.13.0') 13 | def test_register_torch_observers(): 14 | 15 | TORCH_observers = register_torch_observers() 16 | assert isinstance(TORCH_observers, list) 17 | for observer in TORCH_observers: 18 | assert MODELS.get(observer) 19 | -------------------------------------------------------------------------------- /mmrazor/tests/test_models/test_task_modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | -------------------------------------------------------------------------------- /mmrazor/tests/test_models/test_task_modules/test_demo_inputs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | -------------------------------------------------------------------------------- /mmrazor/tests/test_models/test_task_modules/test_demo_inputs/test_demo_inputs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import unittest 3 | 4 | from mmrazor.models.task_modules.demo_inputs import DefaultDemoInput 5 | from ....data.tracer_passed_models import FxPassedModelManager 6 | 7 | 8 | class TestDemoInputs(unittest.TestCase): 9 | 10 | def test_demo_inputs(self): 11 | for Model in FxPassedModelManager().include_models(): 12 | with self.subTest(model=Model): 13 | demo_input = DefaultDemoInput(input_shape=[1, 3, 224, 224]) 14 | model = Model() 15 | model.eval() 16 | try: 17 | demo_input(model) 18 | input = demo_input.get_data(model) 19 | if isinstance(input, dict): 20 | model(**input) 21 | else: 22 | model(input) 23 | except Exception as e: 24 | self.fail(f'{e}') 25 | -------------------------------------------------------------------------------- /mmrazor/tests/test_models/test_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | -------------------------------------------------------------------------------- /mmrazor/tests/test_models/test_utils/test_expandable_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | -------------------------------------------------------------------------------- /mmrazor/tests/test_runners/test_utils/test_genetic.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmrazor.engine.runner.utils import crossover 3 | 4 | 5 | def test_crossover(): 6 | fake_random_subnet1 = {} 7 | fake_random_subnet2 = {} 8 | for i in range(50): 9 | fake_random_subnet1[i] = f'{i}_choice1' 10 | fake_random_subnet2[i] = f'{i}_choice2' 11 | 12 | result = crossover(fake_random_subnet1, fake_random_subnet2) 13 | 14 | assert type(result) == type(fake_random_subnet1) 15 | assert len(result) == len(fake_random_subnet1) 16 | -------------------------------------------------------------------------------- /mmrazor/tests/test_tools/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | -------------------------------------------------------------------------------- /mmrazor/tests/test_utils/test_index_dict.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import unittest 3 | 4 | from mmrazor.utils.index_dict import IndexDict 5 | 6 | 7 | class TestIndexDict(unittest.TestCase): 8 | 9 | def test_dict(self): 10 | dict = IndexDict() 11 | dict[(4, 5)] = 2 12 | dict[(1, 3)] = 1 13 | 14 | self.assertSequenceEqual(list(dict.keys()), [(1, 3), (4, 5)]) 15 | with self.assertRaisesRegex(AssertionError, 'overlap'): 16 | dict[2, 3] = 3 17 | -------------------------------------------------------------------------------- /mmrazor/tests/test_utils/test_placeholder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import unittest 3 | 4 | import pytest 5 | 6 | from mmrazor.utils import get_placeholder 7 | 8 | 9 | class TestPlaceholder(unittest.TestCase): 10 | 11 | def test_placeholder(self): 12 | holder = get_placeholder('test') 13 | with pytest.raises(ImportError): 14 | holder() 15 | from mmrazor.models.architectures.dynamic_ops import DynamicMixin 16 | 17 | class tmp(holder, DynamicMixin): 18 | pass 19 | -------------------------------------------------------------------------------- /mmrazor/tests/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .set_torch_thread import SetTorchThread 3 | 4 | __all__ = ['SetTorchThread'] 5 | -------------------------------------------------------------------------------- /mmrazor/tests/utils/set_dist_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os 3 | import random 4 | 5 | import torch 6 | import torch.distributed as dist 7 | 8 | 9 | class SetDistEnv: 10 | 11 | def __init__(self, using_cuda=False, port=None) -> None: 12 | self.using_cuda = using_cuda 13 | if self.using_cuda: 14 | assert torch.cuda.is_available() 15 | if port is None: 16 | port = random.randint(10000, 20000) 17 | self.port = port 18 | 19 | def __enter__(self): 20 | os.environ['MASTER_ADDR'] = 'localhost' 21 | os.environ['MASTER_PORT'] = str(self.port) 22 | 23 | # initialize the process group 24 | if self.using_cuda: 25 | backend = 'nccl' 26 | else: 27 | backend = 'gloo' 28 | dist.init_process_group(backend, rank=0, world_size=1) 29 | 30 | def __exit__(self, exc_type, exc_value, tb): 31 | dist.destroy_process_group() 32 | -------------------------------------------------------------------------------- /mmrazor/tests/utils/set_torch_thread.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | 4 | 5 | class SetTorchThread: 6 | 7 | def __init__(self, num_thread: int = -1) -> None: 8 | self.prev_num_threads = torch.get_num_threads() 9 | self.num_threads = num_thread 10 | 11 | def __enter__(self): 12 | if self.num_threads != -1: 13 | torch.set_num_threads(self.num_threads) 14 | 15 | def __exit__(self, exc_type, exc_value, tb): 16 | if self.num_threads != -1: 17 | torch.set_num_threads(self.prev_num_threads) 18 | -------------------------------------------------------------------------------- /mmrazor/tools/dist_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | CHECKPOINT=$2 5 | GPUS=$3 6 | NNODES=${NNODES:-1} 7 | NODE_RANK=${NODE_RANK:-0} 8 | PORT=${PORT:-29500} 9 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} 10 | 11 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 12 | python -m torch.distributed.launch \ 13 | --nnodes=$NNODES \ 14 | --node_rank=$NODE_RANK \ 15 | --master_addr=$MASTER_ADDR \ 16 | --nproc_per_node=$GPUS \ 17 | --master_port=$PORT \ 18 | $(dirname "$0")/test.py \ 19 | $CONFIG \ 20 | $CHECKPOINT \ 21 | --launcher pytorch \ 22 | ${@:4} 23 | -------------------------------------------------------------------------------- /mmrazor/tools/dist_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | GPUS=$2 5 | NNODES=${NNODES:-1} 6 | NODE_RANK=${NODE_RANK:-0} 7 | PORT=${PORT:-29500} 8 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} 9 | 10 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 11 | python -m torch.distributed.launch \ 12 | --nnodes=$NNODES \ 13 | --node_rank=$NODE_RANK \ 14 | --master_addr=$MASTER_ADDR \ 15 | --nproc_per_node=$GPUS \ 16 | --master_port=$PORT \ 17 | $(dirname "$0")/train.py \ 18 | $CONFIG \ 19 | --launcher pytorch ${@:3} 20 | -------------------------------------------------------------------------------- /mmrazor/tools/slurm_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | CHECKPOINT=$2 5 | GPUS=$3 6 | PORT=${PORT:-29500} 7 | 8 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 9 | python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ 10 | $(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4} 11 | -------------------------------------------------------------------------------- /mmrazor/tools/slurm_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | PARTITION=$1 6 | JOB_NAME=$2 7 | CONFIG=$3 8 | WORK_DIR=$4 9 | GPUS=${GPUS:-8} 10 | GPUS_PER_NODE=${GPUS_PER_NODE:-8} 11 | CPUS_PER_TASK=${CPUS_PER_TASK:-5} 12 | SRUN_ARGS=${SRUN_ARGS:-""} 13 | PY_ARGS=${@:5} 14 | 15 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 16 | srun -p ${PARTITION} \ 17 | --job-name=${JOB_NAME} \ 18 | --gres=gpu:${GPUS_PER_NODE} \ 19 | --ntasks=${GPUS} \ 20 | --ntasks-per-node=${GPUS_PER_NODE} \ 21 | --cpus-per-task=${CPUS_PER_TASK} \ 22 | --kill-on-bad-exit=1 \ 23 | ${SRUN_ARGS} \ 24 | python -u tools/train.py ${CONFIG} --work-dir=${WORK_DIR} --launcher="slurm" ${PY_ARGS} 25 | -------------------------------------------------------------------------------- /mmrazor/tools/visualizations/demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LKJacky/Differentiable-Model-Scaling/27188926d8e0815ba1b6159b47741ea289ba2ca2/mmrazor/tools/visualizations/demo.jpg -------------------------------------------------------------------------------- /mmrazor/tools/visualizations/vis_configs/backbone_feature_diff_visualization.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | # configs for the 1st model 3 | recorders1 = dict( 4 | backbone=dict(_scope_='mmrazor', type='ModuleOutputs', source='backbone')) 5 | mappings1 = dict( 6 | p3=dict(recorder='backbone', data_idx=0), 7 | p4=dict(recorder='backbone', data_idx=1), 8 | p5=dict(recorder='backbone', data_idx=2), 9 | p6=dict(recorder='backbone', data_idx=3)) 10 | 11 | # configs for the 2nd model 12 | recorders2 = dict( 13 | backbone=dict(_scope_='mmrazor', type='ModuleOutputs', source='backbone')) 14 | mappings2 = dict( 15 | p3=dict(recorder='backbone', data_idx=0), 16 | p4=dict(recorder='backbone', data_idx=1), 17 | p5=dict(recorder='backbone', data_idx=2), 18 | p6=dict(recorder='backbone', data_idx=3)) 19 | -------------------------------------------------------------------------------- /mmrazor/tools/visualizations/vis_configs/backbone_feature_visualization.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | recorders = dict( 3 | backbone=dict(_scope_='mmrazor', type='ModuleOutputs', source='backbone')) 4 | mappings = dict( 5 | p3=dict(recorder='backbone', data_idx=0), 6 | p4=dict(recorder='backbone', data_idx=1), 7 | p5=dict(recorder='backbone', data_idx=2), 8 | p6=dict(recorder='backbone', data_idx=3)) 9 | -------------------------------------------------------------------------------- /mmrazor/tools/visualizations/vis_configs/fpn_feature_diff_visualization.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | # configs for the 1st model 3 | recorders1 = dict( 4 | neck=dict(_scope_='mmrazor', type='ModuleOutputs', source='neck')) 5 | mappings1 = dict( 6 | p3=dict(recorder='neck', data_idx=0), 7 | p4=dict(recorder='neck', data_idx=1), 8 | p5=dict(recorder='neck', data_idx=2), 9 | p6=dict(recorder='neck', data_idx=3)) 10 | 11 | # configs for the 2nd model 12 | recorders2 = dict( 13 | neck=dict(_scope_='mmrazor', type='ModuleOutputs', source='neck')) 14 | mappings2 = dict( 15 | p3=dict(recorder='neck', data_idx=0), 16 | p4=dict(recorder='neck', data_idx=1), 17 | p5=dict(recorder='neck', data_idx=2), 18 | p6=dict(recorder='neck', data_idx=3)) 19 | -------------------------------------------------------------------------------- /mmrazor/tools/visualizations/vis_configs/fpn_feature_visualization.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | recorders = dict( 3 | neck=dict(_scope_='mmrazor', type='ModuleOutputs', source='neck')) 4 | mappings = dict( 5 | p3=dict(recorder='neck', data_idx=0), 6 | p4=dict(recorder='neck', data_idx=1), 7 | p5=dict(recorder='neck', data_idx=2), 8 | p6=dict(recorder='neck', data_idx=3)) 9 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "dms" 7 | version = "0.0.1" 8 | authors = [ 9 | { name="kai liu", email="author@example.com" }, 10 | ] 11 | description = "Differentiable Model Scaling" 12 | readme = "README.md" 13 | requires-python = ">=3.7" 14 | classifiers = [ 15 | "Programming Language :: Python :: 3", 16 | "License :: OSI Approved :: MIT License", 17 | "Operating System :: OS Independent", 18 | ] 19 | 20 | [project.urls] 21 | "Homepage" = "https://github.com/pypa/sampleproject" 22 | "Bug Tracker" = "https://github.com/pypa/sampleproject/issues" 23 | -------------------------------------------------------------------------------- /test_dms.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from dms import differentiable_topk, MASK_THRESHOLD 3 | import torch 4 | 5 | torch.set_printoptions(precision=3) 6 | 7 | 8 | class TestDMS(unittest.TestCase): 9 | def test_dms_topk(self): 10 | N = 10 11 | for a in [0, 0.3, 0.5, 0.7, 1.0]: 12 | c = torch.rand([N]) 13 | a = torch.tensor([a]) 14 | m = differentiable_topk(c, a, lambda_=c.numel()) 15 | 16 | n_remain = (m > MASK_THRESHOLD).float().sum().int().item() 17 | self.assertTrue(n_remain == int((1 - a) * N)) 18 | print( 19 | ( 20 | f"Pruning Ratio: {a}\n" 21 | f"Element Importance: {c}\n" 22 | f"Soft mask: {m}\n" 23 | f"Number of Remained Elements: {n_remain}\n" 24 | ), 25 | ) 26 | --------------------------------------------------------------------------------