├── Cifar ├── learners │ ├── __init__.py │ └── basic_learner.py ├── utils │ ├── __init__.py │ ├── helper.py │ ├── dist_utils.py │ ├── model_profiling.py │ ├── compute_flops.py │ ├── pytorch_utils.py │ └── utils.py ├── run_scripts │ ├── feature_extract │ │ ├── extract_resnet20_20m.sh │ │ └── extract_resnet56_60m.sh │ ├── evaluate │ │ ├── evaluate_resnet20_cifar.sh │ │ └── evaluate_resnet56_cifar.sh │ └── train │ │ ├── start_train_resnet20_cifar.sh │ │ └── start_train_resnet56_cifar.sh ├── nets │ ├── se_module.py │ ├── model_helper.py │ ├── __init__.py │ ├── base_models.py │ ├── resnet_20s_decode.py │ └── resnet_cifar.py ├── data_providers │ ├── __init__.py │ └── cifar10.py ├── feature_extract.py ├── data_loader.py └── main.py ├── .gitignore ├── Imagenet ├── utils │ ├── __init__.py │ └── pytorch_utils.py ├── run_scripts │ ├── evaluate │ │ ├── evaluate_mobilenetv2_150m.sh │ │ ├── evaluate_mobilenetv2_220m.sh │ │ ├── evaluate_mobilenet_50m.sh │ │ ├── evaluate_mobilenet_150m.sh │ │ ├── evaluate_mobilenet_285m.sh │ │ ├── evaluate_resnet50_2000m.sh │ │ └── evaluate_resnet50_2300m.sh │ ├── feature_extract │ │ ├── extract_mobilenet_150m.sh │ │ ├── extract_mobilenet_285m.sh │ │ ├── extract_mobilenet_50m.sh │ │ ├── extract_resnet50_1800m.sh │ │ ├── extract_resnet50_2000m.sh │ │ ├── extract_resnet50_2300m.sh │ │ ├── extract_mobilenetv2_150m.sh │ │ └── extract_mobilenetv2_220m.sh │ └── train │ │ ├── train_resnet50_base.sh │ │ ├── train_mobilenet_base.sh │ │ ├── train_mobilenetv2_base.sh │ │ ├── train_mobilenetv2_150m.sh │ │ ├── train_mobilenetv2_220m.sh │ │ ├── train_mobilenet_50m.sh │ │ ├── train_mobilenet_150m.sh │ │ ├── train_mobilenet_285m.sh │ │ ├── train_resnet50_1800m.sh │ │ ├── train_resnet50_2000m.sh │ │ └── train_resnet50_2300m.sh ├── prep_imagenet.sh ├── data_providers │ ├── __init__.py │ └── imagenet.py ├── evaluate.py ├── models │ ├── __init__.py │ └── base_models.py ├── cal_FLOPs.py ├── feature_extract.py └── train.py ├── ToyExample ├── utils │ ├── __init__.py │ └── pytorch_utils.py ├── run_scripts │ ├── evaluate │ │ └── evaluate_vgg_100m.sh │ ├── feature_extract │ │ └── extract_vgg_100m.sh │ └── train │ │ └── train_vgg_100m.sh ├── data_providers │ ├── __init__.py │ └── cifar10.py ├── evaluate.py ├── models │ ├── __init__.py │ ├── base_models.py │ └── vgg.py ├── cal_FLOPs.py ├── feature_extract.py └── train.py ├── requirements.txt ├── README.md └── CKA └── cka.py /Cifar/learners/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | best_models 2 | Exp_base 3 | .ipynb_checkpoints 4 | __pycache__ -------------------------------------------------------------------------------- /Cifar/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .pytorch_utils import * 2 | from .get_data_iter import * 3 | -------------------------------------------------------------------------------- /Imagenet/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .pytorch_utils import * 2 | from .get_data_iter import * 3 | -------------------------------------------------------------------------------- /ToyExample/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .pytorch_utils import * 2 | from .get_data_iter import * 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.7.1 2 | numpy==1.19.2 3 | torchvision==0.8.2 4 | matplotlib==3.3.2 5 | Pillow==8.0.1 6 | tensorboardX==2.1 7 | scipy==1.5.3 8 | scikit-learn==0.23.2 -------------------------------------------------------------------------------- /Cifar/run_scripts/feature_extract/extract_resnet20_20m.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python3 -m torch.distributed.launch --nproc_per_node=1 feature_extract.py \ 3 | --model "resnet20" \ 4 | --path "Exp_base/resnet20_base" \ 5 | --dataset "cifar10" \ 6 | --save_path 'your_data_path' \ 7 | --target_flops 20000000 \ 8 | --beta 246 -------------------------------------------------------------------------------- /Cifar/run_scripts/feature_extract/extract_resnet56_60m.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python3 -m torch.distributed.launch --nproc_per_node=1 feature_extract.py \ 3 | --model "resnet56" \ 4 | --path "Exp_base/resnet56_base" \ 5 | --dataset "cifar10" \ 6 | --save_path 'your_data_path' \ 7 | --target_flops 60000000 \ 8 | --beta 500 -------------------------------------------------------------------------------- /Imagenet/run_scripts/evaluate/evaluate_mobilenetv2_150m.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python3 -m torch.distributed.launch --nproc_per_node=1 evaluate.py \ 3 | --model "mobilenetv2" \ 4 | --path "pretrain_models/train_mobilenetv2_150m_19509" \ 5 | --dataset "imagenet" \ 6 | --save_path 'your_data_path' \ 7 | --cfg "[29, 12, 16, 24, 43, 50, 98, 273]" -------------------------------------------------------------------------------- /Imagenet/run_scripts/evaluate/evaluate_mobilenetv2_220m.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python3 -m torch.distributed.launch --nproc_per_node=1 evaluate.py \ 3 | --model "mobilenetv2" \ 4 | --path "pretrain_models/train_mobilenetv2_220m_19738" \ 5 | --dataset "imagenet" \ 6 | --save_path 'your_data_path' \ 7 | --cfg "[31, 14, 20, 28, 54, 74, 130, 298]" -------------------------------------------------------------------------------- /ToyExample/run_scripts/evaluate/evaluate_vgg_100m.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python3 -m torch.distributed.launch --nproc_per_node=1 evaluate.py \ 4 | --model "vgg" \ 5 | --path "pretrain_models/train_vgg_100m_22031" \ 6 | --dataset "cifar10" \ 7 | --save_path 'your_data_path' \ 8 | --cfg "[44, 35, 71, 71, 145, 104, 145, 290, 205, 325, 432, 429, 469]" 9 | -------------------------------------------------------------------------------- /ToyExample/run_scripts/feature_extract/extract_vgg_100m.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python3 -m torch.distributed.launch --nproc_per_node=1 feature_extract.py \ 3 | --model "vgg" \ 4 | --path "Exp_base/vgg_base" \ 5 | --dataset "cifar10" \ 6 | --save_path 'your_data_path' \ 7 | --target_flops 100000000 \ 8 | --beta 231 \ 9 | --test_batch_size 1024 10 | -------------------------------------------------------------------------------- /Imagenet/run_scripts/evaluate/evaluate_mobilenet_50m.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python3 -m torch.distributed.launch --nproc_per_node=1 evaluate.py \ 3 | --model "mobilenet" \ 4 | --path "pretrain_models/train_mobilenet_150m_31752" \ 5 | --dataset "imagenet" \ 6 | --save_path 'your_data_path' \ 7 | --cfg "[18, 31, 30, 38, 71, 76, 173, 52, 78, 78, 51, 178, 250, 522]" -------------------------------------------------------------------------------- /Imagenet/run_scripts/evaluate/evaluate_mobilenet_150m.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python3 -m torch.distributed.launch --nproc_per_node=1 evaluate.py \ 3 | --model "mobilenet" \ 4 | --path "pretrain_models/train_mobilenet_150m_31752" \ 5 | --dataset "imagenet" \ 6 | --save_path 'your_data_path' \ 7 | --cfg "[23, 42, 63, 67, 132, 134, 275, 194, 202, 202, 194, 277, 522, 687]" -------------------------------------------------------------------------------- /Imagenet/run_scripts/evaluate/evaluate_mobilenet_285m.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python3 -m torch.distributed.launch --nproc_per_node=1 evaluate.py \ 3 | --model "mobilenet" \ 4 | --path "pretrain_models/train_mobilenet_285m_3956" \ 5 | --dataset "imagenet" \ 6 | --save_path 'your_data_path' \ 7 | --cfg "[28, 52, 88, 90, 179, 182, 362, 311, 314, 314, 312, 367, 733, 1024]" -------------------------------------------------------------------------------- /ToyExample/run_scripts/train/train_vgg_100m.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python3 -m torch.distributed.launch --nproc_per_node=1 train.py --train \ 3 | --model "vgg" \ 4 | --cfg "[44, 35, 71, 71, 145, 104, 145, 290, 205, 325, 432, 429, 469]" \ 5 | --path "Exp_train/train_vgg_100m_${RANDOM}" \ 6 | --dataset "cifar10" \ 7 | --save_path 'your_data_path' \ 8 | --base_path "Exp_base/vgg_base" \ 9 | --warm_epoch 1 \ 10 | --n_epochs 300 \ 11 | --sync_bn 12 | 13 | -------------------------------------------------------------------------------- /Cifar/run_scripts/evaluate/evaluate_resnet20_cifar.sh: -------------------------------------------------------------------------------- 1 | ## ============================= resnet cifar10 decode ============================ 2 | python main.py \ 3 | --gpu_id 0 \ 4 | --exec_mode eval \ 5 | --learner vanilla \ 6 | --dataset cifar10 \ 7 | --data_path ~/datasets \ 8 | --model_type resnet_decode \ 9 | --load_path /ITPruner/Cifar/models/best_models/resnet20_20m/model.pt \ 10 | --cfg 16,11,11,11,11,11,11,22,22,21,21,22,22,45,45,43,43,48,48 11 | 12 | -------------------------------------------------------------------------------- /Imagenet/run_scripts/evaluate/evaluate_resnet50_2000m.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python3 -m torch.distributed.launch --nproc_per_node=1 evaluate.py \ 3 | --model "resnet50" \ 4 | --path "pretrain_models/train_resnet50_2000m_2229" \ 5 | --dataset "imagenet" \ 6 | --save_path 'your_data_path' \ 7 | --cfg "[50, 53, 54, 125, 54, 54, 54, 54, 101, 108, 179, 107, 107, 107, 107, 107, 107, 203, 216, 148, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 407, 429, 1341, 429, 429, 429, 428]" -------------------------------------------------------------------------------- /Imagenet/run_scripts/evaluate/evaluate_resnet50_2300m.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python3 -m torch.distributed.launch --nproc_per_node=1 evaluate.py \ 3 | --model "resnet50" \ 4 | --path "pretrain_models/train_resnet50_2300m_12067" \ 5 | --dataset "imagenet" \ 6 | --save_path 'your_data_path' \ 7 | --cfg "[52, 55, 55, 146, 55, 55, 55, 55, 105, 111, 232, 111, 111, 111, 111, 111, 111, 211, 222, 290, 222, 222, 222, 222, 222, 222, 222, 222, 222, 222, 423, 442, 1453, 442, 442, 442, 442]" -------------------------------------------------------------------------------- /Cifar/run_scripts/train/start_train_resnet20_cifar.sh: -------------------------------------------------------------------------------- 1 | ## ============================= resnet cifar10 decode ============================ 2 | python main.py \ 3 | --gpu_id 0 \ 4 | --exec_mode train \ 5 | --learner vanilla \ 6 | --dataset cifar10 \ 7 | --data_path ~/datasets \ 8 | --model_type resnet_decode \ 9 | --lr 0.1 \ 10 | --lr_min 0. \ 11 | --lr_decy_type cosine \ 12 | --weight_decay 5e-4 \ 13 | --nesterov \ 14 | --epochs 300 \ 15 | --cfg 16,11,11,11,11,11,11,22,22,21,21,22,22,45,45,43,43,48,48 16 | 17 | -------------------------------------------------------------------------------- /Imagenet/prep_imagenet.sh: -------------------------------------------------------------------------------- 1 | # mount -t tmpfs -o size=160G tmpfs /userhome/memory_data 2 | root_dir="/dev/shm/imagenet" 3 | mkdir -p "${root_dir}/train" 4 | mkdir -p "${root_dir}/val" 5 | tar -xvf /userhome/data/ILSVRC2012_img_train.tar -C "${root_dir}/train" 6 | cp /userhome/unzip.sh "${root_dir}/train/" 7 | cd "${root_dir}/train" 8 | chmod +x unzip.sh 9 | ./unzip.sh 10 | tar -xvf /userhome/data/ILSVRC2012_img_val.tar -C "${root_dir}/val" 11 | cp /userhome/valprep.sh "${root_dir}/val/" 12 | cd "${root_dir}/val" 13 | chmod +x valprep.sh 14 | ./valprep.sh -------------------------------------------------------------------------------- /Imagenet/run_scripts/feature_extract/extract_mobilenet_150m.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | code_path='/ITPruner/Imagenet/' 3 | chmod +x ${code_path}/prep_imagenet.sh 4 | cd ${code_path} 5 | echo "preparing data" 6 | bash ${code_path}/prep_imagenet.sh >> /dev/null 7 | echo "preparing data finished" 8 | python3 -m torch.distributed.launch --nproc_per_node=1 feature_extract.py \ 9 | --model "mobilenet" \ 10 | --path "Exp_base/mobilenet_base" \ 11 | --dataset "imagenet" \ 12 | --save_path 'your_data_path' \ 13 | --target_flops 150000000 \ 14 | --beta 243 -------------------------------------------------------------------------------- /Imagenet/run_scripts/feature_extract/extract_mobilenet_285m.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | code_path='/ITPruner/Imagenet/' 3 | chmod +x ${code_path}/prep_imagenet.sh 4 | cd ${code_path} 5 | echo "preparing data" 6 | bash ${code_path}/prep_imagenet.sh >> /dev/null 7 | echo "preparing data finished" 8 | python3 -m torch.distributed.launch --nproc_per_node=1 feature_extract.py \ 9 | --model "mobilenet" \ 10 | --path "Exp_base/mobilenet_base" \ 11 | --dataset "imagenet" \ 12 | --save_path 'your_data_path' \ 13 | --target_flops 285000000 \ 14 | --beta 233 -------------------------------------------------------------------------------- /Imagenet/run_scripts/feature_extract/extract_mobilenet_50m.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | code_path='/ITPruner/Imagenet/' 3 | chmod +x ${code_path}/prep_imagenet.sh 4 | cd ${code_path} 5 | echo "preparing data" 6 | bash ${code_path}/prep_imagenet.sh >> /dev/null 7 | echo "preparing data finished" 8 | python3 -m torch.distributed.launch --nproc_per_node=1 feature_extract.py \ 9 | --model "mobilenet" \ 10 | --path "Exp_base/mobilenet_base" \ 11 | --dataset "imagenet" \ 12 | --save_path 'your_data_path' \ 13 | --target_flops 50000000 \ 14 | --beta 243 -------------------------------------------------------------------------------- /Imagenet/run_scripts/feature_extract/extract_resnet50_1800m.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | code_path='/ITPruner/Imagenet/' 3 | chmod +x ${code_path}/prep_imagenet.sh 4 | cd ${code_path} 5 | echo "preparing data" 6 | bash ${code_path}/prep_imagenet.sh >> /dev/null 7 | echo "preparing data finished" 8 | python3 -m torch.distributed.launch --nproc_per_node=1 feature_extract.py \ 9 | --model "resnet50" \ 10 | --path "Exp_base/resnet50_base" \ 11 | --dataset "imagenet" \ 12 | --save_path 'your_data_path' \ 13 | --target_flops 1800000000 \ 14 | --beta 2000 -------------------------------------------------------------------------------- /Imagenet/run_scripts/feature_extract/extract_resnet50_2000m.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | code_path='/ITPruner/Imagenet/' 3 | chmod +x ${code_path}/prep_imagenet.sh 4 | cd ${code_path} 5 | echo "preparing data" 6 | bash ${code_path}/prep_imagenet.sh >> /dev/null 7 | echo "preparing data finished" 8 | python3 -m torch.distributed.launch --nproc_per_node=1 feature_extract.py \ 9 | --model "resnet50" \ 10 | --path "Exp_base/resnet50_base" \ 11 | --dataset "imagenet" \ 12 | --save_path 'your_data_path' \ 13 | --target_flops 2000000000 \ 14 | --beta 2000 -------------------------------------------------------------------------------- /Imagenet/run_scripts/feature_extract/extract_resnet50_2300m.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | code_path='/ITPruner/Imagenet/' 3 | chmod +x ${code_path}/prep_imagenet.sh 4 | cd ${code_path} 5 | echo "preparing data" 6 | bash ${code_path}/prep_imagenet.sh >> /dev/null 7 | echo "preparing data finished" 8 | python3 -m torch.distributed.launch --nproc_per_node=1 feature_extract.py \ 9 | --model "resnet50" \ 10 | --path "Exp_base/resnet50_base" \ 11 | --dataset "imagenet" \ 12 | --save_path 'your_data_path' \ 13 | --target_flops 2300000000 \ 14 | --beta 2000 -------------------------------------------------------------------------------- /Imagenet/run_scripts/feature_extract/extract_mobilenetv2_150m.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | code_path='/ITPruner/Imagenet/' 3 | chmod +x ${code_path}/prep_imagenet.sh 4 | cd ${code_path} 5 | echo "preparing data" 6 | bash ${code_path}/prep_imagenet.sh >> /dev/null 7 | echo "preparing data finished" 8 | python3 -m torch.distributed.launch --nproc_per_node=1 feature_extract.py \ 9 | --model "mobilenetv2" \ 10 | --path "Exp_base/mobilenetv2_base" \ 11 | --dataset "imagenet" \ 12 | --save_path 'your_data_path' \ 13 | --target_flops 150000000 \ 14 | --beta 2000 -------------------------------------------------------------------------------- /Imagenet/run_scripts/feature_extract/extract_mobilenetv2_220m.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | code_path='/ITPruner/Imagenet/' 3 | chmod +x ${code_path}/prep_imagenet.sh 4 | cd ${code_path} 5 | echo "preparing data" 6 | bash ${code_path}/prep_imagenet.sh >> /dev/null 7 | echo "preparing data finished" 8 | python3 -m torch.distributed.launch --nproc_per_node=1 feature_extract.py \ 9 | --model "mobilenetv2" \ 10 | --path "Exp_base/mobilenetv2_base" \ 11 | --dataset "imagenet" \ 12 | --save_path 'your_data_path' \ 13 | --target_flops 220000000 \ 14 | --beta 2000 -------------------------------------------------------------------------------- /Imagenet/run_scripts/train/train_resnet50_base.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | code_path='/ITPruner/Imagenet/' 3 | chmod +x ${code_path}/prep_imagenet.sh 4 | cd ${code_path} 5 | echo "preparing data" 6 | bash ${code_path}/prep_imagenet.sh >> /dev/null 7 | echo "preparing data finished" 8 | python3 -m torch.distributed.launch --nproc_per_node=4 train.py --train \ 9 | --model "resnet50" \ 10 | --path "Exp_base/resnet50_base_${RANDOM}" \ 11 | --dataset "imagenet" \ 12 | --save_path 'your_data_path' \ 13 | --sync_bn \ 14 | --warm_epoch 1 \ 15 | --n_epochs 120 16 | -------------------------------------------------------------------------------- /Cifar/run_scripts/evaluate/evaluate_resnet56_cifar.sh: -------------------------------------------------------------------------------- 1 | ## ============================= resnet cifar10 decode ============================ 2 | python main.py \ 3 | --gpu_id 0 \ 4 | --exec_mode eval \ 5 | --learner vanilla \ 6 | --dataset cifar10 \ 7 | --data_path ~/datasets \ 8 | --model_type resnet_decode \ 9 | --load_path /ITPruner/Cifar/models/best_models/resnet56_60m/model.pt \ 10 | --cfg 16,10,10,10,10,10,10,10,10,10,10,10,10,10,10,10,10,11,11,22,22,21,21,21,21,21,21,21,21,21,21,21,21,21,21,22,22,44,44,42,42,42,42,42,42,42,42,42,42,42,42,64,64,64,64 11 | -------------------------------------------------------------------------------- /Imagenet/run_scripts/train/train_mobilenet_base.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | code_path='/ITPruner/Imagenet/' 3 | chmod +x ${code_path}/prep_imagenet.sh 4 | cd ${code_path} 5 | echo "preparing data" 6 | bash ${code_path}/prep_imagenet.sh >> /dev/null 7 | echo "preparing data finished" 8 | python3 -m torch.distributed.launch --nproc_per_node=4 train.py --train \ 9 | --model "mobilenet" \ 10 | --path "Exp_base/mobilenet_base_${RANDOM}" \ 11 | --dataset "imagenet" \ 12 | --save_path 'your_data_path' \ 13 | --warm_epoch 1 \ 14 | --sync_bn \ 15 | --n_epochs 250 \ 16 | --label_smoothing 0.1 17 | -------------------------------------------------------------------------------- /Imagenet/run_scripts/train/train_mobilenetv2_base.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | code_path='/ITPruner/Imagenet/' 3 | chmod +x ${code_path}/prep_imagenet.sh 4 | cd ${code_path} 5 | echo "preparing data" 6 | bash ${code_path}/prep_imagenet.sh >> /dev/null 7 | echo "preparing data finished" 8 | python3 -m torch.distributed.launch --nproc_per_node=4 train.py --train \ 9 | --model "mobilenetv2" \ 10 | --path "Exp_base/mobilenetv2_base_${RANDOM}" \ 11 | --dataset "imagenet" \ 12 | --save_path 'your_data_path' \ 13 | --warm_epoch 1 \ 14 | --sync_bn \ 15 | --n_epochs 250 \ 16 | --label_smoothing 0.1 17 | -------------------------------------------------------------------------------- /Cifar/run_scripts/train/start_train_resnet56_cifar.sh: -------------------------------------------------------------------------------- 1 | ## ============================= resnet cifar10 decode ============================ 2 | python main.py \ 3 | --gpu_id 0 \ 4 | --exec_mode train \ 5 | --learner vanilla \ 6 | --dataset cifar10 \ 7 | --data_path ~/datasets \ 8 | --model_type resnet_decode \ 9 | --lr 0.1 \ 10 | --lr_min 0. \ 11 | --lr_decy_type cosine \ 12 | --weight_decay 5e-4 \ 13 | --nesterov \ 14 | --epochs 300 \ 15 | --cfg 16,10,10,10,10,10,10,10,10,10,10,10,10,10,10,10,10,11,11,22,22,21,21,21,21,21,21,21,21,21,21,21,21,21,21,22,22,44,44,42,42,42,42,42,42,42,42,42,42,42,42,64,64,64,64 16 | 17 | -------------------------------------------------------------------------------- /Imagenet/run_scripts/train/train_mobilenetv2_150m.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | code_path='/ITPruner/Imagenet/' 3 | chmod +x ${code_path}/prep_imagenet.sh 4 | cd ${code_path} 5 | echo "preparing data" 6 | bash ${code_path}/prep_imagenet.sh >> /dev/null 7 | echo "preparing data finished" 8 | python3 -m torch.distributed.launch --nproc_per_node=4 train.py --train \ 9 | --model "mobilenetv2" \ 10 | --cfg "[29, 12, 16, 24, 43, 50, 98, 273]" \ 11 | --path "Exp_train/train_mobilenetv2_150m_${RANDOM}" \ 12 | --dataset "imagenet" \ 13 | --save_path 'your_data_path' \ 14 | --base_path "Exp_base/mobilenetv2_base" \ 15 | --warm_epoch 1 \ 16 | --sync_bn \ 17 | --n_epochs 250 \ 18 | --label_smoothing 0.1 19 | -------------------------------------------------------------------------------- /Imagenet/run_scripts/train/train_mobilenetv2_220m.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | code_path='/ITPruner/Imagenet/' 3 | chmod +x ${code_path}/prep_imagenet.sh 4 | cd ${code_path} 5 | echo "preparing data" 6 | bash ${code_path}/prep_imagenet.sh >> /dev/null 7 | echo "preparing data finished" 8 | python3 -m torch.distributed.launch --nproc_per_node=4 train.py --train \ 9 | --model "mobilenetv2" \ 10 | --cfg "[31, 14, 20, 28, 54, 74, 130, 298]" \ 11 | --path "Exp_train/train_mobilenetv2_220m_${RANDOM}" \ 12 | --dataset "imagenet" \ 13 | --save_path 'your_data_path' \ 14 | --base_path "Exp_base/mobilenetv2_base" \ 15 | --warm_epoch 1 \ 16 | --sync_bn \ 17 | --n_epochs 250 \ 18 | --label_smoothing 0.1 19 | -------------------------------------------------------------------------------- /Imagenet/run_scripts/train/train_mobilenet_50m.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | code_path='/ITPruner/Imagenet/' 3 | chmod +x ${code_path}/prep_imagenet.sh 4 | cd ${code_path} 5 | echo "preparing data" 6 | bash ${code_path}/prep_imagenet.sh >> /dev/null 7 | echo "preparing data finished" 8 | python3 -m torch.distributed.launch --nproc_per_node=4 train.py --train \ 9 | --model "mobilenet" \ 10 | --cfg "[18, 31, 30, 38, 71, 76, 173, 52, 78, 78, 51, 178, 250, 522]" \ 11 | --path "Exp_train/train_mobilenet_50m_${RANDOM}" \ 12 | --dataset "imagenet" \ 13 | --save_path 'your_data_path' \ 14 | --base_path "Exp_base/mobilenet_base" \ 15 | --warm_epoch 1 \ 16 | --sync_bn \ 17 | --n_epochs 250 \ 18 | --label_smoothing 0.1 19 | -------------------------------------------------------------------------------- /Imagenet/run_scripts/train/train_mobilenet_150m.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | code_path='/ITPruner/Imagenet/' 3 | chmod +x ${code_path}/prep_imagenet.sh 4 | cd ${code_path} 5 | echo "preparing data" 6 | bash ${code_path}/prep_imagenet.sh >> /dev/null 7 | echo "preparing data finished" 8 | python3 -m torch.distributed.launch --nproc_per_node=4 train.py --train \ 9 | --model "mobilenet" \ 10 | --cfg "[23, 42, 63, 67, 132, 134, 275, 194, 202, 202, 194, 277, 522, 687]" \ 11 | --path "Exp_train/train_mobilenet_150m_${RANDOM}" \ 12 | --dataset "imagenet" \ 13 | --save_path 'your_data_path' \ 14 | --base_path "Exp_base/mobilenet_base" \ 15 | --warm_epoch 1 \ 16 | --sync_bn \ 17 | --n_epochs 250 \ 18 | --label_smoothing 0.1 19 | -------------------------------------------------------------------------------- /Imagenet/run_scripts/train/train_mobilenet_285m.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | code_path='/ITPruner/Imagenet/' 3 | chmod +x ${code_path}/prep_imagenet.sh 4 | cd ${code_path} 5 | echo "preparing data" 6 | bash ${code_path}/prep_imagenet.sh >> /dev/null 7 | echo "preparing data finished" 8 | python3 -m torch.distributed.launch --nproc_per_node=4 train.py --train \ 9 | --model "mobilenet" \ 10 | --cfg "[28, 52, 88, 90, 179, 182, 362, 311, 314, 314, 312, 367, 734, 1024]" \ 11 | --path "Exp_train/train_mobilenet_285m_${RANDOM}" \ 12 | --dataset "imagenet" \ 13 | --save_path 'your_data_path' \ 14 | --base_path "Exp_base/mobilenet_base" \ 15 | --warm_epoch 1 \ 16 | --sync_bn \ 17 | --n_epochs 250 \ 18 | --label_smoothing 0.1 19 | -------------------------------------------------------------------------------- /Cifar/nets/se_module.py: -------------------------------------------------------------------------------- 1 | """ Taken from https://github.com/moskomule/senet.pytorch/blob/master/senet/se_module.py """ 2 | 3 | from torch import nn 4 | 5 | class SELayer(nn.Module): 6 | def __init__(self, channel, reduction=16): 7 | super(SELayer, self).__init__() 8 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 9 | self.fc = nn.Sequential( 10 | nn.Linear(channel, channel // reduction, bias=False), 11 | nn.ReLU(inplace=True), 12 | nn.Linear(channel // reduction, channel, bias=False), 13 | nn.Sigmoid() 14 | ) 15 | 16 | def forward(self, x): 17 | b, c, _, _ = x.size() 18 | y = self.avg_pool(x).view(b, c) 19 | y = self.fc(y).view(b, c, 1, 1) 20 | return x * y.expand_as(x) 21 | -------------------------------------------------------------------------------- /Imagenet/run_scripts/train/train_resnet50_1800m.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | code_path='/ITPruner/Imagenet/' 3 | chmod +x ${code_path}/prep_imagenet.sh 4 | cd ${code_path} 5 | echo "preparing data" 6 | bash ${code_path}/prep_imagenet.sh >> /dev/null 7 | echo "preparing data finished" 8 | python3 -m torch.distributed.launch --nproc_per_node=4 train.py --train \ 9 | --model "resnet50" \ 10 | --cfg "[48, 52, 52, 110, 52, 52, 52, 52, 98, 105, 142, 105, 105, 105, 105, 105, 105, 197, 211, 48, 211, 211, 211, 211, 211, 211, 211, 211, 211, 211, 397, 420, 1264, 419, 419, 419, 418]" \ 11 | --path "Exp_train/train_resnet50_1800m_${RANDOM}" \ 12 | --dataset "imagenet" \ 13 | --save_path 'your_data_path' \ 14 | --sync_bn \ 15 | --warm_epoch 1 \ 16 | --base_path "Exp_base/resnet50_base" \ 17 | --n_epochs 120 18 | -------------------------------------------------------------------------------- /Imagenet/run_scripts/train/train_resnet50_2000m.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | code_path='/ITPruner/Imagenet/' 3 | chmod +x ${code_path}/prep_imagenet.sh 4 | cd ${code_path} 5 | echo "preparing data" 6 | bash ${code_path}/prep_imagenet.sh >> /dev/null 7 | echo "preparing data finished" 8 | python3 -m torch.distributed.launch --nproc_per_node=4 train.py --train \ 9 | --model "resnet50" \ 10 | --cfg "[50, 53, 54, 125, 54, 54, 54, 54, 101, 108, 179, 107, 107, 107, 107, 107, 107, 203, 216, 148, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 407, 429, 1341, 429, 429, 429, 428]" \ 11 | --path "Exp_train/train_resnet50_2000m_${RANDOM}" \ 12 | --dataset "imagenet" \ 13 | --save_path 'your_data_path' \ 14 | --base_path "Exp_base/resnet50_base" \ 15 | --sync_bn \ 16 | --warm_epoch 1 \ 17 | --n_epochs 120 18 | -------------------------------------------------------------------------------- /Imagenet/run_scripts/train/train_resnet50_2300m.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | code_path='/ITPruner/Imagenet/' 3 | chmod +x ${code_path}/prep_imagenet.sh 4 | cd ${code_path} 5 | echo "preparing data" 6 | bash ${code_path}/prep_imagenet.sh >> /dev/null 7 | echo "preparing data finished" 8 | python3 -m torch.distributed.launch --nproc_per_node=4 train.py --train \ 9 | --model "resnet50" \ 10 | --cfg "[52, 55, 55, 146, 55, 55, 55, 55, 105, 111, 232, 111, 111, 111, 111, 111, 111, 211, 222, 290, 222, 222, 222, 222, 222, 222, 222, 222, 222, 222, 423, 442, 1453, 442, 442, 442, 442]" \ 11 | --path "Exp_train/train_resnet50_2300m_${RANDOM}" \ 12 | --dataset "imagenet" \ 13 | --save_path 'your_data_path' \ 14 | --base_path "Exp_base/resnet50_base" \ 15 | --sync_bn \ 16 | --warm_epoch 1 \ 17 | --n_epochs 120 18 | -------------------------------------------------------------------------------- /Cifar/data_providers/__init__.py: -------------------------------------------------------------------------------- 1 | class DataProvider: 2 | VALID_SEED = 0 # random seed for the validation set 3 | 4 | @staticmethod 5 | def name(): 6 | """ Return name of the dataset """ 7 | raise NotImplementedError 8 | 9 | @property 10 | def data_shape(self): 11 | """ Return shape as python list of one data entry """ 12 | raise NotImplementedError 13 | 14 | @property 15 | def n_classes(self): 16 | """ Return `int` of num classes """ 17 | raise NotImplementedError 18 | 19 | @property 20 | def save_path(self): 21 | """ local path to save the data """ 22 | raise NotImplementedError 23 | 24 | @property 25 | def data_url(self): 26 | """ link to download the data """ 27 | raise NotImplementedError 28 | -------------------------------------------------------------------------------- /Imagenet/data_providers/__init__.py: -------------------------------------------------------------------------------- 1 | class DataProvider: 2 | VALID_SEED = 0 # random seed for the validation set 3 | 4 | @staticmethod 5 | def name(): 6 | """ Return name of the dataset """ 7 | raise NotImplementedError 8 | 9 | @property 10 | def data_shape(self): 11 | """ Return shape as python list of one data entry """ 12 | raise NotImplementedError 13 | 14 | @property 15 | def n_classes(self): 16 | """ Return `int` of num classes """ 17 | raise NotImplementedError 18 | 19 | @property 20 | def save_path(self): 21 | """ local path to save the data """ 22 | raise NotImplementedError 23 | 24 | @property 25 | def data_url(self): 26 | """ link to download the data """ 27 | raise NotImplementedError 28 | -------------------------------------------------------------------------------- /ToyExample/data_providers/__init__.py: -------------------------------------------------------------------------------- 1 | class DataProvider: 2 | VALID_SEED = 0 # random seed for the validation set 3 | 4 | @staticmethod 5 | def name(): 6 | """ Return name of the dataset """ 7 | raise NotImplementedError 8 | 9 | @property 10 | def data_shape(self): 11 | """ Return shape as python list of one data entry """ 12 | raise NotImplementedError 13 | 14 | @property 15 | def n_classes(self): 16 | """ Return `int` of num classes """ 17 | raise NotImplementedError 18 | 19 | @property 20 | def save_path(self): 21 | """ local path to save the data """ 22 | raise NotImplementedError 23 | 24 | @property 25 | def data_url(self): 26 | """ link to download the data """ 27 | raise NotImplementedError 28 | -------------------------------------------------------------------------------- /Cifar/utils/helper.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import random 4 | import string 5 | 6 | 7 | def init_logging(log_path): 8 | 9 | if not os.path.isdir(os.path.dirname(log_path)): 10 | print("Log path does not exist. Create a new one.") 11 | os.makedirs(os.path.dirname(log_path)) 12 | 13 | log = logging.getLogger() 14 | log.setLevel(logging.INFO) 15 | logFormatter = logging.Formatter('%(asctime)s [%(levelname)s]: %(message)s') 16 | 17 | fileHandler = logging.FileHandler(log_path) 18 | fileHandler.setFormatter(logFormatter) 19 | log.addHandler(fileHandler) 20 | 21 | consoleHandler = logging.StreamHandler() 22 | consoleHandler.setFormatter(logFormatter) 23 | log.addHandler(consoleHandler) 24 | 25 | 26 | def print_args(args): 27 | for k, v in zip(args.keys(), args.values()): 28 | logging.info("{0}: {1}".format(k, v)) 29 | 30 | 31 | def generate_job_id(): 32 | return ''.join(random.sample(string.ascii_letters+string.digits, 8)) 33 | -------------------------------------------------------------------------------- /Cifar/data_providers/cifar10.py: -------------------------------------------------------------------------------- 1 | from utils import get_cifar_iter 2 | from data_providers import DataProvider 3 | 4 | 5 | class CifarDataProvider(DataProvider): 6 | def __init__(self, save_path=None, train_batch_size=256, test_batch_size=500, valid_size=None, 7 | n_worker=8, manual_seed = 12, local_rank=0, world_size=1, **kwargs): 8 | 9 | self._save_path = save_path 10 | self.valid = None 11 | self.train = get_cifar_iter('train', self.save_path, train_batch_size, n_worker, cutout=16, manual_seed=manual_seed) 12 | self.test = get_cifar_iter('val', self.save_path, test_batch_size, n_worker, cutout=16, manual_seed=manual_seed) 13 | if self.valid is None: 14 | self.valid = self.test 15 | 16 | @staticmethod 17 | def name(): 18 | return 'cifar10' 19 | 20 | @property 21 | def data_shape(self): 22 | return 3, 32, 32 # C, H, W 23 | 24 | @property 25 | def n_classes(self): 26 | return 10 27 | 28 | @property 29 | def save_path(self): 30 | if self._save_path is None: 31 | self._save_path = '/userhome/data/cifar10' 32 | return self._save_path 33 | -------------------------------------------------------------------------------- /ToyExample/data_providers/cifar10.py: -------------------------------------------------------------------------------- 1 | from utils import get_cifar_iter 2 | from data_providers import DataProvider 3 | 4 | 5 | class CifarDataProvider(DataProvider): 6 | def __init__(self, save_path=None, train_batch_size=256, test_batch_size=500, valid_size=None, 7 | n_worker=8, manual_seed = 12, local_rank=0, world_size=1, **kwargs): 8 | 9 | self._save_path = save_path 10 | self.valid = None 11 | self.train = get_cifar_iter('train', self.save_path, train_batch_size, n_worker, cutout=16, manual_seed=manual_seed) 12 | self.test = get_cifar_iter('val', self.save_path, test_batch_size, n_worker, cutout=16, manual_seed=manual_seed) 13 | if self.valid is None: 14 | self.valid = self.test 15 | 16 | @staticmethod 17 | def name(): 18 | return 'cifar10' 19 | 20 | @property 21 | def data_shape(self): 22 | return 3, 32, 32 # C, H, W 23 | 24 | @property 25 | def n_classes(self): 26 | return 10 27 | 28 | @property 29 | def save_path(self): 30 | if self._save_path is None: 31 | self._save_path = '/userhome/data/cifar10' 32 | return self._save_path 33 | -------------------------------------------------------------------------------- /Cifar/nets/model_helper.py: -------------------------------------------------------------------------------- 1 | from nets.resnet_20s_decode import * 2 | from utils.compute_flops import * 3 | import torch 4 | 5 | 6 | class ModelHelper(object): 7 | def __init__(self): 8 | self.model_dict = {'resnet_decode': resnet_decode} 9 | 10 | def get_model(self, args): 11 | if args.model_type not in self.model_dict.keys(): 12 | raise ValueError('Wrong model type.') 13 | 14 | num_classes = self.__get_number_class(args.dataset) 15 | 16 | if 'decode' in args.model_type: 17 | if args.cfg == '': 18 | raise ValueError('Running decoding model. Empty cfg!') 19 | cfg = [int(v) for v in args.cfg.split(',')] 20 | model = self.model_dict[args.model_type](cfg, num_classes, args.se, args.se_reduction) 21 | if args.rank == 0: 22 | print(model) 23 | print("Flops and params:") 24 | resol = 32 if args.dataset == 'cifar10' else 224 25 | print_model_param_nums(model) 26 | print_model_param_flops(model, resol, multiply_adds=False) 27 | 28 | return model 29 | 30 | def __get_number_class(self, dataset): 31 | # determine the number of classes 32 | if dataset == 'cifar10': 33 | num_classes = 10 34 | elif dataset == 'cifar100': 35 | num_classes = 100 36 | elif dataset == 'ilsvrc_12': 37 | num_classes = 1000 38 | return num_classes 39 | 40 | def test(): 41 | mh = ModelHelper() 42 | for k in mh.model_dict.keys(): 43 | print(mh.get_model(k)) 44 | 45 | 46 | if __name__ == '__main__': 47 | test() 48 | 49 | -------------------------------------------------------------------------------- /Cifar/utils/dist_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.multiprocessing as mp 4 | import torch.distributed as dist 5 | import torch.nn as nn 6 | 7 | 8 | def init_dist(backend='nccl', 9 | master_ip='tcp://127.0.0.1', 10 | port=6669): 11 | if mp.get_start_method(allow_none=True) is None: 12 | mp.set_start_method('spawn') 13 | os.environ['MASTER_ADDR'] = master_ip 14 | os.environ['MASTER_PORT'] = str(port) 15 | rank = int(os.environ['RANK']) 16 | world_size = int(os.environ['WORLD_SIZE']) 17 | num_gpus = torch.cuda.device_count() 18 | torch.cuda.set_device(rank % num_gpus) 19 | dist_url = master_ip + ':' + str(port) 20 | dist.init_process_group(backend=backend, init_method=dist_url, \ 21 | world_size=world_size, rank=rank) 22 | return rank, world_size 23 | 24 | 25 | def average_gradients(model): 26 | for param in model.parameters(): 27 | if param.requires_grad and not (param.grad is None): 28 | dist.all_reduce(param.grad.data) 29 | 30 | 31 | def average_group_gradients(group_params): 32 | for param in group_params: 33 | if param.requires_grad and not (param.grad is None): 34 | dist.all_reduce(param.grad.data) 35 | 36 | 37 | def sync_bn_stat(model, world_size): 38 | if world_size == 1: 39 | return 40 | for mod in model.modules(): 41 | if type(mod) == nn.BatchNorm2d: 42 | dist.all_reduce(mod.running_mean.data) 43 | mod.running_mean.data.div_(world_size) 44 | dist.all_reduce(mod.running_var.data) 45 | mod.running_var.data.div_(world_size) 46 | 47 | 48 | def broadcast_params(model): 49 | for p in model.state_dict().values(): 50 | dist.broadcast(p, 0) 51 | 52 | 53 | def reduce_vars(vars_list, world_size): 54 | if world_size == 1: 55 | return 56 | vars_result = [] 57 | for var in vars_list: 58 | var = var / world_size 59 | dist.all_reduce(var) 60 | vars_result.append(var) 61 | return vars_result 62 | 63 | 64 | def sync_adam_optimizer(opt, world_size): 65 | # use this when using Adam for parallel training 66 | pass 67 | -------------------------------------------------------------------------------- /ToyExample/evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import numpy as np 5 | 6 | import torch 7 | from run_manager import RunManager 8 | from models import VGG_CIFAR, TrainRunConfig 9 | 10 | parser = argparse.ArgumentParser() 11 | 12 | """ model config """ 13 | parser.add_argument('--path', type=str) 14 | parser.add_argument('--model', type=str, default="vgg", 15 | choices=['vgg', 'resnet56', 'resnet110', 'resnet18', 'resnet34', 'resnet50', 'mobilenetv2', 'mobilenet']) 16 | parser.add_argument('--cfg', type=str, default="None") 17 | 18 | """ dataset config """ 19 | parser.add_argument('--dataset', type=str, default='cifar10', 20 | choices=['cifar10', 'imagenet']) 21 | parser.add_argument('--save_path', type=str, default='/userhome/data/cifar10') 22 | 23 | """ runtime config """ 24 | parser.add_argument('--gpu', help='gpu available', default='0') 25 | parser.add_argument('--test_batch_size', type=int, default=250) 26 | parser.add_argument('--n_worker', type=int, default=24) 27 | parser.add_argument("--local_rank", default=0, type=int) 28 | 29 | if __name__ == '__main__': 30 | args = parser.parse_args() 31 | 32 | # distributed setting 33 | torch.distributed.init_process_group(backend='nccl', 34 | init_method='env://') 35 | args.world_size = torch.distributed.get_world_size() 36 | 37 | # prepare run config 38 | run_config_path = '%s/run.config' % args.path 39 | 40 | run_config = TrainRunConfig( 41 | **args.__dict__ 42 | ) 43 | if args.local_rank == 0: 44 | print('Run config:') 45 | for k, v in args.__dict__.items(): 46 | print('\t%s: %s' % (k, v)) 47 | 48 | if args.model == "vgg": 49 | assert args.dataset == 'cifar10', 'vgg only supports cifar10 dataset' 50 | net = VGG_CIFAR( 51 | num_classes=run_config.data_provider.n_classes, cfg=eval(args.cfg)) 52 | 53 | # build run manager 54 | run_manager = RunManager(args.path, net, run_config) 55 | 56 | # load checkpoints 57 | best_model_path = '%s/checkpoint/model_best.pth.tar' % args.path 58 | assert os.path.isfile(best_model_path), 'wrong path' 59 | if torch.cuda.is_available(): 60 | checkpoint = torch.load(best_model_path) 61 | else: 62 | checkpoint = torch.load(best_model_path, map_location='cpu') 63 | if 'state_dict' in checkpoint: 64 | checkpoint = checkpoint['state_dict'] 65 | run_manager.net.load_state_dict(checkpoint) 66 | 67 | output_dict = {} 68 | 69 | # test 70 | print('Test on test set') 71 | loss, acc1, acc5 = run_manager.validate(is_test=True, return_top5=True) 72 | log = 'test_loss: %f\t test_acc1: %f\t test_acc5: %f' % (loss, acc1, acc5) 73 | run_manager.write_log(log, prefix='test') 74 | output_dict = { 75 | **output_dict, 76 | 'test_loss': '%f' % loss, 'test_acc1': '%f' % acc1, 'test_acc5': '%f' % acc5 77 | } 78 | if args.local_rank == 0: 79 | print(output_dict) 80 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ITPruner: An Information Theory-inspired Strategy for Automatic Network Pruning 2 | 3 | This repository contains all the experiments of our paper "An Information Theory-inspired Strategy for Automatic Network Pruning". It also includes some [pretrain_models](https://drive.google.com/drive/folders/12B741haMD8qGV6x2UMe0yjnIYOTuIokU?usp=sharing) which we list in the paper. 4 | 5 | # Requirements 6 | 7 | * [DALI](https://github.com/NVIDIA/DALI) 8 | * [Apex](https://github.com/NVIDIA/apex) 9 | * [torchprofile](https://github.com/zhijian-liu/torchprofile) 10 | * other requirements, running requirements.txt 11 | 12 | ```python 13 | pip install -r requirements.txt 14 | ``` 15 | 16 | 17 | 18 | # Running 19 | 20 | **feature extract** 21 | 22 | You need to download [base models](https://drive.google.com/drive/folders/1jAXw4pEFmaw6fSgqNGRhhvtsM-ZUhZsr?usp=sharing) and copy the path of them to "--path". 23 | 24 | ```python 25 | # [optional]cache imagenet dataset in RAM for accelerting I/O 26 | code_path='/ITPruner/Imagenet/' 27 | chmod +x ${code_path}/prep_imagenet.sh 28 | cd ${code_path} 29 | echo "preparing data" 30 | bash ${code_path}/prep_imagenet.sh >> /dev/null 31 | echo "preparing data finished" 32 | 33 | python3 -m torch.distributed.launch --nproc_per_node=1 feature_extract.py \ 34 | --model "mobilenet" \ 35 | --path "Exp_base/mobilenet_base" \ 36 | --dataset "imagenet" \ 37 | --save_path 'your_data_path' \ 38 | --target_flops 150000000 \ 39 | --beta 243 40 | ``` 41 | 42 | or 43 | 44 | ```python 45 | bash ./run_scripts/feature_extract/extract_mobilenet_150m.sh 46 | ``` 47 | 48 | 49 | 50 | **Train** 51 | 52 | Because of random seed, cfg obtained through feature extraction may have a little difference from ours. Our cfg are given in .sh files. 53 | 54 | ```python 55 | # [optional]cache imagenet dataset in RAM for accelerting I/O 56 | code_path='/ITPruner/Imagenet/' 57 | chmod +x ${code_path}/prep_imagenet.sh 58 | cd ${code_path} 59 | echo "preparing data" 60 | bash ${code_path}/prep_imagenet.sh >> /dev/null 61 | echo "preparing data finished" 62 | 63 | python3 -m torch.distributed.launch --nproc_per_node=4 train.py --train \ 64 | --model "mobilenet" \ 65 | --cfg "[23, 42, 63, 67, 132, 134, 275, 194, 202, 202, 194, 277, 522, 687]" \ 66 | --path "Exp_train/train_mobilenet_150m_${RANDOM}" \ 67 | --dataset "imagenet" \ 68 | --save_path 'your_data_path' \ 69 | --base_path "Exp_base/mobilenet_base" \ 70 | --warm_epoch 1 \ 71 | --sync_bn \ 72 | --n_epochs 250 \ 73 | --label_smoothing 0.1 74 | ``` 75 | 76 | or 77 | 78 | ```python 79 | bash ./run_scripts/train/train_mobilenet_150m.sh 80 | ``` 81 | 82 | 83 | 84 | **Evaluate** 85 | 86 | We provide some [pretrain_models](https://drive.google.com/drive/folders/12B741haMD8qGV6x2UMe0yjnIYOTuIokU?usp=sharing) which we list in the paper. 87 | 88 | ```python 89 | python3 -m torch.distributed.launch --nproc_per_node=1 evaluate.py \ 90 | --model "mobilenet" \ 91 | --path "pretrain_models/train_mobilenet_150m_31752" \ 92 | --dataset "imagenet" \ 93 | --save_path 'your_data_path' \ 94 | --cfg "[23, 42, 63, 67, 132, 134, 275, 194, 202, 202, 194, 277, 522, 687]" 95 | ``` 96 | 97 | or 98 | 99 | ```python 100 | bash ./run_scripts/evaluate/evaluate_mobilenet_150m.sh 101 | ``` 102 | 103 | -------------------------------------------------------------------------------- /Imagenet/evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import numpy as np 5 | 6 | import torch 7 | from run_manager import RunManager 8 | from models import ResNet_ImageNet, MobileNet, MobileNetV2, TrainRunConfig 9 | 10 | parser = argparse.ArgumentParser() 11 | 12 | """ model config """ 13 | parser.add_argument('--path', type=str) 14 | parser.add_argument('--model', type=str, default="vgg", 15 | choices=['vgg', 'resnet56', 'resnet110', 'resnet18', 'resnet34', 'resnet50', 'mobilenetv2', 'mobilenet']) 16 | parser.add_argument('--cfg', type=str, default="None") 17 | 18 | """ dataset config """ 19 | parser.add_argument('--dataset', type=str, default='cifar10', 20 | choices=['cifar10', 'imagenet']) 21 | parser.add_argument('--save_path', type=str, default='/userhome/data/cifar10') 22 | 23 | """ runtime config """ 24 | parser.add_argument('--gpu', help='gpu available', default='0') 25 | parser.add_argument('--test_batch_size', type=int, default=64) 26 | parser.add_argument('--n_worker', type=int, default=24) 27 | parser.add_argument("--local_rank", default=0, type=int) 28 | 29 | if __name__ == '__main__': 30 | args = parser.parse_args() 31 | 32 | # distributed setting 33 | torch.distributed.init_process_group(backend='nccl', 34 | init_method='env://') 35 | args.world_size = torch.distributed.get_world_size() 36 | 37 | # prepare run config 38 | run_config_path = '%s/run.config' % args.path 39 | 40 | run_config = TrainRunConfig( 41 | **args.__dict__ 42 | ) 43 | if args.local_rank == 0: 44 | print('Run config:') 45 | for k, v in args.__dict__.items(): 46 | print('\t%s: %s' % (k, v)) 47 | 48 | if args.model == "resnet50": 49 | assert args.dataset == 'imagenet', 'resnet50 only supports imagenet dataset' 50 | net = ResNet_ImageNet( 51 | depth=50, num_classes=run_config.data_provider.n_classes, cfg=eval(args.cfg)) 52 | elif args.model == "mobilenetv2": 53 | assert args.dataset == 'imagenet', 'mobilenetv2 only supports imagenet dataset' 54 | net = MobileNetV2( 55 | num_classes=run_config.data_provider.n_classes, cfg=eval(args.cfg)) 56 | elif args.model == "mobilenet": 57 | assert args.dataset == 'imagenet', 'mobilenet only supports imagenet dataset' 58 | net = MobileNet( 59 | num_classes=run_config.data_provider.n_classes, cfg=eval(args.cfg)) 60 | 61 | # build run manager 62 | run_manager = RunManager(args.path, net, run_config) 63 | 64 | # load checkpoints 65 | best_model_path = '%s/checkpoint/model_best.pth.tar' % args.path 66 | assert os.path.isfile(best_model_path), 'wrong path' 67 | if torch.cuda.is_available(): 68 | checkpoint = torch.load(best_model_path) 69 | else: 70 | checkpoint = torch.load(best_model_path, map_location='cpu') 71 | if 'state_dict' in checkpoint: 72 | checkpoint = checkpoint['state_dict'] 73 | run_manager.net.load_state_dict(checkpoint) 74 | 75 | output_dict = {} 76 | 77 | # test 78 | print('Test on test set') 79 | loss, acc1, acc5 = run_manager.validate(is_test=True, return_top5=True) 80 | log = 'test_loss: %f\t test_acc1: %f\t test_acc5: %f' % (loss, acc1, acc5) 81 | run_manager.write_log(log, prefix='test') 82 | output_dict = { 83 | **output_dict, 84 | 'test_loss': '%f' % loss, 'test_acc1': '%f' % acc1, 'test_acc5': '%f' % acc5 85 | } 86 | if args.local_rank == 0: 87 | print(output_dict) 88 | -------------------------------------------------------------------------------- /ToyExample/models/__init__.py: -------------------------------------------------------------------------------- 1 | from run_manager import RunConfig 2 | from .base_models import * 3 | from .vgg import * 4 | 5 | 6 | class TrainRunConfig(RunConfig): 7 | def __init__(self, n_epochs=150, init_lr=0.05, lr_schedule_type='cosine', lr_schedule_param=None, 8 | dataset='imagenet', train_batch_size=256, test_batch_size=500, 9 | opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.1, no_decay_keys='bn', 10 | model_init='he_fout', init_div_groups=False, validation_frequency=1, print_frequency=10, 11 | n_worker=32, local_rank=0, world_size=1, sync_bn=True, 12 | warm_epoch=5, save_path=None, base_path=None, **kwargs): 13 | super(TrainRunConfig, self).__init__( 14 | n_epochs, init_lr, lr_schedule_type, lr_schedule_param, 15 | dataset, train_batch_size, test_batch_size, 16 | opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys, 17 | model_init, init_div_groups, validation_frequency, print_frequency, local_rank, world_size, sync_bn, 18 | warm_epoch 19 | ) 20 | 21 | self.n_worker = n_worker 22 | self.save_path = save_path 23 | self.base_path = base_path 24 | 25 | print(kwargs.keys()) 26 | 27 | @property 28 | def data_config(self): 29 | return { 30 | 'train_batch_size': self.train_batch_size, 31 | 'test_batch_size': self.test_batch_size, 32 | 'n_worker': self.n_worker, 33 | 'local_rank': self.local_rank, 34 | 'world_size': self.world_size, 35 | 'save_path': self.save_path, 36 | } 37 | 38 | class SearchRunConfig(RunConfig): 39 | def __init__(self, n_epochs=150, init_lr=0.05, lr_schedule_type='cosine', lr_schedule_param=None, 40 | dataset='imagenet', train_batch_size=256, test_batch_size=500, 41 | opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.1, no_decay_keys='bn', 42 | model_init='he_fout', init_div_groups=False, validation_frequency=1, print_frequency=10, 43 | n_worker=32, local_rank=0, world_size=1, sync_bn=True, warm_epoch=5, save_path=None, 44 | search_epoch=10, target_flops=1000, n_remove=2, n_best=3, n_populations=8, n_generations=25, div=8, **kwargs): 45 | super(SearchRunConfig, self).__init__( 46 | n_epochs, init_lr, lr_schedule_type, lr_schedule_param, 47 | dataset, train_batch_size, test_batch_size, 48 | opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys, 49 | model_init, init_div_groups, validation_frequency, print_frequency, local_rank, world_size, sync_bn, 50 | warm_epoch 51 | ) 52 | 53 | self.div = div 54 | self.n_worker = n_worker 55 | self.save_path = save_path 56 | self.search_epoch = search_epoch 57 | self.target_flops = target_flops 58 | self.n_remove = n_remove 59 | self.n_best = n_best 60 | self.n_populations = n_populations 61 | self.n_generations = n_generations 62 | 63 | print(kwargs.keys()) 64 | 65 | @property 66 | def data_config(self): 67 | return { 68 | 'train_batch_size': self.train_batch_size, 69 | 'test_batch_size': self.test_batch_size, 70 | 'n_worker': self.n_worker, 71 | 'local_rank': self.local_rank, 72 | 'world_size': self.world_size, 73 | 'save_path': self.save_path 74 | } -------------------------------------------------------------------------------- /Imagenet/models/__init__.py: -------------------------------------------------------------------------------- 1 | from run_manager import RunConfig 2 | 3 | from .base_models import * 4 | from .resnet_imagenet import * 5 | from .mobilenet_imagenet import * 6 | 7 | 8 | class TrainRunConfig(RunConfig): 9 | def __init__(self, n_epochs=150, init_lr=0.05, lr_schedule_type='cosine', lr_schedule_param=None, 10 | dataset='imagenet', train_batch_size=256, test_batch_size=500, 11 | opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.1, no_decay_keys='bn', 12 | model_init='he_fout', init_div_groups=False, validation_frequency=1, print_frequency=10, 13 | n_worker=32, local_rank=0, world_size=1, sync_bn=True, 14 | warm_epoch=5, save_path=None, base_path=None, **kwargs): 15 | super(TrainRunConfig, self).__init__( 16 | n_epochs, init_lr, lr_schedule_type, lr_schedule_param, 17 | dataset, train_batch_size, test_batch_size, 18 | opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys, 19 | model_init, init_div_groups, validation_frequency, print_frequency, local_rank, world_size, sync_bn, 20 | warm_epoch 21 | ) 22 | 23 | self.n_worker = n_worker 24 | self.save_path = save_path 25 | self.base_path = base_path 26 | 27 | print(kwargs.keys()) 28 | 29 | @property 30 | def data_config(self): 31 | return { 32 | 'train_batch_size': self.train_batch_size, 33 | 'test_batch_size': self.test_batch_size, 34 | 'n_worker': self.n_worker, 35 | 'local_rank': self.local_rank, 36 | 'world_size': self.world_size, 37 | 'save_path': self.save_path, 38 | } 39 | 40 | class SearchRunConfig(RunConfig): 41 | def __init__(self, n_epochs=150, init_lr=0.05, lr_schedule_type='cosine', lr_schedule_param=None, 42 | dataset='imagenet', train_batch_size=256, test_batch_size=500, 43 | opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.1, no_decay_keys='bn', 44 | model_init='he_fout', init_div_groups=False, validation_frequency=1, print_frequency=10, 45 | n_worker=32, local_rank=0, world_size=1, sync_bn=True, warm_epoch=5, save_path=None, 46 | search_epoch=10, target_flops=1000, n_remove=2, n_best=3, n_populations=8, n_generations=25, div=8, **kwargs): 47 | super(SearchRunConfig, self).__init__( 48 | n_epochs, init_lr, lr_schedule_type, lr_schedule_param, 49 | dataset, train_batch_size, test_batch_size, 50 | opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys, 51 | model_init, init_div_groups, validation_frequency, print_frequency, local_rank, world_size, sync_bn, 52 | warm_epoch 53 | ) 54 | 55 | self.div = div 56 | self.n_worker = n_worker 57 | self.save_path = save_path 58 | self.search_epoch = search_epoch 59 | self.target_flops = target_flops 60 | self.n_remove = n_remove 61 | self.n_best = n_best 62 | self.n_populations = n_populations 63 | self.n_generations = n_generations 64 | 65 | print(kwargs.keys()) 66 | 67 | @property 68 | def data_config(self): 69 | return { 70 | 'train_batch_size': self.train_batch_size, 71 | 'test_batch_size': self.test_batch_size, 72 | 'n_worker': self.n_worker, 73 | 'local_rank': self.local_rank, 74 | 'world_size': self.world_size, 75 | 'save_path': self.save_path 76 | } -------------------------------------------------------------------------------- /Cifar/nets/__init__.py: -------------------------------------------------------------------------------- 1 | from run_manager import RunConfig 2 | from .resnet_cifar import * 3 | 4 | 5 | class TrainRunConfig(RunConfig): 6 | def __init__(self, n_epochs=150, init_lr=0.05, lr_schedule_type='cosine', lr_schedule_param=None, 7 | dataset='imagenet', train_batch_size=256, test_batch_size=500, 8 | opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.1, no_decay_keys='bn', 9 | model_init='he_fout', init_div_groups=False, validation_frequency=1, print_frequency=10, 10 | n_worker=32, local_rank=0, world_size=1, sync_bn=True, 11 | warm_epoch=5, save_path=None, base_path=None, manual_seed=12, **kwargs): 12 | super(TrainRunConfig, self).__init__( 13 | n_epochs, init_lr, lr_schedule_type, lr_schedule_param, 14 | dataset, train_batch_size, test_batch_size, 15 | opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys, 16 | model_init, init_div_groups, validation_frequency, print_frequency, local_rank, world_size, sync_bn, 17 | warm_epoch 18 | ) 19 | 20 | self.n_worker = n_worker 21 | self.save_path = save_path 22 | self.base_path = base_path 23 | self.manual_seed = manual_seed 24 | 25 | print(kwargs.keys()) 26 | 27 | @property 28 | def data_config(self): 29 | return { 30 | 'train_batch_size': self.train_batch_size, 31 | 'test_batch_size': self.test_batch_size, 32 | 'n_worker': self.n_worker, 33 | 'local_rank': self.local_rank, 34 | 'world_size': self.world_size, 35 | 'save_path': self.save_path, 36 | 'manual_seed': self.manual_seed, 37 | } 38 | 39 | 40 | class SearchRunConfig(RunConfig): 41 | def __init__(self, n_epochs=150, init_lr=0.05, lr_schedule_type='cosine', lr_schedule_param=None, 42 | dataset='imagenet', train_batch_size=256, test_batch_size=500, 43 | opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.1, no_decay_keys='bn', 44 | model_init='he_fout', init_div_groups=False, validation_frequency=1, print_frequency=10, 45 | n_worker=32, local_rank=0, world_size=1, sync_bn=True, warm_epoch=5, save_path=None, 46 | search_epoch=10, target_flops=1000, n_remove=2, n_best=3, n_populations=8, n_generations=25, div=8, **kwargs): 47 | super(SearchRunConfig, self).__init__( 48 | n_epochs, init_lr, lr_schedule_type, lr_schedule_param, 49 | dataset, train_batch_size, test_batch_size, 50 | opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys, 51 | model_init, init_div_groups, validation_frequency, print_frequency, local_rank, world_size, sync_bn, 52 | warm_epoch 53 | ) 54 | 55 | self.div = div 56 | self.n_worker = n_worker 57 | self.save_path = save_path 58 | self.search_epoch = search_epoch 59 | self.target_flops = target_flops 60 | self.n_remove = n_remove 61 | self.n_best = n_best 62 | self.n_populations = n_populations 63 | self.n_generations = n_generations 64 | 65 | print(kwargs.keys()) 66 | 67 | @property 68 | def data_config(self): 69 | return { 70 | 'train_batch_size': self.train_batch_size, 71 | 'test_batch_size': self.test_batch_size, 72 | 'n_worker': self.n_worker, 73 | 'local_rank': self.local_rank, 74 | 'world_size': self.world_size, 75 | 'save_path': self.save_path 76 | } 77 | -------------------------------------------------------------------------------- /Imagenet/cal_FLOPs.py: -------------------------------------------------------------------------------- 1 | #coding:utf8 2 | import torch 3 | import torchvision 4 | 5 | import torch.nn as nn 6 | from torch.autograd import Variable 7 | import torchvision.models as models 8 | 9 | import numpy as np 10 | 11 | def print_model_parm_flops(one_shot_model): 12 | prods = {} 13 | def save_hook(name): 14 | def hook_per(self, input, output): 15 | prods[name] = np.prod(input[0].shape) 16 | return hook_per 17 | 18 | list_1=[] 19 | def simple_hook(self, input, output): 20 | list_1.append(np.prod(input[0].shape)) 21 | list_2={} 22 | def simple_hook2(self, input, output): 23 | list_2['names'] = np.prod(input[0].shape) 24 | 25 | 26 | multiply_adds = False 27 | list_conv=[] 28 | def conv_hook(self, input, output): 29 | batch_size, input_channels, input_height, input_width = input[0].size() 30 | output_channels, output_height, output_width = output[0].size() 31 | 32 | kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) * (2 if multiply_adds else 1) 33 | bias_ops = 1 if self.bias is not None else 0 34 | 35 | params = output_channels * (kernel_ops + bias_ops) 36 | flops = batch_size * params * output_height * output_width 37 | 38 | list_conv.append(flops) 39 | 40 | 41 | list_linear=[] 42 | def linear_hook(self, input, output): 43 | batch_size = input[0].size(0) if input[0].dim() == 2 else 1 44 | 45 | weight_ops = self.weight.nelement() * (2 if multiply_adds else 1) 46 | bias_ops = self.bias.nelement() 47 | 48 | flops = batch_size * (weight_ops + bias_ops) 49 | list_linear.append(flops) 50 | 51 | list_bn=[] 52 | def bn_hook(self, input, output): 53 | list_bn.append(input[0].nelement()) 54 | 55 | list_relu=[] 56 | def relu_hook(self, input, output): 57 | list_relu.append(input[0].nelement()) 58 | 59 | list_pooling=[] 60 | def pooling_hook(self, input, output): 61 | batch_size, input_channels, input_height, input_width = input[0].size() 62 | output_channels, output_height, output_width = output[0].size() 63 | 64 | kernel_ops = self.kernel_size * self.kernel_size 65 | bias_ops = 0 66 | params = output_channels * (kernel_ops + bias_ops) 67 | flops = batch_size * params * output_height * output_width 68 | 69 | list_pooling.append(flops) 70 | 71 | def foo(net): 72 | childrens = list(net.children()) 73 | if not childrens: 74 | if isinstance(net, torch.nn.Conv2d): 75 | net.register_forward_hook(conv_hook) 76 | if isinstance(net, torch.nn.Linear): 77 | net.register_forward_hook(linear_hook) 78 | if isinstance(net, torch.nn.BatchNorm2d): 79 | net.register_forward_hook(bn_hook) 80 | if isinstance(net, torch.nn.ReLU): 81 | net.register_forward_hook(relu_hook) 82 | if isinstance(net, torch.nn.MaxPool2d) or isinstance(net, torch.nn.AvgPool2d): 83 | net.register_forward_hook(pooling_hook) 84 | return 85 | for c in childrens: 86 | foo(c) 87 | 88 | foo(one_shot_model) 89 | input = Variable(torch.rand(3,224,224).unsqueeze(0), requires_grad = True).cuda() 90 | out = one_shot_model(input) 91 | 92 | total_flops = (sum(list_conv) + sum(list_linear) + sum(list_bn) + sum(list_relu) + sum(list_pooling)) 93 | M_flops = total_flops / 1e6 94 | #print(' + Number of FLOPs: %.2fM' % (M_flops)) 95 | 96 | return M_flops 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | -------------------------------------------------------------------------------- /ToyExample/cal_FLOPs.py: -------------------------------------------------------------------------------- 1 | #coding:utf8 2 | import torch 3 | import torchvision 4 | 5 | import torch.nn as nn 6 | from torch.autograd import Variable 7 | import torchvision.models as models 8 | 9 | import numpy as np 10 | 11 | def print_model_parm_flops(one_shot_model): 12 | prods = {} 13 | def save_hook(name): 14 | def hook_per(self, input, output): 15 | prods[name] = np.prod(input[0].shape) 16 | return hook_per 17 | 18 | list_1=[] 19 | def simple_hook(self, input, output): 20 | list_1.append(np.prod(input[0].shape)) 21 | list_2={} 22 | def simple_hook2(self, input, output): 23 | list_2['names'] = np.prod(input[0].shape) 24 | 25 | 26 | multiply_adds = False 27 | list_conv=[] 28 | def conv_hook(self, input, output): 29 | batch_size, input_channels, input_height, input_width = input[0].size() 30 | output_channels, output_height, output_width = output[0].size() 31 | 32 | kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) * (2 if multiply_adds else 1) 33 | bias_ops = 1 if self.bias is not None else 0 34 | 35 | params = output_channels * (kernel_ops + bias_ops) 36 | flops = batch_size * params * output_height * output_width 37 | 38 | list_conv.append(flops) 39 | 40 | 41 | list_linear=[] 42 | def linear_hook(self, input, output): 43 | batch_size = input[0].size(0) if input[0].dim() == 2 else 1 44 | 45 | weight_ops = self.weight.nelement() * (2 if multiply_adds else 1) 46 | bias_ops = self.bias.nelement() 47 | 48 | flops = batch_size * (weight_ops + bias_ops) 49 | list_linear.append(flops) 50 | 51 | list_bn=[] 52 | def bn_hook(self, input, output): 53 | list_bn.append(input[0].nelement()) 54 | 55 | list_relu=[] 56 | def relu_hook(self, input, output): 57 | list_relu.append(input[0].nelement()) 58 | 59 | list_pooling=[] 60 | def pooling_hook(self, input, output): 61 | batch_size, input_channels, input_height, input_width = input[0].size() 62 | output_channels, output_height, output_width = output[0].size() 63 | 64 | kernel_ops = self.kernel_size * self.kernel_size 65 | bias_ops = 0 66 | params = output_channels * (kernel_ops + bias_ops) 67 | flops = batch_size * params * output_height * output_width 68 | 69 | list_pooling.append(flops) 70 | 71 | def foo(net): 72 | childrens = list(net.children()) 73 | if not childrens: 74 | if isinstance(net, torch.nn.Conv2d): 75 | net.register_forward_hook(conv_hook) 76 | if isinstance(net, torch.nn.Linear): 77 | net.register_forward_hook(linear_hook) 78 | if isinstance(net, torch.nn.BatchNorm2d): 79 | net.register_forward_hook(bn_hook) 80 | if isinstance(net, torch.nn.ReLU): 81 | net.register_forward_hook(relu_hook) 82 | if isinstance(net, torch.nn.MaxPool2d) or isinstance(net, torch.nn.AvgPool2d): 83 | net.register_forward_hook(pooling_hook) 84 | return 85 | for c in childrens: 86 | foo(c) 87 | 88 | foo(one_shot_model) 89 | input = Variable(torch.rand(3,224,224).unsqueeze(0), requires_grad = True).cuda() 90 | out = one_shot_model(input) 91 | 92 | total_flops = (sum(list_conv) + sum(list_linear) + sum(list_bn) + sum(list_relu) + sum(list_pooling)) 93 | M_flops = total_flops / 1e6 94 | #print(' + Number of FLOPs: %.2fM' % (M_flops)) 95 | 96 | return M_flops 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | -------------------------------------------------------------------------------- /Cifar/nets/base_models.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from utils import count_parameters 5 | 6 | class MyNetwork(nn.Module): 7 | def forward(self, x): 8 | raise NotImplementedError 9 | 10 | def feature_extract(self, x): 11 | raise NotImplementedError 12 | 13 | @property 14 | def config(self): # should include name/cfg/cfg_base/dataset 15 | raise NotImplementedError 16 | 17 | def cfg2params(self, cfg): 18 | raise NotImplementedError 19 | 20 | def cfg2flops(self, cfg): 21 | raise NotImplementedError 22 | 23 | def set_bn_param(self, momentum, eps): 24 | for m in self.modules(): 25 | if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): 26 | m.momentum = momentum 27 | m.eps = eps 28 | return 29 | 30 | def get_bn_param(self): 31 | for m in self.modules(): 32 | if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): 33 | return { 34 | 'momentum': m.momentum, 35 | 'eps': m.eps, 36 | } 37 | return None 38 | 39 | def init_model(self, model_init, init_div_groups=False): 40 | for m in self.modules(): 41 | if isinstance(m, nn.Conv2d): 42 | if model_init == 'he_fout': 43 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 44 | if init_div_groups: 45 | n /= m.groups 46 | m.weight.data.normal_(0, math.sqrt(2. / n)) 47 | elif model_init == 'he_fin': 48 | n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels 49 | if init_div_groups: 50 | n /= m.groups 51 | m.weight.data.normal_(0, math.sqrt(2. / n)) 52 | elif model_init == 'xavier_normal': 53 | nn.init.xavier_normal_(m.weight.data) 54 | elif model_init == 'xavier_uniform': 55 | nn.init.xavier_uniform_(m.weight.data) 56 | else: 57 | raise NotImplementedError 58 | elif isinstance(m, nn.BatchNorm2d): 59 | m.weight.data.fill_(1) 60 | m.bias.data.zero_() 61 | elif isinstance(m, nn.Linear): 62 | stdv = 1. / math.sqrt(m.weight.size(1)) 63 | m.weight.data.uniform_(-stdv, stdv) 64 | if m.bias is not None: 65 | m.bias.data.zero_() 66 | elif isinstance(m, nn.BatchNorm1d): 67 | m.weight.data.fill_(1) 68 | m.bias.data.zero_() 69 | 70 | def get_parameters(self, keys=None, mode='include'): 71 | if keys is None: 72 | for name, param in self.named_parameters(): 73 | yield param 74 | elif mode == 'include': 75 | for name, param in self.named_parameters(): 76 | flag = False 77 | for key in keys: 78 | if key in name: 79 | flag = True 80 | break 81 | if flag: 82 | yield param 83 | elif mode == 'exclude': 84 | for name, param in self.named_parameters(): 85 | flag = True 86 | for key in keys: 87 | if key in name: 88 | flag = False 89 | break 90 | if flag: 91 | yield param 92 | else: 93 | raise ValueError('do not support: %s' % mode) 94 | 95 | def weight_parameters(self): 96 | return self.get_parameters() 97 | -------------------------------------------------------------------------------- /Imagenet/models/base_models.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from utils import count_parameters 5 | 6 | class MyNetwork(nn.Module): 7 | def forward(self, x): 8 | raise NotImplementedError 9 | 10 | def feature_extract(self, x): 11 | raise NotImplementedError 12 | 13 | @property 14 | def config(self): # should include name/cfg/cfg_base/dataset 15 | raise NotImplementedError 16 | 17 | def cfg2params(self, cfg): 18 | raise NotImplementedError 19 | 20 | def cfg2flops(self, cfg): 21 | raise NotImplementedError 22 | 23 | def set_bn_param(self, momentum, eps): 24 | for m in self.modules(): 25 | if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): 26 | m.momentum = momentum 27 | m.eps = eps 28 | return 29 | 30 | def get_bn_param(self): 31 | for m in self.modules(): 32 | if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): 33 | return { 34 | 'momentum': m.momentum, 35 | 'eps': m.eps, 36 | } 37 | return None 38 | 39 | def init_model(self, model_init, init_div_groups=False): 40 | for m in self.modules(): 41 | if isinstance(m, nn.Conv2d): 42 | if model_init == 'he_fout': 43 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 44 | if init_div_groups: 45 | n /= m.groups 46 | m.weight.data.normal_(0, math.sqrt(2. / n)) 47 | elif model_init == 'he_fin': 48 | n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels 49 | if init_div_groups: 50 | n /= m.groups 51 | m.weight.data.normal_(0, math.sqrt(2. / n)) 52 | elif model_init == 'xavier_normal': 53 | nn.init.xavier_normal_(m.weight.data) 54 | elif model_init == 'xavier_uniform': 55 | nn.init.xavier_uniform_(m.weight.data) 56 | else: 57 | raise NotImplementedError 58 | elif isinstance(m, nn.BatchNorm2d): 59 | m.weight.data.fill_(1) 60 | m.bias.data.zero_() 61 | elif isinstance(m, nn.Linear): 62 | stdv = 1. / math.sqrt(m.weight.size(1)) 63 | m.weight.data.uniform_(-stdv, stdv) 64 | if m.bias is not None: 65 | m.bias.data.zero_() 66 | elif isinstance(m, nn.BatchNorm1d): 67 | m.weight.data.fill_(1) 68 | m.bias.data.zero_() 69 | 70 | def get_parameters(self, keys=None, mode='include'): 71 | if keys is None: 72 | for name, param in self.named_parameters(): 73 | yield param 74 | elif mode == 'include': 75 | for name, param in self.named_parameters(): 76 | flag = False 77 | for key in keys: 78 | if key in name: 79 | flag = True 80 | break 81 | if flag: 82 | yield param 83 | elif mode == 'exclude': 84 | for name, param in self.named_parameters(): 85 | flag = True 86 | for key in keys: 87 | if key in name: 88 | flag = False 89 | break 90 | if flag: 91 | yield param 92 | else: 93 | raise ValueError('do not support: %s' % mode) 94 | 95 | def weight_parameters(self): 96 | return self.get_parameters() 97 | -------------------------------------------------------------------------------- /ToyExample/models/base_models.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from utils import count_parameters 5 | 6 | class MyNetwork(nn.Module): 7 | def forward(self, x): 8 | raise NotImplementedError 9 | 10 | def feature_extract(self, x): 11 | raise NotImplementedError 12 | 13 | @property 14 | def config(self): # should include name/cfg/cfg_base/dataset 15 | raise NotImplementedError 16 | 17 | def cfg2params(self, cfg): 18 | raise NotImplementedError 19 | 20 | def cfg2flops(self, cfg): 21 | raise NotImplementedError 22 | 23 | def set_bn_param(self, momentum, eps): 24 | for m in self.modules(): 25 | if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): 26 | m.momentum = momentum 27 | m.eps = eps 28 | return 29 | 30 | def get_bn_param(self): 31 | for m in self.modules(): 32 | if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): 33 | return { 34 | 'momentum': m.momentum, 35 | 'eps': m.eps, 36 | } 37 | return None 38 | 39 | def init_model(self, model_init, init_div_groups=False): 40 | for m in self.modules(): 41 | if isinstance(m, nn.Conv2d): 42 | if model_init == 'he_fout': 43 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 44 | if init_div_groups: 45 | n /= m.groups 46 | m.weight.data.normal_(0, math.sqrt(2. / n)) 47 | elif model_init == 'he_fin': 48 | n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels 49 | if init_div_groups: 50 | n /= m.groups 51 | m.weight.data.normal_(0, math.sqrt(2. / n)) 52 | elif model_init == 'xavier_normal': 53 | nn.init.xavier_normal_(m.weight.data) 54 | elif model_init == 'xavier_uniform': 55 | nn.init.xavier_uniform_(m.weight.data) 56 | else: 57 | raise NotImplementedError 58 | elif isinstance(m, nn.BatchNorm2d): 59 | m.weight.data.fill_(1) 60 | m.bias.data.zero_() 61 | elif isinstance(m, nn.Linear): 62 | stdv = 1. / math.sqrt(m.weight.size(1)) 63 | m.weight.data.uniform_(-stdv, stdv) 64 | if m.bias is not None: 65 | m.bias.data.zero_() 66 | elif isinstance(m, nn.BatchNorm1d): 67 | m.weight.data.fill_(1) 68 | m.bias.data.zero_() 69 | 70 | def get_parameters(self, keys=None, mode='include'): 71 | if keys is None: 72 | for name, param in self.named_parameters(): 73 | yield param 74 | elif mode == 'include': 75 | for name, param in self.named_parameters(): 76 | flag = False 77 | for key in keys: 78 | if key in name: 79 | flag = True 80 | break 81 | if flag: 82 | yield param 83 | elif mode == 'exclude': 84 | for name, param in self.named_parameters(): 85 | flag = True 86 | for key in keys: 87 | if key in name: 88 | flag = False 89 | break 90 | if flag: 91 | yield param 92 | else: 93 | raise ValueError('do not support: %s' % mode) 94 | 95 | def weight_parameters(self): 96 | return self.get_parameters() 97 | -------------------------------------------------------------------------------- /Imagenet/data_providers/imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import numpy as np 4 | import torch.utils.data 5 | 6 | import utils 7 | from data_providers import DataProvider 8 | 9 | 10 | def make_imagenet_subset(path2subset, n_sub_classes, path2imagenet='/userhome/memory_data/imagenet'): 11 | imagenet_train_folder = os.path.join(path2imagenet, 'train') 12 | imagenet_val_folder = os.path.join(path2imagenet, 'val') 13 | 14 | subfolders = sorted([f.path for f in os.scandir(imagenet_train_folder) if f.is_dir()]) 15 | # np.random.seed(DataProvider.VALID_SEED) 16 | np.random.shuffle(subfolders) 17 | 18 | chosen_train_folders = subfolders[:n_sub_classes] 19 | class_name_list = [] 20 | for train_folder in chosen_train_folders: 21 | class_name = train_folder.split('/')[-1] 22 | class_name_list.append(class_name) 23 | 24 | print('=> Start building subset%d' % n_sub_classes) 25 | for cls_name in class_name_list: 26 | src_train_folder = os.path.join(imagenet_train_folder, cls_name) 27 | target_train_folder = os.path.join(path2subset, 'train/%s' % cls_name) 28 | shutil.copytree(src_train_folder, target_train_folder) 29 | print('Train: %s -> %s' % (src_train_folder, target_train_folder)) 30 | 31 | src_val_folder = os.path.join(imagenet_val_folder, cls_name) 32 | target_val_folder = os.path.join(path2subset, 'val/%s' % cls_name) 33 | shutil.copytree(src_val_folder, target_val_folder) 34 | print('Val: %s -> %s' % (src_val_folder, target_val_folder)) 35 | print('=> Finish building subset%d' % n_sub_classes) 36 | 37 | 38 | class ImagenetDataProvider(DataProvider): 39 | def __init__(self, save_path=None, train_batch_size=256, test_batch_size=512, valid_size=None, 40 | n_worker=24, manual_seed = 12, load_type='dali', local_rank=0, world_size=1, **kwargs): 41 | 42 | self._save_path = save_path 43 | self.valid = None 44 | if valid_size is not None: 45 | pass 46 | else: 47 | self.train = utils.get_imagenet_iter(data_type='train', image_dir=self.train_path, 48 | batch_size=train_batch_size, num_threads=n_worker, 49 | device_id=local_rank, manual_seed=manual_seed, 50 | num_gpus=torch.cuda.device_count(), crop=self.image_size, 51 | val_size=self.image_size, world_size=world_size, local_rank=local_rank) 52 | self.test = utils.get_imagenet_iter(data_type='val', image_dir=self.valid_path, manual_seed=manual_seed, 53 | batch_size=test_batch_size, num_threads=n_worker, device_id=local_rank, 54 | num_gpus=torch.cuda.device_count(), crop=self.image_size, 55 | val_size=256, world_size=world_size, local_rank=local_rank) 56 | if self.valid is None: 57 | self.valid = self.test 58 | 59 | @staticmethod 60 | def name(): 61 | return 'imagenet' 62 | 63 | @property 64 | def data_shape(self): 65 | return 3, self.image_size, self.image_size # C, H, W 66 | 67 | @property 68 | def n_classes(self): 69 | return 1000 70 | 71 | @property 72 | def save_path(self): 73 | if self._save_path is None: 74 | self._save_path = '/userhome/data/imagenet' 75 | return self._save_path 76 | 77 | @property 78 | def data_url(self): 79 | raise ValueError('unable to download ImageNet') 80 | 81 | @property 82 | def train_path(self): 83 | return os.path.join(self.save_path, 'train') 84 | 85 | @property 86 | def valid_path(self): 87 | return os.path.join(self._save_path, 'val') 88 | 89 | @property 90 | def resize_value(self): 91 | return 256 92 | 93 | @property 94 | def image_size(self): 95 | return 224 96 | -------------------------------------------------------------------------------- /Cifar/nets/resnet_20s_decode.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Decode model from alpha 3 | ''' 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.nn.init as init 9 | from torch.autograd import Variable 10 | from utils.compute_flops import * 11 | from nets.se_module import SELayer 12 | import argparse 13 | from pdb import set_trace as br 14 | 15 | __all__ = ['resnet_decode'] 16 | 17 | 18 | def _weights_init(m): 19 | classname = m.__class__.__name__ 20 | # print(classname) 21 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 22 | init.kaiming_normal_(m.weight) 23 | 24 | 25 | class DShortCut(nn.Module): 26 | def __init__(self, cin, cout, has_avg, has_BN, affine=True): 27 | super(DShortCut, self).__init__() 28 | self.conv = nn.Conv2d(cin, cout, kernel_size=1, stride=1, padding=0, bias=False) 29 | if has_avg: 30 | self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) 31 | else: 32 | self.avg = None 33 | 34 | if has_BN: 35 | self.bn = nn.BatchNorm2d(cout, affine=affine) 36 | else: 37 | self.bn = None 38 | 39 | def forward(self, x): 40 | if self.avg: 41 | out = self.avg(x) 42 | else: 43 | out = x 44 | 45 | out = self.conv(out) 46 | if self.bn: 47 | out = self.bn(out) 48 | return out 49 | 50 | 51 | def ChannelWiseInterV2(inputs, oC, downsample=False): 52 | assert inputs.dim() == 4, 'invalid dimension : {:}'.format(inputs.size()) 53 | batch, C, H, W = inputs.size() 54 | 55 | if downsample: 56 | inputs = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)(inputs) 57 | H, W = inputs.shape[2], inputs.shape[3] 58 | if C == oC: 59 | return inputs 60 | else: 61 | return nn.functional.adaptive_avg_pool3d(inputs, (oC,H,W)) 62 | 63 | 64 | class BasicBlock(nn.Module): 65 | # expansion = 1 66 | 67 | def __init__(self, in_planes, cfg, stride=1, affine=True, \ 68 | se=False, se_reduction=-1): 69 | """ Args: 70 | in_planes: an int, the input chanels; 71 | cfg: a list of int, the mid and output channels; 72 | se: whether use SE module or not 73 | se_reduction: the mid width for se module 74 | """ 75 | super(BasicBlock, self).__init__() 76 | assert len(cfg) == 2, 'wrong cfg length' 77 | mid_planes, planes = cfg[0], cfg[1] 78 | self.conv1 = nn.Conv2d(in_planes, mid_planes, kernel_size=3, stride=stride, padding=1, bias=False) 79 | self.bn1 = nn.BatchNorm2d(mid_planes, affine=affine) 80 | self.relu1 = nn.ReLU() 81 | self.conv2 = nn.Conv2d(mid_planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 82 | self.bn2 = nn.BatchNorm2d(planes, affine=affine) 83 | self.relu2 = nn.ReLU() 84 | 85 | self.se = se 86 | self.se_reduction = se_reduction 87 | if self.se: 88 | assert se_reduction > 0, "Must specify se reduction > 0" 89 | self.se_module = SELayer(planes, se_reduction) 90 | 91 | if stride == 2: 92 | self.shortcut = DShortCut(in_planes, planes, has_avg=True, has_BN=False) 93 | elif in_planes != planes: 94 | self.shortcut = DShortCut(in_planes, planes, has_avg=False, has_BN=True) 95 | else: 96 | self.shortcut = nn.Sequential() 97 | 98 | def forward(self, x): 99 | out = self.relu1(self.bn1(self.conv1(x))) 100 | out = self.bn2(self.conv2(out)) 101 | if self.se: 102 | out = self.se_module(out) 103 | 104 | out += self.shortcut(x) 105 | out = self.relu2(out) 106 | return out 107 | 108 | 109 | class ResNetDecode(nn.Module): 110 | """ 111 | Re-build the resnet based on the cfg obtained from alpha. 112 | """ 113 | def __init__(self, block, cfg, num_classes=10, affine=True, se=False, se_reduction=-1): 114 | super(ResNetDecode, self).__init__() 115 | # self.in_planes = 16 116 | self.cfg = cfg 117 | self.affine = affine 118 | self.se = se 119 | self.se_reduction = se_reduction 120 | 121 | self.conv1 = nn.Conv2d(3, self.cfg[0], kernel_size=3, stride=1, padding=1, bias=False) 122 | self.bn1 = nn.BatchNorm2d(self.cfg[0], affine=self.affine) 123 | self.relu1 = nn.ReLU() 124 | num_blocks = [(len(self.cfg) - 1) // 2 // 3] * 3 # [3] * 3 125 | 126 | count = 1 127 | self.layer1 = self._make_layer(block, self.cfg[count-1:count+num_blocks[0]*2-1], self.cfg[count:count+num_blocks[0]*2], stride=1) 128 | count += num_blocks[0]*2 129 | self.layer2 = self._make_layer(block, self.cfg[count-1:count+num_blocks[1]*2-1], self.cfg[count:count+num_blocks[1]*2], stride=2) 130 | count += num_blocks[1]*2 131 | self.layer3 = self._make_layer(block, self.cfg[count-1:count+num_blocks[2]*2-1], self.cfg[count:count+num_blocks[2]*2], stride=2) 132 | count += num_blocks[2]*2 133 | assert count == len(self.cfg) 134 | self.linear = nn.Linear(self.cfg[-1], num_classes) 135 | self.apply(_weights_init) 136 | 137 | def _make_layer(self, block, planes, cfg, stride): 138 | num_block = len(cfg) // 2 139 | strides = [stride] + [1]*(num_block-1) 140 | layers = nn.ModuleList() 141 | count = 0 142 | for idx, stride in enumerate(strides): 143 | layers.append(block(planes[count], cfg[count:count+2], stride, affine=self.affine, se=self.se, se_reduction=self.se_reduction)) 144 | count += 2 145 | assert count == len(cfg), 'cfg and block num mismatch' 146 | return nn.Sequential(*layers) 147 | 148 | def forward(self, x): 149 | out = self.relu1(self.bn1(self.conv1(x))) 150 | out = self.layer1(out) 151 | out = self.layer2(out) 152 | out = self.layer3(out) 153 | avgpool = nn.AvgPool2d(out.size()[3]) 154 | out = avgpool(out) 155 | 156 | out = out.view(out.size(0), -1) 157 | out = self.linear(out) 158 | return out 159 | 160 | def resnet_decode(cfg, num_classes, se=False, se_reduction=-1): 161 | return ResNetDecode(BasicBlock, cfg, num_classes=num_classes, se=se, se_reduction=se_reduction) 162 | -------------------------------------------------------------------------------- /ToyExample/feature_extract.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | import math 5 | from scipy import optimize 6 | import random, sys 7 | sys.path.append("..") 8 | from CKA import cka 9 | import torch 10 | import torch.nn as nn 11 | from run_manager import RunManager 12 | from models import VGG_CIFAR, TrainRunConfig 13 | 14 | parser = argparse.ArgumentParser() 15 | 16 | """ model config """ 17 | parser.add_argument('--path', type=str) 18 | parser.add_argument('--model', type=str, default="vgg", 19 | choices=['vgg']) 20 | parser.add_argument('--cfg', type=str, default="None") 21 | parser.add_argument('--manual_seed', default=0, type=int) 22 | parser.add_argument("--target_flops", default=0, type=int) 23 | parser.add_argument("--beta", default=1, type=int) 24 | 25 | """ dataset config """ 26 | parser.add_argument('--dataset', type=str, default='cifar10', 27 | choices=['cifar10']) 28 | parser.add_argument('--save_path', type=str, default='/userhome/data/cifar10') 29 | 30 | """ runtime config """ 31 | parser.add_argument('--gpu', help='gpu available', default='0') 32 | parser.add_argument('--test_batch_size', type=int, default=100) 33 | parser.add_argument('--n_worker', type=int, default=24) 34 | parser.add_argument("--local_rank", default=0, type=int) 35 | 36 | if __name__ == '__main__': 37 | args = parser.parse_args() 38 | 39 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) 40 | torch.cuda.set_device(0) 41 | 42 | random.seed(args.manual_seed) 43 | torch.manual_seed(args.manual_seed) 44 | torch.cuda.manual_seed_all(args.manual_seed) 45 | np.random.seed(args.manual_seed) 46 | # distributed setting 47 | torch.distributed.init_process_group(backend='nccl', 48 | init_method='env://') 49 | args.world_size = torch.distributed.get_world_size() 50 | 51 | # prepare run config 52 | run_config_path = '%s/run.config' % args.path 53 | 54 | run_config = TrainRunConfig( 55 | **args.__dict__ 56 | ) 57 | if args.local_rank == 0: 58 | print('Run config:') 59 | for k, v in args.__dict__.items(): 60 | print('\t%s: %s' % (k, v)) 61 | 62 | if args.model == "vgg": 63 | assert args.dataset == 'cifar10', 'vgg only supports cifar10 dataset' 64 | net = VGG_CIFAR( 65 | num_classes=run_config.data_provider.n_classes, cfg=eval(args.cfg)) 66 | 67 | # build run manager 68 | run_manager = RunManager(args.path, net, run_config) 69 | 70 | # load checkpoints 71 | best_model_path = '%s/checkpoint/model_best.pth.tar' % args.path 72 | assert os.path.isfile(best_model_path), 'wrong path' 73 | if torch.cuda.is_available(): 74 | checkpoint = torch.load(best_model_path) 75 | else: 76 | checkpoint = torch.load(best_model_path, map_location='cpu') 77 | if 'state_dict' in checkpoint: 78 | checkpoint = checkpoint['state_dict'] 79 | run_manager.net.load_state_dict(checkpoint) 80 | output_dict = {} 81 | 82 | # feature extract 83 | data_loader = run_manager.run_config.test_loader 84 | data = next(iter(data_loader)) 85 | data = data[0] 86 | n = data.size()[0] 87 | 88 | with torch.no_grad(): 89 | feature = net.feature_extract(data) 90 | 91 | for i in range(len(feature)): 92 | feature[i] = feature[i].view(n, -1) 93 | feature[i] = feature[i].data.cpu().numpy() 94 | 95 | similar_matrix = np.zeros((len(feature), len(feature))) 96 | 97 | for i in range(len(feature)): 98 | for j in range(len(feature)): 99 | with torch.no_grad(): 100 | similar_matrix[i][j] = cka.cka(cka.gram_linear(feature[i]), cka.gram_linear(feature[j])) 101 | 102 | def sum_list(a, j): 103 | b = 0 104 | for i in range(len(a)): 105 | if i != j: 106 | b += a[i] 107 | return b 108 | 109 | important = [] 110 | temp = [] 111 | flops = [] 112 | 113 | for i in range(len(feature)): 114 | temp.append( sum_list(similar_matrix[i], i) ) 115 | 116 | if args.beta !=1: 117 | b = sum_list(temp, -1) 118 | temp = [x / b for x in temp] 119 | print("hahahahahahahahahahahah") 120 | 121 | for i in range(len(feature)): 122 | important.append(math.exp(-1 * args.beta * temp[i])) 123 | 124 | length = len(net.cfg) 125 | flops_singlecfg, flops_doublecfg, flops_squarecfg = net.cfg2flops_perlayer(net.cfg, length) 126 | important = np.array(important) 127 | important = np.negative(important) 128 | 129 | # Objective function 130 | def func(x, sign=1.0): 131 | """ Objective function """ 132 | global important,length 133 | sum_fuc =[] 134 | for i in range(length): 135 | sum_fuc.append(x[i]*important[i]) 136 | return sum(sum_fuc) 137 | 138 | # Derivative function of objective function 139 | def func_deriv(x, sign=1.0): 140 | """ Derivative of objective function """ 141 | global important 142 | diff = [] 143 | for i in range(len(important)): 144 | diff.append(sign * (important[i])) 145 | return np.array(diff) 146 | 147 | # Constraint function 148 | def constrain_func(x): 149 | """ constrain function """ 150 | global flops_singlecfg, flops_doublecfg, flops_squarecfg, length 151 | a = [] 152 | for i in range(length): 153 | a.append(x[i] * flops_singlecfg[i]) 154 | a.append(flops_squarecfg[i] * x[i] * x[i]) 155 | for i in range(1,length): 156 | for j in range(i): 157 | a.append(x[i] * x[j] * flops_doublecfg[i][j]) 158 | return np.array([args.target_flops - sum(a)]) 159 | 160 | 161 | bnds = [] 162 | for i in range(length): 163 | bnds.append((0,1)) 164 | 165 | bnds = tuple(bnds) 166 | cons = ({'type': 'ineq', 167 | 'fun': constrain_func}) 168 | 169 | result = optimize.minimize(func,x0=[1 for i in range(length)], jac=func_deriv, method='SLSQP', bounds=bnds, constraints=cons) 170 | prun_cfg = np.around(np.array(net.cfg)*result.x) 171 | 172 | optimize_cfg = [] 173 | for i in range(len(prun_cfg)): 174 | b = list(prun_cfg)[i].tolist() 175 | optimize_cfg.append(int(b)) 176 | print(optimize_cfg) 177 | print(net.cfg2flops(prun_cfg)) 178 | print(net.cfg2flops(net.cfg)) 179 | 180 | 181 | -------------------------------------------------------------------------------- /ToyExample/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | 5 | import torch 6 | import torch.nn as nn 7 | from run_manager import RunManager 8 | from models import VGG_CIFAR, TrainRunConfig 9 | 10 | parser = argparse.ArgumentParser() 11 | 12 | """ basic config """ 13 | parser.add_argument('--path', type=str, default='Exp/debug') 14 | parser.add_argument('--gpu', help='gpu available', default='0,1,2,3') 15 | parser.add_argument('--train', action='store_true', default=True) 16 | parser.add_argument('--manual_seed', default=13, type=int) 17 | parser.add_argument('--resume', action='store_true') 18 | parser.add_argument('--validation_frequency', type=int, default=1) 19 | parser.add_argument('--print_frequency', type=int, default=100) 20 | 21 | """ optimizer config """ 22 | parser.add_argument('--n_epochs', type=int, default=300) 23 | parser.add_argument('--init_lr', type=float, default=0.1) 24 | parser.add_argument('--lr_schedule_type', type=str, default='cosine') 25 | parser.add_argument('--warm_epoch', type=int, default=5) 26 | parser.add_argument('--opt_type', type=str, default='sgd', choices=['sgd']) 27 | parser.add_argument('--momentum', type=float, default=0.9) # opt_param 28 | parser.add_argument('--no_nesterov', action='store_true') # opt_param 29 | parser.add_argument('--weight_decay', type=float, default=4e-5) 30 | parser.add_argument('--label_smoothing', type=float, default=0.1) 31 | parser.add_argument('--no_decay_keys', type=str, 32 | default='bn#bias', choices=[None, 'bn', 'bn#bias']) 33 | 34 | """ dataset config """ 35 | parser.add_argument('--dataset', type=str, default='cifar10', 36 | choices=['cifar10', 'imagenet', 'imagenet10', 'imagenet100']) 37 | parser.add_argument('--save_path', type=str, default='/userhome/data/cifar10') 38 | parser.add_argument('--train_batch_size', type=int, default=256) 39 | parser.add_argument('--test_batch_size', type=int, default=250) 40 | 41 | """ model config """ 42 | parser.add_argument('--model', type=str, default="vgg", 43 | choices=['vgg', 'resnet20', 'resnet56', 'resnet110', 'resnet18', 'resnet34', 'resnet50_TAS', 'mobilenetv2_autoslim', 'resnet50', 'mobilenetv2', 'mobilenet']) 44 | parser.add_argument('--cfg', type=str, default="None") 45 | parser.add_argument('--base_path', type=str, default=None) 46 | parser.add_argument('--dropout', type=float, default=0.0) 47 | parser.add_argument('--model_init', type=str, 48 | default='he_fout', choices=['he_fin', 'he_fout']) 49 | parser.add_argument('--init_div_groups', action='store_true') 50 | 51 | """ runtime config """ 52 | parser.add_argument('--n_worker', type=int, default=24) 53 | parser.add_argument('--sync_bn', action='store_true', 54 | help='enabling apex sync BN.') 55 | parser.add_argument("--local_rank", default=0, type=int) 56 | 57 | 58 | if __name__ == '__main__': 59 | args = parser.parse_args() 60 | 61 | # random.seed(args.manual_seed) 62 | # torch.manual_seed(args.manual_seed) 63 | # torch.cuda.manual_seed(args.manual_seed) 64 | # torch.cuda.manual_seed_all(args.manual_seed) 65 | # np.random.seed(args.manual_seed) 66 | torch.backends.cudnn.benchmark = True 67 | # torch.backends.cudnn.deterministic = True 68 | # torch.backends.cudnn.enabled = True 69 | os.makedirs(args.path, exist_ok=True) 70 | 71 | # distributed setting 72 | torch.distributed.init_process_group(backend='nccl', 73 | init_method='env://') 74 | args.world_size = torch.distributed.get_world_size() 75 | 76 | # prepare run config 77 | run_config_path = '%s/run.config' % args.path 78 | # build run config from args 79 | args.lr_schedule_param = None 80 | args.opt_param = { 81 | 'momentum': args.momentum, 82 | 'nesterov': not args.no_nesterov, 83 | } 84 | run_config = TrainRunConfig( 85 | **args.__dict__ 86 | ) 87 | 88 | if args.local_rank == 0: 89 | print('Run config:') 90 | for k, v in run_config.config.items(): 91 | print('\t%s: %s' % (k, v)) 92 | 93 | weight_path = None 94 | net_origin = None 95 | if args.model == "vgg": 96 | assert args.dataset == 'cifar10', 'vgg only supports cifar10 dataset' 97 | net = VGG_CIFAR( 98 | num_classes=run_config.data_provider.n_classes, cfg=eval(args.cfg)) 99 | if args.base_path is not None: 100 | weight_path = args.base_path + '/checkpoint/model_best.pth.tar' 101 | net_origin = nn.DataParallel( 102 | VGG_CIFAR(num_classes=run_config.data_provider.n_classes)) 103 | 104 | # build run manager 105 | run_manager = RunManager(args.path, net, run_config) 106 | if args.local_rank == 0: 107 | run_manager.save_config(print_info=True) 108 | 109 | # load checkpoints 110 | if args.base_path is not None: 111 | weight_path = args.base_path+'/checkpoint/model_best.pth.tar' 112 | if args.resume: 113 | run_manager.load_model() 114 | if args.train and run_manager.best_acc == 0: 115 | loss, acc1, acc5 = run_manager.validate( 116 | is_test=True, return_top5=True) 117 | run_manager.best_acc = acc1 118 | elif weight_path is not None and os.path.isfile(weight_path): 119 | assert net_origin is not None, "original network is None" 120 | net_origin.load_state_dict(torch.load(weight_path)['state_dict']) 121 | net_origin = net_origin.module 122 | if args.model == 'vgg': 123 | run_manager.reset_model( 124 | VGG_CIFAR(cfg=eval(args.cfg), cutout=True), net_origin.cpu()) 125 | else: 126 | print('Random initialization') 127 | 128 | # train 129 | if args.train: 130 | print('Start training: %d' % args.local_rank) 131 | if args.local_rank == 0: 132 | print(net) 133 | run_manager.train(print_top5=True) 134 | if args.local_rank == 0: 135 | run_manager.save_model() 136 | 137 | output_dict = {} 138 | 139 | # test 140 | print('Test on test set') 141 | loss, acc1, acc5 = run_manager.validate(is_test=True, return_top5=True) 142 | log = 'test_loss: %f\t test_acc1: %f\t test_acc5: %f' % (loss, acc1, acc5) 143 | run_manager.write_log(log, prefix='test') 144 | output_dict = { 145 | **output_dict, 146 | 'test_loss': '%f' % loss, 'test_acc1': '%f' % acc1, 'test_acc5': '%f' % acc5 147 | } 148 | json.dump(output_dict, open('%s/output' % args.path, 'w'), indent=4) 149 | -------------------------------------------------------------------------------- /Cifar/feature_extract.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | import math 5 | from scipy import optimize 6 | import random, sys 7 | sys.path.append("..") 8 | from CKA import cka 9 | import torch 10 | import torch.nn as nn 11 | from run_manager import RunManager 12 | from nets import ResNet_CIFAR, TrainRunConfig 13 | 14 | parser = argparse.ArgumentParser() 15 | 16 | """ model config """ 17 | parser.add_argument('--path', type=str) 18 | parser.add_argument('--model', type=str, default="vgg", 19 | choices=['resnet20', 'resnet56']) 20 | parser.add_argument('--cfg', type=str, default="None") 21 | parser.add_argument('--manual_seed', default=0, type=int) 22 | parser.add_argument("--target_flops", default=0, type=int) 23 | parser.add_argument("--beta", default=1, type=int) 24 | 25 | """ dataset config """ 26 | parser.add_argument('--dataset', type=str, default='cifar10', 27 | choices=['cifar10', 'imagenet']) 28 | parser.add_argument('--save_path', type=str, default='/userhome/data/cifar10') 29 | 30 | """ runtime config """ 31 | parser.add_argument('--gpu', help='gpu available', default='0') 32 | parser.add_argument('--test_batch_size', type=int, default=100) 33 | parser.add_argument('--n_worker', type=int, default=24) 34 | parser.add_argument("--local_rank", default=0, type=int) 35 | 36 | if __name__ == '__main__': 37 | args = parser.parse_args() 38 | 39 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) 40 | torch.cuda.set_device(0) 41 | 42 | random.seed(args.manual_seed) 43 | torch.manual_seed(args.manual_seed) 44 | torch.cuda.manual_seed_all(args.manual_seed) 45 | np.random.seed(args.manual_seed) 46 | # distributed setting 47 | torch.distributed.init_process_group(backend='nccl', 48 | init_method='env://') 49 | args.world_size = torch.distributed.get_world_size() 50 | 51 | # prepare run config 52 | run_config_path = '%s/run.config' % args.path 53 | 54 | run_config = TrainRunConfig( 55 | **args.__dict__ 56 | ) 57 | if args.local_rank == 0: 58 | print('Run config:') 59 | for k, v in args.__dict__.items(): 60 | print('\t%s: %s' % (k, v)) 61 | 62 | if args.model == "resnet20": 63 | assert args.dataset == 'cifar10', 'resnet20 only supports cifar10 dataset' 64 | net = ResNet_CIFAR( 65 | depth=20, num_classes=run_config.data_provider.n_classes, cfg=eval(args.cfg)) 66 | elif args.model == "resnet56": 67 | assert args.dataset == 'cifar10', 'resnet56 only supports cifar10 dataset' 68 | net = ResNet_CIFAR( 69 | depth=56, num_classes=run_config.data_provider.n_classes, cfg=eval(args.cfg)) 70 | 71 | # build run manager 72 | run_manager = RunManager(args.path, net, run_config) 73 | 74 | # load checkpoints 75 | best_model_path = '%s/checkpoint/model_best.pth.tar' % args.path 76 | assert os.path.isfile(best_model_path), 'wrong path' 77 | if torch.cuda.is_available(): 78 | checkpoint = torch.load(best_model_path) 79 | else: 80 | checkpoint = torch.load(best_model_path, map_location='cpu') 81 | if 'state_dict' in checkpoint: 82 | checkpoint = checkpoint['state_dict'] 83 | run_manager.net.load_state_dict(checkpoint) 84 | output_dict = {} 85 | 86 | # feature extract 87 | data_loader = run_manager.run_config.test_loader 88 | data = next(iter(data_loader)) 89 | data = data[0] 90 | n = data.size()[0] 91 | 92 | with torch.no_grad(): 93 | feature = net.feature_extract(data) 94 | 95 | for i in range(len(feature)): 96 | feature[i] = feature[i].view(n, -1) 97 | feature[i] = feature[i].data.cpu().numpy() 98 | 99 | similar_matrix = np.zeros((len(feature), len(feature))) 100 | 101 | for i in range(len(feature)): 102 | for j in range(len(feature)): 103 | with torch.no_grad(): 104 | similar_matrix[i][j] = cka.cka(cka.gram_linear(feature[i]), cka.gram_linear(feature[j])) 105 | 106 | def sum_list(a, j): 107 | b = 0 108 | for i in range(len(a)): 109 | if i != j: 110 | b += a[i] 111 | return b 112 | 113 | important = [] 114 | temp = [] 115 | flops = [] 116 | 117 | for i in range(len(feature)): 118 | temp.append( sum_list(similar_matrix[i], i) ) 119 | 120 | b = sum_list(temp, -1) 121 | temp = [x/b for x in temp] 122 | 123 | for i in range(len(feature)): 124 | important.append( math.exp(-1* args.beta *temp[i] ) ) 125 | 126 | length = len(net.cfg) 127 | flops_singlecfg, flops_doublecfg, flops_squarecfg = net.cfg2flops_perlayer(net.cfg, length) 128 | important = np.array(important) 129 | important = np.negative(important) 130 | 131 | # Objective function 132 | def func(x, sign=1.0): 133 | """ Objective function """ 134 | global important,length 135 | sum_fuc =[] 136 | for i in range(length): 137 | sum_fuc.append(x[i]*important[i]) 138 | return sum(sum_fuc) 139 | 140 | # Derivative function of objective function 141 | def func_deriv(x, sign=1.0): 142 | """ Derivative of objective function """ 143 | global important 144 | diff = [] 145 | for i in range(len(important)): 146 | diff.append(sign * (important[i])) 147 | return np.array(diff) 148 | 149 | # Constraint function 150 | def constrain_func(x): 151 | """ constrain function """ 152 | global flops_singlecfg, flops_doublecfg, flops_squarecfg, length 153 | a = [] 154 | for i in range(length): 155 | a.append(x[i] * flops_singlecfg[i]) 156 | a.append(flops_squarecfg[i] * x[i] * x[i]) 157 | for i in range(1,length): 158 | for j in range(i): 159 | a.append(x[i] * x[j] * flops_doublecfg[i][j]) 160 | return np.array([args.target_flops - sum(a)]) 161 | 162 | 163 | bnds = [] 164 | for i in range(length): 165 | bnds.append((0,1)) 166 | 167 | bnds = tuple(bnds) 168 | cons = ({'type': 'ineq', 169 | 'fun': constrain_func}) 170 | 171 | result = optimize.minimize(func,x0=[1 for i in range(length)], jac=func_deriv, method='SLSQP', bounds=bnds, constraints=cons) 172 | prun_cfg = np.around(np.array(net.cfg)*result.x) 173 | 174 | optimize_cfg = [] 175 | for i in range(len(prun_cfg)): 176 | b = list(prun_cfg)[i].tolist() 177 | optimize_cfg.append(int(b)) 178 | print(optimize_cfg) 179 | print(net.cfg2flops(prun_cfg)) 180 | print(net.cfg2flops(net.cfg)) 181 | 182 | 183 | -------------------------------------------------------------------------------- /Cifar/utils/model_profiling.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | model_profiling_hooks = [] 8 | model_profiling_speed_hooks = [] 9 | 10 | name_space = 95 11 | params_space = 15 12 | macs_space = 15 13 | seconds_space = 15 14 | 15 | num_forwards = 10 16 | 17 | 18 | class Timer(object): 19 | def __init__(self, verbose=False): 20 | self.verbose = verbose 21 | self.start = None 22 | self.end = None 23 | 24 | def __enter__(self): 25 | self.start = time.time() 26 | return self 27 | 28 | def __exit__(self, *args): 29 | self.end = time.time() 30 | self.time = self.end - self.start 31 | if self.verbose: 32 | print('Elapsed time: %f ms.' % self.time) 33 | 34 | 35 | def get_params(self): 36 | """get number of params in module""" 37 | return np.sum( 38 | [np.prod(list(w.size())) for w in self.parameters()]) 39 | 40 | 41 | def run_forward(self, input): 42 | with Timer() as t: 43 | for _ in range(num_forwards): 44 | self.forward(*input) 45 | torch.cuda.synchronize() 46 | return int(t.time * 1e9 / num_forwards) 47 | 48 | 49 | def conv_module_name_filter(name): 50 | """filter module name to have a short view""" 51 | filters = { 52 | 'kernel_size': 'k', 53 | 'stride': 's', 54 | 'padding': 'pad', 55 | 'bias': 'b', 56 | 'groups': 'g', 57 | } 58 | for k in filters: 59 | name = name.replace(k, filters[k]) 60 | return name 61 | 62 | 63 | def module_profiling(self, input, output, verbose): 64 | ins = input[0].size() 65 | outs = output.size() 66 | # NOTE: There are some difference between type and isinstance, thus please 67 | # be careful. 68 | t = type(self) 69 | if isinstance(self, nn.Conv2d): 70 | self.n_macs = (ins[1] * outs[1] * 71 | self.kernel_size[0] * self.kernel_size[1] * 72 | outs[2] * outs[3] // self.groups) * outs[0] 73 | self.n_params = get_params(self) 74 | self.n_seconds = run_forward(self, input) 75 | self.name = conv_module_name_filter(self.__repr__()) 76 | elif isinstance(self, nn.ConvTranspose2d): 77 | self.n_macs = (ins[1] * outs[1] * 78 | self.kernel_size[0] * self.kernel_size[1] * 79 | outs[2] * outs[3] // self.groups) * outs[0] 80 | self.n_params = get_params(self) 81 | self.n_seconds = run_forward(self, input) 82 | self.name = conv_module_name_filter(self.__repr__()) 83 | elif isinstance(self, nn.Linear): 84 | self.n_macs = ins[1] * outs[1] * outs[0] 85 | self.n_params = get_params(self) 86 | self.n_seconds = run_forward(self, input) 87 | self.name = self.__repr__() 88 | elif isinstance(self, nn.AvgPool2d): 89 | # NOTE: this function is correct only when stride == kernel size 90 | self.n_macs = ins[1] * ins[2] * ins[3] * ins[0] 91 | self.n_params = 0 92 | self.n_seconds = run_forward(self, input) 93 | self.name = self.__repr__() 94 | elif isinstance(self, nn.AdaptiveAvgPool2d): 95 | # NOTE: this function is correct only when stride == kernel size 96 | self.n_macs = ins[1] * ins[2] * ins[3] * ins[0] 97 | self.n_params = 0 98 | self.n_seconds = run_forward(self, input) 99 | self.name = self.__repr__() 100 | else: 101 | # This works only in depth-first travel of modules. 102 | self.n_macs = 0 103 | self.n_params = 0 104 | self.n_seconds = 0 105 | num_children = 0 106 | for m in self.children(): 107 | self.n_macs += getattr(m, 'n_macs', 0) 108 | self.n_params += getattr(m, 'n_params', 0) 109 | self.n_seconds += getattr(m, 'n_seconds', 0) 110 | num_children += 1 111 | ignore_zeros_t = [ 112 | nn.BatchNorm2d, nn.Dropout2d, nn.Dropout, nn.Sequential, 113 | nn.ReLU6, nn.ReLU, nn.MaxPool2d, 114 | nn.modules.padding.ZeroPad2d, nn.modules.activation.Sigmoid, 115 | ] 116 | if (not getattr(self, 'ignore_model_profiling', False) and 117 | self.n_macs == 0 and 118 | t not in ignore_zeros_t): 119 | print( 120 | 'WARNING: leaf module {} has zero n_macs.'.format(type(self))) 121 | return 122 | if verbose: 123 | print( 124 | self.name.ljust(name_space, ' ') + 125 | '{:,}'.format(self.n_params).rjust(params_space, ' ') + 126 | '{:,}'.format(self.n_macs).rjust(macs_space, ' ') + 127 | '{:,}'.format(self.n_seconds).rjust(seconds_space, ' ')) 128 | return 129 | 130 | 131 | def add_profiling_hooks(m, verbose): 132 | global model_profiling_hooks 133 | model_profiling_hooks.append( 134 | m.register_forward_hook(lambda m, input, output: module_profiling( 135 | m, input, output, verbose=verbose))) 136 | 137 | 138 | def remove_profiling_hooks(): 139 | global model_profiling_hooks 140 | for h in model_profiling_hooks: 141 | h.remove() 142 | model_profiling_hooks = [] 143 | 144 | 145 | def model_profiling(model, height, width, batch=1, channel=3, use_cuda=True, 146 | verbose=True): 147 | """ Pytorch model profiling with input image size 148 | (batch, channel, height, width). 149 | The function exams the number of multiply-accumulates (n_macs). 150 | 151 | Args: 152 | model: pytorch model 153 | height: int 154 | width: int 155 | batch: int 156 | channel: int 157 | use_cuda: bool 158 | 159 | Returns: 160 | macs: int 161 | params: int 162 | 163 | """ 164 | model.eval() 165 | data = torch.rand(batch, channel, height, width) 166 | device = torch.device("cuda:0" if use_cuda else "cpu") 167 | model = model.to(device) 168 | data = data.to(device) 169 | model.apply(lambda m: add_profiling_hooks(m, verbose=verbose)) 170 | print( 171 | 'Item'.ljust(name_space, ' ') + 172 | 'params'.rjust(macs_space, ' ') + 173 | 'macs'.rjust(macs_space, ' ') + 174 | 'nanosecs'.rjust(seconds_space, ' ')) 175 | if verbose: 176 | print(''.center( 177 | name_space + params_space + macs_space + seconds_space, '-')) 178 | model(data) 179 | if verbose: 180 | print(''.center( 181 | name_space + params_space + macs_space + seconds_space, '-')) 182 | print( 183 | 'Total'.ljust(name_space, ' ') + 184 | '{:,}'.format(model.n_params).rjust(params_space, ' ') + 185 | '{:,}'.format(model.n_macs).rjust(macs_space, ' ') + 186 | '{:,}'.format(model.n_seconds).rjust(seconds_space, ' ')) 187 | remove_profiling_hooks() 188 | return model.n_macs, model.n_params 189 | -------------------------------------------------------------------------------- /Imagenet/feature_extract.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | import math 5 | from scipy import optimize 6 | import random, sys 7 | sys.path.append("../../ITPruner") 8 | from CKA import cka 9 | import torch 10 | import torch.nn as nn 11 | from run_manager import RunManager 12 | from models import ResNet_ImageNet, MobileNet, MobileNetV2, TrainRunConfig 13 | 14 | parser = argparse.ArgumentParser() 15 | 16 | """ model config """ 17 | parser.add_argument('--path', type=str) 18 | parser.add_argument('--model', type=str, default="vgg", 19 | choices=['resnet50', 'mobilenetv2', 'mobilenet']) 20 | parser.add_argument('--cfg', type=str, default="None") 21 | parser.add_argument('--manual_seed', default=0, type=int) 22 | parser.add_argument("--target_flops", default=0, type=int) 23 | parser.add_argument("--beta", default=1, type=int) 24 | 25 | """ dataset config """ 26 | parser.add_argument('--dataset', type=str, default='cifar10', 27 | choices=['cifar10', 'imagenet']) 28 | parser.add_argument('--save_path', type=str, default='/userhome/data/cifar10') 29 | 30 | """ runtime config """ 31 | parser.add_argument('--gpu', help='gpu available', default='0') 32 | parser.add_argument('--test_batch_size', type=int, default=100) 33 | parser.add_argument('--n_worker', type=int, default=24) 34 | parser.add_argument("--local_rank", default=0, type=int) 35 | 36 | if __name__ == '__main__': 37 | args = parser.parse_args() 38 | 39 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) 40 | torch.cuda.set_device(0) 41 | 42 | random.seed(args.manual_seed) 43 | torch.manual_seed(args.manual_seed) 44 | torch.cuda.manual_seed_all(args.manual_seed) 45 | np.random.seed(args.manual_seed) 46 | # distributed setting 47 | torch.distributed.init_process_group(backend='nccl', 48 | init_method='env://') 49 | args.world_size = torch.distributed.get_world_size() 50 | 51 | # prepare run config 52 | run_config_path = '%s/run.config' % args.path 53 | 54 | run_config = TrainRunConfig( 55 | **args.__dict__ 56 | ) 57 | if args.local_rank == 0: 58 | print('Run config:') 59 | for k, v in args.__dict__.items(): 60 | print('\t%s: %s' % (k, v)) 61 | 62 | if args.model == "resnet50": 63 | assert args.dataset == 'imagenet', 'resnet50 only supports imagenet dataset' 64 | net = ResNet_ImageNet( 65 | depth=50, num_classes=run_config.data_provider.n_classes, cfg=eval(args.cfg)) 66 | elif args.model == "mobilenetv2": 67 | assert args.dataset == 'imagenet', 'mobilenetv2 only supports imagenet dataset' 68 | net = MobileNetV2( 69 | num_classes=run_config.data_provider.n_classes, cfg=eval(args.cfg)) 70 | elif args.model == "mobilenet": 71 | assert args.dataset == 'imagenet', 'mobilenet only supports imagenet dataset' 72 | net = MobileNet( 73 | num_classes=run_config.data_provider.n_classes, cfg=eval(args.cfg)) 74 | 75 | # build run manager 76 | run_manager = RunManager(args.path, net, run_config) 77 | 78 | # load checkpoints 79 | best_model_path = '%s/checkpoint/model_best.pth.tar' % args.path 80 | assert os.path.isfile(best_model_path), 'wrong path' 81 | if torch.cuda.is_available(): 82 | checkpoint = torch.load(best_model_path) 83 | else: 84 | checkpoint = torch.load(best_model_path, map_location='cpu') 85 | if 'state_dict' in checkpoint: 86 | checkpoint = checkpoint['state_dict'] 87 | run_manager.net.load_state_dict(checkpoint) 88 | output_dict = {} 89 | 90 | # feature extract 91 | data_loader = run_manager.run_config.test_loader 92 | data = next(iter(data_loader)) 93 | data = data[0] 94 | n = data.size()[0] 95 | 96 | with torch.no_grad(): 97 | feature = net.feature_extract(data) 98 | 99 | for i in range(len(feature)): 100 | feature[i] = feature[i].view(n, -1) 101 | feature[i] = feature[i].data.cpu().numpy() 102 | 103 | similar_matrix = np.zeros((len(feature), len(feature))) 104 | 105 | for i in range(len(feature)): 106 | for j in range(len(feature)): 107 | with torch.no_grad(): 108 | similar_matrix[i][j] = cka.cka(cka.gram_linear(feature[i]), cka.gram_linear(feature[j])) 109 | 110 | def sum_list(a, j): 111 | b = 0 112 | for i in range(len(a)): 113 | if i != j: 114 | b += a[i] 115 | return b 116 | 117 | important = [] 118 | temp = [] 119 | flops = [] 120 | 121 | for i in range(len(feature)): 122 | temp.append( sum_list(similar_matrix[i], i) ) 123 | 124 | b = sum_list(temp, -1) 125 | temp = [x/b for x in temp] 126 | 127 | for i in range(len(feature)): 128 | important.append( math.exp(-1* args.beta *temp[i] ) ) 129 | 130 | length = len(net.cfg) 131 | flops_singlecfg, flops_doublecfg, flops_squarecfg = net.cfg2flops_perlayer(net.cfg, length) 132 | important = np.array(important) 133 | important = np.negative(important) 134 | 135 | # Objective function 136 | def func(x, sign=1.0): 137 | """ Objective function """ 138 | global important,length 139 | sum_fuc =[] 140 | for i in range(length): 141 | sum_fuc.append(x[i]*important[i]) 142 | return sum(sum_fuc) 143 | 144 | # Derivative function of objective function 145 | def func_deriv(x, sign=1.0): 146 | """ Derivative of objective function """ 147 | global important 148 | diff = [] 149 | for i in range(len(important)): 150 | diff.append(sign * (important[i])) 151 | return np.array(diff) 152 | 153 | # Constraint function 154 | def constrain_func(x): 155 | """ constrain function """ 156 | global flops_singlecfg, flops_doublecfg, flops_squarecfg, length 157 | a = [] 158 | for i in range(length): 159 | a.append(x[i] * flops_singlecfg[i]) 160 | a.append(flops_squarecfg[i] * x[i] * x[i]) 161 | for i in range(1,length): 162 | for j in range(i): 163 | a.append(x[i] * x[j] * flops_doublecfg[i][j]) 164 | return np.array([args.target_flops - sum(a)]) 165 | 166 | 167 | bnds = [] 168 | for i in range(length): 169 | bnds.append((0,1)) 170 | 171 | bnds = tuple(bnds) 172 | cons = ({'type': 'ineq', 173 | 'fun': constrain_func}) 174 | 175 | result = optimize.minimize(func,x0=[1 for i in range(length)], jac=func_deriv, method='SLSQP', bounds=bnds, constraints=cons) 176 | prun_cfg = np.around(np.array(net.cfg)*result.x) 177 | 178 | optimize_cfg = [] 179 | for i in range(len(prun_cfg)): 180 | b = list(prun_cfg)[i].tolist() 181 | optimize_cfg.append(int(b)) 182 | print(optimize_cfg) 183 | print(net.cfg2flops(prun_cfg)) 184 | print(net.cfg2flops(net.cfg)) 185 | 186 | 187 | -------------------------------------------------------------------------------- /ToyExample/models/vgg.py: -------------------------------------------------------------------------------- 1 | # https://github.com/pytorch/vision/blob/master/torchvision/models 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from collections import OrderedDict 6 | import torch.utils.model_zoo as model_zoo 7 | 8 | from utils import cutout_batch 9 | from models.base_models import MyNetwork 10 | import numpy as np 11 | 12 | cifar_cfg = { 13 | 11: [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512], 14 | 13: [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512], 15 | 16: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512], 16 | 19: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512], 17 | } 18 | 19 | 20 | class VGG_CIFAR(MyNetwork): 21 | def __init__(self, cfg=None, cutout=True, num_classes=10): 22 | super(VGG_CIFAR, self).__init__() 23 | if cfg is None: 24 | cfg = [64, 64, 128, 128, 256, 256, 25 | 256, 512, 512, 512, 512, 512, 512] 26 | self.cutout = cutout 27 | self.cfg = cfg 28 | _cfg = list(cfg) 29 | _cfg.insert(2, 'M') 30 | _cfg.insert(5, 'M') 31 | _cfg.insert(9, 'M') 32 | _cfg.insert(13, 'M') 33 | self._cfg = _cfg 34 | self.feature = self.make_layers(_cfg, True) 35 | self.avgpool = nn.AvgPool2d(2) 36 | self.classifier = nn.Sequential( 37 | nn.Linear(self.cfg[-1], 512), 38 | # nn.BatchNorm1d(512), 39 | nn.ReLU(inplace=True), 40 | nn.Linear(512, num_classes) 41 | ) 42 | self.num_classes = num_classes 43 | self.classifier_param = ( 44 | self.cfg[-1] + 1) * 512 + (512 + 1) * num_classes 45 | 46 | def make_layers(self, cfg, batch_norm=False): 47 | layers = [] 48 | in_channels = 3 49 | pool_index = 0 50 | conv_index = 0 51 | for v in cfg: 52 | if v == 'M': 53 | layers += [('maxpool_%d' % pool_index, 54 | nn.MaxPool2d(kernel_size=2, stride=2))] 55 | pool_index += 1 56 | else: 57 | conv2d = nn.Conv2d( 58 | in_channels, v, kernel_size=3, padding=1, bias=False) 59 | conv_index += 1 60 | if batch_norm: 61 | bn = nn.BatchNorm2d(v) 62 | layers += [('conv_%d' % conv_index, conv2d), ('bn_%d' % conv_index, bn), 63 | ('relu_%d' % conv_index, nn.ReLU(inplace=True))] 64 | else: 65 | layers += [('conv_%d' % conv_index, conv2d), 66 | ('relu_%d' % conv_index, nn.ReLU(inplace=True))] 67 | in_channels = v 68 | self.conv_num = conv_index 69 | return nn.Sequential(OrderedDict(layers)) 70 | 71 | def cfg2param(self, cfg): 72 | total_param = self.classifier_param 73 | c_in = 3 74 | _cfg = list(cfg) 75 | _cfg.insert(2, 'M') 76 | _cfg.insert(5, 'M') 77 | _cfg.insert(9, 'M') 78 | _cfg.insert(13, 'M') 79 | for c_out in _cfg: 80 | if isinstance(c_out, int): 81 | total_param += 3 * 3 * c_in * c_out + 2 * c_out 82 | c_in = c_out 83 | total_param += 512 * (_cfg[-1] + 1) 84 | return total_param 85 | 86 | def cfg2flops(self, cfg): # to simplify, only count convolution flops 87 | _cfg = list(cfg) 88 | _cfg.insert(2, 'M') 89 | _cfg.insert(5, 'M') 90 | _cfg.insert(9, 'M') 91 | _cfg.insert(13, 'M') 92 | input_size = 32 93 | num_classes = 10 94 | total_flops = 0 95 | c_in = 3 96 | for c_out in _cfg: 97 | if c_out !='M': 98 | total_flops += 3 * 3 * c_in * c_out * input_size * input_size # conv 99 | total_flops += 4 * c_out * input_size * input_size # bn 100 | total_flops += input_size * input_size * c_out # relu 101 | c_in = c_out 102 | else: 103 | input_size /= 2 104 | total_flops += 2 * input_size * input_size * c_in # max pool 105 | input_size /= 2 106 | total_flops += 3 * input_size * input_size * _cfg[-1] # avg pool 107 | total_flops += (2 * input_size * input_size * 108 | _cfg[-1] - 1) * input_size * input_size * 512 # fc 109 | total_flops += 4 * _cfg[-1] * input_size * input_size # bn_1d 110 | total_flops += input_size * input_size * 512 # relu 111 | total_flops += (2 * input_size * input_size * 512 - 1) * \ 112 | input_size * input_size * num_classes # fc 113 | return total_flops 114 | 115 | def cfg2flops_perlayer(self, cfg, length): # to simplify, only count convolution flops 116 | _cfg = list(cfg) 117 | _cfg.insert(2, 'M') 118 | _cfg.insert(5, 'M') 119 | _cfg.insert(9, 'M') 120 | _cfg.insert(13, 'M') 121 | input_size = 32 122 | num_classes = 10 123 | flops_singlecfg = [0 for j in range(length)] 124 | flops_doublecfg = np.zeros((length, length)) 125 | flops_squarecfg = [0 for j in range(length)] 126 | c_in = 3 127 | count = 0 128 | for c_out in _cfg: 129 | if c_out !='M' and count ==0: 130 | flops_singlecfg[count] += 3 * 3 * c_in * c_out * input_size * input_size # conv 131 | flops_singlecfg[count] += 4 * c_out * input_size * input_size # bn 132 | flops_singlecfg[count] += input_size * input_size * c_out # relu 133 | c_in = c_out 134 | count += 1 135 | elif c_out !='M' and count !=0: 136 | flops_doublecfg[count-1][count] += 3 * 3 * c_in * c_out * input_size * input_size # conv 137 | flops_doublecfg[count][count-1] += 3 * 3 * c_in * c_out * input_size * input_size 138 | flops_singlecfg[count] += 4 * c_out * input_size * input_size # bn 139 | flops_singlecfg[count] += input_size * input_size * c_out # relu 140 | c_in = c_out 141 | count += 1 142 | else: 143 | input_size /= 2 144 | flops_singlecfg[count-1] += 2 * input_size * input_size * c_in # max pool 145 | input_size /= 2 146 | flops_singlecfg[count-1] += 3 * input_size * input_size * _cfg[-1] # avg pool 147 | flops_singlecfg[count-1] += (2 * input_size * input_size * _cfg[-1] ) * input_size * input_size * 512 # fc 148 | flops_singlecfg[count-1] += 4 * _cfg[-1] * input_size * input_size # bn_1d 149 | 150 | return flops_singlecfg, flops_doublecfg, flops_squarecfg 151 | 152 | def forward(self, x): 153 | if self.training and self.cutout: 154 | with torch.no_grad(): 155 | x = cutout_batch(x, 16) 156 | x = self.feature(x) 157 | x = self.avgpool(x) 158 | x = x.view(x.size(0), -1) 159 | y = self.classifier(x) 160 | return y 161 | 162 | def feature_extract(self, x): 163 | tensor = [] 164 | if self.training and self.cutout: 165 | with torch.no_grad(): 166 | x = cutout_batch(x, 16) 167 | for _layer in self.feature: 168 | x = _layer(x) 169 | if type(_layer) is nn.ReLU: 170 | tensor.append(x) 171 | return tensor 172 | 173 | @property 174 | def config(self): 175 | return { 176 | 'name': self.__class__.__name__, 177 | 'cfg': self.cfg, 178 | 'cfg_base': [64, 64, 128, 128, 256, 256, 256, 512, 512, 512, 512, 512, 512], 179 | 'dataset': 'cifar10', 180 | } 181 | 182 | 183 | -------------------------------------------------------------------------------- /Cifar/data_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import datasets, transforms 3 | import torchvision.transforms as transforms 4 | import torchvision.datasets as datasets 5 | import torch.utils.data.distributed 6 | from torch.autograd import Variable 7 | import torch.nn.functional as F 8 | import os 9 | from torch.utils.data import DataLoader, Subset, sampler 10 | import pickle 11 | import pdb 12 | import numpy as np 13 | from PIL import Image 14 | 15 | 16 | def cifar10_loader_search(args, num_workers=4): 17 | normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]], 18 | std=[x/255.0 for x in [63.0, 62.1, 66.7]]) 19 | if args.data_aug: 20 | with torch.no_grad(): 21 | transform_train = transforms.Compose([ 22 | transforms.ToTensor(), 23 | transforms.Lambda(lambda x: F.pad( 24 | Variable(x.unsqueeze(0), requires_grad=False), 25 | (4,4,4,4),mode='reflect').data.squeeze()), 26 | transforms.ToPILImage(), 27 | transforms.RandomCrop(32), 28 | transforms.RandomHorizontalFlip(), 29 | transforms.ToTensor(), 30 | normalize, 31 | ]) 32 | transform_test = transforms.Compose([ 33 | transforms.ToTensor(), 34 | normalize 35 | ]) 36 | else: 37 | transform_train = transforms.Compose([ 38 | transforms.ToTensor(), 39 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 40 | ]) 41 | transform_test = transforms.Compose([ 42 | transforms.ToTensor(), 43 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 44 | ]) 45 | trainset = datasets.CIFAR10(root=args.data_path, train=True, download=True, transform=transform_train) 46 | testset = datasets.CIFAR10(root=args.data_path, train=False, transform=transform_test) 47 | 48 | num_train = len(trainset) 49 | indices = list(range(num_train)) 50 | split = int(np.floor(args.train_portion * num_train * args.total_portion)) # we now add a portion_total arg to just use smaller number of data 51 | split_upb = int(num_train * args.total_portion) 52 | 53 | train_loader = DataLoader(trainset, batch_size=args.batch_size, 54 | sampler=sampler.SubsetRandomSampler(indices[:split]), 55 | num_workers=num_workers, pin_memory=True) 56 | valid_loader = DataLoader(trainset, batch_size=args.batch_size, 57 | sampler=sampler.SubsetRandomSampler(indices[split:split_upb]), 58 | num_workers=num_workers, pin_memory=True) 59 | test_loader = DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) 60 | return train_loader, valid_loader, test_loader 61 | 62 | def cifar10_loader_train(args, num_workers=4): 63 | normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]], 64 | std=[x/255.0 for x in [63.0, 62.1, 66.7]]) 65 | if args.data_aug: 66 | with torch.no_grad(): 67 | transform_train = transforms.Compose([ 68 | transforms.ToTensor(), 69 | transforms.Lambda(lambda x: F.pad(Variable(x.unsqueeze(0), requires_grad=False), 70 | (4,4,4,4),mode='reflect').data.squeeze()), 71 | transforms.ToPILImage(), 72 | transforms.RandomCrop(32), 73 | transforms.RandomHorizontalFlip(), 74 | transforms.ToTensor(), 75 | normalize, 76 | ]) 77 | transform_test = transforms.Compose([ 78 | transforms.ToTensor(), 79 | normalize 80 | ]) 81 | else: 82 | transform_train = transforms.Compose([ 83 | transforms.ToTensor(), 84 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 85 | ]) 86 | transform_test = transforms.Compose([ 87 | transforms.ToTensor(), 88 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 89 | ]) 90 | 91 | trainset = datasets.CIFAR10(root=args.data_path, train=True, download=True, transform=transform_train) 92 | testset = datasets.CIFAR10(root=args.data_path, train=False, transform=transform_test) 93 | 94 | train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, 95 | num_workers=num_workers, pin_memory=True) 96 | test_loader = DataLoader(testset, batch_size=args.batch_size, shuffle=False, 97 | num_workers=num_workers, pin_memory=True) 98 | return train_loader, test_loader 99 | 100 | def cifar100_loader(args, num_workers=4): 101 | normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]], 102 | std=[x/255.0 for x in [63.0, 62.1, 66.7]]) 103 | if args.data_aug: 104 | with torch.no_grad(): 105 | transform_train = transforms.Compose([ 106 | transforms.ToTensor(), 107 | transforms.Lambda(lambda x: F.pad( 108 | Variable(x.unsqueeze(0), requires_grad=False), 109 | (4,4,4,4),mode='reflect').data.squeeze()), 110 | transforms.ToPILImage(), 111 | transforms.RandomCrop(32), 112 | transforms.RandomHorizontalFlip(), 113 | transforms.ToTensor(), 114 | normalize, 115 | ]) 116 | transform_test = transforms.Compose([ 117 | transforms.ToTensor(), 118 | normalize 119 | ]) 120 | else: 121 | transform_train = transforms.Compose([ 122 | transforms.ToTensor(), 123 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 124 | ]) 125 | transform_test = transforms.Compose([ 126 | transforms.ToTensor(), 127 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 128 | ]) 129 | trainset = datasets.CIFAR100(root=args.data_path, train=True, download=True, transform=transform_train) 130 | testset = datasets.CIFAR100(root=args.data_path, train=False, download=True, transform=transform_test) 131 | train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) 132 | test_loader = DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) 133 | return train_loader, test_loader 134 | 135 | 136 | def ilsvrc12_loader_train(args, num_workers=4): 137 | traindir = os.path.join(args.data_path, 'train') 138 | testdir = os.path.join(args.data_path, 'val') 139 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 140 | std=[0.229, 0.224, 0.225]) 141 | 142 | train_dataset = datasets.ImageFolder( 143 | traindir, 144 | transforms.Compose([ 145 | transforms.RandomResizedCrop(224), 146 | transforms.RandomHorizontalFlip(), 147 | transforms.ToTensor(), 148 | normalize, 149 | ])) 150 | 151 | if args.distributed: 152 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 153 | else: 154 | train_sampler = None 155 | 156 | train_loader = torch.utils.data.DataLoader( 157 | train_dataset, batch_size=args.batch_size // args.world_size, shuffle=(train_sampler is None), 158 | num_workers=num_workers, pin_memory=True, sampler=train_sampler) 159 | 160 | test_loader = torch.utils.data.DataLoader( 161 | datasets.ImageFolder(testdir, transforms.Compose([ 162 | transforms.Resize(256), 163 | transforms.CenterCrop(224), 164 | transforms.ToTensor(), 165 | normalize, 166 | ])), 167 | batch_size=args.batch_size // args.world_size, shuffle=False, 168 | num_workers=num_workers, pin_memory=True) 169 | 170 | return [train_loader, test_loader], train_sampler 171 | 172 | -------------------------------------------------------------------------------- /Imagenet/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | 5 | import torch 6 | import torch.nn as nn 7 | from run_manager import RunManager 8 | from models import ResNet_ImageNet, MobileNet, MobileNetV2, TrainRunConfig 9 | 10 | parser = argparse.ArgumentParser() 11 | 12 | """ basic config """ 13 | parser.add_argument('--path', type=str, default='Exp/debug') 14 | parser.add_argument('--gpu', help='gpu available', default='0,1,2,3') 15 | parser.add_argument('--train', action='store_true', default=True) 16 | parser.add_argument('--manual_seed', default=13, type=int) 17 | parser.add_argument('--resume', action='store_true') 18 | parser.add_argument('--validation_frequency', type=int, default=1) 19 | parser.add_argument('--print_frequency', type=int, default=100) 20 | 21 | """ optimizer config """ 22 | parser.add_argument('--n_epochs', type=int, default=300) 23 | parser.add_argument('--init_lr', type=float, default=0.1) 24 | parser.add_argument('--lr_schedule_type', type=str, default='cosine') 25 | parser.add_argument('--warm_epoch', type=int, default=5) 26 | parser.add_argument('--opt_type', type=str, default='sgd', choices=['sgd']) 27 | parser.add_argument('--momentum', type=float, default=0.9) # opt_param 28 | parser.add_argument('--no_nesterov', action='store_true') # opt_param 29 | parser.add_argument('--weight_decay', type=float, default=4e-5) 30 | parser.add_argument('--label_smoothing', type=float, default=0.1) 31 | parser.add_argument('--no_decay_keys', type=str, 32 | default='bn#bias', choices=[None, 'bn', 'bn#bias']) 33 | 34 | """ dataset config """ 35 | parser.add_argument('--dataset', type=str, default='cifar10', 36 | choices=['cifar10', 'imagenet', 'imagenet10', 'imagenet100']) 37 | parser.add_argument('--save_path', type=str, default='/userhome/data/cifar10') 38 | parser.add_argument('--train_batch_size', type=int, default=256) 39 | parser.add_argument('--test_batch_size', type=int, default=250) 40 | 41 | """ model config """ 42 | parser.add_argument('--model', type=str, default="vgg", 43 | choices=['vgg', 'resnet20', 'resnet56', 'resnet110', 'resnet18', 'resnet34', 'resnet50_TAS', 'mobilenetv2_autoslim', 'resnet50', 'mobilenetv2', 'mobilenet']) 44 | parser.add_argument('--cfg', type=str, default="None") 45 | parser.add_argument('--base_path', type=str, default=None) 46 | parser.add_argument('--dropout', type=float, default=0.0) 47 | parser.add_argument('--model_init', type=str, 48 | default='he_fout', choices=['he_fin', 'he_fout']) 49 | parser.add_argument('--init_div_groups', action='store_true') 50 | 51 | """ runtime config """ 52 | parser.add_argument('--n_worker', type=int, default=24) 53 | parser.add_argument('--sync_bn', action='store_true', 54 | help='enabling apex sync BN.') 55 | parser.add_argument("--local_rank", default=0, type=int) 56 | 57 | 58 | if __name__ == '__main__': 59 | args = parser.parse_args() 60 | 61 | # random.seed(args.manual_seed) 62 | # torch.manual_seed(args.manual_seed) 63 | # torch.cuda.manual_seed(args.manual_seed) 64 | # torch.cuda.manual_seed_all(args.manual_seed) 65 | # np.random.seed(args.manual_seed) 66 | torch.backends.cudnn.benchmark = True 67 | # torch.backends.cudnn.deterministic = True 68 | # torch.backends.cudnn.enabled = True 69 | os.makedirs(args.path, exist_ok=True) 70 | 71 | # distributed setting 72 | torch.distributed.init_process_group(backend='nccl', 73 | init_method='env://') 74 | args.world_size = torch.distributed.get_world_size() 75 | 76 | # prepare run config 77 | run_config_path = '%s/run.config' % args.path 78 | # build run config from args 79 | args.lr_schedule_param = None 80 | args.opt_param = { 81 | 'momentum': args.momentum, 82 | 'nesterov': not args.no_nesterov, 83 | } 84 | run_config = TrainRunConfig( 85 | **args.__dict__ 86 | ) 87 | 88 | if args.local_rank == 0: 89 | print('Run config:') 90 | for k, v in run_config.config.items(): 91 | print('\t%s: %s' % (k, v)) 92 | 93 | weight_path = None 94 | net_origin = None 95 | if args.model == "resnet50": 96 | assert args.dataset == 'imagenet', 'resnet50 only supports imagenet dataset' 97 | net = ResNet_ImageNet( 98 | depth=50, num_classes=run_config.data_provider.n_classes, cfg=eval(args.cfg)) 99 | if args.base_path is not None: 100 | weight_path = args.base_path+'/checkpoint/model_best.pth.tar' 101 | net_origin = nn.DataParallel(ResNet_ImageNet( 102 | depth=50, num_classes=run_config.data_provider.n_classes)) 103 | elif args.model == "mobilenetv2": 104 | assert args.dataset == 'imagenet', 'mobilenetv2 only supports imagenet dataset' 105 | net = MobileNetV2( 106 | num_classes=run_config.data_provider.n_classes, cfg=eval(args.cfg)) 107 | if args.base_path is not None: 108 | weight_path = args.base_path+'/checkpoint/model_best.pth.tar' 109 | net_origin = nn.DataParallel(MobileNetV2( 110 | num_classes=run_config.data_provider.n_classes)) 111 | elif args.model == "mobilenet": 112 | assert args.dataset == 'imagenet', 'mobilenet only supports imagenet dataset' 113 | net = MobileNet( 114 | num_classes=run_config.data_provider.n_classes, cfg=eval(args.cfg)) 115 | if args.base_path is not None: 116 | weight_path = args.base_path+'/checkpoint/model_best.pth.tar' 117 | net_origin = nn.DataParallel( 118 | MobileNet(num_classes=run_config.data_provider.n_classes)) 119 | 120 | # build run manager 121 | run_manager = RunManager(args.path, net, run_config) 122 | if args.local_rank == 0: 123 | run_manager.save_config(print_info=True) 124 | 125 | # load checkpoints 126 | if args.base_path is not None: 127 | weight_path = args.base_path+'/checkpoint/model_best.pth.tar' 128 | if args.resume: 129 | run_manager.load_model() 130 | if args.train and run_manager.best_acc == 0: 131 | loss, acc1, acc5 = run_manager.validate( 132 | is_test=True, return_top5=True) 133 | run_manager.best_acc = acc1 134 | elif weight_path is not None and os.path.isfile(weight_path): 135 | assert net_origin is not None, "original network is None" 136 | net_origin.load_state_dict(torch.load(weight_path)['state_dict']) 137 | net_origin = net_origin.module 138 | if args.model == 'resnet50': 139 | run_manager.reset_model(ResNet_ImageNet( 140 | num_classes=run_config.data_provider.n_classes, cfg=eval(args.cfg), depth=50), net_origin.cpu()) 141 | elif args.model == 'mobilenet': 142 | run_manager.reset_model(MobileNet( 143 | num_classes=run_config.data_provider.n_classes, cfg=eval(args.cfg)), net_origin.cpu()) 144 | elif args.model == 'mobilenetv2': 145 | run_manager.reset_model(MobileNetV2( 146 | num_classes=run_config.data_provider.n_classes, cfg=eval(args.cfg)), net_origin.cpu()) 147 | else: 148 | print('Random initialization') 149 | 150 | # train 151 | if args.train: 152 | print('Start training: %d' % args.local_rank) 153 | if args.local_rank == 0: 154 | print(net) 155 | run_manager.train(print_top5=True) 156 | if args.local_rank == 0: 157 | run_manager.save_model() 158 | 159 | output_dict = {} 160 | 161 | # test 162 | print('Test on test set') 163 | loss, acc1, acc5 = run_manager.validate(is_test=True, return_top5=True) 164 | log = 'test_loss: %f\t test_acc1: %f\t test_acc5: %f' % (loss, acc1, acc5) 165 | run_manager.write_log(log, prefix='test') 166 | output_dict = { 167 | **output_dict, 168 | 'test_loss': '%f' % loss, 'test_acc1': '%f' % acc1, 'test_acc5': '%f' % acc5 169 | } 170 | json.dump(output_dict, open('%s/output' % args.path, 'w'), indent=4) 171 | -------------------------------------------------------------------------------- /Cifar/utils/compute_flops.py: -------------------------------------------------------------------------------- 1 | # Code from https://github.com/simochen/model-tools. 2 | import numpy as np 3 | import pdb 4 | import torch 5 | import torchvision 6 | import torch.nn as nn 7 | import numpy as np 8 | import random 9 | 10 | 11 | def lookup_table_flops(model, candidate_width, alphas=None, input_res=32, multiply_adds=False): 12 | if alphas is None: 13 | for n, v in model.named_parameters(): 14 | if 'alphas' in n: 15 | alphas = v 16 | num_conv = alphas.shape[0] 17 | device = alphas.device 18 | 19 | # obtain the feature map sizes 20 | list_bn=[] 21 | def bn_hook(self, input, output): 22 | if input[0].ndimension() == 4: 23 | list_bn.append(input[0].shape) 24 | 25 | def foo(net): 26 | childrens = list(net.children()) 27 | if not childrens: 28 | if isinstance(net, torch.nn.BatchNorm2d): 29 | net.register_forward_hook(bn_hook) 30 | return 31 | for c in childrens: 32 | foo(c) 33 | 34 | foo(model) 35 | input_ = torch.rand(1, 3, input_res, input_res).to(alphas.device) 36 | input_.requires_grad = True 37 | # print('alphas:', alphas) 38 | # print('inputs:', input_) 39 | if torch.cuda.device_count() > 1: 40 | model.module.register_buffer('alphas_tmp', alphas.data) 41 | else: 42 | model.register_buffer('alphas_tmp', alphas.data) 43 | out = model(input_) 44 | 45 | # TODO: only appliable for resnet_20s: 2 convs followed by 1 shortcut 46 | list_main_bn = [] 47 | num_width = len(candidate_width) 48 | 49 | for i, b in enumerate(list_bn): 50 | if i//num_width == 0 or \ 51 | ((i-num_width)//(num_width**2) >= 0 and ((i-num_width)//(num_width)**2) % 3 != 2): 52 | list_main_bn.append(b) 53 | assert len(list_main_bn) == (num_width + num_width ** 2 * (num_conv-1)), 'wrong list of feature map length' 54 | 55 | # start compute flops for each branch 56 | # first obtain the kernel shapes, a list of length: num_width + num_width**2 * num_conv 57 | 58 | def kernel_shape_types(candidate_width): 59 | kshape_types = [] 60 | first_kshape_types = [] 61 | for i in candidate_width: 62 | first_kshape_types.append((i, 3, 3, 3)) 63 | 64 | for i in candidate_width: 65 | for j in candidate_width: 66 | kshape_types.append((i, j, 3, 3)) # [co, ci, k, k] 67 | return kshape_types, first_kshape_types 68 | 69 | kshape_types, first_kshape_types = kernel_shape_types(candidate_width) 70 | k_shapes = [] 71 | layer_idx = 0 72 | for v in model.parameters(): 73 | if v.ndimension() == 4 and v.shape[2] == 3: 74 | if layer_idx == 0: 75 | k_shapes += first_kshape_types 76 | else: 77 | k_shapes += kshape_types 78 | layer_idx += 1 79 | 80 | # compute flops 81 | flops = [] # a list of length: num_width + num_width**2 * num_conv 82 | for idx, a_shape in enumerate(list_main_bn): 83 | n, ci, h, w = a_shape 84 | k_shape = k_shapes[idx] 85 | co, ci, k, _ = k_shape 86 | flop = co * ci * k * k * h * w 87 | flops.append(flop) 88 | 89 | # reshape flops back to list. len == num_conv 90 | table_flops = [] 91 | table_flops.append(torch.Tensor(flops[:num_width]).to(device)) 92 | for layer_idx in range(num_conv-1): 93 | tmp = flops[num_width + layer_idx*num_width**2:\ 94 | num_width + (layer_idx+1)*num_width**2] 95 | assert len(tmp) == num_width ** 2, 'need have %d elements in %d layer'%(num_width**2, layer_idx+1) 96 | table_flops.append(torch.Tensor(tmp).to(device)) 97 | 98 | return table_flops 99 | 100 | 101 | def print_model_param_nums(model, multiply_adds=False): 102 | total = sum([param.nelement() for param in model.parameters()]) 103 | print(' + Number of original params: %.8fM' % (total / 1e6)) 104 | return total 105 | 106 | 107 | def print_model_param_flops(model, input_res, multiply_adds=False): 108 | 109 | prods = {} 110 | def save_hook(name): 111 | def hook_per(self, input, output): 112 | prods[name] = np.prod(input[0].shape) 113 | return hook_per 114 | 115 | list_conv=[] 116 | def conv_hook(self, input, output): 117 | batch_size, input_channels, input_height, input_width = input[0].size() 118 | output_channels, output_height, output_width = output[0].size() 119 | 120 | kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) 121 | bias_ops = 1 if self.bias is not None else 0 122 | 123 | params = output_channels * (kernel_ops + bias_ops) 124 | flops = (kernel_ops * (2 if multiply_adds else 1) + bias_ops) * output_channels * output_height * output_width * batch_size 125 | list_conv.append(flops) 126 | 127 | list_linear=[] 128 | def linear_hook(self, input, output): 129 | batch_size = input[0].size(0) if input[0].dim() == 2 else 1 130 | 131 | weight_ops = self.weight.nelement() * (2 if multiply_adds else 1) 132 | bias_ops = self.bias.nelement() if self.bias is not None else 0 133 | 134 | flops = batch_size * (weight_ops + bias_ops) 135 | list_linear.append(flops) 136 | 137 | list_bn=[] 138 | def bn_hook(self, input, output): 139 | list_bn.append(input[0].nelement() * 2) 140 | 141 | list_relu=[] 142 | def relu_hook(self, input, output): 143 | list_relu.append(input[0].nelement()) 144 | 145 | list_pooling=[] 146 | def pooling_hook(self, input, output): 147 | batch_size, input_channels, input_height, input_width = input[0].size() 148 | output_channels, output_height, output_width = output[0].size() 149 | 150 | kernel_ops = self.kernel_size * self.kernel_size 151 | bias_ops = 0 152 | params = 0 153 | flops = (kernel_ops + bias_ops) * output_channels * output_height * output_width * batch_size 154 | list_pooling.append(flops) 155 | 156 | list_upsample=[] 157 | # For bilinear upsample 158 | def upsample_hook(self, input, output): 159 | batch_size, input_channels, input_height, input_width = input[0].size() 160 | output_channels, output_height, output_width = output[0].size() 161 | 162 | flops = output_height * output_width * output_channels * batch_size * 12 163 | list_upsample.append(flops) 164 | 165 | def foo(net): 166 | childrens = list(net.children()) 167 | if not childrens: 168 | if isinstance(net, torch.nn.Conv2d): 169 | net.register_forward_hook(conv_hook) 170 | if isinstance(net, torch.nn.Linear): 171 | net.register_forward_hook(linear_hook) 172 | # if isinstance(net, torch.nn.BatchNorm2d): 173 | # net.register_forward_hook(bn_hook) 174 | # if isinstance(net, torch.nn.ReLU): 175 | # net.register_forward_hook(relu_hook) 176 | # if isinstance(net, torch.nn.MaxPool2d) or isinstance(net, torch.nn.AvgPool2d): 177 | # net.register_forward_hook(pooling_hook) 178 | # if isinstance(net, torch.nn.Upsample): 179 | # net.register_forward_hook(upsample_hook) 180 | return 181 | for c in childrens: 182 | foo(c) 183 | 184 | model = model.cuda() 185 | foo(model) 186 | input_ = torch.rand(3, 3, input_res, input_res).cuda() 187 | input_.requires_grad = True 188 | out = model(input_) 189 | 190 | total_flops = (sum(list_conv) + sum(list_linear) + sum(list_bn) + sum(list_relu) + sum(list_pooling) + sum(list_upsample)) 191 | total_flops /= 3 192 | 193 | print(' + Number of FLOPs of original model: %.8fG' % (total_flops / 1e9)) 194 | # print('list_conv', list_conv) 195 | # print('list_linear', list_linear) 196 | # print('list_bn', list_bn) 197 | # print('list_relu', list_relu) 198 | # print('list_pooling', list_pooling) 199 | 200 | return total_flops 201 | 202 | -------------------------------------------------------------------------------- /Cifar/utils/pytorch_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import sys 5 | 6 | def _make_divisible(v, divisor=8, min_value=None): 7 | """ 8 | This function is taken from the original tf repo. 9 | It ensures that all layers have a channel number that is divisible by 8 10 | It can be seen here: 11 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 12 | :param v: 13 | :param divisor: 14 | :param min_value: 15 | :return: 16 | """ 17 | if min_value is None: 18 | min_value = divisor 19 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 20 | # Make sure that round down does not go down by more than 10%. 21 | if new_v < 0.9 * v: 22 | new_v += divisor 23 | return new_v 24 | 25 | def build_activation(act_func, inplace=True): 26 | if act_func == 'relu': 27 | return nn.ReLU(inplace=inplace) 28 | elif act_func == 'relu6': 29 | return nn.ReLU6(inplace=inplace) 30 | elif act_func == 'tanh': 31 | return nn.Tanh() 32 | elif act_func == 'sigmoid': 33 | return nn.Sigmoid() 34 | elif act_func is None: 35 | return None 36 | else: 37 | raise ValueError('do not support: %s' % act_func) 38 | 39 | def raw2cfg(model, raw_ratios, flops, p=False, div=8): 40 | left = 0 41 | right = 50 42 | scale = 0 43 | cfg = None 44 | current_flops = 0 45 | base_channels = model.config['cfg_base'] 46 | cnt = 0 47 | while (True): 48 | cnt += 1 49 | scale = (left + right) / 2 50 | scaled_ratios = raw_ratios * scale 51 | for i in range(len(scaled_ratios)): 52 | scaled_ratios[i] = max(0.1, scaled_ratios[i]) 53 | scaled_ratios[i] = min(1, scaled_ratios[i]) 54 | cfg = (base_channels * scaled_ratios).astype(int).tolist() 55 | for i in range(len(cfg)): 56 | cfg[i] = _make_divisible(cfg[i], div) # 8 divisible channels 57 | current_flops = model.cfg2flops(cfg) 58 | if cnt > 20: 59 | break 60 | if abs(current_flops - flops) / flops < 0.01: 61 | break 62 | if p: 63 | print(str(current_flops)+'---'+str(flops)+'---left: '+str(left)+'---right: '+str(right)+'---cfg: '+str(cfg)) 64 | if current_flops < flops: 65 | left = scale 66 | elif current_flops > flops: 67 | right = scale 68 | else: 69 | break 70 | return cfg 71 | 72 | def weight2mask(weight, keep_c): # simple L1 pruning 73 | weight_copy = weight.abs().clone() 74 | L1_norm = torch.sum(weight_copy, dim=(1,2,3)) 75 | arg_max = torch.argsort(L1_norm, descending=True) 76 | arg_max_rev = arg_max[:keep_c].tolist() 77 | mask = np.zeros(weight.shape[0]) 78 | mask[arg_max_rev] = 1 79 | return mask 80 | 81 | def get_unpruned_weights(model, model_origin): 82 | masks = [] 83 | for [m0, m1] in zip(model_origin.named_modules(), model.named_modules()): 84 | if isinstance(m0[1], nn.Conv2d): 85 | if m0[1].weight.data.shape!=m1[1].weight.data.shape: 86 | flag = False 87 | if m0[1].weight.data.shape[1]!=m1[1].weight.data.shape[1]: 88 | assert len(masks)>0, "masks is empty!" 89 | if m0[0].endswith('downsample.conv'): 90 | if model.config['depth']>=50: 91 | mask = masks[-4] 92 | else: 93 | mask = masks[-3] 94 | else: 95 | mask = masks[-1] 96 | idx = np.squeeze(np.argwhere(mask)) 97 | if idx.size == 1: 98 | idx = np.resize(idx, (1,)) 99 | w = m0[1].weight.data[:, idx.tolist(), :, :].clone() 100 | flag = True 101 | if m0[1].weight.data.shape[0]==m1[1].weight.data.shape[0]: 102 | masks.append(None) 103 | if m0[1].weight.data.shape[0]!=m1[1].weight.data.shape[0]: 104 | if m0[0].endswith('downsample.conv'): 105 | mask = masks[-1] 106 | else: 107 | if flag: 108 | mask = weight2mask(w.clone(), m1[1].weight.data.shape[0]) 109 | else: 110 | mask = weight2mask(m0[1].weight.data, m1[1].weight.data.shape[0]) 111 | idx = np.squeeze(np.argwhere(mask)) 112 | if idx.size == 1: 113 | idx = np.resize(idx, (1,)) 114 | if flag: 115 | w = w[idx.tolist(), :, :, :].clone() 116 | else: 117 | w = m0[1].weight.data[idx.tolist(), :, :, :].clone() 118 | m1[1].weight.data = w.clone() 119 | masks.append(mask) 120 | continue 121 | else: 122 | m1[1].weight.data = m0[1].weight.data.clone() 123 | masks.append(None) 124 | elif isinstance(m0[1], nn.BatchNorm2d): 125 | assert isinstance(m1[1], nn.BatchNorm2d), "There should not be bn layer here." 126 | if m0[1].weight.data.shape!=m1[1].weight.data.shape: 127 | mask = masks[-1] 128 | idx = np.squeeze(np.argwhere(mask)) 129 | if idx.size == 1: 130 | idx = np.resize(idx, (1,)) 131 | m1[1].weight.data = m0[1].weight.data[idx.tolist()].clone() 132 | m1[1].bias.data = m0[1].bias.data[idx.tolist()].clone() 133 | m1[1].running_mean = m0[1].running_mean[idx.tolist()].clone() 134 | m1[1].running_var = m0[1].running_var[idx.tolist()].clone() 135 | continue 136 | m1[1].weight.data = m0[1].weight.data.clone() 137 | m1[1].bias.data = m0[1].bias.data.clone() 138 | m1[1].running_mean = m0[1].running_mean.clone() 139 | m1[1].running_var = m0[1].running_var.clone() 140 | 141 | # noinspection PyUnresolvedReferences 142 | def cross_entropy_with_label_smoothing(pred, target, label_smoothing=0.1): 143 | logsoftmax = nn.LogSoftmax(dim=1) 144 | n_classes = pred.size(1) 145 | # convert to one-hot 146 | target = torch.unsqueeze(target, 1) 147 | soft_target = torch.zeros_like(pred) 148 | soft_target.scatter_(1, target, 1) 149 | # label smoothing 150 | soft_target = soft_target * (1 - label_smoothing) + label_smoothing / n_classes 151 | return torch.mean(torch.sum(- soft_target * logsoftmax(pred), 1)) 152 | 153 | 154 | def count_parameters(model): 155 | total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 156 | return total_params 157 | 158 | 159 | def detach_variable(inputs): 160 | if isinstance(inputs, tuple): 161 | return tuple([detach_variable(x) for x in inputs]) 162 | else: 163 | x = inputs.detach() 164 | x.requires_grad = inputs.requires_grad 165 | return x 166 | 167 | 168 | def accuracy(output, target, topk=(1,)): 169 | """ Computes the precision@k for the specified values of k """ 170 | maxk = max(topk) 171 | batch_size = target.size(0) 172 | 173 | _, pred = output.topk(maxk, 1, True, True) 174 | pred = pred.t() 175 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 176 | 177 | res = [] 178 | for k in topk: 179 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 180 | res.append(correct_k.mul_(100.0 / batch_size)) 181 | return res 182 | 183 | 184 | class AverageMeter(object): 185 | """ 186 | Computes and stores the average and current value 187 | Copied from: https://github.com/pytorch/examples/blob/master/imagenet/main.py 188 | """ 189 | 190 | def __init__(self): 191 | self.val = 0 192 | self.avg = 0 193 | self.sum = 0 194 | self.count = 0 195 | 196 | def reset(self): 197 | self.val = 0 198 | self.avg = 0 199 | self.sum = 0 200 | self.count = 0 201 | 202 | def update(self, val, n=1): 203 | self.val = val 204 | self.sum += val * n 205 | self.count += n 206 | self.avg = self.sum / self.count 207 | -------------------------------------------------------------------------------- /Imagenet/utils/pytorch_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import sys 5 | 6 | def _make_divisible(v, divisor=8, min_value=None): 7 | """ 8 | This function is taken from the original tf repo. 9 | It ensures that all layers have a channel number that is divisible by 8 10 | It can be seen here: 11 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 12 | :param v: 13 | :param divisor: 14 | :param min_value: 15 | :return: 16 | """ 17 | if min_value is None: 18 | min_value = divisor 19 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 20 | # Make sure that round down does not go down by more than 10%. 21 | if new_v < 0.9 * v: 22 | new_v += divisor 23 | return new_v 24 | 25 | def build_activation(act_func, inplace=True): 26 | if act_func == 'relu': 27 | return nn.ReLU(inplace=inplace) 28 | elif act_func == 'relu6': 29 | return nn.ReLU6(inplace=inplace) 30 | elif act_func == 'tanh': 31 | return nn.Tanh() 32 | elif act_func == 'sigmoid': 33 | return nn.Sigmoid() 34 | elif act_func is None: 35 | return None 36 | else: 37 | raise ValueError('do not support: %s' % act_func) 38 | 39 | def raw2cfg(model, raw_ratios, flops, p=False, div=8): 40 | left = 0 41 | right = 50 42 | scale = 0 43 | cfg = None 44 | current_flops = 0 45 | base_channels = model.config['cfg_base'] 46 | cnt = 0 47 | while (True): 48 | cnt += 1 49 | scale = (left + right) / 2 50 | scaled_ratios = raw_ratios * scale 51 | for i in range(len(scaled_ratios)): 52 | scaled_ratios[i] = max(0.1, scaled_ratios[i]) 53 | scaled_ratios[i] = min(1, scaled_ratios[i]) 54 | cfg = (base_channels * scaled_ratios).astype(int).tolist() 55 | for i in range(len(cfg)): 56 | cfg[i] = _make_divisible(cfg[i], div) # 8 divisible channels 57 | current_flops = model.cfg2flops(cfg) 58 | if cnt > 20: 59 | break 60 | if abs(current_flops - flops) / flops < 0.01: 61 | break 62 | if p: 63 | print(str(current_flops)+'---'+str(flops)+'---left: '+str(left)+'---right: '+str(right)+'---cfg: '+str(cfg)) 64 | if current_flops < flops: 65 | left = scale 66 | elif current_flops > flops: 67 | right = scale 68 | else: 69 | break 70 | return cfg 71 | 72 | def weight2mask(weight, keep_c): # simple L1 pruning 73 | weight_copy = weight.abs().clone() 74 | L1_norm = torch.sum(weight_copy, dim=(1,2,3)) 75 | arg_max = torch.argsort(L1_norm, descending=True) 76 | arg_max_rev = arg_max[:keep_c].tolist() 77 | mask = np.zeros(weight.shape[0]) 78 | mask[arg_max_rev] = 1 79 | return mask 80 | 81 | def get_unpruned_weights(model, model_origin): 82 | masks = [] 83 | for [m0, m1] in zip(model_origin.named_modules(), model.named_modules()): 84 | if isinstance(m0[1], nn.Conv2d): 85 | if m0[1].weight.data.shape!=m1[1].weight.data.shape: 86 | flag = False 87 | if m0[1].weight.data.shape[1]!=m1[1].weight.data.shape[1]: 88 | assert len(masks)>0, "masks is empty!" 89 | if m0[0].endswith('downsample.conv'): 90 | if model.config['depth']>=50: 91 | mask = masks[-4] 92 | else: 93 | mask = masks[-3] 94 | else: 95 | mask = masks[-1] 96 | idx = np.squeeze(np.argwhere(mask)) 97 | if idx.size == 1: 98 | idx = np.resize(idx, (1,)) 99 | w = m0[1].weight.data[:, idx.tolist(), :, :].clone() 100 | flag = True 101 | if m0[1].weight.data.shape[0]==m1[1].weight.data.shape[0]: 102 | masks.append(None) 103 | if m0[1].weight.data.shape[0]!=m1[1].weight.data.shape[0]: 104 | if m0[0].endswith('downsample.conv'): 105 | mask = masks[-1] 106 | else: 107 | if flag: 108 | mask = weight2mask(w.clone(), m1[1].weight.data.shape[0]) 109 | else: 110 | mask = weight2mask(m0[1].weight.data, m1[1].weight.data.shape[0]) 111 | idx = np.squeeze(np.argwhere(mask)) 112 | if idx.size == 1: 113 | idx = np.resize(idx, (1,)) 114 | if flag: 115 | w = w[idx.tolist(), :, :, :].clone() 116 | else: 117 | w = m0[1].weight.data[idx.tolist(), :, :, :].clone() 118 | m1[1].weight.data = w.clone() 119 | masks.append(mask) 120 | continue 121 | else: 122 | m1[1].weight.data = m0[1].weight.data.clone() 123 | masks.append(None) 124 | elif isinstance(m0[1], nn.BatchNorm2d): 125 | assert isinstance(m1[1], nn.BatchNorm2d), "There should not be bn layer here." 126 | if m0[1].weight.data.shape!=m1[1].weight.data.shape: 127 | mask = masks[-1] 128 | idx = np.squeeze(np.argwhere(mask)) 129 | if idx.size == 1: 130 | idx = np.resize(idx, (1,)) 131 | m1[1].weight.data = m0[1].weight.data[idx.tolist()].clone() 132 | m1[1].bias.data = m0[1].bias.data[idx.tolist()].clone() 133 | m1[1].running_mean = m0[1].running_mean[idx.tolist()].clone() 134 | m1[1].running_var = m0[1].running_var[idx.tolist()].clone() 135 | continue 136 | m1[1].weight.data = m0[1].weight.data.clone() 137 | m1[1].bias.data = m0[1].bias.data.clone() 138 | m1[1].running_mean = m0[1].running_mean.clone() 139 | m1[1].running_var = m0[1].running_var.clone() 140 | 141 | # noinspection PyUnresolvedReferences 142 | def cross_entropy_with_label_smoothing(pred, target, label_smoothing=0.1): 143 | logsoftmax = nn.LogSoftmax(dim=1) 144 | n_classes = pred.size(1) 145 | # convert to one-hot 146 | target = torch.unsqueeze(target, 1) 147 | soft_target = torch.zeros_like(pred) 148 | soft_target.scatter_(1, target, 1) 149 | # label smoothing 150 | soft_target = soft_target * (1 - label_smoothing) + label_smoothing / n_classes 151 | return torch.mean(torch.sum(- soft_target * logsoftmax(pred), 1)) 152 | 153 | 154 | def count_parameters(model): 155 | total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 156 | return total_params 157 | 158 | 159 | def detach_variable(inputs): 160 | if isinstance(inputs, tuple): 161 | return tuple([detach_variable(x) for x in inputs]) 162 | else: 163 | x = inputs.detach() 164 | x.requires_grad = inputs.requires_grad 165 | return x 166 | 167 | 168 | def accuracy(output, target, topk=(1,)): 169 | """ Computes the precision@k for the specified values of k """ 170 | maxk = max(topk) 171 | batch_size = target.size(0) 172 | 173 | _, pred = output.topk(maxk, 1, True, True) 174 | pred = pred.t() 175 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 176 | 177 | res = [] 178 | for k in topk: 179 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 180 | res.append(correct_k.mul_(100.0 / batch_size)) 181 | return res 182 | 183 | 184 | class AverageMeter(object): 185 | """ 186 | Computes and stores the average and current value 187 | Copied from: https://github.com/pytorch/examples/blob/master/imagenet/main.py 188 | """ 189 | 190 | def __init__(self): 191 | self.val = 0 192 | self.avg = 0 193 | self.sum = 0 194 | self.count = 0 195 | 196 | def reset(self): 197 | self.val = 0 198 | self.avg = 0 199 | self.sum = 0 200 | self.count = 0 201 | 202 | def update(self, val, n=1): 203 | self.val = val 204 | self.sum += val * n 205 | self.count += n 206 | self.avg = self.sum / self.count 207 | -------------------------------------------------------------------------------- /ToyExample/utils/pytorch_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import sys 5 | 6 | def _make_divisible(v, divisor=8, min_value=None): 7 | """ 8 | This function is taken from the original tf repo. 9 | It ensures that all layers have a channel number that is divisible by 8 10 | It can be seen here: 11 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 12 | :param v: 13 | :param divisor: 14 | :param min_value: 15 | :return: 16 | """ 17 | if min_value is None: 18 | min_value = divisor 19 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 20 | # Make sure that round down does not go down by more than 10%. 21 | if new_v < 0.9 * v: 22 | new_v += divisor 23 | return new_v 24 | 25 | def build_activation(act_func, inplace=True): 26 | if act_func == 'relu': 27 | return nn.ReLU(inplace=inplace) 28 | elif act_func == 'relu6': 29 | return nn.ReLU6(inplace=inplace) 30 | elif act_func == 'tanh': 31 | return nn.Tanh() 32 | elif act_func == 'sigmoid': 33 | return nn.Sigmoid() 34 | elif act_func is None: 35 | return None 36 | else: 37 | raise ValueError('do not support: %s' % act_func) 38 | 39 | def raw2cfg(model, raw_ratios, flops, p=False, div=8): 40 | left = 0 41 | right = 50 42 | scale = 0 43 | cfg = None 44 | current_flops = 0 45 | base_channels = model.config['cfg_base'] 46 | cnt = 0 47 | while (True): 48 | cnt += 1 49 | scale = (left + right) / 2 50 | scaled_ratios = raw_ratios * scale 51 | for i in range(len(scaled_ratios)): 52 | scaled_ratios[i] = max(0.1, scaled_ratios[i]) 53 | scaled_ratios[i] = min(1, scaled_ratios[i]) 54 | cfg = (base_channels * scaled_ratios).astype(int).tolist() 55 | for i in range(len(cfg)): 56 | cfg[i] = _make_divisible(cfg[i], div) # 8 divisible channels 57 | current_flops = model.cfg2flops(cfg) 58 | if cnt > 20: 59 | break 60 | if abs(current_flops - flops) / flops < 0.01: 61 | break 62 | if p: 63 | print(str(current_flops)+'---'+str(flops)+'---left: '+str(left)+'---right: '+str(right)+'---cfg: '+str(cfg)) 64 | if current_flops < flops: 65 | left = scale 66 | elif current_flops > flops: 67 | right = scale 68 | else: 69 | break 70 | return cfg 71 | 72 | def weight2mask(weight, keep_c): # simple L1 pruning 73 | weight_copy = weight.abs().clone() 74 | L1_norm = torch.sum(weight_copy, dim=(1,2,3)) 75 | arg_max = torch.argsort(L1_norm, descending=True) 76 | arg_max_rev = arg_max[:keep_c].tolist() 77 | mask = np.zeros(weight.shape[0]) 78 | mask[arg_max_rev] = 1 79 | return mask 80 | 81 | def get_unpruned_weights(model, model_origin): 82 | masks = [] 83 | for [m0, m1] in zip(model_origin.named_modules(), model.named_modules()): 84 | if isinstance(m0[1], nn.Conv2d): 85 | if m0[1].weight.data.shape!=m1[1].weight.data.shape: 86 | flag = False 87 | if m0[1].weight.data.shape[1]!=m1[1].weight.data.shape[1]: 88 | assert len(masks)>0, "masks is empty!" 89 | if m0[0].endswith('downsample.conv'): 90 | if model.config['depth']>=50: 91 | mask = masks[-4] 92 | else: 93 | mask = masks[-3] 94 | else: 95 | mask = masks[-1] 96 | idx = np.squeeze(np.argwhere(mask)) 97 | if idx.size == 1: 98 | idx = np.resize(idx, (1,)) 99 | w = m0[1].weight.data[:, idx.tolist(), :, :].clone() 100 | flag = True 101 | if m0[1].weight.data.shape[0]==m1[1].weight.data.shape[0]: 102 | masks.append(None) 103 | if m0[1].weight.data.shape[0]!=m1[1].weight.data.shape[0]: 104 | if m0[0].endswith('downsample.conv'): 105 | mask = masks[-1] 106 | else: 107 | if flag: 108 | mask = weight2mask(w.clone(), m1[1].weight.data.shape[0]) 109 | else: 110 | mask = weight2mask(m0[1].weight.data, m1[1].weight.data.shape[0]) 111 | idx = np.squeeze(np.argwhere(mask)) 112 | if idx.size == 1: 113 | idx = np.resize(idx, (1,)) 114 | if flag: 115 | w = w[idx.tolist(), :, :, :].clone() 116 | else: 117 | w = m0[1].weight.data[idx.tolist(), :, :, :].clone() 118 | m1[1].weight.data = w.clone() 119 | masks.append(mask) 120 | continue 121 | else: 122 | m1[1].weight.data = m0[1].weight.data.clone() 123 | masks.append(None) 124 | elif isinstance(m0[1], nn.BatchNorm2d): 125 | assert isinstance(m1[1], nn.BatchNorm2d), "There should not be bn layer here." 126 | if m0[1].weight.data.shape!=m1[1].weight.data.shape: 127 | mask = masks[-1] 128 | idx = np.squeeze(np.argwhere(mask)) 129 | if idx.size == 1: 130 | idx = np.resize(idx, (1,)) 131 | m1[1].weight.data = m0[1].weight.data[idx.tolist()].clone() 132 | m1[1].bias.data = m0[1].bias.data[idx.tolist()].clone() 133 | m1[1].running_mean = m0[1].running_mean[idx.tolist()].clone() 134 | m1[1].running_var = m0[1].running_var[idx.tolist()].clone() 135 | continue 136 | m1[1].weight.data = m0[1].weight.data.clone() 137 | m1[1].bias.data = m0[1].bias.data.clone() 138 | m1[1].running_mean = m0[1].running_mean.clone() 139 | m1[1].running_var = m0[1].running_var.clone() 140 | 141 | # noinspection PyUnresolvedReferences 142 | def cross_entropy_with_label_smoothing(pred, target, label_smoothing=0.1): 143 | logsoftmax = nn.LogSoftmax(dim=1) 144 | n_classes = pred.size(1) 145 | # convert to one-hot 146 | target = torch.unsqueeze(target, 1) 147 | soft_target = torch.zeros_like(pred) 148 | soft_target.scatter_(1, target, 1) 149 | # label smoothing 150 | soft_target = soft_target * (1 - label_smoothing) + label_smoothing / n_classes 151 | return torch.mean(torch.sum(- soft_target * logsoftmax(pred), 1)) 152 | 153 | 154 | def count_parameters(model): 155 | total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 156 | return total_params 157 | 158 | 159 | def detach_variable(inputs): 160 | if isinstance(inputs, tuple): 161 | return tuple([detach_variable(x) for x in inputs]) 162 | else: 163 | x = inputs.detach() 164 | x.requires_grad = inputs.requires_grad 165 | return x 166 | 167 | 168 | def accuracy(output, target, topk=(1,)): 169 | """ Computes the precision@k for the specified values of k """ 170 | maxk = max(topk) 171 | batch_size = target.size(0) 172 | 173 | _, pred = output.topk(maxk, 1, True, True) 174 | pred = pred.t() 175 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 176 | 177 | res = [] 178 | for k in topk: 179 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 180 | res.append(correct_k.mul_(100.0 / batch_size)) 181 | return res 182 | 183 | 184 | class AverageMeter(object): 185 | """ 186 | Computes and stores the average and current value 187 | Copied from: https://github.com/pytorch/examples/blob/master/imagenet/main.py 188 | """ 189 | 190 | def __init__(self): 191 | self.val = 0 192 | self.avg = 0 193 | self.sum = 0 194 | self.count = 0 195 | 196 | def reset(self): 197 | self.val = 0 198 | self.avg = 0 199 | self.sum = 0 200 | self.count = 0 201 | 202 | def update(self, val, n=1): 203 | self.val = val 204 | self.sum += val * n 205 | self.count += n 206 | self.avg = self.sum / self.count 207 | -------------------------------------------------------------------------------- /CKA/cka.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import sys 4 | 5 | def gram_linear(x): 6 | """Compute Gram (kernel) matrix for a linear kernel. 7 | 8 | Args: 9 | x: A num_examples x num_features matrix of features. 10 | 11 | Returns: 12 | A num_examples x num_examples Gram matrix of examples. 13 | """ 14 | return x.dot(x.T) 15 | 16 | 17 | def gram_rbf(x, threshold=1.0): 18 | """Compute Gram (kernel) matrix for an RBF kernel. 19 | 20 | Args: 21 | x: A num_examples x num_features matrix of features. 22 | threshold: Fraction of median Euclidean distance to use as RBF kernel 23 | bandwidth. (This is the heuristic we use in the paper. There are other 24 | possible ways to set the bandwidth; we didn't try them.) 25 | 26 | Returns: 27 | A num_examples x num_examples Gram matrix of examples. 28 | """ 29 | dot_products = x.dot(x.T) 30 | sq_norms = np.diag(dot_products) 31 | sq_distances = -2 * dot_products + sq_norms[:, None] + sq_norms[None, :] 32 | sq_median_distance = np.median(sq_distances) 33 | return np.exp(-sq_distances / (2 * threshold ** 2 * sq_median_distance)) 34 | 35 | 36 | def center_gram(gram, unbiased=False): 37 | """Center a symmetric Gram matrix. 38 | 39 | This is equvialent to centering the (possibly infinite-dimensional) features 40 | induced by the kernel before computing the Gram matrix. 41 | 42 | Args: 43 | gram: A num_examples x num_examples symmetric matrix. 44 | unbiased: Whether to adjust the Gram matrix in order to compute an unbiased 45 | estimate of HSIC. Note that this estimator may be negative. 46 | 47 | Returns: 48 | A symmetric matrix with centered columns and rows. 49 | """ 50 | if not np.allclose(gram, gram.T): 51 | raise ValueError('Input must be a symmetric matrix.') 52 | gram = gram.copy() 53 | 54 | if unbiased: 55 | # This formulation of the U-statistic, from Szekely, G. J., & Rizzo, M. 56 | # L. (2014). Partial distance correlation with methods for dissimilarities. 57 | # The Annals of Statistics, 42(6), 2382-2412, seems to be more numerically 58 | # stable than the alternative from Song et al. (2007). 59 | n = gram.shape[0] 60 | np.fill_diagonal(gram, 0) 61 | means = np.sum(gram, 0, dtype=np.float64) / (n - 2) 62 | means -= np.sum(means) / (2 * (n - 1)) 63 | gram -= means[:, None] 64 | gram -= means[None, :] 65 | np.fill_diagonal(gram, 0) 66 | else: 67 | means = np.mean(gram, 0, dtype=np.float64) 68 | means -= np.mean(means) / 2 69 | gram -= means[:, None] 70 | gram -= means[None, :] 71 | 72 | return gram 73 | 74 | 75 | def cka(gram_x, gram_y, debiased=False): 76 | """Compute CKA. 77 | 78 | Args: 79 | gram_x: A num_examples x num_examples Gram matrix. 80 | gram_y: A num_examples x num_examples Gram matrix. 81 | debiased: Use unbiased estimator of HSIC. CKA may still be biased. 82 | 83 | Returns: 84 | The value of CKA between X and Y. 85 | """ 86 | gram_x = center_gram(gram_x, unbiased=debiased) 87 | gram_y = center_gram(gram_y, unbiased=debiased) 88 | 89 | # Note: To obtain HSIC, this should be divided by (n-1)**2 (biased variant) or 90 | # n*(n-3) (unbiased variant), but this cancels for CKA. 91 | scaled_hsic = gram_x.ravel().dot(gram_y.ravel()) 92 | 93 | normalization_x = np.linalg.norm(gram_x) 94 | normalization_y = np.linalg.norm(gram_y) 95 | return scaled_hsic / (normalization_x * normalization_y) 96 | 97 | 98 | def _debiased_dot_product_similarity_helper( 99 | xty, sum_squared_rows_x, sum_squared_rows_y, squared_norm_x, squared_norm_y, 100 | n): 101 | """Helper for computing debiased dot product similarity (i.e. linear HSIC).""" 102 | # This formula can be derived by manipulating the unbiased estimator from 103 | # Song et al. (2007). 104 | return ( 105 | xty - n / (n - 2.) * sum_squared_rows_x.dot(sum_squared_rows_y) 106 | + squared_norm_x * squared_norm_y / ((n - 1) * (n - 2))) 107 | 108 | 109 | # def feature_space_linear_cka(features_x, features_y, debiased=False): 110 | # """Compute CKA with a linear kernel, in feature space. 111 | # 112 | # This is typically faster than computing the Gram matrix when there are fewer 113 | # features than examples. 114 | # 115 | # Args: 116 | # features_x: A num_examples x num_features matrix of features. 117 | # features_y: A num_examples x num_features matrix of features. 118 | # debiased: Use unbiased estimator of dot product similarity. CKA may still be 119 | # biased. Note that this estimator may be negative. 120 | # 121 | # Returns: 122 | # The value of CKA between X and Y. 123 | # """ 124 | # features_x = features_x - torch.mean(features_x, 0, keepdim=True) 125 | # features_y = features_y - torch.mean(features_y, 0, keepdim=True) 126 | # 127 | # a = torch.mm(features_x.t(), features_y) 128 | # b = torch.mm(features_x.t(), features_x) 129 | # c = torch.mm(features_y.t(), features_y) 130 | # dot_product_similarity = torch.linalg.norm(a) ** 2 131 | # normalization_x = torch.linalg.norm(b) 132 | # normalization_y = torch.linalg.norm(c) 133 | # 134 | # if debiased: 135 | # n = features_x.shape[0] 136 | # # Equivalent to np.sum(features_x ** 2, 1) but avoids an intermediate array. 137 | # sum_squared_rows_x = np.einsum('ij,ij->i', features_x, features_x) 138 | # sum_squared_rows_y = np.einsum('ij,ij->i', features_y, features_y) 139 | # squared_norm_x = np.sum(sum_squared_rows_x) 140 | # squared_norm_y = np.sum(sum_squared_rows_y) 141 | # 142 | # dot_product_similarity = _debiased_dot_product_similarity_helper( 143 | # dot_product_similarity, sum_squared_rows_x, sum_squared_rows_y, 144 | # squared_norm_x, squared_norm_y, n) 145 | # normalization_x = np.sqrt(_debiased_dot_product_similarity_helper( 146 | # normalization_x ** 2, sum_squared_rows_x, sum_squared_rows_x, 147 | # squared_norm_x, squared_norm_x, n)) 148 | # normalization_y = np.sqrt(_debiased_dot_product_similarity_helper( 149 | # normalization_y ** 2, sum_squared_rows_y, sum_squared_rows_y, 150 | # squared_norm_y, squared_norm_y, n)) 151 | # 152 | # return dot_product_similarity / (normalization_x * normalization_y) 153 | def feature_space_linear_cka(features_x, features_y, debiased=False): 154 | """Compute CKA with a linear kernel, in feature space. 155 | 156 | This is typically faster than computing the Gram matrix when there are fewer 157 | features than examples. 158 | 159 | Args: 160 | features_x: A num_examples x num_features matrix of features. 161 | features_y: A num_examples x num_features matrix of features. 162 | debiased: Use unbiased estimator of dot product similarity. CKA may still be 163 | biased. Note that this estimator may be negative. 164 | 165 | Returns: 166 | The value of CKA between X and Y. 167 | """ 168 | features_x = features_x - np.mean(features_x, 0, keepdims=True) 169 | features_y = features_y - np.mean(features_y, 0, keepdims=True) 170 | 171 | dot_product_similarity = np.linalg.norm(features_x.T.dot(features_y)) ** 2 172 | normalization_x = np.linalg.norm(features_x.T.dot(features_x)) 173 | normalization_y = np.linalg.norm(features_y.T.dot(features_y)) 174 | 175 | if debiased: 176 | n = features_x.shape[0] 177 | # Equivalent to np.sum(features_x ** 2, 1) but avoids an intermediate array. 178 | sum_squared_rows_x = np.einsum('ij,ij->i', features_x, features_x) 179 | sum_squared_rows_y = np.einsum('ij,ij->i', features_y, features_y) 180 | squared_norm_x = np.sum(sum_squared_rows_x) 181 | squared_norm_y = np.sum(sum_squared_rows_y) 182 | 183 | dot_product_similarity = _debiased_dot_product_similarity_helper( 184 | dot_product_similarity, sum_squared_rows_x, sum_squared_rows_y, 185 | squared_norm_x, squared_norm_y, n) 186 | normalization_x = np.sqrt(_debiased_dot_product_similarity_helper( 187 | normalization_x ** 2, sum_squared_rows_x, sum_squared_rows_x, 188 | squared_norm_x, squared_norm_x, n)) 189 | normalization_y = np.sqrt(_debiased_dot_product_similarity_helper( 190 | normalization_y ** 2, sum_squared_rows_y, sum_squared_rows_y, 191 | squared_norm_y, squared_norm_y, n)) 192 | 193 | return dot_product_similarity / (normalization_x * normalization_y) -------------------------------------------------------------------------------- /Cifar/nets/resnet_cifar.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | import torch.nn as nn 3 | from collections import OrderedDict 4 | import torch.nn.functional as F 5 | from nets.base_models import * 6 | import torch.nn.init as init 7 | import torch.utils.model_zoo as model_zoo 8 | 9 | def _weights_init(m): 10 | classname = m.__class__.__name__ 11 | # print(classname) 12 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 13 | init.kaiming_normal(m.weight) 14 | 15 | 16 | class LambdaLayer(nn.Module): 17 | def __init__(self, lambd): 18 | super(LambdaLayer, self).__init__() 19 | self.lambd = lambd 20 | 21 | def forward(self, x): 22 | return self.lambd(x) 23 | 24 | 25 | class BasicBlock(nn.Module): 26 | expansion = 1 27 | def __init__(self, in_planes, planes, stride=1, affine=True): 28 | super(BasicBlock, self).__init__() 29 | conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 30 | bn1 = nn.BatchNorm2d(planes, affine=affine) 31 | conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 32 | bn2 = nn.BatchNorm2d(planes, affine=affine) 33 | 34 | self.conv_bn1 = nn.Sequential(OrderedDict([('conv', conv1), ('bn', bn1)])) 35 | self.conv_bn2 = nn.Sequential(OrderedDict([('conv', conv2), ('bn', bn2)])) 36 | self.shortcut = nn.Sequential() 37 | if stride != 1 or in_planes != planes: 38 | if stride != 1: 39 | self.shortcut = LambdaLayer( 40 | lambda x: F.pad(x[:, :, ::2, ::2], 41 | (0, 0, 0, 0, (planes - in_planes) // 2, 42 | planes - in_planes - (planes - in_planes) // 2), "constant", 0)) 43 | else: 44 | self.shortcut = LambdaLayer( 45 | lambda x: F.pad(x[:, :, :, :], 46 | (0, 0, 0, 0, (planes - in_planes) // 2, 47 | planes - in_planes - (planes - in_planes) // 2), "constant", 0)) 48 | 49 | def forward(self, x): 50 | out = F.relu(self.conv_bn1(x)) 51 | out = self.conv_bn2(out) 52 | out += self.shortcut(x) 53 | out = F.relu(out) 54 | return out 55 | 56 | 57 | class ResNet_CIFAR(MyNetwork): 58 | def __init__(self, depth=20, num_classes=10, cfg=None, cutout=False): 59 | super(ResNet_CIFAR, self).__init__() 60 | assert (depth - 2) % 6 == 0, 'depth should be 6n+2' 61 | cfg_base = [] 62 | n = (depth-2)//6 63 | for i in [16, 32, 64]: 64 | for j in range(n): 65 | cfg_base.append(i) 66 | 67 | if cfg is None: 68 | cfg = cfg_base 69 | num_blocks = [] 70 | if depth==20: 71 | num_blocks = [3, 3, 3] 72 | elif depth==32: 73 | num_blocks = [5, 5, 5] 74 | elif depth==44: 75 | num_blocks = [7, 7, 7] 76 | elif depth==56: 77 | num_blocks = [9, 9, 9] 78 | elif depth==110: 79 | num_blocks = [18, 18, 18] 80 | block = BasicBlock 81 | self.cfg_base = cfg_base 82 | self.num_classes = num_classes 83 | self.num_blocks = num_blocks 84 | self.cutout = cutout 85 | self.cfg = cfg 86 | self.in_planes = 16 87 | conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 88 | bn1 = nn.BatchNorm2d(16) 89 | self.conv_bn = nn.Sequential(OrderedDict([('conv', conv1), ('bn', bn1)])) 90 | self.layer1 = self._make_layer(block, cfg[0:n], num_blocks[0], stride=1) 91 | self.layer2 = self._make_layer(block, cfg[n:2*n], num_blocks[1], stride=2) 92 | self.layer3 = self._make_layer(block, cfg[2*n:], num_blocks[2], stride=2) 93 | self.pool = nn.AdaptiveAvgPool2d(1) 94 | self.linear = nn.Linear(cfg[-1], num_classes) 95 | self.apply(_weights_init) 96 | 97 | def _make_layer(self, block, planes, num_blocks, stride): 98 | strides = [stride] + [1]*(num_blocks-1) 99 | layers = [] 100 | for i in range(len(strides)): 101 | layers.append(('block_%d'%i, block(self.in_planes, planes[i], strides[i]))) 102 | self.in_planes = planes[i] 103 | return nn.Sequential(OrderedDict(layers)) 104 | 105 | def forward(self, x): 106 | if self.training and self.cutout: 107 | with torch.no_grad(): 108 | x = cutout_batch(x, 16) 109 | out = F.relu(self.conv_bn(x)) 110 | out = self.layer1(out) 111 | out = self.layer2(out) 112 | out = self.layer3(out) 113 | out = self.pool(out) 114 | out = out.view(out.size(0), -1) 115 | out = self.linear(out) 116 | return out 117 | 118 | def feature_extract(self, x): 119 | if self.training and self.cutout: 120 | with torch.no_grad(): 121 | x = cutout_batch(x, 16) 122 | tensor = [] 123 | out = F.relu(self.conv_bn(x)) 124 | for i in [self.layer1, self.layer2, self.layer3]: 125 | for _layer in i: 126 | out = _layer(out) 127 | if type(_layer) is BasicBlock: 128 | tensor.append(out) 129 | 130 | return tensor 131 | 132 | def cfg2params(self, cfg): 133 | params = 0 134 | params += (3 * 3 * 3 * 16 + 16 * 2) # conv1+bn1 135 | in_c = 16 136 | cfg_idx = 0 137 | for i in range(3): 138 | num_blocks = self.num_blocks[i] 139 | for j in range(num_blocks): 140 | c = cfg[cfg_idx] 141 | params += (in_c * 3 * 3 * c + 2 * c + c * 3 * 3 * c + 2 * c) # per block params 142 | if in_c != c: 143 | params += in_c * c # shortcut 144 | in_c = c 145 | cfg_idx += 1 146 | params += (self.cfg[-1] + 1) * self.num_classes # fc layer 147 | return params 148 | 149 | def cfg2flops(self, cfg): # to simplify, only count convolution flops 150 | size = 32 151 | flops = 0 152 | flops += (3 * 3 * 3 * 16 * 32 * 32 + 16 * 32 * 32 * 4) # conv1+bn1 153 | in_c = 16 154 | cfg_idx = 0 155 | for i in range(3): 156 | num_blocks = self.num_blocks[i] 157 | if i==1 or i==2: 158 | size = size // 2 159 | for j in range(num_blocks): 160 | c = cfg[cfg_idx] 161 | flops += (in_c * 3 * 3 * c * size * size + c * size * size * 4 + c * 3 * 3 * c * size * size + c * size * size * 4) # per block flops 162 | if in_c != c: 163 | flops += in_c * c * size * size # shortcut 164 | in_c = c 165 | cfg_idx += 1 166 | flops += (2 * self.cfg[-1] + 1) * self.num_classes # fc layer 167 | return flops 168 | 169 | def cfg2flops_perlayer(self, cfg, length): # to simplify, only count convolution flops 170 | size = 32 171 | flops_singlecfg = [0 for j in range(length)] 172 | flops_doublecfg = np.zeros((length, length)) 173 | flops_squarecfg = [0 for j in range(length)] 174 | 175 | in_c = 16 176 | cfg_idx = 0 177 | for i in range(3): 178 | num_blocks = self.num_blocks[i] 179 | if i==1 or i==2: 180 | size = size // 2 181 | for j in range(num_blocks): 182 | c = cfg[cfg_idx] 183 | if i==0 and j==0: 184 | flops_singlecfg[cfg_idx] += (c * size * size * 4 + c * size * size * 4 + in_c * 3 * 3 * c * size * size) 185 | flops_squarecfg[cfg_idx] += c * 3 * 3 * c * size * size 186 | else: 187 | flops_singlecfg[cfg_idx] += (c * size * size * 4 + c * size * size * 4) 188 | flops_doublecfg[cfg_idx-1][cfg_idx] += in_c * 3 * 3 * c * size * size 189 | flops_doublecfg[cfg_idx][cfg_idx-1] += in_c * 3 * 3 * c * size * size 190 | flops_squarecfg[cfg_idx] += (c * 3 * 3 * c * size * size ) 191 | if in_c != c: 192 | flops_doublecfg[cfg_idx][cfg_idx-1] += in_c * c * size * size # shortcut 193 | flops_doublecfg[cfg_idx-1][cfg_idx] += in_c * c * size * size 194 | in_c = c 195 | cfg_idx += 1 196 | 197 | flops_singlecfg[-1] += 2 * self.cfg[-1] * self.num_classes # fc layer 198 | return flops_singlecfg, flops_doublecfg, flops_squarecfg 199 | 200 | @property 201 | def config(self): 202 | return { 203 | 'name': self.__class__.__name__, 204 | 'cfg': self.cfg, 205 | 'cfg_base': self.cfg_base, 206 | 'dataset': 'cifar10', 207 | } -------------------------------------------------------------------------------- /Cifar/utils/utils.py: -------------------------------------------------------------------------------- 1 | """ This file stores some common functions for learners """ 2 | 3 | import torch 4 | import torch.nn as nn 5 | import pdb 6 | import math 7 | import numpy as np 8 | import logging, sys 9 | import os 10 | import torch.distributed as dist 11 | import torch.nn.functional as F 12 | from collections import defaultdict 13 | 14 | 15 | class CrossEntropyLabelSmooth(nn.Module): 16 | def __init__(self, num_classes, epsilon): 17 | super(CrossEntropyLabelSmooth, self).__init__() 18 | self.num_classes = num_classes 19 | self.epsilon = epsilon 20 | self.logsoftmax = nn.LogSoftmax(dim=1) 21 | 22 | def forward(self, inputs, targets): 23 | log_probs = self.logsoftmax(inputs) 24 | targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1) 25 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 26 | loss = (-targets * log_probs).mean(0).sum() 27 | return loss 28 | 29 | 30 | class AverageMeter(object): 31 | """Computes and stores the average and current value""" 32 | def __init__(self, name, fmt=':f'): 33 | self.name = name 34 | self.fmt = fmt 35 | self.reset() 36 | 37 | def reset(self): 38 | self.val = 0 39 | self.avg = 0 40 | self.sum = 0 41 | self.count = 0 42 | 43 | def update(self, val, n=1): 44 | self.val = val 45 | self.sum += val * n 46 | self.count += n 47 | self.avg = self.sum / self.count 48 | 49 | def __str__(self): 50 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 51 | return fmtstr.format(**self.__dict__) 52 | 53 | 54 | class ProgressMeter(object): 55 | """ Code taken from torchvision """ 56 | def __init__(self, num_batches, *meters, prefix=""): 57 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 58 | self.meters = meters 59 | self.prefix = prefix 60 | 61 | def show(self, batch): 62 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 63 | entries += [str(meter) for meter in self.meters] 64 | logging.info(' '.join(entries)) 65 | # print(' '.join(entries)) 66 | 67 | def _get_batch_fmtstr(self, num_batches): 68 | num_digits = len(str(num_batches // 1)) 69 | fmt = '{:' + str(num_digits) + 'd}' 70 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 71 | 72 | 73 | def accuracy(output, target, topk=(1,)): 74 | """Computes the accuracy over the k top predictions for the specified values of k 75 | Code taken from torchvision. 76 | """ 77 | 78 | with torch.no_grad(): 79 | maxk = max(topk) 80 | batch_size = target.size(0) 81 | 82 | _, pred = output.topk(maxk, 1, True, True) 83 | pred = pred.t() 84 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 85 | 86 | res = [] 87 | for k in topk: 88 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 89 | res.append(correct_k.mul_(100.0 / batch_size)) 90 | return res 91 | 92 | def accuracy_seen_unseen(output, target, seen_classes, topk=(1.)): 93 | """ Compute the acc of seen classes and unseen classes seperately.""" 94 | seen_classes = list(range(seen_classes)) 95 | with torch.no_grad(): 96 | maxk = max(topk) 97 | batch_size = target.size(0) 98 | 99 | _, pred = output.topk(maxk, 1, True, True) 100 | pred = pred.t() 101 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 102 | 103 | seen_ind = torch.zeros(target.shape[0]).byte().cuda() 104 | unseen_ind = torch.ones(target.shape[0]).byte().cuda() 105 | for v in seen_classes: 106 | seen_ind += (target == v) 107 | unseen_ind -= seen_ind 108 | 109 | seen_correct = pred[:, seen_ind].eq(target[seen_ind].view(1, -1).expand_as(pred[:, seen_ind])) 110 | unseen_correct = pred[:, unseen_ind].eq(target[unseen_ind].view(1, -1).expand_as(pred[:, unseen_ind])) 111 | 112 | seen_num = seen_correct.shape[1] 113 | res = [] 114 | seen_accs = [] 115 | unseen_accs = [] 116 | for k in topk: 117 | seen_correct_k = seen_correct[:k].view(-1).float().sum(0, keepdim=True) 118 | seen_accs.append(seen_correct_k.mul_(100.0 / (seen_num+1e-10))) 119 | unseen_correct_k = unseen_correct[:k].view(-1).float().sum(0, keepdim=True) 120 | unseen_accs.append(unseen_correct_k.mul_(100.0 / (batch_size-seen_num+1e-10))) 121 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 122 | res.append(correct_k.mul_(100.0 / batch_size)) 123 | 124 | return res, seen_accs, unseen_accs, seen_num 125 | 126 | def is_last_layer(layer): 127 | W = layer.weight 128 | if W.ndimension() == 2 and (W.shape[0] == 10 or W.shape[0] == 100 or W.shape[0] == 1000): 129 | return True 130 | else: 131 | return False 132 | 133 | 134 | def is_first_layer(layer): 135 | if isinstance(layer, nn.Conv2d): 136 | W = layer.weight 137 | if W.ndimension() == 4 and (W.shape[1] == 3 or W.shape[1] == 1): 138 | return True 139 | else: 140 | return False 141 | else: 142 | return False 143 | 144 | 145 | def reinitialize_conv_weights(model, init_first=False): 146 | """ Only re-initialize the kernels for conv layers """ 147 | print("re-intializing conv weights done. evaluating...") 148 | for W in model.parameters(): 149 | if W.ndimension() == 4: 150 | if W.shape[1] == 3 and not init_first: 151 | continue # do not init first conv layer 152 | nn.init.kaiming_uniform_(W, a = math.sqrt(5)) 153 | 154 | 155 | def weights_init(m): 156 | """ default param initializer in pytorch. """ 157 | if isinstance(m, nn.Conv2d): 158 | n = m.in_channels 159 | for k in m.kernel_size: 160 | n *= k 161 | stdv = 1. / math.sqrt(n) 162 | m.weight.data.uniform_(-stdv, stdv) 163 | if m.bias is not None: 164 | m.bias.data.uniform_(-stdv, stdv) 165 | elif isinstance(m, nn.Linear): 166 | stdv = 1. / math.sqrt(m.weight.size(1)) 167 | m.weight.data.uniform_(-stdv, stdv) 168 | if m.bias is not None: 169 | m.bias.data.uniform_(-stdv, stdv) 170 | 171 | 172 | class keydefaultdict(defaultdict): 173 | def __missing__(self, key): 174 | if self.default_factory is None: 175 | raise KeyError(key) 176 | else: 177 | ret = self[key] = self.default_factory(key) 178 | return ret 179 | 180 | 181 | class WarmUpCosineLRScheduler: 182 | """ 183 | update lr every step 184 | """ 185 | def __init__(self, optimizer, T_max, eta_min, base_lr, warmup_lr, warmup_steps, last_iter=-1): 186 | # Attach optimizer 187 | if not isinstance(optimizer, torch.optim.Optimizer): 188 | raise TypeError('{} is not an Optimizer'.format( 189 | type(optimizer).__name__)) 190 | self.optimizer = optimizer 191 | 192 | # Initialize step and base learning rates 193 | if last_iter == -1: 194 | for group in optimizer.param_groups: 195 | group.setdefault('initial_lr', group['lr']) 196 | else: 197 | for i, group in enumerate(optimizer.param_groups): 198 | if 'initial_lr' not in group: 199 | raise KeyError("param 'initial_lr' is not specified " 200 | "in param_groups[{}] when resuming an optimizer".format(i)) 201 | self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups)) 202 | self.last_iter = last_iter 203 | 204 | assert warmup_steps < T_max 205 | self.T_max = T_max 206 | self.eta_min = eta_min 207 | 208 | # warmup settings 209 | self.base_lr = base_lr 210 | self.warmup_lr = warmup_lr 211 | self.warmup_step = warmup_steps 212 | self.warmup_k = None 213 | 214 | def state_dict(self): 215 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} 216 | 217 | def load_state_dict(self, state_dict): 218 | self.__dict__.update(state_dict) 219 | 220 | def get_lr(self): 221 | return list(map(lambda group: group['lr'], self.optimizer.param_groups)) 222 | 223 | def step(self, this_iter=None): 224 | if this_iter is None: 225 | this_iter = self.last_iter + 1 226 | self.last_iter = this_iter 227 | 228 | # get lr during warmup stage 229 | if self.warmup_step > 0 and this_iter < self.warmup_step: 230 | if self.warmup_k is None: 231 | self.warmup_k = (self.warmup_lr - self.base_lr) / self.warmup_step 232 | scale = (self.warmup_k * this_iter + self.base_lr) / self.base_lr 233 | # get lr during cosine annealing 234 | else: 235 | step_ratio = (this_iter - self.warmup_step) / (self.T_max - self.warmup_step) 236 | scale = self.eta_min + (self.warmup_lr - self.eta_min) * (1 + math.cos(math.pi * step_ratio)) / 2 237 | scale /= self.base_lr 238 | 239 | values = [scale * lr for lr in self.base_lrs] 240 | for param_group, lr in zip(self.optimizer.param_groups, values): 241 | param_group['lr'] = lr 242 | -------------------------------------------------------------------------------- /Cifar/learners/basic_learner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import logging 5 | import os 6 | import pdb 7 | import torch.distributed as dist 8 | from tensorboardX import SummaryWriter 9 | from utils.utils import * 10 | from utils.dist_utils import * 11 | from timeit import default_timer as timer 12 | import time 13 | import copy 14 | 15 | 16 | class BasicLearner(object): 17 | """ Performs vanilla training """ 18 | def __init__(self, model, loaders, args, device): 19 | self.args = args 20 | self.device = device 21 | self.model = model 22 | self.__build_path() 23 | self.train_loader, self.test_loader = loaders 24 | self.setup_optim() 25 | self.criterion = nn.CrossEntropyLoss().cuda() 26 | if self.check_is_primary(): 27 | self.writer = SummaryWriter(os.path.dirname(self.save_path)) 28 | # self.add_graph() 29 | 30 | def train(self, train_sampler=None): 31 | self.warm_up_lr() 32 | for epoch in range(self.args.epochs): 33 | if self.args.distributed: 34 | assert train_sampler != None 35 | train_sampler.set_epoch(epoch) 36 | 37 | self.model.train() 38 | if self.check_is_primary(): 39 | logging.info("Training at Epoch: %d" % epoch) 40 | train_acc, train_loss = self.epoch(True) 41 | 42 | if self.check_is_primary(): 43 | self.writer.add_scalar('train_acc', train_acc, epoch) 44 | self.writer.add_scalar('train_loss', train_loss, epoch) 45 | 46 | if self.lr_scheduler: 47 | self.lr_scheduler.step() 48 | 49 | if (epoch+1) % self.args.eval_epoch == 0: 50 | # evaluate every GPU, but we only show the results on a single one.! 51 | if self.check_is_primary(): 52 | logging.info("Evaluation at Epoch: %d" % epoch) 53 | self.evaluate(True, epoch) 54 | 55 | if self.check_is_primary(): 56 | self.save_model() 57 | 58 | def evaluate(self, is_train=False, epoch=None): 59 | self.model.eval() 60 | # NOTE: syncronizing the BN statistics 61 | if self.args.distributed: 62 | sync_bn_stat(self.model, self.args.world_size) 63 | 64 | if not is_train: 65 | self.load_model() 66 | 67 | with torch.no_grad(): 68 | test_acc, test_loss = self.epoch(False) 69 | 70 | if is_train and epoch and self.check_is_primary(): 71 | self.writer.add_scalar('test_acc', test_acc, epoch) 72 | self.writer.add_scalar('test_loss', test_loss, epoch) 73 | return test_acc, test_loss 74 | 75 | def finetune(self, train_sampler): 76 | self.load_model() 77 | self.evaluate() 78 | 79 | for epoch in range(self.args.epochs): 80 | if self.args.distributed: 81 | assert train_sampler != None 82 | train_sampler.set_epoch(epoch) 83 | 84 | self.model.train() 85 | 86 | # NOTE: use the preset learning rate for all epochs. 87 | ft_acc, ft_loss = self.epoch(True) 88 | 89 | if self.check_is_primary(): 90 | self.writer.add_scalar('ft_acc', ft_acc, epoch) 91 | self.writer.add_scalar('ft_loss', ft_loss, epoch) 92 | 93 | # evaluate every k step 94 | if (epoch+1) % self.args.eval_epoch == 0: 95 | if self.check_is_primary(): 96 | logging.info("Evaluation at Epoch: %d" % epoch) 97 | self.evaluate(True, epoch) 98 | 99 | # save the model 100 | if self.check_is_primary(): 101 | self.save_model() 102 | 103 | def misc(self): 104 | raise NotImplementedError("Misc functions are implemented in sub classes") 105 | 106 | def epoch(self, is_train): 107 | """ Rewrite this function if necessary in the sub-classes. """ 108 | 109 | loader = self.train_loader if is_train else self.test_loader 110 | 111 | # setup statistics 112 | batch_time = AverageMeter('Time', ':3.3f') 113 | # data_time = AverageMeter('Data', ':6.3f') 114 | losses = AverageMeter('Loss', ':.4e') 115 | top1 = AverageMeter('Acc@1', ':3.3f') 116 | top5 = AverageMeter('Acc@5', ':3.3f') 117 | metrics = [batch_time, top1, top5, losses] 118 | 119 | loader_len = len(loader) 120 | progress = ProgressMeter(loader_len, *metrics, prefix='Job id: %s, ' % self.args.job_id) 121 | end = time.time() 122 | 123 | for idx, (X, y) in enumerate(loader): 124 | 125 | # data_time.update(time.time() - end) 126 | X, y = X.to(self.device), y.to(self.device) 127 | yp = self.model(X) 128 | loss = self.criterion(yp, y) / self.args.world_size 129 | 130 | acc1, acc5 = accuracy(yp, y, topk=(1, 5)) 131 | 132 | reduced_loss = loss.data.clone() 133 | reduced_acc1 = acc1.clone() / self.args.world_size 134 | reduced_acc5 = acc5.clone() / self.args.world_size 135 | 136 | if self.args.distributed: 137 | dist.all_reduce(reduced_loss) 138 | dist.all_reduce(reduced_acc1) 139 | dist.all_reduce(reduced_acc5) 140 | 141 | if is_train: 142 | self.opt.zero_grad() 143 | loss.backward() 144 | if self.args.distributed: 145 | average_gradients(self.model) # NOTE: important 146 | self.opt.step() 147 | 148 | # update statistics 149 | top1.update(reduced_acc1[0].item(), X.shape[0]) 150 | top5.update(reduced_acc5[0].item(), X.shape[0]) 151 | losses.update(reduced_loss.item(), X.shape[0]) 152 | batch_time.update(time.time() - end) 153 | end = time.time() 154 | 155 | # show the training/evaluating statistics 156 | if self.check_is_primary() and ((idx % self.args.print_freq == 0) or (idx+1) % loader_len == 0): 157 | progress.show(idx) 158 | 159 | return top1.avg, losses.avg 160 | 161 | def setup_optim(self): 162 | if self.args.model_type.startswith('model_'): 163 | self.opt = optim.SGD(self.model.parameters(), lr=self.args.lr, \ 164 | momentum=self.args.momentum, nesterov=self.args.nesterov, \ 165 | weight_decay=self.args.weight_decay) 166 | self.lr_scheduler = optim.lr_scheduler.MultiStepLR(self.opt, milestones=[int(self.args.epochs*0.5), int(self.args.epochs*0.75)]) 167 | 168 | elif self.args.model_type.startswith('resnet_') or self.args.model_type.startswith('vgg_'): 169 | self.opt = optim.SGD(self.model.parameters(), lr=self.args.lr, \ 170 | momentum=self.args.momentum, nesterov=self.args.nesterov, \ 171 | weight_decay=self.args.weight_decay) 172 | if self.args.lr_decy_type == 'cosine': 173 | self.lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(self.opt, self.args.epochs, eta_min=0) 174 | elif self.args.lr_decy_type == 'multi_step': 175 | self.lr_scheduler = optim.lr_scheduler.MultiStepLR(self.opt, milestones=[int(self.args.epochs*0.5), int(self.args.epochs*0.75)]) 176 | else: 177 | raise ValueError("Unknown decy type") 178 | 179 | elif self.args.model_type.startswith('wideresnet'): 180 | self.opt = optim.SGD(self.model.parameters(), lr=self.args.lr, \ 181 | momentum=self.args.momentum, nesterov=self.args.nesterov, \ 182 | weight_decay=self.args.weight_decay) 183 | self.lr_scheduler = optim.lr_scheduler.MultiStepLR(self.opt, milestones=[int(self.args.epochs*0.3), int(self.args.epochs*0.6), int(self.args.epochs*0.8)], gamma=0.2) 184 | 185 | elif self.args.model_type.startswith('mobilenet_v2'): 186 | # default: 150 epochs, 5e-2 lr with cosine, 4e-5 wd, no dropout 187 | self.opt = optim.SGD(self.model.parameters(), lr=self.args.lr, \ 188 | momentum=self.args.momentum, nesterov=self.args.nesterov, \ 189 | weight_decay=self.args.weight_decay) 190 | self.lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(self.opt, self.args.epochs, eta_min=0) 191 | else: 192 | raise ValueError("Unknown model, failed to initalize optim") 193 | 194 | def add_graph(self): 195 | # create dummy input 196 | x = torch.randn(self.args.batch_size, 3, 32, 32) 197 | with self.writer: 198 | self.writer.add_graph(self.model, (x,)) 199 | 200 | def __build_path(self): 201 | if self.args.exec_mode == 'finetune': 202 | self.load_path = self.args.load_path 203 | self.save_path = os.path.join(os.path.dirname(self.load_path), 'model_ft.pt') 204 | elif self.args.exec_mode == 'train': 205 | self.save_path = os.path.join(self.args.save_path, '_'.join([self.args.model_type, self.args.learner]), self.args.job_id, 'model.pt') 206 | self.load_path = self.save_path 207 | else: 208 | self.load_path = self.args.load_path 209 | self.save_path = self.load_path 210 | 211 | def warm_up_lr(self): 212 | if self.args.model_type.endswith('1202') or self.args.model_type.endswith('110'): 213 | for param in self.opt.param_groups: 214 | param['lr'] = 0.1 * self.args.lr 215 | else: 216 | pass 217 | 218 | def check_is_primary(self): 219 | if (self.args.distributed and self.args.rank == 0) or \ 220 | not self.args.distributed: 221 | return True 222 | else: 223 | return False 224 | 225 | def save_model(self): 226 | state = {'state_dict': self.model.state_dict(), \ 227 | 'optimizer': self.opt.state_dict()} 228 | torch.save(state, self.save_path) 229 | logging.info("Model stored at: " + self.save_path) 230 | 231 | def load_model(self): 232 | if self.args.distributed: 233 | # read parameters to each GPU seperately 234 | loc = 'cuda:{}'.format(torch.cuda.current_device()) 235 | checkpoint = torch.load(self.load_path, map_location=loc) 236 | else: 237 | checkpoint = torch.load(self.load_path) 238 | 239 | self.model.load_state_dict(checkpoint['state_dict']) 240 | self.opt.load_state_dict(checkpoint['optimizer']) 241 | logging.info("Model succesfully restored from %s" % self.load_path) 242 | 243 | if self.args.distributed: 244 | broadcast_params(self.model) 245 | 246 | -------------------------------------------------------------------------------- /Cifar/main.py: -------------------------------------------------------------------------------- 1 | import traceback 2 | import argparse 3 | import torch 4 | import torch.backends.cudnn as cudnn 5 | from nets.model_helper import ModelHelper 6 | from learners.basic_learner import BasicLearner 7 | from data_loader import * 8 | from utils.helper import * 9 | import torch.multiprocessing as mp 10 | import torch.distributed as dist 11 | from utils.dist_utils import * 12 | import os 13 | import numpy as np 14 | import random 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--fix_random', action='store_true', help='fix randomness') 18 | parser.add_argument('--seed', default=1234, type=int) 19 | parser.add_argument('--gpu_id', default='0', help='gpu id') 20 | parser.add_argument('--job_id', default='', help='generated automatically') 21 | parser.add_argument('--exec_mode', default='', choices=['train', 'eval', 'finetune', 'misc']) 22 | parser.add_argument('--print_freq', default=30, type=int, help='training printing frequency') 23 | parser.add_argument('--model_type', default='', help='what model do you want to use') 24 | parser.add_argument('--data_path', default='TODO', help='path to the dataset') 25 | parser.add_argument('--data_aug', action='store_false', help='whether to use data augmentation') 26 | parser.add_argument('--save_path', default='./models', help='save path for models') 27 | parser.add_argument('--load_path', default='', help='path to load the model') 28 | parser.add_argument('--load_path_log', default='record.log', help='suffix to name record.log') 29 | parser.add_argument('--eval_epoch', type=int, default=1, help='perform eval at intervals') 30 | 31 | # multi-process gpu 32 | parser.add_argument('--dist-url', default='tcp://127.0.0.1', type=str, 33 | help='url used to set up distributed training') 34 | parser.add_argument('--num_workers', type=int, default=4, help='num of workers data loaders') 35 | parser.add_argument('--port', default=6669, type=int, help='port of server') 36 | parser.add_argument('--world-size', default=1, type=int, 37 | help='number of nodes for distributed training') 38 | parser.add_argument('--rank', default=0, type=int, 39 | help='node rank for distributed training') 40 | parser.add_argument('--local_rank', type=int, 41 | help='local node rank for distributed learning (useless for now)') 42 | parser.add_argument('--dist-backend', default='nccl', type=str, 43 | help='distributed backend') 44 | parser.add_argument('--distributed', default=False, type=bool, help='use distributed or not. Do NOT SET THIS') 45 | 46 | # learner config 47 | parser_learner = parser.add_argument_group('Learner') 48 | parser_learner.add_argument('--learner', default='', choices=['vanilla', 'trick', 'chann_cifar', 'chann_ilsvrc']) #'chann_search' is excluded for simplicity 49 | parser_learner.add_argument('--lr', type=float, default=1e-1, help='initial learning rate') 50 | parser_learner.add_argument('--lr_decy_type', default='multi_step', choices=['multi_step', 'cosine'], help='lr decy type for model params') 51 | parser_learner.add_argument('--lr_min', type=float, default=2e-4, help='minimal learning rate for cosine decy') 52 | parser_learner.add_argument('--momentum', type=float, default=0.9, help='momentum value') 53 | parser_learner.add_argument('--nesterov', action='store_true') 54 | parser_learner.add_argument('--weight_decay', type=float, default=2e-4, help='weight decay for resnet') 55 | parser_learner.add_argument('--dropout_rate', default=0.0, type=float, help='dropout rate for mobilenet_v2, default: 0') 56 | 57 | parser.add_argument('--gamma', type=float, default=0.99, help='time decay for baseline update') 58 | parser_learner.add_argument('--flops', action='store_true', help='add flops regularization') 59 | parser_learner.add_argument('--orthg_weight', type=float, default=0.001, help='orthognal regularization weight') 60 | parser_learner.add_argument('--ft_schedual' , default='', choices=['follow_meta','fixed']) 61 | parser_learner.add_argument('--ft_proj_lr' , type=float, default=0.001, help='lr in fixed ft_schedual') 62 | parser_learner.add_argument('--norm_constraint' , default='', choices=['regularization','constraint']) 63 | parser_learner.add_argument('--norm_weight' , type=float, default=0.01, help='norm 1 regularization weight') 64 | parser_learner.add_argument('--updt_proj', type=int, default=10, help='intervals to update orthogonal loss') 65 | parser_learner.add_argument('--ortho_type', default='l1', choices=['l1', 'l2'], help='type of norm for orthogonal regularization') 66 | 67 | parser_arc = parser.add_argument_group('Architecture') 68 | parser_arc.add_argument('--cfg', default='', help='decoding cfgs obtained by one-shot model') 69 | 70 | # dataset params 71 | parser_dataset = parser.add_argument_group('Dataset') 72 | parser_dataset.add_argument('--dataset', choices=['cifar10', 'cifar100', 'ilsvrc_12'], default='cifar10', help='dataset to use') 73 | parser_dataset.add_argument('--batch_size', type=int, default=128, help='batch_size for dataset') 74 | parser_dataset.add_argument('--epochs', type=int, default=200, help='resnet:200') 75 | parser_dataset.add_argument('--warmup_epochs', type=int, default=-1, help='epochs training weights only') 76 | parser_dataset.add_argument('--train_portion', type=float, default=0.5, help='portion of training data') 77 | parser_dataset.add_argument('--total_portion', type=float, default=1.0, help='portion of data will be used in searching') 78 | 79 | # trick params 80 | parser_trick = parser.add_argument_group("Tricks") 81 | parser_trick.add_argument('--mixup', action='store_true') 82 | parser_trick.add_argument('--mixup_alpha', type=float, default=1.0) 83 | parser_trick.add_argument('--label_smooth', action='store_true') 84 | parser_trick.add_argument('--label_smooth_eps', type=float, default=0.1) 85 | parser_trick.add_argument('--se', action='store_true') 86 | parser_trick.add_argument('--se_reduction', type=int, default=-1) 87 | 88 | def set_loader(args): 89 | sampler = None 90 | 91 | if args.dataset == 'cifar10': 92 | if args.learner in ['vanilla']: 93 | loaders = cifar10_loader_train(args, num_workers=4) 94 | else: 95 | raise ValueError("Unknown dataset") 96 | 97 | return loaders, sampler 98 | 99 | 100 | def set_learner(args): 101 | if args.learner == 'vanilla': 102 | return BasicLearner 103 | else: 104 | raise ValueError("Unknown learner") 105 | 106 | 107 | def main(): 108 | ##################### Preliminaries ################## 109 | args = parser.parse_args() 110 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id) 111 | 112 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 113 | ngpus_per_node = torch.cuda.device_count() 114 | args.distributed = True if ngpus_per_node > 1 else False 115 | ##################################################### 116 | 117 | try: 118 | if args.distributed: 119 | rank, world_size = init_dist(backend=args.dist_backend, master_ip=args.dist_url, port=args.port) 120 | args.rank = rank 121 | args.world_size = world_size 122 | print("Distributed Enabled. Rank %d initalized" % args.rank) 123 | else: 124 | print("Single model training...") 125 | 126 | if args.fix_random: 127 | # init seed within each thread 128 | manualSeed = args.seed 129 | np.random.seed(manualSeed) 130 | random.seed(manualSeed) 131 | torch.manual_seed(manualSeed) 132 | torch.cuda.manual_seed(manualSeed) 133 | torch.cuda.manual_seed_all(manualSeed) 134 | # NOTE: literally you should uncomment the following, but slower 135 | cudnn.deterministic = True 136 | cudnn.benchmark = False 137 | print('Warning: You have chosen to seed training. ' 138 | 'This will turn on the CUDNN deterministic setting, ' 139 | 'which can slow down your training considerably! ' 140 | 'You may see unexpected behavior when restarting ' 141 | 'from checkpoints.') 142 | else: 143 | cudnn.benchmark = True 144 | 145 | if args.rank == 0: 146 | # either single GPU or multi GPU with rank = 0 147 | print("+++++++++++++++++++++++++") 148 | print("torch version:", torch.__version__) 149 | print("+++++++++++++++++++++++++") 150 | 151 | # setup logging 152 | if args.exec_mode in ['train', 'misc']: 153 | args.job_id = generate_job_id() 154 | init_logging(os.path.join(args.save_path, '_'.join([args.model_type, args.learner]), args.job_id, 'record.log')) 155 | elif args.exec_mode == 'finetune': 156 | args.job_id = generate_job_id() 157 | init_logging(os.path.join(os.path.dirname(args.load_path), 'ft_record.log')) 158 | else: 159 | init_logging(os.path.join(os.path.dirname(args.load_path), args.load_path_log)) 160 | 161 | print_args(vars(args)) 162 | logging.info("Using GPU: "+args.gpu_id) 163 | 164 | # create model 165 | model_helper = ModelHelper() 166 | model = model_helper.get_model(args) 167 | model.cuda() 168 | 169 | if args.distributed: 170 | # share the same initialization 171 | broadcast_params(model) 172 | 173 | loaders, sampler = set_loader(args) 174 | learner_fn = set_learner(args) 175 | learner = learner_fn(model, loaders, args, device) 176 | 177 | if args.exec_mode == 'train': 178 | learner.train(sampler) 179 | elif args.exec_mode == 'eval': 180 | learner.evaluate() 181 | elif args.exec_mode == 'finetune': 182 | learner.finetune(sampler) 183 | elif args.exec_mode == 'misc': 184 | learner.misc() 185 | 186 | if args.distributed: 187 | dist.destroy_process_group() 188 | return 0 189 | 190 | except: 191 | traceback.print_exc() 192 | if args.distributed: 193 | dist.destroy_process_group() 194 | return 1 195 | 196 | if __name__ == '__main__': 197 | main() 198 | --------------------------------------------------------------------------------