├── detection ├── __init__.py ├── model │ ├── teacher │ │ ├── __init__.py │ │ └── teacher.py │ ├── backbone │ │ ├── __init__.py │ │ ├── fpn.py │ │ └── mobilenetv2.py │ ├── __init__.py │ └── reviewkd.py ├── configs │ ├── DKD │ │ ├── DKD-MV2-R50.yaml │ │ ├── DKD-R50-R101.yaml │ │ ├── ReviewDKD-MV2-R50.yaml │ │ ├── ReviewDKD-R50-R101.yaml │ │ ├── DKD-R18-R101.yaml │ │ └── ReviewDKD-R18-R101.yaml │ ├── ReviewKD │ │ ├── ReviewKD-MV2-R50.yaml │ │ ├── ReviewKD-R50-R101.yaml │ │ ├── ReviewKD-MV2-R50-Mask.yaml │ │ ├── ReviewKD-R50-R101-Mask.yaml │ │ ├── ReviewKD-R18-R101.yaml │ │ └── ReviewKD-R18-R101-Mask.yaml │ └── Base-Distillation.yaml ├── README.md └── train_net.py ├── mdistiller ├── __init__.py ├── models │ ├── __init__.py │ ├── imagenet │ │ ├── __init__.py │ │ └── mobilenetv1.py │ └── cifar │ │ ├── __init__.py │ │ ├── mv2_tinyimagenet.py │ │ ├── ShuffleNetv1.py │ │ ├── wrn.py │ │ ├── mobilenetv2.py │ │ ├── ShuffleNetv2.py │ │ ├── vgg.py │ │ ├── resnetv2.py │ │ └── resnet.py ├── engine │ ├── __init__.py │ ├── utils.py │ ├── cfg.py │ └── dot.py ├── distillers │ ├── __init__.py │ ├── KD.py │ ├── _common.py │ ├── SP.py │ ├── AT.py │ ├── FitNet.py │ ├── NST.py │ ├── _base.py │ ├── PKT.py │ ├── RKD.py │ ├── DKD.py │ ├── KDSVD.py │ ├── VID.py │ ├── ReviewKD.py │ └── OFD.py └── dataset │ ├── __init__.py │ ├── imagenet.py │ ├── tiny_imagenet.py │ └── cifar100.py ├── .gitattributes ├── .github ├── dkd.png ├── dot.png └── mdistiller.png ├── requirements.txt ├── configs ├── cifar100 │ ├── kd.yaml │ ├── sp.yaml │ ├── at.yaml │ ├── nst.yaml │ ├── pkt.yaml │ ├── rkd.yaml │ ├── vid.yaml │ ├── kdsvd.yaml │ ├── ofd.yaml │ ├── dkd │ │ ├── res50_mv2.yaml │ │ ├── wrn40_2_shuv1.yaml │ │ ├── res32x4_res8x4.yaml │ │ ├── res32x4_shuv1.yaml │ │ ├── res32x4_shuv2.yaml │ │ ├── vgg13_vgg8.yaml │ │ ├── res56_res20.yaml │ │ ├── res110_res32.yaml │ │ ├── vgg13_mv2.yaml │ │ ├── wrn40_2_wrn_16_2.yaml │ │ └── wrn40_2_wrn_40_1.yaml │ ├── fitnet.yaml │ ├── vanilla.yaml │ ├── crd.yaml │ ├── dot │ │ ├── vgg13_vgg8.yaml │ │ ├── res32x4_shuv2.yaml │ │ └── res32x4_res8x4.yaml │ └── reviewkd.yaml ├── tiny_imagenet │ └── dot │ │ ├── r18_mv2.yaml │ │ └── r18_shuv2.yaml └── imagenet │ ├── r34_r18 │ ├── at.yaml │ ├── dkd.yaml │ ├── kd.yaml │ ├── crd.yaml │ ├── dot.yaml │ └── reviewkd.yaml │ └── r50_mv1 │ ├── at.yaml │ ├── ofd.yaml │ ├── dkd.yaml │ ├── kd.yaml │ ├── crd.yaml │ ├── dot.yaml │ └── reviewkd.yaml ├── setup.py ├── LICENSE ├── .gitignore ├── tools ├── eval.py └── train.py └── README.md /detection/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mdistiller/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | * linguist-language=Python 2 | -------------------------------------------------------------------------------- /detection/model/teacher/__init__.py: -------------------------------------------------------------------------------- 1 | from .teacher import build_teacher 2 | -------------------------------------------------------------------------------- /.github/dkd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/mdistiller/HEAD/.github/dkd.png -------------------------------------------------------------------------------- /.github/dot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/mdistiller/HEAD/.github/dot.png -------------------------------------------------------------------------------- /.github/mdistiller.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/mdistiller/HEAD/.github/mdistiller.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.9.0 2 | torchvision==0.10.0 3 | tensorboard-logger==0.1.0 4 | yacs 5 | wandb 6 | tqdm 7 | tensorboardX 8 | -------------------------------------------------------------------------------- /mdistiller/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .cifar import cifar_model_dict, tiny_imagenet_model_dict 2 | from .imagenet import imagenet_model_dict 3 | -------------------------------------------------------------------------------- /detection/model/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import build_resnet_backbone_kd 2 | from .fpn import build_resnet_fpn_backbone_kd, build_mobilenetv2_fpn_backbone 3 | -------------------------------------------------------------------------------- /detection/model/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .rcnn import RCNNKD 4 | from .config import add_distillation_cfg 5 | from .backbone import build_resnet_fpn_backbone_kd 6 | -------------------------------------------------------------------------------- /mdistiller/engine/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import BaseTrainer, CRDTrainer, DOT, CRDDOT 2 | trainer_dict = { 3 | "base": BaseTrainer, 4 | "crd": CRDTrainer, 5 | "dot": DOT, 6 | "crd_dot": CRDDOT, 7 | } 8 | -------------------------------------------------------------------------------- /mdistiller/models/imagenet/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import resnet18, resnet34, resnet50, resnet101, resnet152 2 | from .mobilenetv1 import MobileNetV1 3 | 4 | 5 | imagenet_model_dict = { 6 | "ResNet18": resnet18, 7 | "ResNet34": resnet34, 8 | "ResNet50": resnet50, 9 | "ResNet101": resnet101, 10 | "MobileNetV1": MobileNetV1, 11 | } 12 | -------------------------------------------------------------------------------- /configs/cifar100/kd.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "kd,res32x4,res8x4" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "KD" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "resnet8x4" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/sp.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "sp,res32x4,res8x4" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "SP" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "resnet8x4" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/at.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "at,res32x4,res8x4" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "AT" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "resnet8x4" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | 19 | -------------------------------------------------------------------------------- /configs/cifar100/nst.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "nst,res32x4,res8x4" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "NST" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "resnet8x4" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/pkt.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "pkt,res32x4,res8x4" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "PKT" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "resnet8x4" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/rkd.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "rkd,res32x4,res8x4" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "RKD" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "resnet8x4" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/vid.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "vid,res32x4,res8x4" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "VID" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "resnet8x4" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/kdsvd.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "kdsvd,res32x4,res8x4" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "KDSVD" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "resnet8x4" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/ofd.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "ofd,res32x4,res8x4" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "OFD" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "resnet8x4" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | 19 | -------------------------------------------------------------------------------- /configs/cifar100/dkd/res50_mv2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "dkd,res50,mv2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "DKD" 7 | TEACHER: "ResNet50" 8 | STUDENT: "MobileNetV2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.01 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | 19 | -------------------------------------------------------------------------------- /configs/cifar100/fitnet.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "fitnet,res32x4,res8x4" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "FITNET" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "resnet8x4" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/vanilla.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "vanilla,res32x4,res8x4" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "NONE" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "resnet8x4" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/dkd/wrn40_2_shuv1.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "dkd,wrn_40_2,shuv1" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "DKD" 7 | TEACHER: "wrn_40_2" 8 | STUDENT: "ShuffleV1" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.01 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | 19 | -------------------------------------------------------------------------------- /configs/cifar100/crd.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "crd,res32x4,res8x4" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "CRD" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "resnet8x4" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | TRAINER: "crd" 19 | -------------------------------------------------------------------------------- /configs/cifar100/dkd/res32x4_res8x4.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "dkd,res32x4,res8x4" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "DKD" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "resnet8x4" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | 19 | -------------------------------------------------------------------------------- /configs/cifar100/dkd/res32x4_shuv1.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "dkd,res32x4,shuv1" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "DKD" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "ShuffleV1" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.01 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | 19 | -------------------------------------------------------------------------------- /configs/cifar100/dkd/res32x4_shuv2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "dkd,res32x4,shuv2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "DKD" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "ShuffleV2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.01 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | 19 | -------------------------------------------------------------------------------- /configs/cifar100/dkd/vgg13_vgg8.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "dkd,vgg13,vgg8" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "DKD" 7 | TEACHER: "vgg13" 8 | STUDENT: "vgg8" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | DKD: 19 | BETA: 6.0 20 | -------------------------------------------------------------------------------- /configs/cifar100/dkd/res56_res20.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "dkd,res56,res20" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "DKD" 7 | TEACHER: "resnet56" 8 | STUDENT: "resnet20" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | DKD: 19 | BETA: 2.0 20 | -------------------------------------------------------------------------------- /configs/cifar100/dkd/res110_res32.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "dkd,res110,res32" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "DKD" 7 | TEACHER: "resnet110" 8 | STUDENT: "resnet32" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | DKD: 19 | BETA: 2.0 20 | -------------------------------------------------------------------------------- /configs/cifar100/dkd/vgg13_mv2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "dkd,vgg13,mv2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "DKD" 7 | TEACHER: "vgg13" 8 | STUDENT: "MobileNetV2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.01 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | DKD: 19 | BETA: 6.0 20 | 21 | -------------------------------------------------------------------------------- /configs/cifar100/dkd/wrn40_2_wrn_16_2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "dkd,wrn_40_2,wrn_16_2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "DKD" 7 | TEACHER: "wrn_40_2" 8 | STUDENT: "wrn_16_2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | DKD: 19 | BETA: 6.0 20 | -------------------------------------------------------------------------------- /configs/cifar100/dkd/wrn40_2_wrn_40_1.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "dkd,wrn_40_2,wrn_40_1" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "DKD" 7 | TEACHER: "wrn_40_2" 8 | STUDENT: "wrn_40_1" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | DKD: 19 | BETA: 6.0 20 | -------------------------------------------------------------------------------- /configs/cifar100/dot/vgg13_vgg8.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "kd,dot,vgg13,vgg8" 4 | PROJECT: "dot_cifar" 5 | DISTILLER: 6 | TYPE: "KD" 7 | TEACHER: "vgg13" 8 | STUDENT: "vgg8" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | TRAINER: "dot" 19 | DOT: 20 | DELTA: 0.075 21 | -------------------------------------------------------------------------------- /configs/cifar100/dot/res32x4_shuv2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "kd,dot,res32x4,shuv2" 4 | PROJECT: "dot_cifar" 5 | DISTILLER: 6 | TYPE: "KD" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "ShuffleV2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.01 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | TRAINER: "dot" 19 | DOT: 20 | DELTA: 0.075 21 | -------------------------------------------------------------------------------- /configs/cifar100/dot/res32x4_res8x4.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "kd,dot,res32x4,res8x4" 4 | PROJECT: "dot_cifar" 5 | DISTILLER: 6 | TYPE: "KD" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "resnet8x4" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | TRAINER: "dot" 19 | DOT: 20 | DELTA: 0.075 21 | -------------------------------------------------------------------------------- /configs/cifar100/reviewkd.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "reviewkd,res32x4,res8x4" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "REVIEWKD" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "resnet8x4" 9 | REVIEWKD: 10 | REVIEWKD_WEIGHT: 5.0 11 | SOLVER: 12 | BATCH_SIZE: 64 13 | EPOCHS: 240 14 | LR: 0.05 15 | LR_DECAY_STAGES: [150, 180, 210] 16 | LR_DECAY_RATE: 0.1 17 | WEIGHT_DECAY: 0.0005 18 | MOMENTUM: 0.9 19 | TYPE: "SGD" 20 | 21 | -------------------------------------------------------------------------------- /configs/tiny_imagenet/dot/r18_mv2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "kd,dot,r18,mv2" 4 | PROJECT: "dot_tinyimagenet" 5 | DATASET: 6 | TYPE: "tiny_imagenet" 7 | NUM_WORKERS: 16 8 | DISTILLER: 9 | TYPE: "KD" 10 | TEACHER: "ResNet18" 11 | STUDENT: "MobileNetV2" 12 | SOLVER: 13 | BATCH_SIZE: 256 14 | EPOCHS: 200 15 | LR: 0.2 16 | LR_DECAY_STAGES: [60, 120, 160] 17 | LR_DECAY_RATE: 0.1 18 | WEIGHT_DECAY: 0.0005 19 | MOMENTUM: 0.9 20 | TYPE: "SGD" 21 | TRAINER: "dot" 22 | DOT: 23 | DELTA: 0.075 24 | -------------------------------------------------------------------------------- /configs/tiny_imagenet/dot/r18_shuv2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "kd,dot,r18,shuv2" 4 | PROJECT: "dot_tinyimagenet" 5 | DATASET: 6 | TYPE: "tiny_imagenet" 7 | NUM_WORKERS: 16 8 | DISTILLER: 9 | TYPE: "KD" 10 | TEACHER: "ResNet18" 11 | STUDENT: "ShuffleV2" 12 | SOLVER: 13 | BATCH_SIZE: 256 14 | EPOCHS: 200 15 | LR: 0.2 16 | LR_DECAY_STAGES: [60, 120, 160] 17 | LR_DECAY_RATE: 0.1 18 | WEIGHT_DECAY: 0.0005 19 | MOMENTUM: 0.9 20 | TYPE: "SGD" 21 | TRAINER: "dot" 22 | DOT: 23 | DELTA: 0.075 24 | -------------------------------------------------------------------------------- /configs/imagenet/r34_r18/at.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "at,res34,res18" 4 | PROJECT: "imagenet_baselines" 5 | DATASET: 6 | TYPE: "imagenet" 7 | NUM_WORKERS: 32 8 | TEST: 9 | BATCH_SIZE: 128 10 | DISTILLER: 11 | TYPE: "AT" 12 | TEACHER: "ResNet34" 13 | STUDENT: "ResNet18" 14 | SOLVER: 15 | BATCH_SIZE: 512 16 | EPOCHS: 100 17 | LR: 0.2 18 | LR_DECAY_STAGES: [30, 60, 90] 19 | LR_DECAY_RATE: 0.1 20 | WEIGHT_DECAY: 0.0001 21 | MOMENTUM: 0.9 22 | TYPE: "SGD" 23 | LOG: 24 | TENSORBOARD_FREQ: 50 25 | SAVE_CHECKPOINT_FREQ: 10 26 | -------------------------------------------------------------------------------- /configs/imagenet/r50_mv1/at.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "at,res50,mobilenetv1" 4 | PROJECT: "imagenet_baselines" 5 | DATASET: 6 | TYPE: "imagenet" 7 | NUM_WORKERS: 32 8 | TEST: 9 | BATCH_SIZE: 128 10 | DISTILLER: 11 | TYPE: "AT" 12 | TEACHER: "ResNet50" 13 | STUDENT: "MobileNetV1" 14 | SOLVER: 15 | BATCH_SIZE: 512 16 | EPOCHS: 100 17 | LR: 0.2 18 | LR_DECAY_STAGES: [30, 60, 90] 19 | LR_DECAY_RATE: 0.1 20 | WEIGHT_DECAY: 0.0001 21 | MOMENTUM: 0.9 22 | TYPE: "SGD" 23 | LOG: 24 | TENSORBOARD_FREQ: 50 25 | SAVE_CHECKPOINT_FREQ: 10 26 | -------------------------------------------------------------------------------- /configs/imagenet/r50_mv1/ofd.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "ofd,res50,mobilenetv1" 4 | PROJECT: "imagenet_baselines" 5 | DATASET: 6 | TYPE: "imagenet" 7 | NUM_WORKERS: 32 8 | TEST: 9 | BATCH_SIZE: 128 10 | DISTILLER: 11 | TYPE: "OFD" 12 | TEACHER: "ResNet50" 13 | STUDENT: "MobileNetV1" 14 | SOLVER: 15 | BATCH_SIZE: 128 16 | EPOCHS: 100 17 | LR: 0.05 18 | LR_DECAY_STAGES: [30, 60, 90] 19 | LR_DECAY_RATE: 0.1 20 | WEIGHT_DECAY: 0.0001 21 | MOMENTUM: 0.9 22 | TYPE: "SGD" 23 | OFD: 24 | LOSS: 25 | FEAT_WEIGHT: 0.0001 26 | LOG: 27 | TENSORBOARD_FREQ: 50 28 | SAVE_CHECKPOINT_FREQ: 10 29 | -------------------------------------------------------------------------------- /configs/imagenet/r34_r18/dkd.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "dkd,res34,res18" 4 | PROJECT: "imagenet_baselines" 5 | DATASET: 6 | TYPE: "imagenet" 7 | NUM_WORKERS: 32 8 | TEST: 9 | BATCH_SIZE: 128 10 | DISTILLER: 11 | TYPE: "DKD" 12 | TEACHER: "ResNet34" 13 | STUDENT: "ResNet18" 14 | SOLVER: 15 | BATCH_SIZE: 512 16 | EPOCHS: 100 17 | LR: 0.2 18 | LR_DECAY_STAGES: [30, 60, 90] 19 | LR_DECAY_RATE: 0.1 20 | WEIGHT_DECAY: 0.0001 21 | MOMENTUM: 0.9 22 | TYPE: "SGD" 23 | LOG: 24 | TENSORBOARD_FREQ: 50 25 | SAVE_CHECKPOINT_FREQ: 10 26 | DKD: 27 | CE_WEIGHT: 1.0 28 | BETA: 0.5 29 | T: 1.0 30 | WARMUP: 1 31 | -------------------------------------------------------------------------------- /configs/imagenet/r34_r18/kd.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "kd,res34,res18" 4 | PROJECT: "imagenet_baselines" 5 | DATASET: 6 | TYPE: "imagenet" 7 | NUM_WORKERS: 32 8 | TEST: 9 | BATCH_SIZE: 128 10 | DISTILLER: 11 | TYPE: "KD" 12 | TEACHER: "ResNet34" 13 | STUDENT: "ResNet18" 14 | SOLVER: 15 | BATCH_SIZE: 512 16 | EPOCHS: 100 17 | LR: 0.2 18 | LR_DECAY_STAGES: [30, 60, 90] 19 | LR_DECAY_RATE: 0.1 20 | WEIGHT_DECAY: 0.0001 21 | MOMENTUM: 0.9 22 | TYPE: "SGD" 23 | KD: 24 | TEMPERATURE: 1 25 | LOSS: 26 | CE_WEIGHT: 0.5 27 | KD_WEIGHT: 0.5 28 | LOG: 29 | TENSORBOARD_FREQ: 50 30 | SAVE_CHECKPOINT_FREQ: 10 31 | -------------------------------------------------------------------------------- /configs/imagenet/r50_mv1/dkd.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "dkd,res50,mobilenetv1" 4 | PROJECT: "imagenet_baselines" 5 | DATASET: 6 | TYPE: "imagenet" 7 | NUM_WORKERS: 32 8 | TEST: 9 | BATCH_SIZE: 128 10 | DISTILLER: 11 | TYPE: "DKD" 12 | TEACHER: "ResNet50" 13 | STUDENT: "MobileNetV1" 14 | SOLVER: 15 | BATCH_SIZE: 512 16 | EPOCHS: 100 17 | LR: 0.2 18 | LR_DECAY_STAGES: [30, 60, 90] 19 | LR_DECAY_RATE: 0.1 20 | WEIGHT_DECAY: 0.0001 21 | MOMENTUM: 0.9 22 | TYPE: "SGD" 23 | LOG: 24 | TENSORBOARD_FREQ: 50 25 | SAVE_CHECKPOINT_FREQ: 10 26 | DKD: 27 | CE_WEIGHT: 1.0 28 | BETA: 2.0 29 | T: 1.0 30 | WARMUP: 1 31 | -------------------------------------------------------------------------------- /configs/imagenet/r34_r18/crd.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "crd,res34,res18" 4 | PROJECT: "imagenet_baselines" 5 | DATASET: 6 | TYPE: "imagenet" 7 | NUM_WORKERS: 32 8 | TEST: 9 | BATCH_SIZE: 128 10 | DISTILLER: 11 | TYPE: "CRD" 12 | TEACHER: "ResNet34" 13 | STUDENT: "ResNet18" 14 | SOLVER: 15 | BATCH_SIZE: 256 16 | EPOCHS: 100 17 | LR: 0.1 18 | LR_DECAY_STAGES: [30, 60, 90] 19 | LR_DECAY_RATE: 0.1 20 | WEIGHT_DECAY: 0.0001 21 | MOMENTUM: 0.9 22 | TYPE: "SGD" 23 | TRAINER: "crd" 24 | CRD: 25 | FEAT: 26 | STUDENT_DIM: 512 27 | TEACHER_DIM: 512 28 | LOG: 29 | TENSORBOARD_FREQ: 50 30 | SAVE_CHECKPOINT_FREQ: 10 31 | -------------------------------------------------------------------------------- /configs/imagenet/r50_mv1/kd.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "kd,res50,mobilenetv1" 4 | PROJECT: "imagenet_baselines" 5 | DATASET: 6 | TYPE: "imagenet" 7 | NUM_WORKERS: 32 8 | TEST: 9 | BATCH_SIZE: 128 10 | DISTILLER: 11 | TYPE: "KD" 12 | TEACHER: "ResNet50" 13 | STUDENT: "MobileNetV1" 14 | SOLVER: 15 | BATCH_SIZE: 512 16 | EPOCHS: 100 17 | LR: 0.2 18 | LR_DECAY_STAGES: [30, 60, 90] 19 | LR_DECAY_RATE: 0.1 20 | WEIGHT_DECAY: 0.0001 21 | MOMENTUM: 0.9 22 | TYPE: "SGD" 23 | KD: 24 | TEMPERATURE: 1 25 | LOSS: 26 | CE_WEIGHT: 0.5 27 | KD_WEIGHT: 0.5 28 | LOG: 29 | TENSORBOARD_FREQ: 50 30 | SAVE_CHECKPOINT_FREQ: 10 31 | -------------------------------------------------------------------------------- /configs/imagenet/r50_mv1/crd.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "crd,res50,mobilenetv1" 4 | PROJECT: "imagenet_baselines" 5 | DATASET: 6 | TYPE: "imagenet" 7 | NUM_WORKERS: 32 8 | TEST: 9 | BATCH_SIZE: 128 10 | DISTILLER: 11 | TYPE: "CRD" 12 | TEACHER: "ResNet50" 13 | STUDENT: "MobileNetV1" 14 | SOLVER: 15 | BATCH_SIZE: 256 16 | EPOCHS: 100 17 | LR: 0.1 18 | LR_DECAY_STAGES: [30, 60, 90] 19 | LR_DECAY_RATE: 0.1 20 | WEIGHT_DECAY: 0.0001 21 | MOMENTUM: 0.9 22 | TYPE: "SGD" 23 | TRAINER: "crd" 24 | CRD: 25 | FEAT: 26 | STUDENT_DIM: 1024 27 | TEACHER_DIM: 2048 28 | LOG: 29 | TENSORBOARD_FREQ: 50 30 | SAVE_CHECKPOINT_FREQ: 10 31 | -------------------------------------------------------------------------------- /detection/configs/DKD/DKD-MV2-R50.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-Distillation.yaml" 2 | OUTPUT_DIR: output/DKD-MV2-R50 3 | MODEL: 4 | BACKBONE: 5 | NAME: "build_mobilenetv2_fpn_backbone" 6 | FREEZE_AT: 0 7 | WEIGHTS: "pretrained/mv2-r50.pth" 8 | MOBILENETV2: 9 | OUT_FEATURES: ["m2", "m3", "m4", "m5"] 10 | FPN: 11 | IN_FEATURES: ["m2", "m3", "m4", "m5"] 12 | PROPOSAL_GENERATOR: 13 | NAME: "RPN" 14 | ROI_HEADS: 15 | NAME: "StandardROIHeads" 16 | 17 | TEACHER: 18 | MODEL: 19 | RESNETS: 20 | DEPTH: 50 21 | KD: 22 | TYPE: "DKD" 23 | 24 | SOLVER: 25 | IMS_PER_BATCH: 8 26 | BASE_LR: 0.01 27 | MAX_ITER: 180000 28 | STEPS: 29 | - 120000 30 | - 160000 31 | 32 | -------------------------------------------------------------------------------- /configs/imagenet/r34_r18/dot.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "kd,dot,res34,res18" 4 | PROJECT: "dot_imagenet" 5 | DATASET: 6 | TYPE: "imagenet" 7 | NUM_WORKERS: 32 8 | TEST: 9 | BATCH_SIZE: 128 10 | DISTILLER: 11 | TYPE: "KD" 12 | TEACHER: "ResNet34" 13 | STUDENT: "ResNet18" 14 | SOLVER: 15 | BATCH_SIZE: 512 16 | EPOCHS: 100 17 | LR: 0.2 18 | LR_DECAY_STAGES: [30, 60, 90] 19 | LR_DECAY_RATE: 0.1 20 | WEIGHT_DECAY: 0.0001 21 | MOMENTUM: 0.9 22 | TYPE: "SGD" 23 | TRAINER: "dot" 24 | DOT: 25 | DELTA: 0.09 26 | KD: 27 | TEMPERATURE: 1 28 | LOSS: 29 | CE_WEIGHT: 0.5 30 | KD_WEIGHT: 0.5 31 | LOG: 32 | TENSORBOARD_FREQ: 50 33 | SAVE_CHECKPOINT_FREQ: 10 34 | -------------------------------------------------------------------------------- /detection/configs/DKD/DKD-R50-R101.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-Distillation.yaml" 2 | OUTPUT_DIR: output/DKD-R50-R101 3 | MODEL: 4 | BACKBONE: 5 | NAME: "build_resnet_fpn_backbone_kd" 6 | WEIGHTS: "pretrained/r50-r101.pth" 7 | RESNETS: 8 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 9 | DEPTH: 50 10 | FPN: 11 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 12 | PROPOSAL_GENERATOR: 13 | NAME: "RPN" 14 | ROI_HEADS: 15 | NAME: "StandardROIHeads" 16 | 17 | TEACHER: 18 | MODEL: 19 | RESNETS: 20 | DEPTH: 101 21 | KD: 22 | TYPE: "DKD" 23 | 24 | SOLVER: 25 | IMS_PER_BATCH: 8 26 | BASE_LR: 0.01 27 | MAX_ITER: 180000 28 | STEPS: 29 | - 120000 30 | - 160000 31 | 32 | -------------------------------------------------------------------------------- /configs/imagenet/r50_mv1/dot.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "kd,dot,res50,mobilenetv1" 4 | PROJECT: "dot_imagenet" 5 | DATASET: 6 | TYPE: "imagenet" 7 | NUM_WORKERS: 32 8 | TEST: 9 | BATCH_SIZE: 128 10 | DISTILLER: 11 | TYPE: "KD" 12 | TEACHER: "ResNet50" 13 | STUDENT: "MobileNetV1" 14 | SOLVER: 15 | BATCH_SIZE: 512 16 | EPOCHS: 100 17 | LR: 0.2 18 | LR_DECAY_STAGES: [30, 60, 90] 19 | LR_DECAY_RATE: 0.1 20 | WEIGHT_DECAY: 0.0001 21 | MOMENTUM: 0.9 22 | TYPE: "SGD" 23 | TRAINER: "dot" 24 | DOT: 25 | DELTA: 0.09 26 | KD: 27 | TEMPERATURE: 1 28 | LOSS: 29 | CE_WEIGHT: 0.5 30 | KD_WEIGHT: 0.5 31 | LOG: 32 | TENSORBOARD_FREQ: 50 33 | SAVE_CHECKPOINT_FREQ: 10 34 | -------------------------------------------------------------------------------- /mdistiller/distillers/__init__.py: -------------------------------------------------------------------------------- 1 | from ._base import Vanilla 2 | from .KD import KD 3 | from .AT import AT 4 | from .OFD import OFD 5 | from .RKD import RKD 6 | from .FitNet import FitNet 7 | from .KDSVD import KDSVD 8 | from .CRD import CRD 9 | from .NST import NST 10 | from .PKT import PKT 11 | from .SP import SP 12 | from .VID import VID 13 | from .ReviewKD import ReviewKD 14 | from .DKD import DKD 15 | 16 | distiller_dict = { 17 | "NONE": Vanilla, 18 | "KD": KD, 19 | "AT": AT, 20 | "OFD": OFD, 21 | "RKD": RKD, 22 | "FITNET": FitNet, 23 | "KDSVD": KDSVD, 24 | "CRD": CRD, 25 | "NST": NST, 26 | "PKT": PKT, 27 | "SP": SP, 28 | "VID": VID, 29 | "REVIEWKD": ReviewKD, 30 | "DKD": DKD, 31 | } 32 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """MDistiller: a deep learning toolkit for knowledge distillation. 2 | """ 3 | 4 | import os.path 5 | import sys 6 | import setuptools 7 | 8 | 9 | if __name__ == "__main__": 10 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), "mdistiller")) 11 | DISTNAME = "mdistiller" 12 | DESCRIPTION = "A Deep Learning Toolkit for Knowledge Distillation." 13 | AUTHOR = "zhaoborui" 14 | DOCLINES = __doc__ 15 | 16 | setuptools.setup( 17 | name=DISTNAME, 18 | packages=setuptools.find_packages(), 19 | version="0.1", 20 | description=DESCRIPTION, 21 | long_description=DOCLINES, 22 | long_description_content_type="text/markdown", 23 | author=AUTHOR, 24 | ) 25 | -------------------------------------------------------------------------------- /detection/configs/DKD/ReviewDKD-MV2-R50.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-Distillation.yaml" 2 | OUTPUT_DIR: output/ReviewDKD-MV2-R50 3 | MODEL: 4 | BACKBONE: 5 | NAME: "build_mobilenetv2_fpn_backbone" 6 | FREEZE_AT: 0 7 | WEIGHTS: "pretrained/mv2-r50.pth" 8 | MOBILENETV2: 9 | OUT_FEATURES: ["m2", "m3", "m4", "m5"] 10 | FPN: 11 | IN_FEATURES: ["m2", "m3", "m4", "m5"] 12 | 13 | PROPOSAL_GENERATOR: 14 | NAME: "RPN" 15 | ROI_HEADS: 16 | NAME: "StandardROIHeads" 17 | 18 | TEACHER: 19 | MODEL: 20 | RESNETS: 21 | DEPTH: 50 22 | KD: 23 | TYPE: "ReviewDKD" 24 | REVIEWKD: 25 | LOSS_WEIGHT: 2.0 26 | 27 | SOLVER: 28 | IMS_PER_BATCH: 8 29 | BASE_LR: 0.01 30 | MAX_ITER: 180000 31 | STEPS: 32 | - 120000 33 | - 160000 34 | 35 | -------------------------------------------------------------------------------- /detection/configs/ReviewKD/ReviewKD-MV2-R50.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-Distillation.yaml" 2 | OUTPUT_DIR: output/ReviewKD-MV2-R50 3 | MODEL: 4 | BACKBONE: 5 | NAME: "build_mobilenetv2_fpn_backbone" 6 | FREEZE_AT: 0 7 | WEIGHTS: "pretrained/mv2-r50.pth" 8 | MOBILENETV2: 9 | OUT_FEATURES: ["m2", "m3", "m4", "m5"] 10 | FPN: 11 | IN_FEATURES: ["m2", "m3", "m4", "m5"] 12 | 13 | PROPOSAL_GENERATOR: 14 | NAME: "RPN" 15 | ROI_HEADS: 16 | NAME: "StandardROIHeads" 17 | 18 | TEACHER: 19 | MODEL: 20 | RESNETS: 21 | DEPTH: 50 22 | KD: 23 | TYPE: "ReviewKD" 24 | REVIEWKD: 25 | LOSS_WEIGHT: 2.0 26 | 27 | SOLVER: 28 | IMS_PER_BATCH: 8 29 | BASE_LR: 0.01 30 | MAX_ITER: 180000 31 | STEPS: 32 | - 120000 33 | - 160000 34 | 35 | -------------------------------------------------------------------------------- /detection/configs/DKD/ReviewDKD-R50-R101.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-Distillation.yaml" 2 | OUTPUT_DIR: output/ReviewDKD-R50-R101 3 | MODEL: 4 | BACKBONE: 5 | NAME: "build_resnet_fpn_backbone_kd" 6 | WEIGHTS: "pretrained/r50-r101.pth" 7 | RESNETS: 8 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 9 | DEPTH: 50 10 | FPN: 11 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 12 | PROPOSAL_GENERATOR: 13 | NAME: "RPN" 14 | ROI_HEADS: 15 | NAME: "StandardROIHeads" 16 | 17 | TEACHER: 18 | MODEL: 19 | RESNETS: 20 | DEPTH: 101 21 | KD: 22 | TYPE: "ReviewDKD" 23 | REVIEWKD: 24 | LOSS_WEIGHT: 1.0 25 | 26 | SOLVER: 27 | IMS_PER_BATCH: 8 28 | BASE_LR: 0.01 29 | MAX_ITER: 180000 30 | STEPS: 31 | - 120000 32 | - 160000 33 | 34 | -------------------------------------------------------------------------------- /detection/configs/ReviewKD/ReviewKD-R50-R101.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-Distillation.yaml" 2 | OUTPUT_DIR: output/ReviewKD-R50-R101 3 | MODEL: 4 | BACKBONE: 5 | NAME: "build_resnet_fpn_backbone_kd" 6 | WEIGHTS: "pretrained/r50-r101.pth" 7 | RESNETS: 8 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 9 | DEPTH: 50 10 | FPN: 11 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 12 | PROPOSAL_GENERATOR: 13 | NAME: "RPN" 14 | ROI_HEADS: 15 | NAME: "StandardROIHeads" 16 | 17 | TEACHER: 18 | MODEL: 19 | RESNETS: 20 | DEPTH: 101 21 | KD: 22 | TYPE: "ReviewKD" 23 | REVIEWKD: 24 | LOSS_WEIGHT: 1.0 25 | 26 | SOLVER: 27 | IMS_PER_BATCH: 8 28 | BASE_LR: 0.01 29 | MAX_ITER: 180000 30 | STEPS: 31 | - 120000 32 | - 160000 33 | 34 | -------------------------------------------------------------------------------- /detection/configs/ReviewKD/ReviewKD-MV2-R50-Mask.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-Distillation.yaml" 2 | OUTPUT_DIR: output/ReviewKD-MV2-R50-Mask 3 | MODEL: 4 | MASK_ON: True 5 | BACKBONE: 6 | NAME: "build_mobilenetv2_fpn_backbone" 7 | FREEZE_AT: 0 8 | WEIGHTS: "pretrained/mv2-r50mask.pth" 9 | MOBILENETV2: 10 | OUT_FEATURES: ["m2", "m3", "m4", "m5"] 11 | FPN: 12 | IN_FEATURES: ["m2", "m3", "m4", "m5"] 13 | PROPOSAL_GENERATOR: 14 | NAME: "RPN" 15 | ROI_HEADS: 16 | NAME: "StandardROIHeads" 17 | 18 | TEACHER: 19 | MODEL: 20 | RESNETS: 21 | DEPTH: 50 22 | KD: 23 | TYPE: "ReviewKD" 24 | REVIEWKD: 25 | LOSS_WEIGHT: 1.0 26 | 27 | SOLVER: 28 | IMS_PER_BATCH: 8 29 | BASE_LR: 0.01 30 | MAX_ITER: 180000 31 | STEPS: 32 | - 120000 33 | - 160000 34 | 35 | -------------------------------------------------------------------------------- /detection/configs/ReviewKD/ReviewKD-R50-R101-Mask.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-Distillation.yaml" 2 | OUTPUT_DIR: output/ReviewKD-R50-R101-Mask 3 | MODEL: 4 | MASK_ON: True 5 | BACKBONE: 6 | NAME: "build_resnet_fpn_backbone_kd" 7 | WEIGHTS: "pretrained/r50-r101mask.pth" 8 | RESNETS: 9 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 10 | DEPTH: 50 11 | FPN: 12 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 13 | PROPOSAL_GENERATOR: 14 | NAME: "RPN" 15 | ROI_HEADS: 16 | NAME: "StandardROIHeads" 17 | 18 | TEACHER: 19 | MODEL: 20 | RESNETS: 21 | DEPTH: 101 22 | KD: 23 | TYPE: "ReviewKD" 24 | REVIEWKD: 25 | LOSS_WEIGHT: 0.8 26 | 27 | SOLVER: 28 | IMS_PER_BATCH: 8 29 | BASE_LR: 0.01 30 | MAX_ITER: 180000 31 | STEPS: 32 | - 120000 33 | - 160000 34 | 35 | -------------------------------------------------------------------------------- /detection/configs/DKD/DKD-R18-R101.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-Distillation.yaml" 2 | OUTPUT_DIR: output/DKD-R18-R101 3 | INPUT: 4 | FORMAT: 'RGB' 5 | MODEL: 6 | PIXEL_STD: [57.375, 57.120, 58.395] 7 | BACKBONE: 8 | NAME: "build_resnet_fpn_backbone_kd" 9 | WEIGHTS: "pretrained/r18-r101.pth" 10 | RESNETS: 11 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 12 | DEPTH: 18 13 | RES2_OUT_CHANNELS: 64 14 | FPN: 15 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 16 | PROPOSAL_GENERATOR: 17 | NAME: "RPN" 18 | ROI_HEADS: 19 | NAME: "StandardROIHeads" 20 | 21 | TEACHER: 22 | MODEL: 23 | RESNETS: 24 | DEPTH: 101 25 | KD: 26 | TYPE: "DKD" 27 | 28 | SOLVER: 29 | IMS_PER_BATCH: 8 30 | BASE_LR: 0.01 31 | MAX_ITER: 180000 32 | STEPS: 33 | - 120000 34 | - 160000 35 | 36 | -------------------------------------------------------------------------------- /configs/imagenet/r34_r18/reviewkd.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "reviewkd,res34,res18" 4 | PROJECT: "imagenet_baselines" 5 | DATASET: 6 | TYPE: "imagenet" 7 | NUM_WORKERS: 32 8 | TEST: 9 | BATCH_SIZE: 128 10 | DISTILLER: 11 | TYPE: "REVIEWKD" 12 | TEACHER: "ResNet34" 13 | STUDENT: "ResNet18" 14 | SOLVER: 15 | BATCH_SIZE: 256 16 | EPOCHS: 100 17 | LR: 0.1 18 | LR_DECAY_STAGES: [30, 60, 90] 19 | LR_DECAY_RATE: 0.1 20 | WEIGHT_DECAY: 0.0001 21 | MOMENTUM: 0.9 22 | TYPE: "SGD" 23 | LOG: 24 | TENSORBOARD_FREQ: 50 25 | SAVE_CHECKPOINT_FREQ: 10 26 | REVIEWKD: 27 | CE_WEIGHT: 1.0 28 | REVIEWKD_WEIGHT: 1.0 29 | WARMUP_EPOCHS: 1 30 | SHAPES: [1,7,14,28,56] 31 | OUT_SHAPES: [1,7,14,28,56] 32 | IN_CHANNELS: [64,128,256,512,512] 33 | OUT_CHANNELS: [64,128,256,512,512] 34 | STU_PREACT: True 35 | -------------------------------------------------------------------------------- /detection/configs/DKD/ReviewDKD-R18-R101.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-Distillation.yaml" 2 | OUTPUT_DIR: output/ReviewDKD-R18-R101 3 | INPUT: 4 | FORMAT: 'RGB' 5 | MODEL: 6 | PIXEL_STD: [57.375, 57.120, 58.395] 7 | BACKBONE: 8 | NAME: "build_resnet_fpn_backbone_kd" 9 | WEIGHTS: "pretrained/r18-r101.pth" 10 | RESNETS: 11 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 12 | DEPTH: 18 13 | RES2_OUT_CHANNELS: 64 14 | FPN: 15 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 16 | PROPOSAL_GENERATOR: 17 | NAME: "RPN" 18 | ROI_HEADS: 19 | NAME: "StandardROIHeads" 20 | 21 | TEACHER: 22 | MODEL: 23 | RESNETS: 24 | DEPTH: 101 25 | KD: 26 | TYPE: "ReviewDKD" 27 | REVIEWKD: 28 | LOSS_WEIGHT: 1.2 29 | 30 | SOLVER: 31 | IMS_PER_BATCH: 8 32 | BASE_LR: 0.01 33 | MAX_ITER: 180000 34 | STEPS: 35 | - 120000 36 | - 160000 37 | 38 | -------------------------------------------------------------------------------- /detection/configs/ReviewKD/ReviewKD-R18-R101.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-Distillation.yaml" 2 | OUTPUT_DIR: output/ReviewKD-R18-R101 3 | INPUT: 4 | FORMAT: 'RGB' 5 | MODEL: 6 | PIXEL_STD: [57.375, 57.120, 58.395] 7 | BACKBONE: 8 | NAME: "build_resnet_fpn_backbone_kd" 9 | WEIGHTS: "pretrained/r18-r101.pth" 10 | RESNETS: 11 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 12 | DEPTH: 18 13 | RES2_OUT_CHANNELS: 64 14 | FPN: 15 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 16 | PROPOSAL_GENERATOR: 17 | NAME: "RPN" 18 | ROI_HEADS: 19 | NAME: "StandardROIHeads" 20 | 21 | TEACHER: 22 | MODEL: 23 | RESNETS: 24 | DEPTH: 101 25 | KD: 26 | TYPE: "ReviewKD" 27 | REVIEWKD: 28 | LOSS_WEIGHT: 1.2 29 | 30 | SOLVER: 31 | IMS_PER_BATCH: 8 32 | BASE_LR: 0.01 33 | MAX_ITER: 180000 34 | STEPS: 35 | - 120000 36 | - 160000 37 | 38 | -------------------------------------------------------------------------------- /configs/imagenet/r50_mv1/reviewkd.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "reviewkd,res50,mobilenetv1" 4 | PROJECT: "imagenet_baselines" 5 | DATASET: 6 | TYPE: "imagenet" 7 | NUM_WORKERS: 32 8 | TEST: 9 | BATCH_SIZE: 128 10 | DISTILLER: 11 | TYPE: "REVIEWKD" 12 | TEACHER: "ResNet50" 13 | STUDENT: "MobileNetV1" 14 | SOLVER: 15 | BATCH_SIZE: 256 16 | EPOCHS: 100 17 | LR: 0.1 18 | LR_DECAY_STAGES: [30, 60, 90] 19 | LR_DECAY_RATE: 0.1 20 | WEIGHT_DECAY: 0.0001 21 | MOMENTUM: 0.9 22 | TYPE: "SGD" 23 | LOG: 24 | TENSORBOARD_FREQ: 50 25 | SAVE_CHECKPOINT_FREQ: 10 26 | REVIEWKD: 27 | CE_WEIGHT: 1.0 28 | REVIEWKD_WEIGHT: 8.0 29 | WARMUP_EPOCHS: 1 30 | SHAPES: [1,7,14,28,56] 31 | OUT_SHAPES: [1,7,14,28,56] 32 | IN_CHANNELS: [128,256,512,1024,1024] 33 | OUT_CHANNELS: [256,512,1024,2048,2048] 34 | MAX_MID_CHANNEL: 256 35 | STU_PREACT: True 36 | -------------------------------------------------------------------------------- /detection/configs/ReviewKD/ReviewKD-R18-R101-Mask.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-Distillation.yaml" 2 | OUTPUT_DIR: output/ReviewKD-R18-R101-Mask 3 | INPUT: 4 | FORMAT: 'RGB' 5 | MODEL: 6 | PIXEL_STD: [57.375, 57.120, 58.395] 7 | MASK_ON: True 8 | BACKBONE: 9 | NAME: "build_resnet_fpn_backbone_kd" 10 | WEIGHTS: "pretrained/r18-r101mask.pth" 11 | RESNETS: 12 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 13 | DEPTH: 18 14 | RES2_OUT_CHANNELS: 64 15 | FPN: 16 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 17 | PROPOSAL_GENERATOR: 18 | NAME: "RPN" 19 | ROI_HEADS: 20 | NAME: "StandardROIHeads" 21 | 22 | TEACHER: 23 | MODEL: 24 | RESNETS: 25 | DEPTH: 101 26 | KD: 27 | TYPE: "ReviewKD" 28 | REVIEWKD: 29 | LOSS_WEIGHT: 1.5 30 | 31 | SOLVER: 32 | IMS_PER_BATCH: 8 33 | BASE_LR: 0.01 34 | MAX_ITER: 180000 35 | STEPS: 36 | - 120000 37 | - 160000 38 | 39 | -------------------------------------------------------------------------------- /detection/model/teacher/teacher.py: -------------------------------------------------------------------------------- 1 | from detectron2.modeling.backbone import build_backbone 2 | from detectron2.modeling.proposal_generator import build_proposal_generator 3 | from detectron2.modeling.roi_heads import build_roi_heads 4 | from detectron2.checkpoint import DetectionCheckpointer 5 | 6 | from torch import nn 7 | 8 | class Teacher(nn.Module): 9 | def __init__(self, backbone, proposal_generator, roi_heads): 10 | super().__init__() 11 | self.backbone = backbone 12 | self.proposal_generator = proposal_generator 13 | self.roi_heads = roi_heads 14 | 15 | def build_teacher(cfg): 16 | teacher_cfg = cfg.TEACHER 17 | backbone = build_backbone(teacher_cfg) 18 | if not 'Retina' in teacher_cfg.MODEL.META_ARCHITECTURE: 19 | proposal_generator = build_proposal_generator(teacher_cfg, backbone.output_shape()) 20 | roi_heads = build_roi_heads(teacher_cfg, backbone.output_shape()) 21 | else: 22 | proposal_generator = None 23 | roi_heads = None 24 | teacher = Teacher(backbone, proposal_generator, roi_heads) 25 | for param in teacher.parameters(): 26 | param.requires_grad = False 27 | return teacher 28 | 29 | 30 | -------------------------------------------------------------------------------- /detection/README.md: -------------------------------------------------------------------------------- 1 | # COCO object detection and instance segmentation 2 | 3 | PS: based on the [ReviewKD's codebase](https://github.com/dvlab-research/ReviewKD). 4 | 5 | ## Environment 6 | 7 | * 4 GPUs 8 | * python 3.6 9 | * torch 1.9.0 10 | * torchvision 0.10.0 11 | 12 | ## Installation 13 | 14 | Our code is based on Detectron2, please install Detectron2 refer to https://github.com/facebookresearch/detectron2. 15 | 16 | Please put the [COCO](https://cocodataset.org/#download) dataset in datasets/. 17 | 18 | Please put the pretrained weights for teacher and student in pretrained/. You can find the pretrained weights [here](https://github.com/dvlab-research/ReviewKD/releases/). The pretrained models we provided contains both teacher's and student's weights. The teacher's weights come from Detectron2's pretrained detector. The student's weights are ImageNet pretrained weights. 19 | 20 | ## Training 21 | 22 | ``` 23 | # Tea: R-101, Stu: R-18 24 | python3 train_net.py --config-file configs/DKD/DKD-R18-R101.yaml --num-gpus 4 25 | 26 | # Tea: R-101, Stu: R-50 27 | python3 train_net.py --config-file configs/DKD/DKD-R50-R101.yaml --num-gpus 4 28 | 29 | # Tea: R-50, Stu: MV2 30 | python3 train_net.py --config-file configs/DKD/DKD-MV2-R50.yaml --num-gpus 4 31 | 32 | ``` 33 | -------------------------------------------------------------------------------- /mdistiller/distillers/KD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ._base import Distiller 6 | 7 | 8 | def kd_loss(logits_student, logits_teacher, temperature): 9 | log_pred_student = F.log_softmax(logits_student / temperature, dim=1) 10 | pred_teacher = F.softmax(logits_teacher / temperature, dim=1) 11 | loss_kd = F.kl_div(log_pred_student, pred_teacher, reduction="none").sum(1).mean() 12 | loss_kd *= temperature**2 13 | return loss_kd 14 | 15 | 16 | class KD(Distiller): 17 | """Distilling the Knowledge in a Neural Network""" 18 | 19 | def __init__(self, student, teacher, cfg): 20 | super(KD, self).__init__(student, teacher) 21 | self.temperature = cfg.KD.TEMPERATURE 22 | self.ce_loss_weight = cfg.KD.LOSS.CE_WEIGHT 23 | self.kd_loss_weight = cfg.KD.LOSS.KD_WEIGHT 24 | 25 | def forward_train(self, image, target, **kwargs): 26 | logits_student, _ = self.student(image) 27 | with torch.no_grad(): 28 | logits_teacher, _ = self.teacher(image) 29 | 30 | # losses 31 | loss_ce = self.ce_loss_weight * F.cross_entropy(logits_student, target) 32 | loss_kd = self.kd_loss_weight * kd_loss( 33 | logits_student, logits_teacher, self.temperature 34 | ) 35 | losses_dict = { 36 | "loss_ce": loss_ce, 37 | "loss_kd": loss_kd, 38 | } 39 | return logits_student, losses_dict 40 | -------------------------------------------------------------------------------- /mdistiller/distillers/_common.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ConvReg(nn.Module): 7 | """Convolutional regression""" 8 | 9 | def __init__(self, s_shape, t_shape, use_relu=True): 10 | super(ConvReg, self).__init__() 11 | self.use_relu = use_relu 12 | s_N, s_C, s_H, s_W = s_shape 13 | t_N, t_C, t_H, t_W = t_shape 14 | if s_H == 2 * t_H: 15 | self.conv = nn.Conv2d(s_C, t_C, kernel_size=3, stride=2, padding=1) 16 | elif s_H * 2 == t_H: 17 | self.conv = nn.ConvTranspose2d(s_C, t_C, kernel_size=4, stride=2, padding=1) 18 | elif s_H >= t_H: 19 | self.conv = nn.Conv2d(s_C, t_C, kernel_size=(1 + s_H - t_H, 1 + s_W - t_W)) 20 | else: 21 | raise NotImplemented("student size {}, teacher size {}".format(s_H, t_H)) 22 | self.bn = nn.BatchNorm2d(t_C) 23 | self.relu = nn.ReLU(inplace=True) 24 | 25 | def forward(self, x): 26 | x = self.conv(x) 27 | if self.use_relu: 28 | return self.relu(self.bn(x)) 29 | else: 30 | return self.bn(x) 31 | 32 | 33 | def get_feat_shapes(student, teacher, input_size): 34 | data = torch.randn(1, 3, *input_size) 35 | with torch.no_grad(): 36 | _, feat_s = student(data) 37 | _, feat_t = teacher(data) 38 | feat_s_shapes = [f.shape for f in feat_s["feats"]] 39 | feat_t_shapes = [f.shape for f in feat_t["feats"]] 40 | return feat_s_shapes, feat_t_shapes 41 | -------------------------------------------------------------------------------- /mdistiller/distillers/SP.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ._base import Distiller 6 | 7 | 8 | def sp_loss(g_s, g_t): 9 | return sum([similarity_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]) 10 | 11 | 12 | def similarity_loss(f_s, f_t): 13 | bsz = f_s.shape[0] 14 | f_s = f_s.view(bsz, -1) 15 | f_t = f_t.view(bsz, -1) 16 | 17 | G_s = torch.mm(f_s, torch.t(f_s)) 18 | G_s = torch.nn.functional.normalize(G_s) 19 | G_t = torch.mm(f_t, torch.t(f_t)) 20 | G_t = torch.nn.functional.normalize(G_t) 21 | 22 | G_diff = G_t - G_s 23 | loss = (G_diff * G_diff).view(-1, 1).sum(0) / (bsz * bsz) 24 | return loss 25 | 26 | 27 | class SP(Distiller): 28 | """Similarity-Preserving Knowledge Distillation, ICCV2019""" 29 | 30 | def __init__(self, student, teacher, cfg): 31 | super(SP, self).__init__(student, teacher) 32 | self.ce_loss_weight = cfg.SP.LOSS.CE_WEIGHT 33 | self.feat_loss_weight = cfg.SP.LOSS.FEAT_WEIGHT 34 | 35 | def forward_train(self, image, target, **kwargs): 36 | logits_student, feature_student = self.student(image) 37 | with torch.no_grad(): 38 | _, feature_teacher = self.teacher(image) 39 | 40 | # losses 41 | loss_ce = self.ce_loss_weight * F.cross_entropy(logits_student, target) 42 | loss_feat = self.feat_loss_weight * sp_loss( 43 | [feature_student["feats"][-1]], [feature_teacher["feats"][-1]] 44 | ) 45 | losses_dict = { 46 | "loss_ce": loss_ce, 47 | "loss_kd": loss_feat, 48 | } 49 | return logits_student, losses_dict 50 | -------------------------------------------------------------------------------- /mdistiller/distillers/AT.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ._base import Distiller 6 | 7 | 8 | def single_stage_at_loss(f_s, f_t, p): 9 | def _at(feat, p): 10 | return F.normalize(feat.pow(p).mean(1).reshape(feat.size(0), -1)) 11 | 12 | s_H, t_H = f_s.shape[2], f_t.shape[2] 13 | if s_H > t_H: 14 | f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H)) 15 | elif s_H < t_H: 16 | f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H)) 17 | return (_at(f_s, p) - _at(f_t, p)).pow(2).mean() 18 | 19 | 20 | def at_loss(g_s, g_t, p): 21 | return sum([single_stage_at_loss(f_s, f_t, p) for f_s, f_t in zip(g_s, g_t)]) 22 | 23 | 24 | class AT(Distiller): 25 | """ 26 | Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks via Attention Transfer 27 | src code: https://github.com/szagoruyko/attention-transfer 28 | """ 29 | 30 | def __init__(self, student, teacher, cfg): 31 | super(AT, self).__init__(student, teacher) 32 | self.p = cfg.AT.P 33 | self.ce_loss_weight = cfg.AT.LOSS.CE_WEIGHT 34 | self.feat_loss_weight = cfg.AT.LOSS.FEAT_WEIGHT 35 | 36 | def forward_train(self, image, target, **kwargs): 37 | logits_student, feature_student = self.student(image) 38 | with torch.no_grad(): 39 | _, feature_teacher = self.teacher(image) 40 | 41 | # losses 42 | loss_ce = self.ce_loss_weight * F.cross_entropy(logits_student, target) 43 | loss_feat = self.feat_loss_weight * at_loss( 44 | feature_student["feats"][1:], feature_teacher["feats"][1:], self.p 45 | ) 46 | losses_dict = { 47 | "loss_ce": loss_ce, 48 | "loss_kd": loss_feat, 49 | } 50 | return logits_student, losses_dict 51 | -------------------------------------------------------------------------------- /mdistiller/distillers/FitNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ._base import Distiller 6 | from ._common import ConvReg, get_feat_shapes 7 | 8 | 9 | class FitNet(Distiller): 10 | """FitNets: Hints for Thin Deep Nets""" 11 | 12 | def __init__(self, student, teacher, cfg): 13 | super(FitNet, self).__init__(student, teacher) 14 | self.ce_loss_weight = cfg.FITNET.LOSS.CE_WEIGHT 15 | self.feat_loss_weight = cfg.FITNET.LOSS.FEAT_WEIGHT 16 | self.hint_layer = cfg.FITNET.HINT_LAYER 17 | feat_s_shapes, feat_t_shapes = get_feat_shapes( 18 | self.student, self.teacher, cfg.FITNET.INPUT_SIZE 19 | ) 20 | self.conv_reg = ConvReg( 21 | feat_s_shapes[self.hint_layer], feat_t_shapes[self.hint_layer] 22 | ) 23 | 24 | def get_learnable_parameters(self): 25 | return super().get_learnable_parameters() + list(self.conv_reg.parameters()) 26 | 27 | def get_extra_parameters(self): 28 | num_p = 0 29 | for p in self.conv_reg.parameters(): 30 | num_p += p.numel() 31 | return num_p 32 | 33 | def forward_train(self, image, target, **kwargs): 34 | logits_student, feature_student = self.student(image) 35 | with torch.no_grad(): 36 | _, feature_teacher = self.teacher(image) 37 | 38 | # losses 39 | loss_ce = self.ce_loss_weight * F.cross_entropy(logits_student, target) 40 | f_s = self.conv_reg(feature_student["feats"][self.hint_layer]) 41 | loss_feat = self.feat_loss_weight * F.mse_loss( 42 | f_s, feature_teacher["feats"][self.hint_layer] 43 | ) 44 | losses_dict = { 45 | "loss_ce": loss_ce, 46 | "loss_kd": loss_feat, 47 | } 48 | return logits_student, losses_dict 49 | -------------------------------------------------------------------------------- /detection/model/backbone/fpn.py: -------------------------------------------------------------------------------- 1 | from .resnet import build_resnet_backbone_kd 2 | from .mobilenetv2 import build_mobilenetv2_backbone 3 | from detectron2.modeling.backbone import BACKBONE_REGISTRY, FPN 4 | from detectron2.modeling.backbone.fpn import LastLevelMaxPool 5 | from detectron2.layers import Conv2d, ShapeSpec, get_norm 6 | 7 | 8 | 9 | @BACKBONE_REGISTRY.register() 10 | def build_resnet_fpn_backbone_kd(cfg, input_shape: ShapeSpec): 11 | # add ResNet 18 12 | """ 13 | Args: 14 | cfg: a detectron2 CfgNode 15 | 16 | Returns: 17 | backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`. 18 | """ 19 | bottom_up = build_resnet_backbone_kd(cfg, input_shape) 20 | in_features = cfg.MODEL.FPN.IN_FEATURES 21 | out_channels = cfg.MODEL.FPN.OUT_CHANNELS 22 | backbone = FPN( 23 | bottom_up=bottom_up, 24 | in_features=in_features, 25 | out_channels=out_channels, 26 | norm=cfg.MODEL.FPN.NORM, 27 | top_block=LastLevelMaxPool(), 28 | fuse_type=cfg.MODEL.FPN.FUSE_TYPE, 29 | ) 30 | return backbone 31 | 32 | @BACKBONE_REGISTRY.register() 33 | def build_mobilenetv2_fpn_backbone(cfg, input_shape: ShapeSpec): 34 | """ 35 | Args: 36 | cfg: a detectron2 CfgNode 37 | 38 | Returns: 39 | backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`. 40 | """ 41 | bottom_up = build_mobilenetv2_backbone(cfg, input_shape) 42 | in_features = cfg.MODEL.FPN.IN_FEATURES 43 | out_channels = cfg.MODEL.FPN.OUT_CHANNELS 44 | backbone = FPN( 45 | bottom_up=bottom_up, 46 | in_features=in_features, 47 | out_channels=out_channels, 48 | norm=cfg.MODEL.FPN.NORM, 49 | top_block=LastLevelMaxPool(), 50 | fuse_type=cfg.MODEL.FPN.FUSE_TYPE, 51 | ) 52 | return backbone 53 | 54 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 MEGVII Research 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | 24 | 25 | detectron2 26 | 27 | Copyright 2020 - present, Facebook, Inc 28 | 29 | Licensed under the Apache License, Version 2.0 (the "License"); 30 | you may not use this file except in compliance with the License. 31 | You may obtain a copy of the License at 32 | 33 | http://www.apache.org/licenses/LICENSE-2.0 34 | 35 | Unless required by applicable law or agreed to in writing, software 36 | distributed under the License is distributed on an "AS IS" BASIS, 37 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 38 | See the License for the specific language governing permissions and 39 | limitations under the License. 40 | 41 | 42 | RepDistiller 43 | 44 | Copyright (c) 2020, Yonglong Tian 45 | -------------------------------------------------------------------------------- /mdistiller/distillers/NST.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ._base import Distiller 6 | 7 | 8 | def nst_loss(g_s, g_t): 9 | return sum([single_stage_nst_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]) 10 | 11 | 12 | def single_stage_nst_loss(f_s, f_t): 13 | s_H, t_H = f_s.shape[2], f_t.shape[2] 14 | if s_H > t_H: 15 | f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H)) 16 | elif s_H < t_H: 17 | f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H)) 18 | 19 | f_s = f_s.view(f_s.shape[0], f_s.shape[1], -1) 20 | f_s = F.normalize(f_s, dim=2) 21 | f_t = f_t.view(f_t.shape[0], f_t.shape[1], -1) 22 | f_t = F.normalize(f_t, dim=2) 23 | 24 | return ( 25 | poly_kernel(f_t, f_t).mean().detach() 26 | + poly_kernel(f_s, f_s).mean() 27 | - 2 * poly_kernel(f_s, f_t).mean() 28 | ) 29 | 30 | 31 | def poly_kernel(a, b): 32 | a = a.unsqueeze(1) 33 | b = b.unsqueeze(2) 34 | res = (a * b).sum(-1).pow(2) 35 | return res 36 | 37 | 38 | class NST(Distiller): 39 | """ 40 | Like What You Like: Knowledge Distill via Neuron Selectivity Transfer 41 | """ 42 | 43 | def __init__(self, student, teacher, cfg): 44 | super(NST, self).__init__(student, teacher) 45 | self.ce_loss_weight = cfg.NST.LOSS.CE_WEIGHT 46 | self.feat_loss_weight = cfg.NST.LOSS.FEAT_WEIGHT 47 | 48 | def forward_train(self, image, target, **kwargs): 49 | logits_student, feature_student = self.student(image) 50 | with torch.no_grad(): 51 | _, feature_teacher = self.teacher(image) 52 | 53 | # losses 54 | loss_ce = self.ce_loss_weight * F.cross_entropy(logits_student, target) 55 | loss_feat = self.feat_loss_weight * nst_loss( 56 | feature_student["feats"][1:], feature_teacher["feats"][1:] 57 | ) 58 | losses_dict = { 59 | "loss_ce": loss_ce, 60 | "loss_kd": loss_feat, 61 | } 62 | return logits_student, losses_dict 63 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | data 3 | data/ 4 | output 5 | output*/ 6 | detection/datasets 7 | detection/output 8 | detection/output*/ 9 | ckpts/ 10 | *.pth 11 | *.t7 12 | tmp*.py 13 | 14 | *.pdf 15 | 16 | 17 | # Byte-compiled / optimized / DLL files 18 | __pycache__/ 19 | *.py[cod] 20 | *$py.class 21 | 22 | # C extensions 23 | *.so 24 | 25 | # Distribution / packaging 26 | .Python 27 | build/ 28 | develop-eggs/ 29 | dist/ 30 | downloads/ 31 | eggs/ 32 | .eggs/ 33 | lib/ 34 | lib64/ 35 | parts/ 36 | sdist/ 37 | var/ 38 | wheels/ 39 | *.egg-info/ 40 | .installed.cfg 41 | *.egg 42 | MANIFEST 43 | 44 | # PyInstaller 45 | # Usually these files are written by a python script from a template 46 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 47 | *.manifest 48 | *.spec 49 | 50 | # Installer logs 51 | pip-log.txt 52 | pip-delete-this-directory.txt 53 | 54 | # Unit test / coverage reports 55 | htmlcov/ 56 | .tox/ 57 | .coverage 58 | .coverage.* 59 | .cache 60 | nosetests.xml 61 | coverage.xml 62 | *.cover 63 | .hypothesis/ 64 | .pytest_cache/ 65 | 66 | # Translations 67 | *.mo 68 | *.pot 69 | 70 | # Django stuff: 71 | *.log 72 | local_settings.py 73 | db.sqlite3 74 | 75 | # Flask stuff: 76 | instance/ 77 | .webassets-cache 78 | 79 | # Scrapy stuff: 80 | .scrapy 81 | 82 | # Sphinx documentation 83 | docs/_build/ 84 | 85 | # PyBuilder 86 | target/ 87 | 88 | # Jupyter Notebook 89 | .ipynb_checkpoints 90 | 91 | # pyenv 92 | .python-version 93 | 94 | # celery beat schedule file 95 | celerybeat-schedule 96 | 97 | # SageMath parsed files 98 | *.sage.py 99 | 100 | # Environments 101 | .env 102 | .venv 103 | env/ 104 | venv/ 105 | ENV/ 106 | env.bak/ 107 | venv.bak/ 108 | 109 | # Spyder project settings 110 | .spyderproject 111 | .spyproject 112 | 113 | # Rope project settings 114 | .ropeproject 115 | 116 | # mkdocs documentation 117 | /site 118 | 119 | # mypy 120 | .mypy_cache/ 121 | 122 | # wandb 123 | wandb/ 124 | 125 | # download ckpts 126 | download_ckpts/ -------------------------------------------------------------------------------- /mdistiller/distillers/_base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Distiller(nn.Module): 7 | def __init__(self, student, teacher): 8 | super(Distiller, self).__init__() 9 | self.student = student 10 | self.teacher = teacher 11 | 12 | def train(self, mode=True): 13 | # teacher as eval mode by default 14 | if not isinstance(mode, bool): 15 | raise ValueError("training mode is expected to be boolean") 16 | self.training = mode 17 | for module in self.children(): 18 | module.train(mode) 19 | self.teacher.eval() 20 | return self 21 | 22 | def get_learnable_parameters(self): 23 | # if the method introduces extra parameters, re-impl this function 24 | return [v for k, v in self.student.named_parameters()] 25 | 26 | def get_extra_parameters(self): 27 | # calculate the extra parameters introduced by the distiller 28 | return 0 29 | 30 | def forward_train(self, **kwargs): 31 | # training function for the distillation method 32 | raise NotImplementedError() 33 | 34 | def forward_test(self, image): 35 | return self.student(image)[0] 36 | 37 | def forward(self, **kwargs): 38 | if self.training: 39 | return self.forward_train(**kwargs) 40 | return self.forward_test(kwargs["image"]) 41 | 42 | 43 | class Vanilla(nn.Module): 44 | def __init__(self, student): 45 | super(Vanilla, self).__init__() 46 | self.student = student 47 | 48 | def get_learnable_parameters(self): 49 | return [v for k, v in self.student.named_parameters()] 50 | 51 | def forward_train(self, image, target, **kwargs): 52 | logits_student, _ = self.student(image) 53 | loss = F.cross_entropy(logits_student, target) 54 | return logits_student, {"ce": loss} 55 | 56 | def forward(self, **kwargs): 57 | if self.training: 58 | return self.forward_train(**kwargs) 59 | return self.forward_test(kwargs["image"]) 60 | 61 | def forward_test(self, image): 62 | return self.student(image)[0] 63 | -------------------------------------------------------------------------------- /tools/eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.backends.cudnn as cudnn 4 | 5 | cudnn.benchmark = True 6 | 7 | from mdistiller.distillers import Vanilla 8 | from mdistiller.models import cifar_model_dict, imagenet_model_dict, tiny_imagenet_model_dict 9 | from mdistiller.dataset import get_dataset 10 | from mdistiller.dataset.imagenet import get_imagenet_val_loader 11 | from mdistiller.engine.utils import load_checkpoint, validate 12 | from mdistiller.engine.cfg import CFG as cfg 13 | 14 | 15 | if __name__ == "__main__": 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("-m", "--model", type=str, default="") 18 | parser.add_argument("-c", "--ckpt", type=str, default="pretrain") 19 | parser.add_argument( 20 | "-d", 21 | "--dataset", 22 | type=str, 23 | default="cifar100", 24 | choices=["cifar100", "imagenet", "tiny_imagenet"], 25 | ) 26 | parser.add_argument("-bs", "--batch-size", type=int, default=64) 27 | args = parser.parse_args() 28 | 29 | cfg.DATASET.TYPE = args.dataset 30 | cfg.DATASET.TEST.BATCH_SIZE = args.batch_size 31 | if args.dataset == "imagenet": 32 | val_loader = get_imagenet_val_loader(args.batch_size) 33 | if args.ckpt == "pretrain": 34 | model = imagenet_model_dict[args.model](pretrained=True) 35 | else: 36 | model = imagenet_model_dict[args.model](pretrained=False) 37 | model.load_state_dict(load_checkpoint(args.ckpt)["model"]) 38 | elif args.dataset in ("cifar100", "tiny_imagenet"): 39 | train_loader, val_loader, num_data, num_classes = get_dataset(cfg) 40 | model_dict = tiny_imagenet_model_dict if args.dataset == "tiny_imagenet" else cifar_model_dict 41 | model, pretrain_model_path = model_dict[args.model] 42 | model = model(num_classes=num_classes) 43 | ckpt = pretrain_model_path if args.ckpt == "pretrain" else args.ckpt 44 | model.load_state_dict(load_checkpoint(ckpt)["model"]) 45 | model = Vanilla(model) 46 | model = model.cuda() 47 | model = torch.nn.DataParallel(model) 48 | test_acc, test_acc_top5, test_loss = validate(val_loader, model) 49 | -------------------------------------------------------------------------------- /mdistiller/distillers/PKT.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ._base import Distiller 6 | 7 | 8 | def pkt_loss(f_s, f_t, eps=1e-7): 9 | # Normalize each vector by its norm 10 | output_net_norm = torch.sqrt(torch.sum(f_s**2, dim=1, keepdim=True)) 11 | f_s = f_s / (output_net_norm + eps) 12 | f_s[f_s != f_s] = 0 13 | target_net_norm = torch.sqrt(torch.sum(f_t**2, dim=1, keepdim=True)) 14 | f_t = f_t / (target_net_norm + eps) 15 | f_t[f_t != f_t] = 0 16 | 17 | # Calculate the cosine similarity 18 | model_similarity = torch.mm(f_s, f_s.transpose(0, 1)) 19 | target_similarity = torch.mm(f_t, f_t.transpose(0, 1)) 20 | # Scale cosine similarity to 0..1 21 | model_similarity = (model_similarity + 1.0) / 2.0 22 | target_similarity = (target_similarity + 1.0) / 2.0 23 | # Transform them into probabilities 24 | model_similarity = model_similarity / torch.sum( 25 | model_similarity, dim=1, keepdim=True 26 | ) 27 | target_similarity = target_similarity / torch.sum( 28 | target_similarity, dim=1, keepdim=True 29 | ) 30 | # Calculate the KL-divergence 31 | loss = torch.mean( 32 | target_similarity 33 | * torch.log((target_similarity + eps) / (model_similarity + eps)) 34 | ) 35 | return loss 36 | 37 | 38 | class PKT(Distiller): 39 | """ 40 | Probabilistic Knowledge Transfer for deep representation learning 41 | Code from: https://github.com/passalis/probabilistic_kt 42 | """ 43 | 44 | def __init__(self, student, teacher, cfg): 45 | super(PKT, self).__init__(student, teacher) 46 | self.ce_loss_weight = cfg.PKT.LOSS.CE_WEIGHT 47 | self.feat_loss_weight = cfg.PKT.LOSS.FEAT_WEIGHT 48 | 49 | def forward_train(self, image, target, **kwargs): 50 | logits_student, feature_student = self.student(image) 51 | with torch.no_grad(): 52 | _, feature_teacher = self.teacher(image) 53 | 54 | # lossess 55 | loss_ce = self.ce_loss_weight * F.cross_entropy(logits_student, target) 56 | loss_feat = self.feat_loss_weight * pkt_loss( 57 | feature_student["pooled_feat"], feature_teacher["pooled_feat"] 58 | ) 59 | losses_dict = { 60 | "loss_ce": loss_ce, 61 | "loss_kd": loss_feat, 62 | } 63 | return logits_student, losses_dict 64 | -------------------------------------------------------------------------------- /mdistiller/models/imagenet/mobilenetv1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class MobileNetV1(nn.Module): 7 | def __init__(self, **kwargs): 8 | super(MobileNetV1, self).__init__() 9 | 10 | def conv_bn(inp, oup, stride): 11 | return nn.Sequential( 12 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 13 | nn.BatchNorm2d(oup), 14 | nn.ReLU(inplace=True), 15 | ) 16 | 17 | def conv_dw(inp, oup, stride): 18 | return nn.Sequential( 19 | nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), 20 | nn.BatchNorm2d(inp), 21 | nn.ReLU(inplace=True), 22 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 23 | nn.BatchNorm2d(oup), 24 | nn.ReLU(inplace=True), 25 | ) 26 | 27 | self.model = nn.Sequential( 28 | conv_bn(3, 32, 2), 29 | conv_dw(32, 64, 1), 30 | conv_dw(64, 128, 2), 31 | conv_dw(128, 128, 1), 32 | conv_dw(128, 256, 2), 33 | conv_dw(256, 256, 1), 34 | conv_dw(256, 512, 2), 35 | conv_dw(512, 512, 1), 36 | conv_dw(512, 512, 1), 37 | conv_dw(512, 512, 1), 38 | conv_dw(512, 512, 1), 39 | conv_dw(512, 512, 1), 40 | conv_dw(512, 1024, 2), 41 | conv_dw(1024, 1024, 1), 42 | nn.AvgPool2d(7), 43 | ) 44 | self.fc = nn.Linear(1024, 1000) 45 | 46 | def forward(self, x, is_feat=False): 47 | feat1 = self.model[3][:-1](self.model[0:3](x)) 48 | feat2 = self.model[5][:-1](self.model[4:5](F.relu(feat1))) 49 | feat3 = self.model[11][:-1](self.model[6:11](F.relu(feat2))) 50 | feat4 = self.model[13][:-1](self.model[12:13](F.relu(feat3))) 51 | feat5 = self.model[14](F.relu(feat4)) 52 | avg = feat5.reshape(-1, 1024) 53 | out = self.fc(avg) 54 | 55 | feats = {} 56 | feats["pooled_feat"] = avg 57 | feats["feats"] = [F.relu(feat1), F.relu(feat2), F.relu(feat3), F.relu(feat4)] 58 | feats["preact_feats"] = [feat1, feat2, feat3, feat4] 59 | return out, feats 60 | 61 | def get_bn_before_relu(self): 62 | bn1 = self.model[3][-2] 63 | bn2 = self.model[5][-2] 64 | bn3 = self.model[11][-2] 65 | bn4 = self.model[13][-2] 66 | return [bn1, bn2, bn3, bn4] 67 | 68 | def get_stage_channels(self): 69 | return [128, 256, 512, 1024] 70 | -------------------------------------------------------------------------------- /mdistiller/models/cifar/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .resnet import ( 3 | resnet8, 4 | resnet14, 5 | resnet20, 6 | resnet32, 7 | resnet44, 8 | resnet56, 9 | resnet110, 10 | resnet8x4, 11 | resnet32x4, 12 | ) 13 | from .resnetv2 import ResNet50, ResNet18 14 | from .wrn import wrn_16_1, wrn_16_2, wrn_40_1, wrn_40_2 15 | from .vgg import vgg19_bn, vgg16_bn, vgg13_bn, vgg11_bn, vgg8_bn 16 | from .mobilenetv2 import mobile_half 17 | from .ShuffleNetv1 import ShuffleV1 18 | from .ShuffleNetv2 import ShuffleV2 19 | from .mv2_tinyimagenet import mobilenetv2_tinyimagenet 20 | 21 | 22 | cifar100_model_prefix = os.path.join( 23 | os.path.dirname(os.path.abspath(__file__)), 24 | "../../../download_ckpts/cifar_teachers/" 25 | ) 26 | 27 | tiny_imagenet_model_prefix = os.path.join( 28 | os.path.dirname(os.path.abspath(__file__)), 29 | "../../../download_ckpts/tiny_imagenet_teachers/" 30 | ) 31 | 32 | cifar_model_dict = { 33 | # teachers 34 | "resnet56": ( 35 | resnet56, 36 | cifar100_model_prefix + "resnet56_vanilla/ckpt_epoch_240.pth", 37 | ), 38 | "resnet110": ( 39 | resnet110, 40 | cifar100_model_prefix + "resnet110_vanilla/ckpt_epoch_240.pth", 41 | ), 42 | "resnet32x4": ( 43 | resnet32x4, 44 | cifar100_model_prefix + "resnet32x4_vanilla/ckpt_epoch_240.pth", 45 | ), 46 | "ResNet50": ( 47 | ResNet50, 48 | cifar100_model_prefix + "ResNet50_vanilla/ckpt_epoch_240.pth", 49 | ), 50 | "wrn_40_2": ( 51 | wrn_40_2, 52 | cifar100_model_prefix + "wrn_40_2_vanilla/ckpt_epoch_240.pth", 53 | ), 54 | "vgg13": (vgg13_bn, cifar100_model_prefix + "vgg13_vanilla/ckpt_epoch_240.pth"), 55 | # students 56 | "resnet8": (resnet8, None), 57 | "resnet14": (resnet14, None), 58 | "resnet20": (resnet20, None), 59 | "resnet32": (resnet32, None), 60 | "resnet44": (resnet44, None), 61 | "resnet8x4": (resnet8x4, None), 62 | "ResNet18": (ResNet18, None), 63 | "wrn_16_1": (wrn_16_1, None), 64 | "wrn_16_2": (wrn_16_2, None), 65 | "wrn_40_1": (wrn_40_1, None), 66 | "vgg8": (vgg8_bn, None), 67 | "vgg11": (vgg11_bn, None), 68 | "vgg16": (vgg16_bn, None), 69 | "vgg19": (vgg19_bn, None), 70 | "MobileNetV2": (mobile_half, None), 71 | "ShuffleV1": (ShuffleV1, None), 72 | "ShuffleV2": (ShuffleV2, None), 73 | } 74 | 75 | 76 | tiny_imagenet_model_dict = { 77 | "ResNet18": (ResNet18, tiny_imagenet_model_prefix + "ResNet18_vanilla/ti_res18"), 78 | "MobileNetV2": (mobilenetv2_tinyimagenet, None), 79 | "ShuffleV2": (ShuffleV2, None), 80 | } 81 | -------------------------------------------------------------------------------- /mdistiller/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .cifar100 import get_cifar100_dataloaders, get_cifar100_dataloaders_sample 2 | from .imagenet import get_imagenet_dataloaders, get_imagenet_dataloaders_sample 3 | from .tiny_imagenet import get_tinyimagenet_dataloader, get_tinyimagenet_dataloader_sample 4 | 5 | 6 | def get_dataset(cfg): 7 | if cfg.DATASET.TYPE == "cifar100": 8 | if cfg.DISTILLER.TYPE == "CRD": 9 | train_loader, val_loader, num_data = get_cifar100_dataloaders_sample( 10 | batch_size=cfg.SOLVER.BATCH_SIZE, 11 | val_batch_size=cfg.DATASET.TEST.BATCH_SIZE, 12 | num_workers=cfg.DATASET.NUM_WORKERS, 13 | k=cfg.CRD.NCE.K, 14 | mode=cfg.CRD.MODE, 15 | ) 16 | else: 17 | train_loader, val_loader, num_data = get_cifar100_dataloaders( 18 | batch_size=cfg.SOLVER.BATCH_SIZE, 19 | val_batch_size=cfg.DATASET.TEST.BATCH_SIZE, 20 | num_workers=cfg.DATASET.NUM_WORKERS, 21 | ) 22 | num_classes = 100 23 | elif cfg.DATASET.TYPE == "imagenet": 24 | if cfg.DISTILLER.TYPE == "CRD": 25 | train_loader, val_loader, num_data = get_imagenet_dataloaders_sample( 26 | batch_size=cfg.SOLVER.BATCH_SIZE, 27 | val_batch_size=cfg.DATASET.TEST.BATCH_SIZE, 28 | num_workers=cfg.DATASET.NUM_WORKERS, 29 | k=cfg.CRD.NCE.K, 30 | ) 31 | else: 32 | train_loader, val_loader, num_data = get_imagenet_dataloaders( 33 | batch_size=cfg.SOLVER.BATCH_SIZE, 34 | val_batch_size=cfg.DATASET.TEST.BATCH_SIZE, 35 | num_workers=cfg.DATASET.NUM_WORKERS, 36 | ) 37 | num_classes = 1000 38 | elif cfg.DATASET.TYPE == "tiny_imagenet": 39 | if cfg.DISTILLER.TYPE in ("CRD", "CRDKD"): 40 | train_loader, val_loader, num_data = get_tinyimagenet_dataloader_sample( 41 | batch_size=cfg.SOLVER.BATCH_SIZE, 42 | val_batch_size=cfg.DATASET.TEST.BATCH_SIZE, 43 | num_workers=cfg.DATASET.NUM_WORKERS, 44 | k=cfg.CRD.NCE.K, 45 | ) 46 | else: 47 | train_loader, val_loader, num_data = get_tinyimagenet_dataloader( 48 | batch_size=cfg.SOLVER.BATCH_SIZE, 49 | val_batch_size=cfg.DATASET.TEST.BATCH_SIZE, 50 | num_workers=cfg.DATASET.NUM_WORKERS, 51 | ) 52 | num_classes = 200 53 | else: 54 | raise NotImplementedError(cfg.DATASET.TYPE) 55 | 56 | return train_loader, val_loader, num_data, num_classes 57 | -------------------------------------------------------------------------------- /detection/configs/Base-Distillation.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "RCNNKD" 3 | BACKBONE: 4 | NAME: "build_resnet_fpn_backbone_kd" 5 | RESNETS: 6 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 7 | FPN: 8 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 9 | ANCHOR_GENERATOR: 10 | SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map 11 | ASPECT_RATIOS: [[0.5, 1.0, 2.0]] # Three aspect ratios (same for all in feature maps) 12 | RPN: 13 | IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"] 14 | PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level 15 | PRE_NMS_TOPK_TEST: 1000 # Per FPN level 16 | # Detectron1 uses 2000 proposals per-batch, 17 | # (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue) 18 | # which is approximately 1000 proposals per-image since the default batch size for FPN is 2. 19 | POST_NMS_TOPK_TRAIN: 1000 20 | POST_NMS_TOPK_TEST: 1000 21 | ROI_HEADS: 22 | NAME: "StandardROIHeads" 23 | IN_FEATURES: ["p2", "p3", "p4", "p5"] 24 | ROI_BOX_HEAD: 25 | NAME: "FastRCNNConvFCHead" 26 | NUM_FC: 2 27 | POOLER_RESOLUTION: 7 28 | ROI_MASK_HEAD: 29 | NAME: "MaskRCNNConvUpsampleHead" 30 | NUM_CONV: 4 31 | POOLER_RESOLUTION: 14 32 | 33 | TEACHER: 34 | MODEL: 35 | META_ARCHITECTURE: "GeneralizedRCNN" 36 | BACKBONE: 37 | NAME: "build_resnet_fpn_backbone" 38 | RESNETS: 39 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 40 | FPN: 41 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 42 | ANCHOR_GENERATOR: 43 | SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map 44 | ASPECT_RATIOS: [[0.5, 1.0, 2.0]] # Three aspect ratios (same for all in feature maps) 45 | RPN: 46 | IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"] 47 | PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level 48 | PRE_NMS_TOPK_TEST: 1000 # Per FPN level 49 | # Detectron1 uses 2000 proposals per-batch, 50 | # (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue) 51 | # which is approximately 1000 proposals per-image since the default batch size for FPN is 2. 52 | POST_NMS_TOPK_TRAIN: 1000 53 | POST_NMS_TOPK_TEST: 1000 54 | ROI_HEADS: 55 | NAME: "StandardROIHeads" 56 | IN_FEATURES: ["p2", "p3", "p4", "p5"] 57 | ROI_BOX_HEAD: 58 | NAME: "FastRCNNConvFCHead" 59 | NUM_FC: 2 60 | POOLER_RESOLUTION: 7 61 | ROI_MASK_HEAD: 62 | NAME: "MaskRCNNConvUpsampleHead" 63 | NUM_CONV: 4 64 | POOLER_RESOLUTION: 14 65 | 66 | 67 | DATASETS: 68 | TRAIN: ("coco_2017_train",) 69 | TEST: ("coco_2017_val",) 70 | SOLVER: 71 | IMS_PER_BATCH: 16 72 | BASE_LR: 0.02 73 | STEPS: (60000, 80000) 74 | MAX_ITER: 90000 75 | INPUT: 76 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) 77 | VERSION: 2 78 | -------------------------------------------------------------------------------- /mdistiller/distillers/RKD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ._base import Distiller 6 | 7 | 8 | def _pdist(e, squared, eps): 9 | e_square = e.pow(2).sum(dim=1) 10 | prod = e @ e.t() 11 | res = (e_square.unsqueeze(1) + e_square.unsqueeze(0) - 2 * prod).clamp(min=eps) 12 | 13 | if not squared: 14 | res = res.sqrt() 15 | 16 | res = res.clone() 17 | res[range(len(e)), range(len(e))] = 0 18 | return res 19 | 20 | 21 | def rkd_loss(f_s, f_t, squared=False, eps=1e-12, distance_weight=25, angle_weight=50): 22 | stu = f_s.view(f_s.shape[0], -1) 23 | tea = f_t.view(f_t.shape[0], -1) 24 | 25 | # RKD distance loss 26 | with torch.no_grad(): 27 | t_d = _pdist(tea, squared, eps) 28 | mean_td = t_d[t_d > 0].mean() 29 | t_d = t_d / mean_td 30 | 31 | d = _pdist(stu, squared, eps) 32 | mean_d = d[d > 0].mean() 33 | d = d / mean_d 34 | 35 | loss_d = F.smooth_l1_loss(d, t_d) 36 | 37 | # RKD Angle loss 38 | with torch.no_grad(): 39 | td = tea.unsqueeze(0) - tea.unsqueeze(1) 40 | norm_td = F.normalize(td, p=2, dim=2) 41 | t_angle = torch.bmm(norm_td, norm_td.transpose(1, 2)).view(-1) 42 | 43 | sd = stu.unsqueeze(0) - stu.unsqueeze(1) 44 | norm_sd = F.normalize(sd, p=2, dim=2) 45 | s_angle = torch.bmm(norm_sd, norm_sd.transpose(1, 2)).view(-1) 46 | 47 | loss_a = F.smooth_l1_loss(s_angle, t_angle) 48 | 49 | loss = distance_weight * loss_d + angle_weight * loss_a 50 | return loss 51 | 52 | 53 | class RKD(Distiller): 54 | """Relational Knowledge Disitllation, CVPR2019""" 55 | 56 | def __init__(self, student, teacher, cfg): 57 | super(RKD, self).__init__(student, teacher) 58 | self.distance_weight = cfg.RKD.DISTANCE_WEIGHT 59 | self.angle_weight = cfg.RKD.ANGLE_WEIGHT 60 | self.ce_loss_weight = cfg.RKD.LOSS.CE_WEIGHT 61 | self.feat_loss_weight = cfg.RKD.LOSS.FEAT_WEIGHT 62 | self.eps = cfg.RKD.PDIST.EPSILON 63 | self.squared = cfg.RKD.PDIST.SQUARED 64 | 65 | def forward_train(self, image, target, **kwargs): 66 | logits_student, feature_student = self.student(image) 67 | with torch.no_grad(): 68 | _, feature_teacher = self.teacher(image) 69 | 70 | # losses 71 | loss_ce = self.ce_loss_weight * F.cross_entropy(logits_student, target) 72 | loss_rkd = self.feat_loss_weight * rkd_loss( 73 | feature_student["pooled_feat"], 74 | feature_teacher["pooled_feat"], 75 | self.squared, 76 | self.eps, 77 | self.distance_weight, 78 | self.angle_weight, 79 | ) 80 | losses_dict = { 81 | "loss_ce": loss_ce, 82 | "loss_kd": loss_rkd, 83 | } 84 | return logits_student, losses_dict 85 | -------------------------------------------------------------------------------- /mdistiller/distillers/DKD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ._base import Distiller 6 | 7 | 8 | def dkd_loss(logits_student, logits_teacher, target, alpha, beta, temperature): 9 | gt_mask = _get_gt_mask(logits_student, target) 10 | other_mask = _get_other_mask(logits_student, target) 11 | pred_student = F.softmax(logits_student / temperature, dim=1) 12 | pred_teacher = F.softmax(logits_teacher / temperature, dim=1) 13 | pred_student = cat_mask(pred_student, gt_mask, other_mask) 14 | pred_teacher = cat_mask(pred_teacher, gt_mask, other_mask) 15 | log_pred_student = torch.log(pred_student) 16 | tckd_loss = ( 17 | F.kl_div(log_pred_student, pred_teacher, size_average=False) 18 | * (temperature**2) 19 | / target.shape[0] 20 | ) 21 | pred_teacher_part2 = F.softmax( 22 | logits_teacher / temperature - 1000.0 * gt_mask, dim=1 23 | ) 24 | log_pred_student_part2 = F.log_softmax( 25 | logits_student / temperature - 1000.0 * gt_mask, dim=1 26 | ) 27 | nckd_loss = ( 28 | F.kl_div(log_pred_student_part2, pred_teacher_part2, size_average=False) 29 | * (temperature**2) 30 | / target.shape[0] 31 | ) 32 | return alpha * tckd_loss + beta * nckd_loss 33 | 34 | 35 | def _get_gt_mask(logits, target): 36 | target = target.reshape(-1) 37 | mask = torch.zeros_like(logits).scatter_(1, target.unsqueeze(1), 1).bool() 38 | return mask 39 | 40 | 41 | def _get_other_mask(logits, target): 42 | target = target.reshape(-1) 43 | mask = torch.ones_like(logits).scatter_(1, target.unsqueeze(1), 0).bool() 44 | return mask 45 | 46 | 47 | def cat_mask(t, mask1, mask2): 48 | t1 = (t * mask1).sum(dim=1, keepdims=True) 49 | t2 = (t * mask2).sum(1, keepdims=True) 50 | rt = torch.cat([t1, t2], dim=1) 51 | return rt 52 | 53 | 54 | class DKD(Distiller): 55 | """Decoupled Knowledge Distillation(CVPR 2022)""" 56 | 57 | def __init__(self, student, teacher, cfg): 58 | super(DKD, self).__init__(student, teacher) 59 | self.ce_loss_weight = cfg.DKD.CE_WEIGHT 60 | self.alpha = cfg.DKD.ALPHA 61 | self.beta = cfg.DKD.BETA 62 | self.temperature = cfg.DKD.T 63 | self.warmup = cfg.DKD.WARMUP 64 | 65 | def forward_train(self, image, target, **kwargs): 66 | logits_student, _ = self.student(image) 67 | with torch.no_grad(): 68 | logits_teacher, _ = self.teacher(image) 69 | 70 | # losses 71 | loss_ce = self.ce_loss_weight * F.cross_entropy(logits_student, target) 72 | loss_dkd = min(kwargs["epoch"] / self.warmup, 1.0) * dkd_loss( 73 | logits_student, 74 | logits_teacher, 75 | target, 76 | self.alpha, 77 | self.beta, 78 | self.temperature, 79 | ) 80 | losses_dict = { 81 | "loss_ce": loss_ce, 82 | "loss_kd": loss_dkd, 83 | } 84 | return logits_student, losses_dict 85 | -------------------------------------------------------------------------------- /mdistiller/distillers/KDSVD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ._base import Distiller 6 | 7 | 8 | def kdsvd_loss(g_s, g_t, k): 9 | v_sb = None 10 | v_tb = None 11 | losses = [] 12 | for i, f_s, f_t in zip(range(len(g_s)), g_s, g_t): 13 | u_t, s_t, v_t = svd(f_t, k) 14 | u_s, s_s, v_s = svd(f_s, k + 3) 15 | v_s, v_t = align_rsv(v_s, v_t) 16 | s_t = s_t.unsqueeze(1) 17 | v_t = v_t * s_t 18 | v_s = v_s * s_t 19 | 20 | if i > 0: 21 | s_rbf = torch.exp(-(v_s.unsqueeze(2) - v_sb.unsqueeze(1)).pow(2) / 8) 22 | t_rbf = torch.exp(-(v_t.unsqueeze(2) - v_tb.unsqueeze(1)).pow(2) / 8) 23 | 24 | l2loss = (s_rbf - t_rbf.detach()).pow(2) 25 | l2loss = torch.where( 26 | torch.isfinite(l2loss), l2loss, torch.zeros_like(l2loss) 27 | ) 28 | losses.append(l2loss.sum()) 29 | 30 | v_tb = v_t 31 | v_sb = v_s 32 | 33 | bsz = g_s[0].shape[0] 34 | losses = [l / bsz for l in losses] 35 | return sum(losses) 36 | 37 | 38 | def svd(feat, n=1): 39 | size = feat.shape 40 | assert len(size) == 4 41 | 42 | x = feat.view(size[0], size[1] * size[2], size[3]).float() 43 | u, s, v = torch.svd(x) 44 | 45 | u = removenan(u) 46 | s = removenan(s) 47 | v = removenan(v) 48 | 49 | if n > 0: 50 | u = F.normalize(u[:, :, :n], dim=1) 51 | s = F.normalize(s[:, :n], dim=1) 52 | v = F.normalize(v[:, :, :n], dim=1) 53 | 54 | return u, s, v 55 | 56 | 57 | def removenan(x): 58 | x = torch.where(torch.isfinite(x), x, torch.zeros_like(x)) 59 | return x 60 | 61 | 62 | def align_rsv(a, b): 63 | cosine = torch.matmul(a.transpose(-2, -1), b) 64 | max_abs_cosine, _ = torch.max(torch.abs(cosine), 1, keepdim=True) 65 | mask = torch.where( 66 | torch.eq(max_abs_cosine, torch.abs(cosine)), 67 | torch.sign(cosine), 68 | torch.zeros_like(cosine), 69 | ) 70 | a = torch.matmul(a, mask) 71 | return a, b 72 | 73 | 74 | class KDSVD(Distiller): 75 | """ 76 | Self-supervised Knowledge Distillation using Singular Value Decomposition 77 | original Tensorflow code: https://github.com/sseung0703/SSKD_SVD 78 | """ 79 | 80 | def __init__(self, student, teacher, cfg): 81 | super(KDSVD, self).__init__(student, teacher) 82 | self.k = cfg.KDSVD.K 83 | self.ce_loss_weight = cfg.KDSVD.LOSS.CE_WEIGHT 84 | self.feat_loss_weight = cfg.KDSVD.LOSS.FEAT_WEIGHT 85 | 86 | def forward_train(self, image, target, **kwargs): 87 | logits_student, feature_student = self.student(image) 88 | with torch.no_grad(): 89 | _, feature_teacher = self.teacher(image) 90 | loss_ce = self.ce_loss_weight * F.cross_entropy(logits_student, target) 91 | loss_feat = self.feat_loss_weight * kdsvd_loss( 92 | feature_student["feats"][1:], feature_teacher["feats"][1:], self.k 93 | ) 94 | losses_dict = { 95 | "loss_ce": loss_ce, 96 | "loss_kd": loss_feat, 97 | } 98 | return logits_student, losses_dict 99 | -------------------------------------------------------------------------------- /detection/model/reviewkd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | class ABF(nn.Module): 6 | def __init__(self, in_channel, mid_channel, out_channel, fuse): 7 | super(ABF, self).__init__() 8 | self.conv1 = nn.Sequential( 9 | nn.Conv2d(in_channel, mid_channel, kernel_size=1, bias=False), 10 | nn.BatchNorm2d(mid_channel), 11 | ) 12 | self.conv2 = nn.Sequential( 13 | nn.Conv2d(mid_channel, out_channel,kernel_size=3,stride=1,padding=1,bias=False), 14 | nn.BatchNorm2d(out_channel), 15 | ) 16 | if fuse: 17 | self.att_conv = nn.Sequential( 18 | nn.Conv2d(mid_channel*2, 2, kernel_size=1), 19 | nn.Sigmoid(), 20 | ) 21 | else: 22 | self.att_conv = None 23 | nn.init.kaiming_uniform_(self.conv1[0].weight, a=1) # pyre-ignore 24 | nn.init.kaiming_uniform_(self.conv2[0].weight, a=1) # pyre-ignore 25 | 26 | def forward(self, x, y=None, shape=None): 27 | n,_,h,w = x.shape 28 | # transform student features 29 | x = self.conv1(x) 30 | if self.att_conv is not None: 31 | # upsample residual features 32 | shape = x.shape[-2:] 33 | y = F.interpolate(y, shape, mode="nearest") 34 | # fusion 35 | z = torch.cat([x, y], dim=1) 36 | z = self.att_conv(z) 37 | x = (x * z[:,0].view(n,1,h,w) + y * z[:,1].view(n,1,h,w)) 38 | # output 39 | y = self.conv2(x) 40 | return y, x 41 | 42 | class ReviewKD(nn.Module): 43 | def __init__( 44 | self, in_channels, out_channels, mid_channel 45 | ): 46 | super(ReviewKD, self).__init__() 47 | 48 | abfs = nn.ModuleList() 49 | 50 | for idx, in_channel in enumerate(in_channels): 51 | abfs.append(ABF(in_channel, mid_channel, out_channels[idx], idx < len(in_channels)-1)) 52 | 53 | 54 | self.abfs = abfs[::-1] 55 | 56 | def forward(self, student_features): 57 | x = student_features[::-1] 58 | results = [] 59 | out_features, res_features = self.abfs[0](x[0]) 60 | results.append(out_features) 61 | for features, abf in zip(x[1:], self.abfs[1:]): 62 | out_features, res_features = abf(features, res_features) 63 | results.insert(0, out_features) 64 | 65 | return results 66 | 67 | 68 | def build_kd_trans(cfg): 69 | in_channels = [256,256,256,256,256] 70 | out_channels = [256,256,256,256,256] 71 | mid_channel = 256 72 | model = ReviewKD(in_channels, out_channels, mid_channel) 73 | return model 74 | 75 | def hcl(fstudent, fteacher): 76 | loss_all = 0.0 77 | for fs, ft in zip(fstudent, fteacher): 78 | n,c,h,w = fs.shape 79 | loss = F.mse_loss(fs, ft, reduction='mean') 80 | cnt = 1.0 81 | tot = 1.0 82 | for l in [4,2,1]: 83 | if l >=h: 84 | continue 85 | tmpfs = F.adaptive_avg_pool2d(fs, (l,l)) 86 | tmpft = F.adaptive_avg_pool2d(ft, (l,l)) 87 | cnt /= 2.0 88 | loss += F.mse_loss(tmpfs, tmpft, reduction='mean') * cnt 89 | tot += cnt 90 | loss = loss / tot 91 | loss_all = loss_all + loss 92 | return loss_all 93 | -------------------------------------------------------------------------------- /mdistiller/models/cifar/mv2_tinyimagenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class LinearBottleNeck(nn.Module): 7 | 8 | def __init__(self, in_channels, out_channels, stride, t=6, class_num=100): 9 | super().__init__() 10 | 11 | self.residual = nn.Sequential( 12 | nn.Conv2d(in_channels, in_channels * t, 1), 13 | nn.BatchNorm2d(in_channels * t), 14 | nn.ReLU6(inplace=True), 15 | 16 | nn.Conv2d(in_channels * t, in_channels * t, 3, stride=stride, padding=1, groups=in_channels * t), 17 | nn.BatchNorm2d(in_channels * t), 18 | nn.ReLU6(inplace=True), 19 | 20 | nn.Conv2d(in_channels * t, out_channels, 1), 21 | nn.BatchNorm2d(out_channels) 22 | ) 23 | 24 | self.stride = stride 25 | self.in_channels = in_channels 26 | self.out_channels = out_channels 27 | 28 | def forward(self, x): 29 | 30 | residual = self.residual(x) 31 | 32 | if self.stride == 1 and self.in_channels == self.out_channels: 33 | residual += x 34 | 35 | return residual 36 | 37 | class MobileNetV2(nn.Module): 38 | 39 | def __init__(self, num_classes=100): 40 | super().__init__() 41 | 42 | self.pre = nn.Sequential( 43 | nn.Conv2d(3, 32, 1, padding=1), 44 | nn.BatchNorm2d(32), 45 | nn.ReLU6(inplace=True) 46 | ) 47 | 48 | self.stage1 = LinearBottleNeck(32, 16, 1, 1) 49 | self.stage2 = self._make_stage(2, 16, 24, 2, 6) 50 | self.stage3 = self._make_stage(3, 24, 32, 2, 6) 51 | self.stage4 = self._make_stage(4, 32, 64, 2, 6) 52 | self.stage5 = self._make_stage(3, 64, 96, 1, 6) 53 | self.stage6 = self._make_stage(3, 96, 160, 1, 6) 54 | self.stage7 = LinearBottleNeck(160, 320, 1, 6) 55 | 56 | self.conv1 = nn.Sequential( 57 | nn.Conv2d(320, 1280, 1), 58 | nn.BatchNorm2d(1280), 59 | nn.ReLU6(inplace=True) 60 | ) 61 | 62 | self.conv2 = nn.Conv2d(1280, num_classes, 1) 63 | 64 | def forward(self, x): 65 | x = self.pre(x) 66 | f0 = x 67 | x = self.stage1(x) 68 | x = self.stage2(x) 69 | f1 = x 70 | x = self.stage3(x) 71 | f2 = x 72 | x = self.stage4(x) 73 | f3 = x 74 | x = self.stage5(x) 75 | x = self.stage6(x) 76 | x = self.stage7(x) 77 | x = self.conv1(x) 78 | f4 = x 79 | x = F.adaptive_avg_pool2d(x, 1) 80 | avg = x 81 | x = self.conv2(x) 82 | x = x.view(x.size(0), -1) 83 | feats = {} 84 | feats["feats"] = [f0, f1, f2, f3, f4] 85 | feats["pooled_feat"] = avg 86 | 87 | return x, feats 88 | 89 | def _make_stage(self, repeat, in_channels, out_channels, stride, t): 90 | 91 | layers = [] 92 | layers.append(LinearBottleNeck(in_channels, out_channels, stride, t)) 93 | 94 | while repeat - 1: 95 | layers.append(LinearBottleNeck(out_channels, out_channels, 1, t)) 96 | repeat -= 1 97 | 98 | return nn.Sequential(*layers) 99 | 100 | def mobilenetv2_tinyimagenet(**kwargs): 101 | return MobileNetV2(**kwargs) -------------------------------------------------------------------------------- /mdistiller/engine/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | import sys 6 | import time 7 | from tqdm import tqdm 8 | 9 | 10 | class AverageMeter(object): 11 | """Computes and stores the average and current value""" 12 | 13 | def __init__(self): 14 | self.reset() 15 | 16 | def reset(self): 17 | self.val = 0 18 | self.avg = 0 19 | self.sum = 0 20 | self.count = 0 21 | 22 | def update(self, val, n=1): 23 | self.val = val 24 | self.sum += val * n 25 | self.count += n 26 | self.avg = self.sum / self.count 27 | 28 | 29 | def validate(val_loader, distiller): 30 | batch_time, losses, top1, top5 = [AverageMeter() for _ in range(4)] 31 | criterion = nn.CrossEntropyLoss() 32 | num_iter = len(val_loader) 33 | pbar = tqdm(range(num_iter)) 34 | 35 | distiller.eval() 36 | with torch.no_grad(): 37 | start_time = time.time() 38 | for idx, (image, target) in enumerate(val_loader): 39 | image = image.float() 40 | image = image.cuda(non_blocking=True) 41 | target = target.cuda(non_blocking=True) 42 | output = distiller(image=image) 43 | loss = criterion(output, target) 44 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 45 | batch_size = image.size(0) 46 | losses.update(loss.cpu().detach().numpy().mean(), batch_size) 47 | top1.update(acc1[0], batch_size) 48 | top5.update(acc5[0], batch_size) 49 | 50 | # measure elapsed time 51 | batch_time.update(time.time() - start_time) 52 | start_time = time.time() 53 | msg = "Top-1:{top1.avg:.3f}| Top-5:{top5.avg:.3f}".format( 54 | top1=top1, top5=top5 55 | ) 56 | pbar.set_description(log_msg(msg, "EVAL")) 57 | pbar.update() 58 | pbar.close() 59 | return top1.avg, top5.avg, losses.avg 60 | 61 | 62 | def log_msg(msg, mode="INFO"): 63 | color_map = { 64 | "INFO": 36, 65 | "TRAIN": 32, 66 | "EVAL": 31, 67 | } 68 | msg = "\033[{}m[{}] {}\033[0m".format(color_map[mode], mode, msg) 69 | return msg 70 | 71 | 72 | def adjust_learning_rate(epoch, cfg, optimizer): 73 | steps = np.sum(epoch > np.asarray(cfg.SOLVER.LR_DECAY_STAGES)) 74 | if steps > 0: 75 | new_lr = cfg.SOLVER.LR * (cfg.SOLVER.LR_DECAY_RATE**steps) 76 | for param_group in optimizer.param_groups: 77 | param_group["lr"] = new_lr 78 | return new_lr 79 | return cfg.SOLVER.LR 80 | 81 | 82 | def accuracy(output, target, topk=(1,)): 83 | with torch.no_grad(): 84 | maxk = max(topk) 85 | batch_size = target.size(0) 86 | _, pred = output.topk(maxk, 1, True, True) 87 | pred = pred.t() 88 | correct = pred.eq(target.reshape(1, -1).expand_as(pred)) 89 | res = [] 90 | for k in topk: 91 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 92 | res.append(correct_k.mul_(100.0 / batch_size)) 93 | return res 94 | 95 | 96 | def save_checkpoint(obj, path): 97 | with open(path, "wb") as f: 98 | torch.save(obj, f) 99 | 100 | 101 | def load_checkpoint(path): 102 | with open(path, "rb") as f: 103 | return torch.load(f, map_location="cpu") 104 | -------------------------------------------------------------------------------- /mdistiller/distillers/VID.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from ._base import Distiller 7 | from ._common import get_feat_shapes 8 | 9 | 10 | def conv1x1(in_channels, out_channels, stride=1): 11 | return nn.Conv2d( 12 | in_channels, out_channels, kernel_size=1, padding=0, bias=False, stride=stride 13 | ) 14 | 15 | 16 | def vid_loss(regressor, log_scale, f_s, f_t, eps=1e-5): 17 | # pool for dimentsion match 18 | s_H, t_H = f_s.shape[2], f_t.shape[2] 19 | if s_H > t_H: 20 | f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H)) 21 | elif s_H < t_H: 22 | f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H)) 23 | else: 24 | pass 25 | pred_mean = regressor(f_s) 26 | pred_var = torch.log(1.0 + torch.exp(log_scale)) + eps 27 | pred_var = pred_var.view(1, -1, 1, 1).to(pred_mean) 28 | neg_log_prob = 0.5 * ((pred_mean - f_t) ** 2 / pred_var + torch.log(pred_var)) 29 | loss = torch.mean(neg_log_prob) 30 | return loss 31 | 32 | 33 | class VID(Distiller): 34 | """ 35 | Variational Information Distillation for Knowledge Transfer (CVPR 2019), 36 | code from author: https://github.com/ssahn0215/variational-information-distillation 37 | """ 38 | 39 | def __init__(self, student, teacher, cfg): 40 | super(VID, self).__init__(student, teacher) 41 | self.ce_loss_weight = cfg.VID.LOSS.CE_WEIGHT 42 | self.feat_loss_weight = cfg.VID.LOSS.FEAT_WEIGHT 43 | self.init_pred_var = cfg.VID.INIT_PRED_VAR 44 | self.eps = cfg.VID.EPS 45 | feat_s_shapes, feat_t_shapes = get_feat_shapes( 46 | self.student, self.teacher, cfg.VID.INPUT_SIZE 47 | ) 48 | feat_s_channels = [s[1] for s in feat_s_shapes[1:]] 49 | feat_t_channels = [s[1] for s in feat_t_shapes[1:]] 50 | self.init_vid_modules(feat_s_channels, feat_t_channels) 51 | 52 | def init_vid_modules(self, feat_s_shapes, feat_t_shapes): 53 | self.regressors = nn.ModuleList() 54 | self.log_scales = [] 55 | for s, t in zip(feat_s_shapes, feat_t_shapes): 56 | regressor = nn.Sequential( 57 | conv1x1(s, t), nn.ReLU(), conv1x1(t, t), nn.ReLU(), conv1x1(t, t) 58 | ) 59 | self.regressors.append(regressor) 60 | log_scale = torch.nn.Parameter( 61 | np.log(np.exp(self.init_pred_var - self.eps) - 1.0) * torch.ones(t) 62 | ) 63 | self.log_scales.append(log_scale) 64 | 65 | def get_learnable_parameters(self): 66 | parameters = super().get_learnable_parameters() 67 | for regressor in self.regressors: 68 | parameters += list(regressor.parameters()) 69 | return parameters 70 | 71 | def get_extra_parameters(self): 72 | num_p = 0 73 | for regressor in self.regressors: 74 | for p in regressor.parameters(): 75 | num_p += p.numel() 76 | return num_p 77 | 78 | def forward_train(self, image, target, **kwargs): 79 | logits_student, feature_student = self.student(image) 80 | with torch.no_grad(): 81 | _, feature_teacher = self.teacher(image) 82 | 83 | # losses 84 | loss_ce = self.ce_loss_weight * F.cross_entropy(logits_student, target) 85 | loss_vid = 0 86 | for i in range(len(feature_student["feats"][1:])): 87 | loss_vid += vid_loss( 88 | self.regressors[i], 89 | self.log_scales[i], 90 | feature_student["feats"][1:][i], 91 | feature_teacher["feats"][1:][i], 92 | self.eps, 93 | ) 94 | loss_vid = self.feat_loss_weight * loss_vid 95 | losses_dict = { 96 | "loss_ce": loss_ce, 97 | "loss_kd": loss_vid, 98 | } 99 | return logits_student, losses_dict 100 | -------------------------------------------------------------------------------- /tools/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | import torch.backends.cudnn as cudnn 6 | 7 | cudnn.benchmark = True 8 | 9 | from mdistiller.models import cifar_model_dict, imagenet_model_dict, tiny_imagenet_model_dict 10 | from mdistiller.distillers import distiller_dict 11 | from mdistiller.dataset import get_dataset 12 | from mdistiller.engine.utils import load_checkpoint, log_msg 13 | from mdistiller.engine.cfg import CFG as cfg 14 | from mdistiller.engine.cfg import show_cfg 15 | from mdistiller.engine import trainer_dict 16 | 17 | 18 | def main(cfg, resume, opts): 19 | experiment_name = cfg.EXPERIMENT.NAME 20 | if experiment_name == "": 21 | experiment_name = cfg.EXPERIMENT.TAG 22 | tags = cfg.EXPERIMENT.TAG.split(",") 23 | if opts: 24 | addtional_tags = ["{}:{}".format(k, v) for k, v in zip(opts[::2], opts[1::2])] 25 | tags += addtional_tags 26 | experiment_name += ",".join(addtional_tags) 27 | experiment_name = os.path.join(cfg.EXPERIMENT.PROJECT, experiment_name) 28 | if cfg.LOG.WANDB: 29 | try: 30 | import wandb 31 | 32 | wandb.init(project=cfg.EXPERIMENT.PROJECT, name=experiment_name, tags=tags) 33 | except: 34 | print(log_msg("Failed to use WANDB", "INFO")) 35 | cfg.LOG.WANDB = False 36 | 37 | # cfg & loggers 38 | show_cfg(cfg) 39 | # init dataloader & models 40 | train_loader, val_loader, num_data, num_classes = get_dataset(cfg) 41 | 42 | # vanilla 43 | if cfg.DISTILLER.TYPE == "NONE": 44 | if cfg.DATASET.TYPE == "imagenet": 45 | model_student = imagenet_model_dict[cfg.DISTILLER.STUDENT](pretrained=False) 46 | elif cfg.DATASET.TYPE == "tiny_imagenet": 47 | model_student = tiny_imagenet_model_dict[cfg.DISTILLER.STUDENT][0](num_classes=num_classes) 48 | else: 49 | model_student = cifar_model_dict[cfg.DISTILLER.STUDENT][0]( 50 | num_classes=num_classes 51 | ) 52 | distiller = distiller_dict[cfg.DISTILLER.TYPE](model_student) 53 | # distillation 54 | else: 55 | print(log_msg("Loading teacher model", "INFO")) 56 | if cfg.DATASET.TYPE == "imagenet": 57 | model_teacher = imagenet_model_dict[cfg.DISTILLER.TEACHER](pretrained=True) 58 | model_student = imagenet_model_dict[cfg.DISTILLER.STUDENT](pretrained=False) 59 | else: 60 | model_dict = tiny_imagenet_model_dict if cfg.DATASET.TYPE == "tiny_imagenet" else cifar_model_dict 61 | net, pretrain_model_path = model_dict[cfg.DISTILLER.TEACHER] 62 | assert ( 63 | pretrain_model_path is not None 64 | ), "no pretrain model for teacher {}".format(cfg.DISTILLER.TEACHER) 65 | model_teacher = net(num_classes=num_classes) 66 | model_teacher.load_state_dict(load_checkpoint(pretrain_model_path)["model"]) 67 | model_student = model_dict[cfg.DISTILLER.STUDENT][0]( 68 | num_classes=num_classes 69 | ) 70 | if cfg.DISTILLER.TYPE == "CRD": 71 | distiller = distiller_dict[cfg.DISTILLER.TYPE]( 72 | model_student, model_teacher, cfg, num_data 73 | ) 74 | else: 75 | distiller = distiller_dict[cfg.DISTILLER.TYPE]( 76 | model_student, model_teacher, cfg 77 | ) 78 | distiller = torch.nn.DataParallel(distiller.cuda()) 79 | 80 | if cfg.DISTILLER.TYPE != "NONE": 81 | print( 82 | log_msg( 83 | "Extra parameters of {}: {}\033[0m".format( 84 | cfg.DISTILLER.TYPE, distiller.module.get_extra_parameters() 85 | ), 86 | "INFO", 87 | ) 88 | ) 89 | 90 | # train 91 | trainer = trainer_dict[cfg.SOLVER.TRAINER]( 92 | experiment_name, distiller, train_loader, val_loader, cfg 93 | ) 94 | trainer.train(resume=resume) 95 | 96 | 97 | if __name__ == "__main__": 98 | import argparse 99 | 100 | parser = argparse.ArgumentParser("training for knowledge distillation.") 101 | parser.add_argument("--cfg", type=str, default="") 102 | parser.add_argument("--resume", action="store_true") 103 | parser.add_argument("opts", default=None, nargs=argparse.REMAINDER) 104 | 105 | args = parser.parse_args() 106 | cfg.merge_from_file(args.cfg) 107 | cfg.merge_from_list(args.opts) 108 | cfg.freeze() 109 | main(cfg, args.resume, args.opts) 110 | -------------------------------------------------------------------------------- /mdistiller/engine/cfg.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | from .utils import log_msg 3 | 4 | 5 | def show_cfg(cfg): 6 | dump_cfg = CN() 7 | dump_cfg.EXPERIMENT = cfg.EXPERIMENT 8 | dump_cfg.DATASET = cfg.DATASET 9 | dump_cfg.DISTILLER = cfg.DISTILLER 10 | dump_cfg.SOLVER = cfg.SOLVER 11 | dump_cfg.LOG = cfg.LOG 12 | if cfg.DISTILLER.TYPE in cfg: 13 | dump_cfg.update({cfg.DISTILLER.TYPE: cfg.get(cfg.DISTILLER.TYPE)}) 14 | print(log_msg("CONFIG:\n{}".format(dump_cfg.dump()), "INFO")) 15 | 16 | 17 | CFG = CN() 18 | 19 | # Experiment 20 | CFG.EXPERIMENT = CN() 21 | CFG.EXPERIMENT.PROJECT = "distill" 22 | CFG.EXPERIMENT.NAME = "" 23 | CFG.EXPERIMENT.TAG = "default" 24 | 25 | # Dataset 26 | CFG.DATASET = CN() 27 | CFG.DATASET.TYPE = "cifar100" 28 | CFG.DATASET.NUM_WORKERS = 2 29 | CFG.DATASET.TEST = CN() 30 | CFG.DATASET.TEST.BATCH_SIZE = 64 31 | 32 | # Distiller 33 | CFG.DISTILLER = CN() 34 | CFG.DISTILLER.TYPE = "NONE" # Vanilla as default 35 | CFG.DISTILLER.TEACHER = "ResNet50" 36 | CFG.DISTILLER.STUDENT = "resnet32" 37 | 38 | # Solver 39 | CFG.SOLVER = CN() 40 | CFG.SOLVER.TRAINER = "base" 41 | CFG.SOLVER.BATCH_SIZE = 64 42 | CFG.SOLVER.EPOCHS = 240 43 | CFG.SOLVER.LR = 0.05 44 | CFG.SOLVER.LR_DECAY_STAGES = [150, 180, 210] 45 | CFG.SOLVER.LR_DECAY_RATE = 0.1 46 | CFG.SOLVER.WEIGHT_DECAY = 0.0001 47 | CFG.SOLVER.MOMENTUM = 0.9 48 | CFG.SOLVER.TYPE = "SGD" 49 | 50 | # Log 51 | CFG.LOG = CN() 52 | CFG.LOG.TENSORBOARD_FREQ = 500 53 | CFG.LOG.SAVE_CHECKPOINT_FREQ = 40 54 | CFG.LOG.PREFIX = "./output" 55 | CFG.LOG.WANDB = True 56 | 57 | # Distillation Methods 58 | 59 | # KD CFG 60 | CFG.KD = CN() 61 | CFG.KD.TEMPERATURE = 4 62 | CFG.KD.LOSS = CN() 63 | CFG.KD.LOSS.CE_WEIGHT = 0.1 64 | CFG.KD.LOSS.KD_WEIGHT = 0.9 65 | 66 | # AT CFG 67 | CFG.AT = CN() 68 | CFG.AT.P = 2 69 | CFG.AT.LOSS = CN() 70 | CFG.AT.LOSS.CE_WEIGHT = 1.0 71 | CFG.AT.LOSS.FEAT_WEIGHT = 1000.0 72 | 73 | # RKD CFG 74 | CFG.RKD = CN() 75 | CFG.RKD.DISTANCE_WEIGHT = 25 76 | CFG.RKD.ANGLE_WEIGHT = 50 77 | CFG.RKD.LOSS = CN() 78 | CFG.RKD.LOSS.CE_WEIGHT = 1.0 79 | CFG.RKD.LOSS.FEAT_WEIGHT = 1.0 80 | CFG.RKD.PDIST = CN() 81 | CFG.RKD.PDIST.EPSILON = 1e-12 82 | CFG.RKD.PDIST.SQUARED = False 83 | 84 | # FITNET CFG 85 | CFG.FITNET = CN() 86 | CFG.FITNET.HINT_LAYER = 2 # (0, 1, 2, 3, 4) 87 | CFG.FITNET.INPUT_SIZE = (32, 32) 88 | CFG.FITNET.LOSS = CN() 89 | CFG.FITNET.LOSS.CE_WEIGHT = 1.0 90 | CFG.FITNET.LOSS.FEAT_WEIGHT = 100.0 91 | 92 | # KDSVD CFG 93 | CFG.KDSVD = CN() 94 | CFG.KDSVD.K = 1 95 | CFG.KDSVD.LOSS = CN() 96 | CFG.KDSVD.LOSS.CE_WEIGHT = 1.0 97 | CFG.KDSVD.LOSS.FEAT_WEIGHT = 1.0 98 | 99 | # OFD CFG 100 | CFG.OFD = CN() 101 | CFG.OFD.LOSS = CN() 102 | CFG.OFD.LOSS.CE_WEIGHT = 1.0 103 | CFG.OFD.LOSS.FEAT_WEIGHT = 0.001 104 | CFG.OFD.CONNECTOR = CN() 105 | CFG.OFD.CONNECTOR.KERNEL_SIZE = 1 106 | 107 | # NST CFG 108 | CFG.NST = CN() 109 | CFG.NST.LOSS = CN() 110 | CFG.NST.LOSS.CE_WEIGHT = 1.0 111 | CFG.NST.LOSS.FEAT_WEIGHT = 50.0 112 | 113 | # PKT CFG 114 | CFG.PKT = CN() 115 | CFG.PKT.LOSS = CN() 116 | CFG.PKT.LOSS.CE_WEIGHT = 1.0 117 | CFG.PKT.LOSS.FEAT_WEIGHT = 30000.0 118 | 119 | # SP CFG 120 | CFG.SP = CN() 121 | CFG.SP.LOSS = CN() 122 | CFG.SP.LOSS.CE_WEIGHT = 1.0 123 | CFG.SP.LOSS.FEAT_WEIGHT = 3000.0 124 | 125 | # VID CFG 126 | CFG.VID = CN() 127 | CFG.VID.LOSS = CN() 128 | CFG.VID.LOSS.CE_WEIGHT = 1.0 129 | CFG.VID.LOSS.FEAT_WEIGHT = 1.0 130 | CFG.VID.EPS = 1e-5 131 | CFG.VID.INIT_PRED_VAR = 5.0 132 | CFG.VID.INPUT_SIZE = (32, 32) 133 | 134 | # CRD CFG 135 | CFG.CRD = CN() 136 | CFG.CRD.MODE = "exact" # ("exact", "relax") 137 | CFG.CRD.FEAT = CN() 138 | CFG.CRD.FEAT.DIM = 128 139 | CFG.CRD.FEAT.STUDENT_DIM = 256 140 | CFG.CRD.FEAT.TEACHER_DIM = 256 141 | CFG.CRD.LOSS = CN() 142 | CFG.CRD.LOSS.CE_WEIGHT = 1.0 143 | CFG.CRD.LOSS.FEAT_WEIGHT = 0.8 144 | CFG.CRD.NCE = CN() 145 | CFG.CRD.NCE.K = 16384 146 | CFG.CRD.NCE.MOMENTUM = 0.5 147 | CFG.CRD.NCE.TEMPERATURE = 0.07 148 | 149 | # ReviewKD CFG 150 | CFG.REVIEWKD = CN() 151 | CFG.REVIEWKD.CE_WEIGHT = 1.0 152 | CFG.REVIEWKD.REVIEWKD_WEIGHT = 1.0 153 | CFG.REVIEWKD.WARMUP_EPOCHS = 20 154 | CFG.REVIEWKD.SHAPES = [1, 8, 16, 32] 155 | CFG.REVIEWKD.OUT_SHAPES = [1, 8, 16, 32] 156 | CFG.REVIEWKD.IN_CHANNELS = [64, 128, 256, 256] 157 | CFG.REVIEWKD.OUT_CHANNELS = [64, 128, 256, 256] 158 | CFG.REVIEWKD.MAX_MID_CHANNEL = 512 159 | CFG.REVIEWKD.STU_PREACT = False 160 | 161 | # DKD(Decoupled Knowledge Distillation) CFG 162 | CFG.DKD = CN() 163 | CFG.DKD.CE_WEIGHT = 1.0 164 | CFG.DKD.ALPHA = 1.0 165 | CFG.DKD.BETA = 8.0 166 | CFG.DKD.T = 4.0 167 | CFG.DKD.WARMUP = 20 168 | 169 | 170 | # DOT CFG 171 | CFG.SOLVER.DOT = CN() 172 | CFG.SOLVER.DOT.DELTA = 0.075 173 | -------------------------------------------------------------------------------- /mdistiller/dataset/imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from torchvision.datasets import ImageFolder 5 | import torchvision.transforms as transforms 6 | 7 | 8 | data_folder = os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../data/imagenet') 9 | 10 | 11 | class ImageNet(ImageFolder): 12 | def __getitem__(self, index): 13 | img, target = super().__getitem__(index) 14 | return img, target, index 15 | 16 | 17 | class ImageNetInstanceSample(ImageNet): 18 | """: Folder datasets which returns (img, label, index, contrast_index): 19 | """ 20 | def __init__(self, folder, transform=None, target_transform=None, 21 | is_sample=False, k=4096): 22 | super().__init__(folder, transform=transform) 23 | 24 | self.k = k 25 | self.is_sample = is_sample 26 | if self.is_sample: 27 | print('preparing contrastive data...') 28 | num_classes = 1000 29 | num_samples = len(self.samples) 30 | label = np.zeros(num_samples, dtype=np.int32) 31 | for i in range(num_samples): 32 | _, target = self.samples[i] 33 | label[i] = target 34 | 35 | self.cls_positive = [[] for i in range(num_classes)] 36 | for i in range(num_samples): 37 | self.cls_positive[label[i]].append(i) 38 | 39 | self.cls_negative = [[] for i in range(num_classes)] 40 | for i in range(num_classes): 41 | for j in range(num_classes): 42 | if j == i: 43 | continue 44 | self.cls_negative[i].extend(self.cls_positive[j]) 45 | 46 | self.cls_positive = [np.asarray(self.cls_positive[i], dtype=np.int32) for i in range(num_classes)] 47 | self.cls_negative = [np.asarray(self.cls_negative[i], dtype=np.int32) for i in range(num_classes)] 48 | print('done.') 49 | 50 | def __getitem__(self, index): 51 | """ 52 | Args: 53 | index (int): Index 54 | Returns: 55 | tuple: (image, target) where target is class_index of the target class. 56 | """ 57 | img, target, index = super().__getitem__(index) 58 | 59 | if self.is_sample: 60 | # sample contrastive examples 61 | pos_idx = index 62 | neg_idx = np.random.choice(self.cls_negative[target], self.k, replace=True) 63 | sample_idx = np.hstack((np.asarray([pos_idx]), neg_idx)) 64 | return img, target, index, sample_idx 65 | else: 66 | return img, target, index 67 | 68 | def get_imagenet_train_transform(mean, std): 69 | normalize = transforms.Normalize(mean=mean, std=std) 70 | train_transform = transforms.Compose( 71 | [ 72 | transforms.RandomResizedCrop(224), 73 | transforms.RandomHorizontalFlip(), 74 | transforms.ToTensor(), 75 | normalize, 76 | ] 77 | ) 78 | return train_transform 79 | 80 | def get_imagenet_test_transform(mean, std): 81 | normalize = transforms.Normalize(mean=mean, std=std) 82 | test_transform = transforms.Compose( 83 | [ 84 | transforms.Resize(256), 85 | transforms.CenterCrop(224), 86 | transforms.ToTensor(), 87 | normalize, 88 | ] 89 | ) 90 | return test_transform 91 | 92 | def get_imagenet_dataloaders(batch_size, val_batch_size, num_workers, 93 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): 94 | train_transform = get_imagenet_train_transform(mean, std) 95 | train_folder = os.path.join(data_folder, 'train') 96 | train_set = ImageNet(train_folder, transform=train_transform) 97 | num_data = len(train_set) 98 | train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, 99 | shuffle=True, num_workers=num_workers, pin_memory=True) 100 | test_loader = get_imagenet_val_loader(val_batch_size, mean, std) 101 | return train_loader, test_loader, num_data 102 | 103 | def get_imagenet_dataloaders_sample(batch_size, val_batch_size, num_workers, k=4096, 104 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): 105 | train_transform = get_imagenet_train_transform(mean, std) 106 | train_folder = os.path.join(data_folder, 'train') 107 | train_set = ImageNetInstanceSample(train_folder, transform=train_transform, is_sample=True, k=k) 108 | num_data = len(train_set) 109 | train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, 110 | shuffle=True, num_workers=num_workers, pin_memory=True) 111 | test_loader = get_imagenet_val_loader(val_batch_size, mean, std) 112 | return train_loader, test_loader, num_data 113 | 114 | def get_imagenet_val_loader(val_batch_size, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): 115 | test_transform = get_imagenet_test_transform(mean, std) 116 | test_folder = os.path.join(data_folder, 'val') 117 | test_set = ImageFolder(test_folder, transform=test_transform) 118 | test_loader = torch.utils.data.DataLoader(test_set, 119 | batch_size=val_batch_size, shuffle=False, num_workers=16, pin_memory=True) 120 | return test_loader 121 | -------------------------------------------------------------------------------- /mdistiller/dataset/tiny_imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.data import DataLoader 3 | from torchvision import datasets 4 | from torchvision import transforms 5 | import numpy as np 6 | 7 | 8 | data_folder = os.path.join( 9 | os.path.dirname(os.path.abspath(__file__)), "../../data/tiny-imagenet-200" 10 | ) 11 | 12 | 13 | class ImageFolderInstance(datasets.ImageFolder): 14 | def __getitem__(self, index): 15 | path, target = self.imgs[index] 16 | img = self.loader(path) 17 | if self.transform is not None: 18 | img = self.transform(img) 19 | return img, target, index 20 | 21 | 22 | class ImageFolderInstanceSample(ImageFolderInstance): 23 | """: Folder datasets which returns (img, label, index, contrast_index): 24 | """ 25 | def __init__(self, folder, transform=None, target_transform=None, 26 | is_sample=False, k=4096): 27 | super().__init__(folder, transform=transform) 28 | 29 | self.k = k 30 | self.is_sample = is_sample 31 | if self.is_sample: 32 | num_classes = 200 33 | num_samples = len(self.samples) 34 | label = np.zeros(num_samples, dtype=np.int32) 35 | for i in range(num_samples): 36 | img, target = self.samples[i] 37 | label[i] = target 38 | 39 | self.cls_positive = [[] for i in range(num_classes)] 40 | for i in range(num_samples): 41 | self.cls_positive[label[i]].append(i) 42 | 43 | self.cls_negative = [[] for i in range(num_classes)] 44 | for i in range(num_classes): 45 | for j in range(num_classes): 46 | if j == i: 47 | continue 48 | self.cls_negative[i].extend(self.cls_positive[j]) 49 | 50 | self.cls_positive = [np.asarray(self.cls_positive[i], dtype=np.int32) for i in range(num_classes)] 51 | self.cls_negative = [np.asarray(self.cls_negative[i], dtype=np.int32) for i in range(num_classes)] 52 | print('dataset initialized!') 53 | 54 | def __getitem__(self, index): 55 | """ 56 | Args: 57 | index (int): Index 58 | Returns: 59 | tuple: (image, target) where target is class_index of the target class. 60 | """ 61 | img, target, index = super().__getitem__(index) 62 | 63 | if self.is_sample: 64 | # sample contrastive examples 65 | pos_idx = index 66 | neg_idx = np.random.choice(self.cls_negative[target], self.k, replace=True) 67 | sample_idx = np.hstack((np.asarray([pos_idx]), neg_idx)) 68 | return img, target, index, sample_idx 69 | else: 70 | return img, target, index 71 | 72 | 73 | def get_tinyimagenet_dataloader(batch_size, val_batch_size, num_workers): 74 | """Data Loader for tiny-imagenet""" 75 | train_transform = transforms.Compose([ 76 | transforms.RandomRotation(20), 77 | transforms.RandomHorizontalFlip(0.5), 78 | transforms.ToTensor(), 79 | transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]), 80 | ]) 81 | test_transform = transforms.Compose([ 82 | transforms.ToTensor(), 83 | transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]), 84 | ]) 85 | train_folder = os.path.join(data_folder, "train") 86 | test_folder = os.path.join(data_folder, "val") 87 | train_set = ImageFolderInstance(train_folder, transform=train_transform) 88 | num_data = len(train_set) 89 | test_set = datasets.ImageFolder(test_folder, transform=test_transform) 90 | train_loader = DataLoader( 91 | train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers 92 | ) 93 | test_loader = DataLoader( 94 | test_set, batch_size=val_batch_size, shuffle=False, num_workers=1 95 | ) 96 | return train_loader, test_loader, num_data 97 | 98 | 99 | def get_tinyimagenet_dataloader_sample(batch_size, val_batch_size, num_workers, k): 100 | """Data Loader for tiny-imagenet""" 101 | train_transform = transforms.Compose([ 102 | transforms.RandomRotation(20), 103 | transforms.RandomHorizontalFlip(0.5), 104 | transforms.ToTensor(), 105 | transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]), 106 | ]) 107 | test_transform = transforms.Compose([ 108 | transforms.ToTensor(), 109 | transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]), 110 | ]) 111 | train_folder = os.path.join(data_folder, "train") 112 | test_folder = os.path.join(data_folder, "val") 113 | train_set = ImageFolderInstanceSample(train_folder, transform=train_transform, is_sample=True, k=k) 114 | num_data = len(train_set) 115 | test_set = datasets.ImageFolder(test_folder, transform=test_transform) 116 | train_loader = DataLoader( 117 | train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers 118 | ) 119 | test_loader = DataLoader( 120 | test_set, batch_size=val_batch_size, shuffle=False, num_workers=1 121 | ) 122 | return train_loader, test_loader, num_data 123 | -------------------------------------------------------------------------------- /mdistiller/models/cifar/ShuffleNetv1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ShuffleBlock(nn.Module): 7 | def __init__(self, groups): 8 | super(ShuffleBlock, self).__init__() 9 | self.groups = groups 10 | 11 | def forward(self, x): 12 | """Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]""" 13 | N, C, H, W = x.size() 14 | g = self.groups 15 | return x.reshape(N, g, C // g, H, W).permute(0, 2, 1, 3, 4).reshape(N, C, H, W) 16 | 17 | 18 | class Bottleneck(nn.Module): 19 | def __init__(self, in_planes, out_planes, stride, groups, is_last=False): 20 | super(Bottleneck, self).__init__() 21 | self.is_last = is_last 22 | self.stride = stride 23 | 24 | mid_planes = int(out_planes / 4) 25 | g = 1 if in_planes == 24 else groups 26 | self.conv1 = nn.Conv2d( 27 | in_planes, mid_planes, kernel_size=1, groups=g, bias=False 28 | ) 29 | self.bn1 = nn.BatchNorm2d(mid_planes) 30 | self.shuffle1 = ShuffleBlock(groups=g) 31 | self.conv2 = nn.Conv2d( 32 | mid_planes, 33 | mid_planes, 34 | kernel_size=3, 35 | stride=stride, 36 | padding=1, 37 | groups=mid_planes, 38 | bias=False, 39 | ) 40 | self.bn2 = nn.BatchNorm2d(mid_planes) 41 | self.conv3 = nn.Conv2d( 42 | mid_planes, out_planes, kernel_size=1, groups=groups, bias=False 43 | ) 44 | self.bn3 = nn.BatchNorm2d(out_planes) 45 | 46 | self.shortcut = nn.Sequential() 47 | if stride == 2: 48 | self.shortcut = nn.Sequential(nn.AvgPool2d(3, stride=2, padding=1)) 49 | 50 | def forward(self, x): 51 | out = F.relu(self.bn1(self.conv1(x))) 52 | out = self.shuffle1(out) 53 | out = F.relu(self.bn2(self.conv2(out))) 54 | out = self.bn3(self.conv3(out)) 55 | res = self.shortcut(x) 56 | preact = torch.cat([out, res], 1) if self.stride == 2 else out + res 57 | out = F.relu(preact) 58 | # out = F.relu(torch.cat([out, res], 1)) if self.stride == 2 else F.relu(out+res) 59 | if self.is_last: 60 | return out, preact 61 | else: 62 | return out 63 | 64 | 65 | class ShuffleNet(nn.Module): 66 | def __init__(self, cfg, num_classes=10): 67 | super(ShuffleNet, self).__init__() 68 | out_planes = cfg["out_planes"] 69 | num_blocks = cfg["num_blocks"] 70 | groups = cfg["groups"] 71 | 72 | self.conv1 = nn.Conv2d(3, 24, kernel_size=1, bias=False) 73 | self.bn1 = nn.BatchNorm2d(24) 74 | self.in_planes = 24 75 | self.layer1 = self._make_layer(out_planes[0], num_blocks[0], groups) 76 | self.layer2 = self._make_layer(out_planes[1], num_blocks[1], groups) 77 | self.layer3 = self._make_layer(out_planes[2], num_blocks[2], groups) 78 | self.linear = nn.Linear(out_planes[2], num_classes) 79 | # self.stage_channels = out_channels 80 | 81 | def _make_layer(self, out_planes, num_blocks, groups): 82 | layers = [] 83 | for i in range(num_blocks): 84 | stride = 2 if i == 0 else 1 85 | cat_planes = self.in_planes if i == 0 else 0 86 | layers.append( 87 | Bottleneck( 88 | self.in_planes, 89 | out_planes - cat_planes, 90 | stride=stride, 91 | groups=groups, 92 | is_last=(i == num_blocks - 1), 93 | ) 94 | ) 95 | self.in_planes = out_planes 96 | return nn.Sequential(*layers) 97 | 98 | def get_feat_modules(self): 99 | feat_m = nn.ModuleList([]) 100 | feat_m.append(self.conv1) 101 | feat_m.append(self.bn1) 102 | feat_m.append(self.layer1) 103 | feat_m.append(self.layer2) 104 | feat_m.append(self.layer3) 105 | return feat_m 106 | 107 | def get_bn_before_relu(self): 108 | raise NotImplementedError( 109 | 'ShuffleNet currently is not supported for "Overhaul" teacher' 110 | ) 111 | 112 | def forward( 113 | self, 114 | x, 115 | ): 116 | out = F.relu(self.bn1(self.conv1(x))) 117 | f0 = out 118 | out, f1_pre = self.layer1(out) 119 | f1 = out 120 | out, f2_pre = self.layer2(out) 121 | f2 = out 122 | out, f3_pre = self.layer3(out) 123 | f3 = out 124 | out = F.avg_pool2d(out, 4) 125 | out = out.reshape(out.size(0), -1) 126 | f4 = out 127 | out = self.linear(out) 128 | 129 | feats = {} 130 | feats["feats"] = [f0, f1, f2, f3] 131 | feats["preact_feats"] = [f0, f1_pre, f2_pre, f3_pre] 132 | feats["pooled_feat"] = f4 133 | 134 | return out, feats 135 | 136 | 137 | def ShuffleV1(**kwargs): 138 | cfg = {"out_planes": [240, 480, 960], "num_blocks": [4, 8, 4], "groups": 3} 139 | return ShuffleNet(cfg, **kwargs) 140 | 141 | 142 | if __name__ == "__main__": 143 | 144 | x = torch.randn(2, 3, 32, 32) 145 | net = ShuffleV1(num_classes=100) 146 | import time 147 | 148 | a = time.time() 149 | logit, feats = net(x) 150 | b = time.time() 151 | print(b - a) 152 | for f in feats["feats"]: 153 | print(f.shape, f.min().item()) 154 | print(logit.shape) 155 | -------------------------------------------------------------------------------- /mdistiller/distillers/ReviewKD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | import math 6 | import pdb 7 | 8 | from ._base import Distiller 9 | 10 | 11 | def hcl_loss(fstudent, fteacher): 12 | loss_all = 0.0 13 | for fs, ft in zip(fstudent, fteacher): 14 | n, c, h, w = fs.shape 15 | loss = F.mse_loss(fs, ft, reduction="mean") 16 | cnt = 1.0 17 | tot = 1.0 18 | for l in [4, 2, 1]: 19 | if l >= h: 20 | continue 21 | tmpfs = F.adaptive_avg_pool2d(fs, (l, l)) 22 | tmpft = F.adaptive_avg_pool2d(ft, (l, l)) 23 | cnt /= 2.0 24 | loss += F.mse_loss(tmpfs, tmpft, reduction="mean") * cnt 25 | tot += cnt 26 | loss = loss / tot 27 | loss_all = loss_all + loss 28 | return loss_all 29 | 30 | 31 | class ReviewKD(Distiller): 32 | def __init__(self, student, teacher, cfg): 33 | super(ReviewKD, self).__init__(student, teacher) 34 | self.shapes = cfg.REVIEWKD.SHAPES 35 | self.out_shapes = cfg.REVIEWKD.OUT_SHAPES 36 | in_channels = cfg.REVIEWKD.IN_CHANNELS 37 | out_channels = cfg.REVIEWKD.OUT_CHANNELS 38 | self.ce_loss_weight = cfg.REVIEWKD.CE_WEIGHT 39 | self.reviewkd_loss_weight = cfg.REVIEWKD.REVIEWKD_WEIGHT 40 | self.warmup_epochs = cfg.REVIEWKD.WARMUP_EPOCHS 41 | self.stu_preact = cfg.REVIEWKD.STU_PREACT 42 | self.max_mid_channel = cfg.REVIEWKD.MAX_MID_CHANNEL 43 | 44 | abfs = nn.ModuleList() 45 | mid_channel = min(512, in_channels[-1]) 46 | for idx, in_channel in enumerate(in_channels): 47 | abfs.append( 48 | ABF( 49 | in_channel, 50 | mid_channel, 51 | out_channels[idx], 52 | idx < len(in_channels) - 1, 53 | ) 54 | ) 55 | self.abfs = abfs[::-1] 56 | 57 | def get_learnable_parameters(self): 58 | return super().get_learnable_parameters() + list(self.abfs.parameters()) 59 | 60 | def get_extra_parameters(self): 61 | num_p = 0 62 | for p in self.abfs.parameters(): 63 | num_p += p.numel() 64 | return num_p 65 | 66 | def forward_train(self, image, target, **kwargs): 67 | logits_student, features_student = self.student(image) 68 | with torch.no_grad(): 69 | logits_teacher, features_teacher = self.teacher(image) 70 | 71 | # get features 72 | if self.stu_preact: 73 | x = features_student["preact_feats"] + [ 74 | features_student["pooled_feat"].unsqueeze(-1).unsqueeze(-1) 75 | ] 76 | else: 77 | x = features_student["feats"] + [ 78 | features_student["pooled_feat"].unsqueeze(-1).unsqueeze(-1) 79 | ] 80 | x = x[::-1] 81 | results = [] 82 | out_features, res_features = self.abfs[0](x[0], out_shape=self.out_shapes[0]) 83 | results.append(out_features) 84 | for features, abf, shape, out_shape in zip( 85 | x[1:], self.abfs[1:], self.shapes[1:], self.out_shapes[1:] 86 | ): 87 | out_features, res_features = abf(features, res_features, shape, out_shape) 88 | results.insert(0, out_features) 89 | features_teacher = features_teacher["preact_feats"][1:] + [ 90 | features_teacher["pooled_feat"].unsqueeze(-1).unsqueeze(-1) 91 | ] 92 | # losses 93 | loss_ce = self.ce_loss_weight * F.cross_entropy(logits_student, target) 94 | loss_reviewkd = ( 95 | self.reviewkd_loss_weight 96 | * min(kwargs["epoch"] / self.warmup_epochs, 1.0) 97 | * hcl_loss(results, features_teacher) 98 | ) 99 | losses_dict = { 100 | "loss_ce": loss_ce, 101 | "loss_kd": loss_reviewkd, 102 | } 103 | return logits_student, losses_dict 104 | 105 | 106 | class ABF(nn.Module): 107 | def __init__(self, in_channel, mid_channel, out_channel, fuse): 108 | super(ABF, self).__init__() 109 | self.conv1 = nn.Sequential( 110 | nn.Conv2d(in_channel, mid_channel, kernel_size=1, bias=False), 111 | nn.BatchNorm2d(mid_channel), 112 | ) 113 | self.conv2 = nn.Sequential( 114 | nn.Conv2d( 115 | mid_channel, out_channel, kernel_size=3, stride=1, padding=1, bias=False 116 | ), 117 | nn.BatchNorm2d(out_channel), 118 | ) 119 | if fuse: 120 | self.att_conv = nn.Sequential( 121 | nn.Conv2d(mid_channel * 2, 2, kernel_size=1), 122 | nn.Sigmoid(), 123 | ) 124 | else: 125 | self.att_conv = None 126 | nn.init.kaiming_uniform_(self.conv1[0].weight, a=1) # pyre-ignore 127 | nn.init.kaiming_uniform_(self.conv2[0].weight, a=1) # pyre-ignore 128 | 129 | def forward(self, x, y=None, shape=None, out_shape=None): 130 | n, _, h, w = x.shape 131 | # transform student features 132 | x = self.conv1(x) 133 | if self.att_conv is not None: 134 | # upsample residual features 135 | y = F.interpolate(y, (shape, shape), mode="nearest") 136 | # fusion 137 | z = torch.cat([x, y], dim=1) 138 | z = self.att_conv(z) 139 | x = x * z[:, 0].view(n, 1, h, w) + y * z[:, 1].view(n, 1, h, w) 140 | # output 141 | if x.shape[-1] != out_shape: 142 | x = F.interpolate(x, (out_shape, out_shape), mode="nearest") 143 | y = self.conv2(x) 144 | return y, x 145 | -------------------------------------------------------------------------------- /mdistiller/distillers/OFD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from scipy.stats import norm 5 | import numpy as np 6 | import math 7 | 8 | from ._base import Distiller 9 | 10 | 11 | def feat_loss(source, target, margin): 12 | margin = margin.to(source) 13 | loss = ( 14 | (source - margin) ** 2 * ((source > margin) & (target <= margin)).float() 15 | + (source - target) ** 2 16 | * ((source > target) & (target > margin) & (target <= 0)).float() 17 | + (source - target) ** 2 * (target > 0).float() 18 | ) 19 | return torch.abs(loss).mean(dim=0).sum() 20 | 21 | 22 | class ConnectorConvBN(nn.Module): 23 | def __init__(self, s_channels, t_channels, kernel_size=1): 24 | super(ConnectorConvBN, self).__init__() 25 | self.s_channels = s_channels 26 | self.t_channels = t_channels 27 | self.connectors = nn.ModuleList( 28 | self._make_conenctors(s_channels, t_channels, kernel_size) 29 | ) 30 | 31 | def _make_conenctors(self, s_channels, t_channels, kernel_size): 32 | assert len(s_channels) == len(t_channels), "unequal length of feat list" 33 | connectors = nn.ModuleList( 34 | [ 35 | self._build_feature_connector(t, s, kernel_size) 36 | for t, s in zip(t_channels, s_channels) 37 | ] 38 | ) 39 | return connectors 40 | 41 | def _build_feature_connector(self, t_channel, s_channel, kernel_size): 42 | C = [ 43 | nn.Conv2d( 44 | s_channel, 45 | t_channel, 46 | kernel_size=kernel_size, 47 | stride=1, 48 | padding=(kernel_size - 1) // 2, 49 | bias=False, 50 | ), 51 | nn.BatchNorm2d(t_channel), 52 | ] 53 | for m in C: 54 | if isinstance(m, nn.Conv2d): 55 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 56 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 57 | elif isinstance(m, nn.BatchNorm2d): 58 | m.weight.data.fill_(1) 59 | m.bias.data.zero_() 60 | return nn.Sequential(*C) 61 | 62 | def forward(self, g_s): 63 | out = [] 64 | for i in range(len(g_s)): 65 | out.append(self.connectors[i](g_s[i])) 66 | 67 | return out 68 | 69 | 70 | class OFD(Distiller): 71 | def __init__(self, student, teacher, cfg): 72 | super(OFD, self).__init__(student, teacher) 73 | self.ce_loss_weight = cfg.OFD.LOSS.CE_WEIGHT 74 | self.feat_loss_weight = cfg.OFD.LOSS.FEAT_WEIGHT 75 | self.init_ofd_modules( 76 | tea_channels=self.teacher.get_stage_channels()[1:], 77 | stu_channels=self.student.get_stage_channels()[1:], 78 | bn_before_relu=self.teacher.get_bn_before_relu(), 79 | kernel_size=cfg.OFD.CONNECTOR.KERNEL_SIZE, 80 | ) 81 | 82 | def init_ofd_modules( 83 | self, tea_channels, stu_channels, bn_before_relu, kernel_size=1 84 | ): 85 | tea_channels, stu_channels = self._align_list(tea_channels, stu_channels) 86 | self.connectors = ConnectorConvBN( 87 | stu_channels, tea_channels, kernel_size=kernel_size 88 | ) 89 | 90 | self.margins = [] 91 | for idx, bn in enumerate(bn_before_relu): 92 | margin = [] 93 | std = bn.weight.data 94 | mean = bn.bias.data 95 | for (s, m) in zip(std, mean): 96 | s = abs(s.item()) 97 | m = m.item() 98 | if norm.cdf(-m / s) > 0.001: 99 | margin.append( 100 | -s 101 | * math.exp(-((m / s) ** 2) / 2) 102 | / math.sqrt(2 * math.pi) 103 | / norm.cdf(-m / s) 104 | + m 105 | ) 106 | else: 107 | margin.append(-3 * s) 108 | margin = torch.FloatTensor(margin).to(std.device) 109 | self.margins.append(margin.unsqueeze(1).unsqueeze(2).unsqueeze(0).detach()) 110 | 111 | def get_learnable_parameters(self): 112 | return super().get_learnable_parameters() + list(self.connectors.parameters()) 113 | 114 | def train(self, mode=True): 115 | # teacher as eval mode by default 116 | if not isinstance(mode, bool): 117 | raise ValueError("training mode is expected to be boolean") 118 | self.training = mode 119 | for module in self.children(): 120 | module.train(mode) 121 | return self 122 | 123 | def get_extra_parameters(self): 124 | num_p = 0 125 | for p in self.connectors.parameters(): 126 | num_p += p.numel() 127 | return num_p 128 | 129 | def forward_train(self, image, target, **kwargs): 130 | logits_student, feature_student = self.student(image) 131 | with torch.no_grad(): 132 | _, feature_teacher = self.teacher(image) 133 | 134 | # losses 135 | loss_ce = self.ce_loss_weight * F.cross_entropy(logits_student, target) 136 | loss_feat = self.feat_loss_weight * self.ofd_loss( 137 | feature_student["preact_feats"][1:], feature_teacher["preact_feats"][1:] 138 | ) 139 | losses_dict = {"loss_ce": loss_ce, "loss_kd": loss_feat} 140 | return logits_student, losses_dict 141 | 142 | def ofd_loss(self, feature_student, feature_teacher): 143 | feature_student, feature_teacher = self._align_list( 144 | feature_student, feature_teacher 145 | ) 146 | feature_student = [ 147 | self.connectors.connectors[idx](feat) 148 | for idx, feat in enumerate(feature_student) 149 | ] 150 | 151 | loss_distill = 0 152 | feat_num = len(feature_student) 153 | for i in range(feat_num): 154 | loss_distill = loss_distill + feat_loss( 155 | feature_student[i], 156 | F.adaptive_avg_pool2d( 157 | feature_teacher[i], feature_student[i].shape[-2:] 158 | ).detach(), 159 | self.margins[i], 160 | ) / 2 ** (feat_num - i - 1) 161 | return loss_distill 162 | 163 | def _align_list(self, *input_list): 164 | min_len = min([len(l) for l in input_list]) 165 | return [l[-min_len:] for l in input_list] 166 | -------------------------------------------------------------------------------- /mdistiller/models/cifar/wrn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | __all__ = ["wrn"] 8 | 9 | 10 | class BasicBlock(nn.Module): 11 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 12 | super(BasicBlock, self).__init__() 13 | self.bn1 = nn.BatchNorm2d(in_planes) 14 | self.relu1 = nn.ReLU(inplace=True) 15 | self.conv1 = nn.Conv2d( 16 | in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False 17 | ) 18 | self.bn2 = nn.BatchNorm2d(out_planes) 19 | self.relu2 = nn.ReLU(inplace=True) 20 | self.conv2 = nn.Conv2d( 21 | out_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=False 22 | ) 23 | self.droprate = dropRate 24 | self.equalInOut = in_planes == out_planes 25 | self.convShortcut = ( 26 | (not self.equalInOut) 27 | and nn.Conv2d( 28 | in_planes, 29 | out_planes, 30 | kernel_size=1, 31 | stride=stride, 32 | padding=0, 33 | bias=False, 34 | ) 35 | or None 36 | ) 37 | 38 | def forward(self, x): 39 | if not self.equalInOut: 40 | x = self.relu1(self.bn1(x)) 41 | else: 42 | out = self.relu1(self.bn1(x)) 43 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 44 | if self.droprate > 0: 45 | out = F.dropout(out, p=self.droprate, training=self.training) 46 | out = self.conv2(out) 47 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 48 | 49 | 50 | class NetworkBlock(nn.Module): 51 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 52 | super(NetworkBlock, self).__init__() 53 | self.layer = self._make_layer( 54 | block, in_planes, out_planes, nb_layers, stride, dropRate 55 | ) 56 | 57 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 58 | layers = [] 59 | for i in range(nb_layers): 60 | layers.append( 61 | block( 62 | i == 0 and in_planes or out_planes, 63 | out_planes, 64 | i == 0 and stride or 1, 65 | dropRate, 66 | ) 67 | ) 68 | return nn.Sequential(*layers) 69 | 70 | def forward(self, x): 71 | return self.layer(x) 72 | 73 | 74 | class WideResNet(nn.Module): 75 | def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0): 76 | super(WideResNet, self).__init__() 77 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 78 | assert (depth - 4) % 6 == 0, "depth should be 6n+4" 79 | n = (depth - 4) // 6 80 | block = BasicBlock 81 | # 1st conv before any network block 82 | self.conv1 = nn.Conv2d( 83 | 3, nChannels[0], kernel_size=3, stride=1, padding=1, bias=False 84 | ) 85 | # 1st block 86 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 87 | # 2nd block 88 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 89 | # 3rd block 90 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 91 | # global average pooling and classifier 92 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 93 | self.relu = nn.ReLU(inplace=True) 94 | self.fc = nn.Linear(nChannels[3], num_classes) 95 | self.nChannels = nChannels[3] 96 | 97 | for m in self.modules(): 98 | if isinstance(m, nn.Conv2d): 99 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 100 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 101 | elif isinstance(m, nn.BatchNorm2d): 102 | m.weight.data.fill_(1) 103 | m.bias.data.zero_() 104 | elif isinstance(m, nn.Linear): 105 | m.bias.data.zero_() 106 | self.stage_channels = nChannels 107 | 108 | def get_feat_modules(self): 109 | feat_m = nn.ModuleList([]) 110 | feat_m.append(self.conv1) 111 | feat_m.append(self.block1) 112 | feat_m.append(self.block2) 113 | feat_m.append(self.block3) 114 | return feat_m 115 | 116 | def get_bn_before_relu(self): 117 | bn1 = self.block2.layer[0].bn1 118 | bn2 = self.block3.layer[0].bn1 119 | bn3 = self.bn1 120 | 121 | return [bn1, bn2, bn3] 122 | 123 | def get_stage_channels(self): 124 | return self.stage_channels 125 | 126 | def forward(self, x): 127 | out = self.conv1(x) 128 | f0 = out 129 | out = self.block1(out) 130 | f1 = out 131 | out = self.block2(out) 132 | f2 = out 133 | out = self.block3(out) 134 | f3 = out 135 | out = self.relu(self.bn1(out)) 136 | out = F.avg_pool2d(out, 8) 137 | out = out.reshape(-1, self.nChannels) 138 | f4 = out 139 | out = self.fc(out) 140 | 141 | f1_pre = self.block2.layer[0].bn1(f1) 142 | f2_pre = self.block3.layer[0].bn1(f2) 143 | f3_pre = self.bn1(f3) 144 | 145 | feats = {} 146 | feats["feats"] = [f0, f1, f2, f3] 147 | feats["preact_feats"] = [f0, f1_pre, f2_pre, f3_pre] 148 | feats["pooled_feat"] = f4 149 | 150 | return out, feats 151 | 152 | 153 | def wrn(**kwargs): 154 | """ 155 | Constructs a Wide Residual Networks. 156 | """ 157 | model = WideResNet(**kwargs) 158 | return model 159 | 160 | 161 | def wrn_40_2(**kwargs): 162 | model = WideResNet(depth=40, widen_factor=2, **kwargs) 163 | return model 164 | 165 | 166 | def wrn_40_1(**kwargs): 167 | model = WideResNet(depth=40, widen_factor=1, **kwargs) 168 | return model 169 | 170 | 171 | def wrn_16_2(**kwargs): 172 | model = WideResNet(depth=16, widen_factor=2, **kwargs) 173 | return model 174 | 175 | 176 | def wrn_16_1(**kwargs): 177 | model = WideResNet(depth=16, widen_factor=1, **kwargs) 178 | return model 179 | 180 | 181 | if __name__ == "__main__": 182 | import torch 183 | 184 | x = torch.randn(2, 3, 32, 32) 185 | net = wrn_40_2(num_classes=100) 186 | logit, feats = net(x) 187 | 188 | for f in feats["feats"]: 189 | print(f.shape, f.min().item()) 190 | print(logit.shape) 191 | -------------------------------------------------------------------------------- /mdistiller/models/cifar/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | __all__ = ["mobilenetv2_T_w", "mobile_half"] 6 | 7 | BN = None 8 | 9 | 10 | def conv_bn(inp, oup, stride): 11 | return nn.Sequential( 12 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 13 | nn.BatchNorm2d(oup), 14 | nn.ReLU(inplace=True), 15 | ) 16 | 17 | 18 | def conv_1x1_bn(inp, oup): 19 | return nn.Sequential( 20 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 21 | nn.BatchNorm2d(oup), 22 | nn.ReLU(inplace=True), 23 | ) 24 | 25 | 26 | class InvertedResidual(nn.Module): 27 | def __init__(self, inp, oup, stride, expand_ratio): 28 | super(InvertedResidual, self).__init__() 29 | self.blockname = None 30 | 31 | self.stride = stride 32 | assert stride in [1, 2] 33 | 34 | self.use_res_connect = self.stride == 1 and inp == oup 35 | 36 | self.conv = nn.Sequential( 37 | # pw 38 | nn.Conv2d(inp, inp * expand_ratio, 1, 1, 0, bias=False), 39 | nn.BatchNorm2d(inp * expand_ratio), 40 | nn.ReLU(inplace=True), 41 | # dw 42 | nn.Conv2d( 43 | inp * expand_ratio, 44 | inp * expand_ratio, 45 | 3, 46 | stride, 47 | 1, 48 | groups=inp * expand_ratio, 49 | bias=False, 50 | ), 51 | nn.BatchNorm2d(inp * expand_ratio), 52 | nn.ReLU(inplace=True), 53 | # pw-linear 54 | nn.Conv2d(inp * expand_ratio, oup, 1, 1, 0, bias=False), 55 | nn.BatchNorm2d(oup), 56 | ) 57 | self.names = ["0", "1", "2", "3", "4", "5", "6", "7"] 58 | 59 | def forward(self, x): 60 | t = x 61 | if self.use_res_connect: 62 | return t + self.conv(x) 63 | else: 64 | return self.conv(x) 65 | 66 | 67 | class MobileNetV2(nn.Module): 68 | """mobilenetV2""" 69 | 70 | def __init__(self, T, feature_dim, input_size=32, width_mult=1.0, remove_avg=False): 71 | super(MobileNetV2, self).__init__() 72 | self.remove_avg = remove_avg 73 | 74 | # setting of inverted residual blocks 75 | self.interverted_residual_setting = [ 76 | # t, c, n, s 77 | [1, 16, 1, 1], 78 | [T, 24, 2, 1], 79 | [T, 32, 3, 2], 80 | [T, 64, 4, 2], 81 | [T, 96, 3, 1], 82 | [T, 160, 3, 2], 83 | [T, 320, 1, 1], 84 | ] 85 | 86 | # building first layer 87 | assert input_size % 32 == 0 88 | input_channel = int(32 * width_mult) 89 | self.conv1 = conv_bn(3, input_channel, 2) 90 | 91 | # building inverted residual blocks 92 | self.blocks = nn.ModuleList([]) 93 | for t, c, n, s in self.interverted_residual_setting: 94 | output_channel = int(c * width_mult) 95 | layers = [] 96 | strides = [s] + [1] * (n - 1) 97 | for stride in strides: 98 | layers.append( 99 | InvertedResidual(input_channel, output_channel, stride, t) 100 | ) 101 | input_channel = output_channel 102 | self.blocks.append(nn.Sequential(*layers)) 103 | 104 | self.last_channel = int(1280 * width_mult) if width_mult > 1.0 else 1280 105 | self.conv2 = conv_1x1_bn(input_channel, self.last_channel) 106 | 107 | # building classifier 108 | self.classifier = nn.Sequential( 109 | # nn.Dropout(0.5), 110 | nn.Linear(self.last_channel, feature_dim), 111 | ) 112 | 113 | H = input_size // (32 // 2) 114 | self.avgpool = nn.AvgPool2d(H, ceil_mode=True) 115 | 116 | self._initialize_weights() 117 | print(T, width_mult) 118 | self.stage_channels = [32, 24, 32, 96, 320] 119 | self.stage_channels = [int(c * width_mult) for c in self.stage_channels] 120 | 121 | def get_bn_before_relu(self): 122 | bn1 = self.blocks[1][-1].conv[-1] 123 | bn2 = self.blocks[2][-1].conv[-1] 124 | bn3 = self.blocks[4][-1].conv[-1] 125 | bn4 = self.blocks[6][-1].conv[-1] 126 | return [bn1, bn2, bn3, bn4] 127 | 128 | def get_feat_modules(self): 129 | feat_m = nn.ModuleList([]) 130 | feat_m.append(self.conv1) 131 | feat_m.append(self.blocks) 132 | return feat_m 133 | 134 | def get_stage_channels(self): 135 | return self.stage_channels 136 | 137 | def forward(self, x): 138 | out = self.conv1(x) 139 | f0 = out 140 | out = self.blocks[0](out) 141 | out = self.blocks[1](out) 142 | f1 = out 143 | out = self.blocks[2](out) 144 | f2 = out 145 | out = self.blocks[3](out) 146 | out = self.blocks[4](out) 147 | f3 = out 148 | out = self.blocks[5](out) 149 | out = self.blocks[6](out) 150 | f4 = out 151 | out = self.conv2(out) 152 | 153 | if not self.remove_avg: 154 | out = self.avgpool(out) 155 | out = out.reshape(out.size(0), -1) 156 | avg = out 157 | out = self.classifier(out) 158 | 159 | feats = {} 160 | feats["feats"] = [f0, f1, f2, f3, f4] 161 | feats["pooled_feat"] = avg 162 | 163 | return out, feats 164 | 165 | def _initialize_weights(self): 166 | for m in self.modules(): 167 | if isinstance(m, nn.Conv2d): 168 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 169 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 170 | if m.bias is not None: 171 | m.bias.data.zero_() 172 | elif isinstance(m, nn.BatchNorm2d): 173 | m.weight.data.fill_(1) 174 | m.bias.data.zero_() 175 | elif isinstance(m, nn.Linear): 176 | n = m.weight.size(1) 177 | m.weight.data.normal_(0, 0.01) 178 | m.bias.data.zero_() 179 | 180 | 181 | def mobilenetv2_T_w(T, W, feature_dim=100): 182 | model = MobileNetV2(T=T, feature_dim=feature_dim, width_mult=W) 183 | return model 184 | 185 | 186 | def mobile_half(num_classes): 187 | return mobilenetv2_T_w(6, 0.5, num_classes) 188 | 189 | 190 | if __name__ == "__main__": 191 | x = torch.randn(2, 3, 32, 32) 192 | 193 | net = mobile_half(100) 194 | 195 | logit, feats = net(x) 196 | for f in feats["feats"]: 197 | print(f.shape, f.min().item()) 198 | print(logit.shape) 199 | -------------------------------------------------------------------------------- /mdistiller/dataset/cifar100.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from torch.utils.data import DataLoader 4 | from torchvision import datasets, transforms 5 | from PIL import Image 6 | 7 | 8 | def get_data_folder(): 9 | data_folder = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../data") 10 | if not os.path.isdir(data_folder): 11 | os.makedirs(data_folder) 12 | return data_folder 13 | 14 | 15 | class CIFAR100Instance(datasets.CIFAR100): 16 | """CIFAR100Instance Dataset.""" 17 | 18 | def __getitem__(self, index): 19 | img, target = super().__getitem__(index) 20 | return img, target, index 21 | 22 | 23 | # CIFAR-100 for CRD 24 | class CIFAR100InstanceSample(datasets.CIFAR100): 25 | """ 26 | CIFAR100Instance+Sample Dataset 27 | """ 28 | 29 | def __init__( 30 | self, 31 | root, 32 | train=True, 33 | transform=None, 34 | target_transform=None, 35 | download=False, 36 | k=4096, 37 | mode="exact", 38 | is_sample=True, 39 | percent=1.0, 40 | ): 41 | super().__init__( 42 | root=root, 43 | train=train, 44 | download=download, 45 | transform=transform, 46 | target_transform=target_transform, 47 | ) 48 | self.k = k 49 | self.mode = mode 50 | self.is_sample = is_sample 51 | 52 | num_classes = 100 53 | num_samples = len(self.data) 54 | label = self.targets 55 | 56 | self.cls_positive = [[] for i in range(num_classes)] 57 | for i in range(num_samples): 58 | self.cls_positive[label[i]].append(i) 59 | 60 | self.cls_negative = [[] for i in range(num_classes)] 61 | for i in range(num_classes): 62 | for j in range(num_classes): 63 | if j == i: 64 | continue 65 | self.cls_negative[i].extend(self.cls_positive[j]) 66 | 67 | self.cls_positive = [ 68 | np.asarray(self.cls_positive[i]) for i in range(num_classes) 69 | ] 70 | self.cls_negative = [ 71 | np.asarray(self.cls_negative[i]) for i in range(num_classes) 72 | ] 73 | 74 | if 0 < percent < 1: 75 | n = int(len(self.cls_negative[0]) * percent) 76 | self.cls_negative = [ 77 | np.random.permutation(self.cls_negative[i])[0:n] 78 | for i in range(num_classes) 79 | ] 80 | 81 | self.cls_positive = np.asarray(self.cls_positive) 82 | self.cls_negative = np.asarray(self.cls_negative) 83 | 84 | def __getitem__(self, index): 85 | img, target = self.data[index], self.targets[index] 86 | 87 | # doing this so that it is consistent with all other datasets 88 | # to return a PIL Image 89 | img = Image.fromarray(img) 90 | 91 | if self.transform is not None: 92 | img = self.transform(img) 93 | 94 | if self.target_transform is not None: 95 | target = self.target_transform(target) 96 | 97 | if not self.is_sample: 98 | # directly return 99 | return img, target, index 100 | else: 101 | # sample contrastive examples 102 | if self.mode == "exact": 103 | pos_idx = index 104 | elif self.mode == "relax": 105 | pos_idx = np.random.choice(self.cls_positive[target], 1) 106 | pos_idx = pos_idx[0] 107 | else: 108 | raise NotImplementedError(self.mode) 109 | replace = True if self.k > len(self.cls_negative[target]) else False 110 | neg_idx = np.random.choice( 111 | self.cls_negative[target], self.k, replace=replace 112 | ) 113 | sample_idx = np.hstack((np.asarray([pos_idx]), neg_idx)) 114 | return img, target, index, sample_idx 115 | 116 | 117 | def get_cifar100_train_transform(): 118 | train_transform = transforms.Compose( 119 | [ 120 | transforms.RandomCrop(32, padding=4), 121 | transforms.RandomHorizontalFlip(), 122 | transforms.ToTensor(), 123 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), 124 | ] 125 | ) 126 | 127 | return train_transform 128 | 129 | 130 | def get_cifar100_test_transform(): 131 | return transforms.Compose( 132 | [ 133 | transforms.ToTensor(), 134 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), 135 | ] 136 | ) 137 | 138 | 139 | def get_cifar100_dataloaders(batch_size, val_batch_size, num_workers): 140 | data_folder = get_data_folder() 141 | train_transform = get_cifar100_train_transform() 142 | test_transform = get_cifar100_test_transform() 143 | train_set = CIFAR100Instance( 144 | root=data_folder, download=True, train=True, transform=train_transform 145 | ) 146 | num_data = len(train_set) 147 | test_set = datasets.CIFAR100( 148 | root=data_folder, download=True, train=False, transform=test_transform 149 | ) 150 | 151 | train_loader = DataLoader( 152 | train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers 153 | ) 154 | test_loader = DataLoader( 155 | test_set, 156 | batch_size=val_batch_size, 157 | shuffle=False, 158 | num_workers=1, 159 | ) 160 | return train_loader, test_loader, num_data 161 | 162 | 163 | # CIFAR-100 for CRD 164 | def get_cifar100_dataloaders_sample( 165 | batch_size, val_batch_size, num_workers, k, mode="exact" 166 | ): 167 | data_folder = get_data_folder() 168 | train_transform = get_cifar100_train_transform() 169 | test_transform = get_cifar100_test_transform() 170 | 171 | train_set = CIFAR100InstanceSample( 172 | root=data_folder, 173 | download=True, 174 | train=True, 175 | transform=train_transform, 176 | k=k, 177 | mode=mode, 178 | is_sample=True, 179 | percent=1.0, 180 | ) 181 | num_data = len(train_set) 182 | test_set = datasets.CIFAR100( 183 | root=data_folder, download=True, train=False, transform=test_transform 184 | ) 185 | 186 | train_loader = DataLoader( 187 | train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers 188 | ) 189 | test_loader = DataLoader( 190 | test_set, 191 | batch_size=val_batch_size, 192 | shuffle=False, 193 | num_workers=num_workers, 194 | ) 195 | return train_loader, test_loader, num_data 196 | -------------------------------------------------------------------------------- /detection/train_net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | """ 4 | A main training script. 5 | 6 | This scripts reads a given config file and runs the training or evaluation. 7 | It is an entry point that is made to train standard models in detectron2. 8 | 9 | In order to let one script support training of many models, 10 | this script contains logic that are specific to these built-in models and therefore 11 | may not be suitable for your own project. 12 | For example, your research project perhaps only needs a single "evaluator". 13 | 14 | Therefore, we recommend you to use detectron2 as an library and take 15 | this file as an example of how to use the library. 16 | You may want to write your own script with your datasets and other customizations. 17 | """ 18 | 19 | import logging 20 | import os 21 | from collections import OrderedDict 22 | import torch 23 | 24 | import detectron2.utils.comm as comm 25 | from detectron2.checkpoint import DetectionCheckpointer 26 | from detectron2.config import get_cfg 27 | from detectron2.data import MetadataCatalog 28 | from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, hooks, launch 29 | from detectron2.evaluation import ( 30 | CityscapesInstanceEvaluator, 31 | CityscapesSemSegEvaluator, 32 | COCOEvaluator, 33 | COCOPanopticEvaluator, 34 | DatasetEvaluators, 35 | LVISEvaluator, 36 | PascalVOCDetectionEvaluator, 37 | SemSegEvaluator, 38 | verify_results, 39 | ) 40 | from detectron2.modeling import GeneralizedRCNNWithTTA 41 | 42 | from model import add_distillation_cfg 43 | from model import RCNNKD 44 | 45 | 46 | class Trainer(DefaultTrainer): 47 | """ 48 | We use the "DefaultTrainer" which contains pre-defined default logic for 49 | standard training workflow. They may not work for you, especially if you 50 | are working on a new research project. In that case you can write your 51 | own training loop. You can use "tools/plain_train_net.py" as an example. 52 | """ 53 | 54 | @classmethod 55 | def build_evaluator(cls, cfg, dataset_name, output_folder=None): 56 | """ 57 | Create evaluator(s) for a given dataset. 58 | This uses the special metadata "evaluator_type" associated with each builtin dataset. 59 | For your own dataset, you can simply create an evaluator manually in your 60 | script and do not have to worry about the hacky if-else logic here. 61 | """ 62 | if output_folder is None: 63 | output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") 64 | evaluator_list = [] 65 | evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type 66 | if evaluator_type in ["sem_seg", "coco_panoptic_seg"]: 67 | evaluator_list.append( 68 | SemSegEvaluator( 69 | dataset_name, 70 | distributed=True, 71 | output_dir=output_folder, 72 | ) 73 | ) 74 | if evaluator_type in ["coco", "coco_panoptic_seg"]: 75 | evaluator_list.append(COCOEvaluator(dataset_name, output_dir=output_folder)) 76 | if evaluator_type == "coco_panoptic_seg": 77 | evaluator_list.append(COCOPanopticEvaluator(dataset_name, output_folder)) 78 | if evaluator_type == "cityscapes_instance": 79 | assert ( 80 | torch.cuda.device_count() >= comm.get_rank() 81 | ), "CityscapesEvaluator currently do not work with multiple machines." 82 | return CityscapesInstanceEvaluator(dataset_name) 83 | if evaluator_type == "cityscapes_sem_seg": 84 | assert ( 85 | torch.cuda.device_count() >= comm.get_rank() 86 | ), "CityscapesEvaluator currently do not work with multiple machines." 87 | return CityscapesSemSegEvaluator(dataset_name) 88 | elif evaluator_type == "pascal_voc": 89 | return PascalVOCDetectionEvaluator(dataset_name) 90 | elif evaluator_type == "lvis": 91 | return LVISEvaluator(dataset_name, output_dir=output_folder) 92 | if len(evaluator_list) == 0: 93 | raise NotImplementedError( 94 | "no Evaluator for the dataset {} with the type {}".format( 95 | dataset_name, evaluator_type 96 | ) 97 | ) 98 | elif len(evaluator_list) == 1: 99 | return evaluator_list[0] 100 | return DatasetEvaluators(evaluator_list) 101 | 102 | @classmethod 103 | def test_with_TTA(cls, cfg, model): 104 | logger = logging.getLogger("detectron2.trainer") 105 | # In the end of training, run an evaluation with TTA 106 | # Only support some R-CNN models. 107 | logger.info("Running inference with test-time augmentation ...") 108 | model = GeneralizedRCNNWithTTA(cfg, model) 109 | evaluators = [ 110 | cls.build_evaluator( 111 | cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA") 112 | ) 113 | for name in cfg.DATASETS.TEST 114 | ] 115 | res = cls.test(cfg, model, evaluators) 116 | res = OrderedDict({k + "_TTA": v for k, v in res.items()}) 117 | return res 118 | 119 | 120 | def setup(args): 121 | """ 122 | Create configs and perform basic setups. 123 | """ 124 | cfg = get_cfg() 125 | add_distillation_cfg(cfg) 126 | cfg.merge_from_file(args.config_file) 127 | cfg.merge_from_list(args.opts) 128 | cfg.freeze() 129 | default_setup(cfg, args) 130 | return cfg 131 | 132 | 133 | def main(args): 134 | cfg = setup(args) 135 | 136 | if args.eval_only: 137 | model = Trainer.build_model(cfg) 138 | DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( 139 | cfg.MODEL.WEIGHTS, resume=args.resume 140 | ) 141 | res = Trainer.test(cfg, model) 142 | if cfg.TEST.AUG.ENABLED: 143 | res.update(Trainer.test_with_TTA(cfg, model)) 144 | if comm.is_main_process(): 145 | verify_results(cfg, res) 146 | return res 147 | 148 | """ 149 | If you'd like to do anything fancier than the standard training logic, 150 | consider writing your own training loop (see plain_train_net.py) or 151 | subclassing the trainer. 152 | """ 153 | trainer = Trainer(cfg) 154 | trainer.resume_or_load(resume=args.resume) 155 | if cfg.TEST.AUG.ENABLED: 156 | trainer.register_hooks( 157 | [hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))] 158 | ) 159 | return trainer.train() 160 | 161 | 162 | if __name__ == "__main__": 163 | args = default_argument_parser().parse_args() 164 | print("Command Line Args:", args) 165 | launch( 166 | main, 167 | args.num_gpus, 168 | num_machines=args.num_machines, 169 | machine_rank=args.machine_rank, 170 | dist_url=args.dist_url, 171 | args=(args,), 172 | ) 173 | -------------------------------------------------------------------------------- /mdistiller/engine/dot.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import Tensor 4 | import torch.optim._functional as F 5 | from torch.optim.optimizer import Optimizer, required 6 | from typing import List, Optional 7 | 8 | 9 | def check_in(t, l): 10 | for i in l: 11 | if t is i: 12 | return True 13 | return False 14 | 15 | def dot(params: List[Tensor], 16 | d_p_list: List[Tensor], 17 | momentum_buffer_list: List[Optional[Tensor]], 18 | kd_grad_buffer: List[Optional[Tensor]], 19 | kd_momentum_buffer: List[Optional[Tensor]], 20 | kd_params: List[Tensor], 21 | *, 22 | weight_decay: float, 23 | momentum: float, 24 | momentum_kd: float, 25 | lr: float, 26 | dampening: float): 27 | for i, param in enumerate(params): 28 | d_p = d_p_list[i] 29 | if weight_decay != 0: 30 | d_p = d_p.add(param, alpha=weight_decay) 31 | if momentum != 0: 32 | buf = momentum_buffer_list[i] 33 | if buf is None: 34 | buf = torch.clone(d_p).detach() 35 | momentum_buffer_list[i] = buf 36 | elif check_in(param, kd_params): 37 | buf.mul_(momentum).add_(d_p, alpha=1 - dampening) 38 | else: 39 | buf.mul_((momentum_kd + momentum) / 2.).add_(d_p, alpha=1 - dampening) 40 | d_p = buf 41 | # update params with task-loss grad 42 | param.add_(d_p, alpha=-lr) 43 | 44 | for i, (d_p, buf, p) in enumerate(zip(kd_grad_buffer, kd_momentum_buffer, kd_params)): 45 | # update params with kd-loss grad 46 | if buf is None: 47 | buf = torch.clone(d_p).detach() 48 | kd_momentum_buffer[i] = buf 49 | elif check_in(p, params): 50 | buf.mul_(momentum_kd).add_(d_p, alpha=1 - dampening) 51 | else: 52 | if weight_decay != 0: 53 | d_p = d_p.add(p, alpha=weight_decay) 54 | buf.mul_((momentum_kd + momentum) / 2.).add_(d_p, alpha=1 - dampening) 55 | p.add_(buf, alpha=-lr) 56 | 57 | 58 | class DistillationOrientedTrainer(Optimizer): 59 | r""" 60 | Distillation-Oriented Trainer 61 | Usage: 62 | ... 63 | optimizer = DistillationOrientedTrainer() 64 | optimizer.zero_grad(set_to_none=True) 65 | loss_kd.backward(retain_graph=True) 66 | optimizer.step_kd() # get kd-grad and update kd-momentum 67 | optimizer.zero_grad(set_to_none=True) 68 | loss_task.backward() 69 | optimizer.step() # get task-grad and update tast-momentum, then update params. 70 | ... 71 | """ 72 | 73 | def __init__( 74 | self, 75 | params, 76 | lr=required, 77 | momentum=0, 78 | momentum_kd=0, 79 | dampening=0, 80 | weight_decay=0): 81 | if lr is not required and lr < 0.0: 82 | raise ValueError("Invalid learning rate: {}".format(lr)) 83 | if momentum < 0.0: 84 | raise ValueError("Invalid momentum value: {}".format(momentum)) 85 | if momentum_kd < 0.0: 86 | raise ValueError("Invalid momentum kd value: {}".format(momentum_kd)) 87 | if weight_decay < 0.0: 88 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 89 | defaults = dict(lr=lr, momentum=momentum, momentum_kd=momentum_kd, dampening=dampening, 90 | weight_decay=weight_decay) 91 | self.kd_grad_buffer = [] 92 | self.kd_grad_params = [] 93 | self.kd_momentum_buffer = [] 94 | super(DistillationOrientedTrainer, self).__init__(params, defaults) 95 | 96 | @torch.no_grad() 97 | def step_kd(self, closure=None): 98 | loss = None 99 | if closure is not None: 100 | with torch.enable_grad(): 101 | loss = closure() 102 | 103 | assert len(self.param_groups) == 1, "Only implement for one-group params." 104 | for group in self.param_groups: 105 | params_with_grad = [] 106 | d_p_list = [] 107 | momentum_kd_buffer_list = [] 108 | weight_decay = group['weight_decay'] 109 | dampening = group['dampening'] 110 | lr = group['lr'] 111 | for p in group['params']: 112 | if p.grad is not None: 113 | params_with_grad.append(p) 114 | d_p_list.append(p.grad) 115 | state = self.state[p] 116 | if 'momentum_kd_buffer' not in state: 117 | momentum_kd_buffer_list.append(None) 118 | else: 119 | momentum_kd_buffer_list.append(state['momentum_kd_buffer']) 120 | 121 | self.kd_momentum_buffer = momentum_kd_buffer_list 122 | self.kd_grad_buffer = d_p_list 123 | self.kd_grad_params = params_with_grad 124 | return loss 125 | 126 | @torch.no_grad() 127 | def step(self, closure=None): 128 | loss = None 129 | if closure is not None: 130 | with torch.enable_grad(): 131 | loss = closure() 132 | assert len(self.param_groups) == 1, "Only implement for one-group params." 133 | for group in self.param_groups: 134 | params_with_grad = [] 135 | d_p_list = [] 136 | momentum_buffer_list = [] 137 | weight_decay = group['weight_decay'] 138 | momentum = group['momentum'] 139 | momentum_kd = group['momentum_kd'] 140 | dampening = group['dampening'] 141 | lr = group['lr'] 142 | 143 | for p in group['params']: 144 | if p.grad is not None: 145 | params_with_grad.append(p) 146 | d_p_list.append(p.grad) 147 | 148 | state = self.state[p] 149 | if 'momentum_buffer' not in state: 150 | momentum_buffer_list.append(None) 151 | else: 152 | momentum_buffer_list.append(state['momentum_buffer']) 153 | dot(params_with_grad, 154 | d_p_list, 155 | momentum_buffer_list, 156 | self.kd_grad_buffer, 157 | self.kd_momentum_buffer, 158 | self.kd_grad_params, 159 | weight_decay=weight_decay, 160 | momentum=momentum, 161 | momentum_kd=momentum_kd, 162 | lr=lr, 163 | dampening=dampening) 164 | # update momentum_buffers in state 165 | for p, momentum_buffer in zip(params_with_grad, momentum_buffer_list): 166 | state = self.state[p] 167 | state['momentum_buffer'] = momentum_buffer 168 | for p, momentum_kd_buffer in zip(self.kd_grad_params, self.kd_momentum_buffer): 169 | state = self.state[p] 170 | state['momentum_kd_buffer'] = momentum_kd_buffer 171 | self.kd_grad_buffer = [] 172 | self.kd_grad_params = [] 173 | self.kd_momentum_buffer = [] 174 | return loss 175 | -------------------------------------------------------------------------------- /mdistiller/models/cifar/ShuffleNetv2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ShuffleBlock(nn.Module): 7 | def __init__(self, groups=2): 8 | super(ShuffleBlock, self).__init__() 9 | self.groups = groups 10 | 11 | def forward(self, x): 12 | """Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]""" 13 | N, C, H, W = x.size() 14 | g = self.groups 15 | return x.reshape(N, g, C // g, H, W).permute(0, 2, 1, 3, 4).reshape(N, C, H, W) 16 | 17 | 18 | class SplitBlock(nn.Module): 19 | def __init__(self, ratio): 20 | super(SplitBlock, self).__init__() 21 | self.ratio = ratio 22 | 23 | def forward(self, x): 24 | c = int(x.size(1) * self.ratio) 25 | return x[:, :c, :, :], x[:, c:, :, :] 26 | 27 | 28 | class BasicBlock(nn.Module): 29 | def __init__(self, in_channels, split_ratio=0.5, is_last=False): 30 | super(BasicBlock, self).__init__() 31 | self.is_last = is_last 32 | self.split = SplitBlock(split_ratio) 33 | in_channels = int(in_channels * split_ratio) 34 | self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=1, bias=False) 35 | self.bn1 = nn.BatchNorm2d(in_channels) 36 | self.conv2 = nn.Conv2d( 37 | in_channels, 38 | in_channels, 39 | kernel_size=3, 40 | stride=1, 41 | padding=1, 42 | groups=in_channels, 43 | bias=False, 44 | ) 45 | self.bn2 = nn.BatchNorm2d(in_channels) 46 | self.conv3 = nn.Conv2d(in_channels, in_channels, kernel_size=1, bias=False) 47 | self.bn3 = nn.BatchNorm2d(in_channels) 48 | self.shuffle = ShuffleBlock() 49 | 50 | def forward(self, x): 51 | x1, x2 = self.split(x) 52 | out = F.relu(self.bn1(self.conv1(x2))) 53 | out = self.bn2(self.conv2(out)) 54 | preact = self.bn3(self.conv3(out)) 55 | out = F.relu(preact) 56 | # out = F.relu(self.bn3(self.conv3(out))) 57 | preact = torch.cat([x1, preact], 1) 58 | out = torch.cat([x1, out], 1) 59 | out = self.shuffle(out) 60 | if self.is_last: 61 | return out, preact 62 | else: 63 | return out 64 | 65 | 66 | class DownBlock(nn.Module): 67 | def __init__(self, in_channels, out_channels): 68 | super(DownBlock, self).__init__() 69 | mid_channels = out_channels // 2 70 | # left 71 | self.conv1 = nn.Conv2d( 72 | in_channels, 73 | in_channels, 74 | kernel_size=3, 75 | stride=2, 76 | padding=1, 77 | groups=in_channels, 78 | bias=False, 79 | ) 80 | self.bn1 = nn.BatchNorm2d(in_channels) 81 | self.conv2 = nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False) 82 | self.bn2 = nn.BatchNorm2d(mid_channels) 83 | # right 84 | self.conv3 = nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False) 85 | self.bn3 = nn.BatchNorm2d(mid_channels) 86 | self.conv4 = nn.Conv2d( 87 | mid_channels, 88 | mid_channels, 89 | kernel_size=3, 90 | stride=2, 91 | padding=1, 92 | groups=mid_channels, 93 | bias=False, 94 | ) 95 | self.bn4 = nn.BatchNorm2d(mid_channels) 96 | self.conv5 = nn.Conv2d(mid_channels, mid_channels, kernel_size=1, bias=False) 97 | self.bn5 = nn.BatchNorm2d(mid_channels) 98 | 99 | self.shuffle = ShuffleBlock() 100 | 101 | def forward(self, x): 102 | # left 103 | out1 = self.bn1(self.conv1(x)) 104 | out1 = F.relu(self.bn2(self.conv2(out1))) 105 | # right 106 | out2 = F.relu(self.bn3(self.conv3(x))) 107 | out2 = self.bn4(self.conv4(out2)) 108 | out2 = F.relu(self.bn5(self.conv5(out2))) 109 | # concat 110 | out = torch.cat([out1, out2], 1) 111 | out = self.shuffle(out) 112 | return out 113 | 114 | 115 | class ShuffleNetV2(nn.Module): 116 | def __init__(self, net_size, num_classes=10): 117 | super(ShuffleNetV2, self).__init__() 118 | out_channels = configs[net_size]["out_channels"] 119 | num_blocks = configs[net_size]["num_blocks"] 120 | 121 | # self.conv1 = nn.Conv2d(3, 24, kernel_size=3, 122 | # stride=1, padding=1, bias=False) 123 | self.conv1 = nn.Conv2d(3, 24, kernel_size=1, bias=False) 124 | self.bn1 = nn.BatchNorm2d(24) 125 | self.in_channels = 24 126 | self.layer1 = self._make_layer(out_channels[0], num_blocks[0]) 127 | self.layer2 = self._make_layer(out_channels[1], num_blocks[1]) 128 | self.layer3 = self._make_layer(out_channels[2], num_blocks[2]) 129 | self.conv2 = nn.Conv2d( 130 | out_channels[2], 131 | out_channels[3], 132 | kernel_size=1, 133 | stride=1, 134 | padding=0, 135 | bias=False, 136 | ) 137 | self.bn2 = nn.BatchNorm2d(out_channels[3]) 138 | self.linear = nn.Linear(out_channels[3], num_classes) 139 | self.stage_channels = out_channels 140 | 141 | def _make_layer(self, out_channels, num_blocks): 142 | layers = [DownBlock(self.in_channels, out_channels)] 143 | for i in range(num_blocks): 144 | layers.append(BasicBlock(out_channels, is_last=(i == num_blocks - 1))) 145 | self.in_channels = out_channels 146 | return nn.Sequential(*layers) 147 | 148 | def get_feat_modules(self): 149 | feat_m = nn.ModuleList([]) 150 | feat_m.append(self.conv1) 151 | feat_m.append(self.bn1) 152 | feat_m.append(self.layer1) 153 | feat_m.append(self.layer2) 154 | feat_m.append(self.layer3) 155 | return feat_m 156 | 157 | def get_bn_before_relu(self): 158 | raise NotImplementedError( 159 | 'ShuffleNetV2 currently is not supported for "Overhaul" teacher' 160 | ) 161 | 162 | def get_stage_channels(self): 163 | return [24] + list(self.stage_channels[:-1]) 164 | 165 | def forward(self, x): 166 | out = F.relu(self.bn1(self.conv1(x))) 167 | # out = F.max_pool2d(out, 3, stride=2, padding=1) 168 | f0 = out 169 | out, f1_pre = self.layer1(out) 170 | f1 = out 171 | out, f2_pre = self.layer2(out) 172 | f2 = out 173 | out, f3_pre = self.layer3(out) 174 | f3 = out 175 | out = F.relu(self.bn2(self.conv2(out))) 176 | out = F.avg_pool2d(out, 4) 177 | out = out.reshape(out.size(0), -1) 178 | avg = out 179 | f4 = out 180 | out = self.linear(out) 181 | 182 | feats = {} 183 | feats["feats"] = [f0, f1, f2, f3] 184 | feats["preact_feats"] = [f0, f1_pre, f2_pre, f3_pre] 185 | feats["pooled_feat"] = f4 186 | 187 | return out, feats 188 | 189 | 190 | configs = { 191 | 0.2: {"out_channels": (40, 80, 160, 512), "num_blocks": (3, 3, 3)}, 192 | 0.3: {"out_channels": (40, 80, 160, 512), "num_blocks": (3, 7, 3)}, 193 | 0.5: {"out_channels": (48, 96, 192, 1024), "num_blocks": (3, 7, 3)}, 194 | 1: {"out_channels": (116, 232, 464, 1024), "num_blocks": (3, 7, 3)}, 195 | 1.5: {"out_channels": (176, 352, 704, 1024), "num_blocks": (3, 7, 3)}, 196 | 2: {"out_channels": (224, 488, 976, 2048), "num_blocks": (3, 7, 3)}, 197 | } 198 | 199 | 200 | def ShuffleV2(**kwargs): 201 | model = ShuffleNetV2(net_size=1, **kwargs) 202 | return model 203 | 204 | 205 | if __name__ == "__main__": 206 | net = ShuffleV2(num_classes=100) 207 | x = torch.randn(3, 3, 32, 32) 208 | import time 209 | 210 | a = time.time() 211 | logit, feats = net(x) 212 | b = time.time() 213 | print(b - a) 214 | for f in feats["feats"]: 215 | print(f.shape, f.min().item()) 216 | print(logit.shape) 217 | -------------------------------------------------------------------------------- /detection/model/backbone/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Creates a MobileNetV2 Model as defined in: 3 | Mark Sandler, Andrew Howard, Menglong Zhu, Andrey Zhmoginov, Liang-Chieh Chen. (2018). 4 | MobileNetV2: Inverted Residuals and Linear Bottlenecks 5 | arXiv preprint arXiv:1801.04381. 6 | import from https://github.com/tonylins/pytorch-mobilenet-v2 7 | """ 8 | 9 | import torch.nn as nn 10 | import math 11 | from detectron2.modeling.backbone import BACKBONE_REGISTRY 12 | from detectron2.modeling.backbone import Backbone, FPN 13 | from detectron2.modeling.backbone.fpn import LastLevelMaxPool 14 | 15 | from detectron2.layers import ( 16 | Conv2d, 17 | DeformConv, 18 | FrozenBatchNorm2d, 19 | ModulatedDeformConv, 20 | ShapeSpec, 21 | get_norm, 22 | ) 23 | 24 | __all__ = ['mobilenetv2'] 25 | 26 | 27 | def _make_divisible(v, divisor, min_value=None): 28 | """ 29 | This function is taken from the original tf repo. 30 | It ensures that all layers have a channel number that is divisible by 8 31 | It can be seen here: 32 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 33 | :param v: 34 | :param divisor: 35 | :param min_value: 36 | :return: 37 | """ 38 | if min_value is None: 39 | min_value = divisor 40 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 41 | # Make sure that round down does not go down by more than 10%. 42 | if new_v < 0.9 * v: 43 | new_v += divisor 44 | return new_v 45 | 46 | 47 | def conv_3x3_bn(inp, oup, stride, bn): 48 | return nn.Sequential( 49 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 50 | get_norm(bn, oup), 51 | nn.ReLU6(inplace=True) 52 | ) 53 | 54 | 55 | def conv_1x1_bn(inp, oup): 56 | return nn.Sequential( 57 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 58 | get_norm(bn, oup), 59 | nn.ReLU6(inplace=True) 60 | ) 61 | 62 | 63 | class InvertedResidual(nn.Module): 64 | def __init__(self, inp, oup, stride, expand_ratio, bn): 65 | super(InvertedResidual, self).__init__() 66 | assert stride in [1, 2] 67 | 68 | hidden_dim = round(inp * expand_ratio) 69 | self.identity = stride == 1 and inp == oup 70 | 71 | if expand_ratio == 1: 72 | self.conv = nn.Sequential( 73 | # dw 74 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 75 | get_norm(bn, hidden_dim), 76 | nn.ReLU6(inplace=True), 77 | # pw-linear 78 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 79 | get_norm(bn, oup), 80 | ) 81 | else: 82 | self.conv = nn.Sequential( 83 | # pw 84 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 85 | get_norm(bn, hidden_dim), 86 | nn.ReLU6(inplace=True), 87 | # dw 88 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 89 | get_norm(bn, hidden_dim), 90 | nn.ReLU6(inplace=True), 91 | # pw-linear 92 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 93 | get_norm(bn, oup), 94 | ) 95 | 96 | def forward(self, x): 97 | if self.identity: 98 | return x + self.conv(x) 99 | else: 100 | return self.conv(x) 101 | 102 | def freeze(self): 103 | for p in self.parameters(): 104 | p.requires_grad = False 105 | FrozenBatchNorm2d.convert_frozen_batchnorm(self) 106 | return self 107 | 108 | 109 | 110 | class MobileNetV2(Backbone): 111 | def __init__(self, cfg, input_shape, width_mult = 1.): 112 | super(MobileNetV2, self).__init__() 113 | self._out_features = cfg.MODEL.MOBILENETV2.OUT_FEATURES 114 | bn = cfg.MODEL.MOBILENETV2.NORM 115 | freeze_at = cfg.MODEL.BACKBONE.FREEZE_AT 116 | 117 | 118 | # setting of inverted residual blocks 119 | self.cfgs = [ 120 | # t, c, n, s 121 | [1, 16, 1, 1, ''], 122 | [6, 24, 2, 2, 'm2'], 123 | [6, 32, 3, 2, 'm3'], 124 | [6, 64, 4, 2, ''], 125 | [6, 96, 3, 1, 'm4'], 126 | [6, 160, 3, 2, ''], 127 | [6, 320, 1, 1, 'm5'], 128 | ] 129 | 130 | # building first layer 131 | input_channel = _make_divisible(32 * width_mult, 4 if width_mult == 0.1 else 8) 132 | layers = [conv_3x3_bn(input_shape.channels, input_channel, 2, bn)] 133 | if freeze_at >= 1: 134 | for p in layers[0].parameters(): 135 | p.requires_grad = False 136 | layers[0] = FrozenBatchNorm2d.convert_frozen_batchnorm(layers[0]) 137 | # building inverted residual blocks 138 | block = InvertedResidual 139 | self.stage_name = [''] 140 | self._out_feature_channels = {} 141 | self._out_feature_strides = {} 142 | cur_stride = 2 143 | cur_stage = 2 144 | for t, c, n, s, name in self.cfgs: 145 | output_channel = _make_divisible(c * width_mult, 4 if width_mult == 0.1 else 8) 146 | cur_stride = cur_stride * s 147 | for i in range(n): 148 | layers.append(block(input_channel, output_channel, s if i == 0 else 1, t, bn)) 149 | if cur_stage <= freeze_at : 150 | layers[-1].freeze() 151 | if name != '' and i == n-1: 152 | self._out_feature_channels[name] = output_channel 153 | self._out_feature_strides[name] = cur_stride 154 | cur_stage += 1 155 | input_channel = output_channel 156 | self.stage_name.append(name if i == n-1 else '') 157 | self.features = nn.Sequential(*layers) 158 | # building last several layers 159 | # output_channel = _make_divisible(1280 * width_mult, 4 if width_mult == 0.1 else 8) if width_mult > 1.0 else 1280 160 | # self.conv = conv_1x1_bn(input_channel, output_channel) 161 | # self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 162 | # self.classifier = nn.Linear(output_channel, num_classes) 163 | 164 | self._initialize_weights() 165 | 166 | def forward(self, x): 167 | output = {} 168 | for i in range(len(self.features)): 169 | x = self.features[i](x) 170 | if self.stage_name[i] in self._out_features: 171 | output[self.stage_name[i]] = x 172 | return output 173 | ''' 174 | x = self.features(x) 175 | x = self.conv(x) 176 | x = self.avgpool(x) 177 | x = x.view(x.size(0), -1) 178 | x = self.classifier(x) 179 | return x 180 | ''' 181 | 182 | def _initialize_weights(self): 183 | for m in self.modules(): 184 | if isinstance(m, nn.Conv2d): 185 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 186 | m.weight.data.normal_(0, math.sqrt(2. / n)) 187 | if m.bias is not None: 188 | m.bias.data.zero_() 189 | elif isinstance(m, nn.BatchNorm2d): 190 | m.weight.data.fill_(1) 191 | m.bias.data.zero_() 192 | elif isinstance(m, nn.Linear): 193 | m.weight.data.normal_(0, 0.01) 194 | m.bias.data.zero_() 195 | def output_shape(self): 196 | return { 197 | name: ShapeSpec( 198 | channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] 199 | ) 200 | for name in self._out_features 201 | } 202 | 203 | 204 | 205 | @BACKBONE_REGISTRY.register() 206 | def build_mobilenetv2_backbone(cfg, input_shape): 207 | """ 208 | Constructs a MobileNet V2 model 209 | """ 210 | return MobileNetV2(cfg, input_shape) 211 | 212 | 213 | 214 | -------------------------------------------------------------------------------- /mdistiller/models/cifar/vgg.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import math 4 | 5 | 6 | __all__ = [ 7 | "VGG", 8 | "vgg11", 9 | "vgg11_bn", 10 | "vgg13", 11 | "vgg13_bn", 12 | "vgg16", 13 | "vgg16_bn", 14 | "vgg19_bn", 15 | "vgg19", 16 | ] 17 | 18 | 19 | model_urls = { 20 | "vgg11": "https://download.pytorch.org/models/vgg11-bbd30ac9.pth", 21 | "vgg13": "https://download.pytorch.org/models/vgg13-c768596a.pth", 22 | "vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth", 23 | "vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth", 24 | } 25 | 26 | 27 | class VGG(nn.Module): 28 | def __init__(self, cfg, batch_norm=False, num_classes=1000): 29 | super(VGG, self).__init__() 30 | self.block0 = self._make_layers(cfg[0], batch_norm, 3) 31 | self.block1 = self._make_layers(cfg[1], batch_norm, cfg[0][-1]) 32 | self.block2 = self._make_layers(cfg[2], batch_norm, cfg[1][-1]) 33 | self.block3 = self._make_layers(cfg[3], batch_norm, cfg[2][-1]) 34 | self.block4 = self._make_layers(cfg[4], batch_norm, cfg[3][-1]) 35 | 36 | self.pool0 = nn.MaxPool2d(kernel_size=2, stride=2) 37 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) 38 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) 39 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) 40 | self.pool4 = nn.AdaptiveAvgPool2d((1, 1)) 41 | # self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) 42 | 43 | self.classifier = nn.Linear(512, num_classes) 44 | self._initialize_weights() 45 | self.stage_channels = [c[-1] for c in cfg] 46 | 47 | def get_feat_modules(self): 48 | feat_m = nn.ModuleList([]) 49 | feat_m.append(self.block0) 50 | feat_m.append(self.pool0) 51 | feat_m.append(self.block1) 52 | feat_m.append(self.pool1) 53 | feat_m.append(self.block2) 54 | feat_m.append(self.pool2) 55 | feat_m.append(self.block3) 56 | feat_m.append(self.pool3) 57 | feat_m.append(self.block4) 58 | feat_m.append(self.pool4) 59 | return feat_m 60 | 61 | def get_bn_before_relu(self): 62 | bn1 = self.block1[-1] 63 | bn2 = self.block2[-1] 64 | bn3 = self.block3[-1] 65 | bn4 = self.block4[-1] 66 | return [bn1, bn2, bn3, bn4] 67 | 68 | def get_stage_channels(self): 69 | return self.stage_channels 70 | 71 | def forward(self, x): 72 | h = x.shape[2] 73 | x = F.relu(self.block0(x)) 74 | f0 = x 75 | x = self.pool0(x) 76 | x = self.block1(x) 77 | f1_pre = x 78 | x = F.relu(x) 79 | f1 = x 80 | x = self.pool1(x) 81 | x = self.block2(x) 82 | f2_pre = x 83 | x = F.relu(x) 84 | f2 = x 85 | x = self.pool2(x) 86 | x = self.block3(x) 87 | f3_pre = x 88 | x = F.relu(x) 89 | f3 = x 90 | if h == 64: 91 | x = self.pool3(x) 92 | x = self.block4(x) 93 | f4_pre = x 94 | x = F.relu(x) 95 | f4 = x 96 | x = self.pool4(x) 97 | x = x.reshape(x.size(0), -1) 98 | f5 = x 99 | x = self.classifier(x) 100 | 101 | feats = {} 102 | feats["feats"] = [f0, f1, f2, f3, f4] 103 | feats["preact_feats"] = [f0, f1_pre, f2_pre, f3_pre, f4_pre] 104 | feats["pooled_feat"] = f5 105 | 106 | return x, feats 107 | 108 | @staticmethod 109 | def _make_layers(cfg, batch_norm=False, in_channels=3): 110 | layers = [] 111 | for v in cfg: 112 | if v == "M": 113 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 114 | else: 115 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 116 | if batch_norm: 117 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 118 | else: 119 | layers += [conv2d, nn.ReLU(inplace=True)] 120 | in_channels = v 121 | layers = layers[:-1] 122 | return nn.Sequential(*layers) 123 | 124 | def _initialize_weights(self): 125 | for m in self.modules(): 126 | if isinstance(m, nn.Conv2d): 127 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 128 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 129 | if m.bias is not None: 130 | m.bias.data.zero_() 131 | elif isinstance(m, nn.BatchNorm2d): 132 | m.weight.data.fill_(1) 133 | m.bias.data.zero_() 134 | elif isinstance(m, nn.Linear): 135 | n = m.weight.size(1) 136 | m.weight.data.normal_(0, 0.01) 137 | m.bias.data.zero_() 138 | 139 | 140 | cfg = { 141 | "A": [[64], [128], [256, 256], [512, 512], [512, 512]], 142 | "B": [[64, 64], [128, 128], [256, 256], [512, 512], [512, 512]], 143 | "D": [[64, 64], [128, 128], [256, 256, 256], [512, 512, 512], [512, 512, 512]], 144 | "E": [ 145 | [64, 64], 146 | [128, 128], 147 | [256, 256, 256, 256], 148 | [512, 512, 512, 512], 149 | [512, 512, 512, 512], 150 | ], 151 | "S": [[64], [128], [256], [512], [512]], 152 | } 153 | 154 | 155 | def vgg8(**kwargs): 156 | """VGG 8-layer model (configuration "S") 157 | Args: 158 | pretrained (bool): If True, returns a model pre-trained on ImageNet 159 | """ 160 | model = VGG(cfg["S"], **kwargs) 161 | return model 162 | 163 | 164 | def vgg8_bn(**kwargs): 165 | """VGG 8-layer model (configuration "S") 166 | Args: 167 | pretrained (bool): If True, returns a model pre-trained on ImageNet 168 | """ 169 | model = VGG(cfg["S"], batch_norm=True, **kwargs) 170 | return model 171 | 172 | 173 | def vgg11(**kwargs): 174 | """VGG 11-layer model (configuration "A") 175 | Args: 176 | pretrained (bool): If True, returns a model pre-trained on ImageNet 177 | """ 178 | model = VGG(cfg["A"], **kwargs) 179 | return model 180 | 181 | 182 | def vgg11_bn(**kwargs): 183 | """VGG 11-layer model (configuration "A") with batch normalization""" 184 | model = VGG(cfg["A"], batch_norm=True, **kwargs) 185 | return model 186 | 187 | 188 | def vgg13(**kwargs): 189 | """VGG 13-layer model (configuration "B") 190 | Args: 191 | pretrained (bool): If True, returns a model pre-trained on ImageNet 192 | """ 193 | model = VGG(cfg["B"], **kwargs) 194 | return model 195 | 196 | 197 | def vgg13_bn(**kwargs): 198 | """VGG 13-layer model (configuration "B") with batch normalization""" 199 | model = VGG(cfg["B"], batch_norm=True, **kwargs) 200 | return model 201 | 202 | 203 | def vgg16(**kwargs): 204 | """VGG 16-layer model (configuration "D") 205 | Args: 206 | pretrained (bool): If True, returns a model pre-trained on ImageNet 207 | """ 208 | model = VGG(cfg["D"], **kwargs) 209 | return model 210 | 211 | 212 | def vgg16_bn(**kwargs): 213 | """VGG 16-layer model (configuration "D") with batch normalization""" 214 | model = VGG(cfg["D"], batch_norm=True, **kwargs) 215 | return model 216 | 217 | 218 | def vgg19(**kwargs): 219 | """VGG 19-layer model (configuration "E") 220 | Args: 221 | pretrained (bool): If True, returns a model pre-trained on ImageNet 222 | """ 223 | model = VGG(cfg["E"], **kwargs) 224 | return model 225 | 226 | 227 | def vgg19_bn(**kwargs): 228 | """VGG 19-layer model (configuration 'E') with batch normalization""" 229 | model = VGG(cfg["E"], batch_norm=True, **kwargs) 230 | return model 231 | 232 | 233 | if __name__ == "__main__": 234 | import torch 235 | 236 | x = torch.randn(2, 3, 32, 32) 237 | net = vgg19_bn(num_classes=100) 238 | logit, feats = net(x) 239 | 240 | for f in feats["feats"]: 241 | print(f.shape, f.min().item()) 242 | print(logit.shape) 243 | -------------------------------------------------------------------------------- /mdistiller/models/cifar/resnetv2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class BasicBlock(nn.Module): 7 | expansion = 1 8 | 9 | def __init__(self, in_planes, planes, stride=1, is_last=False): 10 | super(BasicBlock, self).__init__() 11 | self.is_last = is_last 12 | self.conv1 = nn.Conv2d( 13 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False 14 | ) 15 | self.bn1 = nn.BatchNorm2d(planes) 16 | self.conv2 = nn.Conv2d( 17 | planes, planes, kernel_size=3, stride=1, padding=1, bias=False 18 | ) 19 | self.bn2 = nn.BatchNorm2d(planes) 20 | 21 | self.shortcut = nn.Sequential() 22 | if stride != 1 or in_planes != self.expansion * planes: 23 | self.shortcut = nn.Sequential( 24 | nn.Conv2d( 25 | in_planes, 26 | self.expansion * planes, 27 | kernel_size=1, 28 | stride=stride, 29 | bias=False, 30 | ), 31 | nn.BatchNorm2d(self.expansion * planes), 32 | ) 33 | 34 | def forward(self, x): 35 | out = F.relu(self.bn1(self.conv1(x))) 36 | out = self.bn2(self.conv2(out)) 37 | out += self.shortcut(x) 38 | preact = out 39 | out = F.relu(out) 40 | if self.is_last: 41 | return out, preact 42 | else: 43 | return out 44 | 45 | 46 | class Bottleneck(nn.Module): 47 | expansion = 4 48 | 49 | def __init__(self, in_planes, planes, stride=1, is_last=False): 50 | super(Bottleneck, self).__init__() 51 | self.is_last = is_last 52 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 53 | self.bn1 = nn.BatchNorm2d(planes) 54 | self.conv2 = nn.Conv2d( 55 | planes, planes, kernel_size=3, stride=stride, padding=1, bias=False 56 | ) 57 | self.bn2 = nn.BatchNorm2d(planes) 58 | self.conv3 = nn.Conv2d( 59 | planes, self.expansion * planes, kernel_size=1, bias=False 60 | ) 61 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 62 | 63 | self.shortcut = nn.Sequential() 64 | if stride != 1 or in_planes != self.expansion * planes: 65 | self.shortcut = nn.Sequential( 66 | nn.Conv2d( 67 | in_planes, 68 | self.expansion * planes, 69 | kernel_size=1, 70 | stride=stride, 71 | bias=False, 72 | ), 73 | nn.BatchNorm2d(self.expansion * planes), 74 | ) 75 | 76 | def forward(self, x): 77 | out = F.relu(self.bn1(self.conv1(x))) 78 | out = F.relu(self.bn2(self.conv2(out))) 79 | out = self.bn3(self.conv3(out)) 80 | out += self.shortcut(x) 81 | preact = out 82 | out = F.relu(out) 83 | if self.is_last: 84 | return out, preact 85 | else: 86 | return out 87 | 88 | 89 | class ResNet(nn.Module): 90 | def __init__(self, block, num_blocks, num_classes=10, zero_init_residual=False): 91 | super(ResNet, self).__init__() 92 | self.in_planes = 64 93 | 94 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 95 | self.bn1 = nn.BatchNorm2d(64) 96 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 97 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 98 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 99 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 100 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 101 | self.linear = nn.Linear(512 * block.expansion, num_classes) 102 | 103 | for m in self.modules(): 104 | if isinstance(m, nn.Conv2d): 105 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 106 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 107 | nn.init.constant_(m.weight, 1) 108 | nn.init.constant_(m.bias, 0) 109 | 110 | # Zero-initialize the last BN in each residual branch, 111 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 112 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 113 | if zero_init_residual: 114 | for m in self.modules(): 115 | if isinstance(m, Bottleneck): 116 | nn.init.constant_(m.bn3.weight, 0) 117 | elif isinstance(m, BasicBlock): 118 | nn.init.constant_(m.bn2.weight, 0) 119 | self.stage_channels = [256, 512, 1024, 2048] 120 | 121 | def get_feat_modules(self): 122 | feat_m = nn.ModuleList([]) 123 | feat_m.append(self.conv1) 124 | feat_m.append(self.bn1) 125 | feat_m.append(self.layer1) 126 | feat_m.append(self.layer2) 127 | feat_m.append(self.layer3) 128 | feat_m.append(self.layer4) 129 | return feat_m 130 | 131 | def get_bn_before_relu(self): 132 | if isinstance(self.layer1[0], Bottleneck): 133 | bn1 = self.layer1[-1].bn3 134 | bn2 = self.layer2[-1].bn3 135 | bn3 = self.layer3[-1].bn3 136 | bn4 = self.layer4[-1].bn3 137 | elif isinstance(self.layer1[0], BasicBlock): 138 | bn1 = self.layer1[-1].bn2 139 | bn2 = self.layer2[-1].bn2 140 | bn3 = self.layer3[-1].bn2 141 | bn4 = self.layer4[-1].bn2 142 | else: 143 | raise NotImplementedError("ResNet unknown block error !!!") 144 | 145 | return [bn1, bn2, bn3, bn4] 146 | 147 | def get_stage_channels(self): 148 | return self.stage_channels 149 | 150 | def _make_layer(self, block, planes, num_blocks, stride): 151 | strides = [stride] + [1] * (num_blocks - 1) 152 | layers = [] 153 | for i in range(num_blocks): 154 | stride = strides[i] 155 | layers.append(block(self.in_planes, planes, stride, i == num_blocks - 1)) 156 | self.in_planes = planes * block.expansion 157 | return nn.Sequential(*layers) 158 | 159 | def encode(self, x, idx, preact=False): 160 | if idx == -1: 161 | out, pre = self.layer4(F.relu(x)) 162 | elif idx == -2: 163 | out, pre = self.layer3(F.relu(x)) 164 | elif idx == -3: 165 | out, pre = self.layer2(F.relu(x)) 166 | else: 167 | raise NotImplementedError() 168 | return pre 169 | 170 | def forward(self, x): 171 | out = F.relu(self.bn1(self.conv1(x))) 172 | f0 = out 173 | out, f1_pre = self.layer1(out) 174 | f1 = out 175 | out, f2_pre = self.layer2(out) 176 | f2 = out 177 | out, f3_pre = self.layer3(out) 178 | f3 = out 179 | out, f4_pre = self.layer4(out) 180 | f4 = out 181 | out = self.avgpool(out) 182 | avg = out.reshape(out.size(0), -1) 183 | out = self.linear(avg) 184 | 185 | feats = {} 186 | feats["feats"] = [f0, f1, f2, f3, f4] 187 | feats["preact_feats"] = [f0, f1_pre, f2_pre, f3_pre, f4_pre] 188 | feats["pooled_feat"] = avg 189 | 190 | return out, feats 191 | 192 | 193 | def ResNet18(**kwargs): 194 | return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 195 | 196 | 197 | def ResNet34(**kwargs): 198 | return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 199 | 200 | 201 | def ResNet50(**kwargs): 202 | return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 203 | 204 | 205 | def ResNet101(**kwargs): 206 | return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 207 | 208 | 209 | def ResNet152(**kwargs): 210 | return ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 211 | 212 | 213 | if __name__ == "__main__": 214 | net = ResNet18(num_classes=100) 215 | x = torch.randn(2, 3, 32, 32) 216 | logit, feats = net(x) 217 | 218 | for f in feats["feats"]: 219 | print(f.shape, f.min().item()) 220 | print(logit.shape) 221 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | This repo is 4 | 5 | (1) a PyTorch library that provides classical knowledge distillation algorithms on mainstream CV benchmarks, 6 | 7 | (2) the official implementation of the CVPR-2022 paper: [Decoupled Knowledge Distillation](https://arxiv.org/abs/2203.08679). 8 | 9 | (3) the official implementation of the ICCV-2023 paper: [DOT: A Distillation-Oriented Trainer](https://openaccess.thecvf.com/content/ICCV2023/papers/Zhao_DOT_A_Distillation-Oriented_Trainer_ICCV_2023_paper.pdf). 10 | 11 | 12 | # DOT: A Distillation-Oriented Trainer 13 | 14 | ### Framework 15 | 16 |
17 | 18 | ### Main Benchmark Results 19 | 20 | On CIFAR-100: 21 | 22 | | Teacher
Student | ResNet32x4
ResNet8x4| VGG13
VGG8| ResNet32x4
ShuffleNet-V2| 23 | |:---------------:|:-----------------:|:-----------------:|:-----------------:| 24 | | KD | 73.33 | 72.98 | 74.45 | 25 | | **KD+DOT** | **75.12** | **73.77** | **75.55** | 26 | 27 | On Tiny-ImageNet: 28 | 29 | | Teacher
Student |ResNet18
MobileNet-V2|ResNet18
ShuffleNet-V2| 30 | |:---------------:|:-----------------:|:-----------------:| 31 | | KD | 58.35 | 62.26 | 32 | | **KD+DOT** | **64.01** | **65.75** | 33 | 34 | On ImageNet: 35 | 36 | | Teacher
Student |ResNet34
ResNet18|ResNet50
MobileNet-V1| 37 | |:---------------:|:-----------------:|:-----------------:| 38 | | KD | 71.03 | 70.50 | 39 | | **KD+DOT** | **71.72** | **73.09** | 40 | 41 | # Decoupled Knowledge Distillation 42 | 43 | ### Framework & Performance 44 | 45 |
46 | 47 | ### Main Benchmark Results 48 | 49 | On CIFAR-100: 50 | 51 | 52 | | Teacher
Student |ResNet56
ResNet20|ResNet110
ResNet32| ResNet32x4
ResNet8x4| WRN-40-2
WRN-16-2| WRN-40-2
WRN-40-1 | VGG13
VGG8| 53 | |:---------------:|:-----------------:|:-----------------:|:-----------------:|:------------------:|:------------------:|:--------------------:| 54 | | KD | 70.66 | 73.08 | 73.33 | 74.92 | 73.54 | 72.98 | 55 | | **DKD** | **71.97** | **74.11** | **76.32** | **76.23** | **74.81** | **74.68** | 56 | 57 | 58 | | Teacher
Student |ResNet32x4
ShuffleNet-V1|WRN-40-2
ShuffleNet-V1| VGG13
MobileNet-V2| ResNet50
MobileNet-V2| ResNet32x4
MobileNet-V2| 59 | |:---------------:|:-----------------:|:-----------------:|:-----------------:|:------------------:|:------------------:| 60 | | KD | 74.07 | 74.83 | 67.37 | 67.35 | 74.45 | 61 | | **DKD** | **76.45** | **76.70** | **69.71** | **70.35** | **77.07** | 62 | 63 | 64 | On ImageNet: 65 | 66 | | Teacher
Student |ResNet34
ResNet18|ResNet50
MobileNet-V1| 67 | |:---------------:|:-----------------:|:-----------------:| 68 | | KD | 71.03 | 70.50 | 69 | | **DKD** | **71.70** | **72.05** | 70 | 71 | # MDistiller 72 | 73 | ### Introduction 74 | 75 | MDistiller supports the following distillation methods on CIFAR-100, ImageNet and MS-COCO: 76 | |Method|Paper Link|CIFAR-100|ImageNet|MS-COCO| 77 | |:---:|:---:|:---:|:---:|:---:| 78 | |KD| |✓|✓| | 79 | |FitNet| |✓| | | 80 | |AT| |✓|✓| | 81 | |NST| |✓| | | 82 | |PKT| |✓| | | 83 | |KDSVD| |✓| | | 84 | |OFD| |✓|✓| | 85 | |RKD| |✓| | | 86 | |VID| |✓| | | 87 | |SP| |✓| | | 88 | |CRD| |✓|✓| | 89 | |ReviewKD| |✓|✓|✓| 90 | |DKD| |✓|✓|✓| 91 | 92 | 93 | ### Installation 94 | 95 | Environments: 96 | 97 | - Python 3.6 98 | - PyTorch 1.9.0 99 | - torchvision 0.10.0 100 | 101 | Install the package: 102 | 103 | ``` 104 | sudo pip3 install -r requirements.txt 105 | sudo python3 setup.py develop 106 | ``` 107 | 108 | ### Getting started 109 | 110 | 0. Wandb as the logger 111 | 112 | - The registeration: . 113 | - If you don't want wandb as your logger, set `CFG.LOG.WANDB` as `False` at `mdistiller/engine/cfg.py`. 114 | 115 | 1. Evaluation 116 | 117 | - You can evaluate the performance of our models or models trained by yourself. 118 | 119 | - Our models are at , please download the checkpoints to `./download_ckpts` 120 | 121 | - If test the models on ImageNet, please download the dataset at and put them to `./data/imagenet` 122 | 123 | ```bash 124 | # evaluate teachers 125 | python3 tools/eval.py -m resnet32x4 # resnet32x4 on cifar100 126 | python3 tools/eval.py -m ResNet34 -d imagenet # ResNet34 on imagenet 127 | 128 | # evaluate students 129 | python3 tools/eval.p -m resnet8x4 -c download_ckpts/dkd_resnet8x4 # dkd-resnet8x4 on cifar100 130 | python3 tools/eval.p -m MobileNetV1 -c download_ckpts/imgnet_dkd_mv1 -d imagenet # dkd-mv1 on imagenet 131 | python3 tools/eval.p -m model_name -c output/your_exp/student_best # your checkpoints 132 | ``` 133 | 134 | 135 | 2. Training on CIFAR-100 136 | 137 | - Download the `cifar_teachers.tar` at and untar it to `./download_ckpts` via `tar xvf cifar_teachers.tar`. 138 | 139 | ```bash 140 | # for instance, our DKD method. 141 | python3 tools/train.py --cfg configs/cifar100/dkd/res32x4_res8x4.yaml 142 | 143 | # you can also change settings at command line 144 | python3 tools/train.py --cfg configs/cifar100/dkd/res32x4_res8x4.yaml SOLVER.BATCH_SIZE 128 SOLVER.LR 0.1 145 | ``` 146 | 147 | 3. Training on ImageNet 148 | 149 | - Download the dataset at and put them to `./data/imagenet` 150 | 151 | ```bash 152 | # for instance, our DKD method. 153 | python3 tools/train.py --cfg configs/imagenet/r34_r18/dkd.yaml 154 | ``` 155 | 156 | 4. Training on MS-COCO 157 | 158 | - see [detection.md](detection/README.md) 159 | 160 | 161 | 5. Extension: Visualizations 162 | 163 | - Jupyter notebooks: [tsne](tools/visualizations/tsne.ipynb) and [correlation_matrices](tools/visualizations/correlation.ipynb) 164 | 165 | 166 | ### Custom Distillation Method 167 | 168 | 1. create a python file at `mdistiller/distillers/` and define the distiller 169 | 170 | ```python 171 | from ._base import Distiller 172 | 173 | class MyDistiller(Distiller): 174 | def __init__(self, student, teacher, cfg): 175 | super(MyDistiller, self).__init__(student, teacher) 176 | self.hyper1 = cfg.MyDistiller.hyper1 177 | ... 178 | 179 | def forward_train(self, image, target, **kwargs): 180 | # return the output logits and a Dict of losses 181 | ... 182 | # rewrite the get_learnable_parameters function if there are more nn modules for distillation. 183 | # rewrite the get_extra_parameters if you want to obtain the extra cost. 184 | ... 185 | ``` 186 | 187 | 2. regist the distiller in `distiller_dict` at `mdistiller/distillers/__init__.py` 188 | 189 | 3. regist the corresponding hyper-parameters at `mdistiller/engines/cfg.py` 190 | 191 | 4. create a new config file and test it. 192 | 193 | # Citation 194 | 195 | If this repo is helpful for your research, please consider citing the paper: 196 | 197 | ```BibTeX 198 | @article{zhao2022dkd, 199 | title={Decoupled Knowledge Distillation}, 200 | author={Zhao, Borui and Cui, Quan and Song, Renjie and Qiu, Yiyu and Liang, Jiajun}, 201 | journal={arXiv preprint arXiv:2203.08679}, 202 | year={2022} 203 | } 204 | @article{zhao2023dot, 205 | title={DOT: A Distillation-Oriented Trainer}, 206 | author={Zhao, Borui and Cui, Quan and Song, Renjie and Liang, Jiajun}, 207 | journal={arXiv preprint arXiv:2307.08436}, 208 | year={2023} 209 | } 210 | ``` 211 | 212 | # License 213 | 214 | MDistiller is released under the MIT license. See [LICENSE](LICENSE) for details. 215 | 216 | # Acknowledgement 217 | 218 | - Thanks for CRD and ReviewKD. We build this library based on the [CRD's codebase](https://github.com/HobbitLong/RepDistiller) and the [ReviewKD's codebase](https://github.com/dvlab-research/ReviewKD). 219 | 220 | - Thanks Yiyu Qiu and Yi Shi for the code contribution during their internship in MEGVII Technology. 221 | 222 | - Thanks Xin Jin for the discussion about DKD. 223 | -------------------------------------------------------------------------------- /mdistiller/models/cifar/resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | __all__ = ["resnet"] 7 | 8 | 9 | def conv3x3(in_planes, out_planes, stride=1): 10 | """3x3 convolution with padding""" 11 | return nn.Conv2d( 12 | in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False 13 | ) 14 | 15 | 16 | class BasicBlock(nn.Module): 17 | expansion = 1 18 | 19 | def __init__(self, inplanes, planes, stride=1, downsample=None, is_last=False): 20 | super(BasicBlock, self).__init__() 21 | self.is_last = is_last 22 | self.conv1 = conv3x3(inplanes, planes, stride) 23 | self.bn1 = nn.BatchNorm2d(planes) 24 | self.relu = nn.ReLU(inplace=True) 25 | self.conv2 = conv3x3(planes, planes) 26 | self.bn2 = nn.BatchNorm2d(planes) 27 | self.downsample = downsample 28 | self.stride = stride 29 | 30 | def forward(self, x): 31 | residual = x 32 | 33 | out = self.conv1(x) 34 | out = self.bn1(out) 35 | out = self.relu(out) 36 | 37 | out = self.conv2(out) 38 | out = self.bn2(out) 39 | 40 | if self.downsample is not None: 41 | residual = self.downsample(x) 42 | 43 | out += residual 44 | preact = out 45 | out = F.relu(out) 46 | if self.is_last: 47 | return out, preact 48 | else: 49 | return out 50 | 51 | 52 | class Bottleneck(nn.Module): 53 | expansion = 4 54 | 55 | def __init__(self, inplanes, planes, stride=1, downsample=None, is_last=False): 56 | super(Bottleneck, self).__init__() 57 | self.is_last = is_last 58 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 59 | self.bn1 = nn.BatchNorm2d(planes) 60 | self.conv2 = nn.Conv2d( 61 | planes, planes, kernel_size=3, stride=stride, padding=1, bias=False 62 | ) 63 | self.bn2 = nn.BatchNorm2d(planes) 64 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 65 | self.bn3 = nn.BatchNorm2d(planes * 4) 66 | self.relu = nn.ReLU(inplace=True) 67 | self.downsample = downsample 68 | self.stride = stride 69 | 70 | def forward(self, x): 71 | residual = x 72 | 73 | out = self.conv1(x) 74 | out = self.bn1(out) 75 | out = self.relu(out) 76 | 77 | out = self.conv2(out) 78 | out = self.bn2(out) 79 | out = self.relu(out) 80 | 81 | out = self.conv3(out) 82 | out = self.bn3(out) 83 | 84 | if self.downsample is not None: 85 | residual = self.downsample(x) 86 | 87 | out += residual 88 | preact = out 89 | out = F.relu(out) 90 | if self.is_last: 91 | return out, preact 92 | else: 93 | return out 94 | 95 | 96 | class ResNet(nn.Module): 97 | def __init__(self, depth, num_filters, block_name="BasicBlock", num_classes=10): 98 | super(ResNet, self).__init__() 99 | # Model type specifies number of layers for CIFAR-10 model 100 | if block_name.lower() == "basicblock": 101 | assert ( 102 | depth - 2 103 | ) % 6 == 0, "When use basicblock, depth should be 6n+2, e.g. 20, 32, 44, 56, 110, 1202" 104 | n = (depth - 2) // 6 105 | block = BasicBlock 106 | elif block_name.lower() == "bottleneck": 107 | assert ( 108 | depth - 2 109 | ) % 9 == 0, "When use bottleneck, depth should be 9n+2, e.g. 20, 29, 47, 56, 110, 1199" 110 | n = (depth - 2) // 9 111 | block = Bottleneck 112 | else: 113 | raise ValueError("block_name shoule be Basicblock or Bottleneck") 114 | 115 | self.inplanes = num_filters[0] 116 | self.conv1 = nn.Conv2d(3, num_filters[0], kernel_size=3, padding=1, bias=False) 117 | self.bn1 = nn.BatchNorm2d(num_filters[0]) 118 | self.relu = nn.ReLU(inplace=True) 119 | self.layer1 = self._make_layer(block, num_filters[1], n) 120 | self.layer2 = self._make_layer(block, num_filters[2], n, stride=2) 121 | self.layer3 = self._make_layer(block, num_filters[3], n, stride=2) 122 | self.avgpool = nn.AvgPool2d(8) 123 | self.fc = nn.Linear(num_filters[3] * block.expansion, num_classes) 124 | self.stage_channels = num_filters 125 | 126 | for m in self.modules(): 127 | if isinstance(m, nn.Conv2d): 128 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 129 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 130 | nn.init.constant_(m.weight, 1) 131 | nn.init.constant_(m.bias, 0) 132 | 133 | def _make_layer(self, block, planes, blocks, stride=1): 134 | downsample = None 135 | if stride != 1 or self.inplanes != planes * block.expansion: 136 | downsample = nn.Sequential( 137 | nn.Conv2d( 138 | self.inplanes, 139 | planes * block.expansion, 140 | kernel_size=1, 141 | stride=stride, 142 | bias=False, 143 | ), 144 | nn.BatchNorm2d(planes * block.expansion), 145 | ) 146 | 147 | layers = list([]) 148 | layers.append( 149 | block(self.inplanes, planes, stride, downsample, is_last=(blocks == 1)) 150 | ) 151 | self.inplanes = planes * block.expansion 152 | for i in range(1, blocks): 153 | layers.append(block(self.inplanes, planes, is_last=(i == blocks - 1))) 154 | 155 | return nn.Sequential(*layers) 156 | 157 | def get_feat_modules(self): 158 | feat_m = nn.ModuleList([]) 159 | feat_m.append(self.conv1) 160 | feat_m.append(self.bn1) 161 | feat_m.append(self.relu) 162 | feat_m.append(self.layer1) 163 | feat_m.append(self.layer2) 164 | feat_m.append(self.layer3) 165 | return feat_m 166 | 167 | def get_bn_before_relu(self): 168 | if isinstance(self.layer1[0], Bottleneck): 169 | bn1 = self.layer1[-1].bn3 170 | bn2 = self.layer2[-1].bn3 171 | bn3 = self.layer3[-1].bn3 172 | elif isinstance(self.layer1[0], BasicBlock): 173 | bn1 = self.layer1[-1].bn2 174 | bn2 = self.layer2[-1].bn2 175 | bn3 = self.layer3[-1].bn2 176 | else: 177 | raise NotImplementedError("ResNet unknown block error !!!") 178 | 179 | return [bn1, bn2, bn3] 180 | 181 | def get_stage_channels(self): 182 | return self.stage_channels 183 | 184 | def forward(self, x): 185 | x = self.conv1(x) 186 | x = self.bn1(x) 187 | x = self.relu(x) # 32x32 188 | f0 = x 189 | 190 | x, f1_pre = self.layer1(x) # 32x32 191 | f1 = x 192 | x, f2_pre = self.layer2(x) # 16x16 193 | f2 = x 194 | x, f3_pre = self.layer3(x) # 8x8 195 | f3 = x 196 | 197 | x = self.avgpool(x) 198 | avg = x.reshape(x.size(0), -1) 199 | out = self.fc(avg) 200 | 201 | feats = {} 202 | feats["feats"] = [f0, f1, f2, f3] 203 | feats["preact_feats"] = [f0, f1_pre, f2_pre, f3_pre] 204 | feats["pooled_feat"] = avg 205 | 206 | return out, feats 207 | 208 | 209 | def resnet8(**kwargs): 210 | return ResNet(8, [16, 16, 32, 64], "basicblock", **kwargs) 211 | 212 | 213 | def resnet14(**kwargs): 214 | return ResNet(14, [16, 16, 32, 64], "basicblock", **kwargs) 215 | 216 | 217 | def resnet20(**kwargs): 218 | return ResNet(20, [16, 16, 32, 64], "basicblock", **kwargs) 219 | 220 | 221 | def resnet32(**kwargs): 222 | return ResNet(32, [16, 16, 32, 64], "basicblock", **kwargs) 223 | 224 | 225 | def resnet44(**kwargs): 226 | return ResNet(44, [16, 16, 32, 64], "basicblock", **kwargs) 227 | 228 | 229 | def resnet56(**kwargs): 230 | return ResNet(56, [16, 16, 32, 64], "basicblock", **kwargs) 231 | 232 | 233 | def resnet110(**kwargs): 234 | return ResNet(110, [16, 16, 32, 64], "basicblock", **kwargs) 235 | 236 | 237 | def resnet8x4(**kwargs): 238 | return ResNet(8, [32, 64, 128, 256], "basicblock", **kwargs) 239 | 240 | 241 | def resnet32x4(**kwargs): 242 | return ResNet(32, [32, 64, 128, 256], "basicblock", **kwargs) 243 | 244 | 245 | if __name__ == "__main__": 246 | import torch 247 | 248 | x = torch.randn(2, 3, 32, 32) 249 | net = resnet8x4(num_classes=20) 250 | logit, feats = net(x) 251 | 252 | for f in feats["feats"]: 253 | print(f.shape, f.min().item()) 254 | print(logit.shape) 255 | --------------------------------------------------------------------------------