├── Cifar
├── learners
│ ├── __init__.py
│ └── basic_learner.py
├── utils
│ ├── __init__.py
│ ├── helper.py
│ ├── dist_utils.py
│ ├── model_profiling.py
│ ├── compute_flops.py
│ ├── pytorch_utils.py
│ └── utils.py
├── run_scripts
│ ├── feature_extract
│ │ ├── extract_resnet20_20m.sh
│ │ └── extract_resnet56_60m.sh
│ ├── evaluate
│ │ ├── evaluate_resnet20_cifar.sh
│ │ └── evaluate_resnet56_cifar.sh
│ └── train
│ │ ├── start_train_resnet20_cifar.sh
│ │ └── start_train_resnet56_cifar.sh
├── nets
│ ├── se_module.py
│ ├── model_helper.py
│ ├── __init__.py
│ ├── base_models.py
│ ├── resnet_20s_decode.py
│ └── resnet_cifar.py
├── data_providers
│ ├── __init__.py
│ └── cifar10.py
├── feature_extract.py
├── data_loader.py
└── main.py
├── .gitignore
├── Imagenet
├── utils
│ ├── __init__.py
│ └── pytorch_utils.py
├── run_scripts
│ ├── evaluate
│ │ ├── evaluate_mobilenetv2_150m.sh
│ │ ├── evaluate_mobilenetv2_220m.sh
│ │ ├── evaluate_mobilenet_50m.sh
│ │ ├── evaluate_mobilenet_150m.sh
│ │ ├── evaluate_mobilenet_285m.sh
│ │ ├── evaluate_resnet50_2000m.sh
│ │ └── evaluate_resnet50_2300m.sh
│ ├── feature_extract
│ │ ├── extract_mobilenet_150m.sh
│ │ ├── extract_mobilenet_285m.sh
│ │ ├── extract_mobilenet_50m.sh
│ │ ├── extract_resnet50_1800m.sh
│ │ ├── extract_resnet50_2000m.sh
│ │ ├── extract_resnet50_2300m.sh
│ │ ├── extract_mobilenetv2_150m.sh
│ │ └── extract_mobilenetv2_220m.sh
│ └── train
│ │ ├── train_resnet50_base.sh
│ │ ├── train_mobilenet_base.sh
│ │ ├── train_mobilenetv2_base.sh
│ │ ├── train_mobilenetv2_150m.sh
│ │ ├── train_mobilenetv2_220m.sh
│ │ ├── train_mobilenet_50m.sh
│ │ ├── train_mobilenet_150m.sh
│ │ ├── train_mobilenet_285m.sh
│ │ ├── train_resnet50_1800m.sh
│ │ ├── train_resnet50_2000m.sh
│ │ └── train_resnet50_2300m.sh
├── prep_imagenet.sh
├── data_providers
│ ├── __init__.py
│ └── imagenet.py
├── evaluate.py
├── models
│ ├── __init__.py
│ └── base_models.py
├── cal_FLOPs.py
├── feature_extract.py
└── train.py
├── ToyExample
├── utils
│ ├── __init__.py
│ └── pytorch_utils.py
├── run_scripts
│ ├── evaluate
│ │ └── evaluate_vgg_100m.sh
│ ├── feature_extract
│ │ └── extract_vgg_100m.sh
│ └── train
│ │ └── train_vgg_100m.sh
├── data_providers
│ ├── __init__.py
│ └── cifar10.py
├── evaluate.py
├── models
│ ├── __init__.py
│ ├── base_models.py
│ └── vgg.py
├── cal_FLOPs.py
├── feature_extract.py
└── train.py
├── requirements.txt
├── README.md
└── CKA
└── cka.py
/Cifar/learners/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | best_models
2 | Exp_base
3 | .ipynb_checkpoints
4 | __pycache__
--------------------------------------------------------------------------------
/Cifar/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .pytorch_utils import *
2 | from .get_data_iter import *
3 |
--------------------------------------------------------------------------------
/Imagenet/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .pytorch_utils import *
2 | from .get_data_iter import *
3 |
--------------------------------------------------------------------------------
/ToyExample/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .pytorch_utils import *
2 | from .get_data_iter import *
3 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.7.1
2 | numpy==1.19.2
3 | torchvision==0.8.2
4 | matplotlib==3.3.2
5 | Pillow==8.0.1
6 | tensorboardX==2.1
7 | scipy==1.5.3
8 | scikit-learn==0.23.2
--------------------------------------------------------------------------------
/Cifar/run_scripts/feature_extract/extract_resnet20_20m.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python3 -m torch.distributed.launch --nproc_per_node=1 feature_extract.py \
3 | --model "resnet20" \
4 | --path "Exp_base/resnet20_base" \
5 | --dataset "cifar10" \
6 | --save_path 'your_data_path' \
7 | --target_flops 20000000 \
8 | --beta 246
--------------------------------------------------------------------------------
/Cifar/run_scripts/feature_extract/extract_resnet56_60m.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python3 -m torch.distributed.launch --nproc_per_node=1 feature_extract.py \
3 | --model "resnet56" \
4 | --path "Exp_base/resnet56_base" \
5 | --dataset "cifar10" \
6 | --save_path 'your_data_path' \
7 | --target_flops 60000000 \
8 | --beta 500
--------------------------------------------------------------------------------
/Imagenet/run_scripts/evaluate/evaluate_mobilenetv2_150m.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python3 -m torch.distributed.launch --nproc_per_node=1 evaluate.py \
3 | --model "mobilenetv2" \
4 | --path "pretrain_models/train_mobilenetv2_150m_19509" \
5 | --dataset "imagenet" \
6 | --save_path 'your_data_path' \
7 | --cfg "[29, 12, 16, 24, 43, 50, 98, 273]"
--------------------------------------------------------------------------------
/Imagenet/run_scripts/evaluate/evaluate_mobilenetv2_220m.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python3 -m torch.distributed.launch --nproc_per_node=1 evaluate.py \
3 | --model "mobilenetv2" \
4 | --path "pretrain_models/train_mobilenetv2_220m_19738" \
5 | --dataset "imagenet" \
6 | --save_path 'your_data_path' \
7 | --cfg "[31, 14, 20, 28, 54, 74, 130, 298]"
--------------------------------------------------------------------------------
/ToyExample/run_scripts/evaluate/evaluate_vgg_100m.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | python3 -m torch.distributed.launch --nproc_per_node=1 evaluate.py \
4 | --model "vgg" \
5 | --path "pretrain_models/train_vgg_100m_22031" \
6 | --dataset "cifar10" \
7 | --save_path 'your_data_path' \
8 | --cfg "[44, 35, 71, 71, 145, 104, 145, 290, 205, 325, 432, 429, 469]"
9 |
--------------------------------------------------------------------------------
/ToyExample/run_scripts/feature_extract/extract_vgg_100m.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python3 -m torch.distributed.launch --nproc_per_node=1 feature_extract.py \
3 | --model "vgg" \
4 | --path "Exp_base/vgg_base" \
5 | --dataset "cifar10" \
6 | --save_path 'your_data_path' \
7 | --target_flops 100000000 \
8 | --beta 231 \
9 | --test_batch_size 1024
10 |
--------------------------------------------------------------------------------
/Imagenet/run_scripts/evaluate/evaluate_mobilenet_50m.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python3 -m torch.distributed.launch --nproc_per_node=1 evaluate.py \
3 | --model "mobilenet" \
4 | --path "pretrain_models/train_mobilenet_150m_31752" \
5 | --dataset "imagenet" \
6 | --save_path 'your_data_path' \
7 | --cfg "[18, 31, 30, 38, 71, 76, 173, 52, 78, 78, 51, 178, 250, 522]"
--------------------------------------------------------------------------------
/Imagenet/run_scripts/evaluate/evaluate_mobilenet_150m.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python3 -m torch.distributed.launch --nproc_per_node=1 evaluate.py \
3 | --model "mobilenet" \
4 | --path "pretrain_models/train_mobilenet_150m_31752" \
5 | --dataset "imagenet" \
6 | --save_path 'your_data_path' \
7 | --cfg "[23, 42, 63, 67, 132, 134, 275, 194, 202, 202, 194, 277, 522, 687]"
--------------------------------------------------------------------------------
/Imagenet/run_scripts/evaluate/evaluate_mobilenet_285m.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python3 -m torch.distributed.launch --nproc_per_node=1 evaluate.py \
3 | --model "mobilenet" \
4 | --path "pretrain_models/train_mobilenet_285m_3956" \
5 | --dataset "imagenet" \
6 | --save_path 'your_data_path' \
7 | --cfg "[28, 52, 88, 90, 179, 182, 362, 311, 314, 314, 312, 367, 733, 1024]"
--------------------------------------------------------------------------------
/ToyExample/run_scripts/train/train_vgg_100m.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python3 -m torch.distributed.launch --nproc_per_node=1 train.py --train \
3 | --model "vgg" \
4 | --cfg "[44, 35, 71, 71, 145, 104, 145, 290, 205, 325, 432, 429, 469]" \
5 | --path "Exp_train/train_vgg_100m_${RANDOM}" \
6 | --dataset "cifar10" \
7 | --save_path 'your_data_path' \
8 | --base_path "Exp_base/vgg_base" \
9 | --warm_epoch 1 \
10 | --n_epochs 300 \
11 | --sync_bn
12 |
13 |
--------------------------------------------------------------------------------
/Cifar/run_scripts/evaluate/evaluate_resnet20_cifar.sh:
--------------------------------------------------------------------------------
1 | ## ============================= resnet cifar10 decode ============================
2 | python main.py \
3 | --gpu_id 0 \
4 | --exec_mode eval \
5 | --learner vanilla \
6 | --dataset cifar10 \
7 | --data_path ~/datasets \
8 | --model_type resnet_decode \
9 | --load_path /ITPruner/Cifar/models/best_models/resnet20_20m/model.pt \
10 | --cfg 16,11,11,11,11,11,11,22,22,21,21,22,22,45,45,43,43,48,48
11 |
12 |
--------------------------------------------------------------------------------
/Imagenet/run_scripts/evaluate/evaluate_resnet50_2000m.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python3 -m torch.distributed.launch --nproc_per_node=1 evaluate.py \
3 | --model "resnet50" \
4 | --path "pretrain_models/train_resnet50_2000m_2229" \
5 | --dataset "imagenet" \
6 | --save_path 'your_data_path' \
7 | --cfg "[50, 53, 54, 125, 54, 54, 54, 54, 101, 108, 179, 107, 107, 107, 107, 107, 107, 203, 216, 148, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 407, 429, 1341, 429, 429, 429, 428]"
--------------------------------------------------------------------------------
/Imagenet/run_scripts/evaluate/evaluate_resnet50_2300m.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python3 -m torch.distributed.launch --nproc_per_node=1 evaluate.py \
3 | --model "resnet50" \
4 | --path "pretrain_models/train_resnet50_2300m_12067" \
5 | --dataset "imagenet" \
6 | --save_path 'your_data_path' \
7 | --cfg "[52, 55, 55, 146, 55, 55, 55, 55, 105, 111, 232, 111, 111, 111, 111, 111, 111, 211, 222, 290, 222, 222, 222, 222, 222, 222, 222, 222, 222, 222, 423, 442, 1453, 442, 442, 442, 442]"
--------------------------------------------------------------------------------
/Cifar/run_scripts/train/start_train_resnet20_cifar.sh:
--------------------------------------------------------------------------------
1 | ## ============================= resnet cifar10 decode ============================
2 | python main.py \
3 | --gpu_id 0 \
4 | --exec_mode train \
5 | --learner vanilla \
6 | --dataset cifar10 \
7 | --data_path ~/datasets \
8 | --model_type resnet_decode \
9 | --lr 0.1 \
10 | --lr_min 0. \
11 | --lr_decy_type cosine \
12 | --weight_decay 5e-4 \
13 | --nesterov \
14 | --epochs 300 \
15 | --cfg 16,11,11,11,11,11,11,22,22,21,21,22,22,45,45,43,43,48,48
16 |
17 |
--------------------------------------------------------------------------------
/Imagenet/prep_imagenet.sh:
--------------------------------------------------------------------------------
1 | # mount -t tmpfs -o size=160G tmpfs /userhome/memory_data
2 | root_dir="/dev/shm/imagenet"
3 | mkdir -p "${root_dir}/train"
4 | mkdir -p "${root_dir}/val"
5 | tar -xvf /userhome/data/ILSVRC2012_img_train.tar -C "${root_dir}/train"
6 | cp /userhome/unzip.sh "${root_dir}/train/"
7 | cd "${root_dir}/train"
8 | chmod +x unzip.sh
9 | ./unzip.sh
10 | tar -xvf /userhome/data/ILSVRC2012_img_val.tar -C "${root_dir}/val"
11 | cp /userhome/valprep.sh "${root_dir}/val/"
12 | cd "${root_dir}/val"
13 | chmod +x valprep.sh
14 | ./valprep.sh
--------------------------------------------------------------------------------
/Imagenet/run_scripts/feature_extract/extract_mobilenet_150m.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | code_path='/ITPruner/Imagenet/'
3 | chmod +x ${code_path}/prep_imagenet.sh
4 | cd ${code_path}
5 | echo "preparing data"
6 | bash ${code_path}/prep_imagenet.sh >> /dev/null
7 | echo "preparing data finished"
8 | python3 -m torch.distributed.launch --nproc_per_node=1 feature_extract.py \
9 | --model "mobilenet" \
10 | --path "Exp_base/mobilenet_base" \
11 | --dataset "imagenet" \
12 | --save_path 'your_data_path' \
13 | --target_flops 150000000 \
14 | --beta 243
--------------------------------------------------------------------------------
/Imagenet/run_scripts/feature_extract/extract_mobilenet_285m.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | code_path='/ITPruner/Imagenet/'
3 | chmod +x ${code_path}/prep_imagenet.sh
4 | cd ${code_path}
5 | echo "preparing data"
6 | bash ${code_path}/prep_imagenet.sh >> /dev/null
7 | echo "preparing data finished"
8 | python3 -m torch.distributed.launch --nproc_per_node=1 feature_extract.py \
9 | --model "mobilenet" \
10 | --path "Exp_base/mobilenet_base" \
11 | --dataset "imagenet" \
12 | --save_path 'your_data_path' \
13 | --target_flops 285000000 \
14 | --beta 233
--------------------------------------------------------------------------------
/Imagenet/run_scripts/feature_extract/extract_mobilenet_50m.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | code_path='/ITPruner/Imagenet/'
3 | chmod +x ${code_path}/prep_imagenet.sh
4 | cd ${code_path}
5 | echo "preparing data"
6 | bash ${code_path}/prep_imagenet.sh >> /dev/null
7 | echo "preparing data finished"
8 | python3 -m torch.distributed.launch --nproc_per_node=1 feature_extract.py \
9 | --model "mobilenet" \
10 | --path "Exp_base/mobilenet_base" \
11 | --dataset "imagenet" \
12 | --save_path 'your_data_path' \
13 | --target_flops 50000000 \
14 | --beta 243
--------------------------------------------------------------------------------
/Imagenet/run_scripts/feature_extract/extract_resnet50_1800m.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | code_path='/ITPruner/Imagenet/'
3 | chmod +x ${code_path}/prep_imagenet.sh
4 | cd ${code_path}
5 | echo "preparing data"
6 | bash ${code_path}/prep_imagenet.sh >> /dev/null
7 | echo "preparing data finished"
8 | python3 -m torch.distributed.launch --nproc_per_node=1 feature_extract.py \
9 | --model "resnet50" \
10 | --path "Exp_base/resnet50_base" \
11 | --dataset "imagenet" \
12 | --save_path 'your_data_path' \
13 | --target_flops 1800000000 \
14 | --beta 2000
--------------------------------------------------------------------------------
/Imagenet/run_scripts/feature_extract/extract_resnet50_2000m.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | code_path='/ITPruner/Imagenet/'
3 | chmod +x ${code_path}/prep_imagenet.sh
4 | cd ${code_path}
5 | echo "preparing data"
6 | bash ${code_path}/prep_imagenet.sh >> /dev/null
7 | echo "preparing data finished"
8 | python3 -m torch.distributed.launch --nproc_per_node=1 feature_extract.py \
9 | --model "resnet50" \
10 | --path "Exp_base/resnet50_base" \
11 | --dataset "imagenet" \
12 | --save_path 'your_data_path' \
13 | --target_flops 2000000000 \
14 | --beta 2000
--------------------------------------------------------------------------------
/Imagenet/run_scripts/feature_extract/extract_resnet50_2300m.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | code_path='/ITPruner/Imagenet/'
3 | chmod +x ${code_path}/prep_imagenet.sh
4 | cd ${code_path}
5 | echo "preparing data"
6 | bash ${code_path}/prep_imagenet.sh >> /dev/null
7 | echo "preparing data finished"
8 | python3 -m torch.distributed.launch --nproc_per_node=1 feature_extract.py \
9 | --model "resnet50" \
10 | --path "Exp_base/resnet50_base" \
11 | --dataset "imagenet" \
12 | --save_path 'your_data_path' \
13 | --target_flops 2300000000 \
14 | --beta 2000
--------------------------------------------------------------------------------
/Imagenet/run_scripts/feature_extract/extract_mobilenetv2_150m.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | code_path='/ITPruner/Imagenet/'
3 | chmod +x ${code_path}/prep_imagenet.sh
4 | cd ${code_path}
5 | echo "preparing data"
6 | bash ${code_path}/prep_imagenet.sh >> /dev/null
7 | echo "preparing data finished"
8 | python3 -m torch.distributed.launch --nproc_per_node=1 feature_extract.py \
9 | --model "mobilenetv2" \
10 | --path "Exp_base/mobilenetv2_base" \
11 | --dataset "imagenet" \
12 | --save_path 'your_data_path' \
13 | --target_flops 150000000 \
14 | --beta 2000
--------------------------------------------------------------------------------
/Imagenet/run_scripts/feature_extract/extract_mobilenetv2_220m.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | code_path='/ITPruner/Imagenet/'
3 | chmod +x ${code_path}/prep_imagenet.sh
4 | cd ${code_path}
5 | echo "preparing data"
6 | bash ${code_path}/prep_imagenet.sh >> /dev/null
7 | echo "preparing data finished"
8 | python3 -m torch.distributed.launch --nproc_per_node=1 feature_extract.py \
9 | --model "mobilenetv2" \
10 | --path "Exp_base/mobilenetv2_base" \
11 | --dataset "imagenet" \
12 | --save_path 'your_data_path' \
13 | --target_flops 220000000 \
14 | --beta 2000
--------------------------------------------------------------------------------
/Imagenet/run_scripts/train/train_resnet50_base.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | code_path='/ITPruner/Imagenet/'
3 | chmod +x ${code_path}/prep_imagenet.sh
4 | cd ${code_path}
5 | echo "preparing data"
6 | bash ${code_path}/prep_imagenet.sh >> /dev/null
7 | echo "preparing data finished"
8 | python3 -m torch.distributed.launch --nproc_per_node=4 train.py --train \
9 | --model "resnet50" \
10 | --path "Exp_base/resnet50_base_${RANDOM}" \
11 | --dataset "imagenet" \
12 | --save_path 'your_data_path' \
13 | --sync_bn \
14 | --warm_epoch 1 \
15 | --n_epochs 120
16 |
--------------------------------------------------------------------------------
/Cifar/run_scripts/evaluate/evaluate_resnet56_cifar.sh:
--------------------------------------------------------------------------------
1 | ## ============================= resnet cifar10 decode ============================
2 | python main.py \
3 | --gpu_id 0 \
4 | --exec_mode eval \
5 | --learner vanilla \
6 | --dataset cifar10 \
7 | --data_path ~/datasets \
8 | --model_type resnet_decode \
9 | --load_path /ITPruner/Cifar/models/best_models/resnet56_60m/model.pt \
10 | --cfg 16,10,10,10,10,10,10,10,10,10,10,10,10,10,10,10,10,11,11,22,22,21,21,21,21,21,21,21,21,21,21,21,21,21,21,22,22,44,44,42,42,42,42,42,42,42,42,42,42,42,42,64,64,64,64
11 |
--------------------------------------------------------------------------------
/Imagenet/run_scripts/train/train_mobilenet_base.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | code_path='/ITPruner/Imagenet/'
3 | chmod +x ${code_path}/prep_imagenet.sh
4 | cd ${code_path}
5 | echo "preparing data"
6 | bash ${code_path}/prep_imagenet.sh >> /dev/null
7 | echo "preparing data finished"
8 | python3 -m torch.distributed.launch --nproc_per_node=4 train.py --train \
9 | --model "mobilenet" \
10 | --path "Exp_base/mobilenet_base_${RANDOM}" \
11 | --dataset "imagenet" \
12 | --save_path 'your_data_path' \
13 | --warm_epoch 1 \
14 | --sync_bn \
15 | --n_epochs 250 \
16 | --label_smoothing 0.1
17 |
--------------------------------------------------------------------------------
/Imagenet/run_scripts/train/train_mobilenetv2_base.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | code_path='/ITPruner/Imagenet/'
3 | chmod +x ${code_path}/prep_imagenet.sh
4 | cd ${code_path}
5 | echo "preparing data"
6 | bash ${code_path}/prep_imagenet.sh >> /dev/null
7 | echo "preparing data finished"
8 | python3 -m torch.distributed.launch --nproc_per_node=4 train.py --train \
9 | --model "mobilenetv2" \
10 | --path "Exp_base/mobilenetv2_base_${RANDOM}" \
11 | --dataset "imagenet" \
12 | --save_path 'your_data_path' \
13 | --warm_epoch 1 \
14 | --sync_bn \
15 | --n_epochs 250 \
16 | --label_smoothing 0.1
17 |
--------------------------------------------------------------------------------
/Cifar/run_scripts/train/start_train_resnet56_cifar.sh:
--------------------------------------------------------------------------------
1 | ## ============================= resnet cifar10 decode ============================
2 | python main.py \
3 | --gpu_id 0 \
4 | --exec_mode train \
5 | --learner vanilla \
6 | --dataset cifar10 \
7 | --data_path ~/datasets \
8 | --model_type resnet_decode \
9 | --lr 0.1 \
10 | --lr_min 0. \
11 | --lr_decy_type cosine \
12 | --weight_decay 5e-4 \
13 | --nesterov \
14 | --epochs 300 \
15 | --cfg 16,10,10,10,10,10,10,10,10,10,10,10,10,10,10,10,10,11,11,22,22,21,21,21,21,21,21,21,21,21,21,21,21,21,21,22,22,44,44,42,42,42,42,42,42,42,42,42,42,42,42,64,64,64,64
16 |
17 |
--------------------------------------------------------------------------------
/Imagenet/run_scripts/train/train_mobilenetv2_150m.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | code_path='/ITPruner/Imagenet/'
3 | chmod +x ${code_path}/prep_imagenet.sh
4 | cd ${code_path}
5 | echo "preparing data"
6 | bash ${code_path}/prep_imagenet.sh >> /dev/null
7 | echo "preparing data finished"
8 | python3 -m torch.distributed.launch --nproc_per_node=4 train.py --train \
9 | --model "mobilenetv2" \
10 | --cfg "[29, 12, 16, 24, 43, 50, 98, 273]" \
11 | --path "Exp_train/train_mobilenetv2_150m_${RANDOM}" \
12 | --dataset "imagenet" \
13 | --save_path 'your_data_path' \
14 | --base_path "Exp_base/mobilenetv2_base" \
15 | --warm_epoch 1 \
16 | --sync_bn \
17 | --n_epochs 250 \
18 | --label_smoothing 0.1
19 |
--------------------------------------------------------------------------------
/Imagenet/run_scripts/train/train_mobilenetv2_220m.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | code_path='/ITPruner/Imagenet/'
3 | chmod +x ${code_path}/prep_imagenet.sh
4 | cd ${code_path}
5 | echo "preparing data"
6 | bash ${code_path}/prep_imagenet.sh >> /dev/null
7 | echo "preparing data finished"
8 | python3 -m torch.distributed.launch --nproc_per_node=4 train.py --train \
9 | --model "mobilenetv2" \
10 | --cfg "[31, 14, 20, 28, 54, 74, 130, 298]" \
11 | --path "Exp_train/train_mobilenetv2_220m_${RANDOM}" \
12 | --dataset "imagenet" \
13 | --save_path 'your_data_path' \
14 | --base_path "Exp_base/mobilenetv2_base" \
15 | --warm_epoch 1 \
16 | --sync_bn \
17 | --n_epochs 250 \
18 | --label_smoothing 0.1
19 |
--------------------------------------------------------------------------------
/Imagenet/run_scripts/train/train_mobilenet_50m.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | code_path='/ITPruner/Imagenet/'
3 | chmod +x ${code_path}/prep_imagenet.sh
4 | cd ${code_path}
5 | echo "preparing data"
6 | bash ${code_path}/prep_imagenet.sh >> /dev/null
7 | echo "preparing data finished"
8 | python3 -m torch.distributed.launch --nproc_per_node=4 train.py --train \
9 | --model "mobilenet" \
10 | --cfg "[18, 31, 30, 38, 71, 76, 173, 52, 78, 78, 51, 178, 250, 522]" \
11 | --path "Exp_train/train_mobilenet_50m_${RANDOM}" \
12 | --dataset "imagenet" \
13 | --save_path 'your_data_path' \
14 | --base_path "Exp_base/mobilenet_base" \
15 | --warm_epoch 1 \
16 | --sync_bn \
17 | --n_epochs 250 \
18 | --label_smoothing 0.1
19 |
--------------------------------------------------------------------------------
/Imagenet/run_scripts/train/train_mobilenet_150m.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | code_path='/ITPruner/Imagenet/'
3 | chmod +x ${code_path}/prep_imagenet.sh
4 | cd ${code_path}
5 | echo "preparing data"
6 | bash ${code_path}/prep_imagenet.sh >> /dev/null
7 | echo "preparing data finished"
8 | python3 -m torch.distributed.launch --nproc_per_node=4 train.py --train \
9 | --model "mobilenet" \
10 | --cfg "[23, 42, 63, 67, 132, 134, 275, 194, 202, 202, 194, 277, 522, 687]" \
11 | --path "Exp_train/train_mobilenet_150m_${RANDOM}" \
12 | --dataset "imagenet" \
13 | --save_path 'your_data_path' \
14 | --base_path "Exp_base/mobilenet_base" \
15 | --warm_epoch 1 \
16 | --sync_bn \
17 | --n_epochs 250 \
18 | --label_smoothing 0.1
19 |
--------------------------------------------------------------------------------
/Imagenet/run_scripts/train/train_mobilenet_285m.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | code_path='/ITPruner/Imagenet/'
3 | chmod +x ${code_path}/prep_imagenet.sh
4 | cd ${code_path}
5 | echo "preparing data"
6 | bash ${code_path}/prep_imagenet.sh >> /dev/null
7 | echo "preparing data finished"
8 | python3 -m torch.distributed.launch --nproc_per_node=4 train.py --train \
9 | --model "mobilenet" \
10 | --cfg "[28, 52, 88, 90, 179, 182, 362, 311, 314, 314, 312, 367, 734, 1024]" \
11 | --path "Exp_train/train_mobilenet_285m_${RANDOM}" \
12 | --dataset "imagenet" \
13 | --save_path 'your_data_path' \
14 | --base_path "Exp_base/mobilenet_base" \
15 | --warm_epoch 1 \
16 | --sync_bn \
17 | --n_epochs 250 \
18 | --label_smoothing 0.1
19 |
--------------------------------------------------------------------------------
/Cifar/nets/se_module.py:
--------------------------------------------------------------------------------
1 | """ Taken from https://github.com/moskomule/senet.pytorch/blob/master/senet/se_module.py """
2 |
3 | from torch import nn
4 |
5 | class SELayer(nn.Module):
6 | def __init__(self, channel, reduction=16):
7 | super(SELayer, self).__init__()
8 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
9 | self.fc = nn.Sequential(
10 | nn.Linear(channel, channel // reduction, bias=False),
11 | nn.ReLU(inplace=True),
12 | nn.Linear(channel // reduction, channel, bias=False),
13 | nn.Sigmoid()
14 | )
15 |
16 | def forward(self, x):
17 | b, c, _, _ = x.size()
18 | y = self.avg_pool(x).view(b, c)
19 | y = self.fc(y).view(b, c, 1, 1)
20 | return x * y.expand_as(x)
21 |
--------------------------------------------------------------------------------
/Imagenet/run_scripts/train/train_resnet50_1800m.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | code_path='/ITPruner/Imagenet/'
3 | chmod +x ${code_path}/prep_imagenet.sh
4 | cd ${code_path}
5 | echo "preparing data"
6 | bash ${code_path}/prep_imagenet.sh >> /dev/null
7 | echo "preparing data finished"
8 | python3 -m torch.distributed.launch --nproc_per_node=4 train.py --train \
9 | --model "resnet50" \
10 | --cfg "[48, 52, 52, 110, 52, 52, 52, 52, 98, 105, 142, 105, 105, 105, 105, 105, 105, 197, 211, 48, 211, 211, 211, 211, 211, 211, 211, 211, 211, 211, 397, 420, 1264, 419, 419, 419, 418]" \
11 | --path "Exp_train/train_resnet50_1800m_${RANDOM}" \
12 | --dataset "imagenet" \
13 | --save_path 'your_data_path' \
14 | --sync_bn \
15 | --warm_epoch 1 \
16 | --base_path "Exp_base/resnet50_base" \
17 | --n_epochs 120
18 |
--------------------------------------------------------------------------------
/Imagenet/run_scripts/train/train_resnet50_2000m.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | code_path='/ITPruner/Imagenet/'
3 | chmod +x ${code_path}/prep_imagenet.sh
4 | cd ${code_path}
5 | echo "preparing data"
6 | bash ${code_path}/prep_imagenet.sh >> /dev/null
7 | echo "preparing data finished"
8 | python3 -m torch.distributed.launch --nproc_per_node=4 train.py --train \
9 | --model "resnet50" \
10 | --cfg "[50, 53, 54, 125, 54, 54, 54, 54, 101, 108, 179, 107, 107, 107, 107, 107, 107, 203, 216, 148, 215, 215, 215, 215, 215, 215, 215, 215, 215, 215, 407, 429, 1341, 429, 429, 429, 428]" \
11 | --path "Exp_train/train_resnet50_2000m_${RANDOM}" \
12 | --dataset "imagenet" \
13 | --save_path 'your_data_path' \
14 | --base_path "Exp_base/resnet50_base" \
15 | --sync_bn \
16 | --warm_epoch 1 \
17 | --n_epochs 120
18 |
--------------------------------------------------------------------------------
/Imagenet/run_scripts/train/train_resnet50_2300m.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | code_path='/ITPruner/Imagenet/'
3 | chmod +x ${code_path}/prep_imagenet.sh
4 | cd ${code_path}
5 | echo "preparing data"
6 | bash ${code_path}/prep_imagenet.sh >> /dev/null
7 | echo "preparing data finished"
8 | python3 -m torch.distributed.launch --nproc_per_node=4 train.py --train \
9 | --model "resnet50" \
10 | --cfg "[52, 55, 55, 146, 55, 55, 55, 55, 105, 111, 232, 111, 111, 111, 111, 111, 111, 211, 222, 290, 222, 222, 222, 222, 222, 222, 222, 222, 222, 222, 423, 442, 1453, 442, 442, 442, 442]" \
11 | --path "Exp_train/train_resnet50_2300m_${RANDOM}" \
12 | --dataset "imagenet" \
13 | --save_path 'your_data_path' \
14 | --base_path "Exp_base/resnet50_base" \
15 | --sync_bn \
16 | --warm_epoch 1 \
17 | --n_epochs 120
18 |
--------------------------------------------------------------------------------
/Cifar/data_providers/__init__.py:
--------------------------------------------------------------------------------
1 | class DataProvider:
2 | VALID_SEED = 0 # random seed for the validation set
3 |
4 | @staticmethod
5 | def name():
6 | """ Return name of the dataset """
7 | raise NotImplementedError
8 |
9 | @property
10 | def data_shape(self):
11 | """ Return shape as python list of one data entry """
12 | raise NotImplementedError
13 |
14 | @property
15 | def n_classes(self):
16 | """ Return `int` of num classes """
17 | raise NotImplementedError
18 |
19 | @property
20 | def save_path(self):
21 | """ local path to save the data """
22 | raise NotImplementedError
23 |
24 | @property
25 | def data_url(self):
26 | """ link to download the data """
27 | raise NotImplementedError
28 |
--------------------------------------------------------------------------------
/Imagenet/data_providers/__init__.py:
--------------------------------------------------------------------------------
1 | class DataProvider:
2 | VALID_SEED = 0 # random seed for the validation set
3 |
4 | @staticmethod
5 | def name():
6 | """ Return name of the dataset """
7 | raise NotImplementedError
8 |
9 | @property
10 | def data_shape(self):
11 | """ Return shape as python list of one data entry """
12 | raise NotImplementedError
13 |
14 | @property
15 | def n_classes(self):
16 | """ Return `int` of num classes """
17 | raise NotImplementedError
18 |
19 | @property
20 | def save_path(self):
21 | """ local path to save the data """
22 | raise NotImplementedError
23 |
24 | @property
25 | def data_url(self):
26 | """ link to download the data """
27 | raise NotImplementedError
28 |
--------------------------------------------------------------------------------
/ToyExample/data_providers/__init__.py:
--------------------------------------------------------------------------------
1 | class DataProvider:
2 | VALID_SEED = 0 # random seed for the validation set
3 |
4 | @staticmethod
5 | def name():
6 | """ Return name of the dataset """
7 | raise NotImplementedError
8 |
9 | @property
10 | def data_shape(self):
11 | """ Return shape as python list of one data entry """
12 | raise NotImplementedError
13 |
14 | @property
15 | def n_classes(self):
16 | """ Return `int` of num classes """
17 | raise NotImplementedError
18 |
19 | @property
20 | def save_path(self):
21 | """ local path to save the data """
22 | raise NotImplementedError
23 |
24 | @property
25 | def data_url(self):
26 | """ link to download the data """
27 | raise NotImplementedError
28 |
--------------------------------------------------------------------------------
/Cifar/utils/helper.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import random
4 | import string
5 |
6 |
7 | def init_logging(log_path):
8 |
9 | if not os.path.isdir(os.path.dirname(log_path)):
10 | print("Log path does not exist. Create a new one.")
11 | os.makedirs(os.path.dirname(log_path))
12 |
13 | log = logging.getLogger()
14 | log.setLevel(logging.INFO)
15 | logFormatter = logging.Formatter('%(asctime)s [%(levelname)s]: %(message)s')
16 |
17 | fileHandler = logging.FileHandler(log_path)
18 | fileHandler.setFormatter(logFormatter)
19 | log.addHandler(fileHandler)
20 |
21 | consoleHandler = logging.StreamHandler()
22 | consoleHandler.setFormatter(logFormatter)
23 | log.addHandler(consoleHandler)
24 |
25 |
26 | def print_args(args):
27 | for k, v in zip(args.keys(), args.values()):
28 | logging.info("{0}: {1}".format(k, v))
29 |
30 |
31 | def generate_job_id():
32 | return ''.join(random.sample(string.ascii_letters+string.digits, 8))
33 |
--------------------------------------------------------------------------------
/Cifar/data_providers/cifar10.py:
--------------------------------------------------------------------------------
1 | from utils import get_cifar_iter
2 | from data_providers import DataProvider
3 |
4 |
5 | class CifarDataProvider(DataProvider):
6 | def __init__(self, save_path=None, train_batch_size=256, test_batch_size=500, valid_size=None,
7 | n_worker=8, manual_seed = 12, local_rank=0, world_size=1, **kwargs):
8 |
9 | self._save_path = save_path
10 | self.valid = None
11 | self.train = get_cifar_iter('train', self.save_path, train_batch_size, n_worker, cutout=16, manual_seed=manual_seed)
12 | self.test = get_cifar_iter('val', self.save_path, test_batch_size, n_worker, cutout=16, manual_seed=manual_seed)
13 | if self.valid is None:
14 | self.valid = self.test
15 |
16 | @staticmethod
17 | def name():
18 | return 'cifar10'
19 |
20 | @property
21 | def data_shape(self):
22 | return 3, 32, 32 # C, H, W
23 |
24 | @property
25 | def n_classes(self):
26 | return 10
27 |
28 | @property
29 | def save_path(self):
30 | if self._save_path is None:
31 | self._save_path = '/userhome/data/cifar10'
32 | return self._save_path
33 |
--------------------------------------------------------------------------------
/ToyExample/data_providers/cifar10.py:
--------------------------------------------------------------------------------
1 | from utils import get_cifar_iter
2 | from data_providers import DataProvider
3 |
4 |
5 | class CifarDataProvider(DataProvider):
6 | def __init__(self, save_path=None, train_batch_size=256, test_batch_size=500, valid_size=None,
7 | n_worker=8, manual_seed = 12, local_rank=0, world_size=1, **kwargs):
8 |
9 | self._save_path = save_path
10 | self.valid = None
11 | self.train = get_cifar_iter('train', self.save_path, train_batch_size, n_worker, cutout=16, manual_seed=manual_seed)
12 | self.test = get_cifar_iter('val', self.save_path, test_batch_size, n_worker, cutout=16, manual_seed=manual_seed)
13 | if self.valid is None:
14 | self.valid = self.test
15 |
16 | @staticmethod
17 | def name():
18 | return 'cifar10'
19 |
20 | @property
21 | def data_shape(self):
22 | return 3, 32, 32 # C, H, W
23 |
24 | @property
25 | def n_classes(self):
26 | return 10
27 |
28 | @property
29 | def save_path(self):
30 | if self._save_path is None:
31 | self._save_path = '/userhome/data/cifar10'
32 | return self._save_path
33 |
--------------------------------------------------------------------------------
/Cifar/nets/model_helper.py:
--------------------------------------------------------------------------------
1 | from nets.resnet_20s_decode import *
2 | from utils.compute_flops import *
3 | import torch
4 |
5 |
6 | class ModelHelper(object):
7 | def __init__(self):
8 | self.model_dict = {'resnet_decode': resnet_decode}
9 |
10 | def get_model(self, args):
11 | if args.model_type not in self.model_dict.keys():
12 | raise ValueError('Wrong model type.')
13 |
14 | num_classes = self.__get_number_class(args.dataset)
15 |
16 | if 'decode' in args.model_type:
17 | if args.cfg == '':
18 | raise ValueError('Running decoding model. Empty cfg!')
19 | cfg = [int(v) for v in args.cfg.split(',')]
20 | model = self.model_dict[args.model_type](cfg, num_classes, args.se, args.se_reduction)
21 | if args.rank == 0:
22 | print(model)
23 | print("Flops and params:")
24 | resol = 32 if args.dataset == 'cifar10' else 224
25 | print_model_param_nums(model)
26 | print_model_param_flops(model, resol, multiply_adds=False)
27 |
28 | return model
29 |
30 | def __get_number_class(self, dataset):
31 | # determine the number of classes
32 | if dataset == 'cifar10':
33 | num_classes = 10
34 | elif dataset == 'cifar100':
35 | num_classes = 100
36 | elif dataset == 'ilsvrc_12':
37 | num_classes = 1000
38 | return num_classes
39 |
40 | def test():
41 | mh = ModelHelper()
42 | for k in mh.model_dict.keys():
43 | print(mh.get_model(k))
44 |
45 |
46 | if __name__ == '__main__':
47 | test()
48 |
49 |
--------------------------------------------------------------------------------
/Cifar/utils/dist_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.multiprocessing as mp
4 | import torch.distributed as dist
5 | import torch.nn as nn
6 |
7 |
8 | def init_dist(backend='nccl',
9 | master_ip='tcp://127.0.0.1',
10 | port=6669):
11 | if mp.get_start_method(allow_none=True) is None:
12 | mp.set_start_method('spawn')
13 | os.environ['MASTER_ADDR'] = master_ip
14 | os.environ['MASTER_PORT'] = str(port)
15 | rank = int(os.environ['RANK'])
16 | world_size = int(os.environ['WORLD_SIZE'])
17 | num_gpus = torch.cuda.device_count()
18 | torch.cuda.set_device(rank % num_gpus)
19 | dist_url = master_ip + ':' + str(port)
20 | dist.init_process_group(backend=backend, init_method=dist_url, \
21 | world_size=world_size, rank=rank)
22 | return rank, world_size
23 |
24 |
25 | def average_gradients(model):
26 | for param in model.parameters():
27 | if param.requires_grad and not (param.grad is None):
28 | dist.all_reduce(param.grad.data)
29 |
30 |
31 | def average_group_gradients(group_params):
32 | for param in group_params:
33 | if param.requires_grad and not (param.grad is None):
34 | dist.all_reduce(param.grad.data)
35 |
36 |
37 | def sync_bn_stat(model, world_size):
38 | if world_size == 1:
39 | return
40 | for mod in model.modules():
41 | if type(mod) == nn.BatchNorm2d:
42 | dist.all_reduce(mod.running_mean.data)
43 | mod.running_mean.data.div_(world_size)
44 | dist.all_reduce(mod.running_var.data)
45 | mod.running_var.data.div_(world_size)
46 |
47 |
48 | def broadcast_params(model):
49 | for p in model.state_dict().values():
50 | dist.broadcast(p, 0)
51 |
52 |
53 | def reduce_vars(vars_list, world_size):
54 | if world_size == 1:
55 | return
56 | vars_result = []
57 | for var in vars_list:
58 | var = var / world_size
59 | dist.all_reduce(var)
60 | vars_result.append(var)
61 | return vars_result
62 |
63 |
64 | def sync_adam_optimizer(opt, world_size):
65 | # use this when using Adam for parallel training
66 | pass
67 |
--------------------------------------------------------------------------------
/ToyExample/evaluate.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 | import numpy as np
5 |
6 | import torch
7 | from run_manager import RunManager
8 | from models import VGG_CIFAR, TrainRunConfig
9 |
10 | parser = argparse.ArgumentParser()
11 |
12 | """ model config """
13 | parser.add_argument('--path', type=str)
14 | parser.add_argument('--model', type=str, default="vgg",
15 | choices=['vgg', 'resnet56', 'resnet110', 'resnet18', 'resnet34', 'resnet50', 'mobilenetv2', 'mobilenet'])
16 | parser.add_argument('--cfg', type=str, default="None")
17 |
18 | """ dataset config """
19 | parser.add_argument('--dataset', type=str, default='cifar10',
20 | choices=['cifar10', 'imagenet'])
21 | parser.add_argument('--save_path', type=str, default='/userhome/data/cifar10')
22 |
23 | """ runtime config """
24 | parser.add_argument('--gpu', help='gpu available', default='0')
25 | parser.add_argument('--test_batch_size', type=int, default=250)
26 | parser.add_argument('--n_worker', type=int, default=24)
27 | parser.add_argument("--local_rank", default=0, type=int)
28 |
29 | if __name__ == '__main__':
30 | args = parser.parse_args()
31 |
32 | # distributed setting
33 | torch.distributed.init_process_group(backend='nccl',
34 | init_method='env://')
35 | args.world_size = torch.distributed.get_world_size()
36 |
37 | # prepare run config
38 | run_config_path = '%s/run.config' % args.path
39 |
40 | run_config = TrainRunConfig(
41 | **args.__dict__
42 | )
43 | if args.local_rank == 0:
44 | print('Run config:')
45 | for k, v in args.__dict__.items():
46 | print('\t%s: %s' % (k, v))
47 |
48 | if args.model == "vgg":
49 | assert args.dataset == 'cifar10', 'vgg only supports cifar10 dataset'
50 | net = VGG_CIFAR(
51 | num_classes=run_config.data_provider.n_classes, cfg=eval(args.cfg))
52 |
53 | # build run manager
54 | run_manager = RunManager(args.path, net, run_config)
55 |
56 | # load checkpoints
57 | best_model_path = '%s/checkpoint/model_best.pth.tar' % args.path
58 | assert os.path.isfile(best_model_path), 'wrong path'
59 | if torch.cuda.is_available():
60 | checkpoint = torch.load(best_model_path)
61 | else:
62 | checkpoint = torch.load(best_model_path, map_location='cpu')
63 | if 'state_dict' in checkpoint:
64 | checkpoint = checkpoint['state_dict']
65 | run_manager.net.load_state_dict(checkpoint)
66 |
67 | output_dict = {}
68 |
69 | # test
70 | print('Test on test set')
71 | loss, acc1, acc5 = run_manager.validate(is_test=True, return_top5=True)
72 | log = 'test_loss: %f\t test_acc1: %f\t test_acc5: %f' % (loss, acc1, acc5)
73 | run_manager.write_log(log, prefix='test')
74 | output_dict = {
75 | **output_dict,
76 | 'test_loss': '%f' % loss, 'test_acc1': '%f' % acc1, 'test_acc5': '%f' % acc5
77 | }
78 | if args.local_rank == 0:
79 | print(output_dict)
80 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ITPruner: An Information Theory-inspired Strategy for Automatic Network Pruning
2 |
3 | This repository contains all the experiments of our paper "An Information Theory-inspired Strategy for Automatic Network Pruning". It also includes some [pretrain_models](https://drive.google.com/drive/folders/12B741haMD8qGV6x2UMe0yjnIYOTuIokU?usp=sharing) which we list in the paper.
4 |
5 | # Requirements
6 |
7 | * [DALI](https://github.com/NVIDIA/DALI)
8 | * [Apex](https://github.com/NVIDIA/apex)
9 | * [torchprofile](https://github.com/zhijian-liu/torchprofile)
10 | * other requirements, running requirements.txt
11 |
12 | ```python
13 | pip install -r requirements.txt
14 | ```
15 |
16 |
17 |
18 | # Running
19 |
20 | **feature extract**
21 |
22 | You need to download [base models](https://drive.google.com/drive/folders/1jAXw4pEFmaw6fSgqNGRhhvtsM-ZUhZsr?usp=sharing) and copy the path of them to "--path".
23 |
24 | ```python
25 | # [optional]cache imagenet dataset in RAM for accelerting I/O
26 | code_path='/ITPruner/Imagenet/'
27 | chmod +x ${code_path}/prep_imagenet.sh
28 | cd ${code_path}
29 | echo "preparing data"
30 | bash ${code_path}/prep_imagenet.sh >> /dev/null
31 | echo "preparing data finished"
32 |
33 | python3 -m torch.distributed.launch --nproc_per_node=1 feature_extract.py \
34 | --model "mobilenet" \
35 | --path "Exp_base/mobilenet_base" \
36 | --dataset "imagenet" \
37 | --save_path 'your_data_path' \
38 | --target_flops 150000000 \
39 | --beta 243
40 | ```
41 |
42 | or
43 |
44 | ```python
45 | bash ./run_scripts/feature_extract/extract_mobilenet_150m.sh
46 | ```
47 |
48 |
49 |
50 | **Train**
51 |
52 | Because of random seed, cfg obtained through feature extraction may have a little difference from ours. Our cfg are given in .sh files.
53 |
54 | ```python
55 | # [optional]cache imagenet dataset in RAM for accelerting I/O
56 | code_path='/ITPruner/Imagenet/'
57 | chmod +x ${code_path}/prep_imagenet.sh
58 | cd ${code_path}
59 | echo "preparing data"
60 | bash ${code_path}/prep_imagenet.sh >> /dev/null
61 | echo "preparing data finished"
62 |
63 | python3 -m torch.distributed.launch --nproc_per_node=4 train.py --train \
64 | --model "mobilenet" \
65 | --cfg "[23, 42, 63, 67, 132, 134, 275, 194, 202, 202, 194, 277, 522, 687]" \
66 | --path "Exp_train/train_mobilenet_150m_${RANDOM}" \
67 | --dataset "imagenet" \
68 | --save_path 'your_data_path' \
69 | --base_path "Exp_base/mobilenet_base" \
70 | --warm_epoch 1 \
71 | --sync_bn \
72 | --n_epochs 250 \
73 | --label_smoothing 0.1
74 | ```
75 |
76 | or
77 |
78 | ```python
79 | bash ./run_scripts/train/train_mobilenet_150m.sh
80 | ```
81 |
82 |
83 |
84 | **Evaluate**
85 |
86 | We provide some [pretrain_models](https://drive.google.com/drive/folders/12B741haMD8qGV6x2UMe0yjnIYOTuIokU?usp=sharing) which we list in the paper.
87 |
88 | ```python
89 | python3 -m torch.distributed.launch --nproc_per_node=1 evaluate.py \
90 | --model "mobilenet" \
91 | --path "pretrain_models/train_mobilenet_150m_31752" \
92 | --dataset "imagenet" \
93 | --save_path 'your_data_path' \
94 | --cfg "[23, 42, 63, 67, 132, 134, 275, 194, 202, 202, 194, 277, 522, 687]"
95 | ```
96 |
97 | or
98 |
99 | ```python
100 | bash ./run_scripts/evaluate/evaluate_mobilenet_150m.sh
101 | ```
102 |
103 |
--------------------------------------------------------------------------------
/Imagenet/evaluate.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 | import numpy as np
5 |
6 | import torch
7 | from run_manager import RunManager
8 | from models import ResNet_ImageNet, MobileNet, MobileNetV2, TrainRunConfig
9 |
10 | parser = argparse.ArgumentParser()
11 |
12 | """ model config """
13 | parser.add_argument('--path', type=str)
14 | parser.add_argument('--model', type=str, default="vgg",
15 | choices=['vgg', 'resnet56', 'resnet110', 'resnet18', 'resnet34', 'resnet50', 'mobilenetv2', 'mobilenet'])
16 | parser.add_argument('--cfg', type=str, default="None")
17 |
18 | """ dataset config """
19 | parser.add_argument('--dataset', type=str, default='cifar10',
20 | choices=['cifar10', 'imagenet'])
21 | parser.add_argument('--save_path', type=str, default='/userhome/data/cifar10')
22 |
23 | """ runtime config """
24 | parser.add_argument('--gpu', help='gpu available', default='0')
25 | parser.add_argument('--test_batch_size', type=int, default=64)
26 | parser.add_argument('--n_worker', type=int, default=24)
27 | parser.add_argument("--local_rank", default=0, type=int)
28 |
29 | if __name__ == '__main__':
30 | args = parser.parse_args()
31 |
32 | # distributed setting
33 | torch.distributed.init_process_group(backend='nccl',
34 | init_method='env://')
35 | args.world_size = torch.distributed.get_world_size()
36 |
37 | # prepare run config
38 | run_config_path = '%s/run.config' % args.path
39 |
40 | run_config = TrainRunConfig(
41 | **args.__dict__
42 | )
43 | if args.local_rank == 0:
44 | print('Run config:')
45 | for k, v in args.__dict__.items():
46 | print('\t%s: %s' % (k, v))
47 |
48 | if args.model == "resnet50":
49 | assert args.dataset == 'imagenet', 'resnet50 only supports imagenet dataset'
50 | net = ResNet_ImageNet(
51 | depth=50, num_classes=run_config.data_provider.n_classes, cfg=eval(args.cfg))
52 | elif args.model == "mobilenetv2":
53 | assert args.dataset == 'imagenet', 'mobilenetv2 only supports imagenet dataset'
54 | net = MobileNetV2(
55 | num_classes=run_config.data_provider.n_classes, cfg=eval(args.cfg))
56 | elif args.model == "mobilenet":
57 | assert args.dataset == 'imagenet', 'mobilenet only supports imagenet dataset'
58 | net = MobileNet(
59 | num_classes=run_config.data_provider.n_classes, cfg=eval(args.cfg))
60 |
61 | # build run manager
62 | run_manager = RunManager(args.path, net, run_config)
63 |
64 | # load checkpoints
65 | best_model_path = '%s/checkpoint/model_best.pth.tar' % args.path
66 | assert os.path.isfile(best_model_path), 'wrong path'
67 | if torch.cuda.is_available():
68 | checkpoint = torch.load(best_model_path)
69 | else:
70 | checkpoint = torch.load(best_model_path, map_location='cpu')
71 | if 'state_dict' in checkpoint:
72 | checkpoint = checkpoint['state_dict']
73 | run_manager.net.load_state_dict(checkpoint)
74 |
75 | output_dict = {}
76 |
77 | # test
78 | print('Test on test set')
79 | loss, acc1, acc5 = run_manager.validate(is_test=True, return_top5=True)
80 | log = 'test_loss: %f\t test_acc1: %f\t test_acc5: %f' % (loss, acc1, acc5)
81 | run_manager.write_log(log, prefix='test')
82 | output_dict = {
83 | **output_dict,
84 | 'test_loss': '%f' % loss, 'test_acc1': '%f' % acc1, 'test_acc5': '%f' % acc5
85 | }
86 | if args.local_rank == 0:
87 | print(output_dict)
88 |
--------------------------------------------------------------------------------
/ToyExample/models/__init__.py:
--------------------------------------------------------------------------------
1 | from run_manager import RunConfig
2 | from .base_models import *
3 | from .vgg import *
4 |
5 |
6 | class TrainRunConfig(RunConfig):
7 | def __init__(self, n_epochs=150, init_lr=0.05, lr_schedule_type='cosine', lr_schedule_param=None,
8 | dataset='imagenet', train_batch_size=256, test_batch_size=500,
9 | opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.1, no_decay_keys='bn',
10 | model_init='he_fout', init_div_groups=False, validation_frequency=1, print_frequency=10,
11 | n_worker=32, local_rank=0, world_size=1, sync_bn=True,
12 | warm_epoch=5, save_path=None, base_path=None, **kwargs):
13 | super(TrainRunConfig, self).__init__(
14 | n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
15 | dataset, train_batch_size, test_batch_size,
16 | opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
17 | model_init, init_div_groups, validation_frequency, print_frequency, local_rank, world_size, sync_bn,
18 | warm_epoch
19 | )
20 |
21 | self.n_worker = n_worker
22 | self.save_path = save_path
23 | self.base_path = base_path
24 |
25 | print(kwargs.keys())
26 |
27 | @property
28 | def data_config(self):
29 | return {
30 | 'train_batch_size': self.train_batch_size,
31 | 'test_batch_size': self.test_batch_size,
32 | 'n_worker': self.n_worker,
33 | 'local_rank': self.local_rank,
34 | 'world_size': self.world_size,
35 | 'save_path': self.save_path,
36 | }
37 |
38 | class SearchRunConfig(RunConfig):
39 | def __init__(self, n_epochs=150, init_lr=0.05, lr_schedule_type='cosine', lr_schedule_param=None,
40 | dataset='imagenet', train_batch_size=256, test_batch_size=500,
41 | opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.1, no_decay_keys='bn',
42 | model_init='he_fout', init_div_groups=False, validation_frequency=1, print_frequency=10,
43 | n_worker=32, local_rank=0, world_size=1, sync_bn=True, warm_epoch=5, save_path=None,
44 | search_epoch=10, target_flops=1000, n_remove=2, n_best=3, n_populations=8, n_generations=25, div=8, **kwargs):
45 | super(SearchRunConfig, self).__init__(
46 | n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
47 | dataset, train_batch_size, test_batch_size,
48 | opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
49 | model_init, init_div_groups, validation_frequency, print_frequency, local_rank, world_size, sync_bn,
50 | warm_epoch
51 | )
52 |
53 | self.div = div
54 | self.n_worker = n_worker
55 | self.save_path = save_path
56 | self.search_epoch = search_epoch
57 | self.target_flops = target_flops
58 | self.n_remove = n_remove
59 | self.n_best = n_best
60 | self.n_populations = n_populations
61 | self.n_generations = n_generations
62 |
63 | print(kwargs.keys())
64 |
65 | @property
66 | def data_config(self):
67 | return {
68 | 'train_batch_size': self.train_batch_size,
69 | 'test_batch_size': self.test_batch_size,
70 | 'n_worker': self.n_worker,
71 | 'local_rank': self.local_rank,
72 | 'world_size': self.world_size,
73 | 'save_path': self.save_path
74 | }
--------------------------------------------------------------------------------
/Imagenet/models/__init__.py:
--------------------------------------------------------------------------------
1 | from run_manager import RunConfig
2 |
3 | from .base_models import *
4 | from .resnet_imagenet import *
5 | from .mobilenet_imagenet import *
6 |
7 |
8 | class TrainRunConfig(RunConfig):
9 | def __init__(self, n_epochs=150, init_lr=0.05, lr_schedule_type='cosine', lr_schedule_param=None,
10 | dataset='imagenet', train_batch_size=256, test_batch_size=500,
11 | opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.1, no_decay_keys='bn',
12 | model_init='he_fout', init_div_groups=False, validation_frequency=1, print_frequency=10,
13 | n_worker=32, local_rank=0, world_size=1, sync_bn=True,
14 | warm_epoch=5, save_path=None, base_path=None, **kwargs):
15 | super(TrainRunConfig, self).__init__(
16 | n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
17 | dataset, train_batch_size, test_batch_size,
18 | opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
19 | model_init, init_div_groups, validation_frequency, print_frequency, local_rank, world_size, sync_bn,
20 | warm_epoch
21 | )
22 |
23 | self.n_worker = n_worker
24 | self.save_path = save_path
25 | self.base_path = base_path
26 |
27 | print(kwargs.keys())
28 |
29 | @property
30 | def data_config(self):
31 | return {
32 | 'train_batch_size': self.train_batch_size,
33 | 'test_batch_size': self.test_batch_size,
34 | 'n_worker': self.n_worker,
35 | 'local_rank': self.local_rank,
36 | 'world_size': self.world_size,
37 | 'save_path': self.save_path,
38 | }
39 |
40 | class SearchRunConfig(RunConfig):
41 | def __init__(self, n_epochs=150, init_lr=0.05, lr_schedule_type='cosine', lr_schedule_param=None,
42 | dataset='imagenet', train_batch_size=256, test_batch_size=500,
43 | opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.1, no_decay_keys='bn',
44 | model_init='he_fout', init_div_groups=False, validation_frequency=1, print_frequency=10,
45 | n_worker=32, local_rank=0, world_size=1, sync_bn=True, warm_epoch=5, save_path=None,
46 | search_epoch=10, target_flops=1000, n_remove=2, n_best=3, n_populations=8, n_generations=25, div=8, **kwargs):
47 | super(SearchRunConfig, self).__init__(
48 | n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
49 | dataset, train_batch_size, test_batch_size,
50 | opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
51 | model_init, init_div_groups, validation_frequency, print_frequency, local_rank, world_size, sync_bn,
52 | warm_epoch
53 | )
54 |
55 | self.div = div
56 | self.n_worker = n_worker
57 | self.save_path = save_path
58 | self.search_epoch = search_epoch
59 | self.target_flops = target_flops
60 | self.n_remove = n_remove
61 | self.n_best = n_best
62 | self.n_populations = n_populations
63 | self.n_generations = n_generations
64 |
65 | print(kwargs.keys())
66 |
67 | @property
68 | def data_config(self):
69 | return {
70 | 'train_batch_size': self.train_batch_size,
71 | 'test_batch_size': self.test_batch_size,
72 | 'n_worker': self.n_worker,
73 | 'local_rank': self.local_rank,
74 | 'world_size': self.world_size,
75 | 'save_path': self.save_path
76 | }
--------------------------------------------------------------------------------
/Cifar/nets/__init__.py:
--------------------------------------------------------------------------------
1 | from run_manager import RunConfig
2 | from .resnet_cifar import *
3 |
4 |
5 | class TrainRunConfig(RunConfig):
6 | def __init__(self, n_epochs=150, init_lr=0.05, lr_schedule_type='cosine', lr_schedule_param=None,
7 | dataset='imagenet', train_batch_size=256, test_batch_size=500,
8 | opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.1, no_decay_keys='bn',
9 | model_init='he_fout', init_div_groups=False, validation_frequency=1, print_frequency=10,
10 | n_worker=32, local_rank=0, world_size=1, sync_bn=True,
11 | warm_epoch=5, save_path=None, base_path=None, manual_seed=12, **kwargs):
12 | super(TrainRunConfig, self).__init__(
13 | n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
14 | dataset, train_batch_size, test_batch_size,
15 | opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
16 | model_init, init_div_groups, validation_frequency, print_frequency, local_rank, world_size, sync_bn,
17 | warm_epoch
18 | )
19 |
20 | self.n_worker = n_worker
21 | self.save_path = save_path
22 | self.base_path = base_path
23 | self.manual_seed = manual_seed
24 |
25 | print(kwargs.keys())
26 |
27 | @property
28 | def data_config(self):
29 | return {
30 | 'train_batch_size': self.train_batch_size,
31 | 'test_batch_size': self.test_batch_size,
32 | 'n_worker': self.n_worker,
33 | 'local_rank': self.local_rank,
34 | 'world_size': self.world_size,
35 | 'save_path': self.save_path,
36 | 'manual_seed': self.manual_seed,
37 | }
38 |
39 |
40 | class SearchRunConfig(RunConfig):
41 | def __init__(self, n_epochs=150, init_lr=0.05, lr_schedule_type='cosine', lr_schedule_param=None,
42 | dataset='imagenet', train_batch_size=256, test_batch_size=500,
43 | opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.1, no_decay_keys='bn',
44 | model_init='he_fout', init_div_groups=False, validation_frequency=1, print_frequency=10,
45 | n_worker=32, local_rank=0, world_size=1, sync_bn=True, warm_epoch=5, save_path=None,
46 | search_epoch=10, target_flops=1000, n_remove=2, n_best=3, n_populations=8, n_generations=25, div=8, **kwargs):
47 | super(SearchRunConfig, self).__init__(
48 | n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
49 | dataset, train_batch_size, test_batch_size,
50 | opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
51 | model_init, init_div_groups, validation_frequency, print_frequency, local_rank, world_size, sync_bn,
52 | warm_epoch
53 | )
54 |
55 | self.div = div
56 | self.n_worker = n_worker
57 | self.save_path = save_path
58 | self.search_epoch = search_epoch
59 | self.target_flops = target_flops
60 | self.n_remove = n_remove
61 | self.n_best = n_best
62 | self.n_populations = n_populations
63 | self.n_generations = n_generations
64 |
65 | print(kwargs.keys())
66 |
67 | @property
68 | def data_config(self):
69 | return {
70 | 'train_batch_size': self.train_batch_size,
71 | 'test_batch_size': self.test_batch_size,
72 | 'n_worker': self.n_worker,
73 | 'local_rank': self.local_rank,
74 | 'world_size': self.world_size,
75 | 'save_path': self.save_path
76 | }
77 |
--------------------------------------------------------------------------------
/Imagenet/cal_FLOPs.py:
--------------------------------------------------------------------------------
1 | #coding:utf8
2 | import torch
3 | import torchvision
4 |
5 | import torch.nn as nn
6 | from torch.autograd import Variable
7 | import torchvision.models as models
8 |
9 | import numpy as np
10 |
11 | def print_model_parm_flops(one_shot_model):
12 | prods = {}
13 | def save_hook(name):
14 | def hook_per(self, input, output):
15 | prods[name] = np.prod(input[0].shape)
16 | return hook_per
17 |
18 | list_1=[]
19 | def simple_hook(self, input, output):
20 | list_1.append(np.prod(input[0].shape))
21 | list_2={}
22 | def simple_hook2(self, input, output):
23 | list_2['names'] = np.prod(input[0].shape)
24 |
25 |
26 | multiply_adds = False
27 | list_conv=[]
28 | def conv_hook(self, input, output):
29 | batch_size, input_channels, input_height, input_width = input[0].size()
30 | output_channels, output_height, output_width = output[0].size()
31 |
32 | kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) * (2 if multiply_adds else 1)
33 | bias_ops = 1 if self.bias is not None else 0
34 |
35 | params = output_channels * (kernel_ops + bias_ops)
36 | flops = batch_size * params * output_height * output_width
37 |
38 | list_conv.append(flops)
39 |
40 |
41 | list_linear=[]
42 | def linear_hook(self, input, output):
43 | batch_size = input[0].size(0) if input[0].dim() == 2 else 1
44 |
45 | weight_ops = self.weight.nelement() * (2 if multiply_adds else 1)
46 | bias_ops = self.bias.nelement()
47 |
48 | flops = batch_size * (weight_ops + bias_ops)
49 | list_linear.append(flops)
50 |
51 | list_bn=[]
52 | def bn_hook(self, input, output):
53 | list_bn.append(input[0].nelement())
54 |
55 | list_relu=[]
56 | def relu_hook(self, input, output):
57 | list_relu.append(input[0].nelement())
58 |
59 | list_pooling=[]
60 | def pooling_hook(self, input, output):
61 | batch_size, input_channels, input_height, input_width = input[0].size()
62 | output_channels, output_height, output_width = output[0].size()
63 |
64 | kernel_ops = self.kernel_size * self.kernel_size
65 | bias_ops = 0
66 | params = output_channels * (kernel_ops + bias_ops)
67 | flops = batch_size * params * output_height * output_width
68 |
69 | list_pooling.append(flops)
70 |
71 | def foo(net):
72 | childrens = list(net.children())
73 | if not childrens:
74 | if isinstance(net, torch.nn.Conv2d):
75 | net.register_forward_hook(conv_hook)
76 | if isinstance(net, torch.nn.Linear):
77 | net.register_forward_hook(linear_hook)
78 | if isinstance(net, torch.nn.BatchNorm2d):
79 | net.register_forward_hook(bn_hook)
80 | if isinstance(net, torch.nn.ReLU):
81 | net.register_forward_hook(relu_hook)
82 | if isinstance(net, torch.nn.MaxPool2d) or isinstance(net, torch.nn.AvgPool2d):
83 | net.register_forward_hook(pooling_hook)
84 | return
85 | for c in childrens:
86 | foo(c)
87 |
88 | foo(one_shot_model)
89 | input = Variable(torch.rand(3,224,224).unsqueeze(0), requires_grad = True).cuda()
90 | out = one_shot_model(input)
91 |
92 | total_flops = (sum(list_conv) + sum(list_linear) + sum(list_bn) + sum(list_relu) + sum(list_pooling))
93 | M_flops = total_flops / 1e6
94 | #print(' + Number of FLOPs: %.2fM' % (M_flops))
95 |
96 | return M_flops
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
--------------------------------------------------------------------------------
/ToyExample/cal_FLOPs.py:
--------------------------------------------------------------------------------
1 | #coding:utf8
2 | import torch
3 | import torchvision
4 |
5 | import torch.nn as nn
6 | from torch.autograd import Variable
7 | import torchvision.models as models
8 |
9 | import numpy as np
10 |
11 | def print_model_parm_flops(one_shot_model):
12 | prods = {}
13 | def save_hook(name):
14 | def hook_per(self, input, output):
15 | prods[name] = np.prod(input[0].shape)
16 | return hook_per
17 |
18 | list_1=[]
19 | def simple_hook(self, input, output):
20 | list_1.append(np.prod(input[0].shape))
21 | list_2={}
22 | def simple_hook2(self, input, output):
23 | list_2['names'] = np.prod(input[0].shape)
24 |
25 |
26 | multiply_adds = False
27 | list_conv=[]
28 | def conv_hook(self, input, output):
29 | batch_size, input_channels, input_height, input_width = input[0].size()
30 | output_channels, output_height, output_width = output[0].size()
31 |
32 | kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) * (2 if multiply_adds else 1)
33 | bias_ops = 1 if self.bias is not None else 0
34 |
35 | params = output_channels * (kernel_ops + bias_ops)
36 | flops = batch_size * params * output_height * output_width
37 |
38 | list_conv.append(flops)
39 |
40 |
41 | list_linear=[]
42 | def linear_hook(self, input, output):
43 | batch_size = input[0].size(0) if input[0].dim() == 2 else 1
44 |
45 | weight_ops = self.weight.nelement() * (2 if multiply_adds else 1)
46 | bias_ops = self.bias.nelement()
47 |
48 | flops = batch_size * (weight_ops + bias_ops)
49 | list_linear.append(flops)
50 |
51 | list_bn=[]
52 | def bn_hook(self, input, output):
53 | list_bn.append(input[0].nelement())
54 |
55 | list_relu=[]
56 | def relu_hook(self, input, output):
57 | list_relu.append(input[0].nelement())
58 |
59 | list_pooling=[]
60 | def pooling_hook(self, input, output):
61 | batch_size, input_channels, input_height, input_width = input[0].size()
62 | output_channels, output_height, output_width = output[0].size()
63 |
64 | kernel_ops = self.kernel_size * self.kernel_size
65 | bias_ops = 0
66 | params = output_channels * (kernel_ops + bias_ops)
67 | flops = batch_size * params * output_height * output_width
68 |
69 | list_pooling.append(flops)
70 |
71 | def foo(net):
72 | childrens = list(net.children())
73 | if not childrens:
74 | if isinstance(net, torch.nn.Conv2d):
75 | net.register_forward_hook(conv_hook)
76 | if isinstance(net, torch.nn.Linear):
77 | net.register_forward_hook(linear_hook)
78 | if isinstance(net, torch.nn.BatchNorm2d):
79 | net.register_forward_hook(bn_hook)
80 | if isinstance(net, torch.nn.ReLU):
81 | net.register_forward_hook(relu_hook)
82 | if isinstance(net, torch.nn.MaxPool2d) or isinstance(net, torch.nn.AvgPool2d):
83 | net.register_forward_hook(pooling_hook)
84 | return
85 | for c in childrens:
86 | foo(c)
87 |
88 | foo(one_shot_model)
89 | input = Variable(torch.rand(3,224,224).unsqueeze(0), requires_grad = True).cuda()
90 | out = one_shot_model(input)
91 |
92 | total_flops = (sum(list_conv) + sum(list_linear) + sum(list_bn) + sum(list_relu) + sum(list_pooling))
93 | M_flops = total_flops / 1e6
94 | #print(' + Number of FLOPs: %.2fM' % (M_flops))
95 |
96 | return M_flops
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
--------------------------------------------------------------------------------
/Cifar/nets/base_models.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | from utils import count_parameters
5 |
6 | class MyNetwork(nn.Module):
7 | def forward(self, x):
8 | raise NotImplementedError
9 |
10 | def feature_extract(self, x):
11 | raise NotImplementedError
12 |
13 | @property
14 | def config(self): # should include name/cfg/cfg_base/dataset
15 | raise NotImplementedError
16 |
17 | def cfg2params(self, cfg):
18 | raise NotImplementedError
19 |
20 | def cfg2flops(self, cfg):
21 | raise NotImplementedError
22 |
23 | def set_bn_param(self, momentum, eps):
24 | for m in self.modules():
25 | if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
26 | m.momentum = momentum
27 | m.eps = eps
28 | return
29 |
30 | def get_bn_param(self):
31 | for m in self.modules():
32 | if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
33 | return {
34 | 'momentum': m.momentum,
35 | 'eps': m.eps,
36 | }
37 | return None
38 |
39 | def init_model(self, model_init, init_div_groups=False):
40 | for m in self.modules():
41 | if isinstance(m, nn.Conv2d):
42 | if model_init == 'he_fout':
43 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
44 | if init_div_groups:
45 | n /= m.groups
46 | m.weight.data.normal_(0, math.sqrt(2. / n))
47 | elif model_init == 'he_fin':
48 | n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
49 | if init_div_groups:
50 | n /= m.groups
51 | m.weight.data.normal_(0, math.sqrt(2. / n))
52 | elif model_init == 'xavier_normal':
53 | nn.init.xavier_normal_(m.weight.data)
54 | elif model_init == 'xavier_uniform':
55 | nn.init.xavier_uniform_(m.weight.data)
56 | else:
57 | raise NotImplementedError
58 | elif isinstance(m, nn.BatchNorm2d):
59 | m.weight.data.fill_(1)
60 | m.bias.data.zero_()
61 | elif isinstance(m, nn.Linear):
62 | stdv = 1. / math.sqrt(m.weight.size(1))
63 | m.weight.data.uniform_(-stdv, stdv)
64 | if m.bias is not None:
65 | m.bias.data.zero_()
66 | elif isinstance(m, nn.BatchNorm1d):
67 | m.weight.data.fill_(1)
68 | m.bias.data.zero_()
69 |
70 | def get_parameters(self, keys=None, mode='include'):
71 | if keys is None:
72 | for name, param in self.named_parameters():
73 | yield param
74 | elif mode == 'include':
75 | for name, param in self.named_parameters():
76 | flag = False
77 | for key in keys:
78 | if key in name:
79 | flag = True
80 | break
81 | if flag:
82 | yield param
83 | elif mode == 'exclude':
84 | for name, param in self.named_parameters():
85 | flag = True
86 | for key in keys:
87 | if key in name:
88 | flag = False
89 | break
90 | if flag:
91 | yield param
92 | else:
93 | raise ValueError('do not support: %s' % mode)
94 |
95 | def weight_parameters(self):
96 | return self.get_parameters()
97 |
--------------------------------------------------------------------------------
/Imagenet/models/base_models.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | from utils import count_parameters
5 |
6 | class MyNetwork(nn.Module):
7 | def forward(self, x):
8 | raise NotImplementedError
9 |
10 | def feature_extract(self, x):
11 | raise NotImplementedError
12 |
13 | @property
14 | def config(self): # should include name/cfg/cfg_base/dataset
15 | raise NotImplementedError
16 |
17 | def cfg2params(self, cfg):
18 | raise NotImplementedError
19 |
20 | def cfg2flops(self, cfg):
21 | raise NotImplementedError
22 |
23 | def set_bn_param(self, momentum, eps):
24 | for m in self.modules():
25 | if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
26 | m.momentum = momentum
27 | m.eps = eps
28 | return
29 |
30 | def get_bn_param(self):
31 | for m in self.modules():
32 | if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
33 | return {
34 | 'momentum': m.momentum,
35 | 'eps': m.eps,
36 | }
37 | return None
38 |
39 | def init_model(self, model_init, init_div_groups=False):
40 | for m in self.modules():
41 | if isinstance(m, nn.Conv2d):
42 | if model_init == 'he_fout':
43 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
44 | if init_div_groups:
45 | n /= m.groups
46 | m.weight.data.normal_(0, math.sqrt(2. / n))
47 | elif model_init == 'he_fin':
48 | n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
49 | if init_div_groups:
50 | n /= m.groups
51 | m.weight.data.normal_(0, math.sqrt(2. / n))
52 | elif model_init == 'xavier_normal':
53 | nn.init.xavier_normal_(m.weight.data)
54 | elif model_init == 'xavier_uniform':
55 | nn.init.xavier_uniform_(m.weight.data)
56 | else:
57 | raise NotImplementedError
58 | elif isinstance(m, nn.BatchNorm2d):
59 | m.weight.data.fill_(1)
60 | m.bias.data.zero_()
61 | elif isinstance(m, nn.Linear):
62 | stdv = 1. / math.sqrt(m.weight.size(1))
63 | m.weight.data.uniform_(-stdv, stdv)
64 | if m.bias is not None:
65 | m.bias.data.zero_()
66 | elif isinstance(m, nn.BatchNorm1d):
67 | m.weight.data.fill_(1)
68 | m.bias.data.zero_()
69 |
70 | def get_parameters(self, keys=None, mode='include'):
71 | if keys is None:
72 | for name, param in self.named_parameters():
73 | yield param
74 | elif mode == 'include':
75 | for name, param in self.named_parameters():
76 | flag = False
77 | for key in keys:
78 | if key in name:
79 | flag = True
80 | break
81 | if flag:
82 | yield param
83 | elif mode == 'exclude':
84 | for name, param in self.named_parameters():
85 | flag = True
86 | for key in keys:
87 | if key in name:
88 | flag = False
89 | break
90 | if flag:
91 | yield param
92 | else:
93 | raise ValueError('do not support: %s' % mode)
94 |
95 | def weight_parameters(self):
96 | return self.get_parameters()
97 |
--------------------------------------------------------------------------------
/ToyExample/models/base_models.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | from utils import count_parameters
5 |
6 | class MyNetwork(nn.Module):
7 | def forward(self, x):
8 | raise NotImplementedError
9 |
10 | def feature_extract(self, x):
11 | raise NotImplementedError
12 |
13 | @property
14 | def config(self): # should include name/cfg/cfg_base/dataset
15 | raise NotImplementedError
16 |
17 | def cfg2params(self, cfg):
18 | raise NotImplementedError
19 |
20 | def cfg2flops(self, cfg):
21 | raise NotImplementedError
22 |
23 | def set_bn_param(self, momentum, eps):
24 | for m in self.modules():
25 | if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
26 | m.momentum = momentum
27 | m.eps = eps
28 | return
29 |
30 | def get_bn_param(self):
31 | for m in self.modules():
32 | if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
33 | return {
34 | 'momentum': m.momentum,
35 | 'eps': m.eps,
36 | }
37 | return None
38 |
39 | def init_model(self, model_init, init_div_groups=False):
40 | for m in self.modules():
41 | if isinstance(m, nn.Conv2d):
42 | if model_init == 'he_fout':
43 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
44 | if init_div_groups:
45 | n /= m.groups
46 | m.weight.data.normal_(0, math.sqrt(2. / n))
47 | elif model_init == 'he_fin':
48 | n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
49 | if init_div_groups:
50 | n /= m.groups
51 | m.weight.data.normal_(0, math.sqrt(2. / n))
52 | elif model_init == 'xavier_normal':
53 | nn.init.xavier_normal_(m.weight.data)
54 | elif model_init == 'xavier_uniform':
55 | nn.init.xavier_uniform_(m.weight.data)
56 | else:
57 | raise NotImplementedError
58 | elif isinstance(m, nn.BatchNorm2d):
59 | m.weight.data.fill_(1)
60 | m.bias.data.zero_()
61 | elif isinstance(m, nn.Linear):
62 | stdv = 1. / math.sqrt(m.weight.size(1))
63 | m.weight.data.uniform_(-stdv, stdv)
64 | if m.bias is not None:
65 | m.bias.data.zero_()
66 | elif isinstance(m, nn.BatchNorm1d):
67 | m.weight.data.fill_(1)
68 | m.bias.data.zero_()
69 |
70 | def get_parameters(self, keys=None, mode='include'):
71 | if keys is None:
72 | for name, param in self.named_parameters():
73 | yield param
74 | elif mode == 'include':
75 | for name, param in self.named_parameters():
76 | flag = False
77 | for key in keys:
78 | if key in name:
79 | flag = True
80 | break
81 | if flag:
82 | yield param
83 | elif mode == 'exclude':
84 | for name, param in self.named_parameters():
85 | flag = True
86 | for key in keys:
87 | if key in name:
88 | flag = False
89 | break
90 | if flag:
91 | yield param
92 | else:
93 | raise ValueError('do not support: %s' % mode)
94 |
95 | def weight_parameters(self):
96 | return self.get_parameters()
97 |
--------------------------------------------------------------------------------
/Imagenet/data_providers/imagenet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import numpy as np
4 | import torch.utils.data
5 |
6 | import utils
7 | from data_providers import DataProvider
8 |
9 |
10 | def make_imagenet_subset(path2subset, n_sub_classes, path2imagenet='/userhome/memory_data/imagenet'):
11 | imagenet_train_folder = os.path.join(path2imagenet, 'train')
12 | imagenet_val_folder = os.path.join(path2imagenet, 'val')
13 |
14 | subfolders = sorted([f.path for f in os.scandir(imagenet_train_folder) if f.is_dir()])
15 | # np.random.seed(DataProvider.VALID_SEED)
16 | np.random.shuffle(subfolders)
17 |
18 | chosen_train_folders = subfolders[:n_sub_classes]
19 | class_name_list = []
20 | for train_folder in chosen_train_folders:
21 | class_name = train_folder.split('/')[-1]
22 | class_name_list.append(class_name)
23 |
24 | print('=> Start building subset%d' % n_sub_classes)
25 | for cls_name in class_name_list:
26 | src_train_folder = os.path.join(imagenet_train_folder, cls_name)
27 | target_train_folder = os.path.join(path2subset, 'train/%s' % cls_name)
28 | shutil.copytree(src_train_folder, target_train_folder)
29 | print('Train: %s -> %s' % (src_train_folder, target_train_folder))
30 |
31 | src_val_folder = os.path.join(imagenet_val_folder, cls_name)
32 | target_val_folder = os.path.join(path2subset, 'val/%s' % cls_name)
33 | shutil.copytree(src_val_folder, target_val_folder)
34 | print('Val: %s -> %s' % (src_val_folder, target_val_folder))
35 | print('=> Finish building subset%d' % n_sub_classes)
36 |
37 |
38 | class ImagenetDataProvider(DataProvider):
39 | def __init__(self, save_path=None, train_batch_size=256, test_batch_size=512, valid_size=None,
40 | n_worker=24, manual_seed = 12, load_type='dali', local_rank=0, world_size=1, **kwargs):
41 |
42 | self._save_path = save_path
43 | self.valid = None
44 | if valid_size is not None:
45 | pass
46 | else:
47 | self.train = utils.get_imagenet_iter(data_type='train', image_dir=self.train_path,
48 | batch_size=train_batch_size, num_threads=n_worker,
49 | device_id=local_rank, manual_seed=manual_seed,
50 | num_gpus=torch.cuda.device_count(), crop=self.image_size,
51 | val_size=self.image_size, world_size=world_size, local_rank=local_rank)
52 | self.test = utils.get_imagenet_iter(data_type='val', image_dir=self.valid_path, manual_seed=manual_seed,
53 | batch_size=test_batch_size, num_threads=n_worker, device_id=local_rank,
54 | num_gpus=torch.cuda.device_count(), crop=self.image_size,
55 | val_size=256, world_size=world_size, local_rank=local_rank)
56 | if self.valid is None:
57 | self.valid = self.test
58 |
59 | @staticmethod
60 | def name():
61 | return 'imagenet'
62 |
63 | @property
64 | def data_shape(self):
65 | return 3, self.image_size, self.image_size # C, H, W
66 |
67 | @property
68 | def n_classes(self):
69 | return 1000
70 |
71 | @property
72 | def save_path(self):
73 | if self._save_path is None:
74 | self._save_path = '/userhome/data/imagenet'
75 | return self._save_path
76 |
77 | @property
78 | def data_url(self):
79 | raise ValueError('unable to download ImageNet')
80 |
81 | @property
82 | def train_path(self):
83 | return os.path.join(self.save_path, 'train')
84 |
85 | @property
86 | def valid_path(self):
87 | return os.path.join(self._save_path, 'val')
88 |
89 | @property
90 | def resize_value(self):
91 | return 256
92 |
93 | @property
94 | def image_size(self):
95 | return 224
96 |
--------------------------------------------------------------------------------
/Cifar/nets/resnet_20s_decode.py:
--------------------------------------------------------------------------------
1 | '''
2 | Decode model from alpha
3 | '''
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import torch.nn.init as init
9 | from torch.autograd import Variable
10 | from utils.compute_flops import *
11 | from nets.se_module import SELayer
12 | import argparse
13 | from pdb import set_trace as br
14 |
15 | __all__ = ['resnet_decode']
16 |
17 |
18 | def _weights_init(m):
19 | classname = m.__class__.__name__
20 | # print(classname)
21 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
22 | init.kaiming_normal_(m.weight)
23 |
24 |
25 | class DShortCut(nn.Module):
26 | def __init__(self, cin, cout, has_avg, has_BN, affine=True):
27 | super(DShortCut, self).__init__()
28 | self.conv = nn.Conv2d(cin, cout, kernel_size=1, stride=1, padding=0, bias=False)
29 | if has_avg:
30 | self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
31 | else:
32 | self.avg = None
33 |
34 | if has_BN:
35 | self.bn = nn.BatchNorm2d(cout, affine=affine)
36 | else:
37 | self.bn = None
38 |
39 | def forward(self, x):
40 | if self.avg:
41 | out = self.avg(x)
42 | else:
43 | out = x
44 |
45 | out = self.conv(out)
46 | if self.bn:
47 | out = self.bn(out)
48 | return out
49 |
50 |
51 | def ChannelWiseInterV2(inputs, oC, downsample=False):
52 | assert inputs.dim() == 4, 'invalid dimension : {:}'.format(inputs.size())
53 | batch, C, H, W = inputs.size()
54 |
55 | if downsample:
56 | inputs = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)(inputs)
57 | H, W = inputs.shape[2], inputs.shape[3]
58 | if C == oC:
59 | return inputs
60 | else:
61 | return nn.functional.adaptive_avg_pool3d(inputs, (oC,H,W))
62 |
63 |
64 | class BasicBlock(nn.Module):
65 | # expansion = 1
66 |
67 | def __init__(self, in_planes, cfg, stride=1, affine=True, \
68 | se=False, se_reduction=-1):
69 | """ Args:
70 | in_planes: an int, the input chanels;
71 | cfg: a list of int, the mid and output channels;
72 | se: whether use SE module or not
73 | se_reduction: the mid width for se module
74 | """
75 | super(BasicBlock, self).__init__()
76 | assert len(cfg) == 2, 'wrong cfg length'
77 | mid_planes, planes = cfg[0], cfg[1]
78 | self.conv1 = nn.Conv2d(in_planes, mid_planes, kernel_size=3, stride=stride, padding=1, bias=False)
79 | self.bn1 = nn.BatchNorm2d(mid_planes, affine=affine)
80 | self.relu1 = nn.ReLU()
81 | self.conv2 = nn.Conv2d(mid_planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
82 | self.bn2 = nn.BatchNorm2d(planes, affine=affine)
83 | self.relu2 = nn.ReLU()
84 |
85 | self.se = se
86 | self.se_reduction = se_reduction
87 | if self.se:
88 | assert se_reduction > 0, "Must specify se reduction > 0"
89 | self.se_module = SELayer(planes, se_reduction)
90 |
91 | if stride == 2:
92 | self.shortcut = DShortCut(in_planes, planes, has_avg=True, has_BN=False)
93 | elif in_planes != planes:
94 | self.shortcut = DShortCut(in_planes, planes, has_avg=False, has_BN=True)
95 | else:
96 | self.shortcut = nn.Sequential()
97 |
98 | def forward(self, x):
99 | out = self.relu1(self.bn1(self.conv1(x)))
100 | out = self.bn2(self.conv2(out))
101 | if self.se:
102 | out = self.se_module(out)
103 |
104 | out += self.shortcut(x)
105 | out = self.relu2(out)
106 | return out
107 |
108 |
109 | class ResNetDecode(nn.Module):
110 | """
111 | Re-build the resnet based on the cfg obtained from alpha.
112 | """
113 | def __init__(self, block, cfg, num_classes=10, affine=True, se=False, se_reduction=-1):
114 | super(ResNetDecode, self).__init__()
115 | # self.in_planes = 16
116 | self.cfg = cfg
117 | self.affine = affine
118 | self.se = se
119 | self.se_reduction = se_reduction
120 |
121 | self.conv1 = nn.Conv2d(3, self.cfg[0], kernel_size=3, stride=1, padding=1, bias=False)
122 | self.bn1 = nn.BatchNorm2d(self.cfg[0], affine=self.affine)
123 | self.relu1 = nn.ReLU()
124 | num_blocks = [(len(self.cfg) - 1) // 2 // 3] * 3 # [3] * 3
125 |
126 | count = 1
127 | self.layer1 = self._make_layer(block, self.cfg[count-1:count+num_blocks[0]*2-1], self.cfg[count:count+num_blocks[0]*2], stride=1)
128 | count += num_blocks[0]*2
129 | self.layer2 = self._make_layer(block, self.cfg[count-1:count+num_blocks[1]*2-1], self.cfg[count:count+num_blocks[1]*2], stride=2)
130 | count += num_blocks[1]*2
131 | self.layer3 = self._make_layer(block, self.cfg[count-1:count+num_blocks[2]*2-1], self.cfg[count:count+num_blocks[2]*2], stride=2)
132 | count += num_blocks[2]*2
133 | assert count == len(self.cfg)
134 | self.linear = nn.Linear(self.cfg[-1], num_classes)
135 | self.apply(_weights_init)
136 |
137 | def _make_layer(self, block, planes, cfg, stride):
138 | num_block = len(cfg) // 2
139 | strides = [stride] + [1]*(num_block-1)
140 | layers = nn.ModuleList()
141 | count = 0
142 | for idx, stride in enumerate(strides):
143 | layers.append(block(planes[count], cfg[count:count+2], stride, affine=self.affine, se=self.se, se_reduction=self.se_reduction))
144 | count += 2
145 | assert count == len(cfg), 'cfg and block num mismatch'
146 | return nn.Sequential(*layers)
147 |
148 | def forward(self, x):
149 | out = self.relu1(self.bn1(self.conv1(x)))
150 | out = self.layer1(out)
151 | out = self.layer2(out)
152 | out = self.layer3(out)
153 | avgpool = nn.AvgPool2d(out.size()[3])
154 | out = avgpool(out)
155 |
156 | out = out.view(out.size(0), -1)
157 | out = self.linear(out)
158 | return out
159 |
160 | def resnet_decode(cfg, num_classes, se=False, se_reduction=-1):
161 | return ResNetDecode(BasicBlock, cfg, num_classes=num_classes, se=se, se_reduction=se_reduction)
162 |
--------------------------------------------------------------------------------
/ToyExample/feature_extract.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import numpy as np
4 | import math
5 | from scipy import optimize
6 | import random, sys
7 | sys.path.append("..")
8 | from CKA import cka
9 | import torch
10 | import torch.nn as nn
11 | from run_manager import RunManager
12 | from models import VGG_CIFAR, TrainRunConfig
13 |
14 | parser = argparse.ArgumentParser()
15 |
16 | """ model config """
17 | parser.add_argument('--path', type=str)
18 | parser.add_argument('--model', type=str, default="vgg",
19 | choices=['vgg'])
20 | parser.add_argument('--cfg', type=str, default="None")
21 | parser.add_argument('--manual_seed', default=0, type=int)
22 | parser.add_argument("--target_flops", default=0, type=int)
23 | parser.add_argument("--beta", default=1, type=int)
24 |
25 | """ dataset config """
26 | parser.add_argument('--dataset', type=str, default='cifar10',
27 | choices=['cifar10'])
28 | parser.add_argument('--save_path', type=str, default='/userhome/data/cifar10')
29 |
30 | """ runtime config """
31 | parser.add_argument('--gpu', help='gpu available', default='0')
32 | parser.add_argument('--test_batch_size', type=int, default=100)
33 | parser.add_argument('--n_worker', type=int, default=24)
34 | parser.add_argument("--local_rank", default=0, type=int)
35 |
36 | if __name__ == '__main__':
37 | args = parser.parse_args()
38 |
39 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
40 | torch.cuda.set_device(0)
41 |
42 | random.seed(args.manual_seed)
43 | torch.manual_seed(args.manual_seed)
44 | torch.cuda.manual_seed_all(args.manual_seed)
45 | np.random.seed(args.manual_seed)
46 | # distributed setting
47 | torch.distributed.init_process_group(backend='nccl',
48 | init_method='env://')
49 | args.world_size = torch.distributed.get_world_size()
50 |
51 | # prepare run config
52 | run_config_path = '%s/run.config' % args.path
53 |
54 | run_config = TrainRunConfig(
55 | **args.__dict__
56 | )
57 | if args.local_rank == 0:
58 | print('Run config:')
59 | for k, v in args.__dict__.items():
60 | print('\t%s: %s' % (k, v))
61 |
62 | if args.model == "vgg":
63 | assert args.dataset == 'cifar10', 'vgg only supports cifar10 dataset'
64 | net = VGG_CIFAR(
65 | num_classes=run_config.data_provider.n_classes, cfg=eval(args.cfg))
66 |
67 | # build run manager
68 | run_manager = RunManager(args.path, net, run_config)
69 |
70 | # load checkpoints
71 | best_model_path = '%s/checkpoint/model_best.pth.tar' % args.path
72 | assert os.path.isfile(best_model_path), 'wrong path'
73 | if torch.cuda.is_available():
74 | checkpoint = torch.load(best_model_path)
75 | else:
76 | checkpoint = torch.load(best_model_path, map_location='cpu')
77 | if 'state_dict' in checkpoint:
78 | checkpoint = checkpoint['state_dict']
79 | run_manager.net.load_state_dict(checkpoint)
80 | output_dict = {}
81 |
82 | # feature extract
83 | data_loader = run_manager.run_config.test_loader
84 | data = next(iter(data_loader))
85 | data = data[0]
86 | n = data.size()[0]
87 |
88 | with torch.no_grad():
89 | feature = net.feature_extract(data)
90 |
91 | for i in range(len(feature)):
92 | feature[i] = feature[i].view(n, -1)
93 | feature[i] = feature[i].data.cpu().numpy()
94 |
95 | similar_matrix = np.zeros((len(feature), len(feature)))
96 |
97 | for i in range(len(feature)):
98 | for j in range(len(feature)):
99 | with torch.no_grad():
100 | similar_matrix[i][j] = cka.cka(cka.gram_linear(feature[i]), cka.gram_linear(feature[j]))
101 |
102 | def sum_list(a, j):
103 | b = 0
104 | for i in range(len(a)):
105 | if i != j:
106 | b += a[i]
107 | return b
108 |
109 | important = []
110 | temp = []
111 | flops = []
112 |
113 | for i in range(len(feature)):
114 | temp.append( sum_list(similar_matrix[i], i) )
115 |
116 | if args.beta !=1:
117 | b = sum_list(temp, -1)
118 | temp = [x / b for x in temp]
119 | print("hahahahahahahahahahahah")
120 |
121 | for i in range(len(feature)):
122 | important.append(math.exp(-1 * args.beta * temp[i]))
123 |
124 | length = len(net.cfg)
125 | flops_singlecfg, flops_doublecfg, flops_squarecfg = net.cfg2flops_perlayer(net.cfg, length)
126 | important = np.array(important)
127 | important = np.negative(important)
128 |
129 | # Objective function
130 | def func(x, sign=1.0):
131 | """ Objective function """
132 | global important,length
133 | sum_fuc =[]
134 | for i in range(length):
135 | sum_fuc.append(x[i]*important[i])
136 | return sum(sum_fuc)
137 |
138 | # Derivative function of objective function
139 | def func_deriv(x, sign=1.0):
140 | """ Derivative of objective function """
141 | global important
142 | diff = []
143 | for i in range(len(important)):
144 | diff.append(sign * (important[i]))
145 | return np.array(diff)
146 |
147 | # Constraint function
148 | def constrain_func(x):
149 | """ constrain function """
150 | global flops_singlecfg, flops_doublecfg, flops_squarecfg, length
151 | a = []
152 | for i in range(length):
153 | a.append(x[i] * flops_singlecfg[i])
154 | a.append(flops_squarecfg[i] * x[i] * x[i])
155 | for i in range(1,length):
156 | for j in range(i):
157 | a.append(x[i] * x[j] * flops_doublecfg[i][j])
158 | return np.array([args.target_flops - sum(a)])
159 |
160 |
161 | bnds = []
162 | for i in range(length):
163 | bnds.append((0,1))
164 |
165 | bnds = tuple(bnds)
166 | cons = ({'type': 'ineq',
167 | 'fun': constrain_func})
168 |
169 | result = optimize.minimize(func,x0=[1 for i in range(length)], jac=func_deriv, method='SLSQP', bounds=bnds, constraints=cons)
170 | prun_cfg = np.around(np.array(net.cfg)*result.x)
171 |
172 | optimize_cfg = []
173 | for i in range(len(prun_cfg)):
174 | b = list(prun_cfg)[i].tolist()
175 | optimize_cfg.append(int(b))
176 | print(optimize_cfg)
177 | print(net.cfg2flops(prun_cfg))
178 | print(net.cfg2flops(net.cfg))
179 |
180 |
181 |
--------------------------------------------------------------------------------
/ToyExample/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 |
5 | import torch
6 | import torch.nn as nn
7 | from run_manager import RunManager
8 | from models import VGG_CIFAR, TrainRunConfig
9 |
10 | parser = argparse.ArgumentParser()
11 |
12 | """ basic config """
13 | parser.add_argument('--path', type=str, default='Exp/debug')
14 | parser.add_argument('--gpu', help='gpu available', default='0,1,2,3')
15 | parser.add_argument('--train', action='store_true', default=True)
16 | parser.add_argument('--manual_seed', default=13, type=int)
17 | parser.add_argument('--resume', action='store_true')
18 | parser.add_argument('--validation_frequency', type=int, default=1)
19 | parser.add_argument('--print_frequency', type=int, default=100)
20 |
21 | """ optimizer config """
22 | parser.add_argument('--n_epochs', type=int, default=300)
23 | parser.add_argument('--init_lr', type=float, default=0.1)
24 | parser.add_argument('--lr_schedule_type', type=str, default='cosine')
25 | parser.add_argument('--warm_epoch', type=int, default=5)
26 | parser.add_argument('--opt_type', type=str, default='sgd', choices=['sgd'])
27 | parser.add_argument('--momentum', type=float, default=0.9) # opt_param
28 | parser.add_argument('--no_nesterov', action='store_true') # opt_param
29 | parser.add_argument('--weight_decay', type=float, default=4e-5)
30 | parser.add_argument('--label_smoothing', type=float, default=0.1)
31 | parser.add_argument('--no_decay_keys', type=str,
32 | default='bn#bias', choices=[None, 'bn', 'bn#bias'])
33 |
34 | """ dataset config """
35 | parser.add_argument('--dataset', type=str, default='cifar10',
36 | choices=['cifar10', 'imagenet', 'imagenet10', 'imagenet100'])
37 | parser.add_argument('--save_path', type=str, default='/userhome/data/cifar10')
38 | parser.add_argument('--train_batch_size', type=int, default=256)
39 | parser.add_argument('--test_batch_size', type=int, default=250)
40 |
41 | """ model config """
42 | parser.add_argument('--model', type=str, default="vgg",
43 | choices=['vgg', 'resnet20', 'resnet56', 'resnet110', 'resnet18', 'resnet34', 'resnet50_TAS', 'mobilenetv2_autoslim', 'resnet50', 'mobilenetv2', 'mobilenet'])
44 | parser.add_argument('--cfg', type=str, default="None")
45 | parser.add_argument('--base_path', type=str, default=None)
46 | parser.add_argument('--dropout', type=float, default=0.0)
47 | parser.add_argument('--model_init', type=str,
48 | default='he_fout', choices=['he_fin', 'he_fout'])
49 | parser.add_argument('--init_div_groups', action='store_true')
50 |
51 | """ runtime config """
52 | parser.add_argument('--n_worker', type=int, default=24)
53 | parser.add_argument('--sync_bn', action='store_true',
54 | help='enabling apex sync BN.')
55 | parser.add_argument("--local_rank", default=0, type=int)
56 |
57 |
58 | if __name__ == '__main__':
59 | args = parser.parse_args()
60 |
61 | # random.seed(args.manual_seed)
62 | # torch.manual_seed(args.manual_seed)
63 | # torch.cuda.manual_seed(args.manual_seed)
64 | # torch.cuda.manual_seed_all(args.manual_seed)
65 | # np.random.seed(args.manual_seed)
66 | torch.backends.cudnn.benchmark = True
67 | # torch.backends.cudnn.deterministic = True
68 | # torch.backends.cudnn.enabled = True
69 | os.makedirs(args.path, exist_ok=True)
70 |
71 | # distributed setting
72 | torch.distributed.init_process_group(backend='nccl',
73 | init_method='env://')
74 | args.world_size = torch.distributed.get_world_size()
75 |
76 | # prepare run config
77 | run_config_path = '%s/run.config' % args.path
78 | # build run config from args
79 | args.lr_schedule_param = None
80 | args.opt_param = {
81 | 'momentum': args.momentum,
82 | 'nesterov': not args.no_nesterov,
83 | }
84 | run_config = TrainRunConfig(
85 | **args.__dict__
86 | )
87 |
88 | if args.local_rank == 0:
89 | print('Run config:')
90 | for k, v in run_config.config.items():
91 | print('\t%s: %s' % (k, v))
92 |
93 | weight_path = None
94 | net_origin = None
95 | if args.model == "vgg":
96 | assert args.dataset == 'cifar10', 'vgg only supports cifar10 dataset'
97 | net = VGG_CIFAR(
98 | num_classes=run_config.data_provider.n_classes, cfg=eval(args.cfg))
99 | if args.base_path is not None:
100 | weight_path = args.base_path + '/checkpoint/model_best.pth.tar'
101 | net_origin = nn.DataParallel(
102 | VGG_CIFAR(num_classes=run_config.data_provider.n_classes))
103 |
104 | # build run manager
105 | run_manager = RunManager(args.path, net, run_config)
106 | if args.local_rank == 0:
107 | run_manager.save_config(print_info=True)
108 |
109 | # load checkpoints
110 | if args.base_path is not None:
111 | weight_path = args.base_path+'/checkpoint/model_best.pth.tar'
112 | if args.resume:
113 | run_manager.load_model()
114 | if args.train and run_manager.best_acc == 0:
115 | loss, acc1, acc5 = run_manager.validate(
116 | is_test=True, return_top5=True)
117 | run_manager.best_acc = acc1
118 | elif weight_path is not None and os.path.isfile(weight_path):
119 | assert net_origin is not None, "original network is None"
120 | net_origin.load_state_dict(torch.load(weight_path)['state_dict'])
121 | net_origin = net_origin.module
122 | if args.model == 'vgg':
123 | run_manager.reset_model(
124 | VGG_CIFAR(cfg=eval(args.cfg), cutout=True), net_origin.cpu())
125 | else:
126 | print('Random initialization')
127 |
128 | # train
129 | if args.train:
130 | print('Start training: %d' % args.local_rank)
131 | if args.local_rank == 0:
132 | print(net)
133 | run_manager.train(print_top5=True)
134 | if args.local_rank == 0:
135 | run_manager.save_model()
136 |
137 | output_dict = {}
138 |
139 | # test
140 | print('Test on test set')
141 | loss, acc1, acc5 = run_manager.validate(is_test=True, return_top5=True)
142 | log = 'test_loss: %f\t test_acc1: %f\t test_acc5: %f' % (loss, acc1, acc5)
143 | run_manager.write_log(log, prefix='test')
144 | output_dict = {
145 | **output_dict,
146 | 'test_loss': '%f' % loss, 'test_acc1': '%f' % acc1, 'test_acc5': '%f' % acc5
147 | }
148 | json.dump(output_dict, open('%s/output' % args.path, 'w'), indent=4)
149 |
--------------------------------------------------------------------------------
/Cifar/feature_extract.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import numpy as np
4 | import math
5 | from scipy import optimize
6 | import random, sys
7 | sys.path.append("..")
8 | from CKA import cka
9 | import torch
10 | import torch.nn as nn
11 | from run_manager import RunManager
12 | from nets import ResNet_CIFAR, TrainRunConfig
13 |
14 | parser = argparse.ArgumentParser()
15 |
16 | """ model config """
17 | parser.add_argument('--path', type=str)
18 | parser.add_argument('--model', type=str, default="vgg",
19 | choices=['resnet20', 'resnet56'])
20 | parser.add_argument('--cfg', type=str, default="None")
21 | parser.add_argument('--manual_seed', default=0, type=int)
22 | parser.add_argument("--target_flops", default=0, type=int)
23 | parser.add_argument("--beta", default=1, type=int)
24 |
25 | """ dataset config """
26 | parser.add_argument('--dataset', type=str, default='cifar10',
27 | choices=['cifar10', 'imagenet'])
28 | parser.add_argument('--save_path', type=str, default='/userhome/data/cifar10')
29 |
30 | """ runtime config """
31 | parser.add_argument('--gpu', help='gpu available', default='0')
32 | parser.add_argument('--test_batch_size', type=int, default=100)
33 | parser.add_argument('--n_worker', type=int, default=24)
34 | parser.add_argument("--local_rank", default=0, type=int)
35 |
36 | if __name__ == '__main__':
37 | args = parser.parse_args()
38 |
39 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
40 | torch.cuda.set_device(0)
41 |
42 | random.seed(args.manual_seed)
43 | torch.manual_seed(args.manual_seed)
44 | torch.cuda.manual_seed_all(args.manual_seed)
45 | np.random.seed(args.manual_seed)
46 | # distributed setting
47 | torch.distributed.init_process_group(backend='nccl',
48 | init_method='env://')
49 | args.world_size = torch.distributed.get_world_size()
50 |
51 | # prepare run config
52 | run_config_path = '%s/run.config' % args.path
53 |
54 | run_config = TrainRunConfig(
55 | **args.__dict__
56 | )
57 | if args.local_rank == 0:
58 | print('Run config:')
59 | for k, v in args.__dict__.items():
60 | print('\t%s: %s' % (k, v))
61 |
62 | if args.model == "resnet20":
63 | assert args.dataset == 'cifar10', 'resnet20 only supports cifar10 dataset'
64 | net = ResNet_CIFAR(
65 | depth=20, num_classes=run_config.data_provider.n_classes, cfg=eval(args.cfg))
66 | elif args.model == "resnet56":
67 | assert args.dataset == 'cifar10', 'resnet56 only supports cifar10 dataset'
68 | net = ResNet_CIFAR(
69 | depth=56, num_classes=run_config.data_provider.n_classes, cfg=eval(args.cfg))
70 |
71 | # build run manager
72 | run_manager = RunManager(args.path, net, run_config)
73 |
74 | # load checkpoints
75 | best_model_path = '%s/checkpoint/model_best.pth.tar' % args.path
76 | assert os.path.isfile(best_model_path), 'wrong path'
77 | if torch.cuda.is_available():
78 | checkpoint = torch.load(best_model_path)
79 | else:
80 | checkpoint = torch.load(best_model_path, map_location='cpu')
81 | if 'state_dict' in checkpoint:
82 | checkpoint = checkpoint['state_dict']
83 | run_manager.net.load_state_dict(checkpoint)
84 | output_dict = {}
85 |
86 | # feature extract
87 | data_loader = run_manager.run_config.test_loader
88 | data = next(iter(data_loader))
89 | data = data[0]
90 | n = data.size()[0]
91 |
92 | with torch.no_grad():
93 | feature = net.feature_extract(data)
94 |
95 | for i in range(len(feature)):
96 | feature[i] = feature[i].view(n, -1)
97 | feature[i] = feature[i].data.cpu().numpy()
98 |
99 | similar_matrix = np.zeros((len(feature), len(feature)))
100 |
101 | for i in range(len(feature)):
102 | for j in range(len(feature)):
103 | with torch.no_grad():
104 | similar_matrix[i][j] = cka.cka(cka.gram_linear(feature[i]), cka.gram_linear(feature[j]))
105 |
106 | def sum_list(a, j):
107 | b = 0
108 | for i in range(len(a)):
109 | if i != j:
110 | b += a[i]
111 | return b
112 |
113 | important = []
114 | temp = []
115 | flops = []
116 |
117 | for i in range(len(feature)):
118 | temp.append( sum_list(similar_matrix[i], i) )
119 |
120 | b = sum_list(temp, -1)
121 | temp = [x/b for x in temp]
122 |
123 | for i in range(len(feature)):
124 | important.append( math.exp(-1* args.beta *temp[i] ) )
125 |
126 | length = len(net.cfg)
127 | flops_singlecfg, flops_doublecfg, flops_squarecfg = net.cfg2flops_perlayer(net.cfg, length)
128 | important = np.array(important)
129 | important = np.negative(important)
130 |
131 | # Objective function
132 | def func(x, sign=1.0):
133 | """ Objective function """
134 | global important,length
135 | sum_fuc =[]
136 | for i in range(length):
137 | sum_fuc.append(x[i]*important[i])
138 | return sum(sum_fuc)
139 |
140 | # Derivative function of objective function
141 | def func_deriv(x, sign=1.0):
142 | """ Derivative of objective function """
143 | global important
144 | diff = []
145 | for i in range(len(important)):
146 | diff.append(sign * (important[i]))
147 | return np.array(diff)
148 |
149 | # Constraint function
150 | def constrain_func(x):
151 | """ constrain function """
152 | global flops_singlecfg, flops_doublecfg, flops_squarecfg, length
153 | a = []
154 | for i in range(length):
155 | a.append(x[i] * flops_singlecfg[i])
156 | a.append(flops_squarecfg[i] * x[i] * x[i])
157 | for i in range(1,length):
158 | for j in range(i):
159 | a.append(x[i] * x[j] * flops_doublecfg[i][j])
160 | return np.array([args.target_flops - sum(a)])
161 |
162 |
163 | bnds = []
164 | for i in range(length):
165 | bnds.append((0,1))
166 |
167 | bnds = tuple(bnds)
168 | cons = ({'type': 'ineq',
169 | 'fun': constrain_func})
170 |
171 | result = optimize.minimize(func,x0=[1 for i in range(length)], jac=func_deriv, method='SLSQP', bounds=bnds, constraints=cons)
172 | prun_cfg = np.around(np.array(net.cfg)*result.x)
173 |
174 | optimize_cfg = []
175 | for i in range(len(prun_cfg)):
176 | b = list(prun_cfg)[i].tolist()
177 | optimize_cfg.append(int(b))
178 | print(optimize_cfg)
179 | print(net.cfg2flops(prun_cfg))
180 | print(net.cfg2flops(net.cfg))
181 |
182 |
183 |
--------------------------------------------------------------------------------
/Cifar/utils/model_profiling.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import time
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | model_profiling_hooks = []
8 | model_profiling_speed_hooks = []
9 |
10 | name_space = 95
11 | params_space = 15
12 | macs_space = 15
13 | seconds_space = 15
14 |
15 | num_forwards = 10
16 |
17 |
18 | class Timer(object):
19 | def __init__(self, verbose=False):
20 | self.verbose = verbose
21 | self.start = None
22 | self.end = None
23 |
24 | def __enter__(self):
25 | self.start = time.time()
26 | return self
27 |
28 | def __exit__(self, *args):
29 | self.end = time.time()
30 | self.time = self.end - self.start
31 | if self.verbose:
32 | print('Elapsed time: %f ms.' % self.time)
33 |
34 |
35 | def get_params(self):
36 | """get number of params in module"""
37 | return np.sum(
38 | [np.prod(list(w.size())) for w in self.parameters()])
39 |
40 |
41 | def run_forward(self, input):
42 | with Timer() as t:
43 | for _ in range(num_forwards):
44 | self.forward(*input)
45 | torch.cuda.synchronize()
46 | return int(t.time * 1e9 / num_forwards)
47 |
48 |
49 | def conv_module_name_filter(name):
50 | """filter module name to have a short view"""
51 | filters = {
52 | 'kernel_size': 'k',
53 | 'stride': 's',
54 | 'padding': 'pad',
55 | 'bias': 'b',
56 | 'groups': 'g',
57 | }
58 | for k in filters:
59 | name = name.replace(k, filters[k])
60 | return name
61 |
62 |
63 | def module_profiling(self, input, output, verbose):
64 | ins = input[0].size()
65 | outs = output.size()
66 | # NOTE: There are some difference between type and isinstance, thus please
67 | # be careful.
68 | t = type(self)
69 | if isinstance(self, nn.Conv2d):
70 | self.n_macs = (ins[1] * outs[1] *
71 | self.kernel_size[0] * self.kernel_size[1] *
72 | outs[2] * outs[3] // self.groups) * outs[0]
73 | self.n_params = get_params(self)
74 | self.n_seconds = run_forward(self, input)
75 | self.name = conv_module_name_filter(self.__repr__())
76 | elif isinstance(self, nn.ConvTranspose2d):
77 | self.n_macs = (ins[1] * outs[1] *
78 | self.kernel_size[0] * self.kernel_size[1] *
79 | outs[2] * outs[3] // self.groups) * outs[0]
80 | self.n_params = get_params(self)
81 | self.n_seconds = run_forward(self, input)
82 | self.name = conv_module_name_filter(self.__repr__())
83 | elif isinstance(self, nn.Linear):
84 | self.n_macs = ins[1] * outs[1] * outs[0]
85 | self.n_params = get_params(self)
86 | self.n_seconds = run_forward(self, input)
87 | self.name = self.__repr__()
88 | elif isinstance(self, nn.AvgPool2d):
89 | # NOTE: this function is correct only when stride == kernel size
90 | self.n_macs = ins[1] * ins[2] * ins[3] * ins[0]
91 | self.n_params = 0
92 | self.n_seconds = run_forward(self, input)
93 | self.name = self.__repr__()
94 | elif isinstance(self, nn.AdaptiveAvgPool2d):
95 | # NOTE: this function is correct only when stride == kernel size
96 | self.n_macs = ins[1] * ins[2] * ins[3] * ins[0]
97 | self.n_params = 0
98 | self.n_seconds = run_forward(self, input)
99 | self.name = self.__repr__()
100 | else:
101 | # This works only in depth-first travel of modules.
102 | self.n_macs = 0
103 | self.n_params = 0
104 | self.n_seconds = 0
105 | num_children = 0
106 | for m in self.children():
107 | self.n_macs += getattr(m, 'n_macs', 0)
108 | self.n_params += getattr(m, 'n_params', 0)
109 | self.n_seconds += getattr(m, 'n_seconds', 0)
110 | num_children += 1
111 | ignore_zeros_t = [
112 | nn.BatchNorm2d, nn.Dropout2d, nn.Dropout, nn.Sequential,
113 | nn.ReLU6, nn.ReLU, nn.MaxPool2d,
114 | nn.modules.padding.ZeroPad2d, nn.modules.activation.Sigmoid,
115 | ]
116 | if (not getattr(self, 'ignore_model_profiling', False) and
117 | self.n_macs == 0 and
118 | t not in ignore_zeros_t):
119 | print(
120 | 'WARNING: leaf module {} has zero n_macs.'.format(type(self)))
121 | return
122 | if verbose:
123 | print(
124 | self.name.ljust(name_space, ' ') +
125 | '{:,}'.format(self.n_params).rjust(params_space, ' ') +
126 | '{:,}'.format(self.n_macs).rjust(macs_space, ' ') +
127 | '{:,}'.format(self.n_seconds).rjust(seconds_space, ' '))
128 | return
129 |
130 |
131 | def add_profiling_hooks(m, verbose):
132 | global model_profiling_hooks
133 | model_profiling_hooks.append(
134 | m.register_forward_hook(lambda m, input, output: module_profiling(
135 | m, input, output, verbose=verbose)))
136 |
137 |
138 | def remove_profiling_hooks():
139 | global model_profiling_hooks
140 | for h in model_profiling_hooks:
141 | h.remove()
142 | model_profiling_hooks = []
143 |
144 |
145 | def model_profiling(model, height, width, batch=1, channel=3, use_cuda=True,
146 | verbose=True):
147 | """ Pytorch model profiling with input image size
148 | (batch, channel, height, width).
149 | The function exams the number of multiply-accumulates (n_macs).
150 |
151 | Args:
152 | model: pytorch model
153 | height: int
154 | width: int
155 | batch: int
156 | channel: int
157 | use_cuda: bool
158 |
159 | Returns:
160 | macs: int
161 | params: int
162 |
163 | """
164 | model.eval()
165 | data = torch.rand(batch, channel, height, width)
166 | device = torch.device("cuda:0" if use_cuda else "cpu")
167 | model = model.to(device)
168 | data = data.to(device)
169 | model.apply(lambda m: add_profiling_hooks(m, verbose=verbose))
170 | print(
171 | 'Item'.ljust(name_space, ' ') +
172 | 'params'.rjust(macs_space, ' ') +
173 | 'macs'.rjust(macs_space, ' ') +
174 | 'nanosecs'.rjust(seconds_space, ' '))
175 | if verbose:
176 | print(''.center(
177 | name_space + params_space + macs_space + seconds_space, '-'))
178 | model(data)
179 | if verbose:
180 | print(''.center(
181 | name_space + params_space + macs_space + seconds_space, '-'))
182 | print(
183 | 'Total'.ljust(name_space, ' ') +
184 | '{:,}'.format(model.n_params).rjust(params_space, ' ') +
185 | '{:,}'.format(model.n_macs).rjust(macs_space, ' ') +
186 | '{:,}'.format(model.n_seconds).rjust(seconds_space, ' '))
187 | remove_profiling_hooks()
188 | return model.n_macs, model.n_params
189 |
--------------------------------------------------------------------------------
/Imagenet/feature_extract.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import numpy as np
4 | import math
5 | from scipy import optimize
6 | import random, sys
7 | sys.path.append("../../ITPruner")
8 | from CKA import cka
9 | import torch
10 | import torch.nn as nn
11 | from run_manager import RunManager
12 | from models import ResNet_ImageNet, MobileNet, MobileNetV2, TrainRunConfig
13 |
14 | parser = argparse.ArgumentParser()
15 |
16 | """ model config """
17 | parser.add_argument('--path', type=str)
18 | parser.add_argument('--model', type=str, default="vgg",
19 | choices=['resnet50', 'mobilenetv2', 'mobilenet'])
20 | parser.add_argument('--cfg', type=str, default="None")
21 | parser.add_argument('--manual_seed', default=0, type=int)
22 | parser.add_argument("--target_flops", default=0, type=int)
23 | parser.add_argument("--beta", default=1, type=int)
24 |
25 | """ dataset config """
26 | parser.add_argument('--dataset', type=str, default='cifar10',
27 | choices=['cifar10', 'imagenet'])
28 | parser.add_argument('--save_path', type=str, default='/userhome/data/cifar10')
29 |
30 | """ runtime config """
31 | parser.add_argument('--gpu', help='gpu available', default='0')
32 | parser.add_argument('--test_batch_size', type=int, default=100)
33 | parser.add_argument('--n_worker', type=int, default=24)
34 | parser.add_argument("--local_rank", default=0, type=int)
35 |
36 | if __name__ == '__main__':
37 | args = parser.parse_args()
38 |
39 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
40 | torch.cuda.set_device(0)
41 |
42 | random.seed(args.manual_seed)
43 | torch.manual_seed(args.manual_seed)
44 | torch.cuda.manual_seed_all(args.manual_seed)
45 | np.random.seed(args.manual_seed)
46 | # distributed setting
47 | torch.distributed.init_process_group(backend='nccl',
48 | init_method='env://')
49 | args.world_size = torch.distributed.get_world_size()
50 |
51 | # prepare run config
52 | run_config_path = '%s/run.config' % args.path
53 |
54 | run_config = TrainRunConfig(
55 | **args.__dict__
56 | )
57 | if args.local_rank == 0:
58 | print('Run config:')
59 | for k, v in args.__dict__.items():
60 | print('\t%s: %s' % (k, v))
61 |
62 | if args.model == "resnet50":
63 | assert args.dataset == 'imagenet', 'resnet50 only supports imagenet dataset'
64 | net = ResNet_ImageNet(
65 | depth=50, num_classes=run_config.data_provider.n_classes, cfg=eval(args.cfg))
66 | elif args.model == "mobilenetv2":
67 | assert args.dataset == 'imagenet', 'mobilenetv2 only supports imagenet dataset'
68 | net = MobileNetV2(
69 | num_classes=run_config.data_provider.n_classes, cfg=eval(args.cfg))
70 | elif args.model == "mobilenet":
71 | assert args.dataset == 'imagenet', 'mobilenet only supports imagenet dataset'
72 | net = MobileNet(
73 | num_classes=run_config.data_provider.n_classes, cfg=eval(args.cfg))
74 |
75 | # build run manager
76 | run_manager = RunManager(args.path, net, run_config)
77 |
78 | # load checkpoints
79 | best_model_path = '%s/checkpoint/model_best.pth.tar' % args.path
80 | assert os.path.isfile(best_model_path), 'wrong path'
81 | if torch.cuda.is_available():
82 | checkpoint = torch.load(best_model_path)
83 | else:
84 | checkpoint = torch.load(best_model_path, map_location='cpu')
85 | if 'state_dict' in checkpoint:
86 | checkpoint = checkpoint['state_dict']
87 | run_manager.net.load_state_dict(checkpoint)
88 | output_dict = {}
89 |
90 | # feature extract
91 | data_loader = run_manager.run_config.test_loader
92 | data = next(iter(data_loader))
93 | data = data[0]
94 | n = data.size()[0]
95 |
96 | with torch.no_grad():
97 | feature = net.feature_extract(data)
98 |
99 | for i in range(len(feature)):
100 | feature[i] = feature[i].view(n, -1)
101 | feature[i] = feature[i].data.cpu().numpy()
102 |
103 | similar_matrix = np.zeros((len(feature), len(feature)))
104 |
105 | for i in range(len(feature)):
106 | for j in range(len(feature)):
107 | with torch.no_grad():
108 | similar_matrix[i][j] = cka.cka(cka.gram_linear(feature[i]), cka.gram_linear(feature[j]))
109 |
110 | def sum_list(a, j):
111 | b = 0
112 | for i in range(len(a)):
113 | if i != j:
114 | b += a[i]
115 | return b
116 |
117 | important = []
118 | temp = []
119 | flops = []
120 |
121 | for i in range(len(feature)):
122 | temp.append( sum_list(similar_matrix[i], i) )
123 |
124 | b = sum_list(temp, -1)
125 | temp = [x/b for x in temp]
126 |
127 | for i in range(len(feature)):
128 | important.append( math.exp(-1* args.beta *temp[i] ) )
129 |
130 | length = len(net.cfg)
131 | flops_singlecfg, flops_doublecfg, flops_squarecfg = net.cfg2flops_perlayer(net.cfg, length)
132 | important = np.array(important)
133 | important = np.negative(important)
134 |
135 | # Objective function
136 | def func(x, sign=1.0):
137 | """ Objective function """
138 | global important,length
139 | sum_fuc =[]
140 | for i in range(length):
141 | sum_fuc.append(x[i]*important[i])
142 | return sum(sum_fuc)
143 |
144 | # Derivative function of objective function
145 | def func_deriv(x, sign=1.0):
146 | """ Derivative of objective function """
147 | global important
148 | diff = []
149 | for i in range(len(important)):
150 | diff.append(sign * (important[i]))
151 | return np.array(diff)
152 |
153 | # Constraint function
154 | def constrain_func(x):
155 | """ constrain function """
156 | global flops_singlecfg, flops_doublecfg, flops_squarecfg, length
157 | a = []
158 | for i in range(length):
159 | a.append(x[i] * flops_singlecfg[i])
160 | a.append(flops_squarecfg[i] * x[i] * x[i])
161 | for i in range(1,length):
162 | for j in range(i):
163 | a.append(x[i] * x[j] * flops_doublecfg[i][j])
164 | return np.array([args.target_flops - sum(a)])
165 |
166 |
167 | bnds = []
168 | for i in range(length):
169 | bnds.append((0,1))
170 |
171 | bnds = tuple(bnds)
172 | cons = ({'type': 'ineq',
173 | 'fun': constrain_func})
174 |
175 | result = optimize.minimize(func,x0=[1 for i in range(length)], jac=func_deriv, method='SLSQP', bounds=bnds, constraints=cons)
176 | prun_cfg = np.around(np.array(net.cfg)*result.x)
177 |
178 | optimize_cfg = []
179 | for i in range(len(prun_cfg)):
180 | b = list(prun_cfg)[i].tolist()
181 | optimize_cfg.append(int(b))
182 | print(optimize_cfg)
183 | print(net.cfg2flops(prun_cfg))
184 | print(net.cfg2flops(net.cfg))
185 |
186 |
187 |
--------------------------------------------------------------------------------
/ToyExample/models/vgg.py:
--------------------------------------------------------------------------------
1 | # https://github.com/pytorch/vision/blob/master/torchvision/models
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from collections import OrderedDict
6 | import torch.utils.model_zoo as model_zoo
7 |
8 | from utils import cutout_batch
9 | from models.base_models import MyNetwork
10 | import numpy as np
11 |
12 | cifar_cfg = {
13 | 11: [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512],
14 | 13: [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512],
15 | 16: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512],
16 | 19: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512],
17 | }
18 |
19 |
20 | class VGG_CIFAR(MyNetwork):
21 | def __init__(self, cfg=None, cutout=True, num_classes=10):
22 | super(VGG_CIFAR, self).__init__()
23 | if cfg is None:
24 | cfg = [64, 64, 128, 128, 256, 256,
25 | 256, 512, 512, 512, 512, 512, 512]
26 | self.cutout = cutout
27 | self.cfg = cfg
28 | _cfg = list(cfg)
29 | _cfg.insert(2, 'M')
30 | _cfg.insert(5, 'M')
31 | _cfg.insert(9, 'M')
32 | _cfg.insert(13, 'M')
33 | self._cfg = _cfg
34 | self.feature = self.make_layers(_cfg, True)
35 | self.avgpool = nn.AvgPool2d(2)
36 | self.classifier = nn.Sequential(
37 | nn.Linear(self.cfg[-1], 512),
38 | # nn.BatchNorm1d(512),
39 | nn.ReLU(inplace=True),
40 | nn.Linear(512, num_classes)
41 | )
42 | self.num_classes = num_classes
43 | self.classifier_param = (
44 | self.cfg[-1] + 1) * 512 + (512 + 1) * num_classes
45 |
46 | def make_layers(self, cfg, batch_norm=False):
47 | layers = []
48 | in_channels = 3
49 | pool_index = 0
50 | conv_index = 0
51 | for v in cfg:
52 | if v == 'M':
53 | layers += [('maxpool_%d' % pool_index,
54 | nn.MaxPool2d(kernel_size=2, stride=2))]
55 | pool_index += 1
56 | else:
57 | conv2d = nn.Conv2d(
58 | in_channels, v, kernel_size=3, padding=1, bias=False)
59 | conv_index += 1
60 | if batch_norm:
61 | bn = nn.BatchNorm2d(v)
62 | layers += [('conv_%d' % conv_index, conv2d), ('bn_%d' % conv_index, bn),
63 | ('relu_%d' % conv_index, nn.ReLU(inplace=True))]
64 | else:
65 | layers += [('conv_%d' % conv_index, conv2d),
66 | ('relu_%d' % conv_index, nn.ReLU(inplace=True))]
67 | in_channels = v
68 | self.conv_num = conv_index
69 | return nn.Sequential(OrderedDict(layers))
70 |
71 | def cfg2param(self, cfg):
72 | total_param = self.classifier_param
73 | c_in = 3
74 | _cfg = list(cfg)
75 | _cfg.insert(2, 'M')
76 | _cfg.insert(5, 'M')
77 | _cfg.insert(9, 'M')
78 | _cfg.insert(13, 'M')
79 | for c_out in _cfg:
80 | if isinstance(c_out, int):
81 | total_param += 3 * 3 * c_in * c_out + 2 * c_out
82 | c_in = c_out
83 | total_param += 512 * (_cfg[-1] + 1)
84 | return total_param
85 |
86 | def cfg2flops(self, cfg): # to simplify, only count convolution flops
87 | _cfg = list(cfg)
88 | _cfg.insert(2, 'M')
89 | _cfg.insert(5, 'M')
90 | _cfg.insert(9, 'M')
91 | _cfg.insert(13, 'M')
92 | input_size = 32
93 | num_classes = 10
94 | total_flops = 0
95 | c_in = 3
96 | for c_out in _cfg:
97 | if c_out !='M':
98 | total_flops += 3 * 3 * c_in * c_out * input_size * input_size # conv
99 | total_flops += 4 * c_out * input_size * input_size # bn
100 | total_flops += input_size * input_size * c_out # relu
101 | c_in = c_out
102 | else:
103 | input_size /= 2
104 | total_flops += 2 * input_size * input_size * c_in # max pool
105 | input_size /= 2
106 | total_flops += 3 * input_size * input_size * _cfg[-1] # avg pool
107 | total_flops += (2 * input_size * input_size *
108 | _cfg[-1] - 1) * input_size * input_size * 512 # fc
109 | total_flops += 4 * _cfg[-1] * input_size * input_size # bn_1d
110 | total_flops += input_size * input_size * 512 # relu
111 | total_flops += (2 * input_size * input_size * 512 - 1) * \
112 | input_size * input_size * num_classes # fc
113 | return total_flops
114 |
115 | def cfg2flops_perlayer(self, cfg, length): # to simplify, only count convolution flops
116 | _cfg = list(cfg)
117 | _cfg.insert(2, 'M')
118 | _cfg.insert(5, 'M')
119 | _cfg.insert(9, 'M')
120 | _cfg.insert(13, 'M')
121 | input_size = 32
122 | num_classes = 10
123 | flops_singlecfg = [0 for j in range(length)]
124 | flops_doublecfg = np.zeros((length, length))
125 | flops_squarecfg = [0 for j in range(length)]
126 | c_in = 3
127 | count = 0
128 | for c_out in _cfg:
129 | if c_out !='M' and count ==0:
130 | flops_singlecfg[count] += 3 * 3 * c_in * c_out * input_size * input_size # conv
131 | flops_singlecfg[count] += 4 * c_out * input_size * input_size # bn
132 | flops_singlecfg[count] += input_size * input_size * c_out # relu
133 | c_in = c_out
134 | count += 1
135 | elif c_out !='M' and count !=0:
136 | flops_doublecfg[count-1][count] += 3 * 3 * c_in * c_out * input_size * input_size # conv
137 | flops_doublecfg[count][count-1] += 3 * 3 * c_in * c_out * input_size * input_size
138 | flops_singlecfg[count] += 4 * c_out * input_size * input_size # bn
139 | flops_singlecfg[count] += input_size * input_size * c_out # relu
140 | c_in = c_out
141 | count += 1
142 | else:
143 | input_size /= 2
144 | flops_singlecfg[count-1] += 2 * input_size * input_size * c_in # max pool
145 | input_size /= 2
146 | flops_singlecfg[count-1] += 3 * input_size * input_size * _cfg[-1] # avg pool
147 | flops_singlecfg[count-1] += (2 * input_size * input_size * _cfg[-1] ) * input_size * input_size * 512 # fc
148 | flops_singlecfg[count-1] += 4 * _cfg[-1] * input_size * input_size # bn_1d
149 |
150 | return flops_singlecfg, flops_doublecfg, flops_squarecfg
151 |
152 | def forward(self, x):
153 | if self.training and self.cutout:
154 | with torch.no_grad():
155 | x = cutout_batch(x, 16)
156 | x = self.feature(x)
157 | x = self.avgpool(x)
158 | x = x.view(x.size(0), -1)
159 | y = self.classifier(x)
160 | return y
161 |
162 | def feature_extract(self, x):
163 | tensor = []
164 | if self.training and self.cutout:
165 | with torch.no_grad():
166 | x = cutout_batch(x, 16)
167 | for _layer in self.feature:
168 | x = _layer(x)
169 | if type(_layer) is nn.ReLU:
170 | tensor.append(x)
171 | return tensor
172 |
173 | @property
174 | def config(self):
175 | return {
176 | 'name': self.__class__.__name__,
177 | 'cfg': self.cfg,
178 | 'cfg_base': [64, 64, 128, 128, 256, 256, 256, 512, 512, 512, 512, 512, 512],
179 | 'dataset': 'cifar10',
180 | }
181 |
182 |
183 |
--------------------------------------------------------------------------------
/Cifar/data_loader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision import datasets, transforms
3 | import torchvision.transforms as transforms
4 | import torchvision.datasets as datasets
5 | import torch.utils.data.distributed
6 | from torch.autograd import Variable
7 | import torch.nn.functional as F
8 | import os
9 | from torch.utils.data import DataLoader, Subset, sampler
10 | import pickle
11 | import pdb
12 | import numpy as np
13 | from PIL import Image
14 |
15 |
16 | def cifar10_loader_search(args, num_workers=4):
17 | normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]],
18 | std=[x/255.0 for x in [63.0, 62.1, 66.7]])
19 | if args.data_aug:
20 | with torch.no_grad():
21 | transform_train = transforms.Compose([
22 | transforms.ToTensor(),
23 | transforms.Lambda(lambda x: F.pad(
24 | Variable(x.unsqueeze(0), requires_grad=False),
25 | (4,4,4,4),mode='reflect').data.squeeze()),
26 | transforms.ToPILImage(),
27 | transforms.RandomCrop(32),
28 | transforms.RandomHorizontalFlip(),
29 | transforms.ToTensor(),
30 | normalize,
31 | ])
32 | transform_test = transforms.Compose([
33 | transforms.ToTensor(),
34 | normalize
35 | ])
36 | else:
37 | transform_train = transforms.Compose([
38 | transforms.ToTensor(),
39 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
40 | ])
41 | transform_test = transforms.Compose([
42 | transforms.ToTensor(),
43 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
44 | ])
45 | trainset = datasets.CIFAR10(root=args.data_path, train=True, download=True, transform=transform_train)
46 | testset = datasets.CIFAR10(root=args.data_path, train=False, transform=transform_test)
47 |
48 | num_train = len(trainset)
49 | indices = list(range(num_train))
50 | split = int(np.floor(args.train_portion * num_train * args.total_portion)) # we now add a portion_total arg to just use smaller number of data
51 | split_upb = int(num_train * args.total_portion)
52 |
53 | train_loader = DataLoader(trainset, batch_size=args.batch_size,
54 | sampler=sampler.SubsetRandomSampler(indices[:split]),
55 | num_workers=num_workers, pin_memory=True)
56 | valid_loader = DataLoader(trainset, batch_size=args.batch_size,
57 | sampler=sampler.SubsetRandomSampler(indices[split:split_upb]),
58 | num_workers=num_workers, pin_memory=True)
59 | test_loader = DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
60 | return train_loader, valid_loader, test_loader
61 |
62 | def cifar10_loader_train(args, num_workers=4):
63 | normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]],
64 | std=[x/255.0 for x in [63.0, 62.1, 66.7]])
65 | if args.data_aug:
66 | with torch.no_grad():
67 | transform_train = transforms.Compose([
68 | transforms.ToTensor(),
69 | transforms.Lambda(lambda x: F.pad(Variable(x.unsqueeze(0), requires_grad=False),
70 | (4,4,4,4),mode='reflect').data.squeeze()),
71 | transforms.ToPILImage(),
72 | transforms.RandomCrop(32),
73 | transforms.RandomHorizontalFlip(),
74 | transforms.ToTensor(),
75 | normalize,
76 | ])
77 | transform_test = transforms.Compose([
78 | transforms.ToTensor(),
79 | normalize
80 | ])
81 | else:
82 | transform_train = transforms.Compose([
83 | transforms.ToTensor(),
84 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
85 | ])
86 | transform_test = transforms.Compose([
87 | transforms.ToTensor(),
88 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
89 | ])
90 |
91 | trainset = datasets.CIFAR10(root=args.data_path, train=True, download=True, transform=transform_train)
92 | testset = datasets.CIFAR10(root=args.data_path, train=False, transform=transform_test)
93 |
94 | train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True,
95 | num_workers=num_workers, pin_memory=True)
96 | test_loader = DataLoader(testset, batch_size=args.batch_size, shuffle=False,
97 | num_workers=num_workers, pin_memory=True)
98 | return train_loader, test_loader
99 |
100 | def cifar100_loader(args, num_workers=4):
101 | normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]],
102 | std=[x/255.0 for x in [63.0, 62.1, 66.7]])
103 | if args.data_aug:
104 | with torch.no_grad():
105 | transform_train = transforms.Compose([
106 | transforms.ToTensor(),
107 | transforms.Lambda(lambda x: F.pad(
108 | Variable(x.unsqueeze(0), requires_grad=False),
109 | (4,4,4,4),mode='reflect').data.squeeze()),
110 | transforms.ToPILImage(),
111 | transforms.RandomCrop(32),
112 | transforms.RandomHorizontalFlip(),
113 | transforms.ToTensor(),
114 | normalize,
115 | ])
116 | transform_test = transforms.Compose([
117 | transforms.ToTensor(),
118 | normalize
119 | ])
120 | else:
121 | transform_train = transforms.Compose([
122 | transforms.ToTensor(),
123 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
124 | ])
125 | transform_test = transforms.Compose([
126 | transforms.ToTensor(),
127 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
128 | ])
129 | trainset = datasets.CIFAR100(root=args.data_path, train=True, download=True, transform=transform_train)
130 | testset = datasets.CIFAR100(root=args.data_path, train=False, download=True, transform=transform_test)
131 | train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
132 | test_loader = DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
133 | return train_loader, test_loader
134 |
135 |
136 | def ilsvrc12_loader_train(args, num_workers=4):
137 | traindir = os.path.join(args.data_path, 'train')
138 | testdir = os.path.join(args.data_path, 'val')
139 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
140 | std=[0.229, 0.224, 0.225])
141 |
142 | train_dataset = datasets.ImageFolder(
143 | traindir,
144 | transforms.Compose([
145 | transforms.RandomResizedCrop(224),
146 | transforms.RandomHorizontalFlip(),
147 | transforms.ToTensor(),
148 | normalize,
149 | ]))
150 |
151 | if args.distributed:
152 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
153 | else:
154 | train_sampler = None
155 |
156 | train_loader = torch.utils.data.DataLoader(
157 | train_dataset, batch_size=args.batch_size // args.world_size, shuffle=(train_sampler is None),
158 | num_workers=num_workers, pin_memory=True, sampler=train_sampler)
159 |
160 | test_loader = torch.utils.data.DataLoader(
161 | datasets.ImageFolder(testdir, transforms.Compose([
162 | transforms.Resize(256),
163 | transforms.CenterCrop(224),
164 | transforms.ToTensor(),
165 | normalize,
166 | ])),
167 | batch_size=args.batch_size // args.world_size, shuffle=False,
168 | num_workers=num_workers, pin_memory=True)
169 |
170 | return [train_loader, test_loader], train_sampler
171 |
172 |
--------------------------------------------------------------------------------
/Imagenet/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 |
5 | import torch
6 | import torch.nn as nn
7 | from run_manager import RunManager
8 | from models import ResNet_ImageNet, MobileNet, MobileNetV2, TrainRunConfig
9 |
10 | parser = argparse.ArgumentParser()
11 |
12 | """ basic config """
13 | parser.add_argument('--path', type=str, default='Exp/debug')
14 | parser.add_argument('--gpu', help='gpu available', default='0,1,2,3')
15 | parser.add_argument('--train', action='store_true', default=True)
16 | parser.add_argument('--manual_seed', default=13, type=int)
17 | parser.add_argument('--resume', action='store_true')
18 | parser.add_argument('--validation_frequency', type=int, default=1)
19 | parser.add_argument('--print_frequency', type=int, default=100)
20 |
21 | """ optimizer config """
22 | parser.add_argument('--n_epochs', type=int, default=300)
23 | parser.add_argument('--init_lr', type=float, default=0.1)
24 | parser.add_argument('--lr_schedule_type', type=str, default='cosine')
25 | parser.add_argument('--warm_epoch', type=int, default=5)
26 | parser.add_argument('--opt_type', type=str, default='sgd', choices=['sgd'])
27 | parser.add_argument('--momentum', type=float, default=0.9) # opt_param
28 | parser.add_argument('--no_nesterov', action='store_true') # opt_param
29 | parser.add_argument('--weight_decay', type=float, default=4e-5)
30 | parser.add_argument('--label_smoothing', type=float, default=0.1)
31 | parser.add_argument('--no_decay_keys', type=str,
32 | default='bn#bias', choices=[None, 'bn', 'bn#bias'])
33 |
34 | """ dataset config """
35 | parser.add_argument('--dataset', type=str, default='cifar10',
36 | choices=['cifar10', 'imagenet', 'imagenet10', 'imagenet100'])
37 | parser.add_argument('--save_path', type=str, default='/userhome/data/cifar10')
38 | parser.add_argument('--train_batch_size', type=int, default=256)
39 | parser.add_argument('--test_batch_size', type=int, default=250)
40 |
41 | """ model config """
42 | parser.add_argument('--model', type=str, default="vgg",
43 | choices=['vgg', 'resnet20', 'resnet56', 'resnet110', 'resnet18', 'resnet34', 'resnet50_TAS', 'mobilenetv2_autoslim', 'resnet50', 'mobilenetv2', 'mobilenet'])
44 | parser.add_argument('--cfg', type=str, default="None")
45 | parser.add_argument('--base_path', type=str, default=None)
46 | parser.add_argument('--dropout', type=float, default=0.0)
47 | parser.add_argument('--model_init', type=str,
48 | default='he_fout', choices=['he_fin', 'he_fout'])
49 | parser.add_argument('--init_div_groups', action='store_true')
50 |
51 | """ runtime config """
52 | parser.add_argument('--n_worker', type=int, default=24)
53 | parser.add_argument('--sync_bn', action='store_true',
54 | help='enabling apex sync BN.')
55 | parser.add_argument("--local_rank", default=0, type=int)
56 |
57 |
58 | if __name__ == '__main__':
59 | args = parser.parse_args()
60 |
61 | # random.seed(args.manual_seed)
62 | # torch.manual_seed(args.manual_seed)
63 | # torch.cuda.manual_seed(args.manual_seed)
64 | # torch.cuda.manual_seed_all(args.manual_seed)
65 | # np.random.seed(args.manual_seed)
66 | torch.backends.cudnn.benchmark = True
67 | # torch.backends.cudnn.deterministic = True
68 | # torch.backends.cudnn.enabled = True
69 | os.makedirs(args.path, exist_ok=True)
70 |
71 | # distributed setting
72 | torch.distributed.init_process_group(backend='nccl',
73 | init_method='env://')
74 | args.world_size = torch.distributed.get_world_size()
75 |
76 | # prepare run config
77 | run_config_path = '%s/run.config' % args.path
78 | # build run config from args
79 | args.lr_schedule_param = None
80 | args.opt_param = {
81 | 'momentum': args.momentum,
82 | 'nesterov': not args.no_nesterov,
83 | }
84 | run_config = TrainRunConfig(
85 | **args.__dict__
86 | )
87 |
88 | if args.local_rank == 0:
89 | print('Run config:')
90 | for k, v in run_config.config.items():
91 | print('\t%s: %s' % (k, v))
92 |
93 | weight_path = None
94 | net_origin = None
95 | if args.model == "resnet50":
96 | assert args.dataset == 'imagenet', 'resnet50 only supports imagenet dataset'
97 | net = ResNet_ImageNet(
98 | depth=50, num_classes=run_config.data_provider.n_classes, cfg=eval(args.cfg))
99 | if args.base_path is not None:
100 | weight_path = args.base_path+'/checkpoint/model_best.pth.tar'
101 | net_origin = nn.DataParallel(ResNet_ImageNet(
102 | depth=50, num_classes=run_config.data_provider.n_classes))
103 | elif args.model == "mobilenetv2":
104 | assert args.dataset == 'imagenet', 'mobilenetv2 only supports imagenet dataset'
105 | net = MobileNetV2(
106 | num_classes=run_config.data_provider.n_classes, cfg=eval(args.cfg))
107 | if args.base_path is not None:
108 | weight_path = args.base_path+'/checkpoint/model_best.pth.tar'
109 | net_origin = nn.DataParallel(MobileNetV2(
110 | num_classes=run_config.data_provider.n_classes))
111 | elif args.model == "mobilenet":
112 | assert args.dataset == 'imagenet', 'mobilenet only supports imagenet dataset'
113 | net = MobileNet(
114 | num_classes=run_config.data_provider.n_classes, cfg=eval(args.cfg))
115 | if args.base_path is not None:
116 | weight_path = args.base_path+'/checkpoint/model_best.pth.tar'
117 | net_origin = nn.DataParallel(
118 | MobileNet(num_classes=run_config.data_provider.n_classes))
119 |
120 | # build run manager
121 | run_manager = RunManager(args.path, net, run_config)
122 | if args.local_rank == 0:
123 | run_manager.save_config(print_info=True)
124 |
125 | # load checkpoints
126 | if args.base_path is not None:
127 | weight_path = args.base_path+'/checkpoint/model_best.pth.tar'
128 | if args.resume:
129 | run_manager.load_model()
130 | if args.train and run_manager.best_acc == 0:
131 | loss, acc1, acc5 = run_manager.validate(
132 | is_test=True, return_top5=True)
133 | run_manager.best_acc = acc1
134 | elif weight_path is not None and os.path.isfile(weight_path):
135 | assert net_origin is not None, "original network is None"
136 | net_origin.load_state_dict(torch.load(weight_path)['state_dict'])
137 | net_origin = net_origin.module
138 | if args.model == 'resnet50':
139 | run_manager.reset_model(ResNet_ImageNet(
140 | num_classes=run_config.data_provider.n_classes, cfg=eval(args.cfg), depth=50), net_origin.cpu())
141 | elif args.model == 'mobilenet':
142 | run_manager.reset_model(MobileNet(
143 | num_classes=run_config.data_provider.n_classes, cfg=eval(args.cfg)), net_origin.cpu())
144 | elif args.model == 'mobilenetv2':
145 | run_manager.reset_model(MobileNetV2(
146 | num_classes=run_config.data_provider.n_classes, cfg=eval(args.cfg)), net_origin.cpu())
147 | else:
148 | print('Random initialization')
149 |
150 | # train
151 | if args.train:
152 | print('Start training: %d' % args.local_rank)
153 | if args.local_rank == 0:
154 | print(net)
155 | run_manager.train(print_top5=True)
156 | if args.local_rank == 0:
157 | run_manager.save_model()
158 |
159 | output_dict = {}
160 |
161 | # test
162 | print('Test on test set')
163 | loss, acc1, acc5 = run_manager.validate(is_test=True, return_top5=True)
164 | log = 'test_loss: %f\t test_acc1: %f\t test_acc5: %f' % (loss, acc1, acc5)
165 | run_manager.write_log(log, prefix='test')
166 | output_dict = {
167 | **output_dict,
168 | 'test_loss': '%f' % loss, 'test_acc1': '%f' % acc1, 'test_acc5': '%f' % acc5
169 | }
170 | json.dump(output_dict, open('%s/output' % args.path, 'w'), indent=4)
171 |
--------------------------------------------------------------------------------
/Cifar/utils/compute_flops.py:
--------------------------------------------------------------------------------
1 | # Code from https://github.com/simochen/model-tools.
2 | import numpy as np
3 | import pdb
4 | import torch
5 | import torchvision
6 | import torch.nn as nn
7 | import numpy as np
8 | import random
9 |
10 |
11 | def lookup_table_flops(model, candidate_width, alphas=None, input_res=32, multiply_adds=False):
12 | if alphas is None:
13 | for n, v in model.named_parameters():
14 | if 'alphas' in n:
15 | alphas = v
16 | num_conv = alphas.shape[0]
17 | device = alphas.device
18 |
19 | # obtain the feature map sizes
20 | list_bn=[]
21 | def bn_hook(self, input, output):
22 | if input[0].ndimension() == 4:
23 | list_bn.append(input[0].shape)
24 |
25 | def foo(net):
26 | childrens = list(net.children())
27 | if not childrens:
28 | if isinstance(net, torch.nn.BatchNorm2d):
29 | net.register_forward_hook(bn_hook)
30 | return
31 | for c in childrens:
32 | foo(c)
33 |
34 | foo(model)
35 | input_ = torch.rand(1, 3, input_res, input_res).to(alphas.device)
36 | input_.requires_grad = True
37 | # print('alphas:', alphas)
38 | # print('inputs:', input_)
39 | if torch.cuda.device_count() > 1:
40 | model.module.register_buffer('alphas_tmp', alphas.data)
41 | else:
42 | model.register_buffer('alphas_tmp', alphas.data)
43 | out = model(input_)
44 |
45 | # TODO: only appliable for resnet_20s: 2 convs followed by 1 shortcut
46 | list_main_bn = []
47 | num_width = len(candidate_width)
48 |
49 | for i, b in enumerate(list_bn):
50 | if i//num_width == 0 or \
51 | ((i-num_width)//(num_width**2) >= 0 and ((i-num_width)//(num_width)**2) % 3 != 2):
52 | list_main_bn.append(b)
53 | assert len(list_main_bn) == (num_width + num_width ** 2 * (num_conv-1)), 'wrong list of feature map length'
54 |
55 | # start compute flops for each branch
56 | # first obtain the kernel shapes, a list of length: num_width + num_width**2 * num_conv
57 |
58 | def kernel_shape_types(candidate_width):
59 | kshape_types = []
60 | first_kshape_types = []
61 | for i in candidate_width:
62 | first_kshape_types.append((i, 3, 3, 3))
63 |
64 | for i in candidate_width:
65 | for j in candidate_width:
66 | kshape_types.append((i, j, 3, 3)) # [co, ci, k, k]
67 | return kshape_types, first_kshape_types
68 |
69 | kshape_types, first_kshape_types = kernel_shape_types(candidate_width)
70 | k_shapes = []
71 | layer_idx = 0
72 | for v in model.parameters():
73 | if v.ndimension() == 4 and v.shape[2] == 3:
74 | if layer_idx == 0:
75 | k_shapes += first_kshape_types
76 | else:
77 | k_shapes += kshape_types
78 | layer_idx += 1
79 |
80 | # compute flops
81 | flops = [] # a list of length: num_width + num_width**2 * num_conv
82 | for idx, a_shape in enumerate(list_main_bn):
83 | n, ci, h, w = a_shape
84 | k_shape = k_shapes[idx]
85 | co, ci, k, _ = k_shape
86 | flop = co * ci * k * k * h * w
87 | flops.append(flop)
88 |
89 | # reshape flops back to list. len == num_conv
90 | table_flops = []
91 | table_flops.append(torch.Tensor(flops[:num_width]).to(device))
92 | for layer_idx in range(num_conv-1):
93 | tmp = flops[num_width + layer_idx*num_width**2:\
94 | num_width + (layer_idx+1)*num_width**2]
95 | assert len(tmp) == num_width ** 2, 'need have %d elements in %d layer'%(num_width**2, layer_idx+1)
96 | table_flops.append(torch.Tensor(tmp).to(device))
97 |
98 | return table_flops
99 |
100 |
101 | def print_model_param_nums(model, multiply_adds=False):
102 | total = sum([param.nelement() for param in model.parameters()])
103 | print(' + Number of original params: %.8fM' % (total / 1e6))
104 | return total
105 |
106 |
107 | def print_model_param_flops(model, input_res, multiply_adds=False):
108 |
109 | prods = {}
110 | def save_hook(name):
111 | def hook_per(self, input, output):
112 | prods[name] = np.prod(input[0].shape)
113 | return hook_per
114 |
115 | list_conv=[]
116 | def conv_hook(self, input, output):
117 | batch_size, input_channels, input_height, input_width = input[0].size()
118 | output_channels, output_height, output_width = output[0].size()
119 |
120 | kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups)
121 | bias_ops = 1 if self.bias is not None else 0
122 |
123 | params = output_channels * (kernel_ops + bias_ops)
124 | flops = (kernel_ops * (2 if multiply_adds else 1) + bias_ops) * output_channels * output_height * output_width * batch_size
125 | list_conv.append(flops)
126 |
127 | list_linear=[]
128 | def linear_hook(self, input, output):
129 | batch_size = input[0].size(0) if input[0].dim() == 2 else 1
130 |
131 | weight_ops = self.weight.nelement() * (2 if multiply_adds else 1)
132 | bias_ops = self.bias.nelement() if self.bias is not None else 0
133 |
134 | flops = batch_size * (weight_ops + bias_ops)
135 | list_linear.append(flops)
136 |
137 | list_bn=[]
138 | def bn_hook(self, input, output):
139 | list_bn.append(input[0].nelement() * 2)
140 |
141 | list_relu=[]
142 | def relu_hook(self, input, output):
143 | list_relu.append(input[0].nelement())
144 |
145 | list_pooling=[]
146 | def pooling_hook(self, input, output):
147 | batch_size, input_channels, input_height, input_width = input[0].size()
148 | output_channels, output_height, output_width = output[0].size()
149 |
150 | kernel_ops = self.kernel_size * self.kernel_size
151 | bias_ops = 0
152 | params = 0
153 | flops = (kernel_ops + bias_ops) * output_channels * output_height * output_width * batch_size
154 | list_pooling.append(flops)
155 |
156 | list_upsample=[]
157 | # For bilinear upsample
158 | def upsample_hook(self, input, output):
159 | batch_size, input_channels, input_height, input_width = input[0].size()
160 | output_channels, output_height, output_width = output[0].size()
161 |
162 | flops = output_height * output_width * output_channels * batch_size * 12
163 | list_upsample.append(flops)
164 |
165 | def foo(net):
166 | childrens = list(net.children())
167 | if not childrens:
168 | if isinstance(net, torch.nn.Conv2d):
169 | net.register_forward_hook(conv_hook)
170 | if isinstance(net, torch.nn.Linear):
171 | net.register_forward_hook(linear_hook)
172 | # if isinstance(net, torch.nn.BatchNorm2d):
173 | # net.register_forward_hook(bn_hook)
174 | # if isinstance(net, torch.nn.ReLU):
175 | # net.register_forward_hook(relu_hook)
176 | # if isinstance(net, torch.nn.MaxPool2d) or isinstance(net, torch.nn.AvgPool2d):
177 | # net.register_forward_hook(pooling_hook)
178 | # if isinstance(net, torch.nn.Upsample):
179 | # net.register_forward_hook(upsample_hook)
180 | return
181 | for c in childrens:
182 | foo(c)
183 |
184 | model = model.cuda()
185 | foo(model)
186 | input_ = torch.rand(3, 3, input_res, input_res).cuda()
187 | input_.requires_grad = True
188 | out = model(input_)
189 |
190 | total_flops = (sum(list_conv) + sum(list_linear) + sum(list_bn) + sum(list_relu) + sum(list_pooling) + sum(list_upsample))
191 | total_flops /= 3
192 |
193 | print(' + Number of FLOPs of original model: %.8fG' % (total_flops / 1e9))
194 | # print('list_conv', list_conv)
195 | # print('list_linear', list_linear)
196 | # print('list_bn', list_bn)
197 | # print('list_relu', list_relu)
198 | # print('list_pooling', list_pooling)
199 |
200 | return total_flops
201 |
202 |
--------------------------------------------------------------------------------
/Cifar/utils/pytorch_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torch.nn as nn
4 | import sys
5 |
6 | def _make_divisible(v, divisor=8, min_value=None):
7 | """
8 | This function is taken from the original tf repo.
9 | It ensures that all layers have a channel number that is divisible by 8
10 | It can be seen here:
11 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
12 | :param v:
13 | :param divisor:
14 | :param min_value:
15 | :return:
16 | """
17 | if min_value is None:
18 | min_value = divisor
19 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
20 | # Make sure that round down does not go down by more than 10%.
21 | if new_v < 0.9 * v:
22 | new_v += divisor
23 | return new_v
24 |
25 | def build_activation(act_func, inplace=True):
26 | if act_func == 'relu':
27 | return nn.ReLU(inplace=inplace)
28 | elif act_func == 'relu6':
29 | return nn.ReLU6(inplace=inplace)
30 | elif act_func == 'tanh':
31 | return nn.Tanh()
32 | elif act_func == 'sigmoid':
33 | return nn.Sigmoid()
34 | elif act_func is None:
35 | return None
36 | else:
37 | raise ValueError('do not support: %s' % act_func)
38 |
39 | def raw2cfg(model, raw_ratios, flops, p=False, div=8):
40 | left = 0
41 | right = 50
42 | scale = 0
43 | cfg = None
44 | current_flops = 0
45 | base_channels = model.config['cfg_base']
46 | cnt = 0
47 | while (True):
48 | cnt += 1
49 | scale = (left + right) / 2
50 | scaled_ratios = raw_ratios * scale
51 | for i in range(len(scaled_ratios)):
52 | scaled_ratios[i] = max(0.1, scaled_ratios[i])
53 | scaled_ratios[i] = min(1, scaled_ratios[i])
54 | cfg = (base_channels * scaled_ratios).astype(int).tolist()
55 | for i in range(len(cfg)):
56 | cfg[i] = _make_divisible(cfg[i], div) # 8 divisible channels
57 | current_flops = model.cfg2flops(cfg)
58 | if cnt > 20:
59 | break
60 | if abs(current_flops - flops) / flops < 0.01:
61 | break
62 | if p:
63 | print(str(current_flops)+'---'+str(flops)+'---left: '+str(left)+'---right: '+str(right)+'---cfg: '+str(cfg))
64 | if current_flops < flops:
65 | left = scale
66 | elif current_flops > flops:
67 | right = scale
68 | else:
69 | break
70 | return cfg
71 |
72 | def weight2mask(weight, keep_c): # simple L1 pruning
73 | weight_copy = weight.abs().clone()
74 | L1_norm = torch.sum(weight_copy, dim=(1,2,3))
75 | arg_max = torch.argsort(L1_norm, descending=True)
76 | arg_max_rev = arg_max[:keep_c].tolist()
77 | mask = np.zeros(weight.shape[0])
78 | mask[arg_max_rev] = 1
79 | return mask
80 |
81 | def get_unpruned_weights(model, model_origin):
82 | masks = []
83 | for [m0, m1] in zip(model_origin.named_modules(), model.named_modules()):
84 | if isinstance(m0[1], nn.Conv2d):
85 | if m0[1].weight.data.shape!=m1[1].weight.data.shape:
86 | flag = False
87 | if m0[1].weight.data.shape[1]!=m1[1].weight.data.shape[1]:
88 | assert len(masks)>0, "masks is empty!"
89 | if m0[0].endswith('downsample.conv'):
90 | if model.config['depth']>=50:
91 | mask = masks[-4]
92 | else:
93 | mask = masks[-3]
94 | else:
95 | mask = masks[-1]
96 | idx = np.squeeze(np.argwhere(mask))
97 | if idx.size == 1:
98 | idx = np.resize(idx, (1,))
99 | w = m0[1].weight.data[:, idx.tolist(), :, :].clone()
100 | flag = True
101 | if m0[1].weight.data.shape[0]==m1[1].weight.data.shape[0]:
102 | masks.append(None)
103 | if m0[1].weight.data.shape[0]!=m1[1].weight.data.shape[0]:
104 | if m0[0].endswith('downsample.conv'):
105 | mask = masks[-1]
106 | else:
107 | if flag:
108 | mask = weight2mask(w.clone(), m1[1].weight.data.shape[0])
109 | else:
110 | mask = weight2mask(m0[1].weight.data, m1[1].weight.data.shape[0])
111 | idx = np.squeeze(np.argwhere(mask))
112 | if idx.size == 1:
113 | idx = np.resize(idx, (1,))
114 | if flag:
115 | w = w[idx.tolist(), :, :, :].clone()
116 | else:
117 | w = m0[1].weight.data[idx.tolist(), :, :, :].clone()
118 | m1[1].weight.data = w.clone()
119 | masks.append(mask)
120 | continue
121 | else:
122 | m1[1].weight.data = m0[1].weight.data.clone()
123 | masks.append(None)
124 | elif isinstance(m0[1], nn.BatchNorm2d):
125 | assert isinstance(m1[1], nn.BatchNorm2d), "There should not be bn layer here."
126 | if m0[1].weight.data.shape!=m1[1].weight.data.shape:
127 | mask = masks[-1]
128 | idx = np.squeeze(np.argwhere(mask))
129 | if idx.size == 1:
130 | idx = np.resize(idx, (1,))
131 | m1[1].weight.data = m0[1].weight.data[idx.tolist()].clone()
132 | m1[1].bias.data = m0[1].bias.data[idx.tolist()].clone()
133 | m1[1].running_mean = m0[1].running_mean[idx.tolist()].clone()
134 | m1[1].running_var = m0[1].running_var[idx.tolist()].clone()
135 | continue
136 | m1[1].weight.data = m0[1].weight.data.clone()
137 | m1[1].bias.data = m0[1].bias.data.clone()
138 | m1[1].running_mean = m0[1].running_mean.clone()
139 | m1[1].running_var = m0[1].running_var.clone()
140 |
141 | # noinspection PyUnresolvedReferences
142 | def cross_entropy_with_label_smoothing(pred, target, label_smoothing=0.1):
143 | logsoftmax = nn.LogSoftmax(dim=1)
144 | n_classes = pred.size(1)
145 | # convert to one-hot
146 | target = torch.unsqueeze(target, 1)
147 | soft_target = torch.zeros_like(pred)
148 | soft_target.scatter_(1, target, 1)
149 | # label smoothing
150 | soft_target = soft_target * (1 - label_smoothing) + label_smoothing / n_classes
151 | return torch.mean(torch.sum(- soft_target * logsoftmax(pred), 1))
152 |
153 |
154 | def count_parameters(model):
155 | total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
156 | return total_params
157 |
158 |
159 | def detach_variable(inputs):
160 | if isinstance(inputs, tuple):
161 | return tuple([detach_variable(x) for x in inputs])
162 | else:
163 | x = inputs.detach()
164 | x.requires_grad = inputs.requires_grad
165 | return x
166 |
167 |
168 | def accuracy(output, target, topk=(1,)):
169 | """ Computes the precision@k for the specified values of k """
170 | maxk = max(topk)
171 | batch_size = target.size(0)
172 |
173 | _, pred = output.topk(maxk, 1, True, True)
174 | pred = pred.t()
175 | correct = pred.eq(target.view(1, -1).expand_as(pred))
176 |
177 | res = []
178 | for k in topk:
179 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
180 | res.append(correct_k.mul_(100.0 / batch_size))
181 | return res
182 |
183 |
184 | class AverageMeter(object):
185 | """
186 | Computes and stores the average and current value
187 | Copied from: https://github.com/pytorch/examples/blob/master/imagenet/main.py
188 | """
189 |
190 | def __init__(self):
191 | self.val = 0
192 | self.avg = 0
193 | self.sum = 0
194 | self.count = 0
195 |
196 | def reset(self):
197 | self.val = 0
198 | self.avg = 0
199 | self.sum = 0
200 | self.count = 0
201 |
202 | def update(self, val, n=1):
203 | self.val = val
204 | self.sum += val * n
205 | self.count += n
206 | self.avg = self.sum / self.count
207 |
--------------------------------------------------------------------------------
/Imagenet/utils/pytorch_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torch.nn as nn
4 | import sys
5 |
6 | def _make_divisible(v, divisor=8, min_value=None):
7 | """
8 | This function is taken from the original tf repo.
9 | It ensures that all layers have a channel number that is divisible by 8
10 | It can be seen here:
11 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
12 | :param v:
13 | :param divisor:
14 | :param min_value:
15 | :return:
16 | """
17 | if min_value is None:
18 | min_value = divisor
19 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
20 | # Make sure that round down does not go down by more than 10%.
21 | if new_v < 0.9 * v:
22 | new_v += divisor
23 | return new_v
24 |
25 | def build_activation(act_func, inplace=True):
26 | if act_func == 'relu':
27 | return nn.ReLU(inplace=inplace)
28 | elif act_func == 'relu6':
29 | return nn.ReLU6(inplace=inplace)
30 | elif act_func == 'tanh':
31 | return nn.Tanh()
32 | elif act_func == 'sigmoid':
33 | return nn.Sigmoid()
34 | elif act_func is None:
35 | return None
36 | else:
37 | raise ValueError('do not support: %s' % act_func)
38 |
39 | def raw2cfg(model, raw_ratios, flops, p=False, div=8):
40 | left = 0
41 | right = 50
42 | scale = 0
43 | cfg = None
44 | current_flops = 0
45 | base_channels = model.config['cfg_base']
46 | cnt = 0
47 | while (True):
48 | cnt += 1
49 | scale = (left + right) / 2
50 | scaled_ratios = raw_ratios * scale
51 | for i in range(len(scaled_ratios)):
52 | scaled_ratios[i] = max(0.1, scaled_ratios[i])
53 | scaled_ratios[i] = min(1, scaled_ratios[i])
54 | cfg = (base_channels * scaled_ratios).astype(int).tolist()
55 | for i in range(len(cfg)):
56 | cfg[i] = _make_divisible(cfg[i], div) # 8 divisible channels
57 | current_flops = model.cfg2flops(cfg)
58 | if cnt > 20:
59 | break
60 | if abs(current_flops - flops) / flops < 0.01:
61 | break
62 | if p:
63 | print(str(current_flops)+'---'+str(flops)+'---left: '+str(left)+'---right: '+str(right)+'---cfg: '+str(cfg))
64 | if current_flops < flops:
65 | left = scale
66 | elif current_flops > flops:
67 | right = scale
68 | else:
69 | break
70 | return cfg
71 |
72 | def weight2mask(weight, keep_c): # simple L1 pruning
73 | weight_copy = weight.abs().clone()
74 | L1_norm = torch.sum(weight_copy, dim=(1,2,3))
75 | arg_max = torch.argsort(L1_norm, descending=True)
76 | arg_max_rev = arg_max[:keep_c].tolist()
77 | mask = np.zeros(weight.shape[0])
78 | mask[arg_max_rev] = 1
79 | return mask
80 |
81 | def get_unpruned_weights(model, model_origin):
82 | masks = []
83 | for [m0, m1] in zip(model_origin.named_modules(), model.named_modules()):
84 | if isinstance(m0[1], nn.Conv2d):
85 | if m0[1].weight.data.shape!=m1[1].weight.data.shape:
86 | flag = False
87 | if m0[1].weight.data.shape[1]!=m1[1].weight.data.shape[1]:
88 | assert len(masks)>0, "masks is empty!"
89 | if m0[0].endswith('downsample.conv'):
90 | if model.config['depth']>=50:
91 | mask = masks[-4]
92 | else:
93 | mask = masks[-3]
94 | else:
95 | mask = masks[-1]
96 | idx = np.squeeze(np.argwhere(mask))
97 | if idx.size == 1:
98 | idx = np.resize(idx, (1,))
99 | w = m0[1].weight.data[:, idx.tolist(), :, :].clone()
100 | flag = True
101 | if m0[1].weight.data.shape[0]==m1[1].weight.data.shape[0]:
102 | masks.append(None)
103 | if m0[1].weight.data.shape[0]!=m1[1].weight.data.shape[0]:
104 | if m0[0].endswith('downsample.conv'):
105 | mask = masks[-1]
106 | else:
107 | if flag:
108 | mask = weight2mask(w.clone(), m1[1].weight.data.shape[0])
109 | else:
110 | mask = weight2mask(m0[1].weight.data, m1[1].weight.data.shape[0])
111 | idx = np.squeeze(np.argwhere(mask))
112 | if idx.size == 1:
113 | idx = np.resize(idx, (1,))
114 | if flag:
115 | w = w[idx.tolist(), :, :, :].clone()
116 | else:
117 | w = m0[1].weight.data[idx.tolist(), :, :, :].clone()
118 | m1[1].weight.data = w.clone()
119 | masks.append(mask)
120 | continue
121 | else:
122 | m1[1].weight.data = m0[1].weight.data.clone()
123 | masks.append(None)
124 | elif isinstance(m0[1], nn.BatchNorm2d):
125 | assert isinstance(m1[1], nn.BatchNorm2d), "There should not be bn layer here."
126 | if m0[1].weight.data.shape!=m1[1].weight.data.shape:
127 | mask = masks[-1]
128 | idx = np.squeeze(np.argwhere(mask))
129 | if idx.size == 1:
130 | idx = np.resize(idx, (1,))
131 | m1[1].weight.data = m0[1].weight.data[idx.tolist()].clone()
132 | m1[1].bias.data = m0[1].bias.data[idx.tolist()].clone()
133 | m1[1].running_mean = m0[1].running_mean[idx.tolist()].clone()
134 | m1[1].running_var = m0[1].running_var[idx.tolist()].clone()
135 | continue
136 | m1[1].weight.data = m0[1].weight.data.clone()
137 | m1[1].bias.data = m0[1].bias.data.clone()
138 | m1[1].running_mean = m0[1].running_mean.clone()
139 | m1[1].running_var = m0[1].running_var.clone()
140 |
141 | # noinspection PyUnresolvedReferences
142 | def cross_entropy_with_label_smoothing(pred, target, label_smoothing=0.1):
143 | logsoftmax = nn.LogSoftmax(dim=1)
144 | n_classes = pred.size(1)
145 | # convert to one-hot
146 | target = torch.unsqueeze(target, 1)
147 | soft_target = torch.zeros_like(pred)
148 | soft_target.scatter_(1, target, 1)
149 | # label smoothing
150 | soft_target = soft_target * (1 - label_smoothing) + label_smoothing / n_classes
151 | return torch.mean(torch.sum(- soft_target * logsoftmax(pred), 1))
152 |
153 |
154 | def count_parameters(model):
155 | total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
156 | return total_params
157 |
158 |
159 | def detach_variable(inputs):
160 | if isinstance(inputs, tuple):
161 | return tuple([detach_variable(x) for x in inputs])
162 | else:
163 | x = inputs.detach()
164 | x.requires_grad = inputs.requires_grad
165 | return x
166 |
167 |
168 | def accuracy(output, target, topk=(1,)):
169 | """ Computes the precision@k for the specified values of k """
170 | maxk = max(topk)
171 | batch_size = target.size(0)
172 |
173 | _, pred = output.topk(maxk, 1, True, True)
174 | pred = pred.t()
175 | correct = pred.eq(target.view(1, -1).expand_as(pred))
176 |
177 | res = []
178 | for k in topk:
179 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
180 | res.append(correct_k.mul_(100.0 / batch_size))
181 | return res
182 |
183 |
184 | class AverageMeter(object):
185 | """
186 | Computes and stores the average and current value
187 | Copied from: https://github.com/pytorch/examples/blob/master/imagenet/main.py
188 | """
189 |
190 | def __init__(self):
191 | self.val = 0
192 | self.avg = 0
193 | self.sum = 0
194 | self.count = 0
195 |
196 | def reset(self):
197 | self.val = 0
198 | self.avg = 0
199 | self.sum = 0
200 | self.count = 0
201 |
202 | def update(self, val, n=1):
203 | self.val = val
204 | self.sum += val * n
205 | self.count += n
206 | self.avg = self.sum / self.count
207 |
--------------------------------------------------------------------------------
/ToyExample/utils/pytorch_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torch.nn as nn
4 | import sys
5 |
6 | def _make_divisible(v, divisor=8, min_value=None):
7 | """
8 | This function is taken from the original tf repo.
9 | It ensures that all layers have a channel number that is divisible by 8
10 | It can be seen here:
11 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
12 | :param v:
13 | :param divisor:
14 | :param min_value:
15 | :return:
16 | """
17 | if min_value is None:
18 | min_value = divisor
19 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
20 | # Make sure that round down does not go down by more than 10%.
21 | if new_v < 0.9 * v:
22 | new_v += divisor
23 | return new_v
24 |
25 | def build_activation(act_func, inplace=True):
26 | if act_func == 'relu':
27 | return nn.ReLU(inplace=inplace)
28 | elif act_func == 'relu6':
29 | return nn.ReLU6(inplace=inplace)
30 | elif act_func == 'tanh':
31 | return nn.Tanh()
32 | elif act_func == 'sigmoid':
33 | return nn.Sigmoid()
34 | elif act_func is None:
35 | return None
36 | else:
37 | raise ValueError('do not support: %s' % act_func)
38 |
39 | def raw2cfg(model, raw_ratios, flops, p=False, div=8):
40 | left = 0
41 | right = 50
42 | scale = 0
43 | cfg = None
44 | current_flops = 0
45 | base_channels = model.config['cfg_base']
46 | cnt = 0
47 | while (True):
48 | cnt += 1
49 | scale = (left + right) / 2
50 | scaled_ratios = raw_ratios * scale
51 | for i in range(len(scaled_ratios)):
52 | scaled_ratios[i] = max(0.1, scaled_ratios[i])
53 | scaled_ratios[i] = min(1, scaled_ratios[i])
54 | cfg = (base_channels * scaled_ratios).astype(int).tolist()
55 | for i in range(len(cfg)):
56 | cfg[i] = _make_divisible(cfg[i], div) # 8 divisible channels
57 | current_flops = model.cfg2flops(cfg)
58 | if cnt > 20:
59 | break
60 | if abs(current_flops - flops) / flops < 0.01:
61 | break
62 | if p:
63 | print(str(current_flops)+'---'+str(flops)+'---left: '+str(left)+'---right: '+str(right)+'---cfg: '+str(cfg))
64 | if current_flops < flops:
65 | left = scale
66 | elif current_flops > flops:
67 | right = scale
68 | else:
69 | break
70 | return cfg
71 |
72 | def weight2mask(weight, keep_c): # simple L1 pruning
73 | weight_copy = weight.abs().clone()
74 | L1_norm = torch.sum(weight_copy, dim=(1,2,3))
75 | arg_max = torch.argsort(L1_norm, descending=True)
76 | arg_max_rev = arg_max[:keep_c].tolist()
77 | mask = np.zeros(weight.shape[0])
78 | mask[arg_max_rev] = 1
79 | return mask
80 |
81 | def get_unpruned_weights(model, model_origin):
82 | masks = []
83 | for [m0, m1] in zip(model_origin.named_modules(), model.named_modules()):
84 | if isinstance(m0[1], nn.Conv2d):
85 | if m0[1].weight.data.shape!=m1[1].weight.data.shape:
86 | flag = False
87 | if m0[1].weight.data.shape[1]!=m1[1].weight.data.shape[1]:
88 | assert len(masks)>0, "masks is empty!"
89 | if m0[0].endswith('downsample.conv'):
90 | if model.config['depth']>=50:
91 | mask = masks[-4]
92 | else:
93 | mask = masks[-3]
94 | else:
95 | mask = masks[-1]
96 | idx = np.squeeze(np.argwhere(mask))
97 | if idx.size == 1:
98 | idx = np.resize(idx, (1,))
99 | w = m0[1].weight.data[:, idx.tolist(), :, :].clone()
100 | flag = True
101 | if m0[1].weight.data.shape[0]==m1[1].weight.data.shape[0]:
102 | masks.append(None)
103 | if m0[1].weight.data.shape[0]!=m1[1].weight.data.shape[0]:
104 | if m0[0].endswith('downsample.conv'):
105 | mask = masks[-1]
106 | else:
107 | if flag:
108 | mask = weight2mask(w.clone(), m1[1].weight.data.shape[0])
109 | else:
110 | mask = weight2mask(m0[1].weight.data, m1[1].weight.data.shape[0])
111 | idx = np.squeeze(np.argwhere(mask))
112 | if idx.size == 1:
113 | idx = np.resize(idx, (1,))
114 | if flag:
115 | w = w[idx.tolist(), :, :, :].clone()
116 | else:
117 | w = m0[1].weight.data[idx.tolist(), :, :, :].clone()
118 | m1[1].weight.data = w.clone()
119 | masks.append(mask)
120 | continue
121 | else:
122 | m1[1].weight.data = m0[1].weight.data.clone()
123 | masks.append(None)
124 | elif isinstance(m0[1], nn.BatchNorm2d):
125 | assert isinstance(m1[1], nn.BatchNorm2d), "There should not be bn layer here."
126 | if m0[1].weight.data.shape!=m1[1].weight.data.shape:
127 | mask = masks[-1]
128 | idx = np.squeeze(np.argwhere(mask))
129 | if idx.size == 1:
130 | idx = np.resize(idx, (1,))
131 | m1[1].weight.data = m0[1].weight.data[idx.tolist()].clone()
132 | m1[1].bias.data = m0[1].bias.data[idx.tolist()].clone()
133 | m1[1].running_mean = m0[1].running_mean[idx.tolist()].clone()
134 | m1[1].running_var = m0[1].running_var[idx.tolist()].clone()
135 | continue
136 | m1[1].weight.data = m0[1].weight.data.clone()
137 | m1[1].bias.data = m0[1].bias.data.clone()
138 | m1[1].running_mean = m0[1].running_mean.clone()
139 | m1[1].running_var = m0[1].running_var.clone()
140 |
141 | # noinspection PyUnresolvedReferences
142 | def cross_entropy_with_label_smoothing(pred, target, label_smoothing=0.1):
143 | logsoftmax = nn.LogSoftmax(dim=1)
144 | n_classes = pred.size(1)
145 | # convert to one-hot
146 | target = torch.unsqueeze(target, 1)
147 | soft_target = torch.zeros_like(pred)
148 | soft_target.scatter_(1, target, 1)
149 | # label smoothing
150 | soft_target = soft_target * (1 - label_smoothing) + label_smoothing / n_classes
151 | return torch.mean(torch.sum(- soft_target * logsoftmax(pred), 1))
152 |
153 |
154 | def count_parameters(model):
155 | total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
156 | return total_params
157 |
158 |
159 | def detach_variable(inputs):
160 | if isinstance(inputs, tuple):
161 | return tuple([detach_variable(x) for x in inputs])
162 | else:
163 | x = inputs.detach()
164 | x.requires_grad = inputs.requires_grad
165 | return x
166 |
167 |
168 | def accuracy(output, target, topk=(1,)):
169 | """ Computes the precision@k for the specified values of k """
170 | maxk = max(topk)
171 | batch_size = target.size(0)
172 |
173 | _, pred = output.topk(maxk, 1, True, True)
174 | pred = pred.t()
175 | correct = pred.eq(target.view(1, -1).expand_as(pred))
176 |
177 | res = []
178 | for k in topk:
179 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
180 | res.append(correct_k.mul_(100.0 / batch_size))
181 | return res
182 |
183 |
184 | class AverageMeter(object):
185 | """
186 | Computes and stores the average and current value
187 | Copied from: https://github.com/pytorch/examples/blob/master/imagenet/main.py
188 | """
189 |
190 | def __init__(self):
191 | self.val = 0
192 | self.avg = 0
193 | self.sum = 0
194 | self.count = 0
195 |
196 | def reset(self):
197 | self.val = 0
198 | self.avg = 0
199 | self.sum = 0
200 | self.count = 0
201 |
202 | def update(self, val, n=1):
203 | self.val = val
204 | self.sum += val * n
205 | self.count += n
206 | self.avg = self.sum / self.count
207 |
--------------------------------------------------------------------------------
/CKA/cka.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import sys
4 |
5 | def gram_linear(x):
6 | """Compute Gram (kernel) matrix for a linear kernel.
7 |
8 | Args:
9 | x: A num_examples x num_features matrix of features.
10 |
11 | Returns:
12 | A num_examples x num_examples Gram matrix of examples.
13 | """
14 | return x.dot(x.T)
15 |
16 |
17 | def gram_rbf(x, threshold=1.0):
18 | """Compute Gram (kernel) matrix for an RBF kernel.
19 |
20 | Args:
21 | x: A num_examples x num_features matrix of features.
22 | threshold: Fraction of median Euclidean distance to use as RBF kernel
23 | bandwidth. (This is the heuristic we use in the paper. There are other
24 | possible ways to set the bandwidth; we didn't try them.)
25 |
26 | Returns:
27 | A num_examples x num_examples Gram matrix of examples.
28 | """
29 | dot_products = x.dot(x.T)
30 | sq_norms = np.diag(dot_products)
31 | sq_distances = -2 * dot_products + sq_norms[:, None] + sq_norms[None, :]
32 | sq_median_distance = np.median(sq_distances)
33 | return np.exp(-sq_distances / (2 * threshold ** 2 * sq_median_distance))
34 |
35 |
36 | def center_gram(gram, unbiased=False):
37 | """Center a symmetric Gram matrix.
38 |
39 | This is equvialent to centering the (possibly infinite-dimensional) features
40 | induced by the kernel before computing the Gram matrix.
41 |
42 | Args:
43 | gram: A num_examples x num_examples symmetric matrix.
44 | unbiased: Whether to adjust the Gram matrix in order to compute an unbiased
45 | estimate of HSIC. Note that this estimator may be negative.
46 |
47 | Returns:
48 | A symmetric matrix with centered columns and rows.
49 | """
50 | if not np.allclose(gram, gram.T):
51 | raise ValueError('Input must be a symmetric matrix.')
52 | gram = gram.copy()
53 |
54 | if unbiased:
55 | # This formulation of the U-statistic, from Szekely, G. J., & Rizzo, M.
56 | # L. (2014). Partial distance correlation with methods for dissimilarities.
57 | # The Annals of Statistics, 42(6), 2382-2412, seems to be more numerically
58 | # stable than the alternative from Song et al. (2007).
59 | n = gram.shape[0]
60 | np.fill_diagonal(gram, 0)
61 | means = np.sum(gram, 0, dtype=np.float64) / (n - 2)
62 | means -= np.sum(means) / (2 * (n - 1))
63 | gram -= means[:, None]
64 | gram -= means[None, :]
65 | np.fill_diagonal(gram, 0)
66 | else:
67 | means = np.mean(gram, 0, dtype=np.float64)
68 | means -= np.mean(means) / 2
69 | gram -= means[:, None]
70 | gram -= means[None, :]
71 |
72 | return gram
73 |
74 |
75 | def cka(gram_x, gram_y, debiased=False):
76 | """Compute CKA.
77 |
78 | Args:
79 | gram_x: A num_examples x num_examples Gram matrix.
80 | gram_y: A num_examples x num_examples Gram matrix.
81 | debiased: Use unbiased estimator of HSIC. CKA may still be biased.
82 |
83 | Returns:
84 | The value of CKA between X and Y.
85 | """
86 | gram_x = center_gram(gram_x, unbiased=debiased)
87 | gram_y = center_gram(gram_y, unbiased=debiased)
88 |
89 | # Note: To obtain HSIC, this should be divided by (n-1)**2 (biased variant) or
90 | # n*(n-3) (unbiased variant), but this cancels for CKA.
91 | scaled_hsic = gram_x.ravel().dot(gram_y.ravel())
92 |
93 | normalization_x = np.linalg.norm(gram_x)
94 | normalization_y = np.linalg.norm(gram_y)
95 | return scaled_hsic / (normalization_x * normalization_y)
96 |
97 |
98 | def _debiased_dot_product_similarity_helper(
99 | xty, sum_squared_rows_x, sum_squared_rows_y, squared_norm_x, squared_norm_y,
100 | n):
101 | """Helper for computing debiased dot product similarity (i.e. linear HSIC)."""
102 | # This formula can be derived by manipulating the unbiased estimator from
103 | # Song et al. (2007).
104 | return (
105 | xty - n / (n - 2.) * sum_squared_rows_x.dot(sum_squared_rows_y)
106 | + squared_norm_x * squared_norm_y / ((n - 1) * (n - 2)))
107 |
108 |
109 | # def feature_space_linear_cka(features_x, features_y, debiased=False):
110 | # """Compute CKA with a linear kernel, in feature space.
111 | #
112 | # This is typically faster than computing the Gram matrix when there are fewer
113 | # features than examples.
114 | #
115 | # Args:
116 | # features_x: A num_examples x num_features matrix of features.
117 | # features_y: A num_examples x num_features matrix of features.
118 | # debiased: Use unbiased estimator of dot product similarity. CKA may still be
119 | # biased. Note that this estimator may be negative.
120 | #
121 | # Returns:
122 | # The value of CKA between X and Y.
123 | # """
124 | # features_x = features_x - torch.mean(features_x, 0, keepdim=True)
125 | # features_y = features_y - torch.mean(features_y, 0, keepdim=True)
126 | #
127 | # a = torch.mm(features_x.t(), features_y)
128 | # b = torch.mm(features_x.t(), features_x)
129 | # c = torch.mm(features_y.t(), features_y)
130 | # dot_product_similarity = torch.linalg.norm(a) ** 2
131 | # normalization_x = torch.linalg.norm(b)
132 | # normalization_y = torch.linalg.norm(c)
133 | #
134 | # if debiased:
135 | # n = features_x.shape[0]
136 | # # Equivalent to np.sum(features_x ** 2, 1) but avoids an intermediate array.
137 | # sum_squared_rows_x = np.einsum('ij,ij->i', features_x, features_x)
138 | # sum_squared_rows_y = np.einsum('ij,ij->i', features_y, features_y)
139 | # squared_norm_x = np.sum(sum_squared_rows_x)
140 | # squared_norm_y = np.sum(sum_squared_rows_y)
141 | #
142 | # dot_product_similarity = _debiased_dot_product_similarity_helper(
143 | # dot_product_similarity, sum_squared_rows_x, sum_squared_rows_y,
144 | # squared_norm_x, squared_norm_y, n)
145 | # normalization_x = np.sqrt(_debiased_dot_product_similarity_helper(
146 | # normalization_x ** 2, sum_squared_rows_x, sum_squared_rows_x,
147 | # squared_norm_x, squared_norm_x, n))
148 | # normalization_y = np.sqrt(_debiased_dot_product_similarity_helper(
149 | # normalization_y ** 2, sum_squared_rows_y, sum_squared_rows_y,
150 | # squared_norm_y, squared_norm_y, n))
151 | #
152 | # return dot_product_similarity / (normalization_x * normalization_y)
153 | def feature_space_linear_cka(features_x, features_y, debiased=False):
154 | """Compute CKA with a linear kernel, in feature space.
155 |
156 | This is typically faster than computing the Gram matrix when there are fewer
157 | features than examples.
158 |
159 | Args:
160 | features_x: A num_examples x num_features matrix of features.
161 | features_y: A num_examples x num_features matrix of features.
162 | debiased: Use unbiased estimator of dot product similarity. CKA may still be
163 | biased. Note that this estimator may be negative.
164 |
165 | Returns:
166 | The value of CKA between X and Y.
167 | """
168 | features_x = features_x - np.mean(features_x, 0, keepdims=True)
169 | features_y = features_y - np.mean(features_y, 0, keepdims=True)
170 |
171 | dot_product_similarity = np.linalg.norm(features_x.T.dot(features_y)) ** 2
172 | normalization_x = np.linalg.norm(features_x.T.dot(features_x))
173 | normalization_y = np.linalg.norm(features_y.T.dot(features_y))
174 |
175 | if debiased:
176 | n = features_x.shape[0]
177 | # Equivalent to np.sum(features_x ** 2, 1) but avoids an intermediate array.
178 | sum_squared_rows_x = np.einsum('ij,ij->i', features_x, features_x)
179 | sum_squared_rows_y = np.einsum('ij,ij->i', features_y, features_y)
180 | squared_norm_x = np.sum(sum_squared_rows_x)
181 | squared_norm_y = np.sum(sum_squared_rows_y)
182 |
183 | dot_product_similarity = _debiased_dot_product_similarity_helper(
184 | dot_product_similarity, sum_squared_rows_x, sum_squared_rows_y,
185 | squared_norm_x, squared_norm_y, n)
186 | normalization_x = np.sqrt(_debiased_dot_product_similarity_helper(
187 | normalization_x ** 2, sum_squared_rows_x, sum_squared_rows_x,
188 | squared_norm_x, squared_norm_x, n))
189 | normalization_y = np.sqrt(_debiased_dot_product_similarity_helper(
190 | normalization_y ** 2, sum_squared_rows_y, sum_squared_rows_y,
191 | squared_norm_y, squared_norm_y, n))
192 |
193 | return dot_product_similarity / (normalization_x * normalization_y)
--------------------------------------------------------------------------------
/Cifar/nets/resnet_cifar.py:
--------------------------------------------------------------------------------
1 | from utils import *
2 | import torch.nn as nn
3 | from collections import OrderedDict
4 | import torch.nn.functional as F
5 | from nets.base_models import *
6 | import torch.nn.init as init
7 | import torch.utils.model_zoo as model_zoo
8 |
9 | def _weights_init(m):
10 | classname = m.__class__.__name__
11 | # print(classname)
12 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
13 | init.kaiming_normal(m.weight)
14 |
15 |
16 | class LambdaLayer(nn.Module):
17 | def __init__(self, lambd):
18 | super(LambdaLayer, self).__init__()
19 | self.lambd = lambd
20 |
21 | def forward(self, x):
22 | return self.lambd(x)
23 |
24 |
25 | class BasicBlock(nn.Module):
26 | expansion = 1
27 | def __init__(self, in_planes, planes, stride=1, affine=True):
28 | super(BasicBlock, self).__init__()
29 | conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
30 | bn1 = nn.BatchNorm2d(planes, affine=affine)
31 | conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
32 | bn2 = nn.BatchNorm2d(planes, affine=affine)
33 |
34 | self.conv_bn1 = nn.Sequential(OrderedDict([('conv', conv1), ('bn', bn1)]))
35 | self.conv_bn2 = nn.Sequential(OrderedDict([('conv', conv2), ('bn', bn2)]))
36 | self.shortcut = nn.Sequential()
37 | if stride != 1 or in_planes != planes:
38 | if stride != 1:
39 | self.shortcut = LambdaLayer(
40 | lambda x: F.pad(x[:, :, ::2, ::2],
41 | (0, 0, 0, 0, (planes - in_planes) // 2,
42 | planes - in_planes - (planes - in_planes) // 2), "constant", 0))
43 | else:
44 | self.shortcut = LambdaLayer(
45 | lambda x: F.pad(x[:, :, :, :],
46 | (0, 0, 0, 0, (planes - in_planes) // 2,
47 | planes - in_planes - (planes - in_planes) // 2), "constant", 0))
48 |
49 | def forward(self, x):
50 | out = F.relu(self.conv_bn1(x))
51 | out = self.conv_bn2(out)
52 | out += self.shortcut(x)
53 | out = F.relu(out)
54 | return out
55 |
56 |
57 | class ResNet_CIFAR(MyNetwork):
58 | def __init__(self, depth=20, num_classes=10, cfg=None, cutout=False):
59 | super(ResNet_CIFAR, self).__init__()
60 | assert (depth - 2) % 6 == 0, 'depth should be 6n+2'
61 | cfg_base = []
62 | n = (depth-2)//6
63 | for i in [16, 32, 64]:
64 | for j in range(n):
65 | cfg_base.append(i)
66 |
67 | if cfg is None:
68 | cfg = cfg_base
69 | num_blocks = []
70 | if depth==20:
71 | num_blocks = [3, 3, 3]
72 | elif depth==32:
73 | num_blocks = [5, 5, 5]
74 | elif depth==44:
75 | num_blocks = [7, 7, 7]
76 | elif depth==56:
77 | num_blocks = [9, 9, 9]
78 | elif depth==110:
79 | num_blocks = [18, 18, 18]
80 | block = BasicBlock
81 | self.cfg_base = cfg_base
82 | self.num_classes = num_classes
83 | self.num_blocks = num_blocks
84 | self.cutout = cutout
85 | self.cfg = cfg
86 | self.in_planes = 16
87 | conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
88 | bn1 = nn.BatchNorm2d(16)
89 | self.conv_bn = nn.Sequential(OrderedDict([('conv', conv1), ('bn', bn1)]))
90 | self.layer1 = self._make_layer(block, cfg[0:n], num_blocks[0], stride=1)
91 | self.layer2 = self._make_layer(block, cfg[n:2*n], num_blocks[1], stride=2)
92 | self.layer3 = self._make_layer(block, cfg[2*n:], num_blocks[2], stride=2)
93 | self.pool = nn.AdaptiveAvgPool2d(1)
94 | self.linear = nn.Linear(cfg[-1], num_classes)
95 | self.apply(_weights_init)
96 |
97 | def _make_layer(self, block, planes, num_blocks, stride):
98 | strides = [stride] + [1]*(num_blocks-1)
99 | layers = []
100 | for i in range(len(strides)):
101 | layers.append(('block_%d'%i, block(self.in_planes, planes[i], strides[i])))
102 | self.in_planes = planes[i]
103 | return nn.Sequential(OrderedDict(layers))
104 |
105 | def forward(self, x):
106 | if self.training and self.cutout:
107 | with torch.no_grad():
108 | x = cutout_batch(x, 16)
109 | out = F.relu(self.conv_bn(x))
110 | out = self.layer1(out)
111 | out = self.layer2(out)
112 | out = self.layer3(out)
113 | out = self.pool(out)
114 | out = out.view(out.size(0), -1)
115 | out = self.linear(out)
116 | return out
117 |
118 | def feature_extract(self, x):
119 | if self.training and self.cutout:
120 | with torch.no_grad():
121 | x = cutout_batch(x, 16)
122 | tensor = []
123 | out = F.relu(self.conv_bn(x))
124 | for i in [self.layer1, self.layer2, self.layer3]:
125 | for _layer in i:
126 | out = _layer(out)
127 | if type(_layer) is BasicBlock:
128 | tensor.append(out)
129 |
130 | return tensor
131 |
132 | def cfg2params(self, cfg):
133 | params = 0
134 | params += (3 * 3 * 3 * 16 + 16 * 2) # conv1+bn1
135 | in_c = 16
136 | cfg_idx = 0
137 | for i in range(3):
138 | num_blocks = self.num_blocks[i]
139 | for j in range(num_blocks):
140 | c = cfg[cfg_idx]
141 | params += (in_c * 3 * 3 * c + 2 * c + c * 3 * 3 * c + 2 * c) # per block params
142 | if in_c != c:
143 | params += in_c * c # shortcut
144 | in_c = c
145 | cfg_idx += 1
146 | params += (self.cfg[-1] + 1) * self.num_classes # fc layer
147 | return params
148 |
149 | def cfg2flops(self, cfg): # to simplify, only count convolution flops
150 | size = 32
151 | flops = 0
152 | flops += (3 * 3 * 3 * 16 * 32 * 32 + 16 * 32 * 32 * 4) # conv1+bn1
153 | in_c = 16
154 | cfg_idx = 0
155 | for i in range(3):
156 | num_blocks = self.num_blocks[i]
157 | if i==1 or i==2:
158 | size = size // 2
159 | for j in range(num_blocks):
160 | c = cfg[cfg_idx]
161 | flops += (in_c * 3 * 3 * c * size * size + c * size * size * 4 + c * 3 * 3 * c * size * size + c * size * size * 4) # per block flops
162 | if in_c != c:
163 | flops += in_c * c * size * size # shortcut
164 | in_c = c
165 | cfg_idx += 1
166 | flops += (2 * self.cfg[-1] + 1) * self.num_classes # fc layer
167 | return flops
168 |
169 | def cfg2flops_perlayer(self, cfg, length): # to simplify, only count convolution flops
170 | size = 32
171 | flops_singlecfg = [0 for j in range(length)]
172 | flops_doublecfg = np.zeros((length, length))
173 | flops_squarecfg = [0 for j in range(length)]
174 |
175 | in_c = 16
176 | cfg_idx = 0
177 | for i in range(3):
178 | num_blocks = self.num_blocks[i]
179 | if i==1 or i==2:
180 | size = size // 2
181 | for j in range(num_blocks):
182 | c = cfg[cfg_idx]
183 | if i==0 and j==0:
184 | flops_singlecfg[cfg_idx] += (c * size * size * 4 + c * size * size * 4 + in_c * 3 * 3 * c * size * size)
185 | flops_squarecfg[cfg_idx] += c * 3 * 3 * c * size * size
186 | else:
187 | flops_singlecfg[cfg_idx] += (c * size * size * 4 + c * size * size * 4)
188 | flops_doublecfg[cfg_idx-1][cfg_idx] += in_c * 3 * 3 * c * size * size
189 | flops_doublecfg[cfg_idx][cfg_idx-1] += in_c * 3 * 3 * c * size * size
190 | flops_squarecfg[cfg_idx] += (c * 3 * 3 * c * size * size )
191 | if in_c != c:
192 | flops_doublecfg[cfg_idx][cfg_idx-1] += in_c * c * size * size # shortcut
193 | flops_doublecfg[cfg_idx-1][cfg_idx] += in_c * c * size * size
194 | in_c = c
195 | cfg_idx += 1
196 |
197 | flops_singlecfg[-1] += 2 * self.cfg[-1] * self.num_classes # fc layer
198 | return flops_singlecfg, flops_doublecfg, flops_squarecfg
199 |
200 | @property
201 | def config(self):
202 | return {
203 | 'name': self.__class__.__name__,
204 | 'cfg': self.cfg,
205 | 'cfg_base': self.cfg_base,
206 | 'dataset': 'cifar10',
207 | }
--------------------------------------------------------------------------------
/Cifar/utils/utils.py:
--------------------------------------------------------------------------------
1 | """ This file stores some common functions for learners """
2 |
3 | import torch
4 | import torch.nn as nn
5 | import pdb
6 | import math
7 | import numpy as np
8 | import logging, sys
9 | import os
10 | import torch.distributed as dist
11 | import torch.nn.functional as F
12 | from collections import defaultdict
13 |
14 |
15 | class CrossEntropyLabelSmooth(nn.Module):
16 | def __init__(self, num_classes, epsilon):
17 | super(CrossEntropyLabelSmooth, self).__init__()
18 | self.num_classes = num_classes
19 | self.epsilon = epsilon
20 | self.logsoftmax = nn.LogSoftmax(dim=1)
21 |
22 | def forward(self, inputs, targets):
23 | log_probs = self.logsoftmax(inputs)
24 | targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
25 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
26 | loss = (-targets * log_probs).mean(0).sum()
27 | return loss
28 |
29 |
30 | class AverageMeter(object):
31 | """Computes and stores the average and current value"""
32 | def __init__(self, name, fmt=':f'):
33 | self.name = name
34 | self.fmt = fmt
35 | self.reset()
36 |
37 | def reset(self):
38 | self.val = 0
39 | self.avg = 0
40 | self.sum = 0
41 | self.count = 0
42 |
43 | def update(self, val, n=1):
44 | self.val = val
45 | self.sum += val * n
46 | self.count += n
47 | self.avg = self.sum / self.count
48 |
49 | def __str__(self):
50 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
51 | return fmtstr.format(**self.__dict__)
52 |
53 |
54 | class ProgressMeter(object):
55 | """ Code taken from torchvision """
56 | def __init__(self, num_batches, *meters, prefix=""):
57 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
58 | self.meters = meters
59 | self.prefix = prefix
60 |
61 | def show(self, batch):
62 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
63 | entries += [str(meter) for meter in self.meters]
64 | logging.info(' '.join(entries))
65 | # print(' '.join(entries))
66 |
67 | def _get_batch_fmtstr(self, num_batches):
68 | num_digits = len(str(num_batches // 1))
69 | fmt = '{:' + str(num_digits) + 'd}'
70 | return '[' + fmt + '/' + fmt.format(num_batches) + ']'
71 |
72 |
73 | def accuracy(output, target, topk=(1,)):
74 | """Computes the accuracy over the k top predictions for the specified values of k
75 | Code taken from torchvision.
76 | """
77 |
78 | with torch.no_grad():
79 | maxk = max(topk)
80 | batch_size = target.size(0)
81 |
82 | _, pred = output.topk(maxk, 1, True, True)
83 | pred = pred.t()
84 | correct = pred.eq(target.view(1, -1).expand_as(pred))
85 |
86 | res = []
87 | for k in topk:
88 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
89 | res.append(correct_k.mul_(100.0 / batch_size))
90 | return res
91 |
92 | def accuracy_seen_unseen(output, target, seen_classes, topk=(1.)):
93 | """ Compute the acc of seen classes and unseen classes seperately."""
94 | seen_classes = list(range(seen_classes))
95 | with torch.no_grad():
96 | maxk = max(topk)
97 | batch_size = target.size(0)
98 |
99 | _, pred = output.topk(maxk, 1, True, True)
100 | pred = pred.t()
101 | correct = pred.eq(target.view(1, -1).expand_as(pred))
102 |
103 | seen_ind = torch.zeros(target.shape[0]).byte().cuda()
104 | unseen_ind = torch.ones(target.shape[0]).byte().cuda()
105 | for v in seen_classes:
106 | seen_ind += (target == v)
107 | unseen_ind -= seen_ind
108 |
109 | seen_correct = pred[:, seen_ind].eq(target[seen_ind].view(1, -1).expand_as(pred[:, seen_ind]))
110 | unseen_correct = pred[:, unseen_ind].eq(target[unseen_ind].view(1, -1).expand_as(pred[:, unseen_ind]))
111 |
112 | seen_num = seen_correct.shape[1]
113 | res = []
114 | seen_accs = []
115 | unseen_accs = []
116 | for k in topk:
117 | seen_correct_k = seen_correct[:k].view(-1).float().sum(0, keepdim=True)
118 | seen_accs.append(seen_correct_k.mul_(100.0 / (seen_num+1e-10)))
119 | unseen_correct_k = unseen_correct[:k].view(-1).float().sum(0, keepdim=True)
120 | unseen_accs.append(unseen_correct_k.mul_(100.0 / (batch_size-seen_num+1e-10)))
121 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
122 | res.append(correct_k.mul_(100.0 / batch_size))
123 |
124 | return res, seen_accs, unseen_accs, seen_num
125 |
126 | def is_last_layer(layer):
127 | W = layer.weight
128 | if W.ndimension() == 2 and (W.shape[0] == 10 or W.shape[0] == 100 or W.shape[0] == 1000):
129 | return True
130 | else:
131 | return False
132 |
133 |
134 | def is_first_layer(layer):
135 | if isinstance(layer, nn.Conv2d):
136 | W = layer.weight
137 | if W.ndimension() == 4 and (W.shape[1] == 3 or W.shape[1] == 1):
138 | return True
139 | else:
140 | return False
141 | else:
142 | return False
143 |
144 |
145 | def reinitialize_conv_weights(model, init_first=False):
146 | """ Only re-initialize the kernels for conv layers """
147 | print("re-intializing conv weights done. evaluating...")
148 | for W in model.parameters():
149 | if W.ndimension() == 4:
150 | if W.shape[1] == 3 and not init_first:
151 | continue # do not init first conv layer
152 | nn.init.kaiming_uniform_(W, a = math.sqrt(5))
153 |
154 |
155 | def weights_init(m):
156 | """ default param initializer in pytorch. """
157 | if isinstance(m, nn.Conv2d):
158 | n = m.in_channels
159 | for k in m.kernel_size:
160 | n *= k
161 | stdv = 1. / math.sqrt(n)
162 | m.weight.data.uniform_(-stdv, stdv)
163 | if m.bias is not None:
164 | m.bias.data.uniform_(-stdv, stdv)
165 | elif isinstance(m, nn.Linear):
166 | stdv = 1. / math.sqrt(m.weight.size(1))
167 | m.weight.data.uniform_(-stdv, stdv)
168 | if m.bias is not None:
169 | m.bias.data.uniform_(-stdv, stdv)
170 |
171 |
172 | class keydefaultdict(defaultdict):
173 | def __missing__(self, key):
174 | if self.default_factory is None:
175 | raise KeyError(key)
176 | else:
177 | ret = self[key] = self.default_factory(key)
178 | return ret
179 |
180 |
181 | class WarmUpCosineLRScheduler:
182 | """
183 | update lr every step
184 | """
185 | def __init__(self, optimizer, T_max, eta_min, base_lr, warmup_lr, warmup_steps, last_iter=-1):
186 | # Attach optimizer
187 | if not isinstance(optimizer, torch.optim.Optimizer):
188 | raise TypeError('{} is not an Optimizer'.format(
189 | type(optimizer).__name__))
190 | self.optimizer = optimizer
191 |
192 | # Initialize step and base learning rates
193 | if last_iter == -1:
194 | for group in optimizer.param_groups:
195 | group.setdefault('initial_lr', group['lr'])
196 | else:
197 | for i, group in enumerate(optimizer.param_groups):
198 | if 'initial_lr' not in group:
199 | raise KeyError("param 'initial_lr' is not specified "
200 | "in param_groups[{}] when resuming an optimizer".format(i))
201 | self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
202 | self.last_iter = last_iter
203 |
204 | assert warmup_steps < T_max
205 | self.T_max = T_max
206 | self.eta_min = eta_min
207 |
208 | # warmup settings
209 | self.base_lr = base_lr
210 | self.warmup_lr = warmup_lr
211 | self.warmup_step = warmup_steps
212 | self.warmup_k = None
213 |
214 | def state_dict(self):
215 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
216 |
217 | def load_state_dict(self, state_dict):
218 | self.__dict__.update(state_dict)
219 |
220 | def get_lr(self):
221 | return list(map(lambda group: group['lr'], self.optimizer.param_groups))
222 |
223 | def step(self, this_iter=None):
224 | if this_iter is None:
225 | this_iter = self.last_iter + 1
226 | self.last_iter = this_iter
227 |
228 | # get lr during warmup stage
229 | if self.warmup_step > 0 and this_iter < self.warmup_step:
230 | if self.warmup_k is None:
231 | self.warmup_k = (self.warmup_lr - self.base_lr) / self.warmup_step
232 | scale = (self.warmup_k * this_iter + self.base_lr) / self.base_lr
233 | # get lr during cosine annealing
234 | else:
235 | step_ratio = (this_iter - self.warmup_step) / (self.T_max - self.warmup_step)
236 | scale = self.eta_min + (self.warmup_lr - self.eta_min) * (1 + math.cos(math.pi * step_ratio)) / 2
237 | scale /= self.base_lr
238 |
239 | values = [scale * lr for lr in self.base_lrs]
240 | for param_group, lr in zip(self.optimizer.param_groups, values):
241 | param_group['lr'] = lr
242 |
--------------------------------------------------------------------------------
/Cifar/learners/basic_learner.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | import logging
5 | import os
6 | import pdb
7 | import torch.distributed as dist
8 | from tensorboardX import SummaryWriter
9 | from utils.utils import *
10 | from utils.dist_utils import *
11 | from timeit import default_timer as timer
12 | import time
13 | import copy
14 |
15 |
16 | class BasicLearner(object):
17 | """ Performs vanilla training """
18 | def __init__(self, model, loaders, args, device):
19 | self.args = args
20 | self.device = device
21 | self.model = model
22 | self.__build_path()
23 | self.train_loader, self.test_loader = loaders
24 | self.setup_optim()
25 | self.criterion = nn.CrossEntropyLoss().cuda()
26 | if self.check_is_primary():
27 | self.writer = SummaryWriter(os.path.dirname(self.save_path))
28 | # self.add_graph()
29 |
30 | def train(self, train_sampler=None):
31 | self.warm_up_lr()
32 | for epoch in range(self.args.epochs):
33 | if self.args.distributed:
34 | assert train_sampler != None
35 | train_sampler.set_epoch(epoch)
36 |
37 | self.model.train()
38 | if self.check_is_primary():
39 | logging.info("Training at Epoch: %d" % epoch)
40 | train_acc, train_loss = self.epoch(True)
41 |
42 | if self.check_is_primary():
43 | self.writer.add_scalar('train_acc', train_acc, epoch)
44 | self.writer.add_scalar('train_loss', train_loss, epoch)
45 |
46 | if self.lr_scheduler:
47 | self.lr_scheduler.step()
48 |
49 | if (epoch+1) % self.args.eval_epoch == 0:
50 | # evaluate every GPU, but we only show the results on a single one.!
51 | if self.check_is_primary():
52 | logging.info("Evaluation at Epoch: %d" % epoch)
53 | self.evaluate(True, epoch)
54 |
55 | if self.check_is_primary():
56 | self.save_model()
57 |
58 | def evaluate(self, is_train=False, epoch=None):
59 | self.model.eval()
60 | # NOTE: syncronizing the BN statistics
61 | if self.args.distributed:
62 | sync_bn_stat(self.model, self.args.world_size)
63 |
64 | if not is_train:
65 | self.load_model()
66 |
67 | with torch.no_grad():
68 | test_acc, test_loss = self.epoch(False)
69 |
70 | if is_train and epoch and self.check_is_primary():
71 | self.writer.add_scalar('test_acc', test_acc, epoch)
72 | self.writer.add_scalar('test_loss', test_loss, epoch)
73 | return test_acc, test_loss
74 |
75 | def finetune(self, train_sampler):
76 | self.load_model()
77 | self.evaluate()
78 |
79 | for epoch in range(self.args.epochs):
80 | if self.args.distributed:
81 | assert train_sampler != None
82 | train_sampler.set_epoch(epoch)
83 |
84 | self.model.train()
85 |
86 | # NOTE: use the preset learning rate for all epochs.
87 | ft_acc, ft_loss = self.epoch(True)
88 |
89 | if self.check_is_primary():
90 | self.writer.add_scalar('ft_acc', ft_acc, epoch)
91 | self.writer.add_scalar('ft_loss', ft_loss, epoch)
92 |
93 | # evaluate every k step
94 | if (epoch+1) % self.args.eval_epoch == 0:
95 | if self.check_is_primary():
96 | logging.info("Evaluation at Epoch: %d" % epoch)
97 | self.evaluate(True, epoch)
98 |
99 | # save the model
100 | if self.check_is_primary():
101 | self.save_model()
102 |
103 | def misc(self):
104 | raise NotImplementedError("Misc functions are implemented in sub classes")
105 |
106 | def epoch(self, is_train):
107 | """ Rewrite this function if necessary in the sub-classes. """
108 |
109 | loader = self.train_loader if is_train else self.test_loader
110 |
111 | # setup statistics
112 | batch_time = AverageMeter('Time', ':3.3f')
113 | # data_time = AverageMeter('Data', ':6.3f')
114 | losses = AverageMeter('Loss', ':.4e')
115 | top1 = AverageMeter('Acc@1', ':3.3f')
116 | top5 = AverageMeter('Acc@5', ':3.3f')
117 | metrics = [batch_time, top1, top5, losses]
118 |
119 | loader_len = len(loader)
120 | progress = ProgressMeter(loader_len, *metrics, prefix='Job id: %s, ' % self.args.job_id)
121 | end = time.time()
122 |
123 | for idx, (X, y) in enumerate(loader):
124 |
125 | # data_time.update(time.time() - end)
126 | X, y = X.to(self.device), y.to(self.device)
127 | yp = self.model(X)
128 | loss = self.criterion(yp, y) / self.args.world_size
129 |
130 | acc1, acc5 = accuracy(yp, y, topk=(1, 5))
131 |
132 | reduced_loss = loss.data.clone()
133 | reduced_acc1 = acc1.clone() / self.args.world_size
134 | reduced_acc5 = acc5.clone() / self.args.world_size
135 |
136 | if self.args.distributed:
137 | dist.all_reduce(reduced_loss)
138 | dist.all_reduce(reduced_acc1)
139 | dist.all_reduce(reduced_acc5)
140 |
141 | if is_train:
142 | self.opt.zero_grad()
143 | loss.backward()
144 | if self.args.distributed:
145 | average_gradients(self.model) # NOTE: important
146 | self.opt.step()
147 |
148 | # update statistics
149 | top1.update(reduced_acc1[0].item(), X.shape[0])
150 | top5.update(reduced_acc5[0].item(), X.shape[0])
151 | losses.update(reduced_loss.item(), X.shape[0])
152 | batch_time.update(time.time() - end)
153 | end = time.time()
154 |
155 | # show the training/evaluating statistics
156 | if self.check_is_primary() and ((idx % self.args.print_freq == 0) or (idx+1) % loader_len == 0):
157 | progress.show(idx)
158 |
159 | return top1.avg, losses.avg
160 |
161 | def setup_optim(self):
162 | if self.args.model_type.startswith('model_'):
163 | self.opt = optim.SGD(self.model.parameters(), lr=self.args.lr, \
164 | momentum=self.args.momentum, nesterov=self.args.nesterov, \
165 | weight_decay=self.args.weight_decay)
166 | self.lr_scheduler = optim.lr_scheduler.MultiStepLR(self.opt, milestones=[int(self.args.epochs*0.5), int(self.args.epochs*0.75)])
167 |
168 | elif self.args.model_type.startswith('resnet_') or self.args.model_type.startswith('vgg_'):
169 | self.opt = optim.SGD(self.model.parameters(), lr=self.args.lr, \
170 | momentum=self.args.momentum, nesterov=self.args.nesterov, \
171 | weight_decay=self.args.weight_decay)
172 | if self.args.lr_decy_type == 'cosine':
173 | self.lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(self.opt, self.args.epochs, eta_min=0)
174 | elif self.args.lr_decy_type == 'multi_step':
175 | self.lr_scheduler = optim.lr_scheduler.MultiStepLR(self.opt, milestones=[int(self.args.epochs*0.5), int(self.args.epochs*0.75)])
176 | else:
177 | raise ValueError("Unknown decy type")
178 |
179 | elif self.args.model_type.startswith('wideresnet'):
180 | self.opt = optim.SGD(self.model.parameters(), lr=self.args.lr, \
181 | momentum=self.args.momentum, nesterov=self.args.nesterov, \
182 | weight_decay=self.args.weight_decay)
183 | self.lr_scheduler = optim.lr_scheduler.MultiStepLR(self.opt, milestones=[int(self.args.epochs*0.3), int(self.args.epochs*0.6), int(self.args.epochs*0.8)], gamma=0.2)
184 |
185 | elif self.args.model_type.startswith('mobilenet_v2'):
186 | # default: 150 epochs, 5e-2 lr with cosine, 4e-5 wd, no dropout
187 | self.opt = optim.SGD(self.model.parameters(), lr=self.args.lr, \
188 | momentum=self.args.momentum, nesterov=self.args.nesterov, \
189 | weight_decay=self.args.weight_decay)
190 | self.lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(self.opt, self.args.epochs, eta_min=0)
191 | else:
192 | raise ValueError("Unknown model, failed to initalize optim")
193 |
194 | def add_graph(self):
195 | # create dummy input
196 | x = torch.randn(self.args.batch_size, 3, 32, 32)
197 | with self.writer:
198 | self.writer.add_graph(self.model, (x,))
199 |
200 | def __build_path(self):
201 | if self.args.exec_mode == 'finetune':
202 | self.load_path = self.args.load_path
203 | self.save_path = os.path.join(os.path.dirname(self.load_path), 'model_ft.pt')
204 | elif self.args.exec_mode == 'train':
205 | self.save_path = os.path.join(self.args.save_path, '_'.join([self.args.model_type, self.args.learner]), self.args.job_id, 'model.pt')
206 | self.load_path = self.save_path
207 | else:
208 | self.load_path = self.args.load_path
209 | self.save_path = self.load_path
210 |
211 | def warm_up_lr(self):
212 | if self.args.model_type.endswith('1202') or self.args.model_type.endswith('110'):
213 | for param in self.opt.param_groups:
214 | param['lr'] = 0.1 * self.args.lr
215 | else:
216 | pass
217 |
218 | def check_is_primary(self):
219 | if (self.args.distributed and self.args.rank == 0) or \
220 | not self.args.distributed:
221 | return True
222 | else:
223 | return False
224 |
225 | def save_model(self):
226 | state = {'state_dict': self.model.state_dict(), \
227 | 'optimizer': self.opt.state_dict()}
228 | torch.save(state, self.save_path)
229 | logging.info("Model stored at: " + self.save_path)
230 |
231 | def load_model(self):
232 | if self.args.distributed:
233 | # read parameters to each GPU seperately
234 | loc = 'cuda:{}'.format(torch.cuda.current_device())
235 | checkpoint = torch.load(self.load_path, map_location=loc)
236 | else:
237 | checkpoint = torch.load(self.load_path)
238 |
239 | self.model.load_state_dict(checkpoint['state_dict'])
240 | self.opt.load_state_dict(checkpoint['optimizer'])
241 | logging.info("Model succesfully restored from %s" % self.load_path)
242 |
243 | if self.args.distributed:
244 | broadcast_params(self.model)
245 |
246 |
--------------------------------------------------------------------------------
/Cifar/main.py:
--------------------------------------------------------------------------------
1 | import traceback
2 | import argparse
3 | import torch
4 | import torch.backends.cudnn as cudnn
5 | from nets.model_helper import ModelHelper
6 | from learners.basic_learner import BasicLearner
7 | from data_loader import *
8 | from utils.helper import *
9 | import torch.multiprocessing as mp
10 | import torch.distributed as dist
11 | from utils.dist_utils import *
12 | import os
13 | import numpy as np
14 | import random
15 |
16 | parser = argparse.ArgumentParser()
17 | parser.add_argument('--fix_random', action='store_true', help='fix randomness')
18 | parser.add_argument('--seed', default=1234, type=int)
19 | parser.add_argument('--gpu_id', default='0', help='gpu id')
20 | parser.add_argument('--job_id', default='', help='generated automatically')
21 | parser.add_argument('--exec_mode', default='', choices=['train', 'eval', 'finetune', 'misc'])
22 | parser.add_argument('--print_freq', default=30, type=int, help='training printing frequency')
23 | parser.add_argument('--model_type', default='', help='what model do you want to use')
24 | parser.add_argument('--data_path', default='TODO', help='path to the dataset')
25 | parser.add_argument('--data_aug', action='store_false', help='whether to use data augmentation')
26 | parser.add_argument('--save_path', default='./models', help='save path for models')
27 | parser.add_argument('--load_path', default='', help='path to load the model')
28 | parser.add_argument('--load_path_log', default='record.log', help='suffix to name record.log')
29 | parser.add_argument('--eval_epoch', type=int, default=1, help='perform eval at intervals')
30 |
31 | # multi-process gpu
32 | parser.add_argument('--dist-url', default='tcp://127.0.0.1', type=str,
33 | help='url used to set up distributed training')
34 | parser.add_argument('--num_workers', type=int, default=4, help='num of workers data loaders')
35 | parser.add_argument('--port', default=6669, type=int, help='port of server')
36 | parser.add_argument('--world-size', default=1, type=int,
37 | help='number of nodes for distributed training')
38 | parser.add_argument('--rank', default=0, type=int,
39 | help='node rank for distributed training')
40 | parser.add_argument('--local_rank', type=int,
41 | help='local node rank for distributed learning (useless for now)')
42 | parser.add_argument('--dist-backend', default='nccl', type=str,
43 | help='distributed backend')
44 | parser.add_argument('--distributed', default=False, type=bool, help='use distributed or not. Do NOT SET THIS')
45 |
46 | # learner config
47 | parser_learner = parser.add_argument_group('Learner')
48 | parser_learner.add_argument('--learner', default='', choices=['vanilla', 'trick', 'chann_cifar', 'chann_ilsvrc']) #'chann_search' is excluded for simplicity
49 | parser_learner.add_argument('--lr', type=float, default=1e-1, help='initial learning rate')
50 | parser_learner.add_argument('--lr_decy_type', default='multi_step', choices=['multi_step', 'cosine'], help='lr decy type for model params')
51 | parser_learner.add_argument('--lr_min', type=float, default=2e-4, help='minimal learning rate for cosine decy')
52 | parser_learner.add_argument('--momentum', type=float, default=0.9, help='momentum value')
53 | parser_learner.add_argument('--nesterov', action='store_true')
54 | parser_learner.add_argument('--weight_decay', type=float, default=2e-4, help='weight decay for resnet')
55 | parser_learner.add_argument('--dropout_rate', default=0.0, type=float, help='dropout rate for mobilenet_v2, default: 0')
56 |
57 | parser.add_argument('--gamma', type=float, default=0.99, help='time decay for baseline update')
58 | parser_learner.add_argument('--flops', action='store_true', help='add flops regularization')
59 | parser_learner.add_argument('--orthg_weight', type=float, default=0.001, help='orthognal regularization weight')
60 | parser_learner.add_argument('--ft_schedual' , default='', choices=['follow_meta','fixed'])
61 | parser_learner.add_argument('--ft_proj_lr' , type=float, default=0.001, help='lr in fixed ft_schedual')
62 | parser_learner.add_argument('--norm_constraint' , default='', choices=['regularization','constraint'])
63 | parser_learner.add_argument('--norm_weight' , type=float, default=0.01, help='norm 1 regularization weight')
64 | parser_learner.add_argument('--updt_proj', type=int, default=10, help='intervals to update orthogonal loss')
65 | parser_learner.add_argument('--ortho_type', default='l1', choices=['l1', 'l2'], help='type of norm for orthogonal regularization')
66 |
67 | parser_arc = parser.add_argument_group('Architecture')
68 | parser_arc.add_argument('--cfg', default='', help='decoding cfgs obtained by one-shot model')
69 |
70 | # dataset params
71 | parser_dataset = parser.add_argument_group('Dataset')
72 | parser_dataset.add_argument('--dataset', choices=['cifar10', 'cifar100', 'ilsvrc_12'], default='cifar10', help='dataset to use')
73 | parser_dataset.add_argument('--batch_size', type=int, default=128, help='batch_size for dataset')
74 | parser_dataset.add_argument('--epochs', type=int, default=200, help='resnet:200')
75 | parser_dataset.add_argument('--warmup_epochs', type=int, default=-1, help='epochs training weights only')
76 | parser_dataset.add_argument('--train_portion', type=float, default=0.5, help='portion of training data')
77 | parser_dataset.add_argument('--total_portion', type=float, default=1.0, help='portion of data will be used in searching')
78 |
79 | # trick params
80 | parser_trick = parser.add_argument_group("Tricks")
81 | parser_trick.add_argument('--mixup', action='store_true')
82 | parser_trick.add_argument('--mixup_alpha', type=float, default=1.0)
83 | parser_trick.add_argument('--label_smooth', action='store_true')
84 | parser_trick.add_argument('--label_smooth_eps', type=float, default=0.1)
85 | parser_trick.add_argument('--se', action='store_true')
86 | parser_trick.add_argument('--se_reduction', type=int, default=-1)
87 |
88 | def set_loader(args):
89 | sampler = None
90 |
91 | if args.dataset == 'cifar10':
92 | if args.learner in ['vanilla']:
93 | loaders = cifar10_loader_train(args, num_workers=4)
94 | else:
95 | raise ValueError("Unknown dataset")
96 |
97 | return loaders, sampler
98 |
99 |
100 | def set_learner(args):
101 | if args.learner == 'vanilla':
102 | return BasicLearner
103 | else:
104 | raise ValueError("Unknown learner")
105 |
106 |
107 | def main():
108 | ##################### Preliminaries ##################
109 | args = parser.parse_args()
110 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)
111 |
112 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
113 | ngpus_per_node = torch.cuda.device_count()
114 | args.distributed = True if ngpus_per_node > 1 else False
115 | #####################################################
116 |
117 | try:
118 | if args.distributed:
119 | rank, world_size = init_dist(backend=args.dist_backend, master_ip=args.dist_url, port=args.port)
120 | args.rank = rank
121 | args.world_size = world_size
122 | print("Distributed Enabled. Rank %d initalized" % args.rank)
123 | else:
124 | print("Single model training...")
125 |
126 | if args.fix_random:
127 | # init seed within each thread
128 | manualSeed = args.seed
129 | np.random.seed(manualSeed)
130 | random.seed(manualSeed)
131 | torch.manual_seed(manualSeed)
132 | torch.cuda.manual_seed(manualSeed)
133 | torch.cuda.manual_seed_all(manualSeed)
134 | # NOTE: literally you should uncomment the following, but slower
135 | cudnn.deterministic = True
136 | cudnn.benchmark = False
137 | print('Warning: You have chosen to seed training. '
138 | 'This will turn on the CUDNN deterministic setting, '
139 | 'which can slow down your training considerably! '
140 | 'You may see unexpected behavior when restarting '
141 | 'from checkpoints.')
142 | else:
143 | cudnn.benchmark = True
144 |
145 | if args.rank == 0:
146 | # either single GPU or multi GPU with rank = 0
147 | print("+++++++++++++++++++++++++")
148 | print("torch version:", torch.__version__)
149 | print("+++++++++++++++++++++++++")
150 |
151 | # setup logging
152 | if args.exec_mode in ['train', 'misc']:
153 | args.job_id = generate_job_id()
154 | init_logging(os.path.join(args.save_path, '_'.join([args.model_type, args.learner]), args.job_id, 'record.log'))
155 | elif args.exec_mode == 'finetune':
156 | args.job_id = generate_job_id()
157 | init_logging(os.path.join(os.path.dirname(args.load_path), 'ft_record.log'))
158 | else:
159 | init_logging(os.path.join(os.path.dirname(args.load_path), args.load_path_log))
160 |
161 | print_args(vars(args))
162 | logging.info("Using GPU: "+args.gpu_id)
163 |
164 | # create model
165 | model_helper = ModelHelper()
166 | model = model_helper.get_model(args)
167 | model.cuda()
168 |
169 | if args.distributed:
170 | # share the same initialization
171 | broadcast_params(model)
172 |
173 | loaders, sampler = set_loader(args)
174 | learner_fn = set_learner(args)
175 | learner = learner_fn(model, loaders, args, device)
176 |
177 | if args.exec_mode == 'train':
178 | learner.train(sampler)
179 | elif args.exec_mode == 'eval':
180 | learner.evaluate()
181 | elif args.exec_mode == 'finetune':
182 | learner.finetune(sampler)
183 | elif args.exec_mode == 'misc':
184 | learner.misc()
185 |
186 | if args.distributed:
187 | dist.destroy_process_group()
188 | return 0
189 |
190 | except:
191 | traceback.print_exc()
192 | if args.distributed:
193 | dist.destroy_process_group()
194 | return 1
195 |
196 | if __name__ == '__main__':
197 | main()
198 |
--------------------------------------------------------------------------------