├── .gitignore
├── .vscode
├── launch.json
└── settings.json
├── README.md
├── applications
└── efficientnet
│ ├── .gitignore
│ ├── .vscode
│ ├── launch.json
│ └── settings.json
│ ├── README.md
│ ├── dms_efficientnet.py
│ ├── requirements.txt
│ ├── scripts
│ └── DMS-450
│ │ ├── prune.sh
│ │ └── retrain.sh
│ ├── test_eff.py
│ ├── timm_dataset.py
│ ├── timm_pruning.py
│ └── timm_retrain.py
├── dms
├── __init__.py
├── dtopk_src.py
└── modules
│ ├── __init__.py
│ ├── algorithm.py
│ ├── dtp.py
│ ├── mutable.py
│ ├── mutator.py
│ ├── op.py
│ ├── optimizer.py
│ ├── scheduler.py
│ ├── unit.py
│ └── utils.py
├── images
└── compare.png
├── mmrazor
├── .circleci
│ ├── config.yml
│ ├── docker
│ │ └── Dockerfile
│ └── test.yml
├── .dev_scripts
│ ├── benchmark_summary_analyse.py
│ ├── benchmark_test.py
│ ├── benchmark_train.py
│ └── meta_files_test.py
├── .github
│ ├── CONTRIBUTING.md
│ ├── ISSUE_TEMPLATE
│ │ ├── bug_report.md
│ │ ├── config.yml
│ │ ├── feature_request.md
│ │ └── general-questions.md
│ ├── pull_request_template.md
│ └── workflows
│ │ ├── build.yml
│ │ ├── deploy.yml
│ │ └── lint.yml
├── .gitignore
├── .pre-commit-config.yaml
├── .readthedocs.yml
├── LICENSE
├── MANIFEST.in
├── README.md
├── README_zh-CN.md
├── configs
│ ├── _base_
│ │ ├── datasets
│ │ │ └── mmcls
│ │ │ │ ├── cifar100_bs16_auto_aug.py
│ │ │ │ └── pipelines
│ │ │ │ └── auto_aug_cifar.py
│ │ ├── nas_backbones
│ │ │ ├── attentive_mobilenetv3_supernet.py
│ │ │ ├── darts_supernet.py
│ │ │ ├── dsnas_shufflenet_supernet.py
│ │ │ ├── ofa_mobilenetv3_supernet.py
│ │ │ ├── spos_mobilenet_supernet.py
│ │ │ └── spos_shufflenet_supernet.py
│ │ ├── settings
│ │ │ ├── cifar10_darts_subnet.py
│ │ │ ├── cifar10_darts_supernet.py
│ │ │ ├── imagenet_bs1024_dsnas.py
│ │ │ ├── imagenet_bs1024_spos.py
│ │ │ ├── imagenet_bs2048_AdamW.py
│ │ │ ├── imagenet_bs2048_autoslim.py
│ │ │ ├── imagenet_bs2048_autoslim_pil.py
│ │ │ ├── imagenet_bs2048_bignas.py
│ │ │ ├── imagenet_bs2048_dmcp.py
│ │ │ └── imagenet_bs2048_ofa.py
│ │ └── vanilla_models
│ │ │ └── wrn16_2_cifar10.py
│ ├── distill
│ │ ├── mmcls
│ │ │ ├── abloss
│ │ │ │ ├── README.md
│ │ │ │ ├── abloss_logits_resnet50_resnet18_8xb32_in1k.py
│ │ │ │ ├── abloss_pretrain_backbone_resnet50_resnet18_8xb32_in1k.py
│ │ │ │ └── metafile.yml
│ │ │ ├── byot
│ │ │ │ ├── README.md
│ │ │ │ ├── byot_resnet18_8xb16_cifar100.py
│ │ │ │ └── metafile.yml
│ │ │ ├── crd
│ │ │ │ ├── README.md
│ │ │ │ ├── crd_neck_r50_r18_8xb16_cifar10.py
│ │ │ │ └── datasets
│ │ │ │ │ └── crd_cifar10_bs16.py
│ │ │ ├── dafl
│ │ │ │ ├── README.md
│ │ │ │ ├── dafl_logits_resnet34_resnet18_8xb256_cifar10.py
│ │ │ │ └── metafile.yml
│ │ │ ├── deit
│ │ │ │ ├── README.md
│ │ │ │ ├── deit-base_regnety160_pt-16xb64_in1k.py
│ │ │ │ └── metafile.yml
│ │ │ ├── dfad
│ │ │ │ ├── README.md
│ │ │ │ ├── dfad_logits_resnet34_resnet18_8xb32_cifar10.py
│ │ │ │ └── metafile.yml
│ │ │ ├── dkd
│ │ │ │ ├── README.md
│ │ │ │ ├── dkd_resnet34_resnet18_8xb32_in1k.py
│ │ │ │ └── metafile.yml
│ │ │ ├── factor_transfer
│ │ │ │ ├── README.md
│ │ │ │ ├── factor-transfer_backbone_resnet50_resnet18_8xb16_cifar10_pretrain.py
│ │ │ │ ├── factor-transfer_backbone_resnet50_resnet18_8xb16_cifar10_train.py
│ │ │ │ └── metafile.yml
│ │ │ ├── fitnets
│ │ │ │ ├── README.md
│ │ │ │ ├── fitnets_backbone_logits_resnet50_resnet18_8xb32_in1k.py
│ │ │ │ └── metafile.yml
│ │ │ ├── kd
│ │ │ │ ├── README.md
│ │ │ │ ├── kd_logits_resnet34_resnet18_8xb32_in1k.py
│ │ │ │ ├── kd_logits_resnet50_mobilenet-v2_8xb32_in1k.py
│ │ │ │ ├── kd_logits_resnet50_shufflenet-v2-1x_16xb64_in1k.py
│ │ │ │ └── metafile.yml
│ │ │ ├── ofd
│ │ │ │ ├── README.md
│ │ │ │ ├── metafile.yml
│ │ │ │ └── ofd_backbone_resnet50_resnet18_8xb16_cifar10.py
│ │ │ ├── rkd
│ │ │ │ ├── README.md
│ │ │ │ ├── metafile.yml
│ │ │ │ └── rkd_neck_resnet34_resnet18_8xb32_in1k.py
│ │ │ ├── wsld
│ │ │ │ ├── README.md
│ │ │ │ ├── metafile.yml
│ │ │ │ └── wsld_logits_resnet34_resnet18_8xb32_in1k.py
│ │ │ └── zskt
│ │ │ │ ├── README.md
│ │ │ │ ├── metafile.yml
│ │ │ │ └── zskt_backbone_logits_resnet34_resnet18_8xb16_cifar10.py
│ │ ├── mmdet
│ │ │ ├── cwd
│ │ │ │ ├── README.md
│ │ │ │ ├── cwd_cls_head_gfl_r101_fpn_gfl_r50_fpn_1x_coco.py
│ │ │ │ ├── cwd_fpn_frcnn_r101_frcnn_r50_1x_coco.py
│ │ │ │ ├── cwd_fpn_retina_r101_retina_r50_1x_coco.py
│ │ │ │ ├── cwd_fpn_retina_r101_retina_r50_1x_coco_visualization.py
│ │ │ │ └── metafile.yml
│ │ │ ├── fbkd
│ │ │ │ ├── README.md
│ │ │ │ ├── fbkd_fpn_faster-rcnn_r101_faster-rcnn_r50_1x_coco.py
│ │ │ │ └── metafile.yml
│ │ │ ├── mgd
│ │ │ │ ├── README.md
│ │ │ │ └── mgd_fpn_retina_x101_retina_r50_2x_coco.py
│ │ │ └── pkd
│ │ │ │ ├── README.md
│ │ │ │ ├── metafile.yml
│ │ │ │ ├── pkd_fpn_faster-rcnn_r101_faster-rcnn_r50_2x_coco.py
│ │ │ │ ├── pkd_fpn_fcos_x101_retina_r50_1x_coco.py
│ │ │ │ ├── pkd_fpn_mask-rcnn_swin_retina_r50_2x_coco.py
│ │ │ │ ├── pkd_fpn_reppoints_x101-dcn_reppoints_r50_2x_coco.py
│ │ │ │ └── pkd_fpn_retina_x101_retina_r50_2x_coco.py
│ │ ├── mmdet3d
│ │ │ └── pkd
│ │ │ │ ├── README.md
│ │ │ │ ├── metafile.yml
│ │ │ │ └── pkd_fpn_fcos3d_r101_fcos3d_r50_8xb2-1x_nus-mono3d.py
│ │ └── mmseg
│ │ │ └── cwd
│ │ │ ├── README.md
│ │ │ ├── cwd_logits_pspnet_r101-d8_pspnet_r18-d8_4xb2-80k_cityscapes-512x1024.py
│ │ │ └── metafile.yml
│ ├── nas
│ │ ├── mmcls
│ │ │ ├── autoformer
│ │ │ │ ├── AUTOFORMER_SUBNET_B.yaml
│ │ │ │ ├── README.md
│ │ │ │ ├── autoformer_search_8xb128_in1k.py
│ │ │ │ ├── autoformer_subnet_8xb256_in1k.py
│ │ │ │ └── autoformer_supernet_32xb256_in1k.py
│ │ │ ├── autoslim
│ │ │ │ ├── README.md
│ │ │ │ ├── autoslim_mbv2_1.5x_search_8xb256_in1k.py
│ │ │ │ ├── autoslim_mbv2_1.5x_slimmable_subnet_8xb256_in1k.py
│ │ │ │ ├── autoslim_mbv2_1.5x_subnet_8xb256_in1k_flops-220M.py
│ │ │ │ ├── autoslim_mbv2_1.5x_subnet_8xb256_in1k_flops-320M.py
│ │ │ │ ├── autoslim_mbv2_1.5x_subnet_8xb256_in1k_flops-530M.py
│ │ │ │ ├── autoslim_mbv2_1.5x_supernet_8xb256_in1k.py
│ │ │ │ └── metafile.yml
│ │ │ ├── bignas
│ │ │ │ ├── ATTENTIVE_SUBNET_A0.yaml
│ │ │ │ ├── ATTENTIVE_SUBNET_A1.yaml
│ │ │ │ ├── ATTENTIVE_SUBNET_A2.yaml
│ │ │ │ ├── ATTENTIVE_SUBNET_A3.yaml
│ │ │ │ ├── ATTENTIVE_SUBNET_A4.yaml
│ │ │ │ ├── ATTENTIVE_SUBNET_A5.yaml
│ │ │ │ ├── ATTENTIVE_SUBNET_A6.yaml
│ │ │ │ ├── README.md
│ │ │ │ ├── attentive_mobilenet_search_8xb128_in1k.py
│ │ │ │ ├── attentive_mobilenet_subnet_8xb256_in1k.py
│ │ │ │ └── attentive_mobilenet_supernet_32xb64_in1k.py
│ │ │ ├── darts
│ │ │ │ ├── DARTS_SUBNET_CIFAR_MMRAZOR_97.32.yaml
│ │ │ │ ├── DARTS_SUBNET_CIFAR_PAPER_ALIAS.yaml
│ │ │ │ ├── README.md
│ │ │ │ ├── darts_subnet_1xb96_cifar10_2.0.py
│ │ │ │ ├── darts_subnet_1xb96_cifar10_2.0_mmrazor.py
│ │ │ │ ├── darts_supernet_unroll_1xb96_cifar10.py
│ │ │ │ └── metafile.yml
│ │ │ ├── dsnas
│ │ │ │ ├── DSNAS_SUBNET_IMAGENET_PAPER_ALIAS.yaml
│ │ │ │ ├── README.md
│ │ │ │ ├── dsnas_subnet_8xb128_in1k.py
│ │ │ │ └── dsnas_supernet_8xb128_in1k.py
│ │ │ ├── onceforall
│ │ │ │ ├── OFA_SUBNET_NOTE8_LAT22.yaml
│ │ │ │ ├── OFA_SUBNET_NOTE8_LAT31.yaml
│ │ │ │ ├── README.md
│ │ │ │ ├── ofa_mobilenet_search_8xb128_in1k.py
│ │ │ │ ├── ofa_mobilenet_subnet_8xb256_in1k.py
│ │ │ │ └── ofa_mobilenet_supernet_32xb64_in1k.py
│ │ │ └── spos
│ │ │ │ ├── README.md
│ │ │ │ ├── SPOS_SUBNET.yaml
│ │ │ │ ├── faster-rcnn_nas_backbone_fpn_1x_coco.py
│ │ │ │ ├── metafile.yml
│ │ │ │ ├── spos_mobilenet_search_8xb128_in1k.py
│ │ │ │ ├── spos_mobilenet_subnet_8xb128_in1k.py
│ │ │ │ ├── spos_mobilenet_supernet_8xb128_in1k.py
│ │ │ │ ├── spos_shufflenet_search_8xb128_in1k.py
│ │ │ │ ├── spos_shufflenet_search_predictor_8xb128_in1k.py
│ │ │ │ ├── spos_shufflenet_subnet_8xb128_in1k.py
│ │ │ │ └── spos_shufflenet_supernet_8xb128_in1k.py
│ │ └── mmdet
│ │ │ └── detnas
│ │ │ ├── DETNAS_SUBNET.yaml
│ │ │ ├── README.md
│ │ │ ├── detnas_frcnn_shufflenet_search_coco_1x.py
│ │ │ ├── detnas_frcnn_shufflenet_subnet_coco_1x.py
│ │ │ ├── detnas_frcnn_shufflenet_supernet_coco_1x.py
│ │ │ ├── detnas_retina_shufflenet_supernet_coco_1x.py
│ │ │ ├── detnas_shufflenet_subnet_8xb128_in1k.py
│ │ │ ├── detnas_shufflenet_supernet_8xb128_in1k.py
│ │ │ └── metafile.yml
│ ├── pruning
│ │ ├── base
│ │ │ └── group_fisher
│ │ │ │ ├── README.md
│ │ │ │ ├── group_fisher_deploy_template.py
│ │ │ │ ├── group_fisher_finetune_template.py
│ │ │ │ └── group_fisher_prune_template.py
│ │ ├── mmcls
│ │ │ ├── dcff
│ │ │ │ ├── README.md
│ │ │ │ ├── dcff_compact_resnet_8xb32_in1k.py
│ │ │ │ ├── dcff_resnet_8xb32_in1k.py
│ │ │ │ └── fix_subnet.json
│ │ │ ├── dmcp
│ │ │ │ ├── DMCP_MBV2_100M.json
│ │ │ │ ├── DMCP_R50_2G.json
│ │ │ │ ├── README.md
│ │ │ │ ├── dmcp_mbv2_subnet_32xb64.py
│ │ │ │ ├── dmcp_mbv2_supernet_32xb64.py
│ │ │ │ ├── dmcp_resnet50_subnet_32xb64.py
│ │ │ │ ├── dmcp_resnet50_supernet_32xb64.py
│ │ │ │ └── metafile.yml
│ │ │ ├── group_fisher
│ │ │ │ ├── README.md
│ │ │ │ ├── mobilenet
│ │ │ │ │ ├── group_fisher_act_deploy_mobilenet-v2_8xb32_in1k.py
│ │ │ │ │ ├── group_fisher_act_finetune_mobilenet-v2_8xb32_in1k.py
│ │ │ │ │ ├── group_fisher_act_prune_mobilenet-v2_8xb32_in1k.py
│ │ │ │ │ ├── group_fisher_flops_deploy_mobilenet-v2_8xb32_in1k.py
│ │ │ │ │ ├── group_fisher_flops_finetune_mobilenet-v2_8xb32_in1k.py
│ │ │ │ │ ├── group_fisher_flops_prune_mobilenet-v2_8xb32_in1k.py
│ │ │ │ │ ├── metafile.yml
│ │ │ │ │ └── script.sh
│ │ │ │ └── resnet50
│ │ │ │ │ ├── group_fisher_act_deploy_resnet50_8xb32_in1k.py
│ │ │ │ │ ├── group_fisher_act_finetune_resnet50_8xb32_in1k.py
│ │ │ │ │ ├── group_fisher_act_finetune_resnet50_8xb32_in1k_dist.py
│ │ │ │ │ ├── group_fisher_act_prune_resnet50_8xb32_in1k.py
│ │ │ │ │ ├── group_fisher_flops_deploy_resnet50_8xb32_in1k.py
│ │ │ │ │ ├── group_fisher_flops_finetune_resnet50_8xb32_in1k.py
│ │ │ │ │ ├── group_fisher_flops_prune_resnet50_8xb32_in1k.py
│ │ │ │ │ ├── metafile.yml
│ │ │ │ │ └── script.sh
│ │ │ └── l1-norm
│ │ │ │ ├── README.md
│ │ │ │ ├── l1-norm_resnet34_8xb32_in1k_a.py
│ │ │ │ ├── l1-norm_resnet34_8xb32_in1k_a_deploy.py
│ │ │ │ ├── l1-norm_resnet34_8xb32_in1k_b.py
│ │ │ │ ├── l1-norm_resnet34_8xb32_in1k_b_deploy.py
│ │ │ │ ├── l1-norm_resnet34_8xb32_in1k_c.py
│ │ │ │ ├── l1-norm_resnet34_8xb32_in1k_c_deploy.py
│ │ │ │ ├── metafile.yml
│ │ │ │ └── script.sh
│ │ ├── mmdet
│ │ │ ├── dcff
│ │ │ │ ├── README.md
│ │ │ │ ├── dcff_compact_faster_rcnn_resnet50_8xb4_coco.py
│ │ │ │ ├── dcff_faster_rcnn_resnet50_8xb4_coco.py
│ │ │ │ ├── dcff_faster_rcnn_resnet50_fpn.py
│ │ │ │ └── fix_subnet.json
│ │ │ └── group_fisher
│ │ │ │ ├── README.md
│ │ │ │ └── retinanet
│ │ │ │ ├── group_fisher_act_deploy_retinanet_r50_fpn_1x_coco.py
│ │ │ │ ├── group_fisher_act_finetune_retinanet_r50_fpn_1x_coco.py
│ │ │ │ ├── group_fisher_act_prune_retinanet_r50_fpn_1x_coco.py
│ │ │ │ ├── group_fisher_flops_deploy_retinanet_r50_fpn_1x_coco.py
│ │ │ │ ├── group_fisher_flops_finetune_retinanet_r50_fpn_1x_coco.py
│ │ │ │ ├── group_fisher_flops_prune_retinanet_r50_fpn_1x_coco.py
│ │ │ │ ├── metafile.yml
│ │ │ │ └── script.sh
│ │ ├── mmpose
│ │ │ ├── dcff
│ │ │ │ ├── README.md
│ │ │ │ ├── dcff_compact_topdown_heatmap_resnet50_coco.py
│ │ │ │ ├── dcff_topdown_heatmap_resnet50_coco.py
│ │ │ │ └── fix_subnet.json
│ │ │ └── group_fisher
│ │ │ │ ├── group_fisher_deploy_rtmpose-s_8xb256-420e_aic-coco-256x192.py
│ │ │ │ ├── group_fisher_deploy_rtmpose-s_8xb256-420e_coco-256x192.py
│ │ │ │ ├── group_fisher_finetune_rtmpose-s_8xb256-420e_aic-coco-256x192.py
│ │ │ │ ├── group_fisher_finetune_rtmpose-s_8xb256-420e_coco-256x192.py
│ │ │ │ ├── group_fisher_prune_rtmpose-s_8xb256-420e_aic-coco-256x192.py
│ │ │ │ ├── group_fisher_prune_rtmpose-s_8xb256-420e_coco-256x192.py
│ │ │ │ └── script.sh
│ │ └── mmseg
│ │ │ └── dcff
│ │ │ ├── README.md
│ │ │ ├── dcff_compact_pointrend_resnet50_8xb2_cityscapes.py
│ │ │ ├── dcff_pointrend_resnet50_8xb2_cityscapes.py
│ │ │ ├── fix_subnet.json
│ │ │ └── pointrend_resnet50.py
│ ├── quantization
│ │ ├── deploy_cfgs
│ │ │ ├── mmcls
│ │ │ │ ├── classification_openvino_dynamic-224x224.py
│ │ │ │ └── classification_tensorrt-int8-explicit_dynamic-224x224.py
│ │ │ └── mmdet
│ │ │ │ ├── detection_openvino_dynamic-800x1344.py
│ │ │ │ └── detection_tensorrt-int8-explicit_dynamic-320x320-1344x1344.py
│ │ ├── ptq
│ │ │ └── base
│ │ │ │ ├── README.md
│ │ │ │ ├── metafile.yml
│ │ │ │ ├── ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py
│ │ │ │ ├── ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py
│ │ │ │ ├── ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py
│ │ │ │ ├── ptq_openvino_retina_r50_1x_coco_calib32xb32.py
│ │ │ │ ├── ptq_openvino_yolox_s_8xb8-300e_coco_calib32xb32.py
│ │ │ │ ├── ptq_tensorrt_mbv2_8xb32_in1k_calib32xb32.py
│ │ │ │ ├── ptq_tensorrt_resnet18_8xb32_in1k_calib32xb32.py
│ │ │ │ ├── ptq_tensorrt_resnet50_8xb32_in1k_calib32xb32.py
│ │ │ │ ├── ptq_tensorrt_retina_r50_1x_coco_calib32xb32.py
│ │ │ │ └── ptq_tensorrt_yolox_s_8xb8-300e_coco_calib32xb32.py
│ │ └── qat
│ │ │ ├── base
│ │ │ ├── README.md
│ │ │ ├── metafile.yml
│ │ │ └── qat_openvino_resnet18_10e_8xb32_in1k.py
│ │ │ └── lsq
│ │ │ ├── README.md
│ │ │ ├── lsq_openvino_resnet18_8xb32_100e_in1k.py
│ │ │ ├── lsq_openvino_resnet18_8xb32_10e_in1k.py
│ │ │ └── metafile.yml
│ └── vanilla
│ │ └── mmcls
│ │ └── wide-resnet
│ │ ├── README.md
│ │ ├── wrn16-w2_b16x8_cifar10.py
│ │ ├── wrn22-w4_b16x8_cifar10.py
│ │ ├── wrn28-w4_b16x8_cifar10.py
│ │ └── wrn40-w2_b16x8_cifar10.py
├── demo
│ └── pruning
│ │ └── config_pruning.ipynb
├── docker
│ ├── Dockerfile
│ └── serve
│ │ ├── Dockerfile
│ │ ├── config.properties
│ │ └── entrypoint.sh
├── docs
│ ├── en
│ │ ├── Makefile
│ │ ├── _static
│ │ │ ├── css
│ │ │ │ └── readthedocs.css
│ │ │ └── image
│ │ │ │ └── mmrazor-logo.png
│ │ ├── advanced_guides
│ │ │ ├── algorithm.md
│ │ │ ├── apply_existing_algorithms_to_new_tasks.md
│ │ │ ├── customize_architectures.md
│ │ │ ├── customize_kd_algorithms.md
│ │ │ ├── customize_mixed_algorithms.md
│ │ │ ├── customize_nas_algorithms.md
│ │ │ ├── customize_pruning_algorithms.md
│ │ │ ├── customize_quantization_algorithms.md
│ │ │ ├── delivery.md
│ │ │ ├── index.rst
│ │ │ ├── mutable.md
│ │ │ ├── mutator.md
│ │ │ ├── recorder.md
│ │ │ └── tutorials
│ │ │ │ ├── how_to_prune_your_model.md
│ │ │ │ └── how_to_use_config_tool_of_pruning.md
│ │ ├── api.rst
│ │ ├── conf.py
│ │ ├── get_started
│ │ │ ├── installation.md
│ │ │ ├── model_zoo.md
│ │ │ └── overview.md
│ │ ├── imgs
│ │ │ └── pruning
│ │ │ │ ├── draw-config.png
│ │ │ │ ├── framework-ChanelMutator.png
│ │ │ │ ├── framework-algorithm.png
│ │ │ │ ├── framework-framework.png
│ │ │ │ ├── framework-graph.png
│ │ │ │ ├── framework-op.png
│ │ │ │ ├── pruning_framework.png
│ │ │ │ └── unit.png
│ │ ├── index.rst
│ │ ├── make.bat
│ │ ├── notes
│ │ │ ├── changelog.md
│ │ │ ├── contribution_guide.md
│ │ │ └── faq.md
│ │ ├── switch_language.md
│ │ └── user_guides
│ │ │ ├── 1_learn_about_config.md
│ │ │ ├── 2_train_different_types_algorithms.md
│ │ │ ├── 3_train_with_different_devices.md
│ │ │ ├── 4_test_a_model.md
│ │ │ ├── index.rst
│ │ │ ├── pruning_user_guide.md
│ │ │ └── quantization_user_guide.md
│ └── zh_cn
│ │ ├── Makefile
│ │ ├── api.rst
│ │ ├── conf.py
│ │ ├── index.rst
│ │ ├── make.bat
│ │ ├── switch_language.md
│ │ └── user_guides
│ │ └── visualization.md
├── mmrazor
│ ├── __init__.py
│ ├── datasets
│ │ ├── __init__.py
│ │ ├── crd_dataset_wrapper.py
│ │ └── transforms
│ │ │ ├── __init__.py
│ │ │ ├── auto_augment.py
│ │ │ ├── auto_augmentv2.py
│ │ │ └── formatting.py
│ ├── engine
│ │ ├── __init__.py
│ │ ├── hooks
│ │ │ ├── __init__.py
│ │ │ ├── dmcp_subnet_hook.py
│ │ │ ├── dump_subnet_hook.py
│ │ │ ├── estimate_resources_hook.py
│ │ │ ├── group_fisher_hooks.py
│ │ │ ├── stop_distillation_hook.py
│ │ │ └── visualization_hook.py
│ │ ├── optimizers
│ │ │ ├── __init__.py
│ │ │ └── optimizer_constructor.py
│ │ └── runner
│ │ │ ├── __init__.py
│ │ │ ├── autoslim_greedy_search_loop.py
│ │ │ ├── darts_loop.py
│ │ │ ├── distill_val_loop.py
│ │ │ ├── evolution_search_loop.py
│ │ │ ├── iteprune_val_loop.py
│ │ │ ├── quantization_loops.py
│ │ │ ├── slimmable_val_loop.py
│ │ │ ├── subnet_sampler_loop.py
│ │ │ ├── subnet_val_loop.py
│ │ │ └── utils
│ │ │ ├── __init__.py
│ │ │ ├── calibrate_bn_mixin.py
│ │ │ ├── check.py
│ │ │ └── genetic.py
│ ├── implementations
│ │ ├── __init__.py
│ │ ├── pruning
│ │ │ ├── __init__.py
│ │ │ ├── group_fisher
│ │ │ │ ├── __init__.py
│ │ │ │ ├── algorithm.py
│ │ │ │ ├── counters.py
│ │ │ │ ├── hook.py
│ │ │ │ ├── mutator.py
│ │ │ │ ├── ops.py
│ │ │ │ ├── prune_deploy_sub_model.py
│ │ │ │ ├── prune_sub_model.py
│ │ │ │ └── unit.py
│ │ │ └── sparse_gpt
│ │ │ │ ├── __init__.py
│ │ │ │ ├── compressor.py
│ │ │ │ ├── ops.py
│ │ │ │ ├── sparse24_utils.py
│ │ │ │ └── utils.py
│ │ └── quantization
│ │ │ └── gptq
│ │ │ ├── __init__.py
│ │ │ ├── compressor.py
│ │ │ ├── custom_autotune.py
│ │ │ ├── gptq.py
│ │ │ ├── ops.py
│ │ │ ├── quantizer.py
│ │ │ └── utils.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── algorithms
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── distill
│ │ │ │ ├── __init__.py
│ │ │ │ └── configurable
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── datafree_distillation.py
│ │ │ │ │ ├── fpn_teacher_distill.py
│ │ │ │ │ ├── overhaul_feature_distillation.py
│ │ │ │ │ ├── self_distill.py
│ │ │ │ │ └── single_teacher_distill.py
│ │ │ ├── nas
│ │ │ │ ├── __init__.py
│ │ │ │ ├── autoformer.py
│ │ │ │ ├── autoslim.py
│ │ │ │ ├── bignas.py
│ │ │ │ ├── darts.py
│ │ │ │ ├── dsnas.py
│ │ │ │ └── spos.py
│ │ │ ├── pruning
│ │ │ │ ├── __init__.py
│ │ │ │ ├── dcff.py
│ │ │ │ ├── dmcp.py
│ │ │ │ ├── group_fisher_algoritho.py
│ │ │ │ ├── ite_prune_algorithm.py
│ │ │ │ └── slimmable_network.py
│ │ │ └── quantization
│ │ │ │ ├── __init__.py
│ │ │ │ └── mm_architecture.py
│ │ ├── architectures
│ │ │ ├── __init__.py
│ │ │ ├── backbones
│ │ │ │ ├── __init__.py
│ │ │ │ ├── darts_backbone.py
│ │ │ │ ├── searchable_autoformer.py
│ │ │ │ ├── searchable_mobilenet_v2.py
│ │ │ │ ├── searchable_mobilenet_v3.py
│ │ │ │ ├── searchable_shufflenet_v2.py
│ │ │ │ └── wideresnet.py
│ │ │ ├── classifiers
│ │ │ │ ├── __init__.py
│ │ │ │ └── image.py
│ │ │ ├── connectors
│ │ │ │ ├── __init__.py
│ │ │ │ ├── base_connector.py
│ │ │ │ ├── byot_connector.py
│ │ │ │ ├── convmodule_connector.py
│ │ │ │ ├── crd_connector.py
│ │ │ │ ├── factor_transfer_connectors.py
│ │ │ │ ├── fbkd_connector.py
│ │ │ │ ├── mgd_connector.py
│ │ │ │ ├── norm_connector.py
│ │ │ │ ├── ofd_connector.py
│ │ │ │ └── torch_connector.py
│ │ │ ├── dynamic_ops
│ │ │ │ ├── __init__.py
│ │ │ │ ├── bricks
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── dynamic_container.py
│ │ │ │ │ ├── dynamic_conv.py
│ │ │ │ │ ├── dynamic_embed.py
│ │ │ │ │ ├── dynamic_function.py
│ │ │ │ │ ├── dynamic_linear.py
│ │ │ │ │ ├── dynamic_multi_head_attention.py
│ │ │ │ │ ├── dynamic_norm.py
│ │ │ │ │ ├── dynamic_relative_position.py
│ │ │ │ │ └── group_fisher_ops.py
│ │ │ │ ├── head
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── dynamic_linear_head.py
│ │ │ │ └── mixins
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── dynamic_conv_mixins.py
│ │ │ │ │ ├── dynamic_layernorm_mixins.py
│ │ │ │ │ └── dynamic_mixins.py
│ │ │ ├── generators
│ │ │ │ ├── __init__.py
│ │ │ │ ├── base_generator.py
│ │ │ │ ├── dafl_generator.py
│ │ │ │ └── zskt_generator.py
│ │ │ ├── heads
│ │ │ │ ├── __init__.py
│ │ │ │ ├── darts_subnet_head.py
│ │ │ │ └── deit_head.py
│ │ │ ├── necks
│ │ │ │ ├── __init__.py
│ │ │ │ └── squeezemean_with_dropout.py
│ │ │ ├── ops
│ │ │ │ ├── __init__.py
│ │ │ │ ├── base.py
│ │ │ │ ├── common.py
│ │ │ │ ├── darts_series.py
│ │ │ │ ├── efficientnet_series.py
│ │ │ │ ├── function.py
│ │ │ │ ├── gather_tensors.py
│ │ │ │ ├── mobilenet_series.py
│ │ │ │ ├── shufflenet_series.py
│ │ │ │ └── transformer_series.py
│ │ │ └── utils
│ │ │ │ ├── __init__.py
│ │ │ │ ├── mutable_register.py
│ │ │ │ └── set_dropout.py
│ │ ├── distillers
│ │ │ ├── __init__.py
│ │ │ ├── base_distiller.py
│ │ │ ├── byot_distiller.py
│ │ │ ├── configurable_distiller.py
│ │ │ └── ofd_distiller.py
│ │ ├── fake_quants
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── lsq.py
│ │ │ └── torch_fake_quants.py
│ │ ├── losses
│ │ │ ├── __init__.py
│ │ │ ├── ab_loss.py
│ │ │ ├── at_loss.py
│ │ │ ├── crd_loss.py
│ │ │ ├── cross_entropy_loss.py
│ │ │ ├── cwd.py
│ │ │ ├── dafl_loss.py
│ │ │ ├── decoupled_kd.py
│ │ │ ├── dist_loss.py
│ │ │ ├── factor_transfer_loss.py
│ │ │ ├── fbkd_loss.py
│ │ │ ├── kd_soft_ce_loss.py
│ │ │ ├── kl_divergence.py
│ │ │ ├── l1_loss.py
│ │ │ ├── l2_loss.py
│ │ │ ├── mgd_loss.py
│ │ │ ├── ofd_loss.py
│ │ │ ├── pkd_loss.py
│ │ │ ├── relational_kd.py
│ │ │ └── weighted_soft_label_distillation.py
│ │ ├── mutables
│ │ │ ├── __init__.py
│ │ │ ├── base_mutable.py
│ │ │ ├── derived_mutable.py
│ │ │ ├── mutable_channel
│ │ │ │ ├── MutableChannel.md
│ │ │ │ ├── __init__.py
│ │ │ │ ├── base_mutable_channel.py
│ │ │ │ ├── mutable_channel_container.py
│ │ │ │ ├── oneshot_mutable_channel.py
│ │ │ │ ├── sequential_mutable_channel.py
│ │ │ │ ├── simple_mutable_channel.py
│ │ │ │ └── units
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── channel_unit.py
│ │ │ │ │ ├── dcff_channel_unit.py
│ │ │ │ │ ├── dmcp_channel_unit.py
│ │ │ │ │ ├── group_fisher_unit.py
│ │ │ │ │ ├── l1_mutable_channel_unit.py
│ │ │ │ │ ├── mutable_channel_unit.ipynb
│ │ │ │ │ ├── mutable_channel_unit.py
│ │ │ │ │ ├── one_shot_mutable_channel_unit.py
│ │ │ │ │ ├── sequential_mutable_channel_unit.py
│ │ │ │ │ ├── slimmable_channel_unit.py
│ │ │ │ │ └── utils.py
│ │ │ ├── mutable_module
│ │ │ │ ├── __init__.py
│ │ │ │ ├── diff_mutable_module.py
│ │ │ │ ├── mutable_module.py
│ │ │ │ └── one_shot_mutable_module.py
│ │ │ └── mutable_value
│ │ │ │ ├── __init__.py
│ │ │ │ └── mutable_value.py
│ │ ├── mutators
│ │ │ ├── __init__.py
│ │ │ ├── base_mutator.py
│ │ │ ├── channel_mutator
│ │ │ │ ├── __init__.py
│ │ │ │ ├── channel_mutator.ipynb
│ │ │ │ ├── channel_mutator.py
│ │ │ │ ├── dcff_channel_mutator.py
│ │ │ │ ├── dmcp_channel_mutator.py
│ │ │ │ ├── group_fisher_mutator.py
│ │ │ │ ├── one_shot_channel_mutator.py
│ │ │ │ └── slimmable_channel_mutator.py
│ │ │ ├── group_mixin.py
│ │ │ └── nas_mutator.py
│ │ ├── observers
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── lsq.py
│ │ │ └── torch_observers.py
│ │ ├── quantizers
│ │ │ ├── __init__.py
│ │ │ ├── academic_quantizer.py
│ │ │ ├── base.py
│ │ │ ├── exporters
│ │ │ │ ├── __init__.py
│ │ │ │ ├── base_quantize_exporter.py
│ │ │ │ ├── openvino_quantize_exporter.py
│ │ │ │ ├── optim_utils.py
│ │ │ │ └── tensorrt_quantize_exporter.py
│ │ │ ├── native_quantizer.py
│ │ │ ├── openvino_quantizer.py
│ │ │ └── tensorrt_quantizer.py
│ │ ├── task_modules
│ │ │ ├── __init__.py
│ │ │ ├── delivery
│ │ │ │ ├── __init__.py
│ │ │ │ ├── delivery_manager.py
│ │ │ │ ├── distill_delivery.py
│ │ │ │ ├── function_outputs_delivery.py
│ │ │ │ └── method_outputs_delivery.py
│ │ │ ├── demo_inputs
│ │ │ │ ├── __init__.py
│ │ │ │ ├── default_demo_inputs.py
│ │ │ │ ├── demo_inputs.py
│ │ │ │ ├── mmpose_demo_input.py
│ │ │ │ └── mmseg_demo_input.py
│ │ │ ├── estimators
│ │ │ │ ├── __init__.py
│ │ │ │ ├── base_estimator.py
│ │ │ │ ├── counters
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── flops_params_counter.py
│ │ │ │ │ ├── latency_counter.py
│ │ │ │ │ └── op_counters
│ │ │ │ │ │ ├── __init__.py
│ │ │ │ │ │ ├── activation_layer_counter.py
│ │ │ │ │ │ ├── base_counter.py
│ │ │ │ │ │ ├── conv_layer_counter.py
│ │ │ │ │ │ ├── deconv_layer_counter.py
│ │ │ │ │ │ ├── group_fisher_counters.py
│ │ │ │ │ │ ├── linear_layer_counter.py
│ │ │ │ │ │ ├── norm_layer_counter.py
│ │ │ │ │ │ ├── pooling_layer_counter.py
│ │ │ │ │ │ └── upsample_layer_counter.py
│ │ │ │ └── resource_estimator.py
│ │ │ ├── predictor
│ │ │ │ ├── __init__.py
│ │ │ │ ├── handler
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── base_handler.py
│ │ │ │ │ ├── carts_handler.py
│ │ │ │ │ ├── gp_handler.py
│ │ │ │ │ ├── mlp_handler.py
│ │ │ │ │ └── rbf_handler.py
│ │ │ │ └── metric_predictor.py
│ │ │ ├── recorder
│ │ │ │ ├── __init__.py
│ │ │ │ ├── base_recorder.py
│ │ │ │ ├── function_inputs_recorder.py
│ │ │ │ ├── function_outputs_recorder.py
│ │ │ │ ├── method_inputs_recorder.py
│ │ │ │ ├── method_outputs_recorder.py
│ │ │ │ ├── module_inputs_recorder.py
│ │ │ │ ├── module_outputs_recorder.py
│ │ │ │ ├── param_recorder.py
│ │ │ │ └── recorder_manager.py
│ │ │ └── tracer
│ │ │ │ ├── __init__.py
│ │ │ │ ├── backward_tracer.py
│ │ │ │ ├── channel_analyzer.py
│ │ │ │ ├── fx
│ │ │ │ ├── __init__.py
│ │ │ │ ├── custom_tracer.py
│ │ │ │ └── graph_utils.py
│ │ │ │ ├── fx_tracer.py
│ │ │ │ ├── loss_calculator
│ │ │ │ ├── __init__.py
│ │ │ │ ├── cascade_encoder_decoder_loss_calculator.py
│ │ │ │ ├── image_classifier_loss_calculator.py
│ │ │ │ ├── single_stage_detector_loss_calculator.py
│ │ │ │ ├── sum_loss_calculator.py
│ │ │ │ ├── top_down_pose_estimator_loss_calculator.py
│ │ │ │ └── two_stage_detector_loss_calculator.py
│ │ │ │ ├── parsers.py
│ │ │ │ └── path.py
│ │ └── utils
│ │ │ ├── __init__.py
│ │ │ ├── expandable_utils
│ │ │ ├── __init__.py
│ │ │ ├── ops.py
│ │ │ ├── tools.py
│ │ │ └── unit.py
│ │ │ ├── make_divisible.py
│ │ │ ├── misc.py
│ │ │ ├── optim_wrapper.py
│ │ │ ├── parse_values.py
│ │ │ ├── quantization_util.py
│ │ │ └── utils.py
│ ├── registry
│ │ ├── __init__.py
│ │ └── registry.py
│ ├── structures
│ │ ├── __init__.py
│ │ ├── graph
│ │ │ ├── __init__.py
│ │ │ ├── base_graph.py
│ │ │ ├── channel_flow.py
│ │ │ ├── channel_graph.py
│ │ │ ├── channel_nodes.py
│ │ │ ├── module_graph.py
│ │ │ └── pseudo_fx_graph.py
│ │ ├── quantization
│ │ │ ├── __init__.py
│ │ │ ├── backend_config
│ │ │ │ ├── __init__.py
│ │ │ │ ├── academic.py
│ │ │ │ ├── common_operator_config_utils.py
│ │ │ │ ├── mapping.py
│ │ │ │ ├── native.py
│ │ │ │ ├── openvino.py
│ │ │ │ └── tensorrt.py
│ │ │ └── qconfig.py
│ │ └── subnet
│ │ │ ├── __init__.py
│ │ │ ├── candidate.py
│ │ │ └── fix_subnet.py
│ ├── testing
│ │ ├── __init__.py
│ │ ├── _fast_stop_training_hook.py
│ │ └── _fx_models.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── index_dict.py
│ │ ├── log_tools.py
│ │ ├── misc.py
│ │ ├── placeholder.py
│ │ ├── runtime_info.py
│ │ ├── setup_env.py
│ │ └── typing.py
│ ├── version.py
│ └── visualization
│ │ ├── __init__.py
│ │ └── local_visualizer.py
├── model-index.yml
├── projects
│ └── mmrazor_large
│ │ ├── README.md
│ │ ├── algorithms
│ │ ├── GPTQ.md
│ │ └── SparseGPT.md
│ │ └── examples
│ │ ├── ResNet
│ │ ├── README.md
│ │ ├── resnet18_gptq.py
│ │ └── resnet18_sparse_gpt.py
│ │ └── language_models
│ │ ├── LLaMA
│ │ ├── README.md
│ │ ├── datautils.py
│ │ ├── llama_gptq.py
│ │ ├── llama_sparse_gpt.py
│ │ ├── llama_sparse_gpt_fsdp.py
│ │ └── utils.py
│ │ └── OPT
│ │ ├── README.md
│ │ ├── datautils.py
│ │ ├── opt_gptq.py
│ │ ├── opt_sparse_gpt.py
│ │ ├── opt_sparse_gpt_fsdp.py
│ │ └── utils.py
├── requirements.txt
├── requirements
│ ├── docs.txt
│ ├── mminstall.txt
│ ├── optional.txt
│ ├── readthedocs.txt
│ ├── runtime.txt
│ └── tests.txt
├── resources
│ ├── design_and_implement.png
│ ├── mmrazor-logo.png
│ ├── qq_group_qrcode.jpg
│ ├── xiaozhushou_weixin_qrcode.jpeg
│ └── zhihu_qrcode.jpg
├── setup.cfg
├── setup.py
├── tests
│ ├── __init__.py
│ ├── data
│ │ ├── MBV2_220M.yaml
│ │ ├── MBV2_320M.yaml
│ │ ├── MBV2_530M.yaml
│ │ ├── MBV2_slimmable_channel_config.json
│ │ ├── MBV2_slimmable_config.json
│ │ ├── __init__.py
│ │ ├── color.jpeg
│ │ ├── concat_subnet1.yaml
│ │ ├── concat_subnet2.yaml
│ │ ├── dataset
│ │ │ ├── a
│ │ │ │ └── 1.JPG
│ │ │ ├── ann.json
│ │ │ ├── ann.txt
│ │ │ ├── b
│ │ │ │ ├── 2.jpeg
│ │ │ │ └── subb
│ │ │ │ │ └── 3.jpg
│ │ │ ├── classes.txt
│ │ │ └── multi_label_ann.json
│ │ ├── model_library.py
│ │ ├── models.py
│ │ ├── subnet1.yaml
│ │ ├── subnet2.yaml
│ │ ├── test_models
│ │ │ ├── test_algorithm
│ │ │ │ └── MBV2_220M.yaml
│ │ │ ├── test_mutator
│ │ │ │ └── subnet1.json
│ │ │ ├── test_subnet
│ │ │ │ └── mockmodel_subnet.yaml
│ │ │ └── test_task_modules
│ │ │ │ └── mmcls_cfg.py
│ │ ├── test_registry
│ │ │ ├── registry_architecture_config.py
│ │ │ ├── registry_subnet_config.py
│ │ │ └── subnet.json
│ │ └── tracer_passed_models.py
│ ├── test_core
│ │ ├── __init__.py
│ │ ├── test_delivers
│ │ │ ├── test_deliver_manager.py
│ │ │ ├── test_function_outputs_deliver.py
│ │ │ ├── test_method_outputs_deliver.py
│ │ │ └── toy_module.py
│ │ ├── test_graph
│ │ │ ├── __init__.py
│ │ │ ├── test_channel_flow.py
│ │ │ ├── test_channel_graph.py
│ │ │ ├── test_graph.py
│ │ │ └── test_prune_tracer_model.py
│ │ ├── test_recorders
│ │ │ ├── test_base_recorder.py
│ │ │ ├── test_func_inputs_recorder.py
│ │ │ ├── test_func_outputs_recorder.py
│ │ │ ├── test_method_inputs_recorder.py
│ │ │ ├── test_method_outputs_recorder.py
│ │ │ ├── test_module_recorders.py
│ │ │ ├── test_param_recorder.py
│ │ │ ├── test_recorder_manager.py
│ │ │ └── toy_mod.py
│ │ └── test_tracer
│ │ │ ├── __init__.py
│ │ │ ├── test_backward_tracer.py
│ │ │ ├── test_fx_tracer.py
│ │ │ ├── test_loss_calculator.py
│ │ │ └── test_prune_tracer.py
│ ├── test_data.py
│ ├── test_datasets
│ │ ├── test_datasets.py
│ │ └── test_transforms
│ │ │ └── test_formatting.py
│ ├── test_doc.py
│ ├── test_engine
│ │ └── test_hooks
│ │ │ ├── test_stop_distillation_hook.py
│ │ │ └── test_visualization_hook.py
│ ├── test_impl
│ │ ├── __init__.py
│ │ ├── test_pruning
│ │ │ ├── __init__.py
│ │ │ ├── test_group_fisher
│ │ │ │ ├── __init__.py
│ │ │ │ ├── test_algorithm.py
│ │ │ │ ├── test_prune_deploy_sub_model.py
│ │ │ │ ├── test_prune_sub_model.py
│ │ │ │ └── test_unit.py
│ │ │ └── test_sparse_gpt
│ │ │ │ └── test_op.py
│ │ └── test_quantization
│ │ │ └── test_gptq
│ │ │ └── test_op_gptq.py
│ ├── test_models
│ │ ├── __init__.py
│ │ ├── test_algorithms
│ │ │ ├── __init__.py
│ │ │ ├── test_autoformer.py
│ │ │ ├── test_autoslim.py
│ │ │ ├── test_base_algorithm.py
│ │ │ ├── test_bignas.py
│ │ │ ├── test_darts.py
│ │ │ ├── test_datafree_distill.py
│ │ │ ├── test_dcff_network.py
│ │ │ ├── test_dmcp.py
│ │ │ ├── test_dsnas.py
│ │ │ ├── test_general_quant.py
│ │ │ ├── test_mm_architecture.py
│ │ │ ├── test_ofd_algo.py
│ │ │ ├── test_prune_algorithm.py
│ │ │ ├── test_self_distill.py
│ │ │ ├── test_single_teacher_distill.py
│ │ │ ├── test_slimmable_network.py
│ │ │ ├── test_spos.py
│ │ │ └── toy_models.py
│ │ ├── test_architectures
│ │ │ ├── test_backbones
│ │ │ │ ├── test_autoformerbackbone.py
│ │ │ │ ├── test_dartsbackbone.py
│ │ │ │ ├── test_searchable_mobilenet_v2.py
│ │ │ │ ├── test_searchable_mobilenet_v3.py
│ │ │ │ ├── test_searchable_shufflenet_v2.py
│ │ │ │ └── utils.py
│ │ │ ├── test_connectors
│ │ │ │ └── test_connectors.py
│ │ │ ├── test_dynamic_op
│ │ │ │ ├── __init__.py
│ │ │ │ ├── test_bricks
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── test_dynamic_attention.py
│ │ │ │ │ ├── test_dynamic_container.py
│ │ │ │ │ ├── test_dynamic_conv.py
│ │ │ │ │ ├── test_dynamic_embed.py
│ │ │ │ │ ├── test_dynamic_layernorm.py
│ │ │ │ │ ├── test_dynamic_linear.py
│ │ │ │ │ ├── test_dynamic_norm.py
│ │ │ │ │ ├── test_dynamic_relative_position.py
│ │ │ │ │ └── test_dynamic_resizer.py
│ │ │ │ └── utils.py
│ │ │ └── test_generators
│ │ │ │ └── test_generators.py
│ │ ├── test_classifier
│ │ │ └── test_imageclassifier.py
│ │ ├── test_distillers
│ │ │ ├── test_byot_distill.py
│ │ │ └── test_configurable_distill.py
│ │ ├── test_fake_quants
│ │ │ ├── test_lsq_fake_quants.py
│ │ │ └── test_torch_fake_quants.py
│ │ ├── test_losses
│ │ │ ├── test_distillation_losses.py
│ │ │ └── test_general_losses.py
│ │ ├── test_mutables
│ │ │ ├── __init__.py
│ │ │ ├── test_derived_mutable.py
│ │ │ ├── test_diffchoiceroute.py
│ │ │ ├── test_diffop.py
│ │ │ ├── test_gumbelchoiceroute.py
│ │ │ ├── test_mutable_channel
│ │ │ │ ├── __init__.py
│ │ │ │ ├── test_mutable_channels.py
│ │ │ │ ├── test_sequential_mutable_channel.py
│ │ │ │ └── test_units
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── test_dcff_channel_unit.py
│ │ │ │ │ ├── test_l1_mutable_channel_unit.py
│ │ │ │ │ ├── test_mutable_channel_units.py
│ │ │ │ │ ├── test_one_shot_mutable_channel_unit.py
│ │ │ │ │ └── test_sequential_mutable_channel_unit.py
│ │ │ ├── test_mutable_value.py
│ │ │ ├── test_onehotop.py
│ │ │ ├── test_oneshotop.py
│ │ │ └── test_sequential_mutable_channel.py
│ │ ├── test_mutators
│ │ │ ├── __init__.py
│ │ │ ├── test_channel_mutator.py
│ │ │ ├── test_dcff_mutator.py
│ │ │ ├── test_dmcp_mutator.py
│ │ │ └── test_nas_mutator.py
│ │ ├── test_observers
│ │ │ ├── test_lsq_observer.py
│ │ │ └── test_torch_observers.py
│ │ ├── test_quantizers
│ │ │ ├── test_academic_quantizer.py
│ │ │ ├── test_exporter.py
│ │ │ ├── test_native_quantizer.py
│ │ │ ├── test_openvino_quantizer.py
│ │ │ └── test_tensorrt_quantizer.py
│ │ ├── test_subnet
│ │ │ ├── test_candidate.py
│ │ │ └── test_fix_subnet.py
│ │ ├── test_task_modules
│ │ │ ├── __init__.py
│ │ │ ├── test_custom_tracer.py
│ │ │ ├── test_demo_inputs
│ │ │ │ ├── __init__.py
│ │ │ │ └── test_demo_inputs.py
│ │ │ ├── test_estimators
│ │ │ │ └── test_flops_params.py
│ │ │ ├── test_graph_utils.py
│ │ │ └── test_predictors
│ │ │ │ └── test_metric_predictor.py
│ │ └── test_utils
│ │ │ ├── __init__.py
│ │ │ └── test_expandable_utils
│ │ │ ├── __init__.py
│ │ │ └── test_expand.py
│ ├── test_registry
│ │ └── test_registry.py
│ ├── test_runners
│ │ ├── test_autoslim_greedy_search_loop.py
│ │ ├── test_darts_loop.py
│ │ ├── test_distill_val_loop.py
│ │ ├── test_evolution_search_loop.py
│ │ ├── test_quantization_loop.py
│ │ ├── test_subnet_sampler_loop.py
│ │ └── test_utils
│ │ │ ├── test_calibrate_bn_mixin.py
│ │ │ ├── test_check.py
│ │ │ └── test_genetic.py
│ ├── test_structures
│ │ ├── test_backendconfig.py
│ │ └── test_qconfig.py
│ ├── test_tools
│ │ ├── __init__.py
│ │ └── test_tools.py
│ ├── test_utils
│ │ ├── test_index_dict.py
│ │ └── test_placeholder.py
│ ├── test_visualizer
│ │ └── test_visualizer.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── set_dist_env.py
│ │ └── set_torch_thread.py
└── tools
│ ├── dist_test.sh
│ ├── dist_train.sh
│ ├── misc
│ └── print_config.py
│ ├── model_converters
│ ├── convert_attentivenas_nas_ckpt.py
│ ├── convert_bignas_gml_ckpt.py
│ ├── convert_kd_ckpt.py
│ ├── convert_kd_ckpt_to_student.py
│ ├── convert_ofa_ckpt.py
│ ├── convert_quant_ckpt.py
│ ├── convert_supernet2subnet.py
│ └── publish_model.py
│ ├── pruning
│ ├── get_channel_units.py
│ ├── get_flops.py
│ ├── get_l1_prune_config.py
│ └── get_static_model_from_algorithm.py
│ ├── ptq.py
│ ├── slurm_test.sh
│ ├── slurm_train.sh
│ ├── test.py
│ ├── train.py
│ └── visualizations
│ ├── demo.jpg
│ ├── feature_diff_visualization.py
│ ├── feature_visualization.py
│ ├── vis_configs
│ ├── backbone_feature_diff_visualization.py
│ ├── backbone_feature_visualization.py
│ ├── fpn_feature_diff_visualization.py
│ └── fpn_feature_visualization.py
│ └── vis_scheduler.py
├── pyproject.toml
└── test_dms.py
/.vscode/launch.json:
--------------------------------------------------------------------------------
1 | {
2 | // Use IntelliSense to learn about possible attributes.
3 | // Hover to view descriptions of existing attributes.
4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5 | "version": "0.2.0",
6 | "configurations": [
7 | {
8 | "name": "Python: Debug Tests",
9 | "type": "debugpy",
10 | "request": "launch",
11 | "program": "${file}",
12 | "purpose": [
13 | "debug-test"
14 | ],
15 | "console": "integratedTerminal",
16 | "justMyCode": true
17 | }
18 | ]
19 | }
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "python.testing.unittestArgs": [
3 | "-v",
4 | "-s",
5 | ".",
6 | "-p",
7 | "test*.py"
8 | ],
9 | "python.testing.pytestEnabled": false,
10 | "python.testing.unittestEnabled": true
11 | }
--------------------------------------------------------------------------------
/applications/efficientnet/.gitignore:
--------------------------------------------------------------------------------
1 | data
2 | output
3 | checkpoints
--------------------------------------------------------------------------------
/applications/efficientnet/.vscode/launch.json:
--------------------------------------------------------------------------------
1 | {
2 | // Use IntelliSense to learn about possible attributes.
3 | // Hover to view descriptions of existing attributes.
4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5 | "version": "0.2.0",
6 | "configurations": [
7 | {
8 | "name": "Python Debugger: Current File",
9 | "type": "debugpy",
10 | "request": "launch",
11 | "program": "${file}",
12 | "console": "integratedTerminal",
13 | "justMyCode": false
14 | }
15 | ]
16 | }
--------------------------------------------------------------------------------
/applications/efficientnet/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "python.testing.unittestArgs": [
3 | "-v",
4 | "-s",
5 | ".",
6 | "-p",
7 | "test*.py"
8 | ],
9 | "python.testing.pytestEnabled": false,
10 | "python.testing.unittestEnabled": true
11 | }
--------------------------------------------------------------------------------
/applications/efficientnet/README.md:
--------------------------------------------------------------------------------
1 | # DMS-EfficientNet
2 |
3 | ## Getting Started
4 |
5 | ```
6 | pip install -r requirements.txt
7 | ```
8 |
9 | ## Pruning && Retrain
10 |
11 | ```
12 | export Variant=DMS-450 # for example
13 | scripts/{Variant}/prune.sh
14 | scripts/{Variant}/retrain.sh
15 | ```
--------------------------------------------------------------------------------
/applications/efficientnet/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.13.1+cu117
2 | torchvision==0.14.1+cu117
3 | torchaudio==0.13.1
4 | timm==0.9.2
5 | mmcv==2.0.1
6 | joblib
7 | mmcls== 1.0.0rc6
--------------------------------------------------------------------------------
/applications/efficientnet/scripts/DMS-450/prune.sh:
--------------------------------------------------------------------------------
1 | torchrun --nproc_per_node=8 --master_port 29112 timm_pruning.py ./data/imagenet_torch --model efficientnet_b4 -b 96 --sched step --epochs 4 \
2 | --decay-epochs 2.4 --decay-rate .97 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 \
3 | --drop 0.5 --drop-path 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .048 \
4 | --experiment dms_450_pruned --pin-mem --input-size 3 224 224 \
5 | --target 0.296 --mutator_lr 2e-5 --loss_weight 1 --skip_full_target
6 |
7 |
--------------------------------------------------------------------------------
/applications/efficientnet/scripts/DMS-450/retrain.sh:
--------------------------------------------------------------------------------
1 | torchrun --nproc_per_node=8 --master_port 29112 \
2 | timm_retrain.py ./data/imagenet_torch --model efficientnet_b4 -b 96 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .97 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.2 --drop-path 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .048 \
3 | --experiment dms_450_retrain_distil --pin-mem --resume output/train/dms_450_retrain_distil/last.pth.tar \
4 | --pruned output/train/dms_450_prune/last.pth.tar --input-size 3 224 224 --target 0.2924 --teacher timm/efficientnet_b4 --teacher_input_image_size 320
5 |
6 |
--------------------------------------------------------------------------------
/dms/__init__.py:
--------------------------------------------------------------------------------
1 | from .dtopk_src import differentiable_topk, MASK_THRESHOLD
2 |
--------------------------------------------------------------------------------
/dms/dtopk_src.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from mmengine.dist import all_reduce
4 |
5 | MASK_THRESHOLD = 0.5
6 |
7 |
8 | @torch.jit.script
9 | def _dtopk(c: torch.Tensor, a: torch.Tensor, lambda_: float = 1.0):
10 | y = (c - a) * c.numel() * lambda_
11 | y = y.sigmoid()
12 | return y
13 |
14 |
15 | @torch.jit.script
16 | def differentiable_topk(
17 | c: torch.Tensor,
18 | a: torch.Tensor,
19 | lambda_: float = 1.0,
20 | normalize: bool = True,
21 | ):
22 | """
23 | Differentiable top-k operator: Elements with large importance are kept.
24 |
25 | Args:
26 | c: importance score of elements
27 | a: pruning ratio of elements
28 | lambda_: hyper-parameter to control the polarity of the generated mask, default to 1.
29 | normalize: whether to normalize the importance score, default is True
30 | Returns:
31 | soft masks of elements
32 | """
33 |
34 | if c.numel() == 1:
35 | return (-(0.5 - a) * lambda_).sigmoid()
36 | else:
37 | if normalize:
38 | c_compare = c.unsqueeze(-1) - c.unsqueeze(-2) # [N,N]
39 | c_ = (c_compare >= 0).float().mean(dim=-1) # normalize to [0,1]
40 | else:
41 | c_ = c
42 |
43 | imp = _dtopk(c_, a, lambda_=lambda_) # [0,1]
44 | return imp
45 |
--------------------------------------------------------------------------------
/dms/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LKJacky/Differentiable-Model-Scaling/27188926d8e0815ba1b6159b47741ea289ba2ca2/dms/modules/__init__.py
--------------------------------------------------------------------------------
/images/compare.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LKJacky/Differentiable-Model-Scaling/27188926d8e0815ba1b6159b47741ea289ba2ca2/images/compare.png
--------------------------------------------------------------------------------
/mmrazor/.circleci/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | ARG PYTORCH="1.8.1"
2 | ARG CUDA="10.2"
3 | ARG CUDNN="7"
4 |
5 | FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel
6 |
7 | # To fix GPG key error when running apt-get update
8 | RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub
9 | RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub
10 |
11 | RUN apt-get update && apt-get install -y ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6 libgl1-mesa-glx
12 |
--------------------------------------------------------------------------------
/mmrazor/.github/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | We appreciate all contributions to improve MMRazor. Please refer to [CONTRIBUTING.md](https://github.com/open-mmlab/mmcv/blob/master/CONTRIBUTING.md) in MMCV for more details about the contributing guideline.
2 |
--------------------------------------------------------------------------------
/mmrazor/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a report to help us improve
4 | title: '[Bug]'
5 | labels: bug
6 | assignees: ''
7 | ---
8 |
9 | ### Describe the bug
10 |
11 | A clear and concise description of what the bug is.
12 |
13 | \[here\]
14 |
15 | ### To Reproduce
16 |
17 | The command you executed.
18 |
19 | ```shell
20 | [here]
21 | ```
22 |
23 | ### Post related information
24 |
25 | 1. The output of `pip list | grep "mmcv\|mmrazor\|^torch"`
26 | \[here\]
27 | 2. Your config file if you modified it or created a new one.
28 |
29 | ```python
30 | [here]
31 | ```
32 |
33 | 3. Your train log file if you meet the problem during training.
34 | \[here\]
35 | 4. Other code you modified in the `mmrazor` folder.
36 | \[here\]
37 |
38 | ### Additional context
39 |
40 | Add any other context about the problem here.
41 |
42 | \[here\]
43 |
--------------------------------------------------------------------------------
/mmrazor/.github/ISSUE_TEMPLATE/config.yml:
--------------------------------------------------------------------------------
1 | blank_issues_enabled: false
2 |
3 | contact_links:
4 | - name: MMRazor Documentation
5 | url: https://mmrazor.readthedocs.io/en/latest/
6 | about: Check if your question is answered in docs
7 |
--------------------------------------------------------------------------------
/mmrazor/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Suggest an idea for this project
4 | title: '[Feature]'
5 | labels: enhancement
6 | assignees: ''
7 | ---
8 |
9 | ### Describe the feature
10 |
11 | \[here\]
12 |
13 | ### Motivation
14 |
15 | A clear and concise description of the motivation of the feature.
16 | Ex1. It is inconvenient when \[....\].
17 | Ex2. There is a recent paper \[....\], which is very helpful for \[....\].
18 |
19 | \[here\]
20 |
21 | ### Related resources
22 |
23 | If there is an official code release or third-party implementation, please also provide the information here, which would be very helpful.
24 |
25 | \[here\]
26 |
27 | ### Additional context
28 |
29 | Add any other context or screenshots about the feature request here.
30 | If you would like to implement the feature and create a PR, please leave a comment here and that would be much appreciated.
31 |
32 | \[here\]
33 |
--------------------------------------------------------------------------------
/mmrazor/.github/ISSUE_TEMPLATE/general-questions.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: General questions
3 | about: 'Ask general questions to get help '
4 | title: ''
5 | labels: help wanted
6 | assignees: ''
7 | ---
8 |
9 | ### Checklist
10 |
11 | - I have searched related issues but cannot get the expected help.
12 | - I have read related documents and don't know what to do.
13 |
14 | ### Describe the question you meet
15 |
16 | \[here\]
17 |
18 | ### Post related information
19 |
20 | 1. The output of `pip list | grep "mmcv\|mmrazor\|^torch"`
21 | \[here\]
22 | 2. Your config file if you modified it or created a new one.
23 |
24 | ```python
25 | [here]
26 | ```
27 |
28 | 3. Your train log file if you meet the problem during training.
29 | \[here\]
30 | 4. Other code you modified in the `mmrazor` folder.
31 | \[here\]
32 |
--------------------------------------------------------------------------------
/mmrazor/.github/workflows/deploy.yml:
--------------------------------------------------------------------------------
1 | name: deploy
2 |
3 | on: push
4 |
5 | concurrency:
6 | group: ${{ github.workflow }}-${{ github.ref }}
7 | cancel-in-progress: true
8 |
9 | jobs:
10 | build-n-publish:
11 | runs-on: ubuntu-latest
12 | if: startsWith(github.event.ref, 'refs/tags')
13 | steps:
14 | - uses: actions/checkout@v2
15 | - name: Set up Python 3.7
16 | uses: actions/setup-python@v2
17 | with:
18 | python-version: 3.7
19 | - name: Build MMRAZOR
20 | run: |
21 | pip install wheel
22 | python setup.py sdist bdist_wheel
23 | - name: Publish distribution to PyPI
24 | run: |
25 | pip install twine
26 | twine upload dist/* -u __token__ -p ${{ secrets.pypi_password }}
27 |
--------------------------------------------------------------------------------
/mmrazor/.github/workflows/lint.yml:
--------------------------------------------------------------------------------
1 | name: lint
2 |
3 | on: [push, pull_request]
4 |
5 | concurrency:
6 | group: ${{ github.workflow }}-${{ github.ref }}
7 | cancel-in-progress: true
8 |
9 | jobs:
10 | lint:
11 | runs-on: ubuntu-latest
12 | steps:
13 | - uses: actions/checkout@v2
14 | - name: Set up Python 3.7
15 | uses: actions/setup-python@v2
16 | with:
17 | python-version: 3.7
18 | - name: Install pre-commit hook
19 | run: |
20 | pip install pre-commit
21 | pre-commit install
22 | - name: Linting
23 | run: pre-commit run --all-files
24 | - name: Check docstring coverage
25 | run: |
26 | pip install interrogate
27 | interrogate -v --ignore-init-method --ignore-module --ignore-nested-functions --ignore-regex "__repr__" --fail-under 80 mmrazor
28 |
--------------------------------------------------------------------------------
/mmrazor/.readthedocs.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 |
3 | formats: all
4 |
5 | python:
6 | version: 3.7
7 | install:
8 | - requirements: requirements/docs.txt
9 | - requirements: requirements/readthedocs.txt
10 |
--------------------------------------------------------------------------------
/mmrazor/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include requirements/*.txt
2 | include mmrazor/VERSION
3 | include mmrazor/.mim/model-index.yml
4 | include mmrazor/.mim/demo/*/*
5 | recursive-include mmrazor/.mim/configs *.py *.yml
6 | recursive-include mmrazor/.mim/tools *.sh *.py
7 |
--------------------------------------------------------------------------------
/mmrazor/configs/_base_/nas_backbones/darts_supernet.py:
--------------------------------------------------------------------------------
1 | mutable_cfg = dict(
2 | type='mmrazor.DiffMutableOP',
3 | candidates=dict(
4 | zero=dict(type='mmrazor.DartsZero'),
5 | skip_connect=dict(type='mmrazor.DartsSkipConnect', use_drop_path=True),
6 | max_pool_3x3=dict(
7 | type='mmrazor.DartsPoolBN', pool_type='max', use_drop_path=True),
8 | avg_pool_3x3=dict(
9 | type='mmrazor.DartsPoolBN', pool_type='avg', use_drop_path=True),
10 | sep_conv_3x3=dict(
11 | type='mmrazor.DartsSepConv', kernel_size=3, use_drop_path=True),
12 | sep_conv_5x5=dict(
13 | type='mmrazor.DartsSepConv', kernel_size=5, use_drop_path=True),
14 | dil_conv_3x3=dict(
15 | type='mmrazor.DartsDilConv', kernel_size=3, use_drop_path=True),
16 | dil_conv_5x5=dict(
17 | type='mmrazor.DartsDilConv', kernel_size=5, use_drop_path=True)))
18 |
19 | route_cfg = dict(type='mmrazor.DiffChoiceRoute', with_arch_param=True)
20 |
21 | nas_backbone = dict(
22 | type='mmrazor.DartsBackbone',
23 | in_channels=3,
24 | base_channels=16,
25 | num_layers=8,
26 | num_nodes=4,
27 | stem_multiplier=3,
28 | out_indices=(7, ),
29 | mutable_cfg=mutable_cfg,
30 | route_cfg=route_cfg,
31 | norm_cfg=dict(type='BN', affine=False))
32 |
--------------------------------------------------------------------------------
/mmrazor/configs/_base_/nas_backbones/dsnas_shufflenet_supernet.py:
--------------------------------------------------------------------------------
1 | norm_cfg = dict(type='BN', eps=0.01)
2 |
3 | _STAGE_MUTABLE = dict(
4 | type='mmrazor.OneHotMutableOP',
5 | fix_threshold=0.3,
6 | candidates=dict(
7 | shuffle_3x3=dict(
8 | type='ShuffleBlock', kernel_size=3, norm_cfg=norm_cfg),
9 | shuffle_5x5=dict(
10 | type='ShuffleBlock', kernel_size=5, norm_cfg=norm_cfg),
11 | shuffle_7x7=dict(
12 | type='ShuffleBlock', kernel_size=7, norm_cfg=norm_cfg),
13 | shuffle_xception=dict(type='ShuffleXception', norm_cfg=norm_cfg)))
14 |
15 | arch_setting = [
16 | # Parameters to build layers. 3 parameters are needed to construct a
17 | # layer, from left to right: channel, num_blocks, mutable_cfg.
18 | [64, 4, _STAGE_MUTABLE],
19 | [160, 4, _STAGE_MUTABLE],
20 | [320, 8, _STAGE_MUTABLE],
21 | [640, 4, _STAGE_MUTABLE]
22 | ]
23 |
24 | nas_backbone = dict(
25 | type='mmrazor.SearchableShuffleNetV2',
26 | widen_factor=1.0,
27 | arch_setting=arch_setting,
28 | norm_cfg=norm_cfg)
29 |
--------------------------------------------------------------------------------
/mmrazor/configs/_base_/nas_backbones/spos_shufflenet_supernet.py:
--------------------------------------------------------------------------------
1 | _STAGE_MUTABLE = dict(
2 | _scope_='mmrazor',
3 | type='OneShotMutableOP',
4 | candidates=dict(
5 | shuffle_3x3=dict(type='ShuffleBlock', kernel_size=3),
6 | shuffle_5x5=dict(type='ShuffleBlock', kernel_size=5),
7 | shuffle_7x7=dict(type='ShuffleBlock', kernel_size=7),
8 | shuffle_xception=dict(type='ShuffleXception')))
9 |
10 | arch_setting = [
11 | # Parameters to build layers. 3 parameters are needed to construct a
12 | # layer, from left to right: channel, num_blocks, mutable_cfg.
13 | [64, 4, _STAGE_MUTABLE],
14 | [160, 4, _STAGE_MUTABLE],
15 | [320, 8, _STAGE_MUTABLE],
16 | [640, 4, _STAGE_MUTABLE]
17 | ]
18 |
19 | nas_backbone = dict(
20 | _scope_='mmrazor',
21 | type='SearchableShuffleNetV2',
22 | widen_factor=1.0,
23 | arch_setting=arch_setting)
24 |
--------------------------------------------------------------------------------
/mmrazor/configs/_base_/settings/imagenet_bs2048_autoslim.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | './imagenet_bs1024_spos.py',
3 | ]
4 |
5 | _RandomResizedCrop_cfg = _base_.train_dataloader.dataset.pipeline[1]
6 | assert _RandomResizedCrop_cfg.type == 'RandomResizedCrop'
7 | _RandomResizedCrop_cfg.crop_ratio_range = (0.25, 1.0)
8 |
9 | optim_wrapper = dict(optimizer=dict(weight_decay=1e-4, nesterov=True))
10 |
11 | train_dataloader = dict(batch_size=256)
12 |
13 | val_dataloader = dict(batch_size=256)
14 |
15 | test_dataloader = dict(batch_size=256)
16 |
--------------------------------------------------------------------------------
/mmrazor/configs/_base_/settings/imagenet_bs2048_autoslim_pil.py:
--------------------------------------------------------------------------------
1 | _base_ = 'imagenet_bs2048_autoslim.py'
2 |
3 | _RandomResizedCrop_cfg = _base_.train_dataloader.dataset.pipeline[1]
4 | assert _RandomResizedCrop_cfg.type == 'RandomResizedCrop'
5 | _RandomResizedCrop_cfg.backend = 'pillow'
6 |
7 | _ResizeEdge_cfg_val = _base_.val_dataloader.dataset.pipeline[1]
8 | assert _ResizeEdge_cfg_val.type == 'ResizeEdge'
9 | _ResizeEdge_cfg_val.backend = 'pillow'
10 |
11 | _ResizeEdge_cfg_test = _base_.test_dataloader.dataset.pipeline[1]
12 | assert _ResizeEdge_cfg_test.type == 'ResizeEdge'
13 | _ResizeEdge_cfg_test.backend = 'pillow'
14 |
--------------------------------------------------------------------------------
/mmrazor/configs/_base_/vanilla_models/wrn16_2_cifar10.py:
--------------------------------------------------------------------------------
1 | model = dict(
2 | _scope_='mmcls',
3 | type='ImageClassifier',
4 | backbone=dict(
5 | _scope_='mmrazor',
6 | type='WideResNet',
7 | depth=16,
8 | num_stages=3,
9 | widen_factor=2,
10 | ),
11 | neck=dict(type='GlobalAveragePooling'),
12 | head=dict(
13 | type='LinearClsHead',
14 | num_classes=10,
15 | in_channels=128,
16 | loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
17 | topk=(1, 5),
18 | ))
19 |
20 | find_unused_parameters = True
21 |
--------------------------------------------------------------------------------
/mmrazor/configs/distill/mmcls/byot/metafile.yml:
--------------------------------------------------------------------------------
1 | Collections:
2 | - Name: BYOT
3 | Metadata:
4 | Training Data:
5 | - CIFAR100
6 | Paper:
7 | URL: https://arxiv.org/pdf/2107.06916.pdf
8 | Title: Training Compact CNNs for Image Classification using Dynamic-coded Filter Fusion
9 | README: configs/distill/mmcls/byot/README.md
10 | Converted From:
11 | Code:
12 | URL: https://github.com/luanyunteng/pytorch-be-your-own-teacher
13 | Models:
14 | - Name: byot_resnet18_8xb16_cifar100
15 | In Collection: BYOT
16 | Metadata:
17 | inference time (ms/im):
18 | - value: 0.62
19 | hardware: V100
20 | backend: PyTorch
21 | batch size: 16
22 | mode: FP32
23 | resolution: (32, 32)
24 | Results:
25 | - Task: Classification
26 | Dataset: CIFAR100
27 | Metrics:
28 | Top 1 Accuracy: 80.66
29 | Top 5 Accuracy: 95.76
30 | Weights: https://download.openmmlab.com/mmrazor/v1/byot/byot_resnet18_8xb16_cifar100_20220817_191217-0251084e.pth
31 | Config: configs/distill/mmcls/byot/byot_resnet18_8xb16_cifar100.py
32 |
--------------------------------------------------------------------------------
/mmrazor/configs/distill/mmcls/factor_transfer/factor-transfer_backbone_resnet50_resnet18_8xb16_cifar10_train.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | './factor-transfer_backbone_resnet50_resnet18_8xb16_cifar10_pretrain.py'
3 | ]
4 |
5 | train_cfg = dict(by_epoch=True, max_epochs=200, val_interval=1)
6 |
7 | model = dict(
8 | calculate_student_loss=True,
9 | student_trainable=True,
10 | distiller=dict(
11 | distill_losses=dict(loss_s4=dict(type='FTLoss', loss_weight=1.0)),
12 | connectors=dict(loss_s4_tfeat=dict(phase='train')),
13 | loss_forward_mappings=dict(
14 | _delete_=True,
15 | loss_s4=dict(
16 | s_feature=dict(
17 | from_student=True,
18 | recorder='bb_s4',
19 | connector='loss_s4_sfeat'),
20 | t_feature=dict(
21 | from_student=False,
22 | recorder='bb_s4',
23 | connector='loss_s4_tfeat'),
24 | ))),
25 | init_cfg=dict(
26 | type='Pretrained',
27 | checkpoint= # noqa: E251
28 | 'https://download.openmmlab.com/mmrazor/v1/factor_transfer/factor-transfer_backbone_resnet50_resnet18_8xb16_cifar10_pretrain_20220831_173259-ebdb09e2.pth' # noqa: E501
29 | ))
30 |
--------------------------------------------------------------------------------
/mmrazor/configs/distill/mmdet/cwd/cwd_cls_head_gfl_r101_fpn_gfl_r50_fpn_1x_coco.py:
--------------------------------------------------------------------------------
1 | _base_ = ['./cwd_fpn_retina_r101_retina_r50_1x_coco.py']
2 |
3 | teacher_ckpt = 'https://download.openmmlab.com/mmdetection/v2.0/gfl/gfl_r101_fpn_mstrain_2x_coco/gfl_r101_fpn_mstrain_2x_coco_20200629_200126-dd12f847.pth' # noqa: E501
4 | model = dict(
5 | architecture=dict(
6 | cfg_path='mmdet::gfl/gfl_r50_fpn_1x_coco.py', pretrained=False),
7 | teacher=dict(
8 | cfg_path='mmdet::gfl/gfl_r101_fpn_ms-2x_coco.py', pretrained=True),
9 | teacher_ckpt=teacher_ckpt)
10 |
--------------------------------------------------------------------------------
/mmrazor/configs/distill/mmdet/cwd/cwd_fpn_retina_r101_retina_r50_1x_coco.py:
--------------------------------------------------------------------------------
1 | _base_ = ['./cwd_fpn_frcnn_r101_frcnn_r50_1x_coco.py']
2 |
3 | model = dict(
4 | architecture=dict(
5 | cfg_path='mmdet::retinanet/retinanet_r50_fpn_1x_coco.py',
6 | pretrained=False),
7 | teacher=dict(
8 | cfg_path='mmdet::retinanet/retinanet_r101_fpn_2x_coco.py',
9 | pretrained=True))
10 |
11 | # optimizer
12 | optim_wrapper = dict(
13 | optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001))
14 |
--------------------------------------------------------------------------------
/mmrazor/configs/distill/mmdet/cwd/cwd_fpn_retina_r101_retina_r50_1x_coco_visualization.py:
--------------------------------------------------------------------------------
1 | _base_ = ['./cwd_fpn_retina_r101_retina_r50_1x_coco.py']
2 |
3 | default_hooks = dict(
4 | checkpoint=dict(type='CheckpointHook', interval=-1),
5 | visualization=dict(
6 | _scope_='mmrazor',
7 | type='RazorVisualizationHook',
8 | enabled=True,
9 | recorders=dict(
10 | # todo: Maybe it is hard for users to understand why to add a
11 | # prefix `architecture.`
12 | neck=dict(
13 | _scope_='mmrazor',
14 | type='ModuleOutputs',
15 | source='architecture.neck')),
16 | mappings=dict(
17 | p3=dict(recorder='neck', data_idx=0),
18 | p4=dict(recorder='neck', data_idx=1),
19 | p5=dict(recorder='neck', data_idx=2),
20 | p6=dict(recorder='neck', data_idx=3)),
21 | out_dir='retina_vis'))
22 |
--------------------------------------------------------------------------------
/mmrazor/configs/distill/mmdet/cwd/metafile.yml:
--------------------------------------------------------------------------------
1 |
2 | Models:
3 | - Name: cwd_cls_head_gfl_r101_fpn_gfl_r50_fpn_1x_coco
4 | In Collection: CWD
5 | Metadata:
6 | Location: cls head
7 | Student:
8 | Metrics:
9 | box AP: 40.2
10 | Config: mmdet::gfl/gfl_r50_fpn_1x_coco.py
11 | Weights: https://download.openmmlab.com/mmdetection/v2.0/gfl/gfl_r50_fpn_1x_coco/gfl_r50_fpn_1x_coco_20200629_121244-25944287.pth
12 | Teacher:
13 | Metrics:
14 | box AP: 44.7
15 | Config: mmdet::gfl/gfl_r50_fpn_mstrain_2x_coco.py
16 | Weights: https://download.openmmlab.com/mmdetection/v2.0/gfl/gfl_r101_fpn_mstrain_2x_coco/gfl_r101_fpn_mstrain_2x_coco_20200629_200126-dd12f847.pth
17 | Results:
18 | - Task: Object Detection
19 | Dataset: COCO
20 | Metrics:
21 | box AP: 41.9
22 | Config: configs/distill/mmdet/cwd/cwd_cls_head_gfl_r101_fpn_gfl_r50_fpn_1x_coco.py
23 | Weights: https://download.openmmlab.com/mmrazor/v1/cwd/cwd_cls_head_gfl_r101_fpn_gfl_r50_fpn_1x_coco/cwd_cls_head_gfl_r101_fpn_gfl_r50_fpn_1x_coco_20211222-c134bb21.pth
24 |
--------------------------------------------------------------------------------
/mmrazor/configs/distill/mmdet/pkd/pkd_fpn_fcos_x101_retina_r50_1x_coco.py:
--------------------------------------------------------------------------------
1 | _base_ = ['./pkd_fpn_retina_x101_retina_r50_2x_coco.py']
2 |
3 | teacher_ckpt = 'https://download.openmmlab.com/mmdetection/v2.0/fcos/fcos_x101_64x4d_fpn_gn-head_mstrain_640-800_2x_coco/fcos_x101_64x4d_fpn_gn-head_mstrain_640-800_2x_coco-ede514a8.pth' # noqa: E501
4 |
5 | model = dict(
6 | architecture=dict(
7 | cfg_path='mmdet::retinanet/retinanet_r50_fpn_1x_coco.py'),
8 | teacher=dict(
9 | cfg_path= # noqa: E251
10 | 'mmdet::fcos/fcos_x101-64x4d_fpn_gn-head_ms-640-800-2x_coco.py'),
11 | teacher_ckpt=teacher_ckpt)
12 |
13 | # training schedule for 1x
14 | train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=12, val_interval=1)
15 |
16 | # learning rate
17 | param_scheduler = [
18 | dict(
19 | type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, end=500),
20 | dict(
21 | type='MultiStepLR',
22 | begin=0,
23 | end=12,
24 | by_epoch=True,
25 | milestones=[8, 11],
26 | gamma=0.1)
27 | ]
28 |
--------------------------------------------------------------------------------
/mmrazor/configs/distill/mmdet/pkd/pkd_fpn_reppoints_x101-dcn_reppoints_r50_2x_coco.py:
--------------------------------------------------------------------------------
1 | _base_ = ['./pkd_fpn_retina_x101_retina_r50_2x_coco.py']
2 |
3 | teacher_ckpt = 'https://download.openmmlab.com/mmdetection/v2.0/reppoints/reppoints_moment_x101_fpn_dconv_c3-c5_gn-neck%2Bhead_2x_coco/reppoints_moment_x101_fpn_dconv_c3-c5_gn-neck%2Bhead_2x_coco_20200329-f87da1ea.pth' # noqa: E501
4 |
5 | model = dict(
6 | architecture=dict(
7 | cfg_path= # noqa: E251
8 | 'mmdet::reppoints/reppoints-moment_r50_fpn-gn_head-gn_2x_coco.py'),
9 | teacher=dict(
10 | cfg_path= # noqa: E251
11 | 'mmdet::reppoints/reppoints-moment_x101-dconv-c3-c5_fpn-gn_head-gn_2x_coco.py' # noqa: E501
12 | ),
13 | teacher_ckpt=teacher_ckpt)
14 |
--------------------------------------------------------------------------------
/mmrazor/configs/distill/mmdet/pkd/pkd_fpn_retina_x101_retina_r50_2x_coco.py:
--------------------------------------------------------------------------------
1 | _base_ = ['./pkd_fpn_faster-rcnn_r101_faster-rcnn_r50_2x_coco.py']
2 |
3 | teacher_ckpt = 'https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_x101_64x4d_fpn_1x_coco/retinanet_x101_64x4d_fpn_1x_coco_20200130-366f5af1.pth' # noqa: E501
4 |
5 | model = dict(
6 | architecture=dict(
7 | cfg_path='mmdet::retinanet/retinanet_r50_fpn_2x_coco.py'),
8 | teacher=dict(
9 | cfg_path='mmdet::retinanet/retinanet_x101-64x4d_fpn_1x_coco.py'),
10 | teacher_ckpt=teacher_ckpt,
11 | distiller=dict(
12 | distill_losses=dict(
13 | loss_pkd_fpn0=dict(loss_weight=10),
14 | loss_pkd_fpn1=dict(loss_weight=10),
15 | loss_pkd_fpn2=dict(loss_weight=10),
16 | loss_pkd_fpn3=dict(loss_weight=10))))
17 |
18 | # optimizer
19 | optim_wrapper = dict(optimizer=dict(lr=0.01))
20 |
--------------------------------------------------------------------------------
/mmrazor/configs/distill/mmdet3d/pkd/metafile.yml:
--------------------------------------------------------------------------------
1 | Models:
2 | - Name: pkd_fpn_fcos3d_r101_fcos3d_r50_8xb2-1x_nus-mono3d
3 | In Collection: PKD
4 | Metadata:
5 | Location: FPN
6 | Student:
7 | Metrics:
8 | box AP: 26.8
9 | Config:
10 | Weights:
11 | Teacher:
12 | Metrics:
13 | box AP: 32.1
14 | Config: mmdet3d::fcos3d/fcos3d_r101-caffe-dcn_fpn_head-gn_8xb2-1x_nus-mono3d_finetune.py
15 | Weights: https://download.openmmlab.com/mmdetection3d/v0.1.0_models/fcos3d/fcos3d_r101_caffe_fpn_gn-head_dcn_2x8_1x_nus-mono3d_finetune/fcos3d_r101_caffe_fpn_gn-head_dcn_2x8_1x_nus-mono3d_finetune_20210717_095645-8d806dc2.pth
16 | Results:
17 | - Task: Object Detection
18 | Dataset: COCO
19 | Metrics:
20 | box AP: 29.3
21 | Config: configs/distill/mmdet3d/pkd/pkd_fpn_fcos3d_r101_fcos3d_r50_8xb2-1x_nus-mono3d.py
22 | Weights: https://download.openmmlab.com/mmrazor/v1/pkd/pkd_fcos3d_w10/pkd_fpn_fcos3d_r101_fcos3d_r50_8xb2-1x_nus-mono3d_20220928_234557-0b51b62e.json?versionId=CAEQThiBgIDrvdC0oBgiIDNmNGNkNDZhM2RmNjQ1MmI4ZDM0OGNmYmFkYjk5ZjFi
23 |
--------------------------------------------------------------------------------
/mmrazor/configs/nas/mmcls/autoformer/autoformer_search_8xb128_in1k.py:
--------------------------------------------------------------------------------
1 | _base_ = ['./autoformer_supernet_32xb256_in1k.py']
2 |
3 | custom_hooks = None
4 |
5 | train_cfg = dict(
6 | _delete_=True,
7 | type='mmrazor.EvolutionSearchLoop',
8 | dataloader=_base_.val_dataloader,
9 | evaluator=_base_.val_evaluator,
10 | max_epochs=20,
11 | num_candidates=20,
12 | top_k=10,
13 | num_mutation=5,
14 | num_crossover=5,
15 | mutate_prob=0.2,
16 | constraints_range=dict(params=(0, 55)),
17 | score_key='accuracy/top1')
18 |
--------------------------------------------------------------------------------
/mmrazor/configs/nas/mmcls/autoformer/autoformer_subnet_8xb256_in1k.py:
--------------------------------------------------------------------------------
1 | _base_ = 'autoformer_supernet_32xb256_in1k.py'
2 |
3 | model = dict(
4 | _scope_='mmrazor',
5 | type='sub_model',
6 | cfg=_base_.supernet,
7 | # NOTE: You can replace the yaml with the mutable_cfg searched by yourself
8 | fix_subnet='configs/nas/mmcls/autoformer/AUTOFORMER_SUBNET_B.yaml',
9 | # You can also load the checkpoint of supernet instead of the specific
10 | # subnet by modifying the `checkpoint`(path) in the following `init_cfg`
11 | # with `init_weight_from_supernet = True`.
12 | init_weight_from_supernet=False,
13 | init_cfg=dict(
14 | type='Pretrained',
15 | checkpoint= # noqa: E251
16 | 'https://download.openmmlab.com/mmrazor/v1/autoformer/autoformer_supernet_32xb256_in1k_20220919_110144-c658ce8f.pth', # noqa: E501
17 | prefix='architecture.'))
18 |
--------------------------------------------------------------------------------
/mmrazor/configs/nas/mmcls/autoslim/autoslim_mbv2_1.5x_search_8xb256_in1k.py:
--------------------------------------------------------------------------------
1 | _base_ = ['./autoslim_mbv2_1.5x_supernet_8xb256_in1k.py']
2 |
3 | model = dict(bn_training_mode=True)
4 |
5 | train_cfg = None
6 | optim_wrapper = None
7 | param_scheduler = None
8 | train_dataloader = None
9 |
10 | val_cfg = None
11 | val_dataloader = None
12 | val_evaluator = None
13 |
14 | test_cfg = dict(
15 | _delete_=True,
16 | type='mmrazor.AutoSlimGreedySearchLoop',
17 | dataloader=_base_.test_dataloader,
18 | evaluator=_base_.test_evaluator,
19 | target_flops=(500., 300., 200.))
20 |
--------------------------------------------------------------------------------
/mmrazor/configs/nas/mmcls/autoslim/autoslim_mbv2_1.5x_subnet_8xb256_in1k_flops-220M.py:
--------------------------------------------------------------------------------
1 | _base_ = 'autoslim_mbv2_1.5x_slimmable_subnet_8xb256_in1k.py'
2 |
3 | model = dict(deploy_index=0)
4 |
--------------------------------------------------------------------------------
/mmrazor/configs/nas/mmcls/autoslim/autoslim_mbv2_1.5x_subnet_8xb256_in1k_flops-320M.py:
--------------------------------------------------------------------------------
1 | _base_ = 'autoslim_mbv2_1.5x_slimmable_subnet_8xb256_in1k.py'
2 |
3 | model = dict(deploy_index=1)
4 |
--------------------------------------------------------------------------------
/mmrazor/configs/nas/mmcls/autoslim/autoslim_mbv2_1.5x_subnet_8xb256_in1k_flops-530M.py:
--------------------------------------------------------------------------------
1 | _base_ = 'autoslim_mbv2_1.5x_slimmable_subnet_8xb256_in1k.py'
2 |
3 | model = dict(deploy_index=2)
4 |
--------------------------------------------------------------------------------
/mmrazor/configs/nas/mmcls/bignas/attentive_mobilenet_search_8xb128_in1k.py:
--------------------------------------------------------------------------------
1 | _base_ = ['./attentive_mobilenet_supernet_32xb64_in1k.py']
2 |
3 | train_cfg = dict(
4 | _delete_=True,
5 | type='mmrazor.EvolutionSearchLoop',
6 | dataloader=_base_.val_dataloader,
7 | evaluator=_base_.val_evaluator,
8 | max_epochs=20,
9 | num_candidates=50,
10 | top_k=10,
11 | num_mutation=25,
12 | num_crossover=25,
13 | mutate_prob=0.1,
14 | calibrate_sample_num=4096,
15 | constraints_range=dict(flops=(0., 700.)),
16 | score_key='accuracy/top1')
17 |
--------------------------------------------------------------------------------
/mmrazor/configs/nas/mmcls/bignas/attentive_mobilenet_subnet_8xb256_in1k.py:
--------------------------------------------------------------------------------
1 | _base_ = 'attentive_mobilenet_supernet_32xb64_in1k.py'
2 |
3 | model = dict(
4 | _scope_='mmrazor',
5 | type='sub_model',
6 | cfg=_base_.supernet,
7 | # NOTE: You can replace the yaml with the mutable_cfg searched by yourself
8 | fix_subnet='configs/nas/mmcls/bignas/ATTENTIVE_SUBNET_A0.yaml',
9 | # You can load the checkpoint of supernet instead of the specific
10 | # subnet by modifying the `checkpoint`(path) in the following `init_cfg`
11 | # with `init_weight_from_supernet = True`.
12 | init_weight_from_supernet=True,
13 | init_cfg=dict(
14 | type='Pretrained',
15 | checkpoint= # noqa: E251
16 | 'https://download.openmmlab.com/mmrazor/v1/bignas/attentive_mobilenet_supernet_32xb64_in1k_flops-2G_acc-81.72_20221229_200440-954772a3.pth', # noqa: E501
17 | prefix='architecture.'))
18 |
19 | model_wrapper_cfg = None
20 | find_unused_parameters = True
21 |
22 | test_cfg = dict(evaluate_fixed_subnet=True)
23 |
24 | default_hooks = dict(checkpoint=None)
25 |
--------------------------------------------------------------------------------
/mmrazor/configs/nas/mmcls/darts/darts_supernet_unroll_1xb96_cifar10.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | 'mmrazor::_base_/settings/cifar10_darts_supernet.py',
3 | 'mmrazor::_base_/nas_backbones/darts_supernet.py',
4 | 'mmcls::_base_/default_runtime.py',
5 | ]
6 |
7 | custom_hooks = [
8 | dict(type='mmrazor.DumpSubnetHook', interval=10, by_epoch=True)
9 | ]
10 |
11 | # model
12 | model = dict(
13 | type='mmrazor.Darts',
14 | architecture=dict(
15 | type='ImageClassifier',
16 | backbone=_base_.nas_backbone,
17 | neck=dict(type='GlobalAveragePooling'),
18 | head=dict(
19 | type='LinearClsHead',
20 | num_classes=10,
21 | in_channels=256,
22 | loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
23 | topk=(1, 5),
24 | cal_acc=True)),
25 | mutator=dict(type='mmrazor.NasMutator'),
26 | unroll=True)
27 |
28 | model_wrapper_cfg = dict(
29 | type='mmrazor.DartsDDP',
30 | broadcast_buffers=False,
31 | find_unused_parameters=False)
32 |
33 | find_unused_parameter = False
34 |
--------------------------------------------------------------------------------
/mmrazor/configs/nas/mmcls/darts/metafile.yml:
--------------------------------------------------------------------------------
1 | Collections:
2 | - Name: Darts
3 | Metadata:
4 | Training Data:
5 | - CIFAR-10
6 | Paper:
7 | URL: https://arxiv.org/abs/1806.09055
8 | Title: DARTS:Differentiable Architecture Search
9 | README: configs/nas/mmcls/darts/README.md
10 | Code:
11 | URL: https://github.com/open-mmlab/mmrazor/blob/v0.1.0/mmrazor/models/algorithms/darts.py
12 | Version: v0.1.0
13 | Converted From:
14 | Code: https://github.com/quark0/darts
15 | Models:
16 | - Name: darts_subnet_1xb96_cifar10_2.0
17 | In Collection: Darts
18 | Metadata:
19 | Params(M): 3.42
20 | Mutable: https://download.openmmlab.com/mmrazor/v1/darts/darts_subnetnet_1xb96_cifar10_acc-97.32_20211222-e5727921_mutable_cfg.yaml
21 | Results:
22 | - Task: Image Classification
23 | Dataset: CIFAR-10
24 | Metrics:
25 | Top 1 Accuracy: 97.32
26 | Top 5 Accuracy: 99.94
27 | Config: configs/nas/mmcls/darts/darts_subnet_1xb96_cifar10_2.0.py
28 | Weights: https://download.openmmlab.com/mmrazor/v1/darts/darts_subnetnet_1xb96_cifar10_acc-97.32_20211222-e5727921_latest.pth
29 |
--------------------------------------------------------------------------------
/mmrazor/configs/nas/mmcls/dsnas/DSNAS_SUBNET_IMAGENET_PAPER_ALIAS.yaml:
--------------------------------------------------------------------------------
1 | backbone.layers.0.0:
2 | chosen: shuffle_3x3
3 | backbone.layers.0.1:
4 | chosen: shuffle_7x7
5 | backbone.layers.0.2:
6 | chosen: shuffle_3x3
7 | backbone.layers.0.3:
8 | chosen: shuffle_5x5
9 | backbone.layers.1.0:
10 | chosen: shuffle_3x3
11 | backbone.layers.1.1:
12 | chosen: shuffle_3x3
13 | backbone.layers.1.2:
14 | chosen: shuffle_3x3
15 | backbone.layers.1.3:
16 | chosen: shuffle_7x7
17 | backbone.layers.2.0:
18 | chosen: shuffle_xception
19 | backbone.layers.2.1:
20 | chosen: shuffle_3x3
21 | backbone.layers.2.2:
22 | chosen: shuffle_3x3
23 | backbone.layers.2.3:
24 | chosen: shuffle_5x5
25 | backbone.layers.2.4:
26 | chosen: shuffle_3x3
27 | backbone.layers.2.5:
28 | chosen: shuffle_5x5
29 | backbone.layers.2.6:
30 | chosen: shuffle_7x7
31 | backbone.layers.2.7:
32 | chosen: shuffle_7x7
33 | backbone.layers.3.0:
34 | chosen: shuffle_xception
35 | backbone.layers.3.1:
36 | chosen: shuffle_3x3
37 | backbone.layers.3.2:
38 | chosen: shuffle_7x7
39 | backbone.layers.3.3:
40 | chosen: shuffle_3x3
41 |
--------------------------------------------------------------------------------
/mmrazor/configs/nas/mmcls/dsnas/dsnas_subnet_8xb128_in1k.py:
--------------------------------------------------------------------------------
1 | _base_ = ['./dsnas_supernet_8xb128_in1k.py']
2 |
3 | model = dict(
4 | _scope_='mmrazor',
5 | type='sub_model',
6 | cfg=_base_.supernet,
7 | # NOTE: You can replace the yaml with the mutable_cfg searched by yourself
8 | fix_subnet= # noqa: E251
9 | 'configs/nas/mmcls/dsnas/DSNAS_SUBNET_IMAGENET_PAPER_ALIAS.yaml'
10 | ) # noqa: E501
11 |
12 | find_unused_parameters = False
13 |
--------------------------------------------------------------------------------
/mmrazor/configs/nas/mmcls/dsnas/dsnas_supernet_8xb128_in1k.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | 'mmrazor::_base_/settings/imagenet_bs1024_dsnas.py',
3 | 'mmrazor::_base_/nas_backbones/dsnas_shufflenet_supernet.py',
4 | 'mmcls::_base_/default_runtime.py',
5 | ]
6 |
7 | custom_hooks = [
8 | dict(type='mmrazor.DumpSubnetHook', interval=10, by_epoch=True)
9 | ]
10 |
11 | supernet = dict(
12 | _scope_='mmcls',
13 | type='ImageClassifier',
14 | data_preprocessor=_base_.data_preprocessor,
15 | backbone=_base_.nas_backbone,
16 | neck=dict(type='GlobalAveragePooling'),
17 | head=dict(
18 | type='LinearClsHead',
19 | num_classes=1000,
20 | in_channels=1024,
21 | loss=dict(
22 | type='LabelSmoothLoss',
23 | num_classes=1000,
24 | label_smooth_val=0.1,
25 | mode='original',
26 | loss_weight=1.0),
27 | topk=(1, 5)))
28 |
29 | # model
30 | model = dict(
31 | type='mmrazor.DSNAS',
32 | architecture=supernet,
33 | mutator=dict(type='mmrazor.NasMutator'),
34 | pretrain_epochs=15,
35 | finetune_epochs=_base_.search_epochs,
36 | )
37 |
38 | model_wrapper_cfg = dict(
39 | type='mmrazor.DSNASDDP',
40 | broadcast_buffers=False,
41 | find_unused_parameters=True)
42 |
43 | randomness = dict(seed=48, diff_rank_seed=True)
44 |
--------------------------------------------------------------------------------
/mmrazor/configs/nas/mmcls/onceforall/ofa_mobilenet_search_8xb128_in1k.py:
--------------------------------------------------------------------------------
1 | _base_ = ['./ofa_mobilenet_supernet_32xb64_in1k.py']
2 |
3 | train_cfg = dict(
4 | _delete_=True,
5 | type='mmrazor.EvolutionSearchLoop',
6 | dataloader=_base_.val_dataloader,
7 | evaluator=_base_.val_evaluator,
8 | max_epochs=1,
9 | num_candidates=2,
10 | top_k=1,
11 | num_mutation=1,
12 | num_crossover=1,
13 | mutate_prob=0.1,
14 | calibrate_sample_num=4096,
15 | constraints_range=dict(flops=(0., 700.)),
16 | score_key='accuracy/top1')
17 |
--------------------------------------------------------------------------------
/mmrazor/configs/nas/mmcls/onceforall/ofa_mobilenet_subnet_8xb256_in1k.py:
--------------------------------------------------------------------------------
1 | _base_ = 'ofa_mobilenet_supernet_32xb64_in1k.py'
2 |
3 | model = dict(
4 | _scope_='mmrazor',
5 | type='sub_model',
6 | cfg=_base_.supernet,
7 | # NOTE: You can replace the yaml with the mutable_cfg searched by yourself
8 | fix_subnet='configs/nas/mmcls/onceforall/OFA_SUBNET_NOTE8_LAT31.yaml',
9 | # You can also load the checkpoint of supernet instead of the specific
10 | # subnet by modifying the `checkpoint`(path) in the following `init_cfg`
11 | # with `init_weight_from_supernet = True`.
12 | init_weight_from_supernet=False,
13 | init_cfg=dict(
14 | type='Pretrained',
15 | checkpoint= # noqa: E251
16 | 'https://download.openmmlab.com/mmrazor/v1/ofa/ofa_mobilenet_subnet_8xb256_in1k_note8_lat%4031ms_top1%4072.8_finetune%4025.py_20221214_0939-981a8b2a.pth', # noqa: E501
17 | prefix='architecture.'))
18 |
19 | model_wrapper_cfg = None
20 | find_unused_parameters = True
21 |
22 | test_cfg = dict(evaluate_fixed_subnet=True)
23 |
--------------------------------------------------------------------------------
/mmrazor/configs/nas/mmcls/spos/SPOS_SUBNET.yaml:
--------------------------------------------------------------------------------
1 | backbone.layers.0.0:
2 | chosen: shuffle_7x7
3 | backbone.layers.0.1:
4 | chosen: shuffle_3x3
5 | backbone.layers.0.2:
6 | chosen: shuffle_7x7
7 | backbone.layers.0.3:
8 | chosen: shuffle_3x3
9 | backbone.layers.1.0:
10 | chosen: shuffle_xception
11 | backbone.layers.1.1:
12 | chosen: shuffle_5x5
13 | backbone.layers.1.2:
14 | chosen: shuffle_5x5
15 | backbone.layers.1.3:
16 | chosen: shuffle_3x3
17 | backbone.layers.2.0:
18 | chosen: shuffle_3x3
19 | backbone.layers.2.1:
20 | chosen: shuffle_5x5
21 | backbone.layers.2.2:
22 | chosen: shuffle_3x3
23 | backbone.layers.2.3:
24 | chosen: shuffle_5x5
25 | backbone.layers.2.4:
26 | chosen: shuffle_3x3
27 | backbone.layers.2.5:
28 | chosen: shuffle_xception
29 | backbone.layers.2.6:
30 | chosen: shuffle_5x5
31 | backbone.layers.2.7:
32 | chosen: shuffle_7x7
33 | backbone.layers.3.0:
34 | chosen: shuffle_7x7
35 | backbone.layers.3.1:
36 | chosen: shuffle_3x3
37 | backbone.layers.3.2:
38 | chosen: shuffle_5x5
39 | backbone.layers.3.3:
40 | chosen: shuffle_xception
41 |
--------------------------------------------------------------------------------
/mmrazor/configs/nas/mmcls/spos/faster-rcnn_nas_backbone_fpn_1x_coco.py:
--------------------------------------------------------------------------------
1 | # Suppose you are in mmdet and want to use the searched subnet
2 | # as backbone for faster-rcnn, then you can just use this config.
3 |
4 | _base_ = [
5 | '../_base_/models/faster-rcnn_r50_fpn.py',
6 | '../_base_/datasets/coco_detection.py',
7 | '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py',
8 | 'mmrazor::_base_/nas_backbones/spos_shufflenet_supernet.py'
9 | ]
10 |
11 | _base_.nas_backbone.out_indices = (0, 1, 2, 3)
12 | _base_.nas_backbone.with_last_layer = False
13 | nas_backbone = dict(
14 | # use mmrazor's build_func
15 | type='mmrazor.sub_model',
16 | cfg=_base_.nas_backbone,
17 | fix_subnet='/path/to/your/mmrazor/configs/nas/mmcls/spos/SPOS_SUBNET.yaml',
18 | extra_prefix='backbone.')
19 |
20 | _base_.model.backbone = nas_backbone
21 | _base_.model.neck.in_channels = [64, 160, 320, 640]
22 |
--------------------------------------------------------------------------------
/mmrazor/configs/nas/mmcls/spos/metafile.yml:
--------------------------------------------------------------------------------
1 | Collections:
2 | - Name: SPOS
3 | Metadata:
4 | Training Data:
5 | - ImageNet-1k
6 | Paper:
7 | URL: https://arxiv.org/abs/1904.00420
8 | Title: Single Path One-Shot Neural Architecture Search with Uniform Sampling
9 | README: configs/nas/mmcls/spos/README.md
10 | Code:
11 | URL: https://github.com/open-mmlab/mmrazor/blob/v0.1.0/mmrazor/models/algorithms/spos.py
12 | Version: v0.1.0
13 | Converted From:
14 | Code: https://github.com/megvii-model/SinglePathOneShot
15 | Models:
16 | - Name: spos_shufflenet_subnet_8xb128_in1k
17 | In Collection: SPOS
18 | Metadata:
19 | FLOPs: 330 MB
20 | Subnet: https://download.openmmlab.com/mmrazor/v1/spos/spos_shufflenetv2_subnet_8xb128_in1k_flops_0.33M_acc_73.87_20211222-1f0a0b4d_subnet_cfg_v3.yaml
21 | Results:
22 | - Task: Image Classification
23 | Dataset: ImageNet-1k
24 | Metrics:
25 | Top 1 Accuracy: 73.87
26 | Top 5 Accuracy: 91.60
27 | Config: configs/nas/mmcls/spos/spos_shufflenet_subnet_8xb128_in1k.py
28 | Weights: https://download.openmmlab.com/mmrazor/v1/spos/spos_shufflenetv2_subnet_8xb128_in1k_flops_0.33M_acc_73.87_20211222-1f0a0b4d_v3.pth
29 |
--------------------------------------------------------------------------------
/mmrazor/configs/nas/mmcls/spos/spos_mobilenet_search_8xb128_in1k.py:
--------------------------------------------------------------------------------
1 | _base_ = ['./spos_mobilenet_supernet_8xb128_in1k.py']
2 |
3 | model = dict(norm_training=True)
4 |
5 | train_cfg = dict(
6 | _delete_=True,
7 | type='mmrazor.EvolutionSearchLoop',
8 | dataloader=_base_.val_dataloader,
9 | evaluator=_base_.val_evaluator,
10 | max_epochs=20,
11 | num_candidates=50,
12 | top_k=10,
13 | num_mutation=25,
14 | num_crossover=25,
15 | mutate_prob=0.1,
16 | constraints_range=dict(flops=(0., 465.)),
17 | score_key='accuracy/top1')
18 |
--------------------------------------------------------------------------------
/mmrazor/configs/nas/mmcls/spos/spos_mobilenet_subnet_8xb128_in1k.py:
--------------------------------------------------------------------------------
1 | _base_ = ['./spos_mobilenet_supernet_8xb128_in1k.py']
2 |
3 | model = dict(
4 | _scope_='mmrazor',
5 | type='sub_model',
6 | cfg=_base_.supernet,
7 | # NOTE: You can replace the yaml with the mutable_cfg searched by yourself
8 | fix_subnet='configs/nas/spos/AngleNAS_SHUFFLENETV2_IN1k_2.0.yaml',
9 | init_cfg=dict(
10 | type='Pretrained',
11 | checkpoint= # noqa: E251
12 | 'https://download.openmmlab.com/mmrazor/v1/spos/spos_mobilenetv2_subnet_8xb128_in1k_flops_0.33M_acc_73.87_20211222-1f0a0b4d_v3.pth', # noqa: E501
13 | prefix='architecture.'))
14 |
15 | model_wrapper_cfg = None
16 |
17 | find_unused_parameters = False
18 |
--------------------------------------------------------------------------------
/mmrazor/configs/nas/mmcls/spos/spos_mobilenet_supernet_8xb128_in1k.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | 'mmrazor::_base_/settings/imagenet_bs1024_spos.py',
3 | 'mmrazor::_base_/nas_backbones/spos_mobilenet_supernet.py',
4 | 'mmcls::_base_/default_runtime.py',
5 | ]
6 |
7 | # model
8 | supernet = dict(
9 | _scope_='mmcls',
10 | type='ImageClassifier',
11 | # data_preprocessor=_base_.preprocess_cfg,
12 | backbone=_base_.nas_backbone,
13 | neck=dict(type='GlobalAveragePooling'),
14 | head=dict(
15 | type='LinearClsHead',
16 | num_classes=1000,
17 | in_channels=1728,
18 | loss=dict(
19 | type='LabelSmoothLoss',
20 | num_classes=1000,
21 | label_smooth_val=0.1,
22 | mode='original',
23 | loss_weight=1.0),
24 | topk=(1, 5)))
25 |
26 | model = dict(
27 | type='mmrazor.SPOS',
28 | architecture=supernet,
29 | mutator=dict(type='mmrazor.NasMutator'))
30 |
31 | find_unused_parameters = True
32 |
--------------------------------------------------------------------------------
/mmrazor/configs/nas/mmcls/spos/spos_shufflenet_search_8xb128_in1k.py:
--------------------------------------------------------------------------------
1 | _base_ = ['./spos_shufflenet_supernet_8xb128_in1k.py']
2 |
3 | model = dict(norm_training=True)
4 |
5 | train_cfg = dict(
6 | _delete_=True,
7 | type='mmrazor.EvolutionSearchLoop',
8 | dataloader=_base_.val_dataloader,
9 | evaluator=_base_.val_evaluator,
10 | max_epochs=20,
11 | num_candidates=50,
12 | top_k=10,
13 | num_mutation=25,
14 | num_crossover=25,
15 | mutate_prob=0.1,
16 | constraints_range=dict(flops=(0, 330)),
17 | score_key='accuracy/top1')
18 |
--------------------------------------------------------------------------------
/mmrazor/configs/nas/mmcls/spos/spos_shufflenet_search_predictor_8xb128_in1k.py:
--------------------------------------------------------------------------------
1 | _base_ = ['./spos_shufflenet_supernet_8xb128_in1k.py']
2 |
3 | model = dict(norm_training=True)
4 |
5 | train_cfg = dict(
6 | _delete_=True,
7 | type='mmrazor.EvolutionSearchLoop',
8 | dataloader=_base_.val_dataloader,
9 | evaluator=_base_.val_evaluator,
10 | max_epochs=20,
11 | num_candidates=50,
12 | top_k=10,
13 | num_mutation=25,
14 | num_crossover=25,
15 | mutate_prob=0.1,
16 | constraints_range=dict(flops=(0., 360.)),
17 | predictor_cfg=dict(
18 | type='mmrazor.MetricPredictor',
19 | train_samples=20,
20 | handler_cfg=dict(type='mmrazor.GaussProcessHandler')),
21 | )
22 |
--------------------------------------------------------------------------------
/mmrazor/configs/nas/mmcls/spos/spos_shufflenet_subnet_8xb128_in1k.py:
--------------------------------------------------------------------------------
1 | _base_ = ['./spos_shufflenet_supernet_8xb128_in1k.py']
2 |
3 | _base_.model = dict(
4 | _scope_='mmrazor',
5 | type='sub_model',
6 | cfg=_base_.supernet,
7 | # NOTE: You can replace the yaml with the mutable_cfg searched by yourself
8 | fix_subnet='configs/nas/mmcls/spos/SPOS_SUBNET.yaml',
9 | init_cfg=dict(
10 | type='Pretrained',
11 | checkpoint= # noqa: E251
12 | 'https://download.openmmlab.com/mmrazor/v1/spos/spos_shufflenetv2_subnet_8xb128_in1k_flops_0.33M_acc_73.87_20211222-1f0a0b4d_v3.pth', # noqa: E501
13 | prefix='architecture.'))
14 |
15 | model_wrapper_cfg = None
16 |
17 | find_unused_parameters = False
18 |
--------------------------------------------------------------------------------
/mmrazor/configs/nas/mmcls/spos/spos_shufflenet_supernet_8xb128_in1k.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | 'mmrazor::_base_/settings/imagenet_bs1024_spos.py',
3 | 'mmrazor::_base_/nas_backbones/spos_shufflenet_supernet.py',
4 | 'mmcls::_base_/default_runtime.py',
5 | ]
6 |
7 | # model
8 | supernet = dict(
9 | _scope_='mmcls',
10 | type='ImageClassifier',
11 | data_preprocessor=_base_.preprocess_cfg,
12 | backbone=_base_.nas_backbone,
13 | neck=dict(type='GlobalAveragePooling'),
14 | head=dict(
15 | type='LinearClsHead',
16 | num_classes=1000,
17 | in_channels=1024,
18 | loss=dict(
19 | type='LabelSmoothLoss',
20 | num_classes=1000,
21 | label_smooth_val=0.1,
22 | mode='original',
23 | loss_weight=1.0),
24 | topk=(1, 5)))
25 |
26 | model = dict(
27 | type='mmrazor.SPOS',
28 | architecture=supernet,
29 | mutator=dict(type='mmrazor.NasMutator'))
30 |
31 | find_unused_parameters = True
32 |
--------------------------------------------------------------------------------
/mmrazor/configs/nas/mmdet/detnas/DETNAS_SUBNET.yaml:
--------------------------------------------------------------------------------
1 | backbone.layers.0.0:
2 | chosen: shuffle_5x5
3 | backbone.layers.0.1:
4 | chosen: shuffle_3x3
5 | backbone.layers.0.2:
6 | chosen: shuffle_3x3
7 | backbone.layers.0.3:
8 | chosen: shuffle_3x3
9 | backbone.layers.1.0:
10 | chosen: shuffle_xception
11 | backbone.layers.1.1:
12 | chosen: shuffle_3x3
13 | backbone.layers.1.2:
14 | chosen: shuffle_xception
15 | backbone.layers.1.3:
16 | chosen: shuffle_7x7
17 | backbone.layers.2.0:
18 | chosen: shuffle_7x7
19 | backbone.layers.2.1:
20 | chosen: shuffle_7x7
21 | backbone.layers.2.2:
22 | chosen: shuffle_xception
23 | backbone.layers.2.3:
24 | chosen: shuffle_xception
25 | backbone.layers.2.4:
26 | chosen: shuffle_3x3
27 | backbone.layers.2.5:
28 | chosen: shuffle_7x7
29 | backbone.layers.2.6:
30 | chosen: shuffle_5x5
31 | backbone.layers.2.7:
32 | chosen: shuffle_xception
33 | backbone.layers.3.0:
34 | chosen: shuffle_7x7
35 | backbone.layers.3.1:
36 | chosen: shuffle_7x7
37 | backbone.layers.3.2:
38 | chosen: shuffle_7x7
39 | backbone.layers.3.3:
40 | chosen: shuffle_5x5
41 |
--------------------------------------------------------------------------------
/mmrazor/configs/nas/mmdet/detnas/detnas_frcnn_shufflenet_search_coco_1x.py:
--------------------------------------------------------------------------------
1 | _base_ = ['./detnas_frcnn_shufflenet_supernet_coco_1x.py']
2 |
3 | model = dict(norm_training=True)
4 |
5 | train_cfg = dict(
6 | _delete_=True,
7 | type='mmrazor.EvolutionSearchLoop',
8 | dataloader=_base_.val_dataloader,
9 | evaluator=_base_.val_evaluator,
10 | max_epochs=20,
11 | num_candidates=50,
12 | top_k=10,
13 | num_mutation=20,
14 | num_crossover=20,
15 | mutate_prob=0.1,
16 | constraints_range=dict(flops=(0, 330)),
17 | score_key='coco/bbox_mAP')
18 |
--------------------------------------------------------------------------------
/mmrazor/configs/nas/mmdet/detnas/detnas_frcnn_shufflenet_subnet_coco_1x.py:
--------------------------------------------------------------------------------
1 | _base_ = ['./detnas_frcnn_shufflenet_supernet_coco_1x.py']
2 |
3 | model = dict(
4 | _scope_='mmrazor',
5 | type='sub_model',
6 | cfg=_base_.supernet,
7 | # NOTE: You can replace the yaml with the mutable_cfg searched by yourself
8 | fix_subnet='configs/nas/mmdet/detnas/DETNAS_SUBNET.yaml',
9 | init_cfg=dict(
10 | type='Pretrained',
11 | checkpoint= # noqa: E251
12 | 'https://download.openmmlab.com/mmrazor/v1/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco_bbox_backbone_flops-0.34M_mAP-37.5_20220715-61d2e900_v1.pth', # noqa: E501
13 | prefix='architecture.'))
14 |
15 | find_unused_parameters = False
16 |
--------------------------------------------------------------------------------
/mmrazor/configs/nas/mmdet/detnas/detnas_frcnn_shufflenet_supernet_coco_1x.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | 'mmdet::_base_/models/faster-rcnn_r50_fpn.py',
3 | 'mmdet::_base_/datasets/coco_detection.py',
4 | 'mmdet::_base_/schedules/schedule_1x.py',
5 | 'mmdet::_base_/default_runtime.py',
6 | 'mmrazor::_base_/nas_backbones/spos_shufflenet_supernet.py'
7 | ]
8 |
9 | norm_cfg = dict(type='SyncBN', requires_grad=True)
10 |
11 | supernet = _base_.model
12 |
13 | supernet.backbone = _base_.nas_backbone
14 | supernet.backbone.norm_cfg = norm_cfg
15 | supernet.backbone.out_indices = (0, 1, 2, 3)
16 | supernet.backbone.with_last_layer = False
17 |
18 | supernet.neck.norm_cfg = norm_cfg
19 | supernet.neck.in_channels = [64, 160, 320, 640]
20 |
21 | supernet.roi_head.bbox_head.norm_cfg = norm_cfg
22 | supernet.roi_head.bbox_head.type = 'Shared4Conv1FCBBoxHead'
23 |
24 | model = dict(
25 | _delete_=True,
26 | type='mmrazor.SPOS',
27 | architecture=supernet,
28 | mutator=dict(type='mmrazor.NasMutator'))
29 |
30 | find_unused_parameters = True
31 |
--------------------------------------------------------------------------------
/mmrazor/configs/nas/mmdet/detnas/detnas_retina_shufflenet_supernet_coco_1x.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | 'mmdet::_base_/models/retinanet_r50_fpn.py',
3 | 'mmdet::_base_/datasets/coco_detection.py',
4 | 'mmdet::_base_/schedules/schedule_1x.py',
5 | 'mmdet::_base_/default_runtime.py',
6 | 'mmrazor::_base_/nas_backbones/spos_shufflenet_supernet.py'
7 | ]
8 |
9 | norm_cfg = dict(type='SyncBN', requires_grad=True)
10 |
11 | supernet = _base_.model
12 |
13 | supernet.backbone = _base_.nas_backbone
14 | supernet.backbone.norm_cfg = norm_cfg
15 | supernet.backbone.out_indices = (0, 1, 2, 3)
16 | supernet.backbone.with_last_layer = False
17 |
18 | supernet.neck.norm_cfg = norm_cfg
19 | supernet.neck.in_channels = [64, 160, 320, 640]
20 |
21 | model = dict(
22 | _delete_=True,
23 | type='mmrazor.SPOS',
24 | architecture=supernet,
25 | mutator=dict(type='mmrazor.NasMutator'))
26 |
27 | find_unused_parameters = True
28 |
--------------------------------------------------------------------------------
/mmrazor/configs/nas/mmdet/detnas/detnas_shufflenet_subnet_8xb128_in1k.py:
--------------------------------------------------------------------------------
1 | _base_ = './detnas_shufflenet_supernet_8xb128_in1k.py'
2 |
3 | model = dict(
4 | _scope_='mmrazor',
5 | type='sub_model',
6 | cfg=_base_.supernet,
7 | # NOTE: You can replace the yaml with the mutable_cfg searched by yourself
8 | fix_subnet= # noqa: E251
9 | 'https://download.openmmlab.com/mmrazor/v1/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco_bbox_backbone_flops-0.34M_mAP-37.5_20220715-61d2e900_subnet_cfg_v1.yaml' # noqa: E501
10 | )
11 |
12 | find_unused_parameters = False
13 |
--------------------------------------------------------------------------------
/mmrazor/configs/nas/mmdet/detnas/detnas_shufflenet_supernet_8xb128_in1k.py:
--------------------------------------------------------------------------------
1 | _base_ = 'mmrazor::nas/mmcls/spos/shufflenet/spos_shufflenet_supernet_8xb128_in1k.py' # noqa: E501
2 |
--------------------------------------------------------------------------------
/mmrazor/configs/nas/mmdet/detnas/metafile.yml:
--------------------------------------------------------------------------------
1 | Collections:
2 | - Name: DetNAS
3 | Metadata:
4 | Training Data:
5 | - ImageNet-1k
6 | - COCO
7 | Paper:
8 | URL: https://arxiv.org/abs/1903.10979
9 | Title: DetNAS:Backbone Search for Object Detection
10 | README: configs/nas/mmdet/detnas/README.md
11 | Code:
12 | URL: https://github.com/open-mmlab/mmrazor/blob/v0.1.0/mmrazor/models/algorithms/detnas.py
13 | Version: v0.1.0
14 | Converted From:
15 | Code: https://github.com/megvii-model/DetNAS
16 | Models:
17 | - Name: detnas_frcnn_shufflenet_subnet_coco_1x
18 | In Collection: DetNAS
19 | Metadata:
20 | FLOPs(Backbone): 340 MB
21 | Params(Backbone): 3.35 MB
22 | Supernet: FRCNN-ShuffleNetV2
23 | Mutable: https://download.openmmlab.com/mmrazor/v1/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco_bbox_backbone_flops-0.34M_mAP-37.5_20220715-61d2e900_subnet_cfg_v1.yaml
24 | Results:
25 | - Task: Object Detection
26 | Dataset: COCO
27 | Metrics:
28 | box AP: 37.5
29 | Config: configs/nas/mmdet/detnas/detnas_frcnn_shufflenet_subnet_coco_1x.py
30 | Weights: https://download.openmmlab.com/mmrazor/v1/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco_bbox_backbone_flops-0.34M_mAP-37.5_20220715-61d2e900_v1.pth
31 |
--------------------------------------------------------------------------------
/mmrazor/configs/pruning/base/group_fisher/group_fisher_deploy_template.py:
--------------------------------------------------------------------------------
1 | #############################################################################
2 | """You have to fill these args.
3 |
4 | _base_(str): The path to your pretrain config file.
5 | fix_subnet (Union[dict,str]): The dict store the pruning structure or the
6 | json file including it.
7 | divisor (int): The divisor the make the channel number divisible.
8 | """
9 |
10 | _base_ = ''
11 | fix_subnet = {}
12 | divisor = 8
13 | ##############################################################################
14 |
15 | architecture = _base_.model
16 |
17 | model = dict(
18 | _delete_=True,
19 | _scope_='mmrazor',
20 | type='GroupFisherDeploySubModel',
21 | architecture=architecture,
22 | fix_subnet=fix_subnet,
23 | divisor=divisor,
24 | )
25 |
--------------------------------------------------------------------------------
/mmrazor/configs/pruning/base/group_fisher/group_fisher_finetune_template.py:
--------------------------------------------------------------------------------
1 | #############################################################################
2 | """# You have to fill these args.
3 |
4 | _base_(str): The path to your pruning config file.
5 | pruned_path (str): The path to the checkpoint of the pruned model.
6 | finetune_lr (float): The lr rate to finetune. Usually, we directly use the lr
7 | rate of the pretrain.
8 | """
9 |
10 | _base_ = ''
11 | pruned_path = ''
12 | finetune_lr = 0.1
13 | ##############################################################################
14 |
15 | algorithm = _base_.model
16 | algorithm.init_cfg = dict(type='Pretrained', checkpoint=pruned_path)
17 |
18 | model = dict(
19 | _delete_=True,
20 | _scope_='mmrazor',
21 | type='GroupFisherSubModel',
22 | algorithm=algorithm,
23 | )
24 |
25 | # restore lr
26 | optim_wrapper = dict(optimizer=dict(lr=finetune_lr))
27 |
28 | # remove pruning related hooks
29 | custom_hooks = _base_.custom_hooks[:-2]
30 |
31 | # delete ddp
32 | model_wrapper_cfg = None
33 |
--------------------------------------------------------------------------------
/mmrazor/configs/pruning/mmcls/dcff/dcff_compact_resnet_8xb32_in1k.py:
--------------------------------------------------------------------------------
1 | _base_ = ['dcff_resnet_8xb32_in1k.py']
2 |
3 | # model settings
4 | _base_.model = dict(
5 | _scope_='mmrazor',
6 | type='sub_model',
7 | cfg=dict(
8 | cfg_path='mmcls::resnet/resnet50_8xb32_in1k.py', pretrained=False),
9 | fix_subnet='configs/pruning/mmcls/dcff/fix_subnet.json',
10 | mode='mutator',
11 | init_cfg=dict(
12 | type='Pretrained',
13 | checkpoint='configs/pruning/mmcls/dcff/fix_subnet_weight.pth'))
14 |
--------------------------------------------------------------------------------
/mmrazor/configs/pruning/mmcls/dmcp/metafile.yml:
--------------------------------------------------------------------------------
1 | # Models:
2 | # - Name: dmcp_resnet50_subnet_32xb64
3 | # In Collection: DMCP
4 | # Config: configs/pruning/mmcls/dmcp/dmcp_resnet50_subnet_32xb64.py
5 | # Weights: https://download.openmmlab.com/mmrazor/v1/pruning/dmcp/resnet50/2G/DMCP_R50_2G.pth
6 | # Results:
7 | # - Task: Image Classification
8 | # Dataset: ImageNet-1k
9 | # Metrics:
10 | # Top 1 Accuracy: 76.11
11 | # - Name: dmcp_mbv2_subnet_32xb64
12 | # In Collection: DMCP
13 | # Config: configs/pruning/mmcls/dmcp/dmcp_mbv2_subnet_32xb64.py
14 | # Weights: https://download.openmmlab.com/mmrazor/v1/pruning/dmcp/mobilenetv2/100M/DMCP_MBV2_100M.pth
15 | # Results:
16 | # - Task: Image Classification
17 | # Dataset: ImageNet-1k
18 | # Metrics:
19 | # Top 1 Accuracy: 67.22
20 |
--------------------------------------------------------------------------------
/mmrazor/configs/pruning/mmcls/group_fisher/mobilenet/group_fisher_act_finetune_mobilenet-v2_8xb32_in1k.py:
--------------------------------------------------------------------------------
1 | #############################################################################
2 | """# You have to fill these args.
3 |
4 | _base_(str): The path to your pruning config file.
5 | pruned_path (str): The path to the checkpoint of the pruned model.
6 | finetune_lr (float): The lr rate to finetune. Usually, we directly use the lr
7 | rate of the pretrain.
8 | """
9 |
10 | _base_ = './group_fisher_act_prune_mobilenet-v2_8xb32_in1k.py'
11 | pruned_path = 'https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/mobilenet/act/group_fisher_act_prune_mobilenet-v2_8xb32_in1k.pth' # noqa
12 | finetune_lr = 0.045
13 | ##############################################################################
14 | algorithm = _base_.model
15 | algorithm.init_cfg = dict(type='Pretrained', checkpoint=pruned_path)
16 |
17 | model = dict(
18 | _delete_=True,
19 | _scope_='mmrazor',
20 | type='GroupFisherSubModel',
21 | algorithm=algorithm,
22 | )
23 |
24 | # restore lr
25 | optim_wrapper = dict(optimizer=dict(lr=finetune_lr))
26 |
27 | # remove pruning related hooks
28 | custom_hooks = _base_.custom_hooks[:-2]
29 |
30 | # delete ddp
31 | model_wrapper_cfg = None
32 |
--------------------------------------------------------------------------------
/mmrazor/configs/pruning/mmcls/group_fisher/mobilenet/group_fisher_flops_finetune_mobilenet-v2_8xb32_in1k.py:
--------------------------------------------------------------------------------
1 | #############################################################################
2 | """# You have to fill these args.
3 |
4 | _base_(str): The path to your pruning config file.
5 | pruned_path (str): The path to the checkpoint of the pruned model.
6 | finetune_lr (float): The lr rate to finetune. Usually, we directly use the lr
7 | rate of the pretrain.
8 | """
9 |
10 | _base_ = './group_fisher_flops_prune_mobilenet-v2_8xb32_in1k.py'
11 | pruned_path = 'https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/mobilenet/flop/group_fisher_flops_prune_mobilenet-v2_8xb32_in1k.pth' # noqa
12 | finetune_lr = 0.045
13 | ##############################################################################
14 |
15 | algorithm = _base_.model
16 | algorithm.init_cfg = dict(type='Pretrained', checkpoint=pruned_path)
17 |
18 | model = dict(
19 | _delete_=True,
20 | _scope_='mmrazor',
21 | type='GroupFisherSubModel',
22 | algorithm=algorithm,
23 | )
24 |
25 | # restore lr
26 | optim_wrapper = dict(optimizer=dict(lr=finetune_lr))
27 |
28 | # remove pruning related hooks
29 | custom_hooks = _base_.custom_hooks[:-2]
30 |
31 | # delete ddp
32 | model_wrapper_cfg = None
33 |
--------------------------------------------------------------------------------
/mmrazor/configs/pruning/mmcls/group_fisher/mobilenet/group_fisher_flops_prune_mobilenet-v2_8xb32_in1k.py:
--------------------------------------------------------------------------------
1 | _base_ = './group_fisher_act_prune_mobilenet-v2_8xb32_in1k.py'
2 | model = dict(
3 | mutator=dict(
4 | channel_unit_cfg=dict(
5 | default_args=dict(normalization_type='flops', ), ), ), )
6 |
--------------------------------------------------------------------------------
/mmrazor/configs/pruning/mmcls/group_fisher/mobilenet/metafile.yml:
--------------------------------------------------------------------------------
1 | Models:
2 | - Name: group_fisher_act_finetune_mobilenet-v2_8xb32_in1k
3 | In Collection: GroupFisher
4 | Config: configs/pruning/mmcls/group_fisher/mobilenet/group_fisher_act_finetune_mobilenet-v2_8xb32_in1k.py
5 | Weights: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/mobilenet/act/group_fisher_act_finetune_mobilenet-v2_8xb32_in1k.pth
6 | Results:
7 | - Task: Image Classification
8 | Dataset: ImageNet-1k
9 | Metrics:
10 | Top 1 Accuracy: 70.82
11 | - Name: group_fisher_flops_finetune_mobilenet-v2_8xb32_in1k
12 | In Collection: GroupFisher
13 | Config: configs/pruning/mmcls/group_fisher/mobilenet/group_fisher_flops_finetune_mobilenet-v2_8xb32_in1k.py
14 | Weights: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/mobilenet/flop/group_fisher_flops_finetune_mobilenet-v2_8xb32_in1k.pth
15 | Results:
16 | - Task: Image Classification
17 | Dataset: ImageNet-1k
18 | Metrics:
19 | Top 1 Accuracy: 70.87
20 |
--------------------------------------------------------------------------------
/mmrazor/configs/pruning/mmcls/group_fisher/resnet50/group_fisher_act_finetune_resnet50_8xb32_in1k.py:
--------------------------------------------------------------------------------
1 | #############################################################################
2 | """# You have to fill these args.
3 |
4 | _base_(str): The path to your pruning config file.
5 | pruned_path (str): The path to the checkpoint of the pruned model.
6 | finetune_lr (float): The lr rate to finetune. Usually, we directly use the lr
7 | rate of the pretrain.
8 | """
9 |
10 | _base_ = './group_fisher_act_prune_resnet50_8xb32_in1k.py'
11 | pruned_path = 'https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/resnet50/act/group_fisher_act_prune_resnet50_8xb32_in1k.pth' # noqa
12 | finetune_lr = 0.1
13 | ##############################################################################
14 | algorithm = _base_.model
15 | algorithm.init_cfg = dict(type='Pretrained', checkpoint=pruned_path)
16 |
17 | model = dict(
18 | _delete_=True,
19 | _scope_='mmrazor',
20 | type='GroupFisherSubModel',
21 | algorithm=algorithm,
22 | )
23 |
24 | # restore lr
25 | optim_wrapper = dict(optimizer=dict(lr=finetune_lr))
26 |
27 | # remove pruning related hooks
28 | custom_hooks = _base_.custom_hooks[:-2]
29 |
30 | # delete ddp
31 | model_wrapper_cfg = None
32 |
--------------------------------------------------------------------------------
/mmrazor/configs/pruning/mmcls/group_fisher/resnet50/group_fisher_flops_finetune_resnet50_8xb32_in1k.py:
--------------------------------------------------------------------------------
1 | #############################################################################
2 | """# You have to fill these args.
3 |
4 | _base_(str): The path to your pruning config file.
5 | pruned_path (str): The path to the checkpoint of the pruned model.
6 | finetune_lr (float): The lr rate to finetune. Usually, we directly use the lr
7 | rate of the pretrain.
8 | """
9 |
10 | _base_ = './group_fisher_flops_prune_resnet50_8xb32_in1k.py'
11 | pruned_path = 'https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/resnet50/flops/group_fisher_flops_prune_resnet50_8xb32_in1k.pth' # noqa
12 | finetune_lr = 0.1
13 | ##############################################################################
14 | algorithm = _base_.model
15 | algorithm.init_cfg = dict(type='Pretrained', checkpoint=pruned_path)
16 |
17 | model = dict(
18 | _delete_=True,
19 | _scope_='mmrazor',
20 | type='GroupFisherSubModel',
21 | algorithm=algorithm,
22 | )
23 |
24 | # restore lr
25 | optim_wrapper = dict(optimizer=dict(lr=finetune_lr))
26 |
27 | # remove pruning related hooks
28 | custom_hooks = _base_.custom_hooks[:-2]
29 |
30 | # delete ddp
31 | model_wrapper_cfg = None
32 |
--------------------------------------------------------------------------------
/mmrazor/configs/pruning/mmcls/group_fisher/resnet50/group_fisher_flops_prune_resnet50_8xb32_in1k.py:
--------------------------------------------------------------------------------
1 | _base_ = './group_fisher_act_prune_resnet50_8xb32_in1k.py'
2 | model = dict(
3 | mutator=dict(
4 | channel_unit_cfg=dict(
5 | default_args=dict(normalization_type='flops', ), ), ), )
6 |
--------------------------------------------------------------------------------
/mmrazor/configs/pruning/mmcls/group_fisher/resnet50/metafile.yml:
--------------------------------------------------------------------------------
1 | Models:
2 | - Name: group_fisher_act_finetune_resnet50_8xb32_in1k
3 | In Collection: GroupFisher
4 | Config: configs/pruning/mmcls/group_fisher/resnet50/group_fisher_act_finetune_resnet50_8xb32_in1k.py
5 | Weights: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/resnet50/act/group_fisher_act_finetune_resnet50_8xb32_in1k.pth
6 | Results:
7 | - Task: Image Classification
8 | Dataset: ImageNet-1k
9 | Metrics:
10 | Top 1 Accuracy: 75.22
11 | - Name: group_fisher_flops_finetune_resnet50_8xb32_in1k
12 | In Collection: GroupFisher
13 | Config: configs/pruning/mmcls/group_fisher/resnet50/group_fisher_flops_finetune_resnet50_8xb32_in1k.py
14 | Weights: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/resnet50/flops/group_fisher_flops_finetune_resnet50_8xb32_in1k.pth
15 | Results:
16 | - Task: Image Classification
17 | Dataset: ImageNet-1k
18 | Metrics:
19 | Top 1 Accuracy: 75.61
20 |
--------------------------------------------------------------------------------
/mmrazor/configs/pruning/mmcls/l1-norm/metafile.yml:
--------------------------------------------------------------------------------
1 | Models:
2 | - Name: l1-norm_resnet34_8xb32_in1k_a
3 | In Collection: L1-norm
4 | Config: configs/pruning/mmcls/l1-norm/l1-norm_resnet34_8xb32_in1k_a.py
5 | Weights: https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pruning/l1-norm/l1-norm_resnet34_8xb32_in1k_a.pth
6 | Results:
7 | - Task: Image Classification
8 | Dataset: ImageNet-1k
9 | Metrics:
10 | Top 1 Accuracy: 73.61
11 | - Name: l1-norm_resnet34_8xb32_in1k_b
12 | In Collection: L1-norm
13 | Config: configs/pruning/mmcls/l1-norm/l1-norm_resnet34_8xb32_in1k_b.py
14 | Weights: https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pruning/l1-norm/l1-norm_resnet34_8xb32_in1k_b.pth
15 | Results:
16 | - Task: Image Classification
17 | Dataset: ImageNet-1k
18 | Metrics:
19 | Top 1 Accuracy: 73.20
20 | - Name: l1-norm_resnet34_8xb32_in1k_c
21 | In Collection: L1-norm
22 | Config: configs/pruning/mmcls/l1-norm/l1-norm_resnet34_8xb32_in1k_c.py
23 | Weights: https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pruning/l1-norm/l1-norm_resnet34_8xb32_in1k_c.pth
24 | Results:
25 | - Task: Image Classification
26 | Dataset: ImageNet-1k
27 | Metrics:
28 | Top 1 Accuracy: 73.89
29 |
--------------------------------------------------------------------------------
/mmrazor/configs/pruning/mmcls/l1-norm/script.sh:
--------------------------------------------------------------------------------
1 |
2 | # export pruned checkpoint example
3 |
4 | python ./tools/pruning/get_static_model_from_algorithm.py configs/pruning/mmcls/l1-norm/l1-norm_resnet34_8xb32_in1k_a.py https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pruning/l1-norm/l1-norm_resnet34_8xb32_in1k_a.pth -o ./work_dirs/norm_resnet34_8xb32_in1k_a
5 |
6 | # deploy example
7 |
8 | razor_config=configs/pruning/mmcls/l1-norm/l1-norm_resnet34_8xb32_in1k_a_deploy.py
9 | deploy_config=mmdeploy/configs/mmcls/classification_onnxruntime_dynamic.py
10 | static_model_checkpoint_path=path/to/pruend/checkpoint
11 |
12 | python mmdeploy/tools/deploy.py $deploy_config \
13 | $razor_config \
14 | $static_model_checkpoint_path \
15 | mmdeploy/tests/data/tiger.jpeg \
16 | --work-dir ./work_dirs/mmdeploy
17 |
18 | python mmdeploy/tools/profiler.py $deploy_config \
19 | $razor_config \
20 | mmdeploy/demo/resources \
21 | --model ./work_dirs/mmdeploy/end2end.onnx \
22 | --shape 224x224 \
23 | --device cpu \
24 | --num-iter 1000 \
25 | --warmup 100
26 |
--------------------------------------------------------------------------------
/mmrazor/configs/pruning/mmdet/dcff/dcff_compact_faster_rcnn_resnet50_8xb4_coco.py:
--------------------------------------------------------------------------------
1 | _base_ = ['dcff_faster_rcnn_resnet50_8xb4_coco.py']
2 |
3 | # model settings
4 | _base_.model = dict(
5 | _scope_='mmrazor',
6 | type='sub_model',
7 | cfg=_base_.architecture,
8 | fix_subnet='configs/pruning/mmdet/dcff/fix_subnet.json',
9 | mode='mutator',
10 | init_cfg=dict(
11 | type='Pretrained',
12 | checkpoint='configs/pruning/mmdet/dcff/fix_subnet_weight.pth'))
13 |
--------------------------------------------------------------------------------
/mmrazor/configs/pruning/mmdet/group_fisher/retinanet/group_fisher_act_finetune_retinanet_r50_fpn_1x_coco.py:
--------------------------------------------------------------------------------
1 | #############################################################################
2 | """# You have to fill these args.
3 |
4 | _base_(str): The path to your pruning config file.
5 | pruned_path (str): The path to the checkpoint of the pruned model.
6 | finetune_lr (float): The lr rate to finetune. Usually, we directly use the lr
7 | rate of the pretrain.
8 | """
9 |
10 | _base_ = './group_fisher_act_prune_retinanet_r50_fpn_1x_coco.py'
11 | pruned_path = 'https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/retinanet/act/group_fisher_act_prune_retinanet_r50_fpn_1x_coco.pth' # noqa
12 | finetune_lr = 0.005
13 | ##############################################################################
14 | algorithm = _base_.model
15 | algorithm.init_cfg = dict(type='Pretrained', checkpoint=pruned_path)
16 |
17 | model = dict(
18 | _delete_=True,
19 | _scope_='mmrazor',
20 | type='GroupFisherSubModel',
21 | algorithm=algorithm,
22 | )
23 |
24 | # restore lr
25 | optim_wrapper = dict(optimizer=dict(lr=finetune_lr))
26 |
27 | # remove pruning related hooks
28 | custom_hooks = _base_.custom_hooks[:-2]
29 |
30 | # delete ddp
31 | model_wrapper_cfg = None
32 |
--------------------------------------------------------------------------------
/mmrazor/configs/pruning/mmdet/group_fisher/retinanet/group_fisher_flops_finetune_retinanet_r50_fpn_1x_coco.py:
--------------------------------------------------------------------------------
1 | #############################################################################
2 | """# You have to fill these args.
3 |
4 | _base_(str): The path to your pruning config file.
5 | pruned_path (str): The path to the checkpoint of the pruned model.
6 | finetune_lr (float): The lr rate to finetune. Usually, we directly use the lr
7 | rate of the pretrain.
8 | """
9 |
10 | _base_ = './group_fisher_flops_prune_retinanet_r50_fpn_1x_coco.py'
11 | pruned_path = 'https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/retinanet/flops/group_fisher_flops_prune_retinanet_r50_fpn_1x_coco.pth' # noqa
12 | finetune_lr = 0.005
13 | ##############################################################################
14 | algorithm = _base_.model
15 | algorithm.init_cfg = dict(type='Pretrained', checkpoint=pruned_path)
16 |
17 | model = dict(
18 | _delete_=True,
19 | _scope_='mmrazor',
20 | type='GroupFisherSubModel',
21 | algorithm=algorithm,
22 | )
23 |
24 | # restore lr
25 | optim_wrapper = dict(optimizer=dict(lr=finetune_lr))
26 |
27 | # remove pruning related hooks
28 | custom_hooks = _base_.custom_hooks[:-2]
29 |
30 | # delete ddp
31 | model_wrapper_cfg = None
32 |
--------------------------------------------------------------------------------
/mmrazor/configs/pruning/mmdet/group_fisher/retinanet/group_fisher_flops_prune_retinanet_r50_fpn_1x_coco.py:
--------------------------------------------------------------------------------
1 | _base_ = './group_fisher_act_prune_retinanet_r50_fpn_1x_coco.py'
2 | model = dict(
3 | mutator=dict(
4 | channel_unit_cfg=dict(
5 | default_args=dict(normalization_type='flops', ), ), ), )
6 |
--------------------------------------------------------------------------------
/mmrazor/configs/pruning/mmdet/group_fisher/retinanet/metafile.yml:
--------------------------------------------------------------------------------
1 | Models:
2 | - Name: group_fisher_act_finetune_retinanet_r50_fpn_1x_coco
3 | In Collection: GroupFisher
4 | Config: configs/pruning/mmdet/group_fisher/retinanet/group_fisher_act_finetune_retinanet_r50_fpn_1x_coco.py
5 | Weights: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/retinanet/act/group_fisher_act_finetune_retinanet_r50_fpn_1x_coco.pth
6 | Results:
7 | - Task: Object Detection
8 | Dataset: COCO
9 | Metrics:
10 | box AP: 36.5
11 | - Name: group_fisher_flops_finetune_retinanet_r50_fpn_1x_coco
12 | In Collection: GroupFisher
13 | Config: configs/pruning/mmdet/group_fisher/retinanet/group_fisher_flops_finetune_retinanet_r50_fpn_1x_coco.py
14 | Weights: https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/retinanet/flops/group_fisher_flops_finetune_retinanet_r50_fpn_1x_coco.pth
15 | Results:
16 | - Task: Object Detection
17 | Dataset: COCO
18 | Metrics:
19 | box AP: 36.6
20 |
--------------------------------------------------------------------------------
/mmrazor/configs/pruning/mmpose/dcff/dcff_compact_topdown_heatmap_resnet50_coco.py:
--------------------------------------------------------------------------------
1 | _base_ = ['dcff_topdown_heatmap_resnet50_coco.py']
2 |
3 | # model settings
4 | _base_.model = dict(
5 | _scope_='mmrazor',
6 | type='sub_model',
7 | cfg=_base_.architecture,
8 | fix_subnet='configs/pruning/mmpose/dcff/fix_subnet.json',
9 | mode='mutator',
10 | init_cfg=dict(
11 | type='Pretrained',
12 | checkpoint='configs/pruning/mmpose/dcff/fix_subnet_weight.pth'))
13 |
--------------------------------------------------------------------------------
/mmrazor/configs/pruning/mmpose/group_fisher/group_fisher_finetune_rtmpose-s_8xb256-420e_aic-coco-256x192.py:
--------------------------------------------------------------------------------
1 | #############################################################################
2 | """# You have to fill these args.
3 |
4 | _base_(str): The path to your pruning config file.
5 | pruned_path (str): The path to the checkpoint of the pruned model.
6 | finetune_lr (float): The lr rate to finetune. Usually, we directly use the lr
7 | rate of the pretrain.
8 | """
9 |
10 | _base_ = './group_fisher_prune_rtmpose-s_8xb256-420e_aic-coco-256x192.py' # noqa
11 | pruned_path = 'https://download.openmmlab.com/mmrazor/v1/pruning/group_fisher/rtmpose-s/group_fisher_prune_rtmpose-s_8xb256-420e_aic-coco-256x192.pth' # noqa
12 | finetune_lr = 4e-3
13 | ##############################################################################
14 |
15 | algorithm = _base_.model
16 | algorithm.init_cfg = dict(type='Pretrained', checkpoint=pruned_path)
17 |
18 | model = dict(
19 | _delete_=True,
20 | _scope_='mmrazor',
21 | type='GroupFisherSubModel',
22 | algorithm=algorithm,
23 | )
24 |
25 | # restore lr
26 | optim_wrapper = dict(optimizer=dict(lr=finetune_lr))
27 |
28 | # remove pruning related hooks
29 | custom_hooks = _base_.custom_hooks[:-2]
30 |
31 | # delete ddp
32 | model_wrapper_cfg = None
33 |
--------------------------------------------------------------------------------
/mmrazor/configs/pruning/mmseg/dcff/dcff_compact_pointrend_resnet50_8xb2_cityscapes.py:
--------------------------------------------------------------------------------
1 | _base_ = ['dcff_pointrend_resnet50_8xb2_cityscapes.py']
2 |
3 | # model settings
4 | _base_.model = dict(
5 | _scope_='mmrazor',
6 | type='sub_model',
7 | cfg=_base_.architecture,
8 | fix_subnet='configs/pruning/mmseg/dcff/fix_subnet.json',
9 | mode='mutator',
10 | init_cfg=dict(
11 | type='Pretrained',
12 | checkpoint='configs/pruning/mmseg/dcff/fix_subnet_weight.pth'))
13 |
--------------------------------------------------------------------------------
/mmrazor/configs/quantization/deploy_cfgs/mmcls/classification_openvino_dynamic-224x224.py:
--------------------------------------------------------------------------------
1 | deploy_cfg = dict(
2 | onnx_config=dict(
3 | type='onnx',
4 | export_params=True,
5 | keep_initializers_as_inputs=False,
6 | opset_version=11,
7 | save_file='end2end.onnx',
8 | input_names=['input'],
9 | output_names=['output'],
10 | input_shape=None,
11 | optimize=True,
12 | dynamic_axes={
13 | 'input': {
14 | 0: 'batch',
15 | 2: 'height',
16 | 3: 'width'
17 | },
18 | 'output': {
19 | 0: 'batch'
20 | }
21 | }),
22 | backend_config=dict(
23 | type='openvino',
24 | model_inputs=[dict(opt_shapes=dict(input=[1, 3, 224, 224]))]),
25 | codebase_config=dict(type='mmcls', task='Classification'),
26 | function_record_to_pop=[
27 | 'mmcls.models.classifiers.ImageClassifier.forward',
28 | 'mmcls.models.classifiers.BaseClassifier.forward'
29 | ],
30 | )
31 |
--------------------------------------------------------------------------------
/mmrazor/configs/quantization/qat/base/metafile.yml:
--------------------------------------------------------------------------------
1 | Collections:
2 | - Name: QAT
3 | README: configs/quantization/qat/base/README.md
4 | Models:
5 | - Name: qat_openvino_resnet18_10e_8xb32_in1k.py
6 | In Collection: QAT
7 | Metadata:
8 | Backend: openvino
9 | Float Model:
10 | Config: mmcls::resnet/resnet18_8xb32_in1k.py
11 | Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth
12 | Metrics:
13 | Top 1 Accuracy: 69.90
14 | Results:
15 | - Task: Image Classification
16 | Dataset: ImageNet-1k
17 | Metrics:
18 | Top 1 Accuracy: 69.98
19 | Config: configs/quantization/qat/base/qat_openvino_resnet18_10e_8xb32_in1k.py
20 | Weights: https://download.openmmlab.com/mmrazor/v1/quantization/qat/openvino/qat_openvino_resnet18_8xb32_10e_in1k_20230413_172732-5b9ff01d.pth
21 |
--------------------------------------------------------------------------------
/mmrazor/configs/vanilla/mmcls/wide-resnet/wrn16-w2_b16x8_cifar10.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | 'mmcls::_base_/datasets/cifar10_bs16.py',
3 | '../../../_base_/vanilla_models/wrn16_2_cifar10.py',
4 | 'mmcls::_base_/schedules/cifar10_bs128.py',
5 | 'mmcls::_base_/default_runtime.py',
6 | ]
7 | test_evaluator = dict(topk=(1, 5))
8 |
--------------------------------------------------------------------------------
/mmrazor/configs/vanilla/mmcls/wide-resnet/wrn22-w4_b16x8_cifar10.py:
--------------------------------------------------------------------------------
1 | _base_ = ['wrn16-w2_b16x8_cifar10.py']
2 | model = dict(
3 | backbone=dict(depth=22, widen_factor=4), head=dict(in_channels=256, ))
4 |
--------------------------------------------------------------------------------
/mmrazor/configs/vanilla/mmcls/wide-resnet/wrn28-w4_b16x8_cifar10.py:
--------------------------------------------------------------------------------
1 | _base_ = ['wrn16-w2_b16x8_cifar10.py']
2 | model = dict(
3 | backbone=dict(depth=28, widen_factor=4), head=dict(in_channels=256, ))
4 |
--------------------------------------------------------------------------------
/mmrazor/configs/vanilla/mmcls/wide-resnet/wrn40-w2_b16x8_cifar10.py:
--------------------------------------------------------------------------------
1 | _base_ = ['wrn16-w2_b16x8_cifar10.py']
2 | model = dict(
3 | backbone=dict(depth=40, widen_factor=2), head=dict(in_channels=128, ))
4 |
--------------------------------------------------------------------------------
/mmrazor/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | ARG PYTORCH="1.6.0"
2 | ARG CUDA="10.1"
3 | ARG CUDNN="7"
4 |
5 | FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel
6 |
7 | ENV TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0+PTX"
8 | ENV TORCH_NVCC_FLAGS="-Xfatbin -compress-all"
9 | ENV CMAKE_PREFIX_PATH="(dirname(which conda))/../"
10 |
11 | RUN apt-get update && apt-get install -y ffmpeg libsm6 libxext6 git ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6 \
12 | && apt-get clean \
13 | && rm -rf /var/lib/apt/lists/*
14 |
15 | # Install MMCV
16 | RUN pip install mmcv-full==1.3.8 -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.6.0/index.html
17 |
18 | # Install MMClassification
19 | RUN conda clean --all
20 | RUN git clone https://github.com/open-mmlab/mmclassification.git
21 | WORKDIR ./mmclassification
22 | RUN pip install --no-cache-dir -e .
23 |
--------------------------------------------------------------------------------
/mmrazor/docker/serve/config.properties:
--------------------------------------------------------------------------------
1 | inference_address=http://0.0.0.0:8080
2 | management_address=http://0.0.0.0:8081
3 | metrics_address=http://0.0.0.0:8082
4 | model_store=/home/model-server/model-store
5 | load_models=all
6 |
--------------------------------------------------------------------------------
/mmrazor/docker/serve/entrypoint.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -e
3 |
4 | if [[ "$1" = "serve" ]]; then
5 | shift 1
6 | torchserve --start --ts-config /home/model-server/config.properties
7 | else
8 | eval "$@"
9 | fi
10 |
11 | # prevent docker exit
12 | tail -f /dev/null
13 |
--------------------------------------------------------------------------------
/mmrazor/docs/en/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/mmrazor/docs/en/_static/css/readthedocs.css:
--------------------------------------------------------------------------------
1 | .header-logo {
2 | background-image: url("../image/mmrazor-logo.png");
3 | background-size: 125px 40px;
4 | height: 40px;
5 | width: 125px;
6 | }
7 |
--------------------------------------------------------------------------------
/mmrazor/docs/en/_static/image/mmrazor-logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LKJacky/Differentiable-Model-Scaling/27188926d8e0815ba1b6159b47741ea289ba2ca2/mmrazor/docs/en/_static/image/mmrazor-logo.png
--------------------------------------------------------------------------------
/mmrazor/docs/en/advanced_guides/index.rst:
--------------------------------------------------------------------------------
1 | Key Concepts
2 | ***************
3 |
4 | .. toctree::
5 | :maxdepth: 1
6 |
7 | algorithm.md
8 | mutator.md
9 | mutable.md
10 | recorder.md
11 | delivery.md
12 |
13 | Development tutorials
14 | ************************
15 |
16 | .. toctree::
17 | :maxdepth: 1
18 |
19 | customize_architectures.md
20 | customize_nas_algorithms.md
21 | customize_pruning_algorithms.md
22 | customize_kd_algorithms.md
23 | customize_quantization_algorithms.md
24 | customize_mixed_algorithms.md
25 | apply_existing_algorithms_to_new_tasks.md
26 |
--------------------------------------------------------------------------------
/mmrazor/docs/en/imgs/pruning/draw-config.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LKJacky/Differentiable-Model-Scaling/27188926d8e0815ba1b6159b47741ea289ba2ca2/mmrazor/docs/en/imgs/pruning/draw-config.png
--------------------------------------------------------------------------------
/mmrazor/docs/en/imgs/pruning/framework-ChanelMutator.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LKJacky/Differentiable-Model-Scaling/27188926d8e0815ba1b6159b47741ea289ba2ca2/mmrazor/docs/en/imgs/pruning/framework-ChanelMutator.png
--------------------------------------------------------------------------------
/mmrazor/docs/en/imgs/pruning/framework-algorithm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LKJacky/Differentiable-Model-Scaling/27188926d8e0815ba1b6159b47741ea289ba2ca2/mmrazor/docs/en/imgs/pruning/framework-algorithm.png
--------------------------------------------------------------------------------
/mmrazor/docs/en/imgs/pruning/framework-framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LKJacky/Differentiable-Model-Scaling/27188926d8e0815ba1b6159b47741ea289ba2ca2/mmrazor/docs/en/imgs/pruning/framework-framework.png
--------------------------------------------------------------------------------
/mmrazor/docs/en/imgs/pruning/framework-graph.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LKJacky/Differentiable-Model-Scaling/27188926d8e0815ba1b6159b47741ea289ba2ca2/mmrazor/docs/en/imgs/pruning/framework-graph.png
--------------------------------------------------------------------------------
/mmrazor/docs/en/imgs/pruning/framework-op.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LKJacky/Differentiable-Model-Scaling/27188926d8e0815ba1b6159b47741ea289ba2ca2/mmrazor/docs/en/imgs/pruning/framework-op.png
--------------------------------------------------------------------------------
/mmrazor/docs/en/imgs/pruning/pruning_framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LKJacky/Differentiable-Model-Scaling/27188926d8e0815ba1b6159b47741ea289ba2ca2/mmrazor/docs/en/imgs/pruning/pruning_framework.png
--------------------------------------------------------------------------------
/mmrazor/docs/en/imgs/pruning/unit.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LKJacky/Differentiable-Model-Scaling/27188926d8e0815ba1b6159b47741ea289ba2ca2/mmrazor/docs/en/imgs/pruning/unit.png
--------------------------------------------------------------------------------
/mmrazor/docs/en/index.rst:
--------------------------------------------------------------------------------
1 | .. mmrazor documentation master file, created by
2 | sphinx-quickstart on Mon Aug 29 15:21:38 2022.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | Welcome to mmrazor's documentation!
7 | ===================================
8 |
9 | .. toctree::
10 | :maxdepth: 1
11 | :caption: Get Started:
12 |
13 | get_started/overview.md
14 | get_started/installation.md
15 | get_started/model_zoo.md
16 |
17 | .. toctree::
18 | :maxdepth: 2
19 | :caption: User Guides:
20 |
21 | user_guides/index.rst
22 |
23 | .. toctree::
24 | :maxdepth: 2
25 | :caption: Advanced Guides
26 |
27 | advanced_guides/index.rst
28 |
29 | .. toctree::
30 | :maxdepth: 1
31 | :caption: Notes
32 |
33 | notes/changelog.md
34 | notes/contribution_guide.md
35 | notes/faq.md
36 |
37 | .. toctree::
38 | :maxdepth: 1
39 | :caption: APIs Reference
40 |
41 | api.rst
42 |
43 | .. toctree::
44 | :maxdepth: 1
45 | :caption: Switch Language
46 |
47 | switch_language.md
48 |
49 |
50 | Indices and tables
51 | ==================
52 |
53 | * :ref:`genindex`
54 | * :ref:`modindex`
55 | * :ref:`search`
56 |
--------------------------------------------------------------------------------
/mmrazor/docs/en/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.http://sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/mmrazor/docs/en/notes/faq.md:
--------------------------------------------------------------------------------
1 | # Frequently Asked Questions
2 |
--------------------------------------------------------------------------------
/mmrazor/docs/en/switch_language.md:
--------------------------------------------------------------------------------
1 | ## English
2 |
3 | ## 简体中文
4 |
--------------------------------------------------------------------------------
/mmrazor/docs/en/user_guides/1_learn_about_config.md:
--------------------------------------------------------------------------------
1 | # Learn about Configs
2 |
3 | ## Directory structure of configs in mmrazor
4 |
5 | 
6 |
7 | `mmxxx`: means some task repositories of OpenMMLab, such mmcls, mmdet, mmseg and so on.
8 |
9 | `_base_`: includes configures of datasets, experiment settings and model architectures.
10 |
11 | `distill`/`nas`/`pruning`: model compression algorithms.
12 |
13 | `vanilla`: task models owned by mmrazor.
14 |
15 | ## More about config
16 |
17 | Please refer to [config](https://github.com/open-mmlab/mmengine/blob/main/docs/zh_cn/tutorials/config.md) in mmengine.
18 |
--------------------------------------------------------------------------------
/mmrazor/docs/en/user_guides/index.rst:
--------------------------------------------------------------------------------
1 | Train & Test
2 | **************
3 |
4 | .. toctree::
5 | :maxdepth: 1
6 |
7 |
8 | 1_learn_about_config.md
9 | 2_train_different_types_algorithms.md
10 | 3_train_with_different_devices.md
11 | 4_test_a_model.md
12 |
13 | Quantization
14 | ************
15 |
16 | .. toctree::
17 | :maxdepth: 1
18 |
19 | quantization_user_guide.md
20 |
21 | Useful Tools
22 | ************
23 |
24 | please refer to upstream applied repositories' docs
25 |
--------------------------------------------------------------------------------
/mmrazor/docs/zh_cn/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/mmrazor/docs/zh_cn/index.rst:
--------------------------------------------------------------------------------
1 | 欢迎来到 MMRazor 的中文文档!
2 | =========================
3 |
4 | .. toctree::
5 | :maxdepth: 2
6 | :caption: 开始你的第一步
7 |
8 | .. toctree::
9 | :maxdepth: 2
10 | :caption: 快速启动
11 |
12 | .. toctree::
13 | :maxdepth: 2
14 | :caption: 教程
15 |
16 | .. toctree::
17 | :caption: 语言切换
18 |
19 | switch_language.md
20 |
21 | .. toctree::
22 | :caption: 接口文档(英文)
23 |
24 | api.rst
25 |
26 | Indices and tables
27 | ==================
28 |
29 | * :ref:`genindex`
30 | * :ref:`search`
31 |
--------------------------------------------------------------------------------
/mmrazor/docs/zh_cn/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.http://sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/mmrazor/docs/zh_cn/switch_language.md:
--------------------------------------------------------------------------------
1 | ## English
2 |
3 | ## 简体中文
4 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import mmcv
3 | import mmengine
4 | from mmengine.utils import digit_version
5 |
6 | from .version import __version__
7 |
8 | mmcv_minimum_version = '2.0.0rc1'
9 | mmcv_maximum_version = '2.1.0'
10 | mmcv_version = digit_version(mmcv.__version__)
11 |
12 | mmengine_minimum_version = '0.1.0'
13 | mmengine_maximum_version = '1.0.0'
14 | mmengine_version = digit_version(mmengine.__version__)
15 |
16 | assert (mmcv_version >= digit_version(mmcv_minimum_version)
17 | and mmcv_version <= digit_version(mmcv_maximum_version)), \
18 | f'MMCV=={mmcv.__version__} is used but incompatible. ' \
19 | f'Please install mmcv>={mmcv_minimum_version}, <={mmcv_maximum_version}.'
20 |
21 | assert (mmengine_version >= digit_version(mmengine_minimum_version)
22 | and mmengine_version < digit_version(mmengine_maximum_version)), \
23 | f'MMEngine=={mmengine.__version__} is used but incompatible. ' \
24 | f'Please install mmengine>={mmengine_minimum_version}, ' \
25 | f'<{mmengine_maximum_version}.'
26 |
27 | __all__ = ['__version__', 'digit_version']
28 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .crd_dataset_wrapper import CRDDataset
3 | from .transforms import AutoAugment, AutoAugmentV2, PackCRDClsInputs
4 |
5 | __all__ = ['AutoAugment', 'AutoAugmentV2', 'PackCRDClsInputs', 'CRDDataset']
6 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/datasets/transforms/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .auto_augment import AutoAugment
3 | from .auto_augmentv2 import AutoAugmentV2
4 | from .formatting import PackCRDClsInputs
5 |
6 | __all__ = ['AutoAugment', 'AutoAugmentV2', 'PackCRDClsInputs']
7 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/engine/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .hooks import (DMCPSubnetHook, DumpSubnetHook, EstimateResourcesHook,
3 | StopDistillHook)
4 | from .optimizers import SeparateOptimWrapperConstructor
5 | from .runner import (AutoSlimGreedySearchLoop, DartsEpochBasedTrainLoop,
6 | DartsIterBasedTrainLoop, EvolutionSearchLoop,
7 | GreedySamplerTrainLoop, LSQEpochBasedLoop, PTQLoop,
8 | QATEpochBasedLoop, QATValLoop, SelfDistillValLoop,
9 | SingleTeacherDistillValLoop, SlimmableValLoop,
10 | SubnetValLoop)
11 |
12 | __all__ = [
13 | 'DMCPSubnetHook', 'StopDistillHook', 'SeparateOptimWrapperConstructor',
14 | 'DumpSubnetHook', 'SingleTeacherDistillValLoop',
15 | 'DartsEpochBasedTrainLoop', 'DartsIterBasedTrainLoop', 'SlimmableValLoop',
16 | 'EvolutionSearchLoop', 'GreedySamplerTrainLoop', 'EstimateResourcesHook',
17 | 'SelfDistillValLoop', 'AutoSlimGreedySearchLoop', 'SubnetValLoop',
18 | 'PTQLoop', 'QATEpochBasedLoop', 'LSQEpochBasedLoop', 'QATValLoop'
19 | ]
20 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/engine/hooks/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .dmcp_subnet_hook import DMCPSubnetHook
3 | from .dump_subnet_hook import DumpSubnetHook
4 | from .estimate_resources_hook import EstimateResourcesHook
5 | from .stop_distillation_hook import StopDistillHook
6 | from .visualization_hook import RazorVisualizationHook
7 |
8 | __all__ = [
9 | 'DumpSubnetHook', 'EstimateResourcesHook', 'RazorVisualizationHook',
10 | 'DMCPSubnetHook', 'StopDistillHook'
11 | ]
12 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/engine/hooks/group_fisher_hooks.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | """This file includes the modules in the impl folder.
3 |
4 | As it only records impl modules, it is not initialized automatically.
5 | """
6 | from mmrazor.implementations.pruning.group_fisher import \
7 | PruningStructureHook # noqa
8 | from mmrazor.implementations.pruning.group_fisher import \
9 | ResourceInfoHook # noqa
10 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/engine/hooks/stop_distillation_hook.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from mmengine.hooks import Hook
3 | from mmengine.model import is_model_wrapper
4 |
5 | from mmrazor.registry import HOOKS
6 |
7 |
8 | @HOOKS.register_module()
9 | class StopDistillHook(Hook):
10 | """Stop distilling at a certain time.
11 |
12 | Args:
13 | stop_epoch (int): Stop distillation at this epoch.
14 | """
15 |
16 | priority = 'LOW'
17 |
18 | def __init__(self, stop_epoch: int) -> None:
19 | self.stop_epoch = stop_epoch
20 |
21 | def before_train_epoch(self, runner) -> None:
22 | """Stop distillation."""
23 | if runner.epoch >= self.stop_epoch:
24 | model = runner.model
25 | # TODO: refactor after mmengine using model wrapper
26 | if is_model_wrapper(model):
27 | model = model.module
28 | assert hasattr(model, 'distillation_stopped')
29 |
30 | runner.logger.info('Distillation has been stopped!')
31 | model.distillation_stopped = True
32 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/engine/optimizers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .optimizer_constructor import SeparateOptimWrapperConstructor
3 |
4 | __all__ = ['SeparateOptimWrapperConstructor']
5 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/engine/runner/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .autoslim_greedy_search_loop import AutoSlimGreedySearchLoop
3 | from .darts_loop import DartsEpochBasedTrainLoop, DartsIterBasedTrainLoop
4 | from .distill_val_loop import SelfDistillValLoop, SingleTeacherDistillValLoop
5 | from .evolution_search_loop import EvolutionSearchLoop
6 | from .iteprune_val_loop import ItePruneValLoop
7 | from .quantization_loops import (LSQEpochBasedLoop, PTQLoop, QATEpochBasedLoop,
8 | QATValLoop)
9 | from .slimmable_val_loop import SlimmableValLoop
10 | from .subnet_sampler_loop import GreedySamplerTrainLoop
11 | from .subnet_val_loop import SubnetValLoop
12 |
13 | __all__ = [
14 | 'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop',
15 | 'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop',
16 | 'GreedySamplerTrainLoop', 'SubnetValLoop', 'SelfDistillValLoop',
17 | 'ItePruneValLoop', 'AutoSlimGreedySearchLoop', 'QATEpochBasedLoop',
18 | 'PTQLoop', 'LSQEpochBasedLoop', 'QATValLoop'
19 | ]
20 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/engine/runner/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .calibrate_bn_mixin import CalibrateBNMixin
3 | from .check import check_subnet_resources
4 | from .genetic import crossover
5 |
6 | __all__ = ['crossover', 'check_subnet_resources', 'CalibrateBNMixin']
7 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/engine/runner/utils/genetic.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import copy
3 |
4 | import numpy as np
5 |
6 | from mmrazor.utils import SingleMutatorRandomSubnet
7 |
8 |
9 | def crossover(random_subnet1: SingleMutatorRandomSubnet,
10 | random_subnet2: SingleMutatorRandomSubnet,
11 | prob: float = 0.5) -> SingleMutatorRandomSubnet:
12 | """Crossover in genetic algorithm.
13 |
14 | Args:
15 | random_subnet1 (SINGLE_MUTATOR_RANDOM_SUBNET): One of the subnets to
16 | crossover.
17 | random_subnet2 (SINGLE_MUTATOR_RANDOM_SUBNET): One of the subnets to
18 | crossover.
19 | prob (float): The probablity of getting choice from `random_subnet2`.
20 | Defaults to 0.5.
21 |
22 | Returns:
23 | SINGLE_MUTATOR_RANDOM_SUBNET: The result of crossover.
24 | """
25 | assert prob >= 0. and prob <= 1., \
26 | 'The probability of crossover has to be between 0 and 1'
27 | crossover_subnet = copy.deepcopy(random_subnet1)
28 | for group_id, choice in random_subnet2.items():
29 | if np.random.random_sample() < prob:
30 | crossover_subnet[group_id] = choice
31 | return crossover_subnet
32 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/implementations/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | """impl folder is an experimental file structure to store algorithm
3 | implementations.
4 |
5 | Previous file structure splits the files of an algorithm into different folders
6 | according to the types of these files. It may make it hard to understand an
7 | algorithm. So we add the impl folder, where all files of an algorithm are
8 | stored in one folder. As this structure is experimental, it may change rapidly.
9 | """
10 |
11 | from . import pruning # noqa
12 |
13 | __all__ = ['pruning']
14 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/implementations/pruning/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from . import group_fisher, sparse_gpt
3 |
4 | __all__ = ['group_fisher', 'sparse_gpt']
5 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/implementations/pruning/group_fisher/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .algorithm import GroupFisherAlgorithm
3 | from .counters import GroupFisherConv2dCounter, GroupFisherLinearCounter
4 | from .hook import PruningStructureHook, ResourceInfoHook
5 | from .mutator import GroupFisherChannelMutator
6 | from .ops import GroupFisherConv2d, GroupFisherLinear, GroupFisherMixin
7 | from .prune_deploy_sub_model import GroupFisherDeploySubModel
8 | from .prune_sub_model import GroupFisherSubModel
9 | from .unit import GroupFisherChannelUnit
10 |
11 | __all__ = [
12 | 'GroupFisherDeploySubModel',
13 | 'GroupFisherSubModel',
14 | 'GroupFisherAlgorithm',
15 | 'GroupFisherConv2dCounter',
16 | 'GroupFisherLinearCounter',
17 | 'PruningStructureHook',
18 | 'ResourceInfoHook',
19 | 'GroupFisherChannelMutator',
20 | 'GroupFisherChannelUnit',
21 | 'GroupFisherConv2d',
22 | 'GroupFisherLinear',
23 | 'GroupFisherMixin',
24 | ]
25 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/implementations/pruning/group_fisher/counters.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from mmrazor.models.task_modules.estimators.counters import (
3 | DynamicConv2dCounter, DynamicLinearCounter)
4 | from mmrazor.registry import TASK_UTILS
5 |
6 |
7 | @TASK_UTILS.register_module()
8 | class GroupFisherConv2dCounter(DynamicConv2dCounter):
9 | """Counter of GroupFisherConv2d."""
10 | pass
11 |
12 |
13 | @TASK_UTILS.register_module()
14 | class GroupFisherLinearCounter(DynamicLinearCounter):
15 | """Counter of GroupFisherLinear."""
16 | pass
17 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/implementations/pruning/sparse_gpt/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .compressor import SparseGptCompressor
3 | from .ops import SparseGptLinear, SparseGptMixIn
4 | from .utils import replace_with_dynamic_ops
5 |
6 | __all__ = [
7 | 'SparseGptLinear', 'SparseGptMixIn', 'replace_with_dynamic_ops',
8 | 'SparseGptCompressor'
9 | ]
10 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/implementations/pruning/sparse_gpt/sparse24_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch
3 |
4 |
5 | @torch.no_grad()
6 | def is_weight_sparse_24(weight: torch.Tensor, dim=-1):
7 | """"Check if the weight is sparse 24."""
8 | weight = weight.transpose(-1, dim).reshape(-1, 4) # N 4
9 | is_zero = (weight == 0).sum(-1) # N
10 | return (is_zero >= 2).all()
11 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/implementations/quantization/gptq/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .compressor import GPTQCompressor
3 | from .gptq import GPTQMixIn
4 | from .ops import GPTQConv2d, GPTQLinear, TritonGPTQLinear
5 | from .quantizer import Quantizer
6 |
7 | __all__ = [
8 | 'GPTQCompressor',
9 | 'GPTQMixIn',
10 | 'GPTQConv2d',
11 | 'GPTQLinear',
12 | 'TritonGPTQLinear',
13 | 'Quantizer',
14 | ]
15 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .algorithms import * # noqa: F401,F403
3 | from .architectures import * # noqa: F401,F403
4 | from .distillers import * # noqa: F401,F403
5 | from .fake_quants import * # noqa: F401,F403
6 | from .losses import * # noqa: F401,F403
7 | from .mutables import * # noqa: F401,F403
8 | from .mutators import * # noqa: F401,F403
9 | from .observers import * # noqa: F401,F403
10 | from .quantizers import * # noqa: F401,F403
11 | from .task_modules import * # noqa: F401,F403
12 | from .utils import * # noqa: F401,F403
13 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/algorithms/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .base import BaseAlgorithm
3 | from .distill import (DAFLDataFreeDistillation, DataFreeDistillation,
4 | FpnTeacherDistill, OverhaulFeatureDistillation,
5 | SelfDistill, SingleTeacherDistill)
6 | from .nas import (DSNAS, DSNASDDP, SPOS, Autoformer, AutoSlim, AutoSlimDDP,
7 | BigNAS, BigNASDDP, Darts, DartsDDP)
8 | from .pruning import DCFF, DMCP, DMCPDDP, SlimmableNetwork, SlimmableNetworkDDP
9 | from .pruning.ite_prune_algorithm import ItePruneAlgorithm
10 | from .quantization import MMArchitectureQuant, MMArchitectureQuantDDP
11 |
12 | __all__ = [
13 | 'SingleTeacherDistill', 'BaseAlgorithm', 'FpnTeacherDistill', 'SPOS',
14 | 'SlimmableNetwork', 'SlimmableNetworkDDP', 'AutoSlim', 'AutoSlimDDP',
15 | 'Darts', 'DartsDDP', 'DCFF', 'SelfDistill', 'DataFreeDistillation',
16 | 'DAFLDataFreeDistillation', 'OverhaulFeatureDistillation',
17 | 'ItePruneAlgorithm', 'DSNAS', 'DSNASDDP', 'Autoformer', 'BigNAS',
18 | 'BigNASDDP', 'DMCP', 'DMCPDDP', 'MMArchitectureQuant',
19 | 'MMArchitectureQuantDDP'
20 | ]
21 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/algorithms/distill/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .configurable import (DAFLDataFreeDistillation, DataFreeDistillation,
3 | FpnTeacherDistill, OverhaulFeatureDistillation,
4 | SelfDistill, SingleTeacherDistill)
5 |
6 | __all__ = [
7 | 'SingleTeacherDistill', 'FpnTeacherDistill', 'SelfDistill',
8 | 'DataFreeDistillation', 'DAFLDataFreeDistillation',
9 | 'OverhaulFeatureDistillation'
10 | ]
11 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/algorithms/distill/configurable/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .datafree_distillation import (DAFLDataFreeDistillation,
3 | DataFreeDistillation)
4 | from .fpn_teacher_distill import FpnTeacherDistill
5 | from .overhaul_feature_distillation import OverhaulFeatureDistillation
6 | from .self_distill import SelfDistill
7 | from .single_teacher_distill import SingleTeacherDistill
8 |
9 | __all__ = [
10 | 'SelfDistill', 'SingleTeacherDistill', 'FpnTeacherDistill',
11 | 'DataFreeDistillation', 'DAFLDataFreeDistillation',
12 | 'OverhaulFeatureDistillation'
13 | ]
14 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/algorithms/nas/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .autoformer import Autoformer
3 | from .autoslim import AutoSlim, AutoSlimDDP
4 | from .bignas import BigNAS, BigNASDDP
5 | from .darts import Darts, DartsDDP
6 | from .dsnas import DSNAS, DSNASDDP
7 | from .spos import SPOS
8 |
9 | __all__ = [
10 | 'SPOS', 'AutoSlim', 'AutoSlimDDP', 'BigNAS', 'BigNASDDP', 'Darts',
11 | 'DartsDDP', 'DSNAS', 'DSNASDDP', 'DSNAS', 'DSNASDDP', 'Autoformer'
12 | ]
13 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/algorithms/pruning/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .dcff import DCFF
3 | from .dmcp import DMCP, DMCPDDP
4 | from .slimmable_network import SlimmableNetwork, SlimmableNetworkDDP
5 |
6 | __all__ = [
7 | 'SlimmableNetwork', 'SlimmableNetworkDDP', 'DCFF', 'DMCP', 'DMCPDDP'
8 | ]
9 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/algorithms/pruning/group_fisher_algoritho.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | """This file includes the modules in the impl folder.
3 |
4 | As it only records impl modules, it is not initialized automatically.
5 | """
6 | from mmrazor.implementations.pruning.group_fisher import \
7 | GroupFisherAlgorithm # noqa
8 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/algorithms/quantization/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .mm_architecture import MMArchitectureQuant, MMArchitectureQuantDDP
3 |
4 | __all__ = ['MMArchitectureQuant', 'MMArchitectureQuantDDP']
5 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/architectures/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .backbones import * # noqa: F401,F403
3 | from .classifiers import * # noqa: F401,F403
4 | from .connectors import * # noqa: F401,F403
5 | from .dynamic_ops import * # noqa: F401,F403
6 | from .generators import * # noqa: F401,F403
7 | from .heads import * # noqa: F401,F403
8 | from .necks import * # noqa: F401,F403
9 | from .ops import * # noqa: F401,F403
10 | from .utils import * # noqa: F401,F403
11 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/architectures/backbones/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .darts_backbone import DartsBackbone
3 | from .searchable_autoformer import AutoformerBackbone
4 | from .searchable_mobilenet_v2 import SearchableMobileNetV2
5 | from .searchable_mobilenet_v3 import AttentiveMobileNetV3
6 | from .searchable_shufflenet_v2 import SearchableShuffleNetV2
7 | from .wideresnet import WideResNet
8 |
9 | __all__ = [
10 | 'DartsBackbone', 'AutoformerBackbone', 'SearchableMobileNetV2',
11 | 'AttentiveMobileNetV3', 'SearchableShuffleNetV2', 'WideResNet'
12 | ]
13 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/architectures/classifiers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .image import SearchableImageClassifier
3 |
4 | __all__ = ['SearchableImageClassifier']
5 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/architectures/connectors/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .byot_connector import BYOTConnector
3 | from .convmodule_connector import ConvModuleConnector
4 | from .crd_connector import CRDConnector
5 | from .factor_transfer_connectors import Paraphraser, Translator
6 | from .fbkd_connector import FBKDStudentConnector, FBKDTeacherConnector
7 | from .mgd_connector import MGDConnector
8 | from .norm_connector import NormConnector
9 | from .ofd_connector import OFDTeacherConnector
10 | from .torch_connector import TorchFunctionalConnector, TorchNNConnector
11 |
12 | __all__ = [
13 | 'ConvModuleConnector', 'Translator', 'Paraphraser', 'BYOTConnector',
14 | 'FBKDTeacherConnector', 'FBKDStudentConnector', 'TorchFunctionalConnector',
15 | 'CRDConnector', 'TorchNNConnector', 'OFDTeacherConnector', 'MGDConnector',
16 | 'NormConnector'
17 | ]
18 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/architectures/connectors/norm_connector.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from typing import Dict, Optional
3 |
4 | import torch
5 | from mmcv.cnn import build_norm_layer
6 |
7 | from mmrazor.registry import MODELS
8 | from .base_connector import BaseConnector
9 |
10 |
11 | @MODELS.register_module()
12 | class NormConnector(BaseConnector):
13 |
14 | def __init__(self, in_channels, norm_cfg, init_cfg: Optional[Dict] = None):
15 | super(NormConnector, self).__init__(init_cfg)
16 | _, self.norm = build_norm_layer(norm_cfg, in_channels)
17 |
18 | def forward_train(self, feature: torch.Tensor) -> torch.Tensor:
19 | return self.norm(feature)
20 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/architectures/connectors/ofd_connector.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from typing import Dict, Optional
3 |
4 | import torch
5 |
6 | from mmrazor.registry import MODELS
7 | from .base_connector import BaseConnector
8 |
9 |
10 | @MODELS.register_module()
11 | class OFDTeacherConnector(BaseConnector):
12 | """Connector designed for ``OverhaulFeatureDistillation``
13 |
14 | Args:
15 | init_cfg (Optional[Dict], optional): Initialization config dict.
16 | Defaults to None.
17 | """
18 |
19 | def __init__(self, init_cfg: Optional[Dict] = None) -> None:
20 | super().__init__(init_cfg)
21 | self.margin: torch.Tensor = None
22 |
23 | def init_margin(self, margin: torch.Tensor) -> None:
24 | """Initializing margin, will be called by
25 | ``OverhaulFeatureDistillation``.
26 |
27 | Args:
28 | margin (torch.Tensor): margin
29 | """
30 | self.margin = margin
31 |
32 | def forward_train(self, feature: torch.Tensor) -> torch.Tensor:
33 | """forward func for training."""
34 | assert self.margin is not None, (
35 | 'margin must be initialized before training.')
36 | self.margin = self.margin.to(feature.device)
37 | feature = torch.max(feature.detach(), self.margin)
38 | return feature
39 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/architectures/dynamic_ops/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .bricks import * # noqa: F401,F403
3 | from .head import * # noqa: F401,F403
4 | from .mixins import * # noqa: F401,F403
5 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/architectures/dynamic_ops/bricks/group_fisher_ops.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | """This file includes the modules in the impl folder.
3 |
4 | As it only records impl modules, it is not initialized automatically.
5 | """
6 | from mmrazor.implementations.pruning.group_fisher import \
7 | GroupFisherConv2d # noqa
8 | from mmrazor.implementations.pruning.group_fisher import \
9 | GroupFisherLinear # noqa
10 | from mmrazor.implementations.pruning.group_fisher import \
11 | GroupFisherMixin # noqa
12 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/architectures/dynamic_ops/head/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .dynamic_linear_head import DynamicLinearClsHead # noqa: F401
3 |
4 | __all__ = ['DynamicLinearClsHead']
5 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/architectures/dynamic_ops/mixins/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .dynamic_conv_mixins import DynamicConvMixin
3 | from .dynamic_layernorm_mixins import DynamicLayerNormMixin
4 | from .dynamic_mixins import (DynamicBatchNormMixin, DynamicChannelMixin,
5 | DynamicLinearMixin, DynamicMixin)
6 |
7 | __all__ = [
8 | 'DynamicChannelMixin',
9 | 'DynamicBatchNormMixin',
10 | 'DynamicLinearMixin',
11 | 'DynamicMixin',
12 | 'DynamicConvMixin',
13 | 'DynamicLayerNormMixin',
14 | ]
15 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/architectures/generators/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .dafl_generator import DAFLGenerator
3 | from .zskt_generator import ZSKTGenerator
4 |
5 | __all__ = ['DAFLGenerator', 'ZSKTGenerator']
6 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/architectures/heads/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .darts_subnet_head import DartsSubnetClsHead
3 | from .deit_head import DeiTClsHead
4 |
5 | __all__ = ['DartsSubnetClsHead', 'DeiTClsHead']
6 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/architectures/necks/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .squeezemean_with_dropout import SqueezeMeanPoolingWithDropout
3 |
4 | __all__ = ['SqueezeMeanPoolingWithDropout']
5 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/architectures/ops/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .common import Identity
3 | from .darts_series import (DartsDilConv, DartsPoolBN, DartsSepConv,
4 | DartsSkipConnect, DartsZero)
5 | from .efficientnet_series import ConvBnAct, DepthwiseSeparableConv
6 | from .function import InputResizer
7 | from .gather_tensors import GatherTensors
8 | from .mobilenet_series import MBBlock
9 | from .shufflenet_series import ShuffleBlock, ShuffleXception
10 | from .transformer_series import MultiheadAttention, RelativePosition2D
11 |
12 | __all__ = [
13 | 'ShuffleBlock', 'ShuffleXception', 'DartsPoolBN', 'DartsDilConv',
14 | 'DartsSepConv', 'DartsSkipConnect', 'DartsZero', 'MBBlock', 'Identity',
15 | 'ConvBnAct', 'DepthwiseSeparableConv', 'GatherTensors', 'InputResizer',
16 | 'RelativePosition2D', 'MultiheadAttention'
17 | ]
18 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/architectures/ops/base.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from mmengine.model import BaseModule
3 |
4 |
5 | class BaseOP(BaseModule):
6 | """Base class for searchable operations.
7 |
8 | Args:
9 | in_channels (int): The input channels of the operation.
10 | out_channels (int): The output channels of the operation.
11 | stride (int): Stride of the operation. Defaults to 1.
12 | """
13 |
14 | def __init__(self, in_channels, out_channels, stride=1, **kwargs):
15 | super(BaseOP, self).__init__(**kwargs)
16 |
17 | self.in_channels = in_channels
18 | self.out_channels = out_channels
19 | self.stride = stride
20 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/architectures/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .mutable_register import mutate_conv_module, mutate_mobilenet_layer
3 | from .set_dropout import set_dropout
4 |
5 | __all__ = ['mutate_conv_module', 'mutate_mobilenet_layer', 'set_dropout']
6 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/distillers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .base_distiller import BaseDistiller
3 | from .byot_distiller import BYOTDistiller
4 | from .configurable_distiller import ConfigurableDistiller
5 | from .ofd_distiller import OFDDistiller
6 |
7 | __all__ = [
8 | 'ConfigurableDistiller', 'BaseDistiller', 'BYOTDistiller', 'OFDDistiller'
9 | ]
10 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/distillers/base_distiller.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from abc import ABC, abstractmethod
3 | from typing import Dict, Optional
4 |
5 | from mmengine.model import BaseModule
6 |
7 | from ..algorithms.base import LossResults
8 |
9 |
10 | class BaseDistiller(BaseModule, ABC):
11 | """Base class for distiller.
12 |
13 | Args:
14 | init_cfg (dict, optional): Config for distiller. Default to None.
15 | """
16 |
17 | def __init__(self, init_cfg: Optional[Dict] = None) -> None:
18 | super().__init__(init_cfg)
19 |
20 | @abstractmethod
21 | def compute_distill_losses(self) -> LossResults:
22 | """Compute distill losses automatically."""
23 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/fake_quants/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .base import BaseFakeQuantize
3 | from .lsq import (LearnableFakeQuantize, enable_param_learning,
4 | enable_static_estimate, enable_val)
5 | from .torch_fake_quants import register_torch_fake_quants
6 |
7 | __all__ = [
8 | 'BaseFakeQuantize', 'register_torch_fake_quants', 'LearnableFakeQuantize',
9 | 'enable_val', 'enable_param_learning', 'enable_static_estimate'
10 | ]
11 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/fake_quants/base.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | try:
3 | from torch.ao.quantization import FakeQuantize
4 | except ImportError:
5 | from mmrazor.utils import get_placeholder
6 | FakeQuantize = get_placeholder('torch>=1.13')
7 |
8 | BaseFakeQuantize = FakeQuantize
9 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/losses/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .ab_loss import ABLoss
3 | from .at_loss import ATLoss
4 | from .crd_loss import CRDLoss
5 | from .cross_entropy_loss import CrossEntropyLoss
6 | from .cwd import ChannelWiseDivergence
7 | from .dafl_loss import ActivationLoss, InformationEntropyLoss, OnehotLikeLoss
8 | from .decoupled_kd import DKDLoss
9 | from .dist_loss import DISTLoss
10 | from .factor_transfer_loss import FTLoss
11 | from .fbkd_loss import FBKDLoss
12 | from .kd_soft_ce_loss import KDSoftCELoss
13 | from .kl_divergence import KLDivergence
14 | from .l1_loss import L1Loss
15 | from .l2_loss import L2Loss
16 | from .mgd_loss import MGDLoss
17 | from .ofd_loss import OFDLoss
18 | from .pkd_loss import PKDLoss
19 | from .relational_kd import AngleWiseRKD, DistanceWiseRKD
20 | from .weighted_soft_label_distillation import WSLD
21 |
22 | __all__ = [
23 | 'ChannelWiseDivergence', 'KLDivergence', 'AngleWiseRKD', 'DistanceWiseRKD',
24 | 'WSLD', 'L2Loss', 'ABLoss', 'DKDLoss', 'KDSoftCELoss', 'ActivationLoss',
25 | 'OnehotLikeLoss', 'InformationEntropyLoss', 'FTLoss', 'ATLoss', 'OFDLoss',
26 | 'L1Loss', 'FBKDLoss', 'CRDLoss', 'CrossEntropyLoss', 'PKDLoss', 'MGDLoss',
27 | 'DISTLoss'
28 | ]
29 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/losses/cross_entropy_loss.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from mmrazor.registry import MODELS
6 |
7 |
8 | @MODELS.register_module()
9 | class CrossEntropyLoss(nn.Module):
10 | """Cross entropy loss.
11 |
12 | Args:
13 | loss_weight (float): Weight of the loss. Defaults to 1.0.
14 | """
15 |
16 | def __init__(self, loss_weight=1.0):
17 | super(CrossEntropyLoss, self).__init__()
18 | self.loss_weight = loss_weight
19 |
20 | def forward(self, preds_S, preds_T):
21 | preds_T = preds_T.detach()
22 | loss = F.cross_entropy(preds_S, preds_T.argmax(dim=1))
23 | return loss * self.loss_weight
24 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/mutables/mutable_channel/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .base_mutable_channel import BaseMutableChannel
3 | from .mutable_channel_container import MutableChannelContainer
4 | from .oneshot_mutable_channel import OneShotMutableChannel
5 | from .sequential_mutable_channel import SquentialMutableChannel
6 | from .simple_mutable_channel import SimpleMutableChannel
7 | from .units import (ChannelUnitType, DCFFChannelUnit, DMCPChannelUnit,
8 | L1MutableChannelUnit, MutableChannelUnit,
9 | OneShotMutableChannelUnit, SequentialMutableChannelUnit,
10 | SlimmableChannelUnit)
11 |
12 | __all__ = [
13 | 'SimpleMutableChannel', 'L1MutableChannelUnit',
14 | 'SequentialMutableChannelUnit', 'MutableChannelUnit',
15 | 'OneShotMutableChannelUnit', 'SlimmableChannelUnit', 'BaseMutableChannel',
16 | 'MutableChannelContainer', 'SquentialMutableChannel', 'ChannelUnitType',
17 | 'DCFFChannelUnit', 'OneShotMutableChannel', 'DMCPChannelUnit'
18 | ]
19 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/mutables/mutable_channel/units/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .dcff_channel_unit import DCFFChannelUnit
3 | from .dmcp_channel_unit import DMCPChannelUnit
4 | from .l1_mutable_channel_unit import L1MutableChannelUnit
5 | from .mutable_channel_unit import ChannelUnitType, MutableChannelUnit
6 | from .one_shot_mutable_channel_unit import OneShotMutableChannelUnit
7 | from .sequential_mutable_channel_unit import SequentialMutableChannelUnit
8 | from .slimmable_channel_unit import SlimmableChannelUnit
9 |
10 | __all__ = [
11 | 'L1MutableChannelUnit', 'MutableChannelUnit',
12 | 'SequentialMutableChannelUnit', 'OneShotMutableChannelUnit',
13 | 'SlimmableChannelUnit', 'ChannelUnitType', 'DCFFChannelUnit',
14 | 'DMCPChannelUnit'
15 | ]
16 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/mutables/mutable_channel/units/group_fisher_unit.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | """This file includes the modules in the impl folder.
3 |
4 | As it only records impl modules, it is not initialized automatically.
5 | """
6 | from mmrazor.implementations.pruning.group_fisher import \
7 | GroupFisherChannelUnit # noqa
8 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/mutables/mutable_module/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .diff_mutable_module import (DiffChoiceRoute, DiffMutableModule,
3 | DiffMutableOP, OneHotMutableOP)
4 | from .mutable_module import MutableModule
5 | from .one_shot_mutable_module import OneShotMutableModule, OneShotMutableOP
6 |
7 | __all__ = [
8 | 'DiffMutableModule', 'DiffMutableOP', 'DiffChoiceRoute',
9 | 'OneShotMutableOP', 'OneShotMutableModule', 'MutableModule',
10 | 'OneHotMutableOP'
11 | ]
12 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/mutables/mutable_value/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .mutable_value import MutableValue, OneShotMutableValue
3 |
4 | __all__ = ['MutableValue', 'OneShotMutableValue']
5 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/mutators/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .channel_mutator import (ChannelMutator, DCFFChannelMutator,
3 | DMCPChannelMutator, OneShotChannelMutator,
4 | SlimmableChannelMutator)
5 | from .nas_mutator import NasMutator
6 |
7 | __all__ = [
8 | 'ChannelMutator', 'DCFFChannelMutator', 'DMCPChannelMutator',
9 | 'SlimmableChannelMutator', 'NasMutator', 'OneShotChannelMutator'
10 | ]
11 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/mutators/channel_mutator/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .channel_mutator import ChannelMutator
3 | from .dcff_channel_mutator import DCFFChannelMutator
4 | from .dmcp_channel_mutator import DMCPChannelMutator
5 | from .one_shot_channel_mutator import OneShotChannelMutator
6 | from .slimmable_channel_mutator import SlimmableChannelMutator
7 |
8 | __all__ = [
9 | 'SlimmableChannelMutator', 'ChannelMutator', 'OneShotChannelMutator',
10 | 'DCFFChannelMutator', 'DMCPChannelMutator'
11 | ]
12 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/mutators/channel_mutator/group_fisher_mutator.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | """This file includes the modules in the impl folder.
3 |
4 | As it only records impl modules, it is not initialized automatically.
5 | """
6 | from mmrazor.implementations.pruning.group_fisher import \
7 | GroupFisherChannelMutator # noqa
8 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/observers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .base import BaseObserver
3 | from .lsq import LSQObserver, LSQPerChannelObserver
4 | from .torch_observers import register_torch_observers
5 |
6 | __all__ = [
7 | 'BaseObserver', 'register_torch_observers', 'LSQObserver',
8 | 'LSQPerChannelObserver'
9 | ]
10 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/observers/base.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | try:
3 | from torch.ao.quantization.observer import UniformQuantizationObserverBase
4 | except ImportError:
5 | from mmrazor.utils import get_placeholder
6 | UniformQuantizationObserverBase = get_placeholder('torch>=1.13')
7 |
8 | BaseObserver = UniformQuantizationObserverBase
9 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/quantizers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .academic_quantizer import AcademicQuantizer
3 | from .base import BaseQuantizer
4 | from .native_quantizer import TorchNativeQuantizer
5 | from .openvino_quantizer import OpenVINOQuantizer
6 | from .tensorrt_quantizer import TensorRTQuantizer
7 |
8 | __all__ = [
9 | 'BaseQuantizer', 'AcademicQuantizer', 'TorchNativeQuantizer',
10 | 'TensorRTQuantizer', 'OpenVINOQuantizer'
11 | ]
12 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/quantizers/exporters/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .openvino_quantize_exporter import OpenVinoQuantizeExportor
3 | from .tensorrt_quantize_exporter import TensorRTExplicitExporter
4 |
5 | __all__ = ['OpenVinoQuantizeExportor', 'TensorRTExplicitExporter']
6 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/task_modules/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .delivery import * # noqa: F401,F403
3 | from .demo_inputs import * # noqa: F401,F403
4 | from .estimators import ResourceEstimator
5 | from .predictor import * # noqa: F401,F403
6 | from .recorder import * # noqa: F401,F403
7 | from .tracer import * # noqa: F401,F403
8 |
9 | __all__ = ['ResourceEstimator']
10 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/task_modules/delivery/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .delivery_manager import DistillDeliveryManager
3 | from .function_outputs_delivery import FunctionOutputsDelivery
4 | from .method_outputs_delivery import MethodOutputsDelivery
5 |
6 | __all__ = [
7 | 'FunctionOutputsDelivery', 'MethodOutputsDelivery',
8 | 'DistillDeliveryManager'
9 | ]
10 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/task_modules/demo_inputs/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .default_demo_inputs import DefaultDemoInput, defaul_demo_inputs
3 | from .demo_inputs import (BaseDemoInput, DefaultMMClsDemoInput,
4 | DefaultMMDemoInput, DefaultMMDetDemoInput,
5 | DefaultMMSegDemoInput)
6 |
7 | __all__ = [
8 | 'defaul_demo_inputs',
9 | 'DefaultMMClsDemoInput',
10 | 'DefaultMMDetDemoInput',
11 | 'DefaultMMDemoInput',
12 | 'DefaultMMSegDemoInput',
13 | 'BaseDemoInput',
14 | 'DefaultDemoInput',
15 | ]
16 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/task_modules/estimators/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .counters import * # noqa: F401,F403
3 | from .resource_estimator import ResourceEstimator
4 |
5 | __all__ = ['ResourceEstimator']
6 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/task_modules/estimators/counters/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .flops_params_counter import get_model_flops_params
3 | from .latency_counter import get_model_latency
4 | from .op_counters import * # noqa: F401,F403
5 |
6 | __all__ = ['get_model_flops_params', 'get_model_latency']
7 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/task_modules/estimators/counters/op_counters/base_counter.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from abc import ABCMeta, abstractclassmethod
3 |
4 |
5 | class BaseCounter(object, metaclass=ABCMeta):
6 | """Base class of all op module counters in `TASK_UTILS`.
7 |
8 | In ResourceEstimator, `XXModuleCounter` is responsible for `XXModule`,
9 | which refers to estimator/flops_params_counter.py::get_counter_type().
10 | Users can customize a `ModuleACounter` and overwrite the `add_count_hook`
11 | method with a self-defined module `ModuleA`.
12 | """
13 |
14 | def __init__(self) -> None:
15 | pass
16 |
17 | @staticmethod
18 | @abstractclassmethod
19 | def add_count_hook(module, input, output):
20 | """The main method of a `BaseCounter` which defines the way to
21 | calculate resources(flops/params) of the current module.
22 |
23 | Args:
24 | module (nn.Module): the module to be tested.
25 | input (_type_): input_tensor. Plz refer to `torch forward_hook`
26 | output (_type_): output_tensor. Plz refer to `torch forward_hook`
27 | """
28 | pass
29 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/task_modules/estimators/counters/op_counters/group_fisher_counters.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | """This file includes the modules in the impl folder.
3 |
4 | As it only records impl modules, it is not initialized automatically.
5 | """
6 | from mmrazor.implementations.pruning.group_fisher import ( # noqa
7 | GroupFisherConv2dCounter, GroupFisherLinearCounter)
8 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/task_modules/estimators/counters/op_counters/linear_layer_counter.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import numpy as np
3 |
4 | from mmrazor.registry import TASK_UTILS
5 | from ..flops_params_counter import get_model_parameters_number
6 | from .base_counter import BaseCounter
7 |
8 |
9 | @TASK_UTILS.register_module()
10 | class LinearCounter(BaseCounter):
11 | """FLOPs/params counter for Linear operation series."""
12 |
13 | @staticmethod
14 | def add_count_hook(module, input, output):
15 | """Calculate FLOPs and params based on the size of input & output."""
16 | input = input[0]
17 | output_last_dim = output.shape[
18 | -1] # pytorch checks dimensions, so here we don't care much
19 | module.__flops__ += int(np.prod(input.shape) * output_last_dim)
20 | module.__params__ += get_model_parameters_number(module)
21 |
22 |
23 | @TASK_UTILS.register_module()
24 | class DynamicLinearCounter(LinearCounter):
25 | pass
26 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/task_modules/estimators/counters/op_counters/upsample_layer_counter.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from mmrazor.registry import TASK_UTILS
3 | from ..flops_params_counter import get_model_parameters_number
4 | from .base_counter import BaseCounter
5 |
6 |
7 | @TASK_UTILS.register_module()
8 | class UpsampleCounter(BaseCounter):
9 | """FLOPs/params counter for Upsample function."""
10 |
11 | @staticmethod
12 | def add_count_hook(module, input, output):
13 | """Calculate FLOPs and params based on the size of input & output."""
14 | output_size = output[0]
15 | batch_size = output_size.shape[0]
16 | output_elements_count = batch_size
17 | for val in output_size.shape[1:]:
18 | output_elements_count *= val
19 | module.__flops__ += int(output_elements_count)
20 | module.__params__ += get_model_parameters_number(module)
21 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/task_modules/predictor/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .metric_predictor import MetricPredictor
3 |
4 | __all__ = ['MetricPredictor']
5 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/task_modules/predictor/handler/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .carts_handler import CartsHandler
3 | from .gp_handler import GaussProcessHandler
4 | from .mlp_handler import MLPHandler
5 | from .rbf_handler import RBFHandler
6 |
7 | __all__ = ['CartsHandler', 'GaussProcessHandler', 'MLPHandler', 'RBFHandler']
8 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/task_modules/predictor/handler/base_handler.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from joblib import dump, load
3 |
4 |
5 | class BaseHandler:
6 | """Base class for a handler.
7 |
8 | Note:
9 | The handler works through a specific machine leanring algorithm,
10 | and is designed for predicting the evaluation metric of a model.
11 | """
12 |
13 | def __init__(self) -> None:
14 | pass
15 |
16 | def fit(self, train_data, train_label):
17 | """Training the model of handler."""
18 | pass
19 |
20 | def predict(self, test_data):
21 | """Predicting the metric using the model of handler."""
22 | pass
23 |
24 | def load(self, path):
25 | """Load pretrained weights for the handler."""
26 | self.model = load(path)
27 |
28 | def save(self, path):
29 | """Save the handler and return saved path for diff suffix."""
30 | path += f'_{self.__class__.__name__}.joblib'.lower()
31 | dump(self.model, path)
32 | return path
33 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/task_modules/recorder/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .function_inputs_recorder import FunctionInputsRecorder
3 | from .function_outputs_recorder import FunctionOutputsRecorder
4 | from .method_inputs_recorder import MethodInputsRecorder
5 | from .method_outputs_recorder import MethodOutputsRecorder
6 | from .module_inputs_recorder import ModuleInputsRecorder
7 | from .module_outputs_recorder import ModuleOutputsRecorder
8 | from .param_recorder import ParameterRecorder
9 | from .recorder_manager import RecorderManager
10 |
11 | __all__ = [
12 | 'FunctionOutputsRecorder', 'MethodOutputsRecorder',
13 | 'ModuleOutputsRecorder', 'ParameterRecorder', 'RecorderManager',
14 | 'ModuleInputsRecorder', 'MethodInputsRecorder', 'FunctionInputsRecorder'
15 | ]
16 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/task_modules/recorder/module_inputs_recorder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from typing import Any, Tuple
3 |
4 | from torch import nn
5 |
6 | from mmrazor.registry import TASK_UTILS
7 | from .module_outputs_recorder import ModuleOutputsRecorder
8 |
9 |
10 | @TASK_UTILS.register_module()
11 | class ModuleInputsRecorder(ModuleOutputsRecorder):
12 | """Recorder for intermediate results which are Pytorch moudle's inputs."""
13 |
14 | def __init__(self, *args, **kwargs):
15 | super().__init__(*args, **kwargs)
16 |
17 | def forward_hook(self, module: nn.Module, inputs: Tuple,
18 | outputs: Any) -> None:
19 | """Save the module's forward input.
20 |
21 | Args:
22 | module (:obj:`torch.nn.Module`): The module to register hook.
23 | inputs (tuple): The input of the module.
24 | outputs : The output of the module.
25 | """
26 | if self.recording:
27 | self.data_buffer.append(inputs)
28 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/task_modules/tracer/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .backward_tracer import BackwardTracer
3 | from .channel_analyzer import ChannelAnalyzer
4 | # from .razor_tracer import RazorFxTracer
5 | from .fx import (CustomTracer, UntracedMethodRegistry, build_graphmodule,
6 | custom_symbolic_trace)
7 | from .loss_calculator import * # noqa: F401,F403
8 | from .parsers import * # noqa: F401,F403
9 | from .path import (Path, PathConcatNode, PathConvNode, PathDepthWiseConvNode,
10 | PathLinearNode, PathList, PathNode, PathNormNode)
11 |
12 | __all__ = [
13 | 'BackwardTracer', 'PathConvNode', 'PathLinearNode', 'PathNormNode',
14 | 'PathConcatNode', 'Path', 'PathList', 'PathNode', 'PathDepthWiseConvNode',
15 | 'ChannelAnalyzer', 'CustomTracer', 'UntracedMethodRegistry',
16 | 'custom_symbolic_trace', 'build_graphmodule'
17 | ]
18 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/task_modules/tracer/fx/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .custom_tracer import (CustomTracer, UntracedMethodRegistry,
3 | build_graphmodule, custom_symbolic_trace)
4 | from .graph_utils import (del_fakequant_after_function,
5 | del_fakequant_after_method,
6 | del_fakequant_after_module, del_fakequant_after_op,
7 | del_fakequant_before_function,
8 | del_fakequant_before_method,
9 | del_fakequant_before_module, del_fakequant_before_op)
10 |
11 | __all__ = [
12 | 'CustomTracer', 'UntracedMethodRegistry', 'custom_symbolic_trace',
13 | 'build_graphmodule', 'del_fakequant_before_module',
14 | 'del_fakequant_after_module', 'del_fakequant_after_function',
15 | 'del_fakequant_before_function', 'del_fakequant_after_op',
16 | 'del_fakequant_before_op', 'del_fakequant_before_method',
17 | 'del_fakequant_after_method'
18 | ]
19 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/task_modules/tracer/loss_calculator/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .cascade_encoder_decoder_loss_calculator import \
3 | CascadeEncoderDecoderPseudoLoss
4 | from .image_classifier_loss_calculator import ImageClassifierPseudoLoss
5 | from .single_stage_detector_loss_calculator import \
6 | SingleStageDetectorPseudoLoss
7 | from .sum_loss_calculator import SumPseudoLoss
8 | from .top_down_pose_estimator_loss_calculator import \
9 | TopdownPoseEstimatorPseudoLoss
10 | from .two_stage_detector_loss_calculator import TwoStageDetectorPseudoLoss
11 |
12 | __all__ = [
13 | 'ImageClassifierPseudoLoss', 'SingleStageDetectorPseudoLoss',
14 | 'TwoStageDetectorPseudoLoss', 'TopdownPoseEstimatorPseudoLoss',
15 | 'CascadeEncoderDecoderPseudoLoss', 'SumPseudoLoss'
16 | ]
17 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/task_modules/tracer/loss_calculator/cascade_encoder_decoder_loss_calculator.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch
3 |
4 | from mmrazor.registry import TASK_UTILS
5 |
6 | try:
7 | from mmseg.models import CascadeEncoderDecoder
8 | except ImportError:
9 | from mmrazor.utils import get_placeholder
10 | CascadeEncoderDecoder = get_placeholder('mmseg')
11 |
12 |
13 | @TASK_UTILS.register_module()
14 | class CascadeEncoderDecoderPseudoLoss:
15 | """Calculate the pseudo loss to trace the topology of a
16 | `CascadeEncoderDecoder` in MMSegmentation with `BackwardTracer`."""
17 |
18 | def __call__(self, model: CascadeEncoderDecoder) -> torch.Tensor:
19 | pseudo_img = torch.rand(1, 3, 224, 224)
20 | pseudo_output = model.backbone(pseudo_img)
21 | pseudo_output = model.neck(pseudo_output)
22 | # unmodified decode_heads
23 | out = torch.tensor(0.)
24 | for levels in pseudo_output:
25 | out += sum([level.sum() for level in levels])
26 | return out
27 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/task_modules/tracer/loss_calculator/image_classifier_loss_calculator.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch
3 |
4 | from mmrazor.registry import TASK_UTILS
5 |
6 | try:
7 | from mmcls.models import ImageClassifier
8 | except ImportError:
9 | from mmrazor.utils import get_placeholder
10 | ImageClassifier = get_placeholder('mmcls')
11 |
12 |
13 | @TASK_UTILS.register_module()
14 | class ImageClassifierPseudoLoss:
15 | """Calculate the pseudo loss to trace the topology of a `ImageClassifier`
16 | in MMClassification with `BackwardTracer`.
17 |
18 | Args:
19 | input_shape (Tuple): The shape of the pseudo input. Defaults to
20 | (2, 3, 224, 224).
21 | """
22 |
23 | def __init__(self, input_shape=(2, 3, 224, 224)):
24 | self.input_shape = input_shape
25 |
26 | def __call__(self, model: ImageClassifier) -> torch.Tensor:
27 | pseudo_img = torch.rand(self.input_shape)
28 | pseudo_output = model(pseudo_img)
29 | return pseudo_output.sum()
30 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/task_modules/tracer/loss_calculator/single_stage_detector_loss_calculator.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch
3 |
4 | from mmrazor.registry import TASK_UTILS
5 |
6 | try:
7 | from mmdet.models import SingleStageDetector
8 | except ImportError:
9 | from mmrazor.utils import get_placeholder
10 | SingleStageDetector = get_placeholder('mmdet')
11 |
12 |
13 | @TASK_UTILS.register_module()
14 | class SingleStageDetectorPseudoLoss:
15 | """Calculate the pseudo loss to trace the topology of a
16 | `SingleStageDetector` in MMDetection with `BackwardTracer`.
17 |
18 | Args:
19 | input_shape (Tuple): The shape of the pseudo input. Defaults to
20 | (2, 3, 224, 224).
21 | """
22 |
23 | def __init__(self, input_shape=(2, 3, 224, 224)):
24 | self.input_shape = input_shape
25 |
26 | def __call__(self, model: SingleStageDetector) -> torch.Tensor:
27 | pseudo_img = torch.rand(self.input_shape)
28 | pseudo_output = model(pseudo_img)
29 | out = torch.tensor(0.)
30 | for levels in pseudo_output:
31 | out += sum([level.sum() for level in levels])
32 |
33 | return out
34 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/task_modules/tracer/loss_calculator/top_down_pose_estimator_loss_calculator.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch
3 |
4 | from mmrazor.registry import TASK_UTILS
5 |
6 | try:
7 | from mmpose.models import TopdownPoseEstimator
8 | except ImportError:
9 | from mmrazor.utils import get_placeholder
10 | TopdownPoseEstimator = get_placeholder('mmpose')
11 |
12 |
13 | @TASK_UTILS.register_module()
14 | class TopdownPoseEstimatorPseudoLoss:
15 | """Calculate the pseudo loss to trace the topology of a
16 | `TopdownPoseEstimator` in MMPose with `BackwardTracer`."""
17 |
18 | def __call__(self, model: TopdownPoseEstimator) -> torch.Tensor:
19 | pseudo_img = torch.rand(1, 3, 224, 224)
20 | pseudo_output = model.backbone(pseudo_img)
21 | # immutable decode_heads
22 | out = torch.tensor(0.)
23 | for levels in pseudo_output:
24 | out += sum([level.sum() for level in levels])
25 | return out
26 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/task_modules/tracer/loss_calculator/two_stage_detector_loss_calculator.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch
3 |
4 | from mmrazor.registry import TASK_UTILS
5 |
6 | try:
7 | from mmdet.models import TwoStageDetector
8 | except ImportError:
9 | from mmrazor.utils import get_placeholder
10 | TwoStageDetector = get_placeholder('mmdet')
11 |
12 |
13 | # todo: adapt to mmdet 2.0
14 | @TASK_UTILS.register_module()
15 | class TwoStageDetectorPseudoLoss:
16 | """Calculate the pseudo loss to trace the topology of a `TwoStageDetector`
17 | in MMDet with `BackwardTracer`."""
18 |
19 | def __call__(self, model: TwoStageDetector) -> torch.Tensor:
20 | pseudo_img = torch.rand(1, 3, 224, 224)
21 | pseudo_output = model.backbone(pseudo_img)
22 | pseudo_output = model.neck(pseudo_output)
23 | out = torch.tensor(0.)
24 | for levels in pseudo_output:
25 | out += sum([level.sum() for level in levels])
26 |
27 | return out
28 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .make_divisible import make_divisible
3 | from .misc import add_prefix
4 | from .optim_wrapper import reinitialize_optim_wrapper_count_status
5 | from .parse_values import parse_values
6 | from .quantization_util import pop_rewriter_function_record, str2class
7 | from .utils import get_module_device, set_requires_grad
8 |
9 | __all__ = [
10 | 'make_divisible', 'add_prefix', 'reinitialize_optim_wrapper_count_status',
11 | 'str2class', 'get_module_device', 'set_requires_grad', 'parse_values',
12 | 'pop_rewriter_function_record'
13 | ]
14 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/utils/expandable_utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | """This module is used to expand the channels of a supernet.
3 |
4 | We only expose some tool functions, rather than all DynamicOps and
5 | MutableChannelUnits, as They uses a few hacky operations.
6 | """
7 | from .tools import (expand_expandable_dynamic_model, expand_static_model,ExpandableUnit,
8 | make_channel_divisible, to_expandable_model)
9 |
10 | __all__ = [
11 | 'make_channel_divisible',
12 | 'to_expandable_model',
13 | 'expand_expandable_dynamic_model',
14 | 'expand_static_model',
15 | 'ExpandableUnit'
16 | ]
17 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/utils/misc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from typing import Dict
3 |
4 |
5 | def add_prefix(inputs: Dict, prefix: str) -> Dict:
6 | """Add prefix for dict.
7 |
8 | Args:
9 | inputs (dict): The input dict with str keys.
10 | prefix (str): The prefix to add.
11 |
12 | Returns:
13 | dict: The dict with keys updated with ``prefix``.
14 | """
15 |
16 | outputs = dict()
17 | for name, value in inputs.items():
18 | outputs[f'{prefix}.{name}'] = value
19 |
20 | return outputs
21 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/utils/parse_values.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from typing import List
3 |
4 |
5 | def parse_values(candidate_lists: List[list]):
6 | """Parse a list with format `(min_range, max_range, step)`.
7 |
8 | NOTE: this method is required when customizing search space in configs.
9 | """
10 |
11 | def _range_to_list(input_range: List[int]) -> List[int]:
12 | assert len(input_range) == 3, (
13 | 'The format should be `(min_range, max_range, step)` with dim=3, '
14 | f'but got dim={len(input_range)}.')
15 | start, end, step = input_range
16 | return list(range(start, end + 1, step))
17 |
18 | return [_range_to_list(i) for i in candidate_lists]
19 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/models/utils/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from typing import List, Union
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 |
8 | def get_module_device(module: nn.Module) -> torch.device:
9 | """Get the device of a module.
10 |
11 | Args:
12 | module (nn.Module): A module contains the parameters.
13 | """
14 | try:
15 | next(module.parameters())
16 | except StopIteration as e:
17 | raise ValueError('The input module should contain parameters.') from e
18 |
19 | if next(module.parameters()).is_cuda:
20 | return next(module.parameters()).get_device()
21 |
22 | return torch.device('cpu')
23 |
24 |
25 | def set_requires_grad(nets: Union[nn.Module, List[nn.Module]],
26 | requires_grad: bool = False) -> None:
27 | """Set requires_grad for all the networks.
28 |
29 | Args:
30 | nets (nn.Module | list[nn.Module]): A list of networks or a single
31 | network.
32 | requires_grad (bool): Whether the networks require gradients or not
33 | """
34 | if not isinstance(nets, list):
35 | nets = [nets]
36 | for net in nets:
37 | if net is not None:
38 | for param in net.parameters():
39 | param.requires_grad = requires_grad
40 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/registry/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS, METRICS,
3 | MODEL_WRAPPERS, MODELS, OPTIM_WRAPPER_CONSTRUCTORS,
4 | OPTIM_WRAPPERS, OPTIMIZERS, PARAM_SCHEDULERS,
5 | RUNNER_CONSTRUCTORS, RUNNERS, TASK_UTILS, TRANSFORMS,
6 | VISBACKENDS, VISUALIZERS, WEIGHT_INITIALIZERS,
7 | sub_model)
8 |
9 | __all__ = [
10 | 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS', 'DATASETS', 'DATA_SAMPLERS',
11 | 'TRANSFORMS', 'MODELS', 'WEIGHT_INITIALIZERS', 'OPTIMIZERS',
12 | 'OPTIM_WRAPPERS', 'OPTIM_WRAPPER_CONSTRUCTORS', 'TASK_UTILS',
13 | 'PARAM_SCHEDULERS', 'METRICS', 'MODEL_WRAPPERS', 'LOOPS', 'VISBACKENDS',
14 | 'VISUALIZERS', 'sub_model'
15 | ]
16 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/structures/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .quantization import * # noqa: F401,F403
3 | from .subnet import * # noqa: F401,F403
4 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/structures/graph/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .base_graph import BaseGraph, BaseNode
3 | from .module_graph import ModuleGraph, ModuleNode
4 |
5 | __all__ = ['BaseGraph', 'BaseNode', 'ModuleNode', 'ModuleGraph']
6 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/structures/quantization/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .backend_config import * # noqa: F401,F403
3 | from .qconfig import * # noqa: F401,F403
4 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/structures/quantization/backend_config/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .academic import (get_academic_backend_config,
3 | get_academic_backend_config_dict)
4 | from .mapping import BackendConfigs
5 | from .native import get_native_backend_config, get_native_backend_config_dict
6 | from .openvino import (get_openvino_backend_config,
7 | get_openvino_backend_config_dict)
8 | from .tensorrt import (get_tensorrt_backend_config,
9 | get_tensorrt_backend_config_dict)
10 |
11 | __all__ = [
12 | 'BackendConfigs',
13 | 'get_native_backend_config',
14 | 'get_native_backend_config_dict',
15 | 'get_academic_backend_config',
16 | 'get_academic_backend_config_dict',
17 | 'get_openvino_backend_config',
18 | 'get_openvino_backend_config_dict',
19 | 'get_tensorrt_backend_config',
20 | 'get_tensorrt_backend_config_dict',
21 | ]
22 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/structures/quantization/backend_config/mapping.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch
3 |
4 | from mmrazor import digit_version
5 | from .academic import get_academic_backend_config
6 | from .native import get_native_backend_config
7 | from .openvino import get_openvino_backend_config
8 | from .tensorrt import get_tensorrt_backend_config
9 |
10 | if digit_version(
11 | torch.__version__) >= digit_version('1.13.0') and digit_version(
12 | torch.__version__) <= digit_version('1.13.1'):
13 | BackendConfigs = {
14 | 'academic': get_academic_backend_config(),
15 | 'native': get_native_backend_config(),
16 | 'tensorrt': get_tensorrt_backend_config(),
17 | 'openvino': get_openvino_backend_config()
18 | }
19 | else:
20 | BackendConfigs = {
21 | 'academic': None,
22 | 'native': None,
23 | 'tensorrt': None,
24 | 'openvino': None
25 | }
26 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/structures/subnet/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .candidate import Candidates
3 | from .fix_subnet import convert_fix_subnet, export_fix_subnet, load_fix_subnet
4 |
5 | __all__ = [
6 | 'load_fix_subnet', 'export_fix_subnet', 'convert_fix_subnet', 'Candidates'
7 | ]
8 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/testing/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from ._fast_stop_training_hook import FastStopTrainingHook # noqa: F401,F403
3 | from ._fx_models import * # noqa: F401, F403
4 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/testing/_fast_stop_training_hook.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from mmengine.hooks import Hook
3 |
4 | from mmrazor.registry import HOOKS
5 |
6 |
7 | @HOOKS.register_module()
8 | class FastStopTrainingHook(Hook):
9 | """Set runner's epoch information to the model."""
10 |
11 | def __init__(self, by_epoch, save_ckpt=False, stop_iter_or_epoch=5):
12 | self.by_epoch = by_epoch
13 | self.save_ckpt = save_ckpt
14 | self.stop_iter_or_epoch = stop_iter_or_epoch
15 |
16 | def after_train_iter(self, runner, batch_idx: int, data_batch: None,
17 | outputs: None) -> None:
18 | if self.save_ckpt and self.by_epoch:
19 | # If it is epoch-based and want to save weights,
20 | # we must run at least 1 epoch.
21 | return
22 | if runner.iter >= self.stop_iter_or_epoch:
23 | raise RuntimeError('quick exit')
24 |
25 | def after_train_epoch(self, runner) -> None:
26 | if runner.epoch >= self.stop_iter_or_epoch - 1:
27 | raise RuntimeError('quick exit')
28 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .index_dict import IndexDict
3 | from .log_tools import get_level, print_log
4 | from .misc import find_latest_checkpoint
5 | from .placeholder import get_package_placeholder, get_placeholder
6 | from .runtime_info import RuntimeInfo
7 | from .setup_env import register_all_modules, setup_multi_processes
8 | from .typing import (FixMutable, MultiMutatorsRandomSubnet,
9 | SingleMutatorRandomSubnet, SupportRandomSubnet,
10 | ValidFixMutable)
11 |
12 | __all__ = [
13 | 'find_latest_checkpoint', 'setup_multi_processes', 'register_all_modules',
14 | 'FixMutable', 'ValidFixMutable', 'SingleMutatorRandomSubnet',
15 | 'MultiMutatorsRandomSubnet', 'SupportRandomSubnet', 'get_placeholder',
16 | 'IndexDict', 'get_level', 'print_log', 'RuntimeInfo',
17 | 'get_package_placeholder'
18 | ]
19 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/utils/log_tools.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import logging
3 |
4 | import torch.distributed as dist
5 | from mmengine import MMLogger
6 | from mmengine import print_log as engine_print_log
7 |
8 |
9 | def get_level(level='info'):
10 | if isinstance(level, str):
11 | level = level.upper()
12 | assert level in logging._nameToLevel
13 | level = logging._nameToLevel[level]
14 | elif isinstance(level, int):
15 | pass
16 | else:
17 | raise NotImplementedError()
18 | return level
19 |
20 |
21 | def print_log(msg, logger='current', level='info', only_rank0=True):
22 |
23 | if only_rank0 and dist.is_initialized():
24 | if dist.get_rank() == 0:
25 | engine_print_log(msg, logger, get_level(level))
26 | else:
27 | pass
28 | else:
29 | engine_print_log(msg, logger, get_level(level))
30 |
31 |
32 | def set_log_level(level='debug'):
33 | level = get_level(level)
34 |
35 | logger = MMLogger.get_current_instance()
36 | logger.handlers[0].setLevel(level)
37 | logger.setLevel(level)
38 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/utils/misc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import glob
3 | import os.path as osp
4 | import warnings
5 |
6 |
7 | def find_latest_checkpoint(path, suffix='pth'):
8 | """Find the latest checkpoint from the working directory.
9 |
10 | Args:
11 | path(str): The path to find checkpoints.
12 | suffix(str): File extension. Defaults to pth.
13 |
14 | Returns:
15 | latest_path(str | None): File path of the latest checkpoint.
16 |
17 | References:
18 | .. [1] https://github.com/microsoft/SoftTeacher
19 | /blob/main/ssod/utils/patch.py
20 | """
21 | if not osp.exists(path):
22 | warnings.warn('The path of checkpoints does not exist.')
23 | return None
24 | if osp.exists(osp.join(path, f'latest.{suffix}')):
25 | return osp.join(path, f'latest.{suffix}')
26 |
27 | checkpoints = glob.glob(osp.join(path, f'*.{suffix}'))
28 | if len(checkpoints) == 0:
29 | warnings.warn('There are no checkpoints in the path.')
30 | return None
31 | latest = -1
32 | latest_path = None
33 | for checkpoint in checkpoints:
34 | count = int(osp.basename(checkpoint).split('_')[-1].split('.')[0])
35 | if count > latest:
36 | latest = count
37 | latest_path = checkpoint
38 | return latest_path
39 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/version.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved
2 |
3 | __version__ = '1.0.0'
4 |
5 |
6 | def parse_version_info(version_str):
7 | """Parse a version string into a tuple.
8 |
9 | Args:
10 | version_str (str): The version string.
11 | Returns:
12 | tuple[int | str]: The version info, e.g., "1.3.0" is parsed into
13 | (1, 3, 0), and "2.0.0rc1" is parsed into (2, 0, 0, 'rc1').
14 | """
15 | version_info = []
16 | for x in version_str.split('.'):
17 | if x.isdigit():
18 | version_info.append(int(x))
19 | elif x.find('rc') != -1:
20 | patch_version = x.split('rc')
21 | version_info.append(int(patch_version[0]))
22 | version_info.append(f'rc{patch_version[1]}')
23 | return tuple(version_info)
24 |
25 |
26 | version_info = parse_version_info(__version__)
27 |
28 | __all__ = ['__version__', 'version_info', 'parse_version_info']
29 |
--------------------------------------------------------------------------------
/mmrazor/mmrazor/visualization/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .local_visualizer import modify
3 |
4 | __all__ = ['modify']
5 |
--------------------------------------------------------------------------------
/mmrazor/projects/mmrazor_large/examples/ResNet/README.md:
--------------------------------------------------------------------------------
1 | # Examples for ResNet
2 |
3 | ## SparseGPT
4 |
5 | For more details about SparseGPT, please refer to [SparseGPT](../../algorithms/SparseGPT.md)
6 |
7 | ### Usage
8 |
9 | ```shell
10 | python projects/mmrazor_large/examples/ResNet/resnet18_sparse_gpt.py --data {imagenet_path} --batchsize 128 --num_samples 512
11 | ```
12 |
13 | **Note**: this imagenet folder follows torch format.
14 |
15 | ## GPTQ
16 |
17 | For more details about GPTQ, please refer to [GPTQ](../../algorithms/GPTQ.md)
18 |
19 | ### Usage
20 |
21 | ```shell
22 | python projects/mmrazor_large/examples/ResNet/resnet18_gptq.py --data {imagenet_path} --batchsize 128 --num_samples 512
23 | ```
24 |
25 | **Note**: this imagenet folder follows torch format.
26 |
--------------------------------------------------------------------------------
/mmrazor/requirements.txt:
--------------------------------------------------------------------------------
1 | -r requirements/optional.txt
2 | -r requirements/runtime.txt
3 | -r requirements/tests.txt
4 |
--------------------------------------------------------------------------------
/mmrazor/requirements/docs.txt:
--------------------------------------------------------------------------------
1 | docutils==0.16.0
2 | m2r
3 | myst-parser
4 | git+https://github.com/open-mmlab/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
5 | sphinx==4.0.2
6 | sphinx-copybutton
7 | sphinx_markdown_tables
8 |
--------------------------------------------------------------------------------
/mmrazor/requirements/mminstall.txt:
--------------------------------------------------------------------------------
1 | mmcv>=2.0.0rc1
2 | mmengine>=0.1.0,<1.0.0
3 |
--------------------------------------------------------------------------------
/mmrazor/requirements/optional.txt:
--------------------------------------------------------------------------------
1 | pydacefit
2 | pySOT==0.2.3
3 | scipy
4 | timm
5 |
--------------------------------------------------------------------------------
/mmrazor/requirements/readthedocs.txt:
--------------------------------------------------------------------------------
1 | mmcv>=1.3.8
2 | ordered_set
3 | torch
4 | torchvision
5 |
--------------------------------------------------------------------------------
/mmrazor/requirements/runtime.txt:
--------------------------------------------------------------------------------
1 | ordered_set
2 | typing_extensions;python_version<"3.8"
3 |
--------------------------------------------------------------------------------
/mmrazor/requirements/tests.txt:
--------------------------------------------------------------------------------
1 | coverage
2 | flake8
3 | interrogate
4 | isort==4.3.21
5 | nbconvert
6 | nbformat
7 | numpy < 1.24.0 # A temporary solution for tests with mmdet.
8 | onnx
9 | pytest
10 | triton==2.0.0
11 | xdoctest >= 0.10.0
12 | yapf
13 |
--------------------------------------------------------------------------------
/mmrazor/resources/design_and_implement.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LKJacky/Differentiable-Model-Scaling/27188926d8e0815ba1b6159b47741ea289ba2ca2/mmrazor/resources/design_and_implement.png
--------------------------------------------------------------------------------
/mmrazor/resources/mmrazor-logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LKJacky/Differentiable-Model-Scaling/27188926d8e0815ba1b6159b47741ea289ba2ca2/mmrazor/resources/mmrazor-logo.png
--------------------------------------------------------------------------------
/mmrazor/resources/qq_group_qrcode.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LKJacky/Differentiable-Model-Scaling/27188926d8e0815ba1b6159b47741ea289ba2ca2/mmrazor/resources/qq_group_qrcode.jpg
--------------------------------------------------------------------------------
/mmrazor/resources/xiaozhushou_weixin_qrcode.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LKJacky/Differentiable-Model-Scaling/27188926d8e0815ba1b6159b47741ea289ba2ca2/mmrazor/resources/xiaozhushou_weixin_qrcode.jpeg
--------------------------------------------------------------------------------
/mmrazor/resources/zhihu_qrcode.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LKJacky/Differentiable-Model-Scaling/27188926d8e0815ba1b6159b47741ea289ba2ca2/mmrazor/resources/zhihu_qrcode.jpg
--------------------------------------------------------------------------------
/mmrazor/setup.cfg:
--------------------------------------------------------------------------------
1 | [bdist_wheel]
2 | universal=1
3 |
4 | [aliases]
5 | test=pytest
6 |
7 | [yapf]
8 | based_on_style = pep8
9 | blank_line_before_nested_class_or_def = true
10 | split_before_expression_after_opening_paren = true
11 |
12 | [isort]
13 | line_length = 79
14 | multi_line_output = 0
15 | extra_standard_library = pkg_resources,setuptools
16 | known_first_party = mmrazor
17 | known_third_party=cv2,mmcls,mmcv,mmdet,mmseg,numpy,ordered_set,packaging,pytest,pytorch_sphinx_theme,torch,yaml
18 | no_lines_before = STDLIB,LOCALFOLDER
19 | default_section = THIRDPARTY
20 |
21 | [codespell]
22 | skip = *.ipynb
23 | quiet-level = 3
24 | ignore-words-list = patten,confectionary,nd,ty,formating
25 |
--------------------------------------------------------------------------------
/mmrazor/tests/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 |
--------------------------------------------------------------------------------
/mmrazor/tests/data/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 |
--------------------------------------------------------------------------------
/mmrazor/tests/data/color.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LKJacky/Differentiable-Model-Scaling/27188926d8e0815ba1b6159b47741ea289ba2ca2/mmrazor/tests/data/color.jpeg
--------------------------------------------------------------------------------
/mmrazor/tests/data/concat_subnet1.yaml:
--------------------------------------------------------------------------------
1 | op1.mutable_in_channels:
2 | current_choice: 3
3 | origin_channels: 3
4 | op1.mutable_out_channels:
5 | current_choice: 4
6 | origin_channels: 8
7 | bn1.mutable_num_features:
8 | current_choice: 4
9 | origin_channels: 8
10 | op2.mutable_in_channels:
11 | current_choice: 3
12 | origin_channels: 3
13 | op2.mutable_out_channels:
14 | current_choice: 4
15 | origin_channels: 8
16 | bn2.mutable_num_features:
17 | current_choice: 4
18 | origin_channels: 8
19 | op3.mutable_in_channels:
20 | current_choice: 8
21 | origin_channels: 16
22 | op3.mutable_out_channels:
23 | current_choice: 8
24 | origin_channels: 8
--------------------------------------------------------------------------------
/mmrazor/tests/data/concat_subnet2.yaml:
--------------------------------------------------------------------------------
1 | op1.mutable_in_channels:
2 | current_choice: 3
3 | origin_channels: 3
4 | op1.mutable_out_channels:
5 | current_choice: 8
6 | origin_channels: 8
7 | bn1.mutable_num_features:
8 | current_choice: 8
9 | origin_channels: 8
10 | op2.mutable_in_channels:
11 | current_choice: 3
12 | origin_channels: 3
13 | op2.mutable_out_channels:
14 | current_choice: 8
15 | origin_channels: 8
16 | bn2.mutable_num_features:
17 | current_choice: 8
18 | origin_channels: 8
19 | op3.mutable_in_channels:
20 | current_choice: 16
21 | origin_channels: 16
22 | op3.mutable_out_channels:
23 | current_choice: 8
24 | origin_channels: 8
--------------------------------------------------------------------------------
/mmrazor/tests/data/dataset/a/1.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LKJacky/Differentiable-Model-Scaling/27188926d8e0815ba1b6159b47741ea289ba2ca2/mmrazor/tests/data/dataset/a/1.JPG
--------------------------------------------------------------------------------
/mmrazor/tests/data/dataset/ann.json:
--------------------------------------------------------------------------------
1 | {
2 | "metainfo": {
3 | "categories": [
4 | {
5 | "category_name": "first",
6 | "id": 0
7 | },
8 | {
9 | "category_name": "second",
10 | "id": 1
11 | }
12 | ]
13 | },
14 | "data_list": [
15 | {
16 | "img_path": "a/1.JPG",
17 | "gt_label": 0
18 | },
19 | {
20 | "img_path": "b/2.jpeg",
21 | "gt_label": 1
22 | },
23 | {
24 | "img_path": "b/subb/2.jpeg",
25 | "gt_label": 1
26 | }
27 | ]
28 | }
29 |
--------------------------------------------------------------------------------
/mmrazor/tests/data/dataset/ann.txt:
--------------------------------------------------------------------------------
1 | a/1.JPG 0
2 | b/2.jpeg 1
3 | b/subb/3.jpg 1
4 |
--------------------------------------------------------------------------------
/mmrazor/tests/data/dataset/b/2.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LKJacky/Differentiable-Model-Scaling/27188926d8e0815ba1b6159b47741ea289ba2ca2/mmrazor/tests/data/dataset/b/2.jpeg
--------------------------------------------------------------------------------
/mmrazor/tests/data/dataset/b/subb/3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LKJacky/Differentiable-Model-Scaling/27188926d8e0815ba1b6159b47741ea289ba2ca2/mmrazor/tests/data/dataset/b/subb/3.jpg
--------------------------------------------------------------------------------
/mmrazor/tests/data/dataset/classes.txt:
--------------------------------------------------------------------------------
1 | bus
2 | car
3 |
--------------------------------------------------------------------------------
/mmrazor/tests/data/dataset/multi_label_ann.json:
--------------------------------------------------------------------------------
1 | {
2 | "metainfo": {
3 | "categories": [
4 | {
5 | "category_name": "first",
6 | "id": 0
7 | },
8 | {
9 | "category_name": "second",
10 | "id": 1
11 | }
12 | ]
13 | },
14 | "data_list": [
15 | {
16 | "img_path": "a/1.JPG",
17 | "gt_label": [0]
18 | },
19 | {
20 | "img_path": "b/2.jpeg",
21 | "gt_label": [1]
22 | },
23 | {
24 | "img_path": "b/subb/2.jpeg",
25 | "gt_label": [0, 1]
26 | }
27 | ]
28 | }
29 |
--------------------------------------------------------------------------------
/mmrazor/tests/data/subnet1.yaml:
--------------------------------------------------------------------------------
1 | op1.mutable_in_channels:
2 | current_choice: 3
3 | origin_channels: 3
4 | op1.mutable_out_channels:
5 | current_choice: 4
6 | origin_channels: 8
7 | bn1.mutable_num_features:
8 | current_choice: 4
9 | origin_channels: 8
10 | op2.mutable_in_channels:
11 | current_choice: 4
12 | origin_channels: 8
13 | op2.mutable_out_channels:
14 | current_choice: 4
15 | origin_channels: 8
16 | bn2.mutable_num_features:
17 | current_choice: 4
18 | origin_channels: 8
19 | op3.mutable_in_channels:
20 | current_choice: 4
21 | origin_channels: 8
22 | op3.mutable_out_channels:
23 | current_choice: 8
24 | origin_channels: 8
--------------------------------------------------------------------------------
/mmrazor/tests/data/subnet2.yaml:
--------------------------------------------------------------------------------
1 | op1.mutable_in_channels:
2 | current_choice: 3
3 | origin_channels: 3
4 | op1.mutable_out_channels:
5 | current_choice: 8
6 | origin_channels: 8
7 | bn1.mutable_num_features:
8 | current_choice: 8
9 | origin_channels: 8
10 | op2.mutable_in_channels:
11 | current_choice: 8
12 | origin_channels: 8
13 | op2.mutable_out_channels:
14 | current_choice: 8
15 | origin_channels: 8
16 | bn2.mutable_num_features:
17 | current_choice: 8
18 | origin_channels: 8
19 | op3.mutable_in_channels:
20 | current_choice: 8
21 | origin_channels: 8
22 | op3.mutable_out_channels:
23 | current_choice: 8
24 | origin_channels: 8
--------------------------------------------------------------------------------
/mmrazor/tests/data/test_models/test_mutator/subnet1.json:
--------------------------------------------------------------------------------
1 | {
2 | "op1_(0, 8)_8": {
3 | "init_args":{
4 | "num_channels":8,
5 | "divisor":1,
6 | "min_value":1,
7 | "min_ratio":0.9,
8 | "candidate_choices":[
9 | 6
10 | ],
11 | "choice_mode":"number"
12 | },
13 | "choice":6
14 | }
15 | }
16 |
--------------------------------------------------------------------------------
/mmrazor/tests/data/test_models/test_subnet/mockmodel_subnet.yaml:
--------------------------------------------------------------------------------
1 | mutable1:
2 | chosen: conv1
3 | mutable2:
4 | chosen: conv2
5 | mutable3.0.kernel_size:
6 | chosen: 3
7 |
--------------------------------------------------------------------------------
/mmrazor/tests/data/test_models/test_task_modules/mmcls_cfg.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | _base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py']
--------------------------------------------------------------------------------
/mmrazor/tests/data/test_registry/registry_architecture_config.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from platform import architecture
3 |
4 |
5 | supernet = dict(
6 | type='MockModel',
7 | )
8 |
9 | model = dict(
10 | type='MockAlgorithm',
11 | architecture=supernet,
12 | _return_architecture_ = True,
13 | )
14 |
15 |
--------------------------------------------------------------------------------
/mmrazor/tests/data/test_registry/registry_subnet_config.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | supernet = dict(
3 | type='mmrazor.sub_model',
4 | cfg=dict(
5 | type='MockModel',
6 | ),
7 | fix_subnet = {
8 | 'backbone.mutable1': {'chosen':'conv1'},
9 | 'backbone.mutable2': {'chosen':'conv2'},
10 | },
11 | extra_prefix='backbone.'
12 | )
13 |
14 | model = dict(
15 | type='MockAlgorithm',
16 | architecture=supernet
17 | )
18 |
--------------------------------------------------------------------------------
/mmrazor/tests/test_core/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 |
--------------------------------------------------------------------------------
/mmrazor/tests/test_core/test_delivers/toy_module.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import random
3 |
4 | TOY_VAR = 'aaa'
5 |
6 |
7 | def toy_func():
8 | return random.randint(0, 1000)
9 |
10 |
11 | class ToyClass:
12 |
13 | def __init__(self):
14 | self._count = 0
15 |
16 | def random_int(self):
17 | return random.randint(0, 1000)
18 |
19 | @property
20 | def count(self):
21 | return self._count
22 |
23 | def __call__(self):
24 | self._count += 1
25 | return self._count
26 |
--------------------------------------------------------------------------------
/mmrazor/tests/test_core/test_graph/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 |
--------------------------------------------------------------------------------
/mmrazor/tests/test_core/test_graph/test_graph.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import sys
3 | from unittest import TestCase
4 |
5 | import torch
6 |
7 | sys.setrecursionlimit(int(1e8))
8 |
9 | DEVICE = torch.device('cpu')
10 |
11 |
12 | class TestGraph(TestCase):
13 | pass
14 | # def test_init_from_fx_tracer(self) -> None:
15 | # TestData = BackwardPassedModelManager.include_models()
16 | # with SetTorchThread(1):
17 | # with mp.Pool() as p:
18 | # result = p.map(_test_init_from_fx_tracer, TestData)
19 | # for res, model in zip(result, TestData):
20 | # with self.subTest(model=model):
21 | # self.assertTrue(res[0], res[1])
22 |
23 | # def test_init_from_backward_tracer(self) -> None:
24 | # TestData = FxPassedModelManager.include_models()
25 | # with SetTorchThread(1) as _:
26 | # with mp.Pool() as p:
27 | # result = p.map(_test_init_from_backward_tracer, TestData)
28 | # for res, model in zip(result, TestData):
29 | # # test_init_from_backward_tracer(model)
30 | # with self.subTest(model=model):
31 | # self.assertTrue(res[0], res[1])
32 |
--------------------------------------------------------------------------------
/mmrazor/tests/test_core/test_recorders/test_method_inputs_recorder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from unittest import TestCase
3 |
4 | from mmrazor.models.task_modules import MethodInputsRecorder
5 |
6 |
7 | class TestFuncOutputsRecorder(TestCase):
8 |
9 | def test_context_manager(self):
10 | from toy_mod import ToyClass
11 |
12 | toy = ToyClass()
13 |
14 | recorder = MethodInputsRecorder('toy_mod.ToyClass.func')
15 | recorder.initialize()
16 |
17 | with recorder:
18 | _ = toy.func(x=1, y=2)
19 | _ = toy.func(1, y=2)
20 | _ = toy.func(y=2, x=1)
21 |
22 | self.assertTrue(
23 | recorder.get_record_data(record_idx=0, data_idx=0) == 1)
24 | self.assertTrue(
25 | recorder.get_record_data(record_idx=0, data_idx=1) == 2)
26 |
27 | self.assertTrue(
28 | recorder.get_record_data(record_idx=1, data_idx=0) == 1)
29 | self.assertTrue(
30 | recorder.get_record_data(record_idx=1, data_idx=1) == 2)
31 |
32 | self.assertTrue(
33 | recorder.get_record_data(record_idx=2, data_idx=0) == 1)
34 | self.assertTrue(
35 | recorder.get_record_data(record_idx=2, data_idx=1) == 2)
36 |
--------------------------------------------------------------------------------
/mmrazor/tests/test_core/test_recorders/test_param_recorder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from unittest import TestCase
3 |
4 | import torch
5 | from torch import nn
6 |
7 | from mmrazor.models.task_modules import ParameterRecorder
8 |
9 |
10 | class ToyModel(nn.Module):
11 |
12 | def __init__(self):
13 | super().__init__()
14 | self.toy_conv = nn.Conv2d(1, 1, 1)
15 | self.no_record_conv = nn.Conv2d(1, 1, 1)
16 |
17 | def forward(self, x):
18 | return self.toy_conv(x)
19 |
20 |
21 | class TestParameterRecorder(TestCase):
22 |
23 | def test_prepare_from_model(self):
24 |
25 | model = ToyModel()
26 | recorder = ParameterRecorder('AAA')
27 | with self.assertRaisesRegex(AssertionError, '"AAA" is not in the'):
28 | recorder.initialize(model)
29 |
30 | recorder = ParameterRecorder('toy_conv.bias')
31 | with self.assertRaisesRegex(AssertionError, 'model can not be None'):
32 | recorder.prepare_from_model()
33 |
34 | recorder.initialize(model)
35 | bias_weight = recorder.get_record_data()
36 |
37 | self.assertEquals(bias_weight, model.toy_conv.bias)
38 |
39 | with recorder:
40 | _ = model(torch.randn(1, 1, 1, 1))
41 |
42 | self.assertEquals(bias_weight, model.toy_conv.bias)
43 |
--------------------------------------------------------------------------------
/mmrazor/tests/test_core/test_recorders/toy_mod.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import random
3 |
4 | TOY_VAR = 'aaa'
5 |
6 |
7 | def toy_func(a):
8 | return a
9 |
10 |
11 | def toy_func2(a, b):
12 | return a, b
13 |
14 |
15 | def toy_list_func(a):
16 | return [a, a, a]
17 |
18 |
19 | def execute_toy_func(a):
20 | toy_func(a)
21 |
22 |
23 | def execute_toy_func2(a, b):
24 | toy_func2(a, b)
25 |
26 |
27 | def execute_toy_list_func(a):
28 | toy_list_func(a)
29 |
30 |
31 | class ToyClass:
32 |
33 | TOY_CLS = 'TOY_CLASS'
34 |
35 | def __init__(self):
36 | self._count = 0
37 |
38 | def toy(self):
39 | self._count += 1
40 | return self._count
41 |
42 | def func(self, x, y=0):
43 | return x + y
44 |
45 | def __call__(self):
46 | self._count += 1
47 | return self._count
48 |
49 |
50 | class Toy():
51 |
52 | def toy_func(self):
53 | return random.randint(0, 1000)
54 |
55 | def toy_list_func(self):
56 | return [random.randint(0, 1000) for _ in range(3)]
57 |
--------------------------------------------------------------------------------
/mmrazor/tests/test_core/test_tracer/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 |
--------------------------------------------------------------------------------
/mmrazor/tests/test_core/test_tracer/test_prune_tracer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from unittest import TestCase
3 |
4 | import torch
5 |
6 | from mmrazor import digit_version
7 | from mmrazor.models.task_modules.tracer import ChannelAnalyzer
8 | from ...data.models import SingleLineModel
9 |
10 |
11 | class TestChannelAnalyzer(TestCase):
12 |
13 | def test_backward_tracer(self):
14 | model = SingleLineModel()
15 | tracer = ChannelAnalyzer(tracer_type='BackwardTracer')
16 | unit_configs = tracer.analyze(model)
17 | print(unit_configs)
18 |
19 | def test_fx_tracer(self):
20 | if digit_version(torch.__version__) < digit_version('1.12.0'):
21 | self.skipTest('torch<1.12.0')
22 | model = SingleLineModel()
23 | tracer = ChannelAnalyzer(tracer_type='FxTracer')
24 | unit_configs = tracer.analyze(model)
25 | print(unit_configs)
26 |
--------------------------------------------------------------------------------
/mmrazor/tests/test_doc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import os
3 | from unittest import TestCase
4 |
5 | import nbformat
6 | from nbconvert.preprocessors import ExecutePreprocessor
7 |
8 | TEST_DOC = os.getenv('TEST_DOC') == 'true'
9 | notebook_paths = [
10 | './mmrazor/models/mutators/channel_mutator/channel_mutator.ipynb',
11 | './mmrazor/models/mutables/mutable_channel/units/mutable_channel_unit.ipynb', # noqa
12 | './demo/config_pruning.ipynb'
13 | ]
14 |
15 |
16 | class TestDocs(TestCase):
17 |
18 | def setUp(self) -> None:
19 | if not TEST_DOC:
20 | self.skipTest('disabled')
21 |
22 | def test_notebooks(self):
23 | for path in notebook_paths:
24 | with self.subTest(path=path):
25 | with open(path) as file:
26 | nb_in = nbformat.read(file, nbformat.NO_CONVERT)
27 | ep = ExecutePreprocessor(
28 | timeout=600, kernel_name='python3')
29 | try:
30 | _ = ep.preprocess(nb_in)
31 | except Exception:
32 | self.fail()
33 |
--------------------------------------------------------------------------------
/mmrazor/tests/test_engine/test_hooks/test_stop_distillation_hook.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from unittest import TestCase
3 | from unittest.mock import Mock
4 |
5 | from mmrazor.engine import StopDistillHook
6 |
7 |
8 | class TestStopDistillHook(TestCase):
9 |
10 | def setUp(self):
11 | self.hook = StopDistillHook(stop_epoch=5)
12 | runner = Mock()
13 | runner.model = Mock()
14 | runner.model.distillation_stopped = False
15 |
16 | runner.epoch = 0
17 | self.runner = runner
18 |
19 | def test_before_train_epoch(self):
20 | max_epochs = 10
21 | target = [False] * 5 + [True] * 5
22 | for epoch in range(max_epochs):
23 | self.hook.before_train_epoch(self.runner)
24 | self.assertEquals(self.runner.model.distillation_stopped,
25 | target[epoch])
26 | self.runner.epoch += 1
27 |
--------------------------------------------------------------------------------
/mmrazor/tests/test_impl/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 |
--------------------------------------------------------------------------------
/mmrazor/tests/test_impl/test_pruning/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 |
--------------------------------------------------------------------------------
/mmrazor/tests/test_impl/test_pruning/test_group_fisher/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 |
--------------------------------------------------------------------------------
/mmrazor/tests/test_models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 |
--------------------------------------------------------------------------------
/mmrazor/tests/test_models/test_algorithms/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 |
--------------------------------------------------------------------------------
/mmrazor/tests/test_models/test_algorithms/test_general_quant.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from unittest import TestCase
3 |
4 | import torch.nn as nn
5 |
6 |
7 | class ToyModel(nn.Module):
8 |
9 | def __init__(self) -> None:
10 | super().__init__()
11 | # TODO
12 |
13 |
14 | class TestGeneralQuant(TestCase):
15 | """TODO.
16 |
17 | Args:
18 | TestCase (_type_): _description_
19 | """
20 |
21 | def test_init(self):
22 | pass
23 |
24 | def test_prepare(self):
25 | pass
26 |
27 | def test_convert(self):
28 | pass
29 |
30 | def test_states(self):
31 | pass
32 |
33 | def test_forward(self):
34 | pass
35 |
--------------------------------------------------------------------------------
/mmrazor/tests/test_models/test_architectures/test_backbones/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from typing import Dict, List
3 |
4 | from torch import Tensor
5 | from torch.nn import Conv2d, Module
6 |
7 | from mmrazor.registry import MODELS
8 |
9 |
10 | @MODELS.register_module()
11 | class MockMutable(Module):
12 |
13 | def __init__(self, choices: List[str], module_kwargs: Dict) -> None:
14 | super().__init__()
15 |
16 | self.choices = choices
17 | self.module_kwargs = module_kwargs
18 | self.conv = Conv2d(**module_kwargs, kernel_size=3, padding=3 // 2)
19 |
20 | def forward(self, x: Tensor) -> Tensor:
21 | return self.conv(x)
22 |
--------------------------------------------------------------------------------
/mmrazor/tests/test_models/test_architectures/test_dynamic_op/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 |
--------------------------------------------------------------------------------
/mmrazor/tests/test_models/test_architectures/test_dynamic_op/test_bricks/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 |
--------------------------------------------------------------------------------
/mmrazor/tests/test_models/test_architectures/test_dynamic_op/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from typing import Dict, Optional
3 |
4 | from mmrazor.models.architectures.dynamic_ops import DynamicMixin
5 | from mmrazor.utils.typing import DumpChosen
6 |
7 |
8 | def fix_dynamic_op(op: DynamicMixin,
9 | fix_mutables: Optional[Dict] = None) -> None:
10 | for name, mutable in op.mutable_attrs.items():
11 |
12 | if fix_mutables is not None:
13 | chosen = fix_mutables[f'mutable_attrs.{name}']
14 | else:
15 | chosen = mutable.dump_chosen()
16 |
17 | if not isinstance(chosen, DumpChosen):
18 | chosen = DumpChosen(**chosen)
19 |
20 | mutable.fix_chosen(chosen.chosen)
21 |
--------------------------------------------------------------------------------
/mmrazor/tests/test_models/test_fake_quants/test_torch_fake_quants.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import pytest
3 | import torch
4 |
5 | from mmrazor import digit_version
6 | from mmrazor.models.fake_quants import register_torch_fake_quants
7 | from mmrazor.registry import MODELS
8 |
9 |
10 | @pytest.mark.skipif(
11 | digit_version(torch.__version__) < digit_version('1.13.0'),
12 | reason='version of torch < 1.13.0')
13 | def test_register_torch_fake_quants():
14 |
15 | TORCH_fake_quants = register_torch_fake_quants()
16 | assert isinstance(TORCH_fake_quants, list)
17 | for fake_quant in TORCH_fake_quants:
18 | assert MODELS.get(fake_quant)
19 |
--------------------------------------------------------------------------------
/mmrazor/tests/test_models/test_mutables/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 |
--------------------------------------------------------------------------------
/mmrazor/tests/test_models/test_mutables/test_mutable_channel/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 |
--------------------------------------------------------------------------------
/mmrazor/tests/test_models/test_mutables/test_mutable_channel/test_units/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 |
--------------------------------------------------------------------------------
/mmrazor/tests/test_models/test_mutables/test_mutable_channel/test_units/test_l1_mutable_channel_unit.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from unittest import TestCase
3 |
4 | import torch.nn as nn
5 |
6 | from mmrazor.models.mutables import L1MutableChannelUnit
7 | from mmrazor.models.mutators import ChannelMutator
8 | from .....data.models import SingleLineModel
9 |
10 |
11 | class TestL1MutableChannelUnit(TestCase):
12 |
13 | def test_init(self):
14 | model = SingleLineModel()
15 | mutator = ChannelMutator(
16 | channel_unit_cfg={
17 | 'type': 'L1MutableChannelUnit',
18 | 'default_args': {
19 | 'choice_mode': 'ratio'
20 | }
21 | })
22 | mutator.prepare_from_supernet(model)
23 |
24 | def test_convnd(self):
25 | unit = L1MutableChannelUnit(8)
26 | conv = nn.Conv3d(3, 8, 3)
27 | norm = unit._get_l1_norm(conv, 0, 8)
28 | self.assertSequenceEqual(norm.shape, [8])
29 |
--------------------------------------------------------------------------------
/mmrazor/tests/test_models/test_mutables/test_sequential_mutable_channel.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from unittest import TestCase
3 |
4 | from mmrazor.models.mutables import SquentialMutableChannel
5 |
6 |
7 | class TestSquentialMutableChannel(TestCase):
8 |
9 | def test_mul_float(self):
10 | channel = SquentialMutableChannel(10)
11 | new_channel = channel * 0.5
12 | self.assertEqual(new_channel.current_choice, 5)
13 | channel.current_choice = 5
14 | self.assertEqual(new_channel.current_choice, 2)
15 |
--------------------------------------------------------------------------------
/mmrazor/tests/test_models/test_mutators/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 |
--------------------------------------------------------------------------------
/mmrazor/tests/test_models/test_mutators/test_dmcp_mutator.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch
3 | from mmcls.models import * # noqa: F401,F403
4 | from torch import Tensor, nn
5 | from torch.nn import Module
6 |
7 | from mmrazor.models.mutators import DMCPChannelMutator
8 |
9 |
10 | class ResBlock(Module):
11 |
12 | def __init__(self) -> None:
13 | super().__init__()
14 |
15 | self.op1 = nn.Conv2d(3, 8, 1)
16 | self.bn1 = nn.BatchNorm2d(8)
17 | self.op2 = nn.Conv2d(8, 8, 1)
18 | self.bn2 = nn.BatchNorm2d(8)
19 | self.op3 = nn.Conv2d(8, 8, 1)
20 |
21 | def forward(self, x: Tensor) -> Tensor:
22 | x1 = self.bn1(self.op1(x))
23 | x2 = self.bn2(self.op2(x1))
24 | x3 = self.op3(x2 + x1)
25 | return x3
26 |
27 |
28 | def test_DMCP_channel_mutator() -> None:
29 | imgs = torch.randn(16, 3, 224, 224)
30 |
31 | # ResBlock
32 | mutator = DMCPChannelMutator(channel_unit_cfg=dict(type='DMCPChannelUnit'))
33 |
34 | model = ResBlock()
35 | mutator.prepare_from_supernet(model)
36 | for mode in ['max', 'min', 'random', 'expected', 'direct']:
37 | mutator.sample_subnet(mode, arch_train=True)
38 | out3 = model(imgs)
39 |
40 | assert out3.shape == (16, 8, 224, 224)
41 |
--------------------------------------------------------------------------------
/mmrazor/tests/test_models/test_observers/test_torch_observers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import pytest
3 | import torch
4 |
5 | from mmrazor import digit_version
6 | from mmrazor.models.observers import register_torch_observers
7 | from mmrazor.registry import MODELS
8 |
9 |
10 | @pytest.mark.skipif(
11 | digit_version(torch.__version__) < digit_version('1.13.0'),
12 | reason='version of torch < 1.13.0')
13 | def test_register_torch_observers():
14 |
15 | TORCH_observers = register_torch_observers()
16 | assert isinstance(TORCH_observers, list)
17 | for observer in TORCH_observers:
18 | assert MODELS.get(observer)
19 |
--------------------------------------------------------------------------------
/mmrazor/tests/test_models/test_task_modules/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 |
--------------------------------------------------------------------------------
/mmrazor/tests/test_models/test_task_modules/test_demo_inputs/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 |
--------------------------------------------------------------------------------
/mmrazor/tests/test_models/test_task_modules/test_demo_inputs/test_demo_inputs.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import unittest
3 |
4 | from mmrazor.models.task_modules.demo_inputs import DefaultDemoInput
5 | from ....data.tracer_passed_models import FxPassedModelManager
6 |
7 |
8 | class TestDemoInputs(unittest.TestCase):
9 |
10 | def test_demo_inputs(self):
11 | for Model in FxPassedModelManager().include_models():
12 | with self.subTest(model=Model):
13 | demo_input = DefaultDemoInput(input_shape=[1, 3, 224, 224])
14 | model = Model()
15 | model.eval()
16 | try:
17 | demo_input(model)
18 | input = demo_input.get_data(model)
19 | if isinstance(input, dict):
20 | model(**input)
21 | else:
22 | model(input)
23 | except Exception as e:
24 | self.fail(f'{e}')
25 |
--------------------------------------------------------------------------------
/mmrazor/tests/test_models/test_utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 |
--------------------------------------------------------------------------------
/mmrazor/tests/test_models/test_utils/test_expandable_utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 |
--------------------------------------------------------------------------------
/mmrazor/tests/test_runners/test_utils/test_genetic.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from mmrazor.engine.runner.utils import crossover
3 |
4 |
5 | def test_crossover():
6 | fake_random_subnet1 = {}
7 | fake_random_subnet2 = {}
8 | for i in range(50):
9 | fake_random_subnet1[i] = f'{i}_choice1'
10 | fake_random_subnet2[i] = f'{i}_choice2'
11 |
12 | result = crossover(fake_random_subnet1, fake_random_subnet2)
13 |
14 | assert type(result) == type(fake_random_subnet1)
15 | assert len(result) == len(fake_random_subnet1)
16 |
--------------------------------------------------------------------------------
/mmrazor/tests/test_tools/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 |
--------------------------------------------------------------------------------
/mmrazor/tests/test_utils/test_index_dict.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import unittest
3 |
4 | from mmrazor.utils.index_dict import IndexDict
5 |
6 |
7 | class TestIndexDict(unittest.TestCase):
8 |
9 | def test_dict(self):
10 | dict = IndexDict()
11 | dict[(4, 5)] = 2
12 | dict[(1, 3)] = 1
13 |
14 | self.assertSequenceEqual(list(dict.keys()), [(1, 3), (4, 5)])
15 | with self.assertRaisesRegex(AssertionError, 'overlap'):
16 | dict[2, 3] = 3
17 |
--------------------------------------------------------------------------------
/mmrazor/tests/test_utils/test_placeholder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import unittest
3 |
4 | import pytest
5 |
6 | from mmrazor.utils import get_placeholder
7 |
8 |
9 | class TestPlaceholder(unittest.TestCase):
10 |
11 | def test_placeholder(self):
12 | holder = get_placeholder('test')
13 | with pytest.raises(ImportError):
14 | holder()
15 | from mmrazor.models.architectures.dynamic_ops import DynamicMixin
16 |
17 | class tmp(holder, DynamicMixin):
18 | pass
19 |
--------------------------------------------------------------------------------
/mmrazor/tests/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .set_torch_thread import SetTorchThread
3 |
4 | __all__ = ['SetTorchThread']
5 |
--------------------------------------------------------------------------------
/mmrazor/tests/utils/set_dist_env.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import os
3 | import random
4 |
5 | import torch
6 | import torch.distributed as dist
7 |
8 |
9 | class SetDistEnv:
10 |
11 | def __init__(self, using_cuda=False, port=None) -> None:
12 | self.using_cuda = using_cuda
13 | if self.using_cuda:
14 | assert torch.cuda.is_available()
15 | if port is None:
16 | port = random.randint(10000, 20000)
17 | self.port = port
18 |
19 | def __enter__(self):
20 | os.environ['MASTER_ADDR'] = 'localhost'
21 | os.environ['MASTER_PORT'] = str(self.port)
22 |
23 | # initialize the process group
24 | if self.using_cuda:
25 | backend = 'nccl'
26 | else:
27 | backend = 'gloo'
28 | dist.init_process_group(backend, rank=0, world_size=1)
29 |
30 | def __exit__(self, exc_type, exc_value, tb):
31 | dist.destroy_process_group()
32 |
--------------------------------------------------------------------------------
/mmrazor/tests/utils/set_torch_thread.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch
3 |
4 |
5 | class SetTorchThread:
6 |
7 | def __init__(self, num_thread: int = -1) -> None:
8 | self.prev_num_threads = torch.get_num_threads()
9 | self.num_threads = num_thread
10 |
11 | def __enter__(self):
12 | if self.num_threads != -1:
13 | torch.set_num_threads(self.num_threads)
14 |
15 | def __exit__(self, exc_type, exc_value, tb):
16 | if self.num_threads != -1:
17 | torch.set_num_threads(self.prev_num_threads)
18 |
--------------------------------------------------------------------------------
/mmrazor/tools/dist_test.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | CONFIG=$1
4 | CHECKPOINT=$2
5 | GPUS=$3
6 | NNODES=${NNODES:-1}
7 | NODE_RANK=${NODE_RANK:-0}
8 | PORT=${PORT:-29500}
9 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
10 |
11 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
12 | python -m torch.distributed.launch \
13 | --nnodes=$NNODES \
14 | --node_rank=$NODE_RANK \
15 | --master_addr=$MASTER_ADDR \
16 | --nproc_per_node=$GPUS \
17 | --master_port=$PORT \
18 | $(dirname "$0")/test.py \
19 | $CONFIG \
20 | $CHECKPOINT \
21 | --launcher pytorch \
22 | ${@:4}
23 |
--------------------------------------------------------------------------------
/mmrazor/tools/dist_train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | CONFIG=$1
4 | GPUS=$2
5 | NNODES=${NNODES:-1}
6 | NODE_RANK=${NODE_RANK:-0}
7 | PORT=${PORT:-29500}
8 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
9 |
10 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
11 | python -m torch.distributed.launch \
12 | --nnodes=$NNODES \
13 | --node_rank=$NODE_RANK \
14 | --master_addr=$MASTER_ADDR \
15 | --nproc_per_node=$GPUS \
16 | --master_port=$PORT \
17 | $(dirname "$0")/train.py \
18 | $CONFIG \
19 | --launcher pytorch ${@:3}
20 |
--------------------------------------------------------------------------------
/mmrazor/tools/slurm_test.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | CONFIG=$1
4 | CHECKPOINT=$2
5 | GPUS=$3
6 | PORT=${PORT:-29500}
7 |
8 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
9 | python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \
10 | $(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4}
11 |
--------------------------------------------------------------------------------
/mmrazor/tools/slurm_train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | set -x
4 |
5 | PARTITION=$1
6 | JOB_NAME=$2
7 | CONFIG=$3
8 | WORK_DIR=$4
9 | GPUS=${GPUS:-8}
10 | GPUS_PER_NODE=${GPUS_PER_NODE:-8}
11 | CPUS_PER_TASK=${CPUS_PER_TASK:-5}
12 | SRUN_ARGS=${SRUN_ARGS:-""}
13 | PY_ARGS=${@:5}
14 |
15 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
16 | srun -p ${PARTITION} \
17 | --job-name=${JOB_NAME} \
18 | --gres=gpu:${GPUS_PER_NODE} \
19 | --ntasks=${GPUS} \
20 | --ntasks-per-node=${GPUS_PER_NODE} \
21 | --cpus-per-task=${CPUS_PER_TASK} \
22 | --kill-on-bad-exit=1 \
23 | ${SRUN_ARGS} \
24 | python -u tools/train.py ${CONFIG} --work-dir=${WORK_DIR} --launcher="slurm" ${PY_ARGS}
25 |
--------------------------------------------------------------------------------
/mmrazor/tools/visualizations/demo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LKJacky/Differentiable-Model-Scaling/27188926d8e0815ba1b6159b47741ea289ba2ca2/mmrazor/tools/visualizations/demo.jpg
--------------------------------------------------------------------------------
/mmrazor/tools/visualizations/vis_configs/backbone_feature_diff_visualization.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | # configs for the 1st model
3 | recorders1 = dict(
4 | backbone=dict(_scope_='mmrazor', type='ModuleOutputs', source='backbone'))
5 | mappings1 = dict(
6 | p3=dict(recorder='backbone', data_idx=0),
7 | p4=dict(recorder='backbone', data_idx=1),
8 | p5=dict(recorder='backbone', data_idx=2),
9 | p6=dict(recorder='backbone', data_idx=3))
10 |
11 | # configs for the 2nd model
12 | recorders2 = dict(
13 | backbone=dict(_scope_='mmrazor', type='ModuleOutputs', source='backbone'))
14 | mappings2 = dict(
15 | p3=dict(recorder='backbone', data_idx=0),
16 | p4=dict(recorder='backbone', data_idx=1),
17 | p5=dict(recorder='backbone', data_idx=2),
18 | p6=dict(recorder='backbone', data_idx=3))
19 |
--------------------------------------------------------------------------------
/mmrazor/tools/visualizations/vis_configs/backbone_feature_visualization.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | recorders = dict(
3 | backbone=dict(_scope_='mmrazor', type='ModuleOutputs', source='backbone'))
4 | mappings = dict(
5 | p3=dict(recorder='backbone', data_idx=0),
6 | p4=dict(recorder='backbone', data_idx=1),
7 | p5=dict(recorder='backbone', data_idx=2),
8 | p6=dict(recorder='backbone', data_idx=3))
9 |
--------------------------------------------------------------------------------
/mmrazor/tools/visualizations/vis_configs/fpn_feature_diff_visualization.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | # configs for the 1st model
3 | recorders1 = dict(
4 | neck=dict(_scope_='mmrazor', type='ModuleOutputs', source='neck'))
5 | mappings1 = dict(
6 | p3=dict(recorder='neck', data_idx=0),
7 | p4=dict(recorder='neck', data_idx=1),
8 | p5=dict(recorder='neck', data_idx=2),
9 | p6=dict(recorder='neck', data_idx=3))
10 |
11 | # configs for the 2nd model
12 | recorders2 = dict(
13 | neck=dict(_scope_='mmrazor', type='ModuleOutputs', source='neck'))
14 | mappings2 = dict(
15 | p3=dict(recorder='neck', data_idx=0),
16 | p4=dict(recorder='neck', data_idx=1),
17 | p5=dict(recorder='neck', data_idx=2),
18 | p6=dict(recorder='neck', data_idx=3))
19 |
--------------------------------------------------------------------------------
/mmrazor/tools/visualizations/vis_configs/fpn_feature_visualization.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | recorders = dict(
3 | neck=dict(_scope_='mmrazor', type='ModuleOutputs', source='neck'))
4 | mappings = dict(
5 | p3=dict(recorder='neck', data_idx=0),
6 | p4=dict(recorder='neck', data_idx=1),
7 | p5=dict(recorder='neck', data_idx=2),
8 | p6=dict(recorder='neck', data_idx=3))
9 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["hatchling"]
3 | build-backend = "hatchling.build"
4 |
5 | [project]
6 | name = "dms"
7 | version = "0.0.1"
8 | authors = [
9 | { name="kai liu", email="author@example.com" },
10 | ]
11 | description = "Differentiable Model Scaling"
12 | readme = "README.md"
13 | requires-python = ">=3.7"
14 | classifiers = [
15 | "Programming Language :: Python :: 3",
16 | "License :: OSI Approved :: MIT License",
17 | "Operating System :: OS Independent",
18 | ]
19 |
20 | [project.urls]
21 | "Homepage" = "https://github.com/pypa/sampleproject"
22 | "Bug Tracker" = "https://github.com/pypa/sampleproject/issues"
23 |
--------------------------------------------------------------------------------
/test_dms.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from dms import differentiable_topk, MASK_THRESHOLD
3 | import torch
4 |
5 | torch.set_printoptions(precision=3)
6 |
7 |
8 | class TestDMS(unittest.TestCase):
9 | def test_dms_topk(self):
10 | N = 10
11 | for a in [0, 0.3, 0.5, 0.7, 1.0]:
12 | c = torch.rand([N])
13 | a = torch.tensor([a])
14 | m = differentiable_topk(c, a, lambda_=c.numel())
15 |
16 | n_remain = (m > MASK_THRESHOLD).float().sum().int().item()
17 | self.assertTrue(n_remain == int((1 - a) * N))
18 | print(
19 | (
20 | f"Pruning Ratio: {a}\n"
21 | f"Element Importance: {c}\n"
22 | f"Soft mask: {m}\n"
23 | f"Number of Remained Elements: {n_remain}\n"
24 | ),
25 | )
26 |
--------------------------------------------------------------------------------