├── README.md ├── configs ├── cifar100 │ ├── at.yaml │ ├── crd.yaml │ ├── dkd │ │ ├── ResNet50_MobileNetV2.yaml │ │ ├── resnet110_resnet20.yaml │ │ ├── resnet110_resnet32.yaml │ │ ├── resnet32x4_ShuffleV2.yaml │ │ ├── resnet32x4_resnet8x4.yaml │ │ ├── resnet32x4_wrn_16_2.yaml │ │ ├── resnet32x4_wrn_40_2.yaml │ │ ├── resnet56_resnet20.yaml │ │ ├── vgg13_MobileNetV2.yaml │ │ ├── vgg13_vgg8.yaml │ │ ├── wrn_40_2_MobileNetV2.yaml │ │ ├── wrn_40_2_resnet8x4.yaml │ │ ├── wrn_40_2_wrn_16_2.yaml │ │ └── wrn_40_2_wrn_40_1.yaml │ ├── fitnet.yaml │ ├── kd │ │ ├── ResNet50_MobileNetV2.yaml │ │ ├── resnet110_resnet20.yaml │ │ ├── resnet110_resnet32.yaml │ │ ├── resnet32x4_ShuffleV2.yaml │ │ ├── resnet32x4_resnet8x4.yaml │ │ ├── resnet32x4_wrn_16_2.yaml │ │ ├── resnet32x4_wrn_40_2.yaml │ │ ├── resnet56_resnet20.yaml │ │ ├── vgg13_MobileNetV2.yaml │ │ ├── vgg13_vgg8.yaml │ │ ├── wrn_40_2_MobileNetV2.yaml │ │ ├── wrn_40_2_resnet8x4.yaml │ │ ├── wrn_40_2_wrn_16_2.yaml │ │ └── wrn_40_2_wrn_40_1.yaml │ ├── kdsvd.yaml │ ├── mlkd │ │ ├── ResNet50_MobileNetV2.yaml │ │ ├── resnet110_resnet20.yaml │ │ ├── resnet110_resnet32.yaml │ │ ├── resnet32x4_ShuffleV2.yaml │ │ ├── resnet32x4_resnet8x4.yaml │ │ ├── resnet32x4_wrn_16_2.yaml │ │ ├── resnet32x4_wrn_40_2.yaml │ │ ├── resnet56_resnet20.yaml │ │ ├── vgg13_MobileNetV2.yaml │ │ ├── vgg13_vgg8.yaml │ │ ├── wrn_40_2_MobileNetV2.yaml │ │ ├── wrn_40_2_resnet8x4.yaml │ │ ├── wrn_40_2_wrn_16_2.yaml │ │ └── wrn_40_2_wrn_40_1.yaml │ ├── mlkd_noaug │ │ ├── ResNet50_MobileNetV2.yaml │ │ ├── resnet110_resnet20.yaml │ │ ├── resnet110_resnet32.yaml │ │ ├── resnet32x4_ShuffleV2.yaml │ │ ├── resnet32x4_resnet8x4.yaml │ │ ├── resnet32x4_wrn_16_2.yaml │ │ ├── resnet32x4_wrn_40_2.yaml │ │ ├── resnet56_resnet20.yaml │ │ ├── vgg13_MobileNetV2.yaml │ │ ├── vgg13_vgg8.yaml │ │ ├── wrn_40_2_MobileNetV2.yaml │ │ ├── wrn_40_2_resnet8x4.yaml │ │ ├── wrn_40_2_wrn_16_2.yaml │ │ └── wrn_40_2_wrn_40_1.yaml │ ├── nst.yaml │ ├── ofd.yaml │ ├── pkt.yaml │ ├── rc │ │ ├── ResNet50_MobileNetV2.yaml │ │ ├── resnet110_resnet20.yaml │ │ ├── resnet110_resnet32.yaml │ │ ├── resnet32x4_ShuffleV2.yaml │ │ ├── resnet32x4_resnet8x4.yaml │ │ ├── resnet32x4_wrn_16_2.yaml │ │ ├── resnet32x4_wrn_40_2.yaml │ │ ├── resnet56_resnet20.yaml │ │ ├── vgg13_MobileNetV2.yaml │ │ ├── vgg13_vgg8.yaml │ │ ├── wrn_40_2_MobileNetV2.yaml │ │ ├── wrn_40_2_resnet8x4.yaml │ │ ├── wrn_40_2_wrn_16_2.yaml │ │ └── wrn_40_2_wrn_40_1.yaml │ ├── reviewkd.yaml │ ├── revision │ │ ├── ResNet50_MobileNetV2.yaml │ │ ├── resnet110_resnet20.yaml │ │ ├── resnet110_resnet32.yaml │ │ ├── resnet32x4_ShuffleV2.yaml │ │ ├── resnet32x4_resnet8x4.yaml │ │ ├── resnet32x4_wrn_16_2.yaml │ │ ├── resnet32x4_wrn_40_2.yaml │ │ ├── resnet56_resnet20.yaml │ │ ├── vgg13_MobileNetV2.yaml │ │ ├── vgg13_vgg8.yaml │ │ ├── wrn_40_2_MobileNetV2.yaml │ │ ├── wrn_40_2_resnet8x4.yaml │ │ ├── wrn_40_2_wrn_16_2.yaml │ │ └── wrn_40_2_wrn_40_1.yaml │ ├── rkd.yaml │ ├── rld │ │ ├── ResNet50_MobileNetV2.yaml │ │ ├── resnet110_resnet20.yaml │ │ ├── resnet110_resnet32.yaml │ │ ├── resnet32x4_ShuffleV2.yaml │ │ ├── resnet32x4_resnet8x4.yaml │ │ ├── resnet32x4_wrn_16_2.yaml │ │ ├── resnet32x4_wrn_40_2.yaml │ │ ├── resnet56_resnet20.yaml │ │ ├── vgg13_MobileNetV2.yaml │ │ ├── vgg13_vgg8.yaml │ │ ├── wrn_40_2_MobileNetV2.yaml │ │ ├── wrn_40_2_resnet8x4.yaml │ │ ├── wrn_40_2_wrn_16_2.yaml │ │ └── wrn_40_2_wrn_40_1.yaml │ ├── sp.yaml │ ├── swap │ │ ├── ResNet50_MobileNetV2.yaml │ │ ├── resnet110_resnet20.yaml │ │ ├── resnet110_resnet32.yaml │ │ ├── resnet32x4_ShuffleV2.yaml │ │ ├── resnet32x4_resnet8x4.yaml │ │ ├── resnet32x4_wrn_16_2.yaml │ │ ├── resnet32x4_wrn_40_2.yaml │ │ ├── resnet56_resnet20.yaml │ │ ├── vgg13_MobileNetV2.yaml │ │ ├── vgg13_vgg8.yaml │ │ ├── wrn_40_2_MobileNetV2.yaml │ │ ├── wrn_40_2_resnet8x4.yaml │ │ ├── wrn_40_2_wrn_16_2.yaml │ │ └── wrn_40_2_wrn_40_1.yaml │ ├── vanilla │ │ ├── ResNet50_MobileNetV2.yaml │ │ ├── resnet110_resnet20.yaml │ │ ├── resnet110_resnet32.yaml │ │ ├── resnet32x4_ShuffleV2.yaml │ │ ├── resnet32x4_resnet8x4.yaml │ │ ├── resnet32x4_wrn_16_2.yaml │ │ ├── resnet32x4_wrn_40_2.yaml │ │ ├── resnet56_resnet20.yaml │ │ ├── vgg13_MobileNetV2.yaml │ │ ├── vgg13_vgg8.yaml │ │ ├── wrn_40_2_MobileNetV2.yaml │ │ ├── wrn_40_2_resnet8x4.yaml │ │ ├── wrn_40_2_wrn_16_2.yaml │ │ └── wrn_40_2_wrn_40_1.yaml │ └── vid.yaml └── imagenet │ ├── r34_r18 │ ├── at.yaml │ ├── crd.yaml │ ├── dkd.yaml │ ├── kd.yaml │ ├── mlkd.yaml │ ├── rc.yaml │ ├── reviewkd.yaml │ ├── revision.yaml │ ├── rld.yaml │ └── swap.yaml │ └── r50_mv1 │ ├── at.yaml │ ├── crd.yaml │ ├── dkd.yaml │ ├── kd.yaml │ ├── mlkd.yaml │ ├── ofd.yaml │ ├── rc.yaml │ ├── reviewkd.yaml │ ├── revision.yaml │ ├── rld.yaml │ └── swap.yaml ├── mdistiller ├── __init__.py ├── dataset │ ├── __init__.py │ ├── cifar100.py │ └── imagenet.py ├── distillers │ ├── AT.py │ ├── CRD.py │ ├── DKD.py │ ├── FitNet.py │ ├── KD.py │ ├── KDSVD.py │ ├── MLKD.py │ ├── MLKD_NOAUG.py │ ├── NST.py │ ├── OFD.py │ ├── PKT.py │ ├── RC.py │ ├── REVISION.py │ ├── RKD.py │ ├── RLD.py │ ├── ReviewKD.py │ ├── SP.py │ ├── SWAP.py │ ├── Sonly.py │ ├── VID.py │ ├── __init__.py │ ├── _base.py │ ├── _common.py │ └── loss.py ├── engine │ ├── __init__.py │ ├── cfg.py │ ├── trainer.py │ └── utils.py └── models │ ├── __init__.py │ ├── cifar │ ├── ShuffleNetv1.py │ ├── ShuffleNetv2.py │ ├── __init__.py │ ├── mobilenetv2.py │ ├── resnet.py │ ├── resnetv2.py │ ├── vgg.py │ └── wrn.py │ └── imagenet │ ├── __init__.py │ ├── mobilenetv1.py │ └── resnet.py └── tools ├── eval.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # Knowledge Distillation with Refined Logits 2 | 3 | ## Environment 4 | 5 | Python 3.8, torch 1.7.0 6 | 7 | A selection of packages that may require additional installation: torchvision, tensorboardX, yacs, wandb, tqdm, scipy 8 | 9 | ## Pre-trained Teachers 10 | 11 | Pre-trained teachers can be downloaded from [Decoupled Knowledge Distillation (CVPR 2022)](https://github.com/megvii-research/mdistiller/releases/tag/checkpoints). Download the `cifar_teachers.tar` and untar it to `./download_ckpts` via `tar xvf cifar_teachers.tar`. 12 | 13 | ## Training on CIFAR-100 14 | 15 | ```sh 16 | # Train method X with the following code 17 | CUDA_VISIBLE_DEVICES=0 python tools/train.py --cfg configs/cifar100/X/vgg13_vgg8.yaml 18 | # You can refer to the following code to additionally specify hyper-parameters 19 | # Train DKD 20 | CUDA_VISIBLE_DEVICES=0 python tools/train.py --cfg configs/cifar100/dkd/vgg13_vgg8.yaml DKD.ALPHA 1. DKD.BETA 8. DKD.T 4. 21 | # Train RLD 22 | CUDA_VISIBLE_DEVICES=0 python tools/train.py --cfg configs/cifar100/rld/vgg13_vgg8.yaml --same-t RLD.ALPHA 1. RLD.BETA 8. RLD.T 4. 23 | ``` 24 | 25 | ## Acknowledgment 26 | 27 | This codebase is heavily borrowed from [Logit Standardization in Knowledge Distillation (CVPR 2024)](https://github.com/sunshangquan/logit-standardization-KD). Sincere gratitude to the authors for their distinguished efforts. -------------------------------------------------------------------------------- /configs/cifar100/at.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "at,res32x4,res8x4" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "AT" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "resnet8x4" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | 19 | -------------------------------------------------------------------------------- /configs/cifar100/crd.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "crd,res32x4,res8x4" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "CRD" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "resnet8x4" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | TRAINER: "crd" 19 | -------------------------------------------------------------------------------- /configs/cifar100/dkd/ResNet50_MobileNetV2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "dkd,res50,mv2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "DKD" 7 | TEACHER: "ResNet50" 8 | STUDENT: "MobileNetV2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.01 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | 19 | -------------------------------------------------------------------------------- /configs/cifar100/dkd/resnet110_resnet20.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "dkd,res110,res20" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "DKD" 7 | TEACHER: "resnet110" 8 | STUDENT: "resnet20" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | DKD: 19 | BETA: 2.0 -------------------------------------------------------------------------------- /configs/cifar100/dkd/resnet110_resnet32.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "dkd,res110,res32" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "DKD" 7 | TEACHER: "resnet110" 8 | STUDENT: "resnet32" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | DKD: 19 | BETA: 2.0 20 | -------------------------------------------------------------------------------- /configs/cifar100/dkd/resnet32x4_ShuffleV2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "dkd,res32x4,shuv2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "DKD" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "ShuffleV2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.01 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | 19 | -------------------------------------------------------------------------------- /configs/cifar100/dkd/resnet32x4_resnet8x4.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "dkd,res32x4,res8x4" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "DKD" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "resnet8x4" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | 19 | -------------------------------------------------------------------------------- /configs/cifar100/dkd/resnet32x4_wrn_16_2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "dkd,res32x4,wrn_16_2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "DKD" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "wrn_16_2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/dkd/resnet32x4_wrn_40_2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "dkd,res32x4,wrn_40_2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "DKD" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "wrn_40_2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/dkd/resnet56_resnet20.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "dkd,res56,res20" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "DKD" 7 | TEACHER: "resnet56" 8 | STUDENT: "resnet20" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | DKD: 19 | BETA: 2.0 20 | -------------------------------------------------------------------------------- /configs/cifar100/dkd/vgg13_MobileNetV2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "dkd,vgg13,mv2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "DKD" 7 | TEACHER: "vgg13" 8 | STUDENT: "MobileNetV2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.01 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | DKD: 19 | BETA: 6.0 20 | 21 | -------------------------------------------------------------------------------- /configs/cifar100/dkd/vgg13_vgg8.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "dkd,vgg13,vgg8" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "DKD" 7 | TEACHER: "vgg13" 8 | STUDENT: "vgg8" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | DKD: 19 | BETA: 6.0 20 | -------------------------------------------------------------------------------- /configs/cifar100/dkd/wrn_40_2_MobileNetV2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "dkd,wrn_40_2,mv2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "DKD" 7 | TEACHER: "wrn_40_2" 8 | STUDENT: "MobileNetV2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.01 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | DKD: 19 | BETA: 6.0 -------------------------------------------------------------------------------- /configs/cifar100/dkd/wrn_40_2_resnet8x4.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "dkd,wrn_40_2,res8x4" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "DKD" 7 | TEACHER: "wrn_40_2" 8 | STUDENT: "resnet8x4" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | DKD: 19 | BETA: 6.0 -------------------------------------------------------------------------------- /configs/cifar100/dkd/wrn_40_2_wrn_16_2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "dkd,wrn_40_2,wrn_16_2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "DKD" 7 | TEACHER: "wrn_40_2" 8 | STUDENT: "wrn_16_2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | DKD: 19 | BETA: 6.0 20 | -------------------------------------------------------------------------------- /configs/cifar100/dkd/wrn_40_2_wrn_40_1.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "dkd,wrn_40_2,wrn_40_1" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "DKD" 7 | TEACHER: "wrn_40_2" 8 | STUDENT: "wrn_40_1" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | DKD: 19 | BETA: 6.0 20 | -------------------------------------------------------------------------------- /configs/cifar100/fitnet.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "fitnet,res32x4,res8x4" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "FITNET" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "resnet8x4" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/kd/ResNet50_MobileNetV2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "kd,res50,mv2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "KD" 7 | TEACHER: "ResNet50" 8 | STUDENT: "MobileNetV2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.01 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/kd/resnet110_resnet20.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "kd,res110,res20" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "KD" 7 | TEACHER: "resnet110" 8 | STUDENT: "resnet20" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/kd/resnet110_resnet32.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "kd,res110,res32" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "KD" 7 | TEACHER: "resnet110" 8 | STUDENT: "resnet32" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/kd/resnet32x4_ShuffleV2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "kd,res32x4,shuv2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "KD" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "ShuffleV2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.01 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/kd/resnet32x4_resnet8x4.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "kd,res32x4,res8x4" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "KD" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "resnet8x4" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/kd/resnet32x4_wrn_16_2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "kd,res32x4,wrn_16_2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "KD" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "wrn_16_2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/kd/resnet32x4_wrn_40_2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "kd,res32x4,wrn_40_2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "KD" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "wrn_40_2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/kd/resnet56_resnet20.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "kd,res56,res20" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "KD" 7 | TEACHER: "resnet56" 8 | STUDENT: "resnet20" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/kd/vgg13_MobileNetV2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "kd,vgg13,mv2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "KD" 7 | TEACHER: "vgg13" 8 | STUDENT: "MobileNetV2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.01 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/kd/vgg13_vgg8.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "kd,vgg13,vgg8" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "KD" 7 | TEACHER: "vgg13" 8 | STUDENT: "vgg8" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/kd/wrn_40_2_MobileNetV2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "kd,wrn_40_2,mv2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "KD" 7 | TEACHER: "wrn_40_2" 8 | STUDENT: "MobileNetV2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.01 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/kd/wrn_40_2_resnet8x4.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "kd,wrn_40_2,res8x4" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "KD" 7 | TEACHER: "wrn_40_2" 8 | STUDENT: "resnet8x4" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/kd/wrn_40_2_wrn_16_2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "kd,wrn_40_2,wrn_16_2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "KD" 7 | TEACHER: "wrn_40_2" 8 | STUDENT: "wrn_16_2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/kd/wrn_40_2_wrn_40_1.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "kd,wrn_40_2,wrn_40_1" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "KD" 7 | TEACHER: "wrn_40_2" 8 | STUDENT: "wrn_40_1" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/kdsvd.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "kdsvd,res32x4,res8x4" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "KDSVD" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "resnet8x4" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/mlkd/ResNet50_MobileNetV2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "mlkd,res50,mv2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "MLKD" 7 | TEACHER: "ResNet50" 8 | STUDENT: "MobileNetV2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 480 12 | LR: 0.01 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | TRAINER: "ours" 19 | -------------------------------------------------------------------------------- /configs/cifar100/mlkd/resnet110_resnet20.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "mlkd,res110,res20" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "MLKD" 7 | TEACHER: "resnet110" 8 | STUDENT: "resnet20" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 480 12 | LR: 0.025 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | TRAINER: "ours" 19 | -------------------------------------------------------------------------------- /configs/cifar100/mlkd/resnet110_resnet32.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "mlkd,res110,res32" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "MLKD" 7 | TEACHER: "resnet110" 8 | STUDENT: "resnet32" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 480 12 | LR: 0.025 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | TRAINER: "ours" 19 | -------------------------------------------------------------------------------- /configs/cifar100/mlkd/resnet32x4_ShuffleV2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "mlkd,res32x4,shuv2" 4 | PROJECT: "cifar100_baselines_epoch960" 5 | DISTILLER: 6 | TYPE: "MLKD" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "ShuffleV2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 960 12 | LR: 0.01 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | TRAINER: "ours" 19 | -------------------------------------------------------------------------------- /configs/cifar100/mlkd/resnet32x4_resnet8x4.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "mlkd,res32x4,res8x4" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "MLKD" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "resnet8x4" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 480 12 | LR: 0.025 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | TRAINER: "ours" 19 | -------------------------------------------------------------------------------- /configs/cifar100/mlkd/resnet32x4_wrn_16_2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "mlkd,res32x4,wrn_16_2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "MLKD" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "wrn_16_2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 480 12 | LR: 0.025 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | TRAINER: "ours" 19 | -------------------------------------------------------------------------------- /configs/cifar100/mlkd/resnet32x4_wrn_40_2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "mlkd,res32x4,wrn_40_2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "MLKD" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "wrn_40_2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 480 12 | LR: 0.025 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | TRAINER: "ours" 19 | -------------------------------------------------------------------------------- /configs/cifar100/mlkd/resnet56_resnet20.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "mlkd,res56,res20" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "MLKD" 7 | TEACHER: "resnet56" 8 | STUDENT: "resnet20" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 480 12 | LR: 0.025 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | TRAINER: "ours" 19 | -------------------------------------------------------------------------------- /configs/cifar100/mlkd/vgg13_MobileNetV2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "mlkd,vgg13,mv2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "MLKD" 7 | TEACHER: "vgg13" 8 | STUDENT: "MobileNetV2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 480 12 | LR: 0.01 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | TRAINER: "ours" 19 | -------------------------------------------------------------------------------- /configs/cifar100/mlkd/vgg13_vgg8.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "mlkd,vgg13,vgg8" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "MLKD" 7 | TEACHER: "vgg13" 8 | STUDENT: "vgg8" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 480 12 | LR: 0.025 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | TRAINER: "ours" 19 | -------------------------------------------------------------------------------- /configs/cifar100/mlkd/wrn_40_2_MobileNetV2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "mlkd,wrn_40_2,mv2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "MLKD" 7 | TEACHER: "wrn_40_2" 8 | STUDENT: "MobileNetV2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 480 12 | LR: 0.01 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | TRAINER: "ours" 19 | -------------------------------------------------------------------------------- /configs/cifar100/mlkd/wrn_40_2_resnet8x4.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "mlkd,wrn_40_2,res8x4" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "MLKD" 7 | TEACHER: "wrn_40_2" 8 | STUDENT: "resnet8x4" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 480 12 | LR: 0.025 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | TRAINER: "ours" 19 | -------------------------------------------------------------------------------- /configs/cifar100/mlkd/wrn_40_2_wrn_16_2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "mlkd,wrn_40_2,wrn_16_2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "MLKD" 7 | TEACHER: "wrn_40_2" 8 | STUDENT: "wrn_16_2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 480 12 | LR: 0.025 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | TRAINER: "ours" 19 | -------------------------------------------------------------------------------- /configs/cifar100/mlkd/wrn_40_2_wrn_40_1.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "mlkd,wrn_40_2,wrn_40_1" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "MLKD" 7 | TEACHER: "wrn_40_2" 8 | STUDENT: "wrn_40_1" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 480 12 | LR: 0.025 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | TRAINER: "ours" 19 | -------------------------------------------------------------------------------- /configs/cifar100/mlkd_noaug/ResNet50_MobileNetV2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "mlkd_noaug,res50,mv2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "MLKD_NOAUG" 7 | TEACHER: "ResNet50" 8 | STUDENT: "MobileNetV2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.01 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/mlkd_noaug/resnet110_resnet20.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "mlkd_noaug,res110,res20" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "MLKD_NOAUG" 7 | TEACHER: "resnet110" 8 | STUDENT: "resnet20" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/mlkd_noaug/resnet110_resnet32.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "mlkd_noaug,res110,res32" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "MLKD_NOAUG" 7 | TEACHER: "resnet110" 8 | STUDENT: "resnet32" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/mlkd_noaug/resnet32x4_ShuffleV2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "mlkd_noaug,res32x4,shuv2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "MLKD_NOAUG" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "ShuffleV2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.01 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/mlkd_noaug/resnet32x4_resnet8x4.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "mlkd_noaug,res32x4,res8x4" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "MLKD_NOAUG" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "resnet8x4" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/mlkd_noaug/resnet32x4_wrn_16_2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "mlkd_noaug,res32x4,wrn_16_2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "MLKD_NOAUG" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "wrn_16_2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/mlkd_noaug/resnet32x4_wrn_40_2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "mlkd_noaug,res32x4,wrn_40_2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "MLKD_NOAUG" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "wrn_40_2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/mlkd_noaug/resnet56_resnet20.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "mlkd_noaug,res56,res20" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "MLKD_NOAUG" 7 | TEACHER: "resnet56" 8 | STUDENT: "resnet20" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/mlkd_noaug/vgg13_MobileNetV2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "mlkd_noaug,vgg13,mv2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "MLKD_NOAUG" 7 | TEACHER: "vgg13" 8 | STUDENT: "MobileNetV2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.01 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/mlkd_noaug/vgg13_vgg8.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "mlkd_noaug,vgg13,vgg8" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "MLKD_NOAUG" 7 | TEACHER: "vgg13" 8 | STUDENT: "vgg8" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/mlkd_noaug/wrn_40_2_MobileNetV2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "mlkd_noaug,wrn_40_2,mv2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "MLKD_NOAUG" 7 | TEACHER: "wrn_40_2" 8 | STUDENT: "MobileNetV2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.01 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/mlkd_noaug/wrn_40_2_resnet8x4.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "mlkd_noaug,wrn_40_2,res8x4" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "MLKD_NOAUG" 7 | TEACHER: "wrn_40_2" 8 | STUDENT: "resnet8x4" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/mlkd_noaug/wrn_40_2_wrn_16_2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "mlkd_noaug,wrn_40_2,wrn_16_2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "MLKD_NOAUG" 7 | TEACHER: "wrn_40_2" 8 | STUDENT: "wrn_16_2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/mlkd_noaug/wrn_40_2_wrn_40_1.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "mlkd_noaug,wrn_40_2,wrn_40_1" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "MLKD_NOAUG" 7 | TEACHER: "wrn_40_2" 8 | STUDENT: "wrn_40_1" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/nst.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "nst,res32x4,res8x4" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "NST" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "resnet8x4" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/ofd.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "ofd,res32x4,res8x4" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "OFD" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "resnet8x4" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | 19 | -------------------------------------------------------------------------------- /configs/cifar100/pkt.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "pkt,res32x4,res8x4" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "PKT" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "resnet8x4" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/rc/ResNet50_MobileNetV2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "rc,res50,mv2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "RC" 7 | TEACHER: "ResNet50" 8 | STUDENT: "MobileNetV2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.01 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/rc/resnet110_resnet20.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "rc,res110,res20" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "RC" 7 | TEACHER: "resnet110" 8 | STUDENT: "resnet20" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/rc/resnet110_resnet32.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "rc,res110,res32" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "RC" 7 | TEACHER: "resnet110" 8 | STUDENT: "resnet32" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/rc/resnet32x4_ShuffleV2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "rc,res32x4,shuv2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "RC" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "ShuffleV2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.01 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/rc/resnet32x4_resnet8x4.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "rc,res32x4,res8x4" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "RC" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "resnet8x4" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/rc/resnet32x4_wrn_16_2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "rc,res32x4,wrn_16_2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "RC" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "wrn_16_2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/rc/resnet32x4_wrn_40_2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "rc,res32x4,wrn_40_2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "RC" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "wrn_40_2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/rc/resnet56_resnet20.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "rc,res56,res20" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "RC" 7 | TEACHER: "resnet56" 8 | STUDENT: "resnet20" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/rc/vgg13_MobileNetV2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "rc,vgg13,mv2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "RC" 7 | TEACHER: "vgg13" 8 | STUDENT: "MobileNetV2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.01 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/rc/vgg13_vgg8.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "rc,vgg13,vgg8" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "RC" 7 | TEACHER: "vgg13" 8 | STUDENT: "vgg8" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/rc/wrn_40_2_MobileNetV2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "rc,wrn_40_2,mv2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "RC" 7 | TEACHER: "wrn_40_2" 8 | STUDENT: "MobileNetV2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.01 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/rc/wrn_40_2_resnet8x4.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "rc,wrn_40_2,res8x4" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "RC" 7 | TEACHER: "wrn_40_2" 8 | STUDENT: "resnet8x4" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/rc/wrn_40_2_wrn_16_2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "rc,wrn_40_2,wrn_16_2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "RC" 7 | TEACHER: "wrn_40_2" 8 | STUDENT: "wrn_16_2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/rc/wrn_40_2_wrn_40_1.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "rc,wrn_40_2,wrn_40_1" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "RC" 7 | TEACHER: "wrn_40_2" 8 | STUDENT: "wrn_40_1" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/reviewkd.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "reviewkd,res32x4,res8x4" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "REVIEWKD" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "resnet8x4" 9 | REVIEWKD: 10 | REVIEWKD_WEIGHT: 5.0 11 | SOLVER: 12 | BATCH_SIZE: 64 13 | EPOCHS: 240 14 | LR: 0.05 15 | LR_DECAY_STAGES: [150, 180, 210] 16 | LR_DECAY_RATE: 0.1 17 | WEIGHT_DECAY: 0.0005 18 | MOMENTUM: 0.9 19 | TYPE: "SGD" 20 | 21 | -------------------------------------------------------------------------------- /configs/cifar100/revision/ResNet50_MobileNetV2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "revision,res50,mv2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "REVISION" 7 | TEACHER: "ResNet50" 8 | STUDENT: "MobileNetV2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.01 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/revision/resnet110_resnet20.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "revision,res110,res20" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "REVISION" 7 | TEACHER: "resnet110" 8 | STUDENT: "resnet20" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/revision/resnet110_resnet32.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "revision,res110,res32" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "REVISION" 7 | TEACHER: "resnet110" 8 | STUDENT: "resnet32" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/revision/resnet32x4_ShuffleV2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "revision,res32x4,shuv2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "REVISION" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "ShuffleV2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.01 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/revision/resnet32x4_resnet8x4.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "revision,res32x4,res8x4" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "REVISION" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "resnet8x4" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/revision/resnet32x4_wrn_16_2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "revision,res32x4,wrn_16_2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "REVISION" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "wrn_16_2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/revision/resnet32x4_wrn_40_2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "revision,res32x4,wrn_40_2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "REVISION" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "wrn_40_2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/revision/resnet56_resnet20.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "revision,res56,res20" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "REVISION" 7 | TEACHER: "resnet56" 8 | STUDENT: "resnet20" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/revision/vgg13_MobileNetV2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "revision,vgg13,mv2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "REVISION" 7 | TEACHER: "vgg13" 8 | STUDENT: "MobileNetV2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.01 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/revision/vgg13_vgg8.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "revision,vgg13,vgg8" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "REVISION" 7 | TEACHER: "vgg13" 8 | STUDENT: "vgg8" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/revision/wrn_40_2_MobileNetV2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "revision,wrn_40_2,mv2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "REVISION" 7 | TEACHER: "wrn_40_2" 8 | STUDENT: "MobileNetV2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.01 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/revision/wrn_40_2_resnet8x4.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "revision,wrn_40_2,res8x4" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "REVISION" 7 | TEACHER: "wrn_40_2" 8 | STUDENT: "resnet8x4" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/revision/wrn_40_2_wrn_16_2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "revision,wrn_40_2,wrn_16_2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "REVISION" 7 | TEACHER: "wrn_40_2" 8 | STUDENT: "wrn_16_2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/revision/wrn_40_2_wrn_40_1.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "revision,wrn_40_2,wrn_40_1" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "REVISION" 7 | TEACHER: "wrn_40_2" 8 | STUDENT: "wrn_40_1" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/rkd.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "rkd,res32x4,res8x4" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "RKD" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "resnet8x4" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/rld/ResNet50_MobileNetV2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "rld,res50,mv2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "RLD" 7 | TEACHER: "ResNet50" 8 | STUDENT: "MobileNetV2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.01 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/rld/resnet110_resnet20.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "rld,res110,res20" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "RLD" 7 | TEACHER: "resnet110" 8 | STUDENT: "resnet20" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/rld/resnet110_resnet32.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "rld,res110,res32" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "RLD" 7 | TEACHER: "resnet110" 8 | STUDENT: "resnet32" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/rld/resnet32x4_ShuffleV2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "rld,res32x4,shuv2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "RLD" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "ShuffleV2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.01 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/rld/resnet32x4_resnet8x4.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "rld,res32x4,res8x4" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "RLD" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "resnet8x4" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/rld/resnet32x4_wrn_16_2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "rld,res32x4,wrn_16_2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "RLD" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "wrn_16_2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/rld/resnet32x4_wrn_40_2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "rld,res32x4,wrn_40_2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "RLD" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "wrn_40_2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/rld/resnet56_resnet20.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "rld,res56,res20" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "RLD" 7 | TEACHER: "resnet56" 8 | STUDENT: "resnet20" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/rld/vgg13_MobileNetV2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "rld,vgg13,mv2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "RLD" 7 | TEACHER: "vgg13" 8 | STUDENT: "MobileNetV2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.01 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/rld/vgg13_vgg8.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "rld,vgg13,vgg8" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "RLD" 7 | TEACHER: "vgg13" 8 | STUDENT: "vgg8" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/rld/wrn_40_2_MobileNetV2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "rld,wrn_40_2,mv2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "RLD" 7 | TEACHER: "wrn_40_2" 8 | STUDENT: "MobileNetV2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.01 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/rld/wrn_40_2_resnet8x4.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "rld,wrn_40_2,res8x4" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "RLD" 7 | TEACHER: "wrn_40_2" 8 | STUDENT: "resnet8x4" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/rld/wrn_40_2_wrn_16_2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "rld,wrn_40_2,wrn_16_2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "RLD" 7 | TEACHER: "wrn_40_2" 8 | STUDENT: "wrn_16_2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/rld/wrn_40_2_wrn_40_1.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "rld,wrn_40_2,wrn_40_1" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "RLD" 7 | TEACHER: "wrn_40_2" 8 | STUDENT: "wrn_40_1" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/sp.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "sp,res32x4,res8x4" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "SP" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "resnet8x4" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/swap/ResNet50_MobileNetV2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "swap,res50,mv2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "SWAP" 7 | TEACHER: "ResNet50" 8 | STUDENT: "MobileNetV2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.01 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/swap/resnet110_resnet20.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "swap,res110,res20" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "SWAP" 7 | TEACHER: "resnet110" 8 | STUDENT: "resnet20" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/swap/resnet110_resnet32.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "swap,res110,res32" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "SWAP" 7 | TEACHER: "resnet110" 8 | STUDENT: "resnet32" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/swap/resnet32x4_ShuffleV2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "swap,res32x4,shuv2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "SWAP" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "ShuffleV2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.01 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/swap/resnet32x4_resnet8x4.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "swap,res32x4,res8x4" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "SWAP" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "resnet8x4" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/swap/resnet32x4_wrn_16_2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "swap,res32x4,wrn_16_2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "SWAP" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "wrn_16_2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/swap/resnet32x4_wrn_40_2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "swap,res32x4,wrn_40_2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "SWAP" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "wrn_40_2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/swap/resnet56_resnet20.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "swap,res56,res20" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "SWAP" 7 | TEACHER: "resnet56" 8 | STUDENT: "resnet20" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/swap/vgg13_MobileNetV2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "swap,vgg13,mv2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "SWAP" 7 | TEACHER: "vgg13" 8 | STUDENT: "MobileNetV2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.01 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/swap/vgg13_vgg8.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "swap,vgg13,vgg8" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "SWAP" 7 | TEACHER: "vgg13" 8 | STUDENT: "vgg8" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/swap/wrn_40_2_MobileNetV2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "swap,wrn_40_2,mv2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "SWAP" 7 | TEACHER: "wrn_40_2" 8 | STUDENT: "MobileNetV2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.01 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/swap/wrn_40_2_resnet8x4.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "swap,wrn_40_2,res8x4" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "SWAP" 7 | TEACHER: "wrn_40_2" 8 | STUDENT: "resnet8x4" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/swap/wrn_40_2_wrn_16_2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "swap,wrn_40_2,wrn_16_2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "SWAP" 7 | TEACHER: "wrn_40_2" 8 | STUDENT: "wrn_16_2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/swap/wrn_40_2_wrn_40_1.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "swap,wrn_40_2,wrn_40_1" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "SWAP" 7 | TEACHER: "wrn_40_2" 8 | STUDENT: "wrn_40_1" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/vanilla/ResNet50_MobileNetV2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "vanilla,res50,mv2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "NONE" 7 | TEACHER: "ResNet50" 8 | STUDENT: "MobileNetV2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.01 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/vanilla/resnet110_resnet20.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "vanilla,res110,res20" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "NONE" 7 | TEACHER: "resnet110" 8 | STUDENT: "resnet20" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/vanilla/resnet110_resnet32.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "vanilla,res110,res32" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "NONE" 7 | TEACHER: "resnet110" 8 | STUDENT: "resnet32" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/vanilla/resnet32x4_ShuffleV2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "vanilla,res32x4,shuv2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "NONE" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "ShuffleV2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.01 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/vanilla/resnet32x4_resnet8x4.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "vanilla,res32x4,res8x4" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "NONE" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "resnet8x4" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/vanilla/resnet32x4_wrn_16_2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "vanilla,res32x4,wrn_16_2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "NONE" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "wrn_16_2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/vanilla/resnet32x4_wrn_40_2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "vanilla,res32x4,wrn_40_2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "NONE" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "wrn_40_2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/vanilla/resnet56_resnet20.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "vanilla,res56,res20" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "NONE" 7 | TEACHER: "resnet56" 8 | STUDENT: "resnet20" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/vanilla/vgg13_MobileNetV2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "vanilla,vgg13,mv2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "NONE" 7 | TEACHER: "vgg13" 8 | STUDENT: "MobileNetV2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.01 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/vanilla/vgg13_vgg8.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "vanilla,vgg13,vgg8" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "NONE" 7 | TEACHER: "vgg13" 8 | STUDENT: "vgg8" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/vanilla/wrn_40_2_MobileNetV2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "vanilla,wrn_40_2,mv2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "NONE" 7 | TEACHER: "wrn_40_2" 8 | STUDENT: "MobileNetV2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.01 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/vanilla/wrn_40_2_resnet8x4.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "vanilla,wrn_40_2,res8x4" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "NONE" 7 | TEACHER: "wrn_40_2" 8 | STUDENT: "resnet8x4" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/vanilla/wrn_40_2_wrn_16_2.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "vanilla,wrn_40_2,wrn_16_2" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "NONE" 7 | TEACHER: "wrn_40_2" 8 | STUDENT: "wrn_16_2" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/vanilla/wrn_40_2_wrn_40_1.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "vanilla,wrn_40_2,wrn_40_1" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "NONE" 7 | TEACHER: "wrn_40_2" 8 | STUDENT: "wrn_40_1" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/cifar100/vid.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "vid,res32x4,res8x4" 4 | PROJECT: "cifar100_baselines" 5 | DISTILLER: 6 | TYPE: "VID" 7 | TEACHER: "resnet32x4" 8 | STUDENT: "resnet8x4" 9 | SOLVER: 10 | BATCH_SIZE: 64 11 | EPOCHS: 240 12 | LR: 0.05 13 | LR_DECAY_STAGES: [150, 180, 210] 14 | LR_DECAY_RATE: 0.1 15 | WEIGHT_DECAY: 0.0005 16 | MOMENTUM: 0.9 17 | TYPE: "SGD" 18 | -------------------------------------------------------------------------------- /configs/imagenet/r34_r18/at.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "at,res34,res18" 4 | PROJECT: "imagenet_baselines" 5 | DATASET: 6 | TYPE: "imagenet" 7 | NUM_WORKERS: 32 8 | TEST: 9 | BATCH_SIZE: 128 10 | DISTILLER: 11 | TYPE: "AT" 12 | TEACHER: "ResNet34" 13 | STUDENT: "ResNet18" 14 | SOLVER: 15 | BATCH_SIZE: 512 16 | EPOCHS: 100 17 | LR: 0.2 18 | LR_DECAY_STAGES: [30, 60, 90] 19 | LR_DECAY_RATE: 0.1 20 | WEIGHT_DECAY: 0.0001 21 | MOMENTUM: 0.9 22 | TYPE: "SGD" 23 | LOG: 24 | TENSORBOARD_FREQ: 50 25 | SAVE_CHECKPOINT_FREQ: 200 26 | -------------------------------------------------------------------------------- /configs/imagenet/r34_r18/crd.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "crd,res34,res18" 4 | PROJECT: "imagenet_baselines" 5 | DATASET: 6 | TYPE: "imagenet" 7 | NUM_WORKERS: 32 8 | TEST: 9 | BATCH_SIZE: 128 10 | DISTILLER: 11 | TYPE: "CRD" 12 | TEACHER: "ResNet34" 13 | STUDENT: "ResNet18" 14 | SOLVER: 15 | BATCH_SIZE: 256 16 | EPOCHS: 100 17 | LR: 0.1 18 | LR_DECAY_STAGES: [30, 60, 90] 19 | LR_DECAY_RATE: 0.1 20 | WEIGHT_DECAY: 0.0001 21 | MOMENTUM: 0.9 22 | TYPE: "SGD" 23 | TRAINER: "crd" 24 | CRD: 25 | FEAT: 26 | STUDENT_DIM: 512 27 | TEACHER_DIM: 512 28 | LOG: 29 | TENSORBOARD_FREQ: 50 30 | SAVE_CHECKPOINT_FREQ: 200 31 | -------------------------------------------------------------------------------- /configs/imagenet/r34_r18/dkd.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "dkd,res34,res18" 4 | PROJECT: "imagenet_baselines" 5 | DATASET: 6 | TYPE: "imagenet" 7 | NUM_WORKERS: 32 8 | TEST: 9 | BATCH_SIZE: 128 10 | DISTILLER: 11 | TYPE: "DKD" 12 | TEACHER: "ResNet34" 13 | STUDENT: "ResNet18" 14 | SOLVER: 15 | BATCH_SIZE: 512 16 | EPOCHS: 100 17 | LR: 0.2 18 | LR_DECAY_STAGES: [30, 60, 90] 19 | LR_DECAY_RATE: 0.1 20 | WEIGHT_DECAY: 0.0001 21 | MOMENTUM: 0.9 22 | TYPE: "SGD" 23 | LOG: 24 | TENSORBOARD_FREQ: 50 25 | SAVE_CHECKPOINT_FREQ: 200 26 | DKD: 27 | CE_WEIGHT: 1.0 28 | BETA: 0.5 29 | T: 1.0 30 | WARMUP: 1 31 | -------------------------------------------------------------------------------- /configs/imagenet/r34_r18/kd.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "kd,res34,res18" 4 | PROJECT: "imagenet_baselines" 5 | DATASET: 6 | TYPE: "imagenet" 7 | NUM_WORKERS: 32 8 | TEST: 9 | BATCH_SIZE: 128 10 | DISTILLER: 11 | TYPE: "KD" 12 | TEACHER: "ResNet34" 13 | STUDENT: "ResNet18" 14 | SOLVER: 15 | BATCH_SIZE: 512 16 | EPOCHS: 100 17 | LR: 0.2 18 | LR_DECAY_STAGES: [30, 60, 90] 19 | LR_DECAY_RATE: 0.1 20 | WEIGHT_DECAY: 0.0001 21 | MOMENTUM: 0.9 22 | TYPE: "SGD" 23 | KD: 24 | TEMPERATURE: 1.0 25 | LOSS: 26 | CE_WEIGHT: 1.0 27 | KD_WEIGHT: 0.5 28 | LOG: 29 | TENSORBOARD_FREQ: 50 30 | SAVE_CHECKPOINT_FREQ: 200 31 | -------------------------------------------------------------------------------- /configs/imagenet/r34_r18/mlkd.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "kd_ours,res34,res18" 4 | PROJECT: "imagenet_baselines" 5 | DATASET: 6 | TYPE: "imagenet" 7 | NUM_WORKERS: 32 8 | TEST: 9 | BATCH_SIZE: 128 10 | DISTILLER: 11 | TYPE: "KD_ours" 12 | TEACHER: "ResNet34" 13 | STUDENT: "ResNet18" 14 | SOLVER: 15 | BATCH_SIZE: 512 16 | EPOCHS: 100 17 | LR: 0.1 18 | LR_DECAY_STAGES: [30, 60, 90] 19 | LR_DECAY_RATE: 0.1 20 | WEIGHT_DECAY: 0.0001 21 | MOMENTUM: 0.9 22 | TYPE: "SGD" 23 | TRAINER: "ours" 24 | KD: 25 | TEMPERATURE: 1.0 26 | LOSS: 27 | CE_WEIGHT: 0.5 28 | KD_WEIGHT: 0.5 29 | LOG: 30 | TENSORBOARD_FREQ: 50 31 | SAVE_CHECKPOINT_FREQ: 200 32 | -------------------------------------------------------------------------------- /configs/imagenet/r34_r18/rc.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "rc,res34,res18" 4 | PROJECT: "imagenet_baselines" 5 | DATASET: 6 | TYPE: "imagenet" 7 | NUM_WORKERS: 32 8 | TEST: 9 | BATCH_SIZE: 128 10 | DISTILLER: 11 | TYPE: "RC" 12 | TEACHER: "ResNet34" 13 | STUDENT: "ResNet18" 14 | SOLVER: 15 | BATCH_SIZE: 512 16 | EPOCHS: 100 17 | LR: 0.2 18 | LR_DECAY_STAGES: [30, 60, 90] 19 | LR_DECAY_RATE: 0.1 20 | WEIGHT_DECAY: 0.0001 21 | MOMENTUM: 0.9 22 | TYPE: "SGD" 23 | RC: 24 | T: 1.0 25 | LOG: 26 | TENSORBOARD_FREQ: 50 27 | SAVE_CHECKPOINT_FREQ: 200 28 | -------------------------------------------------------------------------------- /configs/imagenet/r34_r18/reviewkd.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "reviewkd,res34,res18" 4 | PROJECT: "imagenet_baselines" 5 | DATASET: 6 | TYPE: "imagenet" 7 | NUM_WORKERS: 32 8 | TEST: 9 | BATCH_SIZE: 128 10 | DISTILLER: 11 | TYPE: "REVIEWKD" 12 | TEACHER: "ResNet34" 13 | STUDENT: "ResNet18" 14 | SOLVER: 15 | BATCH_SIZE: 256 16 | EPOCHS: 100 17 | LR: 0.1 18 | LR_DECAY_STAGES: [30, 60, 90] 19 | LR_DECAY_RATE: 0.1 20 | WEIGHT_DECAY: 0.0001 21 | MOMENTUM: 0.9 22 | TYPE: "SGD" 23 | LOG: 24 | TENSORBOARD_FREQ: 50 25 | SAVE_CHECKPOINT_FREQ: 200 26 | REVIEWKD: 27 | CE_WEIGHT: 1.0 28 | REVIEWKD_WEIGHT: 1.0 29 | WARMUP_EPOCHS: 1 30 | SHAPES: [1,7,14,28,56] 31 | OUT_SHAPES: [1,7,14,28,56] 32 | IN_CHANNELS: [64,128,256,512,512] 33 | OUT_CHANNELS: [64,128,256,512,512] 34 | STU_PREACT: True 35 | -------------------------------------------------------------------------------- /configs/imagenet/r34_r18/revision.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "revision,res34,res18" 4 | PROJECT: "imagenet_baselines" 5 | DATASET: 6 | TYPE: "imagenet" 7 | NUM_WORKERS: 32 8 | TEST: 9 | BATCH_SIZE: 128 10 | DISTILLER: 11 | TYPE: "REVISION" 12 | TEACHER: "ResNet34" 13 | STUDENT: "ResNet18" 14 | SOLVER: 15 | BATCH_SIZE: 512 16 | EPOCHS: 100 17 | LR: 0.2 18 | LR_DECAY_STAGES: [30, 60, 90] 19 | LR_DECAY_RATE: 0.1 20 | WEIGHT_DECAY: 0.0001 21 | MOMENTUM: 0.9 22 | TYPE: "SGD" 23 | LOG: 24 | TENSORBOARD_FREQ: 50 25 | SAVE_CHECKPOINT_FREQ: 200 26 | -------------------------------------------------------------------------------- /configs/imagenet/r34_r18/rld.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "rld,res34,res18" 4 | PROJECT: "imagenet_baselines" 5 | DATASET: 6 | TYPE: "imagenet" 7 | NUM_WORKERS: 32 8 | TEST: 9 | BATCH_SIZE: 128 10 | DISTILLER: 11 | TYPE: "RLD" 12 | TEACHER: "ResNet34" 13 | STUDENT: "ResNet18" 14 | SOLVER: 15 | BATCH_SIZE: 512 16 | EPOCHS: 100 17 | LR: 0.2 18 | LR_DECAY_STAGES: [30, 60, 90] 19 | LR_DECAY_RATE: 0.1 20 | WEIGHT_DECAY: 0.0001 21 | MOMENTUM: 0.9 22 | TYPE: "SGD" 23 | LOG: 24 | TENSORBOARD_FREQ: 50 25 | SAVE_CHECKPOINT_FREQ: 200 26 | RLD: 27 | CE_WEIGHT: 1.0 28 | BETA: 0.5 29 | T: 1.0 30 | WARMUP: 1 31 | -------------------------------------------------------------------------------- /configs/imagenet/r34_r18/swap.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "swap,res34,res18" 4 | PROJECT: "imagenet_baselines" 5 | DATASET: 6 | TYPE: "imagenet" 7 | NUM_WORKERS: 32 8 | TEST: 9 | BATCH_SIZE: 128 10 | DISTILLER: 11 | TYPE: "SWAP" 12 | TEACHER: "ResNet34" 13 | STUDENT: "ResNet18" 14 | SOLVER: 15 | BATCH_SIZE: 512 16 | EPOCHS: 100 17 | LR: 0.2 18 | LR_DECAY_STAGES: [30, 60, 90] 19 | LR_DECAY_RATE: 0.1 20 | WEIGHT_DECAY: 0.0001 21 | MOMENTUM: 0.9 22 | TYPE: "SGD" 23 | KD: 24 | TEMPERATURE: 1.0 25 | LOG: 26 | TENSORBOARD_FREQ: 50 27 | SAVE_CHECKPOINT_FREQ: 200 28 | -------------------------------------------------------------------------------- /configs/imagenet/r50_mv1/at.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "at,res50,mv1" 4 | PROJECT: "imagenet_baselines" 5 | DATASET: 6 | TYPE: "imagenet" 7 | NUM_WORKERS: 32 8 | TEST: 9 | BATCH_SIZE: 128 10 | DISTILLER: 11 | TYPE: "AT" 12 | TEACHER: "ResNet50" 13 | STUDENT: "MobileNetV1" 14 | SOLVER: 15 | BATCH_SIZE: 512 16 | EPOCHS: 100 17 | LR: 0.2 18 | LR_DECAY_STAGES: [30, 60, 90] 19 | LR_DECAY_RATE: 0.1 20 | WEIGHT_DECAY: 0.0001 21 | MOMENTUM: 0.9 22 | TYPE: "SGD" 23 | LOG: 24 | TENSORBOARD_FREQ: 50 25 | SAVE_CHECKPOINT_FREQ: 200 26 | -------------------------------------------------------------------------------- /configs/imagenet/r50_mv1/crd.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "crd,res50,mv1" 4 | PROJECT: "imagenet_baselines" 5 | DATASET: 6 | TYPE: "imagenet" 7 | NUM_WORKERS: 32 8 | TEST: 9 | BATCH_SIZE: 128 10 | DISTILLER: 11 | TYPE: "CRD" 12 | TEACHER: "ResNet50" 13 | STUDENT: "MobileNetV1" 14 | SOLVER: 15 | BATCH_SIZE: 256 16 | EPOCHS: 100 17 | LR: 0.1 18 | LR_DECAY_STAGES: [30, 60, 90] 19 | LR_DECAY_RATE: 0.1 20 | WEIGHT_DECAY: 0.0001 21 | MOMENTUM: 0.9 22 | TYPE: "SGD" 23 | TRAINER: "crd" 24 | CRD: 25 | FEAT: 26 | STUDENT_DIM: 1024 27 | TEACHER_DIM: 2048 28 | LOG: 29 | TENSORBOARD_FREQ: 50 30 | SAVE_CHECKPOINT_FREQ: 200 31 | -------------------------------------------------------------------------------- /configs/imagenet/r50_mv1/dkd.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "dkd,res50,mv1" 4 | PROJECT: "imagenet_baselines" 5 | DATASET: 6 | TYPE: "imagenet" 7 | NUM_WORKERS: 32 8 | TEST: 9 | BATCH_SIZE: 128 10 | DISTILLER: 11 | TYPE: "DKD" 12 | TEACHER: "ResNet50" 13 | STUDENT: "MobileNetV1" 14 | SOLVER: 15 | BATCH_SIZE: 512 16 | EPOCHS: 100 17 | LR: 0.2 18 | LR_DECAY_STAGES: [30, 60, 90] 19 | LR_DECAY_RATE: 0.1 20 | WEIGHT_DECAY: 0.0001 21 | MOMENTUM: 0.9 22 | TYPE: "SGD" 23 | LOG: 24 | TENSORBOARD_FREQ: 50 25 | SAVE_CHECKPOINT_FREQ: 200 26 | DKD: 27 | CE_WEIGHT: 1.0 28 | BETA: 2.0 29 | T: 1.0 30 | WARMUP: 1 31 | -------------------------------------------------------------------------------- /configs/imagenet/r50_mv1/kd.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "kd,res50,mv1" 4 | PROJECT: "imagenet_baselines" 5 | DATASET: 6 | TYPE: "imagenet" 7 | NUM_WORKERS: 32 8 | TEST: 9 | BATCH_SIZE: 128 10 | DISTILLER: 11 | TYPE: "KD" 12 | TEACHER: "ResNet50" 13 | STUDENT: "MobileNetV1" 14 | SOLVER: 15 | BATCH_SIZE: 512 16 | EPOCHS: 100 17 | LR: 0.2 18 | LR_DECAY_STAGES: [30, 60, 90] 19 | LR_DECAY_RATE: 0.1 20 | WEIGHT_DECAY: 0.0001 21 | MOMENTUM: 0.9 22 | TYPE: "SGD" 23 | KD: 24 | TEMPERATURE: 1.0 25 | LOSS: 26 | CE_WEIGHT: 0.5 27 | KD_WEIGHT: 0.5 28 | LOG: 29 | TENSORBOARD_FREQ: 50 30 | SAVE_CHECKPOINT_FREQ: 200 31 | -------------------------------------------------------------------------------- /configs/imagenet/r50_mv1/mlkd.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "kd_ours,res50,mv1" 4 | PROJECT: "imagenet_baselines" 5 | DATASET: 6 | TYPE: "imagenet" 7 | NUM_WORKERS: 32 8 | TEST: 9 | BATCH_SIZE: 128 10 | DISTILLER: 11 | TYPE: "KD_ours" 12 | TEACHER: "ResNet50" 13 | STUDENT: "MobileNetV1" 14 | SOLVER: 15 | BATCH_SIZE: 512 16 | EPOCHS: 100 17 | LR: 0.2 18 | LR_DECAY_STAGES: [30, 60, 90] 19 | LR_DECAY_RATE: 0.1 20 | WEIGHT_DECAY: 0.0001 21 | MOMENTUM: 0.9 22 | TYPE: "SGD" 23 | TRAINER: "ours" 24 | KD: 25 | TEMPERATURE: 1.0 26 | LOSS: 27 | CE_WEIGHT: 0.5 28 | KD_WEIGHT: 0.5 29 | LOG: 30 | TENSORBOARD_FREQ: 50 31 | SAVE_CHECKPOINT_FREQ: 200 32 | -------------------------------------------------------------------------------- /configs/imagenet/r50_mv1/ofd.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "ofd,res50,mv1" 4 | PROJECT: "imagenet_baselines" 5 | DATASET: 6 | TYPE: "imagenet" 7 | NUM_WORKERS: 32 8 | TEST: 9 | BATCH_SIZE: 128 10 | DISTILLER: 11 | TYPE: "OFD" 12 | TEACHER: "ResNet50" 13 | STUDENT: "MobileNetV1" 14 | SOLVER: 15 | BATCH_SIZE: 128 16 | EPOCHS: 100 17 | LR: 0.05 18 | LR_DECAY_STAGES: [30, 60, 90] 19 | LR_DECAY_RATE: 0.1 20 | WEIGHT_DECAY: 0.0001 21 | MOMENTUM: 0.9 22 | TYPE: "SGD" 23 | OFD: 24 | LOSS: 25 | FEAT_WEIGHT: 0.0001 26 | LOG: 27 | TENSORBOARD_FREQ: 50 28 | SAVE_CHECKPOINT_FREQ: 200 29 | -------------------------------------------------------------------------------- /configs/imagenet/r50_mv1/rc.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "rc,res50,mv1" 4 | PROJECT: "imagenet_baselines" 5 | DATASET: 6 | TYPE: "imagenet" 7 | NUM_WORKERS: 32 8 | TEST: 9 | BATCH_SIZE: 128 10 | DISTILLER: 11 | TYPE: "RC" 12 | TEACHER: "ResNet50" 13 | STUDENT: "MobileNetV1" 14 | SOLVER: 15 | BATCH_SIZE: 512 16 | EPOCHS: 100 17 | LR: 0.2 18 | LR_DECAY_STAGES: [30, 60, 90] 19 | LR_DECAY_RATE: 0.1 20 | WEIGHT_DECAY: 0.0001 21 | MOMENTUM: 0.9 22 | TYPE: "SGD" 23 | RC: 24 | T: 1.0 25 | LOG: 26 | TENSORBOARD_FREQ: 50 27 | SAVE_CHECKPOINT_FREQ: 200 28 | -------------------------------------------------------------------------------- /configs/imagenet/r50_mv1/reviewkd.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "reviewkd,res50,mv1" 4 | PROJECT: "imagenet_baselines" 5 | DATASET: 6 | TYPE: "imagenet" 7 | NUM_WORKERS: 32 8 | TEST: 9 | BATCH_SIZE: 128 10 | DISTILLER: 11 | TYPE: "REVIEWKD" 12 | TEACHER: "ResNet50" 13 | STUDENT: "MobileNetV1" 14 | SOLVER: 15 | BATCH_SIZE: 256 16 | EPOCHS: 100 17 | LR: 0.1 18 | LR_DECAY_STAGES: [30, 60, 90] 19 | LR_DECAY_RATE: 0.1 20 | WEIGHT_DECAY: 0.0001 21 | MOMENTUM: 0.9 22 | TYPE: "SGD" 23 | LOG: 24 | TENSORBOARD_FREQ: 50 25 | SAVE_CHECKPOINT_FREQ: 200 26 | REVIEWKD: 27 | CE_WEIGHT: 1.0 28 | REVIEWKD_WEIGHT: 8.0 29 | WARMUP_EPOCHS: 1 30 | SHAPES: [1,7,14,28,56] 31 | OUT_SHAPES: [1,7,14,28,56] 32 | IN_CHANNELS: [128,256,512,1024,1024] 33 | OUT_CHANNELS: [256,512,1024,2048,2048] 34 | MAX_MID_CHANNEL: 256 35 | STU_PREACT: True 36 | -------------------------------------------------------------------------------- /configs/imagenet/r50_mv1/revision.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "revision,res50,mv1" 4 | PROJECT: "imagenet_baselines" 5 | DATASET: 6 | TYPE: "imagenet" 7 | NUM_WORKERS: 32 8 | TEST: 9 | BATCH_SIZE: 128 10 | DISTILLER: 11 | TYPE: "REVISION" 12 | TEACHER: "ResNet50" 13 | STUDENT: "MobileNetV1" 14 | SOLVER: 15 | BATCH_SIZE: 512 16 | EPOCHS: 100 17 | LR: 0.2 18 | LR_DECAY_STAGES: [30, 60, 90] 19 | LR_DECAY_RATE: 0.1 20 | WEIGHT_DECAY: 0.0001 21 | MOMENTUM: 0.9 22 | TYPE: "SGD" 23 | LOG: 24 | TENSORBOARD_FREQ: 50 25 | SAVE_CHECKPOINT_FREQ: 200 26 | -------------------------------------------------------------------------------- /configs/imagenet/r50_mv1/rld.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "rld,res50,mv1" 4 | PROJECT: "imagenet_baselines" 5 | DATASET: 6 | TYPE: "imagenet" 7 | NUM_WORKERS: 32 8 | TEST: 9 | BATCH_SIZE: 128 10 | DISTILLER: 11 | TYPE: "RLD" 12 | TEACHER: "ResNet50" 13 | STUDENT: "MobileNetV1" 14 | SOLVER: 15 | BATCH_SIZE: 512 16 | EPOCHS: 100 17 | LR: 0.2 18 | LR_DECAY_STAGES: [30, 60, 90] 19 | LR_DECAY_RATE: 0.1 20 | WEIGHT_DECAY: 0.0001 21 | MOMENTUM: 0.9 22 | TYPE: "SGD" 23 | LOG: 24 | TENSORBOARD_FREQ: 50 25 | SAVE_CHECKPOINT_FREQ: 200 26 | RLD: 27 | CE_WEIGHT: 1.0 28 | BETA: 2.0 29 | T: 1.0 30 | WARMUP: 1 31 | -------------------------------------------------------------------------------- /configs/imagenet/r50_mv1/swap.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT: 2 | NAME: "" 3 | TAG: "swap,res50,mv1" 4 | PROJECT: "imagenet_baselines" 5 | DATASET: 6 | TYPE: "imagenet" 7 | NUM_WORKERS: 32 8 | TEST: 9 | BATCH_SIZE: 128 10 | DISTILLER: 11 | TYPE: "SWAP" 12 | TEACHER: "ResNet50" 13 | STUDENT: "MobileNetV1" 14 | SOLVER: 15 | BATCH_SIZE: 512 16 | EPOCHS: 100 17 | LR: 0.2 18 | LR_DECAY_STAGES: [30, 60, 90] 19 | LR_DECAY_RATE: 0.1 20 | WEIGHT_DECAY: 0.0001 21 | MOMENTUM: 0.9 22 | TYPE: "SGD" 23 | KD: 24 | TEMPERATURE: 1.0 25 | LOG: 26 | TENSORBOARD_FREQ: 50 27 | SAVE_CHECKPOINT_FREQ: 200 28 | -------------------------------------------------------------------------------- /mdistiller/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zju-SWJ/RLD/d54578fd77cd669d6209be4d05ae773bca4109b5/mdistiller/__init__.py -------------------------------------------------------------------------------- /mdistiller/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .cifar100 import get_cifar100_dataloaders, get_cifar100_dataloaders_sample, get_cifar100_dataloaders_trainval, get_cifar100_dataloaders_val_only, get_cifar100_dataloaders_train_only, get_cifar100_dataloaders_strong 2 | from .imagenet import get_imagenet_dataloaders, get_imagenet_dataloaders_sample, get_imagenet_dataloaders_strong 3 | 4 | 5 | def get_dataset(cfg): 6 | if cfg.DATASET.TYPE == "cifar100": 7 | if cfg.DISTILLER.TYPE == "CRD": 8 | train_loader, val_loader, num_data = get_cifar100_dataloaders_sample( 9 | batch_size=cfg.SOLVER.BATCH_SIZE, 10 | val_batch_size=cfg.DATASET.TEST.BATCH_SIZE, 11 | num_workers=cfg.DATASET.NUM_WORKERS, 12 | k=cfg.CRD.NCE.K, 13 | mode=cfg.CRD.MODE, 14 | ) 15 | else: 16 | train_loader, val_loader, num_data = get_cifar100_dataloaders( 17 | batch_size=cfg.SOLVER.BATCH_SIZE, 18 | val_batch_size=cfg.DATASET.TEST.BATCH_SIZE, 19 | num_workers=cfg.DATASET.NUM_WORKERS, 20 | aug=cfg.EXPERIMENT.AUG, 21 | ) 22 | num_classes = 100 23 | elif cfg.DATASET.TYPE == "imagenet": 24 | if cfg.DISTILLER.TYPE == "CRD": 25 | train_loader, val_loader, num_data = get_imagenet_dataloaders_sample( 26 | batch_size=cfg.SOLVER.BATCH_SIZE, 27 | val_batch_size=cfg.DATASET.TEST.BATCH_SIZE, 28 | num_workers=cfg.DATASET.NUM_WORKERS, 29 | k=cfg.CRD.NCE.K, 30 | ) 31 | else: 32 | train_loader, val_loader, num_data = get_imagenet_dataloaders( 33 | batch_size=cfg.SOLVER.BATCH_SIZE, 34 | val_batch_size=cfg.DATASET.TEST.BATCH_SIZE, 35 | num_workers=cfg.DATASET.NUM_WORKERS, 36 | aug=cfg.EXPERIMENT.AUG, 37 | ) 38 | num_classes = 1000 39 | else: 40 | raise NotImplementedError(cfg.DATASET.TYPE) 41 | 42 | return train_loader, val_loader, num_data, num_classes 43 | 44 | 45 | def get_dataset_strong(cfg): 46 | if cfg.DATASET.TYPE == "cifar100": 47 | if cfg.DISTILLER.TYPE == "CRD": 48 | train_loader, val_loader, num_data = get_cifar100_dataloaders_sample( 49 | batch_size=cfg.SOLVER.BATCH_SIZE, 50 | val_batch_size=cfg.DATASET.TEST.BATCH_SIZE, 51 | num_workers=cfg.DATASET.NUM_WORKERS, 52 | k=cfg.CRD.NCE.K, 53 | mode=cfg.CRD.MODE, 54 | ) 55 | else: 56 | train_loader, val_loader, num_data = get_cifar100_dataloaders_strong( 57 | batch_size=cfg.SOLVER.BATCH_SIZE, 58 | val_batch_size=cfg.DATASET.TEST.BATCH_SIZE, 59 | num_workers=cfg.DATASET.NUM_WORKERS, 60 | ) 61 | num_classes = 100 62 | elif cfg.DATASET.TYPE == "imagenet": 63 | if cfg.DISTILLER.TYPE == "CRD": 64 | train_loader, val_loader, num_data = get_imagenet_dataloaders_sample( 65 | batch_size=cfg.SOLVER.BATCH_SIZE, 66 | val_batch_size=cfg.DATASET.TEST.BATCH_SIZE, 67 | num_workers=cfg.DATASET.NUM_WORKERS, 68 | k=cfg.CRD.NCE.K, 69 | ) 70 | else: 71 | train_loader, val_loader, num_data = get_imagenet_dataloaders_strong( 72 | batch_size=cfg.SOLVER.BATCH_SIZE, 73 | val_batch_size=cfg.DATASET.TEST.BATCH_SIZE, 74 | num_workers=cfg.DATASET.NUM_WORKERS, 75 | ) 76 | num_classes = 1000 77 | else: 78 | raise NotImplementedError(cfg.DATASET.TYPE) 79 | 80 | return train_loader, val_loader, num_data, num_classes 81 | 82 | 83 | -------------------------------------------------------------------------------- /mdistiller/distillers/AT.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ._base import Distiller 6 | 7 | 8 | def single_stage_at_loss(f_s, f_t, p): 9 | def _at(feat, p): 10 | return F.normalize(feat.pow(p).mean(1).reshape(feat.size(0), -1)) 11 | 12 | s_H, t_H = f_s.shape[2], f_t.shape[2] 13 | if s_H > t_H: 14 | f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H)) 15 | elif s_H < t_H: 16 | f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H)) 17 | return (_at(f_s, p) - _at(f_t, p)).pow(2).mean() 18 | 19 | 20 | def at_loss(g_s, g_t, p): 21 | return sum([single_stage_at_loss(f_s, f_t, p) for f_s, f_t in zip(g_s, g_t)]) 22 | 23 | 24 | class AT(Distiller): 25 | """ 26 | Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks via Attention Transfer 27 | src code: https://github.com/szagoruyko/attention-transfer 28 | """ 29 | 30 | def __init__(self, student, teacher, cfg): 31 | super(AT, self).__init__(student, teacher) 32 | self.p = cfg.AT.P 33 | self.ce_loss_weight = cfg.AT.LOSS.CE_WEIGHT 34 | self.feat_loss_weight = cfg.AT.LOSS.FEAT_WEIGHT 35 | 36 | def forward_train(self, image, target, **kwargs): 37 | logits_student, feature_student = self.student(image) 38 | with torch.no_grad(): 39 | _, feature_teacher = self.teacher(image) 40 | 41 | # losses 42 | loss_ce = self.ce_loss_weight * F.cross_entropy(logits_student, target) 43 | loss_feat = self.feat_loss_weight * at_loss( 44 | feature_student["feats"][1:], feature_teacher["feats"][1:], self.p 45 | ) 46 | losses_dict = { 47 | "loss_ce": loss_ce, 48 | "loss_kd": loss_feat, 49 | } 50 | return logits_student, losses_dict 51 | -------------------------------------------------------------------------------- /mdistiller/distillers/DKD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ._base import Distiller 6 | 7 | def normalize(logit): 8 | mean = logit.mean(dim=-1, keepdims=True) 9 | stdv = logit.std(dim=-1, keepdims=True) 10 | return (logit - mean) / (1e-7 + stdv) 11 | 12 | def dkd_loss(logits_student_in, logits_teacher_in, target, alpha, beta, temperature, logit_stand): 13 | logits_student = normalize(logits_student_in) if logit_stand else logits_student_in 14 | logits_teacher = normalize(logits_teacher_in) if logit_stand else logits_teacher_in 15 | 16 | gt_mask = _get_gt_mask(logits_student, target) 17 | other_mask = _get_other_mask(logits_student, target) 18 | pred_student = F.softmax(logits_student / temperature, dim=1) 19 | pred_teacher = F.softmax(logits_teacher / temperature, dim=1) 20 | pred_student = cat_mask(pred_student, gt_mask, other_mask) 21 | pred_teacher = cat_mask(pred_teacher, gt_mask, other_mask) 22 | log_pred_student = torch.log(pred_student) 23 | tckd_loss = F.kl_div(log_pred_student, pred_teacher, reduction='batchmean') * (temperature**2) 24 | 25 | pred_teacher_part2 = F.softmax(logits_teacher / temperature - 1000.0 * gt_mask, dim=1) 26 | log_pred_student_part2 = F.log_softmax(logits_student / temperature - 1000.0 * gt_mask, dim=1) 27 | nckd_loss = F.kl_div(log_pred_student_part2, pred_teacher_part2, reduction='batchmean') * (temperature**2) 28 | 29 | return alpha * tckd_loss + beta * nckd_loss 30 | 31 | 32 | def _get_gt_mask(logits, target): 33 | target = target.reshape(-1) 34 | mask = torch.zeros_like(logits).scatter_(1, target.unsqueeze(1), 1).bool() 35 | return mask 36 | 37 | 38 | def _get_other_mask(logits, target): 39 | target = target.reshape(-1) 40 | mask = torch.ones_like(logits).scatter_(1, target.unsqueeze(1), 0).bool() 41 | return mask 42 | 43 | 44 | def cat_mask(t, mask1, mask2): 45 | t1 = (t * mask1).sum(dim=1, keepdims=True) 46 | t2 = (t * mask2).sum(1, keepdims=True) 47 | rt = torch.cat([t1, t2], dim=1) 48 | return rt 49 | 50 | 51 | class DKD(Distiller): 52 | """Decoupled Knowledge Distillation(CVPR 2022)""" 53 | 54 | def __init__(self, student, teacher, cfg): 55 | super(DKD, self).__init__(student, teacher) 56 | self.ce_loss_weight = cfg.DKD.CE_WEIGHT 57 | self.alpha = cfg.DKD.ALPHA 58 | self.beta = cfg.DKD.BETA 59 | self.temperature = cfg.DKD.T 60 | self.warmup = cfg.DKD.WARMUP 61 | self.logit_stand = cfg.EXPERIMENT.LOGIT_STAND 62 | 63 | def forward_train(self, image, target, **kwargs): 64 | logits_student, _ = self.student(image) 65 | with torch.no_grad(): 66 | logits_teacher, _ = self.teacher(image) 67 | 68 | # losses 69 | loss_ce = self.ce_loss_weight * F.cross_entropy(logits_student, target) 70 | loss_dkd = min(kwargs["epoch"] / self.warmup, 1.0) * dkd_loss( 71 | logits_student, 72 | logits_teacher, 73 | target, 74 | self.alpha, 75 | self.beta, 76 | self.temperature, 77 | self.logit_stand, 78 | ) 79 | losses_dict = { 80 | "loss_ce": loss_ce, 81 | "loss_kd": loss_dkd, 82 | } 83 | return logits_student, losses_dict 84 | -------------------------------------------------------------------------------- /mdistiller/distillers/FitNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ._base import Distiller 6 | from ._common import ConvReg, get_feat_shapes 7 | 8 | 9 | class FitNet(Distiller): 10 | """FitNets: Hints for Thin Deep Nets""" 11 | 12 | def __init__(self, student, teacher, cfg): 13 | super(FitNet, self).__init__(student, teacher) 14 | self.ce_loss_weight = cfg.FITNET.LOSS.CE_WEIGHT 15 | self.feat_loss_weight = cfg.FITNET.LOSS.FEAT_WEIGHT 16 | self.hint_layer = cfg.FITNET.HINT_LAYER 17 | feat_s_shapes, feat_t_shapes = get_feat_shapes( 18 | self.student, self.teacher, cfg.FITNET.INPUT_SIZE 19 | ) 20 | self.conv_reg = ConvReg( 21 | feat_s_shapes[self.hint_layer], feat_t_shapes[self.hint_layer] 22 | ) 23 | 24 | def get_learnable_parameters(self): 25 | return super().get_learnable_parameters() + list(self.conv_reg.parameters()) 26 | 27 | def get_extra_parameters(self): 28 | num_p = 0 29 | for p in self.conv_reg.parameters(): 30 | num_p += p.numel() 31 | return num_p 32 | 33 | def forward_train(self, image, target, **kwargs): 34 | logits_student, feature_student = self.student(image) 35 | with torch.no_grad(): 36 | _, feature_teacher = self.teacher(image) 37 | 38 | # losses 39 | loss_ce = self.ce_loss_weight * F.cross_entropy(logits_student, target) 40 | f_s = self.conv_reg(feature_student["feats"][self.hint_layer]) 41 | loss_feat = self.feat_loss_weight * F.mse_loss( 42 | f_s, feature_teacher["feats"][self.hint_layer] 43 | ) 44 | losses_dict = { 45 | "loss_ce": loss_ce, 46 | "loss_kd": loss_feat, 47 | } 48 | return logits_student, losses_dict 49 | -------------------------------------------------------------------------------- /mdistiller/distillers/KD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ._base import Distiller 6 | 7 | def normalize(logit): 8 | mean = logit.mean(dim=-1, keepdims=True) 9 | stdv = logit.std(dim=-1, keepdims=True) 10 | return (logit - mean) / (1e-7 + stdv) 11 | 12 | def kd_loss(logits_student_in, logits_teacher_in, temperature, logit_stand): 13 | logits_student = normalize(logits_student_in) if logit_stand else logits_student_in 14 | logits_teacher = normalize(logits_teacher_in) if logit_stand else logits_teacher_in 15 | log_pred_student = F.log_softmax(logits_student / temperature, dim=1) 16 | pred_teacher = F.softmax(logits_teacher / temperature, dim=1) 17 | loss_kd = F.kl_div(log_pred_student, pred_teacher, reduction="none").sum(1).mean() 18 | loss_kd *= temperature**2 19 | return loss_kd 20 | 21 | 22 | class KD(Distiller): 23 | """Distilling the Knowledge in a Neural Network""" 24 | 25 | def __init__(self, student, teacher, cfg): 26 | super(KD, self).__init__(student, teacher) 27 | self.temperature = cfg.KD.TEMPERATURE 28 | self.ce_loss_weight = cfg.KD.LOSS.CE_WEIGHT 29 | self.kd_loss_weight = cfg.KD.LOSS.KD_WEIGHT 30 | self.logit_stand = cfg.EXPERIMENT.LOGIT_STAND 31 | 32 | def forward_train(self, image, target, **kwargs): 33 | logits_student, _ = self.student(image) 34 | with torch.no_grad(): 35 | logits_teacher, _ = self.teacher(image) 36 | 37 | # losses 38 | loss_ce = self.ce_loss_weight * F.cross_entropy(logits_student, target) 39 | loss_kd = self.kd_loss_weight * kd_loss( 40 | logits_student, logits_teacher, self.temperature, self.logit_stand 41 | ) 42 | losses_dict = { 43 | "loss_ce": loss_ce, 44 | "loss_kd": loss_kd, 45 | } 46 | return logits_student, losses_dict 47 | -------------------------------------------------------------------------------- /mdistiller/distillers/KDSVD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ._base import Distiller 6 | 7 | 8 | def kdsvd_loss(g_s, g_t, k): 9 | v_sb = None 10 | v_tb = None 11 | losses = [] 12 | for i, f_s, f_t in zip(range(len(g_s)), g_s, g_t): 13 | u_t, s_t, v_t = svd(f_t, k) 14 | u_s, s_s, v_s = svd(f_s, k + 3) 15 | v_s, v_t = align_rsv(v_s, v_t) 16 | s_t = s_t.unsqueeze(1) 17 | v_t = v_t * s_t 18 | v_s = v_s * s_t 19 | 20 | if i > 0: 21 | s_rbf = torch.exp(-(v_s.unsqueeze(2) - v_sb.unsqueeze(1)).pow(2) / 8) 22 | t_rbf = torch.exp(-(v_t.unsqueeze(2) - v_tb.unsqueeze(1)).pow(2) / 8) 23 | 24 | l2loss = (s_rbf - t_rbf.detach()).pow(2) 25 | l2loss = torch.where( 26 | torch.isfinite(l2loss), l2loss, torch.zeros_like(l2loss) 27 | ) 28 | losses.append(l2loss.sum()) 29 | 30 | v_tb = v_t 31 | v_sb = v_s 32 | 33 | bsz = g_s[0].shape[0] 34 | losses = [l / bsz for l in losses] 35 | return sum(losses) 36 | 37 | 38 | def svd(feat, n=1): 39 | size = feat.shape 40 | assert len(size) == 4 41 | 42 | x = feat.view(size[0], size[1] * size[2], size[3]).float() 43 | u, s, v = torch.svd(x) 44 | 45 | u = removenan(u) 46 | s = removenan(s) 47 | v = removenan(v) 48 | 49 | if n > 0: 50 | u = F.normalize(u[:, :, :n], dim=1) 51 | s = F.normalize(s[:, :n], dim=1) 52 | v = F.normalize(v[:, :, :n], dim=1) 53 | 54 | return u, s, v 55 | 56 | 57 | def removenan(x): 58 | x = torch.where(torch.isfinite(x), x, torch.zeros_like(x)) 59 | return x 60 | 61 | 62 | def align_rsv(a, b): 63 | cosine = torch.matmul(a.transpose(-2, -1), b) 64 | max_abs_cosine, _ = torch.max(torch.abs(cosine), 1, keepdim=True) 65 | mask = torch.where( 66 | torch.eq(max_abs_cosine, torch.abs(cosine)), 67 | torch.sign(cosine), 68 | torch.zeros_like(cosine), 69 | ) 70 | a = torch.matmul(a, mask) 71 | return a, b 72 | 73 | 74 | class KDSVD(Distiller): 75 | """ 76 | Self-supervised Knowledge Distillation using Singular Value Decomposition 77 | original Tensorflow code: https://github.com/sseung0703/SSKD_SVD 78 | """ 79 | 80 | def __init__(self, student, teacher, cfg): 81 | super(KDSVD, self).__init__(student, teacher) 82 | self.k = cfg.KDSVD.K 83 | self.ce_loss_weight = cfg.KDSVD.LOSS.CE_WEIGHT 84 | self.feat_loss_weight = cfg.KDSVD.LOSS.FEAT_WEIGHT 85 | 86 | def forward_train(self, image, target, **kwargs): 87 | logits_student, feature_student = self.student(image) 88 | with torch.no_grad(): 89 | _, feature_teacher = self.teacher(image) 90 | loss_ce = self.ce_loss_weight * F.cross_entropy(logits_student, target) 91 | loss_feat = self.feat_loss_weight * kdsvd_loss( 92 | feature_student["feats"][1:], feature_teacher["feats"][1:], self.k 93 | ) 94 | losses_dict = { 95 | "loss_ce": loss_ce, 96 | "loss_kd": loss_feat, 97 | } 98 | return logits_student, losses_dict 99 | -------------------------------------------------------------------------------- /mdistiller/distillers/MLKD_NOAUG.py: -------------------------------------------------------------------------------- 1 | from termios import CEOL 2 | from turtle import st 3 | import torch 4 | import torch.fft 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import numpy as np 8 | 9 | from ._base import Distiller 10 | from .loss import CrossEntropyLabelSmooth 11 | 12 | def normalize(logit): 13 | mean = logit.mean(dim=-1, keepdims=True) 14 | stdv = logit.std(dim=-1, keepdims=True) 15 | return (logit - mean) / (1e-7 + stdv) 16 | 17 | def kd_loss(logits_student_in, logits_teacher_in, temperature, reduce=True, logit_stand=False): 18 | logits_student = normalize(logits_student_in) if logit_stand else logits_student_in 19 | logits_teacher = normalize(logits_teacher_in) if logit_stand else logits_teacher_in 20 | 21 | log_pred_student = F.log_softmax(logits_student / temperature, dim=1) 22 | pred_teacher = F.softmax(logits_teacher / temperature, dim=1) 23 | if reduce: 24 | loss_kd = F.kl_div(log_pred_student, pred_teacher, reduction="none").sum(1).mean() 25 | else: 26 | loss_kd = F.kl_div(log_pred_student, pred_teacher, reduction="none").sum(1) 27 | loss_kd *= temperature**2 28 | return loss_kd 29 | 30 | 31 | def cc_loss(logits_student, logits_teacher, temperature, reduce=True): 32 | batch_size, class_num = logits_teacher.shape 33 | pred_student = F.softmax(logits_student / temperature, dim=1) 34 | pred_teacher = F.softmax(logits_teacher / temperature, dim=1) 35 | student_matrix = torch.mm(pred_student.transpose(1, 0), pred_student) 36 | teacher_matrix = torch.mm(pred_teacher.transpose(1, 0), pred_teacher) 37 | if reduce: 38 | consistency_loss = ((teacher_matrix - student_matrix) ** 2).sum() / class_num 39 | else: 40 | consistency_loss = ((teacher_matrix - student_matrix) ** 2) / class_num 41 | return consistency_loss 42 | 43 | 44 | def bc_loss(logits_student, logits_teacher, temperature, reduce=True): 45 | batch_size, class_num = logits_teacher.shape 46 | pred_student = F.softmax(logits_student / temperature, dim=1) 47 | pred_teacher = F.softmax(logits_teacher / temperature, dim=1) 48 | student_matrix = torch.mm(pred_student, pred_student.transpose(1, 0)) 49 | teacher_matrix = torch.mm(pred_teacher, pred_teacher.transpose(1, 0)) 50 | if reduce: 51 | consistency_loss = ((teacher_matrix - student_matrix) ** 2).sum() / batch_size 52 | else: 53 | consistency_loss = ((teacher_matrix - student_matrix) ** 2) / batch_size 54 | return consistency_loss 55 | 56 | 57 | def mixup_data(x, y, alpha=1.0, use_cuda=True): 58 | '''Returns mixed inputs, pairs of targets, and lambda''' 59 | if alpha > 0: 60 | lam = np.random.beta(alpha, alpha) 61 | else: 62 | lam = 1 63 | 64 | batch_size = x.size()[0] 65 | if use_cuda: 66 | index = torch.randperm(batch_size).cuda() 67 | else: 68 | index = torch.randperm(batch_size) 69 | 70 | mixed_x = lam * x + (1 - lam) * x[index, :] 71 | y_a, y_b = y, y[index] 72 | return mixed_x, y_a, y_b, lam 73 | 74 | 75 | def mixup_data_conf(x, y, lam, use_cuda=True): 76 | '''Returns mixed inputs, pairs of targets, and lambda''' 77 | lam = lam.reshape(-1,1,1,1) 78 | batch_size = x.size()[0] 79 | if use_cuda: 80 | index = torch.randperm(batch_size).cuda() 81 | else: 82 | index = torch.randperm(batch_size) 83 | 84 | mixed_x = lam * x + (1 - lam) * x[index, :] 85 | y_a, y_b = y, y[index] 86 | return mixed_x, y_a, y_b, lam 87 | 88 | 89 | class MLKD_NOAUG(Distiller): 90 | def __init__(self, student, teacher, cfg): 91 | super(MLKD_NOAUG, self).__init__(student, teacher) 92 | self.temperature = cfg.KD.TEMPERATURE 93 | self.ce_loss_weight = cfg.KD.LOSS.CE_WEIGHT 94 | self.kd_loss_weight = cfg.KD.LOSS.KD_WEIGHT 95 | self.logit_stand = cfg.EXPERIMENT.LOGIT_STAND 96 | 97 | def forward_train(self, image, target, **kwargs): 98 | logits_student_weak, _ = self.student(image) 99 | with torch.no_grad(): 100 | logits_teacher_weak, _ = self.teacher(image) 101 | 102 | batch_size, class_num = logits_student_weak.shape 103 | 104 | pred_teacher_weak = F.softmax(logits_teacher_weak.detach(), dim=1) 105 | confidence, pseudo_labels = pred_teacher_weak.max(dim=1) 106 | confidence = confidence.detach() 107 | conf_thresh = np.percentile( 108 | confidence.cpu().numpy().flatten(), 50 109 | ) 110 | mask = confidence.le(conf_thresh).bool() 111 | 112 | class_confidence = torch.sum(pred_teacher_weak, dim=0) 113 | class_confidence = class_confidence.detach() 114 | class_confidence_thresh = np.percentile( 115 | class_confidence.cpu().numpy().flatten(), 50 116 | ) 117 | class_conf_mask = class_confidence.le(class_confidence_thresh).bool() 118 | 119 | # losses 120 | loss_ce = self.ce_loss_weight * F.cross_entropy(logits_student_weak, target) 121 | loss_kd_weak = self.kd_loss_weight * ((kd_loss( 122 | logits_student_weak, 123 | logits_teacher_weak, 124 | self.temperature, 125 | # reduce=False 126 | logit_stand=self.logit_stand, 127 | ) * mask).mean()) + self.kd_loss_weight * ((kd_loss( 128 | logits_student_weak, 129 | logits_teacher_weak, 130 | 3.0, 131 | # reduce=False 132 | logit_stand=self.logit_stand, 133 | ) * mask).mean()) + self.kd_loss_weight * ((kd_loss( 134 | logits_student_weak, 135 | logits_teacher_weak, 136 | 5.0, 137 | # reduce=False 138 | logit_stand=self.logit_stand, 139 | ) * mask).mean()) + self.kd_loss_weight * ((kd_loss( 140 | logits_student_weak, 141 | logits_teacher_weak, 142 | 2.0, 143 | # reduce=False 144 | logit_stand=self.logit_stand, 145 | ) * mask).mean()) + self.kd_loss_weight * ((kd_loss( 146 | logits_student_weak, 147 | logits_teacher_weak, 148 | 6.0, 149 | # reduce=False 150 | logit_stand=self.logit_stand, 151 | ) * mask).mean()) 152 | 153 | loss_cc_weak = self.kd_loss_weight * ((cc_loss( 154 | logits_student_weak, 155 | logits_teacher_weak, 156 | self.temperature, 157 | # reduce=False 158 | ) * class_conf_mask).mean()) + self.kd_loss_weight * ((cc_loss( 159 | logits_student_weak, 160 | logits_teacher_weak, 161 | 3.0, 162 | ) * class_conf_mask).mean()) + self.kd_loss_weight * ((cc_loss( 163 | logits_student_weak, 164 | logits_teacher_weak, 165 | 5.0, 166 | ) * class_conf_mask).mean()) + self.kd_loss_weight * ((cc_loss( 167 | logits_student_weak, 168 | logits_teacher_weak, 169 | 2.0, 170 | ) * class_conf_mask).mean()) + self.kd_loss_weight * ((cc_loss( 171 | logits_student_weak, 172 | logits_teacher_weak, 173 | 6.0, 174 | ) * class_conf_mask).mean()) 175 | 176 | loss_bc_weak = self.kd_loss_weight * ((bc_loss( 177 | logits_student_weak, 178 | logits_teacher_weak, 179 | self.temperature, 180 | ) * mask).mean()) + self.kd_loss_weight * ((bc_loss( 181 | logits_student_weak, 182 | logits_teacher_weak, 183 | 3.0, 184 | ) * mask).mean()) + self.kd_loss_weight * ((bc_loss( 185 | logits_student_weak, 186 | logits_teacher_weak, 187 | 5.0, 188 | ) * mask).mean()) + self.kd_loss_weight * ((bc_loss( 189 | logits_student_weak, 190 | logits_teacher_weak, 191 | 2.0, 192 | ) * mask).mean()) + self.kd_loss_weight * ((bc_loss( 193 | logits_student_weak, 194 | logits_teacher_weak, 195 | 6.0, 196 | ) * mask).mean()) 197 | 198 | losses_dict = { 199 | "loss_ce": loss_ce, 200 | "loss_kd": loss_kd_weak, 201 | "loss_cc": loss_cc_weak, 202 | "loss_bc": loss_bc_weak 203 | } 204 | return logits_student_weak, losses_dict 205 | 206 | -------------------------------------------------------------------------------- /mdistiller/distillers/NST.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ._base import Distiller 6 | 7 | 8 | def nst_loss(g_s, g_t): 9 | return sum([single_stage_nst_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]) 10 | 11 | 12 | def single_stage_nst_loss(f_s, f_t): 13 | s_H, t_H = f_s.shape[2], f_t.shape[2] 14 | if s_H > t_H: 15 | f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H)) 16 | elif s_H < t_H: 17 | f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H)) 18 | 19 | f_s = f_s.view(f_s.shape[0], f_s.shape[1], -1) 20 | f_s = F.normalize(f_s, dim=2) 21 | f_t = f_t.view(f_t.shape[0], f_t.shape[1], -1) 22 | f_t = F.normalize(f_t, dim=2) 23 | 24 | return ( 25 | poly_kernel(f_t, f_t).mean().detach() 26 | + poly_kernel(f_s, f_s).mean() 27 | - 2 * poly_kernel(f_s, f_t).mean() 28 | ) 29 | 30 | 31 | def poly_kernel(a, b): 32 | a = a.unsqueeze(1) 33 | b = b.unsqueeze(2) 34 | res = (a * b).sum(-1).pow(2) 35 | return res 36 | 37 | 38 | class NST(Distiller): 39 | """ 40 | Like What You Like: Knowledge Distill via Neuron Selectivity Transfer 41 | """ 42 | 43 | def __init__(self, student, teacher, cfg): 44 | super(NST, self).__init__(student, teacher) 45 | self.ce_loss_weight = cfg.NST.LOSS.CE_WEIGHT 46 | self.feat_loss_weight = cfg.NST.LOSS.FEAT_WEIGHT 47 | 48 | def forward_train(self, image, target, **kwargs): 49 | logits_student, feature_student = self.student(image) 50 | with torch.no_grad(): 51 | _, feature_teacher = self.teacher(image) 52 | 53 | # losses 54 | loss_ce = self.ce_loss_weight * F.cross_entropy(logits_student, target) 55 | loss_feat = self.feat_loss_weight * nst_loss( 56 | feature_student["feats"][1:], feature_teacher["feats"][1:] 57 | ) 58 | losses_dict = { 59 | "loss_ce": loss_ce, 60 | "loss_kd": loss_feat, 61 | } 62 | return logits_student, losses_dict 63 | -------------------------------------------------------------------------------- /mdistiller/distillers/OFD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from scipy.stats import norm 5 | import numpy as np 6 | import math 7 | 8 | from ._base import Distiller 9 | 10 | 11 | def feat_loss(source, target, margin): 12 | margin = margin.to(source) 13 | loss = ( 14 | (source - margin) ** 2 * ((source > margin) & (target <= margin)).float() 15 | + (source - target) ** 2 16 | * ((source > target) & (target > margin) & (target <= 0)).float() 17 | + (source - target) ** 2 * (target > 0).float() 18 | ) 19 | return torch.abs(loss).mean(dim=0).sum() 20 | 21 | 22 | class ConnectorConvBN(nn.Module): 23 | def __init__(self, s_channels, t_channels, kernel_size=1): 24 | super(ConnectorConvBN, self).__init__() 25 | self.s_channels = s_channels 26 | self.t_channels = t_channels 27 | self.connectors = nn.ModuleList( 28 | self._make_conenctors(s_channels, t_channels, kernel_size) 29 | ) 30 | 31 | def _make_conenctors(self, s_channels, t_channels, kernel_size): 32 | assert len(s_channels) == len(t_channels), "unequal length of feat list" 33 | connectors = nn.ModuleList( 34 | [ 35 | self._build_feature_connector(t, s, kernel_size) 36 | for t, s in zip(t_channels, s_channels) 37 | ] 38 | ) 39 | return connectors 40 | 41 | def _build_feature_connector(self, t_channel, s_channel, kernel_size): 42 | C = [ 43 | nn.Conv2d( 44 | s_channel, 45 | t_channel, 46 | kernel_size=kernel_size, 47 | stride=1, 48 | padding=(kernel_size - 1) // 2, 49 | bias=False, 50 | ), 51 | nn.BatchNorm2d(t_channel), 52 | ] 53 | for m in C: 54 | if isinstance(m, nn.Conv2d): 55 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 56 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 57 | elif isinstance(m, nn.BatchNorm2d): 58 | m.weight.data.fill_(1) 59 | m.bias.data.zero_() 60 | return nn.Sequential(*C) 61 | 62 | def forward(self, g_s): 63 | out = [] 64 | for i in range(len(g_s)): 65 | out.append(self.connectors[i](g_s[i])) 66 | 67 | return out 68 | 69 | 70 | class OFD(Distiller): 71 | def __init__(self, student, teacher, cfg): 72 | super(OFD, self).__init__(student, teacher) 73 | self.ce_loss_weight = cfg.OFD.LOSS.CE_WEIGHT 74 | self.feat_loss_weight = cfg.OFD.LOSS.FEAT_WEIGHT 75 | self.init_ofd_modules( 76 | tea_channels=self.teacher.get_stage_channels()[1:], 77 | stu_channels=self.student.get_stage_channels()[1:], 78 | bn_before_relu=self.teacher.get_bn_before_relu(), 79 | kernel_size=cfg.OFD.CONNECTOR.KERNEL_SIZE, 80 | ) 81 | 82 | def init_ofd_modules( 83 | self, tea_channels, stu_channels, bn_before_relu, kernel_size=1 84 | ): 85 | tea_channels, stu_channels = self._align_list(tea_channels, stu_channels) 86 | self.connectors = ConnectorConvBN( 87 | stu_channels, tea_channels, kernel_size=kernel_size 88 | ) 89 | 90 | self.margins = [] 91 | for idx, bn in enumerate(bn_before_relu): 92 | margin = [] 93 | std = bn.weight.data 94 | mean = bn.bias.data 95 | for (s, m) in zip(std, mean): 96 | s = abs(s.item()) 97 | m = m.item() 98 | if norm.cdf(-m / s) > 0.001: 99 | margin.append( 100 | -s 101 | * math.exp(-((m / s) ** 2) / 2) 102 | / math.sqrt(2 * math.pi) 103 | / norm.cdf(-m / s) 104 | + m 105 | ) 106 | else: 107 | margin.append(-3 * s) 108 | margin = torch.FloatTensor(margin).to(std.device) 109 | self.margins.append(margin.unsqueeze(1).unsqueeze(2).unsqueeze(0).detach()) 110 | 111 | def get_learnable_parameters(self): 112 | return super().get_learnable_parameters() + list(self.connectors.parameters()) 113 | 114 | def train(self, mode=True): 115 | # teacher as eval mode by default 116 | if not isinstance(mode, bool): 117 | raise ValueError("training mode is expected to be boolean") 118 | self.training = mode 119 | for module in self.children(): 120 | module.train(mode) 121 | return self 122 | 123 | def get_extra_parameters(self): 124 | num_p = 0 125 | for p in self.connectors.parameters(): 126 | num_p += p.numel() 127 | return num_p 128 | 129 | def forward_train(self, image, target, **kwargs): 130 | logits_student, feature_student = self.student(image) 131 | with torch.no_grad(): 132 | _, feature_teacher = self.teacher(image) 133 | 134 | # losses 135 | loss_ce = self.ce_loss_weight * F.cross_entropy(logits_student, target) 136 | loss_feat = self.feat_loss_weight * self.ofd_loss( 137 | feature_student["preact_feats"][1:], feature_teacher["preact_feats"][1:] 138 | ) 139 | losses_dict = {"loss_ce": loss_ce, "loss_kd": loss_feat} 140 | return logits_student, losses_dict 141 | 142 | def ofd_loss(self, feature_student, feature_teacher): 143 | feature_student, feature_teacher = self._align_list( 144 | feature_student, feature_teacher 145 | ) 146 | feature_student = [ 147 | self.connectors.connectors[idx](feat) 148 | for idx, feat in enumerate(feature_student) 149 | ] 150 | 151 | loss_distill = 0 152 | feat_num = len(feature_student) 153 | for i in range(feat_num): 154 | loss_distill = loss_distill + feat_loss( 155 | feature_student[i], 156 | F.adaptive_avg_pool2d( 157 | feature_teacher[i], feature_student[i].shape[-2:] 158 | ).detach(), 159 | self.margins[i], 160 | ) / 2 ** (feat_num - i - 1) 161 | return loss_distill 162 | 163 | def _align_list(self, *input_list): 164 | min_len = min([len(l) for l in input_list]) 165 | return [l[-min_len:] for l in input_list] 166 | -------------------------------------------------------------------------------- /mdistiller/distillers/PKT.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ._base import Distiller 6 | 7 | 8 | def pkt_loss(f_s, f_t, eps=1e-7): 9 | # Normalize each vector by its norm 10 | output_net_norm = torch.sqrt(torch.sum(f_s**2, dim=1, keepdim=True)) 11 | f_s = f_s / (output_net_norm + eps) 12 | f_s[f_s != f_s] = 0 13 | target_net_norm = torch.sqrt(torch.sum(f_t**2, dim=1, keepdim=True)) 14 | f_t = f_t / (target_net_norm + eps) 15 | f_t[f_t != f_t] = 0 16 | 17 | # Calculate the cosine similarity 18 | model_similarity = torch.mm(f_s, f_s.transpose(0, 1)) 19 | target_similarity = torch.mm(f_t, f_t.transpose(0, 1)) 20 | # Scale cosine similarity to 0..1 21 | model_similarity = (model_similarity + 1.0) / 2.0 22 | target_similarity = (target_similarity + 1.0) / 2.0 23 | # Transform them into probabilities 24 | model_similarity = model_similarity / torch.sum( 25 | model_similarity, dim=1, keepdim=True 26 | ) 27 | target_similarity = target_similarity / torch.sum( 28 | target_similarity, dim=1, keepdim=True 29 | ) 30 | # Calculate the KL-divergence 31 | loss = torch.mean( 32 | target_similarity 33 | * torch.log((target_similarity + eps) / (model_similarity + eps)) 34 | ) 35 | return loss 36 | 37 | 38 | class PKT(Distiller): 39 | """ 40 | Probabilistic Knowledge Transfer for deep representation learning 41 | Code from: https://github.com/passalis/probabilistic_kt 42 | """ 43 | 44 | def __init__(self, student, teacher, cfg): 45 | super(PKT, self).__init__(student, teacher) 46 | self.ce_loss_weight = cfg.PKT.LOSS.CE_WEIGHT 47 | self.feat_loss_weight = cfg.PKT.LOSS.FEAT_WEIGHT 48 | 49 | def forward_train(self, image, target, **kwargs): 50 | logits_student, feature_student = self.student(image) 51 | with torch.no_grad(): 52 | _, feature_teacher = self.teacher(image) 53 | 54 | # lossess 55 | loss_ce = self.ce_loss_weight * F.cross_entropy(logits_student, target) 56 | loss_feat = self.feat_loss_weight * pkt_loss( 57 | feature_student["pooled_feat"], feature_teacher["pooled_feat"] 58 | ) 59 | losses_dict = { 60 | "loss_ce": loss_ce, 61 | "loss_kd": loss_feat, 62 | } 63 | return logits_student, losses_dict 64 | -------------------------------------------------------------------------------- /mdistiller/distillers/RC.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import copy 5 | from ._base import Distiller 6 | 7 | def normalize(logit): 8 | mean = logit.mean(dim=-1, keepdims=True) 9 | stdv = logit.std(dim=-1, keepdims=True) 10 | return (logit - mean) / (1e-7 + stdv) 11 | 12 | def kd_loss(logits_student_in, logits_teacher_in, temperature, logit_stand): 13 | logits_student = normalize(logits_student_in) if logit_stand else logits_student_in 14 | logits_teacher = normalize(logits_teacher_in) if logit_stand else logits_teacher_in 15 | log_pred_student = F.log_softmax(logits_student / temperature, dim=1) 16 | pred_teacher = F.softmax(logits_teacher / temperature, dim=1) 17 | loss_kd = F.kl_div(log_pred_student, pred_teacher, reduction="none").sum(1).mean() 18 | loss_kd *= temperature**2 19 | return loss_kd 20 | 21 | def rc_loss(logits_student_in, target, temperature, logit_stand): 22 | logits_student = normalize(logits_student_in) if logit_stand else logits_student_in 23 | 24 | max_value, _ = torch.max(logits_student, dim=1, keepdim=True) 25 | gt_value = torch.gather(logits_student, 1, target.unsqueeze(1)) 26 | src = (gt_value + max_value).detach() 27 | logits_teacher = logits_student.detach().clone().scatter_(1, target.unsqueeze(1), src) 28 | 29 | log_pred_student = F.log_softmax(logits_student / temperature, dim=1) 30 | pred_teacher = F.softmax(logits_teacher / temperature, dim=1) 31 | loss_rc = F.kl_div(log_pred_student, pred_teacher, reduction="none").sum(1).mean() 32 | loss_rc *= temperature**2 33 | return loss_rc 34 | 35 | 36 | class RC(Distiller): 37 | """Distilling the Knowledge in a Neural Network""" 38 | 39 | def __init__(self, student, teacher, cfg): 40 | super(RC, self).__init__(student, teacher) 41 | self.temperature = cfg.RC.T 42 | self.ce_loss_weight = cfg.RC.CE_WEIGHT 43 | self.kd_loss_weight = cfg.RC.KD_WEIGHT 44 | self.rc_loss_weight = cfg.RC.RC_WEIGHT 45 | self.logit_stand = cfg.EXPERIMENT.LOGIT_STAND 46 | 47 | def forward_train(self, image, target, **kwargs): 48 | logits_student, _ = self.student(image) 49 | with torch.no_grad(): 50 | logits_teacher, _ = self.teacher(image) 51 | 52 | # losses 53 | loss_ce = self.ce_loss_weight * F.cross_entropy(logits_student, target) 54 | loss_kd = self.kd_loss_weight * kd_loss(logits_student, logits_teacher, self.temperature, self.logit_stand) 55 | loss_rc = self.rc_loss_weight * rc_loss(logits_student, target, self.temperature, self.logit_stand) 56 | losses_dict = { 57 | "loss_ce": loss_ce, 58 | "loss_kd": loss_kd, 59 | "loss_rc": loss_rc, 60 | } 61 | return logits_student, losses_dict 62 | -------------------------------------------------------------------------------- /mdistiller/distillers/REVISION.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ._base import Distiller 6 | 7 | def revision_loss(logits_student_in, logits_teacher_in, target, lambda1, lambda2, eta, num_classes): 8 | _, max_index = torch.max(logits_teacher_in, dim=1) 9 | right_mask = (target == max_index) 10 | wrong_mask = (target != max_index) 11 | right_logits_teacher = logits_teacher_in[right_mask] 12 | wrong_logits_teacher = logits_teacher_in[wrong_mask] 13 | right_logits_student = logits_student_in[right_mask] 14 | wrong_logits_student = logits_student_in[wrong_mask] 15 | right_target = target[right_mask] 16 | wrong_target = target[wrong_mask] 17 | if len(right_target) > 0: 18 | ce_loss = F.cross_entropy(right_logits_student, right_target) 19 | mse_loss1 = F.mse_loss(right_logits_student, right_logits_teacher) 20 | else: 21 | ce_loss = torch.zeros([1], device=target.device) 22 | mse_loss1 = torch.zeros([1], device=target.device) 23 | 24 | if len(wrong_target) > 0: 25 | wrong_pred_teacher = F.softmax(wrong_logits_teacher, dim=1) 26 | wrong_pred_student = F.softmax(wrong_logits_student, dim=1) 27 | one_hot_wrong_target = F.one_hot(wrong_target, num_classes) 28 | max_pred, _ = torch.max(wrong_pred_teacher, dim=1, keepdim=True) 29 | gt_pred = torch.gather(wrong_pred_teacher, dim=1, index=wrong_target.unsqueeze(1)) 30 | beta = eta / (max_pred - gt_pred + 1) 31 | wrong_pred = beta * wrong_pred_teacher + (1 - beta) * one_hot_wrong_target 32 | mse_loss2 = F.mse_loss(wrong_pred_student, wrong_pred) 33 | else: 34 | mse_loss2 = torch.zeros([1], device=target.device) 35 | 36 | return ce_loss + lambda1 * mse_loss1 + lambda2 * mse_loss2 37 | 38 | 39 | class REVISION(Distiller): 40 | """Distilling the Knowledge in a Neural Network""" 41 | 42 | def __init__(self, student, teacher, cfg): 43 | super(REVISION, self).__init__(student, teacher) 44 | self.lambda1 = 4. 45 | self.lambda2 = 1. 46 | self.eta = 0.8 47 | self.num_classes = 1000 if cfg.DATASET.TYPE == "imagenet" else 100 48 | 49 | def forward_train(self, image, target, **kwargs): 50 | logits_student, _ = self.student(image) 51 | with torch.no_grad(): 52 | logits_teacher, _ = self.teacher(image) 53 | 54 | # losses 55 | overall_loss = revision_loss(logits_student, logits_teacher, target, self.lambda1, self.lambda2, self.eta, self.num_classes) 56 | losses_dict = { 57 | "loss": overall_loss, 58 | } 59 | return logits_student, losses_dict 60 | -------------------------------------------------------------------------------- /mdistiller/distillers/RKD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ._base import Distiller 6 | 7 | 8 | def _pdist(e, squared, eps): 9 | e_square = e.pow(2).sum(dim=1) 10 | prod = e @ e.t() 11 | res = (e_square.unsqueeze(1) + e_square.unsqueeze(0) - 2 * prod).clamp(min=eps) 12 | 13 | if not squared: 14 | res = res.sqrt() 15 | 16 | res = res.clone() 17 | res[range(len(e)), range(len(e))] = 0 18 | return res 19 | 20 | 21 | def rkd_loss(f_s, f_t, squared=False, eps=1e-12, distance_weight=25, angle_weight=50): 22 | stu = f_s.view(f_s.shape[0], -1) 23 | tea = f_t.view(f_t.shape[0], -1) 24 | 25 | # RKD distance loss 26 | with torch.no_grad(): 27 | t_d = _pdist(tea, squared, eps) 28 | mean_td = t_d[t_d > 0].mean() 29 | t_d = t_d / mean_td 30 | 31 | d = _pdist(stu, squared, eps) 32 | mean_d = d[d > 0].mean() 33 | d = d / mean_d 34 | 35 | loss_d = F.smooth_l1_loss(d, t_d) 36 | 37 | # RKD Angle loss 38 | with torch.no_grad(): 39 | td = tea.unsqueeze(0) - tea.unsqueeze(1) 40 | norm_td = F.normalize(td, p=2, dim=2) 41 | t_angle = torch.bmm(norm_td, norm_td.transpose(1, 2)).view(-1) 42 | 43 | sd = stu.unsqueeze(0) - stu.unsqueeze(1) 44 | norm_sd = F.normalize(sd, p=2, dim=2) 45 | s_angle = torch.bmm(norm_sd, norm_sd.transpose(1, 2)).view(-1) 46 | 47 | loss_a = F.smooth_l1_loss(s_angle, t_angle) 48 | 49 | loss = distance_weight * loss_d + angle_weight * loss_a 50 | return loss 51 | 52 | 53 | class RKD(Distiller): 54 | """Relational Knowledge Disitllation, CVPR2019""" 55 | 56 | def __init__(self, student, teacher, cfg): 57 | super(RKD, self).__init__(student, teacher) 58 | self.distance_weight = cfg.RKD.DISTANCE_WEIGHT 59 | self.angle_weight = cfg.RKD.ANGLE_WEIGHT 60 | self.ce_loss_weight = cfg.RKD.LOSS.CE_WEIGHT 61 | self.feat_loss_weight = cfg.RKD.LOSS.FEAT_WEIGHT 62 | self.eps = cfg.RKD.PDIST.EPSILON 63 | self.squared = cfg.RKD.PDIST.SQUARED 64 | 65 | def forward_train(self, image, target, **kwargs): 66 | logits_student, feature_student = self.student(image) 67 | with torch.no_grad(): 68 | _, feature_teacher = self.teacher(image) 69 | 70 | # losses 71 | loss_ce = self.ce_loss_weight * F.cross_entropy(logits_student, target) 72 | loss_rkd = self.feat_loss_weight * rkd_loss( 73 | feature_student["pooled_feat"], 74 | feature_teacher["pooled_feat"], 75 | self.squared, 76 | self.eps, 77 | self.distance_weight, 78 | self.angle_weight, 79 | ) 80 | losses_dict = { 81 | "loss_ce": loss_ce, 82 | "loss_kd": loss_rkd, 83 | } 84 | return logits_student, losses_dict 85 | -------------------------------------------------------------------------------- /mdistiller/distillers/RLD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ._base import Distiller 6 | 7 | def normalize(logit): 8 | mean = logit.mean(dim=-1, keepdims=True) 9 | stdv = logit.std(dim=-1, keepdims=True) 10 | return (logit - mean) / (1e-7 + stdv) 11 | 12 | 13 | def rld_loss(logits_student_in, logits_teacher_in, target, alpha, beta, temperature, logit_stand, alpha_temperature): 14 | logits_student = normalize(logits_student_in) if logit_stand else logits_student_in 15 | logits_teacher = normalize(logits_teacher_in) if logit_stand else logits_teacher_in 16 | 17 | # scd loss 18 | student_gt_mask = _get_gt_mask(logits_student, target) 19 | student_other_mask = _get_other_mask(logits_student, target) 20 | max_index = torch.argmax(logits_teacher, dim=1) 21 | teacher_max_mask = _get_gt_mask(logits_teacher, max_index) 22 | teacher_other_mask = _get_other_mask(logits_teacher, max_index) 23 | pred_student = F.softmax(logits_student / alpha_temperature, dim=1) 24 | pred_teacher = F.softmax(logits_teacher / alpha_temperature, dim=1) 25 | pred_student = cat_mask(pred_student, student_gt_mask, student_other_mask) 26 | pred_teacher = cat_mask(pred_teacher, teacher_max_mask, teacher_other_mask) 27 | log_pred_student = torch.log(pred_student) 28 | scd_loss = F.kl_div(log_pred_student, pred_teacher, reduction='batchmean') * (alpha_temperature**2) 29 | 30 | # mcd loss 31 | mask = _get_ge_mask(logits_teacher, target) 32 | assert mask.shape == logits_student.shape 33 | masked_student = (logits_student / temperature).masked_fill(mask, -1e9) 34 | log_pred_student_part2 = F.log_softmax(masked_student, dim=1) 35 | masked_teacher = (logits_teacher / temperature).masked_fill(mask, -1e9) 36 | pred_teacher_part2 = F.softmax(masked_teacher, dim=1) 37 | mcd_loss = F.kl_div(log_pred_student_part2, pred_teacher_part2, reduction='batchmean') * (temperature**2) 38 | 39 | return alpha * scd_loss + beta * mcd_loss 40 | 41 | 42 | def _get_ge_mask(logits, target): 43 | assert logits.dim() == 2 and target.dim() == 1 and logits.size(0) == target.size(0) 44 | gt_value = torch.gather(logits, 1, target.unsqueeze(1)) 45 | mask = torch.where(logits >= gt_value, 1, 0).bool() 46 | return mask 47 | 48 | def _get_gt_mask(logits, target): 49 | target = target.reshape(-1) 50 | mask = torch.zeros_like(logits).scatter_(1, target.unsqueeze(1), 1).bool() 51 | return mask 52 | 53 | def _get_other_mask(logits, target): 54 | target = target.reshape(-1) 55 | mask = torch.ones_like(logits).scatter_(1, target.unsqueeze(1), 0).bool() 56 | return mask 57 | 58 | def cat_mask(t, mask1, mask2): 59 | t1 = (t * mask1).sum(dim=1, keepdim=True) 60 | t2 = (t * mask2).sum(dim=1, keepdim=True) 61 | rt = torch.cat([t1, t2], dim=1) 62 | return rt 63 | 64 | 65 | class RLD(Distiller): 66 | 67 | def __init__(self, student, teacher, cfg): 68 | super(RLD, self).__init__(student, teacher) 69 | self.ce_loss_weight = cfg.RLD.CE_WEIGHT 70 | self.alpha = cfg.RLD.ALPHA 71 | self.beta = cfg.RLD.BETA 72 | self.temperature = cfg.RLD.T 73 | self.warmup = cfg.RLD.WARMUP 74 | self.logit_stand = cfg.EXPERIMENT.LOGIT_STAND 75 | self.alpha_temperature = cfg.RLD.ALPHA_T 76 | 77 | def forward_train(self, image, target, **kwargs): 78 | logits_student, _ = self.student(image) 79 | with torch.no_grad(): 80 | logits_teacher, _ = self.teacher(image) 81 | 82 | # losses 83 | loss_ce = self.ce_loss_weight * F.cross_entropy(logits_student, target) 84 | loss_rld = min(kwargs["epoch"] / self.warmup, 1.0) * rld_loss( 85 | logits_student, 86 | logits_teacher, 87 | target, 88 | self.alpha, 89 | self.beta, 90 | self.temperature, 91 | self.logit_stand, 92 | self.alpha_temperature, 93 | ) 94 | losses_dict = { 95 | "loss_ce": loss_ce, 96 | "loss_kd": loss_rld, 97 | } 98 | return logits_student, losses_dict 99 | -------------------------------------------------------------------------------- /mdistiller/distillers/ReviewKD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | import math 6 | import pdb 7 | 8 | from ._base import Distiller 9 | 10 | 11 | def hcl_loss(fstudent, fteacher): 12 | loss_all = 0.0 13 | for fs, ft in zip(fstudent, fteacher): 14 | n, c, h, w = fs.shape 15 | loss = F.mse_loss(fs, ft, reduction="mean") 16 | cnt = 1.0 17 | tot = 1.0 18 | for l in [4, 2, 1]: 19 | if l >= h: 20 | continue 21 | tmpfs = F.adaptive_avg_pool2d(fs, (l, l)) 22 | tmpft = F.adaptive_avg_pool2d(ft, (l, l)) 23 | cnt /= 2.0 24 | loss += F.mse_loss(tmpfs, tmpft, reduction="mean") * cnt 25 | tot += cnt 26 | loss = loss / tot 27 | loss_all = loss_all + loss 28 | return loss_all 29 | 30 | 31 | class ReviewKD(Distiller): 32 | def __init__(self, student, teacher, cfg): 33 | super(ReviewKD, self).__init__(student, teacher) 34 | self.shapes = cfg.REVIEWKD.SHAPES 35 | self.out_shapes = cfg.REVIEWKD.OUT_SHAPES 36 | in_channels = cfg.REVIEWKD.IN_CHANNELS 37 | out_channels = cfg.REVIEWKD.OUT_CHANNELS 38 | self.ce_loss_weight = cfg.REVIEWKD.CE_WEIGHT 39 | self.reviewkd_loss_weight = cfg.REVIEWKD.REVIEWKD_WEIGHT 40 | self.warmup_epochs = cfg.REVIEWKD.WARMUP_EPOCHS 41 | self.stu_preact = cfg.REVIEWKD.STU_PREACT 42 | self.max_mid_channel = cfg.REVIEWKD.MAX_MID_CHANNEL 43 | 44 | abfs = nn.ModuleList() 45 | mid_channel = min(512, in_channels[-1]) 46 | for idx, in_channel in enumerate(in_channels): 47 | abfs.append( 48 | ABF( 49 | in_channel, 50 | mid_channel, 51 | out_channels[idx], 52 | idx < len(in_channels) - 1, 53 | ) 54 | ) 55 | self.abfs = abfs[::-1] 56 | 57 | def get_learnable_parameters(self): 58 | return super().get_learnable_parameters() + list(self.abfs.parameters()) 59 | 60 | def get_extra_parameters(self): 61 | num_p = 0 62 | for p in self.abfs.parameters(): 63 | num_p += p.numel() 64 | return num_p 65 | 66 | def forward_train(self, image, target, **kwargs): 67 | logits_student, features_student = self.student(image) 68 | with torch.no_grad(): 69 | logits_teacher, features_teacher = self.teacher(image) 70 | 71 | # get features 72 | if self.stu_preact: 73 | x = features_student["preact_feats"] + [ 74 | features_student["pooled_feat"].unsqueeze(-1).unsqueeze(-1) 75 | ] 76 | else: 77 | x = features_student["feats"] + [ 78 | features_student["pooled_feat"].unsqueeze(-1).unsqueeze(-1) 79 | ] 80 | x = x[::-1] 81 | results = [] 82 | out_features, res_features = self.abfs[0](x[0], out_shape=self.out_shapes[0]) 83 | results.append(out_features) 84 | for features, abf, shape, out_shape in zip( 85 | x[1:], self.abfs[1:], self.shapes[1:], self.out_shapes[1:] 86 | ): 87 | out_features, res_features = abf(features, res_features, shape, out_shape) 88 | results.insert(0, out_features) 89 | features_teacher = features_teacher["preact_feats"][1:] + [ 90 | features_teacher["pooled_feat"].unsqueeze(-1).unsqueeze(-1) 91 | ] 92 | # losses 93 | loss_ce = self.ce_loss_weight * F.cross_entropy(logits_student, target) 94 | loss_reviewkd = ( 95 | self.reviewkd_loss_weight 96 | * min(kwargs["epoch"] / self.warmup_epochs, 1.0) 97 | * hcl_loss(results, features_teacher) 98 | ) 99 | losses_dict = { 100 | "loss_ce": loss_ce, 101 | "loss_kd": loss_reviewkd, 102 | } 103 | return logits_student, losses_dict 104 | 105 | 106 | class ABF(nn.Module): 107 | def __init__(self, in_channel, mid_channel, out_channel, fuse): 108 | super(ABF, self).__init__() 109 | self.conv1 = nn.Sequential( 110 | nn.Conv2d(in_channel, mid_channel, kernel_size=1, bias=False), 111 | nn.BatchNorm2d(mid_channel), 112 | ) 113 | self.conv2 = nn.Sequential( 114 | nn.Conv2d( 115 | mid_channel, out_channel, kernel_size=3, stride=1, padding=1, bias=False 116 | ), 117 | nn.BatchNorm2d(out_channel), 118 | ) 119 | if fuse: 120 | self.att_conv = nn.Sequential( 121 | nn.Conv2d(mid_channel * 2, 2, kernel_size=1), 122 | nn.Sigmoid(), 123 | ) 124 | else: 125 | self.att_conv = None 126 | nn.init.kaiming_uniform_(self.conv1[0].weight, a=1) # pyre-ignore 127 | nn.init.kaiming_uniform_(self.conv2[0].weight, a=1) # pyre-ignore 128 | 129 | def forward(self, x, y=None, shape=None, out_shape=None): 130 | n, _, h, w = x.shape 131 | # transform student features 132 | x = self.conv1(x) 133 | if self.att_conv is not None: 134 | # upsample residual features 135 | y = F.interpolate(y, (shape, shape), mode="nearest") 136 | # fusion 137 | z = torch.cat([x, y], dim=1) 138 | z = self.att_conv(z) 139 | x = x * z[:, 0].view(n, 1, h, w) + y * z[:, 1].view(n, 1, h, w) 140 | # output 141 | if x.shape[-1] != out_shape: 142 | x = F.interpolate(x, (out_shape, out_shape), mode="nearest") 143 | y = self.conv2(x) 144 | return y, x 145 | -------------------------------------------------------------------------------- /mdistiller/distillers/SP.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ._base import Distiller 6 | 7 | 8 | def sp_loss(g_s, g_t): 9 | return sum([similarity_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]) 10 | 11 | 12 | def similarity_loss(f_s, f_t): 13 | bsz = f_s.shape[0] 14 | f_s = f_s.view(bsz, -1) 15 | f_t = f_t.view(bsz, -1) 16 | 17 | G_s = torch.mm(f_s, torch.t(f_s)) 18 | G_s = torch.nn.functional.normalize(G_s) 19 | G_t = torch.mm(f_t, torch.t(f_t)) 20 | G_t = torch.nn.functional.normalize(G_t) 21 | 22 | G_diff = G_t - G_s 23 | loss = (G_diff * G_diff).view(-1, 1).sum(0) / (bsz * bsz) 24 | return loss 25 | 26 | 27 | class SP(Distiller): 28 | """Similarity-Preserving Knowledge Distillation, ICCV2019""" 29 | 30 | def __init__(self, student, teacher, cfg): 31 | super(SP, self).__init__(student, teacher) 32 | self.ce_loss_weight = cfg.SP.LOSS.CE_WEIGHT 33 | self.feat_loss_weight = cfg.SP.LOSS.FEAT_WEIGHT 34 | 35 | def forward_train(self, image, target, **kwargs): 36 | logits_student, feature_student = self.student(image) 37 | with torch.no_grad(): 38 | _, feature_teacher = self.teacher(image) 39 | 40 | # losses 41 | loss_ce = self.ce_loss_weight * F.cross_entropy(logits_student, target) 42 | loss_feat = self.feat_loss_weight * sp_loss( 43 | [feature_student["feats"][-1]], [feature_teacher["feats"][-1]] 44 | ) 45 | losses_dict = { 46 | "loss_ce": loss_ce, 47 | "loss_kd": loss_feat, 48 | } 49 | return logits_student, losses_dict 50 | -------------------------------------------------------------------------------- /mdistiller/distillers/SWAP.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import copy 5 | 6 | from ._base import Distiller 7 | 8 | def normalize(logit): 9 | mean = logit.mean(dim=-1, keepdims=True) 10 | stdv = logit.std(dim=-1, keepdims=True) 11 | return (logit - mean) / (1e-7 + stdv) 12 | 13 | def swap(logits, target): 14 | assert logits.dim() == 2 and target.dim() == 1 and logits.size(0) == target.size(0) 15 | gt_value = torch.gather(logits, 1, target.unsqueeze(1)) 16 | gt_index = target.unsqueeze(1) 17 | max_value, max_index = torch.max(logits, dim=1, keepdim=True) 18 | swapped_logits = copy.deepcopy(logits).scatter_(1, max_index, gt_value).scatter_(1, gt_index, max_value) 19 | return swapped_logits 20 | 21 | def kd_loss(logits_student_in, logits_teacher_in, temperature, logit_stand): 22 | logits_student = normalize(logits_student_in) if logit_stand else logits_student_in 23 | logits_teacher = normalize(logits_teacher_in) if logit_stand else logits_teacher_in 24 | log_pred_student = F.log_softmax(logits_student / temperature, dim=1) 25 | pred_teacher = F.softmax(logits_teacher / temperature, dim=1) 26 | loss_kd = F.kl_div(log_pred_student, pred_teacher, reduction="none").sum(1).mean() 27 | loss_kd *= temperature**2 28 | return loss_kd 29 | 30 | 31 | class SWAP(Distiller): 32 | """Distilling the Knowledge in a Neural Network""" 33 | 34 | def __init__(self, student, teacher, cfg): 35 | super(SWAP, self).__init__(student, teacher) 36 | self.temperature = cfg.KD.TEMPERATURE 37 | self.logit_stand = cfg.EXPERIMENT.LOGIT_STAND 38 | 39 | def forward_train(self, image, target, **kwargs): 40 | logits_student, _ = self.student(image) 41 | with torch.no_grad(): 42 | logits_teacher, _ = self.teacher(image) 43 | logits_teacher = swap(logits_teacher, target) 44 | 45 | # losses 46 | loss_kd = kd_loss(logits_student, logits_teacher, self.temperature, self.logit_stand) 47 | losses_dict = { 48 | "loss_kd": loss_kd, 49 | } 50 | return logits_student, losses_dict 51 | -------------------------------------------------------------------------------- /mdistiller/distillers/Sonly.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ._base import Distiller 6 | 7 | 8 | def kd_loss(logits_student, logits_teacher, temperature): 9 | log_pred_student = F.log_softmax(logits_student / temperature, dim=1) 10 | pred_teacher = F.softmax(logits_teacher / temperature, dim=1) 11 | loss_kd = F.kl_div(log_pred_student, pred_teacher, reduction="none").sum(1).mean() 12 | loss_kd *= temperature**2 13 | return loss_kd 14 | 15 | 16 | class Sonly(Distiller): 17 | """Distilling the Knowledge in a Neural Network""" 18 | 19 | def __init__(self, student, teacher, cfg): 20 | super(Sonly, self).__init__(student, teacher) 21 | self.temperature = cfg.KD.TEMPERATURE 22 | self.ce_loss_weight = cfg.KD.LOSS.CE_WEIGHT 23 | self.kd_loss_weight = cfg.KD.LOSS.KD_WEIGHT 24 | 25 | def forward_train(self, image, target, **kwargs): 26 | logits_student, _ = self.student(image) 27 | with torch.no_grad(): 28 | logits_teacher, _ = self.teacher(image) 29 | 30 | # losses 31 | loss_ce = self.ce_loss_weight * F.cross_entropy(logits_student, target) 32 | loss_kd = self.kd_loss_weight * kd_loss( 33 | logits_student, logits_teacher, self.temperature 34 | ) 35 | losses_dict = { 36 | "loss_ce": loss_ce, 37 | } 38 | return logits_student, losses_dict 39 | -------------------------------------------------------------------------------- /mdistiller/distillers/VID.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from ._base import Distiller 7 | from ._common import get_feat_shapes 8 | 9 | 10 | def conv1x1(in_channels, out_channels, stride=1): 11 | return nn.Conv2d( 12 | in_channels, out_channels, kernel_size=1, padding=0, bias=False, stride=stride 13 | ) 14 | 15 | 16 | def vid_loss(regressor, log_scale, f_s, f_t, eps=1e-5): 17 | # pool for dimentsion match 18 | s_H, t_H = f_s.shape[2], f_t.shape[2] 19 | if s_H > t_H: 20 | f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H)) 21 | elif s_H < t_H: 22 | f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H)) 23 | else: 24 | pass 25 | pred_mean = regressor(f_s) 26 | pred_var = torch.log(1.0 + torch.exp(log_scale)) + eps 27 | pred_var = pred_var.view(1, -1, 1, 1).to(pred_mean) 28 | neg_log_prob = 0.5 * ((pred_mean - f_t) ** 2 / pred_var + torch.log(pred_var)) 29 | loss = torch.mean(neg_log_prob) 30 | return loss 31 | 32 | 33 | class VID(Distiller): 34 | """ 35 | Variational Information Distillation for Knowledge Transfer (CVPR 2019), 36 | code from author: https://github.com/ssahn0215/variational-information-distillation 37 | """ 38 | 39 | def __init__(self, student, teacher, cfg): 40 | super(VID, self).__init__(student, teacher) 41 | self.ce_loss_weight = cfg.VID.LOSS.CE_WEIGHT 42 | self.feat_loss_weight = cfg.VID.LOSS.FEAT_WEIGHT 43 | self.init_pred_var = cfg.VID.INIT_PRED_VAR 44 | self.eps = cfg.VID.EPS 45 | feat_s_shapes, feat_t_shapes = get_feat_shapes( 46 | self.student, self.teacher, cfg.VID.INPUT_SIZE 47 | ) 48 | feat_s_channels = [s[1] for s in feat_s_shapes[1:]] 49 | feat_t_channels = [s[1] for s in feat_t_shapes[1:]] 50 | self.init_vid_modules(feat_s_channels, feat_t_channels) 51 | 52 | def init_vid_modules(self, feat_s_shapes, feat_t_shapes): 53 | self.regressors = nn.ModuleList() 54 | self.log_scales = [] 55 | for s, t in zip(feat_s_shapes, feat_t_shapes): 56 | regressor = nn.Sequential( 57 | conv1x1(s, t), nn.ReLU(), conv1x1(t, t), nn.ReLU(), conv1x1(t, t) 58 | ) 59 | self.regressors.append(regressor) 60 | log_scale = torch.nn.Parameter( 61 | np.log(np.exp(self.init_pred_var - self.eps) - 1.0) * torch.ones(t) 62 | ) 63 | self.log_scales.append(log_scale) 64 | 65 | def get_learnable_parameters(self): 66 | parameters = super().get_learnable_parameters() 67 | for regressor in self.regressors: 68 | parameters += list(regressor.parameters()) 69 | return parameters 70 | 71 | def get_extra_parameters(self): 72 | num_p = 0 73 | for regressor in self.regressors: 74 | for p in regressor.parameters(): 75 | num_p += p.numel() 76 | return num_p 77 | 78 | def forward_train(self, image, target, **kwargs): 79 | logits_student, feature_student = self.student(image) 80 | with torch.no_grad(): 81 | _, feature_teacher = self.teacher(image) 82 | 83 | # losses 84 | loss_ce = self.ce_loss_weight * F.cross_entropy(logits_student, target) 85 | loss_vid = 0 86 | for i in range(len(feature_student["feats"][1:])): 87 | loss_vid += vid_loss( 88 | self.regressors[i], 89 | self.log_scales[i], 90 | feature_student["feats"][1:][i], 91 | feature_teacher["feats"][1:][i], 92 | self.eps, 93 | ) 94 | loss_vid = self.feat_loss_weight * loss_vid 95 | losses_dict = { 96 | "loss_ce": loss_ce, 97 | "loss_kd": loss_vid, 98 | } 99 | return logits_student, losses_dict 100 | -------------------------------------------------------------------------------- /mdistiller/distillers/__init__.py: -------------------------------------------------------------------------------- 1 | from ._base import Vanilla 2 | from .KD import KD 3 | from .MLKD import MLKD 4 | from .AT import AT 5 | from .OFD import OFD 6 | from .RKD import RKD 7 | from .FitNet import FitNet 8 | from .KDSVD import KDSVD 9 | from .CRD import CRD 10 | from .NST import NST 11 | from .PKT import PKT 12 | from .SP import SP 13 | from .Sonly import Sonly 14 | from .VID import VID 15 | from .ReviewKD import ReviewKD 16 | from .DKD import DKD 17 | from .RLD import RLD 18 | from .SWAP import SWAP 19 | from .REVISION import REVISION 20 | from .RC import RC 21 | from .MLKD_NOAUG import MLKD_NOAUG 22 | 23 | distiller_dict = { 24 | "NONE": Vanilla, 25 | "KD": KD, 26 | "MLKD": MLKD, 27 | "MLKD_NOAUG": MLKD_NOAUG, 28 | "AT": AT, 29 | "OFD": OFD, 30 | "RKD": RKD, 31 | "FITNET": FitNet, 32 | "KDSVD": KDSVD, 33 | "CRD": CRD, 34 | "NST": NST, 35 | "PKT": PKT, 36 | "SP": SP, 37 | "Sonly": Sonly, 38 | "VID": VID, 39 | "REVIEWKD": ReviewKD, 40 | "DKD": DKD, 41 | "RLD": RLD, 42 | "SWAP": SWAP, 43 | "REVISION": REVISION, 44 | "RC": RC, 45 | } 46 | -------------------------------------------------------------------------------- /mdistiller/distillers/_base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Distiller(nn.Module): 7 | def __init__(self, student, teacher): 8 | super(Distiller, self).__init__() 9 | self.student = student 10 | self.teacher = teacher 11 | 12 | def train(self, mode=True): 13 | # teacher as eval mode by default 14 | if not isinstance(mode, bool): 15 | raise ValueError("training mode is expected to be boolean") 16 | self.training = mode 17 | for module in self.children(): 18 | module.train(mode) 19 | self.teacher.eval() 20 | return self 21 | 22 | def get_learnable_parameters(self): 23 | # if the method introduces extra parameters, re-impl this function 24 | return [v for k, v in self.student.named_parameters()] 25 | 26 | def get_extra_parameters(self): 27 | # calculate the extra parameters introduced by the distiller 28 | return 0 29 | 30 | def forward_train(self, **kwargs): 31 | # training function for the distillation method 32 | raise NotImplementedError() 33 | 34 | def forward_test(self, image): 35 | return self.student(image)[0] 36 | 37 | def forward(self, **kwargs): 38 | if self.training: 39 | return self.forward_train(**kwargs) 40 | return self.forward_test(kwargs["image"]) 41 | 42 | 43 | class Vanilla(nn.Module): 44 | def __init__(self, student): 45 | super(Vanilla, self).__init__() 46 | self.student = student 47 | 48 | def get_learnable_parameters(self): 49 | return [v for k, v in self.student.named_parameters()] 50 | 51 | def forward_train(self, image, target, **kwargs): 52 | logits_student, _ = self.student(image) 53 | loss = F.cross_entropy(logits_student, target) 54 | return logits_student, {"ce": loss} 55 | 56 | def forward(self, **kwargs): 57 | if self.training: 58 | return self.forward_train(**kwargs) 59 | return self.forward_test(kwargs["image"]) 60 | 61 | def forward_test(self, image): 62 | return self.student(image)[0] 63 | -------------------------------------------------------------------------------- /mdistiller/distillers/_common.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ConvReg(nn.Module): 7 | """Convolutional regression""" 8 | 9 | def __init__(self, s_shape, t_shape, use_relu=True): 10 | super(ConvReg, self).__init__() 11 | self.use_relu = use_relu 12 | s_N, s_C, s_H, s_W = s_shape 13 | t_N, t_C, t_H, t_W = t_shape 14 | if s_H == 2 * t_H: 15 | self.conv = nn.Conv2d(s_C, t_C, kernel_size=3, stride=2, padding=1) 16 | elif s_H * 2 == t_H: 17 | self.conv = nn.ConvTranspose2d(s_C, t_C, kernel_size=4, stride=2, padding=1) 18 | elif s_H >= t_H: 19 | self.conv = nn.Conv2d(s_C, t_C, kernel_size=(1 + s_H - t_H, 1 + s_W - t_W)) 20 | else: 21 | raise NotImplemented("student size {}, teacher size {}".format(s_H, t_H)) 22 | self.bn = nn.BatchNorm2d(t_C) 23 | self.relu = nn.ReLU(inplace=True) 24 | 25 | def forward(self, x): 26 | x = self.conv(x) 27 | if self.use_relu: 28 | return self.relu(self.bn(x)) 29 | else: 30 | return self.bn(x) 31 | 32 | 33 | def get_feat_shapes(student, teacher, input_size): 34 | data = torch.randn(1, 3, *input_size) 35 | with torch.no_grad(): 36 | _, feat_s = student(data) 37 | _, feat_t = teacher(data) 38 | feat_s_shapes = [f.shape for f in feat_s["feats"]] 39 | feat_t_shapes = [f.shape for f in feat_t["feats"]] 40 | return feat_s_shapes, feat_t_shapes 41 | -------------------------------------------------------------------------------- /mdistiller/distillers/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class CrossEntropyLabelSmooth(nn.Module): 5 | """Cross entropy loss with label smoothing regularizer. 6 | Reference: 7 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 8 | Equation: y = (1 - epsilon) * y + epsilon / K. 9 | Args: 10 | num_classes (int): number of classes. 11 | epsilon (float): weight. 12 | """ 13 | 14 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True, reduction=True): 15 | super(CrossEntropyLabelSmooth, self).__init__() 16 | self.num_classes = num_classes 17 | self.epsilon = epsilon 18 | self.use_gpu = use_gpu 19 | self.reduction = reduction 20 | self.logsoftmax = nn.LogSoftmax(dim=1) 21 | 22 | def forward(self, inputs, targets): 23 | """ 24 | Args: 25 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 26 | targets: ground truth labels with shape (num_classes) 27 | """ 28 | log_probs = self.logsoftmax(inputs) 29 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).cpu(), 1) 30 | if self.use_gpu: targets = targets.cuda() 31 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 32 | loss = (- targets * log_probs).sum(dim=1) 33 | if self.reduction: 34 | return loss.mean() 35 | else: 36 | return loss -------------------------------------------------------------------------------- /mdistiller/engine/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import BaseTrainer, CRDTrainer, AugTrainer 2 | 3 | trainer_dict = { 4 | "base": BaseTrainer, 5 | "crd": CRDTrainer, 6 | "ours": AugTrainer 7 | } 8 | -------------------------------------------------------------------------------- /mdistiller/engine/cfg.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | from .utils import log_msg 3 | 4 | 5 | def show_cfg(cfg): 6 | dump_cfg = CN() 7 | dump_cfg.EXPERIMENT = cfg.EXPERIMENT 8 | dump_cfg.DATASET = cfg.DATASET 9 | dump_cfg.DISTILLER = cfg.DISTILLER 10 | dump_cfg.SOLVER = cfg.SOLVER 11 | dump_cfg.LOG = cfg.LOG 12 | if cfg.DISTILLER.TYPE in cfg: 13 | dump_cfg.update({cfg.DISTILLER.TYPE: cfg.get(cfg.DISTILLER.TYPE)}) 14 | print(log_msg("CONFIG:\n{}".format(dump_cfg.dump()), "INFO")) 15 | 16 | 17 | CFG = CN() 18 | 19 | # Experiment 20 | CFG.EXPERIMENT = CN() 21 | CFG.EXPERIMENT.PROJECT = "distill" 22 | CFG.EXPERIMENT.NAME = "" 23 | CFG.EXPERIMENT.TAG = "default" 24 | CFG.EXPERIMENT.LOGIT_STAND = False 25 | CFG.EXPERIMENT.AUG = False 26 | 27 | # Dataset 28 | CFG.DATASET = CN() 29 | CFG.DATASET.TYPE = "cifar100" 30 | CFG.DATASET.NUM_WORKERS = 2 31 | CFG.DATASET.TEST = CN() 32 | CFG.DATASET.TEST.BATCH_SIZE = 64 33 | 34 | # Distiller 35 | CFG.DISTILLER = CN() 36 | CFG.DISTILLER.TYPE = "NONE" # Vanilla as default 37 | CFG.DISTILLER.TEACHER = "ResNet50" 38 | CFG.DISTILLER.STUDENT = "resnet32" 39 | 40 | # Solver 41 | CFG.SOLVER = CN() 42 | CFG.SOLVER.TRAINER = "base" 43 | CFG.SOLVER.BATCH_SIZE = 64 44 | CFG.SOLVER.EPOCHS = 240 45 | CFG.SOLVER.LR = 0.05 46 | CFG.SOLVER.LR_DECAY_STAGES = [150, 180, 210] 47 | CFG.SOLVER.LR_DECAY_RATE = 0.1 48 | CFG.SOLVER.WEIGHT_DECAY = 0.0001 49 | CFG.SOLVER.MOMENTUM = 0.9 50 | CFG.SOLVER.TYPE = "SGD" 51 | 52 | # Log 53 | CFG.LOG = CN() 54 | CFG.LOG.TENSORBOARD_FREQ = 500 55 | CFG.LOG.SAVE_CHECKPOINT_FREQ = 1000 56 | CFG.LOG.PREFIX = "./output" 57 | CFG.LOG.WANDB = False 58 | 59 | # Distillation Methods 60 | 61 | # KD CFG 62 | CFG.KD = CN() 63 | CFG.KD.TEMPERATURE = 4.0 64 | CFG.KD.LOSS = CN() 65 | CFG.KD.LOSS.CE_WEIGHT = 0.1 66 | CFG.KD.LOSS.KD_WEIGHT = 0.9 67 | 68 | # AT CFG 69 | CFG.AT = CN() 70 | CFG.AT.P = 2 71 | CFG.AT.LOSS = CN() 72 | CFG.AT.LOSS.CE_WEIGHT = 1.0 73 | CFG.AT.LOSS.FEAT_WEIGHT = 1000.0 74 | 75 | # RKD CFG 76 | CFG.RKD = CN() 77 | CFG.RKD.DISTANCE_WEIGHT = 25 78 | CFG.RKD.ANGLE_WEIGHT = 50 79 | CFG.RKD.LOSS = CN() 80 | CFG.RKD.LOSS.CE_WEIGHT = 1.0 81 | CFG.RKD.LOSS.FEAT_WEIGHT = 1.0 82 | CFG.RKD.PDIST = CN() 83 | CFG.RKD.PDIST.EPSILON = 1e-12 84 | CFG.RKD.PDIST.SQUARED = False 85 | 86 | # FITNET CFG 87 | CFG.FITNET = CN() 88 | CFG.FITNET.HINT_LAYER = 2 # (0, 1, 2, 3, 4) 89 | CFG.FITNET.INPUT_SIZE = (32, 32) 90 | CFG.FITNET.LOSS = CN() 91 | CFG.FITNET.LOSS.CE_WEIGHT = 1.0 92 | CFG.FITNET.LOSS.FEAT_WEIGHT = 100.0 93 | 94 | # KDSVD CFG 95 | CFG.KDSVD = CN() 96 | CFG.KDSVD.K = 1 97 | CFG.KDSVD.LOSS = CN() 98 | CFG.KDSVD.LOSS.CE_WEIGHT = 1.0 99 | CFG.KDSVD.LOSS.FEAT_WEIGHT = 1.0 100 | 101 | # OFD CFG 102 | CFG.OFD = CN() 103 | CFG.OFD.LOSS = CN() 104 | CFG.OFD.LOSS.CE_WEIGHT = 1.0 105 | CFG.OFD.LOSS.FEAT_WEIGHT = 0.001 106 | CFG.OFD.CONNECTOR = CN() 107 | CFG.OFD.CONNECTOR.KERNEL_SIZE = 1 108 | 109 | # NST CFG 110 | CFG.NST = CN() 111 | CFG.NST.LOSS = CN() 112 | CFG.NST.LOSS.CE_WEIGHT = 1.0 113 | CFG.NST.LOSS.FEAT_WEIGHT = 50.0 114 | 115 | # PKT CFG 116 | CFG.PKT = CN() 117 | CFG.PKT.LOSS = CN() 118 | CFG.PKT.LOSS.CE_WEIGHT = 1.0 119 | CFG.PKT.LOSS.FEAT_WEIGHT = 30000.0 120 | 121 | # SP CFG 122 | CFG.SP = CN() 123 | CFG.SP.LOSS = CN() 124 | CFG.SP.LOSS.CE_WEIGHT = 1.0 125 | CFG.SP.LOSS.FEAT_WEIGHT = 3000.0 126 | 127 | # VID CFG 128 | CFG.VID = CN() 129 | CFG.VID.LOSS = CN() 130 | CFG.VID.LOSS.CE_WEIGHT = 1.0 131 | CFG.VID.LOSS.FEAT_WEIGHT = 1.0 132 | CFG.VID.EPS = 1e-5 133 | CFG.VID.INIT_PRED_VAR = 5.0 134 | CFG.VID.INPUT_SIZE = (32, 32) 135 | 136 | # CRD CFG 137 | CFG.CRD = CN() 138 | CFG.CRD.MODE = "exact" # ("exact", "relax") 139 | CFG.CRD.FEAT = CN() 140 | CFG.CRD.FEAT.DIM = 128 141 | CFG.CRD.FEAT.STUDENT_DIM = 256 142 | CFG.CRD.FEAT.TEACHER_DIM = 256 143 | CFG.CRD.LOSS = CN() 144 | CFG.CRD.LOSS.CE_WEIGHT = 1.0 145 | CFG.CRD.LOSS.FEAT_WEIGHT = 0.8 146 | CFG.CRD.NCE = CN() 147 | CFG.CRD.NCE.K = 16384 148 | CFG.CRD.NCE.MOMENTUM = 0.5 149 | CFG.CRD.NCE.TEMPERATURE = 0.07 150 | 151 | # ReviewKD CFG 152 | CFG.REVIEWKD = CN() 153 | CFG.REVIEWKD.CE_WEIGHT = 1.0 154 | CFG.REVIEWKD.REVIEWKD_WEIGHT = 1.0 155 | CFG.REVIEWKD.WARMUP_EPOCHS = 20 156 | CFG.REVIEWKD.SHAPES = [1, 8, 16, 32] 157 | CFG.REVIEWKD.OUT_SHAPES = [1, 8, 16, 32] 158 | CFG.REVIEWKD.IN_CHANNELS = [64, 128, 256, 256] 159 | CFG.REVIEWKD.OUT_CHANNELS = [64, 128, 256, 256] 160 | CFG.REVIEWKD.MAX_MID_CHANNEL = 512 161 | CFG.REVIEWKD.STU_PREACT = False 162 | 163 | # DKD(Decoupled Knowledge Distillation) CFG 164 | CFG.DKD = CN() 165 | CFG.DKD.CE_WEIGHT = 1.0 166 | CFG.DKD.ALPHA = 1.0 167 | CFG.DKD.BETA = 8.0 168 | CFG.DKD.T = 4.0 169 | CFG.DKD.WARMUP = 20 170 | 171 | # RC 172 | CFG.RC = CN() 173 | CFG.RC.CE_WEIGHT = 1.0 174 | CFG.RC.KD_WEIGHT = 1.0 175 | CFG.RC.RC_WEIGHT = 0.1 176 | CFG.RC.T = 4.0 177 | 178 | # RLD CFG 179 | CFG.RLD = CN() 180 | CFG.RLD.CE_WEIGHT = 1.0 181 | CFG.RLD.ALPHA = 1.0 182 | CFG.RLD.BETA = 8.0 183 | CFG.RLD.ALPHA_T = 1.0 184 | CFG.RLD.T = 4.0 185 | CFG.RLD.WARMUP = 20 -------------------------------------------------------------------------------- /mdistiller/engine/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | import sys 6 | import time 7 | from tqdm import tqdm 8 | 9 | 10 | class AverageMeter(object): 11 | """Computes and stores the average and current value""" 12 | 13 | def __init__(self): 14 | self.reset() 15 | 16 | def reset(self): 17 | self.val = 0 18 | self.avg = 0 19 | self.sum = 0 20 | self.count = 0 21 | 22 | def update(self, val, n=1): 23 | self.val = val 24 | self.sum += val * n 25 | self.count += n 26 | self.avg = self.sum / self.count 27 | 28 | 29 | def validate(val_loader, distiller): 30 | batch_time, losses, top1, top5 = [AverageMeter() for _ in range(4)] 31 | criterion = nn.CrossEntropyLoss() 32 | num_iter = len(val_loader) 33 | pbar = tqdm(range(num_iter)) 34 | 35 | distiller.eval() 36 | with torch.no_grad(): 37 | start_time = time.time() 38 | for idx, (image, target) in enumerate(val_loader): 39 | image = image.float() 40 | image = image.cuda(non_blocking=True) 41 | target = target.cuda(non_blocking=True) 42 | output = distiller(image=image) 43 | loss = criterion(output, target) 44 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 45 | batch_size = image.size(0) 46 | losses.update(loss.cpu().detach().numpy().mean(), batch_size) 47 | top1.update(acc1[0], batch_size) 48 | top5.update(acc5[0], batch_size) 49 | 50 | # measure elapsed time 51 | batch_time.update(time.time() - start_time) 52 | start_time = time.time() 53 | msg = "Top-1:{top1.avg:.3f}| Top-5:{top5.avg:.3f}".format( 54 | top1=top1, top5=top5 55 | ) 56 | pbar.set_description(log_msg(msg, "EVAL")) 57 | pbar.update() 58 | pbar.close() 59 | return top1.avg, top5.avg, losses.avg 60 | 61 | 62 | def validate_npy(val_loader, distiller): 63 | batch_time, losses, top1, top5 = [AverageMeter() for _ in range(4)] 64 | criterion = nn.CrossEntropyLoss() 65 | num_iter = len(val_loader) 66 | pbar = tqdm(range(num_iter)) 67 | 68 | distiller.eval() 69 | with torch.no_grad(): 70 | start_time = time.time() 71 | start_eval = True 72 | for idx, (image, target) in enumerate(val_loader): 73 | image = image.float() 74 | image = image.cuda(non_blocking=True) 75 | target = target.cuda(non_blocking=True) 76 | output = distiller(image=image) 77 | loss = criterion(output, target) 78 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 79 | batch_size = image.size(0) 80 | losses.update(loss.cpu().detach().numpy().mean(), batch_size) 81 | top1.update(acc1[0], batch_size) 82 | top5.update(acc5[0], batch_size) 83 | output = nn.Softmax()(output) 84 | if start_eval: 85 | all_image = image.float().cpu() 86 | all_output = output.float().cpu() 87 | all_label = target.float().cpu() 88 | start_eval = False 89 | else: 90 | all_image = torch.cat((all_image, image.float().cpu()), dim=0) 91 | all_output = torch.cat((all_output, output.float().cpu()), dim=0) 92 | all_label = torch.cat((all_label, target.float().cpu()), dim=0) 93 | 94 | # measure elapsed time 95 | batch_time.update(time.time() - start_time) 96 | start_time = time.time() 97 | msg = "Top-1:{top1.avg:.3f}| Top-5:{top5.avg:.3f}".format( 98 | top1=top1, top5=top5 99 | ) 100 | pbar.set_description(log_msg(msg, "EVAL")) 101 | pbar.update() 102 | all_image, all_output, all_label = all_image.numpy(), all_output.numpy(), all_label.numpy() 103 | pbar.close() 104 | return top1.avg, top5.avg, losses.avg, all_image, all_output, all_label 105 | 106 | 107 | def log_msg(msg, mode="INFO"): 108 | color_map = { 109 | "INFO": 36, 110 | "TRAIN": 32, 111 | "EVAL": 31, 112 | } 113 | msg = "\033[{}m[{}] {}\033[0m".format(color_map[mode], mode, msg) 114 | return msg 115 | 116 | 117 | def adjust_learning_rate(epoch, cfg, optimizer): 118 | steps = np.sum(epoch > np.asarray(cfg.SOLVER.LR_DECAY_STAGES)) 119 | if steps > 0: 120 | new_lr = cfg.SOLVER.LR * (cfg.SOLVER.LR_DECAY_RATE**steps) 121 | for param_group in optimizer.param_groups: 122 | param_group["lr"] = new_lr 123 | return new_lr 124 | return cfg.SOLVER.LR 125 | 126 | 127 | def accuracy(output, target, topk=(1,)): 128 | with torch.no_grad(): 129 | maxk = max(topk) 130 | batch_size = target.size(0) 131 | _, pred = output.topk(maxk, 1, True, True) 132 | pred = pred.t() 133 | correct = pred.eq(target.reshape(1, -1).expand_as(pred)) 134 | res = [] 135 | for k in topk: 136 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 137 | res.append(correct_k.mul_(100.0 / batch_size)) 138 | return res 139 | 140 | 141 | def save_checkpoint(obj, path): 142 | with open(path, "wb") as f: 143 | torch.save(obj, f) 144 | 145 | 146 | def load_checkpoint(path): 147 | with open(path, "rb") as f: 148 | return torch.load(f, map_location="cpu") 149 | -------------------------------------------------------------------------------- /mdistiller/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .cifar import cifar_model_dict 2 | from .imagenet import imagenet_model_dict 3 | -------------------------------------------------------------------------------- /mdistiller/models/cifar/ShuffleNetv1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ShuffleBlock(nn.Module): 7 | def __init__(self, groups): 8 | super(ShuffleBlock, self).__init__() 9 | self.groups = groups 10 | 11 | def forward(self, x): 12 | """Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]""" 13 | N, C, H, W = x.size() 14 | g = self.groups 15 | return x.reshape(N, g, C // g, H, W).permute(0, 2, 1, 3, 4).reshape(N, C, H, W) 16 | 17 | 18 | class Bottleneck(nn.Module): 19 | def __init__(self, in_planes, out_planes, stride, groups, is_last=False): 20 | super(Bottleneck, self).__init__() 21 | self.is_last = is_last 22 | self.stride = stride 23 | 24 | mid_planes = int(out_planes / 4) 25 | g = 1 if in_planes == 24 else groups 26 | self.conv1 = nn.Conv2d( 27 | in_planes, mid_planes, kernel_size=1, groups=g, bias=False 28 | ) 29 | self.bn1 = nn.BatchNorm2d(mid_planes) 30 | self.shuffle1 = ShuffleBlock(groups=g) 31 | self.conv2 = nn.Conv2d( 32 | mid_planes, 33 | mid_planes, 34 | kernel_size=3, 35 | stride=stride, 36 | padding=1, 37 | groups=mid_planes, 38 | bias=False, 39 | ) 40 | self.bn2 = nn.BatchNorm2d(mid_planes) 41 | self.conv3 = nn.Conv2d( 42 | mid_planes, out_planes, kernel_size=1, groups=groups, bias=False 43 | ) 44 | self.bn3 = nn.BatchNorm2d(out_planes) 45 | 46 | self.shortcut = nn.Sequential() 47 | if stride == 2: 48 | self.shortcut = nn.Sequential(nn.AvgPool2d(3, stride=2, padding=1)) 49 | 50 | def forward(self, x): 51 | out = F.relu(self.bn1(self.conv1(x))) 52 | out = self.shuffle1(out) 53 | out = F.relu(self.bn2(self.conv2(out))) 54 | out = self.bn3(self.conv3(out)) 55 | res = self.shortcut(x) 56 | preact = torch.cat([out, res], 1) if self.stride == 2 else out + res 57 | out = F.relu(preact) 58 | # out = F.relu(torch.cat([out, res], 1)) if self.stride == 2 else F.relu(out+res) 59 | if self.is_last: 60 | return out, preact 61 | else: 62 | return out 63 | 64 | 65 | class ShuffleNet(nn.Module): 66 | def __init__(self, cfg, num_classes=10): 67 | super(ShuffleNet, self).__init__() 68 | out_planes = cfg["out_planes"] 69 | num_blocks = cfg["num_blocks"] 70 | groups = cfg["groups"] 71 | 72 | self.conv1 = nn.Conv2d(3, 24, kernel_size=1, bias=False) 73 | self.bn1 = nn.BatchNorm2d(24) 74 | self.in_planes = 24 75 | self.layer1 = self._make_layer(out_planes[0], num_blocks[0], groups) 76 | self.layer2 = self._make_layer(out_planes[1], num_blocks[1], groups) 77 | self.layer3 = self._make_layer(out_planes[2], num_blocks[2], groups) 78 | self.linear = nn.Linear(out_planes[2], num_classes) 79 | # self.stage_channels = out_channels 80 | 81 | def _make_layer(self, out_planes, num_blocks, groups): 82 | layers = [] 83 | for i in range(num_blocks): 84 | stride = 2 if i == 0 else 1 85 | cat_planes = self.in_planes if i == 0 else 0 86 | layers.append( 87 | Bottleneck( 88 | self.in_planes, 89 | out_planes - cat_planes, 90 | stride=stride, 91 | groups=groups, 92 | is_last=(i == num_blocks - 1), 93 | ) 94 | ) 95 | self.in_planes = out_planes 96 | return nn.Sequential(*layers) 97 | 98 | def get_feat_modules(self): 99 | feat_m = nn.ModuleList([]) 100 | feat_m.append(self.conv1) 101 | feat_m.append(self.bn1) 102 | feat_m.append(self.layer1) 103 | feat_m.append(self.layer2) 104 | feat_m.append(self.layer3) 105 | return feat_m 106 | 107 | def get_bn_before_relu(self): 108 | raise NotImplementedError( 109 | 'ShuffleNet currently is not supported for "Overhaul" teacher' 110 | ) 111 | 112 | def forward( 113 | self, 114 | x, 115 | ): 116 | out = F.relu(self.bn1(self.conv1(x))) 117 | f0 = out 118 | out, f1_pre = self.layer1(out) 119 | f1 = out 120 | out, f2_pre = self.layer2(out) 121 | f2 = out 122 | out, f3_pre = self.layer3(out) 123 | f3 = out 124 | out = F.avg_pool2d(out, 4) 125 | out = out.reshape(out.size(0), -1) 126 | f4 = out 127 | out = self.linear(out) 128 | 129 | feats = {} 130 | feats["feats"] = [f0, f1, f2, f3] 131 | feats["preact_feats"] = [f0, f1_pre, f2_pre, f3_pre] 132 | feats["pooled_feat"] = f4 133 | 134 | return out, feats 135 | 136 | 137 | def ShuffleV1(**kwargs): 138 | cfg = {"out_planes": [240, 480, 960], "num_blocks": [4, 8, 4], "groups": 3} 139 | return ShuffleNet(cfg, **kwargs) 140 | 141 | 142 | if __name__ == "__main__": 143 | 144 | x = torch.randn(2, 3, 32, 32) 145 | net = ShuffleV1(num_classes=100) 146 | import time 147 | 148 | a = time.time() 149 | logit, feats = net(x) 150 | b = time.time() 151 | print(b - a) 152 | for f in feats["feats"]: 153 | print(f.shape, f.min().item()) 154 | print(logit.shape) 155 | -------------------------------------------------------------------------------- /mdistiller/models/cifar/ShuffleNetv2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ShuffleBlock(nn.Module): 7 | def __init__(self, groups=2): 8 | super(ShuffleBlock, self).__init__() 9 | self.groups = groups 10 | 11 | def forward(self, x): 12 | """Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]""" 13 | N, C, H, W = x.size() 14 | g = self.groups 15 | return x.reshape(N, g, C // g, H, W).permute(0, 2, 1, 3, 4).reshape(N, C, H, W) 16 | 17 | 18 | class SplitBlock(nn.Module): 19 | def __init__(self, ratio): 20 | super(SplitBlock, self).__init__() 21 | self.ratio = ratio 22 | 23 | def forward(self, x): 24 | c = int(x.size(1) * self.ratio) 25 | return x[:, :c, :, :], x[:, c:, :, :] 26 | 27 | 28 | class BasicBlock(nn.Module): 29 | def __init__(self, in_channels, split_ratio=0.5, is_last=False): 30 | super(BasicBlock, self).__init__() 31 | self.is_last = is_last 32 | self.split = SplitBlock(split_ratio) 33 | in_channels = int(in_channels * split_ratio) 34 | self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=1, bias=False) 35 | self.bn1 = nn.BatchNorm2d(in_channels) 36 | self.conv2 = nn.Conv2d( 37 | in_channels, 38 | in_channels, 39 | kernel_size=3, 40 | stride=1, 41 | padding=1, 42 | groups=in_channels, 43 | bias=False, 44 | ) 45 | self.bn2 = nn.BatchNorm2d(in_channels) 46 | self.conv3 = nn.Conv2d(in_channels, in_channels, kernel_size=1, bias=False) 47 | self.bn3 = nn.BatchNorm2d(in_channels) 48 | self.shuffle = ShuffleBlock() 49 | 50 | def forward(self, x): 51 | x1, x2 = self.split(x) 52 | out = F.relu(self.bn1(self.conv1(x2))) 53 | out = self.bn2(self.conv2(out)) 54 | preact = self.bn3(self.conv3(out)) 55 | out = F.relu(preact) 56 | # out = F.relu(self.bn3(self.conv3(out))) 57 | preact = torch.cat([x1, preact], 1) 58 | out = torch.cat([x1, out], 1) 59 | out = self.shuffle(out) 60 | if self.is_last: 61 | return out, preact 62 | else: 63 | return out 64 | 65 | 66 | class DownBlock(nn.Module): 67 | def __init__(self, in_channels, out_channels): 68 | super(DownBlock, self).__init__() 69 | mid_channels = out_channels // 2 70 | # left 71 | self.conv1 = nn.Conv2d( 72 | in_channels, 73 | in_channels, 74 | kernel_size=3, 75 | stride=2, 76 | padding=1, 77 | groups=in_channels, 78 | bias=False, 79 | ) 80 | self.bn1 = nn.BatchNorm2d(in_channels) 81 | self.conv2 = nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False) 82 | self.bn2 = nn.BatchNorm2d(mid_channels) 83 | # right 84 | self.conv3 = nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False) 85 | self.bn3 = nn.BatchNorm2d(mid_channels) 86 | self.conv4 = nn.Conv2d( 87 | mid_channels, 88 | mid_channels, 89 | kernel_size=3, 90 | stride=2, 91 | padding=1, 92 | groups=mid_channels, 93 | bias=False, 94 | ) 95 | self.bn4 = nn.BatchNorm2d(mid_channels) 96 | self.conv5 = nn.Conv2d(mid_channels, mid_channels, kernel_size=1, bias=False) 97 | self.bn5 = nn.BatchNorm2d(mid_channels) 98 | 99 | self.shuffle = ShuffleBlock() 100 | 101 | def forward(self, x): 102 | # left 103 | out1 = self.bn1(self.conv1(x)) 104 | out1 = F.relu(self.bn2(self.conv2(out1))) 105 | # right 106 | out2 = F.relu(self.bn3(self.conv3(x))) 107 | out2 = self.bn4(self.conv4(out2)) 108 | out2 = F.relu(self.bn5(self.conv5(out2))) 109 | # concat 110 | out = torch.cat([out1, out2], 1) 111 | out = self.shuffle(out) 112 | return out 113 | 114 | 115 | class ShuffleNetV2(nn.Module): 116 | def __init__(self, net_size, num_classes=10): 117 | super(ShuffleNetV2, self).__init__() 118 | out_channels = configs[net_size]["out_channels"] 119 | num_blocks = configs[net_size]["num_blocks"] 120 | 121 | # self.conv1 = nn.Conv2d(3, 24, kernel_size=3, 122 | # stride=1, padding=1, bias=False) 123 | self.conv1 = nn.Conv2d(3, 24, kernel_size=1, bias=False) 124 | self.bn1 = nn.BatchNorm2d(24) 125 | self.in_channels = 24 126 | self.layer1 = self._make_layer(out_channels[0], num_blocks[0]) 127 | self.layer2 = self._make_layer(out_channels[1], num_blocks[1]) 128 | self.layer3 = self._make_layer(out_channels[2], num_blocks[2]) 129 | self.conv2 = nn.Conv2d( 130 | out_channels[2], 131 | out_channels[3], 132 | kernel_size=1, 133 | stride=1, 134 | padding=0, 135 | bias=False, 136 | ) 137 | self.bn2 = nn.BatchNorm2d(out_channels[3]) 138 | self.linear = nn.Linear(out_channels[3], num_classes) 139 | self.stage_channels = out_channels 140 | 141 | def _make_layer(self, out_channels, num_blocks): 142 | layers = [DownBlock(self.in_channels, out_channels)] 143 | for i in range(num_blocks): 144 | layers.append(BasicBlock(out_channels, is_last=(i == num_blocks - 1))) 145 | self.in_channels = out_channels 146 | return nn.Sequential(*layers) 147 | 148 | def get_feat_modules(self): 149 | feat_m = nn.ModuleList([]) 150 | feat_m.append(self.conv1) 151 | feat_m.append(self.bn1) 152 | feat_m.append(self.layer1) 153 | feat_m.append(self.layer2) 154 | feat_m.append(self.layer3) 155 | return feat_m 156 | 157 | def get_bn_before_relu(self): 158 | raise NotImplementedError( 159 | 'ShuffleNetV2 currently is not supported for "Overhaul" teacher' 160 | ) 161 | 162 | def get_stage_channels(self): 163 | return [24] + list(self.stage_channels[:-1]) 164 | 165 | def forward(self, x): 166 | out = F.relu(self.bn1(self.conv1(x))) 167 | # out = F.max_pool2d(out, 3, stride=2, padding=1) 168 | f0 = out 169 | out, f1_pre = self.layer1(out) 170 | f1 = out 171 | out, f2_pre = self.layer2(out) 172 | f2 = out 173 | out, f3_pre = self.layer3(out) 174 | f3 = out 175 | out = F.relu(self.bn2(self.conv2(out))) 176 | out = F.avg_pool2d(out, 4) 177 | out = out.reshape(out.size(0), -1) 178 | avg = out 179 | f4 = out 180 | out = self.linear(out) 181 | 182 | feats = {} 183 | feats["feats"] = [f0, f1, f2, f3] 184 | feats["preact_feats"] = [f0, f1_pre, f2_pre, f3_pre] 185 | feats["pooled_feat"] = f4 186 | 187 | return out, feats 188 | 189 | 190 | configs = { 191 | 0.2: {"out_channels": (40, 80, 160, 512), "num_blocks": (3, 3, 3)}, 192 | 0.3: {"out_channels": (40, 80, 160, 512), "num_blocks": (3, 7, 3)}, 193 | 0.5: {"out_channels": (48, 96, 192, 1024), "num_blocks": (3, 7, 3)}, 194 | 1: {"out_channels": (116, 232, 464, 1024), "num_blocks": (3, 7, 3)}, 195 | 1.5: {"out_channels": (176, 352, 704, 1024), "num_blocks": (3, 7, 3)}, 196 | 2: {"out_channels": (224, 488, 976, 2048), "num_blocks": (3, 7, 3)}, 197 | } 198 | 199 | 200 | def ShuffleV2(**kwargs): 201 | model = ShuffleNetV2(net_size=1, **kwargs) 202 | return model 203 | 204 | 205 | if __name__ == "__main__": 206 | net = ShuffleV2(num_classes=100) 207 | x = torch.randn(3, 3, 32, 32) 208 | import time 209 | 210 | a = time.time() 211 | logit, feats = net(x) 212 | b = time.time() 213 | print(b - a) 214 | for f in feats["feats"]: 215 | print(f.shape, f.min().item()) 216 | print(logit.shape) 217 | -------------------------------------------------------------------------------- /mdistiller/models/cifar/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .resnet import ( 3 | resnet8, 4 | resnet14, 5 | resnet20, 6 | resnet32, 7 | resnet44, 8 | resnet56, 9 | resnet110, 10 | resnet8x4, 11 | resnet32x4, 12 | ) 13 | from .resnetv2 import ResNet50, ResNet18 14 | from .wrn import wrn_16_1, wrn_16_2, wrn_40_1, wrn_40_2, wrn_16_4, wrn_28_2, wrn_28_4 15 | from .vgg import vgg19_bn, vgg16_bn, vgg13_bn, vgg11_bn, vgg8_bn 16 | from .mobilenetv2 import mobile_half 17 | from .ShuffleNetv1 import ShuffleV1 18 | from .ShuffleNetv2 import ShuffleV2 19 | 20 | 21 | cifar100_model_prefix = os.path.join( 22 | os.path.dirname(os.path.abspath(__file__)), 23 | "../../../download_ckpts/cifar_teachers/" 24 | ) 25 | cifar_model_dict = { 26 | # teachers 27 | "resnet56": ( 28 | resnet56, 29 | cifar100_model_prefix + "resnet56_vanilla/ckpt_epoch_240.pth", 30 | ), 31 | "resnet110": ( 32 | resnet110, 33 | cifar100_model_prefix + "resnet110_vanilla/ckpt_epoch_240.pth", 34 | ), 35 | "resnet32x4": ( 36 | resnet32x4, 37 | cifar100_model_prefix + "resnet32x4_vanilla/ckpt_epoch_240.pth", 38 | ), 39 | "ResNet50": ( 40 | ResNet50, 41 | cifar100_model_prefix + "ResNet50_vanilla/ckpt_epoch_240.pth", 42 | ), 43 | "wrn_40_2": ( 44 | wrn_40_2, 45 | cifar100_model_prefix + "wrn_40_2_vanilla/ckpt_epoch_240.pth", 46 | ), 47 | "wrn_16_4": (wrn_16_4, cifar100_model_prefix + "wrn_16_4_vanilla/student_best"), 48 | "wrn_28_2": (wrn_28_2, cifar100_model_prefix + "wrn_28_2_vanilla/student_best"), 49 | "wrn_28_4": (wrn_28_4, cifar100_model_prefix + "wrn_28_4_vanilla/student_best"), 50 | "vgg13": (vgg13_bn, cifar100_model_prefix + "vgg13_vanilla/ckpt_epoch_240.pth"), 51 | # students 52 | "resnet8": (resnet8, None), 53 | "resnet14": (resnet14, None), 54 | "resnet20": (resnet20, None), 55 | "resnet32": (resnet32, None), 56 | "resnet44": (resnet44, None), 57 | "resnet8x4": (resnet8x4, None), 58 | "ResNet18": (ResNet18, None), 59 | "wrn_16_1": (wrn_16_1, None), 60 | "wrn_16_2": (wrn_16_2, None), 61 | "wrn_40_1": (wrn_40_1, None), 62 | "vgg8": (vgg8_bn, None), 63 | "vgg11": (vgg11_bn, None), 64 | "vgg16": (vgg16_bn, None), 65 | "vgg19": (vgg19_bn, None), 66 | "MobileNetV2": (mobile_half, None), 67 | "ShuffleV1": (ShuffleV1, None), 68 | "ShuffleV2": (ShuffleV2, None), 69 | } 70 | -------------------------------------------------------------------------------- /mdistiller/models/cifar/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | __all__ = ["mobilenetv2_T_w", "mobile_half"] 6 | 7 | BN = None 8 | 9 | 10 | def conv_bn(inp, oup, stride): 11 | return nn.Sequential( 12 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 13 | nn.BatchNorm2d(oup), 14 | nn.ReLU(inplace=True), 15 | ) 16 | 17 | 18 | def conv_1x1_bn(inp, oup): 19 | return nn.Sequential( 20 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 21 | nn.BatchNorm2d(oup), 22 | nn.ReLU(inplace=True), 23 | ) 24 | 25 | 26 | class InvertedResidual(nn.Module): 27 | def __init__(self, inp, oup, stride, expand_ratio): 28 | super(InvertedResidual, self).__init__() 29 | self.blockname = None 30 | 31 | self.stride = stride 32 | assert stride in [1, 2] 33 | 34 | self.use_res_connect = self.stride == 1 and inp == oup 35 | 36 | self.conv = nn.Sequential( 37 | # pw 38 | nn.Conv2d(inp, inp * expand_ratio, 1, 1, 0, bias=False), 39 | nn.BatchNorm2d(inp * expand_ratio), 40 | nn.ReLU(inplace=True), 41 | # dw 42 | nn.Conv2d( 43 | inp * expand_ratio, 44 | inp * expand_ratio, 45 | 3, 46 | stride, 47 | 1, 48 | groups=inp * expand_ratio, 49 | bias=False, 50 | ), 51 | nn.BatchNorm2d(inp * expand_ratio), 52 | nn.ReLU(inplace=True), 53 | # pw-linear 54 | nn.Conv2d(inp * expand_ratio, oup, 1, 1, 0, bias=False), 55 | nn.BatchNorm2d(oup), 56 | ) 57 | self.names = ["0", "1", "2", "3", "4", "5", "6", "7"] 58 | 59 | def forward(self, x): 60 | t = x 61 | if self.use_res_connect: 62 | return t + self.conv(x) 63 | else: 64 | return self.conv(x) 65 | 66 | 67 | class MobileNetV2(nn.Module): 68 | """mobilenetV2""" 69 | 70 | def __init__(self, T, feature_dim, input_size=32, width_mult=1.0, remove_avg=False): 71 | super(MobileNetV2, self).__init__() 72 | self.remove_avg = remove_avg 73 | 74 | # setting of inverted residual blocks 75 | self.interverted_residual_setting = [ 76 | # t, c, n, s 77 | [1, 16, 1, 1], 78 | [T, 24, 2, 1], 79 | [T, 32, 3, 2], 80 | [T, 64, 4, 2], 81 | [T, 96, 3, 1], 82 | [T, 160, 3, 2], 83 | [T, 320, 1, 1], 84 | ] 85 | 86 | # building first layer 87 | assert input_size % 32 == 0 88 | input_channel = int(32 * width_mult) 89 | self.conv1 = conv_bn(3, input_channel, 2) 90 | 91 | # building inverted residual blocks 92 | self.blocks = nn.ModuleList([]) 93 | for t, c, n, s in self.interverted_residual_setting: 94 | output_channel = int(c * width_mult) 95 | layers = [] 96 | strides = [s] + [1] * (n - 1) 97 | for stride in strides: 98 | layers.append( 99 | InvertedResidual(input_channel, output_channel, stride, t) 100 | ) 101 | input_channel = output_channel 102 | self.blocks.append(nn.Sequential(*layers)) 103 | 104 | self.last_channel = int(1280 * width_mult) if width_mult > 1.0 else 1280 105 | self.conv2 = conv_1x1_bn(input_channel, self.last_channel) 106 | 107 | # building classifier 108 | self.classifier = nn.Sequential( 109 | # nn.Dropout(0.5), 110 | nn.Linear(self.last_channel, feature_dim), 111 | ) 112 | 113 | H = input_size // (32 // 2) 114 | self.avgpool = nn.AvgPool2d(H, ceil_mode=True) 115 | 116 | self._initialize_weights() 117 | print(T, width_mult) 118 | self.stage_channels = [32, 24, 32, 96, 320] 119 | self.stage_channels = [int(c * width_mult) for c in self.stage_channels] 120 | 121 | def get_bn_before_relu(self): 122 | bn1 = self.blocks[1][-1].conv[-1] 123 | bn2 = self.blocks[2][-1].conv[-1] 124 | bn3 = self.blocks[4][-1].conv[-1] 125 | bn4 = self.blocks[6][-1].conv[-1] 126 | return [bn1, bn2, bn3, bn4] 127 | 128 | def get_feat_modules(self): 129 | feat_m = nn.ModuleList([]) 130 | feat_m.append(self.conv1) 131 | feat_m.append(self.blocks) 132 | return feat_m 133 | 134 | def get_stage_channels(self): 135 | return self.stage_channels 136 | 137 | def forward(self, x): 138 | out = self.conv1(x) 139 | f0 = out 140 | out = self.blocks[0](out) 141 | out = self.blocks[1](out) 142 | f1 = out 143 | out = self.blocks[2](out) 144 | f2 = out 145 | out = self.blocks[3](out) 146 | out = self.blocks[4](out) 147 | f3 = out 148 | out = self.blocks[5](out) 149 | out = self.blocks[6](out) 150 | f4 = out 151 | out = self.conv2(out) 152 | 153 | if not self.remove_avg: 154 | out = self.avgpool(out) 155 | out = out.reshape(out.size(0), -1) 156 | avg = out 157 | out = self.classifier(out) 158 | 159 | feats = {} 160 | feats["feats"] = [f0, f1, f2, f3, f4] 161 | feats["pooled_feat"] = avg 162 | 163 | return out, feats 164 | 165 | def _initialize_weights(self): 166 | for m in self.modules(): 167 | if isinstance(m, nn.Conv2d): 168 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 169 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 170 | if m.bias is not None: 171 | m.bias.data.zero_() 172 | elif isinstance(m, nn.BatchNorm2d): 173 | m.weight.data.fill_(1) 174 | m.bias.data.zero_() 175 | elif isinstance(m, nn.Linear): 176 | n = m.weight.size(1) 177 | m.weight.data.normal_(0, 0.01) 178 | m.bias.data.zero_() 179 | 180 | 181 | def mobilenetv2_T_w(T, W, feature_dim=100): 182 | model = MobileNetV2(T=T, feature_dim=feature_dim, width_mult=W) 183 | return model 184 | 185 | 186 | def mobile_half(num_classes): 187 | return mobilenetv2_T_w(6, 0.5, num_classes) 188 | 189 | 190 | if __name__ == "__main__": 191 | x = torch.randn(2, 3, 32, 32) 192 | 193 | net = mobile_half(100) 194 | 195 | logit, feats = net(x) 196 | for f in feats["feats"]: 197 | print(f.shape, f.min().item()) 198 | print(logit.shape) 199 | -------------------------------------------------------------------------------- /mdistiller/models/cifar/resnetv2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class BasicBlock(nn.Module): 7 | expansion = 1 8 | 9 | def __init__(self, in_planes, planes, stride=1, is_last=False): 10 | super(BasicBlock, self).__init__() 11 | self.is_last = is_last 12 | self.conv1 = nn.Conv2d( 13 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False 14 | ) 15 | self.bn1 = nn.BatchNorm2d(planes) 16 | self.conv2 = nn.Conv2d( 17 | planes, planes, kernel_size=3, stride=1, padding=1, bias=False 18 | ) 19 | self.bn2 = nn.BatchNorm2d(planes) 20 | 21 | self.shortcut = nn.Sequential() 22 | if stride != 1 or in_planes != self.expansion * planes: 23 | self.shortcut = nn.Sequential( 24 | nn.Conv2d( 25 | in_planes, 26 | self.expansion * planes, 27 | kernel_size=1, 28 | stride=stride, 29 | bias=False, 30 | ), 31 | nn.BatchNorm2d(self.expansion * planes), 32 | ) 33 | 34 | def forward(self, x): 35 | out = F.relu(self.bn1(self.conv1(x))) 36 | out = self.bn2(self.conv2(out)) 37 | out += self.shortcut(x) 38 | preact = out 39 | out = F.relu(out) 40 | if self.is_last: 41 | return out, preact 42 | else: 43 | return out 44 | 45 | 46 | class Bottleneck(nn.Module): 47 | expansion = 4 48 | 49 | def __init__(self, in_planes, planes, stride=1, is_last=False): 50 | super(Bottleneck, self).__init__() 51 | self.is_last = is_last 52 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 53 | self.bn1 = nn.BatchNorm2d(planes) 54 | self.conv2 = nn.Conv2d( 55 | planes, planes, kernel_size=3, stride=stride, padding=1, bias=False 56 | ) 57 | self.bn2 = nn.BatchNorm2d(planes) 58 | self.conv3 = nn.Conv2d( 59 | planes, self.expansion * planes, kernel_size=1, bias=False 60 | ) 61 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 62 | 63 | self.shortcut = nn.Sequential() 64 | if stride != 1 or in_planes != self.expansion * planes: 65 | self.shortcut = nn.Sequential( 66 | nn.Conv2d( 67 | in_planes, 68 | self.expansion * planes, 69 | kernel_size=1, 70 | stride=stride, 71 | bias=False, 72 | ), 73 | nn.BatchNorm2d(self.expansion * planes), 74 | ) 75 | 76 | def forward(self, x): 77 | out = F.relu(self.bn1(self.conv1(x))) 78 | out = F.relu(self.bn2(self.conv2(out))) 79 | out = self.bn3(self.conv3(out)) 80 | out += self.shortcut(x) 81 | preact = out 82 | out = F.relu(out) 83 | if self.is_last: 84 | return out, preact 85 | else: 86 | return out 87 | 88 | 89 | class ResNet(nn.Module): 90 | def __init__(self, block, num_blocks, num_classes=10, zero_init_residual=False): 91 | super(ResNet, self).__init__() 92 | self.in_planes = 64 93 | 94 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 95 | self.bn1 = nn.BatchNorm2d(64) 96 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 97 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 98 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 99 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 100 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 101 | self.linear = nn.Linear(512 * block.expansion, num_classes) 102 | 103 | for m in self.modules(): 104 | if isinstance(m, nn.Conv2d): 105 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 106 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 107 | nn.init.constant_(m.weight, 1) 108 | nn.init.constant_(m.bias, 0) 109 | 110 | # Zero-initialize the last BN in each residual branch, 111 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 112 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 113 | if zero_init_residual: 114 | for m in self.modules(): 115 | if isinstance(m, Bottleneck): 116 | nn.init.constant_(m.bn3.weight, 0) 117 | elif isinstance(m, BasicBlock): 118 | nn.init.constant_(m.bn2.weight, 0) 119 | self.stage_channels = [256, 512, 1024, 2048] 120 | 121 | def get_feat_modules(self): 122 | feat_m = nn.ModuleList([]) 123 | feat_m.append(self.conv1) 124 | feat_m.append(self.bn1) 125 | feat_m.append(self.layer1) 126 | feat_m.append(self.layer2) 127 | feat_m.append(self.layer3) 128 | feat_m.append(self.layer4) 129 | return feat_m 130 | 131 | def get_bn_before_relu(self): 132 | if isinstance(self.layer1[0], Bottleneck): 133 | bn1 = self.layer1[-1].bn3 134 | bn2 = self.layer2[-1].bn3 135 | bn3 = self.layer3[-1].bn3 136 | bn4 = self.layer4[-1].bn3 137 | elif isinstance(self.layer1[0], BasicBlock): 138 | bn1 = self.layer1[-1].bn2 139 | bn2 = self.layer2[-1].bn2 140 | bn3 = self.layer3[-1].bn2 141 | bn4 = self.layer4[-1].bn2 142 | else: 143 | raise NotImplementedError("ResNet unknown block error !!!") 144 | 145 | return [bn1, bn2, bn3, bn4] 146 | 147 | def get_stage_channels(self): 148 | return self.stage_channels 149 | 150 | def _make_layer(self, block, planes, num_blocks, stride): 151 | strides = [stride] + [1] * (num_blocks - 1) 152 | layers = [] 153 | for i in range(num_blocks): 154 | stride = strides[i] 155 | layers.append(block(self.in_planes, planes, stride, i == num_blocks - 1)) 156 | self.in_planes = planes * block.expansion 157 | return nn.Sequential(*layers) 158 | 159 | def encode(self, x, idx, preact=False): 160 | if idx == -1: 161 | out, pre = self.layer4(F.relu(x)) 162 | elif idx == -2: 163 | out, pre = self.layer3(F.relu(x)) 164 | elif idx == -3: 165 | out, pre = self.layer2(F.relu(x)) 166 | else: 167 | raise NotImplementedError() 168 | return pre 169 | 170 | def forward(self, x): 171 | out = F.relu(self.bn1(self.conv1(x))) 172 | f0 = out 173 | out, f1_pre = self.layer1(out) 174 | f1 = out 175 | out, f2_pre = self.layer2(out) 176 | f2 = out 177 | out, f3_pre = self.layer3(out) 178 | f3 = out 179 | out, f4_pre = self.layer4(out) 180 | f4 = out 181 | out = self.avgpool(out) 182 | avg = out.reshape(out.size(0), -1) 183 | out = self.linear(avg) 184 | 185 | feats = {} 186 | feats["feats"] = [f0, f1, f2, f3, f4] 187 | feats["preact_feats"] = [f0, f1_pre, f2_pre, f3_pre, f4_pre] 188 | feats["pooled_feat"] = avg 189 | 190 | return out, feats 191 | 192 | 193 | def ResNet18(**kwargs): 194 | return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 195 | 196 | 197 | def ResNet34(**kwargs): 198 | return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 199 | 200 | 201 | def ResNet50(**kwargs): 202 | return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 203 | 204 | 205 | def ResNet101(**kwargs): 206 | return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 207 | 208 | 209 | def ResNet152(**kwargs): 210 | return ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 211 | 212 | 213 | if __name__ == "__main__": 214 | net = ResNet18(num_classes=100) 215 | x = torch.randn(2, 3, 32, 32) 216 | logit, feats = net(x) 217 | 218 | for f in feats["feats"]: 219 | print(f.shape, f.min().item()) 220 | print(logit.shape) 221 | -------------------------------------------------------------------------------- /mdistiller/models/cifar/vgg.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import math 4 | 5 | 6 | __all__ = [ 7 | "VGG", 8 | "vgg11", 9 | "vgg11_bn", 10 | "vgg13", 11 | "vgg13_bn", 12 | "vgg16", 13 | "vgg16_bn", 14 | "vgg19_bn", 15 | "vgg19", 16 | ] 17 | 18 | 19 | model_urls = { 20 | "vgg11": "https://download.pytorch.org/models/vgg11-bbd30ac9.pth", 21 | "vgg13": "https://download.pytorch.org/models/vgg13-c768596a.pth", 22 | "vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth", 23 | "vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth", 24 | } 25 | 26 | 27 | class VGG(nn.Module): 28 | def __init__(self, cfg, batch_norm=False, num_classes=1000): 29 | super(VGG, self).__init__() 30 | self.block0 = self._make_layers(cfg[0], batch_norm, 3) 31 | self.block1 = self._make_layers(cfg[1], batch_norm, cfg[0][-1]) 32 | self.block2 = self._make_layers(cfg[2], batch_norm, cfg[1][-1]) 33 | self.block3 = self._make_layers(cfg[3], batch_norm, cfg[2][-1]) 34 | self.block4 = self._make_layers(cfg[4], batch_norm, cfg[3][-1]) 35 | 36 | self.pool0 = nn.MaxPool2d(kernel_size=2, stride=2) 37 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) 38 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) 39 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) 40 | self.pool4 = nn.AdaptiveAvgPool2d((1, 1)) 41 | # self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) 42 | 43 | self.classifier = nn.Linear(512, num_classes) 44 | self._initialize_weights() 45 | self.stage_channels = [c[-1] for c in cfg] 46 | 47 | def get_feat_modules(self): 48 | feat_m = nn.ModuleList([]) 49 | feat_m.append(self.block0) 50 | feat_m.append(self.pool0) 51 | feat_m.append(self.block1) 52 | feat_m.append(self.pool1) 53 | feat_m.append(self.block2) 54 | feat_m.append(self.pool2) 55 | feat_m.append(self.block3) 56 | feat_m.append(self.pool3) 57 | feat_m.append(self.block4) 58 | feat_m.append(self.pool4) 59 | return feat_m 60 | 61 | def get_bn_before_relu(self): 62 | bn1 = self.block1[-1] 63 | bn2 = self.block2[-1] 64 | bn3 = self.block3[-1] 65 | bn4 = self.block4[-1] 66 | return [bn1, bn2, bn3, bn4] 67 | 68 | def get_stage_channels(self): 69 | return self.stage_channels 70 | 71 | def forward(self, x): 72 | h = x.shape[2] 73 | x = F.relu(self.block0(x)) 74 | f0 = x 75 | x = self.pool0(x) 76 | x = self.block1(x) 77 | f1_pre = x 78 | x = F.relu(x) 79 | f1 = x 80 | x = self.pool1(x) 81 | x = self.block2(x) 82 | f2_pre = x 83 | x = F.relu(x) 84 | f2 = x 85 | x = self.pool2(x) 86 | x = self.block3(x) 87 | f3_pre = x 88 | x = F.relu(x) 89 | f3 = x 90 | if h == 64: 91 | x = self.pool3(x) 92 | x = self.block4(x) 93 | f4_pre = x 94 | x = F.relu(x) 95 | f4 = x 96 | x = self.pool4(x) 97 | x = x.reshape(x.size(0), -1) 98 | f5 = x 99 | x = self.classifier(x) 100 | 101 | feats = {} 102 | feats["feats"] = [f0, f1, f2, f3, f4] 103 | feats["preact_feats"] = [f0, f1_pre, f2_pre, f3_pre, f4_pre] 104 | feats["pooled_feat"] = f5 105 | 106 | return x, feats 107 | 108 | @staticmethod 109 | def _make_layers(cfg, batch_norm=False, in_channels=3): 110 | layers = [] 111 | for v in cfg: 112 | if v == "M": 113 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 114 | else: 115 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 116 | if batch_norm: 117 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 118 | else: 119 | layers += [conv2d, nn.ReLU(inplace=True)] 120 | in_channels = v 121 | layers = layers[:-1] 122 | return nn.Sequential(*layers) 123 | 124 | def _initialize_weights(self): 125 | for m in self.modules(): 126 | if isinstance(m, nn.Conv2d): 127 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 128 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 129 | if m.bias is not None: 130 | m.bias.data.zero_() 131 | elif isinstance(m, nn.BatchNorm2d): 132 | m.weight.data.fill_(1) 133 | m.bias.data.zero_() 134 | elif isinstance(m, nn.Linear): 135 | n = m.weight.size(1) 136 | m.weight.data.normal_(0, 0.01) 137 | m.bias.data.zero_() 138 | 139 | 140 | cfg = { 141 | "A": [[64], [128], [256, 256], [512, 512], [512, 512]], 142 | "B": [[64, 64], [128, 128], [256, 256], [512, 512], [512, 512]], 143 | "D": [[64, 64], [128, 128], [256, 256, 256], [512, 512, 512], [512, 512, 512]], 144 | "E": [ 145 | [64, 64], 146 | [128, 128], 147 | [256, 256, 256, 256], 148 | [512, 512, 512, 512], 149 | [512, 512, 512, 512], 150 | ], 151 | "S": [[64], [128], [256], [512], [512]], 152 | } 153 | 154 | 155 | def vgg8(**kwargs): 156 | """VGG 8-layer model (configuration "S") 157 | Args: 158 | pretrained (bool): If True, returns a model pre-trained on ImageNet 159 | """ 160 | model = VGG(cfg["S"], **kwargs) 161 | return model 162 | 163 | 164 | def vgg8_bn(**kwargs): 165 | """VGG 8-layer model (configuration "S") 166 | Args: 167 | pretrained (bool): If True, returns a model pre-trained on ImageNet 168 | """ 169 | model = VGG(cfg["S"], batch_norm=True, **kwargs) 170 | return model 171 | 172 | 173 | def vgg11(**kwargs): 174 | """VGG 11-layer model (configuration "A") 175 | Args: 176 | pretrained (bool): If True, returns a model pre-trained on ImageNet 177 | """ 178 | model = VGG(cfg["A"], **kwargs) 179 | return model 180 | 181 | 182 | def vgg11_bn(**kwargs): 183 | """VGG 11-layer model (configuration "A") with batch normalization""" 184 | model = VGG(cfg["A"], batch_norm=True, **kwargs) 185 | return model 186 | 187 | 188 | def vgg13(**kwargs): 189 | """VGG 13-layer model (configuration "B") 190 | Args: 191 | pretrained (bool): If True, returns a model pre-trained on ImageNet 192 | """ 193 | model = VGG(cfg["B"], **kwargs) 194 | return model 195 | 196 | 197 | def vgg13_bn(**kwargs): 198 | """VGG 13-layer model (configuration "B") with batch normalization""" 199 | model = VGG(cfg["B"], batch_norm=True, **kwargs) 200 | return model 201 | 202 | 203 | def vgg16(**kwargs): 204 | """VGG 16-layer model (configuration "D") 205 | Args: 206 | pretrained (bool): If True, returns a model pre-trained on ImageNet 207 | """ 208 | model = VGG(cfg["D"], **kwargs) 209 | return model 210 | 211 | 212 | def vgg16_bn(**kwargs): 213 | """VGG 16-layer model (configuration "D") with batch normalization""" 214 | model = VGG(cfg["D"], batch_norm=True, **kwargs) 215 | return model 216 | 217 | 218 | def vgg19(**kwargs): 219 | """VGG 19-layer model (configuration "E") 220 | Args: 221 | pretrained (bool): If True, returns a model pre-trained on ImageNet 222 | """ 223 | model = VGG(cfg["E"], **kwargs) 224 | return model 225 | 226 | 227 | def vgg19_bn(**kwargs): 228 | """VGG 19-layer model (configuration 'E') with batch normalization""" 229 | model = VGG(cfg["E"], batch_norm=True, **kwargs) 230 | return model 231 | 232 | 233 | if __name__ == "__main__": 234 | import torch 235 | 236 | x = torch.randn(2, 3, 32, 32) 237 | net = vgg19_bn(num_classes=100) 238 | logit, feats = net(x) 239 | 240 | for f in feats["feats"]: 241 | print(f.shape, f.min().item()) 242 | print(logit.shape) 243 | -------------------------------------------------------------------------------- /mdistiller/models/cifar/wrn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | __all__ = ["wrn"] 8 | 9 | 10 | class BasicBlock(nn.Module): 11 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 12 | super(BasicBlock, self).__init__() 13 | self.bn1 = nn.BatchNorm2d(in_planes) 14 | self.relu1 = nn.ReLU(inplace=True) 15 | self.conv1 = nn.Conv2d( 16 | in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False 17 | ) 18 | self.bn2 = nn.BatchNorm2d(out_planes) 19 | self.relu2 = nn.ReLU(inplace=True) 20 | self.conv2 = nn.Conv2d( 21 | out_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=False 22 | ) 23 | self.droprate = dropRate 24 | self.equalInOut = in_planes == out_planes 25 | self.convShortcut = ( 26 | (not self.equalInOut) 27 | and nn.Conv2d( 28 | in_planes, 29 | out_planes, 30 | kernel_size=1, 31 | stride=stride, 32 | padding=0, 33 | bias=False, 34 | ) 35 | or None 36 | ) 37 | 38 | def forward(self, x): 39 | if not self.equalInOut: 40 | x = self.relu1(self.bn1(x)) 41 | else: 42 | out = self.relu1(self.bn1(x)) 43 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 44 | if self.droprate > 0: 45 | out = F.dropout(out, p=self.droprate, training=self.training) 46 | out = self.conv2(out) 47 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 48 | 49 | 50 | class NetworkBlock(nn.Module): 51 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 52 | super(NetworkBlock, self).__init__() 53 | self.layer = self._make_layer( 54 | block, in_planes, out_planes, nb_layers, stride, dropRate 55 | ) 56 | 57 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 58 | layers = [] 59 | for i in range(nb_layers): 60 | layers.append( 61 | block( 62 | i == 0 and in_planes or out_planes, 63 | out_planes, 64 | i == 0 and stride or 1, 65 | dropRate, 66 | ) 67 | ) 68 | return nn.Sequential(*layers) 69 | 70 | def forward(self, x): 71 | return self.layer(x) 72 | 73 | 74 | class WideResNet(nn.Module): 75 | def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0): 76 | super(WideResNet, self).__init__() 77 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 78 | assert (depth - 4) % 6 == 0, "depth should be 6n+4" 79 | n = (depth - 4) // 6 80 | block = BasicBlock 81 | # 1st conv before any network block 82 | self.conv1 = nn.Conv2d( 83 | 3, nChannels[0], kernel_size=3, stride=1, padding=1, bias=False 84 | ) 85 | # 1st block 86 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 87 | # 2nd block 88 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 89 | # 3rd block 90 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 91 | # global average pooling and classifier 92 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 93 | self.relu = nn.ReLU(inplace=True) 94 | self.fc = nn.Linear(nChannels[3], num_classes) 95 | self.nChannels = nChannels[3] 96 | 97 | for m in self.modules(): 98 | if isinstance(m, nn.Conv2d): 99 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 100 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 101 | elif isinstance(m, nn.BatchNorm2d): 102 | m.weight.data.fill_(1) 103 | m.bias.data.zero_() 104 | elif isinstance(m, nn.Linear): 105 | m.bias.data.zero_() 106 | self.stage_channels = nChannels 107 | 108 | def get_feat_modules(self): 109 | feat_m = nn.ModuleList([]) 110 | feat_m.append(self.conv1) 111 | feat_m.append(self.block1) 112 | feat_m.append(self.block2) 113 | feat_m.append(self.block3) 114 | return feat_m 115 | 116 | def get_bn_before_relu(self): 117 | bn1 = self.block2.layer[0].bn1 118 | bn2 = self.block3.layer[0].bn1 119 | bn3 = self.bn1 120 | 121 | return [bn1, bn2, bn3] 122 | 123 | def get_stage_channels(self): 124 | return self.stage_channels 125 | 126 | def forward(self, x): 127 | out = self.conv1(x) 128 | f0 = out 129 | out = self.block1(out) 130 | f1 = out 131 | out = self.block2(out) 132 | f2 = out 133 | out = self.block3(out) 134 | f3 = out 135 | out = self.relu(self.bn1(out)) 136 | out = F.avg_pool2d(out, 8) 137 | out = out.reshape(-1, self.nChannels) 138 | f4 = out 139 | out = self.fc(out) 140 | 141 | f1_pre = self.block2.layer[0].bn1(f1) 142 | f2_pre = self.block3.layer[0].bn1(f2) 143 | f3_pre = self.bn1(f3) 144 | 145 | feats = {} 146 | feats["feats"] = [f0, f1, f2, f3] 147 | feats["preact_feats"] = [f0, f1_pre, f2_pre, f3_pre] 148 | feats["pooled_feat"] = f4 149 | 150 | return out, feats 151 | 152 | 153 | def wrn(**kwargs): 154 | """ 155 | Constructs a Wide Residual Networks. 156 | """ 157 | model = WideResNet(**kwargs) 158 | return model 159 | 160 | 161 | def wrn_40_2(**kwargs): 162 | model = WideResNet(depth=40, widen_factor=2, **kwargs) 163 | return model 164 | 165 | 166 | def wrn_40_1(**kwargs): 167 | model = WideResNet(depth=40, widen_factor=1, **kwargs) 168 | return model 169 | 170 | 171 | def wrn_16_2(**kwargs): 172 | model = WideResNet(depth=16, widen_factor=2, **kwargs) 173 | return model 174 | 175 | 176 | def wrn_16_1(**kwargs): 177 | model = WideResNet(depth=16, widen_factor=1, **kwargs) 178 | return model 179 | 180 | 181 | def wrn_16_4(**kwargs): 182 | model = WideResNet(depth=16, widen_factor=4, **kwargs) 183 | return model 184 | 185 | 186 | def wrn_28_2(**kwargs): 187 | model = WideResNet(depth=28, widen_factor=2, **kwargs) 188 | return model 189 | 190 | 191 | def wrn_28_4(**kwargs): 192 | model = WideResNet(depth=28, widen_factor=4, **kwargs) 193 | return model 194 | 195 | 196 | if __name__ == "__main__": 197 | import torch 198 | 199 | x = torch.randn(2, 3, 32, 32) 200 | net = wrn_40_2(num_classes=100) 201 | logit, feats = net(x) 202 | 203 | for f in feats["feats"]: 204 | print(f.shape, f.min().item()) 205 | print(logit.shape) 206 | -------------------------------------------------------------------------------- /mdistiller/models/imagenet/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import resnet18, resnet34, resnet50, resnet101, resnet152 2 | from .mobilenetv1 import MobileNetV1 3 | 4 | 5 | imagenet_model_dict = { 6 | "ResNet18": resnet18, 7 | "ResNet34": resnet34, 8 | "ResNet50": resnet50, 9 | "ResNet101": resnet101, 10 | "MobileNetV1": MobileNetV1, 11 | } 12 | -------------------------------------------------------------------------------- /mdistiller/models/imagenet/mobilenetv1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class MobileNetV1(nn.Module): 7 | def __init__(self, **kwargs): 8 | super(MobileNetV1, self).__init__() 9 | 10 | def conv_bn(inp, oup, stride): 11 | return nn.Sequential( 12 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 13 | nn.BatchNorm2d(oup), 14 | nn.ReLU(inplace=True), 15 | ) 16 | 17 | def conv_dw(inp, oup, stride): 18 | return nn.Sequential( 19 | nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), 20 | nn.BatchNorm2d(inp), 21 | nn.ReLU(inplace=True), 22 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 23 | nn.BatchNorm2d(oup), 24 | nn.ReLU(inplace=True), 25 | ) 26 | 27 | self.model = nn.Sequential( 28 | conv_bn(3, 32, 2), 29 | conv_dw(32, 64, 1), 30 | conv_dw(64, 128, 2), 31 | conv_dw(128, 128, 1), 32 | conv_dw(128, 256, 2), 33 | conv_dw(256, 256, 1), 34 | conv_dw(256, 512, 2), 35 | conv_dw(512, 512, 1), 36 | conv_dw(512, 512, 1), 37 | conv_dw(512, 512, 1), 38 | conv_dw(512, 512, 1), 39 | conv_dw(512, 512, 1), 40 | conv_dw(512, 1024, 2), 41 | conv_dw(1024, 1024, 1), 42 | nn.AvgPool2d(7), 43 | ) 44 | self.fc = nn.Linear(1024, 1000) 45 | 46 | def forward(self, x, is_feat=False): 47 | feat1 = self.model[3][:-1](self.model[0:3](x)) 48 | feat2 = self.model[5][:-1](self.model[4:5](F.relu(feat1))) 49 | feat3 = self.model[11][:-1](self.model[6:11](F.relu(feat2))) 50 | feat4 = self.model[13][:-1](self.model[12:13](F.relu(feat3))) 51 | feat5 = self.model[14](F.relu(feat4)) 52 | avg = feat5.reshape(-1, 1024) 53 | out = self.fc(avg) 54 | 55 | feats = {} 56 | feats["pooled_feat"] = avg 57 | feats["feats"] = [F.relu(feat1), F.relu(feat2), F.relu(feat3), F.relu(feat4)] 58 | feats["preact_feats"] = [feat1, feat2, feat3, feat4] 59 | return out, feats 60 | 61 | def get_bn_before_relu(self): 62 | bn1 = self.model[3][-2] 63 | bn2 = self.model[5][-2] 64 | bn3 = self.model[11][-2] 65 | bn4 = self.model[13][-2] 66 | return [bn1, bn2, bn3, bn4] 67 | 68 | def get_stage_channels(self): 69 | return [128, 256, 512, 1024] 70 | -------------------------------------------------------------------------------- /tools/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | cur_path = os.path.abspath(os.path.dirname(__file__)) 4 | root_path = os.path.split(cur_path)[0] 5 | sys.path.append(root_path) 6 | import argparse 7 | import torch 8 | import torch.backends.cudnn as cudnn 9 | 10 | cudnn.benchmark = True 11 | 12 | from mdistiller.distillers import Vanilla 13 | from mdistiller.models import cifar_model_dict, imagenet_model_dict 14 | from mdistiller.dataset import get_dataset 15 | from mdistiller.dataset.imagenet import get_imagenet_val_loader 16 | from mdistiller.engine.utils import load_checkpoint, validate 17 | from mdistiller.engine.cfg import CFG as cfg 18 | 19 | 20 | if __name__ == "__main__": 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument("-m", "--model", type=str, default="") 23 | parser.add_argument("-c", "--ckpt", type=str, default="pretrain") 24 | parser.add_argument( 25 | "-d", 26 | "--dataset", 27 | type=str, 28 | default="cifar100", 29 | choices=["cifar100", "imagenet"], 30 | ) 31 | parser.add_argument("-bs", "--batch-size", type=int, default=64) 32 | args = parser.parse_args() 33 | 34 | cfg.DATASET.TYPE = args.dataset 35 | cfg.DATASET.TEST.BATCH_SIZE = args.batch_size 36 | if args.dataset == "imagenet": 37 | val_loader = get_imagenet_val_loader(args.batch_size) 38 | if args.ckpt == "pretrain": 39 | model = imagenet_model_dict[args.model](pretrained=True) 40 | else: 41 | model = imagenet_model_dict[args.model](pretrained=False) 42 | model.load_state_dict(load_checkpoint(args.ckpt)["model"]) 43 | elif args.dataset == "cifar100": 44 | train_loader, val_loader, num_data, num_classes = get_dataset(cfg) 45 | model, pretrain_model_path = cifar_model_dict[args.model] 46 | model = model(num_classes=num_classes) 47 | ckpt = pretrain_model_path if args.ckpt == "pretrain" else args.ckpt 48 | model.load_state_dict(load_checkpoint(ckpt)["model"]) 49 | model = Vanilla(model) 50 | model = model.cuda() 51 | model = torch.nn.DataParallel(model) 52 | test_acc, test_acc_top5, test_loss = validate(val_loader, model) 53 | --------------------------------------------------------------------------------