├── result ├── placeholder.txt ├── train_online_kd_imagenet_arch_resnet18_resnet18_dataset_imagenet_seed0.txt ├── train_online_kd_cifar_arch_WRN_40_1_WRN_40_1_dataset_cifar100_seed0.txt ├── train_online_kd_cifar_arch_WRN_40_2_WRN_40_2_dataset_cifar100_seed0.txt ├── train_online_kd_cifar_arch_mbv2_mbv2_dataset_cifar100_seed0.txt ├── train_online_kd_cifar_arch_vgg13_vgg13_dataset_cifar100_seed0.txt ├── train_online_kd_cifar_arch_mbv2_mbv2_dataset_cifar100_seed1.txt ├── train_online_kd_cifar_arch_shufflev1_shufflev1_dataset_cifar100_seed0.txt ├── train_online_kd_cifar_arch_resnet32x4_resnet32x4_dataset_cifar100_seed0.txt ├── train_online_kd_cifar_arch_resnet56_resnet56_dataset_cifar100_seed0.txt └── train_online_kd_cifar_arch_shufflev2_shufflev2_dataset_cifar100_seed0.txt ├── poster ├── 765_IJCAI_poster.pdf └── 765_IJCAI_slides.pdf ├── models ├── __pycache__ │ ├── vgg.cpython-38.pyc │ ├── wrn.cpython-38.pyc │ ├── resnet.cpython-38.pyc │ ├── utils.cpython-38.pyc │ ├── __init__.cpython-38.pyc │ ├── resnetv2.cpython-38.pyc │ ├── mobilenetv2.cpython-38.pyc │ ├── ShuffleNetv1.cpython-38.pyc │ ├── ShuffleNetv2.cpython-38.pyc │ └── resnet_imagenet.cpython-38.pyc ├── __init__.py ├── utils.py ├── mutual_kd_net.py ├── ShuffleNetv1.py ├── wrn.py ├── resnetv2.py ├── ShuffleNetv2.py ├── mobilenetv2.py ├── resnet.py ├── vgg.py └── resnet_imagenet.py ├── data └── preprocess_tinyimagenet.py ├── train_baseline_cifar.sh ├── train_teacher_cifar.sh ├── train_online_kd_cifar.sh ├── train_teacher_freezed.sh ├── utils.py ├── train_student_cifar.sh ├── LICENSE ├── eval_rep.py ├── README.md ├── train_baseline_cifar.py ├── train_teacher_cifar.py └── train_student_cifar.py /result/placeholder.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /poster/765_IJCAI_poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/winycg/HSAKD/HEAD/poster/765_IJCAI_poster.pdf -------------------------------------------------------------------------------- /poster/765_IJCAI_slides.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/winycg/HSAKD/HEAD/poster/765_IJCAI_slides.pdf -------------------------------------------------------------------------------- /models/__pycache__/vgg.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/winycg/HSAKD/HEAD/models/__pycache__/vgg.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/wrn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/winycg/HSAKD/HEAD/models/__pycache__/wrn.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/winycg/HSAKD/HEAD/models/__pycache__/resnet.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/winycg/HSAKD/HEAD/models/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/winycg/HSAKD/HEAD/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnetv2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/winycg/HSAKD/HEAD/models/__pycache__/resnetv2.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/mobilenetv2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/winycg/HSAKD/HEAD/models/__pycache__/mobilenetv2.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/ShuffleNetv1.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/winycg/HSAKD/HEAD/models/__pycache__/ShuffleNetv1.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/ShuffleNetv2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/winycg/HSAKD/HEAD/models/__pycache__/ShuffleNetv2.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet_imagenet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/winycg/HSAKD/HEAD/models/__pycache__/resnet_imagenet.cpython-38.pyc -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import * 2 | from .resnetv2 import * 3 | from .wrn import * 4 | from .vgg import * 5 | from .ShuffleNetv1 import * 6 | from .ShuffleNetv2 import * 7 | from .mobilenetv2 import * 8 | from .resnet_imagenet import * 9 | from .mutual_kd_net import * -------------------------------------------------------------------------------- /data/preprocess_tinyimagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.system('mv ./tiny-imagenet-200/val ./tiny-imagenet-200/val_original') 4 | val_labels_t = [] 5 | val_labels = [] 6 | val_names = [] 7 | with open('./tiny-imagenet-200/val_original/val_annotations.txt') as txt: 8 | for line in txt: 9 | img_name = line.strip('\n').split('\t')[0] 10 | label_name = line.strip('\n').split('\t')[1] 11 | 12 | if not os.path.isdir('./tiny-imagenet-200/val/'+label_name): 13 | os.makedirs('./tiny-imagenet-200/val/'+label_name) 14 | os.system('cp ./tiny-imagenet-200/val_original/images/'+img_name+' ./tiny-imagenet-200/val/'+label_name) -------------------------------------------------------------------------------- /train_baseline_cifar.sh: -------------------------------------------------------------------------------- 1 | python train_baseline.py --arch ShuffleV1 --data ./data/ --init-lr 0.01 --gpu 0 2 | python train_baseline.py --arch ShuffleV2 --data ./data/ --init-lr 0.01 --gpu 0 3 | python train_baseline.py --arch mobilenetV2 --data ./data/ --init-lr 0.01 --gpu 0 4 | 5 | 6 | python train_baseline.py --arch wrn_16_2 --data ./data/ --gpu 0 7 | python train_baseline.py --arch wrn_40_1 --data ./data/ --gpu 0 8 | python train_baseline.py --arch resnet8x4 --data ./data/ --gpu 0 9 | python train_baseline.py --arch resnet20 --data ./data/ --gpu 0 10 | 11 | 12 | python train_baseline_cifar.py --arch mobilenetV2 --data /home/ycg/hhd/dataset/ --init-lr 0.01 --gpu 0 13 | python train_baseline_cifar.py --arch wrn_16_2 --data /home/ycg/hhd/dataset/ --gpu 0 -------------------------------------------------------------------------------- /train_teacher_cifar.sh: -------------------------------------------------------------------------------- 1 | python train_teacher_cifar.py \ 2 | --arch wrn_40_2_aux \ 3 | --checkpoint-dir ./checkpoint \ 4 | --data ./data \ 5 | --gpu 2 --manual 0 6 | 7 | python train_teacher_cifar.py \ 8 | --arch ResNet50_aux \ 9 | --checkpoint-dir ./checkpoint \ 10 | --data ./data \ 11 | --gpu 2 --manual 0 12 | 13 | python train_teacher_cifar.py \ 14 | --arch vgg13_bn_aux \ 15 | --checkpoint-dir ./checkpoint \ 16 | --data ./data \ 17 | --gpu 2 --manual 0 18 | 19 | python train_teacher_cifar.py \ 20 | --arch resnet56_aux \ 21 | --checkpoint-dir ./checkpoint \ 22 | --data ./data \ 23 | --gpu 2 --manual 0 24 | 25 | python train_teacher_cifar.py \ 26 | --arch resnet32x4_aux \ 27 | --checkpoint-dir ./checkpoint \ 28 | --data ./data \ 29 | --gpu 2 --manual 0 -------------------------------------------------------------------------------- /train_online_kd_cifar.sh: -------------------------------------------------------------------------------- 1 | python train_online_kd_cifar.py \ 2 | --arch WRN_40_2_WRN_40_2 \ 3 | --checkpoint-dir ./checkpoint 4 | 5 | python train_online_kd_cifar.py \ 6 | --arch WRN_40_1_WRN_40_1 \ 7 | --checkpoint-dir ./checkpoint 8 | 9 | python train_online_kd_cifar.py \ 10 | --arch resnet56_resnet56 \ 11 | --checkpoint-dir ./checkpoint 12 | 13 | 14 | python train_online_kd_cifar.py \ 15 | --arch resnet32x4_resnet32x4 \ 16 | --checkpoint-dir ./checkpoint 17 | 18 | python train_online_kd_cifar.py \ 19 | --arch vgg13_vgg13 \ 20 | --checkpoint-dir ./checkpoint 21 | 22 | 23 | python train_online_kd_cifar.py \ 24 | --arch mbv2_mbv2 \ 25 | --init-lr 0.01 \ 26 | --checkpoint-dir ./checkpoint 27 | 28 | 29 | python train_online_kd_cifar.py \ 30 | --arch shufflev1_shufflev1 \ 31 | --init-lr 0.01 \ 32 | --checkpoint-dir ./checkpoint 33 | 34 | python train_online_kd_cifar.py \ 35 | --arch shufflev2_shufflev2 \ 36 | --init-lr 0.01 \ 37 | --checkpoint-dir ./checkpoint 38 | -------------------------------------------------------------------------------- /result/train_online_kd_imagenet_arch_resnet18_resnet18_dataset_imagenet_seed0.txt: -------------------------------------------------------------------------------- 1 | ========== 2 | Args:Namespace(arch='resnet18_resnet18', batch_size=2, checkpoint_dir='./checkpoint', data='/dev/shm/', dataset='imagenet', dist_backend='nccl', dist_url='tcp://127.0.0.1:2222', epochs=100, evaluate=False, gpu_id='0,1', init_lr=0.1, kd_T=3, log_dir='train_online_kd_imagenet_arch_resnet18_resnet18_dataset_imagenet_seed0', log_txt='result/train_online_kd_imagenet_arch_resnet18_resnet18_dataset_imagenet_seed0.txt', manual_seed=0, milestones=[30, 60, 90], multiprocessing_distributed=True, num_workers=16, rank=0, resume=False, traindir='/dev/shm/train', valdir='/dev/shm/val', warmup_epoch=0, weight_decay=0.0001, world_size=1) 3 | ========== 4 | ========== 5 | Args:Namespace(arch='resnet18_resnet18', batch_size=2, checkpoint_dir='./checkpoint', data='/dev/shm/ImageNet/', dataset='imagenet', dist_backend='nccl', dist_url='tcp://127.0.0.1:2222', epochs=100, evaluate=False, gpu_id='0,1', init_lr=0.1, kd_T=3, log_dir='train_online_kd_imagenet_arch_resnet18_resnet18_dataset_imagenet_seed0', log_txt='result/train_online_kd_imagenet_arch_resnet18_resnet18_dataset_imagenet_seed0.txt', manual_seed=0, milestones=[30, 60, 90], multiprocessing_distributed=True, num_workers=16, rank=0, resume=False, traindir='/dev/shm/ImageNet/train', valdir='/dev/shm/ImageNet/val', warmup_epoch=0, weight_decay=0.0001, world_size=1) 6 | ========== 7 | -------------------------------------------------------------------------------- /result/train_online_kd_cifar_arch_WRN_40_1_WRN_40_1_dataset_cifar100_seed0.txt: -------------------------------------------------------------------------------- 1 | ========== 2 | Args:Namespace(arch='WRN_40_1_WRN_40_1', batch_size=64, checkpoint_dir='./checkpoint/train_online_kd_cifar_arch_WRN_40_1_WRN_40_1_dataset_cifar100_seed0', data='/home/ycg/hhd/dataset/', dataset='cifar100', epochs=1, evaluate=False, gpu_id='0', init_lr=0.01, kd_T=3, log_dir='train_online_kd_cifar_arch_WRN_40_1_WRN_40_1_dataset_cifar100_seed0', log_txt='result/train_online_kd_cifar_arch_WRN_40_1_WRN_40_1_dataset_cifar100_seed0.txt', lr_type='multistep', manual_seed=0, milestones=[150, 180, 210], num_workers=8, resume=False, sgdr_t=300, warmup_epoch=0, weight_decay=0.0005) 3 | ========== 4 | Epoch:0 lr:0.01000 duration:213.741 5 | train_loss:43.85046 train_loss_cls:43.64826 train_loss_div:0.20220 6 | Train Top-1 ss_accuracy: [[0.0077, 0.0081, 0.0087], [0.0069, 0.0084, 0.0087]] 7 | Train Top-1 class_accuracy: [[0.0314, 0.0334, 0.0326, 0.0434], [0.0301, 0.0338, 0.0325, 0.0498]] 8 | test epoch:0 test_loss_cls:7.95888 9 | Top-1 ss_accuracy: [[0.0133, 0.0136, 0.0144], [0.0115, 0.0152, 0.0168]] 10 | Top-1 class_accuracy: [[0.0476, 0.0573, 0.0535, 0.0779], [0.0463, 0.0556, 0.0582, 0.0839]] 11 | test epoch:0 test_loss_cls:7.95888 12 | Top-1 ss_accuracy: [[0.0133, 0.0136, 0.0144], [0.0115, 0.0152, 0.0168]] 13 | Top-1 class_accuracy: [[0.0476, 0.0573, 0.0535, 0.0779], [0.0463, 0.0556, 0.0582, 0.0839]] 14 | best_accuracy: 0.0839 15 | -------------------------------------------------------------------------------- /result/train_online_kd_cifar_arch_WRN_40_2_WRN_40_2_dataset_cifar100_seed0.txt: -------------------------------------------------------------------------------- 1 | ========== 2 | Args:Namespace(arch='WRN_40_2_WRN_40_2', batch_size=64, checkpoint_dir='./checkpoint/train_online_kd_cifar_arch_WRN_40_2_WRN_40_2_dataset_cifar100_seed0', data='/home/ycg/hhd/dataset/', dataset='cifar100', epochs=1, evaluate=False, gpu_id='0', init_lr=0.01, kd_T=3, log_dir='train_online_kd_cifar_arch_WRN_40_2_WRN_40_2_dataset_cifar100_seed0', log_txt='result/train_online_kd_cifar_arch_WRN_40_2_WRN_40_2_dataset_cifar100_seed0.txt', lr_type='multistep', manual_seed=0, milestones=[150, 180, 210], num_workers=8, resume=False, sgdr_t=300, warmup_epoch=0, weight_decay=0.0005) 3 | ========== 4 | Epoch:0 lr:0.01000 duration:323.147 5 | train_loss:42.88193 train_loss_cls:42.70993 train_loss_div:0.17200 6 | Train Top-1 ss_accuracy: [[0.0106, 0.0116, 0.0138], [0.0115, 0.0119, 0.0134]] 7 | Train Top-1 class_accuracy: [[0.043, 0.0449, 0.0499, 0.0661], [0.043, 0.0462, 0.0495, 0.0659]] 8 | test epoch:0 test_loss_cls:7.63474 9 | Top-1 ss_accuracy: [[0.0173, 0.0198, 0.0232], [0.0188, 0.0205, 0.0231]] 10 | Top-1 class_accuracy: [[0.0659, 0.0706, 0.0726, 0.1055], [0.0671, 0.0696, 0.0753, 0.0976]] 11 | test epoch:0 test_loss_cls:7.63474 12 | Top-1 ss_accuracy: [[0.0173, 0.0198, 0.0232], [0.0188, 0.0205, 0.0231]] 13 | Top-1 class_accuracy: [[0.0659, 0.0706, 0.0726, 0.1055], [0.0671, 0.0696, 0.0753, 0.0976]] 14 | best_accuracy: 0.1055 15 | -------------------------------------------------------------------------------- /result/train_online_kd_cifar_arch_mbv2_mbv2_dataset_cifar100_seed0.txt: -------------------------------------------------------------------------------- 1 | ========== 2 | Args:Namespace(arch='mbv2_mbv2', batch_size=64, checkpoint_dir='./checkpoint/train_online_kd_cifar_arch_mbv2_mbv2_dataset_cifar100_seed0', data='/home/ycg/hhd/dataset/', dataset='cifar100', epochs=1, evaluate=False, gpu_id='0', init_lr=0.01, kd_T=3, log_dir='train_online_kd_cifar_arch_mbv2_mbv2_dataset_cifar100_seed0', log_txt='result/train_online_kd_cifar_arch_mbv2_mbv2_dataset_cifar100_seed0.txt', lr_type='multistep', manual_seed=0, milestones=[150, 180, 210], num_workers=8, resume=False, sgdr_t=300, warmup_epoch=0, weight_decay=0.0005) 3 | ========== 4 | Epoch:0 lr:0.01000 duration:335.616 5 | train_loss:51.97935 train_loss_cls:51.51671 train_loss_div:0.46264 6 | Train Top-1 ss_accuracy: [[0.0307, 0.0321, 0.0339, 0.035], [0.0255, 0.0289, 0.0325, 0.0328]] 7 | Train Top-1 class_accuracy: [[0.0482, 0.0472, 0.0503, 0.0514, 0.0834], [0.0436, 0.0458, 0.0489, 0.0507, 0.0809]] 8 | test epoch:0 test_loss_cls:7.15583 9 | Top-1 ss_accuracy: [[0.0675, 0.0703, 0.0733, 0.0718], [0.0619, 0.0698, 0.0712, 0.073]] 10 | Top-1 class_accuracy: [[0.0792, 0.0809, 0.0842, 0.0856, 0.1432], [0.0736, 0.0813, 0.0837, 0.0859, 0.1471]] 11 | test epoch:0 test_loss_cls:7.15583 12 | Top-1 ss_accuracy: [[0.0675, 0.0703, 0.0733, 0.0718], [0.0619, 0.0698, 0.0712, 0.073]] 13 | Top-1 class_accuracy: [[0.0792, 0.0809, 0.0842, 0.0856, 0.1432], [0.0736, 0.0813, 0.0837, 0.0859, 0.1471]] 14 | best_accuracy: 0.1471 15 | -------------------------------------------------------------------------------- /result/train_online_kd_cifar_arch_vgg13_vgg13_dataset_cifar100_seed0.txt: -------------------------------------------------------------------------------- 1 | ========== 2 | Args:Namespace(arch='vgg13_vgg13', batch_size=64, checkpoint_dir='./checkpoint/train_online_kd_cifar_arch_vgg13_vgg13_dataset_cifar100_seed0', data='/home/ycg/hhd/dataset/', dataset='cifar100', epochs=1, evaluate=False, gpu_id='0', init_lr=0.01, kd_T=3, log_dir='train_online_kd_cifar_arch_vgg13_vgg13_dataset_cifar100_seed0', log_txt='result/train_online_kd_cifar_arch_vgg13_vgg13_dataset_cifar100_seed0.txt', lr_type='multistep', manual_seed=0, milestones=[150, 180, 210], num_workers=8, resume=False, sgdr_t=300, warmup_epoch=0, weight_decay=0.0005) 3 | ========== 4 | Epoch:0 lr:0.01000 duration:328.339 5 | train_loss:47.54712 train_loss_cls:47.17569 train_loss_div:0.37144 6 | Train Top-1 ss_accuracy: [[0.0723, 0.0707, 0.0689, 0.0554], [0.0712, 0.0725, 0.0684, 0.0544]] 7 | Train Top-1 class_accuracy: [[0.1073, 0.1076, 0.105, 0.0972, 0.1358], [0.108, 0.1085, 0.1067, 0.0967, 0.1347]] 8 | test epoch:0 test_loss_cls:6.37198 9 | Top-1 ss_accuracy: [[0.1212, 0.1165, 0.1168, 0.1013], [0.1273, 0.1228, 0.1222, 0.1049]] 10 | Top-1 class_accuracy: [[0.1454, 0.1431, 0.1458, 0.1465, 0.2097], [0.1505, 0.1495, 0.1487, 0.1491, 0.2144]] 11 | test epoch:0 test_loss_cls:6.37198 12 | Top-1 ss_accuracy: [[0.1212, 0.1165, 0.1168, 0.1013], [0.1273, 0.1228, 0.1222, 0.1049]] 13 | Top-1 class_accuracy: [[0.1454, 0.1431, 0.1458, 0.1465, 0.2097], [0.1505, 0.1495, 0.1487, 0.1491, 0.2144]] 14 | best_accuracy: 0.2144 15 | -------------------------------------------------------------------------------- /train_teacher_freezed.sh: -------------------------------------------------------------------------------- 1 | python train_teacher_cifar.py \ 2 | --arch wrn_40_2_aux \ 3 | --milestones 30 60 90 --epochs 100 \ 4 | --checkpoint-dir ./checkpoint \ 5 | --data ./data \ 6 | --gpu 2 --manual 0 \ 7 | --pretrained-backbone ./pretrained_backbones/wrn_40_2.pth \ 8 | --freezed 9 | 10 | python train_teacher_cifar.py \ 11 | --arch resnet56_aux \ 12 | --milestones 30 60 90 --epochs 100 \ 13 | --checkpoint-dir ./checkpoint \ 14 | --data ./data \ 15 | --gpu 2 --manual 0 \ 16 | --pretrained-backbone ./pretrained_backbones/resnet56.pth \ 17 | --freezed 18 | 19 | python train_teacher_cifar.py \ 20 | --arch ResNet50_aux \ 21 | --milestones 30 60 90 --epochs 100 \ 22 | --checkpoint-dir ./checkpoint \ 23 | --data ./data \ 24 | --gpu 2 --manual 0 \ 25 | --pretrained-backbone ./pretrained_backbones/ResNet50.pth \ 26 | --freezed 27 | 28 | python train_teacher_cifar.py \ 29 | --arch vgg13_bn_aux \ 30 | --milestones 30 60 90 --epochs 100 \ 31 | --checkpoint-dir ./checkpoint \ 32 | --data ./data \ 33 | --gpu 2 --manual 0 \ 34 | --pretrained-backbone ./pretrained_backbones/vgg13_bn.pth \ 35 | --freezed 36 | 37 | python train_teacher_cifar.py \ 38 | --arch resnet32x4_aux \ 39 | --milestones 30 60 90 --epochs 100 \ 40 | --checkpoint-dir ./checkpoint \ 41 | --data ./data \ 42 | --gpu 2 --manual 0 \ 43 | --pretrained-backbone ./pretrained_backbones/resnet32x4.pth \ 44 | --freezed -------------------------------------------------------------------------------- /result/train_online_kd_cifar_arch_mbv2_mbv2_dataset_cifar100_seed1.txt: -------------------------------------------------------------------------------- 1 | ========== 2 | Args:Namespace(arch='mbv2_mbv2', batch_size=64, checkpoint_dir='/home/ycg/hhd/winycg/checkpoints/hsakd_checkpoints/train_online_kd_cifar_arch_mbv2_mbv2_dataset_cifar100_seed1', data='/home/user/hhd/dataset/', dataset='cifar100', epochs=240, evaluate=False, gpu_id='0', init_lr=0.01, kd_T=3, log_dir='train_online_kd_cifar_arch_mbv2_mbv2_dataset_cifar100_seed1', log_txt='result/train_online_kd_cifar_arch_mbv2_mbv2_dataset_cifar100_seed1.txt', lr_type='multistep', manual_seed=1, milestones=[150, 180, 210], num_workers=8, resume=False, sgdr_t=300, warmup_epoch=0, weight_decay=0.0005) 3 | ========== 4 | ========== 5 | Args:Namespace(arch='mbv2_mbv2', batch_size=64, checkpoint_dir='/home/ycg/hhd/winycg/checkpoints/hsakd_checkpoints/train_online_kd_cifar_arch_mbv2_mbv2_dataset_cifar100_seed1', data='/home/ycg/hhd/dataset/', dataset='cifar100', epochs=240, evaluate=False, gpu_id='0', init_lr=0.01, kd_T=3, log_dir='train_online_kd_cifar_arch_mbv2_mbv2_dataset_cifar100_seed1', log_txt='result/train_online_kd_cifar_arch_mbv2_mbv2_dataset_cifar100_seed1.txt', lr_type='multistep', manual_seed=1, milestones=[150, 180, 210], num_workers=8, resume=False, sgdr_t=300, warmup_epoch=0, weight_decay=0.0005) 6 | ========== 7 | ========== 8 | Args:Namespace(arch='mbv2_mbv2', batch_size=64, checkpoint_dir='/home/ycg/hhd/winycg/checkpoints/hsakd_checkpoints/train_online_kd_cifar_arch_mbv2_mbv2_dataset_cifar100_seed1', data='/home/ycg/hhd/dataset/', dataset='cifar100', epochs=240, evaluate=False, gpu_id='0', init_lr=0.01, kd_T=3, log_dir='train_online_kd_cifar_arch_mbv2_mbv2_dataset_cifar100_seed1', log_txt='result/train_online_kd_cifar_arch_mbv2_mbv2_dataset_cifar100_seed1.txt', lr_type='multistep', manual_seed=1, milestones=[150, 180, 210], num_workers=8, resume=False, sgdr_t=300, warmup_epoch=0, weight_decay=0.0005) 9 | ========== 10 | -------------------------------------------------------------------------------- /result/train_online_kd_cifar_arch_shufflev1_shufflev1_dataset_cifar100_seed0.txt: -------------------------------------------------------------------------------- 1 | ========== 2 | Args:Namespace(arch='shufflev1_shufflev1', batch_size=64, checkpoint_dir='./checkpoint/train_online_kd_cifar_arch_shufflev1_shufflev1_dataset_cifar100_seed0', data='/home/ycg/hhd/dataset/', dataset='cifar100', epochs=1, evaluate=False, gpu_id='0', init_lr=0.01, kd_T=3, log_dir='train_online_kd_cifar_arch_shufflev1_shufflev1_dataset_cifar100_seed0', log_txt='result/train_online_kd_cifar_arch_shufflev1_shufflev1_dataset_cifar100_seed0.txt', lr_type='multistep', manual_seed=0, milestones=[150, 180, 210], num_workers=8, resume=False, sgdr_t=300, warmup_epoch=0, weight_decay=0.0005) 3 | ========== 4 | ========== 5 | Args:Namespace(arch='shufflev1_shufflev1', batch_size=64, checkpoint_dir='./checkpoint/train_online_kd_cifar_arch_shufflev1_shufflev1_dataset_cifar100_seed0', data='/home/ycg/hhd/dataset/', dataset='cifar100', epochs=1, evaluate=False, gpu_id='0', init_lr=0.01, kd_T=3, log_dir='train_online_kd_cifar_arch_shufflev1_shufflev1_dataset_cifar100_seed0', log_txt='result/train_online_kd_cifar_arch_shufflev1_shufflev1_dataset_cifar100_seed0.txt', lr_type='multistep', manual_seed=0, milestones=[150, 180, 210], num_workers=8, resume=False, sgdr_t=300, warmup_epoch=0, weight_decay=0.0005) 6 | ========== 7 | Epoch:0 lr:0.01000 duration:471.176 8 | train_loss:41.14524 train_loss_cls:39.77589 train_loss_div:1.36935 9 | Train Top-1 ss_accuracy: [[0.0482, 0.0515, 0.0502], [0.0493, 0.052, 0.0525]] 10 | Train Top-1 class_accuracy: [[0.0715, 0.0724, 0.0696, 0.1025], [0.0714, 0.0738, 0.0734, 0.1037]] 11 | test epoch:0 test_loss_cls:7.38271 12 | Top-1 ss_accuracy: [[0.0891, 0.0968, 0.0977], [0.0886, 0.0939, 0.0943]] 13 | Top-1 class_accuracy: [[0.1152, 0.1251, 0.1203, 0.1644], [0.1167, 0.116, 0.1113, 0.137]] 14 | test epoch:0 test_loss_cls:7.38271 15 | Top-1 ss_accuracy: [[0.0891, 0.0968, 0.0977], [0.0886, 0.0939, 0.0943]] 16 | Top-1 class_accuracy: [[0.1152, 0.1251, 0.1203, 0.1644], [0.1167, 0.116, 0.1113, 0.137]] 17 | best_accuracy: 0.1644 18 | -------------------------------------------------------------------------------- /result/train_online_kd_cifar_arch_resnet32x4_resnet32x4_dataset_cifar100_seed0.txt: -------------------------------------------------------------------------------- 1 | ========== 2 | Args:Namespace(arch='resnet32x4_resnet32x4', batch_size=64, checkpoint_dir='./checkpoint/train_online_kd_cifar_arch_resnet32x4_resnet32x4_dataset_cifar100_seed0', data='/home/ycg/hhd/dataset/', dataset='cifar100', epochs=1, evaluate=False, gpu_id='0', init_lr=0.01, kd_T=3, log_dir='train_online_kd_cifar_arch_resnet32x4_resnet32x4_dataset_cifar100_seed0', log_txt='result/train_online_kd_cifar_arch_resnet32x4_resnet32x4_dataset_cifar100_seed0.txt', lr_type='multistep', manual_seed=0, milestones=[150, 180, 210], num_workers=8, resume=False, sgdr_t=300, warmup_epoch=0, weight_decay=0.0005) 3 | ========== 4 | ========== 5 | Args:Namespace(arch='resnet32x4_resnet32x4', batch_size=64, checkpoint_dir='./checkpoint/train_online_kd_cifar_arch_resnet32x4_resnet32x4_dataset_cifar100_seed0', data='/home/ycg/hhd/dataset/', dataset='cifar100', epochs=1, evaluate=False, gpu_id='0', init_lr=0.01, kd_T=3, log_dir='train_online_kd_cifar_arch_resnet32x4_resnet32x4_dataset_cifar100_seed0', log_txt='result/train_online_kd_cifar_arch_resnet32x4_resnet32x4_dataset_cifar100_seed0.txt', lr_type='multistep', manual_seed=0, milestones=[150, 180, 210], num_workers=8, resume=False, sgdr_t=300, warmup_epoch=0, weight_decay=0.0005) 6 | ========== 7 | Epoch:0 lr:0.01000 duration:692.951 8 | train_loss:39.45676 train_loss_cls:38.96839 train_loss_div:0.48837 9 | Train Top-1 ss_accuracy: [[0.0442, 0.0458, 0.0483], [0.0444, 0.0445, 0.0477]] 10 | Train Top-1 class_accuracy: [[0.0776, 0.0786, 0.0776, 0.1091], [0.0787, 0.0779, 0.0772, 0.1058]] 11 | test epoch:0 test_loss_cls:7.24091 12 | Top-1 ss_accuracy: [[0.0859, 0.0813, 0.0756], [0.0878, 0.0848, 0.0807]] 13 | Top-1 class_accuracy: [[0.116, 0.109, 0.0962, 0.1511], [0.1181, 0.1088, 0.0996, 0.163]] 14 | test epoch:0 test_loss_cls:7.24091 15 | Top-1 ss_accuracy: [[0.0859, 0.0813, 0.0756], [0.0878, 0.0848, 0.0807]] 16 | Top-1 class_accuracy: [[0.116, 0.109, 0.0962, 0.1511], [0.1181, 0.1088, 0.0996, 0.163]] 17 | best_accuracy: 0.163 18 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def cal_param_size(model): 4 | return sum([i.numel() for i in model.parameters()]) 5 | 6 | 7 | count_ops = 0 8 | def measure_layer(layer, x, multi_add=1): 9 | delta_ops = 0 10 | type_name = str(layer)[:str(layer).find('(')].strip() 11 | # print(type_name) 12 | if type_name in ['Conv2d']: 13 | out_h = int((x.size()[2] + 2 * layer.padding[0] - layer.kernel_size[0]) // 14 | layer.stride[0] + 1) 15 | out_w = int((x.size()[3] + 2 * layer.padding[1] - layer.kernel_size[1]) // 16 | layer.stride[1] + 1) 17 | delta_ops = layer.in_channels * layer.out_channels * layer.kernel_size[0] * \ 18 | layer.kernel_size[1] * out_h * out_w // layer.groups * multi_add 19 | 20 | elif type_name in ['Linear']: 21 | weight_ops = layer.weight.numel() * multi_add 22 | #bias_ops = layer.bias.numel() 23 | delta_ops = weight_ops + 0#bias_ops 24 | 25 | global count_ops 26 | count_ops += delta_ops 27 | return 28 | 29 | def is_leaf(module): 30 | return sum(1 for x in module.children()) == 0 31 | 32 | 33 | def should_measure(module): 34 | if str(module).startswith('Sequential'): 35 | return False 36 | if is_leaf(module): 37 | return True 38 | return False 39 | 40 | def cal_multi_adds(model, shape=(2,3,32,32)): 41 | global count_ops 42 | count_ops = 0 43 | data = torch.zeros(shape) 44 | 45 | def new_forward(m): 46 | def lambda_forward(x): 47 | measure_layer(m, x) 48 | return m.old_forward(x) 49 | return lambda_forward 50 | 51 | def modify_forward(model): 52 | for child in model.children(): 53 | if should_measure(child): 54 | child.old_forward = child.forward 55 | child.forward = new_forward(child) 56 | else: 57 | modify_forward(child) 58 | 59 | def restore_forward(model): 60 | for child in model.children(): 61 | if is_leaf(child) and hasattr(child, 'old_forward'): 62 | child.forward = child.old_forward 63 | child.old_forward = None 64 | else: 65 | restore_forward(child) 66 | 67 | modify_forward(model) 68 | model.forward(data) 69 | restore_forward(model) 70 | 71 | return count_ops 72 | 73 | 74 | 75 | -------------------------------------------------------------------------------- /result/train_online_kd_cifar_arch_resnet56_resnet56_dataset_cifar100_seed0.txt: -------------------------------------------------------------------------------- 1 | ========== 2 | Args:Namespace(arch='resnet56_resnet56', batch_size=64, checkpoint_dir='./checkpoint/train_online_kd_cifar_arch_resnet56_resnet56_dataset_cifar100_seed0', data='/home/ycg/hhd/dataset/', dataset='cifar100', epochs=1, evaluate=False, gpu_id='0', init_lr=0.01, kd_T=3, log_dir='train_online_kd_cifar_arch_resnet56_resnet56_dataset_cifar100_seed0', log_txt='result/train_online_kd_cifar_arch_resnet56_resnet56_dataset_cifar100_seed0.txt', lr_type='multistep', manual_seed=0, milestones=[150, 180, 210], num_workers=8, resume=False, sgdr_t=300, warmup_epoch=0, weight_decay=0.0005) 3 | ========== 4 | ========== 5 | Args:Namespace(arch='resnet56_resnet56', batch_size=64, checkpoint_dir='./checkpoint/train_online_kd_cifar_arch_resnet56_resnet56_dataset_cifar100_seed0', data='/home/ycg/hhd/dataset/', dataset='cifar100', epochs=1, evaluate=False, gpu_id='0', init_lr=0.01, kd_T=3, log_dir='train_online_kd_cifar_arch_resnet56_resnet56_dataset_cifar100_seed0', log_txt='result/train_online_kd_cifar_arch_resnet56_resnet56_dataset_cifar100_seed0.txt', lr_type='multistep', manual_seed=0, milestones=[150, 180, 210], num_workers=8, resume=False, sgdr_t=300, warmup_epoch=0, weight_decay=0.0005) 6 | ========== 7 | ========== 8 | Args:Namespace(arch='resnet56_resnet56', batch_size=64, checkpoint_dir='./checkpoint/train_online_kd_cifar_arch_resnet56_resnet56_dataset_cifar100_seed0', data='/home/ycg/hhd/dataset/', dataset='cifar100', epochs=1, evaluate=False, gpu_id='0', init_lr=0.01, kd_T=3, log_dir='train_online_kd_cifar_arch_resnet56_resnet56_dataset_cifar100_seed0', log_txt='result/train_online_kd_cifar_arch_resnet56_resnet56_dataset_cifar100_seed0.txt', lr_type='multistep', manual_seed=0, milestones=[150, 180, 210], num_workers=8, resume=False, sgdr_t=300, warmup_epoch=0, weight_decay=0.0005) 9 | ========== 10 | Epoch:0 lr:0.01000 duration:294.369 11 | train_loss:43.33140 train_loss_cls:42.94350 train_loss_div:0.38790 12 | Train Top-1 ss_accuracy: [[0.0092, 0.0084, 0.0119], [0.0095, 0.0094, 0.0127]] 13 | Train Top-1 class_accuracy: [[0.033, 0.0319, 0.0367, 0.0465], [0.0351, 0.0345, 0.0379, 0.0482]] 14 | test epoch:0 test_loss_cls:7.83578 15 | Top-1 ss_accuracy: [[0.0217, 0.0182, 0.0276], [0.0219, 0.0203, 0.0309]] 16 | Top-1 class_accuracy: [[0.0632, 0.061, 0.0645, 0.0875], [0.066, 0.0651, 0.0686, 0.0874]] 17 | test epoch:0 test_loss_cls:7.83578 18 | Top-1 ss_accuracy: [[0.0217, 0.0182, 0.0276], [0.0219, 0.0203, 0.0309]] 19 | Top-1 class_accuracy: [[0.0632, 0.061, 0.0645, 0.0875], [0.066, 0.0651, 0.0686, 0.0874]] 20 | best_accuracy: 0.0875 21 | -------------------------------------------------------------------------------- /result/train_online_kd_cifar_arch_shufflev2_shufflev2_dataset_cifar100_seed0.txt: -------------------------------------------------------------------------------- 1 | ========== 2 | Args:Namespace(arch='shufflev2_shufflev2', batch_size=64, checkpoint_dir='./checkpoint/train_online_kd_cifar_arch_shufflev2_shufflev2_dataset_cifar100_seed0', data='/home/ycg/hhd/dataset/', dataset='cifar100', epochs=1, evaluate=False, gpu_id='0', init_lr=0.01, kd_T=3, log_dir='train_online_kd_cifar_arch_shufflev2_shufflev2_dataset_cifar100_seed0', log_txt='result/train_online_kd_cifar_arch_shufflev2_shufflev2_dataset_cifar100_seed0.txt', lr_type='multistep', manual_seed=0, milestones=[150, 180, 210], num_workers=8, resume=False, sgdr_t=300, warmup_epoch=0, weight_decay=0.0005) 3 | ========== 4 | ========== 5 | Args:Namespace(arch='shufflev2_shufflev2', batch_size=64, checkpoint_dir='./checkpoint/train_online_kd_cifar_arch_shufflev2_shufflev2_dataset_cifar100_seed0', data='/home/ycg/hhd/dataset/', dataset='cifar100', epochs=1, evaluate=False, gpu_id='0', init_lr=0.01, kd_T=3, log_dir='train_online_kd_cifar_arch_shufflev2_shufflev2_dataset_cifar100_seed0', log_txt='result/train_online_kd_cifar_arch_shufflev2_shufflev2_dataset_cifar100_seed0.txt', lr_type='multistep', manual_seed=0, milestones=[150, 180, 210], num_workers=8, resume=False, sgdr_t=300, warmup_epoch=0, weight_decay=0.0005) 6 | ========== 7 | ========== 8 | Args:Namespace(arch='shufflev2_shufflev2', batch_size=64, checkpoint_dir='./checkpoint/train_online_kd_cifar_arch_shufflev2_shufflev2_dataset_cifar100_seed0', data='/home/ycg/hhd/dataset/', dataset='cifar100', epochs=1, evaluate=False, gpu_id='0', init_lr=0.01, kd_T=3, log_dir='train_online_kd_cifar_arch_shufflev2_shufflev2_dataset_cifar100_seed0', log_txt='result/train_online_kd_cifar_arch_shufflev2_shufflev2_dataset_cifar100_seed0.txt', lr_type='multistep', manual_seed=0, milestones=[150, 180, 210], num_workers=8, resume=False, sgdr_t=300, warmup_epoch=0, weight_decay=0.0005) 9 | ========== 10 | Epoch:0 lr:0.01000 duration:378.254 11 | train_loss:41.96235 train_loss_cls:41.70390 train_loss_div:0.25844 12 | Train Top-1 ss_accuracy: [[0.0159, 0.0213, 0.0194], [0.012, 0.0183, 0.0184]] 13 | Train Top-1 class_accuracy: [[0.0444, 0.0504, 0.0495, 0.0711], [0.0417, 0.0502, 0.0492, 0.0687]] 14 | test epoch:0 test_loss_cls:7.30475 15 | Top-1 ss_accuracy: [[0.0456, 0.0568, 0.049], [0.0334, 0.0502, 0.0483]] 16 | Top-1 class_accuracy: [[0.0936, 0.1011, 0.0943, 0.1218], [0.0897, 0.0921, 0.0894, 0.1208]] 17 | test epoch:0 test_loss_cls:7.30475 18 | Top-1 ss_accuracy: [[0.0456, 0.0568, 0.049], [0.0334, 0.0502, 0.0483]] 19 | Top-1 class_accuracy: [[0.0936, 0.1011, 0.0943, 0.1218], [0.0897, 0.0921, 0.0894, 0.1208]] 20 | best_accuracy: 0.1218 21 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch.nn as nn 4 | import math 5 | import os 6 | import sys 7 | import time 8 | import math 9 | 10 | import operator 11 | from functools import reduce 12 | import torch.nn as nn 13 | import torch 14 | import torch.nn.init as init 15 | 16 | 17 | def cal_param_size(model): 18 | return sum([i.numel() for i in model.parameters()]) 19 | 20 | 21 | count_ops = 0 22 | def measure_layer(layer, x, multi_add=1): 23 | delta_ops = 0 24 | type_name = str(layer)[:str(layer).find('(')].strip() 25 | 26 | if type_name in ['Conv2d']: 27 | out_h = int((x.size()[2] + 2 * layer.padding[0] - layer.kernel_size[0]) // 28 | layer.stride[0] + 1) 29 | out_w = int((x.size()[3] + 2 * layer.padding[1] - layer.kernel_size[1]) // 30 | layer.stride[1] + 1) 31 | delta_ops = layer.in_channels * layer.out_channels * layer.kernel_size[0] * \ 32 | layer.kernel_size[1] * out_h * out_w // layer.groups * multi_add 33 | 34 | ### ops_linear 35 | elif type_name in ['Linear']: 36 | weight_ops = layer.weight.numel() * multi_add 37 | bias_ops = 0 38 | delta_ops = weight_ops + bias_ops 39 | 40 | global count_ops 41 | count_ops += delta_ops 42 | return 43 | 44 | 45 | def is_leaf(module): 46 | return sum(1 for x in module.children()) == 0 47 | 48 | 49 | def should_measure(module): 50 | if str(module).startswith('Sequential'): 51 | return False 52 | if is_leaf(module): 53 | return True 54 | return False 55 | 56 | 57 | def cal_multi_adds(model, shape=(2,3,32,32)): 58 | global count_ops 59 | count_ops = 0 60 | data = torch.zeros(shape) 61 | 62 | def new_forward(m): 63 | def lambda_forward(x): 64 | measure_layer(m, x) 65 | return m.old_forward(x) 66 | return lambda_forward 67 | 68 | def modify_forward(model): 69 | for child in model.children(): 70 | if should_measure(child): 71 | child.old_forward = child.forward 72 | child.forward = new_forward(child) 73 | else: 74 | modify_forward(child) 75 | 76 | def restore_forward(model): 77 | for child in model.children(): 78 | if is_leaf(child) and hasattr(child, 'old_forward'): 79 | child.forward = child.old_forward 80 | child.old_forward = None 81 | else: 82 | restore_forward(child) 83 | 84 | modify_forward(model) 85 | model.forward(data) 86 | restore_forward(model) 87 | 88 | return count_ops 89 | 90 | -------------------------------------------------------------------------------- /train_student_cifar.sh: -------------------------------------------------------------------------------- 1 | 2 | # using the teacher network of the version of a frozen backbone 3 | 4 | python train_student_cifar.py \ 5 | --tarch wrn_40_2_aux \ 6 | --arch wrn_16_2_aux \ 7 | --tcheckpoint ./checkpoint/train_teacher_cifar_arch_wrn_40_2_aux_dataset_cifar100_seed0/wrn_40_2_aux.pth.tar \ 8 | --checkpoint-dir ./checkpoint \ 9 | --data ./data \ 10 | --gpu 0 --manual 0 11 | 12 | python train_student_cifar.py --tarch wrn_40_2_aux --arch wrn_40_1_aux \ 13 | --tcheckpoint ./checkpoint/train_teacher_cifar_arch_wrn_40_2_aux_dataset_cifar100_seed0/wrn_40_2_aux.pth.tar \ 14 | --checkpoint-dir ./checkpoint \ 15 | --data /home/ycg/hhd/dataset/ \ 16 | --gpu 2 --manual 0 17 | 18 | 19 | python train_student_cifar.py --tarch wrn_40_2_aux --arch ShuffleV1_aux \ 20 | --tcheckpoint ./checkpoint/train_teacher_cifar_arch_wrn_40_2_aux_dataset_cifar100_seed0/wrn_40_2_aux.pth.tar \ 21 | --checkpoint-dir ./checkpoint \ 22 | --data /home/ycg/hhd/dataset/ \ 23 | --init-lr 0.01 \ 24 | --gpu 2 --manual 0 25 | 26 | python train_student_cifar.py --tarch resnet32x4_aux --arch resnet8x4_aux \ 27 | --tcheckpoint ./checkpoint/train_teacher_cifar_arch_resnet32x4_aux_dataset_cifar100_seed0/resnet32x4_aux.pth.tar \ 28 | --checkpoint-dir ./checkpoint \ 29 | --data /home/ycg/hhd/dataset/ \ 30 | --gpu 2 --manual 0 31 | 32 | python train_student_cifar.py --tarch resnet56_aux --arch resnet20_aux \ 33 | --tcheckpoint ./checkpoint/train_teacher_cifar_arch_resnet56_aux_dataset_cifar100_seed0/resnet56_aux.pth.tar \ 34 | --checkpoint-dir ./checkpoint \ 35 | --data /home/ycg/hhd/dataset/ \ 36 | --gpu 2 --manual 0 37 | 38 | python train_student_cifar.py --tarch vgg13_bn_aux --arch mobilenetV2_aux \ 39 | --tcheckpoint ./checkpoint/train_teacher_cifar_arch_vgg13_bn_aux_dataset_cifar100_seed0/vgg13_bn_aux.pth.tar \ 40 | --checkpoint-dir ./checkpoint \ 41 | --data /home/ycg/hhd/dataset/ \ 42 | --init-lr 0.01 \ 43 | --gpu 2 --manual 0 44 | 45 | python train_student_cifar.py --tarch ResNet50_aux --arch mobilenetV2_aux \ 46 | --tcheckpoint ./checkpoint/train_teacher_cifar_arch_ResNet50_aux_dataset_cifar100_seed0/ResNet50_aux.pth.tar \ 47 | --checkpoint-dir ./checkpoint \ 48 | --data /home/ycg/hhd/dataset/ \ 49 | --init-lr 0.01 \ 50 | --gpu 2 --manual 0 51 | 52 | python train_student_cifar.py --tarch resnet32x4_aux --arch ShuffleV2_aux \ 53 | --tcheckpoint ./checkpoint/train_teacher_cifar_arch_resnet32x4_aux_dataset_cifar100_seed0/resnet32x4_aux.pth.tar \ 54 | --checkpoint-dir ./checkpoint \ 55 | --data /home/ycg/hhd/dataset/ \ 56 | --init-lr 0.01 \ 57 | --gpu 2 --manual 0 58 | 59 | 60 | # using the teacher network of the joint training version 61 | python train_student_cifar.py --tarch wrn_40_2_aux --arch wrn_40_1_aux \ 62 | --tcheckpoint ./checkpoint/train_teacher_cifar_arch_wrn_40_2_aux_dataset_cifar100_seed1/wrn_40_2_aux.pth.tar \ 63 | --checkpoint-dir ./checkpoint \ 64 | --data /home/ycg/hhd/dataset/ \ 65 | --gpu 2 --manual 0 66 | 67 | python train_student_cifar.py \ 68 | --tarch wrn_40_2_aux \ 69 | --arch wrn_16_2_aux \ 70 | --tcheckpoint ./checkpoint/train_teacher_cifar_arch_wrn_40_2_aux_dataset_cifar100_seed1/wrn_40_2_aux.pth.tar \ 71 | --checkpoint-dir ./checkpoint \ 72 | --data ./data \ 73 | --gpu 0 --manual 0 74 | 75 | 76 | python train_student_cifar.py --tarch wrn_40_2_aux --arch ShuffleV1_aux \ 77 | --tcheckpoint ./checkpoint/train_teacher_cifar_arch_wrn_40_2_aux_dataset_cifar100_seed1/wrn_40_2_aux.pth.tar \ 78 | --checkpoint-dir ./checkpoint \ 79 | --data /home/ycg/hhd/dataset/ \ 80 | --init-lr 0.01 \ 81 | --gpu 2 --manual 0 82 | 83 | python train_student_cifar.py --tarch resnet32x4_aux --arch resnet8x4_aux \ 84 | --tcheckpoint ./checkpoint/train_teacher_cifar_arch_resnet32x4_aux_dataset_cifar100_seed1/resnet32x4_aux.pth.tar \ 85 | --checkpoint-dir ./checkpoint \ 86 | --data /home/ycg/hhd/dataset/ \ 87 | --gpu 2 --manual 0 88 | 89 | python train_student_cifar.py --tarch resnet56_aux --arch resnet20_aux \ 90 | --tcheckpoint ./checkpoint/train_teacher_cifar_arch_resnet56_aux_dataset_cifar100_seed1/resnet56_aux.pth.tar \ 91 | --checkpoint-dir ./checkpoint \ 92 | --data /home/ycg/hhd/dataset/ \ 93 | --gpu 2 --manual 0 94 | 95 | python train_student_cifar.py --tarch vgg13_bn_aux --arch mobilenetV2_aux \ 96 | --tcheckpoint ./checkpoint/train_teacher_cifar_arch_vgg13_bn_aux_dataset_cifar100_seed1/vgg13_bn_aux.pth.tar \ 97 | --checkpoint-dir ./checkpoint \ 98 | --data /home/ycg/hhd/dataset/ \ 99 | --init-lr 0.01 \ 100 | --gpu 2 --manual 0 101 | 102 | python train_student_cifar.py --tarch ResNet50_aux --arch mobilenetV2_aux \ 103 | --tcheckpoint ./checkpoint/train_teacher_cifar_arch_ResNet50_aux_dataset_cifar100_seed1/ResNet50_aux.pth.tar \ 104 | --checkpoint-dir ./checkpoint \ 105 | --data /home/ycg/hhd/dataset/ \ 106 | --init-lr 0.01 \ 107 | --gpu 2 --manual 0 108 | 109 | python train_student_cifar.py --tarch resnet32x4_aux --arch ShuffleV2_aux \ 110 | --tcheckpoint ./checkpoint/train_teacher_cifar_arch_resnet32x4_aux_dataset_cifar100_seed1/resnet32x4_aux.pth.tar \ 111 | --checkpoint-dir ./checkpoint \ 112 | --data /home/ycg/hhd/dataset/ \ 113 | --init-lr 0.01 \ 114 | --gpu 2 --manual 0 115 | -------------------------------------------------------------------------------- /models/mutual_kd_net.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append('.') 5 | sys.path.append(os.getcwd()) 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | from .resnet import * 11 | from .resnetv2 import * 12 | from .wrn import * 13 | from .vgg import * 14 | from .ShuffleNetv1 import * 15 | from .ShuffleNetv2 import * 16 | from .mobilenetv2 import * 17 | from .resnet_imagenet import * 18 | 19 | 20 | __all__ = ['mbv2_mbv2', 'WRN_40_2_WRN_40_2', 'WRN_40_1_WRN_40_1', 'resnet56_resnet56', 21 | 'shufflev1_shufflev1', 'shufflev2_shufflev2', 22 | 'resnet18_resnet18', 'resnet32x4_resnet32x4', 23 | 'vgg13_vgg13'] 24 | 25 | 26 | class mbv2_mbv2(nn.Module): 27 | def __init__(self,num_classes): 28 | super(mbv2_mbv2, self).__init__() 29 | self.net1 = mobilenetV2_aux(num_classes) 30 | self.net2 = mobilenetV2_aux(num_classes) 31 | 32 | def forward(self, x, grad=True): 33 | logit1, ss_logits1 = self.net1(x, grad=grad) 34 | logit2, ss_logits2 = self.net2(x, grad=grad) 35 | return [logit1, logit2], [ss_logits1, ss_logits2] 36 | 37 | 38 | class WRN_40_2_WRN_40_2(nn.Module): 39 | def __init__(self,num_classes): 40 | super(WRN_40_2_WRN_40_2, self).__init__() 41 | self.net1 = wrn_40_2_aux(num_classes=num_classes) 42 | self.net2 = wrn_40_2_aux(num_classes=num_classes) 43 | 44 | def forward(self, x, grad=True): 45 | logit1, ss_logits1 = self.net1(x, grad=grad) 46 | logit2, ss_logits2 = self.net2(x, grad=grad) 47 | return [logit1, logit2], [ss_logits1, ss_logits2] 48 | 49 | 50 | class WRN_16_2_WRN_16_2(nn.Module): 51 | def __init__(self,num_classes): 52 | super(WRN_16_2_WRN_16_2, self).__init__() 53 | self.net1 = wrn_16_2_aux(num_classes=num_classes) 54 | self.net2 = wrn_16_2_aux(num_classes=num_classes) 55 | 56 | def forward(self, x, grad=True): 57 | logit1, ss_logits1 = self.net1(x, grad=grad) 58 | logit2, ss_logits2 = self.net2(x, grad=grad) 59 | return [logit1, logit2], [ss_logits1, ss_logits2] 60 | 61 | 62 | class resnet20_resnet20(nn.Module): 63 | def __init__(self,num_classes): 64 | super(resnet20_resnet20, self).__init__() 65 | self.net1 = resnet20_aux(num_classes=num_classes) 66 | self.net2 = resnet20_aux(num_classes=num_classes) 67 | 68 | def forward(self, x, grad=True): 69 | logit1, ss_logits1 = self.net1(x, grad=grad) 70 | logit2, ss_logits2 = self.net2(x, grad=grad) 71 | return [logit1, logit2], [ss_logits1, ss_logits2] 72 | 73 | 74 | class resnet56_resnet56(nn.Module): 75 | def __init__(self,num_classes): 76 | super(resnet56_resnet56, self).__init__() 77 | self.net1 = resnet56_aux(num_classes=num_classes) 78 | self.net2 = resnet56_aux(num_classes=num_classes) 79 | 80 | 81 | def forward(self, x, grad=True): 82 | logit1, ss_logits1 = self.net1(x, grad=grad) 83 | logit2, ss_logits2 = self.net2(x, grad=grad) 84 | return [logit1, logit2], [ss_logits1, ss_logits2] 85 | 86 | 87 | class resnet32x4_resnet32x4(nn.Module): 88 | def __init__(self,num_classes): 89 | super(resnet32x4_resnet32x4, self).__init__() 90 | self.net1 = resnet32x4_aux(num_classes=num_classes) 91 | self.net2 = resnet32x4_aux(num_classes=num_classes) 92 | 93 | def forward(self, x, grad=True): 94 | logit1, ss_logits1 = self.net1(x, grad=grad) 95 | logit2, ss_logits2 = self.net2(x, grad=grad) 96 | return [logit1, logit2], [ss_logits1, ss_logits2] 97 | 98 | 99 | class resnet8x4_resnet8x4(nn.Module): 100 | def __init__(self,num_classes): 101 | super(resnet8x4_resnet8x4, self).__init__() 102 | self.net1 = resnet8x4_aux(num_classes=num_classes) 103 | self.net2 = resnet8x4_aux(num_classes=num_classes) 104 | 105 | def forward(self, x, grad=True): 106 | logit1, ss_logits1 = self.net1(x, grad=grad) 107 | logit2, ss_logits2 = self.net2(x, grad=grad) 108 | return [logit1, logit2], [ss_logits1, ss_logits2] 109 | 110 | class vgg13_vgg13(nn.Module): 111 | def __init__(self,num_classes): 112 | super(vgg13_vgg13, self).__init__() 113 | self.net1 = vgg13_bn_aux(num_classes=num_classes) 114 | self.net2 = vgg13_bn_aux(num_classes=num_classes) 115 | 116 | def forward(self, x, grad=True): 117 | logit1, ss_logits1 = self.net1(x, grad=grad) 118 | logit2, ss_logits2 = self.net2(x, grad=grad) 119 | return [logit1, logit2], [ss_logits1, ss_logits2] 120 | 121 | 122 | class shufflev1_shufflev1(nn.Module): 123 | def __init__(self,num_classes): 124 | super(shufflev1_shufflev1, self).__init__() 125 | self.net1 = ShuffleV1_aux(num_classes=num_classes) 126 | self.net2 = ShuffleV1_aux(num_classes=num_classes) 127 | 128 | def forward(self, x, grad=True): 129 | logit1, ss_logits1 = self.net1(x, grad=grad) 130 | logit2, ss_logits2 = self.net2(x, grad=grad) 131 | return [logit1, logit2], [ss_logits1, ss_logits2] 132 | 133 | 134 | class shufflev2_shufflev2(nn.Module): 135 | def __init__(self,num_classes): 136 | super(shufflev2_shufflev2, self).__init__() 137 | self.net1 = ShuffleV2_aux(num_classes=num_classes) 138 | self.net2 = ShuffleV2_aux(num_classes=num_classes) 139 | 140 | def forward(self, x, grad=True): 141 | logit1, ss_logits1 = self.net1(x, grad=grad) 142 | logit2, ss_logits2 = self.net2(x, grad=grad) 143 | return [logit1, logit2], [ss_logits1, ss_logits2] 144 | 145 | 146 | class WRN_40_1_WRN_40_1(nn.Module): 147 | def __init__(self,num_classes): 148 | super(WRN_40_1_WRN_40_1, self).__init__() 149 | self.net1 = wrn_40_1_aux(num_classes=num_classes) 150 | self.net2 = wrn_40_1_aux(num_classes=num_classes) 151 | 152 | def forward(self, x, grad=True): 153 | logit1, ss_logits1 = self.net1(x, grad=grad) 154 | logit2, ss_logits2 = self.net2(x, grad=grad) 155 | return [logit1, logit2], [ss_logits1, ss_logits2] 156 | 157 | 158 | 159 | 160 | class resnet18_resnet18(nn.Module): 161 | def __init__(self,num_classes): 162 | super(resnet18_resnet18, self).__init__() 163 | self.net1 = resnet18_imagenet_aux(num_classes=num_classes) 164 | self.net2 = resnet18_imagenet_aux(num_classes=num_classes) 165 | 166 | def forward(self, x, grad=True): 167 | logit1, ss_logits1 = self.net1(x, grad=grad) 168 | logit2, ss_logits2 = self.net2(x, grad=grad) 169 | return [logit1, logit2], [ss_logits1, ss_logits2] 170 | 171 | 172 | 173 | if __name__ == '__main__': 174 | x = torch.randn(2, 3, 32, 32) 175 | net = mbv2_mbv2(100) 176 | logit, ss_logits = net(x) 177 | print(logit) 178 | print(ss_logits) 179 | from utils import cal_param_size, cal_multi_adds 180 | print('Params: %.2fM, Multi-adds: %.3fM' 181 | % (cal_param_size(net) / 1e6, cal_multi_adds(net, (2, 3, 32, 32)) / 1e6)) -------------------------------------------------------------------------------- /models/ShuffleNetv1.py: -------------------------------------------------------------------------------- 1 | '''ShuffleNet in PyTorch. 2 | See the paper "ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices" for more details. 3 | ''' 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import math 8 | 9 | 10 | __all__ = ['ShuffleV1_aux', 'ShuffleV1'] 11 | class ShuffleBlock(nn.Module): 12 | def __init__(self, groups): 13 | super(ShuffleBlock, self).__init__() 14 | self.groups = groups 15 | 16 | def forward(self, x): 17 | '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]''' 18 | N,C,H,W = x.size() 19 | g = self.groups 20 | return x.view(N,g,C//g,H,W).permute(0,2,1,3,4).reshape(N,C,H,W) 21 | 22 | 23 | class Bottleneck(nn.Module): 24 | def __init__(self, in_planes, out_planes, stride, groups, is_last=False): 25 | super(Bottleneck, self).__init__() 26 | self.is_last = is_last 27 | self.stride = stride 28 | 29 | mid_planes = int(out_planes/4) 30 | g = 1 if in_planes == 24 else groups 31 | self.conv1 = nn.Conv2d(in_planes, mid_planes, kernel_size=1, groups=g, bias=False) 32 | self.bn1 = nn.BatchNorm2d(mid_planes) 33 | self.shuffle1 = ShuffleBlock(groups=g) 34 | self.conv2 = nn.Conv2d(mid_planes, mid_planes, kernel_size=3, stride=stride, padding=1, groups=mid_planes, bias=False) 35 | self.bn2 = nn.BatchNorm2d(mid_planes) 36 | self.conv3 = nn.Conv2d(mid_planes, out_planes, kernel_size=1, groups=groups, bias=False) 37 | self.bn3 = nn.BatchNorm2d(out_planes) 38 | 39 | self.shortcut = nn.Sequential() 40 | if stride == 2: 41 | self.shortcut = nn.Sequential(nn.AvgPool2d(3, stride=2, padding=1)) 42 | 43 | def forward(self, x): 44 | out = F.relu(self.bn1(self.conv1(x))) 45 | out = self.shuffle1(out) 46 | out = F.relu(self.bn2(self.conv2(out))) 47 | out = self.bn3(self.conv3(out)) 48 | res = self.shortcut(x) 49 | preact = torch.cat([out, res], 1) if self.stride == 2 else out+res 50 | out = F.relu(preact) 51 | # out = F.relu(torch.cat([out, res], 1)) if self.stride == 2 else F.relu(out+res) 52 | return out 53 | 54 | 55 | class ShuffleNet(nn.Module): 56 | def __init__(self, cfg, num_classes=10): 57 | super(ShuffleNet, self).__init__() 58 | out_planes = cfg['out_planes'] 59 | num_blocks = cfg['num_blocks'] 60 | groups = cfg['groups'] 61 | 62 | self.conv1 = nn.Conv2d(3, 24, kernel_size=1, bias=False) 63 | self.bn1 = nn.BatchNorm2d(24) 64 | self.in_planes = 24 65 | self.layer1 = self._make_layer(out_planes[0], num_blocks[0], groups) 66 | self.layer2 = self._make_layer(out_planes[1], num_blocks[1], groups) 67 | self.layer3 = self._make_layer(out_planes[2], num_blocks[2], groups) 68 | self.linear = nn.Linear(out_planes[2], num_classes) 69 | 70 | def _make_layer(self, out_planes, num_blocks, groups): 71 | layers = [] 72 | for i in range(num_blocks): 73 | stride = 2 if i == 0 else 1 74 | cat_planes = self.in_planes if i == 0 else 0 75 | layers.append(Bottleneck(self.in_planes, out_planes-cat_planes, 76 | stride=stride, 77 | groups=groups, 78 | is_last=(i == num_blocks - 1))) 79 | self.in_planes = out_planes 80 | return nn.Sequential(*layers) 81 | 82 | def get_feat_modules(self): 83 | feat_m = nn.ModuleList([]) 84 | feat_m.append(self.conv1) 85 | feat_m.append(self.bn1) 86 | feat_m.append(self.layer1) 87 | feat_m.append(self.layer2) 88 | feat_m.append(self.layer3) 89 | return feat_m 90 | 91 | def get_bn_before_relu(self): 92 | raise NotImplementedError('ShuffleNet currently is not supported for "Overhaul" teacher') 93 | 94 | def forward(self, x, is_feat=False, preact=False): 95 | out = F.relu(self.bn1(self.conv1(x))) 96 | out = self.layer1(out) 97 | f1 = out 98 | out = self.layer2(out) 99 | f2 = out 100 | out = self.layer3(out) 101 | f3 = out 102 | out = F.avg_pool2d(out, 4) 103 | out = out.view(out.size(0), -1) 104 | f4 = out 105 | out = self.linear(out) 106 | 107 | if is_feat: 108 | return [f1, f2, f3], out 109 | else: 110 | return out 111 | 112 | 113 | class Auxiliary_Classifier(nn.Module): 114 | def __init__(self, cfg, num_classes=10): 115 | super(Auxiliary_Classifier, self).__init__() 116 | out_planes = cfg['out_planes'] 117 | num_blocks = cfg['num_blocks'] 118 | groups = cfg['groups'] 119 | 120 | self.in_planes = out_planes[0] 121 | self.block_extractor1 = nn.Sequential(*[self._make_layer(out_planes[1], num_blocks[1], groups), 122 | self._make_layer(out_planes[2], num_blocks[2], groups)]) 123 | self.in_planes = out_planes[1] 124 | self.block_extractor2 = nn.Sequential(*[self._make_layer(out_planes[2], num_blocks[2], groups)]) 125 | 126 | self.inplanes = out_planes[2] 127 | self.block_extractor3 = nn.Sequential(*[self._make_layer(out_planes[2], num_blocks[2], groups, downsample=False)]) 128 | 129 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 130 | self.fc1 = nn.Linear(out_planes[2], num_classes) 131 | self.fc2 = nn.Linear(out_planes[2], num_classes) 132 | self.fc3 = nn.Linear(out_planes[2], num_classes) 133 | 134 | for m in self.modules(): 135 | if isinstance(m, nn.Conv2d): 136 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 137 | m.weight.data.normal_(0, math.sqrt(2. / n)) 138 | elif isinstance(m, nn.BatchNorm2d): 139 | m.weight.data.fill_(1) 140 | m.bias.data.zero_() 141 | elif isinstance(m, nn.Linear): 142 | m.bias.data.zero_() 143 | 144 | def _make_layer(self, out_planes, num_blocks, groups, downsample=True): 145 | layers = [] 146 | for i in range(num_blocks): 147 | stride = 2 if i == 0 and downsample is True else 1 148 | cat_planes = self.in_planes if i == 0 and downsample is True else 0 149 | layers.append(Bottleneck(self.in_planes, out_planes-cat_planes, 150 | stride=stride, 151 | groups=groups, 152 | is_last=(i == num_blocks - 1))) 153 | self.in_planes = out_planes 154 | return nn.Sequential(*layers) 155 | 156 | 157 | def forward(self, x): 158 | ss_logits = [] 159 | for i in range(len(x)): 160 | idx = i + 1 161 | out = getattr(self, 'block_extractor'+str(idx))(x[i]) 162 | out = self.avg_pool(out) 163 | out = out.view(out.size(0), -1) 164 | out = getattr(self, 'fc'+str(idx))(out) 165 | ss_logits.append(out) 166 | return ss_logits 167 | 168 | 169 | class ShuffleNet_Auxiliary(nn.Module): 170 | def __init__(self, cfg, num_classes=100): 171 | super(ShuffleNet_Auxiliary, self).__init__() 172 | self.backbone = ShuffleNet(cfg, num_classes=num_classes) 173 | self.auxiliary_classifier = Auxiliary_Classifier(cfg, num_classes=num_classes * 4) 174 | 175 | def forward(self, x, grad=False): 176 | feats, logit = self.backbone(x, is_feat=True) 177 | if grad is False: 178 | for i in range(len(feats)): 179 | feats[i] = feats[i].detach() 180 | ss_logits = self.auxiliary_classifier(feats) 181 | return logit, ss_logits 182 | 183 | 184 | def ShuffleV1(**kwargs): 185 | cfg = { 186 | 'out_planes': [240, 480, 960], 187 | 'num_blocks': [4, 8, 4], 188 | 'groups': 3 189 | } 190 | return ShuffleNet(cfg, **kwargs) 191 | 192 | 193 | def ShuffleV1_aux(**kwargs): 194 | cfg = { 195 | 'out_planes': [240, 480, 960], 196 | 'num_blocks': [4, 8, 4], 197 | 'groups': 3 198 | } 199 | return ShuffleNet_Auxiliary(cfg, **kwargs) 200 | 201 | 202 | if __name__ == '__main__': 203 | x = torch.randn(2, 3, 32, 32) 204 | net = ShuffleV1_aux(num_classes=100) 205 | logit, ss_logits = net(x) 206 | print(logit.size()) 207 | -------------------------------------------------------------------------------- /models/wrn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | """ 7 | Original Author: Wei Yang 8 | """ 9 | 10 | __all__ = ['wrn', 'wrn_40_2_aux', 'wrn_16_2_aux', 'wrn_16_1', 'wrn_16_2', 'wrn_40_1', 'wrn_40_2', 11 | 'wrn_40_1_aux'] 12 | 13 | 14 | class BasicBlock(nn.Module): 15 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 16 | super(BasicBlock, self).__init__() 17 | self.bn1 = nn.BatchNorm2d(in_planes) 18 | self.relu1 = nn.ReLU(inplace=True) 19 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 20 | padding=1, bias=False) 21 | self.bn2 = nn.BatchNorm2d(out_planes) 22 | self.relu2 = nn.ReLU(inplace=True) 23 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 24 | padding=1, bias=False) 25 | self.droprate = dropRate 26 | self.equalInOut = (in_planes == out_planes) 27 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 28 | padding=0, bias=False) or None 29 | 30 | def forward(self, x): 31 | if not self.equalInOut: 32 | x = self.relu1(self.bn1(x)) 33 | else: 34 | out = self.relu1(self.bn1(x)) 35 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 36 | if self.droprate > 0: 37 | out = F.dropout(out, p=self.droprate, training=self.training) 38 | out = self.conv2(out) 39 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 40 | 41 | 42 | class NetworkBlock(nn.Module): 43 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 44 | super(NetworkBlock, self).__init__() 45 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 46 | 47 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 48 | layers = [] 49 | for i in range(nb_layers): 50 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 51 | return nn.Sequential(*layers) 52 | 53 | def forward(self, x): 54 | return self.layer(x) 55 | 56 | 57 | class WideResNet(nn.Module): 58 | def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0): 59 | super(WideResNet, self).__init__() 60 | nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] 61 | assert (depth - 4) % 6 == 0, 'depth should be 6n+4' 62 | n = (depth - 4) // 6 63 | block = BasicBlock 64 | # 1st conv before any network block 65 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 66 | padding=1, bias=False) 67 | # 1st block 68 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 69 | # 2nd block 70 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 71 | # 3rd block 72 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 73 | # global average pooling and classifier 74 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 75 | self.relu = nn.ReLU(inplace=True) 76 | self.fc = nn.Linear(nChannels[3], num_classes) 77 | self.nChannels = nChannels[3] 78 | 79 | for m in self.modules(): 80 | if isinstance(m, nn.Conv2d): 81 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 82 | m.weight.data.normal_(0, math.sqrt(2. / n)) 83 | elif isinstance(m, nn.BatchNorm2d): 84 | m.weight.data.fill_(1) 85 | m.bias.data.zero_() 86 | elif isinstance(m, nn.Linear): 87 | m.bias.data.zero_() 88 | 89 | def get_feat_modules(self): 90 | feat_m = nn.ModuleList([]) 91 | feat_m.append(self.conv1) 92 | feat_m.append(self.block1) 93 | feat_m.append(self.block2) 94 | feat_m.append(self.block3) 95 | return feat_m 96 | 97 | def get_bn_before_relu(self): 98 | bn1 = self.block2.layer[0].bn1 99 | bn2 = self.block3.layer[0].bn1 100 | bn3 = self.bn1 101 | 102 | return [bn1, bn2, bn3] 103 | 104 | def forward(self, x, is_feat=False, preact=False): 105 | out = self.conv1(x) 106 | f0 = out 107 | out = self.block1(out) 108 | f1 = out 109 | out = self.block2(out) 110 | f2 = out 111 | out = self.block3(out) 112 | f3 = out 113 | out = self.relu(self.bn1(out)) 114 | out = F.avg_pool2d(out, 8) 115 | out = out.view(-1, self.nChannels) 116 | f4 = out 117 | out = self.fc(out) 118 | if is_feat: 119 | if preact: 120 | f1 = self.block2.layer[0].bn1(f1) 121 | f2 = self.block3.layer[0].bn1(f2) 122 | f3 = self.bn1(f3) 123 | return [f1, f2, f3], out 124 | else: 125 | return out 126 | 127 | 128 | class Auxiliary_Classifier(nn.Module): 129 | def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0): 130 | super(Auxiliary_Classifier, self).__init__() 131 | self.nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] 132 | block = BasicBlock 133 | n = (depth - 4) // 6 134 | self.block_extractor1 = nn.Sequential(*[NetworkBlock(n, self.nChannels[1], self.nChannels[2], block, 2), 135 | NetworkBlock(n, self.nChannels[2], self.nChannels[3], block, 2),]) 136 | self.block_extractor2 = nn.Sequential(*[NetworkBlock(n, self.nChannels[2], self.nChannels[3], block, 2)]) 137 | self.block_extractor3 = nn.Sequential(*[NetworkBlock(n, self.nChannels[3], self.nChannels[3], block, 1)]) 138 | 139 | self.bn1 = nn.BatchNorm2d(self.nChannels[3]) 140 | self.bn2 = nn.BatchNorm2d(self.nChannels[3]) 141 | self.bn3 = nn.BatchNorm2d(self.nChannels[3]) 142 | 143 | 144 | self.relu = nn.ReLU(inplace=True) 145 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 146 | self.fc1 = nn.Linear(self.nChannels[3], num_classes) 147 | self.fc2 = nn.Linear(self.nChannels[3], num_classes) 148 | self.fc3 = nn.Linear(self.nChannels[3], num_classes) 149 | 150 | for m in self.modules(): 151 | if isinstance(m, nn.Conv2d): 152 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 153 | m.weight.data.normal_(0, math.sqrt(2. / n)) 154 | elif isinstance(m, nn.BatchNorm2d): 155 | m.weight.data.fill_(1) 156 | m.bias.data.zero_() 157 | elif isinstance(m, nn.Linear): 158 | m.bias.data.zero_() 159 | 160 | def forward(self, x): 161 | ss_logits = [] 162 | ss_feats = [] 163 | for i in range(len(x)): 164 | idx = i + 1 165 | out = getattr(self, 'block_extractor'+str(idx))(x[i]) 166 | out = self.relu(getattr(self, 'bn'+str(idx))(out)) 167 | out = self.avg_pool(out) 168 | out = out.view(-1, self.nChannels[3]) 169 | ss_feats.append(out) 170 | out = getattr(self, 'fc'+str(idx))(out) 171 | ss_logits.append(out) 172 | return ss_logits 173 | 174 | 175 | class WideResNet_Auxiliary(nn.Module): 176 | def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0): 177 | super(WideResNet_Auxiliary, self).__init__() 178 | self.backbone = WideResNet(depth, num_classes, widen_factor=widen_factor) 179 | self.auxiliary_classifier = Auxiliary_Classifier(depth=depth, num_classes=num_classes * 4, widen_factor=widen_factor) 180 | 181 | def forward(self, x, grad=False): 182 | feats, logit = self.backbone(x, is_feat=True) 183 | if grad is False: 184 | for i in range(len(feats)): 185 | feats[i] = feats[i].detach() 186 | ss_logits = self.auxiliary_classifier(feats) 187 | 188 | return logit, ss_logits 189 | 190 | 191 | def wrn(**kwargs): 192 | """ 193 | Constructs a Wide Residual Networks. 194 | """ 195 | model = WideResNet(**kwargs) 196 | return model 197 | 198 | 199 | def wrn_40_2(**kwargs): 200 | model = WideResNet(depth=40, widen_factor=2, **kwargs) 201 | return model 202 | 203 | def wrn_40_2_aux(**kwargs): 204 | model = WideResNet_Auxiliary(depth=40, widen_factor=2, **kwargs) 205 | return model 206 | 207 | 208 | def wrn_40_1(**kwargs): 209 | model = WideResNet(depth=40, widen_factor=1, **kwargs) 210 | return model 211 | 212 | def wrn_40_1_aux(**kwargs): 213 | model = WideResNet_Auxiliary(depth=40, widen_factor=1, **kwargs) 214 | return model 215 | 216 | def wrn_16_2(**kwargs): 217 | model = WideResNet(depth=16, widen_factor=2, **kwargs) 218 | return model 219 | 220 | def wrn_16_2_aux(**kwargs): 221 | model = WideResNet_Auxiliary(depth=16, widen_factor=2, **kwargs) 222 | return model 223 | 224 | 225 | def wrn_16_1(**kwargs): 226 | model = WideResNet(depth=16, widen_factor=1, **kwargs) 227 | return model 228 | 229 | 230 | if __name__ == '__main__': 231 | import torch 232 | x = torch.randn(2, 3, 32, 32) 233 | net = wrn_40_2_aux(num_classes=100) 234 | logit, ss_logits = net(x) 235 | print(logit.size()) -------------------------------------------------------------------------------- /models/resnetv2.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | For Pre-activation ResNet, see 'preact_resnet.py'. 3 | Reference: 4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 5 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 6 | ''' 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | __all__ = ['ResNet50_aux'] 12 | class BasicBlock(nn.Module): 13 | expansion = 1 14 | 15 | def __init__(self, in_planes, planes, stride=1, is_last=False): 16 | super(BasicBlock, self).__init__() 17 | self.is_last = is_last 18 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 19 | self.bn1 = nn.BatchNorm2d(planes) 20 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | 23 | self.shortcut = nn.Sequential() 24 | if stride != 1 or in_planes != self.expansion * planes: 25 | self.shortcut = nn.Sequential( 26 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 27 | nn.BatchNorm2d(self.expansion * planes) 28 | ) 29 | 30 | def forward(self, x): 31 | out = F.relu(self.bn1(self.conv1(x))) 32 | out = self.bn2(self.conv2(out)) 33 | out += self.shortcut(x) 34 | preact = out 35 | out = F.relu(out) 36 | if self.is_last: 37 | return out, preact 38 | else: 39 | return out 40 | 41 | 42 | class Bottleneck(nn.Module): 43 | expansion = 4 44 | 45 | def __init__(self, in_planes, planes, stride=1, is_last=False): 46 | super(Bottleneck, self).__init__() 47 | self.is_last = is_last 48 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 49 | self.bn1 = nn.BatchNorm2d(planes) 50 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 51 | self.bn2 = nn.BatchNorm2d(planes) 52 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 53 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 54 | 55 | self.shortcut = nn.Sequential() 56 | if stride != 1 or in_planes != self.expansion * planes: 57 | self.shortcut = nn.Sequential( 58 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 59 | nn.BatchNorm2d(self.expansion * planes) 60 | ) 61 | 62 | def forward(self, x): 63 | out = F.relu(self.bn1(self.conv1(x))) 64 | out = F.relu(self.bn2(self.conv2(out))) 65 | out = self.bn3(self.conv3(out)) 66 | out += self.shortcut(x) 67 | preact = out 68 | out = F.relu(out) 69 | return out 70 | 71 | 72 | class ResNet(nn.Module): 73 | def __init__(self, block, num_blocks, num_classes=10, zero_init_residual=False): 74 | super(ResNet, self).__init__() 75 | self.in_planes = 64 76 | 77 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 78 | self.bn1 = nn.BatchNorm2d(64) 79 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 80 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 81 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 82 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 83 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 84 | self.linear = nn.Linear(512 * block.expansion, num_classes) 85 | 86 | for m in self.modules(): 87 | if isinstance(m, nn.Conv2d): 88 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 89 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 90 | nn.init.constant_(m.weight, 1) 91 | nn.init.constant_(m.bias, 0) 92 | 93 | # Zero-initialize the last BN in each residual branch, 94 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 95 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 96 | if zero_init_residual: 97 | for m in self.modules(): 98 | if isinstance(m, Bottleneck): 99 | nn.init.constant_(m.bn3.weight, 0) 100 | elif isinstance(m, BasicBlock): 101 | nn.init.constant_(m.bn2.weight, 0) 102 | 103 | def get_feat_modules(self): 104 | feat_m = nn.ModuleList([]) 105 | feat_m.append(self.conv1) 106 | feat_m.append(self.bn1) 107 | feat_m.append(self.layer1) 108 | feat_m.append(self.layer2) 109 | feat_m.append(self.layer3) 110 | feat_m.append(self.layer4) 111 | return feat_m 112 | 113 | def get_bn_before_relu(self): 114 | if isinstance(self.layer1[0], Bottleneck): 115 | bn1 = self.layer1[-1].bn3 116 | bn2 = self.layer2[-1].bn3 117 | bn3 = self.layer3[-1].bn3 118 | bn4 = self.layer4[-1].bn3 119 | elif isinstance(self.layer1[0], BasicBlock): 120 | bn1 = self.layer1[-1].bn2 121 | bn2 = self.layer2[-1].bn2 122 | bn3 = self.layer3[-1].bn2 123 | bn4 = self.layer4[-1].bn2 124 | else: 125 | raise NotImplementedError('ResNet unknown block error !!!') 126 | 127 | return [bn1, bn2, bn3, bn4] 128 | 129 | def _make_layer(self, block, planes, num_blocks, stride): 130 | strides = [stride] + [1] * (num_blocks - 1) 131 | layers = [] 132 | for i in range(num_blocks): 133 | stride = strides[i] 134 | layers.append(block(self.in_planes, planes, stride, i == num_blocks - 1)) 135 | self.in_planes = planes * block.expansion 136 | return nn.Sequential(*layers) 137 | 138 | def forward(self, x, is_feat=False, preact=False): 139 | out = F.relu(self.bn1(self.conv1(x))) 140 | f0 = out 141 | out = self.layer1(out) 142 | f1 = out 143 | out = self.layer2(out) 144 | f2 = out 145 | out = self.layer3(out) 146 | f3 = out 147 | out = self.layer4(out) 148 | f4 = out 149 | out = self.avgpool(out) 150 | out = out.view(out.size(0), -1) 151 | f5 = out 152 | out = self.linear(out) 153 | if is_feat: 154 | return [f1, f2, f3, f4], out 155 | else: 156 | return out 157 | 158 | 159 | class Auxiliary_Classifier(nn.Module): 160 | def __init__(self, block, num_blocks, num_classes=10, zero_init_residual=False): 161 | super(Auxiliary_Classifier, self).__init__() 162 | 163 | 164 | self.in_planes = 64 * block.expansion 165 | self.block_extractor1 = nn.Sequential(*[self._make_layer(block, 128, num_blocks[1], stride=2), 166 | self._make_layer(block, 256, num_blocks[2], stride=2), 167 | self._make_layer(block, 512, num_blocks[3], stride=2)]) 168 | 169 | self.in_planes = 128 * block.expansion 170 | self.block_extractor2 = nn.Sequential(*[self._make_layer(block, 256, num_blocks[2], stride=2), 171 | self._make_layer(block, 512, num_blocks[3], stride=2)]) 172 | 173 | self.in_planes = 256 * block.expansion 174 | self.block_extractor3 = nn.Sequential(*[self._make_layer(block, 512, num_blocks[3], stride=2)]) 175 | 176 | self.in_planes = 512 * block.expansion 177 | self.block_extractor4 = nn.Sequential(*[self._make_layer(block, 512, num_blocks[3], stride=1)]) 178 | 179 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 180 | self.fc1 = nn.Linear(512 * block.expansion, num_classes) 181 | self.fc2 = nn.Linear(512 * block.expansion, num_classes) 182 | self.fc3 = nn.Linear(512 * block.expansion, num_classes) 183 | self.fc4 = nn.Linear(512 * block.expansion, num_classes) 184 | 185 | for m in self.modules(): 186 | if isinstance(m, nn.Conv2d): 187 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 188 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 189 | nn.init.constant_(m.weight, 1) 190 | nn.init.constant_(m.bias, 0) 191 | 192 | def _make_layer(self, block, planes, num_blocks, stride): 193 | strides = [stride] + [1] * (num_blocks - 1) 194 | layers = [] 195 | for i in range(num_blocks): 196 | stride = strides[i] 197 | layers.append(block(self.in_planes, planes, stride, i == num_blocks - 1)) 198 | self.in_planes = planes * block.expansion 199 | return nn.Sequential(*layers) 200 | 201 | 202 | def forward(self, x): 203 | ss_logits = [] 204 | for i in range(len(x)): 205 | idx = i + 1 206 | out = getattr(self, 'block_extractor'+str(idx))(x[i]) 207 | out = self.avg_pool(out) 208 | out = out.view(out.size(0), -1) 209 | out = getattr(self, 'fc'+str(idx))(out) 210 | ss_logits.append(out) 211 | return ss_logits 212 | 213 | 214 | class ResNet_Auxiliary(nn.Module): 215 | def __init__(self, block, num_blocks, num_classes=10, zero_init_residual=False): 216 | super(ResNet_Auxiliary, self).__init__() 217 | self.backbone = ResNet(block, num_blocks, num_classes=num_classes, zero_init_residual=zero_init_residual) 218 | self.auxiliary_classifier = Auxiliary_Classifier(block, num_blocks, num_classes=num_classes*4, zero_init_residual=zero_init_residual) 219 | 220 | def forward(self, x, grad=False): 221 | if grad is False: 222 | feats, logit = self.backbone(x, is_feat=True) 223 | for i in range(len(feats)): 224 | feats[i] = feats[i].detach() 225 | else: 226 | feats, logit = self.backbone(x, is_feat=True) 227 | 228 | ss_logits = self.auxiliary_classifier(feats) 229 | return logit, ss_logits 230 | 231 | 232 | def ResNet18(**kwargs): 233 | return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 234 | 235 | 236 | def ResNet34(**kwargs): 237 | return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 238 | 239 | 240 | def ResNet50(**kwargs): 241 | return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 242 | 243 | def ResNet50_aux(**kwargs): 244 | return ResNet_Auxiliary(Bottleneck, [3, 4, 6, 3], **kwargs) 245 | 246 | def ResNet101(**kwargs): 247 | return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 248 | 249 | 250 | def ResNet152(**kwargs): 251 | return ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 252 | 253 | 254 | if __name__ == '__main__': 255 | import torch 256 | x = torch.randn(2, 3, 32, 32) 257 | net = ResNet50_aux(num_classes=100) 258 | logit, ss_logits = net(x) 259 | print(logit.size()) 260 | -------------------------------------------------------------------------------- /models/ShuffleNetv2.py: -------------------------------------------------------------------------------- 1 | '''ShuffleNetV2 in PyTorch. 2 | See the paper "ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" for more details. 3 | ''' 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | __all__ = ['ShuffleV2_aux', 'ShuffleV2'] 10 | class ShuffleBlock(nn.Module): 11 | def __init__(self, groups=2): 12 | super(ShuffleBlock, self).__init__() 13 | self.groups = groups 14 | 15 | def forward(self, x): 16 | '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]''' 17 | N, C, H, W = x.size() 18 | g = self.groups 19 | return x.view(N, g, C//g, H, W).permute(0, 2, 1, 3, 4).reshape(N, C, H, W) 20 | 21 | 22 | class SplitBlock(nn.Module): 23 | def __init__(self, ratio): 24 | super(SplitBlock, self).__init__() 25 | self.ratio = ratio 26 | 27 | def forward(self, x): 28 | c = int(x.size(1) * self.ratio) 29 | return x[:, :c, :, :], x[:, c:, :, :] 30 | 31 | 32 | class BasicBlock(nn.Module): 33 | def __init__(self, in_channels, split_ratio=0.5, is_last=False): 34 | super(BasicBlock, self).__init__() 35 | self.is_last = is_last 36 | self.split = SplitBlock(split_ratio) 37 | in_channels = int(in_channels * split_ratio) 38 | self.conv1 = nn.Conv2d(in_channels, in_channels, 39 | kernel_size=1, bias=False) 40 | self.bn1 = nn.BatchNorm2d(in_channels) 41 | self.conv2 = nn.Conv2d(in_channels, in_channels, 42 | kernel_size=3, stride=1, padding=1, groups=in_channels, bias=False) 43 | self.bn2 = nn.BatchNorm2d(in_channels) 44 | self.conv3 = nn.Conv2d(in_channels, in_channels, 45 | kernel_size=1, bias=False) 46 | self.bn3 = nn.BatchNorm2d(in_channels) 47 | self.shuffle = ShuffleBlock() 48 | 49 | def forward(self, x): 50 | x1, x2 = self.split(x) 51 | out = F.relu(self.bn1(self.conv1(x2))) 52 | out = self.bn2(self.conv2(out)) 53 | preact = self.bn3(self.conv3(out)) 54 | out = F.relu(preact) 55 | # out = F.relu(self.bn3(self.conv3(out))) 56 | preact = torch.cat([x1, preact], 1) 57 | out = torch.cat([x1, out], 1) 58 | out = self.shuffle(out) 59 | return out 60 | 61 | 62 | class DownBlock(nn.Module): 63 | def __init__(self, in_channels, out_channels, stride=2): 64 | super(DownBlock, self).__init__() 65 | mid_channels = out_channels // 2 66 | # left 67 | self.conv1 = nn.Conv2d(in_channels, in_channels, 68 | kernel_size=3, stride=stride, padding=1, groups=in_channels, bias=False) 69 | self.bn1 = nn.BatchNorm2d(in_channels) 70 | self.conv2 = nn.Conv2d(in_channels, mid_channels, 71 | kernel_size=1, bias=False) 72 | self.bn2 = nn.BatchNorm2d(mid_channels) 73 | # right 74 | self.conv3 = nn.Conv2d(in_channels, mid_channels, 75 | kernel_size=1, bias=False) 76 | self.bn3 = nn.BatchNorm2d(mid_channels) 77 | self.conv4 = nn.Conv2d(mid_channels, mid_channels, 78 | kernel_size=3, stride=stride, padding=1, groups=mid_channels, bias=False) 79 | self.bn4 = nn.BatchNorm2d(mid_channels) 80 | self.conv5 = nn.Conv2d(mid_channels, mid_channels, 81 | kernel_size=1, bias=False) 82 | self.bn5 = nn.BatchNorm2d(mid_channels) 83 | 84 | self.shuffle = ShuffleBlock() 85 | 86 | def forward(self, x): 87 | # left 88 | out1 = self.bn1(self.conv1(x)) 89 | out1 = F.relu(self.bn2(self.conv2(out1))) 90 | # right 91 | out2 = F.relu(self.bn3(self.conv3(x))) 92 | out2 = self.bn4(self.conv4(out2)) 93 | out2 = F.relu(self.bn5(self.conv5(out2))) 94 | # concat 95 | out = torch.cat([out1, out2], 1) 96 | out = self.shuffle(out) 97 | return out 98 | 99 | 100 | class ShuffleNetV2(nn.Module): 101 | def __init__(self, net_size, num_classes=100): 102 | super(ShuffleNetV2, self).__init__() 103 | out_channels = configs[net_size]['out_channels'] 104 | num_blocks = configs[net_size]['num_blocks'] 105 | 106 | # self.conv1 = nn.Conv2d(3, 24, kernel_size=3, 107 | # stride=1, padding=1, bias=False) 108 | self.conv1 = nn.Conv2d(3, 24, kernel_size=1, bias=False) 109 | self.bn1 = nn.BatchNorm2d(24) 110 | self.in_channels = 24 111 | self.layer1 = self._make_layer(out_channels[0], num_blocks[0]) 112 | self.layer2 = self._make_layer(out_channels[1], num_blocks[1]) 113 | self.layer3 = self._make_layer(out_channels[2], num_blocks[2]) 114 | self.conv2 = nn.Conv2d(out_channels[2], out_channels[3], 115 | kernel_size=1, stride=1, padding=0, bias=False) 116 | self.bn2 = nn.BatchNorm2d(out_channels[3]) 117 | self.linear = nn.Linear(out_channels[3], num_classes) 118 | 119 | def _make_layer(self, out_channels, num_blocks): 120 | layers = [DownBlock(self.in_channels, out_channels)] 121 | for i in range(num_blocks): 122 | layers.append(BasicBlock(out_channels, is_last=(i == num_blocks - 1))) 123 | self.in_channels = out_channels 124 | return nn.Sequential(*layers) 125 | 126 | def get_feat_modules(self): 127 | feat_m = nn.ModuleList([]) 128 | feat_m.append(self.conv1) 129 | feat_m.append(self.bn1) 130 | feat_m.append(self.layer1) 131 | feat_m.append(self.layer2) 132 | feat_m.append(self.layer3) 133 | return feat_m 134 | 135 | def get_bn_before_relu(self): 136 | raise NotImplementedError('ShuffleNetV2 currently is not supported for "Overhaul" teacher') 137 | 138 | def forward(self, x, is_feat=False, preact=False): 139 | out = F.relu(self.bn1(self.conv1(x))) 140 | # out = F.max_pool2d(out, 3, stride=2, padding=1) 141 | f0 = out 142 | out = self.layer1(out) 143 | f1 = out 144 | out = self.layer2(out) 145 | f2 = out 146 | out = self.layer3(out) 147 | f3 = out 148 | out = F.relu(self.bn2(self.conv2(out))) 149 | out = F.avg_pool2d(out, 4) 150 | out = out.view(out.size(0), -1) 151 | f4 = out 152 | out = self.linear(out) 153 | if is_feat: 154 | return [f1, f2, f3], out 155 | else: 156 | return out 157 | 158 | 159 | class Auxiliary_Classifier(nn.Module): 160 | def __init__(self, net_size, num_classes=100): 161 | super(Auxiliary_Classifier, self).__init__() 162 | out_channels = configs[net_size]['out_channels'] 163 | num_blocks = configs[net_size]['num_blocks'] 164 | 165 | self.in_channels = out_channels[0] 166 | self.block_extractor1 = nn.Sequential(*[self._make_layer(out_channels[1], num_blocks[1]), 167 | self._make_layer(out_channels[2], num_blocks[2]), 168 | nn.Conv2d(out_channels[2], out_channels[3], 169 | kernel_size=1, stride=1, padding=0, bias=False), 170 | nn.BatchNorm2d(out_channels[3]), 171 | nn.ReLU(inplace=True)]) 172 | 173 | self.in_channels = out_channels[1] 174 | self.block_extractor2 = nn.Sequential(*[self._make_layer(out_channels[2], num_blocks[2]), 175 | nn.Conv2d(out_channels[2], out_channels[3], 176 | kernel_size=1, stride=1, padding=0, bias=False), 177 | nn.BatchNorm2d(out_channels[3]), 178 | nn.ReLU(inplace=True)]) 179 | 180 | self.in_channels = out_channels[2] 181 | self.block_extractor3 = nn.Sequential(*[self._make_layer(out_channels[2], num_blocks[2], stride=1), 182 | nn.Conv2d(out_channels[2], out_channels[3], 183 | kernel_size=1, stride=1, padding=0, bias=False), 184 | nn.BatchNorm2d(out_channels[3]), 185 | nn.ReLU(inplace=True)]) 186 | 187 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 188 | self.fc1 = nn.Linear(out_channels[3], num_classes) 189 | self.fc2 = nn.Linear(out_channels[3], num_classes) 190 | self.fc3 = nn.Linear(out_channels[3], num_classes) 191 | 192 | def _make_layer(self, out_channels, num_blocks, stride=2): 193 | layers = [DownBlock(self.in_channels, out_channels, stride=stride)] 194 | for i in range(num_blocks): 195 | layers.append(BasicBlock(out_channels, is_last=(i == num_blocks - 1))) 196 | self.in_channels = out_channels 197 | return nn.Sequential(*layers) 198 | 199 | 200 | def forward(self, x): 201 | ss_logits = [] 202 | for i in range(len(x)): 203 | idx = i + 1 204 | out = getattr(self, 'block_extractor'+str(idx))(x[i]) 205 | out = self.avg_pool(out) 206 | out = out.view(out.size(0), -1) 207 | out = getattr(self, 'fc'+str(idx))(out) 208 | ss_logits.append(out) 209 | return ss_logits 210 | 211 | 212 | class ShuffleNetV2_Auxiliary(nn.Module): 213 | def __init__(self, net_size, num_classes=100): 214 | super(ShuffleNetV2_Auxiliary, self).__init__() 215 | self.backbone = ShuffleNetV2(net_size, num_classes=num_classes) 216 | self.auxiliary_classifier = Auxiliary_Classifier(net_size, num_classes=num_classes * 4) 217 | 218 | def forward(self, x, grad=False): 219 | feats, logit = self.backbone(x, is_feat=True) 220 | if grad is False: 221 | for i in range(len(feats)): 222 | feats[i] = feats[i].detach() 223 | ss_logits = self.auxiliary_classifier(feats) 224 | return logit, ss_logits 225 | 226 | 227 | configs = { 228 | 0.2: { 229 | 'out_channels': (40, 80, 160, 512), 230 | 'num_blocks': (3, 3, 3) 231 | }, 232 | 233 | 0.3: { 234 | 'out_channels': (40, 80, 160, 512), 235 | 'num_blocks': (3, 7, 3) 236 | }, 237 | 238 | 0.5: { 239 | 'out_channels': (48, 96, 192, 1024), 240 | 'num_blocks': (3, 7, 3) 241 | }, 242 | 243 | 1: { 244 | 'out_channels': (116, 232, 464, 1024), 245 | 'num_blocks': (3, 7, 3) 246 | }, 247 | 1.5: { 248 | 'out_channels': (176, 352, 704, 1024), 249 | 'num_blocks': (3, 7, 3) 250 | }, 251 | 2: { 252 | 'out_channels': (224, 488, 976, 2048), 253 | 'num_blocks': (3, 7, 3) 254 | } 255 | } 256 | 257 | 258 | def ShuffleV2(**kwargs): 259 | model = ShuffleNetV2(net_size=1, **kwargs) 260 | return model 261 | 262 | def ShuffleV2_aux(**kwargs): 263 | model = ShuffleNetV2_Auxiliary(net_size=1, **kwargs) 264 | return model 265 | 266 | if __name__ == '__main__': 267 | x = torch.randn(2, 3, 32, 32) 268 | net = ShuffleV2_aux(num_classes=100) 269 | logit, ss_logits = net(x) 270 | print(logit.size()) 271 | 272 | -------------------------------------------------------------------------------- /models/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | """ 2 | MobileNetV2 implementation used in 3 | 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import math 9 | 10 | __all__ = ['mobilenetv2_T_w', 'mobilenetV2', 'mobilenetV2_aux'] 11 | 12 | BN = None 13 | 14 | 15 | def conv_bn(inp, oup, stride): 16 | return nn.Sequential( 17 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 18 | nn.BatchNorm2d(oup), 19 | nn.ReLU(inplace=True) 20 | ) 21 | 22 | 23 | def conv_1x1_bn(inp, oup): 24 | return nn.Sequential( 25 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 26 | nn.BatchNorm2d(oup), 27 | nn.ReLU(inplace=True) 28 | ) 29 | 30 | 31 | class InvertedResidual(nn.Module): 32 | def __init__(self, inp, oup, stride, expand_ratio): 33 | super(InvertedResidual, self).__init__() 34 | self.blockname = None 35 | 36 | self.stride = stride 37 | assert stride in [1, 2] 38 | 39 | self.use_res_connect = self.stride == 1 and inp == oup 40 | 41 | self.conv = nn.Sequential( 42 | # pw 43 | nn.Conv2d(inp, inp * expand_ratio, 1, 1, 0, bias=False), 44 | nn.BatchNorm2d(inp * expand_ratio), 45 | nn.ReLU(inplace=True), 46 | # dw 47 | nn.Conv2d(inp * expand_ratio, inp * expand_ratio, 3, stride, 1, groups=inp * expand_ratio, bias=False), 48 | nn.BatchNorm2d(inp * expand_ratio), 49 | nn.ReLU(inplace=True), 50 | # pw-linear 51 | nn.Conv2d(inp * expand_ratio, oup, 1, 1, 0, bias=False), 52 | nn.BatchNorm2d(oup), 53 | ) 54 | self.names = ['0', '1', '2', '3', '4', '5', '6', '7'] 55 | 56 | def forward(self, x): 57 | t = x 58 | 59 | if self.use_res_connect: 60 | return t + self.conv(x) 61 | else: 62 | return self.conv(x) 63 | 64 | 65 | class MobileNetV2(nn.Module): 66 | """mobilenetV2""" 67 | def __init__(self, T, 68 | feature_dim, 69 | input_size=32, 70 | width_mult=1., 71 | remove_avg=False): 72 | super(MobileNetV2, self).__init__() 73 | self.remove_avg = remove_avg 74 | 75 | # setting of inverted residual blocks 76 | self.interverted_residual_setting = [ 77 | # t, c, n, s 78 | [1, 16, 1, 1], 79 | [T, 24, 2, 1], 80 | 81 | [T, 32, 3, 2], 82 | 83 | [T, 64, 4, 2], 84 | [T, 96, 3, 1], 85 | 86 | [T, 160, 3, 2], 87 | [T, 320, 1, 1], 88 | ] 89 | 90 | # building first layer 91 | assert input_size % 32 == 0 92 | input_channel = int(32 * width_mult) 93 | self.conv1 = conv_bn(3, input_channel,1) 94 | 95 | # building inverted residual blocks 96 | self.blocks = nn.ModuleList([]) 97 | for t, c, n, s in self.interverted_residual_setting: 98 | output_channel = int(c * width_mult) 99 | layers = [] 100 | strides = [s] + [1] * (n - 1) 101 | for stride in strides: 102 | layers.append( 103 | InvertedResidual(input_channel, output_channel, stride, t) 104 | ) 105 | input_channel = output_channel 106 | self.blocks.append(nn.Sequential(*layers)) 107 | 108 | self.last_channel = int(1280 * width_mult) if width_mult > 1.0 else 1280 109 | self.conv2 = conv_1x1_bn(input_channel, self.last_channel) 110 | 111 | self.avgpool = nn.AdaptiveAvgPool2d((1,1)) 112 | 113 | # building classifier 114 | #self.classifier = nn.Sequential( 115 | # # nn.Dropout(0.5), 116 | # nn.Linear(self.last_channel, feature_dim), 117 | #) 118 | self.classifier = nn.Linear(self.last_channel, feature_dim) 119 | 120 | self._initialize_weights() 121 | 122 | def get_bn_before_relu(self): 123 | bn1 = self.blocks[1][-1].conv[-1] 124 | bn2 = self.blocks[2][-1].conv[-1] 125 | bn3 = self.blocks[4][-1].conv[-1] 126 | bn4 = self.blocks[6][-1].conv[-1] 127 | return [bn1, bn2, bn3, bn4] 128 | 129 | def get_feat_modules(self): 130 | feat_m = nn.ModuleList([]) 131 | feat_m.append(self.conv1) 132 | feat_m.append(self.blocks) 133 | return feat_m 134 | 135 | def forward(self, x, is_feat=False, preact=False): 136 | 137 | out = self.conv1(x) 138 | out = self.blocks[0](out) 139 | out = self.blocks[1](out) 140 | f1 = out 141 | out = self.blocks[2](out) 142 | f2 = out 143 | out = self.blocks[3](out) 144 | out = self.blocks[4](out) 145 | f3 = out 146 | out = self.blocks[5](out) 147 | out = self.blocks[6](out) 148 | f4 = out 149 | 150 | out = self.conv2(out) 151 | 152 | if not self.remove_avg: 153 | out = self.avgpool(out) 154 | out = out.view(out.size(0), -1) 155 | f5 = out 156 | out = self.classifier(out) 157 | 158 | if is_feat: 159 | return [f1, f2, f3, f4], out 160 | else: 161 | return out 162 | 163 | def _initialize_weights(self): 164 | for m in self.modules(): 165 | if isinstance(m, nn.Conv2d): 166 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 167 | m.weight.data.normal_(0, math.sqrt(2. / n)) 168 | if m.bias is not None: 169 | m.bias.data.zero_() 170 | elif isinstance(m, nn.BatchNorm2d): 171 | m.weight.data.fill_(1) 172 | m.bias.data.zero_() 173 | elif isinstance(m, nn.Linear): 174 | n = m.weight.size(1) 175 | m.weight.data.normal_(0, 0.01) 176 | m.bias.data.zero_() 177 | 178 | 179 | class Auxiliary_Classifier(nn.Module): 180 | def __init__(self, T, 181 | feature_dim, 182 | input_size=32, 183 | width_mult=1., 184 | remove_avg=False): 185 | super(Auxiliary_Classifier, self).__init__() 186 | 187 | self.remove_avg = remove_avg 188 | self.width_mult = width_mult 189 | # setting of inverted residual blocks 190 | interverted_residual_setting1 = [ 191 | [T, 32, 3, 2], 192 | 193 | [T, 64, 4, 2], 194 | [T, 96, 3, 1], 195 | 196 | [T, 160, 3, 2], 197 | [T, 320, 1, 1], 198 | ] 199 | self.block_extractor1 = self._make_layer(input_channel=12, interverted_residual_setting=interverted_residual_setting1) 200 | 201 | interverted_residual_setting2 = [ 202 | [T, 64, 4, 2], 203 | [T, 96, 3, 1], 204 | 205 | [T, 160, 3, 2], 206 | [T, 320, 1, 1], 207 | ] 208 | self.block_extractor2 = self._make_layer(input_channel=16, interverted_residual_setting=interverted_residual_setting2) 209 | 210 | interverted_residual_setting3 = [ 211 | 212 | [T, 160, 3, 2], 213 | [T, 320, 1, 1], 214 | ] 215 | self.block_extractor3 = self._make_layer(input_channel=48, interverted_residual_setting=interverted_residual_setting3) 216 | 217 | interverted_residual_setting4 = [ 218 | 219 | [T, 160, 3, 1], 220 | [T, 320, 1, 1], 221 | ] 222 | self.block_extractor4 = self._make_layer(input_channel=160, interverted_residual_setting=interverted_residual_setting4) 223 | 224 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 225 | self.last_channel = int(1280 * width_mult) if width_mult > 1.0 else 1280 226 | 227 | self.conv2_1 = conv_1x1_bn(160, self.last_channel) 228 | self.conv2_2 = conv_1x1_bn(160, self.last_channel) 229 | self.conv2_3 = conv_1x1_bn(160, self.last_channel) 230 | self.conv2_4 = conv_1x1_bn(160, self.last_channel) 231 | 232 | 233 | self.fc1 = nn.Linear(self.last_channel, feature_dim) 234 | self.fc2 = nn.Linear(self.last_channel, feature_dim) 235 | self.fc3 = nn.Linear(self.last_channel, feature_dim) 236 | self.fc4 = nn.Linear(self.last_channel, feature_dim) 237 | 238 | self._initialize_weights() 239 | 240 | def _make_layer(self, input_channel, interverted_residual_setting): 241 | # building inverted residual blocks 242 | blocks = [] 243 | for t, c, n, s in interverted_residual_setting: 244 | output_channel = int(c * self.width_mult) 245 | layers = [] 246 | strides = [s] + [1] * (n - 1) 247 | for stride in strides: 248 | layers.append( 249 | InvertedResidual(input_channel, output_channel, stride, t) 250 | ) 251 | input_channel = output_channel 252 | blocks.append(nn.Sequential(*layers)) 253 | 254 | return nn.Sequential(*blocks) 255 | 256 | def _initialize_weights(self): 257 | for m in self.modules(): 258 | if isinstance(m, nn.Conv2d): 259 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 260 | m.weight.data.normal_(0, math.sqrt(2. / n)) 261 | if m.bias is not None: 262 | m.bias.data.zero_() 263 | elif isinstance(m, nn.BatchNorm2d): 264 | m.weight.data.fill_(1) 265 | m.bias.data.zero_() 266 | elif isinstance(m, nn.Linear): 267 | n = m.weight.size(1) 268 | m.weight.data.normal_(0, 0.01) 269 | m.bias.data.zero_() 270 | 271 | 272 | def forward(self, x): 273 | ss_logits = [] 274 | ss_feats = [] 275 | 276 | for i in range(len(x)): 277 | idx = i + 1 278 | out = getattr(self, 'block_extractor'+str(idx))(x[i]) 279 | out = getattr(self, 'conv2_'+str(idx))(out) 280 | out = self.avg_pool(out) 281 | out = out.view(out.size(0), -1) 282 | ss_feats.append(out) 283 | out = getattr(self, 'fc'+str(idx))(out) 284 | ss_logits.append(out) 285 | 286 | return ss_feats, ss_logits 287 | 288 | 289 | class MobileNetv2_Auxiliary(nn.Module): 290 | def __init__(self, T, W, feature_dim=100): 291 | super(MobileNetv2_Auxiliary, self).__init__() 292 | self.backbone = MobileNetV2(T=T, feature_dim=feature_dim, width_mult=W) 293 | self.auxiliary_classifier = Auxiliary_Classifier(T=T, feature_dim=4 * feature_dim, width_mult=W) 294 | 295 | def forward(self, x, grad=False, att=False): 296 | feats, logit = self.backbone(x, is_feat=True) 297 | if grad is False: 298 | for i in range(len(feats)): 299 | feats[i] = feats[i].detach() 300 | ss_feats, ss_logits = self.auxiliary_classifier(feats) 301 | if att is False: 302 | return logit, ss_logits 303 | else: 304 | return logit, ss_logits, feats 305 | 306 | 307 | def mobilenetv2_T_w(T, W, feature_dim=100): 308 | model = MobileNetV2(T=T, feature_dim=feature_dim, width_mult=W) 309 | return model 310 | 311 | 312 | def mobilenetV2(num_classes): 313 | return mobilenetv2_T_w(6, 0.5, num_classes) 314 | 315 | def mobilenetV2_aux(num_classes): 316 | return MobileNetv2_Auxiliary(6, 0.5, num_classes) 317 | 318 | 319 | if __name__ == '__main__': 320 | x = torch.randn(2, 3, 32, 32) 321 | net = mobilenetV2_aux(100) 322 | logit, ss_logits = net(x) 323 | print(logit) 324 | print(ss_logits) 325 | from utils import cal_param_size, cal_multi_adds 326 | print('Params: %.2fM, Multi-adds: %.3fM' 327 | % (cal_param_size(net) / 1e6, cal_multi_adds(net, (2, 3, 32, 32)) / 1e6)) 328 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /eval_rep.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.backends.cudnn as cudnn 5 | import torch.nn.functional as F 6 | 7 | import os 8 | import shutil 9 | import argparse 10 | import numpy as np 11 | 12 | 13 | import models 14 | import torchvision 15 | import torchvision.transforms as transforms 16 | from utils import cal_param_size, cal_multi_adds 17 | 18 | 19 | from bisect import bisect_right 20 | import time 21 | import math 22 | 23 | 24 | 25 | parser = argparse.ArgumentParser(description='PyTorch CIFAR Training') 26 | parser.add_argument('--data', default='./data/', type=str, help='Dataset directory') 27 | parser.add_argument('--dataset', default='STL-10', type=str, help='Dataset name') 28 | parser.add_argument('--arch', default='mobilenetV2', type=str, help='network architecture') 29 | parser.add_argument('--s-path', default='ShuffleV1', type=str, help='pre-trained checkpoint dir') 30 | parser.add_argument('--init-lr', default=0.1, type=float, help='learning rate') 31 | parser.add_argument('--weight-decay', default=0, type=float, help='weight decay') 32 | parser.add_argument('--milestones', default=[30, 60, 90], type=list, help='milestones for lr-multistep') 33 | parser.add_argument('--warmup-epoch', default=0, type=int, help='warmup epoch') 34 | parser.add_argument('--epochs', type=int, default=100, help='number of epochs to train') 35 | parser.add_argument('--batch-size', type=int, default=64, help='batch size') 36 | parser.add_argument('--num-workers', type=int, default=8, help='number workers') 37 | parser.add_argument('--gpu-id', type=str, default='1') 38 | parser.add_argument('--manual_seed', type=int, default=0) 39 | parser.add_argument('--kd_T', type=float, default=3, help='temperature for KD distillation') 40 | parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint') 41 | parser.add_argument('--evaluate', '-e', action='store_true', help='evaluate model') 42 | parser.add_argument('--checkpoint-dir', default='./checkpoint', type=str, help='checkpoint dir') 43 | 44 | 45 | # global hyperparameter set 46 | args = parser.parse_args() 47 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 48 | log_txt = 'result/'+ str(os.path.basename(__file__).split('.')[0]) + '_'+\ 49 | 'arch' + '_' + args.arch + '_'+\ 50 | 'dataset' + '_' + args.dataset + '_'+\ 51 | 'seed'+ str(args.manual_seed) +'.txt' 52 | 53 | log_dir = str(os.path.basename(__file__).split('.')[0]) + '_'+\ 54 | 'arch'+ '_' + args.arch + '_'+\ 55 | 'dataset' + '_' + args.dataset + '_'+\ 56 | 'seed'+ str(args.manual_seed) 57 | 58 | 59 | args.checkpoint_dir = os.path.join(args.checkpoint_dir, log_dir) 60 | if not os.path.isdir(args.checkpoint_dir): 61 | os.makedirs(args.checkpoint_dir) 62 | 63 | 64 | if args.resume is False and args.evaluate is False: 65 | with open(log_txt, 'a+') as f: 66 | f.seek(0) 67 | f.truncate() 68 | f.write("==========\nArgs:{}\n==========".format(args) + '\n') 69 | 70 | np.random.seed(args.manual_seed) 71 | torch.manual_seed(args.manual_seed) 72 | torch.cuda.manual_seed_all(args.manual_seed) 73 | torch.set_printoptions(precision=4) 74 | 75 | 76 | if args.dataset == 'STL-10': 77 | num_classes = 10 78 | trainset = torchvision.datasets.STL10(root=args.data, split='train', download=True, 79 | transform=transforms.Compose([ 80 | transforms.RandomResizedCrop(32), 81 | transforms.RandomHorizontalFlip(), 82 | transforms.ToTensor(), 83 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 84 | ])) 85 | 86 | testset = torchvision.datasets.STL10(root=args.data, split='test', download=True, 87 | transform=transforms.Compose([ 88 | transforms.Resize(32), 89 | transforms.ToTensor(), 90 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 91 | ])) 92 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, 93 | pin_memory=(torch.cuda.is_available())) 94 | 95 | testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=False, 96 | pin_memory=(torch.cuda.is_available())) 97 | 98 | elif args.dataset == 'TinyImageNet': 99 | num_classes = 200 100 | train_set = torchvision.datasets.ImageFolder( 101 | root=os.path.join(args.data, "train"), 102 | transform=transforms.Compose([ 103 | transforms.RandomResizedCrop(32), 104 | transforms.RandomHorizontalFlip(), 105 | transforms.ToTensor(), 106 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 107 | ])) 108 | 109 | test_set = torchvision.datasets.ImageFolder( 110 | root=os.path.join(args.data, "val"), 111 | transform=transforms.Compose([ 112 | transforms.Resize(32), 113 | transforms.ToTensor(), 114 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 115 | ])) 116 | 117 | trainloader = torch.utils.data.DataLoader( 118 | train_set, batch_size=args.batch_size, shuffle=True, 119 | num_workers=args.num_workers, pin_memory=True) 120 | 121 | testloader = torch.utils.data.DataLoader( 122 | test_set, batch_size=args.batch_size, shuffle=False, 123 | num_workers=args.num_workers, pin_memory=True) 124 | # -------------------------------------------------------------------------------------------- 125 | # Model 126 | print('==> Building model..') 127 | net = getattr(models, args.arch)(num_classes=num_classes) 128 | net.eval() 129 | resolution = (1, 3, 32, 32) 130 | print('Arch: %s, Params: %.2fM, Multi-adds: %.2fG' 131 | % (args.arch, cal_param_size(net)/1e6, cal_multi_adds(net, resolution)/1e9)) 132 | del(net) 133 | 134 | 135 | print('load pre-trained weights from: {}'.format(args.s_path)) 136 | 137 | model = getattr(models, args.arch) 138 | net = model(num_classes=num_classes).cuda() 139 | net_dict = net.state_dict() 140 | checkpoint = torch.load(args.s_path, 141 | map_location=torch.device('cpu'))['net'] 142 | for key in net.state_dict().keys(): 143 | if key.startswith('classifier'): 144 | continue 145 | net_dict[key] = checkpoint['backbone.'+key] 146 | 147 | 148 | net.load_state_dict(net_dict) 149 | optimizer = optim.SGD(net.classifier.parameters(), lr=args.init_lr, 150 | momentum=0.9, weight_decay=args.weight_decay, nesterov=True) 151 | net.eval() 152 | for name, p in net.named_parameters(): 153 | if name.startswith('classifier') is False: 154 | p.requires_grad = False 155 | 156 | net = torch.nn.DataParallel(net) 157 | cudnn.benchmark = True 158 | 159 | 160 | class DistillKL(nn.Module): 161 | """Distilling the Knowledge in a Neural Network""" 162 | def __init__(self, T): 163 | super(DistillKL, self).__init__() 164 | self.T = T 165 | 166 | def forward(self, y_s, y_t): 167 | p_s = F.log_softmax(y_s/self.T, dim=1) 168 | p_t = F.softmax(y_t/self.T, dim=1) 169 | loss = F.kl_div(p_s, p_t, reduction='batchmean') * (self.T**2) 170 | return loss 171 | 172 | 173 | def correct_num(output, target, topk=(1,)): 174 | """Computes the precision@k for the specified values of k""" 175 | maxk = max(topk) 176 | batch_size = target.size(0) 177 | 178 | _, pred = output.topk(maxk, 1, True, True) 179 | correct = pred.eq(target.view(-1, 1).expand_as(pred)) 180 | 181 | res = [] 182 | for k in topk: 183 | correct_k = correct[:, :k].float().sum() 184 | res.append(correct_k) 185 | return res 186 | 187 | 188 | def adjust_lr(optimizer, epoch, args, step=0, all_iters_per_epoch=0, eta_min=0.): 189 | cur_lr = 0. 190 | if epoch < args.warmup_epoch: 191 | cur_lr = args.init_lr * float(1 + step + epoch*all_iters_per_epoch)/(args.warmup_epoch *all_iters_per_epoch) 192 | else: 193 | epoch = epoch - args.warmup_epoch 194 | cur_lr = args.init_lr * 0.1 ** bisect_right(args.milestones, epoch) 195 | 196 | for idx, param_group in enumerate(optimizer.param_groups): 197 | param_group['lr'] = cur_lr 198 | return cur_lr 199 | 200 | 201 | # Training 202 | def train(epoch, criterion_list, optimizer): 203 | train_loss = 0. 204 | train_loss_cls = 0. 205 | 206 | top1_num = 0 207 | top5_num = 0 208 | total = 0 209 | 210 | if epoch >= args.warmup_epoch: 211 | lr = adjust_lr(optimizer, epoch, args) 212 | 213 | start_time = time.time() 214 | criterion_cls = criterion_list[0] 215 | criterion_div = criterion_list[1] 216 | 217 | net.module.classifier.train() 218 | for batch_idx, (input, target) in enumerate(trainloader): 219 | batch_start_time = time.time() 220 | input = input.float().cuda() 221 | target = target.cuda() 222 | 223 | if epoch < args.warmup_epoch: 224 | lr = adjust_lr(optimizer, epoch, args, batch_idx, len(trainloader)) 225 | 226 | optimizer.zero_grad() 227 | logits = net(input) 228 | 229 | loss_cls = criterion_cls(logits, target) 230 | 231 | loss = loss_cls 232 | loss.backward() 233 | optimizer.step() 234 | 235 | train_loss += loss.item() / len(trainloader) 236 | train_loss_cls += loss_cls.item() / len(trainloader) 237 | 238 | top1, top5 = correct_num(logits, target, topk=(1, 5)) 239 | top1_num += top1 240 | top5_num += top5 241 | 242 | total += target.size(0) 243 | 244 | print('Epoch:{}, batch_idx:{}/{}, lr:{:.5f}, Duration:{:.2f}'.format(epoch, batch_idx, len(trainloader), lr, time.time()-batch_start_time)) 245 | 246 | with open(log_txt, 'a+') as f: 247 | f.write('Epoch:{}\t lr:{:.4f}\t duration:{:.3f}' 248 | '\n train_loss:{:.5f}\t train_loss_cls:{:.5f}' 249 | '\nTrain Top-1 accuracy: {}\n' 250 | .format(epoch, lr, time.time() - start_time, 251 | train_loss, train_loss_cls, 252 | str(top1_num.item()/total))) 253 | 254 | 255 | 256 | def test(epoch, criterion_cls): 257 | global best_acc 258 | test_loss_cls = 0. 259 | 260 | top1_num = 0 261 | top5_num = 0 262 | total = 0 263 | 264 | net.eval() 265 | with torch.no_grad(): 266 | for batch_idx, (inputs, target) in enumerate(testloader): 267 | batch_start_time = time.time() 268 | input, target = inputs.cuda(), target.cuda() 269 | 270 | logits = net(input) 271 | loss_cls = torch.tensor(0.).cuda() 272 | loss_cls = criterion_cls(logits, target) 273 | 274 | test_loss_cls += loss_cls.item()/ len(testloader) 275 | 276 | top1, top5 = correct_num(logits, target, topk=(1, 5)) 277 | top1_num += top1 278 | top5_num += top5 279 | total += target.size(0) 280 | 281 | 282 | print('Epoch:{}, batch_idx:{}/{}, Duration:{:.2f}'.format(epoch, batch_idx, len(testloader), time.time()-batch_start_time)) 283 | 284 | with open(log_txt, 'a+') as f: 285 | f.write('test epoch:{}\t test_loss_cls:{:.5f}\nTest Top-1 accuracy: {}\n' 286 | .format(epoch, test_loss_cls, str(top1_num.item()/total))) 287 | 288 | return top1_num.item()/total 289 | 290 | 291 | if __name__ == '__main__': 292 | best_acc = 0. # best test accuracy 293 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 294 | criterion_cls = nn.CrossEntropyLoss() 295 | criterion_div = DistillKL(args.kd_T) 296 | 297 | if args.evaluate: 298 | print('load pre-trained weights from: {}'.format(os.path.join(args.checkpoint_dir, str(model.__name__) + '.pth.tar'))) 299 | checkpoint = torch.load(os.path.join(args.checkpoint_dir, str(model.__name__) + '.pth.tar'), 300 | map_location=torch.device('cpu')) 301 | exit() 302 | net.module.load_state_dict(checkpoint['net']) 303 | best_acc = checkpoint['acc'] 304 | test(start_epoch, criterion_cls) 305 | else: 306 | criterion_list = nn.ModuleList([]) 307 | criterion_list.append(criterion_cls) # classification loss 308 | criterion_list.append(criterion_div) # KL divergence loss, original knowledge distillation 309 | criterion_list.cuda() 310 | 311 | for epoch in range(start_epoch, args.epochs): 312 | train(epoch, criterion_list, optimizer) 313 | acc = test(epoch, criterion_cls) 314 | if acc > best_acc: 315 | best_acc = acc 316 | 317 | with open(log_txt, 'a+') as f: 318 | f.write('best_acc:{}'.format(best_acc)) 319 | print('best_acc:{}'.format(best_acc)) 320 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | '''Resnet for cifar dataset. 4 | Ported form 5 | https://github.com/facebook/fb.resnet.torch 6 | and 7 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 8 | (c) YANG, Wei 9 | ''' 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import math 13 | 14 | 15 | __all__ = ['resnet56_aux', 'resnet20_aux', 'resnet32x4_aux', 'resnet8x4_aux', 'resnet8', 'resnet8x4', 'resnet20'] 16 | 17 | 18 | def conv3x3(in_planes, out_planes, stride=1): 19 | """3x3 convolution with padding""" 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 21 | padding=1, bias=False) 22 | 23 | 24 | class BasicBlock(nn.Module): 25 | expansion = 1 26 | 27 | def __init__(self, inplanes, planes, stride=1, downsample=None, is_last=False): 28 | super(BasicBlock, self).__init__() 29 | self.is_last = is_last 30 | self.conv1 = conv3x3(inplanes, planes, stride) 31 | self.bn1 = nn.BatchNorm2d(planes) 32 | self.relu = nn.ReLU(inplace=True) 33 | self.conv2 = conv3x3(planes, planes) 34 | self.bn2 = nn.BatchNorm2d(planes) 35 | self.downsample = downsample 36 | self.stride = stride 37 | 38 | def forward(self, x): 39 | 40 | residual = x 41 | out = self.conv1(x) 42 | out = self.bn1(out) 43 | out = self.relu(out) 44 | 45 | out = self.conv2(out) 46 | out = self.bn2(out) 47 | 48 | if self.downsample is not None: 49 | residual = self.downsample(x) 50 | 51 | out += residual 52 | preact = out 53 | out = F.relu(out) 54 | return out 55 | 56 | 57 | class Bottleneck(nn.Module): 58 | expansion = 4 59 | 60 | def __init__(self, inplanes, planes, stride=1, downsample=None, is_last=False): 61 | super(Bottleneck, self).__init__() 62 | self.is_last = is_last 63 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 64 | self.bn1 = nn.BatchNorm2d(planes) 65 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 66 | padding=1, bias=False) 67 | self.bn2 = nn.BatchNorm2d(planes) 68 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 69 | self.bn3 = nn.BatchNorm2d(planes * 4) 70 | self.relu = nn.ReLU(inplace=True) 71 | self.downsample = downsample 72 | self.stride = stride 73 | 74 | def forward(self, x): 75 | residual = x 76 | 77 | out = self.conv1(x) 78 | out = self.bn1(out) 79 | out = self.relu(out) 80 | 81 | out = self.conv2(out) 82 | out = self.bn2(out) 83 | out = self.relu(out) 84 | 85 | out = self.conv3(out) 86 | out = self.bn3(out) 87 | 88 | if self.downsample is not None: 89 | residual = self.downsample(x) 90 | 91 | out += residual 92 | preact = out 93 | out = F.relu(out) 94 | return out 95 | 96 | 97 | class ResNet(nn.Module): 98 | 99 | def __init__(self, depth, num_filters, block_name='BasicBlock', num_classes=10): 100 | super(ResNet, self).__init__() 101 | # Model type specifies number of layers for CIFAR-10 model 102 | if block_name.lower() == 'basicblock': 103 | assert (depth - 2) % 6 == 0, 'When use basicblock, depth should be 6n+2, e.g. 20, 32, 44, 56, 110, 1202' 104 | n = (depth - 2) // 6 105 | block = BasicBlock 106 | elif block_name.lower() == 'bottleneck': 107 | assert (depth - 2) % 9 == 0, 'When use bottleneck, depth should be 9n+2, e.g. 20, 29, 47, 56, 110, 1199' 108 | n = (depth - 2) // 9 109 | block = Bottleneck 110 | else: 111 | raise ValueError('block_name shoule be Basicblock or Bottleneck') 112 | 113 | self.inplanes = num_filters[0] 114 | self.conv1 = nn.Conv2d(3, num_filters[0], kernel_size=3, padding=1, 115 | bias=False) 116 | self.bn1 = nn.BatchNorm2d(num_filters[0]) 117 | self.relu = nn.ReLU(inplace=True) 118 | self.layer1 = self._make_layer(block, num_filters[1], n) 119 | self.layer2 = self._make_layer(block, num_filters[2], n, stride=2) 120 | self.layer3 = self._make_layer(block, num_filters[3], n, stride=2) 121 | self.avgpool = nn.AdaptiveAvgPool2d((1,1)) 122 | self.fc = nn.Linear(num_filters[3] * block.expansion, num_classes) 123 | 124 | for m in self.modules(): 125 | if isinstance(m, nn.Conv2d): 126 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 127 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 128 | nn.init.constant_(m.weight, 1) 129 | nn.init.constant_(m.bias, 0) 130 | 131 | def _make_layer(self, block, planes, blocks, stride=1): 132 | downsample = None 133 | if stride != 1 or self.inplanes != planes * block.expansion: 134 | downsample = nn.Sequential( 135 | nn.Conv2d(self.inplanes, planes * block.expansion, 136 | kernel_size=1, stride=stride, bias=False), 137 | nn.BatchNorm2d(planes * block.expansion), 138 | ) 139 | 140 | layers = list([]) 141 | layers.append(block(self.inplanes, planes, stride, downsample)) 142 | self.inplanes = planes * block.expansion 143 | for i in range(1, blocks): 144 | layers.append(block(self.inplanes, planes)) 145 | 146 | return nn.Sequential(*layers) 147 | 148 | def get_feat_modules(self): 149 | feat_m = nn.ModuleList([]) 150 | feat_m.append(self.conv1) 151 | feat_m.append(self.bn1) 152 | feat_m.append(self.relu) 153 | feat_m.append(self.layer1) 154 | feat_m.append(self.layer2) 155 | feat_m.append(self.layer3) 156 | return feat_m 157 | 158 | def get_bn_before_relu(self): 159 | if isinstance(self.layer1[0], Bottleneck): 160 | bn1 = self.layer1[-1].bn3 161 | bn2 = self.layer2[-1].bn3 162 | bn3 = self.layer3[-1].bn3 163 | elif isinstance(self.layer1[0], BasicBlock): 164 | bn1 = self.layer1[-1].bn2 165 | bn2 = self.layer2[-1].bn2 166 | bn3 = self.layer3[-1].bn2 167 | else: 168 | raise NotImplementedError('ResNet unknown block error !!!') 169 | 170 | return [bn1, bn2, bn3] 171 | 172 | def forward(self, x, is_feat=False): 173 | x = self.conv1(x) 174 | x = self.bn1(x) 175 | x = self.relu(x) # 32x32 176 | f0 = x 177 | 178 | x = self.layer1(x) # 32x32 179 | f1 = x 180 | x = self.layer2(x) # 16x16 181 | f2 = x 182 | x = self.layer3(x) # 8x8 183 | f3 = x 184 | 185 | x = self.avgpool(x) 186 | x = x.view(x.size(0), -1) 187 | f4 = x 188 | x = self.fc(x) 189 | 190 | if is_feat: 191 | return [f1, f2, f3], x 192 | else: 193 | return x 194 | 195 | class Auxiliary_Classifier(nn.Module): 196 | def __init__(self, depth, num_filters, block_name='BasicBlock', num_classes=100): 197 | super(Auxiliary_Classifier, self).__init__() 198 | if block_name.lower() == 'basicblock': 199 | assert (depth - 2) % 6 == 0, 'When use basicblock, depth should be 6n+2, e.g. 20, 32, 44, 56, 110, 1202' 200 | n = (depth - 2) // 6 201 | block = BasicBlock 202 | elif block_name.lower() == 'bottleneck': 203 | assert (depth - 2) % 9 == 0, 'When use bottleneck, depth should be 9n+2, e.g. 20, 29, 47, 56, 110, 1199' 204 | n = (depth - 2) // 9 205 | block = Bottleneck 206 | else: 207 | raise ValueError('block_name shoule be Basicblock or Bottleneck') 208 | 209 | self.inplanes = num_filters[1] * block.expansion 210 | self.block_extractor1 = nn.Sequential(*[self._make_layer(block, num_filters[2], n, stride=2), 211 | self._make_layer(block, num_filters[3], n, stride=2)]) 212 | self.inplanes = num_filters[2] * block.expansion 213 | self.block_extractor2 = nn.Sequential(*[self._make_layer(block, num_filters[3], n, stride=2)]) 214 | self.inplanes = num_filters[3] * block.expansion 215 | self.block_extractor3 = nn.Sequential(*[self._make_layer(block, num_filters[3], n, stride=1)]) 216 | 217 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 218 | self.fc1 = nn.Linear(num_filters[3] * block.expansion, num_classes) 219 | self.fc2 = nn.Linear(num_filters[3] * block.expansion, num_classes) 220 | self.fc3 = nn.Linear(num_filters[3] * block.expansion, num_classes) 221 | 222 | for m in self.modules(): 223 | if isinstance(m, nn.Conv2d): 224 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 225 | m.weight.data.normal_(0, math.sqrt(2. / n)) 226 | elif isinstance(m, nn.BatchNorm2d): 227 | m.weight.data.fill_(1) 228 | m.bias.data.zero_() 229 | elif isinstance(m, nn.Linear): 230 | m.bias.data.zero_() 231 | 232 | 233 | def _make_layer(self, block, planes, blocks, stride=1): 234 | downsample = None 235 | if stride != 1 or self.inplanes != planes * block.expansion: 236 | downsample = nn.Sequential( 237 | nn.Conv2d(self.inplanes, planes * block.expansion, 238 | kernel_size=1, stride=stride, bias=False), 239 | nn.BatchNorm2d(planes * block.expansion), 240 | ) 241 | 242 | layers = list([]) 243 | layers.append(block(self.inplanes, planes, stride, downsample, is_last=(blocks == 1))) 244 | self.inplanes = planes * block.expansion 245 | for i in range(1, blocks): 246 | layers.append(block(self.inplanes, planes, is_last=(i == blocks-1))) 247 | 248 | return nn.Sequential(*layers) 249 | 250 | 251 | def forward(self, x): 252 | ss_logits = [] 253 | ss_feats = [] 254 | for i in range(len(x)): 255 | idx = i + 1 256 | out = getattr(self, 'block_extractor'+str(idx))(x[i]) 257 | out = self.avg_pool(out) 258 | out = out.view(out.size(0), -1) 259 | ss_feats.append(out) 260 | out = getattr(self, 'fc'+str(idx))(out) 261 | ss_logits.append(out) 262 | 263 | return ss_feats, ss_logits 264 | 265 | 266 | class ResNet_Auxiliary(nn.Module): 267 | def __init__(self, depth, num_filters, block_name='BasicBlock', num_classes=10): 268 | super(ResNet_Auxiliary, self).__init__() 269 | self.backbone = ResNet(depth, num_filters, block_name, num_classes) 270 | self.auxiliary_classifier = Auxiliary_Classifier(depth, num_filters, block_name, num_classes=num_classes * 4) 271 | 272 | def forward(self, x, grad=False, att=False): 273 | feats, logit = self.backbone(x, is_feat=True) 274 | if grad is False: 275 | for i in range(len(feats)): 276 | feats[i] = feats[i].detach() 277 | ss_feats, ss_logits = self.auxiliary_classifier(feats) 278 | if att is False: 279 | return logit, ss_logits 280 | else: 281 | return logit, ss_logits, feats 282 | 283 | 284 | def resnet8(**kwargs): 285 | return ResNet(8, [16, 16, 32, 64], 'basicblock', **kwargs) 286 | 287 | 288 | def resnet14(**kwargs): 289 | return ResNet(14, [16, 16, 32, 64], 'basicblock', **kwargs) 290 | 291 | def resnet20(**kwargs): 292 | return ResNet(20, [16, 16, 32, 64], 'basicblock', **kwargs) 293 | 294 | def resnet20_aux(**kwargs): 295 | return ResNet_Auxiliary(20, [16, 16, 32, 64], 'basicblock', **kwargs) 296 | 297 | 298 | def resnet14x05(**kwargs): 299 | return ResNet(14, [8, 8, 16, 32], 'basicblock', **kwargs) 300 | 301 | def resnet20x05(**kwargs): 302 | return ResNet(20, [8, 8, 16, 32], 'basicblock', **kwargs) 303 | 304 | def resnet20x0375(**kwargs): 305 | return ResNet(20, [6, 6, 12, 24], 'basicblock', **kwargs) 306 | 307 | 308 | 309 | def resnet32(**kwargs): 310 | return ResNet(32, [16, 16, 32, 64], 'basicblock', **kwargs) 311 | 312 | 313 | def resnet44(**kwargs): 314 | return ResNet(44, [16, 16, 32, 64], 'basicblock', **kwargs) 315 | 316 | 317 | def resnet56(**kwargs): 318 | return ResNet(56, [16, 16, 32, 64], 'basicblock', **kwargs) 319 | 320 | def resnet56_aux(**kwargs): 321 | return ResNet_Auxiliary(56, [16, 16, 32, 64], 'basicblock', **kwargs) 322 | 323 | def resnet110(**kwargs): 324 | return ResNet(110, [16, 16, 32, 64], 'basicblock', **kwargs) 325 | 326 | 327 | def resnet8x4(**kwargs): 328 | return ResNet(8, [32, 64, 128, 256], 'basicblock', **kwargs) 329 | 330 | def resnet8x4_aux(**kwargs): 331 | return ResNet_Auxiliary(8, [32, 64, 128, 256], 'basicblock', **kwargs) 332 | 333 | def resnet32x4(**kwargs): 334 | return ResNet(32, [32, 64, 128, 256], 'basicblock', **kwargs) 335 | 336 | 337 | def resnet32x4_aux(**kwargs): 338 | return ResNet_Auxiliary(32, [32, 64, 128, 256], 'basicblock', **kwargs) 339 | 340 | 341 | if __name__ == '__main__': 342 | import torch 343 | x = torch.randn(2, 3, 32, 32) 344 | net = resnet32x4_aux(num_classes=100) 345 | logit, ss_logits = net(x) 346 | print(logit.size()) 347 | 348 | -------------------------------------------------------------------------------- /models/vgg.py: -------------------------------------------------------------------------------- 1 | '''VGG for CIFAR10. FC layers are removed. 2 | (c) YANG, Wei 3 | ''' 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import math 7 | 8 | 9 | __all__ = ['vgg13_bn_aux'] 10 | 11 | 12 | model_urls = { 13 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 14 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 15 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 16 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 17 | } 18 | 19 | 20 | class VGG(nn.Module): 21 | 22 | def __init__(self, cfg, batch_norm=False, num_classes=1000): 23 | super(VGG, self).__init__() 24 | self.block0 = self._make_layers(cfg[0], batch_norm, 3) 25 | self.block1 = self._make_layers(cfg[1], batch_norm, cfg[0][-1]) 26 | self.block2 = self._make_layers(cfg[2], batch_norm, cfg[1][-1]) 27 | self.block3 = self._make_layers(cfg[3], batch_norm, cfg[2][-1]) 28 | self.block4 = self._make_layers(cfg[4], batch_norm, cfg[3][-1]) 29 | 30 | self.pool0 = nn.MaxPool2d(kernel_size=2, stride=2) 31 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) 32 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) 33 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) 34 | self.pool4 = nn.AdaptiveAvgPool2d((1, 1)) 35 | # self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) 36 | 37 | self.classifier = nn.Linear(512, num_classes) 38 | self._initialize_weights() 39 | 40 | def get_feat_modules(self): 41 | feat_m = nn.ModuleList([]) 42 | feat_m.append(self.block0) 43 | feat_m.append(self.pool0) 44 | feat_m.append(self.block1) 45 | feat_m.append(self.pool1) 46 | feat_m.append(self.block2) 47 | feat_m.append(self.pool2) 48 | feat_m.append(self.block3) 49 | feat_m.append(self.pool3) 50 | feat_m.append(self.block4) 51 | feat_m.append(self.pool4) 52 | return feat_m 53 | 54 | def get_bn_before_relu(self): 55 | bn1 = self.block1[-1] 56 | bn2 = self.block2[-1] 57 | bn3 = self.block3[-1] 58 | bn4 = self.block4[-1] 59 | return [bn1, bn2, bn3, bn4] 60 | 61 | def forward(self, x, is_feat=False, preact=False): 62 | h = x.shape[2] 63 | x = F.relu(self.block0(x)) 64 | f0 = x 65 | 66 | x = self.pool0(x) 67 | x = self.block1(x) 68 | f1_pre = x 69 | x = F.relu(x) 70 | f1 = x 71 | 72 | x = self.pool1(x) 73 | x = self.block2(x) 74 | f2_pre = x 75 | x = F.relu(x) 76 | f2 = x 77 | 78 | 79 | x = self.pool2(x) 80 | x = self.block3(x) 81 | f3_pre = x 82 | x = F.relu(x) 83 | if h == 64: 84 | x = self.pool3(x) 85 | x = self.block4(x) 86 | f4_pre = x 87 | x = F.relu(x) 88 | f3 = x 89 | 90 | x = self.pool4(x) 91 | x = x.view(x.size(0), -1) 92 | x = self.classifier(x) 93 | 94 | if is_feat: 95 | return [f0, f1, f2, f3], x 96 | else: 97 | return x 98 | 99 | @staticmethod 100 | def _make_layers(cfg, batch_norm=False, in_channels=3): 101 | layers = [] 102 | for v in cfg: 103 | if v == 'M': 104 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 105 | else: 106 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 107 | if batch_norm: 108 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 109 | else: 110 | layers += [conv2d, nn.ReLU(inplace=True)] 111 | in_channels = v 112 | layers = layers[:-1] 113 | return nn.Sequential(*layers) 114 | 115 | 116 | def _initialize_weights(self): 117 | for m in self.modules(): 118 | if isinstance(m, nn.Conv2d): 119 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 120 | m.weight.data.normal_(0, math.sqrt(2. / n)) 121 | if m.bias is not None: 122 | m.bias.data.zero_() 123 | elif isinstance(m, nn.BatchNorm2d): 124 | m.weight.data.fill_(1) 125 | m.bias.data.zero_() 126 | elif isinstance(m, nn.Linear): 127 | n = m.weight.size(1) 128 | m.weight.data.normal_(0, 0.01) 129 | m.bias.data.zero_() 130 | 131 | 132 | class Auxiliary_Classifier(nn.Module): 133 | def __init__(self, cfg, batch_norm=False, num_classes=100): 134 | super(Auxiliary_Classifier, self).__init__() 135 | 136 | self.block_extractor1 = nn.Sequential(*[nn.MaxPool2d(kernel_size=2, stride=2), 137 | self._make_layers(cfg[1], batch_norm, cfg[0][-1]), 138 | nn.ReLU(inplace=True), 139 | nn.MaxPool2d(kernel_size=2, stride=2), 140 | self._make_layers(cfg[2], batch_norm, cfg[1][-1]), 141 | nn.ReLU(inplace=True), 142 | nn.MaxPool2d(kernel_size=2, stride=2), 143 | self._make_layers(cfg[3], batch_norm, cfg[2][-1]), 144 | nn.ReLU(inplace=True), 145 | self._make_layers(cfg[4], batch_norm, cfg[3][-1]), 146 | nn.ReLU(inplace=True), 147 | nn.AdaptiveAvgPool2d((1, 1))]) 148 | 149 | self.block_extractor2 = nn.Sequential(*[nn.MaxPool2d(kernel_size=2, stride=2), 150 | self._make_layers(cfg[2], batch_norm, cfg[1][-1]), 151 | nn.ReLU(inplace=True), 152 | nn.MaxPool2d(kernel_size=2, stride=2), 153 | self._make_layers(cfg[3], batch_norm, cfg[2][-1]), 154 | nn.ReLU(inplace=True), 155 | self._make_layers(cfg[4], batch_norm, cfg[3][-1]), 156 | nn.ReLU(inplace=True), 157 | nn.AdaptiveAvgPool2d((1, 1))]) 158 | 159 | self.block_extractor3 = nn.Sequential(*[nn.MaxPool2d(kernel_size=2, stride=2), 160 | self._make_layers(cfg[3], batch_norm, cfg[2][-1]), 161 | nn.ReLU(inplace=True), 162 | self._make_layers(cfg[4], batch_norm, cfg[3][-1]), 163 | nn.ReLU(inplace=True), 164 | nn.AdaptiveAvgPool2d((1, 1))]) 165 | 166 | self.block_extractor4 = nn.Sequential(*[self._make_layers(cfg[3], batch_norm, cfg[4][-1]), 167 | nn.ReLU(inplace=True), 168 | self._make_layers(cfg[4], batch_norm, cfg[3][-1]), 169 | nn.ReLU(inplace=True), 170 | nn.AdaptiveAvgPool2d((1, 1))]) 171 | 172 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 173 | self.fc1 = nn.Linear(512, num_classes) 174 | self.fc2 = nn.Linear(512, num_classes) 175 | self.fc3 = nn.Linear(512, num_classes) 176 | self.fc4 = nn.Linear(512, num_classes) 177 | 178 | def _initialize_weights(self): 179 | for m in self.modules(): 180 | if isinstance(m, nn.Conv2d): 181 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 182 | m.weight.data.normal_(0, math.sqrt(2. / n)) 183 | if m.bias is not None: 184 | m.bias.data.zero_() 185 | elif isinstance(m, nn.BatchNorm2d): 186 | m.weight.data.fill_(1) 187 | m.bias.data.zero_() 188 | elif isinstance(m, nn.Linear): 189 | n = m.weight.size(1) 190 | m.weight.data.normal_(0, 0.01) 191 | m.bias.data.zero_() 192 | 193 | 194 | @staticmethod 195 | def _make_layers(cfg, batch_norm=False, in_channels=3): 196 | layers = [] 197 | for v in cfg: 198 | if v == 'M': 199 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 200 | else: 201 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 202 | if batch_norm: 203 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 204 | else: 205 | layers += [conv2d, nn.ReLU(inplace=True)] 206 | in_channels = v 207 | layers = layers[:-1] 208 | return nn.Sequential(*layers) 209 | 210 | 211 | def forward(self, x): 212 | ss_logits = [] 213 | for i in range(len(x)): 214 | idx = i + 1 215 | out = getattr(self, 'block_extractor'+str(idx))(x[i]) 216 | out = out.view(-1, 512) 217 | out = getattr(self, 'fc'+str(idx))(out) 218 | ss_logits.append(out) 219 | return ss_logits 220 | 221 | 222 | class VGG_Auxiliary(nn.Module): 223 | def __init__(self, cfg, batch_norm=False, num_classes=100): 224 | super(VGG_Auxiliary, self).__init__() 225 | self.backbone = VGG(cfg, batch_norm=batch_norm, num_classes=num_classes) 226 | self.auxiliary_classifier = Auxiliary_Classifier(cfg, batch_norm=batch_norm, num_classes=num_classes * 4) 227 | 228 | def forward(self, x, grad=False): 229 | feats, logit = self.backbone(x, is_feat=True) 230 | if grad is False: 231 | for i in range(len(feats)): 232 | feats[i] = feats[i].detach() 233 | ss_logits = self.auxiliary_classifier(feats) 234 | return logit, ss_logits 235 | 236 | 237 | cfg = { 238 | 'A': [[64], [128], [256, 256], [512, 512], [512, 512]], 239 | 'B': [[64, 64], [128, 128], [256, 256], [512, 512], [512, 512]], 240 | 'D': [[64, 64], [128, 128], [256, 256, 256], [512, 512, 512], [512, 512, 512]], 241 | 'E': [[64, 64], [128, 128], [256, 256, 256, 256], [512, 512, 512, 512], [512, 512, 512, 512]], 242 | 'S': [[64], [128], [256], [512], [512]], 243 | } 244 | 245 | 246 | def vgg8(**kwargs): 247 | """VGG 8-layer model (configuration "S") 248 | Args: 249 | pretrained (bool): If True, returns a model pre-trained on ImageNet 250 | """ 251 | model = VGG(cfg['S'], **kwargs) 252 | return model 253 | 254 | 255 | def vgg8_bn(**kwargs): 256 | """VGG 8-layer model (configuration "S") 257 | Args: 258 | pretrained (bool): If True, returns a model pre-trained on ImageNet 259 | """ 260 | model = VGG(cfg['S'], batch_norm=True, **kwargs) 261 | return model 262 | 263 | 264 | def vgg8_bn_aux(**kwargs): 265 | """VGG 8-layer model (configuration "S") 266 | Args: 267 | pretrained (bool): If True, returns a model pre-trained on ImageNet 268 | """ 269 | model = VGG_Auxiliary(cfg['S'], batch_norm=True, **kwargs) 270 | return model 271 | 272 | 273 | def vgg11(**kwargs): 274 | """VGG 11-layer model (configuration "A") 275 | Args: 276 | pretrained (bool): If True, returns a model pre-trained on ImageNet 277 | """ 278 | model = VGG(cfg['A'], **kwargs) 279 | return model 280 | 281 | 282 | def vgg11_bn(**kwargs): 283 | """VGG 11-layer model (configuration "A") with batch normalization""" 284 | model = VGG(cfg['A'], batch_norm=True, **kwargs) 285 | return model 286 | 287 | 288 | def vgg13(**kwargs): 289 | """VGG 13-layer model (configuration "B") 290 | Args: 291 | pretrained (bool): If True, returns a model pre-trained on ImageNet 292 | """ 293 | model = VGG(cfg['B'], **kwargs) 294 | return model 295 | 296 | 297 | def vgg13_bn(**kwargs): 298 | """VGG 13-layer model (configuration "B") with batch normalization""" 299 | model = VGG(cfg['B'], batch_norm=True, **kwargs) 300 | return model 301 | 302 | def vgg13_bn_aux(**kwargs): 303 | """VGG 13-layer model (configuration "B") with batch normalization""" 304 | model = VGG_Auxiliary(cfg['B'], batch_norm=True, **kwargs) 305 | return model 306 | 307 | 308 | def vgg16(**kwargs): 309 | """VGG 16-layer model (configuration "D") 310 | Args: 311 | pretrained (bool): If True, returns a model pre-trained on ImageNet 312 | """ 313 | model = VGG(cfg['D'], **kwargs) 314 | return model 315 | 316 | 317 | def vgg16_bn(**kwargs): 318 | """VGG 16-layer model (configuration "D") with batch normalization""" 319 | model = VGG(cfg['D'], batch_norm=True, **kwargs) 320 | return model 321 | 322 | 323 | def vgg19(**kwargs): 324 | """VGG 19-layer model (configuration "E") 325 | Args: 326 | pretrained (bool): If True, returns a model pre-trained on ImageNet 327 | """ 328 | model = VGG(cfg['E'], **kwargs) 329 | return model 330 | 331 | 332 | def vgg19_bn(**kwargs): 333 | """VGG 19-layer model (configuration 'E') with batch normalization""" 334 | model = VGG(cfg['E'], batch_norm=True, **kwargs) 335 | return model 336 | 337 | 338 | if __name__ == '__main__': 339 | import torch 340 | x = torch.randn(2, 3, 32, 32) 341 | net = vgg13_bn_aux(num_classes=100) 342 | logit, ss_logits = net(x) 343 | print(logit.size()) 344 | -------------------------------------------------------------------------------- /models/resnet_imagenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | 6 | __all__ = ['resnet18_imagenet', 'resnet18_imagenet_aux', 'resnet34_imagenet', 7 | 'resnet34_imagenet_aux', 'resnet50_imagenet','resnet50_imagenet_aux'] 8 | 9 | 10 | 11 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 12 | """3x3 convolution with padding""" 13 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 14 | padding=dilation, groups=groups, bias=False, dilation=dilation) 15 | 16 | 17 | def conv1x1(in_planes, out_planes, stride=1): 18 | """1x1 convolution""" 19 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 20 | 21 | 22 | class BasicBlock(nn.Module): 23 | expansion = 1 24 | 25 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 26 | base_width=64, dilation=1, norm_layer=None): 27 | super(BasicBlock, self).__init__() 28 | if norm_layer is None: 29 | norm_layer = nn.BatchNorm2d 30 | if groups != 1 or base_width != 64: 31 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 32 | if dilation > 1: 33 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 34 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 35 | self.conv1 = conv3x3(inplanes, planes, stride) 36 | self.bn1 = norm_layer(planes) 37 | self.relu = nn.ReLU(inplace=True) 38 | self.conv2 = conv3x3(planes, planes) 39 | self.bn2 = norm_layer(planes) 40 | self.downsample = downsample 41 | self.stride = stride 42 | 43 | def forward(self, x): 44 | identity = x 45 | 46 | out = self.conv1(x) 47 | out = self.bn1(out) 48 | out = self.relu(out) 49 | 50 | out = self.conv2(out) 51 | out = self.bn2(out) 52 | 53 | if self.downsample is not None: 54 | identity = self.downsample(x) 55 | 56 | out += identity 57 | out = self.relu(out) 58 | 59 | return out 60 | 61 | 62 | class Bottleneck(nn.Module): 63 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 64 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 65 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 66 | # This variant is also known as ResNet V1.5 and improves accuracy according to 67 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 68 | 69 | expansion = 4 70 | 71 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 72 | base_width=64, dilation=1, norm_layer=None): 73 | super(Bottleneck, self).__init__() 74 | if norm_layer is None: 75 | norm_layer = nn.BatchNorm2d 76 | width = int(planes * (base_width / 64.)) * groups 77 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 78 | self.conv1 = conv1x1(inplanes, width) 79 | self.bn1 = norm_layer(width) 80 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 81 | self.bn2 = norm_layer(width) 82 | self.conv3 = conv1x1(width, planes * self.expansion) 83 | self.bn3 = norm_layer(planes * self.expansion) 84 | self.relu = nn.ReLU(inplace=True) 85 | self.downsample = downsample 86 | self.stride = stride 87 | 88 | def forward(self, x): 89 | identity = x 90 | 91 | out = self.conv1(x) 92 | out = self.bn1(out) 93 | out = self.relu(out) 94 | 95 | out = self.conv2(out) 96 | out = self.bn2(out) 97 | out = self.relu(out) 98 | 99 | out = self.conv3(out) 100 | out = self.bn3(out) 101 | 102 | if self.downsample is not None: 103 | identity = self.downsample(x) 104 | 105 | out += identity 106 | out = self.relu(out) 107 | 108 | return out 109 | 110 | 111 | class ResNet(nn.Module): 112 | 113 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 114 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 115 | norm_layer=None): 116 | super(ResNet, self).__init__() 117 | if norm_layer is None: 118 | norm_layer = nn.BatchNorm2d 119 | self._norm_layer = norm_layer 120 | 121 | self.inplanes = 64 122 | self.dilation = 1 123 | if replace_stride_with_dilation is None: 124 | # each element in the tuple indicates if we should replace 125 | # the 2x2 stride with a dilated convolution instead 126 | replace_stride_with_dilation = [False, False, False] 127 | if len(replace_stride_with_dilation) != 3: 128 | raise ValueError("replace_stride_with_dilation should be None " 129 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 130 | self.groups = groups 131 | self.base_width = width_per_group 132 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 133 | bias=False) 134 | self.bn1 = norm_layer(self.inplanes) 135 | self.relu = nn.ReLU(inplace=True) 136 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 137 | self.layer1 = self._make_layer(block, 64, layers[0]) 138 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 139 | dilate=replace_stride_with_dilation[0]) 140 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 141 | dilate=replace_stride_with_dilation[1]) 142 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 143 | dilate=replace_stride_with_dilation[2]) 144 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 145 | self.fc = nn.Linear(512 * block.expansion, num_classes) 146 | 147 | for m in self.modules(): 148 | if isinstance(m, nn.Conv2d): 149 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 150 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 151 | nn.init.constant_(m.weight, 1) 152 | nn.init.constant_(m.bias, 0) 153 | 154 | # Zero-initialize the last BN in each residual branch, 155 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 156 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 157 | if zero_init_residual: 158 | for m in self.modules(): 159 | if isinstance(m, Bottleneck): 160 | nn.init.constant_(m.bn3.weight, 0) 161 | elif isinstance(m, BasicBlock): 162 | nn.init.constant_(m.bn2.weight, 0) 163 | 164 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 165 | norm_layer = self._norm_layer 166 | downsample = None 167 | previous_dilation = self.dilation 168 | if dilate: 169 | self.dilation *= stride 170 | stride = 1 171 | if stride != 1 or self.inplanes != planes * block.expansion: 172 | downsample = nn.Sequential( 173 | conv1x1(self.inplanes, planes * block.expansion, stride), 174 | norm_layer(planes * block.expansion), 175 | ) 176 | 177 | layers = [] 178 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 179 | self.base_width, previous_dilation, norm_layer)) 180 | self.inplanes = planes * block.expansion 181 | for _ in range(1, blocks): 182 | layers.append(block(self.inplanes, planes, groups=self.groups, 183 | base_width=self.base_width, dilation=self.dilation, 184 | norm_layer=norm_layer)) 185 | 186 | return nn.Sequential(*layers) 187 | 188 | def forward(self, x, is_feat=False): 189 | # See note [TorchScript super()] 190 | x = self.conv1(x) 191 | x = self.bn1(x) 192 | x = self.relu(x) 193 | x = self.maxpool(x) 194 | 195 | x = self.layer1(x) 196 | f1 = x 197 | x = self.layer2(x) 198 | f2 = x 199 | x = self.layer3(x) 200 | f3 = x 201 | x = self.layer4(x) 202 | f4 = x 203 | 204 | x = self.avgpool(x) 205 | x = torch.flatten(x, 1) 206 | x = self.fc(x) 207 | 208 | if is_feat: 209 | return [f1, f2, f3, f4], x 210 | else: 211 | return x 212 | 213 | 214 | class Auxiliary_Classifier(nn.Module): 215 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 216 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 217 | norm_layer=None): 218 | super(Auxiliary_Classifier, self).__init__() 219 | 220 | self.dilation = 1 221 | self.groups = groups 222 | self.base_width = width_per_group 223 | self.inplanes = 64 * block.expansion 224 | self.block_extractor1 = nn.Sequential(*[self._make_layer(block, 128, layers[1], stride=2), 225 | self._make_layer(block, 256, layers[2], stride=2), 226 | self._make_layer(block, 512, layers[3], stride=2)]) 227 | 228 | self.inplanes = 128 * block.expansion 229 | self.block_extractor2 = nn.Sequential(*[self._make_layer(block, 256, layers[2], stride=2), 230 | self._make_layer(block, 512, layers[3], stride=2)]) 231 | 232 | self.inplanes = 256 * block.expansion 233 | self.block_extractor3 = nn.Sequential(*[self._make_layer(block, 512, layers[3], stride=2)]) 234 | 235 | self.inplanes = 512 * block.expansion 236 | self.block_extractor4 = nn.Sequential(*[self._make_layer(block, 512, layers[3], stride=1)]) 237 | 238 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 239 | self.fc1 = nn.Linear(512 * block.expansion, num_classes) 240 | self.fc2 = nn.Linear(512 * block.expansion, num_classes) 241 | self.fc3 = nn.Linear(512 * block.expansion, num_classes) 242 | self.fc4 = nn.Linear(512 * block.expansion, num_classes) 243 | 244 | for m in self.modules(): 245 | if isinstance(m, nn.Conv2d): 246 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 247 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 248 | nn.init.constant_(m.weight, 1) 249 | nn.init.constant_(m.bias, 0) 250 | 251 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 252 | norm_layer = nn.BatchNorm2d 253 | downsample = None 254 | previous_dilation = self.dilation 255 | if dilate: 256 | self.dilation *= stride 257 | stride = 1 258 | if stride != 1 or self.inplanes != planes * block.expansion: 259 | downsample = nn.Sequential( 260 | conv1x1(self.inplanes, planes * block.expansion, stride), 261 | norm_layer(planes * block.expansion), 262 | ) 263 | 264 | layers = [] 265 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 266 | self.base_width, previous_dilation, norm_layer)) 267 | self.inplanes = planes * block.expansion 268 | for _ in range(1, blocks): 269 | layers.append(block(self.inplanes, planes, groups=self.groups, 270 | base_width=self.base_width, dilation=self.dilation, 271 | norm_layer=norm_layer)) 272 | 273 | return nn.Sequential(*layers) 274 | 275 | 276 | def forward(self, x): 277 | ss_logits = [] 278 | for i in range(len(x)): 279 | idx = i + 1 280 | 281 | out = getattr(self, 'block_extractor'+str(idx))(x[i]) 282 | out = self.avg_pool(out) 283 | out = out.view(out.size(0), -1) 284 | out = getattr(self, 'fc'+str(idx))(out) 285 | ss_logits.append(out) 286 | return ss_logits 287 | 288 | 289 | class ResNet_Auxiliary(nn.Module): 290 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False): 291 | super(ResNet_Auxiliary, self).__init__() 292 | self.backbone = ResNet(block, layers, num_classes=num_classes, zero_init_residual=zero_init_residual) 293 | self.auxiliary_classifier = Auxiliary_Classifier(block, layers, num_classes=num_classes*4, zero_init_residual=zero_init_residual) 294 | 295 | def forward(self, x, grad=False): 296 | if grad is False: 297 | feats, logit = self.backbone(x, is_feat=True) 298 | for i in range(len(feats)): 299 | feats[i] = feats[i].detach() 300 | else: 301 | feats, logit = self.backbone(x, is_feat=True) 302 | 303 | ss_logits = self.auxiliary_classifier(feats) 304 | return logit, ss_logits 305 | 306 | 307 | def resnet18_imagenet(**kwargs): 308 | return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 309 | 310 | def resnet18_imagenet_aux(**kwargs): 311 | return ResNet_Auxiliary(BasicBlock, [2, 2, 2, 2], **kwargs) 312 | 313 | def resnet34_imagenet(**kwargs): 314 | return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 315 | def resnet34_imagenet_aux(**kwargs): 316 | return ResNet_Auxiliary(BasicBlock, [3, 4, 6, 3], **kwargs) 317 | 318 | def resnet50_imagenet(**kwargs): 319 | return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 320 | 321 | def resnet50_imagenet_aux(**kwargs): 322 | return ResNet_Auxiliary(Bottleneck, [3, 4, 6, 3], **kwargs) 323 | 324 | 325 | if __name__ == '__main__': 326 | net = resnet34_imagenet_aux(num_classes=1000) 327 | from utils import cal_param_size, cal_multi_adds 328 | print('Params: %.2fM, Multi-adds: %.3fM' 329 | % (cal_param_size(net) / 1e6, cal_multi_adds(net, (2, 3, 224, 224)) / 1e6)) 330 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Hierarchical Self-supervised Augmented Knowledge Distillation 3 | 4 | - This project provides source code for our Hierarchical Self-supervised Augmented Knowledge Distillation (HSAKD). 5 | 6 | - This paper is publicly available at the official IJCAI proceedings: [https://www.ijcai.org/proceedings/2021/0168.pdf](https://www.ijcai.org/proceedings/2021/0168.pdf) 7 | 8 | - Our poster presentation is publicly available at [765_IJCAI_poster.pdf](https://github.com/winycg/HSAKD/tree/main/poster/765_IJCAI_poster.pdf) 9 | 10 | - Our sildes of oral presentation are publicly available at [765_IJCAI_slides.pdf](https://github.com/winycg/HSAKD/tree/main/poster/765_IJCAI_slides.pdf) 11 | 12 | - 🏆 __SOTA of Knowledge Distillation for student ResNet-18 trained by teacher ResNet-34 on ImageNet.__ 13 | 14 | - The extended version of HSAKD is accepted by IEEE Transactions on Neural Networks and Learning Systems. 15 | 16 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/hierarchical-self-supervised-augmented/knowledge-distillation-on-imagenet)](https://paperswithcode.com/sota/knowledge-distillation-on-imagenet?p=hierarchical-self-supervised-augmented) 17 | 18 | ## Installation 19 | 20 | ### Requirements 21 | 22 | Ubuntu 18.04 LTS 23 | 24 | Python 3.8 ([Anaconda](https://www.anaconda.com/) is recommended) 25 | 26 | CUDA 11.1 27 | 28 | PyTorch 1.6.0 29 | 30 | NCCL for CUDA 11.1 31 | 32 | 33 | ## Perform Offline KD experiments on CIFAR-100 dataset 34 | #### Dataset 35 | CIFAR-100 : [download](http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz) 36 | 37 | unzip to the `./data` folder 38 | 39 | #### Training baselines 40 | ``` 41 | python train_baseline_cifar.py --arch wrn_16_2 --data ./data/ --gpu 0 42 | ``` 43 | More commands for training various architectures can be found in [train_baseline_cifar.sh](https://github.com/winycg/HSAKD/blob/main/train_baseline_cifar.sh) 44 | 45 | #### Training teacher networks 46 | (1) Use pre-trained backbone and train all auxiliary classifiers. 47 | 48 | The pre-trained backbone weights follow .pth files downloaded from repositories of [CRD](https://github.com/HobbitLong/RepDistiller) and [SSKD](https://github.com/xuguodong03/SSKD). 49 | 50 | You should download them from [Google Derive](https://drive.google.com/drive/folders/18hhFrGtJmpJ8J54yCI6KmgmD17xyqkjg?usp=sharing) before training the teacher network that needs a pre-trained backbone 51 | ``` 52 | python train_teacher_cifar.py \ 53 | --arch wrn_40_2_aux \ 54 | --milestones 30 60 90 --epochs 100 \ 55 | --checkpoint-dir ./checkpoint \ 56 | --data ./data \ 57 | --gpu 2 --manual 0 \ 58 | --pretrained-backbone ./pretrained_backbones/wrn_40_2.pth \ 59 | --freezed 60 | ``` 61 | 62 | 63 | More commands for training various teacher networks with frozen backbones can be found in [train_teacher_freezed.sh](https://github.com/winycg/HSAKD/blob/main/train_teacher_freezed.sh) 64 | 65 | The pre-trained teacher networks can be downloaded from [Google Derive](https://drive.google.com/drive/folders/10t6ehp_9qL8iXTLL2k7gAQeyHMCCwnxD?usp=sharing) 66 | 67 | 68 | (2) Train the backbone and all auxiliary classifiers jointly from scratch. In this case, we no longer need a pre-trained teacher backbone. 69 | 70 | It can lead to a better accuracy for teacher backbone towards our empirical study. 71 | ``` 72 | python train_teacher_cifar.py \ 73 | --arch wrn_40_2_aux \ 74 | --checkpoint-dir ./checkpoint \ 75 | --data ./data \ 76 | --gpu 2 --manual 1 77 | ``` 78 | 79 | The pre-trained teacher networks can be downloaded from [Google Derive](https://drive.google.com/drive/folders/1TIzxjUQ1MKUjdZux5EBoaLI3QbuFnDyy?usp=sharing) 80 | 81 | For differentiating (1) and (2), we use `--manual 0` to indicate the case of (1) and `--manual 1` to indicate the case of (2) 82 | #### Training student networks 83 | (1) train baselines of student networks 84 | ``` 85 | python train_baseline_cifar.py --arch wrn_16_2 --data ./data/ --gpu 0 86 | ``` 87 | More commands for training various teacher-student pairs can be found in [train_baseline_cifar.sh](https://github.com/winycg/HSAKD/blob/main/train_baseline_cifar.sh) 88 | 89 | (2) train student networks with a pre-trained teacher network 90 | 91 | Note that the specific teacher network should be pre-trained before training the student networks 92 | 93 | ``` 94 | python train_student_cifar.py \ 95 | --tarch wrn_40_2_aux \ 96 | --arch wrn_16_2_aux \ 97 | --tcheckpoint ./checkpoint/train_teacher_cifar_arch_wrn_40_2_aux_dataset_cifar100_seed0/wrn_40_2_aux.pth.tar \ 98 | --checkpoint-dir ./checkpoint \ 99 | --data ./data \ 100 | --gpu 0 --manual 0 101 | ``` 102 | 103 | More commands for training various teacher-student pairs can be found in [train_student_cifar.sh](https://github.com/winycg/HSAKD/blob/main/train_student_cifar.sh) 104 | 105 | #### Results of the same architecture style between teacher and student networks 106 | 107 | |Teacher
Student | WRN-40-2
WRN-16-2 | WRN-40-2
WRN-40-1 | ResNet-56
ResNet-20 | ResNet32x4
ResNet8x4 | 108 | |:---------------:|:-----------------:|:-----------------:|:-----------------:|:--------------------:| 109 | | Teacher
Teacher* | 76.45
80.70 | 76.45
80.70 | 73.44
77.20 | 79.63
83.73 | 110 | | Student | 73.57±0.23 | 71.95±0.59 | 69.62±0.26 | 72.95±0.24 | 111 | | HSAKD | 77.20±0.17 | 77.00±0.21 | 72.58±0.33 | 77.26±0.14 | 112 | | HSAKD* | **78.67**±0.20 | **78.12**±0.25 | **73.73**±0.10 | **77.69**±0.05| 113 | 114 | #### Results of different architecture styles between teacher and student networks 115 | 116 | |Teacher
Student | VGG13
MobileNetV2 | ResNet50
MobileNetV2 | WRN-40-2
ShuffleNetV1 | ResNet32x4
ShuffleNetV2 | 117 | |:---------------:|:-----------------:|:-----------------:|:-----------------:|:--------------------:| 118 | | Teacher
Teacher* | 74.64
78.48 | 76.34
83.85 | 76.45
80.70 | 79.63
83.73 | 119 | | Student | 73.51±0.26 | 73.51±0.26 | 71.74±0.35 | 72.96±0.33 | 120 | | HSAKD | 77.45±0.21 | 78.79±0.11 | 78.51±0.20 | 79.93±0.11 | 121 | | HSAKD* | **79.27**±0.12 | **79.43**±0.24 | **80.11**±0.32 | **80.86**±0.15| 122 | 123 | - `Teacher` : training teacher networks by (1). 124 | - `Teacher*` : training teacher networks by (2). 125 | - `HSAKD` : training student networks by `Teacher`. 126 | - `HSAKD*` : training student networks by `Teacher*`. 127 | 128 | #### Training student networks under few-shot scenario 129 | ``` 130 | python train_student_few_shot.py \ 131 | --tarch resnet56_aux \ 132 | --arch resnet20_aux \ 133 | --tcheckpoint ./checkpoint/train_teacher_cifar_arch_resnet56_aux_dataset_cifar100_seed0/resnet56_aux.pth.tar \ 134 | --checkpoint-dir ./checkpoint \ 135 | --data ./data/ \ 136 | --few-ratio 0.25 \ 137 | --gpu 2 --manual 0 138 | ``` 139 | 140 | `--few-ratio`: various percentages of training samples 141 | 142 | | Percentage | 25% | 50% | 75% | 100% | 143 | |:---------------:|:-----------------:|:-----------------:|:-----------------:|:-----------------:| 144 | | Student | 68.50±0.24 | 72.18±0.41 | 73.26±0.11 | 73.73±0.10 | 145 | 146 | 147 | ## Perform transfer experiments on STL-10 and TinyImageNet dataset 148 | ### Dataset 149 | STL-10: [download](http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz) 150 | 151 | unzip to the `./data` folder 152 | 153 | TinyImageNet : [download](http://tiny-imagenet.herokuapp.com/) 154 | 155 | unzip to the `./data` folder 156 | 157 | Prepare the TinyImageNet validation dataset as follows 158 | ``` 159 | cd data 160 | python preprocess_tinyimagenet.py 161 | ``` 162 | #### Linear classification on STL-10 163 | ``` 164 | python eval_rep.py \ 165 | --arch mobilenetV2 \ 166 | --dataset STL-10 \ 167 | --data ./data/ \ 168 | --s-path ./checkpoint/train_student_cifar_tarch_vgg13_bn_aux_arch_mobilenetV2_aux_dataset_cifar100_seed0/mobilenetV2_aux.pth.tar 169 | ``` 170 | #### Linear classification on TinyImageNet 171 | ``` 172 | python eval_rep.py \ 173 | --arch mobilenetV2 \ 174 | --dataset TinyImageNet \ 175 | --data ./data/tiny-imagenet-200/ \ 176 | --s-path ./checkpoint/train_student_cifar_tarch_vgg13_bn_aux_arch_mobilenetV2_aux_dataset_cifar100_seed0/mobilenetV2_aux.pth.tar 177 | ``` 178 | | Transferred Dataset | CIFAR-100 → STL-10 | CIFAR-100 → TinyImageNet | 179 | |:---------------:|:-----------------:|:-----------------:| 180 | | Student | 74.66 | 42.57 | 181 | 182 | ## Perform Offline KD experiments on ImageNet dataset 183 | 184 | ### Dataset preparation 185 | 186 | - Download the ImageNet dataset to YOUR_IMAGENET_PATH and move validation images to labeled subfolders 187 | - The [script](https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh) may be helpful. 188 | 189 | - Create a datasets subfolder and a symlink to the ImageNet dataset 190 | 191 | ``` 192 | $ ln -s PATH_TO_YOUR_IMAGENET ./data/ 193 | ``` 194 | Folder of ImageNet Dataset: 195 | ``` 196 | data/ImageNet 197 | ├── train 198 | ├── val 199 | ``` 200 | 201 | ### Training teacher networks 202 | (1) Use pre-trained backbone and train all auxiliary classifiers. 203 | 204 | The pre-trained backbone weights of ResNet-34 follow the [resnet34-333f7ec4.pth](https://download.pytorch.org/models/resnet34-333f7ec4.pth) downloaded from the official PyTorch. 205 | ``` 206 | python train_teacher_imagenet.py \ 207 | --dist-url 'tcp://127.0.0.1:55515' \ 208 | --data ./data/ImageNet/ \ 209 | --dist-backend 'nccl' \ 210 | --multiprocessing-distributed \ 211 | --checkpoint-dir ./checkpoint/ \ 212 | --pretrained-backbone ./pretrained_backbones/resnet34-333f7ec4.pth \ 213 | --freezed \ 214 | --gpu 0,1,2,3,4,5,6,7 \ 215 | --world-size 1 --rank 0 --manual_seed 0 216 | ``` 217 | 218 | (2) Train the backbone and all auxiliary classifiers jointly from scratch. In this case, we no longer need a pre-trained teacher backbone. 219 | 220 | It can lead to a better accuracy for teacher backbone towards our empirical study. 221 | ``` 222 | python train_teacher_imagenet.py \ 223 | --dist-url 'tcp://127.0.0.1:2222' \ 224 | --data ./data/ImageNet/ \ 225 | --dist-backend 'nccl' \ 226 | --multiprocessing-distributed \ 227 | --checkpoint-dir ./checkpoint/ \ 228 | --gpu 0,1,2,3,4,5,6,7 \ 229 | --world-size 1 --rank 0 --manual_seed 1 230 | ``` 231 | 232 | ### Training student networks 233 | 234 | (1) using the teacher network of the version of a frozen backbone 235 | ``` 236 | python train_student_imagenet.py \ 237 | --data ./data/ImageNet/ \ 238 | --arch resnet18_imagenet_aux \ 239 | --tarch resnet34_imagenet_aux \ 240 | --tcheckpoint ./checkpoint/train_teacher_imagenet_arch_resnet34_aux_dataset_imagenet_seed0/resnet34_imagenet_aux_best.pth.tar \ 241 | --dist-url 'tcp://127.0.0.1:2222' \ 242 | --dist-backend 'nccl' \ 243 | --multiprocessing-distributed \ 244 | --gpu-id 0,1,2,3,4,5,6,7 \ 245 | --world-size 1 --rank 0 --manual_seed 0 246 | ``` 247 | 248 | (2) using the teacher network of the joint training version 249 | ``` 250 | python train_student_imagenet.py \ 251 | --data ./data/ImageNet/ \ 252 | --arch resnet18_imagenet_aux \ 253 | --tarch resnet34_imagenet_aux \ 254 | --tcheckpoint ./checkpoint/train_teacher_imagenet_arch_resnet34_aux_dataset_imagenet_seed1/resnet34_imagenet_aux_best.pth.tar \ 255 | --dist-url 'tcp://127.0.0.1:2222' \ 256 | --dist-backend 'nccl' \ 257 | --multiprocessing-distributed \ 258 | --gpu-id 0,1,2,3,4,5,6,7 \ 259 | --world-size 1 --rank 0 --manual_seed 1 260 | ``` 261 | 262 | #### Results on the teacher-student pair of ResNet-34 and ResNet-18 263 | 264 | | Accuracy |Teacher |Teacher* | Student | HSAKD | HSAKD* | 265 | |:---------------:|:-----------------:|:-----------------:|:-----------------:|:--------------------:|:--------------------:| 266 | | Top-1 | 73.31 | 75.48 | 69.75 | 72.16 | **72.39** | 267 | | Top-5 | 91.42 | 92.67 | 89.07 | 90.85 | **91.00** | 268 | | Pretrained Models | [resnet34_0](https://drive.google.com/file/d/1FojZDTafcQj4-vu9TKuq_ss36mbCp3UJ/view?usp=sharing) | [resnet34_1](https://drive.google.com/file/d/1LtZSKtAVr30xn8Yb3FIm-xd5HTsSAwX1/view?usp=sharing) | [resnet18](https://download.pytorch.org/models/resnet18-5c106cde.pth) | [resnet18_0](https://drive.google.com/file/d/1jqoNEAkNgHpX6HS6nyWegMfHMisl2020/view?usp=sharing)| [resnet18_1](https://drive.google.com/file/d/1O5yM-3rJvsU6nrAqCncVMNSZ6HsJTKQN/view?usp=sharing) | 269 | 270 | ## Perform Online Mutual KD experiments on CIFAR-100 dataset 271 | 272 | Online Mutual KD train two same networks to teach each other. More commands for training various student architectures can be found in [train_online_kd_cifar.sh](https://github.com/winycg/HSAKD/blob/main/train_online_kd_cifar.sh) 273 | 274 | | Network | Baseline | HSSAKD (Online) | 275 | |:---------------:|:-----------------:|:-----------------:| 276 | | WRN-40-2 | 76.44±0.20 | 82.58±0.21 | 277 | | WRN-40-1 | 71.95±0.59 | 76.67±0.41 | 278 | | ResNet-56 | 73.00±0.17 | 78.16±0.56 | 279 | | ResNet-32x4 | 79.56±0.23 | 84.91±0.19 | 280 | | VGG-13 | 75.35±0.21 | 80.44±0.05 | 281 | | MobileNetV2 | 73.51±0.26 | 78.85±0.13 | 282 | | ShuffleNetV1 | 71.74±0.35 | 78.34±0.03 | 283 | | ShuffleNetV2 | 72.96±0.33 | 79.98±0.12 | 284 | 285 | ## Perform Online Mutual KD experiments on ImageNet dataset 286 | 287 | ``` 288 | python train_online_kd_imagenet.py \ 289 | --data ./data/ImageNet/ \ 290 | --arch resnet18_resnet18 \ 291 | --dist-url 'tcp://127.0.0.1:2222' \ 292 | --dist-backend 'nccl' \ 293 | --multiprocessing-distributed \ 294 | --gpu-id 0,1,2,3,4,5,6,7 \ 295 | --world-size 1 --rank 0 296 | ``` 297 | 298 | | Network | Baseline | HSSAKD (Online) | 299 | |:---------------:|:-----------------:|:-----------------:| 300 | | ResNet-18 | 69.75 | 71.49 | 301 | 302 | 303 | 304 | ## Citation 305 | 306 | ``` 307 | @inproceedings{yang2021hsakd, 308 | title={Hierarchical Self-supervised Augmented Knowledge Distillation}, 309 | author={Chuanguang Yang, Zhulin An, Linhang Cai, Yongjun Xu}, 310 | booktitle={Proceedings of the Thirtieth International Joint Conference on Artificial Intelligence (IJCAI)}, 311 | pages = {1217--1223}, 312 | year={2021} 313 | } 314 | 315 | @article{yang2022hssakd, 316 | author={Yang, Chuanguang and An, Zhulin and Cai, Linhang and Xu, Yongjun}, 317 | journal={IEEE Transactions on Neural Networks and Learning Systems}, 318 | title={Knowledge Distillation Using Hierarchical Self-Supervision Augmented Distribution}, 319 | year={2022}, 320 | pages={1-15}, 321 | doi={10.1109/TNNLS.2022.3186807}} 322 | ``` 323 | -------------------------------------------------------------------------------- /train_baseline_cifar.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.backends.cudnn as cudnn 5 | import torch.nn.functional as F 6 | 7 | import os 8 | import shutil 9 | import argparse 10 | import numpy as np 11 | 12 | 13 | import models 14 | import torchvision 15 | import torchvision.transforms as transforms 16 | from utils import cal_param_size, cal_multi_adds 17 | 18 | 19 | from bisect import bisect_right 20 | import time 21 | import math 22 | 23 | 24 | parser = argparse.ArgumentParser(description='PyTorch CIFAR Training') 25 | parser.add_argument('--data', default='./data/', type=str, help='Dataset directory') 26 | parser.add_argument('--dataset', default='cifar100', type=str, help='Dataset name') 27 | parser.add_argument('--arch', default='resnet18', type=str, help='network architecture') 28 | parser.add_argument('--init-lr', default=0.05, type=float, help='learning rate') 29 | parser.add_argument('--weight-decay', default=5e-4, type=float, help='weight decay') 30 | parser.add_argument('--lr-type', default='multistep', type=str, help='learning rate strategy') 31 | parser.add_argument('--milestones', default=[150,180,210], type=list, help='milestones for lr-multistep') 32 | parser.add_argument('--sgdr-t', default=300, type=int, dest='sgdr_t',help='SGDR T_0') 33 | parser.add_argument('--warmup-epoch', default=0, type=int, help='warmup epoch') 34 | parser.add_argument('--epochs', type=int, default=240, help='number of epochs to train') 35 | parser.add_argument('--batch-size', type=int, default=64, help='batch size') 36 | parser.add_argument('--num-workers', type=int, default=8, help='number workers') 37 | parser.add_argument('--gpu-id', type=str, default='0') 38 | parser.add_argument('--manual_seed', type=int, default=0) 39 | parser.add_argument('--kd_T', type=float, default=3, help='temperature for KD distillation') 40 | parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint') 41 | parser.add_argument('--evaluate', '-e', action='store_true', help='evaluate model') 42 | parser.add_argument('--checkpoint-dir', default='./checkpoint', type=str, help='checkpoint dir') 43 | 44 | 45 | # global hyperparameter set 46 | args = parser.parse_args() 47 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 48 | 49 | log_txt = 'result/'+ str(os.path.basename(__file__).split('.')[0]) + '_'+\ 50 | 'arch' + '_' + args.arch + '_'+\ 51 | 'dataset' + '_' + args.dataset + '_'+\ 52 | 'seed'+ str(args.manual_seed) +'.txt' 53 | 54 | log_dir = str(os.path.basename(__file__).split('.')[0]) + '_'+\ 55 | 'arch'+ '_' + args.arch + '_'+\ 56 | 'dataset' + '_' + args.dataset + '_'+\ 57 | 'seed'+ str(args.manual_seed) 58 | 59 | args.checkpoint_dir = os.path.join(args.checkpoint_dir, log_dir) 60 | if not os.path.isdir(args.checkpoint_dir): 61 | os.makedirs(args.checkpoint_dir) 62 | 63 | if args.resume is False and args.evaluate is False: 64 | with open(log_txt, 'a+') as f: 65 | f.seek(0) 66 | f.truncate() 67 | 68 | np.random.seed(args.manual_seed) 69 | torch.manual_seed(args.manual_seed) 70 | torch.cuda.manual_seed_all(args.manual_seed) 71 | torch.set_printoptions(precision=4) 72 | # ----------------------------------------------------------------------------------------- 73 | 74 | num_classes = 100 75 | trainset = torchvision.datasets.CIFAR100(root=args.data, train=True, download=True, 76 | transform=transforms.Compose([ 77 | transforms.RandomCrop(32, padding=4), 78 | transforms.RandomHorizontalFlip(), 79 | transforms.ToTensor(), 80 | transforms.Normalize([0.5071, 0.4867, 0.4408], 81 | [0.2675, 0.2565, 0.2761]) 82 | ])) 83 | 84 | testset = torchvision.datasets.CIFAR100(root=args.data, train=False, download=True, 85 | transform=transforms.Compose([ 86 | transforms.ToTensor(), 87 | transforms.Normalize([0.5071, 0.4867, 0.4408], 88 | [0.2675, 0.2565, 0.2761]), 89 | ])) 90 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, 91 | pin_memory=(torch.cuda.is_available())) 92 | 93 | testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=False, 94 | pin_memory=(torch.cuda.is_available())) 95 | 96 | # -------------------------------------------------------------------------------------------- 97 | 98 | # Model 99 | print('==> Building model..') 100 | model = getattr(models, args.arch) 101 | net = model(num_classes=num_classes) 102 | net.eval() 103 | if args.dataset.startswith('cifar'): 104 | resolution = (1, 3, 32, 32) 105 | elif args.dataset.startswith('imagenet'): 106 | resolution = (1, 3, 224, 224) 107 | print('Arch: %s, Params: %.2fM, Multi-adds: %.2fG' 108 | % (args.arch, cal_param_size(net)/1e6, cal_multi_adds(net, resolution)/1e9)) 109 | del(net) 110 | 111 | 112 | net = model(num_classes=num_classes).cuda() 113 | net = torch.nn.DataParallel(net) 114 | cudnn.benchmark = True 115 | 116 | 117 | class DistillKL(nn.Module): 118 | """Distilling the Knowledge in a Neural Network""" 119 | def __init__(self, T): 120 | super(DistillKL, self).__init__() 121 | self.T = T 122 | 123 | def forward(self, y_s, y_t): 124 | p_s = F.log_softmax(y_s/self.T, dim=1) 125 | p_t = F.softmax(y_t/self.T, dim=1) 126 | loss = F.kl_div(p_s, p_t, reduction='batchmean') * (self.T**2) 127 | return loss 128 | 129 | 130 | def correct_num(output, target, topk=(1,)): 131 | """Computes the precision@k for the specified values of k""" 132 | maxk = max(topk) 133 | batch_size = target.size(0) 134 | 135 | _, pred = output.topk(maxk, 1, True, True) 136 | correct = pred.eq(target.view(-1, 1).expand_as(pred)) 137 | 138 | res = [] 139 | for k in topk: 140 | correct_k = correct[:, :k].float().sum() 141 | res.append(correct_k) 142 | return res 143 | 144 | 145 | def adjust_lr(optimizer, epoch, args, step=0, all_iters_per_epoch=0, eta_min=0.): 146 | cur_lr = 0. 147 | if epoch < args.warmup_epoch: 148 | cur_lr = args.init_lr * float(1 + step + epoch*all_iters_per_epoch)/(args.warmup_epoch *all_iters_per_epoch) 149 | else: 150 | epoch = epoch - args.warmup_epoch 151 | if args.lr_type == 'SGDR': 152 | i = int(math.log2(epoch / args.sgdr_t + 1)) 153 | T_cur = epoch - args.sgdr_t * (2 ** (i) - 1) 154 | T_i = (args.sgdr_t * 2 ** i) 155 | 156 | cur_lr = eta_min + 0.5 * (args.init_lr - eta_min) * (1 + np.cos(np.pi * T_cur / T_i)) 157 | 158 | elif args.lr_type == 'multistep': 159 | cur_lr = args.init_lr * 0.1 ** bisect_right(args.milestones, epoch) 160 | 161 | for param_group in optimizer.param_groups: 162 | param_group['lr'] = cur_lr 163 | return cur_lr 164 | 165 | 166 | def train(epoch, criterion_list, optimizer): 167 | train_loss = 0. 168 | train_loss_cls = 0. 169 | train_loss_div = 0. 170 | train_loss_kd = 0. 171 | 172 | top1_num = 0 173 | top5_num = 0 174 | total = 0 175 | 176 | if epoch >= args.warmup_epoch: 177 | lr = adjust_lr(optimizer, epoch, args) 178 | 179 | start_time = time.time() 180 | criterion_cls = criterion_list[0] 181 | criterion_div = criterion_list[1] 182 | 183 | net.train() 184 | for batch_idx, (input, target) in enumerate(trainloader): 185 | batch_start_time = time.time() 186 | input = input.float().cuda() 187 | target = target.cuda() 188 | 189 | if epoch < args.warmup_epoch: 190 | lr = adjust_lr(optimizer, epoch, args, batch_idx, len(trainloader)) 191 | 192 | optimizer.zero_grad() 193 | logits = net(input) 194 | 195 | loss_cls = criterion_cls(logits, target) 196 | 197 | loss = loss_cls 198 | loss.backward() 199 | optimizer.step() 200 | 201 | train_loss += loss.item() / len(trainloader) 202 | train_loss_cls += loss_cls.item() / len(trainloader) 203 | 204 | top1, top5 = correct_num(logits, target, topk=(1, 5)) 205 | top1_num += top1 206 | top5_num += top5 207 | total += target.size(0) 208 | 209 | print('Epoch:{}, batch_idx:{}/{}, lr:{:.5f}, Duration:{:.2f}, Top-1 Acc:{:.4f}'.format( 210 | epoch, batch_idx, len(trainloader), lr, time.time()-batch_start_time, (top1_num/(total)).item())) 211 | 212 | 213 | acc1 = round((top1_num/total).item(), 4) 214 | acc5 = round((top5_num/total).item(), 4) 215 | 216 | 217 | 218 | with open(log_txt, 'a+') as f: 219 | f.write('Epoch:{}\t lr:{:.4f}\t duration:{:.3f}' 220 | '\n train_loss:{:.5f}\t train_loss_cls:{:.5f}' 221 | '\nTop-1 accuracy: {} \n' 222 | .format(epoch, lr, time.time() - start_time, 223 | train_loss, train_loss_cls, 224 | str(acc1))) 225 | if args.dataset == 'imagenet': 226 | f.write('Top-5 accuracy:{}\n'.format(str(acc5))) 227 | 228 | 229 | 230 | def test(epoch, criterion_cls): 231 | net.eval() 232 | global best_acc 233 | test_loss_cls = 0. 234 | 235 | top1_num = 0 236 | top5_num = 0 237 | total = 0 238 | 239 | 240 | with torch.no_grad(): 241 | for batch_idx, (inputs, target) in enumerate(testloader): 242 | batch_start_time = time.time() 243 | inputs, target = inputs.cuda(), target.cuda() 244 | logits = net(inputs) 245 | 246 | loss_cls = criterion_cls(logits, target) 247 | 248 | test_loss_cls += loss_cls.item()/ len(testloader) 249 | 250 | top1, top5 = correct_num(logits, target, topk=(1, 5)) 251 | top1_num += top1 252 | top5_num += top5 253 | total += target.size(0) 254 | 255 | print('Epoch:{}, batch_idx:{}/{}, Duration:{:.2f}, Top-1 Acc:{:.4f}'.format( 256 | epoch, batch_idx, len(testloader), time.time()-batch_start_time, (top1_num/(total)).item())) 257 | 258 | acc1 = round((top1_num/total).item(), 4) 259 | acc5 = round((top5_num/total).item(), 4) 260 | 261 | with open(log_txt, 'a+') as f: 262 | f.write('test epoch:{}\t test_loss_cls:{:.5f}\t Top-1 accuracy:{}\n' 263 | .format(epoch, test_loss_cls, str(acc1))) 264 | if args.dataset == 'imagenet': 265 | f.write('Top-5 accuracy:{}\n'.format(str(acc5))) 266 | print('test epoch:{}\t accuracy:{}\n'.format(epoch, str(acc1))) 267 | 268 | return acc1 269 | 270 | 271 | if __name__ == '__main__': 272 | best_acc = 0. # best test accuracy 273 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 274 | criterion_cls = nn.CrossEntropyLoss() 275 | criterion_div = DistillKL(args.kd_T) 276 | 277 | if args.evaluate: 278 | print('load pre-trained weights from: {}'.format(os.path.join(args.checkpoint_dir, str(model.__name__) + '_best.pth.tar'))) 279 | checkpoint = torch.load(os.path.join(args.checkpoint_dir, str(model.__name__) + '_best.pth.tar'), 280 | map_location=torch.device('cpu')) 281 | net.module.load_state_dict(checkpoint['net']) 282 | best_acc = checkpoint['acc'] 283 | start_epoch = checkpoint['epoch'] + 1 284 | test(start_epoch, criterion_cls) 285 | else: 286 | trainable_list = nn.ModuleList([]) 287 | trainable_list.append(net) 288 | optimizer = optim.SGD(trainable_list.parameters(), 289 | lr=0.1, momentum=0.9, weight_decay=args.weight_decay, nesterov=True) 290 | 291 | criterion_list = nn.ModuleList([]) 292 | criterion_list.append(criterion_cls) # classification loss 293 | criterion_list.append(criterion_div) # KL divergence loss, original knowledge distillation 294 | criterion_list.cuda() 295 | 296 | if args.resume: 297 | print('load pre-trained weights from: {}'.format(os.path.join(args.checkpoint_dir, str(model.__name__) + '.pth.tar'))) 298 | checkpoint = torch.load(os.path.join(args.checkpoint_dir, str(model.__name__) + '.pth.tar'), 299 | map_location=torch.device('cpu')) 300 | net.module.load_state_dict(checkpoint['net']) 301 | optimizer.load_state_dict(checkpoint['optimizer']) 302 | best_acc = checkpoint['acc'] 303 | start_epoch = checkpoint['epoch'] + 1 304 | 305 | for epoch in range(start_epoch, args.epochs): 306 | train(epoch, criterion_list, optimizer) 307 | acc = test(epoch, criterion_cls) 308 | 309 | state = { 310 | 'net': net.module.state_dict(), 311 | 'acc': acc, 312 | 'epoch': epoch, 313 | 'optimizer': optimizer.state_dict() 314 | } 315 | torch.save(state, os.path.join(args.checkpoint_dir, str(model.__name__) + '.pth.tar')) 316 | 317 | is_best = False 318 | if best_acc < acc: 319 | best_acc = acc 320 | is_best = True 321 | 322 | if is_best: 323 | shutil.copyfile(os.path.join(args.checkpoint_dir, str(model.__name__) + '.pth.tar'), 324 | os.path.join(args.checkpoint_dir, str(model.__name__) + '_best.pth.tar')) 325 | 326 | print('Evaluate the best model:') 327 | print('load pre-trained weights from: {}'.format(os.path.join(args.checkpoint_dir, str(model.__name__) + '_best.pth.tar'))) 328 | args.evaluate = True 329 | checkpoint = torch.load(os.path.join(args.checkpoint_dir, str(model.__name__) + '_best.pth.tar'), 330 | map_location=torch.device('cpu')) 331 | net.module.load_state_dict(checkpoint['net']) 332 | start_epoch = checkpoint['epoch'] 333 | top1_acc = test(start_epoch, criterion_cls) 334 | 335 | with open(log_txt, 'a+') as f: 336 | f.write('best_accuracy: {} \n'.format(best_acc)) 337 | print('best_accuracy: {} \n'.format(best_acc)) 338 | os.system('cp ' + log_txt + ' ' + args.checkpoint_dir) -------------------------------------------------------------------------------- /train_teacher_cifar.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.backends.cudnn as cudnn 5 | import torch.nn.functional as F 6 | 7 | import os 8 | import shutil 9 | import argparse 10 | import numpy as np 11 | 12 | 13 | import models 14 | import torchvision 15 | import torchvision.transforms as transforms 16 | from utils import cal_param_size, cal_multi_adds 17 | 18 | 19 | from bisect import bisect_right 20 | import time 21 | import math 22 | 23 | 24 | 25 | parser = argparse.ArgumentParser(description='PyTorch CIFAR Training') 26 | parser.add_argument('--data', default='./data/', type=str, help='Dataset directory') 27 | parser.add_argument('--dataset', default='cifar100', type=str, help='Dataset name') 28 | parser.add_argument('--arch', default='wrn_40_2_aux', type=str, help='network architecture') 29 | parser.add_argument('--init-lr', default=0.05, type=float, help='learning rate') 30 | parser.add_argument('--weight-decay', default=5e-4, type=float, help='weight decay') 31 | parser.add_argument('--milestones', type=int, nargs='+', default=[150,180,210], help='milestones for lr-multistep') 32 | parser.add_argument('--warmup-epoch', default=0, type=int, help='warmup epoch') 33 | parser.add_argument('--epochs', type=int, default=240, help='number of epochs') 34 | parser.add_argument('--batch-size', type=int, default=64, help='batch size') 35 | parser.add_argument('--num-workers', type=int, default=8, help='the number of workers') 36 | parser.add_argument('--gpu-id', type=str, default='0') 37 | parser.add_argument('--manual_seed', type=int, default=0) 38 | parser.add_argument('--kd_T', type=float, default=3, help='temperature for KD distillation') 39 | parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint') 40 | parser.add_argument('--evaluate', '-e', action='store_true', help='evaluate model') 41 | parser.add_argument('--freezed', action='store_true', help='freezing backbone') 42 | parser.add_argument('--checkpoint-dir', default='./checkpoint', type=str, help='directory for storing checkpoint files') 43 | parser.add_argument('--pretrained-backbone', default='./pretrained_models/wrn_40_2_aux.pth', type=str, help='pretrained weights of backbone') 44 | 45 | 46 | args = parser.parse_args() 47 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 48 | 49 | log_txt = 'result/'+ str(os.path.basename(__file__).split('.')[0]) + '_'+\ 50 | 'arch' + '_' + args.arch + '_'+\ 51 | 'dataset' + '_' + args.dataset + '_'+\ 52 | 'seed'+ str(args.manual_seed) +'.txt' 53 | 54 | 55 | log_dir = str(os.path.basename(__file__).split('.')[0]) + '_'+\ 56 | 'arch'+ '_' + args.arch + '_'+\ 57 | 'dataset' + '_' + args.dataset + '_'+\ 58 | 'seed'+ str(args.manual_seed) 59 | 60 | args.checkpoint_dir = os.path.join(args.checkpoint_dir, log_dir) 61 | if not os.path.isdir(args.checkpoint_dir): 62 | os.makedirs(args.checkpoint_dir) 63 | 64 | if args.resume is False and args.evaluate is False: 65 | with open(log_txt, 'a+') as f: 66 | f.write("==========\nArgs:{}\n==========".format(args) + '\n') 67 | 68 | np.random.seed(args.manual_seed) 69 | torch.manual_seed(args.manual_seed) 70 | torch.cuda.manual_seed_all(args.manual_seed) 71 | torch.set_printoptions(precision=4) 72 | 73 | 74 | 75 | num_classes = 100 76 | trainset = torchvision.datasets.CIFAR100(root=args.data, train=True, download=True, 77 | transform=transforms.Compose([ 78 | transforms.RandomCrop(32, padding=4), 79 | transforms.RandomHorizontalFlip(), 80 | transforms.ToTensor(), 81 | transforms.Normalize([0.5071, 0.4867, 0.4408], 82 | [0.2675, 0.2565, 0.2761]) 83 | ])) 84 | 85 | testset = torchvision.datasets.CIFAR100(root=args.data, train=False, download=True, 86 | transform=transforms.Compose([ 87 | transforms.ToTensor(), 88 | transforms.Normalize([0.5071, 0.4867, 0.4408], 89 | [0.2675, 0.2565, 0.2761]), 90 | ])) 91 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, 92 | pin_memory=(torch.cuda.is_available())) 93 | 94 | testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=False, 95 | pin_memory=(torch.cuda.is_available())) 96 | 97 | 98 | print('==> Building model..') 99 | net = getattr(models, args.arch)(num_classes=num_classes) 100 | net.eval() 101 | resolution = (1, 3, 32, 32) 102 | print('Arch: %s, Params: %.2fM, Multi-adds: %.2fG' 103 | % (args.arch, cal_param_size(net)/1e6, cal_multi_adds(net, resolution)/1e9)) 104 | del(net) 105 | 106 | 107 | model = getattr(models, args.arch) 108 | net = model(num_classes=num_classes).cuda() 109 | optimizer = optim.SGD([{'params':net.backbone.parameters(), 'lr':args.init_lr}, 110 | {'params':net.auxiliary_classifier.parameters(), 'lr':args.init_lr}], 111 | momentum=0.9, weight_decay=args.weight_decay, nesterov=True) 112 | 113 | if args.freezed: 114 | print('load pre-trained backbone weights from: {}'.format(args.pretrained_backbone)) 115 | checkpoint = torch.load(args.pretrained_backbone, map_location=torch.device('cpu')) 116 | net.backbone.load_state_dict(checkpoint['state_dict']) 117 | for n,p in net.backbone.named_parameters(): 118 | p.requires_grad = False 119 | net.backbone.eval() 120 | 121 | 122 | net = torch.nn.DataParallel(net) 123 | _, ss_logits = net(torch.randn(2, 3, 32, 32)) 124 | num_auxiliary_branches = len(ss_logits) 125 | cudnn.benchmark = True 126 | 127 | 128 | class DistillKL(nn.Module): 129 | """Distilling the Knowledge in a Neural Network""" 130 | def __init__(self, T): 131 | super(DistillKL, self).__init__() 132 | self.T = T 133 | 134 | def forward(self, y_s, y_t): 135 | p_s = F.log_softmax(y_s/self.T, dim=1) 136 | p_t = F.softmax(y_t/self.T, dim=1) 137 | loss = F.kl_div(p_s, p_t, reduction='batchmean') * (self.T**2) 138 | return loss 139 | 140 | 141 | def correct_num(output, target, topk=(1,)): 142 | """Computes the precision@k for the specified values of k""" 143 | maxk = max(topk) 144 | batch_size = target.size(0) 145 | 146 | _, pred = output.topk(maxk, 1, True, True) 147 | correct = pred.eq(target.view(-1, 1).expand_as(pred)) 148 | 149 | res = [] 150 | for k in topk: 151 | correct_k = correct[:, :k].float().sum() 152 | res.append(correct_k) 153 | return res 154 | 155 | 156 | def adjust_lr(optimizer, epoch, args, step=0, all_iters_per_epoch=0, eta_min=0.): 157 | cur_lr = 0. 158 | if epoch < args.warmup_epoch: 159 | cur_lr = args.init_lr * float(1 + step + epoch*all_iters_per_epoch)/(args.warmup_epoch *all_iters_per_epoch) 160 | else: 161 | epoch = epoch - args.warmup_epoch 162 | cur_lr = args.init_lr * 0.1 ** bisect_right(args.milestones, epoch) 163 | 164 | for param_group in optimizer.param_groups: 165 | param_group['lr'] = cur_lr 166 | return cur_lr 167 | 168 | 169 | def train(epoch, criterion_list, optimizer): 170 | train_loss = 0. 171 | train_loss_cls = 0. 172 | train_loss_div = 0. 173 | train_loss_kd = 0. 174 | 175 | ss_top1_num = [0] * num_auxiliary_branches 176 | ss_top5_num = [0] * num_auxiliary_branches 177 | class_top1_num = [0] * num_auxiliary_branches 178 | class_top5_num = [0] * num_auxiliary_branches 179 | top1_num = 0 180 | top5_num = 0 181 | total = 0 182 | 183 | if epoch >= args.warmup_epoch: 184 | lr = adjust_lr(optimizer, epoch, args) 185 | 186 | start_time = time.time() 187 | criterion_cls = criterion_list[0] 188 | criterion_div = criterion_list[1] 189 | 190 | if args.freezed: 191 | net.module.auxiliary_classifier.train() 192 | else: 193 | net.train() 194 | for batch_idx, (input, target) in enumerate(trainloader): 195 | batch_start_time = time.time() 196 | input = input.float().cuda() 197 | target = target.cuda() 198 | 199 | size = input.shape[1:] 200 | input = torch.stack([torch.rot90(input, k, (2, 3)) for k in range(4)], 1).view(-1, *size) 201 | labels = torch.stack([target*4+i for i in range(4)], 1).view(-1) 202 | 203 | if epoch < args.warmup_epoch: 204 | lr = adjust_lr(optimizer, epoch, args, batch_idx, len(trainloader)) 205 | 206 | optimizer.zero_grad() 207 | 208 | if args.freezed: 209 | logits, ss_logits = net(input, grad=False) 210 | else: 211 | logits, ss_logits = net(input, grad=True) 212 | 213 | 214 | loss_cls = torch.tensor(0.).cuda() 215 | for i in range(len(ss_logits)): 216 | loss_cls = loss_cls + criterion_cls(ss_logits[i], labels) 217 | loss_cls = loss_cls + criterion_cls(logits[0::4], target) 218 | 219 | 220 | loss = loss_cls 221 | loss.backward() 222 | optimizer.step() 223 | 224 | train_loss += loss.item() / len(trainloader) 225 | train_loss_cls += loss_cls.item() / len(trainloader) 226 | 227 | for i in range(len(ss_logits)): 228 | top1, top5 = correct_num(ss_logits[i], labels, topk=(1, 5)) 229 | ss_top1_num[i] += top1 230 | ss_top5_num[i] += top5 231 | 232 | class_logits = [torch.stack(torch.split(ss_logits[i], split_size_or_sections=4, dim=1), dim=1).sum(dim=2) for i in range(len(ss_logits))] 233 | multi_target = target.view(-1, 1).repeat(1, 4).view(-1) 234 | for i in range(len(class_logits)): 235 | top1, top5 = correct_num(class_logits[i], multi_target, topk=(1, 5)) 236 | class_top1_num[i] += top1 237 | class_top5_num[i] += top5 238 | 239 | logits = logits.view(-1, 4, num_classes)[:, 0, :] 240 | top1, top5 = correct_num(logits, target, topk=(1, 5)) 241 | top1_num += top1 242 | top5_num += top5 243 | total += target.size(0) 244 | 245 | print('Epoch:{}, batch_idx:{}/{}, lr:{:.5f}, Duration:{:.2f}, Top-1 Acc:{:.4f}'.format( 246 | epoch, batch_idx, len(trainloader), lr, time.time()-batch_start_time, (top1_num/(total)).item())) 247 | 248 | 249 | ss_acc1 = [round((ss_top1_num[i]/(total*4)).item(), 4) for i in range(num_auxiliary_branches)] 250 | ss_acc5 = [round((ss_top5_num[i]/(total*4)).item(), 4) for i in range(num_auxiliary_branches)] 251 | class_acc1 = [round((class_top1_num[i]/(total*4)).item(), 4) for i in range(num_auxiliary_branches)] + [round((top1_num/(total)).item(), 4)] 252 | class_acc5 = [round((class_top5_num[i]/(total*4)).item(), 4) for i in range(num_auxiliary_branches)] + [round((top5_num/(total)).item(), 4)] 253 | 254 | print('Train epoch:{}\nTrain Top-1 ss_accuracy: {}\nTrain Top-1 class_accuracy: {}\n'.format(epoch, str(ss_acc1), str(class_acc1))) 255 | 256 | with open(log_txt, 'a+') as f: 257 | f.write('Epoch:{}\t lr:{:.4f}\t duration:{:.3f}' 258 | '\n train_loss:{:.5f}\t train_loss_cls:{:.5f}' 259 | '\nTrain Top-1 ss_accuracy: {}\nTrain Top-1 class_accuracy: {}\n' 260 | .format(epoch, lr, time.time() - start_time, 261 | train_loss, train_loss_cls, 262 | str(ss_acc1), str(class_acc1))) 263 | 264 | 265 | 266 | def test(epoch, criterion_cls): 267 | global best_acc 268 | test_loss_cls = 0. 269 | 270 | ss_top1_num = [0] * (num_auxiliary_branches) 271 | ss_top5_num = [0] * (num_auxiliary_branches) 272 | class_top1_num = [0] * num_auxiliary_branches 273 | class_top5_num = [0] * num_auxiliary_branches 274 | top1_num = 0 275 | top5_num = 0 276 | total = 0 277 | 278 | net.eval() 279 | with torch.no_grad(): 280 | for batch_idx, (inputs, target) in enumerate(testloader): 281 | batch_start_time = time.time() 282 | input, target = inputs.cuda(), target.cuda() 283 | 284 | size = input.shape[1:] 285 | input = torch.stack([torch.rot90(input, k, (2, 3)) for k in range(4)], 1).view(-1, *size) 286 | labels = torch.stack([target*4+i for i in range(4)], 1).view(-1) 287 | 288 | logits, ss_logits = net(input) 289 | loss_cls = torch.tensor(0.).cuda() 290 | for i in range(len(ss_logits)): 291 | loss_cls = loss_cls + criterion_cls(ss_logits[i], labels) 292 | loss_cls = loss_cls + criterion_cls(logits[0::4], target) 293 | 294 | test_loss_cls += loss_cls.item()/ len(testloader) 295 | 296 | batch_size = logits.size(0) // 4 297 | for i in range(len(ss_logits)): 298 | top1, top5 = correct_num(ss_logits[i], labels, topk=(1, 5)) 299 | ss_top1_num[i] += top1 300 | ss_top5_num[i] += top5 301 | 302 | class_logits = [torch.stack(torch.split(ss_logits[i], split_size_or_sections=4, dim=1), dim=1).sum(dim=2) for i in range(len(ss_logits))] 303 | multi_target = target.view(-1, 1).repeat(1, 4).view(-1) 304 | for i in range(len(class_logits)): 305 | top1, top5 = correct_num(class_logits[i], multi_target, topk=(1, 5)) 306 | class_top1_num[i] += top1 307 | class_top5_num[i] += top5 308 | 309 | logits = logits.view(-1, 4, num_classes)[:, 0, :] 310 | top1, top5 = correct_num(logits, target, topk=(1, 5)) 311 | top1_num += top1 312 | top5_num += top5 313 | total += target.size(0) 314 | 315 | 316 | print('Epoch:{}, batch_idx:{}/{}, Duration:{:.2f}, Top-1 Acc:{:.4f}'.format( 317 | epoch, batch_idx, len(testloader), time.time()-batch_start_time, (top1_num/(total)).item())) 318 | 319 | ss_acc1 = [round((ss_top1_num[i]/(total*4)).item(), 4) for i in range(len(ss_logits))] 320 | ss_acc5 = [round((ss_top5_num[i]/(total*4)).item(), 4) for i in range(len(ss_logits))] 321 | class_acc1 = [round((class_top1_num[i]/(total*4)).item(), 4) for i in range(num_auxiliary_branches)] + [round((top1_num/(total)).item(), 4)] 322 | class_acc5 = [round((class_top5_num[i]/(total*4)).item(), 4) for i in range(num_auxiliary_branches)] + [round((top5_num/(total)).item(), 4)] 323 | with open(log_txt, 'a+') as f: 324 | f.write('test epoch:{}\t test_loss_cls:{:.5f}\nTop-1 ss_accuracy: {}\nTop-1 class_accuracy: {}\n' 325 | .format(epoch, test_loss_cls, str(ss_acc1), str(class_acc1))) 326 | print('test epoch:{}\tTest Top-1 ss_accuracy: {}\nTest Top-1 class_accuracy: {}\n'.format(epoch, str(ss_acc1), str(class_acc1))) 327 | 328 | return max(ss_acc1) 329 | 330 | 331 | if __name__ == '__main__': 332 | best_acc = 0. # best test accuracy 333 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 334 | criterion_cls = nn.CrossEntropyLoss() 335 | criterion_div = DistillKL(args.kd_T) 336 | 337 | if args.evaluate: 338 | print('load pre-trained weights from: {}'.format(os.path.join(args.checkpoint_dir, str(model.__name__) + '.pth.tar'))) 339 | checkpoint = torch.load(os.path.join(args.checkpoint_dir, str(model.__name__) + '.pth.tar'), 340 | map_location=torch.device('cpu')) 341 | net.module.load_state_dict(checkpoint['net']) 342 | best_acc = checkpoint['acc'] 343 | test(start_epoch, criterion_cls) 344 | else: 345 | criterion_list = nn.ModuleList([]) 346 | criterion_list.append(criterion_cls) # classification loss 347 | criterion_list.append(criterion_div) # KL divergence loss, original knowledge distillation 348 | criterion_list.cuda() 349 | 350 | if args.resume: 351 | print('load pre-trained weights from: {}'.format(os.path.join(args.checkpoint_dir, str(model.__name__) + '.pth.tar'))) 352 | checkpoint = torch.load(os.path.join(args.checkpoint_dir, str(model.__name__) + '.pth.tar'), 353 | map_location=torch.device('cpu')) 354 | net.module.load_state_dict(checkpoint['net']) 355 | optimizer.load_state_dict(checkpoint['optimizer']) 356 | best_acc = checkpoint['acc'] 357 | start_epoch = checkpoint['epoch'] + 1 358 | 359 | for epoch in range(start_epoch, args.epochs): 360 | train(epoch, criterion_list, optimizer) 361 | acc = test(epoch, criterion_cls) 362 | 363 | state = { 364 | 'net': net.module.state_dict(), 365 | 'acc': acc, 366 | 'epoch': epoch, 367 | 'optimizer': optimizer.state_dict() 368 | } 369 | torch.save(state, os.path.join(args.checkpoint_dir, str(model.__name__) + '.pth.tar')) 370 | 371 | is_best = False 372 | if best_acc < acc: 373 | best_acc = acc 374 | is_best = True 375 | 376 | if is_best: 377 | shutil.copyfile(os.path.join(args.checkpoint_dir, str(model.__name__) + '.pth.tar'), 378 | os.path.join(args.checkpoint_dir, str(model.__name__) + '_best.pth.tar')) 379 | 380 | print('Evaluate the best model:') 381 | print('load pre-trained weights from: {}'.format(os.path.join(args.checkpoint_dir, str(model.__name__) + '_best.pth.tar'))) 382 | args.evaluate = True 383 | checkpoint = torch.load(os.path.join(args.checkpoint_dir, str(model.__name__) + '_best.pth.tar'), 384 | map_location=torch.device('cpu')) 385 | net.module.load_state_dict(checkpoint['net']) 386 | start_epoch = checkpoint['epoch'] 387 | top1_acc = test(start_epoch, criterion_cls) 388 | 389 | with open(log_txt, 'a+') as f: 390 | f.write('best_accuracy: {} \n'.format(best_acc)) 391 | print('best_accuracy: {} \n'.format(best_acc)) 392 | os.system('cp ' + log_txt + ' ' + args.checkpoint_dir) 393 | -------------------------------------------------------------------------------- /train_student_cifar.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.backends.cudnn as cudnn 5 | import torch.nn.functional as F 6 | 7 | import os 8 | import shutil 9 | import argparse 10 | import numpy as np 11 | 12 | 13 | import models 14 | import torchvision 15 | import torchvision.transforms as transforms 16 | from utils import cal_param_size, cal_multi_adds 17 | 18 | 19 | from bisect import bisect_right 20 | import time 21 | import math 22 | 23 | 24 | 25 | parser = argparse.ArgumentParser(description='PyTorch CIFAR Training') 26 | parser.add_argument('--data', default='./data/', type=str, help='Dataset directory') 27 | parser.add_argument('--dataset', default='cifar100', type=str, help='Dataset name') 28 | parser.add_argument('--arch', default='wrn_16_2_aux', type=str, help='student network architecture') 29 | parser.add_argument('--tarch', default='wrn_40_2_aux', type=str, help='teacher network architecture') 30 | parser.add_argument('--tcheckpoint', default='wrn_40_2_aux.pth.tar', type=str, help='pre-trained weights of teacher') 31 | parser.add_argument('--init-lr', default=0.05, type=float, help='learning rate') 32 | parser.add_argument('--weight-decay', default=5e-4, type=float, help='weight decay') 33 | parser.add_argument('--lr-type', default='multistep', type=str, help='learning rate strategy') 34 | parser.add_argument('--milestones', default=[150,180,210], type=list, help='milestones for lr-multistep') 35 | parser.add_argument('--sgdr-t', default=300, type=int, dest='sgdr_t',help='SGDR T_0') 36 | parser.add_argument('--warmup-epoch', default=0, type=int, help='warmup epoch') 37 | parser.add_argument('--epochs', type=int, default=240, help='number of epochs to train') 38 | parser.add_argument('--batch-size', type=int, default=64, help='batch size') 39 | parser.add_argument('--num-workers', type=int, default=8, help='the number of workers') 40 | parser.add_argument('--gpu-id', type=str, default='0') 41 | parser.add_argument('--manual_seed', type=int, default=0) 42 | parser.add_argument('--kd_T', type=float, default=3, help='temperature for KD distillation') 43 | parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint') 44 | parser.add_argument('--evaluate', '-e', action='store_true', help='evaluate model') 45 | parser.add_argument('--checkpoint-dir', default='./checkpoint', type=str, help='directory fot storing checkpoints') 46 | 47 | # global hyperparameter set 48 | args = parser.parse_args() 49 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 50 | 51 | log_txt = 'result/'+ str(os.path.basename(__file__).split('.')[0]) + '_'+\ 52 | 'tarch' + '_' + args.tarch + '_'+\ 53 | 'arch' + '_' + args.arch + '_'+\ 54 | 'dataset' + '_' + args.dataset + '_'+\ 55 | 'seed'+ str(args.manual_seed) +'.txt' 56 | 57 | log_dir = str(os.path.basename(__file__).split('.')[0]) + '_'+\ 58 | 'tarch' + '_' + args.tarch + '_'+\ 59 | 'arch'+ '_' + args.arch + '_'+\ 60 | 'dataset' + '_' + args.dataset + '_'+\ 61 | 'seed'+ str(args.manual_seed) 62 | 63 | args.checkpoint_dir = os.path.join(args.checkpoint_dir, log_dir) 64 | if not os.path.isdir(args.checkpoint_dir): 65 | os.makedirs(args.checkpoint_dir) 66 | 67 | if args.resume is False and args.evaluate is False: 68 | with open(log_txt, 'a+') as f: 69 | f.write("==========\nArgs:{}\n==========".format(args) + '\n') 70 | 71 | np.random.seed(args.manual_seed) 72 | torch.manual_seed(args.manual_seed) 73 | torch.cuda.manual_seed_all(args.manual_seed) 74 | torch.set_printoptions(precision=4) 75 | 76 | 77 | num_classes = 100 78 | trainset = torchvision.datasets.CIFAR100(root=args.data, train=True, download=True, 79 | transform=transforms.Compose([ 80 | transforms.RandomCrop(32, padding=4), 81 | transforms.RandomHorizontalFlip(), 82 | transforms.ToTensor(), 83 | transforms.Normalize([0.5071, 0.4867, 0.4408], 84 | [0.2675, 0.2565, 0.2761]) 85 | ])) 86 | 87 | testset = torchvision.datasets.CIFAR100(root=args.data, train=False, download=True, 88 | transform=transforms.Compose([ 89 | transforms.ToTensor(), 90 | transforms.Normalize([0.5071, 0.4867, 0.4408], 91 | [0.2675, 0.2565, 0.2761]), 92 | ])) 93 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, 94 | pin_memory=(torch.cuda.is_available())) 95 | 96 | testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=False, 97 | pin_memory=(torch.cuda.is_available())) 98 | 99 | print('==> Building model..') 100 | net = getattr(models, args.tarch)(num_classes=num_classes) 101 | net.eval() 102 | resolution = (1, 3, 32, 32) 103 | print('Teacher Arch: %s, Params: %.2fM, Multi-adds: %.2fG' 104 | % (args.tarch, cal_param_size(net)/1e6, cal_multi_adds(net, resolution)/1e9)) 105 | del(net) 106 | net = getattr(models, args.arch)(num_classes=num_classes) 107 | net.eval() 108 | resolution = (1, 3, 32, 32) 109 | print('Student Arch: %s, Params: %.2fM, Multi-adds: %.2fG' 110 | % (args.arch, cal_param_size(net)/1e6, cal_multi_adds(net, resolution)/1e9)) 111 | del(net) 112 | 113 | 114 | print('load pre-trained teacher weights from: {}'.format(args.tcheckpoint)) 115 | checkpoint = torch.load(args.tcheckpoint, map_location=torch.device('cpu')) 116 | 117 | model = getattr(models, args.arch) 118 | net = model(num_classes=num_classes).cuda() 119 | net = torch.nn.DataParallel(net) 120 | 121 | tmodel = getattr(models, args.tarch) 122 | tnet = tmodel(num_classes=num_classes).cuda() 123 | tnet.load_state_dict(checkpoint['net']) 124 | tnet.eval() 125 | tnet = torch.nn.DataParallel(tnet) 126 | 127 | _, ss_logits = net(torch.randn(2, 3, 32, 32)) 128 | num_auxiliary_branches = len(ss_logits) 129 | cudnn.benchmark = True 130 | 131 | 132 | class DistillKL(nn.Module): 133 | """Distilling the Knowledge in a Neural Network""" 134 | def __init__(self, T): 135 | super(DistillKL, self).__init__() 136 | self.T = T 137 | 138 | def forward(self, y_s, y_t): 139 | p_s = F.log_softmax(y_s/self.T, dim=1) 140 | p_t = F.softmax(y_t/self.T, dim=1) 141 | loss = F.kl_div(p_s, p_t, reduction='batchmean') * (self.T**2) 142 | return loss 143 | 144 | 145 | def correct_num(output, target, topk=(1,)): 146 | """Computes the precision@k for the specified values of k""" 147 | maxk = max(topk) 148 | batch_size = target.size(0) 149 | 150 | _, pred = output.topk(maxk, 1, True, True) 151 | correct = pred.eq(target.view(-1, 1).expand_as(pred)) 152 | 153 | res = [] 154 | for k in topk: 155 | correct_k = correct[:, :k].float().sum() 156 | res.append(correct_k) 157 | return res 158 | 159 | 160 | def adjust_lr(optimizer, epoch, args, step=0, all_iters_per_epoch=0): 161 | cur_lr = 0. 162 | if epoch < args.warmup_epoch: 163 | cur_lr = args.init_lr * float(1 + step + epoch*all_iters_per_epoch)/(args.warmup_epoch *all_iters_per_epoch) 164 | else: 165 | epoch = epoch - args.warmup_epoch 166 | cur_lr = args.init_lr * 0.1 ** bisect_right(args.milestones, epoch) 167 | 168 | for param_group in optimizer.param_groups: 169 | param_group['lr'] = cur_lr 170 | return cur_lr 171 | 172 | 173 | def train(epoch, criterion_list, optimizer): 174 | train_loss = 0. 175 | train_loss_cls = 0. 176 | train_loss_div = 0. 177 | 178 | ss_top1_num = [0] * num_auxiliary_branches 179 | ss_top5_num = [0] * num_auxiliary_branches 180 | class_top1_num = [0] * num_auxiliary_branches 181 | class_top5_num = [0] * num_auxiliary_branches 182 | top1_num = 0 183 | top5_num = 0 184 | total = 0 185 | 186 | if epoch >= args.warmup_epoch: 187 | lr = adjust_lr(optimizer, epoch, args) 188 | 189 | start_time = time.time() 190 | criterion_cls = criterion_list[0] 191 | criterion_div = criterion_list[1] 192 | 193 | net.train() 194 | for batch_idx, (input, target) in enumerate(trainloader): 195 | batch_start_time = time.time() 196 | input = input.float().cuda() 197 | target = target.cuda() 198 | 199 | size = input.shape[1:] 200 | input = torch.stack([torch.rot90(input, k, (2, 3)) for k in range(4)], 1).view(-1, *size) 201 | labels = torch.stack([target*4+i for i in range(4)], 1).view(-1) 202 | 203 | if epoch < args.warmup_epoch: 204 | lr = adjust_lr(optimizer, epoch, args, batch_idx, len(trainloader)) 205 | 206 | optimizer.zero_grad() 207 | logits, ss_logits = net(input, grad=True) 208 | with torch.no_grad(): 209 | t_logits, t_ss_logits = tnet(input) 210 | 211 | loss_cls = torch.tensor(0.).cuda() 212 | loss_div = torch.tensor(0.).cuda() 213 | 214 | loss_cls = loss_cls + criterion_cls(logits[0::4], target) 215 | for i in range(len(ss_logits)): 216 | loss_div = loss_div + criterion_div(ss_logits[i], t_ss_logits[i].detach()) 217 | 218 | loss_div = loss_div + criterion_div(logits, t_logits.detach()) 219 | 220 | 221 | loss = loss_cls + loss_div 222 | loss.backward() 223 | optimizer.step() 224 | 225 | 226 | train_loss += loss.item() / len(trainloader) 227 | train_loss_cls += loss_cls.item() / len(trainloader) 228 | train_loss_div += loss_div.item() / len(trainloader) 229 | 230 | for i in range(len(ss_logits)): 231 | top1, top5 = correct_num(ss_logits[i], labels, topk=(1, 5)) 232 | ss_top1_num[i] += top1 233 | ss_top5_num[i] += top5 234 | 235 | class_logits = [torch.stack(torch.split(ss_logits[i], split_size_or_sections=4, dim=1), dim=1).sum(dim=2) for i in range(len(ss_logits))] 236 | multi_target = target.view(-1, 1).repeat(1, 4).view(-1) 237 | for i in range(len(class_logits)): 238 | top1, top5 = correct_num(class_logits[i], multi_target, topk=(1, 5)) 239 | class_top1_num[i] += top1 240 | class_top5_num[i] += top5 241 | 242 | logits = logits.view(-1, 4, num_classes)[:, 0, :] 243 | top1, top5 = correct_num(logits, target, topk=(1, 5)) 244 | top1_num += top1 245 | top5_num += top5 246 | total += target.size(0) 247 | 248 | print('Epoch:{}, batch_idx:{}/{}, lr:{:.5f}, Duration:{:.2f}, Top-1 Acc:{:.4f}'.format( 249 | epoch, batch_idx, len(trainloader), lr, time.time()-batch_start_time, (top1_num/(total)).item())) 250 | 251 | 252 | ss_acc1 = [round((ss_top1_num[i]/(total*4)).item(), 4) for i in range(num_auxiliary_branches)] 253 | ss_acc5 = [round((ss_top5_num[i]/(total*4)).item(), 4) for i in range(num_auxiliary_branches)] 254 | class_acc1 = [round((class_top1_num[i]/(total*4)).item(), 4) for i in range(num_auxiliary_branches)] + [round((top1_num/(total)).item(), 4)] 255 | class_acc5 = [round((class_top5_num[i]/(total*4)).item(), 4) for i in range(num_auxiliary_branches)] + [round((top5_num/(total)).item(), 4)] 256 | 257 | print('Train epoch:{}\nTrain Top-1 ss_accuracy: {}\nTrain Top-1 class_accuracy: {}\n'.format(epoch, str(ss_acc1), str(class_acc1))) 258 | 259 | with open(log_txt, 'a+') as f: 260 | f.write('Epoch:{}\t lr:{:.5f}\t duration:{:.3f}' 261 | '\n train_loss:{:.5f}\t train_loss_cls:{:.5f}\t train_loss_div:{:.5f}' 262 | '\nTrain Top-1 ss_accuracy: {}\nTrain Top-1 class_accuracy: {}\n' 263 | .format(epoch, lr, time.time() - start_time, 264 | train_loss, train_loss_cls, train_loss_div, 265 | str(ss_acc1), str(class_acc1))) 266 | 267 | 268 | 269 | def test(epoch, criterion_cls, net): 270 | global best_acc 271 | test_loss_cls = 0. 272 | 273 | ss_top1_num = [0] * (num_auxiliary_branches) 274 | ss_top5_num = [0] * (num_auxiliary_branches) 275 | class_top1_num = [0] * num_auxiliary_branches 276 | class_top5_num = [0] * num_auxiliary_branches 277 | top1_num = 0 278 | top5_num = 0 279 | total = 0 280 | 281 | net.eval() 282 | with torch.no_grad(): 283 | for batch_idx, (inputs, target) in enumerate(testloader): 284 | batch_start_time = time.time() 285 | input, target = inputs.cuda(), target.cuda() 286 | 287 | size = input.shape[1:] 288 | input = torch.stack([torch.rot90(input, k, (2, 3)) for k in range(4)], 1).view(-1, *size) 289 | labels = torch.stack([target*4+i for i in range(4)], 1).view(-1) 290 | 291 | logits, ss_logits = net(input) 292 | loss_cls = torch.tensor(0.).cuda() 293 | loss_cls = loss_cls + criterion_cls(logits[0::4], target) 294 | 295 | test_loss_cls += loss_cls.item()/ len(testloader) 296 | 297 | batch_size = logits.size(0) // 4 298 | for i in range(len(ss_logits)): 299 | top1, top5 = correct_num(ss_logits[i], labels, topk=(1, 5)) 300 | ss_top1_num[i] += top1 301 | ss_top5_num[i] += top5 302 | 303 | class_logits = [torch.stack(torch.split(ss_logits[i], split_size_or_sections=4, dim=1), dim=1).sum(dim=2) for i in range(len(ss_logits))] 304 | multi_target = target.view(-1, 1).repeat(1, 4).view(-1) 305 | for i in range(len(class_logits)): 306 | top1, top5 = correct_num(class_logits[i], multi_target, topk=(1, 5)) 307 | class_top1_num[i] += top1 308 | class_top5_num[i] += top5 309 | 310 | logits = logits.view(-1, 4, num_classes)[:, 0, :] 311 | top1, top5 = correct_num(logits, target, topk=(1, 5)) 312 | top1_num += top1 313 | top5_num += top5 314 | total += target.size(0) 315 | 316 | 317 | print('Epoch:{}, batch_idx:{}/{}, Duration:{:.2f}, Top-1 Acc:{:.4f}'.format( 318 | epoch, batch_idx, len(testloader), time.time()-batch_start_time, (top1_num/(total)).item())) 319 | 320 | ss_acc1 = [round((ss_top1_num[i]/(total*4)).item(), 4) for i in range(len(ss_logits))] 321 | ss_acc5 = [round((ss_top5_num[i]/(total*4)).item(), 4) for i in range(len(ss_logits))] 322 | class_acc1 = [round((class_top1_num[i]/(total*4)).item(), 4) for i in range(num_auxiliary_branches)] + [round((top1_num/(total)).item(), 4)] 323 | class_acc5 = [round((class_top5_num[i]/(total*4)).item(), 4) for i in range(num_auxiliary_branches)] + [round((top5_num/(total)).item(), 4)] 324 | with open(log_txt, 'a+') as f: 325 | f.write('test epoch:{}\t test_loss_cls:{:.5f}\nTop-1 ss_accuracy: {}\nTop-1 class_accuracy: {}\n' 326 | .format(epoch, test_loss_cls, str(ss_acc1), str(class_acc1))) 327 | print('test epoch:{}\nTest Top-1 ss_accuracy: {}\nTest Top-1 class_accuracy: {}\n'.format(epoch, str(ss_acc1), str(class_acc1))) 328 | 329 | return class_acc1[-1] 330 | 331 | 332 | if __name__ == '__main__': 333 | best_acc = 0. # best test accuracy 334 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 335 | criterion_cls = nn.CrossEntropyLoss() 336 | criterion_div = DistillKL(args.kd_T) 337 | 338 | if args.evaluate: 339 | print('load pre-trained weights from: {}'.format(os.path.join(args.checkpoint_dir, str(model.__name__) + '.pth.tar'))) 340 | checkpoint = torch.load(os.path.join(args.checkpoint_dir, str(model.__name__) + '.pth.tar'), 341 | map_location=torch.device('cpu')) 342 | net.module.load_state_dict(checkpoint['net']) 343 | best_acc = checkpoint['acc'] 344 | test(start_epoch, criterion_cls, net) 345 | else: 346 | print('Evaluate Teacher:') 347 | acc = test(0, criterion_cls, tnet) 348 | print('Teacher Acc:', acc) 349 | 350 | trainable_list = nn.ModuleList([]) 351 | trainable_list.append(net) 352 | optimizer = optim.SGD(trainable_list.parameters(), 353 | lr=0.1, momentum=0.9, weight_decay=args.weight_decay, nesterov=True) 354 | 355 | criterion_list = nn.ModuleList([]) 356 | criterion_list.append(criterion_cls) # classification loss 357 | criterion_list.append(criterion_div) # KL divergence loss, original knowledge distillation 358 | criterion_list.cuda() 359 | 360 | 361 | if args.resume: 362 | print('load pre-trained weights from: {}'.format(os.path.join(args.checkpoint_dir, str(model.__name__) + '.pth.tar'))) 363 | checkpoint = torch.load(os.path.join(args.checkpoint_dir, str(model.__name__) + '.pth.tar'), 364 | map_location=torch.device('cpu')) 365 | net.module.load_state_dict(checkpoint['net']) 366 | optimizer.load_state_dict(checkpoint['optimizer']) 367 | best_acc = checkpoint['acc'] 368 | start_epoch = checkpoint['epoch'] + 1 369 | 370 | for epoch in range(start_epoch, args.epochs): 371 | train(epoch, criterion_list, optimizer) 372 | acc = test(epoch, criterion_cls, net) 373 | 374 | state = { 375 | 'net': net.module.state_dict(), 376 | 'acc': acc, 377 | 'epoch': epoch, 378 | 'optimizer': optimizer.state_dict() 379 | } 380 | torch.save(state, os.path.join(args.checkpoint_dir, str(model.__name__) + '.pth.tar')) 381 | 382 | is_best = False 383 | if best_acc < acc: 384 | best_acc = acc 385 | is_best = True 386 | 387 | if is_best: 388 | shutil.copyfile(os.path.join(args.checkpoint_dir, str(model.__name__) + '.pth.tar'), 389 | os.path.join(args.checkpoint_dir, str(model.__name__) + '_best.pth.tar')) 390 | 391 | print('Evaluate the best model:') 392 | print('load pre-trained weights from: {}'.format(os.path.join(args.checkpoint_dir, str(model.__name__) + '_best.pth.tar'))) 393 | args.evaluate = True 394 | checkpoint = torch.load(os.path.join(args.checkpoint_dir, str(model.__name__) + '_best.pth.tar'), 395 | map_location=torch.device('cpu')) 396 | net.module.load_state_dict(checkpoint['net']) 397 | start_epoch = checkpoint['epoch'] 398 | top1_acc = test(start_epoch, criterion_cls, net) 399 | 400 | with open(log_txt, 'a+') as f: 401 | f.write('best_accuracy: {} \n'.format(best_acc)) 402 | print('best_accuracy: {} \n'.format(best_acc)) 403 | os.system('cp ' + log_txt + ' ' + args.checkpoint_dir) 404 | --------------------------------------------------------------------------------